franciszzj commited on
Commit
68f6086
β€’
1 Parent(s): 9b1ec91

update app

Browse files
Files changed (2) hide show
  1. app.py +151 -108
  2. utils/utils.py +12 -0
app.py CHANGED
@@ -6,90 +6,111 @@ from leffa.model import LeffaModel
6
  from leffa.inference import LeffaInference
7
  from utils.garment_agnostic_mask_predictor import AutoMasker
8
  from utils.densepose_predictor import DensePosePredictor
9
- from utils.utils import resize_and_center
10
 
11
  import gradio as gr
12
 
13
  # Download checkpoints
14
  snapshot_download(repo_id="franciszzj/Leffa", local_dir="./ckpts")
15
 
16
- mask_predictor = AutoMasker(
17
- densepose_path="./ckpts/densepose",
18
- schp_path="./ckpts/schp",
19
- )
20
-
21
- densepose_predictor = DensePosePredictor(
22
- config_path="./ckpts/densepose/densepose_rcnn_R_50_FPN_s1x.yaml",
23
- weights_path="./ckpts/densepose/model_final_162be9.pkl",
24
- )
25
-
26
- vt_model = LeffaModel(
27
- pretrained_model_name_or_path="./ckpts/stable-diffusion-inpainting",
28
- pretrained_model="./ckpts/virtual_tryon.pth",
29
- )
30
- vt_inference = LeffaInference(model=vt_model)
31
-
32
- pt_model = LeffaModel(
33
- pretrained_model_name_or_path="./ckpts/stable-diffusion-xl-1.0-inpainting-0.1",
34
- pretrained_model="./ckpts/pose_transfer.pth",
35
- )
36
- pt_inference = LeffaInference(model=pt_model)
37
-
38
-
39
- def leffa_predict(src_image_path, ref_image_path, control_type):
40
- assert control_type in [
41
- "virtual_tryon", "pose_transfer"], "Invalid control type: {}".format(control_type)
42
- src_image = Image.open(src_image_path)
43
- ref_image = Image.open(ref_image_path)
44
- src_image = resize_and_center(src_image, 768, 1024)
45
- ref_image = resize_and_center(ref_image, 768, 1024)
46
-
47
- src_image_array = np.array(src_image)
48
- ref_image_array = np.array(ref_image)
49
-
50
- # Mask
51
- if control_type == "virtual_tryon":
52
- src_image = src_image.convert("RGB")
53
- mask = mask_predictor(src_image, "upper")["mask"]
54
- elif control_type == "pose_transfer":
55
- mask = Image.fromarray(np.ones_like(src_image_array) * 255)
56
-
57
- # DensePose
58
- src_image_iuv_array = densepose_predictor.predict_iuv(src_image_array)
59
- src_image_seg_array = densepose_predictor.predict_seg(src_image_array)
60
- src_image_iuv = Image.fromarray(src_image_iuv_array)
61
- src_image_seg = Image.fromarray(src_image_seg_array)
62
- if control_type == "virtual_tryon":
63
- densepose = src_image_seg
64
- elif control_type == "pose_transfer":
65
- densepose = src_image_iuv
66
-
67
- # Leffa
68
- transform = LeffaTransform()
69
-
70
- data = {
71
- "src_image": [src_image],
72
- "ref_image": [ref_image],
73
- "mask": [mask],
74
- "densepose": [densepose],
75
- }
76
- data = transform(data)
77
- if control_type == "virtual_tryon":
78
- inference = vt_inference
79
- elif control_type == "pose_transfer":
80
- inference = pt_inference
81
- output = inference(data)
82
- gen_image = output["generated_image"][0]
83
- # gen_image.save("gen_image.png")
84
- return np.array(gen_image)
85
-
86
-
87
- def leffa_predict_vt(src_image_path, ref_image_path):
88
- return leffa_predict(src_image_path, ref_image_path, "virtual_tryon")
89
-
90
-
91
- def leffa_predict_pt(src_image_path, ref_image_path):
92
- return leffa_predict(src_image_path, ref_image_path, "pose_transfer")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
 
