Fucius commited on
Commit
b4054cc
1 Parent(s): 234c811

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +255 -425
app.py CHANGED
@@ -1,37 +1,42 @@
1
  import spaces
2
  import sys
3
  import os
4
- import torch
5
- torch.jit.script = lambda f: f
6
 
7
- import sys
 
 
 
 
 
 
 
 
8
  sys.path.append('./')
9
- import argparse
10
- import hashlib
11
- import json
12
- import os.path
13
  import numpy as np
 
 
 
14
  import torch
 
15
  from typing import Tuple, List
16
- from diffusers import DPMSolverMultistepScheduler
17
- from diffusers.models import T2IAdapter
18
- from PIL import Image
19
  import copy
20
- from diffusers import ControlNetModel, StableDiffusionXLPipeline
21
- from insightface.app import FaceAnalysis
22
- import gradio as gr
23
- import random
24
  from PIL import Image, ImageOps
25
  from transformers import DPTFeatureExtractor, DPTForDepthEstimation
26
- from controlnet_aux import OpenposeDetector
27
- from controlnet_aux.open_pose.body import Body
28
- from download import OMG_download
29
-
30
- from inference.models import YOLOWorld
31
- from src.efficientvit.models.efficientvit.sam import EfficientViTSamPredictor
32
- from src.efficientvit.sam_model_zoo import create_sam_model
33
- import supervision as sv
34
 
 
 
 
 
 
 
 
35
 
36
  try:
37
  from groundingdino.models import build_model
@@ -44,28 +49,25 @@ try:
44
  except:
45
  print("groundingdino can not be load")
46
 
47
- from src.pipelines.instantid_pipeline import InstantidMultiConceptPipeline
48
- from src.pipelines.instantid_single_pieline import InstantidSingleConceptPipeline
49
  from src.prompt_attention.p2p_attention import AttentionReplace
50
- from src.pipelines.instantid_pipeline import revise_regionally_controlnet_forward
51
- import cv2
52
- import math
53
- import PIL.Image
54
-
55
- from gradio_demo.character_template import styles, lorapath_styles
56
- STYLE_NAMES = list(styles.keys())
57
-
58
 
 
59
 
 
 
 
60
  MAX_SEED = np.iinfo(np.int32).max
61
 
 
62
  title = r"""
63
- <h1 align="center">OMG: Occlusion-friendly Personalized Multi-concept Generation In Diffusion Models (OMG + InstantID)</h1>
64
  """
65
 
