File size: 4,380 Bytes
2fa7ad5
bd46aed
 
 
b390b12
 
 
 
 
 
2fa7ad5
b390b12
eb387cb
b390b12
 
 
 
 
 
 
 
eb387cb
 
 
b390b12
 
 
2fa7ad5
 
 
 
 
 
 
 
 
b390b12
9024e34
 
 
eb387cb
9024e34
eb387cb
 
 
 
9024e34
 
 
 
 
b390b12
9024e34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b390b12
2fa7ad5
 
 
 
 
 
 
 
 
 
 
 
b390b12
9c8052b
9024e34
 
 
 
 
2fa7ad5
9024e34
eb387cb
b390b12
eb387cb
 
 
 
9024e34
 
 
122ecd4
 
b390b12
 
 
 
 
 
 
122ecd4
b390b12
 
 
 
 
 
 
 
 
 
 
 
2fa7ad5
b390b12
2fa7ad5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from flask import Flask, request, jsonify
import torch
from PIL import Image, ImageOps

from utils_ootd import get_mask_location
from preprocess.openpose.run_openpose import OpenPose
from preprocess.humanparsing.run_parsing import Parsing
from ootd.inference_ootd_hd import OOTDiffusionHD
from ootd.inference_ootd_dc import OOTDiffusionDC

app = Flask(__name__)

# Charger les modèles une seule fois au démarrage de l'application
openpose_model_hd = OpenPose(0)
parsing_model_hd = Parsing(0)
ootd_model_hd = OOTDiffusionHD(0)

openpose_model_dc = OpenPose(1)
parsing_model_dc = Parsing(1)
ootd_model_dc = OOTDiffusionDC(1)

# Définir la configuration GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

category_dict = ['upperbody', 'lowerbody', 'dress']
category_dict_utils = ['upper_body', 'lower_body', 'dresses']

@app.route("/process_hd", methods=["POST"])
def process_hd():
    data = request.files
    vton_img = data['vton_img']
    garm_img = data['garm_img']
    n_samples = int(request.form['n_samples'])
    n_steps = int(request.form['n_steps'])
    image_scale = float(request.form['image_scale'])
    seed = int(request.form['seed'])

    model_type = 'hd'
    category = 0 # 0:upperbody; 1:lowerbody; 2:dress

    # Charger les modèles en mémoire GPU
    with torch.no_grad():
        openpose_model_hd.preprocessor.body_estimation.model.to(device)
        ootd_model_hd.pipe.to(device)
        ootd_model_hd.image_encoder.to(device)
        ootd_model_hd.text_encoder.to(device)
        
        garm_img = Image.open(garm_img).resize((768, 1024))
        vton_img = Image.open(vton_img).resize((768, 1024))
        keypoints = openpose_model_hd(vton_img.resize((384, 512)))
        model_parse, _ = parsing_model_hd(vton_img.resize((384, 512)))

        mask, mask_gray = get_mask_location(model_type, category_dict_utils[category], model_parse, keypoints)
        mask = mask.resize((768, 1024), Image.NEAREST)
        mask_gray = mask_gray.resize((768, 1024), Image.NEAREST)
        
        masked_vton_img = Image.composite(mask_gray, vton_img, mask)

        images = ootd_model_hd(
            model_type=model_type,
            category=category_dict[category],
            image_garm=garm_img,
            image_vton=masked_vton_img,
            mask=mask,
            image_ori=vton_img,
            num_samples=n_samples,
            num_steps=n_steps,
            image_scale=image_scale,
            seed=seed,
        )

    return jsonify(result=images)

@app.route("/process_dc", methods=["POST"])
def process_dc():
    data = request.files
    vton_img = data['vton_img']
    garm_img = data['garm_img']
    category = request.form['category']
    n_samples = int(request.form['n_samples'])
    n_steps = int(request.form['n_steps'])
    image_scale = float(request.form['image_scale'])
    seed = int(request.form['seed'])

    model_type = 'dc'
    if category == 'Upper-body':
        category = 0
    elif category == 'Lower-body':
        category = 1
    else:
        category = 2

    # Charger les modèles en mémoire GPU
    with torch.no_grad():
        openpose_model_dc.preprocessor.body_estimation.model.to(device)
        ootd_model_dc.pipe.to(device)
        ootd_model_dc.image_encoder.to(device)
        ootd_model_dc.text_encoder.to(device)
        
        garm_img = Image.open(garm_img).resize((768, 1024))
        vton_img = Image.open(vton_img).resize((768, 1024))
        keypoints = openpose_model_dc(vton_img.resize((384, 512)))
        model_parse, _ = parsing_model_dc(vton_img.resize((384, 512)))

        mask, mask_gray = get_mask_location(model_type, category_dict_utils[category], model_parse, keypoints)
        mask = mask.resize((768, 1024), Image.NEAREST)
        mask_gray = mask_gray.resize((768, 1024), Image.NEAREST)
        
        masked_vton_img = Image.composite(mask_gray, vton_img, mask)

        images = ootd_model_dc(
            model_type=model_type,
            category=category_dict[category],
            image_garm=garm_img,
            image_vton=masked_vton_img,
            mask=mask,
            image_ori=vton_img,
            num_samples=n_samples,
            num_steps=n_steps,
            image_scale=image_scale,
            seed=seed,
        )

    return jsonify(result=images)

if __name__ == "__main__":
    app.run(host="0.0.0.0", port=7860)