95
  if __name__ == "__main__":
@@ -100,14 +121,26 @@ if __name__ == "__main__":
100
  # control_type = sys.argv[3]
101
  # leffa_predict(src_image_path, ref_image_path, control_type)
102
 
 
 
 
 
 
 
103
  title = "## Leffa: Learning Flow Fields in Attention for Controllable Person Image Generation"
104
- link = "[πŸ“š Paper](https://arxiv.org/abs/2412.08486) - [πŸ”₯ Demo](https://huggingface.co/spaces/franciszzj/Leffa) - [πŸ€— Model](https://huggingface.co/franciszzj/Leffa)"
 
 
 
 
 
105
  description = "Leffa is a unified framework for controllable person image generation that enables precise manipulation of both appearance (i.e., virtual try-on) and pose (i.e., pose transfer)."
106
- note = "Note: The models used in the demo are trained solely on academic datasets. Virtual try-on uses VITON-HD, and pose transfer uses DeepFashion."
107
 
108
  with gr.Blocks(theme=gr.themes.Default(primary_hue=gr.themes.colors.pink, secondary_hue=gr.themes.colors.red)).queue() as demo:
109
  gr.Markdown(title)
110
  gr.Markdown(link)
 
111
  gr.Markdown(description)
112
 
113
  with gr.Tab("Control Appearance (Virtual Try-on)"):
@@ -124,12 +157,8 @@ if __name__ == "__main__":
124
 
125
  gr.Examples(
126
  inputs=vt_src_image,
127
- examples_per_page=5,
128
- examples=["./ckpts/examples/person1/01350_00.jpg",
129
- "./ckpts/examples/person1/01376_00.jpg",
130
- "./ckpts/examples/person1/01416_00.jpg",
131
- "./ckpts/examples/person1/05976_00.jpg",
132
- "./ckpts/examples/person1/06094_00.jpg",],
133
  )
134
 
135
  with gr.Column():
@@ -144,12 +173,8 @@ if __name__ == "__main__":
144
 
145
  gr.Examples(
146
  inputs=vt_ref_image,
147
- examples_per_page=5,
148
- examples=["./ckpts/examples/garment/01449_00.jpg",
149
- "./ckpts/examples/garment/01486_00.jpg",
150
- "./ckpts/examples/garment/01853_00.jpg",
151
- "./ckpts/examples/garment/02070_00.jpg",
152
- "./ckpts/examples/garment/03553_00.jpg",],
153
  )
154
 
155
  with gr.Column():
@@ -163,8 +188,24 @@ if __name__ == "__main__":
163
  with gr.Row():
164
  vt_gen_button = gr.Button("Generate")
165
 
166
- vt_gen_button.click(fn=leffa_predict_vt, inputs=[
167
- vt_src_image, vt_ref_image], outputs=[vt_gen_image])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
  with gr.Tab("Control Pose (Pose Transfer)"):
170
  with gr.Row():
@@ -180,12 +221,8 @@ if __name__ == "__main__":
180
 
181
  gr.Examples(
182
  inputs=pt_ref_image,
183
- examples_per_page=5,
184
- examples=["./ckpts/examples/person1/01350_00.jpg",
185
- "./ckpts/examples/person1/01376_00.jpg",
186
- "./ckpts/examples/person1/01416_00.jpg",
187
- "./ckpts/examples/person1/05976_00.jpg",
188
- "./ckpts/examples/person1/06094_00.jpg",],
189
  )
190
 
191
  with gr.Column():
@@ -200,12 +237,8 @@ if __name__ == "__main__":
200
 
201
  gr.Examples(
202
  inputs=pt_src_image,
203
- examples_per_page=5,
204
- examples=["./ckpts/examples/person2/01850_00.jpg",
205
- "./ckpts/examples/person2/01875_00.jpg",
206
- "./ckpts/examples/person2/02532_00.jpg",
207
- "./ckpts/examples/person2/02902_00.jpg",
208
- "./ckpts/examples/person2/05346_00.jpg",],
209
  )
