import spaces import torch torch.jit.script = lambda f: f from download import OMG_download download = OMG_download() import sys sys.path.append('./') import argparse import hashlib import json import os.path import numpy as np import torch from typing import Tuple, List from diffusers import DPMSolverMultistepScheduler from diffusers.models import T2IAdapter from PIL import Image import copy from diffusers import ControlNetModel, StableDiffusionXLPipeline try: from insightface.app import FaceAnalysis except: print("insightface can not be load") import gradio as gr import random from PIL import Image, ImageOps from transformers import DPTFeatureExtractor, DPTForDepthEstimation from controlnet_aux import OpenposeDetector from controlnet_aux.open_pose.body import Body try: from inference.models import YOLOWorld from src.efficientvit.models.efficientvit.sam import EfficientViTSamPredictor from src.efficientvit.sam_model_zoo import create_sam_model import supervision as sv except: print("YoloWorld can not be load") try: from groundingdino.models import build_model from groundingdino.util import box_ops from groundingdino.util.slconfig import SLConfig from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap from groundingdino.util.inference import annotate, predict from segment_anything import build_sam, SamPredictor import groundingdino.datasets.transforms as T except: print("groundingdino can not be load") from src.pipelines.instantid_pipeline import InstantidMultiConceptPipeline from src.pipelines.instantid_single_pieline import InstantidSingleConceptPipeline from src.prompt_attention.p2p_attention import AttentionReplace from src.pipelines.instantid_pipeline import revise_regionally_controlnet_forward import cv2 import math import PIL.Image from gradio_demo.character_template import styles, lorapath_styles STYLE_NAMES = list(styles.keys()) MAX_SEED = np.iinfo(np.int32).max title = r"""

OMG + InstantID

""" description = r""" Official 🤗 Gradio demo for OMG: Occlusion-friendly Personalized Multi-concept Generation In Diffusion Models.

[Project][Code][Arxiv]

❗️Related demos: OMG + LoRAs ❗️

How to use:
1. Upload two character images of a man and a woman. 2. Enter a text prompt as done in normal text-to-image models. 3. Click the Submit button to start customizing. 4. Enjoy the generated image😊! """ article = r""" --- 📝 **Citation**
If our work is helpful for your research or applications, please cite us via: ```bibtex @article{kong2024omg, title={OMG: Occlusion-friendly Personalized Multi-concept Generation in Diffusion Models}, author={Kong, Zhe and Zhang, Yong and Yang, Tianyu and Wang, Tao and Zhang, Kaihao and Wu, Bizhu and Chen, Guanying and Liu, Wei and Luo, Wenhan}, journal={arXiv preprint arXiv:2403.10983}, year={2024} } ``` """ tips = r""" ### Usage tips of OMG 1. Input text prompts to describe a man and a woman """ css = ''' .gradio-container {width: 85% !important} ''' def build_dino_segment_model(ckpt_repo_id, sam_checkpoint): ckpt_filenmae = "groundingdino_swinb_cogcoor.pth" ckpt_config_filename = os.path.join(ckpt_repo_id, "GroundingDINO_SwinB.cfg.py") groundingdino_model = load_model_hf(ckpt_repo_id, ckpt_filenmae, ckpt_config_filename) sam = build_sam(checkpoint=sam_checkpoint) sam.cuda() sam_predictor = SamPredictor(sam) return groundingdino_model, sam_predictor def load_model_hf(repo_id, filename, ckpt_config_filename, device='cpu'): args = SLConfig.fromfile(ckpt_config_filename) model = build_model(args) args.device = device checkpoint = torch.load(os.path.join(repo_id, filename), map_location='cpu') log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False) print("Model loaded from {} \n => {}".format(filename, log)) _ = model.eval() return model def build_yolo_segment_model(sam_path, device): yolo_world = YOLOWorld(model_id="yolo_world/l") sam = EfficientViTSamPredictor( create_sam_model(name="xl1", weight_url=sam_path).to(device).eval() ) return yolo_world, sam def sample_image(pipe, input_prompt, input_neg_prompt=None, generator=None, concept_models=None, num_inference_steps=50, guidance_scale=3.0, controller=None, face_app=None, image=None, stage=None, region_masks=None, controlnet_conditioning_scale=None, **extra_kargs ): if image is not None: image_condition = [image] else: image_condition = None images = pipe( prompt=input_prompt, concept_models=concept_models, negative_prompt=input_neg_prompt, generator=generator, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, cross_attention_kwargs={"scale": 0.8}, controller=controller, image=image_condition, face_app=face_app, stage=stage, controlnet_conditioning_scale = controlnet_conditioning_scale, region_masks=region_masks, **extra_kargs).images return images def load_image_yoloworld(image_source) -> Tuple[np.array, torch.Tensor]: image = np.asarray(image_source) return image def load_image_dino(image_source) -> Tuple[np.array, torch.Tensor]: transform = T.Compose( [ T.RandomResize([800], max_size=1333), T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ] ) image = np.asarray(image_source) image_transformed, _ = transform(image_source, None) return image, image_transformed def draw_kps_multi(image_pil, kps_list, color_list=[(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255)]): stickwidth = 4 limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]]) w, h = image_pil.size out_img = np.zeros([h, w, 3]) for kps in kps_list: kps = np.array(kps) for i in range(len(limbSeq)): index = limbSeq[i] color = color_list[index[0]] x = kps[index][:, 0] y = kps[index][:, 1] length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5 angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1])) polygon = cv2.ellipse2Poly((int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1) out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color) out_img = (out_img * 0.6).astype(np.uint8) for idx_kp, kp in enumerate(kps): color = color_list[idx_kp] x, y = kp out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1) out_img_pil = PIL.Image.fromarray(out_img.astype(np.uint8)) return out_img_pil def predict_mask(segmentmodel, sam, image, TEXT_PROMPT, segmentType, confidence = 0.2, threshold = 0.5): if segmentType=='GroundingDINO': image_source, image = load_image_dino(image) boxes, logits, phrases = predict( model=segmentmodel, image=image, caption=TEXT_PROMPT, box_threshold=0.3, text_threshold=0.25 ) sam.set_image(image_source) H, W, _ = image_source.shape boxes_xyxy = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H]) transformed_boxes = sam.transform.apply_boxes_torch(boxes_xyxy, image_source.shape[:2]).cuda() masks, _, _ = sam.predict_torch( point_coords=None, point_labels=None, boxes=transformed_boxes, multimask_output=False, ) masks=masks[0].squeeze(0) else: image_source = load_image_yoloworld(image) segmentmodel.set_classes([TEXT_PROMPT]) results = segmentmodel.infer(image_source, confidence=confidence) detections = sv.Detections.from_inference(results).with_nms( class_agnostic=True, threshold=threshold ) masks = None if len(detections) != 0: print(TEXT_PROMPT + " detected!") sam.set_image(image_source, image_format="RGB") masks, _, _ = sam.predict(box=detections.xyxy[0], multimask_output=False) masks = torch.from_numpy(masks.squeeze()) return masks def build_model_sd(pretrained_model, controlnet_path, face_adapter, device, prompts, antelopev2_path, width, height, style_lora): controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16).to(device) pipe = InstantidMultiConceptPipeline.from_pretrained( pretrained_model, controlnet=controlnet, torch_dtype=torch.float16).to(device) controller = AttentionReplace(prompts, 50, cross_replace_steps={"default_": 1.}, self_replace_steps=0.4, tokenizer=pipe.tokenizer, device=device, width=width, height=height, dtype=torch.float16) revise_regionally_controlnet_forward(pipe.unet, controller) controlnet_concept = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16) pipe_concept = InstantidSingleConceptPipeline.from_pretrained( pretrained_model, controlnet=controlnet_concept, torch_dtype=torch.float16 ) pipe_concept.load_ip_adapter_instantid(face_adapter) pipe_concept.set_ip_adapter_scale(0.8) pipe_concept.to(device) pipe_concept.image_proj_model.to(pipe_concept._execution_device) if style_lora is not None and os.path.exists(style_lora): pipe.load_lora_weights(style_lora, weight_name="pytorch_lora_weights.safetensors", adapter_name='style') pipe_concept.load_lora_weights(style_lora, weight_name="pytorch_lora_weights.safetensors", adapter_name='style') # modify app = FaceAnalysis(name='antelopev2', root=antelopev2_path, providers=['CPUExecutionProvider']) app.prepare(ctx_id=0, det_size=(640, 640)) return pipe, controller, pipe_concept, app def prepare_text(prompt, region_prompts): ''' Args: prompt_entity: [subject1]-*-[attribute1]-*-[Location1]|[subject2]-*-[attribute2]-*-[Location2]|[global text] Returns: full_prompt: subject1, attribute1 and subject2, attribute2, global text context_prompt: subject1 and subject2, global text entity_collection: [(subject1, attribute1), Location1] ''' region_collection = [] regions = region_prompts.split('|') for region in regions: if region == '': break prompt_region, neg_prompt_region, ref_img = region.split('-*-') prompt_region = prompt_region.replace('[', '').replace(']', '') neg_prompt_region = neg_prompt_region.replace('[', '').replace(']', '') region_collection.append((prompt_region, neg_prompt_region, ref_img)) return (prompt, region_collection) def build_model_lora(pipe, pipe_concept, style_path, condition, condition_img): if condition == "Human pose" and condition_img is not None: controlnet = ControlNetModel.from_pretrained(args.openpose_checkpoint, torch_dtype=torch.float16).to(device) pipe.controlnet2 = controlnet elif condition == "Canny Edge" and condition_img is not None: controlnet = ControlNetModel.from_pretrained(args.canny_checkpoint, torch_dtype=torch.float16, variant="fp16").to(device) pipe.controlnet2 = controlnet elif condition == "Depth" and condition_img is not None: controlnet = ControlNetModel.from_pretrained(args.depth_checkpoint, torch_dtype=torch.float16).to(device) pipe.controlnet2 = controlnet if style_path is not None and os.path.exists(style_path): pipe_concept.load_lora_weights(style_path, weight_name="pytorch_lora_weights.safetensors", adapter_name='style') pipe.load_lora_weights(style_path, weight_name="pytorch_lora_weights.safetensors", adapter_name='style') def resize_and_center_crop(image, output_size=(1024, 576)): width, height = image.size aspect_ratio = width / height new_height = output_size[1] new_width = int(aspect_ratio * new_height) resized_image = image.resize((new_width, new_height), Image.LANCZOS) if new_width < output_size[0] or new_height < output_size[1]: padding_color = "gray" resized_image = ImageOps.expand(resized_image, ((output_size[0] - new_width) // 2, (output_size[1] - new_height) // 2, (output_size[0] - new_width + 1) // 2, (output_size[1] - new_height + 1) // 2), fill=padding_color) left = (resized_image.width - output_size[0]) / 2 top = (resized_image.height - output_size[1]) / 2 right = (resized_image.width + output_size[0]) / 2 bottom = (resized_image.height + output_size[1]) / 2 cropped_image = resized_image.crop((left, top, right, bottom)) return cropped_image def main(device, segment_type): pipe, controller, pipe_concepts, face_app = build_model_sd(args.pretrained_model, args.controlnet_path, args.face_adapter_path, device, prompts_tmp, args.antelopev2_path, width // 32, height // 32, args.style_lora) if segment_type == 'GroundingDINO': detect_model, sam = build_dino_segment_model(args.dino_checkpoint, args.sam_checkpoint) else: detect_model, sam = build_yolo_segment_model(args.efficientViT_checkpoint, device) resolution_list = ["1440*728", "1344*768", "1216*832", "1152*896", "1024*1024", "896*1152", "832*1216", "768*1344", "728*1440"] ratio_list = [1440 / 728, 1344 / 768, 1216 / 832, 1152 / 896, 1024 / 1024, 896 / 1152, 832 / 1216, 768 / 1344, 728 / 1440] condition_list = ["None", "Human pose", "Canny Edge", "Depth"] depth_estimator = DPTForDepthEstimation.from_pretrained(args.dpt_checkpoint).to("cuda") feature_extractor = DPTFeatureExtractor.from_pretrained(args.dpt_checkpoint) body_model = Body(args.pose_detector_checkpoint) openpose = OpenposeDetector(body_model) prompts_rewrite = [args.prompt_rewrite] input_prompt_test = [prepare_text(p, p_w) for p, p_w in zip(prompts, prompts_rewrite)] input_prompt_test = [prompts, input_prompt_test[0][1]] def remove_tips(): return gr.update(visible=False) def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: if randomize_seed: seed = random.randint(0, MAX_SEED) return seed def get_humanpose(img): openpose_image = openpose(img) return openpose_image def get_cannyedge(image): image = np.array(image) image = cv2.Canny(image, 100, 200) image = image[:, :, None] image = np.concatenate([image, image, image], axis=2) canny_image = Image.fromarray(image) return canny_image def get_depth(image, height, weight): image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda") with torch.no_grad(), torch.autocast("cuda"): depth_map = depth_estimator(image).predicted_depth depth_map = torch.nn.functional.interpolate( depth_map.unsqueeze(1), size=(height, weight), mode="bicubic", align_corners=False, ) depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True) depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True) depth_map = (depth_map - depth_min) / (depth_max - depth_min) image = torch.cat([depth_map] * 3, dim=1) image = image.permute(0, 2, 3, 1).cpu().numpy()[0] image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8)) return image @spaces.GPU(duration=180) def generate_image(prompt1, negative_prompt, reference_1, reference_2, resolution, local_prompt1, local_prompt2, seed, style, identitynet_strength_ratio, adapter_strength_ratio, condition, condition_img, controlnet_ratio, cfg_scale): identitynet_strength_ratio = float(identitynet_strength_ratio) adapter_strength_ratio = float(adapter_strength_ratio) controlnet_ratio = float(controlnet_ratio) cfg_scale = float(cfg_scale) if lorapath_styles[style] is not None and os.path.exists(lorapath_styles[style]): styleL = True else: styleL = False width, height = int(resolution.split("*")[0]), int(resolution.split("*")[1]) kwargs = { 'height': height, 'width': width, 't2i_controlnet_conditioning_scale': controlnet_ratio, } if condition == 'Human pose' and condition_img is not None: index = ratio_list.index( min(ratio_list, key=lambda x: abs(x - condition_img.shape[1] / condition_img.shape[0]))) resolution = resolution_list[index] width, height = int(resolution.split("*")[0]), int(resolution.split("*")[1]) kwargs['height'] = height kwargs['width'] = width condition_img = resize_and_center_crop(Image.fromarray(condition_img), (width, height)) spatial_condition = get_humanpose(condition_img) elif condition == 'Canny Edge' and condition_img is not None: index = ratio_list.index( min(ratio_list, key=lambda x: abs(x - condition_img.shape[1] / condition_img.shape[0]))) resolution = resolution_list[index] width, height = int(resolution.split("*")[0]), int(resolution.split("*")[1]) kwargs['height'] = height kwargs['width'] = width condition_img = resize_and_center_crop(Image.fromarray(condition_img), (width, height)) spatial_condition = get_cannyedge(condition_img) elif condition == 'Depth' and condition_img is not None: index = ratio_list.index( min(ratio_list, key=lambda x: abs(x - condition_img.shape[1] / condition_img.shape[0]))) resolution = resolution_list[index] width, height = int(resolution.split("*")[0]), int(resolution.split("*")[1]) kwargs['height'] = height kwargs['width'] = width condition_img = resize_and_center_crop(Image.fromarray(condition_img), (width, height)) spatial_condition = get_depth(condition_img, height, width) else: spatial_condition = None kwargs['t2i_image'] = spatial_condition pipe.unload_lora_weights() pipe_concepts.unload_lora_weights() build_model_lora(pipe, pipe_concepts, lorapath_styles[style], condition, condition_img) pipe_concepts.set_ip_adapter_scale(adapter_strength_ratio) input_list = [prompt1] for prompt in input_list: if prompt != '': input_prompt = [] p = '{prompt}, 35mm photograph, film, professional, 4k, highly detailed.' if styleL: p = styles[style] + p input_prompt.append([p.replace('{prompt}', prompt), p.replace("{prompt}", prompt)]) if styleL: input_prompt.append([(styles[style] + local_prompt1, 'noisy, blurry, soft, deformed, ugly', PIL.Image.fromarray(reference_1)), (styles[style] + local_prompt2, 'noisy, blurry, soft, deformed, ugly', PIL.Image.fromarray(reference_2))]) else: input_prompt.append( [(local_prompt1, 'noisy, blurry, soft, deformed, ugly', PIL.Image.fromarray(reference_1)), (local_prompt2, 'noisy, blurry, soft, deformed, ugly', PIL.Image.fromarray(reference_2))]) controller.reset() image = sample_image( pipe, input_prompt=input_prompt, concept_models=pipe_concepts, input_neg_prompt=[negative_prompt] * len(input_prompt), generator=torch.Generator(device).manual_seed(seed), controller=controller, face_app=face_app, controlnet_conditioning_scale=identitynet_strength_ratio, stage=1, guidance_scale=cfg_scale, **kwargs) controller.reset() if pipe.tokenizer("man")["input_ids"][1] in pipe.tokenizer(args.prompt)["input_ids"][1:-1]: mask1 = predict_mask(detect_model, sam, image[0], 'man', args.segment_type, confidence=0.05, threshold=0.5) else: mask1 = None if pipe.tokenizer("woman")["input_ids"][1] in pipe.tokenizer(args.prompt)["input_ids"][1:-1]: mask2 = predict_mask(detect_model, sam, image[0], 'woman', args.segment_type, confidence=0.05, threshold=0.5) else: mask2 = None if mask1 is not None or mask2 is not None: face_info = face_app.get(cv2.cvtColor(np.array(image[0]), cv2.COLOR_RGB2BGR)) face_kps = draw_kps_multi(image[0], [face['kps'] for face in face_info]) image = sample_image( pipe, input_prompt=input_prompt, concept_models=pipe_concepts, input_neg_prompt=[negative_prompt] * len(input_prompt), generator=torch.Generator(device).manual_seed(seed), controller=controller, face_app=face_app, image=face_kps, stage=2, controlnet_conditioning_scale=identitynet_strength_ratio, region_masks=[mask1, mask2], guidance_scale=cfg_scale, **kwargs) return [image[1], spatial_condition] # return image with gr.Blocks(css=css) as demo: # description gr.Markdown(title) gr.Markdown(description) with gr.Row(): gallery = gr.Image(label="Generated Images", height=512, width=512) gallery1 = gr.Image(label="Input Condition", height=512, width=512) usage_tips = gr.Markdown(label="Usage tips of OMG", value=tips, visible=False) with gr.Row(): reference_1 = gr.Image(label="Input an RGB image for Character man", height=128, width=128) reference_2 = gr.Image(label="Input an RGB image for Character woman", height=128, width=128) condition_img1 = gr.Image(label="Input an RGB image for condition (Optional)", height=128, width=128) with gr.Row(): local_prompt1 = gr.Textbox(label="Character1_prompt", info="Describe the Character 1", value="Close-up photo of the a man, 35mm photograph, professional, 4k, highly detailed.") local_prompt2 = gr.Textbox(label="Character2_prompt", info="Describe the Character 2", value="Close-up photo of the a woman, 35mm photograph, professional, 4k, highly detailed.") with gr.Row(): identitynet_strength_ratio = gr.Slider( label="IdentityNet strength (for fidelity)", minimum=0, maximum=1.5, step=0.05, value=0.80, ) adapter_strength_ratio = gr.Slider( label="Image adapter strength (for detail)", minimum=0, maximum=1.5, step=0.05, value=0.80, ) controlnet_ratio = gr.Slider( label="ControlNet strength", minimum=0, maximum=1.5, step=0.05, value=1, ) cfg_ratio = gr.Slider( label="CFG scale ", minimum=0.5, maximum=10, step=0.5, value=3.0, ) resolution = gr.Dropdown(label="Image Resolution (width*height)", choices=resolution_list, value="1024*1024") style = gr.Dropdown(label="style", choices=STYLE_NAMES, value="None") condition = gr.Dropdown(label="Input condition type", choices=condition_list, value="None") # prompt with gr.Column(): prompt = gr.Textbox(label="Prompt 1", info="Give a simple prompt to describe the first image content", placeholder="Required", value="close-up shot, photography, a man and a woman on the street, facing the camera smiling") with gr.Accordion(open=False, label="Advanced Options"): seed = gr.Slider( label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42, ) negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="noisy, blurry, soft, deformed, ugly", value="noisy, blurry, soft, deformed, ugly") randomize_seed = gr.Checkbox(label="Randomize seed", value=True) submit = gr.Button("Submit", variant="primary") submit.click( fn=remove_tips, outputs=usage_tips, ).then( fn=randomize_seed_fn, inputs=[seed, randomize_seed], outputs=seed, queue=False, api_name=False, ).then( fn=generate_image, inputs=[prompt, negative_prompt, reference_1, reference_2, resolution, local_prompt1, local_prompt2, seed, style, identitynet_strength_ratio, adapter_strength_ratio, condition, condition_img1, controlnet_ratio, cfg_ratio], outputs=[gallery, gallery1] ) gr.Markdown(article) demo.launch(share=True) def parse_args(): parser = argparse.ArgumentParser('', add_help=False) parser.add_argument('--pretrained_model', default='wangqixun/YamerMIX_v8', type=str) parser.add_argument('--controlnet_path', default='/home/user/app/checkpoint/InstantID/ControlNetModel', type=str) parser.add_argument('--face_adapter_path', default='/home/user/app/checkpoint/InstantID/ip-adapter.bin', type=str) parser.add_argument('--openpose_checkpoint', default='thibaud/controlnet-openpose-sdxl-1.0', type=str) parser.add_argument('--canny_checkpoint', default='diffusers/controlnet-canny-sdxl-1.0', type=str) parser.add_argument('--depth_checkpoint', default='diffusers/controlnet-depth-sdxl-1.0', type=str) parser.add_argument('--dpt_checkpoint', default='Intel/dpt-hybrid-midas', type=str) parser.add_argument('--pose_detector_checkpoint', default='/home/user/app/checkpoint/ControlNet/annotator/ckpts/body_pose_model.pth', type=str) parser.add_argument('--efficientViT_checkpoint', default='/home/user/app/checkpoint/sam/xl1.pt', type=str) parser.add_argument('--dino_checkpoint', default='/home/user/app/checkpoint/GroundingDINO', type=str) parser.add_argument('--sam_checkpoint', default='/home/user/app/checkpoint/sam/sam_vit_h_4b8939.pth', type=str) parser.add_argument('--antelopev2_path', default='/home/user/app/checkpoint/antelopev2', type=str) parser.add_argument('--save_dir', default='results/instantID', type=str) parser.add_argument('--prompt', default='Close-up photo of the cool man and beautiful woman as they accidentally discover a mysterious island while on vacation by the sea, facing the camera smiling, 35mm photograph, film, professional, 4k, highly detailed.', type=str) parser.add_argument('--negative_prompt', default='noisy, blurry, soft, deformed, ugly', type=str) parser.add_argument('--prompt_rewrite', default='[Close-up photo of a man, 35mm photograph, professional, 4k, highly detailed.]-*' '-[noisy, blurry, soft, deformed, ugly]-*-' '../example/chris-evans.jpg|' '[Close-up photo of a woman, 35mm photograph, professional, 4k, highly detailed.]-' '*-[noisy, blurry, soft, deformed, ugly]-*-' '../example/TaylorSwift.png', type=str) parser.add_argument('--seed', default=0, type=int) parser.add_argument('--suffix', default='', type=str) parser.add_argument('--segment_type', default='yoloworld', help='GroundingDINO or yoloworld', type=str) parser.add_argument('--style_lora', default='', type=str) return parser.parse_args() if __name__ == '__main__': args = parse_args() prompts = [args.prompt] * 2 prompts_tmp = copy.deepcopy(prompts) width, height = 1024, 1024 kwargs = { 'height': height, 'width': width, } device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') main(device, args.segment_type)