Saad0KH commited on
Commit
30ffa26
1 Parent(s): 50515cb

Create inference_ootd.py

Browse files
Files changed (1) hide show
  1. ootd/inference_ootd.py +125 -0
ootd/inference_ootd.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ from pathlib import Path
4
+ import sys
5
+ import torch
6
+ from PIL import Image, ImageOps
7
+
8
+ from utils_ootd import get_mask_location
9
+
10
+ PROJECT_ROOT = Path(__file__).absolute().parents[1].absolute()
11
+ sys.path.insert(0, str(PROJECT_ROOT))
12
+
13
+ from preprocess.openpose.run_openpose import OpenPose
14
+ from preprocess.humanparsing.run_parsing import Parsing
15
+ from ootd.inference_ootd_hd import OOTDiffusionHD
16
+ from ootd.inference_ootd_dc import OOTDiffusionDC
17
+
18
+
19
+ openpose_model_hd = OpenPose(0)
20
+ parsing_model_hd = Parsing(0)
21
+ ootd_model_hd = OOTDiffusionHD(0)
22
+
23
+ openpose_model_dc = OpenPose(1)
24
+ parsing_model_dc = Parsing(1)
25
+ ootd_model_dc = OOTDiffusionDC(1)
26
+
27
+
28
+ category_dict = ['upperbody', 'lowerbody', 'dress']
29
+ category_dict_utils = ['upper_body', 'lower_body', 'dresses']
30
+
31
+
32
+ example_path = os.path.join(os.path.dirname(__file__), 'examples')
33
+ model_hd = os.path.join(example_path, 'model/model_1.png')
34
+ garment_hd = os.path.join(example_path, 'garment/03244_00.jpg')
35
+ model_dc = os.path.join(example_path, 'model/model_8.png')
36
+ garment_dc = os.path.join(example_path, 'garment/048554_1.jpg')
37
+
38
+
39
+ import spaces
40
+
41
+ @spaces.GPU
42
+ def process_hd(vton_img, garm_img, n_samples, n_steps, image_scale, seed):
43
+ model_type = 'hd'
44
+ category = 0 # 0:upperbody; 1:lowerbody; 2:dress
45
+
46
+ with torch.no_grad():
47
+ openpose_model_hd.preprocessor.body_estimation.model.to('cuda')
48
+ ootd_model_hd.pipe.to('cuda')
49
+ ootd_model_hd.image_encoder.to('cuda')
50
+ ootd_model_hd.text_encoder.to('cuda')
51
+
52
+ garm_img = Image.open(garm_img).resize((768, 1024))
53
+ vton_img = Image.open(vton_img).resize((768, 1024))
54
+ keypoints = openpose_model_hd(vton_img.resize((384, 512)))
55
+ model_parse, _ = parsing_model_hd(vton_img.resize((384, 512)))
56
+
57
+ mask, mask_gray = get_mask_location(model_type, category_dict_utils[category], model_parse, keypoints)
58
+ mask = mask.resize((768, 1024), Image.NEAREST)
59
+ mask_gray = mask_gray.resize((768, 1024), Image.NEAREST)
60
+
61
+ masked_vton_img = Image.composite(mask_gray, vton_img, mask)
62
+
63
+ images = ootd_model_hd(
64
+ model_type=model_type,
65
+ category=category_dict[category],
66
+ image_garm=garm_img,
67
+ image_vton=masked_vton_img,
68
+ mask=mask,
69
+ image_ori=vton_img,
70
+ num_samples=n_samples,
71
+ num_steps=n_steps,
72
+ image_scale=image_scale,
73
+ seed=seed,
74
+ )
75
+
76
+ return images
77
+
78
+ @spaces.GPU
79
+ def process_dc(vton_img, garm_img, category, n_samples, n_steps, image_scale, seed):
80
+ model_type = 'dc'
81
+ if category == 'Upper-body':
82
+ category = 0
83
+ elif category == 'Lower-body':
84
+ category = 1
85
+ else:
86
+ category =2
87
+
88
+ with torch.no_grad():
89
+ openpose_model_dc.preprocessor.body_estimation.model.to('cuda')
90
+ ootd_model_dc.pipe.to('cuda')
91
+ ootd_model_dc.image_encoder.to('cuda')
92
+ ootd_model_dc.text_encoder.to('cuda')
93
+
94
+ garm_img = Image.open(garm_img).resize((768, 1024))
95
+ vton_img = Image.open(vton_img).resize((768, 1024))
96
+ keypoints = openpose_model_dc(vton_img.resize((384, 512)))
97
+ model_parse, _ = parsing_model_dc(vton_img.resize((384, 512)))
98
+
99
+ mask, mask_gray = get_mask_location(model_type, category_dict_utils[category], model_parse, keypoints)
100
+ mask = mask.resize((768, 1024), Image.NEAREST)
101
+ mask_gray = mask_gray.resize((768, 1024), Image.NEAREST)
102
+
103
+ masked_vton_img = Image.composite(mask_gray, vton_img, mask)
104
+
105
+ images = ootd_model_dc(
106
+ model_type=model_type,
107
+ category=category_dict[category],
108
+ image_garm=garm_img,
109
+ image_vton=masked_vton_img,
110
+ mask=mask,
111
+ image_ori=vton_img,
112
+ num_samples=n_samples,
113
+ num_steps=n_steps,
114
+ image_scale=image_scale,
115
+ seed=seed,
116
+ )
117
+
118
+ return images
119
+
120
+
121
+ block = gr.Interface(fn=process_hd, inputs=["image", "image", "number", "number", "number", "number"], outputs="image", title="OOTDiffusion Demo HD")
122
+ block.launch()
123
+
124
+ block_dc = gr.Interface(fn=process_dc, inputs=["image", "image", "dropdown", "number", "number", "number", "number"], outputs="image", title="OOTDiffusion Demo DC")
125
+ block_dc.launch(api_name='generate')