hysts HF staff commited on
Commit
b106068
1 Parent(s): 0add3fc

Apply formatter

Browse files
Files changed (1) hide show
  1. app.py +187 -114
app.py CHANGED
@@ -1,25 +1,22 @@
1
- import os
2
- import cv2
3
  import math
4
- import spaces
5
- import torch
6
  import random
7
- import numpy as np
8
-
9
- import PIL
10
- from PIL import Image
11
 
 
12
  import diffusers
13
- from diffusers.utils import load_image
14
- from diffusers.models import ControlNetModel
15
-
16
  import insightface
 
 
 
 
 
 
17
  from insightface.app import FaceAnalysis
 
18
 
19
- from style_template import styles
20
  from pipeline_stable_diffusion_xl_instantid import StableDiffusionXLInstantIDPipeline
21
-
22
- import gradio as gr
23
 
24
  # global variable
25
  MAX_SEED = np.iinfo(np.int32).max
@@ -29,22 +26,27 @@ DEFAULT_STYLE_NAME = "Watercolor"
29
 
30
  # download checkpoints
31
  from huggingface_hub import hf_hub_download
 
32
  hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/config.json", local_dir="./checkpoints")
33
- hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/diffusion_pytorch_model.safetensors", local_dir="./checkpoints")
 
 
 
 
34
  hf_hub_download(repo_id="InstantX/InstantID", filename="ip-adapter.bin", local_dir="./checkpoints")
35
 
36
  # Load face encoder
37
- app = FaceAnalysis(name='antelopev2', root='./', providers=['CPUExecutionProvider'])
38
  app.prepare(ctx_id=0, det_size=(640, 640))
39
 
40
  # Path to InstantID models
41
- face_adapter = f'./checkpoints/ip-adapter.bin'
42
- controlnet_path = f'./checkpoints/ControlNetModel'
43
 
44
  # Load pipeline
45
  controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
46
 
47
- base_model_path = 'wangqixun/YamerMIX_v8'
48
 
49
  pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
50
  base_model_path,
@@ -55,54 +57,68 @@ pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
55
  )
56
  pipe.cuda()
57
  pipe.load_ip_adapter_instantid(face_adapter)
58
- pipe.image_proj_model.to('cuda')
59
- pipe.unet.to('cuda')
 
60
 
61
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
62
  if randomize_seed:
63
  seed = random.randint(0, MAX_SEED)
64
  return seed
65
 
 
66
  def swap_to_gallery(images):
67
- return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(visible=False)
 
 
 
 
 
68
 
69
  def upload_example_to_gallery(images, prompt, style, negative_prompt):
70
- return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(visible=False)
 
 
 
 
 
71
 
72
  def remove_back_to_files():
73
  return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
74
 
 
75
  def remove_tips():
76
  return gr.update(visible=False)
77
 
 
78
  def get_example():
79
  case = [
80
  [
81
- ['./examples/yann-lecun_resize.jpg'],
82
  "a man",
83
  "Snow",
84
  "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
85
  ],
86
  [
87
- ['./examples/musk_resize.jpeg'],
88
  "a man",
89
  "Mars",
90
  "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
91
  ],
92
  [
93
- ['./examples/sam_resize.png'],
94
  "a man",
95
  "Jungle",
96
  "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, gree",
97
  ],
98
  [
99
- ['./examples/schmidhuber_resize.png'],
100
  "a man",
101
  "Neon",
102
  "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
103
  ],
104
  [
105
- ['./examples/kaifu_resize.png'],
106
  "a man",
107
  "Vibrant Color",
108
  "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
@@ -110,16 +126,20 @@ def get_example():
110
  ]
111
  return case
112
 
 
113
  def run_for_examples(face_files, prompt, style, negative_prompt):
114
  return generate_image(face_files, None, prompt, negative_prompt, style, True, 30, 0.8, 0.8, 5, 42)
115
 
 
116
  def convert_from_cv2_to_image(img: np.ndarray) -> Image:
117
  return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
118
 
 
119
  def convert_from_image_to_cv2(img: Image) -> np.ndarray:
120
  return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
121
 
122
- def draw_kps(image_pil, kps, color_list=[(255,0,0), (0,255,0), (0,0,255), (255,255,0), (255,0,255)]):
 
123
  stickwidth = 4
124
  limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]])
