Update app.py
Browse files
app.py
CHANGED
@@ -10,6 +10,18 @@ from utils import read_video, frame_sampling, denormalize, reconstrunction
|
|
10 |
from utils import IMAGENET_MEAN, IMAGENET_STD, num_frames, patch_size, input_size
|
11 |
from labels import K400_label_map, SSv2_label_map, UCF_label_map
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
MODELS = {
|
14 |
'K400': [
|
15 |
'./TFVideoMAE_S_K400_16x224_FT',
|
@@ -48,9 +60,16 @@ def tube_mask_generator(mask_ratio):
|
|
48 |
return bool_masked_pos_tf
|
49 |
|
50 |
|
51 |
-
def get_model(
|
52 |
-
ft_model = keras.models.load_model(
|
53 |
-
pt_model = keras.models.load_model(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
label_map = LABEL_MAPS.get(data_type)
|
56 |
label_map = K400_label_map
|
@@ -59,14 +78,14 @@ def get_model(data_type):
|
|
59 |
return ft_model, pt_model, label_map
|
60 |
|
61 |
|
62 |
-
def inference(video_file,
|
63 |
# get sample data
|
64 |
container = read_video(video_file)
|
65 |
frames = frame_sampling(container, num_frames=num_frames)
|
66 |
|
67 |
# get models
|
68 |
bool_masked_pos_tf = tube_mask_generator(mask_ratio)
|
69 |
-
ft_model, pt_model, label_map = get_model(
|
70 |
ft_model.trainable = False
|
71 |
pt_model.trainable = False
|
72 |
|
@@ -110,12 +129,17 @@ def main():
|
|
110 |
fn=inference,
|
111 |
inputs=[
|
112 |
gr.Video(type="file", label="Input Video"),
|
113 |
-
gr.
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
|
|
|
|
|
|
|
|
|
|
119 |
gr.Slider(
|
120 |
0.5,
|
121 |
1.0,
|
|
|
10 |
from utils import IMAGENET_MEAN, IMAGENET_STD, num_frames, patch_size, input_size
|
11 |
from labels import K400_label_map, SSv2_label_map, UCF_label_map
|
12 |
|
13 |
+
|
14 |
+
def available_models():
|
15 |
+
ALL_MODELS = [
|
16 |
+
'TFVideoMAE_S_K400_16x224',
|
17 |
+
'TFVideoMAE_B_K400_16x224',
|
18 |
+
'TFVideoMAE_L_K400_16x224',
|
19 |
+
'TFVideoMAE_S_SSv2_16x224',
|
20 |
+
'TFVideoMAE_B_SSv2_16x224',
|
21 |
+
'TFVideoMAE_B_UCF_16x224',
|
22 |
+
]
|
23 |
+
return ALL_MODELS
|
24 |
+
|
25 |
MODELS = {
|
26 |
'K400': [
|
27 |
'./TFVideoMAE_S_K400_16x224_FT',
|
|
|
60 |
return bool_masked_pos_tf
|
61 |
|
62 |
|
63 |
+
def get_model(model_type):
|
64 |
+
ft_model = keras.models.load_model(model_type + '_FT')
|
65 |
+
pt_model = keras.models.load_model(model_type + '_PT')
|
66 |
+
|
67 |
+
if 'K400' in model_type:
|
68 |
+
data_type = 'K400'
|
69 |
+
elif 'SSv2' in model_type:
|
70 |
+
data_type = 'SSv2'
|
71 |
+
else:
|
72 |
+
data_type = 'UCF'
|
73 |
|
74 |
label_map = LABEL_MAPS.get(data_type)
|
75 |
label_map = K400_label_map
|
|
|
78 |
return ft_model, pt_model, label_map
|
79 |
|
80 |
|
81 |
+
def inference(video_file, model_type, mask_ratio):
|
82 |
# get sample data
|
83 |
container = read_video(video_file)
|
84 |
frames = frame_sampling(container, num_frames=num_frames)
|
85 |
|
86 |
# get models
|
87 |
bool_masked_pos_tf = tube_mask_generator(mask_ratio)
|
88 |
+
ft_model, pt_model, label_map = get_model(model_type)
|
89 |
ft_model.trainable = False
|
90 |
pt_model.trainable = False
|
91 |
|
|
|
129 |
fn=inference,
|
130 |
inputs=[
|
131 |
gr.Video(type="file", label="Input Video"),
|
132 |
+
gr.Dropdown(
|
133 |
+
choices=available_models(),
|
134 |
+
value="TFVideoMAE_S_K400_16x224",
|
135 |
+
label="Model"
|
136 |
+
)
|
137 |
+
# gr.Radio(
|
138 |
+
# datasets,
|
139 |
+
# type='value',
|
140 |
+
# default=datasets[0],
|
141 |
+
# label='Dataset',
|
142 |
+
# ),
|
143 |
gr.Slider(
|
144 |
0.5,
|
145 |
1.0,
|