Saad0KH commited on
Commit
b390b12
1 Parent(s): bb797a5

Create gradio_ootd.py

Browse files
Files changed (1) hide show
  1. run/gradio_ootd.py +118 -0
run/gradio_ootd.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify ,send_file
2
+ from PIL import Image
3
+ import base64
4
+ import io
5
+ import random
6
+ import uuid
7
+ import numpy as np
8
+ import spaces
9
+ import torch
10
+
11
+ from utils_ootd import get_mask_location
12
+
13
+ PROJECT_ROOT = Path(__file__).absolute().parents[1].absolute()
14
+ sys.path.insert(0, str(PROJECT_ROOT))
15
+
16
+ from preprocess.openpose.run_openpose import OpenPose
17
+ from preprocess.humanparsing.run_parsing import Parsing
18
+ from ootd.inference_ootd_hd import OOTDiffusionHD
19
+ from ootd.inference_ootd_dc import OOTDiffusionDC
20
+
21
+
22
+ openpose_model_hd = OpenPose(0)
23
+ parsing_model_hd = Parsing(0)
24
+ ootd_model_hd = OOTDiffusionHD(0)
25
+
26
+ openpose_model_dc = OpenPose(1)
27
+ parsing_model_dc = Parsing(1)
28
+ ootd_model_dc = OOTDiffusionDC(1)
29
+
30
+
31
+ category_dict = ['upperbody', 'lowerbody', 'dress']
32
+ category_dict_utils = ['upper_body', 'lower_body', 'dresses']
33
+
34
+
35
+
36
+ # Créer une instance FastAPI
37
+ app = Flask(__name__)
38
+
39
+ def save_image(img):
40
+ unique_name = str(uuid.uuid4()) + ".png"
41
+ img.save(unique_name)
42
+ return unique_name
43
+
44
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
45
+ if randomize_seed:
46
+ seed = random.randint(0, MAX_SEED)
47
+ return seed
48
+
49
+
50
+ @spaces.GPU
51
+ def process_hd(vton_img, garm_img,category, n_samples, n_steps, image_scale, seed):
52
+ model_type = 'hd'
53
+ with torch.no_grad():
54
+ openpose_model_hd.preprocessor.body_estimation.model.to('cuda')
55
+ ootd_model_hd.pipe.to('cuda')
56
+ ootd_model_hd.image_encoder.to('cuda')
57
+ ootd_model_hd.text_encoder.to('cuda')
58
+
59
+ garm_img = Image.open(garm_img).resize((768, 1024))
60
+ vton_img = Image.open(vton_img).resize((768, 1024))
61
+ keypoints = openpose_model_hd(vton_img.resize((384, 512)))
62
+ model_parse, _ = parsing_model_hd(vton_img.resize((384, 512)))
63
+
64
+ mask, mask_gray = get_mask_location(model_type, category_dict_utils[category], model_parse, keypoints)
65
+ mask = mask.resize((768, 1024), Image.NEAREST)
66
+ mask_gray = mask_gray.resize((768, 1024), Image.NEAREST)
67
+
68
+ masked_vton_img = Image.composite(mask_gray, vton_img, mask)
69
+
70
+ images = ootd_model_hd(
71
+ model_type=model_type,
72
+ category=category_dict[category],
73
+ image_garm=garm_img,
74
+ image_vton=masked_vton_img,
75
+ mask=mask,
76
+ image_ori=vton_img,
77
+ num_samples=n_samples,
78
+ num_steps=n_steps,
79
+ image_scale=image_scale,
80
+ seed=seed,
81
+ )
82
+
83
+ return images
84
+
85
+
86
+ @app.get("/")
87
+ def root():
88
+ return "Welcome to the Fashion OOTDiffusion API "
89
+
90
+ # Route pour récupérer l'image générée
91
+ @app.route('/api/get_image/<image_id>', methods=['GET'])
92
+ def get_image(image_id):
93
+ # Construire le chemin complet de l'image
94
+ image_path = image_id # Assurez-vous que le nom de fichier correspond à celui que vous avez utilisé lors de la sauvegarde
95
+
96
+ # Renvoyer l'image
97
+ try:
98
+ return send_file(image_path, mimetype='image/png')
99
+ except FileNotFoundError:
100
+ return jsonify({'error': 'Image not found'}), 404
101
+
102
+ # Route pour l'API REST
103
+ @app.route('/api/run', methods=['POST'])
104
+ def run():
105
+ data = request.json
106
+ print(data)
107
+ vton_img = data['vton_img']
108
+ garm_img = data['garm_img']
109
+ category = data['category']
110
+ n_samples = data['n_samples']
111
+ n_steps = data['n_steps']
112
+ image_scale = data['image_scale']
113
+ seed = data['seed']
114
+ result = process_hd(vton_img, garm_img,category, n_samples, n_steps, image_scale, seed)
115
+ return jsonify({'out': result})
116
+
117
+ if __name__ == "__main__":
118
+ app.run(host="0.0.0.0", port=7860)