HairStable Bot commited on
Commit
43b5bad
·
1 Parent(s): 82b3d02

chore: log schedulers; enforce DDIM in get_bald each call

Browse files
Hair_stable_new_fresh/infer_full.py CHANGED
@@ -1,4 +1,5 @@
1
  import torch
 
2
  from PIL import Image
3
  import numpy as np
4
  from PIL import Image
@@ -63,6 +64,7 @@ def concatenate_images(image_files, output_file, type="pil"):
63
  class StableHair:
64
  def __init__(self, config="./configs/hair_transfer.yaml", device="cuda", weight_dtype=torch.float32) -> None:
65
  print("Initializing Stable Hair Pipeline...")
 
66
  self.config = OmegaConf.load(config)
67
  self.device = device
68
 
@@ -176,6 +178,12 @@ class StableHair:
176
  self.remove_hair_pipeline.scheduler = _DDIM.from_config(self.remove_hair_pipeline.scheduler.config)
177
  except Exception:
178
  pass
 
 
 
 
 
 
179
  image = self.remove_hair_pipeline(
180
  prompt="",
181
  negative_prompt="",
 
1
  import torch
2
+ import logging
3
  from PIL import Image
4
  import numpy as np
5
  from PIL import Image
 
64
  class StableHair:
65
  def __init__(self, config="./configs/hair_transfer.yaml", device="cuda", weight_dtype=torch.float32) -> None:
66
  print("Initializing Stable Hair Pipeline...")
67
+ self.logger = logging.getLogger("hair_model")
68
  self.config = OmegaConf.load(config)
69
  self.device = device
70
 
 
178
  self.remove_hair_pipeline.scheduler = _DDIM.from_config(self.remove_hair_pipeline.scheduler.config)
179
  except Exception:
180
  pass
181
+ # Log scheduler to confirm
182
+ try:
183
+ sched_name = type(self.remove_hair_pipeline.scheduler).__name__
184
+ self.logger.info(f"remove_hair_pipeline scheduler: {sched_name}")
185
+ except Exception:
186
+ pass
187
  image = self.remove_hair_pipeline(
188
  prompt="",
189
  negative_prompt="",
Hair_stable_new_fresh/server.py CHANGED
@@ -180,6 +180,13 @@ def get_hairswap(req: HairSwapRequest, _=Depends(verify_bearer)):
180
  # Perform hair transfer with error handling
181
  try:
182
  LOGGER.info("Starting hair transfer...")
 
 
 
 
 
 
 
183
  id_np, out_np, bald_np, ref_np = model.Hair_Transfer(
184
  source_image=source_path,
185
  reference_image=reference_path,
 
180
  # Perform hair transfer with error handling
181
  try:
182
  LOGGER.info("Starting hair transfer...")
183
+ # Log current schedulers for visibility
184
+ try:
185
+ sched_main = type(model.pipeline.scheduler).__name__ if hasattr(model, "pipeline") else None
186
+ sched_bald = type(model.remove_hair_pipeline.scheduler).__name__ if hasattr(model, "remove_hair_pipeline") else None
187
+ LOGGER.info(f"Schedulers -> main: {sched_main}, remove_hair: {sched_bald}")
188
+ except Exception:
189
+ pass
190
  id_np, out_np, bald_np, ref_np = model.Hair_Transfer(
191
  source_image=source_path,
192
  reference_image=reference_path,