Saad0KH commited on
Commit
9024e34
1 Parent(s): 30ffa26

Update run/gradio_ootd.py

Browse files
Files changed (1) hide show
  1. run/gradio_ootd.py +57 -63
run/gradio_ootd.py CHANGED
@@ -1,18 +1,10 @@
1
- from flask import Flask, request, jsonify ,send_file
2
- import base64
3
- import io
4
- import random
5
- import uuid
6
- import numpy as np
7
- import spaces
8
- import torch
9
  import os
10
  from pathlib import Path
11
  import sys
12
  import torch
13
  from PIL import Image, ImageOps
14
 
15
-
16
  from utils_ootd import get_mask_location
17
 
18
  PROJECT_ROOT = Path(__file__).absolute().parents[1].absolute()
@@ -36,43 +28,74 @@ ootd_model_dc = OOTDiffusionDC(1)
36
  category_dict = ['upperbody', 'lowerbody', 'dress']
37
  category_dict_utils = ['upper_body', 'lower_body', 'dresses']
38
 
39
- torch.cuda.empty_cache()
40
 
 
 
 
 
 
41
 
42
- # Créer une instance FastAPI
43
- app = Flask(__name__)
44
 
45
- def save_image(img):
46
- unique_name = str(uuid.uuid4()) + ".png"
47
- img.save(unique_name)
48
- return unique_name
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
- def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
51
- if randomize_seed:
52
- seed = random.randint(0, MAX_SEED)
53
- return seed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
- # Fonction pour décoder une image encodée en base64 en objet PIL.Image.Image
56
- def decode_image_from_base64(image_data):
57
- image_data = base64.b64decode(image_data)
58
- image = Image.open(io.BytesIO(image_data))
59
- return image
60
 
61
  @spaces.GPU
62
- def process_hd(vton_img, garm_img,category, n_samples, n_steps, image_scale, seed):
63
  model_type = 'dc'
 
 
 
 
 
 
 
64
  with torch.no_grad():
65
  openpose_model_dc.preprocessor.body_estimation.model.to('cuda')
66
  ootd_model_dc.pipe.to('cuda')
67
  ootd_model_dc.image_encoder.to('cuda')
68
  ootd_model_dc.text_encoder.to('cuda')
69
-
70
- garm_img = decode_image_from_base64(garm_img).resize((768, 1024))
71
- vton_img = decode_image_from_base64(vton_img).resize((768, 1024))
72
  keypoints = openpose_model_dc(vton_img.resize((384, 512)))
73
  model_parse, _ = parsing_model_dc(vton_img.resize((384, 512)))
74
 
75
-
76
  mask, mask_gray = get_mask_location(model_type, category_dict_utils[category], model_parse, keypoints)
77
  mask = mask.resize((768, 1024), Image.NEAREST)
78
  mask_gray = mask_gray.resize((768, 1024), Image.NEAREST)
