HairStable Bot commited on
Commit
a942552
·
1 Parent(s): f61e151

fix(cpu): replace CUDA-only calls; use DDIM for remove_hair on CPU/GPU

Browse files
Hair_stable_new_fresh/infer_full.py CHANGED
@@ -119,7 +119,11 @@ class StableHair:
119
  safety_checker=None,
120
  torch_dtype=weight_dtype,
121
  )
122
- self.remove_hair_pipeline.scheduler = UniPCMultistepScheduler.from_config(self.remove_hair_pipeline.scheduler.config)
 
 
 
 
123
  self.remove_hair_pipeline = self.remove_hair_pipeline.to(device)
124
 
125
  ### move to fp16
 
119
  safety_checker=None,
120
  torch_dtype=weight_dtype,
121
  )
122
+ # UniPC can throw 'stack expects a non-empty TensorList' or order-related
123
+ # assertions on some CPU builds. Use DDIM for wider compatibility.
124
+ self.remove_hair_pipeline.scheduler = DDIMScheduler.from_config(
125
+ self.remove_hair_pipeline.scheduler.config
126
+ )
127
  self.remove_hair_pipeline = self.remove_hair_pipeline.to(device)
128
 
129
  ### move to fp16
utils/pipeline.py CHANGED
@@ -275,7 +275,7 @@ class StableHairPipeline(DiffusionPipeline):
275
  condition = condition
276
  elif isinstance(condition, np.ndarray):
277
  # suppose input is [0, 255]
278
- condition = self.images2latents(condition, dtype).cuda()
279
  if do_classifier_free_guidance:
280
  condition_pad = torch.ones_like(condition) * -1
281
  condition = torch.cat([condition_pad, condition])
@@ -423,12 +423,13 @@ class StableHairPipeline(DiffusionPipeline):
423
  num_actual_inference_steps = num_inference_steps
424
 
425
  if isinstance(ref_image, str):
426
- ref_image_latents = self.images2latents(np.array(Image.open(ref_image).resize((width, height))),
427
- latents_dtype).cuda()
 
428
  elif isinstance(ref_image, np.ndarray):
429
- ref_image_latents = self.images2latents(ref_image, latents_dtype).cuda()
430
  elif isinstance(ref_image, torch.Tensor):
431
- ref_image_latents = self.images2latents(ref_image, latents_dtype).cuda()
432
 
433
  ref_padding_latents = torch.ones_like(ref_image_latents) * -1
434
  ref_image_latents = torch.cat([ref_padding_latents, ref_image_latents]) if do_classifier_free_guidance else ref_image_latents
 
275
  condition = condition
276
  elif isinstance(condition, np.ndarray):
277
  # suppose input is [0, 255]
278
+ condition = self.images2latents(condition, dtype).to(device)
279
  if do_classifier_free_guidance:
280
  condition_pad = torch.ones_like(condition) * -1
281
  condition = torch.cat([condition_pad, condition])
 
423
  num_actual_inference_steps = num_inference_steps
424
 
425
  if isinstance(ref_image, str):
426
+ ref_image_latents = self.images2latents(
427
+ np.array(Image.open(ref_image).resize((width, height))), latents_dtype
428
+ ).to(device)
429
  elif isinstance(ref_image, np.ndarray):
430
+ ref_image_latents = self.images2latents(ref_image, latents_dtype).to(device)
431
  elif isinstance(ref_image, torch.Tensor):
432
+ ref_image_latents = self.images2latents(ref_image, latents_dtype).to(device)
433
 
434
  ref_padding_latents = torch.ones_like(ref_image_latents) * -1
435
  ref_image_latents = torch.cat([ref_padding_latents, ref_image_latents]) if do_classifier_free_guidance else ref_image_latents