vinesmsuic commited on
Commit
8fc3c7f
·
1 Parent(s): 0b406ee

trying preload

Browse files
Files changed (2) hide show
  1. app.py +3 -0
  2. model/pre_download.py +77 -0
app.py CHANGED
@@ -7,6 +7,7 @@ from serve.leaderboard import build_leaderboard_tab
7
  from model.model_manager import ModelManager
8
  from pathlib import Path
9
  from serve.constants import SERVER_PORT, ROOT_PATH, ELO_RESULTS_DIR
 
10
 
11
  def build_combine_demo(models, elo_results_file, leaderboard_table_file):
12
 
@@ -97,6 +98,8 @@ if __name__ == "__main__":
97
  root_path = ROOT_PATH
98
  elo_results_dir = ELO_RESULTS_DIR
99
  models = ModelManager()
 
 
100
 
101
  elo_results_file, leaderboard_table_file = load_elo_results(elo_results_dir)
102
  demo = build_combine_demo(models, elo_results_file, leaderboard_table_file)
 
7
  from model.model_manager import ModelManager
8
  from pathlib import Path
9
  from serve.constants import SERVER_PORT, ROOT_PATH, ELO_RESULTS_DIR
10
+ from model.pre_download import pre_download_all_models
11
 
12
  def build_combine_demo(models, elo_results_file, leaderboard_table_file):
13
 
 
98
  root_path = ROOT_PATH
99
  elo_results_dir = ELO_RESULTS_DIR
100
  models = ModelManager()
101
+
102
+ pre_download_all_models()
103
 
104
  elo_results_file, leaderboard_table_file = load_elo_results(elo_results_dir)
105
  demo = build_combine_demo(models, elo_results_file, leaderboard_table_file)
model/pre_download.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .models import IMAGE_GENERATION_MODELS, IMAGE_EDITION_MODELS, VIDEO_GENERATION_MODELS
2
+
3
+ def pre_download_all_models():
4
+ """
5
+ Pre-download all models to avoid download delay during the first user request
6
+ """
7
+ imagen_dl_error = pre_download_image_models()
8
+ imagedit_dl_error = pre_download_image_models()
9
+ videogen_dl_error = pre_download_video_models()
10
+ print("All models downloaded.")
11
+ print("Models that encountered download error:", "Image Generation:", imagen_dl_error, "Image Edition:", imagedit_dl_error, "Video Generation:", videogen_dl_error)
12
+
13
+ def pre_download_image_models():
14
+ """
15
+ Pre-download image models to avoid download delay during the first user request
16
+ """
17
+ import imagen_hub
18
+ errored_models = []
19
+ for model_string in IMAGE_GENERATION_MODELS:
20
+ print("Loading image generation model:", model_name)
21
+ model_lib, model_name, model_type = model_string.split("_")
22
+
23
+ if model_lib == "imagenhub":
24
+ try:
25
+ temp_model = imagen_hub.get_model(model_name) # Forcing model to download weight files
26
+ del temp_model
27
+ except Exception as e:
28
+ print(f"Failed to load model {model_name} \n {e}")
29
+ errored_models.append(model_string)
30
+ continue
31
+ else:
32
+ pass
33
+ return errored_models
34
+
35
+ def pre_download_image_models():
36
+ """
37
+ Pre-download image models to avoid download delay during the first user request
38
+ """
39
+ import imagen_hub
40
+ errored_models = []
41
+ for model_string in IMAGE_EDITION_MODELS:
42
+ print("Loading image edition model:", model_name)
43
+ model_lib, model_name, model_type = model_string.split("_")
44
+
45
+ if model_lib == "imagenhub":
46
+ try:
47
+ temp_model = imagen_hub.get_model(model_name) # Forcing model to download weight files
48
+ del temp_model
49
+ except Exception as e:
50
+ print(f"Failed to load model {model_name} \n {e}")
51
+ errored_models.append(model_string)
52
+ continue
53
+ else:
54
+ pass
55
+ return errored_models
56
+
57
+ def pre_download_video_models():
58
+ """
59
+ Pre-download video models to avoid download delay during the first user request
60
+ """
61
+ import videogen_hub
62
+ errored_models = []
63
+ for model_string in VIDEO_GENERATION_MODELS:
64
+ print("Loading video generation model:", model_name)
65
+ model_lib, model_name, model_type = model_string.split("_")
66
+
67
+ if model_lib == "videogenhub":
68
+ try:
69
+ temp_model = videogen_hub.get_model(model_name) # Forcing model to download weight files
70
+ del temp_model
71
+ except Exception as e:
72
+ print(f"Failed to load model {model_name} \n {e}")
73
+ errored_models.append(model_string)
74
+ continue
75
+ else:
76
+ pass
77
+ return errored_models