@@ -95,37 +118,8 @@ def process_hd(vton_img, garm_img,category, n_samples, n_steps, image_scale, see
95
  return images
96
 
97
 
98
- @app.get("/")
99
- def root():
100
- return "Welcome to the Fashion OOTDiffusion API "
101
-
102
- # Route pour récupérer l'image générée
103
- @app.route('/api/get_image/<image_id>', methods=['GET'])
104
- def get_image(image_id):
105
- # Construire le chemin complet de l'image
106
- image_path = image_id # Assurez-vous que le nom de fichier correspond à celui que vous avez utilisé lors de la sauvegarde
107
 
108
- # Renvoyer l'image
109
- try:
110
- return send_file(image_path, mimetype='image/png')
111
- except FileNotFoundError:
112
- return jsonify({'error': 'Image not found'}), 404
113
-
114
- # Route pour l'API REST
115
- @spaces.GPU
116
- @app.route('/api/run', methods=['POST'])
117
- def run():
118
- data = request.json
119
- print(data)
120
- vton_img = data['vton_img']
121
- garm_img = data['garm_img']
122
- category = data['category']
123
- n_samples = data['n_samples']
124
- n_steps = data['n_steps']
125
- image_scale = data['image_scale']
126
- seed = data['seed']
127
- result = process_hd(vton_img, garm_img,category, n_samples, n_steps, image_scale, seed)
128
- return jsonify({'out': result})
129
-
130
- if __name__ == "__main__":
131
- app.run(host="0.0.0.0", port=7860)
 
1
+ import gradio as gr
 
 
 
 
 
 
 
2
  import os
3
  from pathlib import Path
4
  import sys
5
  import torch
6
  from PIL import Image, ImageOps
7
 
 
8
  from utils_ootd import get_mask_location
9
 
10
  PROJECT_ROOT = Path(__file__).absolute().parents[1].absolute()
 
28
  category_dict = ['upperbody', 'lowerbody', 'dress']
29
  category_dict_utils = ['upper_body', 'lower_body', 'dresses']
30
 
 
31
 
32
+ example_path = os.path.join(os.path.dirname(__file__), 'examples')
33
+ model_hd = os.path.join(example_path, 'model/model_1.png')
34
+ garment_hd = os.path.join(example_path, 'garment/03244_00.jpg')
35
+ model_dc = os.path.join(example_path, 'model/model_8.png')
36
+ garment_dc = os.path.join(example_path, 'garment/048554_1.jpg')
37
 
 
 
38
 
39
+ import spaces
40
+
41
+ @spaces.GPU
42
+ def process_hd(vton_img, garm_img, n_samples, n_steps, image_scale, seed):
43
+ model_type = 'hd'
44
+ category = 0 # 0:upperbody; 1:lowerbody; 2:dress
45
+
46
+ with torch.no_grad():
47
+ openpose_model_hd.preprocessor.body_estimation.model.to('cuda')
48
+ ootd_model_hd.pipe.to('cuda')
49
+ ootd_model_hd.image_encoder.to('cuda')
50
+ ootd_model_hd.text_encoder.to('cuda')
51
+
52
+ garm_img = Image.open(garm_img).resize((768, 1024))
53
+ vton_img = Image.open(vton_img).resize((768, 1024))
54
+ keypoints = openpose_model_hd(vton_img.resize((384, 512)))
55
+ model_parse, _ = parsing_model_hd(vton_img.resize((384, 512)))
56
 
57
+ mask, mask_gray = get_mask_location(model_type, category_dict_utils[category], model_parse, keypoints)
58
+ mask = mask.resize((768, 1024), Image.NEAREST)
59
+ mask_gray = mask_gray.resize((768, 1024), Image.NEAREST)
60
+
61
+ masked_vton_img = Image.composite(mask_gray, vton_img, mask)
62
+
63
+ images = ootd_model_hd(
64
+ model_type=model_type,
65
+ category=category_dict[category],
66
+ image_garm=garm_img,
67
+ image_vton=masked_vton_img,
68
+ mask=mask,
69
+ image_ori=vton_img,
70
+ num_samples=n_samples,
71
+ num_steps=n_steps,
72
+ image_scale=image_scale,
73
+ seed=seed,
74
+ )
75
 
76
+ return images
 
 
 
 
77
 
78
  @spaces.GPU
79
+ def process_dc(vton_img, garm_img, category, n_samples, n_steps, image_scale, seed):
80
  model_type = 'dc'
81
+ if category == 'Upper-body':
82
+ category = 0
83
+ elif category == 'Lower-body':
84
+ category = 1
85
+ else:
86
+ category =2
87
+
88
  with torch.no_grad():
89
  openpose_model_dc.preprocessor.body_estimation.model.to('cuda')
90
  ootd_model_dc.pipe.to('cuda')
91
  ootd_model_dc.image_encoder.to('cuda')
92
  ootd_model_dc.text_encoder.to('cuda')
93
+
94
+ garm_img = Image.open(garm_img).resize((768, 1024))
95
+ vton_img = Image.open(vton_img).resize((768, 1024))
96
  keypoints = openpose_model_dc(vton_img.resize((384, 512)))
97
  model_parse, _ = parsing_model_dc(vton_img.resize((384, 512)))
98
 
 
99
  mask, mask_gray = get_mask_location(model_type, category_dict_utils[category], model_parse, keypoints)
100
  mask = mask.resize((768, 1024), Image.NEAREST)
101
  mask_gray = mask_gray.resize((768, 1024), Image.NEAREST)
 
118
  return images
119
 
120
 
121
+ block = gr.Interface(fn=process_hd, inputs=["image", "image", "number", "number", "number", "number"], outputs="image", title="OOTDiffusion Demo HD")
122
+ block.launch()
 
 
 
 
 
 
 
123
 
124
+ block_dc = gr.Interface(fn=process_dc, inputs=["image", "image", "dropdown", "number", "number", "number", "number"], outputs="image", title="OOTDiffusion Demo DC")
125
+ block_dc.launch(api_name='generate')