innat commited on
Commit
36e374f
·
1 Parent(s): 93ca8bb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -11
app.py CHANGED
@@ -10,6 +10,18 @@ from utils import read_video, frame_sampling, denormalize, reconstrunction
10
  from utils import IMAGENET_MEAN, IMAGENET_STD, num_frames, patch_size, input_size
11
  from labels import K400_label_map, SSv2_label_map, UCF_label_map
12
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  MODELS = {
14
  'K400': [
15
  './TFVideoMAE_S_K400_16x224_FT',
@@ -48,9 +60,16 @@ def tube_mask_generator(mask_ratio):
48
  return bool_masked_pos_tf
49
 
50
 
51
- def get_model(data_type):
52
- ft_model = keras.models.load_model(MODELS[data_type][0])
53
- pt_model = keras.models.load_model(MODELS[data_type][1])
 
 
 
 
 
 
 
54
 
55
  label_map = LABEL_MAPS.get(data_type)
56
  label_map = K400_label_map
@@ -59,14 +78,14 @@ def get_model(data_type):
59
  return ft_model, pt_model, label_map
60
 
61
 
62
- def inference(video_file, data_type, mask_ratio):
63
  # get sample data
64
  container = read_video(video_file)
65
  frames = frame_sampling(container, num_frames=num_frames)
66
 
67
  # get models
68
  bool_masked_pos_tf = tube_mask_generator(mask_ratio)
69
- ft_model, pt_model, label_map = get_model(data_type)
70
  ft_model.trainable = False
71
  pt_model.trainable = False
72
 
@@ -110,12 +129,17 @@ def main():
110
  fn=inference,
111
  inputs=[
112
  gr.Video(type="file", label="Input Video"),
113
- gr.Radio(
114
- datasets,
115
- type='value',
116
- default=datasets[0],
117
- label='Dataset',
118
- ),
 
 
 
 
 
119
  gr.Slider(
120
  0.5,
121
  1.0,
 
10
  from utils import IMAGENET_MEAN, IMAGENET_STD, num_frames, patch_size, input_size
11
  from labels import K400_label_map, SSv2_label_map, UCF_label_map
12
 
13
+
14
+ def available_models():
15
+ ALL_MODELS = [
16
+ 'TFVideoMAE_S_K400_16x224',
17
+ 'TFVideoMAE_B_K400_16x224',
18
+ 'TFVideoMAE_L_K400_16x224',
19
+ 'TFVideoMAE_S_SSv2_16x224',
20
+ 'TFVideoMAE_B_SSv2_16x224',
21
+ 'TFVideoMAE_B_UCF_16x224',
22
+ ]
23
+ return ALL_MODELS
24
+
25
  MODELS = {
26
  'K400': [
27
  './TFVideoMAE_S_K400_16x224_FT',
 
60
  return bool_masked_pos_tf
61
 
62
 
63
+ def get_model(model_type):
64
+ ft_model = keras.models.load_model(model_type + '_FT')
65
+ pt_model = keras.models.load_model(model_type + '_PT')
66
+
67
+ if 'K400' in model_type:
68
+ data_type = 'K400'
69
+ elif 'SSv2' in model_type:
70
+ data_type = 'SSv2'
71
+ else:
72
+ data_type = 'UCF'
73
 
74
  label_map = LABEL_MAPS.get(data_type)
75
  label_map = K400_label_map
 
78
  return ft_model, pt_model, label_map
79
 
80
 
81
+ def inference(video_file, model_type, mask_ratio):
82
  # get sample data
83
  container = read_video(video_file)
84
  frames = frame_sampling(container, num_frames=num_frames)
85
 
86
  # get models
87
  bool_masked_pos_tf = tube_mask_generator(mask_ratio)
88
+ ft_model, pt_model, label_map = get_model(model_type)
89
  ft_model.trainable = False
90
  pt_model.trainable = False
91
 
 
129
  fn=inference,
130
  inputs=[
131
  gr.Video(type="file", label="Input Video"),
132
+ gr.Dropdown(
133
+ choices=available_models(),
134
+ value="TFVideoMAE_S_K400_16x224",
135
+ label="Model"
136
+ )
137
+ # gr.Radio(
138
+ # datasets,
139
+ # type='value',
140
+ # default=datasets[0],
141
+ # label='Dataset',
142
+ # ),
143
  gr.Slider(
144
  0.5,
145
  1.0,