Fucius commited on
Commit
83de213
1 Parent(s): 821f875

Upload 4 files

Browse files
gradio_demo/app.py ADDED
@@ -0,0 +1,594 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append('./')
3
+ import gradio as gr
4
+ import random
5
+ import numpy as np
6
+ from gradio_demo.character_template import character_man, lorapath_man
7
+ from gradio_demo.character_template import character_woman, lorapath_woman
8
+ from gradio_demo.character_template import styles, lorapath_styles
9
+ import torch
10
+ import os
11
+ from typing import Tuple, List
12
+ import copy
13
+ import argparse
14
+ from diffusers.utils import load_image
15
+ import cv2
16
+ from PIL import Image, ImageOps
17
+ from transformers import DPTFeatureExtractor, DPTForDepthEstimation
18
+ from controlnet_aux import OpenposeDetector
19
+ from controlnet_aux.open_pose.body import Body
20
+
21
+ try:
22
+ from inference.models import YOLOWorld
23
+ from src.efficientvit.models.efficientvit.sam import EfficientViTSamPredictor
24
+ from src.efficientvit.sam_model_zoo import create_sam_model
25
+ import supervision as sv
26
+ except:
27
+ print("YoloWorld can not be load")
28
+
29
+ try:
30
+ from groundingdino.models import build_model
31
+ from groundingdino.util import box_ops
32
+ from groundingdino.util.slconfig import SLConfig
33
+ from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
34
+ from groundingdino.util.inference import annotate, predict
35
+ from segment_anything import build_sam, SamPredictor
36
+ import groundingdino.datasets.transforms as T
37
+ except:
38
+ print("groundingdino can not be load")
39
+
40
+ from src.pipelines.lora_pipeline import LoraMultiConceptPipeline
41
+ from src.prompt_attention.p2p_attention import AttentionReplace
42
+ from diffusers import ControlNetModel, StableDiffusionXLPipeline
43
+ from src.pipelines.lora_pipeline import revise_regionally_controlnet_forward
44
+
45
+ CHARACTER_MAN_NAMES = list(character_man.keys())
46
+ CHARACTER_WOMAN_NAMES = list(character_woman.keys())
47
+ STYLE_NAMES = list(styles.keys())
48
+ MAX_SEED = np.iinfo(np.int32).max
49
+
50
+ ### Description
51
+ title = r"""
52
+ <h1 align="center">OMG: Occlusion-friendly Personalized Multi-concept Generation In Diffusion Models</h1>
53
+ """
54
+
55
+ description = r"""
56
+ <b>Official 🤗 Gradio demo</b> for <a href='https://github.com/' target='_blank'><b>OMG: Occlusion-friendly Personalized Multi-concept Generation In Diffusion Models</b></a>.<br>
57
+
58
+ How to use:<br>
59
+ 1. Select two characters.
60
+ 2. Enter a text prompt as done in normal text-to-image models.
61
+ 3. Click the <b>Submit</b> button to start customizing.
62
+ 4. Enjoy the generated image😊!
63
+ """
64
+
65
+ article = r"""
66
+ ---
67
+ 📝 **Citation**
68
+ <br>
69
+ If our work is helpful for your research or applications, please cite us via:
70
+ ```bibtex
71
+ @article{,
72
+ title={OMG: Occlusion-friendly Personalized Multi-concept Generation In Diffusion Models},
73
+ author={},
74
+ journal={},
75
+ year={}
76
+ }
77
+ ```
78
+ """
79
+
80
+ tips = r"""
81
+ ### Usage tips of OMG
82
+ 1. Input text prompts to describe a man and a woman
83
+ """
84
+
85
+ css = '''
86
+ .gradio-container {width: 85% !important}
87
+ '''
88
+
89
+ def sample_image(pipe,
90
+ input_prompt,
91
+ input_neg_prompt=None,
92
+ generator=None,
93
+ concept_models=None,
94
+ num_inference_steps=50,
95
+ guidance_scale=7.5,
96
+ controller=None,
97
+ stage=None,
98
+ region_masks=None,
99
+ lora_list = None,
100
+ styleL=None,
101
+ **extra_kargs
102
+ ):
103
+
104
+ spatial_condition = extra_kargs.pop('spatial_condition')
105
+ if spatial_condition is not None:
106
+ spatial_condition_input = [spatial_condition] * len(input_prompt)
107
+ else:
108
+ spatial_condition_input = None
109
+
110
+ images = pipe(
111
+ prompt=input_prompt,
112
+ concept_models=concept_models,
113
+ negative_prompt=input_neg_prompt,
114
+ generator=generator,
115
+ guidance_scale=guidance_scale,
116
+ num_inference_steps=num_inference_steps,
117
+ cross_attention_kwargs={"scale": 0.8},
118
+ controller=controller,
119
+ stage=stage,
120
+ region_masks=region_masks,
121
+ lora_list=lora_list,
122
+ styleL=styleL,
123
+ image=spatial_condition_input,
124
+ **extra_kargs).images
125
+
126
+ return images
127
+
128
+ def load_image_yoloworld(image_source) -> Tuple[np.array, torch.Tensor]:
129
+ image = np.asarray(image_source)
130
+ return image
131
+
132
+ def load_image_dino(image_source) -> Tuple[np.array, torch.Tensor]:
133
+ transform = T.Compose(
134
+ [
135
+ T.RandomResize([800], max_size=1333),
136
+ T.ToTensor(),
137
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
138
+ ]
139
+ )
140
+ image = np.asarray(image_source)
141
+ image_transformed, _ = transform(image_source, None)
142
+ return image, image_transformed
143
+
144
+ def predict_mask(segmentmodel, sam, image, TEXT_PROMPT, segmentType, confidence = 0.2, threshold = 0.5):
145
+ if segmentType=='GroundingDINO':
146
+ image_source, image = load_image_dino(image)
147
+ boxes, logits, phrases = predict(
148
+ model=segmentmodel,
149
+ image=image,
150
+ caption=TEXT_PROMPT,
151
+ box_threshold=0.3,
152
+ text_threshold=0.25
153
+ )
154
+ sam.set_image(image_source)
155
+ H, W, _ = image_source.shape
156
+ boxes_xyxy = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H])
157
+
158
+ transformed_boxes = sam.transform.apply_boxes_torch(boxes_xyxy, image_source.shape[:2]).cuda()
159
+ masks, _, _ = sam.predict_torch(
160
+ point_coords=None,
161
+ point_labels=None,
162
+ boxes=transformed_boxes,
163
+ multimask_output=False,
164
+ )
165
+ masks=masks[0].squeeze(0)
166
+ else:
167
+ image_source = load_image_yoloworld(image)
168
+ segmentmodel.set_classes([TEXT_PROMPT])
169
+ results = segmentmodel.infer(image_source, confidence=confidence)
170
+ detections = sv.Detections.from_inference(results).with_nms(
171
+ class_agnostic=True, threshold=threshold
172
+ )
173
+ masks = None
174
+ if len(detections) != 0:
175
+ print(TEXT_PROMPT + " detected!")
176
+ sam.set_image(image_source, image_format="RGB")
177
+ masks, _, _ = sam.predict(box=detections.xyxy[0], multimask_output=False)
178
+ masks = torch.from_numpy(masks.squeeze())
179
+
180
+ return masks
181
+
182
+ def prepare_text(prompt, region_prompts):
183
+ '''
184
+ Args:
185
+ prompt_entity: [subject1]-*-[attribute1]-*-[Location1]|[subject2]-*-[attribute2]-*-[Location2]|[global text]
186
+ Returns:
187
+ full_prompt: subject1, attribute1 and subject2, attribute2, global text
188
+ context_prompt: subject1 and subject2, global text
189
+ entity_collection: [(subject1, attribute1), Location1]
190
+ '''
191
+ region_collection = []
192
+
193
+ regions = region_prompts.split('|')
194
+
195
+ for region in regions:
196
+ if region == '':
197
+ break
198
+ prompt_region, neg_prompt_region = region.split('-*-')
199
+ prompt_region = prompt_region.replace('[', '').replace(']', '')
200
+ neg_prompt_region = neg_prompt_region.replace('[', '').replace(']', '')
201
+
202
+ region_collection.append((prompt_region, neg_prompt_region))
203
+ return (prompt, region_collection)
204
+
205
+
206
+ def build_model_sd(pretrained_model, controlnet_path, device, prompts):
207
+ controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16).to(device)
208
+ pipe = LoraMultiConceptPipeline.from_pretrained(
209
+ pretrained_model, controlnet=controlnet, torch_dtype=torch.float16, variant="fp16").to(device)
210
+ controller = AttentionReplace(prompts, 50, cross_replace_steps={"default_": 1.}, self_replace_steps=0.4, tokenizer=pipe.tokenizer, device=device, dtype=torch.float16, width=1024//32, height=1024//32)
211
+ revise_regionally_controlnet_forward(pipe.unet, controller)
212
+ pipe_concept = StableDiffusionXLPipeline.from_pretrained(pretrained_model, torch_dtype=torch.float16,
213
+ variant="fp16").to(device)
214
+ return pipe, controller, pipe_concept
215
+
216
+ def build_model_lora(pipe_concept, lora_paths, style_path, condition, args, pipe):
217
+ pipe_list = []
218
+ if condition == "Human pose":
219
+ controlnet = ControlNetModel.from_pretrained(args.openpose_checkpoint, torch_dtype=torch.float16).to(device)
220
+ pipe.controlnet = controlnet
221
+ elif condition == "Canny Edge":
222
+ controlnet = ControlNetModel.from_pretrained(args.canny_checkpoint, torch_dtype=torch.float16).to(device)
223
+ pipe.controlnet = controlnet
224
+ elif condition == "Depth":
225
+ controlnet = ControlNetModel.from_pretrained(args.depth_checkpoint, torch_dtype=torch.float16).to(device)
226
+ pipe.controlnet = controlnet
227
+
228
+ if style_path is not None and os.path.exists(style_path):
229
+ pipe_concept.load_lora_weights(style_path, weight_name="pytorch_lora_weights.safetensors", adapter_name='style')
230
+ pipe.load_lora_weights(style_path, weight_name="pytorch_lora_weights.safetensors", adapter_name='style')
231
+
232
+ for lora_path in lora_paths.split('|'):
233
+ adapter_name = lora_path.split('/')[-1].split('.')[0]
234
+ pipe_concept.load_lora_weights(lora_path, weight_name="pytorch_lora_weights.safetensors", adapter_name=adapter_name)
235
+ pipe_concept.enable_xformers_memory_efficient_attention()
236
+ pipe_list.append(adapter_name)
237
+ return pipe_list
238
+
239
+ def build_yolo_segment_model(sam_path, device):
240
+ yolo_world = YOLOWorld(model_id="yolo_world/l")
241
+ sam = EfficientViTSamPredictor(
242
+ create_sam_model(name="xl1", weight_url=sam_path).to(device).eval()
243
+ )
244
+ return yolo_world, sam
245
+
246
+ def load_model_hf(repo_id, filename, ckpt_config_filename, device='cpu'):
247
+ args = SLConfig.fromfile(ckpt_config_filename)
248
+ model = build_model(args)
249
+ args.device = device
250
+
251
+ checkpoint = torch.load(os.path.join(repo_id, filename), map_location='cpu')
252
+ log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
253
+ print("Model loaded from {} \n => {}".format(filename, log))
254
+ _ = model.eval()
255
+ return model
256
+
257
+ def build_dino_segment_model(ckpt_repo_id, sam_checkpoint):
258
+ ckpt_filenmae = "groundingdino_swinb_cogcoor.pth"
259
+ ckpt_config_filename = os.path.join(ckpt_repo_id, "GroundingDINO_SwinB.cfg.py")
260
+ groundingdino_model = load_model_hf(ckpt_repo_id, ckpt_filenmae, ckpt_config_filename)
261
+ sam = build_sam(checkpoint=sam_checkpoint)
262
+ sam.cuda()
263
+ sam_predictor = SamPredictor(sam)
264
+ return groundingdino_model, sam_predictor
265
+
266
+ def resize_and_center_crop(image, output_size=(1024, 576)):
267
+ width, height = image.size
268
+ aspect_ratio = width / height
269
+ new_height = output_size[1]
270
+ new_width = int(aspect_ratio * new_height)
271
+
272
+ resized_image = image.resize((new_width, new_height), Image.LANCZOS)
273
+
274
+ if new_width < output_size[0] or new_height < output_size[1]:
275
+ padding_color = "gray"
276
+ resized_image = ImageOps.expand(resized_image,
277
+ ((output_size[0] - new_width) // 2,
278
+ (output_size[1] - new_height) // 2,
279
+ (output_size[0] - new_width + 1) // 2,
280
+ (output_size[1] - new_height + 1) // 2),
281
+ fill=padding_color)
282
+
283
+ left = (resized_image.width - output_size[0]) / 2
284
+ top = (resized_image.height - output_size[1]) / 2
285
+ right = (resized_image.width + output_size[0]) / 2
286
+ bottom = (resized_image.height + output_size[1]) / 2
287
+
288
+ cropped_image = resized_image.crop((left, top, right, bottom))
289
+
290
+ return cropped_image
291
+
292
+ def main(device, segment_type):
293
+ pipe, controller, pipe_concept = build_model_sd(args.pretrained_sdxl_model, args.openpose_checkpoint, device, prompts_tmp)
294
+
295
+ if segment_type == 'GroundingDINO':
296
+ detect_model, sam = build_dino_segment_model(args.dino_checkpoint, args.sam_checkpoint)
297
+ else:
298
+ detect_model, sam = build_yolo_segment_model(args.efficientViT_checkpoint, device)
299
+
300
+ resolution_list = ["1440*728",
301
+ "1344*768",
302
+ "1216*832",
303
+ "1152*896",
304
+ "1024*1024",
305
+ "896*1152",
306
+ "832*1216",
307
+ "768*1344",
308
+ "728*1440"]
309
+ ratio_list = [1440/728, 1344/768, 1216/832, 1152/896, 1024/1024, 896/1152, 832/1216, 768/1344, 728/1440]
310
+ condition_list = ["None",
311
+ "Human pose",
312
+ "Canny Edge",
313
+ "Depth"]
314
+
315
+ depth_estimator = DPTForDepthEstimation.from_pretrained(args.dpt_checkpoint).to("cuda")
316
+ feature_extractor = DPTFeatureExtractor.from_pretrained(args.dpt_checkpoint)
317
+ body_model = Body(args.pose_detector_checkpoint)
318
+ openpose = OpenposeDetector(body_model)
319
+
320
+ def remove_tips():
321
+ return gr.update(visible=False)
322
+
323
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
324
+ if randomize_seed:
325
+ seed = random.randint(0, MAX_SEED)
326
+ return seed
327
+
328
+ def get_humanpose(img):
329
+ openpose_image = openpose(img)
330
+ return openpose_image
331
+
332
+ def get_cannyedge(image):
333
+ image = np.array(image)
334
+ image = cv2.Canny(image, 100, 200)
335
+ image = image[:, :, None]
336
+ image = np.concatenate([image, image, image], axis=2)
337
+ canny_image = Image.fromarray(image)
338
+ return canny_image
339
+
340
+ def get_depth(image):
341
+ image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda")
342
+ with torch.no_grad(), torch.autocast("cuda"):
343
+ depth_map = depth_estimator(image).predicted_depth
344
+
345
+ depth_map = torch.nn.functional.interpolate(
346
+ depth_map.unsqueeze(1),
347
+ size=(1024, 1024),
348
+ mode="bicubic",
349
+ align_corners=False,
350
+ )
351
+ depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
352
+ depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
353
+ depth_map = (depth_map - depth_min) / (depth_max - depth_min)
354
+ image = torch.cat([depth_map] * 3, dim=1)
355
+ image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
356
+ image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
357
+ return image
358
+
359
+ def generate_image(prompt1, prompt2, prompt3, prompt4, negative_prompt, man, woman, resolution, local_prompt1, local_prompt2, seed, condition, condition_img1, condition_img2, condition_img3, condition_img4, style):
360
+ try:
361
+ path1 = lorapath_man[man]
362
+ path2 = lorapath_woman[woman]
363
+ pipe_concept.unload_lora_weights()
364
+ pipe.unload_lora_weights()
365
+ pipe_list = build_model_lora(pipe_concept, path1 + "|" + path2, lorapath_styles[style], condition, args, pipe)
366
+
367
+ if lorapath_styles[style] is not None and os.path.exists(lorapath_styles[style]):
368
+ styleL = True
369
+ else:
370
+ styleL = False
371
+
372
+ input_list = [prompt1, prompt2, prompt3, prompt4]
373
+ condition_list = [condition_img1, condition_img2, condition_img3, condition_img4]
374
+ output_list = []
375
+
376
+ width, height = int(resolution.split("*")[0]), int(resolution.split("*")[1])
377
+
378
+ kwargs = {
379
+ 'height': height,
380
+ 'width': width,
381
+ }
382
+
383
+ for prompt, condition_img in zip(input_list, condition_list):
384
+ if prompt!='':
385
+ input_prompt = []
386
+ p = '{prompt}, 35mm photograph, film, professional, 4k, highly detailed.'
387
+ if styleL:
388
+ p = styles[style] + p
389
+ input_prompt.append([p.replace("{prompt}", prompt), p.replace("{prompt}", prompt)])
390
+ if styleL:
391
+ input_prompt.append([(styles[style] + local_prompt1, character_man.get(man)[1]), (styles[style] + local_prompt2, character_woman.get(woman)[1])])
392
+ else:
393
+ input_prompt.append([(local_prompt1, character_man.get(man)[1]), (local_prompt2, character_woman.get(woman)[1])])
394
+
395
+ if condition == 'Human pose' and condition_img is not None:
396
+ index = ratio_list.index(min(ratio_list, key=lambda x: abs(x-condition_img.shape[1]/condition_img.shape[0])))
397
+ resolution = resolution_list[index]
398
+ width, height = int(resolution.split("*")[0]), int(resolution.split("*")[1])
399
+ kwargs['height'] = height
400
+ kwargs['width'] = width
401
+ condition_img = resize_and_center_crop(Image.fromarray(condition_img), (width, height))
402
+ spatial_condition = get_humanpose(condition_img)
403
+ elif condition == 'Canny Edge' and condition_img is not None:
404
+ index = ratio_list.index(
405
+ min(ratio_list, key=lambda x: abs(x - condition_img.shape[1] / condition_img.shape[0])))
406
+ resolution = resolution_list[index]
407
+ width, height = int(resolution.split("*")[0]), int(resolution.split("*")[1])
408
+ kwargs['height'] = height
409
+ kwargs['width'] = width
410
+ condition_img = resize_and_center_crop(Image.fromarray(condition_img), (width, height))
411
+ spatial_condition = get_cannyedge(condition_img)
412
+ elif condition == 'Depth' and condition_img is not None:
413
+ index = ratio_list.index(
414
+ min(ratio_list, key=lambda x: abs(x - condition_img.shape[1] / condition_img.shape[0])))
415
+ resolution = resolution_list[index]
416
+ width, height = int(resolution.split("*")[0]), int(resolution.split("*")[1])
417
+ kwargs['height'] = height
418
+ kwargs['width'] = width
419
+ condition_img = resize_and_center_crop(Image.fromarray(condition_img), (width, height))
420
+ spatial_condition = get_depth(condition_img)
421
+ else:
422
+ spatial_condition = None
423
+
424
+ kwargs['spatial_condition'] = spatial_condition
425
+
426
+ controller.reset()
427
+ image = sample_image(
428
+ pipe,
429
+ input_prompt=input_prompt,
430
+ concept_models=pipe_concept,
431
+ input_neg_prompt=[negative_prompt] * len(input_prompt),
432
+ generator=torch.Generator(device).manual_seed(seed),
433
+ controller=controller,
434
+ stage=1,
435
+ lora_list=pipe_list,
436
+ styleL=styleL,
437
+ **kwargs)
438
+
439
+ controller.reset()
440
+ if pipe.tokenizer("man")["input_ids"][1] in pipe.tokenizer(args.prompt)["input_ids"][1:-1]:
441
+ mask1 = predict_mask(detect_model, sam, image[0], 'man', args.segment_type, confidence=0.15,
442
+ threshold=0.5)
443
+ else:
444
+ mask1 = None
445
+
446
+ if pipe.tokenizer("woman")["input_ids"][1] in pipe.tokenizer(args.prompt)["input_ids"][1:-1]:
447
+ mask2 = predict_mask(detect_model, sam, image[0], 'woman', args.segment_type, confidence=0.15,
448
+ threshold=0.5)
449
+ else:
450
+ mask2 = None
451
+
452
+ if mask1 is None and mask2 is None:
453
+ output_list.append(image[1])
454
+ else:
455
+ image = sample_image(
456
+ pipe,
457
+ input_prompt=input_prompt,
458
+ concept_models=pipe_concept,
459
+ input_neg_prompt=[negative_prompt] * len(input_prompt),
460
+ generator=torch.Generator(device).manual_seed(seed),
461
+ controller=controller,
462
+ stage=2,
463
+ region_masks=[mask1, mask2],
464
+ lora_list=pipe_list,
465
+ styleL=styleL,
466
+ **kwargs)
467
+ output_list.append(image[1])
468
+ else:
469
+ output_list.append(None)
470
+ return output_list
471
+ except:
472
+ print("error")
473
+ return None, None, None, None
474
+
475
+ def get_local_value_man(input):
476
+ return character_man[input][0]
477
+
478
+ def get_local_value_woman(input):
479
+ return character_woman[input][0]
480
+
481
+
482
+ with gr.Blocks(css=css) as demo:
483
+ # description
484
+ gr.Markdown(title)
485
+ gr.Markdown(description)
486
+
487
+ with gr.Row():
488
+ gallery = gr.Image(label="Generated Images", height=512, width=512)
489
+ gallery2 = gr.Image(label="Generated Images", height=512, width=512)
490
+ gallery3 = gr.Image(label="Generated Images", height=512, width=512)
491
+ gallery4 = gr.Image(label="Generated Images", height=512, width=512)
492
+ usage_tips = gr.Markdown(label="Usage tips of OMG", value=tips, visible=False)
493
+
494
+ with gr.Row():
495
+ condition_img1 = gr.Image(label="Input condition", height=128, width=128)
496
+ condition_img2 = gr.Image(label="Input condition", height=128, width=128)
497
+ condition_img3 = gr.Image(label="Input condition", height=128, width=128)
498
+ condition_img4 = gr.Image(label="Input condition", height=128, width=128)
499
+
500
+ # character choose
501
+ with gr.Row():
502
+ man = gr.Dropdown(label="Character 1 selection", choices=CHARACTER_MAN_NAMES, value="Chris Evans (identifier: Chris Evans)")
503
+ woman = gr.Dropdown(label="Character 2 selection", choices=CHARACTER_WOMAN_NAMES, value="Taylor Swift (identifier: TaylorSwift)")
504
+ resolution = gr.Dropdown(label="Image Resolution (width*height)", choices=resolution_list, value="1024*1024")
505
+ condition = gr.Dropdown(label="Input condition type", choices=condition_list, value="None")
506
+ style = gr.Dropdown(label="style", choices=STYLE_NAMES, value="None")
507
+
508
+ with gr.Row():
509
+ local_prompt1 = gr.Textbox(label="Character1_prompt",
510
+ info="Describe the Character 1, this prompt should include the identifier of character 1",
511
+ value="Close-up photo of the Chris Evans, 35mm photograph, film, professional, 4k, highly detailed.")
512
+ local_prompt2 = gr.Textbox(label="Character2_prompt",
513
+ info="Describe the Character 2, this prompt should include the identifier of character2",
514
+ value="Close-up photo of the TaylorSwift, 35mm photograph, film, professional, 4k, highly detailed.")
515
+
516
+ man.change(get_local_value_man, man, local_prompt1)
517
+ woman.change(get_local_value_woman, woman, local_prompt2)
518
+
519
+ # prompt
520
+ with gr.Column():
521
+ prompt = gr.Textbox(label="Prompt 1",
522
+ info="Give a simple prompt to describe the first image content",
523
+ placeholder="Required",
524
+ value="close-up shot, photography, a man and a woman on the street, facing the camera smiling")
525
+ prompt2 = gr.Textbox(label="Prompt 2",
526
+ info="Give a simple prompt to describe the second image content",
527
+ placeholder="optional",
528
+ value="")
529
+ prompt3 = gr.Textbox(label="Prompt 3",
530
+ info="Give a simple prompt to describe the third image content",
531
+ placeholder="optional",
532
+ value="")
533
+ prompt4 = gr.Textbox(label="Prompt 4",
534
+ info="Give a simple prompt to describe the fourth image content",
535
+ placeholder="optional",
536
+ value="")
537
+
538
+ with gr.Accordion(open=False, label="Advanced Options"):
539
+ seed = gr.Slider(
540
+ label="Seed",
541
+ minimum=0,
542
+ maximum=MAX_SEED,
543
+ step=1,
544
+ value=42,
545
+ )
546
+ negative_prompt = gr.Textbox(label="Negative Prompt",
547
+ placeholder="noisy, blurry, soft, deformed, ugly",
548
+ value="noisy, blurry, soft, deformed, ugly")
549
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
550
+
551
+ submit = gr.Button("Submit", variant="primary")
552
+
553
+ submit.click(
554
+ fn=remove_tips,
555
+ outputs=usage_tips,
556
+ ).then(
557
+ fn=randomize_seed_fn,
558
+ inputs=[seed, randomize_seed],
559
+ outputs=seed,
560
+ queue=False,
561
+ api_name=False,
562
+ ).then(
563
+ fn=generate_image,
564
+ inputs=[prompt, prompt2, prompt3, prompt4, negative_prompt, man, woman, resolution, local_prompt1, local_prompt2, seed, condition, condition_img1, condition_img2, condition_img3, condition_img4, style],
565
+ outputs=[gallery, gallery2, gallery3, gallery4]
566
+ )
567
+ demo.launch(server_name='0.0.0.0',server_port=7861, debug=True)
568
+
569
+ def parse_args():
570
+ parser = argparse.ArgumentParser('', add_help=False)
571
+ parser.add_argument('--pretrained_sdxl_model', default='./checkpoint/stable-diffusion-xl-base-1.0', type=str)
572
+ parser.add_argument('--openpose_checkpoint', default='./checkpoint/controlnet-openpose-sdxl-1.0', type=str)
573
+ parser.add_argument('--canny_checkpoint', default='./checkpoint/controlnet-canny-sdxl-1.0', type=str)
574
+ parser.add_argument('--depth_checkpoint', default='./checkpoint/controlnet-depth-sdxl-1.0', type=str)
575
+ parser.add_argument('--efficientViT_checkpoint', default='./checkpoint/sam/xl1.pt', type=str)
576
+ parser.add_argument('--dino_checkpoint', default='./checkpoint/GroundingDINO', type=str)
577
+ parser.add_argument('--sam_checkpoint', default='./checkpoint/sam/sam_vit_h_4b8939.pth', type=str)
578
+ parser.add_argument('--dpt_checkpoint', default='./checkpoint/dpt-hybrid-midas', type=str)
579
+ parser.add_argument('--pose_detector_checkpoint', default='./checkpoint/ControlNet/annotator/ckpts/body_pose_model.pth', type=str)
580
+ parser.add_argument('--prompt', default='Close-up photo of the cool man and beautiful woman in surprised expressions as they accidentally discover a mysterious island while on vacation by the sea, 35mm photograph, film, professional, 4k, highly detailed.', type=str)
581
+ parser.add_argument('--negative_prompt', default='noisy, blurry, soft, deformed, ugly', type=str)
582
+ parser.add_argument('--seed', default=22, type=int)
583
+ parser.add_argument('--suffix', default='', type=str)
584
+ parser.add_argument('--segment_type', default='yoloworld', help='GroundingDINO or yoloworld', type=str)
585
+ return parser.parse_args()
586
+
587
+ if __name__ == '__main__':
588
+ args = parse_args()
589
+
590
+ prompts = [args.prompt]*2
591
+ prompts_tmp = copy.deepcopy(prompts)
592
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
593
+
594
+ main(device, args.segment_type)
gradio_demo/app_generateOne.py ADDED
@@ -0,0 +1,582 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append('./')
3
+ import gradio as gr
4
+ import random
5
+ import numpy as np
6
+ from gradio_demo.character_template import character_man, lorapath_man
7
+ from gradio_demo.character_template import character_woman, lorapath_woman
8
+ from gradio_demo.character_template import styles, lorapath_styles
9
+ import torch
10
+ import os
11
+ from typing import Tuple, List
12
+ import copy
13
+ import argparse
14
+ from diffusers.utils import load_image
15
+ import cv2
16
+ from PIL import Image, ImageOps
17
+ from transformers import DPTFeatureExtractor, DPTForDepthEstimation
18
+ from controlnet_aux import OpenposeDetector
19
+ from controlnet_aux.open_pose.body import Body
20
+
21
+ try:
22
+ from inference.models import YOLOWorld
23
+ from src.efficientvit.models.efficientvit.sam import EfficientViTSamPredictor
24
+ from src.efficientvit.sam_model_zoo import create_sam_model
25
+ import supervision as sv
26
+ except:
27
+ print("YoloWorld can not be load")
28
+
29
+ try:
30
+ from groundingdino.models import build_model
31
+ from groundingdino.util import box_ops
32
+ from groundingdino.util.slconfig import SLConfig
33
+ from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
34
+ from groundingdino.util.inference import annotate, predict
35
+ from segment_anything import build_sam, SamPredictor
36
+ import groundingdino.datasets.transforms as T
37
+ except:
38
+ print("groundingdino can not be load")
39
+
40
+ from src.pipelines.lora_pipeline import LoraMultiConceptPipeline
41
+ from src.prompt_attention.p2p_attention import AttentionReplace
42
+ from diffusers import ControlNetModel, StableDiffusionXLPipeline
43
+ from src.pipelines.lora_pipeline import revise_regionally_controlnet_forward
44
+
45
+ CHARACTER_MAN_NAMES = list(character_man.keys())
46
+ CHARACTER_WOMAN_NAMES = list(character_woman.keys())
47
+ STYLE_NAMES = list(styles.keys())
48
+ MAX_SEED = np.iinfo(np.int32).max
49
+
50
+ ### Description
51
+ title = r"""
52
+ <h1 align="center">OMG: Occlusion-friendly Personalized Multi-concept Generation In Diffusion Models</h1>
53
+ """
54
+
55
+ description = r"""
56
+ <b>Official 🤗 Gradio demo</b> for <a href='https://github.com/' target='_blank'><b>OMG: Occlusion-friendly Personalized Multi-concept Generation In Diffusion Models</b></a>.<br>
57
+
58
+ How to use:<br>
59
+ 1. Select two characters.
60
+ 2. Enter a text prompt as done in normal text-to-image models.
61
+ 3. Click the <b>Submit</b> button to start customizing.
62
+ 4. Enjoy the generated image😊!
63
+ """
64
+
65
+ article = r"""
66
+ ---
67
+ 📝 **Citation**
68
+ <br>
69
+ If our work is helpful for your research or applications, please cite us via:
70
+ ```bibtex
71
+ @article{,
72
+ title={OMG: Occlusion-friendly Personalized Multi-concept Generation In Diffusion Models},
73
+ author={},
74
+ journal={},
75
+ year={}
76
+ }
77
+ ```
78
+ """
79
+
80
+ tips = r"""
81
+ ### Usage tips of OMG
82
+ 1. Input text prompts to describe a man and a woman
83
+ """
84
+
85
+ css = '''
86
+ .gradio-container {width: 85% !important}
87
+ '''
88
+
89
+ def sample_image(pipe,
90
+ input_prompt,
91
+ input_neg_prompt=None,
92
+ generator=None,
93
+ concept_models=None,
94
+ num_inference_steps=50,
95
+ guidance_scale=7.5,
96
+ controller=None,
97
+ stage=None,
98
+ region_masks=None,
99
+ lora_list = None,
100
+ styleL=None,
101
+ **extra_kargs
102
+ ):
103
+
104
+ spatial_condition = extra_kargs.pop('spatial_condition')
105
+ if spatial_condition is not None:
106
+ spatial_condition_input = [spatial_condition] * len(input_prompt)
107
+ else:
108
+ spatial_condition_input = None
109
+
110
+ images = pipe(
111
+ prompt=input_prompt,
112
+ concept_models=concept_models,
113
+ negative_prompt=input_neg_prompt,
114
+ generator=generator,
115
+ guidance_scale=guidance_scale,
116
+ num_inference_steps=num_inference_steps,
117
+ cross_attention_kwargs={"scale": 0.8},
118
+ controller=controller,
119
+ stage=stage,
120
+ region_masks=region_masks,
121
+ lora_list=lora_list,
122
+ styleL=styleL,
123
+ image=spatial_condition_input,
124
+ **extra_kargs).images
125
+
126
+ return images
127
+
128
+ def load_image_yoloworld(image_source) -> Tuple[np.array, torch.Tensor]:
129
+ image = np.asarray(image_source)
130
+ return image
131
+
132
+ def load_image_dino(image_source) -> Tuple[np.array, torch.Tensor]:
133
+ transform = T.Compose(
134
+ [
135
+ T.RandomResize([800], max_size=1333),
136
+ T.ToTensor(),
137
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
138
+ ]
139
+ )
140
+ image = np.asarray(image_source)
141
+ image_transformed, _ = transform(image_source, None)
142
+ return image, image_transformed
143
+
144
+ def predict_mask(segmentmodel, sam, image, TEXT_PROMPT, segmentType, confidence = 0.2, threshold = 0.5):
145
+ if segmentType=='GroundingDINO':
146
+ image_source, image = load_image_dino(image)
147
+ boxes, logits, phrases = predict(
148
+ model=segmentmodel,
149
+ image=image,
150
+ caption=TEXT_PROMPT,
151
+ box_threshold=0.3,
152
+ text_threshold=0.25
153
+ )
154
+ sam.set_image(image_source)
155
+ H, W, _ = image_source.shape
156
+ boxes_xyxy = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H])
157
+
158
+ transformed_boxes = sam.transform.apply_boxes_torch(boxes_xyxy, image_source.shape[:2]).cuda()
159
+ masks, _, _ = sam.predict_torch(
160
+ point_coords=None,
161
+ point_labels=None,
162
+ boxes=transformed_boxes,
163
+ multimask_output=False,
164
+ )
165
+ masks=masks[0].squeeze(0)
166
+ else:
167
+ image_source = load_image_yoloworld(image)
168
+ segmentmodel.set_classes([TEXT_PROMPT])
169
+ results = segmentmodel.infer(image_source, confidence=confidence)
170
+ detections = sv.Detections.from_inference(results).with_nms(
171
+ class_agnostic=True, threshold=threshold
172
+ )
173
+ masks = None
174
+ if len(detections) != 0:
175
+ print(TEXT_PROMPT + " detected!")
176
+ sam.set_image(image_source, image_format="RGB")
177
+ masks, _, _ = sam.predict(box=detections.xyxy[0], multimask_output=False)
178
+ masks = torch.from_numpy(masks.squeeze())
179
+
180
+ return masks
181
+
182
+ def prepare_text(prompt, region_prompts):
183
+ '''
184
+ Args:
185
+ prompt_entity: [subject1]-*-[attribute1]-*-[Location1]|[subject2]-*-[attribute2]-*-[Location2]|[global text]
186
+ Returns:
187
+ full_prompt: subject1, attribute1 and subject2, attribute2, global text
188
+ context_prompt: subject1 and subject2, global text
189
+ entity_collection: [(subject1, attribute1), Location1]
190
+ '''
191
+ region_collection = []
192
+
193
+ regions = region_prompts.split('|')
194
+
195
+ for region in regions:
196
+ if region == '':
197
+ break
198
+ prompt_region, neg_prompt_region = region.split('-*-')
199
+ prompt_region = prompt_region.replace('[', '').replace(']', '')
200
+ neg_prompt_region = neg_prompt_region.replace('[', '').replace(']', '')
201
+
202
+ region_collection.append((prompt_region, neg_prompt_region))
203
+ return (prompt, region_collection)
204
+
205
+
206
+ def build_model_sd(pretrained_model, controlnet_path, device, prompts):
207
+ controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16).to(device)
208
+ pipe = LoraMultiConceptPipeline.from_pretrained(
209
+ pretrained_model, controlnet=controlnet, torch_dtype=torch.float16, variant="fp16").to(device)
210
+ controller = AttentionReplace(prompts, 50, cross_replace_steps={"default_": 1.}, self_replace_steps=0.4, tokenizer=pipe.tokenizer, device=device, dtype=torch.float16, width=1024//32, height=1024//32)
211
+ revise_regionally_controlnet_forward(pipe.unet, controller)
212
+ pipe_concept = StableDiffusionXLPipeline.from_pretrained(pretrained_model, torch_dtype=torch.float16,
213
+ variant="fp16").to(device)
214
+ return pipe, controller, pipe_concept
215
+
216
+ def build_model_lora(pipe_concept, lora_paths, style_path, condition, args, pipe):
217
+ pipe_list = []
218
+ if condition == "Human pose":
219
+ controlnet = ControlNetModel.from_pretrained(args.openpose_checkpoint, torch_dtype=torch.float16).to(device)
220
+ pipe.controlnet = controlnet
221
+ elif condition == "Canny Edge":
222
+ controlnet = ControlNetModel.from_pretrained(args.canny_checkpoint, torch_dtype=torch.float16, variant="fp16").to(device)
223
+ pipe.controlnet = controlnet
224
+ elif condition == "Depth":
225
+ controlnet = ControlNetModel.from_pretrained(args.depth_checkpoint, torch_dtype=torch.float16).to(device)
226
+ pipe.controlnet = controlnet
227
+
228
+ if style_path is not None and os.path.exists(style_path):
229
+ pipe_concept.load_lora_weights(style_path, weight_name="pytorch_lora_weights.safetensors", adapter_name='style')
230
+ pipe.load_lora_weights(style_path, weight_name="pytorch_lora_weights.safetensors", adapter_name='style')
231
+
232
+ for lora_path in lora_paths.split('|'):
233
+ adapter_name = lora_path.split('/')[-1].split('.')[0]
234
+ pipe_concept.load_lora_weights(lora_path, weight_name="pytorch_lora_weights.safetensors", adapter_name=adapter_name)
235
+ pipe_concept.enable_xformers_memory_efficient_attention()
236
+ pipe_list.append(adapter_name)
237
+ return pipe_list
238
+
239
+ def build_yolo_segment_model(sam_path, device):
240
+ yolo_world = YOLOWorld(model_id="yolo_world/l")
241
+ sam = EfficientViTSamPredictor(
242
+ create_sam_model(name="xl1", weight_url=sam_path).to(device).eval()
243
+ )
244
+ return yolo_world, sam
245
+
246
+ def load_model_hf(repo_id, filename, ckpt_config_filename, device='cpu'):
247
+ args = SLConfig.fromfile(ckpt_config_filename)
248
+ model = build_model(args)
249
+ args.device = device
250
+
251
+ checkpoint = torch.load(os.path.join(repo_id, filename), map_location='cpu')
252
+ log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
253
+ print("Model loaded from {} \n => {}".format(filename, log))
254
+ _ = model.eval()
255
+ return model
256
+
257
+ def build_dino_segment_model(ckpt_repo_id, sam_checkpoint):
258
+ ckpt_filenmae = "groundingdino_swinb_cogcoor.pth"
259
+ ckpt_config_filename = os.path.join(ckpt_repo_id, "GroundingDINO_SwinB.cfg.py")
260
+ groundingdino_model = load_model_hf(ckpt_repo_id, ckpt_filenmae, ckpt_config_filename)
261
+ sam = build_sam(checkpoint=sam_checkpoint)
262
+ sam.cuda()
263
+ sam_predictor = SamPredictor(sam)
264
+ return groundingdino_model, sam_predictor
265
+
266
+ def resize_and_center_crop(image, output_size=(1024, 576)):
267
+ width, height = image.size
268
+ aspect_ratio = width / height
269
+ new_height = output_size[1]
270
+ new_width = int(aspect_ratio * new_height)
271
+
272
+ resized_image = image.resize((new_width, new_height), Image.LANCZOS)
273
+
274
+ if new_width < output_size[0] or new_height < output_size[1]:
275
+ padding_color = "gray"
276
+ resized_image = ImageOps.expand(resized_image,
277
+ ((output_size[0] - new_width) // 2,
278
+ (output_size[1] - new_height) // 2,
279
+ (output_size[0] - new_width + 1) // 2,
280
+ (output_size[1] - new_height + 1) // 2),
281
+ fill=padding_color)
282
+
283
+ left = (resized_image.width - output_size[0]) / 2
284
+ top = (resized_image.height - output_size[1]) / 2
285
+ right = (resized_image.width + output_size[0]) / 2
286
+ bottom = (resized_image.height + output_size[1]) / 2
287
+
288
+ cropped_image = resized_image.crop((left, top, right, bottom))
289
+
290
+ return cropped_image
291
+
292
+ def main(device, segment_type):
293
+ pipe, controller, pipe_concept = build_model_sd(args.pretrained_sdxl_model, args.openpose_checkpoint, device, prompts_tmp)
294
+
295
+ if segment_type == 'GroundingDINO':
296
+ detect_model, sam = build_dino_segment_model(args.dino_checkpoint, args.sam_checkpoint)
297
+ else:
298
+ detect_model, sam = build_yolo_segment_model(args.efficientViT_checkpoint, device)
299
+
300
+ resolution_list = ["1440*728",
301
+ "1344*768",
302
+ "1216*832",
303
+ "1152*896",
304
+ "1024*1024",
305
+ "896*1152",
306
+ "832*1216",
307
+ "768*1344",
308
+ "728*1440"]
309
+ ratio_list = [1440 / 728, 1344 / 768, 1216 / 832, 1152 / 896, 1024 / 1024, 896 / 1152, 832 / 1216, 768 / 1344,
310
+ 728 / 1440]
311
+ condition_list = ["None",
312
+ "Human pose",
313
+ "Canny Edge",
314
+ "Depth"]
315
+
316
+ depth_estimator = DPTForDepthEstimation.from_pretrained(args.dpt_checkpoint).to("cuda")
317
+ feature_extractor = DPTFeatureExtractor.from_pretrained(args.dpt_checkpoint)
318
+ body_model = Body(args.pose_detector_checkpoint)
319
+ openpose = OpenposeDetector(body_model)
320
+
321
+ def remove_tips():
322
+ return gr.update(visible=False)
323
+
324
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
325
+ if randomize_seed:
326
+ seed = random.randint(0, MAX_SEED)
327
+ return seed
328
+
329
+ def get_humanpose(img):
330
+ openpose_image = openpose(img)
331
+ return openpose_image
332
+
333
+ def get_cannyedge(image):
334
+ image = np.array(image)
335
+ image = cv2.Canny(image, 100, 200)
336
+ image = image[:, :, None]
337
+ image = np.concatenate([image, image, image], axis=2)
338
+ canny_image = Image.fromarray(image)
339
+ return canny_image
340
+
341
+ def get_depth(image):
342
+ image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda")
343
+ with torch.no_grad(), torch.autocast("cuda"):
344
+ depth_map = depth_estimator(image).predicted_depth
345
+
346
+ depth_map = torch.nn.functional.interpolate(
347
+ depth_map.unsqueeze(1),
348
+ size=(1024, 1024),
349
+ mode="bicubic",
350
+ align_corners=False,
351
+ )
352
+ depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
353
+ depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
354
+ depth_map = (depth_map - depth_min) / (depth_max - depth_min)
355
+ image = torch.cat([depth_map] * 3, dim=1)
356
+ image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
357
+ image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
358
+ return image
359
+
360
+ def generate_image(prompt1, negative_prompt, man, woman, resolution, local_prompt1, local_prompt2, seed, condition, condition_img1, style):
361
+ try:
362
+ path1 = lorapath_man[man]
363
+ path2 = lorapath_woman[woman]
364
+ pipe_concept.unload_lora_weights()
365
+ pipe.unload_lora_weights()
366
+ pipe_list = build_model_lora(pipe_concept, path1 + "|" + path2, lorapath_styles[style], condition, args, pipe)
367
+
368
+ if lorapath_styles[style] is not None and os.path.exists(lorapath_styles[style]):
369
+ styleL = True
370
+ else:
371
+ styleL = False
372
+
373
+ input_list = [prompt1]
374
+ condition_list = [condition_img1]
375
+ output_list = []
376
+
377
+ width, height = int(resolution.split("*")[0]), int(resolution.split("*")[1])
378
+
379
+ kwargs = {
380
+ 'height': height,
381
+ 'width': width,
382
+ }
383
+
384
+ for prompt, condition_img in zip(input_list, condition_list):
385
+ if prompt!='':
386
+ input_prompt = []
387
+ p = '{prompt}, 35mm photograph, film, professional, 4k, highly detailed.'
388
+ if styleL:
389
+ p = styles[style] + p
390
+ input_prompt.append([p.replace("{prompt}", prompt), p.replace("{prompt}", prompt)])
391
+ if styleL:
392
+ input_prompt.append([(styles[style] + local_prompt1, character_man.get(man)[1]),
393
+ (styles[style] + local_prompt2, character_woman.get(woman)[1])])
394
+ else:
395
+ input_prompt.append([(local_prompt1, character_man.get(man)[1]),
396
+ (local_prompt2, character_woman.get(woman)[1])])
397
+
398
+ if condition == 'Human pose' and condition_img is not None:
399
+ index = ratio_list.index(
400
+ min(ratio_list, key=lambda x: abs(x - condition_img.shape[1] / condition_img.shape[0])))
401
+ resolution = resolution_list[index]
402
+ width, height = int(resolution.split("*")[0]), int(resolution.split("*")[1])
403
+ kwargs['height'] = height
404
+ kwargs['width'] = width
405
+ condition_img = resize_and_center_crop(Image.fromarray(condition_img), (width, height))
406
+ spatial_condition = get_humanpose(condition_img)
407
+ elif condition == 'Canny Edge' and condition_img is not None:
408
+ index = ratio_list.index(
409
+ min(ratio_list, key=lambda x: abs(x - condition_img.shape[1] / condition_img.shape[0])))
410
+ resolution = resolution_list[index]
411
+ width, height = int(resolution.split("*")[0]), int(resolution.split("*")[1])
412
+ kwargs['height'] = height
413
+ kwargs['width'] = width
414
+ condition_img = resize_and_center_crop(Image.fromarray(condition_img), (width, height))
415
+ spatial_condition = get_cannyedge(condition_img)
416
+ elif condition == 'Depth' and condition_img is not None:
417
+ index = ratio_list.index(
418
+ min(ratio_list, key=lambda x: abs(x - condition_img.shape[1] / condition_img.shape[0])))
419
+ resolution = resolution_list[index]
420
+ width, height = int(resolution.split("*")[0]), int(resolution.split("*")[1])
421
+ kwargs['height'] = height
422
+ kwargs['width'] = width
423
+ condition_img = resize_and_center_crop(Image.fromarray(condition_img), (width, height))
424
+ spatial_condition = get_depth(condition_img)
425
+ else:
426
+ spatial_condition = None
427
+
428
+ kwargs['spatial_condition'] = spatial_condition
429
+ controller.reset()
430
+ image = sample_image(
431
+ pipe,
432
+ input_prompt=input_prompt,
433
+ concept_models=pipe_concept,
434
+ input_neg_prompt=[negative_prompt] * len(input_prompt),
435
+ generator=torch.Generator(device).manual_seed(seed),
436
+ controller=controller,
437
+ stage=1,
438
+ lora_list=pipe_list,
439
+ styleL=styleL,
440
+ **kwargs)
441
+
442
+ controller.reset()
443
+ if pipe.tokenizer("man")["input_ids"][1] in pipe.tokenizer(args.prompt)["input_ids"][1:-1]:
444
+ mask1 = predict_mask(detect_model, sam, image[0], 'man', args.segment_type, confidence=0.15,
445
+ threshold=0.5)
446
+ else:
447
+ mask1 = None
448
+
449
+ if pipe.tokenizer("woman")["input_ids"][1] in pipe.tokenizer(args.prompt)["input_ids"][1:-1]:
450
+ mask2 = predict_mask(detect_model, sam, image[0], 'woman', args.segment_type, confidence=0.15,
451
+ threshold=0.5)
452
+ else:
453
+ mask2 = None
454
+
455
+ if mask1 is None and mask2 is None:
456
+ output_list.append(image[1])
457
+ else:
458
+ image = sample_image(
459
+ pipe,
460
+ input_prompt=input_prompt,
461
+ concept_models=pipe_concept,
462
+ input_neg_prompt=[negative_prompt] * len(input_prompt),
463
+ generator=torch.Generator(device).manual_seed(seed),
464
+ controller=controller,
465
+ stage=2,
466
+ region_masks=[mask1, mask2],
467
+ lora_list=pipe_list,
468
+ styleL=styleL,
469
+ **kwargs)
470
+ output_list.append(image[1])
471
+ else:
472
+ output_list.append(None)
473
+ output_list.append(spatial_condition)
474
+ return output_list
475
+ except:
476
+ print("error")
477
+ return
478
+
479
+ def get_local_value_man(input):
480
+ return character_man[input][0]
481
+
482
+ def get_local_value_woman(input):
483
+ return character_woman[input][0]
484
+
485
+
486
+ with gr.Blocks(css=css) as demo:
487
+ # description
488
+ gr.Markdown(title)
489
+ gr.Markdown(description)
490
+
491
+ with gr.Row():
492
+ gallery = gr.Image(label="Generated Images", height=512, width=512)
493
+ gen_condition = gr.Image(label="Spatial Condition", height=512, width=512)
494
+ usage_tips = gr.Markdown(label="Usage tips of OMG", value=tips, visible=False)
495
+
496
+ with gr.Row():
497
+ condition_img1 = gr.Image(label="Input an RGB image for condition", height=128, width=128)
498
+
499
+ # character choose
500
+ with gr.Row():
501
+ man = gr.Dropdown(label="Character 1 selection", choices=CHARACTER_MAN_NAMES, value="Chris Evans (identifier: Chris Evans)")
502
+ woman = gr.Dropdown(label="Character 2 selection", choices=CHARACTER_WOMAN_NAMES, value="Taylor Swift (identifier: TaylorSwift)")
503
+ resolution = gr.Dropdown(label="Image Resolution (width*height)", choices=resolution_list, value="1024*1024")
504
+ condition = gr.Dropdown(label="Input condition type", choices=condition_list, value="None")
505
+ style = gr.Dropdown(label="style", choices=STYLE_NAMES, value="None")
506
+
507
+ with gr.Row():
508
+ local_prompt1 = gr.Textbox(label="Character1_prompt",
509
+ info="Describe the Character 1, this prompt should include the identifier of character 1",
510
+ value="Close-up photo of the Chris Evans, 35mm photograph, film, professional, 4k, highly detailed.")
511
+ local_prompt2 = gr.Textbox(label="Character2_prompt",
512
+ info="Describe the Character 2, this prompt should include the identifier of character2",
513
+ value="Close-up photo of the TaylorSwift, 35mm photograph, film, professional, 4k, highly detailed.")
514
+
515
+ man.change(get_local_value_man, man, local_prompt1)
516
+ woman.change(get_local_value_woman, woman, local_prompt2)
517
+
518
+ # prompt
519
+ with gr.Column():
520
+ prompt = gr.Textbox(label="Prompt 1",
521
+ info="Give a simple prompt to describe the first image content",
522
+ placeholder="Required",
523
+ value="close-up shot, photography, a man and a woman on the street, facing the camera smiling")
524
+
525
+
526
+ with gr.Accordion(open=False, label="Advanced Options"):
527
+ seed = gr.Slider(
528
+ label="Seed",
529
+ minimum=0,
530
+ maximum=MAX_SEED,
531
+ step=1,
532
+ value=42,
533
+ )
534
+ negative_prompt = gr.Textbox(label="Negative Prompt",
535
+ placeholder="noisy, blurry, soft, deformed, ugly",
536
+ value="noisy, blurry, soft, deformed, ugly")
537
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
538
+
539
+ submit = gr.Button("Submit", variant="primary")
540
+
541
+ submit.click(
542
+ fn=remove_tips,
543
+ outputs=usage_tips,
544
+ ).then(
545
+ fn=randomize_seed_fn,
546
+ inputs=[seed, randomize_seed],
547
+ outputs=seed,
548
+ queue=False,
549
+ api_name=False,
550
+ ).then(
551
+ fn=generate_image,
552
+ inputs=[prompt, negative_prompt, man, woman, resolution, local_prompt1, local_prompt2, seed, condition, condition_img1, style],
553
+ outputs=[gallery, gen_condition]
554
+ )
555
+ demo.launch(server_name='0.0.0.0',server_port=7861, debug=True)
556
+
557
+ def parse_args():
558
+ parser = argparse.ArgumentParser('', add_help=False)
559
+ parser.add_argument('--pretrained_sdxl_model', default='./checkpoint/stable-diffusion-xl-base-1.0', type=str)
560
+ parser.add_argument('--openpose_checkpoint', default='./checkpoint/controlnet-openpose-sdxl-1.0', type=str)
561
+ parser.add_argument('--canny_checkpoint', default='./checkpoint/controlnet-canny-sdxl-1.0', type=str)
562
+ parser.add_argument('--depth_checkpoint', default='./checkpoint/controlnet-depth-sdxl-1.0', type=str)
563
+ parser.add_argument('--efficientViT_checkpoint', default='./checkpoint/sam/xl1.pt', type=str)
564
+ parser.add_argument('--dino_checkpoint', default='./checkpoint/GroundingDINO', type=str)
565
+ parser.add_argument('--sam_checkpoint', default='./checkpoint/sam/sam_vit_h_4b8939.pth', type=str)
566
+ parser.add_argument('--dpt_checkpoint', default='./checkpoint/dpt-hybrid-midas', type=str)
567
+ parser.add_argument('--pose_detector_checkpoint', default='./checkpoint/ControlNet/annotator/ckpts/body_pose_model.pth', type=str)
568
+ parser.add_argument('--prompt', default='Close-up photo of the cool man and beautiful woman in surprised expressions as they accidentally discover a mysterious island while on vacation by the sea, 35mm photograph, film, professional, 4k, highly detailed.', type=str)
569
+ parser.add_argument('--negative_prompt', default='noisy, blurry, soft, deformed, ugly', type=str)
570
+ parser.add_argument('--seed', default=22, type=int)
571
+ parser.add_argument('--suffix', default='', type=str)
572
+ parser.add_argument('--segment_type', default='yoloworld', help='GroundingDINO or yoloworld', type=str)
573
+ return parser.parse_args()
574
+
575
+ if __name__ == '__main__':
576
+ args = parse_args()
577
+
578
+ prompts = [args.prompt]*2
579
+ prompts_tmp = copy.deepcopy(prompts)
580
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
581
+
582
+ main(device, args.segment_type)
gradio_demo/app_instantID.py ADDED
@@ -0,0 +1,701 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append('./')
3
+ import argparse
4
+ import hashlib
5
+ import json
6
+ import os.path
7
+ import numpy as np
8
+ import torch
9
+ from typing import Tuple, List
10
+ from diffusers import DPMSolverMultistepScheduler
11
+ from diffusers.models import T2IAdapter
12
+ from PIL import Image
13
+ import copy
14
+ from diffusers import ControlNetModel, StableDiffusionXLPipeline
15
+ from insightface.app import FaceAnalysis
16
+ import gradio as gr
17
+ import random
18
+ from PIL import Image, ImageOps
19
+ from transformers import DPTFeatureExtractor, DPTForDepthEstimation
20
+ from controlnet_aux import OpenposeDetector
21
+ from controlnet_aux.open_pose.body import Body
22
+
23
+ try:
24
+ from inference.models import YOLOWorld
25
+ from src.efficientvit.models.efficientvit.sam import EfficientViTSamPredictor
26
+ from src.efficientvit.sam_model_zoo import create_sam_model
27
+ import supervision as sv
28
+ except:
29
+ print("YoloWorld can not be load")
30
+
31
+ try:
32
+ from groundingdino.models import build_model
33
+ from groundingdino.util import box_ops
34
+ from groundingdino.util.slconfig import SLConfig
35
+ from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
36
+ from groundingdino.util.inference import annotate, predict
37
+ from segment_anything import build_sam, SamPredictor
38
+ import groundingdino.datasets.transforms as T
39
+ except:
40
+ print("groundingdino can not be load")
41
+
42
+ from src.pipelines.instantid_pipeline import InstantidMultiConceptPipeline
43
+ from src.pipelines.instantid_single_pieline import InstantidSingleConceptPipeline
44
+ from src.prompt_attention.p2p_attention import AttentionReplace
45
+ from src.pipelines.instantid_pipeline import revise_regionally_controlnet_forward
46
+ import cv2
47
+ import math
48
+ import PIL.Image
49
+
50
+ from gradio_demo.character_template import styles, lorapath_styles
51
+ STYLE_NAMES = list(styles.keys())
52
+
53
+
54
+
55
+ MAX_SEED = np.iinfo(np.int32).max
56
+
57
+ title = r"""
58
+ <h1 align="center">OMG: Occlusion-friendly Personalized Multi-concept Generation In Diffusion Models (OMG + InstantID)</h1>
59
+ """
60
+
61
+ description = r"""
62
+ <b>Official 🤗 Gradio demo</b> for <a href='https://github.com/kongzhecn/OMG/' target='_blank'><b>OMG: Occlusion-friendly Personalized Multi-concept Generation In Diffusion Models</b></a>.<be>.<br>
63
+ <a href='https://kongzhecn.github.io/omg-project/' target='_blank'><b>[Project]</b></a>.<a href='https://github.com/kongzhecn/OMG/' target='_blank'><b>[Code]</b></a>.<a href='https://arxiv.org/abs/2403.10983/' target='_blank'><b>[Arxiv]</b></a>.<br>
64
+ How to use:<br>
65
+ 1. Select two characters.
66
+ 2. Enter a text prompt as done in normal text-to-image models.
67
+ 3. Click the <b>Submit</b> button to start customizing.
68
+ 4. Enjoy the generated image😊!
69
+ """
70
+
71
+ article = r"""
72
+ ---
73
+ 📝 **Citation**
74
+ <br>
75
+ If our work is helpful for your research or applications, please cite us via:
76
+ ```bibtex
77
+ @article{,
78
+ title={OMG: Occlusion-friendly Personalized Multi-concept Generation In Diffusion Models},
79
+ author={},
80
+ journal={},
81
+ year={}
82
+ }
83
+ ```
84
+ """
85
+
86
+ tips = r"""
87
+ ### Usage tips of OMG
88
+ 1. Input text prompts to describe a man and a woman
89
+ """
90
+
91
+ css = '''
92
+ .gradio-container {width: 85% !important}
93
+ '''
94
+
95
+
96
+
97
+ def build_dino_segment_model(ckpt_repo_id, sam_checkpoint):
98
+ ckpt_filenmae = "groundingdino_swinb_cogcoor.pth"
99
+ ckpt_config_filename = os.path.join(ckpt_repo_id, "GroundingDINO_SwinB.cfg.py")
100
+ groundingdino_model = load_model_hf(ckpt_repo_id, ckpt_filenmae, ckpt_config_filename)
101
+ sam = build_sam(checkpoint=sam_checkpoint)
102
+ sam.cuda()
103
+ sam_predictor = SamPredictor(sam)
104
+ return groundingdino_model, sam_predictor
105
+
106
+ def load_model_hf(repo_id, filename, ckpt_config_filename, device='cpu'):
107
+ args = SLConfig.fromfile(ckpt_config_filename)
108
+ model = build_model(args)
109
+ args.device = device
110
+
111
+ checkpoint = torch.load(os.path.join(repo_id, filename), map_location='cpu')
112
+ log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
113
+ print("Model loaded from {} \n => {}".format(filename, log))
114
+ _ = model.eval()
115
+ return model
116
+
117
+ def build_yolo_segment_model(sam_path, device):
118
+ yolo_world = YOLOWorld(model_id="yolo_world/l")
119
+ sam = EfficientViTSamPredictor(
120
+ create_sam_model(name="xl1", weight_url=sam_path).to(device).eval()
121
+ )
122
+ return yolo_world, sam
123
+
124
+ def sample_image(pipe,
125
+ input_prompt,
126
+ input_neg_prompt=None,
127
+ generator=None,
128
+ concept_models=None,
129
+ num_inference_steps=50,
130
+ guidance_scale=7.5,
131
+ controller=None,
132
+ face_app=None,
133
+ image=None,
134
+ stage=None,
135
+ region_masks=None,
136
+ controlnet_conditioning_scale=None,
137
+ **extra_kargs
138
+ ):
139
+
140
+ if image is not None:
141
+ image_condition = [image]
142
+ else:
143
+ image_condition = None
144
+
145
+
146
+ images = pipe(
147
+ prompt=input_prompt,
148
+ concept_models=concept_models,
149
+ negative_prompt=input_neg_prompt,
150
+ generator=generator,
151
+ guidance_scale=guidance_scale,
152
+ num_inference_steps=num_inference_steps,
153
+ cross_attention_kwargs={"scale": 0.8},
154
+ controller=controller,
155
+ image=image_condition,
156
+ face_app=face_app,
157
+ stage=stage,
158
+ controlnet_conditioning_scale = controlnet_conditioning_scale,
159
+ region_masks=region_masks,
160
+ **extra_kargs).images
161
+ return images
162
+
163
+ def load_image_yoloworld(image_source) -> Tuple[np.array, torch.Tensor]:
164
+ image = np.asarray(image_source)
165
+ return image
166
+
167
+ def load_image_dino(image_source) -> Tuple[np.array, torch.Tensor]:
168
+ transform = T.Compose(
169
+ [
170
+ T.RandomResize([800], max_size=1333),
171
+ T.ToTensor(),
172
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
173
+ ]
174
+ )
175
+ image = np.asarray(image_source)
176
+ image_transformed, _ = transform(image_source, None)
177
+ return image, image_transformed
178
+
179
+ def draw_kps_multi(image_pil, kps_list, color_list=[(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255)]):
180
+ stickwidth = 4
181
+ limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]])
182
+
183
+
184
+ w, h = image_pil.size
185
+ out_img = np.zeros([h, w, 3])
186
+
187
+ for kps in kps_list:
188
+ kps = np.array(kps)
189
+ for i in range(len(limbSeq)):
190
+ index = limbSeq[i]
191
+ color = color_list[index[0]]
192
+
193
+ x = kps[index][:, 0]
194
+ y = kps[index][:, 1]
195
+ length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5
196
+ angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1]))
197
+ polygon = cv2.ellipse2Poly((int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0,
198
+ 360, 1)
199
+ out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color)
200
+ out_img = (out_img * 0.6).astype(np.uint8)
201
+
202
+ for idx_kp, kp in enumerate(kps):
203
+ color = color_list[idx_kp]
204
+ x, y = kp
205
+ out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1)
206
+
207
+ out_img_pil = PIL.Image.fromarray(out_img.astype(np.uint8))
208
+ return out_img_pil
209
+
210
+ def predict_mask(segmentmodel, sam, image, TEXT_PROMPT, segmentType, confidence = 0.2, threshold = 0.5):
211
+ if segmentType=='GroundingDINO':
212
+ image_source, image = load_image_dino(image)
213
+ boxes, logits, phrases = predict(
214
+ model=segmentmodel,
215
+ image=image,
216
+ caption=TEXT_PROMPT,
217
+ box_threshold=0.3,
218
+ text_threshold=0.25
219
+ )
220
+ sam.set_image(image_source)
221
+ H, W, _ = image_source.shape
222
+ boxes_xyxy = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H])
223
+
224
+ transformed_boxes = sam.transform.apply_boxes_torch(boxes_xyxy, image_source.shape[:2]).cuda()
225
+ masks, _, _ = sam.predict_torch(
226
+ point_coords=None,
227
+ point_labels=None,
228
+ boxes=transformed_boxes,
229
+ multimask_output=False,
230
+ )
231
+ masks=masks[0].squeeze(0)
232
+ else:
233
+ image_source = load_image_yoloworld(image)
234
+ segmentmodel.set_classes(TEXT_PROMPT)
235
+ results = segmentmodel.infer(image_source, confidence=confidence)
236
+ detections = sv.Detections.from_inference(results).with_nms(
237
+ class_agnostic=True, threshold=threshold
238
+ )
239
+
240
+ masks_list = []
241
+ sam.set_image(image_source, image_format="RGB")
242
+ for xyxy in detections.xyxy:
243
+ mask, _, _ = sam.predict(box=xyxy, multimask_output=False)
244
+ masks_list.append(mask.squeeze())
245
+ detections.mask = np.array(masks_list)
246
+
247
+ mask_1 = []
248
+ mask_2 = []
249
+ for i, (class_id, confidence) in enumerate(zip(detections.class_id, detections.confidence)):
250
+ if class_id==0:
251
+ mask_1.append(torch.from_numpy(detections.mask[i]))
252
+ if class_id==1:
253
+ mask_2.append(torch.from_numpy(detections.mask[i]))
254
+ if len(mask_1)==0:
255
+ mask_1.append(None)
256
+ if len(mask_2)==0:
257
+ mask_2.append(None)
258
+ if len(TEXT_PROMPT)==2:
259
+ return mask_1[0], mask_2[0]
260
+
261
+ return mask_1[0]
262
+
263
+ def build_model_sd(pretrained_model, controlnet_path, face_adapter, device, prompts, antelopev2_path, width, height, style_lora):
264
+ controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
265
+ pipe = InstantidMultiConceptPipeline.from_pretrained(
266
+ pretrained_model, controlnet=controlnet, torch_dtype=torch.float16, variant="fp16").to(device)
267
+
268
+ controller = AttentionReplace(prompts, 50, cross_replace_steps={"default_": 1.},
269
+ self_replace_steps=0.4, tokenizer=pipe.tokenizer, device=device, width=width, height=height,
270
+ dtype=torch.float16)
271
+ revise_regionally_controlnet_forward(pipe.unet, controller)
272
+
273
+ controlnet_concept = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
274
+ pipe_concept = InstantidSingleConceptPipeline.from_pretrained(
275
+ pretrained_model,
276
+ controlnet=controlnet_concept,
277
+ torch_dtype=torch.float16
278
+ )
279
+ pipe_concept.load_ip_adapter_instantid(face_adapter)
280
+ pipe_concept.set_ip_adapter_scale(0.8)
281
+ pipe_concept.to(device)
282
+ pipe_concept.image_proj_model.to(pipe_concept._execution_device)
283
+
284
+ if style_lora is not None and os.path.exists(style_lora):
285
+ pipe.load_lora_weights(style_lora, weight_name="pytorch_lora_weights.safetensors", adapter_name='style')
286
+ pipe_concept.load_lora_weights(style_lora, weight_name="pytorch_lora_weights.safetensors", adapter_name='style')
287
+
288
+
289
+ # modify
290
+ app = FaceAnalysis(name='antelopev2', root=antelopev2_path,
291
+ providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
292
+ app.prepare(ctx_id=0, det_size=(640, 640))
293
+
294
+ return pipe, controller, pipe_concept, app
295
+
296
+
297
+ def prepare_text(prompt, region_prompts):
298
+ '''
299
+ Args:
300
+ prompt_entity: [subject1]-*-[attribute1]-*-[Location1]|[subject2]-*-[attribute2]-*-[Location2]|[global text]
301
+ Returns:
302
+ full_prompt: subject1, attribute1 and subject2, attribute2, global text
303
+ context_prompt: subject1 and subject2, global text
304
+ entity_collection: [(subject1, attribute1), Location1]
305
+ '''
306
+ region_collection = []
307
+
308
+ regions = region_prompts.split('|')
309
+
310
+ for region in regions:
311
+ if region == '':
312
+ break
313
+ prompt_region, neg_prompt_region, ref_img = region.split('-*-')
314
+ prompt_region = prompt_region.replace('[', '').replace(']', '')
315
+ neg_prompt_region = neg_prompt_region.replace('[', '').replace(']', '')
316
+
317
+ region_collection.append((prompt_region, neg_prompt_region, ref_img))
318
+ return (prompt, region_collection)
319
+
320
+ def build_model_lora(pipe, pipe_concept, style_path, condition, condition_img):
321
+ if condition == "Human pose" and condition_img is not None:
322
+ controlnet = ControlNetModel.from_pretrained(args.openpose_checkpoint, torch_dtype=torch.float16).to(device)
323
+ pipe.controlnet2 = controlnet
324
+ elif condition == "Canny Edge" and condition_img is not None:
325
+ controlnet = ControlNetModel.from_pretrained(args.canny_checkpoint, torch_dtype=torch.float16, variant="fp16").to(device)
326
+ pipe.controlnet2 = controlnet
327
+ elif condition == "Depth" and condition_img is not None:
328
+ controlnet = ControlNetModel.from_pretrained(args.depth_checkpoint, torch_dtype=torch.float16).to(device)
329
+ pipe.controlnet2 = controlnet
330
+
331
+ if style_path is not None and os.path.exists(style_path):
332
+ pipe_concept.load_lora_weights(style_path, weight_name="pytorch_lora_weights.safetensors", adapter_name='style')
333
+ pipe.load_lora_weights(style_path, weight_name="pytorch_lora_weights.safetensors", adapter_name='style')
334
+
335
+ def resize_and_center_crop(image, output_size=(1024, 576)):
336
+ width, height = image.size
337
+ aspect_ratio = width / height
338
+ new_height = output_size[1]
339
+ new_width = int(aspect_ratio * new_height)
340
+
341
+ resized_image = image.resize((new_width, new_height), Image.LANCZOS)
342
+
343
+ if new_width < output_size[0] or new_height < output_size[1]:
344
+ padding_color = "gray"
345
+ resized_image = ImageOps.expand(resized_image,
346
+ ((output_size[0] - new_width) // 2,
347
+ (output_size[1] - new_height) // 2,
348
+ (output_size[0] - new_width + 1) // 2,
349
+ (output_size[1] - new_height + 1) // 2),
350
+ fill=padding_color)
351
+
352
+ left = (resized_image.width - output_size[0]) / 2
353
+ top = (resized_image.height - output_size[1]) / 2
354
+ right = (resized_image.width + output_size[0]) / 2
355
+ bottom = (resized_image.height + output_size[1]) / 2
356
+
357
+ cropped_image = resized_image.crop((left, top, right, bottom))
358
+
359
+ return cropped_image
360
+
361
+ def main(device, segment_type):
362
+ pipe, controller, pipe_concepts, face_app = build_model_sd(args.pretrained_model, args.controlnet_path,
363
+ args.face_adapter_path, device, prompts_tmp,
364
+ args.antelopev2_path, width // 32, height // 32,
365
+ args.style_lora)
366
+ if segment_type == 'GroundingDINO':
367
+ detect_model, sam = build_dino_segment_model(args.dino_checkpoint, args.sam_checkpoint)
368
+ else:
369
+ detect_model, sam = build_yolo_segment_model(args.efficientViT_checkpoint, device)
370
+
371
+ resolution_list = ["1440*728",
372
+ "1344*768",
373
+ "1216*832",
374
+ "1152*896",
375
+ "1024*1024",
376
+ "896*1152",
377
+ "832*1216",
378
+ "768*1344",
379
+ "728*1440"]
380
+ ratio_list = [1440 / 728, 1344 / 768, 1216 / 832, 1152 / 896, 1024 / 1024, 896 / 1152, 832 / 1216, 768 / 1344,
381
+ 728 / 1440]
382
+ condition_list = ["None",
383
+ "Human pose",
384
+ "Canny Edge",
385
+ "Depth"]
386
+
387
+ depth_estimator = DPTForDepthEstimation.from_pretrained(args.dpt_checkpoint).to("cuda")
388
+ feature_extractor = DPTFeatureExtractor.from_pretrained(args.dpt_checkpoint)
389
+ body_model = Body(args.pose_detector_checkpoint)
390
+ openpose = OpenposeDetector(body_model)
391
+
392
+ prompts_rewrite = [args.prompt_rewrite]
393
+ input_prompt_test = [prepare_text(p, p_w) for p, p_w in zip(prompts, prompts_rewrite)]
394
+ input_prompt_test = [prompts, input_prompt_test[0][1]]
395
+
396
+ def remove_tips():
397
+ return gr.update(visible=False)
398
+
399
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
400
+ if randomize_seed:
401
+ seed = random.randint(0, MAX_SEED)
402
+ return seed
403
+
404
+ def get_humanpose(img):
405
+ openpose_image = openpose(img)
406
+ return openpose_image
407
+
408
+ def get_cannyedge(image):
409
+ image = np.array(image)
410
+ image = cv2.Canny(image, 100, 200)
411
+ image = image[:, :, None]
412
+ image = np.concatenate([image, image, image], axis=2)
413
+ canny_image = Image.fromarray(image)
414
+ return canny_image
415
+
416
+ def get_depth(image):
417
+ image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda")
418
+ with torch.no_grad(), torch.autocast("cuda"):
419
+ depth_map = depth_estimator(image).predicted_depth
420
+
421
+ depth_map = torch.nn.functional.interpolate(
422
+ depth_map.unsqueeze(1),
423
+ size=(1024, 1024),
424
+ mode="bicubic",
425
+ align_corners=False,
426
+ )
427
+ depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
428
+ depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
429
+ depth_map = (depth_map - depth_min) / (depth_max - depth_min)
430
+ image = torch.cat([depth_map] * 3, dim=1)
431
+ image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
432
+ image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
433
+ return image
434
+
435
+ def generate_image(prompt1, negative_prompt, reference_1, reference_2, resolution, local_prompt1, local_prompt2, seed, style, identitynet_strength_ratio, adapter_strength_ratio, condition, condition_img, controlnet_ratio):
436
+ identitynet_strength_ratio = float(identitynet_strength_ratio)
437
+ adapter_strength_ratio = float(adapter_strength_ratio)
438
+ controlnet_ratio = float(controlnet_ratio)
439
+ if lorapath_styles[style] is not None and os.path.exists(lorapath_styles[style]):
440
+ styleL = True
441
+ else:
442
+ styleL = False
443
+
444
+ width, height = int(resolution.split("*")[0]), int(resolution.split("*")[1])
445
+ kwargs = {
446
+ 'height': height,
447
+ 'width': width,
448
+ 't2i_controlnet_conditioning_scale': controlnet_ratio,
449
+ }
450
+
451
+ if condition == 'Human pose' and condition_img is not None:
452
+ index = ratio_list.index(
453
+ min(ratio_list, key=lambda x: abs(x - condition_img.shape[1] / condition_img.shape[0])))
454
+ resolution = resolution_list[index]
455
+ width, height = int(resolution.split("*")[0]), int(resolution.split("*")[1])
456
+ kwargs['height'] = height
457
+ kwargs['width'] = width
458
+ condition_img = resize_and_center_crop(Image.fromarray(condition_img), (width, height))
459
+ spatial_condition = get_humanpose(condition_img)
460
+ elif condition == 'Canny Edge' and condition_img is not None:
461
+ index = ratio_list.index(
462
+ min(ratio_list, key=lambda x: abs(x - condition_img.shape[1] / condition_img.shape[0])))
463
+ resolution = resolution_list[index]
464
+ width, height = int(resolution.split("*")[0]), int(resolution.split("*")[1])
465
+ kwargs['height'] = height
466
+ kwargs['width'] = width
467
+ condition_img = resize_and_center_crop(Image.fromarray(condition_img), (width, height))
468
+ spatial_condition = get_cannyedge(condition_img)
469
+ elif condition == 'Depth' and condition_img is not None:
470
+ index = ratio_list.index(
471
+ min(ratio_list, key=lambda x: abs(x - condition_img.shape[1] / condition_img.shape[0])))
472
+ resolution = resolution_list[index]
473
+ width, height = int(resolution.split("*")[0]), int(resolution.split("*")[1])
474
+ kwargs['height'] = height
475
+ kwargs['width'] = width
476
+ condition_img = resize_and_center_crop(Image.fromarray(condition_img), (width, height))
477
+ spatial_condition = get_depth(condition_img)
478
+ else:
479
+ spatial_condition = None
480
+
481
+ kwargs['t2i_image'] = spatial_condition
482
+ pipe.unload_lora_weights()
483
+ pipe_concepts.unload_lora_weights()
484
+ build_model_lora(pipe, pipe_concepts, lorapath_styles[style], condition, condition_img)
485
+ pipe_concepts.set_ip_adapter_scale(adapter_strength_ratio)
486
+
487
+ input_list = [prompt1]
488
+
489
+
490
+ for prompt in input_list:
491
+ if prompt != '':
492
+ input_prompt = []
493
+ p = '{prompt}, 35mm photograph, film, professional, 4k, highly detailed.'
494
+ if styleL:
495
+ p = styles[style] + p
496
+ input_prompt.append([p.replace('{prompt}', prompt), p.replace("{prompt}", prompt)])
497
+ if styleL:
498
+ input_prompt.append([(styles[style] + local_prompt1, 'noisy, blurry, soft, deformed, ugly',
499
+ PIL.Image.fromarray(reference_1)),
500
+ (styles[style] + local_prompt2, 'noisy, blurry, soft, deformed, ugly',
501
+ PIL.Image.fromarray(reference_2))])
502
+ else:
503
+ input_prompt.append(
504
+ [(local_prompt1, 'noisy, blurry, soft, deformed, ugly', PIL.Image.fromarray(reference_1)),
505
+ (local_prompt2, 'noisy, blurry, soft, deformed, ugly', PIL.Image.fromarray(reference_2))])
506
+
507
+
508
+ controller.reset()
509
+ image = sample_image(
510
+ pipe,
511
+ input_prompt=input_prompt,
512
+ concept_models=pipe_concepts,
513
+ input_neg_prompt=[negative_prompt] * len(input_prompt),
514
+ generator=torch.Generator(device).manual_seed(seed),
515
+ controller=controller,
516
+ face_app=face_app,
517
+ controlnet_conditioning_scale=identitynet_strength_ratio,
518
+ stage=1,
519
+ **kwargs)
520
+
521
+ controller.reset()
522
+
523
+ if (pipe.tokenizer("man")["input_ids"][1] in pipe.tokenizer(args.prompt)["input_ids"][1:-1]) and (
524
+ pipe.tokenizer("woman")["input_ids"][1] in pipe.tokenizer(args.prompt)["input_ids"][1:-1]):
525
+ mask1, mask2 = predict_mask(detect_model, sam, image[0], ['man', 'woman'], args.segment_type, confidence=0.3,
526
+ threshold=0.5)
527
+
528
+ elif pipe.tokenizer("man")["input_ids"][1] in pipe.tokenizer(args.prompt)["input_ids"][1:-1]:
529
+ mask1 = predict_mask(detect_model, sam, image[0], ['man'], args.segment_type, confidence=0.3,
530
+ threshold=0.5)
531
+ mask2 = None
532
+
533
+ elif pipe.tokenizer("woman")["input_ids"][1] in pipe.tokenizer(args.prompt)["input_ids"][1:-1]:
534
+ mask2 = predict_mask(detect_model, sam, image[0], ['woman'], args.segment_type, confidence=0.3,
535
+ threshold=0.5)
536
+ mask1 = None
537
+ else:
538
+ mask1 = mask2 = None
539
+
540
+ if mask1 is not None or mask2 is not None:
541
+ face_info = face_app.get(cv2.cvtColor(np.array(image[0]), cv2.COLOR_RGB2BGR))
542
+ face_kps = draw_kps_multi(image[0], [face['kps'] for face in face_info])
543
+
544
+ image = sample_image(
545
+ pipe,
546
+ input_prompt=input_prompt,
547
+ concept_models=pipe_concepts,
548
+ input_neg_prompt=[negative_prompt] * len(input_prompt),
549
+ generator=torch.Generator(device).manual_seed(seed),
550
+ controller=controller,
551
+ face_app=face_app,
552
+ image=face_kps,
553
+ stage=2,
554
+ controlnet_conditioning_scale=identitynet_strength_ratio,
555
+ region_masks=[mask1, mask2],
556
+ **kwargs)
557
+
558
+ # return [image[1], spatial_condition]
559
+ return image
560
+
561
+ with gr.Blocks(css=css) as demo:
562
+ # description
563
+ gr.Markdown(title)
564
+ gr.Markdown(description)
565
+
566
+ with gr.Row():
567
+ gallery = gr.Image(label="Generated Images", height=512, width=512)
568
+ gallery1 = gr.Image(label="Generated Images", height=512, width=512)
569
+ usage_tips = gr.Markdown(label="Usage tips of OMG", value=tips, visible=False)
570
+
571
+
572
+ with gr.Row():
573
+ reference_1 = gr.Image(label="Input an RGB image for Character man", height=128, width=128)
574
+ reference_2 = gr.Image(label="Input an RGB image for Character woman", height=128, width=128)
575
+ condition_img1 = gr.Image(label="Input an RGB image for condition (Optional)", height=128, width=128)
576
+
577
+
578
+
579
+
580
+ with gr.Row():
581
+ local_prompt1 = gr.Textbox(label="Character1_prompt",
582
+ info="Describe the Character 1",
583
+ value="Close-up photo of the a man, 35mm photograph, professional, 4k, highly detailed.")
584
+ local_prompt2 = gr.Textbox(label="Character2_prompt",
585
+ info="Describe the Character 2",
586
+ value="Close-up photo of the a woman, 35mm photograph, professional, 4k, highly detailed.")
587
+ with gr.Row():
588
+ identitynet_strength_ratio = gr.Slider(
589
+ label="IdentityNet strength (for fidelity)",
590
+ minimum=0,
591
+ maximum=1.5,
592
+ step=0.05,
593
+ value=0.80,
594
+ )
595
+ adapter_strength_ratio = gr.Slider(
596
+ label="Image adapter strength (for detail)",
597
+ minimum=0,
598
+ maximum=1.5,
599
+ step=0.05,
600
+ value=0.80,
601
+ )
602
+ controlnet_ratio = gr.Slider(
603
+ label="ControlNet strength",
604
+ minimum=0,
605
+ maximum=1.5,
606
+ step=0.05,
607
+ value=1,
608
+ )
609
+ resolution = gr.Dropdown(label="Image Resolution (width*height)", choices=resolution_list,
610
+ value="1024*1024")
611
+ style = gr.Dropdown(label="style", choices=STYLE_NAMES, value="None")
612
+ condition = gr.Dropdown(label="Input condition type", choices=condition_list, value="None")
613
+
614
+
615
+ # prompt
616
+ with gr.Column():
617
+ prompt = gr.Textbox(label="Prompt 1",
618
+ info="Give a simple prompt to describe the first image content",
619
+ placeholder="Required",
620
+ value="close-up shot, photography, a man and a woman on the street, facing the camera smiling")
621
+
622
+
623
+ with gr.Accordion(open=False, label="Advanced Options"):
624
+ seed = gr.Slider(
625
+ label="Seed",
626
+ minimum=0,
627
+ maximum=MAX_SEED,
628
+ step=1,
629
+ value=42,
630
+ )
631
+ negative_prompt = gr.Textbox(label="Negative Prompt",
632
+ placeholder="noisy, blurry, soft, deformed, ugly",
633
+ value="noisy, blurry, soft, deformed, ugly")
634
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
635
+
636
+ submit = gr.Button("Submit", variant="primary")
637
+
638
+ submit.click(
639
+ fn=remove_tips,
640
+ outputs=usage_tips,
641
+ ).then(
642
+ fn=randomize_seed_fn,
643
+ inputs=[seed, randomize_seed],
644
+ outputs=seed,
645
+ queue=False,
646
+ api_name=False,
647
+ ).then(
648
+ fn=generate_image,
649
+ inputs=[prompt, negative_prompt, reference_1, reference_2, resolution, local_prompt1, local_prompt2, seed, style, identitynet_strength_ratio, adapter_strength_ratio, condition, condition_img1, controlnet_ratio],
650
+ outputs=[gallery, gallery1]
651
+ )
652
+ demo.launch(server_name='0.0.0.0',server_port=7861, debug=True)
653
+
654
+ def parse_args():
655
+ parser = argparse.ArgumentParser('', add_help=False)
656
+ parser.add_argument('--pretrained_model', default='/home/data1/kz_dir/checkpoint/YamerMIX_v8', type=str)
657
+ parser.add_argument('--controlnet_path', default='../checkpoint/InstantID/ControlNetModel', type=str)
658
+ parser.add_argument('--face_adapter_path', default='../checkpoint/InstantID/ip-adapter.bin', type=str)
659
+ parser.add_argument('--openpose_checkpoint', default='../checkpoint/controlnet-openpose-sdxl-1.0', type=str)
660
+ parser.add_argument('--canny_checkpoint', default='../checkpoint/controlnet-canny-sdxl-1.0', type=str)
661
+ parser.add_argument('--depth_checkpoint', default='../checkpoint/controlnet-depth-sdxl-1.0', type=str)
662
+ parser.add_argument('--dpt_checkpoint', default='../checkpoint/dpt-hybrid-midas', type=str)
663
+ parser.add_argument('--pose_detector_checkpoint',
664
+ default='../checkpoint/ControlNet/annotator/ckpts/body_pose_model.pth', type=str)
665
+ parser.add_argument('--efficientViT_checkpoint', default='../checkpoint/sam/xl1.pt', type=str)
666
+ parser.add_argument('--dino_checkpoint', default='../checkpoint/GroundingDINO', type=str)
667
+ parser.add_argument('--sam_checkpoint', default='../checkpoint/sam/sam_vit_h_4b8939.pth', type=str)
668
+ parser.add_argument('--antelopev2_path', default='../checkpoint/antelopev2', type=str)
669
+ parser.add_argument('--save_dir', default='results/instantID', type=str)
670
+ parser.add_argument('--prompt', default='Close-up photo of the cool man and beautiful woman as they accidentally discover a mysterious island while on vacation by the sea, facing the camera smiling, 35mm photograph, film, professional, 4k, highly detailed.', type=str)
671
+ parser.add_argument('--negative_prompt', default='noisy, blurry, soft, deformed, ugly', type=str)
672
+ parser.add_argument('--prompt_rewrite',
673
+ default='[Close-up photo of a man, 35mm photograph, professional, 4k, highly detailed.]-*'
674
+ '-[noisy, blurry, soft, deformed, ugly]-*-'
675
+ '../example/chris-evans.jpg|'
676
+ '[Close-up photo of a woman, 35mm photograph, professional, 4k, highly detailed.]-'
677
+ '*-[noisy, blurry, soft, deformed, ugly]-*-'
678
+ '../example/TaylorSwift.png',
679
+ type=str)
680
+ parser.add_argument('--seed', default=0, type=int)
681
+ parser.add_argument('--suffix', default='', type=str)
682
+ parser.add_argument('--segment_type', default='yoloworld', help='GroundingDINO or yoloworld', type=str)
683
+ parser.add_argument('--style_lora', default='', type=str)
684
+ return parser.parse_args()
685
+
686
+ if __name__ == '__main__':
687
+ args = parse_args()
688
+
689
+ prompts = [args.prompt] * 2
690
+
691
+ prompts_tmp = copy.deepcopy(prompts)
692
+
693
+ width, height = 1024, 1024
694
+ kwargs = {
695
+ 'height': height,
696
+ 'width': width,
697
+ }
698
+
699
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
700
+ main(device, args.segment_type)
701
+
gradio_demo/character_template.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ character_list_man = [
2
+ {
3
+ "name": "Chris Evans (identifier: Chris Evans)",
4
+ "prompt": "Close-up photo of the Chris Evans, 35mm photograph, film, professional, 4k, highly detailed.",
5
+ "negative_prompt": "noisy, blurry, soft, deformed, ugly",
6
+ "path": "./checkpoint/lora/chris-evans.safetensors",
7
+ },
8
+ {
9
+ "name": "Harry Potter (identifier: Harry Potter)",
10
+ "prompt": "Close-up photo of the Harry Potter, 35mm photograph, film, professional, 4k, highly detailed.",
11
+ "negative_prompt": "noisy, blurry, soft, deformed, ugly",
12
+ "path": "./checkpoint/lora/Harry_Potter.safetensors",
13
+ },
14
+ {
15
+ "name": "Jordan Torres (identifier: jordan_torres)",
16
+ "prompt": "Close-up photo of the jordan_torres man, 35mm photograph, film, professional, 4k, highly detailed.",
17
+ "negative_prompt": "noisy, blurry, soft, deformed, ugly",
18
+ "path": "./checkpoint/lora/jordan_torres_v2_xl.safetensors",
19
+ },
20
+ {
21
+ "name": "Gleb Savchenko (identifier: Liam Hemsworth a man)",
22
+ "prompt": "Close-up photo of Liam Hemsworth a man, 35mm photograph, film, professional, 4k, highly detailed.",
23
+ "negative_prompt": "noisy, blurry, soft, deformed, ugly",
24
+ "path": "./checkpoint/lora/gleb_savchenko_sdxl.safetensors",
25
+ },
26
+ ]
27
+
28
+ character_list_woman = [
29
+ {
30
+ "name": "Taylor Swift (identifier: TaylorSwift)",
31
+ "prompt": "Close-up photo of the TaylorSwift, 35mm photograph, film, professional, 4k, highly detailed.",
32
+ "negative_prompt": "noisy, blurry, soft, deformed, ugly",
33
+ "path": "./checkpoint/lora/TaylorSwiftSDXL.safetensors",
34
+ },
35
+ {
36
+ "name": "Hermione Granger (identifier: Hermione Granger)",
37
+ "prompt": "Close-up photo of the Hermione Granger, 35mm photograph, film, professional, 4k, highly detailed.",
38
+ "negative_prompt": "noisy, blurry, soft, deformed, ugly",
39
+ "path": "./checkpoint/lora/Hermione_Granger.safetensors",
40
+ },
41
+ {
42
+ "name": "Keira Knightley (identifier: ohwx woman)",
43
+ "prompt": "Close-up photo of the ohwx woman, 35mm photograph, film, professional, 4k, highly detailed.",
44
+ "negative_prompt": "noisy, blurry, soft, deformed, ugly",
45
+ "path": "./checkpoint/lora/keira_lora_sdxl_v1-000008.safetensors",
46
+ },
47
+ {
48
+ "name": "Jennifer Lawrence (identifier: Jennifer Lawrence WOMAN)",
49
+ "prompt": "Close-up photo of the Jennifer Lawrence WOMAN, 35mm photograph, film, professional, 4k, highly detailed.",
50
+ "negative_prompt": "noisy, blurry, soft, deformed, ugly",
51
+ "path": "./checkpoint/lora/lawrence_dh128_v1-step00012000.safetensors",
52
+ },
53
+ ]
54
+
55
+ style_list = [
56
+ {
57
+ "name": "None",
58
+ "prompt": "",
59
+ "path": "",
60
+ },
61
+ {
62
+ "name": "Cinematic Photography Style",
63
+ "prompt": "Cinematic Hollywood Film Style, ",
64
+ "path": "../checkpoint/style/Cinematic Hollywood Film.safetensors",
65
+ },
66
+ {
67
+ "name": "Anime sketch style",
68
+ "prompt": "Pencil_Sketch:1.2, messy lines, greyscale, traditional media, sketch, ",
69
+ "path": "../checkpoint/style/Anime_Sketch_SDXL.safetensors",
70
+ },
71
+ {
72
+ "name": "Oil Painting Style",
73
+ "prompt": "palette knife painting, ",
74
+ "path": "../checkpoint/style/EldritchPaletteKnife.safetensors",
75
+ }
76
+ ]
77
+
78
+ character_man = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in character_list_man}
79
+ character_woman = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in character_list_woman}
80
+ styles = {k["name"]: (k["prompt"]) for k in style_list}
81
+
82
+ lorapath_man = {k["name"]: (k["path"]) for k in character_list_man}
83
+ lorapath_woman = {k["name"]: (k["path"]) for k in character_list_woman}
84
+ lorapath_styles = {k["name"]: (k["path"]) for k in style_list}