66
  description = r"""
67
- <b>Official 🤗 Gradio demo</b> for <a href='https://github.com/kongzhecn/OMG/' target='_blank'><b>OMG: Occlusion-friendly Personalized Multi-concept Generation In Diffusion Models</b></a>.<be>.<br>
68
- <a href='https://kongzhecn.github.io/omg-project/' target='_blank'><b>[Project]</b></a>.<a href='https://github.com/kongzhecn/OMG/' target='_blank'><b>[Code]</b></a>.<a href='https://arxiv.org/abs/2403.10983/' target='_blank'><b>[Arxiv]</b></a>.<br>
69
  How to use:<br>
70
  1. Select two characters.
71
  2. Enter a text prompt as done in normal text-to-image models.
@@ -97,56 +99,26 @@ css = '''
97
  .gradio-container {width: 85% !important}
98
  '''
99
 
100
-
101
-
102
- def build_dino_segment_model(ckpt_repo_id, sam_checkpoint):
103
- ckpt_filenmae = "groundingdino_swinb_cogcoor.pth"
104
- ckpt_config_filename = os.path.join(ckpt_repo_id, "GroundingDINO_SwinB.cfg.py")
105
- groundingdino_model = load_model_hf(ckpt_repo_id, ckpt_filenmae, ckpt_config_filename)
106
- sam = build_sam(checkpoint=sam_checkpoint)
107
- sam.cuda()
108
- sam_predictor = SamPredictor(sam)
109
- return groundingdino_model, sam_predictor
110
-
111
- def load_model_hf(repo_id, filename, ckpt_config_filename, device='cpu'):
112
- args = SLConfig.fromfile(ckpt_config_filename)
113
- model = build_model(args)
114
- args.device = device
115
-
116
- checkpoint = torch.load(os.path.join(repo_id, filename), map_location='cpu')
117
- log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
118
- print("Model loaded from {} \n => {}".format(filename, log))
119
- _ = model.eval()
120
- return model
121
-
122
- def build_yolo_segment_model(sam_path, device):
123
- yolo_world = YOLOWorld(model_id="yolo_world/l")
124
- sam = EfficientViTSamPredictor(
125
- create_sam_model(name="xl1", weight_url=sam_path).to(device).eval()
126
- )
127
- return yolo_world, sam
128
-
129
  def sample_image(pipe,
130
  input_prompt,
131
  input_neg_prompt=None,
132
  generator=None,
133
  concept_models=None,
134
  num_inference_steps=50,
135
- guidance_scale=3.0,
136
  controller=None,
137
- face_app=None,
138
- image=None,
139
  stage=None,
140
  region_masks=None,
141
- controlnet_conditioning_scale=None,
 
142
  **extra_kargs
143
  ):
144
 
145
- if image is not None:
146
- image_condition = [image]
 
147
  else:
148
- image_condition = None
149
-
150
 
151
  images = pipe(
152
  prompt=input_prompt,
@@ -157,12 +129,13 @@ def sample_image(pipe,
157
  num_inference_steps=num_inference_steps,
158
  cross_attention_kwargs={"scale": 0.8},
159
  controller=controller,
160
- image=image_condition,
161
- face_app=face_app,
162
  stage=stage,
163
- controlnet_conditioning_scale = controlnet_conditioning_scale,
164
  region_masks=region_masks,
 
 
 
165
  **extra_kargs).images
 
166
  return images
167
 
168
  def load_image_yoloworld(image_source) -> Tuple[np.array, torch.Tensor]:
@@ -181,37 +154,6 @@ def load_image_dino(image_source) -> Tuple[np.array, torch.Tensor]:
181
  image_transformed, _ = transform(image_source, None)
182
  return image, image_transformed
183
 
184
- def draw_kps_multi(image_pil, kps_list, color_list=[(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255)]):
185
- stickwidth = 4
186
- limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]])
187
-
188
-
189
- w, h = image_pil.size
190
- out_img = np.zeros([h, w, 3])
191
-
192
- for kps in kps_list:
193
- kps = np.array(kps)
194
- for i in range(len(limbSeq)):
195
- index = limbSeq[i]
196
- color = color_list[index[0]]
197
-
198
- x = kps[index][:, 0]
199
- y = kps[index][:, 1]
200
- length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5
201
- angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1]))
202
- polygon = cv2.ellipse2Poly((int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0,
203
- 360, 1)
204
- out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color)
205
- out_img = (out_img * 0.6).astype(np.uint8)
206
-
207
- for idx_kp, kp in enumerate(kps):
208
- color = color_list[idx_kp]
209
- x, y = kp
210
- out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1)
211
-
212
- out_img_pil = PIL.Image.fromarray(out_img.astype(np.uint8))
213
- return out_img_pil
214
-
215
  def predict_mask(segmentmodel, sam, image, TEXT_PROMPT, segmentType, confidence = 0.2, threshold = 0.5):
216
  if segmentType=='GroundingDINO':
217
  image_source, image = load_image_dino(image)
@@ -236,68 +178,19 @@ def predict_mask(segmentmodel, sam, image, TEXT_PROMPT, segmentType, confidence
236
  masks=masks[0].squeeze(0)
237
  else:
238
  image_source = load_image_yoloworld(image)
239
- segmentmodel.set_classes(TEXT_PROMPT)
240
  results = segmentmodel.infer(image_source, confidence=confidence)
241
  detections = sv.Detections.from_inference(results).with_nms(
242
  class_agnostic=True, threshold=threshold
243
  )
 
 
 
 
 
 
244
 
245
- masks_list = []
246
- sam.set_image(image_source, image_format="RGB")
247
- for xyxy in detections.xyxy:
248
- mask, _, _ = sam.predict(box=xyxy, multimask_output=False)
249
- masks_list.append(mask.squeeze())
250
- detections.mask = np.array(masks_list)
251
-
252
- mask_1 = []
253
- mask_2 = []
254
- for i, (class_id, confidence) in enumerate(zip(detections.class_id, detections.confidence)):
255
- if class_id==0:
256
- mask_1.append(torch.from_numpy(detections.mask[i]))
257
- if class_id==1:
258
- mask_2.append(torch.from_numpy(detections.mask[i]))
259
- if len(mask_1)==0:
260
- mask_1.append(None)
261
- if len(mask_2)==0:
262
- mask_2.append(None)
263
- if len(TEXT_PROMPT)==2:
264
- return mask_1[0], mask_2[0]
265
-
266
- return mask_1[0]
267
-
268
- def build_model_sd(pretrained_model, controlnet_path, face_adapter, device, prompts, antelopev2_path, width, height, style_lora):
269
- controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
270
- pipe = InstantidMultiConceptPipeline.from_pretrained(
271
- pretrained_model, controlnet=controlnet, torch_dtype=torch.float16, variant="fp16").to(device)
272
-
273
- controller = AttentionReplace(prompts, 50, cross_replace_steps={"default_": 1.},
274
- self_replace_steps=0.4, tokenizer=pipe.tokenizer, device=device, width=width, height=height,
275
- dtype=torch.float16)
276
- revise_regionally_controlnet_forward(pipe.unet, controller)
277
-
278
- controlnet_concept = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
279
- pipe_concept = InstantidSingleConceptPipeline.from_pretrained(
280
- pretrained_model,
281
- controlnet=controlnet_concept,
282
- torch_dtype=torch.float16
283
- )
284
- pipe_concept.load_ip_adapter_instantid(face_adapter)
285
- pipe_concept.set_ip_adapter_scale(0.8)
286
- pipe_concept.to(device)
287
- pipe_concept.image_proj_model.to(pipe_concept._execution_device)
288
-
289
- if style_lora is not None and os.path.exists(style_lora):
290
- pipe.load_lora_weights(style_lora, weight_name="pytorch_lora_weights.safetensors", adapter_name='style')
291
- pipe_concept.load_lora_weights(style_lora, weight_name="pytorch_lora_weights.safetensors", adapter_name='style')
292
-
293
-
294
- # modify
295
- app = FaceAnalysis(name='antelopev2', root=antelopev2_path,
296
- providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
297
- app.prepare(ctx_id=0, det_size=(640, 640))
298
-
299
- return pipe, controller, pipe_concept, app
300
-
301
 
302
  def prepare_text(prompt, region_prompts):
303
  '''
@@ -315,28 +208,74 @@ def prepare_text(prompt, region_prompts):
315
  for region in regions:
316
  if region == '':
317
  break
318
- prompt_region, neg_prompt_region, ref_img = region.split('-*-')
319
  prompt_region = prompt_region.replace('[', '').replace(']', '')
320
  neg_prompt_region = neg_prompt_region.replace('[', '').replace(']', '')
321
 
322
- region_collection.append((prompt_region, neg_prompt_region, ref_img))
323
  return (prompt, region_collection)
324
 
325
- def build_model_lora(pipe, pipe_concept, style_path, condition, condition_img):
326
- if condition == "Human pose" and condition_img is not None:
 
 
 
 
 
 
 
 
 
 
 
 
327
  controlnet = ControlNetModel.from_pretrained(args.openpose_checkpoint, torch_dtype=torch.float16).to(device)
328
- pipe.controlnet2 = controlnet
329
- elif condition == "Canny Edge" and condition_img is not None:
330
  controlnet = ControlNetModel.from_pretrained(args.canny_checkpoint, torch_dtype=torch.float16, variant="fp16").to(device)
331
- pipe.controlnet2 = controlnet
332
- elif condition == "Depth" and condition_img is not None:
333
  controlnet = ControlNetModel.from_pretrained(args.depth_checkpoint, torch_dtype=torch.float16).to(device)
334
- pipe.controlnet2 = controlnet
335
 
336
  if style_path is not None and os.path.exists(style_path):
337
  pipe_concept.load_lora_weights(style_path, weight_name="pytorch_lora_weights.safetensors", adapter_name='style')
338
  pipe.load_lora_weights(style_path, weight_name="pytorch_lora_weights.safetensors", adapter_name='style')
339
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
  def resize_and_center_crop(image, output_size=(1024, 576)):
341
  width, height = image.size
342
  aspect_ratio = width / height
@@ -364,14 +303,12 @@ def resize_and_center_crop(image, output_size=(1024, 576)):
364
  return cropped_image
365
 
366
  def main(device, segment_type):
367
- pipe, controller, pipe_concepts, face_app = build_model_sd(args.pretrained_model, args.controlnet_path,
368
- args.face_adapter_path, device, prompts_tmp,
369
- args.antelopev2_path, width // 32, height // 32,
370
- args.style_lora)
371
- if segment_type == 'GroundingDINO':
372
- detect_model, sam = build_dino_segment_model(args.dino_checkpoint, args.sam_checkpoint)
373
- else:
374
- detect_model, sam = build_yolo_segment_model(args.efficientViT_checkpoint, device)
375
 
376
  resolution_list = ["1440*728",
377
  "1344*768",
@@ -391,12 +328,8 @@ def main(device, segment_type):
391
 
392
  depth_estimator = DPTForDepthEstimation.from_pretrained(args.dpt_checkpoint).to("cuda")
393
  feature_extractor = DPTFeatureExtractor.from_pretrained(args.dpt_checkpoint)
394
- body_model = Body(args.pose_detector_checkpoint)
395
- openpose = OpenposeDetector(body_model)
396
-
397
- prompts_rewrite = [args.prompt_rewrite]
398
- input_prompt_test = [prepare_text(p, p_w) for p, p_w in zip(prompts, prompts_rewrite)]
399
- input_prompt_test = [prompts, input_prompt_test[0][1]]
400
 
401
  def remove_tips():
402
  return gr.update(visible=False)
@@ -438,270 +371,167 @@ def main(device, segment_type):
438
  return image
439
 
440
  @spaces.GPU
441
- def generate_image(prompt1, negative_prompt, reference_1, reference_2, resolution, local_prompt1, local_prompt2, seed, style, identitynet_strength_ratio, adapter_strength_ratio, condition, condition_img, controlnet_ratio):
442
- identitynet_strength_ratio = float(identitynet_strength_ratio)
443
- adapter_strength_ratio = float(adapter_strength_ratio)
444
- controlnet_ratio = float(controlnet_ratio)
445
- if lorapath_styles[style] is not None and os.path.exists(lorapath_styles[style]):
446
- styleL = True
447
- else:
448
- styleL = False
449
-
450
- width, height = int(resolution.split("*")[0]), int(resolution.split("*")[1])
451
- kwargs = {
452
- 'height': height,
453
- 'width': width,
454
- 't2i_controlnet_conditioning_scale': controlnet_ratio,
455
- }
456
-
457
- if condition == 'Human pose' and condition_img is not None:
458
- index = ratio_list.index(
459
- min(ratio_list, key=lambda x: abs(x - condition_img.shape[1] / condition_img.shape[0])))
460
- resolution = resolution_list[index]
461
- width, height = int(resolution.split("*")[0]), int(resolution.split("*")[1])
462
- kwargs['height'] = height
463
- kwargs['width'] = width
464
- condition_img = resize_and_center_crop(Image.fromarray(condition_img), (width, height))
465
- spatial_condition = get_humanpose(condition_img)
466
- elif condition == 'Canny Edge' and condition_img is not None:
467
- index = ratio_list.index(
468
- min(ratio_list, key=lambda x: abs(x - condition_img.shape[1] / condition_img.shape[0])))
469
- resolution = resolution_list[index]
470
- width, height = int(resolution.split("*")[0]), int(resolution.split("*")[1])
471
- kwargs['height'] = height
472
- kwargs['width'] = width
473
- condition_img = resize_and_center_crop(Image.fromarray(condition_img), (width, height))
474
- spatial_condition = get_cannyedge(condition_img)
475
- elif condition == 'Depth' and condition_img is not None:
476
- index = ratio_list.index(
477
- min(ratio_list, key=lambda x: abs(x - condition_img.shape[1] / condition_img.shape[0])))
478
- resolution = resolution_list[index]
479
- width, height = int(resolution.split("*")[0]), int(resolution.split("*")[1])
480
- kwargs['height'] = height
481
- kwargs['width'] = width
482
- condition_img = resize_and_center_crop(Image.fromarray(condition_img), (width, height))
483
- spatial_condition = get_depth(condition_img)
484
- else:
485
- spatial_condition = None
486
-
487
- kwargs['t2i_image'] = spatial_condition
488
- pipe.unload_lora_weights()
489
- pipe_concepts.unload_lora_weights()
490
- build_model_lora(pipe, pipe_concepts, lorapath_styles[style], condition, condition_img)
491
- pipe_concepts.set_ip_adapter_scale(adapter_strength_ratio)
492
-
493
- input_list = [prompt1]
494
-
495
-
496
- for prompt in input_list:
497
- if prompt != '':
498
- input_prompt = []
499
- p = '{prompt}, 35mm photograph, film, professional, 4k, highly detailed.'
500
- if styleL:
501
- p = styles[style] + p
502
- input_prompt.append([p.replace('{prompt}', prompt), p.replace("{prompt}", prompt)])
503
- if styleL:
504
- input_prompt.append([(styles[style] + local_prompt1, 'noisy, blurry, soft, deformed, ugly',
505
- PIL.Image.fromarray(reference_1)),
506
- (styles[style] + local_prompt2, 'noisy, blurry, soft, deformed, ugly',
507
- PIL.Image.fromarray(reference_2))])
508
- else:
509
- input_prompt.append(
510
- [(local_prompt1, 'noisy, blurry, soft, deformed, ugly', PIL.Image.fromarray(reference_1)),
511
- (local_prompt2, 'noisy, blurry, soft, deformed, ugly', PIL.Image.fromarray(reference_2))])
512
-
513
-
514
- controller.reset()
515
- image = sample_image(
516
- pipe,
517
- input_prompt=input_prompt,
518
- concept_models=pipe_concepts,
519
- input_neg_prompt=[negative_prompt] * len(input_prompt),
520
- generator=torch.Generator(device).manual_seed(seed),
521
- controller=controller,
522
- face_app=face_app,
523
- controlnet_conditioning_scale=identitynet_strength_ratio,
524
- stage=1,
525
- **kwargs)
526
-
527
- controller.reset()
528
-
529
- if (pipe.tokenizer("man")["input_ids"][1] in pipe.tokenizer(args.prompt)["input_ids"][1:-1]) and (
530
- pipe.tokenizer("woman")["input_ids"][1] in pipe.tokenizer(args.prompt)["input_ids"][1:-1]):
531
- mask1, mask2 = predict_mask(detect_model, sam, image[0], ['man', 'woman'], args.segment_type, confidence=0.05,
532
- threshold=0.5)
533
-
534
- elif pipe.tokenizer("man")["input_ids"][1] in pipe.tokenizer(args.prompt)["input_ids"][1:-1]:
535
- mask1 = predict_mask(detect_model, sam, image[0], ['man'], args.segment_type, confidence=0.05,
536
- threshold=0.5)
537
- mask2 = None
538
-
539
- elif pipe.tokenizer("woman")["input_ids"][1] in pipe.tokenizer(args.prompt)["input_ids"][1:-1]:
540
- mask2 = predict_mask(detect_model, sam, image[0], ['woman'], args.segment_type, confidence=0.05,
541
- threshold=0.5)
542
- mask1 = None
543
- else:
544
- mask1 = mask2 = None
545
 
546
- if mask1 is not None or mask2 is not None:
547
- face_info = face_app.get(cv2.cvtColor(np.array(image[0]), cv2.COLOR_RGB2BGR))
548
- face_kps = draw_kps_multi(image[0], [face['kps'] for face in face_info])
549
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
550
  image = sample_image(
551
  pipe,
552
  input_prompt=input_prompt,
553
- concept_models=pipe_concepts,
554
  input_neg_prompt=[negative_prompt] * len(input_prompt),
555
  generator=torch.Generator(device).manual_seed(seed),
556
  controller=controller,
557
- face_app=face_app,
558
- image=face_kps,
559
- stage=2,
560
- controlnet_conditioning_scale=identitynet_strength_ratio,
561
- region_masks=[mask1, mask2],
562
  **kwargs)
563
 
564
- # return [image[1], spatial_condition]
565
- return image
566
-
567
- with gr.Blocks(css=css) as demo:
568
- # description
569
- gr.Markdown(title)
570
- gr.Markdown(description)
571
-
572
- with gr.Row():
573
- gallery = gr.Image(label="Generated Images", height=512, width=512)
574
- gallery1 = gr.Image(label="Generated Images", height=512, width=512)
575
- usage_tips = gr.Markdown(label="Usage tips of OMG", value=tips, visible=False)
576
-
577
-
578
- with gr.Row():
579
- reference_1 = gr.Image(label="Input an RGB image for Character man", height=128, width=128)
580
- reference_2 = gr.Image(label="Input an RGB image for Character woman", height=128, width=128)
581
- condition_img1 = gr.Image(label="Input an RGB image for condition (Optional)", height=128, width=128)
582
-
583
-
584
-
585
-
586
- with gr.Row():
587
- local_prompt1 = gr.Textbox(label="Character1_prompt",
588
- info="Describe the Character 1",
589
- value="Close-up photo of the a man, 35mm photograph, professional, 4k, highly detailed.")
590
- local_prompt2 = gr.Textbox(label="Character2_prompt",
591
- info="Describe the Character 2",
592
- value="Close-up photo of the a woman, 35mm photograph, professional, 4k, highly detailed.")
593
- with gr.Row():
594
- identitynet_strength_ratio = gr.Slider(
595
- label="IdentityNet strength (for fidelity)",
596
- minimum=0,
597
- maximum=1.5,
598
- step=0.05,
599
- value=0.80,
600
- )
601
- adapter_strength_ratio = gr.Slider(
602
- label="Image adapter strength (for detail)",
603
- minimum=0,
604
- maximum=1.5,
605
- step=0.05,
606
- value=0.80,
607
- )
608
- controlnet_ratio = gr.Slider(
609
- label="ControlNet strength",
610
- minimum=0,
611
- maximum=1.5,
612
- step=0.05,
613
- value=1,
614
- )
615
- resolution = gr.Dropdown(label="Image Resolution (width*height)", choices=resolution_list,
616
- value="1024*1024")
617
- style = gr.Dropdown(label="style", choices=STYLE_NAMES, value="None")
618
- condition = gr.Dropdown(label="Input condition type", choices=condition_list, value="None")
619
-
620
-
621
- # prompt
622
- with gr.Column():
623
- prompt = gr.Textbox(label="Prompt 1",
624
- info="Give a simple prompt to describe the first image content",
625
- placeholder="Required",
626
- value="close-up shot, photography, a man and a woman on the street, facing the camera smiling")
627
-
628
-
629
- with gr.Accordion(open=False, label="Advanced Options"):
630
- seed = gr.Slider(
631
- label="Seed",
632
- minimum=0,
633
- maximum=MAX_SEED,
634
- step=1,
635
- value=42,
636
- )
637
- negative_prompt = gr.Textbox(label="Negative Prompt",
638
- placeholder="noisy, blurry, soft, deformed, ugly",
639
- value="noisy, blurry, soft, deformed, ugly")
640
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
641
-
642
- submit = gr.Button("Submit", variant="primary")
643
-
644
- submit.click(
645
- fn=remove_tips,
646
- outputs=usage_tips,
647
- ).then(
648
- fn=randomize_seed_fn,
649
- inputs=[seed, randomize_seed],
650
- outputs=seed,
651
- queue=False,
652
- api_name=False,
653
- ).then(
654
- fn=generate_image,
655
- inputs=[prompt, negative_prompt, reference_1, reference_2, resolution, local_prompt1, local_prompt2, seed, style, identitynet_strength_ratio, adapter_strength_ratio, condition, condition_img1, controlnet_ratio],
656
- outputs=[gallery, gallery1]
657
- )
658
- demo.launch(server_name='0.0.0.0',server_port=7861, debug=True)
659
 
660
  def parse_args():
661
  parser = argparse.ArgumentParser('', add_help=False)
662
- parser.add_argument('--pretrained_model', default='stablediffusionapi/realism-engine-sdxl-v30', type=str)
663
- parser.add_argument('--controlnet_path', default='./checkpoint/InstantID/ControlNetModel', type=str)
664
- parser.add_argument('--face_adapter_path', default='./checkpoint/InstantID/ip-adapter.bin', type=str)
665
  parser.add_argument('--openpose_checkpoint', default='thibaud/controlnet-openpose-sdxl-1.0', type=str)
666
  parser.add_argument('--canny_checkpoint', default='diffusers/controlnet-canny-sdxl-1.0', type=str)
667
  parser.add_argument('--depth_checkpoint', default='diffusers/controlnet-depth-sdxl-1.0', type=str)
668
- parser.add_argument('--dpt_checkpoint', default='Intel/dpt-hybrid-midas', type=str)
669
- parser.add_argument('--pose_detector_checkpoint',
670
- default='./checkpoint/ControlNet/annotator/ckpts/body_pose_model.pth', type=str)
671
- parser.add_argument('--efficientViT_checkpoint', default='./checkpoint/sam/xl1.pt', type=str)
672
  parser.add_argument('--dino_checkpoint', default='./checkpoint/GroundingDINO', type=str)
673
  parser.add_argument('--sam_checkpoint', default='./checkpoint/sam/sam_vit_h_4b8939.pth', type=str)
674
- parser.add_argument('--antelopev2_path', default='./checkpoint/antelopev2', type=str)
675
- parser.add_argument('--save_dir', default='results/instantID', type=str)
676
- parser.add_argument('--prompt', default='Close-up photo of the cool man and beautiful woman as they accidentally discover a mysterious island while on vacation by the sea, facing the camera smiling, 35mm photograph, film, professional, 4k, highly detailed.', type=str)
677
  parser.add_argument('--negative_prompt', default='noisy, blurry, soft, deformed, ugly', type=str)
678
- parser.add_argument('--prompt_rewrite',
679
- default='[Close-up photo of a man, 35mm photograph, professional, 4k, highly detailed.]-*'
680
- '-[noisy, blurry, soft, deformed, ugly]-*-'
681
- '../example/chris-evans.jpg|'
682
- '[Close-up photo of a woman, 35mm photograph, professional, 4k, highly detailed.]-'
683
- '*-[noisy, blurry, soft, deformed, ugly]-*-'
684
- '../example/TaylorSwift.png',
685
- type=str)
686
- parser.add_argument('--seed', default=0, type=int)
687
  parser.add_argument('--suffix', default='', type=str)
688
  parser.add_argument('--segment_type', default='yoloworld', help='GroundingDINO or yoloworld', type=str)
689
- parser.add_argument('--style_lora', default='', type=str)
690
  return parser.parse_args()
691
 
692
  if __name__ == '__main__':
693
  args = parse_args()
694
 
695
- prompts = [args.prompt] * 2
696
-
697
  prompts_tmp = copy.deepcopy(prompts)
698
-
699
- width, height = 1024, 1024
700
- kwargs = {
701
- 'height': height,
702
- 'width': width,
703
- }
704
-
705
  device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
706
- main(device, args.segment_type)
707
-
 
1
  import spaces
2
  import sys
3
  import os
 
 
4
 
5
+
6
+
7
+ # os.system(f"git clone https://github.com/Curt-Park/yolo-world-with-efficientvit-sam.git")
8
+ # cwd0 = os.getcwd()
9
+ # cwd1 = os.path.join(cwd0, "yolo-world-with-efficientvit-sam")
10
+ # os.chdir(cwd1)
11
+ # os.system("make setup")
12
+ # os.system(f"cd /home/user/app")
13
+
14
  sys.path.append('./')
15
+ import gradio as gr
16
+ import random
 
 
17
  import numpy as np
18
+ from gradio_demo.character_template import character_man, lorapath_man
19
+ from gradio_demo.character_template import character_woman, lorapath_woman
20
+ from gradio_demo.character_template import styles, lorapath_styles
21
  import torch
22
+ import os
23
  from typing import Tuple, List
 
 
 
24
  import copy
25
+ import argparse
26
+ from diffusers.utils import load_image
27
+ import cv2
 
28
  from PIL import Image, ImageOps
29
  from transformers import DPTFeatureExtractor, DPTForDepthEstimation
30
+ # from controlnet_aux import OpenposeDetector
31
+ # from controlnet_aux.open_pose.body import Body
 
 
 
 
 
 
32
 
33
+ try:
34
+ from inference.models import YOLOWorld
35
+ from src.efficientvit.models.efficientvit.sam import EfficientViTSamPredictor
36
+ from src.efficientvit.sam_model_zoo import create_sam_model
37
+ import supervision as sv
38
+ except:
39
+ print("YoloWorld can not be load")
40
 
41
  try:
42
  from groundingdino.models import build_model
 
49
  except:
50
  print("groundingdino can not be load")
51
 
52
+ from src.pipelines.lora_pipeline import LoraMultiConceptPipeline
 
53
  from src.prompt_attention.p2p_attention import AttentionReplace
54
+ from diffusers import ControlNetModel, StableDiffusionXLPipeline
55
+ from src.pipelines.lora_pipeline import revise_regionally_controlnet_forward
 
 
 
 
 
 
56
 
57
+ from download import OMG_download
58
 
59
+ CHARACTER_MAN_NAMES = list(character_man.keys())
60
+ CHARACTER_WOMAN_NAMES = list(character_woman.keys())
61
+ STYLE_NAMES = list(styles.keys())
62
  MAX_SEED = np.iinfo(np.int32).max
63
 
64
+ ### Description
65
  title = r"""