125
  kps = np.array(kps)
@@ -135,7 +155,9 @@ def draw_kps(image_pil, kps, color_list=[(255,0,0), (0,255,0), (0,0,255), (255,2
135
  y = kps[index][:, 1]
136
  length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5
137
  angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1]))
138
- polygon = cv2.ellipse2Poly((int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
 
 
139
  out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color)
140
  out_img = (out_img * 0.6).astype(np.uint8)
141
 
@@ -147,89 +169,115 @@ def draw_kps(image_pil, kps, color_list=[(255,0,0), (0,255,0), (0,0,255), (255,2
147
  out_img_pil = Image.fromarray(out_img.astype(np.uint8))
148
  return out_img_pil
149
 
150
- def resize_img(input_image, max_side=1280, min_side=1024, size=None,
151
- pad_to_max_side=False, mode=PIL.Image.BILINEAR, base_pixel_number=64):
152
-
153
- w, h = input_image.size
154
- if size is not None:
155
- w_resize_new, h_resize_new = size
156
- else:
157
- ratio = min_side / min(h, w)
158
- w, h = round(ratio*w), round(ratio*h)
159
- ratio = max_side / max(h, w)
160
- input_image = input_image.resize([round(ratio*w), round(ratio*h)], mode)
161
- w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number
162
- h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number
163
- input_image = input_image.resize([w_resize_new, h_resize_new], mode)
164
-
165
- if pad_to_max_side:
166
- res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255
167
- offset_x = (max_side - w_resize_new) // 2
168
- offset_y = (max_side - h_resize_new) // 2
169
- res[offset_y:offset_y+h_resize_new, offset_x:offset_x+w_resize_new] = np.array(input_image)
170
- input_image = Image.fromarray(res)
171
- return input_image
 
 
 
 
 
 
 
 
172
 
173
  def apply_style(style_name: str, positive: str, negative: str = "") -> tuple[str, str]:
174
  p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
175
- return p.replace("{prompt}", positive), n + ' ' + negative
176
 
177
- @spaces.GPU
178
- def generate_image(face_image, pose_image, prompt, negative_prompt, style_name, enhance_face_region, num_steps, identitynet_strength_ratio, adapter_strength_ratio, guidance_scale, seed, progress=gr.Progress(track_tqdm=True)):
179
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  if face_image is None:
181
  raise gr.Error(f"Cannot find any input face image! Please upload the face image")
182
-
183
  if prompt is None:
184
  prompt = "a person"
185
-
186
  # apply the style template
187
  prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt)
188
-
189
  face_image = load_image(face_image[0])
190
  face_image = resize_img(face_image)
191
  face_image_cv2 = convert_from_image_to_cv2(face_image)
192
  height, width, _ = face_image_cv2.shape
193
-
194
  # Extract face features
195
  face_info = app.get(face_image_cv2)
196
-
197
  if len(face_info) == 0:
198
  raise gr.Error(f"Cannot find any face in the image! Please upload another person image")
199
-
200
- face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*x['bbox'][3]-x['bbox'][1])[-1] # only use the maximum face
201
- face_emb = face_info['embedding']
202
- face_kps = draw_kps(convert_from_cv2_to_image(face_image_cv2), face_info['kps'])
203
-
 
 
 
 
 
204
  if pose_image is not None:
205
  pose_image = load_image(pose_image[0])
206
  pose_image = resize_img(pose_image)
207
  pose_image_cv2 = convert_from_image_to_cv2(pose_image)
208
-
209
  face_info = app.get(pose_image_cv2)
210
-
211
  if len(face_info) == 0:
212
  raise gr.Error(f"Cannot find any face in the reference image! Please upload another person image")
213
-
214
  face_info = face_info[-1]
215
- face_kps = draw_kps(pose_image, face_info['kps'])
216
-
217
  width, height = face_kps.size
218
-
219
  if enhance_face_region:
220
  control_mask = np.zeros([height, width, 3])
221
- x1, y1, x2, y2 = face_info['bbox']
222
  x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
223
  control_mask[y1:y2, x1:x2] = 255
224
  control_mask = Image.fromarray(control_mask.astype(np.uint8))
225
  else:
226
  control_mask = None
227
-
228
  generator = torch.Generator(device=device).manual_seed(seed)
229
-
230
  print("Start inference...")
231
  print(f"[Debug] Prompt: {prompt}, \n[Debug] Neg Prompt: {negative_prompt}")
232
-
233
  pipe.set_ip_adapter_scale(adapter_strength_ratio)
234
  images = pipe(
235
  prompt=prompt,
@@ -242,11 +290,12 @@ def generate_image(face_image, pose_image, prompt, negative_prompt, style_name,
242
  guidance_scale=guidance_scale,
243
  height=height,
244
  width=width,
245
- generator=generator
246
  ).images
247
 
248
  return images, gr.update(visible=True)
249
 
 
250
  ### Description
251
  title = r"""
252
  <h1 align="center">InstantID: Zero-shot Identity-Preserving Generation in Seconds</h1>
@@ -289,46 +338,44 @@ tips = r"""
289
  4. Find a good base model always makes a difference.
290
  """
291
 
292
- css = '''
293
  .gradio-container {width: 85% !important}
294
- '''
295
  with gr.Blocks(css=css) as demo:
296
-
297
  # description
298
  gr.Markdown(title)
299
  gr.Markdown(description)
300
 
301
  with gr.Row():
302
  with gr.Column():
303
-
304
  # upload face image
305
- face_files = gr.Files(
306
- label="Upload a photo of your face",
307
- file_types=["image"]
308
- )
309
  uploaded_faces = gr.Gallery(label="Your images", visible=False, columns=1, rows=1, height=512)
310
  with gr.Column(visible=False) as clear_button_face:
311
- remove_and_reupload_faces = gr.ClearButton(value="Remove and upload new ones", components=face_files, size="sm")
312
-
 
 
313
  # optional: upload a reference pose image
314
- pose_files = gr.Files(
315
- label="Upload a reference pose image (optional)",
316
- file_types=["image"]
317
- )
318
  uploaded_poses = gr.Gallery(label="Your images", visible=False, columns=1, rows=1, height=512)
319
  with gr.Column(visible=False) as clear_button_pose:
320
- remove_and_reupload_poses = gr.ClearButton(value="Remove and upload new ones", components=pose_files, size="sm")
321
-
 
 
322
  # prompt
323
- prompt = gr.Textbox(label="Prompt",
324
- info="Give simple prompt is enough to achieve good face fedility",
325
- placeholder="A photo of a person",
326
- value="")
327
-
 
 
328
  submit = gr.Button("Submit", variant="primary")
329
-
330
  style = gr.Dropdown(label="Style template", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME)
331
-
332
  # strength
333
  identitynet_strength_ratio = gr.Slider(
334
  label="IdentityNet strength (for fedility)",
@@ -344,14 +391,14 @@ with gr.Blocks(css=css) as demo:
344
  step=0.05,
345
  value=0.80,
346
  )
347
-
348
  with gr.Accordion(open=False, label="Advanced Options"):
349
  negative_prompt = gr.Textbox(
350
- label="Negative Prompt",
351
  placeholder="low quality",
352
  value="(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
353
  )
354
- num_steps = gr.Slider(
355
  label="Number of sample steps",
356
  minimum=20,
357
  maximum=100,
@@ -377,17 +424,31 @@ with gr.Blocks(css=css) as demo:
377
 
378
  with gr.Column():
379
  gallery = gr.Gallery(label="Generated Images")
380
- usage_tips = gr.Markdown(label="Usage tips of InstantID", value=tips ,visible=False)
381
 
382
- face_files.upload(fn=swap_to_gallery, inputs=face_files, outputs=[uploaded_faces, clear_button_face, face_files])
383
- pose_files.upload(fn=swap_to_gallery, inputs=pose_files, outputs=[uploaded_poses, clear_button_pose, pose_files])
 
 
 
 
 
 
 
 
384
 
385
- remove_and_reupload_faces.click(fn=remove_back_to_files, outputs=[uploaded_faces, clear_button_face, face_files])
386
- remove_and_reupload_poses.click(fn=remove_back_to_files, outputs=[uploaded_poses, clear_button_pose, pose_files])
 
 
 
 
 
 
387
 
388
  submit.click(
389
  fn=remove_tips,
390
- outputs=usage_tips,
391
  ).then(
392
  fn=randomize_seed_fn,
393
  inputs=[seed, randomize_seed],
@@ -396,21 +457,33 @@ with gr.Blocks(css=css) as demo:
396
  api_name=False,
397
  ).then(
398
  fn=generate_image,
399
- inputs=[face_files, pose_files, prompt, negative_prompt, style, enhance_face_region, num_steps, identitynet_strength_ratio, adapter_strength_ratio, guidance_scale, seed],
400
- outputs=[gallery, usage_tips]
 
 
 
 
 
 
 
 
 
 
 
 
401
  )
402
-
403
  gr.Examples(
404
  examples=get_example(),
405
  inputs=[face_files, prompt, style, negative_prompt],
406
  run_on_click=True,
407
  fn=upload_example_to_gallery,
408
  outputs=[uploaded_faces, clear_button_face, face_files],
409
- cache_examples=True
410
  )
411
-
412
  gr.Markdown(article)
413
 
414
 
415
  demo.queue(api_open=False)
416
- demo.launch()
 
 
 
1
  import math
2
+ import os
 
3
  import random
 
 
 
 
4
 
5
+ import cv2
6
  import diffusers
7
+ import gradio as gr
 
 
8
  import insightface
9
+ import numpy as np
10
+ import PIL
11
+ import spaces
12
+ import torch
13
+ from diffusers.models import ControlNetModel
14
+ from diffusers.utils import load_image
15
  from insightface.app import FaceAnalysis
16
+ from PIL import Image
17
 
 
18
  from pipeline_stable_diffusion_xl_instantid import StableDiffusionXLInstantIDPipeline
19
+ from style_template import styles
 
20
 
21
  # global variable
22
  MAX_SEED = np.iinfo(np.int32).max
 
26
 
27
  # download checkpoints
28
  from huggingface_hub import hf_hub_download
29
+
30
  hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/config.json", local_dir="./checkpoints")
31
+ hf_hub_download(
32
+ repo_id="InstantX/InstantID",
33
+ filename="ControlNetModel/diffusion_pytorch_model.safetensors",
34
+ local_dir="./checkpoints",
35
+ )
36
  hf_hub_download(repo_id="InstantX/InstantID", filename="ip-adapter.bin", local_dir="./checkpoints")
37
 
38
  # Load face encoder
39
+ app = FaceAnalysis(name="antelopev2", root="./", providers=["CPUExecutionProvider"])
40
  app.prepare(ctx_id=0, det_size=(640, 640))
41
 
42
  # Path to InstantID models
43
+ face_adapter = f"./checkpoints/ip-adapter.bin"
44
+ controlnet_path = f"./checkpoints/ControlNetModel"
45
 
46
  # Load pipeline
47
  controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
48
 
49
+ base_model_path = "wangqixun/YamerMIX_v8"
50
 
51
  pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
52
  base_model_path,
 
57
  )
58
  pipe.cuda()
59
  pipe.load_ip_adapter_instantid(face_adapter)
60
+ pipe.image_proj_model.to("cuda")
61
+ pipe.unet.to("cuda")
62
+
63
 
64
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
65
  if randomize_seed:
66
  seed = random.randint(0, MAX_SEED)
67
  return seed
68
 
69
+
70
  def swap_to_gallery(images):
71
+ return (
72
+ gr.update(value=images, visible=True),
73
+ gr.update(visible=True),
74
+ gr.update(visible=False),
75
+ )
76
+
77
 
78
  def upload_example_to_gallery(images, prompt, style, negative_prompt):
79
+ return (
80
+ gr.update(value=images, visible=True),
81
+ gr.update(visible=True),
82
+ gr.update(visible=False),
83
+ )
84
+
85
 
86
  def remove_back_to_files():
87
  return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
88
 
89
+
90
  def remove_tips():
91
  return gr.update(visible=False)
92
 
93
+
94
  def get_example():
95
  case = [
96
  [
97
+ ["./examples/yann-lecun_resize.jpg"],
98
  "a man",
99
  "Snow",
100
  "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
101
  ],
102
  [
103
+ ["./examples/musk_resize.jpeg"],
104
  "a man",
105
  "Mars",
106
  "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
107
  ],
108
  [
109
+ ["./examples/sam_resize.png"],
110
  "a man",
111
  "Jungle",
112
  "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, gree",
113
  ],
114
  [
115
+ ["./examples/schmidhuber_resize.png"],
116
  "a man",
117
  "Neon",
118
  "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
119
  ],
120
  [
121
+ ["./examples/kaifu_resize.png"],
122
  "a man",
123
  "Vibrant Color",
124
  "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
 
126
  ]
127
  return case
128
 
129
+
130
  def run_for_examples(face_files, prompt, style, negative_prompt):
131
  return generate_image(face_files, None, prompt, negative_prompt, style, True, 30, 0.8, 0.8, 5, 42)
132
 
133
+
134
  def convert_from_cv2_to_image(img: np.ndarray) -> Image:
135
  return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
136
 
137
+
138
  def convert_from_image_to_cv2(img: Image) -> np.ndarray:
139
  return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
140
 
141
+
142
+ def draw_kps(image_pil, kps, color_list=[(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255)]):
143
  stickwidth = 4
144
  limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]])
145
  kps = np.array(kps)
 
155
  y = kps[index][:, 1]
156
  length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5
157
  angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1]))
158
+ polygon = cv2.ellipse2Poly(
159
+ (int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1
160
+ )
161
  out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color)
162
  out_img = (out_img * 0.6).astype(np.uint8)
163
 
 
169
  out_img_pil = Image.fromarray(out_img.astype(np.uint8))
170
  return out_img_pil
171
 
172
+
173
+ def resize_img(
174
+ input_image,
175
+ max_side=1280,
176
+ min_side=1024,
177
+ size=None,
178
+ pad_to_max_side=False,
179
+ mode=PIL.Image.BILINEAR,
180
+ base_pixel_number=64,
181
+ ):
182
+ w, h = input_image.size
183
+ if size is not None:
184
+ w_resize_new, h_resize_new = size
185
+ else:
186
+ ratio = min_side / min(h, w)
187
+ w, h = round(ratio * w), round(ratio * h)
188
+ ratio = max_side / max(h, w)
189
+ input_image = input_image.resize([round(ratio * w), round(ratio * h)], mode)
190
+ w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number
191
+ h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number
192
+ input_image = input_image.resize([w_resize_new, h_resize_new], mode)
193
+
194
+ if pad_to_max_side:
195
+ res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255
196
+ offset_x = (max_side - w_resize_new) // 2
197
+ offset_y = (max_side - h_resize_new) // 2
198
+ res[offset_y : offset_y + h_resize_new, offset_x : offset_x + w_resize_new] = np.array(input_image)
199
+ input_image = Image.fromarray(res)
200
+ return input_image
201
+
202
 
203
  def apply_style(style_name: str, positive: str, negative: str = "") -> tuple[str, str]:
204
  p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
205
+ return p.replace("{prompt}", positive), n + " " + negative
206
 
 
 
207
 
208
+ @spaces.GPU
209
+ def generate_image(
210
+ face_image,
211
+ pose_image,
212
+ prompt,
213
+ negative_prompt,
214
+ style_name,
215
+ enhance_face_region,
216
+ num_steps,
217
+ identitynet_strength_ratio,
218
+ adapter_strength_ratio,
219
+ guidance_scale,
220
+ seed,
221
+ progress=gr.Progress(track_tqdm=True),
222
+ ):
223
  if face_image is None:
224
  raise gr.Error(f"Cannot find any input face image! Please upload the face image")
225
+
226
  if prompt is None:
227
  prompt = "a person"
228
+
229
  # apply the style template
230
  prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt)
231
+
232
  face_image = load_image(face_image[0])
233
  face_image = resize_img(face_image)
234
  face_image_cv2 = convert_from_image_to_cv2(face_image)
235
  height, width, _ = face_image_cv2.shape
236
+
237
  # Extract face features
238
  face_info = app.get(face_image_cv2)
239
+
240
  if len(face_info) == 0:
241
  raise gr.Error(f"Cannot find any face in the image! Please upload another person image")
242
+
243
+ face_info = sorted(
244
+ face_info,
245
+ key=lambda x: (x["bbox"][2] - x["bbox"][0]) * x["bbox"][3] - x["bbox"][1],
246
+ )[
247
+ -1
248
+ ] # only use the maximum face
249
+ face_emb = face_info["embedding"]
250
+ face_kps = draw_kps(convert_from_cv2_to_image(face_image_cv2), face_info["kps"])
251
+
252
  if pose_image is not None:
253
  pose_image = load_image(pose_image[0])
254
  pose_image = resize_img(pose_image)
255
  pose_image_cv2 = convert_from_image_to_cv2(pose_image)
256
+
257
  face_info = app.get(pose_image_cv2)
258
+
259
  if len(face_info) == 0:
260
  raise gr.Error(f"Cannot find any face in the reference image! Please upload another person image")
261
+
262
  face_info = face_info[-1]
263
+ face_kps = draw_kps(pose_image, face_info["kps"])
264
+
265
  width, height = face_kps.size
266
+
267
  if enhance_face_region:
268
  control_mask = np.zeros([height, width, 3])
269
+ x1, y1, x2, y2 = face_info["bbox"]
270
  x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
271
  control_mask[y1:y2, x1:x2] = 255
272
  control_mask = Image.fromarray(control_mask.astype(np.uint8))
273
  else:
274
  control_mask = None
275
+
276
  generator = torch.Generator(device=device).manual_seed(seed)
277
+
278
  print("Start inference...")
279
  print(f"[Debug] Prompt: {prompt}, \n[Debug] Neg Prompt: {negative_prompt}")
280
+
281
  pipe.set_ip_adapter_scale(adapter_strength_ratio)
282
  images = pipe(
283
  prompt=prompt,
 
290
  guidance_scale=guidance_scale,
291
  height=height,
292
  width=width,
293
+ generator=generator,
294
  ).images
295
 
296
  return images, gr.update(visible=True)
297
 
298
+
299
  ### Description
300
  title = r"""
301
  <h1 align="center">InstantID: Zero-shot Identity-Preserving Generation in Seconds</h1>
 
338
  4. Find a good base model always makes a difference.
339
  """
340
 
341
+ css = """
342
  .gradio-container {width: 85% !important}
343
+ """
344
  with gr.Blocks(css=css) as demo:
 
345
  # description
346
  gr.Markdown(title)
347
  gr.Markdown(description)
348
 
349
  with gr.Row():
350
  with gr.Column():
 
351
  # upload face image
352
+ face_files = gr.Files(label="Upload a photo of your face", file_types=["image"])
 
 
 
353
  uploaded_faces = gr.Gallery(label="Your images", visible=False, columns=1, rows=1, height=512)
354
  with gr.Column(visible=False) as clear_button_face:
355
+ remove_and_reupload_faces = gr.ClearButton(
356
+ value="Remove and upload new ones", components=face_files, size="sm"
357
+ )
358
+
359
  # optional: upload a reference pose image
360
+ pose_files = gr.Files(label="Upload a reference pose image (optional)", file_types=["image"])
 
 
 
361
  uploaded_poses = gr.Gallery(label="Your images", visible=False, columns=1, rows=1, height=512)
362
  with gr.Column(visible=False) as clear_button_pose:
363
+ remove_and_reupload_poses = gr.ClearButton(
364
+ value="Remove and upload new ones", components=pose_files, size="sm"
365
+ )
366
+
367
  # prompt
368
+ prompt = gr.Textbox(
369
+ label="Prompt",
370
+ info="Give simple prompt is enough to achieve good face fedility",
371
+ placeholder="A photo of a person",
372
+ value="",
373
+ )
374
+
375
  submit = gr.Button("Submit", variant="primary")
376
+
377
  style = gr.Dropdown(label="Style template", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME)
378
+
379
  # strength
380
  identitynet_strength_ratio = gr.Slider(
381
  label="IdentityNet strength (for fedility)",
 
391
  step=0.05,
392
  value=0.80,
393
  )
394
+
395
  with gr.Accordion(open=False, label="Advanced Options"):
396
  negative_prompt = gr.Textbox(
397
+ label="Negative Prompt",
398
  placeholder="low quality",
399
  value="(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
400
  )
401
+ num_steps = gr.Slider(
402
  label="Number of sample steps",
403
  minimum=20,
404
  maximum=100,
 
424
 
425
  with gr.Column():
426
  gallery = gr.Gallery(label="Generated Images")
427
+ usage_tips = gr.Markdown(label="Usage tips of InstantID", value=tips, visible=False)
428
 
429
+ face_files.upload(
430
+ fn=swap_to_gallery,
431
+ inputs=face_files,
432
+ outputs=[uploaded_faces, clear_button_face, face_files],
433
+ )
434
+ pose_files.upload(
435
+ fn=swap_to_gallery,
436
+ inputs=pose_files,
437
+ outputs=[uploaded_poses, clear_button_pose, pose_files],
438
+ )
439
 
440
+ remove_and_reupload_faces.click(
441
+ fn=remove_back_to_files,
442
+ outputs=[uploaded_faces, clear_button_face, face_files],
443
+ )
444
+ remove_and_reupload_poses.click(
445
+ fn=remove_back_to_files,
446
+ outputs=[uploaded_poses, clear_button_pose, pose_files],
447
+ )
448
 
449
  submit.click(
450
  fn=remove_tips,
451
+ outputs=usage_tips,
452
  ).then(
453
  fn=randomize_seed_fn,
454
  inputs=[seed, randomize_seed],
 
457
  api_name=False,
458
  ).then(
459
  fn=generate_image,
460
+ inputs=[
461
+ face_files,
462
+ pose_files,
463
+ prompt,
464
+ negative_prompt,
465
+ style,
466
+ enhance_face_region,
467
+ num_steps,
468
+ identitynet_strength_ratio,
469
+ adapter_strength_ratio,
470
+ guidance_scale,
471
+ seed,
472
+ ],
473
+ outputs=[gallery, usage_tips],
474
  )
475
+
476
  gr.Examples(
477
  examples=get_example(),
478
  inputs=[face_files, prompt, style, negative_prompt],
479
  run_on_click=True,
480
  fn=upload_example_to_gallery,
481
  outputs=[uploaded_faces, clear_button_face, face_files],
482
+ cache_examples=True,
483
  )
484
+
485
  gr.Markdown(article)
486
 
487
 
488
  demo.queue(api_open=False)
489
+ demo.launch()