mrtuandao commited on
Commit
62e05d1
1 Parent(s): 1fd8aab

Upload folder using huggingface_hub

Browse files
.ipynb_checkpoints/app-checkpoint.py CHANGED
@@ -309,5 +309,5 @@ with image_blocks as demo:
309
 
310
 
311
 
312
- image_blocks.launch(share=True)
313
 
 
309
 
310
 
311
 
312
+ image_blocks.launch()
313
 
.ipynb_checkpoints/call_api-checkpoint.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import runpod
2
+ import json
3
+ import os
4
+ import requests
5
+ import time
6
+
7
+ # json test file name
8
+ json_test_file = "test_input.json"
9
+
10
+ # open test file
11
+ with open(json_test_file) as f:
12
+ input = json.load(f)
13
+
14
+ # load runpod api key and serverless model id
15
+ runpod.api_key = ""
16
+ RUNPOD_MODEL_ID = ""
17
+
18
+ endpoint = runpod.Endpoint(RUNPOD_MODEL_ID)
19
+
20
+ start = time.time()
21
+
22
+ # First way to call serverless api
23
+ run_request = endpoint.run_sync(input)
24
+ print(run_request)
25
+
26
+ end = time.time()
27
+ print('Time taken: ', end-start)
28
+
29
+ # Second way to call serverless api
30
+ url = f'https://api.runpod.ai/v2/{RUNPOD_MODEL_ID}/run_sync' # or change to runsync
31
+ headers = {
32
+ 'accept': 'application/json',
33
+ 'Content-Type': 'application/json',
34
+ 'Authorization': f'Bearer {runpod.api_key}'
35
+ }
36
+ start = time.time()
37
+ response = requests.post(url, headers=headers, data=json.dumps(input))
38
+ print('\n')
39
+ print(response.json())
40
+ end = time.time()
41
+ print('Time taken: ', end-start)
.ipynb_checkpoints/handler-checkpoint.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import runpod
2
+ from helpers import prepare_pipeline, get_result, b64_to_pil
3
+ import base64
4
+ from PIL import Image
5
+
6
+ def handler(job):
7
+ human_img_b64 = job['input']['human_img_b64']
8
+ human_img = b64_to_pil(human_img_b64)
9
+
10
+ garm_img_b64 = job['input']['garm_img_b64']
11
+ garm_img = b64_to_pil(garm_img_b64)
12
+
13
+ denoise_steps = job['input'].get('denoise_steps') if job['input'].get('denoise_steps') else 30
14
+
15
+ seed = job['input'].get('seed') if job['input'].get('seed') else 42
16
+
17
+ is_checked_crop = job['input'].get('is_checked_crop') if job['input'].get('is_checked_crop') else False
18
+
19
+ garment_des = job['input'].get('garment_des') if job['input'].get('garment_des') else ""
20
+
21
+ result = get_result(PIPE, human_img, garm_img, denoise_steps, seed, is_checked_crop, garment_des)
22
+ return pil_to_b64(result)
23
+
24
+ PIPE = prepare_pipeline()
25
+ runpod.serverless.start({"handler": handler})
.ipynb_checkpoints/helpers-checkpoint.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ from io import BytesIO
3
+ from PIL import Image
4
+ from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline
5
+ from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref
6
+ from src.unet_hacked_tryon import UNet2DConditionModel
7
+ from transformers import (
8
+ CLIPImageProcessor,
9
+ CLIPVisionModelWithProjection,
10
+ CLIPTextModel,
11
+ CLIPTextModelWithProjection,
12
+ )
13
+ from diffusers import DDPMScheduler,AutoencoderKL
14
+ from typing import List
15
+
16
+ import torch
17
+ import os
18
+ from transformers import AutoTokenizer
19
+ import numpy as np
20
+ from utils_mask import get_mask_location
21
+ from torchvision import transforms
22
+ import apply_net
23
+ from preprocess.humanparsing.run_parsing import Parsing
24
+ from preprocess.openpose.run_openpose import OpenPose
25
+ from detectron2.data.detection_utils import convert_PIL_to_numpy,_apply_exif_orientation
26
+ from torchvision.transforms.functional import to_pil_image
27
+
28
+
29
+ def b64_to_pil():
30
+ # Decode the base64 string
31
+ image_data = base64.b64decode(base64_string)
32
+
33
+ # Create a PIL Image object from the decoded image data
34
+ image = Image.open(BytesIO(image_data))
35
+ return image
36
+
37
+ def prepare_pipeline():
38
+ pass
39
+ base_path = 'yisol/IDM-VTON'
40
+ example_path = os.path.join(os.path.dirname(__file__), 'example')
41
+
42
+ unet = UNet2DConditionModel.from_pretrained(
43
+ base_path,
44
+ subfolder="unet",
45
+ torch_dtype=torch.float16,
46
+ )
47
+ unet.requires_grad_(False)
48
+ tokenizer_one = AutoTokenizer.from_pretrained(
49
+ base_path,
50
+ subfolder="tokenizer",
51
+ revision=None,
52
+ use_fast=False,
53
+ )
54
+ tokenizer_two = AutoTokenizer.from_pretrained(
55
+ base_path,
56
+ subfolder="tokenizer_2",
57
+ revision=None,
58
+ use_fast=False,
59
+ )
60
+ noise_scheduler = DDPMScheduler.from_pretrained(base_path, subfolder="scheduler")
61
+
62
+ text_encoder_one = CLIPTextModel.from_pretrained(
63
+ base_path,
64
+ subfolder="text_encoder",
65
+ torch_dtype=torch.float16,
66
+ )
67
+ text_encoder_two = CLIPTextModelWithProjection.from_pretrained(
68
+ base_path,
69
+ subfolder="text_encoder_2",
70
+ torch_dtype=torch.float16,
71
+ )
72
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
73
+ base_path,
74
+ subfolder="image_encoder",
75
+ torch_dtype=torch.float16,
76
+ )
77
+ vae = AutoencoderKL.from_pretrained(base_path,
78
+ subfolder="vae",
79
+ torch_dtype=torch.float16,
80
+ )
81
+
82
+ # "stabilityai/stable-diffusion-xl-base-1.0",
83
+ UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(
84
+ base_path,
85
+ subfolder="unet_encoder",
86
+ torch_dtype=torch.float16,
87
+ )
88
+
89
+ parsing_model = Parsing(0)
90
+ openpose_model = OpenPose(0)
91
+
92
+ UNet_Encoder.requires_grad_(False)
93
+ image_encoder.requires_grad_(False)
94
+ vae.requires_grad_(False)
95
+ unet.requires_grad_(False)
96
+ text_encoder_one.requires_grad_(False)
97
+ text_encoder_two.requires_grad_(False)
98
+ tensor_transfrom = transforms.Compose(
99
+ [
100
+ transforms.ToTensor(),
101
+ transforms.Normalize([0.5], [0.5]),
102
+ ]
103
+ )
104
+
105
+ pipe = TryonPipeline.from_pretrained(
106
+ base_path,
107
+ unet=unet,
108
+ vae=vae,
109
+ feature_extractor= CLIPImageProcessor(),
110
+ text_encoder = text_encoder_one,
111
+ text_encoder_2 = text_encoder_two,
112
+ tokenizer = tokenizer_one,
113
+ tokenizer_2 = tokenizer_two,
114
+ scheduler = noise_scheduler,
115
+ image_encoder=image_encoder,
116
+ torch_dtype=torch.float16,
117
+ )
118
+ pipe.unet_encoder = UNet_Encoder
119
+
120
+ def get_result(human_img,garm_img, body_part="upper_body",denoise_steps=30,seed=42,is_checked_crop=False,garment_des=""):
121
+ device = "cuda"
122
+
123
+ openpose_model.preprocessor.body_estimation.model.to(device)
124
+ pipe.to(device)
125
+ pipe.unet_encoder.to(device)
126
+
127
+ garm_img= garm_img.convert("RGB").resize((768,1024))
128
+ human_img_orig = human_img
129
+
130
+ if is_checked_crop:
131
+ width, height = human_img_orig.size
132
+ target_width = int(min(width, height * (3 / 4)))
133
+ target_height = int(min(height, width * (4 / 3)))
134
+ left = (width - target_width) / 2
135
+ top = (height - target_height) / 2
136
+ right = (width + target_width) / 2
137
+ bottom = (height + target_height) / 2
138
+ cropped_img = human_img_orig.crop((left, top, right, bottom))
139
+ crop_size = cropped_img.size
140
+ human_img = cropped_img.resize((768,1024))
141
+ else:
142
+ human_img = human_img_orig.resize((768,1024))
143
+
144
+
145
+ keypoints = openpose_model(human_img.resize((384,512)))
146
+ model_parse, _ = parsing_model(human_img.resize((384,512)))
147
+ mask, mask_gray = get_mask_location('hd', body_part, model_parse, keypoints)
148
+ mask = mask.resize((768,1024))
149
+
150
+ mask_gray = (1-transforms.ToTensor()(mask)) * tensor_transfrom(human_img)
151
+ mask_gray = to_pil_image((mask_gray+1.0)/2.0)
152
+
153
+
154
+ human_img_arg = _apply_exif_orientation(human_img.resize((384,512)))
155
+ human_img_arg = convert_PIL_to_numpy(human_img_arg, format="BGR")
156
+
157
+
158
+
159
+ args = apply_net.create_argument_parser().parse_args(('show', './configs/densepose_rcnn_R_50_FPN_s1x.yaml', './ckpt/densepose/model_final_162be9.pkl', 'dp_segm', '-v', '--opts', 'MODEL.DEVICE', 'cuda'))
160
+ # verbosity = getattr(args, "verbosity", None)
161
+ pose_img = args.func(args,human_img_arg)
162
+ pose_img = pose_img[:,:,::-1]
163
+ pose_img = Image.fromarray(pose_img).resize((768,1024))
164
+
165
+ with torch.no_grad():
166
+ # Extract the images
167
+ with torch.cuda.amp.autocast():
168
+ with torch.no_grad():
169
+ prompt = "model is wearing " + garment_des
170
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
171
+ with torch.inference_mode():
172
+ (
173
+ prompt_embeds,
174
+ negative_prompt_embeds,
175
+ pooled_prompt_embeds,
176
+ negative_pooled_prompt_embeds,
177
+ ) = pipe.encode_prompt(
178
+ prompt,
179
+ num_images_per_prompt=1,
180
+ do_classifier_free_guidance=True,
181
+ negative_prompt=negative_prompt,
182
+ )
183
+
184
+ prompt = "a photo of " + garment_des
185
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
186
+ if not isinstance(prompt, List):
187
+ prompt = [prompt] * 1
188
+ if not isinstance(negative_prompt, List):
189
+ negative_prompt = [negative_prompt] * 1
190
+ with torch.inference_mode():
191
+ (
192
+ prompt_embeds_c,
193
+ _,
194
+ _,
195
+ _,
196
+ ) = pipe.encode_prompt(
197
+ prompt,
198
+ num_images_per_prompt=1,
199
+ do_classifier_free_guidance=False,
200
+ negative_prompt=negative_prompt,
201
+ )
202
+
203
+
204
+
205
+ pose_img = tensor_transfrom(pose_img).unsqueeze(0).to(device,torch.float16)
206
+ garm_tensor = tensor_transfrom(garm_img).unsqueeze(0).to(device,torch.float16)
207
+ generator = torch.Generator(device).manual_seed(seed) if seed is not None else None
208
+ images = pipe(
209
+ prompt_embeds=prompt_embeds.to(device,torch.float16),
210
+ negative_prompt_embeds=negative_prompt_embeds.to(device,torch.float16),
211
+ pooled_prompt_embeds=pooled_prompt_embeds.to(device,torch.float16),
212
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.to(device,torch.float16),
213
+ num_inference_steps=denoise_steps,
214
+ generator=generator,
215
+ strength = 1.0,
216
+ pose_img = pose_img.to(device,torch.float16),
217
+ text_embeds_cloth=prompt_embeds_c.to(device,torch.float16),
218
+ cloth = garm_tensor.to(device,torch.float16),
219
+ mask_image=mask,
220
+ image=human_img,
221
+ height=1024,
222
+ width=768,
223
+ ip_adapter_image = garm_img.resize((768,1024)),
224
+ guidance_scale=2.0,
225
+ )[0]
226
+
227
+ if is_checked_crop:
228
+ out_img = images[0].resize(crop_size)
229
+ human_img_orig.paste(out_img, (int(left), int(top)))
230
+ return human_img_orig, mask_gray
231
+ else:
232
+ return images[0], mask_gray
233
+ # return images[0], mask_gray
234
+
235
+
236
+
.ipynb_checkpoints/streamlit_code-checkpoint.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from streamlit_image_select import image_select
3
+ from helpers import get_result
4
+ from PIL import Image
5
+ import numpy as np
6
+
7
+ # UI configurations
8
+ st.set_page_config(page_title="AIFR - Demo",
9
+ page_icon=":bridge_at_night:",
10
+ layout="wide")
11
+ st.markdown("# :rainbow[AIFR - Demo]")
12
+ # 3 columns
13
+ col1, col2, col3, col4 = st.columns(4)
14
+
15
+ with col1:
16
+ st.header("User Image")
17
+
18
+ user_image_holder = st.empty()
19
+ # upload file
20
+ user_image = st.file_uploader("Upload User Image")
21
+ if user_image is not None:
22
+ img = None
23
+ user_image_holder.image(user_image, use_column_width=True)
24
+
25
+ # st.write("Examples")
26
+ # img1 = image_select(
27
+ # label="Select a cat",
28
+ # images=[
29
+ # "example1.jpg",
30
+ # "example2.jpg"
31
+ # ],
32
+ # captions=["A cat", "Another cat"],
33
+ # )
34
+ # if img1 and user_image is None:
35
+ # user_image = img1
36
+ # user_image_holder.image(user_image, use_column_width=True)
37
+
38
+ with col2:
39
+ st.header("Clothes Image")
40
+
41
+ clothes_image_holder = st.empty()
42
+ # upload file
43
+ clothes_image = st.file_uploader("Upload Clothes Image")
44
+ if clothes_image is not None:
45
+ clothes_image_holder.image(clothes_image, use_column_width=True)
46
+
47
+ # st.write("Examples")
48
+ # img2 = image_select(
49
+ # label="Select a dress",
50
+ # images=[
51
+ # "https://bagongkia.github.io/react-image-picker/0759b6e526e3c6d72569894e58329d89.jpg",
52
+ # "https://bagongkia.github.io/react-image-picker/0759b6e526e3c6d72569894e58329d89.jpg"
53
+ # ],
54
+ # captions=["A dress", "Another dress"],
55
+ # )
56
+
57
+ # if img2 and clothes_image is None:
58
+ # clothes_image = img2
59
+ # clothes_image_holder.image(clothes_image, use_column_width=True)
60
+ body_part = st.selectbox(
61
+ "Choose your body part",
62
+ ("dresses", "upper_body", "lower_body"))
63
+ submitted = st.button("Get result", use_container_width=True, type="primary")
64
+ output_image = mask_image = None
65
+
66
+ if submitted:
67
+ user_image = Image.open(user_image)
68
+ clothes_image = Image.open(clothes_image)
69
+ output_image, mask_image = get_result(user_image, clothes_image, body_part=body_part)
70
+
71
+ with col3:
72
+ st.header("Masked Image output")
73
+ if submitted:
74
+ if mask_image is not None:
75
+ st.image(mask_image, use_column_width=True)
76
+
77
+ with col4:
78
+ st.header("Output")
79
+ if submitted:
80
+ if output_image is not None:
81
+ st.image(output_image, use_column_width=True)
app.py CHANGED
@@ -309,5 +309,5 @@ with image_blocks as demo:
309
 
