Fucius commited on
Commit
f8624d0
1 Parent(s): ace3e37

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +584 -0
  2. download.py +83 -0
app.py ADDED
@@ -0,0 +1,584 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append('./')
3
+ import gradio as gr
4
+ import random
5
+ import numpy as np
6
+ from gradio_demo.character_template import character_man, lorapath_man
7
+ from gradio_demo.character_template import character_woman, lorapath_woman
8
+ from gradio_demo.character_template import styles, lorapath_styles
9
+ import torch
10
+ import os
11
+ from typing import Tuple, List
12
+ import copy
13
+ import argparse
14
+ from diffusers.utils import load_image
15
+ import cv2
16
+ from PIL import Image, ImageOps
17
+ from transformers import DPTFeatureExtractor, DPTForDepthEstimation
18
+ from controlnet_aux import OpenposeDetector
19
+ from controlnet_aux.open_pose.body import Body
20
+
21
+ try:
22
+ from inference.models import YOLOWorld
23
+ from src.efficientvit.models.efficientvit.sam import EfficientViTSamPredictor
24
+ from src.efficientvit.sam_model_zoo import create_sam_model
25
+ import supervision as sv
26
+ except:
27
+ print("YoloWorld can not be load")
28
+
29
+ try:
30
+ from groundingdino.models import build_model
31
+ from groundingdino.util import box_ops
32
+ from groundingdino.util.slconfig import SLConfig
33
+ from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
34
+ from groundingdino.util.inference import annotate, predict
35
+ from segment_anything import build_sam, SamPredictor
36
+ import groundingdino.datasets.transforms as T
37
+ except:
38
+ print("groundingdino can not be load")
39
+
40
+ from src.pipelines.lora_pipeline import LoraMultiConceptPipeline
41
+ from src.prompt_attention.p2p_attention import AttentionReplace
42
+ from diffusers import ControlNetModel, StableDiffusionXLPipeline
43
+ from src.pipelines.lora_pipeline import revise_regionally_controlnet_forward
44
+
45
+ from download import OMG_download
46
+
47
+ CHARACTER_MAN_NAMES = list(character_man.keys())
48
+ CHARACTER_WOMAN_NAMES = list(character_woman.keys())
49
+ STYLE_NAMES = list(styles.keys())
50
+ MAX_SEED = np.iinfo(np.int32).max
51
+
52
+ ### Description
53
+ title = r"""
54
+ <h1 align="center">OMG: Occlusion-friendly Personalized Multi-concept Generation In Diffusion Models</h1>
55
+ """
56
+
57
+ description = r"""
58
+ <b>Official 🤗 Gradio demo</b> for <a href='https://github.com/' target='_blank'><b>OMG: Occlusion-friendly Personalized Multi-concept Generation In Diffusion Models</b></a>.<br>
59
+
60
+ How to use:<br>
61
+ 1. Select two characters.
62
+ 2. Enter a text prompt as done in normal text-to-image models.
63
+ 3. Click the <b>Submit</b> button to start customizing.
64
+ 4. Enjoy the generated image😊!
65
+ """
66
+
67
+ article = r"""
68
+ ---
69
+ 📝 **Citation**
70
+ <br>
71
+ If our work is helpful for your research or applications, please cite us via:
72
+ ```bibtex
73
+ @article{,
74
+ title={OMG: Occlusion-friendly Personalized Multi-concept Generation In Diffusion Models},
75
+ author={},
76
+ journal={},
77
+ year={}
78
+ }
79
+ ```
80
+ """
81
+
82
+ tips = r"""
83
+ ### Usage tips of OMG
84
+ 1. Input text prompts to describe a man and a woman
85
+ """
86
+
87
+ css = '''
88
+ .gradio-container {width: 85% !important}
89
+ '''
90
+
91
+ def sample_image(pipe,
92
+ input_prompt,
93
+ input_neg_prompt=None,
94
+ generator=None,
95
+ concept_models=None,
96
+ num_inference_steps=50,
97
+ guidance_scale=7.5,
98
+ controller=None,
99
+ stage=None,
100
+ region_masks=None,
101
+ lora_list = None,
102
+ styleL=None,
103
+ **extra_kargs
104
+ ):
105
+
106
+ spatial_condition = extra_kargs.pop('spatial_condition')
107
+ if spatial_condition is not None:
108
+ spatial_condition_input = [spatial_condition] * len(input_prompt)
109
+ else:
110
+ spatial_condition_input = None
111
+
112
+ images = pipe(
113
+ prompt=input_prompt,
114
+ concept_models=concept_models,
115
+ negative_prompt=input_neg_prompt,
116
+ generator=generator,
117
+ guidance_scale=guidance_scale,
118
+ num_inference_steps=num_inference_steps,
119
+ cross_attention_kwargs={"scale": 0.8},
120
+ controller=controller,
121
+ stage=stage,
122
+ region_masks=region_masks,
123
+ lora_list=lora_list,
124
+ styleL=styleL,
125
+ image=spatial_condition_input,
126
+ **extra_kargs).images
127
+
128
+ return images
129
+
130
+ def load_image_yoloworld(image_source) -> Tuple[np.array, torch.Tensor]:
131
+ image = np.asarray(image_source)
132
+ return image
133
+
134
+ def load_image_dino(image_source) -> Tuple[np.array, torch.Tensor]:
135
+ transform = T.Compose(
136
+ [
137
+ T.RandomResize([800], max_size=1333),
138
+ T.ToTensor(),
139
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
140
+ ]
141
+ )
142
+ image = np.asarray(image_source)
143
+ image_transformed, _ = transform(image_source, None)
144
+ return image, image_transformed
145
+
146
+ def predict_mask(segmentmodel, sam, image, TEXT_PROMPT, segmentType, confidence = 0.2, threshold = 0.5):
147
+ if segmentType=='GroundingDINO':
148
+ image_source, image = load_image_dino(image)
149
+ boxes, logits, phrases = predict(
150
+ model=segmentmodel,
151
+ image=image,
152
+ caption=TEXT_PROMPT,
153
+ box_threshold=0.3,
154
+ text_threshold=0.25
155
+ )
156
+ sam.set_image(image_source)
157
+ H, W, _ = image_source.shape
158
+ boxes_xyxy = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H])
159
+
160
+ transformed_boxes = sam.transform.apply_boxes_torch(boxes_xyxy, image_source.shape[:2]).cuda()
161
+ masks, _, _ = sam.predict_torch(
162
+ point_coords=None,
163
+ point_labels=None,
164
+ boxes=transformed_boxes,
165
+ multimask_output=False,
166
+ )
167
+ masks=masks[0].squeeze(0)
168
+ else:
169
+ image_source = load_image_yoloworld(image)
170
+ segmentmodel.set_classes([TEXT_PROMPT])
171
+ results = segmentmodel.infer(image_source, confidence=confidence)
172
+ detections = sv.Detections.from_inference(results).with_nms(
173
+ class_agnostic=True, threshold=threshold
174
+ )
175
+ masks = None
176
+ if len(detections) != 0:
177
+ print(TEXT_PROMPT + " detected!")
178
+ sam.set_image(image_source, image_format="RGB")
179
+ masks, _, _ = sam.predict(box=detections.xyxy[0], multimask_output=False)
180
+ masks = torch.from_numpy(masks.squeeze())
181
+
182
+ return masks
183
+
184
+ def prepare_text(prompt, region_prompts):
185
+ '''
186
+ Args:
187
+ prompt_entity: [subject1]-*-[attribute1]-*-[Location1]|[subject2]-*-[attribute2]-*-[Location2]|[global text]
188
+ Returns:
189
+ full_prompt: subject1, attribute1 and subject2, attribute2, global text
190
+ context_prompt: subject1 and subject2, global text
191
+ entity_collection: [(subject1, attribute1), Location1]
192
+ '''
193
+ region_collection = []
194
+
195
+ regions = region_prompts.split('|')
196
+
197
+ for region in regions:
198
+ if region == '':
199
+ break
200
+ prompt_region, neg_prompt_region = region.split('-*-')
201
+ prompt_region = prompt_region.replace('[', '').replace(']', '')
202
+ neg_prompt_region = neg_prompt_region.replace('[', '').replace(']', '')
203
+
204
+ region_collection.append((prompt_region, neg_prompt_region))
205
+ return (prompt, region_collection)
206
+
207
+
208
+ def build_model_sd(pretrained_model, controlnet_path, device, prompts):
209
+ controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16).to(device)
210
+ pipe = LoraMultiConceptPipeline.from_pretrained(
211
+ pretrained_model, controlnet=controlnet, torch_dtype=torch.float16, variant="fp16").to(device)
212
+ controller = AttentionReplace(prompts, 50, cross_replace_steps={"default_": 1.}, self_replace_steps=0.4, tokenizer=pipe.tokenizer, device=device, dtype=torch.float16, width=1024//32, height=1024//32)
213
+ revise_regionally_controlnet_forward(pipe.unet, controller)
214
+ pipe_concept = StableDiffusionXLPipeline.from_pretrained(pretrained_model, torch_dtype=torch.float16,
215
+ variant="fp16").to(device)
216
+ return pipe, controller, pipe_concept
217
+
218
+ def build_model_lora(pipe_concept, lora_paths, style_path, condition, args, pipe):
219
+ pipe_list = []
220
+ if condition == "Human pose":
221
+ controlnet = ControlNetModel.from_pretrained(args.openpose_checkpoint, torch_dtype=torch.float16).to(device)
222
+ pipe.controlnet = controlnet
223
+ elif condition == "Canny Edge":
224
+ controlnet = ControlNetModel.from_pretrained(args.canny_checkpoint, torch_dtype=torch.float16, variant="fp16").to(device)
225
+ pipe.controlnet = controlnet
226
+ elif condition == "Depth":
227
+ controlnet = ControlNetModel.from_pretrained(args.depth_checkpoint, torch_dtype=torch.float16).to(device)
228
+ pipe.controlnet = controlnet
229
+
230
+ if style_path is not None and os.path.exists(style_path):
231
+ pipe_concept.load_lora_weights(style_path, weight_name="pytorch_lora_weights.safetensors", adapter_name='style')
232
+ pipe.load_lora_weights(style_path, weight_name="pytorch_lora_weights.safetensors", adapter_name='style')
233
+
234
+ for lora_path in lora_paths.split('|'):
235
+ adapter_name = lora_path.split('/')[-1].split('.')[0]
236
+ pipe_concept.load_lora_weights(lora_path, weight_name="pytorch_lora_weights.safetensors", adapter_name=adapter_name)
237
+ pipe_concept.enable_xformers_memory_efficient_attention()
238
+ pipe_list.append(adapter_name)
239
+ return pipe_list
240
+
241
+ def build_yolo_segment_model(sam_path, device):
242
+ yolo_world = YOLOWorld(model_id="yolo_world/l")
243
+ sam = EfficientViTSamPredictor(
244
+ create_sam_model(name="xl1", weight_url=sam_path).to(device).eval()
245
+ )
246
+ return yolo_world, sam
247
+
248
+ def load_model_hf(repo_id, filename, ckpt_config_filename, device='cpu'):
249
+ args = SLConfig.fromfile(ckpt_config_filename)
250
+ model = build_model(args)
251
+ args.device = device
252
+
253
+ checkpoint = torch.load(os.path.join(repo_id, filename), map_location='cpu')
254
+ log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
255
+ print("Model loaded from {} \n => {}".format(filename, log))
256
+ _ = model.eval()
257
+ return model
258
+
259
+ def build_dino_segment_model(ckpt_repo_id, sam_checkpoint):
260
+ ckpt_filenmae = "groundingdino_swinb_cogcoor.pth"
261
+ ckpt_config_filename = os.path.join(ckpt_repo_id, "GroundingDINO_SwinB.cfg.py")
262
+ groundingdino_model = load_model_hf(ckpt_repo_id, ckpt_filenmae, ckpt_config_filename)
263
+ sam = build_sam(checkpoint=sam_checkpoint)
264
+ sam.cuda()
265
+ sam_predictor = SamPredictor(sam)
266
+ return groundingdino_model, sam_predictor
267
+
268
+ def resize_and_center_crop(image, output_size=(1024, 576)):
269
+ width, height = image.size
270
+ aspect_ratio = width / height
271
+ new_height = output_size[1]
272
+ new_width = int(aspect_ratio * new_height)
273
+
274
+ resized_image = image.resize((new_width, new_height), Image.LANCZOS)
275
+
276
+ if new_width < output_size[0] or new_height < output_size[1]:
277
+ padding_color = "gray"
278
+ resized_image = ImageOps.expand(resized_image,
279
+ ((output_size[0] - new_width) // 2,
280
+ (output_size[1] - new_height) // 2,
281
+ (output_size[0] - new_width + 1) // 2,
282
+ (output_size[1] - new_height + 1) // 2),
283
+ fill=padding_color)
284
+
285
+ left = (resized_image.width - output_size[0]) / 2
286
+ top = (resized_image.height - output_size[1]) / 2
287
+ right = (resized_image.width + output_size[0]) / 2
288
+ bottom = (resized_image.height + output_size[1]) / 2
289
+
290
+ cropped_image = resized_image.crop((left, top, right, bottom))
291
+
292
+ return cropped_image
293
+
294
+ def main(device, segment_type):
295
+ pipe, controller, pipe_concept = build_model_sd(args.pretrained_sdxl_model, args.openpose_checkpoint, device, prompts_tmp)
296
+
297
+ if segment_type == 'GroundingDINO':
298
+ detect_model, sam = build_dino_segment_model(args.dino_checkpoint, args.sam_checkpoint)
299
+ else:
300
+ detect_model, sam = build_yolo_segment_model(args.efficientViT_checkpoint, device)
301
+
302
+ resolution_list = ["1440*728",
303
+ "1344*768",
304
+ "1216*832",
305
+ "1152*896",
306
+ "1024*1024",
307
+ "896*1152",
308
+ "832*1216",
309
+ "768*1344",
310
+ "728*1440"]
311
+ ratio_list = [1440 / 728, 1344 / 768, 1216 / 832, 1152 / 896, 1024 / 1024, 896 / 1152, 832 / 1216, 768 / 1344,
312
+ 728 / 1440]
313
+ condition_list = ["None",
314
+ "Human pose",
315
+ "Canny Edge",
316
+ "Depth"]
317
+
318
+ depth_estimator = DPTForDepthEstimation.from_pretrained(args.dpt_checkpoint).to("cuda")
319
+ feature_extractor = DPTFeatureExtractor.from_pretrained(args.dpt_checkpoint)
320
+ body_model = Body(args.pose_detector_checkpoint)
321
+ openpose = OpenposeDetector(body_model)
322
+
323
+ def remove_tips():
324
+ return gr.update(visible=False)
325
+
326
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
327
+ if randomize_seed:
328
+ seed = random.randint(0, MAX_SEED)
329
+ return seed
330
+
331
+ def get_humanpose(img):
332
+ openpose_image = openpose(img)
333
+ return openpose_image
334
+
335
+ def get_cannyedge(image):
336
+ image = np.array(image)
337
+ image = cv2.Canny(image, 100, 200)
338
+ image = image[:, :, None]
339
+ image = np.concatenate([image, image, image], axis=2)
340
+ canny_image = Image.fromarray(image)
341
+ return canny_image
342
+
343
+ def get_depth(image):
344
+ image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda")
345
+ with torch.no_grad(), torch.autocast("cuda"):
346
+ depth_map = depth_estimator(image).predicted_depth
347
+
348
+ depth_map = torch.nn.functional.interpolate(
349
+ depth_map.unsqueeze(1),
350
+ size=(1024, 1024),
351
+ mode="bicubic",
352
+ align_corners=False,
353
+ )
354
+ depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
355
+ depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
356
+ depth_map = (depth_map - depth_min) / (depth_max - depth_min)
357
+ image = torch.cat([depth_map] * 3, dim=1)
358
+ image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
359
+ image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
360
+ return image
361
+
362
+ def generate_image(prompt1, negative_prompt, man, woman, resolution, local_prompt1, local_prompt2, seed, condition, condition_img1, style):
363
+ try:
364
+ path1 = lorapath_man[man]
365
+ path2 = lorapath_woman[woman]
366
+ pipe_concept.unload_lora_weights()
367
+ pipe.unload_lora_weights()
368
+ pipe_list = build_model_lora(pipe_concept, path1 + "|" + path2, lorapath_styles[style], condition, args, pipe)
369
+
370
+ if lorapath_styles[style] is not None and os.path.exists(lorapath_styles[style]):
371
+ styleL = True
372
+ else:
373
+ styleL = False
374
+
375
+ input_list = [prompt1]
376
+ condition_list = [condition_img1]
377
+ output_list = []
378
+
379
+ width, height = int(resolution.split("*")[0]), int(resolution.split("*")[1])
380
+
381
+ kwargs = {
382
+ 'height': height,
383
+ 'width': width,
384
+ }
385
+
386
+ for prompt, condition_img in zip(input_list, condition_list):
387
+ if prompt!='':
388
+ input_prompt = []
389
+ p = '{prompt}, 35mm photograph, film, professional, 4k, highly detailed.'
390
+ if styleL:
391
+ p = styles[style] + p
392
+ input_prompt.append([p.replace("{prompt}", prompt), p.replace("{prompt}", prompt)])
393
+ if styleL:
394
+ input_prompt.append([(styles[style] + local_prompt1, character_man.get(man)[1]),
395
+ (styles[style] + local_prompt2, character_woman.get(woman)[1])])
396
+ else:
397
+ input_prompt.append([(local_prompt1, character_man.get(man)[1]),
398
+ (local_prompt2, character_woman.get(woman)[1])])
399
+
400
+ if condition == 'Human pose' and condition_img is not None:
401
+ index = ratio_list.index(
402
+ min(ratio_list, key=lambda x: abs(x - condition_img.shape[1] / condition_img.shape[0])))
403
+ resolution = resolution_list[index]
404
+ width, height = int(resolution.split("*")[0]), int(resolution.split("*")[1])
405
+ kwargs['height'] = height
406
+ kwargs['width'] = width
407
+ condition_img = resize_and_center_crop(Image.fromarray(condition_img), (width, height))
408
+ spatial_condition = get_humanpose(condition_img)
409
+ elif condition == 'Canny Edge' and condition_img is not None:
410
+ index = ratio_list.index(
411
+ min(ratio_list, key=lambda x: abs(x - condition_img.shape[1] / condition_img.shape[0])))
412
+ resolution = resolution_list[index]
413
+ width, height = int(resolution.split("*")[0]), int(resolution.split("*")[1])
414
+ kwargs['height'] = height
415
+ kwargs['width'] = width
416
+ condition_img = resize_and_center_crop(Image.fromarray(condition_img), (width, height))
417
+ spatial_condition = get_cannyedge(condition_img)
418
+ elif condition == 'Depth' and condition_img is not None:
419
+ index = ratio_list.index(
420
+ min(ratio_list, key=lambda x: abs(x - condition_img.shape[1] / condition_img.shape[0])))
421
+ resolution = resolution_list[index]
422
+ width, height = int(resolution.split("*")[0]), int(resolution.split("*")[1])
423
+ kwargs['height'] = height
424
+ kwargs['width'] = width
425
+ condition_img = resize_and_center_crop(Image.fromarray(condition_img), (width, height))
426
+ spatial_condition = get_depth(condition_img)
427
+ else:
428
+ spatial_condition = None
429
+
430
+ kwargs['spatial_condition'] = spatial_condition
431
+ controller.reset()
432
+ image = sample_image(
433
+ pipe,
434
+ input_prompt=input_prompt,
435
+ concept_models=pipe_concept,
436
+ input_neg_prompt=[negative_prompt] * len(input_prompt),
437
+ generator=torch.Generator(device).manual_seed(seed),
438
+ controller=controller,
439
+ stage=1,
440
+ lora_list=pipe_list,
441
+ styleL=styleL,
442
+ **kwargs)
443
+
444
+ controller.reset()
445
+ if pipe.tokenizer("man")["input_ids"][1] in pipe.tokenizer(args.prompt)["input_ids"][1:-1]:
446
+ mask1 = predict_mask(detect_model, sam, image[0], 'man', args.segment_type, confidence=0.15,
447
+ threshold=0.5)
448
+ else:
449
+ mask1 = None
450
+
451
+ if pipe.tokenizer("woman")["input_ids"][1] in pipe.tokenizer(args.prompt)["input_ids"][1:-1]:
452
+ mask2 = predict_mask(detect_model, sam, image[0], 'woman', args.segment_type, confidence=0.15,
453
+ threshold=0.5)
454
+ else:
455
+ mask2 = None
456
+
457
+ if mask1 is None and mask2 is None:
458
+ output_list.append(image[1])
459
+ else:
460
+ image = sample_image(
461
+ pipe,
462
+ input_prompt=input_prompt,
463
+ concept_models=pipe_concept,
464
+ input_neg_prompt=[negative_prompt] * len(input_prompt),
465
+ generator=torch.Generator(device).manual_seed(seed),
466
+ controller=controller,
467
+ stage=2,
468
+ region_masks=[mask1, mask2],
469
+ lora_list=pipe_list,
470
+ styleL=styleL,
471
+ **kwargs)
472
+ output_list.append(image[1])
473
+ else:
474
+ output_list.append(None)
475
+ output_list.append(spatial_condition)
476
+ return output_list
477
+ except:
478
+ print("error")
479
+ return
480
+
481
+ def get_local_value_man(input):
482
+ return character_man[input][0]
483
+
484
+ def get_local_value_woman(input):
485
+ return character_woman[input][0]
486
+
487
+
488
+ with gr.Blocks(css=css) as demo:
489
+ # description
490
+ gr.Markdown(title)
491
+ gr.Markdown(description)
492
+
493
+ with gr.Row():
494
+ gallery = gr.Image(label="Generated Images", height=512, width=512)
495
+ gen_condition = gr.Image(label="Spatial Condition", height=512, width=512)
496
+ usage_tips = gr.Markdown(label="Usage tips of OMG", value=tips, visible=False)
497
+
498
+ with gr.Row():
499
+ condition_img1 = gr.Image(label="Input an RGB image for condition", height=128, width=128)
500
+
501
+ # character choose
502
+ with gr.Row():
503
+ man = gr.Dropdown(label="Character 1 selection", choices=CHARACTER_MAN_NAMES, value="Chris Evans (identifier: Chris Evans)")
504
+ woman = gr.Dropdown(label="Character 2 selection", choices=CHARACTER_WOMAN_NAMES, value="Taylor Swift (identifier: TaylorSwift)")
505
+ resolution = gr.Dropdown(label="Image Resolution (width*height)", choices=resolution_list, value="1024*1024")
506
+ condition = gr.Dropdown(label="Input condition type", choices=condition_list, value="None")
507
+ style = gr.Dropdown(label="style", choices=STYLE_NAMES, value="None")
508
+
509
+ with gr.Row():
510
+ local_prompt1 = gr.Textbox(label="Character1_prompt",
511
+ info="Describe the Character 1, this prompt should include the identifier of character 1",
512
+ value="Close-up photo of the Chris Evans, 35mm photograph, film, professional, 4k, highly detailed.")
513
+ local_prompt2 = gr.Textbox(label="Character2_prompt",
514
+ info="Describe the Character 2, this prompt should include the identifier of character2",
515
+ value="Close-up photo of the TaylorSwift, 35mm photograph, film, professional, 4k, highly detailed.")
516
+
517
+ man.change(get_local_value_man, man, local_prompt1)
518
+ woman.change(get_local_value_woman, woman, local_prompt2)
519
+
520
+ # prompt
521
+ with gr.Column():
522
+ prompt = gr.Textbox(label="Prompt 1",
523
+ info="Give a simple prompt to describe the first image content",
524
+ placeholder="Required",
525
+ value="close-up shot, photography, a man and a woman on the street, facing the camera smiling")
526
+
527
+
528
+ with gr.Accordion(open=False, label="Advanced Options"):
529
+ seed = gr.Slider(
530
+ label="Seed",
531
+ minimum=0,
532
+ maximum=MAX_SEED,
533
+ step=1,
534
+ value=42,
535
+ )
536
+ negative_prompt = gr.Textbox(label="Negative Prompt",
537
+ placeholder="noisy, blurry, soft, deformed, ugly",
538
+ value="noisy, blurry, soft, deformed, ugly")
539
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
540
+
541
+ submit = gr.Button("Submit", variant="primary")
542
+
543
+ submit.click(
544
+ fn=remove_tips,
545
+ outputs=usage_tips,
546
+ ).then(
547
+ fn=randomize_seed_fn,
548
+ inputs=[seed, randomize_seed],
549
+ outputs=seed,
550
+ queue=False,
551
+ api_name=False,
552
+ ).then(
553
+ fn=generate_image,
554
+ inputs=[prompt, negative_prompt, man, woman, resolution, local_prompt1, local_prompt2, seed, condition, condition_img1, style],
555
+ outputs=[gallery, gen_condition]
556
+ )
557
+ demo.launch(share=True)
558
+
559
+ def parse_args():
560
+ parser = argparse.ArgumentParser('', add_help=False)
561
+ parser.add_argument('--pretrained_sdxl_model', default='Fucius/stable-diffusion-xl-base-1.0', type=str)
562
+ parser.add_argument('--openpose_checkpoint', default='thibaud/controlnet-openpose-sdxl-1.0', type=str)
563
+ parser.add_argument('--canny_checkpoint', default='diffusers/controlnet-canny-sdxl-1.0', type=str)
564
+ parser.add_argument('--depth_checkpoint', default='diffusers/controlnet-depth-sdxl-1.0', type=str)
565
+ parser.add_argument('--efficientViT_checkpoint', default='./checkpoint/sam/xl1.pt', type=str)
566
+ parser.add_argument('--dino_checkpoint', default='./checkpoint/GroundingDINO', type=str)
567
+ parser.add_argument('--sam_checkpoint', default='./checkpoint/sam/sam_vit_h_4b8939.pth', type=str)
568
+ parser.add_argument('--dpt_checkpoint', default='Intel/dpt-hybrid-midas', type=str)
569
+ parser.add_argument('--pose_detector_checkpoint', default='./checkpoint/ControlNet/annotator/ckpts/body_pose_model.pth', type=str)
570
+ parser.add_argument('--prompt', default='Close-up photo of the cool man and beautiful woman in surprised expressions as they accidentally discover a mysterious island while on vacation by the sea, 35mm photograph, film, professional, 4k, highly detailed.', type=str)
571
+ parser.add_argument('--negative_prompt', default='noisy, blurry, soft, deformed, ugly', type=str)
572
+ parser.add_argument('--seed', default=22, type=int)
573
+ parser.add_argument('--suffix', default='', type=str)
574
+ parser.add_argument('--segment_type', default='yoloworld', help='GroundingDINO or yoloworld', type=str)
575
+ return parser.parse_args()
576
+
577
+ if __name__ == '__main__':
578
+ args = parse_args()
579
+
580
+ prompts = [args.prompt]*2
581
+ prompts_tmp = copy.deepcopy(prompts)
582
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
583
+ download = OMG_download()
584
+ main(device, args.segment_type)
download.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import torch
4
+ os.environ['HF_ENDPOINT']="https://hf-mirror.com"
5
+ from huggingface_hub import hf_hub_download
6
+
7
+ class OMG_download():
8
+ def __init__(self) -> None:
9
+ self.download_model_sam()
10
+ print("download efficientvit sam")
11
+
12
+ self.download_model_yoloworld()
13
+ print("download yoloworld")
14
+
15
+ self.download_controlNet()
16
+ print("download controlNet")
17
+
18
+ self.download_characters()
19
+ print("download character")
20
+
21
+ self.download_styles()
22
+ print("download styles")
23
+
24
+
25
+ def download_model_sam(self):
26
+ REPO_ID = 'han-cai/efficientvit-sam'
27
+ filename_list = ['xl1.pt']
28
+ if not os.path.exists('./checkpoints/sam/'):
29
+ os.makedirs('./checkpoints/sam/')
30
+ for filename in filename_list:
31
+ local_file = os.path.join('./checkpoints/sam/', filename)
32
+
33
+ if not os.path.exists(local_file):
34
+ hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir='./checkpoints/sam/', local_dir_use_symlinks=False)
35
+
36
+ def download_model_yoloworld(self):
37
+ REPO_ID = 'Fucius/OMG'
38
+ filename_list = ['yolo-world.pt']
39
+ if not os.path.exists('/tmp/cache/yolo_world/l/'):
40
+ os.makedirs('/tmp/cache/yolo_world/l/')
41
+ for filename in filename_list:
42
+ local_file = os.path.join('/tmp/cache/yolo_world/l/', filename)
43
+ if not os.path.exists(local_file):
44
+ hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir='/tmp/cache/yolo_world/l/', local_dir_use_symlinks=False)
45
+
46
+ def download_controlNet(self):
47
+ REPO_ID = 'lllyasviel/ControlNet'
48
+ filename_list = ['annotator/ckpts/body_pose_model.pth']
49
+ if not os.path.exists('./checkpoints/ControlNet/'):
50
+ os.makedirs('./checkpoints/ControlNet/')
51
+ for filename in filename_list:
52
+ local_file = os.path.join('./checkpoints/ControlNet/', filename)
53
+
54
+ if not os.path.exists(local_file):
55
+ hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir='./checkpoints/ControlNet/',
56
+ local_dir_use_symlinks=False)
57
+
58
+ def download_characters(self):
59
+ REPO_ID = 'Fucius/OMG'
60
+ filename_list = ['lora/chris-evans.safetensors', 'lora/Harry_Potter.safetensors', 'lora/Hermione_Granger.safetensors', 'lora/jordan_torres_v2_xl.safetensors', 'lora/keira_lora_sdxl_v1-000008.safetensors', 'lora/lawrence_dh128_v1-step00012000.safetensors', 'lora/Gleb-Savchenko_Liam-Hemsworth.safetensors', 'lora/TaylorSwiftSDXL.safetensors']
61
+ if not os.path.exists('./checkpoints/lora/'):
62
+ os.makedirs('./checkpoints/lora/')
63
+ for filename in filename_list:
64
+ local_file = os.path.join('./checkpoints/lora/', filename)
65
+
66
+ if not os.path.exists(local_file):
67
+ hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir='./checkpoints/lora/',
68
+ local_dir_use_symlinks=False)
69
+ def download_styles(self):
70
+ REPO_ID = 'Fucius/OMG'
71
+ filename_list = ['style/EldritchPaletteKnife.safetensors', 'style/Cinematic Hollywood Film.safetensors', 'style/Anime_Sketch_SDXL.safetensors']
72
+ if not os.path.exists('./checkpoints/style/'):
73
+ os.makedirs('./checkpoints/style/')
74
+ for filename in filename_list:
75
+ local_file = os.path.join('./checkpoints/style/', filename)
76
+
77
+ if not os.path.exists(local_file):
78
+ hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir='./checkpoints/style/',
79
+ local_dir_use_symlinks=False)
80
+
81
+ if __name__ == '__main__':
82
+ down = OMG_download()
83
+ print("finished download")