DongfuJiang commited on
Commit
e3fdca8
1 Parent(s): c2d6fd1
Files changed (2) hide show
  1. app.py +1 -1
  2. model/model_manager.py +3 -3
app.py CHANGED
@@ -96,7 +96,7 @@ if __name__ == "__main__":
96
  server_port = int(SERVER_PORT)
97
  root_path = ROOT_PATH
98
  elo_results_dir = ELO_RESULTS_DIR
99
- models = ModelManager(enable_nsfw=True, pre_download=True, debug_packages=True)
100
  # models = ModelManager(enable_nsfw=False, pre_download=False, debug_packages=False)
101
 
102
  elo_results_file, leaderboard_table_file = load_elo_results(elo_results_dir)
 
96
  server_port = int(SERVER_PORT)
97
  root_path = ROOT_PATH
98
  elo_results_dir = ELO_RESULTS_DIR
99
+ models = ModelManager(enable_nsfw=False, do_pre_download=True, do_debug_packages=True)
100
  # models = ModelManager(enable_nsfw=False, pre_download=False, debug_packages=False)
101
 
102
  elo_results_file, leaderboard_table_file = load_elo_results(elo_results_dir)
model/model_manager.py CHANGED
@@ -19,7 +19,7 @@ def debug_packages():
19
  print(f"{package.key}=={package.version}")
20
 
21
  class ModelManager:
22
- def __init__(self, enable_nsfw=False, pre_download=False, debug_packages=False):
23
  self.model_ig_list = IMAGE_GENERATION_MODELS
24
  self.model_ie_list = IMAGE_EDITION_MODELS
25
  self.model_vg_list = VIDEO_GENERATION_MODELS
@@ -28,9 +28,9 @@ class ModelManager:
28
  self.enable_nsfw = enable_nsfw
29
  self.load_guard(enable_nsfw)
30
  self.loaded_models = {}
31
- if pre_download:
32
  pre_download_all_models()
33
- if debug_packages:
34
  debug_packages()
35
 
36
  def load_model_pipe(self, model_name):
 
19
  print(f"{package.key}=={package.version}")
20
 
21
  class ModelManager:
22
+ def __init__(self, enable_nsfw=False, do_pre_download=False, do_debug_packages=False):
23
  self.model_ig_list = IMAGE_GENERATION_MODELS
24
  self.model_ie_list = IMAGE_EDITION_MODELS
25
  self.model_vg_list = VIDEO_GENERATION_MODELS
 
28
  self.enable_nsfw = enable_nsfw
29
  self.load_guard(enable_nsfw)
30
  self.loaded_models = {}
31
+ if do_pre_download:
32
  pre_download_all_models()
33
+ if do_debug_packages:
34
  debug_packages()
35
 
36
  def load_model_pipe(self, model_name):