210
 
211
  with gr.Column():
@@ -219,8 +252,18 @@ if __name__ == "__main__":
219
  with gr.Row():
220
  pose_transfer_gen_button = gr.Button("Generate")
221
 
222
- pose_transfer_gen_button.click(fn=leffa_predict_pt, inputs=[
223
- pt_src_image, pt_ref_image], outputs=[pt_gen_image])
 
 
 
 
 
 
 
 
 
 
224
 
225
  gr.Markdown(note)
226
 
 
6
  from leffa.inference import LeffaInference
7
  from utils.garment_agnostic_mask_predictor import AutoMasker
8
  from utils.densepose_predictor import DensePosePredictor
9
+ from utils.utils import resize_and_center, list_dir
10
 
11
  import gradio as gr
12
 
13
  # Download checkpoints
14
  snapshot_download(repo_id="franciszzj/Leffa", local_dir="./ckpts")
15
 
16
+
17
+ class LeffaPredictor(object):
18
+ def __init__(self):
19
+ self.mask_predictor = AutoMasker(
20
+ densepose_path="./ckpts/densepose",
21
+ schp_path="./ckpts/schp",
22
+ )
23
+
24
+ self.densepose_predictor = DensePosePredictor(
25
+ config_path="./ckpts/densepose/densepose_rcnn_R_50_FPN_s1x.yaml",
26
+ weights_path="./ckpts/densepose/model_final_162be9.pkl",
27
+ )
28
+
29
+ vt_model = LeffaModel(
30
+ pretrained_model_name_or_path="./ckpts/stable-diffusion-inpainting",
31
+ pretrained_model="./ckpts/virtual_tryon.pth",
32
+ )
33
+ self.vt_inference = LeffaInference(model=vt_model)
34
+ self.vt_model_type = "viton_hd"
35
+
36
+ pt_model = LeffaModel(
37
+ pretrained_model_name_or_path="./ckpts/stable-diffusion-xl-1.0-inpainting-0.1",
38
+ pretrained_model="./ckpts/pose_transfer.pth",
39
+ )
40
+ self.pt_inference = LeffaInference(model=pt_model)
41
+
42
+ def change_vt_model(self, vt_model_type):
43
+ if vt_model_type == self.vt_model_type:
44
+ return
45
+ if vt_model_type == "viton_hd":
46
+ pretrained_model = "./ckpts/virtual_tryon.pth"
47
+ elif vt_model_type == "dress_code":
48
+ pretrained_model = "./ckpts/virtual_tryon_dc.pth"
49
+ vt_model = LeffaModel(
50
+ pretrained_model_name_or_path="./ckpts/stable-diffusion-inpainting",
51
+ pretrained_model=pretrained_model,
52
+ )
53
+ self.vt_inference = LeffaInference(model=vt_model)
54
+ self.vt_model_type = vt_model_type
55
+
56
+ def leffa_predict(self, src_image_path, ref_image_path, control_type, step=50, scale=2.5, seed=42):
57
+ assert control_type in [
58
+ "virtual_tryon", "pose_transfer"], "Invalid control type: {}".format(control_type)
59
+ src_image = Image.open(src_image_path)
60
+ ref_image = Image.open(ref_image_path)
61
+ src_image = resize_and_center(src_image, 768, 1024)
62
+ ref_image = resize_and_center(ref_image, 768, 1024)
63
+
64
+ src_image_array = np.array(src_image)
65
+
66
+ # Mask
67
+ if control_type == "virtual_tryon":
68
+ src_image = src_image.convert("RGB")
69
+ mask = self.mask_predictor(src_image, "upper")["mask"]
70
+ elif control_type == "pose_transfer":
71
+ mask = Image.fromarray(np.ones_like(src_image_array) * 255)
72
+
73
+ # DensePose
74
+ if control_type == "virtual_tryon":
75
+ src_image_seg_array = self.densepose_predictor.predict_seg(
76
+ src_image_array)
77
+ src_image_seg = Image.fromarray(src_image_seg_array)
78
+ densepose = src_image_seg
79
+ elif control_type == "pose_transfer":
80
+ src_image_iuv_array = self.densepose_predictor.predict_iuv(
81
+ src_image_array)
82
+ src_image_iuv = Image.fromarray(src_image_iuv_array)
83
+ densepose = src_image_iuv
84
+
85
+ # Leffa
86
+ transform = LeffaTransform()
87
+
88
+ data = {
89
+ "src_image": [src_image],
90
+ "ref_image": [ref_image],
91
+ "mask": [mask],
92
+ "densepose": [densepose],
93
+ }
94
+ data = transform(data)
95
+ if control_type == "virtual_tryon":
96
+ inference = self.vt_inference
97
+ elif control_type == "pose_transfer":
98
+ inference = self.pt_inference
99
+ output = inference(
100
+ data,
101
+ num_inference_steps=step,
102
+ guidance_scale=scale,
103
+ seed=seed,)
104
+ gen_image = output["generated_image"][0]
105
+ # gen_image.save("gen_image.png")
106
+ return np.array(gen_image)
107
+
108
+ def leffa_predict_vt(self, src_image_path, ref_image_path, step, scale, seed, vt_model_type="viton_hd"):
109
+ self.change_vt_model(vt_model_type)
110
+ return self.leffa_predict(src_image_path, ref_image_path, "virtual_tryon", step, scale, seed)
111
+
112
+ def leffa_predict_pt(self, src_image_path, ref_image_path, step, scale, seed):
113
+ return self.leffa_predict(src_image_path, ref_image_path, "pose_transfer", step, scale, seed)
114
 
