revi13 commited on
Commit
74415da
·
verified ·
1 Parent(s): afb5de2

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +10 -4
pipeline.py CHANGED
@@ -58,18 +58,24 @@ class IPFacePlusV2Pipeline:
58
 
59
  # Convert face embedding to torch tensor and match dtype
60
  face_embedding = torch.tensor(faces[0].normed_embedding).to(self.device)
61
- if self.pipe.dtype == torch.float16:
62
- face_embedding = face_embedding.half()
 
 
 
 
 
 
63
 
64
  # Generate image
65
  image = self.ip_adapter.generate(
66
  prompt=prompt,
67
  scale=scale,
68
  faceid_embeds=face_embedding.unsqueeze(0) # バッチ次元を追加
69
- )[0] # 最初の1枚を取得(必要に応じて変更)
70
 
71
  # Convert PIL image to base64
72
  buffered = io.BytesIO()
73
  image.save(buffered, format="PNG")
74
  encoded_img = base64.b64encode(buffered.getvalue()).decode("utf-8")
75
- return f"data:image/png;base64,{encoded_img}"
 
58
 
59
  # Convert face embedding to torch tensor and match dtype
60
  face_embedding = torch.tensor(faces[0].normed_embedding).to(self.device)
61
+
62
+ # 安全な dtype チェック(pipe 自体に dtype が無ければ unet から取得)
63
+ pipe_dtype = getattr(self.pipe, "dtype", getattr(self.pipe, "unet", None).dtype)
64
+
65
+ if pipe_dtype == torch.float16:
66
+ face_embedding = face_embedding.to(dtype=torch.float16)
67
+ else:
68
+ face_embedding = face_embedding.to(dtype=torch.float32)
69
 
70
  # Generate image
71
  image = self.ip_adapter.generate(
72
  prompt=prompt,
73
  scale=scale,
74
  faceid_embeds=face_embedding.unsqueeze(0) # バッチ次元を追加
75
+ )[0] # 最初の1枚を取得
76
 
77
  # Convert PIL image to base64
78
  buffered = io.BytesIO()
79
  image.save(buffered, format="PNG")
80
  encoded_img = base64.b64encode(buffered.getvalue()).decode("utf-8")
81
+ return f"data:image/png;base64,{encoded_img}"