Maikou commited on
Commit
77d5485
1 Parent(s): c6ea4a6
Files changed (1) hide show
  1. app.py +11 -1
app.py CHANGED
@@ -15,6 +15,8 @@ from michelangelo.utils.visualizers import html_util
15
 
16
  import gradio as gr
17
 
 
 
18
  from huggingface_hub import snapshot_download
19
 
20
  gradio_cached_dir = "./gradio_cached_dir"
@@ -113,7 +115,15 @@ def load_model(model_name: str, model_config_dict: dict, inference_model: Infere
113
 
114
  config_ckpt_path = model_config_dict[model_name]
115
 
116
- model_config = get_config_from_file(config_ckpt_path["config"])
 
 
 
 
 
 
 
 
117
  if hasattr(model_config, "model"):
118
  model_config = model_config.model
119
 
 
15
 
16
  import gradio as gr
17
 
18
+ from omegaconf import OmegaConf
19
+
20
  from huggingface_hub import snapshot_download
21
 
22
  gradio_cached_dir = "./gradio_cached_dir"
 
115
 
116
  config_ckpt_path = model_config_dict[model_name]
117
 
118
+ raw_config_file = config_ckpt_path["config"]
119
+ raw_config = OmegaConf.load(raw_config_file)
120
+ raw_clip_ckpt_path = raw_config['model']['params']['first_stage_config']['params']['aligned_module_cfg']['params']['clip_model_version']
121
+ clip_ckpt_path = os.path.join(model_path, raw_clip_ckpt_path)
122
+ raw_config['model']['params']['first_stage_config']['params']['aligned_module_cfg']['params']['clip_model_version'] = clip_ckpt_path
123
+ raw_config['model']['params']['cond_stage_config']['params']['version'] = clip_ckpt_path
124
+ OmegaConf.save(raw_config, 'current_config.yaml')
125
+
126
+ model_config = get_config_from_file('current_config.yaml')
127
  if hasattr(model_config, "model"):
128
  model_config = model_config.model
129