115
 
116
  if __name__ == "__main__":
 
121
  # control_type = sys.argv[3]
122
  # leffa_predict(src_image_path, ref_image_path, control_type)
123
 
124
+ leffa_predictor = LeffaPredictor()
125
+ example_dir = "./ckpts/examples"
126
+ person1_images = list_dir(f"{example_dir}/person1")
127
+ person2_images = list_dir(f"{example_dir}/person2")
128
+ garment_images = list_dir(f"{example_dir}/garment")
129
+
130
  title = "## Leffa: Learning Flow Fields in Attention for Controllable Person Image Generation"
131
+ link = "[πŸ“š Paper](https://arxiv.org/abs/2412.08486) - [πŸ€– Code](https://github.com/franciszzj/Leffa) - [πŸ”₯ Demo](https://huggingface.co/spaces/franciszzj/Leffa) - [πŸ€— Model](https://huggingface.co/franciszzj/Leffa)"
132
+ news = """## News
133
+ - 16/Dec/2024, the virtual try-on [model](https://huggingface.co/franciszzj/Leffa/blob/main/virtual_tryon_dc.pth) trained on DressCode is released.
134
+ - 12/Dec/2024, the HuggingFace [demo](https://huggingface.co/spaces/franciszzj/Leffa) and [models](https://huggingface.co/franciszzj/Leffa) (virtual try-on model trained on VITON-HD and pose transfer model trained on DeepFashion) are released.
135
+ - 11/Dec/2024, the [arXiv](https://arxiv.org/abs/2412.08486) version of the paper is released.
136
+ """
137
  description = "Leffa is a unified framework for controllable person image generation that enables precise manipulation of both appearance (i.e., virtual try-on) and pose (i.e., pose transfer)."
138
+ note = "Note: The models used in the demo are trained solely on academic datasets. Virtual try-on uses VITON-HD/DressCode, and pose transfer uses DeepFashion."
139
 
140
  with gr.Blocks(theme=gr.themes.Default(primary_hue=gr.themes.colors.pink, secondary_hue=gr.themes.colors.red)).queue() as demo:
141
  gr.Markdown(title)
142
  gr.Markdown(link)