66
+ <h1 align="center">OMG: Occlusion-friendly Personalized Multi-concept Generation In Diffusion Models</h1>
67
  """
68
 
69
  description = r"""
70
+ <b>Official 🤗 Gradio demo</b> for <a href='https://github.com/' target='_blank'><b>OMG: Occlusion-friendly Personalized Multi-concept Generation In Diffusion Models</b></a>.<br>
 
71
  How to use:<br>
72
  1. Select two characters.
73
  2. Enter a text prompt as done in normal text-to-image models.
 
99
  .gradio-container {width: 85% !important}
100
  '''
101
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  def sample_image(pipe,
103
  input_prompt,
104
  input_neg_prompt=None,
105
  generator=None,
106
  concept_models=None,
107
  num_inference_steps=50,
108
+ guidance_scale=7.5,
109
  controller=None,
 
 
110
  stage=None,
111
  region_masks=None,
112
+ lora_list = None,
113
+ styleL=None,
114
  **extra_kargs
115
  ):
116
 
117
+ spatial_condition = extra_kargs.pop('spatial_condition')
118
+ if spatial_condition is not None:
119
+ spatial_condition_input = [spatial_condition] * len(input_prompt)
120
  else:
121
+ spatial_condition_input = None
 
122
 
123
  images = pipe(
124
  prompt=input_prompt,
 
129
  num_inference_steps=num_inference_steps,
130
  cross_attention_kwargs={"scale": 0.8},
131
  controller=controller,
 
 
132
  stage=stage,
 
133
  region_masks=region_masks,
134
+ lora_list=lora_list,
135
+ styleL=styleL,
136
+ image=spatial_condition_input,
137
  **extra_kargs).images
138
+
139
  return images
140
 
141
  def load_image_yoloworld(image_source) -> Tuple[np.array, torch.Tensor]:
 
154
  image_transformed, _ = transform(image_source, None)
155
  return image, image_transformed
156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  def predict_mask(segmentmodel, sam, image, TEXT_PROMPT, segmentType, confidence = 0.2, threshold = 0.5):
158
  if segmentType=='GroundingDINO':
159
  image_source, image = load_image_dino(image)
 
178
  masks=masks[0].squeeze(0)
179
  else:
180
  image_source = load_image_yoloworld(image)
181
+ segmentmodel.set_classes([TEXT_PROMPT])
182
  results = segmentmodel.infer(image_source, confidence=confidence)
183
  detections = sv.Detections.from_inference(results).with_nms(
184
  class_agnostic=True, threshold=threshold
185
  )
186
+ masks = None
187
+ if len(detections) != 0:
188
+ print(TEXT_PROMPT + " detected!")
189
+ sam.set_image(image_source, image_format="RGB")
190
+ masks, _, _ = sam.predict(box=detections.xyxy[0], multimask_output=False)
191
+ masks = torch.from_numpy(masks.squeeze())
192
 
193
+ return masks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
  def prepare_text(prompt, region_prompts):
196
  '''
 
