Spaces:
Running
Running
from multiprocessing import set_start_method | |
#set_start_method("fork") | |
import sys | |
#sys.path.insert(0, "../HR-VITON-main") | |
from test_generator import * | |
import re | |
import inspect | |
from dataclasses import dataclass, field | |
from tqdm import tqdm | |
import pandas as pd | |
import os | |
import torch | |
import pandas as pd | |
import gradio as gr | |
import streamlit as st | |
from io import BytesIO | |
#### pip install streamlit-image-select | |
from streamlit_image_select import image_select | |
demo_image_dir = "demo_images_dir" | |
assert os.path.exists(demo_image_dir) | |
demo_images = list(map(lambda y: os.path.join(demo_image_dir, y) ,filter(lambda x: x.endswith(".png") or x.endswith(".jpeg") or x.endswith(".jpg") | |
,os.listdir(demo_image_dir)))) | |
assert demo_images | |
#https://github.com/jrieke/streamlit-image-select/issues/10 | |
#.image-box { | |
# border: 1px solid rgba(49, 51, 63, 0.2); | |
# border-radius: 0.25rem; | |
# padding: calc(0.25rem + 1px); | |
# height: 10rem; | |
# min-width: 10rem; | |
#} | |
demo_images = list(map(lambda x: x.resize((256, 256)), map(Image.open, demo_images))) | |
class OPT: | |
#### ConditionGenerator | |
out_layer = None | |
warp_feature = None | |
#### SPADEGenerator | |
semantic_nc = None | |
fine_height = None | |
fine_width = None | |
ngf = None | |
num_upsampling_layers = None | |
norm_G = None | |
gen_semantic_nc = None | |
#### weight load | |
tocg_checkpoint = None | |
gen_checkpoint = None | |
cuda = False | |
data_list = None | |
datamode = None | |
dataroot = None | |
batch_size = None | |
shuffle = False | |
workers = None | |
clothmask_composition = None | |
occlusion = False | |
datasetting = None | |
opt = OPT() | |
opt.out_layer = "relu" | |
opt.warp_feature = "T1" | |
input1_nc = 4 # cloth + cloth-mask | |
nc = 13 | |
input2_nc = nc + 3 # parse_agnostic + densepose | |
output_nc = nc | |
tocg = ConditionGenerator(opt, | |
input1_nc=input1_nc, | |
input2_nc=input2_nc, output_nc=output_nc, ngf=96, norm_layer=nn.BatchNorm2d) | |
#### SPADEResBlock | |
from network_generator import SPADEResBlock | |
opt.semantic_nc = 7 | |
opt.fine_height = 1024 | |
opt.fine_width = 768 | |
opt.ngf = 64 | |
opt.num_upsampling_layers = "most" | |
opt.norm_G = "spectralaliasinstance" | |
opt.gen_semantic_nc = 7 | |
generator = SPADEGenerator(opt, 3+3+3) | |
generator.print_network() | |
#### https://drive.google.com/open?id=1XJTCdRBOPVgVTmqzhVGFAgMm2NLkw5uQ&authuser=0 | |
opt.tocg_checkpoint = "mtviton.pth" | |
#### https://drive.google.com/open?id=1T5_YDUhYSSKPC_nZMk2NeC-XXUFoYeNy&authuser=0 | |
opt.gen_checkpoint = "gen.pth" | |
opt.cuda = False | |
load_checkpoint(tocg, opt.tocg_checkpoint,opt) | |
load_checkpoint_G(generator, opt.gen_checkpoint,opt) | |
#### def test scope | |
tocg.eval() | |
generator.eval() | |
opt.data_list = "test_pairs.txt" | |
opt.datamode = "test" | |
opt.dataroot = "zalando-hd-resized" | |
opt.batch_size = 1 | |
opt.shuffle = False | |
opt.workers = 1 | |
opt.semantic_nc = 13 | |
test_dataset = CPDatasetTest(opt) | |
test_loader = CPDataLoader(opt, test_dataset) | |
def construct_images(img_tensors, img_names = [None]): | |
#for img_tensor, img_name in zip(img_tensors, img_names): | |
for img_tensor, img_name in zip(img_tensors, img_names): | |
tensor = (img_tensor.clone() + 1) * 0.5 * 255 | |
tensor = tensor.cpu().clamp(0, 255) | |
try: | |
array = tensor.numpy().astype('uint8') | |
except: | |
array = tensor.detach().numpy().astype('uint8') | |
if array.shape[0] == 1: | |
array = array.squeeze(0) | |
elif array.shape[0] == 3: | |
array = array.swapaxes(0, 1).swapaxes(1, 2) | |
im = Image.fromarray(array) | |
return im | |
def single_pred_slim_func(opt, inputs, tocg = tocg, generator = generator): | |
gauss = tgm.image.GaussianBlur((15, 15), (3, 3)) | |
if opt.cuda: | |
gauss = gauss.cuda() | |
# Model | |
if opt.cuda: | |
tocg.cuda() | |
tocg.eval() | |
generator.eval() | |
num = 0 | |
iter_start_time = time.time() | |
with torch.no_grad(): | |
for inputs in [inputs]: | |
if opt.cuda : | |
#pose_map = inputs['pose'].cuda() | |
pre_clothes_mask = inputs['cloth_mask'][opt.datasetting].cuda() | |
#label = inputs['parse'] | |
parse_agnostic = inputs['parse_agnostic'] | |
agnostic = inputs['agnostic'].cuda() | |
clothes = inputs['cloth'][opt.datasetting].cuda() # target cloth | |
densepose = inputs['densepose'].cuda() | |
#im = inputs['image'] | |
#input_label, input_parse_agnostic = label.cuda(), parse_agnostic.cuda() | |
input_parse_agnostic = parse_agnostic.cuda() | |
pre_clothes_mask = torch.FloatTensor((pre_clothes_mask.detach().cpu().numpy() > 0.5).astype(np.float)).cuda() | |
else : | |
#pose_map = inputs['pose'] | |
pre_clothes_mask = inputs['cloth_mask'][opt.datasetting] | |
#label = inputs['parse'] | |
parse_agnostic = inputs['parse_agnostic'] | |
agnostic = inputs['agnostic'] | |
clothes = inputs['cloth'][opt.datasetting] # target cloth | |
densepose = inputs['densepose'] | |
#im = inputs['image'] | |
#input_label, input_parse_agnostic = label, parse_agnostic | |
input_parse_agnostic = parse_agnostic | |
pre_clothes_mask = torch.FloatTensor((pre_clothes_mask.detach().cpu().numpy() > 0.5).astype(np.float)) | |
# down | |
#pose_map_down = F.interpolate(pose_map, size=(256, 192), mode='bilinear') | |
pre_clothes_mask_down = F.interpolate(pre_clothes_mask, size=(256, 192), mode='nearest') | |
#input_label_down = F.interpolate(input_label, size=(256, 192), mode='bilinear') | |
input_parse_agnostic_down = F.interpolate(input_parse_agnostic, size=(256, 192), mode='nearest') | |
#agnostic_down = F.interpolate(agnostic, size=(256, 192), mode='nearest') | |
clothes_down = F.interpolate(clothes, size=(256, 192), mode='bilinear') | |
densepose_down = F.interpolate(densepose, size=(256, 192), mode='bilinear') | |
shape = pre_clothes_mask.shape | |
# multi-task inputs | |
input1 = torch.cat([clothes_down, pre_clothes_mask_down], 1) | |
input2 = torch.cat([input_parse_agnostic_down, densepose_down], 1) | |
# forward | |
flow_list, fake_segmap, warped_cloth_paired, warped_clothmask_paired = tocg(opt,input1, input2) | |
# warped cloth mask one hot | |
if opt.cuda : | |
warped_cm_onehot = torch.FloatTensor((warped_clothmask_paired.detach().cpu().numpy() > 0.5).astype(np.float)).cuda() | |
else : | |
warped_cm_onehot = torch.FloatTensor((warped_clothmask_paired.detach().cpu().numpy() > 0.5).astype(np.float)) | |
if opt.clothmask_composition != 'no_composition': | |
if opt.clothmask_composition == 'detach': | |
cloth_mask = torch.ones_like(fake_segmap) | |
cloth_mask[:,3:4, :, :] = warped_cm_onehot | |
fake_segmap = fake_segmap * cloth_mask | |
if opt.clothmask_composition == 'warp_grad': | |
cloth_mask = torch.ones_like(fake_segmap) | |
cloth_mask[:,3:4, :, :] = warped_clothmask_paired | |
fake_segmap = fake_segmap * cloth_mask | |
# make generator input parse map | |
fake_parse_gauss = gauss(F.interpolate(fake_segmap, size=(opt.fine_height, opt.fine_width), mode='bilinear')) | |
fake_parse = fake_parse_gauss.argmax(dim=1)[:, None] | |
if opt.cuda : | |
old_parse = torch.FloatTensor(fake_parse.size(0), 13, opt.fine_height, opt.fine_width).zero_().cuda() | |
else: | |
old_parse = torch.FloatTensor(fake_parse.size(0), 13, opt.fine_height, opt.fine_width).zero_() | |
old_parse.scatter_(1, fake_parse, 1.0) | |
labels = { | |
0: ['background', [0]], | |
1: ['paste', [2, 4, 7, 8, 9, 10, 11]], | |
2: ['upper', [3]], | |
3: ['hair', [1]], | |
4: ['left_arm', [5]], | |
5: ['right_arm', [6]], | |
6: ['noise', [12]] | |
} | |
if opt.cuda : | |
parse = torch.FloatTensor(fake_parse.size(0), 7, opt.fine_height, opt.fine_width).zero_().cuda() | |
else: | |
parse = torch.FloatTensor(fake_parse.size(0), 7, opt.fine_height, opt.fine_width).zero_() | |
for i in range(len(labels)): | |
for label in labels[i][1]: | |
parse[:, i] += old_parse[:, label] | |
# warped cloth | |
N, _, iH, iW = clothes.shape | |
flow = F.interpolate(flow_list[-1].permute(0, 3, 1, 2), size=(iH, iW), mode='bilinear').permute(0, 2, 3, 1) | |
flow_norm = torch.cat([flow[:, :, :, 0:1] / ((96 - 1.0) / 2.0), flow[:, :, :, 1:2] / ((128 - 1.0) / 2.0)], 3) | |
grid = make_grid(N, iH, iW,opt) | |
warped_grid = grid + flow_norm | |
warped_cloth = F.grid_sample(clothes, warped_grid, padding_mode='border') | |
warped_clothmask = F.grid_sample(pre_clothes_mask, warped_grid, padding_mode='border') | |
if opt.occlusion: | |
warped_clothmask = remove_overlap(F.softmax(fake_parse_gauss, dim=1), warped_clothmask) | |
warped_cloth = warped_cloth * warped_clothmask + torch.ones_like(warped_cloth) * (1-warped_clothmask) | |
output = generator(torch.cat((agnostic, densepose, warped_cloth), dim=1), parse) | |
# save output | |
return output | |
#save_images(output, unpaired_names, output_dir) | |
#num += shape[0] | |
#print(num) | |
opt.clothmask_composition = "warp_grad" | |
opt.occlusion = False | |
opt.datasetting = "unpaired" | |
def read_img_and_trans(dataset ,opt ,img_path): | |
if type(img_path) in [type("")]: | |
im = Image.open(img_path) | |
else: | |
im = img_path | |
im = transforms.Resize(opt.fine_width, interpolation=2)(im) | |
im = dataset.transform(im) | |
return im | |
import sys | |
sys.path.insert(0, "fashion-eye-try-on") | |
import os | |
from PIL import Image | |
import gradio as gr | |
from cloth_segmentation import generate_cloth_mask | |
def generate_cloth_mask_and_display(cloth_img): | |
path = 'fashion-eye-try-on/cloth/cloth.jpg' | |
if os.path.exists(path): | |
os.remove(path) | |
cloth_img.save(path) | |
try: | |
# os.system('.\cloth_segmentation\generate_cloth_mask.py') | |
generate_cloth_mask() | |
except Exception as e: | |
print(e) | |
return | |
cloth_mask_img = Image.open("fashion-eye-try-on/cloth_mask/cloth.jpg") | |
return cloth_mask_img | |
def take_human_feature_from_dataset(dataset, idx): | |
inputs_upper = list(torch.utils.data.DataLoader( | |
[dataset[idx]], batch_size=1))[0] | |
return { | |
"parse_agnostic": inputs_upper["parse_agnostic"], | |
"agnostic": inputs_upper["agnostic"], | |
"densepose": inputs_upper["densepose"], | |
} | |
def take_all_feature_with_dataset(cloth_img_path, idx, opt = opt, dataset = test_dataset, only_show_human = False): | |
if type(cloth_img_path) != type(""): | |
assert hasattr(cloth_img_path, "save") | |
cloth_img_path.save("tmp_cloth.jpg") | |
cloth_img_path = "tmp_cloth.jpg" | |
assert type(cloth_img_path) == type("") | |
inputs_upper_dict = take_human_feature_from_dataset(dataset, idx) | |
if only_show_human: | |
return Image.fromarray((inputs_upper_dict["densepose"][0].numpy().transpose((1, 2, 0)) * 255).astype(np.uint8)) | |
cloth_readed = read_img_and_trans(dataset, opt, | |
cloth_img_path | |
) | |
#assert ((cloth_readed - inputs_upper["cloth"][opt.datasetting][0]) ** 2).sum().numpy() < 1e-15 | |
cloth_input = { | |
opt.datasetting: cloth_readed[None,:] | |
} | |
mask_img = generate_cloth_mask_and_display( | |
Image.open( | |
cloth_img_path | |
) | |
) | |
cloth_mask_input = { | |
opt.datasetting: | |
torch.Tensor((np.asarray(mask_img) / 255))[None, None, :] | |
} | |
inputs_upper_dict["cloth"] = cloth_input | |
inputs_upper_dict["cloth_mask"] = cloth_mask_input | |
return inputs_upper_dict | |
def pred_func(cloth_img, pidx | |
): | |
idx = int(pidx) | |
im = cloth_img | |
#### truly input | |
inputs_upper_dict = take_all_feature_with_dataset( | |
im, idx, only_show_human = False) | |
output_slim = single_pred_slim_func(opt, inputs_upper_dict) | |
output_img = construct_images(output_slim) | |
return output_img | |
option = st.selectbox( | |
"Choose cloth image or Upload cloth image", | |
("Choose", "Upload", ) | |
) | |
if type(option) != type(""): | |
option = "Choose" | |
img = None | |
uploaded_file = None | |
if option == "Upload": | |
# To read file as bytes: | |
uploaded_file = st.file_uploader("Upload img") | |
if uploaded_file is not None: | |
bytes_data = uploaded_file.getvalue() | |
img = Image.open(BytesIO(bytes_data)) | |
cloth_img = img.convert("RGB").resize((256 + 128, 512)) | |
st.image(cloth_img) | |
uploaded_file = st.selectbox( | |
"Have Choose the image", | |
("Wait", "Have Done") | |
) | |
else: | |
img = image_select("Choose img", demo_images) | |
#img = Image.open(img) | |
cloth_img = img.convert("RGB").resize((256 + 128, 512)) | |
st.image(cloth_img) | |
uploaded_file = st.selectbox( | |
"Have Choose the image", | |
("Wait", "Have Done") | |
) | |
if img is not None and (uploaded_file is not "Wait" and uploaded_file is not None): | |
cloth_img = img.convert("RGB").resize((768, 1024)) | |
#pidx = 44 | |
pidx_index_list = [44, 84, 67] | |
poeses = [] | |
for idx in range(len(pidx_index_list)): | |
poeses.append( | |
take_all_feature_with_dataset( | |
cloth_img, pidx_index_list[idx], only_show_human = True) | |
) | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
st.header("Pose 0") | |
pose_img = poeses[0] | |
st.image(pose_img) | |
b = pred_func(cloth_img, pidx_index_list[0]) | |
st.image(b) | |
with col2: | |
st.header("Pose 1") | |
pose_img = poeses[1] | |
st.image(pose_img) | |
b = pred_func(cloth_img, pidx_index_list[1]) | |
st.image(b) | |
with col3: | |
st.header("Pose 2") | |
pose_img = poeses[2] | |
st.image(pose_img) | |
b = pred_func(cloth_img, pidx_index_list[2]) | |
st.image(b) | |