Update app.py
Browse files
app.py
CHANGED
@@ -47,12 +47,19 @@ def tube_mask_generator(mask_ratio):
|
|
47 |
|
48 |
|
49 |
def get_model(model_type):
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
ft_model = keras.models.load_model(model_type + '_FT')
|
57 |
pt_model = keras.models.load_model(model_type + '_PT')
|
58 |
|
|
|
47 |
|
48 |
|
49 |
def get_model(model_type):
|
50 |
+
ft_path = keras.utils.get_file(
|
51 |
+
origin=f'https://github.com/innat/VideoMAE/releases/download/v1.1/{model_type}_FT.zip',
|
52 |
+
)
|
53 |
+
pt_path = keras.utils.get_file(
|
54 |
+
origin=f'https://github.com/innat/VideoMAE/releases/download/v1.1/{model_type}_PT.zip',
|
55 |
+
)
|
56 |
+
|
57 |
+
with zipfile.ZipFile(ft_path, 'r') as zip_ref:
|
58 |
+
zip_ref.extractall('./')
|
59 |
+
|
60 |
+
with zipfile.ZipFile(pt_path, 'r') as zip_ref:
|
61 |
+
zip_ref.extractall('./')
|
62 |
+
|
63 |
ft_model = keras.models.load_model(model_type + '_FT')
|
64 |
pt_model = keras.models.load_model(model_type + '_PT')
|
65 |
|