LogicGoInfotechSpaces commited on
Commit
dfc30a3
·
1 Parent(s): 8f6f449

Improve SDXL auth error logging

Browse files
Files changed (1) hide show
  1. app/colorize_model.py +23 -22
app/colorize_model.py CHANGED
@@ -156,41 +156,42 @@ class ColorizeModel:
156
 
157
  def _load_pipeline(self) -> None:
158
  controlnet_path = self._download_controlnet()
 
159
 
160
  logger.info("Loading SDXL components...")
161
- vae = AutoencoderKL.from_pretrained(
162
- self.BASE_MODEL,
163
- subfolder="vae",
164
- torch_dtype=self.dtype,
165
- token=self.hf_token,
166
- )
167
  unet = UNet2DConditionModel.from_config(
168
  self.BASE_MODEL,
169
  subfolder="unet",
170
- token=self.hf_token,
171
  )
172
  lightning_path = hf_hub_download(
173
  repo_id=self.LIGHTNING_REPO,
174
  filename=self.LIGHTNING_WEIGHTS,
175
- token=self.hf_token,
176
  )
177
  unet.load_state_dict(load_file(lightning_path))
178
 
179
- controlnet = ControlNetModel.from_pretrained(
180
- controlnet_path,
181
- torch_dtype=self.dtype,
182
- )
183
 
184
- self.pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
185
- self.BASE_MODEL,
186
- vae=vae,
187
- unet=unet,
188
- controlnet=controlnet,
189
- torch_dtype=self.dtype,
190
- safety_checker=None,
191
- requires_safety_checker=False,
192
- token=self.hf_token,
193
- )
 
 
 
 
 
 
 
 
194
  self.pipe.set_progress_bar_config(disable=True)
195
 
196
  if self.device.type == "cuda":
 
156
 
157
  def _load_pipeline(self) -> None:
158
  controlnet_path = self._download_controlnet()
159
+ base_kwargs = {"use_auth_token": self.hf_token} if self.hf_token else {}
160
 
161
  logger.info("Loading SDXL components...")
162
+ vae = AutoencoderKL.from_pretrained(self.BASE_MODEL, subfolder="vae", torch_dtype=self.dtype, token=self.hf_token)
 
 
 
 
 
163
  unet = UNet2DConditionModel.from_config(
164
  self.BASE_MODEL,
165
  subfolder="unet",
166
+ token=self.hf_token if self.hf_token else None,
167
  )
168
  lightning_path = hf_hub_download(
169
  repo_id=self.LIGHTNING_REPO,
170
  filename=self.LIGHTNING_WEIGHTS,
171
+ token=self.hf_token if self.hf_token else None,
172
  )
173
  unet.load_state_dict(load_file(lightning_path))
174
 
175
+ controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=self.dtype)
 
 
 
176
 
177
+ try:
178
+ self.pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
179
+ self.BASE_MODEL,
180
+ vae=vae,
181
+ unet=unet,
182
+ controlnet=controlnet,
183
+ torch_dtype=self.dtype,
184
+ safety_checker=None,
185
+ requires_safety_checker=False,
186
+ token=self.hf_token if self.hf_token else None,
187
+ )
188
+ except Exception as exc:
189
+ logger.error("Failed to load base SDXL model: %s", exc)
190
+ logger.error(
191
+ "Ensure the account associated with HUGGINGFACE_HUB_TOKEN has accepted "
192
+ "the license for %s and that the token has access.", self.BASE_MODEL
193
+ )
194
+ raise
195
  self.pipe.set_progress_bar_config(disable=True)
196
 
197
  if self.device.type == "cuda":