310
 
311
 
312
- image_blocks.launch(share=True)
313
 
 
309
 
310
 
311
 
312
+ image_blocks.launch()
313
 
call_api.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import runpod
2
+ import json
3
+ import os
4
+ import requests
5
+ import time
6
+
7
+ # json test file name
8
+ json_test_file = "test_input.json"
9
+
10
+ # open test file
11
+ with open(json_test_file) as f:
12
+ input = json.load(f)
13
+
14
+ # load runpod api key and serverless model id
15
+ runpod.api_key = ""
16
+ RUNPOD_MODEL_ID = ""
17
+
18
+ endpoint = runpod.Endpoint(RUNPOD_MODEL_ID)
19
+
20
+ start = time.time()
21
+
22
+ # First way to call serverless api
23
+ run_request = endpoint.run_sync(input)
24
+ print(run_request)
25
+
26
+ end = time.time()
27
+ print('Time taken: ', end-start)
28
+
29
+ # Second way to call serverless api
30
+ url = f'https://api.runpod.ai/v2/{RUNPOD_MODEL_ID}/run_sync' # or change to runsync
31
+ headers = {
32
+ 'accept': 'application/json',
33
+ 'Content-Type': 'application/json',
34
+ 'Authorization': f'Bearer {runpod.api_key}'
35
+ }
36
+ start = time.time()
37
+ response = requests.post(url, headers=headers, data=json.dumps(input))
38
+ print('\n')
39
+ print(response.json())
40
+ end = time.time()
41
+ print('Time taken: ', end-start)
handler.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import runpod
2
+ from helpers import prepare_pipeline, get_result, b64_to_pil
3
+ import base64
4
+ from PIL import Image
5
+
6
+ def handler(job):
7
+ human_img_b64 = job['input']['human_img_b64']
8
+ human_img = b64_to_pil(human_img_b64)
9
+
10
+ garm_img_b64 = job['input']['garm_img_b64']
11
+ garm_img = b64_to_pil(garm_img_b64)
12
+
13
+ denoise_steps = job['input'].get('denoise_steps') if job['input'].get('denoise_steps') else 30
14
+
15
+ seed = job['input'].get('seed') if job['input'].get('seed') else 42
16
+
17
+ is_checked_crop = job['input'].get('is_checked_crop') if job['input'].get('is_checked_crop') else False
18
+
19
+ garment_des = job['input'].get('garment_des') if job['input'].get('garment_des') else ""
20
+
21
+ result = get_result(PIPE, human_img, garm_img, denoise_steps, seed, is_checked_crop, garment_des)
22
+ return pil_to_b64(result)
23
+
24
+ PIPE = prepare_pipeline()
25
+ runpod.serverless.start({"handler": handler})
helpers.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ from io import BytesIO
3
+ from PIL import Image
4
+ from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline
5
+ from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref
6
+ from src.unet_hacked_tryon import UNet2DConditionModel
7
+ from transformers import (
8
+ CLIPImageProcessor,
9
+ CLIPVisionModelWithProjection,
10
+ CLIPTextModel,
11
+ CLIPTextModelWithProjection,
12
+ )
13
+ from diffusers import DDPMScheduler,AutoencoderKL
14
+ from typing import List
15
+
16
+ import torch
17
+ import os
18
+ from transformers import AutoTokenizer
19
+ import numpy as np
20
+ from utils_mask import get_mask_location
21
+ from torchvision import transforms
22
+ import apply_net
23
+ from preprocess.humanparsing.run_parsing import Parsing
24
+ from preprocess.openpose.run_openpose import OpenPose
25
+ from detectron2.data.detection_utils import convert_PIL_to_numpy,_apply_exif_orientation
26
+ from torchvision.transforms.functional import to_pil_image
27
+
28
+
29
+ def b64_to_pil():
30
+ # Decode the base64 string
31
+ image_data = base64.b64decode(base64_string)
32
+
33
+ # Create a PIL Image object from the decoded image data
34
+ image = Image.open(BytesIO(image_data))
35
+ return image
36
+
37
+ def prepare_pipeline():
38
+ pass
39
+ base_path = 'yisol/IDM-VTON'
40
+ example_path = os.path.join(os.path.dirname(__file__), 'example')
41
+
42
+ unet = UNet2DConditionModel.from_pretrained(
43
+ base_path,
44
+ subfolder="unet",
45
+ torch_dtype=torch.float16,
46
+ )
47
+ unet.requires_grad_(False)
48
+ tokenizer_one = AutoTokenizer.from_pretrained(
49
+ base_path,
50
+ subfolder="tokenizer",
51
+ revision=None,
52
+ use_fast=False,
53
+ )
54
+ tokenizer_two = AutoTokenizer.from_pretrained(
55
+ base_path,
56
+ subfolder="tokenizer_2",
57
+ revision=None,
58
+ use_fast=False,
59
+ )
60
+ noise_scheduler = DDPMScheduler.from_pretrained(base_path, subfolder="scheduler")
61
+
62
+ text_encoder_one = CLIPTextModel.from_pretrained(
63
+ base_path,
64
+ subfolder="text_encoder",
65
+ torch_dtype=torch.float16,
66
+ )
67
+ text_encoder_two = CLIPTextModelWithProjection.from_pretrained(
68
+ base_path,
69
+ subfolder="text_encoder_2",
70
+ torch_dtype=torch.float16,
71
+ )
72
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
73
+ base_path,
74
+ subfolder="image_encoder",
75
+ torch_dtype=torch.float16,
76
+ )
77
+ vae = AutoencoderKL.from_pretrained(base_path,
78
+ subfolder="vae",
79
+ torch_dtype=torch.float16,
80
+ )
81
+
82
+ # "stabilityai/stable-diffusion-xl-base-1.0",
83
+ UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(
84
+ base_path,
85
+ subfolder="unet_encoder",
86
+ torch_dtype=torch.float16,
87
+ )
88
+
89
+ parsing_model = Parsing(0)
90
+ openpose_model = OpenPose(0)
91
+
92
+ UNet_Encoder.requires_grad_(False)
93
+ image_encoder.requires_grad_(False)
94
+ vae.requires_grad_(False)
95
+ unet.requires_grad_(False)
96
+ text_encoder_one.requires_grad_(False)
97
+ text_encoder_two.requires_grad_(False)
98
+ tensor_transfrom = transforms.Compose(
99
+ [
100
+ transforms.ToTensor(),
101
+ transforms.Normalize([0.5], [0.5]),
102
+ ]
103
+ )
104
+
105
+ pipe = TryonPipeline.from_pretrained(
106
+ base_path,
107
+ unet=unet,
108
+ vae=vae,
109
+ feature_extractor= CLIPImageProcessor(),
110
+ text_encoder = text_encoder_one,
111
+ text_encoder_2 = text_encoder_two,
112
+ tokenizer = tokenizer_one,
113
+ tokenizer_2 = tokenizer_two,
114
+ scheduler = noise_scheduler,
115
+ image_encoder=image_encoder,
116
+ torch_dtype=torch.float16,
117
+ )
118
+ pipe.unet_encoder = UNet_Encoder
119
+
120
+ def get_result(human_img,garm_img, body_part="upper_body",denoise_steps=30,seed=42,is_checked_crop=False,garment_des=""):
121
+ device = "cuda"
122
+
123
+ openpose_model.preprocessor.body_estimation.model.to(device)
124
+ pipe.to(device)
125
+ pipe.unet_encoder.to(device)
126
+
127
+ garm_img= garm_img.convert("RGB").resize((768,1024))
128
+ human_img_orig = human_img
129
+
130
+ if is_checked_crop:
131
+ width, height = human_img_orig.size
132
+ target_width = int(min(width, height * (3 / 4)))
133
+ target_height = int(min(height, width * (4 / 3)))
134
+ left = (width - target_width) / 2
135
+ top = (height - target_height) / 2
136
+ right = (width + target_width) / 2
137
+ bottom = (height + target_height) / 2
138
+ cropped_img = human_img_orig.crop((left, top, right, bottom))
139
+ crop_size = cropped_img.size
140
+ human_img = cropped_img.resize((768,1024))
141
+ else:
142
+ human_img = human_img_orig.resize((768,1024))
143
+
144
+
145
+ keypoints = openpose_model(human_img.resize((384,512)))
146
+ model_parse, _ = parsing_model(human_img.resize((384,512)))
147
+ mask, mask_gray = get_mask_location('hd', body_part, model_parse, keypoints)
148
+ mask = mask.resize((768,1024))
149
+
150
+ mask_gray = (1-transforms.ToTensor()(mask)) * tensor_transfrom(human_img)
151
+ mask_gray = to_pil_image((mask_gray+1.0)/2.0)
152
+
153
+
154
+ human_img_arg = _apply_exif_orientation(human_img.resize((384,512)))
155
+ human_img_arg = convert_PIL_to_numpy(human_img_arg, format="BGR")
156
+
157
+
158
+
159
+ args = apply_net.create_argument_parser().parse_args(('show', './configs/densepose_rcnn_R_50_FPN_s1x.yaml', './ckpt/densepose/model_final_162be9.pkl', 'dp_segm', '-v', '--opts', 'MODEL.DEVICE', 'cuda'))
160
+ # verbosity = getattr(args, "verbosity", None)
161
+ pose_img = args.func(args,human_img_arg)
162
+ pose_img = pose_img[:,:,::-1]
163
+ pose_img = Image.fromarray(pose_img).resize((768,1024))
164
+
165
+ with torch.no_grad():
166
+ # Extract the images
167
+ with torch.cuda.amp.autocast():
168
+ with torch.no_grad():
169
+ prompt = "model is wearing " + garment_des
170
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
171
+ with torch.inference_mode():
172
+ (
173
+ prompt_embeds,
174
+ negative_prompt_embeds,
175
+ pooled_prompt_embeds,
176
+ negative_pooled_prompt_embeds,
177
+ ) = pipe.encode_prompt(
178
+ prompt,
179
+ num_images_per_prompt=1,
180
+ do_classifier_free_guidance=True,
181
+ negative_prompt=negative_prompt,
182
+ )
183
+
184
+ prompt = "a photo of " + garment_des
185
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
186
+ if not isinstance(prompt, List):
187
+ prompt = [prompt] * 1
188
+ if not isinstance(negative_prompt, List):
189
+ negative_prompt = [negative_prompt] * 1
190
+ with torch.inference_mode():
191
+ (
192
+ prompt_embeds_c,
193
+ _,
194
+ _,
195
+ _,
196
+ ) = pipe.encode_prompt(
197
+ prompt,
198
+ num_images_per_prompt=1,
199
+ do_classifier_free_guidance=False,
200
+ negative_prompt=negative_prompt,
201
+ )
202
+
203
+
204
+
205
+ pose_img = tensor_transfrom(pose_img).unsqueeze(0).to(device,torch.float16)
206
+ garm_tensor = tensor_transfrom(garm_img).unsqueeze(0).to(device,torch.float16)
207
+ generator = torch.Generator(device).manual_seed(seed) if seed is not None else None
208
+ images = pipe(
209
+ prompt_embeds=prompt_embeds.to(device,torch.float16),
210
+ negative_prompt_embeds=negative_prompt_embeds.to(device,torch.float16),
211
+ pooled_prompt_embeds=pooled_prompt_embeds.to(device,torch.float16),
212
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.to(device,torch.float16),
213
+ num_inference_steps=denoise_steps,
214
+ generator=generator,
215
+ strength = 1.0,
216
+ pose_img = pose_img.to(device,torch.float16),
217
+ text_embeds_cloth=prompt_embeds_c.to(device,torch.float16),
218
+ cloth = garm_tensor.to(device,torch.float16),
219
+ mask_image=mask,
220
+ image=human_img,
221
+ height=1024,
222
+ width=768,
223
+ ip_adapter_image = garm_img.resize((768,1024)),
224
+ guidance_scale=2.0,
225
+ )[0]
226
+
227
+ if is_checked_crop:
228
+ out_img = images[0].resize(crop_size)
229
+ human_img_orig.paste(out_img, (int(left), int(top)))
230
+ return human_img_orig, mask_gray
231
+ else:
232
+ return images[0], mask_gray
233
+ # return images[0], mask_gray
234
+
235
+
236
+
maskask.webp ADDED
output.webp ADDED
streamlit_code.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from streamlit_image_select import image_select
3
+ from helpers import get_result
4
+ from PIL import Image
5
+ import numpy as np
6
+
7
+ # UI configurations
8
+ st.set_page_config(page_title="AIFR - Demo",
9
+ page_icon=":bridge_at_night:",
10
+ layout="wide")
11
+ st.markdown("# :rainbow[AIFR - Demo]")
12
+ # 3 columns
13
+ col1, col2, col3, col4 = st.columns(4)
14
+
15
+ with col1:
16
+ st.header("User Image")
17
+
18
+ user_image_holder = st.empty()
19
+ # upload file
20
+ user_image = st.file_uploader("Upload User Image")
21
+ if user_image is not None:
22
+ img = None
23
+ user_image_holder.image(user_image, use_column_width=True)
24
+
25
+ # st.write("Examples")
26
+ # img1 = image_select(
27
+ # label="Select a cat",
28
+ # images=[
29
+ # "example1.jpg",
30
+ # "example2.jpg"
31
+ # ],
32
+ # captions=["A cat", "Another cat"],
33
+ # )
34
+ # if img1 and user_image is None:
35
+ # user_image = img1
36
+ # user_image_holder.image(user_image, use_column_width=True)
37
+
38
+ with col2:
39
+ st.header("Clothes Image")
40
+
41
+ clothes_image_holder = st.empty()
42
+ # upload file
43
+ clothes_image = st.file_uploader("Upload Clothes Image")
44
+ if clothes_image is not None:
45
+ clothes_image_holder.image(clothes_image, use_column_width=True)
46
+
47
+ # st.write("Examples")
48
+ # img2 = image_select(
49
+ # label="Select a dress",
50
+ # images=[
51
+ # "https://bagongkia.github.io/react-image-picker/0759b6e526e3c6d72569894e58329d89.jpg",
52
+ # "https://bagongkia.github.io/react-image-picker/0759b6e526e3c6d72569894e58329d89.jpg"
53
+ # ],
54
+ # captions=["A dress", "Another dress"],
55
+ # )
56
+
57
+ # if img2 and clothes_image is None:
58
+ # clothes_image = img2
59
+ # clothes_image_holder.image(clothes_image, use_column_width=True)
60
+ body_part = st.selectbox(
61
+ "Choose your body part",
62
+ ("dresses", "upper_body", "lower_body"))
63
+ submitted = st.button("Get result", use_container_width=True, type="primary")
64
+ output_image = mask_image = None
65
+
66
+ if submitted:
67
+ user_image = Image.open(user_image)
68
+ clothes_image = Image.open(clothes_image)
69
+ output_image, mask_image = get_result(user_image, clothes_image, body_part=body_part)
70
+
71
+ with col3:
72
+ st.header("Masked Image output")
73
+ if submitted:
74
+ if mask_image is not None:
75
+ st.image(mask_image, use_column_width=True)
76
+
77
+ with col4:
78
+ st.header("Output")
79
+ if submitted:
80
+ if output_image is not None:
81
+ st.image(output_image, use_column_width=True)