-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathResnetV2 - DistractedDriver.py
More file actions
65 lines (47 loc) · 2.27 KB
/
ResnetV2 - DistractedDriver.py
File metadata and controls
65 lines (47 loc) · 2.27 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# coding: utf-8
# In[1]:
import DistractedDriver
import Shared
import tensorflow as tf
from nets import resnet_v2
slim = tf.contrib.slim
parser = Shared.define_parser(klass='ResnetV2')
parser.add_argument('--depth', nargs='*', help='Special Deep Logits Architecture?')
parser.add_argument('--which', default='original', help='Train on segmented')
parser.add_argument('--resnet', default='50', help='Which Resnet? 50? 101? 152?')
args, unknown_args = parser.parse_known_args()
Shared.DIM = 299
Shared.N_CLASSES = 10
Shared.load_training_data = lambda: DistractedDriver.load_data(progressBar=True, which=args.which)
class ResnetV2:
def __init__(self, model_name, isTesting=False):
Shared.define_model(self, model_name, self.__model)
def __get_init_fn(self):
return Shared.get_init_fn('resnet_v2_{}.ckpt'.format(args.resnet), [
"resnet_v2_{}/logits".format(args.resnet)
# ,
# "resnet_v2_{}/AuxLogits".format(args.resnet)
])
def __model(self):
N_CLASSES = len(DistractedDriver.CLASSES)
# Create the model, use the default arg scope to configure the batch norm parameters.
with slim.arg_scope(resnet_v2.resnet_arg_scope()):
if args.resnet == '50':
self.logits, self.end_points = resnet_v2.resnet_v2_50(self.X_Norm, N_CLASSES, is_training=True)
elif args.resnet == '101':
self.logits, self.end_points = resnet_v2.resnet_v2_101(self.X_Norm, N_CLASSES, is_training=True)
elif args.resnet == '152':
self.logits, self.end_points = resnet_v2.resnet_v2_152(self.X_Norm, N_CLASSES, is_training=True)
else:
raise "--resnet argument has to be either 50, 101 or 152"
self.logits = tf.reshape(self.logits, [-1, N_CLASSES])
def train(self, sess, X, y, val_X, val_y, epochs=30, minibatch_size=50, optimizer=None):
self.init_fn = self.__get_init_fn()
return Shared.train_model(self, sess, X, y, val_X, val_y, epochs, minibatch_size, optimizer)
def load_model(self, sess):
return Shared.load_model(self, sess)
def predict_proba(self, sess, X, step=10):
return Shared.predict_proba(self, sess, X, step)
# In[ ]:
if __name__ == "__main__":
Shared.main(ResnetV2, args)