Spaces:
Runtime error
Runtime error
File size: 4,308 Bytes
30ffa26 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import gradio as gr
import os
from pathlib import Path
import sys
import torch
from PIL import Image, ImageOps
from utils_ootd import get_mask_location
PROJECT_ROOT = Path(__file__).absolute().parents[1].absolute()
sys.path.insert(0, str(PROJECT_ROOT))
from preprocess.openpose.run_openpose import OpenPose
from preprocess.humanparsing.run_parsing import Parsing
from ootd.inference_ootd_hd import OOTDiffusionHD
from ootd.inference_ootd_dc import OOTDiffusionDC
openpose_model_hd = OpenPose(0)
parsing_model_hd = Parsing(0)
ootd_model_hd = OOTDiffusionHD(0)
openpose_model_dc = OpenPose(1)
parsing_model_dc = Parsing(1)
ootd_model_dc = OOTDiffusionDC(1)
category_dict = ['upperbody', 'lowerbody', 'dress']
category_dict_utils = ['upper_body', 'lower_body', 'dresses']
example_path = os.path.join(os.path.dirname(__file__), 'examples')
model_hd = os.path.join(example_path, 'model/model_1.png')
garment_hd = os.path.join(example_path, 'garment/03244_00.jpg')
model_dc = os.path.join(example_path, 'model/model_8.png')
garment_dc = os.path.join(example_path, 'garment/048554_1.jpg')
import spaces
@spaces.GPU
def process_hd(vton_img, garm_img, n_samples, n_steps, image_scale, seed):
model_type = 'hd'
category = 0 # 0:upperbody; 1:lowerbody; 2:dress
with torch.no_grad():
openpose_model_hd.preprocessor.body_estimation.model.to('cuda')
ootd_model_hd.pipe.to('cuda')
ootd_model_hd.image_encoder.to('cuda')
ootd_model_hd.text_encoder.to('cuda')
garm_img = Image.open(garm_img).resize((768, 1024))
vton_img = Image.open(vton_img).resize((768, 1024))
keypoints = openpose_model_hd(vton_img.resize((384, 512)))
model_parse, _ = parsing_model_hd(vton_img.resize((384, 512)))
mask, mask_gray = get_mask_location(model_type, category_dict_utils[category], model_parse, keypoints)
mask = mask.resize((768, 1024), Image.NEAREST)
mask_gray = mask_gray.resize((768, 1024), Image.NEAREST)
masked_vton_img = Image.composite(mask_gray, vton_img, mask)
images = ootd_model_hd(
model_type=model_type,
category=category_dict[category],
image_garm=garm_img,
image_vton=masked_vton_img,
mask=mask,
image_ori=vton_img,
num_samples=n_samples,
num_steps=n_steps,
image_scale=image_scale,
seed=seed,
)
return images
@spaces.GPU
def process_dc(vton_img, garm_img, category, n_samples, n_steps, image_scale, seed):
model_type = 'dc'
if category == 'Upper-body':
category = 0
elif category == 'Lower-body':
category = 1
else:
category =2
with torch.no_grad():
openpose_model_dc.preprocessor.body_estimation.model.to('cuda')
ootd_model_dc.pipe.to('cuda')
ootd_model_dc.image_encoder.to('cuda')
ootd_model_dc.text_encoder.to('cuda')
garm_img = Image.open(garm_img).resize((768, 1024))
vton_img = Image.open(vton_img).resize((768, 1024))
keypoints = openpose_model_dc(vton_img.resize((384, 512)))
model_parse, _ = parsing_model_dc(vton_img.resize((384, 512)))
mask, mask_gray = get_mask_location(model_type, category_dict_utils[category], model_parse, keypoints)
mask = mask.resize((768, 1024), Image.NEAREST)
mask_gray = mask_gray.resize((768, 1024), Image.NEAREST)
masked_vton_img = Image.composite(mask_gray, vton_img, mask)
images = ootd_model_dc(
model_type=model_type,
category=category_dict[category],
image_garm=garm_img,
image_vton=masked_vton_img,
mask=mask,
image_ori=vton_img,
num_samples=n_samples,
num_steps=n_steps,
image_scale=image_scale,
seed=seed,
)
return images
block = gr.Interface(fn=process_hd, inputs=["image", "image", "number", "number", "number", "number"], outputs="image", title="OOTDiffusion Demo HD")
block.launch()
block_dc = gr.Interface(fn=process_dc, inputs=["image", "image", "dropdown", "number", "number", "number", "number"], outputs="image", title="OOTDiffusion Demo DC")
block_dc.launch(api_name='generate') |