208
  for region in regions:
209
  if region == '':
210
  break
211
+ prompt_region, neg_prompt_region = region.split('-*-')
212
  prompt_region = prompt_region.replace('[', '').replace(']', '')
213
  neg_prompt_region = neg_prompt_region.replace('[', '').replace(']', '')
214
 
215
+ region_collection.append((prompt_region, neg_prompt_region))
216
  return (prompt, region_collection)
217
 
218
+
219
+ def build_model_sd(pretrained_model, controlnet_path, device, prompts):
220
+ controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16).to(device)
221
+ pipe = LoraMultiConceptPipeline.from_pretrained(
222
+ pretrained_model, controlnet=controlnet, torch_dtype=torch.float16, variant="fp16").to(device)
223
+ controller = AttentionReplace(prompts, 50, cross_replace_steps={"default_": 1.}, self_replace_steps=0.4, tokenizer=pipe.tokenizer, device=device, dtype=torch.float16, width=1024//32, height=1024//32)
224
+ revise_regionally_controlnet_forward(pipe.unet, controller)
225
+ pipe_concept = StableDiffusionXLPipeline.from_pretrained(pretrained_model, torch_dtype=torch.float16,
226
+ variant="fp16").to(device)
227
+ return pipe, controller, pipe_concept
228
+
229
+ def build_model_lora(pipe_concept, lora_paths, style_path, condition, args, pipe):
230
+ pipe_list = []
231
+ if condition == "Human pose":
232
  controlnet = ControlNetModel.from_pretrained(args.openpose_checkpoint, torch_dtype=torch.float16).to(device)
233
+ pipe.controlnet = controlnet
234
+ elif condition == "Canny Edge":
235
  controlnet = ControlNetModel.from_pretrained(args.canny_checkpoint, torch_dtype=torch.float16, variant="fp16").to(device)
236
+ pipe.controlnet = controlnet
237
+ elif condition == "Depth":
238
  controlnet = ControlNetModel.from_pretrained(args.depth_checkpoint, torch_dtype=torch.float16).to(device)
239
+ pipe.controlnet = controlnet
240
 
241
  if style_path is not None and os.path.exists(style_path):
242
  pipe_concept.load_lora_weights(style_path, weight_name="pytorch_lora_weights.safetensors", adapter_name='style')
243
  pipe.load_lora_weights(style_path, weight_name="pytorch_lora_weights.safetensors", adapter_name='style')
244
 
245
+ for lora_path in lora_paths.split('|'):
246
+ adapter_name = lora_path.split('/')[-1].split('.')[0]
247
+ pipe_concept.load_lora_weights(lora_path, weight_name="pytorch_lora_weights.safetensors", adapter_name=adapter_name)
248
+ pipe_concept.enable_xformers_memory_efficient_attention()
249
+ pipe_list.append(adapter_name)
250
+ return pipe_list
251
+
252
+ def build_yolo_segment_model(sam_path, device):
253
+ yolo_world = YOLOWorld(model_id="yolo_world/l")
254
+ sam = EfficientViTSamPredictor(
255
+ create_sam_model(name="xl1", weight_url=sam_path).to(device).eval()
256
+ )
257
+ return yolo_world, sam
258
+
259
+ def load_model_hf(repo_id, filename, ckpt_config_filename, device='cpu'):
260
+ args = SLConfig.fromfile(ckpt_config_filename)
261
+ model = build_model(args)
262
+ args.device = device
263
+
264
+ checkpoint = torch.load(os.path.join(repo_id, filename), map_location='cpu')
265
+ log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
266
+ print("Model loaded from {} \n => {}".format(filename, log))
267
+ _ = model.eval()
268
+ return model
269
+
270
+ def build_dino_segment_model(ckpt_repo_id, sam_checkpoint):
271
+ ckpt_filenmae = "groundingdino_swinb_cogcoor.pth"
272
+ ckpt_config_filename = os.path.join(ckpt_repo_id, "GroundingDINO_SwinB.cfg.py")
273
+ groundingdino_model = load_model_hf(ckpt_repo_id, ckpt_filenmae, ckpt_config_filename)
274
+ sam = build_sam(checkpoint=sam_checkpoint)
275
+ sam.cuda()
276
+ sam_predictor = SamPredictor(sam)
277
+ return groundingdino_model, sam_predictor
278
+
279
  def resize_and_center_crop(image, output_size=(1024, 576)):
280
  width, height = image.size
281
  aspect_ratio = width / height
 
303
  return cropped_image
304
 
305
  def main(device, segment_type):
306
+ pipe, controller, pipe_concept = build_model_sd(args.pretrained_sdxl_model, args.openpose_checkpoint, device, prompts_tmp)
307
+
308
+ # if segment_type == 'GroundingDINO':
309
+ # detect_model, sam = build_dino_segment_model(args.dino_checkpoint, args.sam_checkpoint)
310
+ # else:
311
+ # detect_model, sam = build_yolo_segment_model(args.efficientViT_checkpoint, device)
 
 
312
 
313
  resolution_list = ["1440*728",
314
  "1344*768",
 
328
 
329
  depth_estimator = DPTForDepthEstimation.from_pretrained(args.dpt_checkpoint).to("cuda")
330
  feature_extractor = DPTFeatureExtractor.from_pretrained(args.dpt_checkpoint)
331
+ # body_model = Body(args.pose_detector_checkpoint)
332
+ # openpose = OpenposeDetector(body_model)
 
 
 
 
333
 
334
  def remove_tips():
335
  return gr.update(visible=False)
 
371
  return image
372
 
373
  @spaces.GPU
374
+ def generate_image(prompt1, negative_prompt, man, woman, resolution, local_prompt1, local_prompt2, seed, condition, condition_img1, style):
375
+ try:
376
+ path1 = lorapath_man[man]
377
+ path2 = lorapath_woman[woman]
378
+ pipe_concept.unload_lora_weights()
379
+ pipe.unload_lora_weights()
380
+ pipe_list = build_model_lora(pipe_concept, path1 + "|" + path2, lorapath_styles[style], condition, args, pipe)
381
+
382
+ if lorapath_styles[style] is not None and os.path.exists(lorapath_styles[style]):
383
+ styleL = True
384
+ else:
385
+ styleL = False
386
+
387
+ input_list = [prompt1]
388
+ condition_list = [condition_img1]
389
+ output_list = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
390
 
391
+ width, height = int(resolution.split("*")[0]), int(resolution.split("*")[1])
 
 
392
 
393
+ kwargs = {
394
+ 'height': height,
395
+ 'width': width,
396
+ }
397
+
398
+ for prompt, condition_img in zip(input_list, condition_list):
399
+ if prompt!='':
400
+ input_prompt = []
401
+ p = '{prompt}, 35mm photograph, film, professional, 4k, highly detailed.'
402
+ if styleL:
403
+ p = styles[style] + p
404
+ input_prompt.append([p.replace("{prompt}", prompt), p.replace("{prompt}", prompt)])
405
+ if styleL:
406
+ input_prompt.append([(styles[style] + local_prompt1, character_man.get(man)[1]),
407
+ (styles[style] + local_prompt2, character_woman.get(woman)[1])])
408
+ else:
409
+ input_prompt.append([(local_prompt1, character_man.get(man)[1]),
410
+ (local_prompt2, character_woman.get(woman)[1])])
411
+
412
+ if condition == 'Human pose' and condition_img is not None:
413
+ index = ratio_list.index(
414
+ min(ratio_list, key=lambda x: abs(x - condition_img.shape[1] / condition_img.shape[0])))
415
+ resolution = resolution_list[index]
416
+ width, height = int(resolution.split("*")[0]), int(resolution.split("*")[1])
417
+ kwargs['height'] = height
418
+ kwargs['width'] = width
419
+ condition_img = resize_and_center_crop(Image.fromarray(condition_img), (width, height))
420
+ spatial_condition = get_humanpose(condition_img)
421
+ elif condition == 'Canny Edge' and condition_img is not None:
422
+ index = ratio_list.index(
423
+ min(ratio_list, key=lambda x: abs(x - condition_img.shape[1] / condition_img.shape[0])))
424
+ resolution = resolution_list[index]
425
+ width, height = int(resolution.split("*")[0]), int(resolution.split("*")[1])
426
+ kwargs['height'] = height
427
+ kwargs['width'] = width
428
+ condition_img = resize_and_center_crop(Image.fromarray(condition_img), (width, height))
429
+ spatial_condition = get_cannyedge(condition_img)
430
+ elif condition == 'Depth' and condition_img is not None:
431
+ index = ratio_list.index(
432
+ min(ratio_list, key=lambda x: abs(x - condition_img.shape[1] / condition_img.shape[0])))
433
+ resolution = resolution_list[index]
434
+ width, height = int(resolution.split("*")[0]), int(resolution.split("*")[1])
435
+ kwargs['height'] = height
436
+ kwargs['width'] = width
437
+ condition_img = resize_and_center_crop(Image.fromarray(condition_img), (width, height))
438
+ spatial_condition = get_depth(condition_img)
439
+ else:
440
+ spatial_condition = None
441
+
442
+ kwargs['spatial_condition'] = spatial_condition
443
+ controller.reset()
444
  image = sample_image(
445
  pipe,
446
  input_prompt=input_prompt,
447
+ concept_models=pipe_concept,
448
  input_neg_prompt=[negative_prompt] * len(input_prompt),
449
  generator=torch.Generator(device).manual_seed(seed),
450
  controller=controller,
451
+ stage=1,
452
+ lora_list=pipe_list,
453
+ styleL=styleL,
 
 
454
  **kwargs)
455
 
456
+ controller.reset()
457
+ if pipe.tokenizer("man")["input_ids"][1] in pipe.tokenizer(args.prompt)["input_ids"][1:-1]:
458
+ mask1 = predict_mask(detect_model, sam, image[0], 'man', args.segment_type, confidence=0.15,
459
+ threshold=0.5)
460
+ else:
461
+ mask1 = None
462
+
463
+ if pipe.tokenizer("woman")["input_ids"][1] in pipe.tokenizer(args.prompt)["input_ids"][1:-1]:
464
+ mask2 = predict_mask(detect_model, sam, image[0], 'woman', args.segment_type, confidence=0.15,
465
+ threshold=0.5)
466
+ else:
467
+ mask2 = None
468
+
469
+ if mask1 is None and mask2 is None:
470
+ output_list.append(image[1])
471
+ else:
472
+ image = sample_image(
473
+ pipe,
474
+ input_prompt=input_prompt,
475
+ concept_models=pipe_concept,
476
+ input_neg_prompt=[negative_prompt] * len(input_prompt),
477
+ generator=torch.Generator(device).manual_seed(seed),
478
+ controller=controller,
479
+ stage=2,
480
+ region_masks=[mask1, mask2],
481
+ lora_list=pipe_list,
482
+ styleL=styleL,
483
+ **kwargs)
484
+ output_list.append(image[1])
485
+ else:
486
+ output_list.append(None)
487
+ output_list.append(spatial_condition)
488
+ return output_list
489
+ except:
490
+ print("error")
491
+ return
492
+
493
+ def get_local_value_man(input):
494
+ return character_man[input][0]
495
+
496
+ def get_local_value_woman(input):
497
+ return character_woman[input][0]
498
+
499
+ @spaces.GPU
500
+ def generate(prompt):
501
+ print(os.system(prompt))
502
+ return prompt
503
+
504
+ gr.Interface(
505
+ fn=generate,
506
+ inputs=gr.Text(),
507
+ outputs=gr.Gallery(),
508
+ ).launch()
509
+
510
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
511
 
512
  def parse_args():
513
  parser = argparse.ArgumentParser('', add_help=False)
514
+ parser.add_argument('--pretrained_sdxl_model', default='Fucius/stable-diffusion-xl-base-1.0', type=str)
 
 
515
  parser.add_argument('--openpose_checkpoint', default='thibaud/controlnet-openpose-sdxl-1.0', type=str)
516
  parser.add_argument('--canny_checkpoint', default='diffusers/controlnet-canny-sdxl-1.0', type=str)
517
  parser.add_argument('--depth_checkpoint', default='diffusers/controlnet-depth-sdxl-1.0', type=str)
518
+ parser.add_argument('--efficientViT_checkpoint', default='../checkpoint/sam/xl1.pt', type=str)
 
 
 
519
  parser.add_argument('--dino_checkpoint', default='./checkpoint/GroundingDINO', type=str)
520
  parser.add_argument('--sam_checkpoint', default='./checkpoint/sam/sam_vit_h_4b8939.pth', type=str)
521
+ parser.add_argument('--dpt_checkpoint', default='Intel/dpt-hybrid-midas', type=str)
522
+ parser.add_argument('--pose_detector_checkpoint', default='../checkpoint/ControlNet/annotator/ckpts/body_pose_model.pth', type=str)
523
+ parser.add_argument('--prompt', default='Close-up photo of the cool man and beautiful woman in surprised expressions as they accidentally discover a mysterious island while on vacation by the sea, 35mm photograph, film, professional, 4k, highly detailed.', type=str)
524
  parser.add_argument('--negative_prompt', default='noisy, blurry, soft, deformed, ugly', type=str)
525
+ parser.add_argument('--seed', default=22, type=int)
 
 
 
 
 
 
 
 
526
  parser.add_argument('--suffix', default='', type=str)
527
  parser.add_argument('--segment_type', default='yoloworld', help='GroundingDINO or yoloworld', type=str)
 
528
  return parser.parse_args()
529
 
530
  if __name__ == '__main__':
531
  args = parse_args()
532
 
533
+ prompts = [args.prompt]*2
 
534
  prompts_tmp = copy.deepcopy(prompts)
 
 
 
 
 
 
 
535
  device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
536
+ download = OMG_download()
537
+ main(device, args.segment_type)