143
+ gr.Markdown(news)
144
  gr.Markdown(description)
145
 
146
  with gr.Tab("Control Appearance (Virtual Try-on)"):
 
157
 
158
  gr.Examples(
159
  inputs=vt_src_image,
160
+ examples_per_page=10,
161
+ examples=person1_images,
 
 
 
 
162
  )
163
 
164
  with gr.Column():
 
173
 
174
  gr.Examples(
175
  inputs=vt_ref_image,
176
+ examples_per_page=10,
177
+ examples=garment_images,
 
 
 
 
178
  )
179
 
180
  with gr.Column():
 
188
  with gr.Row():
189
  vt_gen_button = gr.Button("Generate")
190
 
191
+ with gr.Accordion("Advanced Options", open=False):
192
+ vt_step = gr.Number(
193
+ label="Inference Steps", minimum=30, maximum=100, step=1, value=50)
194
+
195
+ vt_scale = gr.Number(
196
+ label="Guidance Scale", minimum=0.1, maximum=5.0, step=0.1, value=2.5)
197
+
198
+ vt_seed = gr.Number(
199
+ label="Random Seed", minimum=-1, maximum=2147483647, step=1, value=42)
200
+
201
+ vt_model_type = gr.Radio(
202
+ choices=["viton_hd", "dress_code"],
203
+ value="viton_hd",
204
+ label="Model Type",
205
+ )
206
+
207
+ vt_gen_button.click(fn=leffa_predictor.leffa_predict_vt, inputs=[
208
+ vt_src_image, vt_ref_image, vt_step, vt_scale, vt_seed, vt_model_type], outputs=[vt_gen_image])
209
 
210
  with gr.Tab("Control Pose (Pose Transfer)"):
211
  with gr.Row():
 
221
 
222
  gr.Examples(
223
  inputs=pt_ref_image,
224
+ examples_per_page=10,
225
+ examples=person1_images,
 
 
 
 
226
  )
227
 
228
  with gr.Column():
 
237
 
238
  gr.Examples(
239
  inputs=pt_src_image,
240
+ examples_per_page=10,
241
+ examples=person2_images,
 
 
 
 
242
  )
243
 
244
  with gr.Column():
 
252
  with gr.Row():
253
  pose_transfer_gen_button = gr.Button("Generate")
254
 
255
+ with gr.Accordion("Advanced Options", open=False):
256
+ pt_step = gr.Number(
257
+ label="Inference Steps", minimum=30, maximum=100, step=1, value=50)
258
+
259
+ pt_scale = gr.Number(
260
+ label="Guidance Scale", minimum=0.1, maximum=5.0, step=0.1, value=2.5)
261
+
262
+ pt_seed = gr.Number(
263
+ label="Random Seed", minimum=-1, maximum=2147483647, step=1, value=42)
264
+
265
+ pose_transfer_gen_button.click(fn=leffa_predictor.leffa_predict_pt, inputs=[
266
+ pt_src_image, pt_ref_image, pt_step, pt_scale, pt_seed], outputs=[pt_gen_image])
267
 
268
  gr.Markdown(note)
269
 
utils/utils.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import cv2
2
  import numpy as np
3
  from PIL import Image
@@ -29,3 +30,14 @@ def resize_and_center(image, target_width, target_height):
29
  padded_img[top:top + new_height, left:left + new_width] = resized_img
30
 
31
  return Image.fromarray(padded_img)
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
  import cv2
3
  import numpy as np
4
  from PIL import Image
 
30
  padded_img[top:top + new_height, left:left + new_width] = resized_img
31
 
32
  return Image.fromarray(padded_img)
33
+
34
+
35
+ def list_dir(folder_path):
36
+ # Collect all file paths within the directory
37
+ file_paths = []
38
+ for root, _, files in os.walk(folder_path):
39
+ for file in files:
40
+ file_paths.append(os.path.join(root, file))
41
+
42
+ file_paths = sorted(file_paths)
43
+ return file_paths