Commit
·
123200e
1
Parent(s):
77656cf
feat: download custom weights from HF Hub at runtime using hf_hub_download
Browse files- infer_full.py +47 -4
infer_full.py
CHANGED
|
@@ -13,6 +13,39 @@ from ref_encoder.reference_unet import ref_unet
|
|
| 13 |
from utils.pipeline import StableHairPipeline
|
| 14 |
from utils.pipeline_cn import StableDiffusionControlNetPipeline
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
def concatenate_images(image_files, output_file, type="pil"):
|
| 17 |
if type == "np":
|
| 18 |
image_files = [Image.fromarray(img) for img in image_files]
|
|
@@ -36,7 +69,8 @@ class StableHair:
|
|
| 36 |
### Load controlnet
|
| 37 |
unet = UNet2DConditionModel.from_pretrained(self.config.pretrained_model_path, subfolder="unet").to(device)
|
| 38 |
controlnet = ControlNetModel.from_unet(unet).to(device)
|
| 39 |
-
|
|
|
|
| 40 |
controlnet.load_state_dict(_state_dict, strict=False)
|
| 41 |
controlnet.to(weight_dtype)
|
| 42 |
|
|
@@ -51,15 +85,24 @@ class StableHair:
|
|
| 51 |
|
| 52 |
### load Hair encoder/adapter
|
| 53 |
self.hair_encoder = ref_unet.from_pretrained(self.config.pretrained_model_path, subfolder="unet").to(device)
|
| 54 |
-
|
|
|
|
| 55 |
self.hair_encoder.load_state_dict(_state_dict, strict=False)
|
| 56 |
self.hair_adapter = adapter_injection(self.pipeline.unet, device=self.device, dtype=torch.float16, use_resampler=False)
|
| 57 |
-
|
|
|
|
| 58 |
self.hair_adapter.load_state_dict(_state_dict, strict=False)
|
| 59 |
|
| 60 |
### load bald converter
|
| 61 |
bald_converter = ControlNetModel.from_unet(unet).to(device)
|
| 62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
bald_converter.load_state_dict(_state_dict, strict=False)
|
| 64 |
bald_converter.to(dtype=weight_dtype)
|
| 65 |
del unet
|
|
|
|
| 13 |
from utils.pipeline import StableHairPipeline
|
| 14 |
from utils.pipeline_cn import StableDiffusionControlNetPipeline
|
| 15 |
|
| 16 |
+
def _resolve_weight(prefix_path: str, filename: str) -> str:
|
| 17 |
+
"""Resolve a weight path, downloading from Hugging Face Hub if needed.
|
| 18 |
+
|
| 19 |
+
prefix_path can be either a local directory (e.g., ./models/stage2)
|
| 20 |
+
or a hub path like Org/Repo/subfolder. When it looks like a hub path,
|
| 21 |
+
we download the file via hf_hub_download using repo_id Org/Repo and
|
| 22 |
+
subfolder the remaining segments.
|
| 23 |
+
"""
|
| 24 |
+
# Try local first
|
| 25 |
+
local_path = os.path.join(prefix_path, filename)
|
| 26 |
+
if os.path.exists(local_path):
|
| 27 |
+
return local_path
|
| 28 |
+
|
| 29 |
+
# Attempt Hub download
|
| 30 |
+
try:
|
| 31 |
+
from huggingface_hub import hf_hub_download
|
| 32 |
+
|
| 33 |
+
parts = prefix_path.strip("/").split("/")
|
| 34 |
+
if len(parts) >= 2:
|
| 35 |
+
repo_id = "/".join(parts[:2])
|
| 36 |
+
subfolder = "/".join(parts[2:]) if len(parts) > 2 else None
|
| 37 |
+
downloaded = hf_hub_download(
|
| 38 |
+
repo_id=repo_id,
|
| 39 |
+
filename=filename,
|
| 40 |
+
subfolder=subfolder,
|
| 41 |
+
token=os.environ.get("HUGGINGFACEHUB_API_TOKEN"),
|
| 42 |
+
)
|
| 43 |
+
return downloaded
|
| 44 |
+
except Exception as exc: # noqa: WPS440
|
| 45 |
+
raise RuntimeError(f"Failed to fetch {filename} from hub ({prefix_path}): {exc}")
|
| 46 |
+
|
| 47 |
+
raise FileNotFoundError(f"Weight not found locally and not a valid hub path: {prefix_path}/{filename}")
|
| 48 |
+
|
| 49 |
def concatenate_images(image_files, output_file, type="pil"):
|
| 50 |
if type == "np":
|
| 51 |
image_files = [Image.fromarray(img) for img in image_files]
|
|
|
|
| 69 |
### Load controlnet
|
| 70 |
unet = UNet2DConditionModel.from_pretrained(self.config.pretrained_model_path, subfolder="unet").to(device)
|
| 71 |
controlnet = ControlNetModel.from_unet(unet).to(device)
|
| 72 |
+
controlnet_weight_path = _resolve_weight(self.config.pretrained_folder, self.config.controlnet_path)
|
| 73 |
+
_state_dict = torch.load(controlnet_weight_path, map_location="cpu")
|
| 74 |
controlnet.load_state_dict(_state_dict, strict=False)
|
| 75 |
controlnet.to(weight_dtype)
|
| 76 |
|
|
|
|
| 85 |
|
| 86 |
### load Hair encoder/adapter
|
| 87 |
self.hair_encoder = ref_unet.from_pretrained(self.config.pretrained_model_path, subfolder="unet").to(device)
|
| 88 |
+
encoder_weight_path = _resolve_weight(self.config.pretrained_folder, self.config.encoder_path)
|
| 89 |
+
_state_dict = torch.load(encoder_weight_path, map_location="cpu")
|
| 90 |
self.hair_encoder.load_state_dict(_state_dict, strict=False)
|
| 91 |
self.hair_adapter = adapter_injection(self.pipeline.unet, device=self.device, dtype=torch.float16, use_resampler=False)
|
| 92 |
+
adapter_weight_path = _resolve_weight(self.config.pretrained_folder, self.config.adapter_path)
|
| 93 |
+
_state_dict = torch.load(adapter_weight_path, map_location="cpu")
|
| 94 |
self.hair_adapter.load_state_dict(_state_dict, strict=False)
|
| 95 |
|
| 96 |
### load bald converter
|
| 97 |
bald_converter = ControlNetModel.from_unet(unet).to(device)
|
| 98 |
+
# bald_converter_path may be a local full path or a hub-like path
|
| 99 |
+
if os.path.exists(self.config.bald_converter_path):
|
| 100 |
+
bald_weight_path = self.config.bald_converter_path
|
| 101 |
+
else:
|
| 102 |
+
prefix, filename = os.path.split(self.config.bald_converter_path)
|
| 103 |
+
bald_weight_path = _resolve_weight(prefix, filename)
|
| 104 |
+
|
| 105 |
+
_state_dict = torch.load(bald_weight_path, map_location="cpu")
|
| 106 |
bald_converter.load_state_dict(_state_dict, strict=False)
|
| 107 |
bald_converter.to(dtype=weight_dtype)
|
| 108 |
del unet
|