HairStable Bot
commited on
Commit
·
43b5bad
1
Parent(s):
82b3d02
chore: log schedulers; enforce DDIM in get_bald each call
Browse files
Hair_stable_new_fresh/infer_full.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import torch
|
|
|
|
| 2 |
from PIL import Image
|
| 3 |
import numpy as np
|
| 4 |
from PIL import Image
|
|
@@ -63,6 +64,7 @@ def concatenate_images(image_files, output_file, type="pil"):
|
|
| 63 |
class StableHair:
|
| 64 |
def __init__(self, config="./configs/hair_transfer.yaml", device="cuda", weight_dtype=torch.float32) -> None:
|
| 65 |
print("Initializing Stable Hair Pipeline...")
|
|
|
|
| 66 |
self.config = OmegaConf.load(config)
|
| 67 |
self.device = device
|
| 68 |
|
|
@@ -176,6 +178,12 @@ class StableHair:
|
|
| 176 |
self.remove_hair_pipeline.scheduler = _DDIM.from_config(self.remove_hair_pipeline.scheduler.config)
|
| 177 |
except Exception:
|
| 178 |
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
image = self.remove_hair_pipeline(
|
| 180 |
prompt="",
|
| 181 |
negative_prompt="",
|
|
|
|
| 1 |
import torch
|
| 2 |
+
import logging
|
| 3 |
from PIL import Image
|
| 4 |
import numpy as np
|
| 5 |
from PIL import Image
|
|
|
|
| 64 |
class StableHair:
|
| 65 |
def __init__(self, config="./configs/hair_transfer.yaml", device="cuda", weight_dtype=torch.float32) -> None:
|
| 66 |
print("Initializing Stable Hair Pipeline...")
|
| 67 |
+
self.logger = logging.getLogger("hair_model")
|
| 68 |
self.config = OmegaConf.load(config)
|
| 69 |
self.device = device
|
| 70 |
|
|
|
|
| 178 |
self.remove_hair_pipeline.scheduler = _DDIM.from_config(self.remove_hair_pipeline.scheduler.config)
|
| 179 |
except Exception:
|
| 180 |
pass
|
| 181 |
+
# Log scheduler to confirm
|
| 182 |
+
try:
|
| 183 |
+
sched_name = type(self.remove_hair_pipeline.scheduler).__name__
|
| 184 |
+
self.logger.info(f"remove_hair_pipeline scheduler: {sched_name}")
|
| 185 |
+
except Exception:
|
| 186 |
+
pass
|
| 187 |
image = self.remove_hair_pipeline(
|
| 188 |
prompt="",
|
| 189 |
negative_prompt="",
|
Hair_stable_new_fresh/server.py
CHANGED
|
@@ -180,6 +180,13 @@ def get_hairswap(req: HairSwapRequest, _=Depends(verify_bearer)):
|
|
| 180 |
# Perform hair transfer with error handling
|
| 181 |
try:
|
| 182 |
LOGGER.info("Starting hair transfer...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
id_np, out_np, bald_np, ref_np = model.Hair_Transfer(
|
| 184 |
source_image=source_path,
|
| 185 |
reference_image=reference_path,
|
|
|
|
| 180 |
# Perform hair transfer with error handling
|
| 181 |
try:
|
| 182 |
LOGGER.info("Starting hair transfer...")
|
| 183 |
+
# Log current schedulers for visibility
|
| 184 |
+
try:
|
| 185 |
+
sched_main = type(model.pipeline.scheduler).__name__ if hasattr(model, "pipeline") else None
|
| 186 |
+
sched_bald = type(model.remove_hair_pipeline.scheduler).__name__ if hasattr(model, "remove_hair_pipeline") else None
|
| 187 |
+
LOGGER.info(f"Schedulers -> main: {sched_main}, remove_hair: {sched_bald}")
|
| 188 |
+
except Exception:
|
| 189 |
+
pass
|
| 190 |
id_np, out_np, bald_np, ref_np = model.Hair_Transfer(
|
| 191 |
source_image=source_path,
|
| 192 |
reference_image=reference_path,
|