hysts commited on
Commit
6be65e1
1 Parent(s): 0e88f89
Files changed (1) hide show
  1. model.py +8 -9
model.py CHANGED
@@ -71,22 +71,21 @@ class Model:
71
  self.model_names = LIGHTWEIGHT_MODEL_NAMES
72
  self.weight_root = LIGHTWEIGHT_WEIGHT_ROOT
73
  base_model_url = 'https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors'
74
- self.download_base_model(base_model_url)
75
- base_model_path = self.model_dir / base_model_url.split('/')[-1]
76
- self.load_base_model(base_model_path)
77
  else:
78
  self.model_names = ORIGINAL_MODEL_NAMES
79
  self.weight_root = ORIGINAL_WEIGHT_ROOT
80
  self.download_models()
81
 
82
- def download_base_model(self, base_model_url: str) -> None:
83
- model_name = base_model_url.split('/')[-1]
84
  out_path = self.model_dir / model_name
85
- if out_path.exists():
86
- return
87
- subprocess.run(shlex.split(f'wget {base_model_url} -O {out_path}'))
88
 
89
- def load_base_model(self, model_path: pathlib.Path) -> None:
 
90
  self.model.load_state_dict(load_state_dict(model_path,
91
  location=self.device.type),
92
  strict=False)
 
71
  self.model_names = LIGHTWEIGHT_MODEL_NAMES
72
  self.weight_root = LIGHTWEIGHT_WEIGHT_ROOT
73
  base_model_url = 'https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors'
74
+ self.load_base_model(base_model_url)
 
 
75
  else:
76
  self.model_names = ORIGINAL_MODEL_NAMES
77
  self.weight_root = ORIGINAL_WEIGHT_ROOT
78
  self.download_models()
79
 
80
+ def download_base_model(self, model_url: str) -> pathlib.Path:
81
+ model_name = model_url.split('/')[-1]
82
  out_path = self.model_dir / model_name
83
+ if not out_path.exists():
84
+ subprocess.run(shlex.split(f'wget {model_url} -O {out_path}'))
85
+ return out_path
86
 
87
+ def load_base_model(self, model_url: str) -> None:
88
+ model_path = self.download_base_model(model_url)
89
  self.model.load_state_dict(load_state_dict(model_path,
90
  location=self.device.type),
91
  strict=False)