File size: 6,002 Bytes
4a285f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import numpy as np
import os
import time
import sys
import gradio as gr
import u2net_load
import u2net_run
from rembg import remove
from PIL import Image, ImageOps
from predict_pose import generate_pose_keypoints

# Make directories
os.system("mkdir ./Data_preprocessing")
os.system("mkdir ./Data_preprocessing/test_color")
os.system("mkdir ./Data_preprocessing/test_colormask")
os.system("mkdir ./Data_preprocessing/test_edge")
os.system("mkdir ./Data_preprocessing/test_img")
os.system("mkdir ./Data_preprocessing/test_label")
os.system("mkdir ./Data_preprocessing/test_mask")
os.system("mkdir ./Data_preprocessing/test_pose")
os.system("mkdir ./inputs")
os.system("mkdir ./inputs/img")
os.system("mkdir ./inputs/cloth")
os.system("mkdir ./saved_models/u2net")
os.system("mkdir ./saved_models/u2netp")
os.system("mkdir ./pose")
os.system("mkdir ./checkpoints")

# Get pose model
os.system("wget -O ./pose/pose_deploy_linevec.prototxt https://github.com/hasibzunair/fifa-demo/releases/download/v1.0/pose_deploy_linevec.prototxt")
os.system("wget -O ./pose/pose_iter_440000.caffemodel https://github.com/hasibzunair/fifa-demo/releases/download/v1.0/pose_iter_440000.caffemodel")

# For segmentation mask generation
os.system("wget https://github.com/hasibzunair/fifa-demo/releases/download/v1.0/lip_final.pth")

# Get U-2-Net weights
os.system("wget -P saved_models/u2netp/ https://github.com/hasibzunair/fifa-demo/releases/download/v1.0/u2netp.pth")
os.system("wget -P saved_models/u2net/ https://github.com/hasibzunair/fifa-demo/releases/download/v1.0/u2net.pth")

# Get model checkpoints
os.system("wget -O ./checkpoints/decavtonfifapretrain.zip https://github.com/hasibzunair/vton-demo/releases/download/v1.0/decavtonfifapretrain.zip")
os.system("unzip ./checkpoints/decavtonfifapretrain.zip -d ./checkpoints/")

print("########################Setup done!########################")

# Load U-2-Net model
u2net = u2net_load.model(model_name = 'u2netp')
# Main inference function
def inference(clothing_image, person_image, remove_bg):
    """
    Do try-on!
    """
    # Read cloth and person images
    cloth = Image.open(clothing_image) # cloth
    person = Image.open(person_image) # person
    # Save cloth and person images in "input" folder
    cloth.save(os.path.join("inputs/cloth/cloth.png"))
    person.save(os.path.join("inputs/img/person.png"))

    ############## Clothing image pre-processing
    cloth_name = 'cloth.png'
    cloth_path = os.path.join('inputs/cloth', sorted(os.listdir('inputs/cloth'))[0])
    cloth = Image.open(cloth_path)
    # Resize cloth image
    cloth = ImageOps.fit(cloth, (192, 256), Image.BICUBIC).convert("RGB")
    # Save resized cloth image
    cloth.save(os.path.join('Data_preprocessing/test_color', cloth_name))
    # 1. Get binary mask for clothing image
    u2net_run.infer(u2net, 'Data_preprocessing/test_color', 'Data_preprocessing/test_edge')

    ############## Person image pre-processing
    start_time = time.time()
    # Person image
    img_name = 'person.png'
    img_path = os.path.join('inputs/img', sorted(os.listdir('inputs/img'))[0])
    img = Image.open(img_path)
    if remove_bg == "yes":
        # Remove background
        img = remove(img, alpha_matting=True, alpha_matting_erode_size=15)
        print("Removing background from person image..")
    img = ImageOps.fit(img, (192, 256), Image.BICUBIC).convert("RGB")
    img_path = os.path.join('Data_preprocessing/test_img', img_name)
    img.save(img_path)
    resize_time = time.time()
    print('Resized image in {}s'.format(resize_time-start_time))

    # 2. Get parsed person image (test_label), uses person image
    os.system("python3 Self-Correction-Human-Parsing-for-ACGPN/simple_extractor.py --dataset 'lip' --model-restore 'lip_final.pth' --input-dir 'Data_preprocessing/test_img' --output-dir 'Data_preprocessing/test_label'")
    parse_time = time.time()
    print('Parsing generated in {}s'.format(parse_time-resize_time))
    
    # 3. Get pose map from person image
    pose_path = os.path.join('Data_preprocessing/test_pose', img_name.replace('.png', '_keypoints.json'))
    generate_pose_keypoints(img_path, pose_path)
    pose_time = time.time()
    print('Pose map generated in {}s'.format(pose_time-parse_time))
    os.system("rm -rf Data_preprocessing/test_pairs.txt")
    
    # Format: person, cloth image
    with open('Data_preprocessing/test_pairs.txt','w') as f:
        f.write('person.png cloth.png')
    
    # Do try-on
    os.system("python test.py --name decavtonfifapretrain")
    tryon_image = Image.open("results/test/try-on/person.png")
    print("Size of image is: ", tryon_image.size)
    return os.path.join("results/test/try-on/person.png")


title = "Virtual Dressing Room"
description = "This is a demo for image based virtual try-on. It generates a synthetic image of a person wearing a target clothing item. To use it, simply upload your clothing item and person images, or click one of the examples to load them. Read more at the links below."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/1512.03385' target='_blank'>Fill in Fabrics: Body-Aware Self-Supervised Inpainting for Image-Based Virtual Try-On (Under Review!)</a> | <a href='https://github.com/dktunited/fifa_demo' target='_blank'>Github</a></p>"
thumbnail = None # "./pathtothumbnail.png"

gr.Interface(
    inference,
    [gr.inputs.Image(type='filepath', label="Clothing Image"), 
     gr.inputs.Image(type='filepath', label="Person Image"),
     gr.inputs.Radio(choices=["yes","no"], default="no", label="Remove background from person image?")],
    gr.outputs.Image(type="filepath", label="Predicted Output"),
    examples=[["./examples/1/cloth.jpg", "./examples/1/person.jpg"],
              ["./examples/2/cloth.jpg", "./examples/2/person.jpg"]],
    title=title,
    description=description,
    article=article,
    allow_flagging=False,
    analytics_enabled=False,
    thumbnail=thumbnail,
    ).launch(debug=True, enable_queue=True)