Spaces:
Runtime error
Runtime error
File size: 4,380 Bytes
2fa7ad5 bd46aed b390b12 2fa7ad5 b390b12 eb387cb b390b12 eb387cb b390b12 2fa7ad5 b390b12 9024e34 eb387cb 9024e34 eb387cb 9024e34 b390b12 9024e34 b390b12 2fa7ad5 b390b12 9c8052b 9024e34 2fa7ad5 9024e34 eb387cb b390b12 eb387cb 9024e34 122ecd4 b390b12 122ecd4 b390b12 2fa7ad5 b390b12 2fa7ad5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
from flask import Flask, request, jsonify
import torch
from PIL import Image, ImageOps
from utils_ootd import get_mask_location
from preprocess.openpose.run_openpose import OpenPose
from preprocess.humanparsing.run_parsing import Parsing
from ootd.inference_ootd_hd import OOTDiffusionHD
from ootd.inference_ootd_dc import OOTDiffusionDC
app = Flask(__name__)
# Charger les modèles une seule fois au démarrage de l'application
openpose_model_hd = OpenPose(0)
parsing_model_hd = Parsing(0)
ootd_model_hd = OOTDiffusionHD(0)
openpose_model_dc = OpenPose(1)
parsing_model_dc = Parsing(1)
ootd_model_dc = OOTDiffusionDC(1)
# Définir la configuration GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
category_dict = ['upperbody', 'lowerbody', 'dress']
category_dict_utils = ['upper_body', 'lower_body', 'dresses']
@app.route("/process_hd", methods=["POST"])
def process_hd():
data = request.files
vton_img = data['vton_img']
garm_img = data['garm_img']
n_samples = int(request.form['n_samples'])
n_steps = int(request.form['n_steps'])
image_scale = float(request.form['image_scale'])
seed = int(request.form['seed'])
model_type = 'hd'
category = 0 # 0:upperbody; 1:lowerbody; 2:dress
# Charger les modèles en mémoire GPU
with torch.no_grad():
openpose_model_hd.preprocessor.body_estimation.model.to(device)
ootd_model_hd.pipe.to(device)
ootd_model_hd.image_encoder.to(device)
ootd_model_hd.text_encoder.to(device)
garm_img = Image.open(garm_img).resize((768, 1024))
vton_img = Image.open(vton_img).resize((768, 1024))
keypoints = openpose_model_hd(vton_img.resize((384, 512)))
model_parse, _ = parsing_model_hd(vton_img.resize((384, 512)))
mask, mask_gray = get_mask_location(model_type, category_dict_utils[category], model_parse, keypoints)
mask = mask.resize((768, 1024), Image.NEAREST)
mask_gray = mask_gray.resize((768, 1024), Image.NEAREST)
masked_vton_img = Image.composite(mask_gray, vton_img, mask)
images = ootd_model_hd(
model_type=model_type,
category=category_dict[category],
image_garm=garm_img,
image_vton=masked_vton_img,
mask=mask,
image_ori=vton_img,
num_samples=n_samples,
num_steps=n_steps,
image_scale=image_scale,
seed=seed,
)
return jsonify(result=images)
@app.route("/process_dc", methods=["POST"])
def process_dc():
data = request.files
vton_img = data['vton_img']
garm_img = data['garm_img']
category = request.form['category']
n_samples = int(request.form['n_samples'])
n_steps = int(request.form['n_steps'])
image_scale = float(request.form['image_scale'])
seed = int(request.form['seed'])
model_type = 'dc'
if category == 'Upper-body':
category = 0
elif category == 'Lower-body':
category = 1
else:
category = 2
# Charger les modèles en mémoire GPU
with torch.no_grad():
openpose_model_dc.preprocessor.body_estimation.model.to(device)
ootd_model_dc.pipe.to(device)
ootd_model_dc.image_encoder.to(device)
ootd_model_dc.text_encoder.to(device)
garm_img = Image.open(garm_img).resize((768, 1024))
vton_img = Image.open(vton_img).resize((768, 1024))
keypoints = openpose_model_dc(vton_img.resize((384, 512)))
model_parse, _ = parsing_model_dc(vton_img.resize((384, 512)))
mask, mask_gray = get_mask_location(model_type, category_dict_utils[category], model_parse, keypoints)
mask = mask.resize((768, 1024), Image.NEAREST)
mask_gray = mask_gray.resize((768, 1024), Image.NEAREST)
masked_vton_img = Image.composite(mask_gray, vton_img, mask)
images = ootd_model_dc(
model_type=model_type,
category=category_dict[category],
image_garm=garm_img,
image_vton=masked_vton_img,
mask=mask,
image_ori=vton_img,
num_samples=n_samples,
num_steps=n_steps,
image_scale=image_scale,
seed=seed,
)
return jsonify(result=images)
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860)
|