Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import gradio as gr | |
from PIL import Image | |
from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline | |
from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref | |
from src.unet_hacked_tryon import UNet2DConditionModel | |
from transformers import ( | |
CLIPImageProcessor, | |
CLIPVisionModelWithProjection, | |
CLIPTextModel, | |
CLIPTextModelWithProjection, | |
) | |
from diffusers import DDPMScheduler,AutoencoderKL | |
from typing import List | |
import torch | |
import os | |
from transformers import AutoTokenizer | |
import numpy as np | |
from utils_mask import get_mask_location | |
from torchvision import transforms | |
import apply_net | |
from preprocess.humanparsing.run_parsing import Parsing | |
from preprocess.openpose.run_openpose import OpenPose | |
from detectron2.data.detection_utils import convert_PIL_to_numpy,_apply_exif_orientation | |
from torchvision.transforms.functional import to_pil_image | |
def pil_to_binary_mask(pil_image, threshold=0): | |
np_image = np.array(pil_image) | |
grayscale_image = Image.fromarray(np_image).convert("L") | |
binary_mask = np.array(grayscale_image) > threshold | |
mask = np.zeros(binary_mask.shape, dtype=np.uint8) | |
for i in range(binary_mask.shape[0]): | |
for j in range(binary_mask.shape[1]): | |
if binary_mask[i,j] == True : | |
mask[i,j] = 1 | |
mask = (mask*255).astype(np.uint8) | |
output_mask = Image.fromarray(mask) | |
return output_mask | |
base_path = 'yisol/IDM-VTON' | |
example_path = os.path.join(os.path.dirname(__file__), 'example') | |
unet = UNet2DConditionModel.from_pretrained( | |
base_path, | |
subfolder="unet", | |
torch_dtype=torch.float16, | |
) | |
unet.requires_grad_(False) | |
tokenizer_one = AutoTokenizer.from_pretrained( | |
base_path, | |
subfolder="tokenizer", | |
revision=None, | |
use_fast=False, | |
) | |
tokenizer_two = AutoTokenizer.from_pretrained( | |
base_path, | |
subfolder="tokenizer_2", | |
revision=None, | |
use_fast=False, | |
) | |
noise_scheduler = DDPMScheduler.from_pretrained(base_path, subfolder="scheduler") | |
text_encoder_one = CLIPTextModel.from_pretrained( | |
base_path, | |
subfolder="text_encoder", | |
torch_dtype=torch.float16, | |
) | |
text_encoder_two = CLIPTextModelWithProjection.from_pretrained( | |
base_path, | |
subfolder="text_encoder_2", | |
torch_dtype=torch.float16, | |
) | |
image_encoder = CLIPVisionModelWithProjection.from_pretrained( | |
base_path, | |
subfolder="image_encoder", | |
torch_dtype=torch.float16, | |
) | |
vae = AutoencoderKL.from_pretrained(base_path, | |
subfolder="vae", | |
torch_dtype=torch.float16, | |
) | |
# "stabilityai/stable-diffusion-xl-base-1.0", | |
UNet_Encoder = UNet2DConditionModel_ref.from_pretrained( | |
base_path, | |
subfolder="unet_encoder", | |
torch_dtype=torch.float16, | |
) | |
parsing_model = Parsing(0) | |
openpose_model = OpenPose(0) | |
UNet_Encoder.requires_grad_(False) | |
image_encoder.requires_grad_(False) | |
vae.requires_grad_(False) | |
unet.requires_grad_(False) | |
text_encoder_one.requires_grad_(False) | |
text_encoder_two.requires_grad_(False) | |
tensor_transfrom = transforms.Compose( | |
[ | |
transforms.ToTensor(), | |
transforms.Normalize([0.5], [0.5]), | |
] | |
) | |
pipe = TryonPipeline.from_pretrained( | |
base_path, | |
unet=unet, | |
vae=vae, | |
feature_extractor= CLIPImageProcessor(), | |
text_encoder = text_encoder_one, | |
text_encoder_2 = text_encoder_two, | |
tokenizer = tokenizer_one, | |
tokenizer_2 = tokenizer_two, | |
scheduler = noise_scheduler, | |
image_encoder=image_encoder, | |
torch_dtype=torch.float16, | |
) | |
pipe.unet_encoder = UNet_Encoder | |
def start_tryon(dict,garm_img,garment_des,is_checked,is_checked_crop,denoise_steps,seed, category): | |
device = "cuda" | |
category = int(category) | |
if category==0: | |
category='upper_body' | |
elif category==1: | |
category='lower_body' | |
else: | |
category='dresses' | |
openpose_model.preprocessor.body_estimation.model.to(device) | |
pipe.to(device) | |
pipe.unet_encoder.to(device) | |
garm_img= garm_img.convert("RGB").resize((768,1024)) | |
human_img_orig = dict["background"].convert("RGB") | |
if is_checked_crop: | |
width, height = human_img_orig.size | |
aspect_ratio = width / height | |
if not (0.45 < aspect_ratio < 0.46): | |
target_width = int(min(width, height * (3 / 4))) | |
target_height = int(min(height, width * (4 / 3))) | |
left = (width - target_width) / 2 | |
top = (height - target_height) / 2 | |
right = (width + target_width) / 2 | |
bottom = (height + target_height) / 2 | |
cropped_img = human_img_orig.crop((left, top, right, bottom)) | |
crop_size = cropped_img.size | |
human_img = cropped_img.resize((768, 1024)) | |
else: | |
human_img = human_img_orig.resize((768,1024)) | |
else: | |
human_img = human_img_orig.resize((768,1024)) | |
if is_checked: | |
keypoints = openpose_model(human_img.resize((384,512))) | |
model_parse, _ = parsing_model(human_img.resize((384,512))) | |
mask, mask_gray = get_mask_location('hd', category, model_parse, keypoints) | |
mask = mask.resize((768,1024)) | |
else: | |
mask = pil_to_binary_mask(dict['layers'][0].convert("RGB").resize((768, 1024))) | |
# mask = transforms.ToTensor()(mask) | |
# mask = mask.unsqueeze(0) | |
mask_gray = (1-transforms.ToTensor()(mask)) * tensor_transfrom(human_img) | |
mask_gray = to_pil_image((mask_gray+1.0)/2.0) | |
human_img_arg = _apply_exif_orientation(human_img.resize((384,512))) | |
human_img_arg = convert_PIL_to_numpy(human_img_arg, format="BGR") | |
args = apply_net.create_argument_parser().parse_args(('show', './configs/densepose_rcnn_R_50_FPN_s1x.yaml', './ckpt/densepose/model_final_162be9.pkl', 'dp_segm', '-v', '--opts', 'MODEL.DEVICE', 'cuda')) | |
# verbosity = getattr(args, "verbosity", None) | |
pose_img = args.func(args,human_img_arg) | |
pose_img = pose_img[:,:,::-1] | |
pose_img = Image.fromarray(pose_img).resize((768,1024)) | |
with torch.no_grad(): | |
# Extract the images | |
with torch.cuda.amp.autocast(): | |
with torch.no_grad(): | |
prompt = "model is wearing " + garment_des | |
negative_prompt = "monochrome, lowres, bad anatomy, bad fingers, worst quality, low quality, contortionist, amputee, polydactyly, deformed, distorted, misshapen, malformed, abnormal, mutant, defaced, shapeless, unreal, missing arms, three hands, bad face, extra fingers, cartoon, fused face, cg, ugly fingers, three legs, bad hands, fused feet, worst face, extra eyes, long fingers, three feet, missing legs, cloned face, worst feet, extra crus, huge eyes, fused crus, three thigh, bad anatomy, disconnected limbs, animate, 3d, worst thigh, extra thigh, fused thigh, missing fingers, amputation, poorly drawn face, three crus, horn, 2girl, bad arms" | |
with torch.inference_mode(): | |
( | |
prompt_embeds, | |
negative_prompt_embeds, | |
pooled_prompt_embeds, | |
negative_pooled_prompt_embeds, | |
) = pipe.encode_prompt( | |
prompt, | |
num_images_per_prompt=1, | |
do_classifier_free_guidance=True, | |
negative_prompt=negative_prompt, | |
) | |
prompt = "a photo of " + garment_des | |
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, bad fingers, low quality, contortionist, amputee, polydactyly, deformed, distorted, misshapen, malformed, abnormal, mutant, defaced, shapeless, unreal, missing arms, three hands, bad face, extra fingers, cartoon, fused face, cg, ugly fingers, three legs, bad hands, fused feet, worst face, extra eyes, long fingers, three feet, missing legs, cloned face, worst feet, extra crus, huge eyes, fused crus, three thigh, bad anatomy, disconnected limbs, animate, 3d, worst thigh, extra thigh, fused thigh, missing fingers, amputation, poorly drawn face, three crus, horn, 2girl, bad arms" | |
if not isinstance(prompt, List): | |
prompt = [prompt] * 1 | |
if not isinstance(negative_prompt, List): | |
negative_prompt = [negative_prompt] * 1 | |
with torch.inference_mode(): | |
( | |
prompt_embeds_c, | |
_, | |
_, | |
_, | |
) = pipe.encode_prompt( | |
prompt, | |
num_images_per_prompt=1, | |
do_classifier_free_guidance=False, | |
negative_prompt=negative_prompt, | |
) | |
pose_img = tensor_transfrom(pose_img).unsqueeze(0).to(device,torch.float16) | |
garm_tensor = tensor_transfrom(garm_img).unsqueeze(0).to(device,torch.float16) | |
generator = torch.Generator(device).manual_seed(seed) if seed is not None else None | |
images = pipe( | |
prompt_embeds=prompt_embeds.to(device,torch.float16), | |
negative_prompt_embeds=negative_prompt_embeds.to(device,torch.float16), | |
pooled_prompt_embeds=pooled_prompt_embeds.to(device,torch.float16), | |
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.to(device,torch.float16), | |
num_inference_steps=denoise_steps, | |
generator=generator, | |
strength = 1.0, | |
pose_img = pose_img.to(device,torch.float16), | |
text_embeds_cloth=prompt_embeds_c.to(device,torch.float16), | |
cloth = garm_tensor.to(device,torch.float16), | |
mask_image=mask, | |
image=human_img, | |
height=1024, | |
width=768, | |
ip_adapter_image = garm_img.resize((768,1024)), | |
guidance_scale=2.0, | |
)[0] | |
if is_checked_crop: | |
if not (0.45 < aspect_ratio < 0.46): | |
out_img =images[0].resize(crop_size) | |
human_img_orig.paste(out_img, (int(left), int(top))) | |
else: | |
return images[0], mask_gray | |
return human_img_orig, mask_gray | |
else: | |
return images[0], mask_gray | |
# return images[0], mask_gray | |
garm_list = os.listdir(os.path.join(example_path,"cloth")) | |
garm_list_path = [os.path.join(example_path,"cloth",garm) for garm in garm_list] | |
human_list = os.listdir(os.path.join(example_path,"human")) | |
human_list_path = [os.path.join(example_path,"human",human) for human in human_list] | |
human_ex_list = [] | |
for ex_human in human_list_path: | |
ex_dict= {} | |
ex_dict['background'] = ex_human | |
ex_dict['layers'] = None | |
ex_dict['composite'] = None | |
human_ex_list.append(ex_dict) | |
##default human | |
image_blocks = gr.Blocks().queue() | |
with image_blocks as demo: | |
gr.Markdown("## DressFit") | |
gr.Markdown("DressFit Demo") | |
with gr.Row(): | |
with gr.Column(): | |
imgs = gr.ImageEditor(sources='upload', type="pil", label='Human. Mask with pen or use auto-masking', interactive=True) | |
with gr.Row(): | |
is_checked = gr.Checkbox(label="Yes", info="Use auto-generated mask (Takes 5 seconds)",value=True) | |
with gr.Row(): | |
is_checked_crop = gr.Checkbox(label="Yes", info="Use auto-crop & resizing",value=False) | |
with gr.Row(): | |
category = gr.Textbox(placeholder="Description of garment ex) Short Sleeve Round Neck T-shirts", show_label=False, elem_id="prompt") | |
example = gr.Examples( | |
inputs=imgs, | |
examples_per_page=10, | |
examples=human_ex_list | |
) | |
with gr.Column(): | |
garm_img = gr.Image(label="Garment", sources='upload', type="pil") | |
with gr.Row(elem_id="prompt-container"): | |
with gr.Row(): | |
prompt = gr.Textbox(placeholder="Description of garment ex) Short Sleeve Round Neck T-shirts", show_label=False, elem_id="prompt") | |
example = gr.Examples( | |
inputs=garm_img, | |
examples_per_page=8, | |
examples=garm_list_path) | |
with gr.Column(): | |
# image_out = gr.Image(label="Output", elem_id="output-img", height=400) | |
masked_img = gr.Image(label="Masked image output", elem_id="masked-img",show_share_button=False) | |
with gr.Column(): | |
# image_out = gr.Image(label="Output", elem_id="output-img", height=400) | |
image_out = gr.Image(label="Output", elem_id="output-img",show_share_button=False) | |
with gr.Column(): | |
try_button = gr.Button(value="Try-on") | |
with gr.Accordion(label="Advanced Settings", open=False): | |
with gr.Row(): | |
denoise_steps = gr.Number(label="Denoising Steps", minimum=20, maximum=40, value=30, step=1) | |
seed = gr.Number(label="Seed", minimum=-1, maximum=2147483647, step=1, value=42) | |
try_button.click(fn=start_tryon, inputs=[imgs, garm_img, prompt, is_checked,is_checked_crop, denoise_steps, seed, category], outputs=[image_out,masked_img], api_name='tryon') | |
image_blocks.launch() | |