fifa-tryon-demo / app.py
hasibzunair's picture
added files
4a285f6
import numpy as np
import os
import time
import sys
import gradio as gr
import u2net_load
import u2net_run
from rembg import remove
from PIL import Image, ImageOps
from predict_pose import generate_pose_keypoints
# Make directories
os.system("mkdir ./Data_preprocessing")
os.system("mkdir ./Data_preprocessing/test_color")
os.system("mkdir ./Data_preprocessing/test_colormask")
os.system("mkdir ./Data_preprocessing/test_edge")
os.system("mkdir ./Data_preprocessing/test_img")
os.system("mkdir ./Data_preprocessing/test_label")
os.system("mkdir ./Data_preprocessing/test_mask")
os.system("mkdir ./Data_preprocessing/test_pose")
os.system("mkdir ./inputs")
os.system("mkdir ./inputs/img")
os.system("mkdir ./inputs/cloth")
os.system("mkdir ./saved_models/u2net")
os.system("mkdir ./saved_models/u2netp")
os.system("mkdir ./pose")
os.system("mkdir ./checkpoints")
# Get pose model
os.system("wget -O ./pose/pose_deploy_linevec.prototxt https://github.com/hasibzunair/fifa-demo/releases/download/v1.0/pose_deploy_linevec.prototxt")
os.system("wget -O ./pose/pose_iter_440000.caffemodel https://github.com/hasibzunair/fifa-demo/releases/download/v1.0/pose_iter_440000.caffemodel")
# For segmentation mask generation
os.system("wget https://github.com/hasibzunair/fifa-demo/releases/download/v1.0/lip_final.pth")
# Get U-2-Net weights
os.system("wget -P saved_models/u2netp/ https://github.com/hasibzunair/fifa-demo/releases/download/v1.0/u2netp.pth")
os.system("wget -P saved_models/u2net/ https://github.com/hasibzunair/fifa-demo/releases/download/v1.0/u2net.pth")
# Get model checkpoints
os.system("wget -O ./checkpoints/decavtonfifapretrain.zip https://github.com/hasibzunair/vton-demo/releases/download/v1.0/decavtonfifapretrain.zip")
os.system("unzip ./checkpoints/decavtonfifapretrain.zip -d ./checkpoints/")
print("########################Setup done!########################")
# Load U-2-Net model
u2net = u2net_load.model(model_name = 'u2netp')
# Main inference function
def inference(clothing_image, person_image, remove_bg):
"""
Do try-on!
"""
# Read cloth and person images
cloth = Image.open(clothing_image) # cloth
person = Image.open(person_image) # person
# Save cloth and person images in "input" folder
cloth.save(os.path.join("inputs/cloth/cloth.png"))
person.save(os.path.join("inputs/img/person.png"))
############## Clothing image pre-processing
cloth_name = 'cloth.png'
cloth_path = os.path.join('inputs/cloth', sorted(os.listdir('inputs/cloth'))[0])
cloth = Image.open(cloth_path)
# Resize cloth image
cloth = ImageOps.fit(cloth, (192, 256), Image.BICUBIC).convert("RGB")
# Save resized cloth image
cloth.save(os.path.join('Data_preprocessing/test_color', cloth_name))
# 1. Get binary mask for clothing image
u2net_run.infer(u2net, 'Data_preprocessing/test_color', 'Data_preprocessing/test_edge')
############## Person image pre-processing
start_time = time.time()
# Person image
img_name = 'person.png'
img_path = os.path.join('inputs/img', sorted(os.listdir('inputs/img'))[0])
img = Image.open(img_path)
if remove_bg == "yes":
# Remove background
img = remove(img, alpha_matting=True, alpha_matting_erode_size=15)
print("Removing background from person image..")
img = ImageOps.fit(img, (192, 256), Image.BICUBIC).convert("RGB")
img_path = os.path.join('Data_preprocessing/test_img', img_name)
img.save(img_path)
resize_time = time.time()
print('Resized image in {}s'.format(resize_time-start_time))
# 2. Get parsed person image (test_label), uses person image
os.system("python3 Self-Correction-Human-Parsing-for-ACGPN/simple_extractor.py --dataset 'lip' --model-restore 'lip_final.pth' --input-dir 'Data_preprocessing/test_img' --output-dir 'Data_preprocessing/test_label'")
parse_time = time.time()
print('Parsing generated in {}s'.format(parse_time-resize_time))
# 3. Get pose map from person image
pose_path = os.path.join('Data_preprocessing/test_pose', img_name.replace('.png', '_keypoints.json'))
generate_pose_keypoints(img_path, pose_path)
pose_time = time.time()
print('Pose map generated in {}s'.format(pose_time-parse_time))
os.system("rm -rf Data_preprocessing/test_pairs.txt")
# Format: person, cloth image
with open('Data_preprocessing/test_pairs.txt','w') as f:
f.write('person.png cloth.png')
# Do try-on
os.system("python test.py --name decavtonfifapretrain")
tryon_image = Image.open("results/test/try-on/person.png")
print("Size of image is: ", tryon_image.size)
return os.path.join("results/test/try-on/person.png")
title = "Virtual Dressing Room"
description = "This is a demo for image based virtual try-on. It generates a synthetic image of a person wearing a target clothing item. To use it, simply upload your clothing item and person images, or click one of the examples to load them. Read more at the links below."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/1512.03385' target='_blank'>Fill in Fabrics: Body-Aware Self-Supervised Inpainting for Image-Based Virtual Try-On (Under Review!)</a> | <a href='https://github.com/dktunited/fifa_demo' target='_blank'>Github</a></p>"
thumbnail = None # "./pathtothumbnail.png"
gr.Interface(
inference,
[gr.inputs.Image(type='filepath', label="Clothing Image"),
gr.inputs.Image(type='filepath', label="Person Image"),
gr.inputs.Radio(choices=["yes","no"], default="no", label="Remove background from person image?")],
gr.outputs.Image(type="filepath", label="Predicted Output"),
examples=[["./examples/1/cloth.jpg", "./examples/1/person.jpg"],
["./examples/2/cloth.jpg", "./examples/2/person.jpg"]],
title=title,
description=description,
article=article,
allow_flagging=False,
analytics_enabled=False,
thumbnail=thumbnail,
).launch(debug=True, enable_queue=True)