Commit
·
dfc30a3
1
Parent(s):
8f6f449
Improve SDXL auth error logging
Browse files- app/colorize_model.py +23 -22
app/colorize_model.py
CHANGED
|
@@ -156,41 +156,42 @@ class ColorizeModel:
|
|
| 156 |
|
| 157 |
def _load_pipeline(self) -> None:
|
| 158 |
controlnet_path = self._download_controlnet()
|
|
|
|
| 159 |
|
| 160 |
logger.info("Loading SDXL components...")
|
| 161 |
-
vae = AutoencoderKL.from_pretrained(
|
| 162 |
-
self.BASE_MODEL,
|
| 163 |
-
subfolder="vae",
|
| 164 |
-
torch_dtype=self.dtype,
|
| 165 |
-
token=self.hf_token,
|
| 166 |
-
)
|
| 167 |
unet = UNet2DConditionModel.from_config(
|
| 168 |
self.BASE_MODEL,
|
| 169 |
subfolder="unet",
|
| 170 |
-
token=self.hf_token,
|
| 171 |
)
|
| 172 |
lightning_path = hf_hub_download(
|
| 173 |
repo_id=self.LIGHTNING_REPO,
|
| 174 |
filename=self.LIGHTNING_WEIGHTS,
|
| 175 |
-
token=self.hf_token,
|
| 176 |
)
|
| 177 |
unet.load_state_dict(load_file(lightning_path))
|
| 178 |
|
| 179 |
-
controlnet = ControlNetModel.from_pretrained(
|
| 180 |
-
controlnet_path,
|
| 181 |
-
torch_dtype=self.dtype,
|
| 182 |
-
)
|
| 183 |
|
| 184 |
-
|
| 185 |
-
self.
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
self.pipe.set_progress_bar_config(disable=True)
|
| 195 |
|
| 196 |
if self.device.type == "cuda":
|
|
|
|
| 156 |
|
| 157 |
def _load_pipeline(self) -> None:
|
| 158 |
controlnet_path = self._download_controlnet()
|
| 159 |
+
base_kwargs = {"use_auth_token": self.hf_token} if self.hf_token else {}
|
| 160 |
|
| 161 |
logger.info("Loading SDXL components...")
|
| 162 |
+
vae = AutoencoderKL.from_pretrained(self.BASE_MODEL, subfolder="vae", torch_dtype=self.dtype, token=self.hf_token)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
unet = UNet2DConditionModel.from_config(
|
| 164 |
self.BASE_MODEL,
|
| 165 |
subfolder="unet",
|
| 166 |
+
token=self.hf_token if self.hf_token else None,
|
| 167 |
)
|
| 168 |
lightning_path = hf_hub_download(
|
| 169 |
repo_id=self.LIGHTNING_REPO,
|
| 170 |
filename=self.LIGHTNING_WEIGHTS,
|
| 171 |
+
token=self.hf_token if self.hf_token else None,
|
| 172 |
)
|
| 173 |
unet.load_state_dict(load_file(lightning_path))
|
| 174 |
|
| 175 |
+
controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=self.dtype)
|
|
|
|
|
|
|
|
|
|
| 176 |
|
| 177 |
+
try:
|
| 178 |
+
self.pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
|
| 179 |
+
self.BASE_MODEL,
|
| 180 |
+
vae=vae,
|
| 181 |
+
unet=unet,
|
| 182 |
+
controlnet=controlnet,
|
| 183 |
+
torch_dtype=self.dtype,
|
| 184 |
+
safety_checker=None,
|
| 185 |
+
requires_safety_checker=False,
|
| 186 |
+
token=self.hf_token if self.hf_token else None,
|
| 187 |
+
)
|
| 188 |
+
except Exception as exc:
|
| 189 |
+
logger.error("Failed to load base SDXL model: %s", exc)
|
| 190 |
+
logger.error(
|
| 191 |
+
"Ensure the account associated with HUGGINGFACE_HUB_TOKEN has accepted "
|
| 192 |
+
"the license for %s and that the token has access.", self.BASE_MODEL
|
| 193 |
+
)
|
| 194 |
+
raise
|
| 195 |
self.pipe.set_progress_bar_config(disable=True)
|
| 196 |
|
| 197 |
if self.device.type == "cuda":
|