Spaces:
Build error
Build error
Create gradio_ootd.py
Browse files- run/gradio_ootd.py +118 -0
run/gradio_ootd.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from flask import Flask, request, jsonify ,send_file
|
2 |
+
from PIL import Image
|
3 |
+
import base64
|
4 |
+
import io
|
5 |
+
import random
|
6 |
+
import uuid
|
7 |
+
import numpy as np
|
8 |
+
import spaces
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from utils_ootd import get_mask_location
|
12 |
+
|
13 |
+
PROJECT_ROOT = Path(__file__).absolute().parents[1].absolute()
|
14 |
+
sys.path.insert(0, str(PROJECT_ROOT))
|
15 |
+
|
16 |
+
from preprocess.openpose.run_openpose import OpenPose
|
17 |
+
from preprocess.humanparsing.run_parsing import Parsing
|
18 |
+
from ootd.inference_ootd_hd import OOTDiffusionHD
|
19 |
+
from ootd.inference_ootd_dc import OOTDiffusionDC
|
20 |
+
|
21 |
+
|
22 |
+
openpose_model_hd = OpenPose(0)
|
23 |
+
parsing_model_hd = Parsing(0)
|
24 |
+
ootd_model_hd = OOTDiffusionHD(0)
|
25 |
+
|
26 |
+
openpose_model_dc = OpenPose(1)
|
27 |
+
parsing_model_dc = Parsing(1)
|
28 |
+
ootd_model_dc = OOTDiffusionDC(1)
|
29 |
+
|
30 |
+
|
31 |
+
category_dict = ['upperbody', 'lowerbody', 'dress']
|
32 |
+
category_dict_utils = ['upper_body', 'lower_body', 'dresses']
|
33 |
+
|
34 |
+
|
35 |
+
|
36 |
+
# Créer une instance FastAPI
|
37 |
+
app = Flask(__name__)
|
38 |
+
|
39 |
+
def save_image(img):
|
40 |
+
unique_name = str(uuid.uuid4()) + ".png"
|
41 |
+
img.save(unique_name)
|
42 |
+
return unique_name
|
43 |
+
|
44 |
+
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
|
45 |
+
if randomize_seed:
|
46 |
+
seed = random.randint(0, MAX_SEED)
|
47 |
+
return seed
|
48 |
+
|
49 |
+
|
50 |
+
@spaces.GPU
|
51 |
+
def process_hd(vton_img, garm_img,category, n_samples, n_steps, image_scale, seed):
|
52 |
+
model_type = 'hd'
|
53 |
+
with torch.no_grad():
|
54 |
+
openpose_model_hd.preprocessor.body_estimation.model.to('cuda')
|
55 |
+
ootd_model_hd.pipe.to('cuda')
|
56 |
+
ootd_model_hd.image_encoder.to('cuda')
|
57 |
+
ootd_model_hd.text_encoder.to('cuda')
|
58 |
+
|
59 |
+
garm_img = Image.open(garm_img).resize((768, 1024))
|
60 |
+
vton_img = Image.open(vton_img).resize((768, 1024))
|
61 |
+
keypoints = openpose_model_hd(vton_img.resize((384, 512)))
|
62 |
+
model_parse, _ = parsing_model_hd(vton_img.resize((384, 512)))
|
63 |
+
|
64 |
+
mask, mask_gray = get_mask_location(model_type, category_dict_utils[category], model_parse, keypoints)
|
65 |
+
mask = mask.resize((768, 1024), Image.NEAREST)
|
66 |
+
mask_gray = mask_gray.resize((768, 1024), Image.NEAREST)
|
67 |
+
|
68 |
+
masked_vton_img = Image.composite(mask_gray, vton_img, mask)
|
69 |
+
|
70 |
+
images = ootd_model_hd(
|
71 |
+
model_type=model_type,
|
72 |
+
category=category_dict[category],
|
73 |
+
image_garm=garm_img,
|
74 |
+
image_vton=masked_vton_img,
|
75 |
+
mask=mask,
|
76 |
+
image_ori=vton_img,
|
77 |
+
num_samples=n_samples,
|
78 |
+
num_steps=n_steps,
|
79 |
+
image_scale=image_scale,
|
80 |
+
seed=seed,
|
81 |
+
)
|
82 |
+
|
83 |
+
return images
|
84 |
+
|
85 |
+
|
86 |
+
@app.get("/")
|
87 |
+
def root():
|
88 |
+
return "Welcome to the Fashion OOTDiffusion API "
|
89 |
+
|
90 |
+
# Route pour récupérer l'image générée
|
91 |
+
@app.route('/api/get_image/<image_id>', methods=['GET'])
|
92 |
+
def get_image(image_id):
|
93 |
+
# Construire le chemin complet de l'image
|
94 |
+
image_path = image_id # Assurez-vous que le nom de fichier correspond à celui que vous avez utilisé lors de la sauvegarde
|
95 |
+
|
96 |
+
# Renvoyer l'image
|
97 |
+
try:
|
98 |
+
return send_file(image_path, mimetype='image/png')
|
99 |
+
except FileNotFoundError:
|
100 |
+
return jsonify({'error': 'Image not found'}), 404
|
101 |
+
|
102 |
+
# Route pour l'API REST
|
103 |
+
@app.route('/api/run', methods=['POST'])
|
104 |
+
def run():
|
105 |
+
data = request.json
|
106 |
+
print(data)
|
107 |
+
vton_img = data['vton_img']
|
108 |
+
garm_img = data['garm_img']
|
109 |
+
category = data['category']
|
110 |
+
n_samples = data['n_samples']
|
111 |
+
n_steps = data['n_steps']
|
112 |
+
image_scale = data['image_scale']
|
113 |
+
seed = data['seed']
|
114 |
+
result = process_hd(vton_img, garm_img,category, n_samples, n_steps, image_scale, seed)
|
115 |
+
return jsonify({'out': result})
|
116 |
+
|
117 |
+
if __name__ == "__main__":
|
118 |
+
app.run(host="0.0.0.0", port=7860)
|