fifa-tryon-demo / app.py
hasibzunair's picture
update app.py with correct checkpoints
1c6d4d2
raw
history blame
8.17 kB
import numpy as np
import os
import time
import sys
import torch
import gradio as gr
import u2net_load
import u2net_run
from rembg import remove
from PIL import Image, ImageOps
from predict_pose import generate_pose_keypoints
# Use GPU if available
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
print(f"####Using {device}.#####")
# Make directories
os.system("mkdir ./Data_preprocessing")
os.system("mkdir ./Data_preprocessing/test_color")
os.system("mkdir ./Data_preprocessing/test_colormask")
os.system("mkdir ./Data_preprocessing/test_edge")
os.system("mkdir ./Data_preprocessing/test_img")
os.system("mkdir ./Data_preprocessing/test_label")
os.system("mkdir ./Data_preprocessing/test_mask")
os.system("mkdir ./Data_preprocessing/test_pose")
os.system("mkdir ./inputs")
os.system("mkdir ./inputs/img")
os.system("mkdir ./inputs/cloth")
os.system("mkdir ./saved_models/")
os.system("mkdir ./saved_models/u2net")
os.system("mkdir ./saved_models/u2netp")
os.system("mkdir ./pose")
os.system("mkdir ./checkpoints")
# Get pose model
if not os.path.exists("./pose/pose_deploy_linevec.prototxt"):
os.system("wget -O ./pose/pose_deploy_linevec.prototxt https://github.com/hasibzunair/fifa-demo/releases/download/v1.0/pose_deploy_linevec.prototxt")
if not os.path.exists("./pose/pose_iter_440000.caffemodel"):
os.system("wget -O ./pose/pose_iter_440000.caffemodel https://github.com/hasibzunair/fifa-demo/releases/download/v1.0/pose_iter_440000.caffemodel")
# For segmentation mask generation
if not os.path.exists("lip_final.pth"):
os.system("wget https://github.com/hasibzunair/fifa-demo/releases/download/v1.0/lip_final.pth")
# Get U-2-Net weights
if not os.path.exists("saved_models/u2netp/u2netp.pth"):
os.system("wget -P saved_models/u2netp/ https://github.com/hasibzunair/fifa-demo/releases/download/v1.0/u2netp.pth")
if not os.path.exists("saved_models/u2net/u2net.pth"):
os.system("wget -P saved_models/u2net/ https://github.com/hasibzunair/fifa-demo/releases/download/v1.0/u2net.pth")
# Get model checkpoints
if not os.path.exists("./checkpoints/decavtonfifapretrain/"):
os.system("wget -O ./checkpoints/decavtonfifapretrain.zip https://github.com/hasibzunair/vton-demo/releases/download/v1.0/decavtonfifapretrain.zip")
os.system("unzip ./checkpoints/decavtonfifapretrain.zip -d ./checkpoints/")
print("########################Setup done!########################")
# Load U-2-Net model
print(f"####Using {device}.#####")
u2net = u2net_load.model(model_name = 'u2netp')
def composite_background(img_mask, person_image_path, tryon_image_path):
"""Put background back on the person image after tryon."""
person = np.array(Image.open(person_image_path))
# tryon image
tryon = np.array(Image.open(tryon_image_path))
# persom image mask from rembg
p_mask = np.array(img_mask)
# make binary mask
p_mask = np.where(p_mask>0, 1, 0)
# invert mask
p_mask_inv = np.logical_not(p_mask)
# make bg without person
background = person * np.stack((p_mask_inv, p_mask_inv, p_mask_inv), axis=2)
# make tryon image without background
tryon_nobg = tryon * np.stack((p_mask, p_mask, p_mask), axis=2)
tryon_nobg = tryon_nobg.astype("uint8")
# composite
tryon_with_bg = np.add(tryon_nobg, background)
tryon_with_bg_pil = Image.fromarray(np.uint8(tryon_with_bg)).convert('RGB')
tryon_with_bg_pil.save("results/test/try-on/tryon_with_bg.png")
# Main inference function
def inference(clothing_image, person_image, remove_bg, retrieve_bg):
"""
Do try-on!
"""
# Read cloth and person images
cloth = Image.open(clothing_image) # cloth
person = Image.open(person_image) # person
# Save cloth and person images in "input" folder
cloth.save(os.path.join("inputs/cloth/cloth.png"))
person.save(os.path.join("inputs/img/person.png"))
############## Clothing image pre-processing
cloth_name = 'cloth.png'
cloth_path = os.path.join('inputs/cloth', sorted(os.listdir('inputs/cloth'))[0])
cloth = Image.open(cloth_path)
# Resize cloth image
cloth = ImageOps.fit(cloth, (192, 256), Image.BICUBIC).convert("RGB")
# Save resized cloth image
cloth.save(os.path.join('Data_preprocessing/test_color', cloth_name))
# 1. Get binary mask for clothing image
u2net_run.infer(u2net, 'Data_preprocessing/test_color', 'Data_preprocessing/test_edge')
############## Person image pre-processing
start_time = time.time()
# Person image
img_name = 'person.png'
img_path = os.path.join('inputs/img', sorted(os.listdir('inputs/img'))[0])
img = Image.open(img_path)
if remove_bg == "yes":
# Remove background
img = remove(img, alpha_matting=True, alpha_matting_erode_size=15)
print("Removing background from person image..")
img = ImageOps.fit(img, (192, 256), Image.BICUBIC).convert("RGB")
# Get binary from person image, used in def_composite_background
img_mask = remove(img, alpha_matting=True, alpha_matting_erode_size=15, only_mask=True)
img_path = os.path.join('Data_preprocessing/test_img', img_name)
img.save(img_path)
resize_time = time.time()
print('Resized image in {}s'.format(resize_time-start_time))
# 2. Get parsed person image (test_label), uses person image
os.system("python Self-Correction-Human-Parsing-for-ACGPN/simple_extractor.py --dataset 'lip' --model-restore 'lip_final.pth' --input-dir 'Data_preprocessing/test_img' --output-dir 'Data_preprocessing/test_label'")
parse_time = time.time()
print('Parsing generated in {}s'.format(parse_time-resize_time))
# 3. Get pose map from person image
pose_path = os.path.join('Data_preprocessing/test_pose', img_name.replace('.png', '_keypoints.json'))
generate_pose_keypoints(img_path, pose_path)
pose_time = time.time()
print('Pose map generated in {}s'.format(pose_time-parse_time))
os.system("rm -rf Data_preprocessing/test_pairs.txt")
# Format: person, cloth image
with open('Data_preprocessing/test_pairs.txt','w') as f:
f.write('person.png cloth.png')
# Do try-on
os.system("python test.py --name decavtonfifapretrain")
tryon_image = Image.open("results/test/try-on/person.png")
print("Size of image is: ", tryon_image.size)
# Return try-on with background added back on the person image
if retrieve_bg == "yes":
composite_background(img_mask, 'Data_preprocessing/test_img/person.png',
'results/test/try-on/person.png')
return os.path.join("results/test/try-on/tryon_with_bg.png")
# Return only try-on result
else:
return os.path.join("results/test/try-on/person.png")
title = "Virtual Dressing Room"
description = "This is a demo for image based virtual try-on. It generates a synthetic image of a person wearing a target clothing item. To use it, simply upload your clothing item and person images, or click one of the examples to load them. Read more at the links below."
article = "<p style='text-align: center'><a href='will_be_added' target='_blank'>Fill in Fabrics: Body-Aware Self-Supervised Inpainting for Image-Based Virtual Try-On (Under Review!)</a> | <a href='https://github.com/dktunited/fifa_demo' target='_blank'>Github</a></p>"
thumbnail = None # "./pathtothumbnail.png"
gr.Interface(
inference,
[gr.inputs.Image(type='filepath', label="Clothing Image"),
gr.inputs.Image(type='filepath', label="Person Image"),
gr.inputs.Radio(choices=["yes","no"], default="no", label="Remove background from the person image?"),
gr.inputs.Radio(choices=["yes","no"], default="no", label="Retrieve original background from the person image?")],
gr.outputs.Image(type="filepath", label="Predicted Output"),
examples=[["./examples/1/cloth.jpg", "./examples/1/person.jpg"],
["./examples/2/cloth.jpg", "./examples/2/person.jpg"]],
title=title,
description=description,
article=article,
allow_flagging=False,
analytics_enabled=False,
thumbnail=thumbnail,
).launch(debug=True, enable_queue=True)