LogicGoInfotechSpaces commited on
Commit
123200e
·
1 Parent(s): 77656cf

feat: download custom weights from HF Hub at runtime using hf_hub_download

Browse files
Files changed (1) hide show
  1. infer_full.py +47 -4
infer_full.py CHANGED
@@ -13,6 +13,39 @@ from ref_encoder.reference_unet import ref_unet
13
  from utils.pipeline import StableHairPipeline
14
  from utils.pipeline_cn import StableDiffusionControlNetPipeline
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  def concatenate_images(image_files, output_file, type="pil"):
17
  if type == "np":
18
  image_files = [Image.fromarray(img) for img in image_files]
@@ -36,7 +69,8 @@ class StableHair:
36
  ### Load controlnet
37
  unet = UNet2DConditionModel.from_pretrained(self.config.pretrained_model_path, subfolder="unet").to(device)
38
  controlnet = ControlNetModel.from_unet(unet).to(device)
39
- _state_dict = torch.load(os.path.join(self.config.pretrained_folder, self.config.controlnet_path))
 
40
  controlnet.load_state_dict(_state_dict, strict=False)
41
  controlnet.to(weight_dtype)
42
 
@@ -51,15 +85,24 @@ class StableHair:
51
 
52
  ### load Hair encoder/adapter
53
  self.hair_encoder = ref_unet.from_pretrained(self.config.pretrained_model_path, subfolder="unet").to(device)
54
- _state_dict = torch.load(os.path.join(self.config.pretrained_folder, self.config.encoder_path))
 
55
  self.hair_encoder.load_state_dict(_state_dict, strict=False)
56
  self.hair_adapter = adapter_injection(self.pipeline.unet, device=self.device, dtype=torch.float16, use_resampler=False)
57
- _state_dict = torch.load(os.path.join(self.config.pretrained_folder, self.config.adapter_path))
 
58
  self.hair_adapter.load_state_dict(_state_dict, strict=False)
59
 
60
  ### load bald converter
61
  bald_converter = ControlNetModel.from_unet(unet).to(device)
62
- _state_dict = torch.load(self.config.bald_converter_path)
 
 
 
 
 
 
 
63
  bald_converter.load_state_dict(_state_dict, strict=False)
64
  bald_converter.to(dtype=weight_dtype)
65
  del unet
 
13
  from utils.pipeline import StableHairPipeline
14
  from utils.pipeline_cn import StableDiffusionControlNetPipeline
15
 
16
+ def _resolve_weight(prefix_path: str, filename: str) -> str:
17
+ """Resolve a weight path, downloading from Hugging Face Hub if needed.
18
+
19
+ prefix_path can be either a local directory (e.g., ./models/stage2)
20
+ or a hub path like Org/Repo/subfolder. When it looks like a hub path,
21
+ we download the file via hf_hub_download using repo_id Org/Repo and
22
+ subfolder the remaining segments.
23
+ """
24
+ # Try local first
25
+ local_path = os.path.join(prefix_path, filename)
26
+ if os.path.exists(local_path):
27
+ return local_path
28
+
29
+ # Attempt Hub download
30
+ try:
31
+ from huggingface_hub import hf_hub_download
32
+
33
+ parts = prefix_path.strip("/").split("/")
34
+ if len(parts) >= 2:
35
+ repo_id = "/".join(parts[:2])
36
+ subfolder = "/".join(parts[2:]) if len(parts) > 2 else None
37
+ downloaded = hf_hub_download(
38
+ repo_id=repo_id,
39
+ filename=filename,
40
+ subfolder=subfolder,
41
+ token=os.environ.get("HUGGINGFACEHUB_API_TOKEN"),
42
+ )
43
+ return downloaded
44
+ except Exception as exc: # noqa: WPS440
45
+ raise RuntimeError(f"Failed to fetch {filename} from hub ({prefix_path}): {exc}")
46
+
47
+ raise FileNotFoundError(f"Weight not found locally and not a valid hub path: {prefix_path}/{filename}")
48
+
49
  def concatenate_images(image_files, output_file, type="pil"):
50
  if type == "np":
51
  image_files = [Image.fromarray(img) for img in image_files]
 
69
  ### Load controlnet
70
  unet = UNet2DConditionModel.from_pretrained(self.config.pretrained_model_path, subfolder="unet").to(device)
71
  controlnet = ControlNetModel.from_unet(unet).to(device)
72
+ controlnet_weight_path = _resolve_weight(self.config.pretrained_folder, self.config.controlnet_path)
73
+ _state_dict = torch.load(controlnet_weight_path, map_location="cpu")
74
  controlnet.load_state_dict(_state_dict, strict=False)
75
  controlnet.to(weight_dtype)
76
 
 
85
 
86
  ### load Hair encoder/adapter
87
  self.hair_encoder = ref_unet.from_pretrained(self.config.pretrained_model_path, subfolder="unet").to(device)
88
+ encoder_weight_path = _resolve_weight(self.config.pretrained_folder, self.config.encoder_path)
89
+ _state_dict = torch.load(encoder_weight_path, map_location="cpu")
90
  self.hair_encoder.load_state_dict(_state_dict, strict=False)
91
  self.hair_adapter = adapter_injection(self.pipeline.unet, device=self.device, dtype=torch.float16, use_resampler=False)
92
+ adapter_weight_path = _resolve_weight(self.config.pretrained_folder, self.config.adapter_path)
93
+ _state_dict = torch.load(adapter_weight_path, map_location="cpu")
94
  self.hair_adapter.load_state_dict(_state_dict, strict=False)
95
 
96
  ### load bald converter
97
  bald_converter = ControlNetModel.from_unet(unet).to(device)
98
+ # bald_converter_path may be a local full path or a hub-like path
99
+ if os.path.exists(self.config.bald_converter_path):
100
+ bald_weight_path = self.config.bald_converter_path
101
+ else:
102
+ prefix, filename = os.path.split(self.config.bald_converter_path)
103
+ bald_weight_path = _resolve_weight(prefix, filename)
104
+
105
+ _state_dict = torch.load(bald_weight_path, map_location="cpu")
106
  bald_converter.load_state_dict(_state_dict, strict=False)
107
  bald_converter.to(dtype=weight_dtype)
108
  del unet