DilshanIRU commited on
Commit
7c62aee
·
verified ·
1 Parent(s): 1eb4732

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +778 -778
app.py CHANGED
@@ -1,778 +1,778 @@
1
- import argparse
2
- import os
3
- os.environ['CUDA_HOME'] = '/usr/local/cuda'
4
- os.environ['PATH'] = os.environ['PATH'] + ':/usr/local/cuda/bin'
5
- from datetime import datetime
6
-
7
- import gradio as gr
8
- import spaces
9
- import numpy as np
10
- import torch
11
- from diffusers.image_processor import VaeImageProcessor
12
- from huggingface_hub import snapshot_download
13
- from PIL import Image
14
- torch.jit.script = lambda f: f
15
- from model.cloth_masker import AutoMasker, vis_mask
16
- from model.pipeline import CatVTONPipeline, CatVTONPix2PixPipeline
17
- from model.flux.pipeline_flux_tryon import FluxTryOnPipeline
18
- from utils import init_weight_dtype, resize_and_crop, resize_and_padding
19
-
20
-
21
- def parse_args():
22
- parser = argparse.ArgumentParser(description="Simple example of a training script.")
23
- parser.add_argument(
24
- "--base_model_path",
25
- type=str,
26
- default="booksforcharlie/stable-diffusion-inpainting",
27
- help=(
28
- "The path to the base model to use for evaluation. This can be a local path or a model identifier from the Model Hub."
29
- ),
30
- )
31
- parser.add_argument(
32
- "--p2p_base_model_path",
33
- type=str,
34
- default="timbrooks/instruct-pix2pix",
35
- help=(
36
- "The path to the base model to use for evaluation. This can be a local path or a model identifier from the Model Hub."
37
- ),
38
- )
39
- parser.add_argument(
40
- "--resume_path",
41
- type=str,
42
- default="zhengchong/CatVTON",
43
- help=(
44
- "The Path to the checkpoint of trained tryon model."
45
- ),
46
- )
47
- parser.add_argument(
48
- "--output_dir",
49
- type=str,
50
- default="resource/demo/output",
51
- help="The output directory where the model predictions will be written.",
52
- )
53
-
54
- parser.add_argument(
55
- "--width",
56
- type=int,
57
- default=768,
58
- help=(
59
- "The resolution for input images, all the images in the train/validation dataset will be resized to this"
60
- " resolution"
61
- ),
62
- )
63
- parser.add_argument(
64
- "--height",
65
- type=int,
66
- default=1024,
67
- help=(
68
- "The resolution for input images, all the images in the train/validation dataset will be resized to this"
69
- " resolution"
70
- ),
71
- )
72
- parser.add_argument(
73
- "--repaint",
74
- action="store_true",
75
- help="Whether to repaint the result image with the original background."
76
- )
77
- parser.add_argument(
78
- "--allow_tf32",
79
- action="store_true",
80
- default=True,
81
- help=(
82
- "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
83
- " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
84
- ),
85
- )
86
- parser.add_argument(
87
- "--mixed_precision",
88
- type=str,
89
- default="bf16",
90
- choices=["no", "fp16", "bf16"],
91
- help=(
92
- "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
93
- " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
94
- " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
95
- ),
96
- )
97
-
98
- args = parser.parse_args()
99
- env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
100
- if env_local_rank != -1 and env_local_rank != args.local_rank:
101
- args.local_rank = env_local_rank
102
-
103
- return args
104
-
105
- def image_grid(imgs, rows, cols):
106
- assert len(imgs) == rows * cols
107
-
108
- w, h = imgs[0].size
109
- grid = Image.new("RGB", size=(cols * w, rows * h))
110
-
111
- for i, img in enumerate(imgs):
112
- grid.paste(img, box=(i % cols * w, i // cols * h))
113
- return grid
114
-
115
-
116
- args = parse_args()
117
-
118
- # Mask-based CatVTON
119
- catvton_repo = "zhengchong/CatVTON"
120
- repo_path = snapshot_download(repo_id=catvton_repo)
121
- # Pipeline
122
- pipeline = CatVTONPipeline(
123
- base_ckpt=args.base_model_path,
124
- attn_ckpt=repo_path,
125
- attn_ckpt_version="mix",
126
- weight_dtype=init_weight_dtype(args.mixed_precision),
127
- use_tf32=args.allow_tf32,
128
- device='cuda'
129
- )
130
- # AutoMasker
131
- mask_processor = VaeImageProcessor(vae_scale_factor=8, do_normalize=False, do_binarize=True, do_convert_grayscale=True)
132
- automasker = AutoMasker(
133
- densepose_ckpt=os.path.join(repo_path, "DensePose"),
134
- schp_ckpt=os.path.join(repo_path, "SCHP"),
135
- device='cuda',
136
- )
137
-
138
-
139
- # Flux-based CatVTON
140
- access_token = os.getenv("HUGGING_FACE_HUB_TOKEN")
141
- flux_repo = "black-forest-labs/FLUX.1-Fill-dev"
142
- pipeline_flux = FluxTryOnPipeline.from_pretrained(flux_repo, use_auth_token=access_token)
143
- pipeline_flux.load_lora_weights(
144
- os.path.join(repo_path, "flux-lora"),
145
- weight_name='pytorch_lora_weights.safetensors'
146
- )
147
- pipeline_flux.to("cuda", init_weight_dtype(args.mixed_precision))
148
-
149
-
150
- # Mask-free CatVTON
151
- catvton_mf_repo = "zhengchong/CatVTON-MaskFree"
152
- repo_path_mf = snapshot_download(repo_id=catvton_mf_repo, use_auth_token=access_token)
153
- pipeline_p2p = CatVTONPix2PixPipeline(
154
- base_ckpt=args.p2p_base_model_path,
155
- attn_ckpt=repo_path_mf,
156
- attn_ckpt_version="mix-48k-1024",
157
- weight_dtype=init_weight_dtype(args.mixed_precision),
158
- use_tf32=args.allow_tf32,
159
- device='cuda'
160
- )
161
-
162
-
163
- @spaces.GPU(duration=120)
164
- def submit_function(
165
- person_image,
166
- cloth_image,
167
- cloth_type,
168
- num_inference_steps,
169
- guidance_scale,
170
- seed,
171
- show_type
172
- ):
173
- person_image, mask = person_image["background"], person_image["layers"][0]
174
- mask = Image.open(mask).convert("L")
175
- if len(np.unique(np.array(mask))) == 1:
176
- mask = None
177
- else:
178
- mask = np.array(mask)
179
- mask[mask > 0] = 255
180
- mask = Image.fromarray(mask)
181
-
182
- tmp_folder = args.output_dir
183
- date_str = datetime.now().strftime("%Y%m%d%H%M%S")
184
- result_save_path = os.path.join(tmp_folder, date_str[:8], date_str[8:] + ".png")
185
- if not os.path.exists(os.path.join(tmp_folder, date_str[:8])):
186
- os.makedirs(os.path.join(tmp_folder, date_str[:8]))
187
-
188
- generator = None
189
- if seed != -1:
190
- generator = torch.Generator(device='cuda').manual_seed(seed)
191
-
192
- person_image = Image.open(person_image).convert("RGB")
193
- cloth_image = Image.open(cloth_image).convert("RGB")
194
- person_image = resize_and_crop(person_image, (args.width, args.height))
195
- cloth_image = resize_and_padding(cloth_image, (args.width, args.height))
196
-
197
- # Process mask
198
- if mask is not None:
199
- mask = resize_and_crop(mask, (args.width, args.height))
200
- else:
201
- mask = automasker(
202
- person_image,
203
- cloth_type
204
- )['mask']
205
- mask = mask_processor.blur(mask, blur_factor=9)
206
-
207
- # Inference
208
- # try:
209
- result_image = pipeline(
210
- image=person_image,
211
- condition_image=cloth_image,
212
- mask=mask,
213
- num_inference_steps=num_inference_steps,
214
- guidance_scale=guidance_scale,
215
- generator=generator
216
- )[0]
217
- # except Exception as e:
218
- # raise gr.Error(
219
- # "An error occurred. Please try again later: {}".format(e)
220
- # )
221
-
222
- # Post-process
223
- masked_person = vis_mask(person_image, mask)
224
- save_result_image = image_grid([person_image, masked_person, cloth_image, result_image], 1, 4)
225
- save_result_image.save(result_save_path)
226
- if show_type == "result only":
227
- return result_image
228
- else:
229
- width, height = person_image.size
230
- if show_type == "input & result":
231
- condition_width = width // 2
232
- conditions = image_grid([person_image, cloth_image], 2, 1)
233
- else:
234
- condition_width = width // 3
235
- conditions = image_grid([person_image, masked_person , cloth_image], 3, 1)
236
- conditions = conditions.resize((condition_width, height), Image.NEAREST)
237
- new_result_image = Image.new("RGB", (width + condition_width + 5, height))
238
- new_result_image.paste(conditions, (0, 0))
239
- new_result_image.paste(result_image, (condition_width + 5, 0))
240
- return new_result_image
241
-
242
- @spaces.GPU(duration=120)
243
- def submit_function_p2p(
244
- person_image,
245
- cloth_image,
246
- num_inference_steps,
247
- guidance_scale,
248
- seed):
249
- person_image= person_image["background"]
250
-
251
- tmp_folder = args.output_dir
252
- date_str = datetime.now().strftime("%Y%m%d%H%M%S")
253
- result_save_path = os.path.join(tmp_folder, date_str[:8], date_str[8:] + ".png")
254
- if not os.path.exists(os.path.join(tmp_folder, date_str[:8])):
255
- os.makedirs(os.path.join(tmp_folder, date_str[:8]))
256
-
257
- generator = None
258
- if seed != -1:
259
- generator = torch.Generator(device='cuda').manual_seed(seed)
260
-
261
- person_image = Image.open(person_image).convert("RGB")
262
- cloth_image = Image.open(cloth_image).convert("RGB")
263
- person_image = resize_and_crop(person_image, (args.width, args.height))
264
- cloth_image = resize_and_padding(cloth_image, (args.width, args.height))
265
-
266
- # Inference
267
- try:
268
- result_image = pipeline_p2p(
269
- image=person_image,
270
- condition_image=cloth_image,
271
- num_inference_steps=num_inference_steps,
272
- guidance_scale=guidance_scale,
273
- generator=generator
274
- )[0]
275
- except Exception as e:
276
- raise gr.Error(
277
- "An error occurred. Please try again later: {}".format(e)
278
- )
279
-
280
- # Post-process
281
- save_result_image = image_grid([person_image, cloth_image, result_image], 1, 3)
282
- save_result_image.save(result_save_path)
283
- return result_image
284
-
285
- @spaces.GPU(duration=120)
286
- def submit_function_flux(
287
- person_image,
288
- cloth_image,
289
- cloth_type,
290
- num_inference_steps,
291
- guidance_scale,
292
- seed,
293
- show_type
294
- ):
295
-
296
- # Process image editor input
297
- person_image, mask = person_image["background"], person_image["layers"][0]
298
- mask = Image.open(mask).convert("L")
299
- if len(np.unique(np.array(mask))) == 1:
300
- mask = None
301
- else:
302
- mask = np.array(mask)
303
- mask[mask > 0] = 255
304
- mask = Image.fromarray(mask)
305
-
306
- # Set random seed
307
- generator = None
308
- if seed != -1:
309
- generator = torch.Generator(device='cuda').manual_seed(seed)
310
-
311
- # Process input images
312
- person_image = Image.open(person_image).convert("RGB")
313
- cloth_image = Image.open(cloth_image).convert("RGB")
314
-
315
- # Adjust image sizes
316
- person_image = resize_and_crop(person_image, (args.width, args.height))
317
- cloth_image = resize_and_padding(cloth_image, (args.width, args.height))
318
-
319
- # Process mask
320
- if mask is not None:
321
- mask = resize_and_crop(mask, (args.width, args.height))
322
- else:
323
- mask = automasker(
324
- person_image,
325
- cloth_type
326
- )['mask']
327
- mask = mask_processor.blur(mask, blur_factor=9)
328
-
329
- # Inference
330
- result_image = pipeline_flux(
331
- image=person_image,
332
- condition_image=cloth_image,
333
- mask_image=mask,
334
- width=args.width,
335
- height=args.height,
336
- num_inference_steps=num_inference_steps,
337
- guidance_scale=guidance_scale,
338
- generator=generator
339
- ).images[0]
340
-
341
- # Post-processing
342
- masked_person = vis_mask(person_image, mask)
343
-
344
- # Return result based on show type
345
- if show_type == "result only":
346
- return result_image
347
- else:
348
- width, height = person_image.size
349
- if show_type == "input & result":
350
- condition_width = width // 2
351
- conditions = image_grid([person_image, cloth_image], 2, 1)
352
- else:
353
- condition_width = width // 3
354
- conditions = image_grid([person_image, masked_person, cloth_image], 3, 1)
355
-
356
- conditions = conditions.resize((condition_width, height), Image.NEAREST)
357
- new_result_image = Image.new("RGB", (width + condition_width + 5, height))
358
- new_result_image.paste(conditions, (0, 0))
359
- new_result_image.paste(result_image, (condition_width + 5, 0))
360
- return new_result_image
361
-
362
-
363
- def person_example_fn(image_path):
364
- return image_path
365
-
366
-
367
- HEADER = """
368
- <h1 style="text-align: center;"> 🐈 CatVTON: Concatenation Is All You Need for Virtual Try-On with Diffusion Models </h1>
369
- <div style="display: flex; justify-content: center; align-items: center;">
370
- <a href="http://arxiv.org/abs/2407.15886" style="margin: 0 2px;">
371
- <img src='https://img.shields.io/badge/arXiv-2407.15886-red?style=flat&logo=arXiv&logoColor=red' alt='arxiv'>
372
- </a>
373
- <a href='https://huggingface.co/zhengchong/CatVTON' style="margin: 0 2px;">
374
- <img src='https://img.shields.io/badge/Hugging Face-ckpts-orange?style=flat&logo=HuggingFace&logoColor=orange' alt='huggingface'>
375
- </a>
376
- <a href="https://github.com/Zheng-Chong/CatVTON" style="margin: 0 2px;">
377
- <img src='https://img.shields.io/badge/GitHub-Repo-blue?style=flat&logo=GitHub' alt='GitHub'>
378
- </a>
379
- <a href="http://120.76.142.206:8888" style="margin: 0 2px;">
380
- <img src='https://img.shields.io/badge/Demo-Gradio-gold?style=flat&logo=Gradio&logoColor=red' alt='Demo'>
381
- </a>
382
- <a href="https://huggingface.co/spaces/zhengchong/CatVTON" style="margin: 0 2px;">
383
- <img src='https://img.shields.io/badge/Space-ZeroGPU-orange?style=flat&logo=Gradio&logoColor=red' alt='Demo'>
384
- </a>
385
- <a href='https://zheng-chong.github.io/CatVTON/' style="margin: 0 2px;">
386
- <img src='https://img.shields.io/badge/Webpage-Project-silver?style=flat&logo=&logoColor=orange' alt='webpage'>
387
- </a>
388
- <a href="https://github.com/Zheng-Chong/CatVTON/LICENCE" style="margin: 0 2px;">
389
- <img src='https://img.shields.io/badge/License-CC BY--NC--SA--4.0-lightgreen?style=flat&logo=Lisence' alt='License'>
390
- </a>
391
- </div>
392
- <br>
393
- · This demo and our weights are only for Non-commercial Use. <br>
394
- · Thanks to <a href="https://huggingface.co/zero-gpu-explorers">ZeroGPU</a> for providing A100 for our <a href="https://huggingface.co/spaces/zhengchong/CatVTON">HuggingFace Space</a>. <br>
395
- · SafetyChecker is set to filter NSFW content, but it may block normal results too. Please adjust the <span>`seed`</span> for normal outcomes.<br>
396
- """
397
-
398
- def app_gradio():
399
- with gr.Blocks(title="CatVTON") as demo:
400
- gr.Markdown(HEADER)
401
- with gr.Tab("Mask-based & SD1.5"):
402
- with gr.Row():
403
- with gr.Column(scale=1, min_width=350):
404
- with gr.Row():
405
- image_path = gr.Image(
406
- type="filepath",
407
- interactive=True,
408
- visible=False,
409
- )
410
- person_image = gr.ImageEditor(
411
- interactive=True, label="Person Image", type="filepath"
412
- )
413
-
414
- with gr.Row():
415
- with gr.Column(scale=1, min_width=230):
416
- cloth_image = gr.Image(
417
- interactive=True, label="Condition Image", type="filepath"
418
- )
419
- with gr.Column(scale=1, min_width=120):
420
- gr.Markdown(
421
- '<span style="color: #808080; font-size: small;">Two ways to provide Mask:<br>1. Upload the person image and use the `🖌️` above to draw the Mask (higher priority)<br>2. Select the `Try-On Cloth Type` to generate automatically </span>'
422
- )
423
- cloth_type = gr.Radio(
424
- label="Try-On Cloth Type",
425
- choices=["upper", "lower", "overall"],
426
- value="upper",
427
- )
428
-
429
-
430
- submit = gr.Button("Submit")
431
- gr.Markdown(
432
- '<center><span style="color: #FF0000">!!! Click only Once, Wait for Delay !!!</span></center>'
433
- )
434
-
435
- gr.Markdown(
436
- '<span style="color: #808080; font-size: small;">Advanced options can adjust details:<br>1. `Inference Step` may enhance details;<br>2. `CFG` is highly correlated with saturation;<br>3. `Random seed` may improve pseudo-shadow.</span>'
437
- )
438
- with gr.Accordion("Advanced Options", open=False):
439
- num_inference_steps = gr.Slider(
440
- label="Inference Step", minimum=10, maximum=100, step=5, value=50
441
- )
442
- # Guidence Scale
443
- guidance_scale = gr.Slider(
444
- label="CFG Strenth", minimum=0.0, maximum=7.5, step=0.5, value=2.5
445
- )
446
- # Random Seed
447
- seed = gr.Slider(
448
- label="Seed", minimum=-1, maximum=10000, step=1, value=42
449
- )
450
- show_type = gr.Radio(
451
- label="Show Type",
452
- choices=["result only", "input & result", "input & mask & result"],
453
- value="input & mask & result",
454
- )
455
-
456
- with gr.Column(scale=2, min_width=500):
457
- result_image = gr.Image(interactive=False, label="Result")
458
- with gr.Row():
459
- # Photo Examples
460
- root_path = "resource/demo/example"
461
- with gr.Column():
462
- men_exm = gr.Examples(
463
- examples=[
464
- os.path.join(root_path, "person", "men", _)
465
- for _ in os.listdir(os.path.join(root_path, "person", "men"))
466
- ],
467
- examples_per_page=4,
468
- inputs=image_path,
469
- label="Person Examples ①",
470
- )
471
- women_exm = gr.Examples(
472
- examples=[
473
- os.path.join(root_path, "person", "women", _)
474
- for _ in os.listdir(os.path.join(root_path, "person", "women"))
475
- ],
476
- examples_per_page=4,
477
- inputs=image_path,
478
- label="Person Examples ②",
479
- )
480
- gr.Markdown(
481
- '<span style="color: #808080; font-size: small;">*Person examples come from the demos of <a href="https://huggingface.co/spaces/levihsu/OOTDiffusion">OOTDiffusion</a> and <a href="https://www.outfitanyone.org">OutfitAnyone</a>. </span>'
482
- )
483
- with gr.Column():
484
- condition_upper_exm = gr.Examples(
485
- examples=[
486
- os.path.join(root_path, "condition", "upper", _)
487
- for _ in os.listdir(os.path.join(root_path, "condition", "upper"))
488
- ],
489
- examples_per_page=4,
490
- inputs=cloth_image,
491
- label="Condition Upper Examples",
492
- )
493
- condition_overall_exm = gr.Examples(
494
- examples=[
495
- os.path.join(root_path, "condition", "overall", _)
496
- for _ in os.listdir(os.path.join(root_path, "condition", "overall"))
497
- ],
498
- examples_per_page=4,
499
- inputs=cloth_image,
500
- label="Condition Overall Examples",
501
- )
502
- condition_person_exm = gr.Examples(
503
- examples=[
504
- os.path.join(root_path, "condition", "person", _)
505
- for _ in os.listdir(os.path.join(root_path, "condition", "person"))
506
- ],
507
- examples_per_page=4,
508
- inputs=cloth_image,
509
- label="Condition Reference Person Examples",
510
- )
511
- gr.Markdown(
512
- '<span style="color: #808080; font-size: small;">*Condition examples come from the Internet. </span>'
513
- )
514
-
515
- image_path.change(
516
- person_example_fn, inputs=image_path, outputs=person_image
517
- )
518
-
519
- submit.click(
520
- submit_function,
521
- [
522
- person_image,
523
- cloth_image,
524
- cloth_type,
525
- num_inference_steps,
526
- guidance_scale,
527
- seed,
528
- show_type,
529
- ],
530
- result_image,
531
- )
532
-
533
- with gr.Tab("Mask-based & Flux.1 Fill Dev"):
534
- with gr.Row():
535
- with gr.Column(scale=1, min_width=350):
536
- with gr.Row():
537
- image_path_flux = gr.Image(
538
- type="filepath",
539
- interactive=True,
540
- visible=False,
541
- )
542
- person_image_flux = gr.ImageEditor(
543
- interactive=True, label="Person Image", type="filepath"
544
- )
545
-
546
- with gr.Row():
547
- with gr.Column(scale=1, min_width=230):
548
- cloth_image_flux = gr.Image(
549
- interactive=True, label="Condition Image", type="filepath"
550
- )
551
- with gr.Column(scale=1, min_width=120):
552
- gr.Markdown(
553
- '<span style="color: #808080; font-size: small;">Two ways to provide Mask:<br>1. Upload the person image and use the `🖌️` above to draw the Mask (higher priority)<br>2. Select the `Try-On Cloth Type` to generate automatically </span>'
554
- )
555
- cloth_type = gr.Radio(
556
- label="Try-On Cloth Type",
557
- choices=["upper", "lower", "overall"],
558
- value="upper",
559
- )
560
-
561
- submit_flux = gr.Button("Submit")
562
- gr.Markdown(
563
- '<center><span style="color: #FF0000">!!! Click only Once, Wait for Delay !!!</span></center>'
564
- )
565
-
566
- with gr.Accordion("Advanced Options", open=False):
567
- num_inference_steps_flux = gr.Slider(
568
- label="Inference Step", minimum=10, maximum=100, step=5, value=50
569
- )
570
- # Guidence Scale
571
- guidance_scale_flux = gr.Slider(
572
- label="CFG Strenth", minimum=0.0, maximum=50, step=0.5, value=30
573
- )
574
- # Random Seed
575
- seed_flux = gr.Slider(
576
- label="Seed", minimum=-1, maximum=10000, step=1, value=42
577
- )
578
- show_type = gr.Radio(
579
- label="Show Type",
580
- choices=["result only", "input & result", "input & mask & result"],
581
- value="input & mask & result",
582
- )
583
-
584
- with gr.Column(scale=2, min_width=500):
585
- result_image_flux = gr.Image(interactive=False, label="Result")
586
- with gr.Row():
587
- # Photo Examples
588
- root_path = "resource/demo/example"
589
- with gr.Column():
590
- gr.Examples(
591
- examples=[
592
- os.path.join(root_path, "person", "men", _)
593
- for _ in os.listdir(os.path.join(root_path, "person", "men"))
594
- ],
595
- examples_per_page=4,
596
- inputs=image_path_flux,
597
- label="Person Examples ①",
598
- )
599
- gr.Examples(
600
- examples=[
601
- os.path.join(root_path, "person", "women", _)
602
- for _ in os.listdir(os.path.join(root_path, "person", "women"))
603
- ],
604
- examples_per_page=4,
605
- inputs=image_path_flux,
606
- label="Person Examples ②",
607
- )
608
- gr.Markdown(
609
- '<span style="color: #808080; font-size: small;">*Person examples come from the demos of <a href="https://huggingface.co/spaces/levihsu/OOTDiffusion">OOTDiffusion</a> and <a href="https://www.outfitanyone.org">OutfitAnyone</a>. </span>'
610
- )
611
- with gr.Column():
612
- gr.Examples(
613
- examples=[
614
- os.path.join(root_path, "condition", "upper", _)
615
- for _ in os.listdir(os.path.join(root_path, "condition", "upper"))
616
- ],
617
- examples_per_page=4,
618
- inputs=cloth_image_flux,
619
- label="Condition Upper Examples",
620
- )
621
- gr.Examples(
622
- examples=[
623
- os.path.join(root_path, "condition", "overall", _)
624
- for _ in os.listdir(os.path.join(root_path, "condition", "overall"))
625
- ],
626
- examples_per_page=4,
627
- inputs=cloth_image_flux,
628
- label="Condition Overall Examples",
629
- )
630
- condition_person_exm = gr.Examples(
631
- examples=[
632
- os.path.join(root_path, "condition", "person", _)
633
- for _ in os.listdir(os.path.join(root_path, "condition", "person"))
634
- ],
635
- examples_per_page=4,
636
- inputs=cloth_image_flux,
637
- label="Condition Reference Person Examples",
638
- )
639
- gr.Markdown(
640
- '<span style="color: #808080; font-size: small;">*Condition examples come from the Internet. </span>'
641
- )
642
-
643
-
644
- image_path_flux.change(
645
- person_example_fn, inputs=image_path_flux, outputs=person_image_flux
646
- )
647
-
648
- submit_flux.click(
649
- submit_function_flux,
650
- [person_image_flux, cloth_image_flux, cloth_type, num_inference_steps_flux, guidance_scale_flux, seed_flux, show_type],
651
- result_image_flux,
652
- )
653
-
654
-
655
- with gr.Tab("Mask-free & SD1.5"):
656
- with gr.Row():
657
- with gr.Column(scale=1, min_width=350):
658
- with gr.Row():
659
- image_path_p2p = gr.Image(
660
- type="filepath",
661
- interactive=True,
662
- visible=False,
663
- )
664
- person_image_p2p = gr.ImageEditor(
665
- interactive=True, label="Person Image", type="filepath"
666
- )
667
-
668
- with gr.Row():
669
- with gr.Column(scale=1, min_width=230):
670
- cloth_image_p2p = gr.Image(
671
- interactive=True, label="Condition Image", type="filepath"
672
- )
673
-
674
- submit_p2p = gr.Button("Submit")
675
- gr.Markdown(
676
- '<center><span style="color: #FF0000">!!! Click only Once, Wait for Delay !!!</span></center>'
677
- )
678
-
679
- gr.Markdown(
680
- '<span style="color: #808080; font-size: small;">Advanced options can adjust details:<br>1. `Inference Step` may enhance details;<br>2. `CFG` is highly correlated with saturation;<br>3. `Random seed` may improve pseudo-shadow.</span>'
681
- )
682
- with gr.Accordion("Advanced Options", open=False):
683
- num_inference_steps_p2p = gr.Slider(
684
- label="Inference Step", minimum=10, maximum=100, step=5, value=50
685
- )
686
- # Guidence Scale
687
- guidance_scale_p2p = gr.Slider(
688
- label="CFG Strenth", minimum=0.0, maximum=7.5, step=0.5, value=2.5
689
- )
690
- # Random Seed
691
- seed_p2p = gr.Slider(
692
- label="Seed", minimum=-1, maximum=10000, step=1, value=42
693
- )
694
- # show_type = gr.Radio(
695
- # label="Show Type",
696
- # choices=["result only", "input & result", "input & mask & result"],
697
- # value="input & mask & result",
698
- # )
699
-
700
- with gr.Column(scale=2, min_width=500):
701
- result_image_p2p = gr.Image(interactive=False, label="Result")
702
- with gr.Row():
703
- # Photo Examples
704
- root_path = "resource/demo/example"
705
- with gr.Column():
706
- gr.Examples(
707
- examples=[
708
- os.path.join(root_path, "person", "men", _)
709
- for _ in os.listdir(os.path.join(root_path, "person", "men"))
710
- ],
711
- examples_per_page=4,
712
- inputs=image_path_p2p,
713
- label="Person Examples ①",
714
- )
715
- gr.Examples(
716
- examples=[
717
- os.path.join(root_path, "person", "women", _)
718
- for _ in os.listdir(os.path.join(root_path, "person", "women"))
719
- ],
720
- examples_per_page=4,
721
- inputs=image_path_p2p,
722
- label="Person Examples ②",
723
- )
724
- gr.Markdown(
725
- '<span style="color: #808080; font-size: small;">*Person examples come from the demos of <a href="https://huggingface.co/spaces/levihsu/OOTDiffusion">OOTDiffusion</a> and <a href="https://www.outfitanyone.org">OutfitAnyone</a>. </span>'
726
- )
727
- with gr.Column():
728
- gr.Examples(
729
- examples=[
730
- os.path.join(root_path, "condition", "upper", _)
731
- for _ in os.listdir(os.path.join(root_path, "condition", "upper"))
732
- ],
733
- examples_per_page=4,
734
- inputs=cloth_image_p2p,
735
- label="Condition Upper Examples",
736
- )
737
- gr.Examples(
738
- examples=[
739
- os.path.join(root_path, "condition", "overall", _)
740
- for _ in os.listdir(os.path.join(root_path, "condition", "overall"))
741
- ],
742
- examples_per_page=4,
743
- inputs=cloth_image_p2p,
744
- label="Condition Overall Examples",
745
- )
746
- condition_person_exm = gr.Examples(
747
- examples=[
748
- os.path.join(root_path, "condition", "person", _)
749
- for _ in os.listdir(os.path.join(root_path, "condition", "person"))
750
- ],
751
- examples_per_page=4,
752
- inputs=cloth_image_p2p,
753
- label="Condition Reference Person Examples",
754
- )
755
- gr.Markdown(
756
- '<span style="color: #808080; font-size: small;">*Condition examples come from the Internet. </span>'
757
- )
758
-
759
- image_path_p2p.change(
760
- person_example_fn, inputs=image_path_p2p, outputs=person_image_p2p
761
- )
762
-
763
- submit_p2p.click(
764
- submit_function_p2p,
765
- [
766
- person_image_p2p,
767
- cloth_image_p2p,
768
- num_inference_steps_p2p,
769
- guidance_scale_p2p,
770
- seed_p2p],
771
- result_image_p2p,
772
- )
773
-
774
- demo.queue().launch(share=True, show_error=True)
775
-
776
-
777
- if __name__ == "__main__":
778
- app_gradio()
 
1
+ import argparse
2
+ import os
3
+ os.environ['CUDA_HOME'] = '/usr/local/cuda'
4
+ os.environ['PATH'] = os.environ['PATH'] + ':/usr/local/cuda/bin'
5
+ from datetime import datetime
6
+
7
+ import gradio as gr
8
+ import spaces
9
+ import numpy as np
10
+ import torch
11
+ from diffusers.image_processor import VaeImageProcessor
12
+ from huggingface_hub import snapshot_download
13
+ from PIL import Image
14
+ torch.jit.script = lambda f: f
15
+ from model.cloth_masker import AutoMasker, vis_mask
16
+ from model.pipeline import CatVTONPipeline, CatVTONPix2PixPipeline
17
+ from model.flux.pipeline_flux_tryon import FluxTryOnPipeline
18
+ from utils import init_weight_dtype, resize_and_crop, resize_and_padding
19
+
20
+
21
+ def parse_args():
22
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
23
+ parser.add_argument(
24
+ "--base_model_path",
25
+ type=str,
26
+ default="booksforcharlie/stable-diffusion-inpainting",
27
+ help=(
28
+ "The path to the base model to use for evaluation. This can be a local path or a model identifier from the Model Hub."
29
+ ),
30
+ )
31
+ parser.add_argument(
32
+ "--p2p_base_model_path",
33
+ type=str,
34
+ default="timbrooks/instruct-pix2pix",
35
+ help=(
36
+ "The path to the base model to use for evaluation. This can be a local path or a model identifier from the Model Hub."
37
+ ),
38
+ )
39
+ parser.add_argument(
40
+ "--resume_path",
41
+ type=str,
42
+ default="zhengchong/CatVTON",
43
+ help=(
44
+ "The Path to the checkpoint of trained tryon model."
45
+ ),
46
+ )
47
+ parser.add_argument(
48
+ "--output_dir",
49
+ type=str,
50
+ default="resource/demo/output",
51
+ help="The output directory where the model predictions will be written.",
52
+ )
53
+
54
+ parser.add_argument(
55
+ "--width",
56
+ type=int,
57
+ default=768,
58
+ help=(
59
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
60
+ " resolution"
61
+ ),
62
+ )
63
+ parser.add_argument(
64
+ "--height",
65
+ type=int,
66
+ default=1024,
67
+ help=(
68
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
69
+ " resolution"
70
+ ),
71
+ )
72
+ parser.add_argument(
73
+ "--repaint",
74
+ action="store_true",
75
+ help="Whether to repaint the result image with the original background."
76
+ )
77
+ parser.add_argument(
78
+ "--allow_tf32",
79
+ action="store_true",
80
+ default=True,
81
+ help=(
82
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
83
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
84
+ ),
85
+ )
86
+ parser.add_argument(
87
+ "--mixed_precision",
88
+ type=str,
89
+ default="bf16",
90
+ choices=["no", "fp16", "bf16"],
91
+ help=(
92
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
93
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
94
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
95
+ ),
96
+ )
97
+
98
+ args = parser.parse_args()
99
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
100
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
101
+ args.local_rank = env_local_rank
102
+
103
+ return args
104
+
105
+ def image_grid(imgs, rows, cols):
106
+ assert len(imgs) == rows * cols
107
+
108
+ w, h = imgs[0].size
109
+ grid = Image.new("RGB", size=(cols * w, rows * h))
110
+
111
+ for i, img in enumerate(imgs):
112
+ grid.paste(img, box=(i % cols * w, i // cols * h))
113
+ return grid
114
+
115
+
116
+ args = parse_args()
117
+
118
+ # Mask-based CatVTON
119
+ catvton_repo = "zhengchong/CatVTON"
120
+ repo_path = snapshot_download(repo_id=catvton_repo)
121
+ # Pipeline
122
+ pipeline = CatVTONPipeline(
123
+ base_ckpt=args.base_model_path,
124
+ attn_ckpt=repo_path,
125
+ attn_ckpt_version="mix",
126
+ weight_dtype=init_weight_dtype(args.mixed_precision),
127
+ use_tf32=args.allow_tf32,
128
+ device='cuda'
129
+ )
130
+ # AutoMasker
131
+ mask_processor = VaeImageProcessor(vae_scale_factor=8, do_normalize=False, do_binarize=True, do_convert_grayscale=True)
132
+ automasker = AutoMasker(
133
+ densepose_ckpt=os.path.join(repo_path, "DensePose"),
134
+ schp_ckpt=os.path.join(repo_path, "SCHP"),
135
+ device='cuda',
136
+ )
137
+
138
+
139
+ # Flux-based CatVTON
140
+ access_token = os.getenv("HUGGING_FACE_HUB_TOKEN")
141
+ flux_repo = "black-forest-labs/FLUX.1-Fill-dev"
142
+ pipeline_flux = FluxTryOnPipeline.from_pretrained(flux_repo, use_auth_token=access_token)
143
+ pipeline_flux.load_lora_weights(
144
+ os.path.join(repo_path, "flux-lora"),
145
+ weight_name='pytorch_lora_weights.safetensors'
146
+ )
147
+ pipeline_flux.to("cuda", init_weight_dtype(args.mixed_precision))
148
+
149
+
150
+ # Mask-free CatVTON
151
+ catvton_mf_repo = "zhengchong/CatVTON-MaskFree"
152
+ repo_path_mf = snapshot_download(repo_id=catvton_mf_repo, use_auth_token=access_token)
153
+ pipeline_p2p = CatVTONPix2PixPipeline(
154
+ base_ckpt=args.p2p_base_model_path,
155
+ attn_ckpt=repo_path_mf,
156
+ attn_ckpt_version="mix-48k-1024",
157
+ weight_dtype=init_weight_dtype(args.mixed_precision),
158
+ use_tf32=args.allow_tf32,
159
+ device='cuda'
160
+ )
161
+
162
+
163
+ @spaces.GPU(duration=120)
164
+ def submit_function(
165
+ person_image,
166
+ cloth_image,
167
+ cloth_type,
168
+ num_inference_steps,
169
+ guidance_scale,
170
+ seed,
171
+ show_type
172
+ ):
173
+ person_image, mask = person_image["background"], person_image["layers"][0]
174
+ mask = Image.open(mask).convert("L")
175
+ if len(np.unique(np.array(mask))) == 1:
176
+ mask = None
177
+ else:
178
+ mask = np.array(mask)
179
+ mask[mask > 0] = 255
180
+ mask = Image.fromarray(mask)
181
+
182
+ tmp_folder = args.output_dir
183
+ date_str = datetime.now().strftime("%Y%m%d%H%M%S")
184
+ result_save_path = os.path.join(tmp_folder, date_str[:8], date_str[8:] + ".png")
185
+ if not os.path.exists(os.path.join(tmp_folder, date_str[:8])):
186
+ os.makedirs(os.path.join(tmp_folder, date_str[:8]))
187
+
188
+ generator = None
189
+ if seed != -1:
190
+ generator = torch.Generator(device='cuda').manual_seed(seed)
191
+
192
+ person_image = Image.open(person_image).convert("RGB")
193
+ cloth_image = Image.open(cloth_image).convert("RGB")
194
+ person_image = resize_and_crop(person_image, (args.width, args.height))
195
+ cloth_image = resize_and_padding(cloth_image, (args.width, args.height))
196
+
197
+ # Process mask
198
+ if mask is not None:
199
+ mask = resize_and_crop(mask, (args.width, args.height))
200
+ else:
201
+ mask = automasker(
202
+ person_image,
203
+ cloth_type
204
+ )['mask']
205
+ mask = mask_processor.blur(mask, blur_factor=9)
206
+
207
+ # Inference
208
+ # try:
209
+ result_image = pipeline(
210
+ image=person_image,
211
+ condition_image=cloth_image,
212
+ mask=mask,
213
+ num_inference_steps=num_inference_steps,
214
+ guidance_scale=guidance_scale,
215
+ generator=generator
216
+ )[0]
217
+ # except Exception as e:
218
+ # raise gr.Error(
219
+ # "An error occurred. Please try again later: {}".format(e)
220
+ # )
221
+
222
+ # Post-process
223
+ masked_person = vis_mask(person_image, mask)
224
+ save_result_image = image_grid([person_image, masked_person, cloth_image, result_image], 1, 4)
225
+ save_result_image.save(result_save_path)
226
+ if show_type == "result only":
227
+ return result_image
228
+ else:
229
+ width, height = person_image.size
230
+ if show_type == "input & result":
231
+ condition_width = width // 2
232
+ conditions = image_grid([person_image, cloth_image], 2, 1)
233
+ else:
234
+ condition_width = width // 3
235
+ conditions = image_grid([person_image, masked_person , cloth_image], 3, 1)
236
+ conditions = conditions.resize((condition_width, height), Image.NEAREST)
237
+ new_result_image = Image.new("RGB", (width + condition_width + 5, height))
238
+ new_result_image.paste(conditions, (0, 0))
239
+ new_result_image.paste(result_image, (condition_width + 5, 0))
240
+ return new_result_image
241
+
242
+ @spaces.GPU(duration=120)
243
+ def submit_function_p2p(
244
+ person_image,
245
+ cloth_image,
246
+ num_inference_steps,
247
+ guidance_scale,
248
+ seed):
249
+ person_image= person_image["background"]
250
+
251
+ tmp_folder = args.output_dir
252
+ date_str = datetime.now().strftime("%Y%m%d%H%M%S")
253
+ result_save_path = os.path.join(tmp_folder, date_str[:8], date_str[8:] + ".png")
254
+ if not os.path.exists(os.path.join(tmp_folder, date_str[:8])):
255
+ os.makedirs(os.path.join(tmp_folder, date_str[:8]))
256
+
257
+ generator = None
258
+ if seed != -1:
259
+ generator = torch.Generator(device='cuda').manual_seed(seed)
260
+
261
+ person_image = Image.open(person_image).convert("RGB")
262
+ cloth_image = Image.open(cloth_image).convert("RGB")
263
+ person_image = resize_and_crop(person_image, (args.width, args.height))
264
+ cloth_image = resize_and_padding(cloth_image, (args.width, args.height))
265
+
266
+ # Inference
267
+ try:
268
+ result_image = pipeline_p2p(
269
+ image=person_image,
270
+ condition_image=cloth_image,
271
+ num_inference_steps=num_inference_steps,
272
+ guidance_scale=guidance_scale,
273
+ generator=generator
274
+ )[0]
275
+ except Exception as e:
276
+ raise gr.Error(
277
+ "An error occurred. Please try again later: {}".format(e)
278
+ )
279
+
280
+ # Post-process
281
+ save_result_image = image_grid([person_image, cloth_image, result_image], 1, 3)
282
+ save_result_image.save(result_save_path)
283
+ return result_image
284
+
285
+ @spaces.GPU(duration=120)
286
+ def submit_function_flux(
287
+ person_image,
288
+ cloth_image,
289
+ cloth_type,
290
+ num_inference_steps,
291
+ guidance_scale,
292
+ seed,
293
+ show_type
294
+ ):
295
+
296
+ # Process image editor input
297
+ person_image, mask = person_image["background"], person_image["layers"][0]
298
+ mask = Image.open(mask).convert("L")
299
+ if len(np.unique(np.array(mask))) == 1:
300
+ mask = None
301
+ else:
302
+ mask = np.array(mask)
303
+ mask[mask > 0] = 255
304
+ mask = Image.fromarray(mask)
305
+
306
+ # Set random seed
307
+ generator = None
308
+ if seed != -1:
309
+ generator = torch.Generator(device='cuda').manual_seed(seed)
310
+
311
+ # Process input images
312
+ person_image = Image.open(person_image).convert("RGB")
313
+ cloth_image = Image.open(cloth_image).convert("RGB")
314
+
315
+ # Adjust image sizes
316
+ person_image = resize_and_crop(person_image, (args.width, args.height))
317
+ cloth_image = resize_and_padding(cloth_image, (args.width, args.height))
318
+
319
+ # Process mask
320
+ if mask is not None:
321
+ mask = resize_and_crop(mask, (args.width, args.height))
322
+ else:
323
+ mask = automasker(
324
+ person_image,
325
+ cloth_type
326
+ )['mask']
327
+ mask = mask_processor.blur(mask, blur_factor=9)
328
+
329
+ # Inference
330
+ result_image = pipeline_flux(
331
+ image=person_image,
332
+ condition_image=cloth_image,
333
+ mask_image=mask,
334
+ width=args.width,
335
+ height=args.height,
336
+ num_inference_steps=num_inference_steps,
337
+ guidance_scale=guidance_scale,
338
+ generator=generator
339
+ ).images[0]
340
+
341
+ # Post-processing
342
+ masked_person = vis_mask(person_image, mask)
343
+
344
+ # Return result based on show type
345
+ if show_type == "result only":
346
+ return result_image
347
+ else:
348
+ width, height = person_image.size
349
+ if show_type == "input & result":
350
+ condition_width = width // 2
351
+ conditions = image_grid([person_image, cloth_image], 2, 1)
352
+ else:
353
+ condition_width = width // 3
354
+ conditions = image_grid([person_image, masked_person, cloth_image], 3, 1)
355
+
356
+ conditions = conditions.resize((condition_width, height), Image.NEAREST)
357
+ new_result_image = Image.new("RGB", (width + condition_width + 5, height))
358
+ new_result_image.paste(conditions, (0, 0))
359
+ new_result_image.paste(result_image, (condition_width + 5, 0))
360
+ return new_result_image
361
+
362
+
363
+ def person_example_fn(image_path):
364
+ return image_path
365
+
366
+
367
+ HEADER = """
368
+ <h1 style="text-align: center;"> 🐈 CatVTON: Concatenation Is All You Need for Virtual Try-On with Diffusion Models </h1>
369
+ <div style="display: flex; justify-content: center; align-items: center;">
370
+ <a href="http://arxiv.org/abs/2407.15886" style="margin: 0 2px;">
371
+ <img src='https://img.shields.io/badge/arXiv-2407.15886-red?style=flat&logo=arXiv&logoColor=red' alt='arxiv'>
372
+ </a>
373
+ <a href='https://huggingface.co/zhengchong/CatVTON' style="margin: 0 2px;">
374
+ <img src='https://img.shields.io/badge/Hugging Face-ckpts-orange?style=flat&logo=HuggingFace&logoColor=orange' alt='huggingface'>
375
+ </a>
376
+ <a href="https://github.com/Zheng-Chong/CatVTON" style="margin: 0 2px;">
377
+ <img src='https://img.shields.io/badge/GitHub-Repo-blue?style=flat&logo=GitHub' alt='GitHub'>
378
+ </a>
379
+ <a href="http://120.76.142.206:8888" style="margin: 0 2px;">
380
+ <img src='https://img.shields.io/badge/Demo-Gradio-gold?style=flat&logo=Gradio&logoColor=red' alt='Demo'>
381
+ </a>
382
+ <a href="https://huggingface.co/spaces/zhengchong/CatVTON" style="margin: 0 2px;">
383
+ <img src='https://img.shields.io/badge/Space-ZeroGPU-orange?style=flat&logo=Gradio&logoColor=red' alt='Demo'>
384
+ </a>
385
+ <a href='https://zheng-chong.github.io/CatVTON/' style="margin: 0 2px;">
386
+ <img src='https://img.shields.io/badge/Webpage-Project-silver?style=flat&logo=&logoColor=orange' alt='webpage'>
387
+ </a>
388
+ <a href="https://github.com/Zheng-Chong/CatVTON/LICENCE" style="margin: 0 2px;">
389
+ <img src='https://img.shields.io/badge/License-CC BY--NC--SA--4.0-lightgreen?style=flat&logo=Lisence' alt='License'>
390
+ </a>
391
+ </div>
392
+ <br>
393
+ · This demo and our weights are only for Non-commercial Use. <br>
394
+ · Thanks to <a href="https://huggingface.co/zero-gpu-explorers">ZeroGPU</a> for providing A100 for our <a href="https://huggingface.co/spaces/zhengchong/CatVTON">HuggingFace Space</a>. <br>
395
+ · SafetyChecker is set to filter NSFW content, but it may block normal results too. Please adjust the <span>`seed`</span> for normal outcomes.<br>
396
+ """
397
+
398
+ def app_gradio():
399
+ with gr.Blocks(title="CatVTON") as demo:
400
+ gr.Markdown(HEADER)
401
+ with gr.Tab("Mask-based & SD1.5"):
402
+ with gr.Row():
403
+ with gr.Column(scale=1, min_width=350):
404
+ with gr.Row():
405
+ image_path = gr.Image(
406
+ type="filepath",
407
+ interactive=True,
408
+ visible=False,
409
+ )
410
+ person_image = gr.ImageEditor(
411
+ interactive=True, label="Person Image", type="filepath"
412
+ )
413
+
414
+ with gr.Row():
415
+ with gr.Column(scale=1, min_width=230):
416
+ cloth_image = gr.Image(
417
+ interactive=True, label="Condition Image", type="filepath"
418
+ )
419
+ with gr.Column(scale=1, min_width=120):
420
+ gr.Markdown(
421
+ '<span style="color: #808080; font-size: small;">Two ways to provide Mask:<br>1. Upload the person image and use the `🖌️` above to draw the Mask (higher priority)<br>2. Select the `Try-On Cloth Type` to generate automatically </span>'
422
+ )
423
+ cloth_type = gr.Radio(
424
+ label="Try-On Cloth Type",
425
+ choices=["upper", "lower", "overall"],
426
+ value="upper",
427
+ )
428
+
429
+
430
+ submit = gr.Button("Submit")
431
+ gr.Markdown(
432
+ '<center><span style="color: #FF0000">!!! Click only Once, Wait for Delay !!!</span></center>'
433
+ )
434
+
435
+ gr.Markdown(
436
+ '<span style="color: #808080; font-size: small;">Advanced options can adjust details:<br>1. `Inference Step` may enhance details;<br>2. `CFG` is highly correlated with saturation;<br>3. `Random seed` may improve pseudo-shadow.</span>'
437
+ )
438
+ with gr.Accordion("Advanced Options", open=False):
439
+ num_inference_steps = gr.Slider(
440
+ label="Inference Step", minimum=10, maximum=100, step=5, value=50
441
+ )
442
+ # Guidence Scale
443
+ guidance_scale = gr.Slider(
444
+ label="CFG Strenth", minimum=0.0, maximum=7.5, step=0.5, value=2.5
445
+ )
446
+ # Random Seed
447
+ seed = gr.Slider(
448
+ label="Seed", minimum=-1, maximum=10000, step=1, value=42
449
+ )
450
+ show_type = gr.Radio(
451
+ label="Show Type",
452
+ choices=["result only", "input & result", "input & mask & result"],
453
+ value="input & mask & result",
454
+ )
455
+
456
+ with gr.Column(scale=2, min_width=500):
457
+ result_image = gr.Image(interactive=False, label="Result")
458
+ with gr.Row():
459
+ # Photo Examples
460
+ root_path = "resource/demo/example"
461
+ with gr.Column():
462
+ men_exm = gr.Examples(
463
+ examples=[
464
+ os.path.join(root_path, "person", "men", _)
465
+ for _ in os.listdir(os.path.join(root_path, "person", "men"))
466
+ ],
467
+ examples_per_page=4,
468
+ inputs=image_path,
469
+ label="Person Examples ①",
470
+ )
471
+ women_exm = gr.Examples(
472
+ examples=[
473
+ os.path.join(root_path, "person", "women", _)
474
+ for _ in os.listdir(os.path.join(root_path, "person", "women"))
475
+ ],
476
+ examples_per_page=4,
477
+ inputs=image_path,
478
+ label="Person Examples ②",
479
+ )
480
+ gr.Markdown(
481
+ '<span style="color: #808080; font-size: small;">*Person examples come from the demos of <a href="https://huggingface.co/spaces/levihsu/OOTDiffusion">OOTDiffusion</a> and <a href="https://www.outfitanyone.org">OutfitAnyone</a>. </span>'
482
+ )
483
+ with gr.Column():
484
+ condition_upper_exm = gr.Examples(
485
+ examples=[
486
+ os.path.join(root_path, "condition", "upper", _)
487
+ for _ in os.listdir(os.path.join(root_path, "condition", "upper"))
488
+ ],
489
+ examples_per_page=4,
490
+ inputs=cloth_image,
491
+ label="Condition Upper Examples",
492
+ )
493
+ condition_overall_exm = gr.Examples(
494
+ examples=[
495
+ os.path.join(root_path, "condition", "overall", _)
496
+ for _ in os.listdir(os.path.join(root_path, "condition", "overall"))
497
+ ],
498
+ examples_per_page=4,
499
+ inputs=cloth_image,
500
+ label="Condition Overall Examples",
501
+ )
502
+ condition_person_exm = gr.Examples(
503
+ examples=[
504
+ os.path.join(root_path, "condition", "person", _)
505
+ for _ in os.listdir(os.path.join(root_path, "condition", "person"))
506
+ ],
507
+ examples_per_page=4,
508
+ inputs=cloth_image,
509
+ label="Condition Reference Person Examples",
510
+ )
511
+ gr.Markdown(
512
+ '<span style="color: #808080; font-size: small;">*Condition examples come from the Internet. </span>'
513
+ )
514
+
515
+ image_path.change(
516
+ person_example_fn, inputs=image_path, outputs=person_image
517
+ )
518
+
519
+ submit.click(
520
+ submit_function,
521
+ [
522
+ person_image,
523
+ cloth_image,
524
+ cloth_type,
525
+ num_inference_steps,
526
+ guidance_scale,
527
+ seed,
528
+ show_type,
529
+ ],
530
+ result_image,
531
+ )
532
+
533
+ with gr.Tab("Mask-based & Flux.1 Fill Dev"):
534
+ with gr.Row():
535
+ with gr.Column(scale=1, min_width=350):
536
+ with gr.Row():
537
+ image_path_flux = gr.Image(
538
+ type="filepath",
539
+ interactive=True,
540
+ visible=False,
541
+ )
542
+ person_image_flux = gr.ImageEditor(
543
+ interactive=True, label="Person Image", type="filepath"
544
+ )
545
+
546
+ with gr.Row():
547
+ with gr.Column(scale=1, min_width=230):
548
+ cloth_image_flux = gr.Image(
549
+ interactive=True, label="Condition Image", type="filepath"
550
+ )
551
+ with gr.Column(scale=1, min_width=120):
552
+ gr.Markdown(
553
+ '<span style="color: #808080; font-size: small;">Two ways to provide Mask:<br>1. Upload the person image and use the `🖌️` above to draw the Mask (higher priority)<br>2. Select the `Try-On Cloth Type` to generate automatically </span>'
554
+ )
555
+ cloth_type = gr.Radio(
556
+ label="Try-On Cloth Type",
557
+ choices=["upper", "lower", "overall"],
558
+ value="upper",
559
+ )
560
+
561
+ submit_flux = gr.Button("Submit")
562
+ gr.Markdown(
563
+ '<center><span style="color: #FF0000">!!! Click only Once, Wait for Delay !!!</span></center>'
564
+ )
565
+
566
+ with gr.Accordion("Advanced Options", open=False):
567
+ num_inference_steps_flux = gr.Slider(
568
+ label="Inference Step", minimum=10, maximum=100, step=5, value=50
569
+ )
570
+ # Guidence Scale
571
+ guidance_scale_flux = gr.Slider(
572
+ label="CFG Strenth", minimum=0.0, maximum=50, step=0.5, value=30
573
+ )
574
+ # Random Seed
575
+ seed_flux = gr.Slider(
576
+ label="Seed", minimum=-1, maximum=10000, step=1, value=42
577
+ )
578
+ show_type = gr.Radio(
579
+ label="Show Type",
580
+ choices=["result only", "input & result", "input & mask & result"],
581
+ value="input & mask & result",
582
+ )
583
+
584
+ with gr.Column(scale=2, min_width=500):
585
+ result_image_flux = gr.Image(interactive=False, label="Result")
586
+ with gr.Row():
587
+ # Photo Examples
588
+ root_path = "resource/demo/example"
589
+ with gr.Column():
590
+ gr.Examples(
591
+ examples=[
592
+ os.path.join(root_path, "person", "men", _)
593
+ for _ in os.listdir(os.path.join(root_path, "person", "men"))
594
+ ],
595
+ examples_per_page=4,
596
+ inputs=image_path_flux,
597
+ label="Person Examples ①",
598
+ )
599
+ gr.Examples(
600
+ examples=[
601
+ os.path.join(root_path, "person", "women", _)
602
+ for _ in os.listdir(os.path.join(root_path, "person", "women"))
603
+ ],
604
+ examples_per_page=4,
605
+ inputs=image_path_flux,
606
+ label="Person Examples ②",
607
+ )
608
+ gr.Markdown(
609
+ '<span style="color: #808080; font-size: small;">*Person examples come from the demos of <a href="https://huggingface.co/spaces/levihsu/OOTDiffusion">OOTDiffusion</a> and <a href="https://www.outfitanyone.org">OutfitAnyone</a>. </span>'
610
+ )
611
+ with gr.Column():
612
+ gr.Examples(
613
+ examples=[
614
+ os.path.join(root_path, "condition", "upper", _)
615
+ for _ in os.listdir(os.path.join(root_path, "condition", "upper"))
616
+ ],
617
+ examples_per_page=4,
618
+ inputs=cloth_image_flux,
619
+ label="Condition Upper Examples",
620
+ )
621
+ gr.Examples(
622
+ examples=[
623
+ os.path.join(root_path, "condition", "overall", _)
624
+ for _ in os.listdir(os.path.join(root_path, "condition", "overall"))
625
+ ],
626
+ examples_per_page=4,
627
+ inputs=cloth_image_flux,
628
+ label="Condition Overall Examples",
629
+ )
630
+ condition_person_exm = gr.Examples(
631
+ examples=[
632
+ os.path.join(root_path, "condition", "person", _)
633
+ for _ in os.listdir(os.path.join(root_path, "condition", "person"))
634
+ ],
635
+ examples_per_page=4,
636
+ inputs=cloth_image_flux,
637
+ label="Condition Reference Person Examples",
638
+ )
639
+ gr.Markdown(
640
+ '<span style="color: #808080; font-size: small;">*Condition examples come from the Internet. </span>'
641
+ )
642
+
643
+
644
+ image_path_flux.change(
645
+ person_example_fn, inputs=image_path_flux, outputs=person_image_flux
646
+ )
647
+
648
+ submit_flux.click(
649
+ submit_function_flux,
650
+ [person_image_flux, cloth_image_flux, cloth_type, num_inference_steps_flux, guidance_scale_flux, seed_flux, show_type],
651
+ result_image_flux,
652
+ )
653
+
654
+
655
+ with gr.Tab("Mask-free & SD1.5"):
656
+ with gr.Row():
657
+ with gr.Column(scale=1, min_width=350):
658
+ with gr.Row():
659
+ image_path_p2p = gr.Image(
660
+ type="filepath",
661
+ interactive=True,
662
+ visible=False,
663
+ )
664
+ person_image_p2p = gr.ImageEditor(
665
+ interactive=True, label="Person Image", type="filepath"
666
+ )
667
+
668
+ with gr.Row():
669
+ with gr.Column(scale=1, min_width=230):
670
+ cloth_image_p2p = gr.Image(
671
+ interactive=True, label="Condition Image", type="filepath"
672
+ )
673
+
674
+ submit_p2p = gr.Button("Submit")
675
+ gr.Markdown(
676
+ '<center><span style="color: #FF0000">!!! Click only Once, Wait for Delay !!!</span></center>'
677
+ )
678
+
679
+ gr.Markdown(
680
+ '<span style="color: #808080; font-size: small;">Advanced options can adjust details:<br>1. `Inference Step` may enhance details;<br>2. `CFG` is highly correlated with saturation;<br>3. `Random seed` may improve pseudo-shadow.</span>'
681
+ )
682
+ with gr.Accordion("Advanced Options", open=False):
683
+ num_inference_steps_p2p = gr.Slider(
684
+ label="Inference Step", minimum=10, maximum=100, step=5, value=50
685
+ )
686
+ # Guidence Scale
687
+ guidance_scale_p2p = gr.Slider(
688
+ label="CFG Strenth", minimum=0.0, maximum=7.5, step=0.5, value=2.5
689
+ )
690
+ # Random Seed
691
+ seed_p2p = gr.Slider(
692
+ label="Seed", minimum=-1, maximum=10000, step=1, value=42
693
+ )
694
+ # show_type = gr.Radio(
695
+ # label="Show Type",
696
+ # choices=["result only", "input & result", "input & mask & result"],
697
+ # value="input & mask & result",
698
+ # )
699
+
700
+ with gr.Column(scale=2, min_width=500):
701
+ result_image_p2p = gr.Image(interactive=False, label="Result")
702
+ with gr.Row():
703
+ # Photo Examples
704
+ root_path = "resource/demo/example"
705
+ with gr.Column():
706
+ gr.Examples(
707
+ examples=[
708
+ os.path.join(root_path, "person", "men", _)
709
+ for _ in os.listdir(os.path.join(root_path, "person", "men"))
710
+ ],
711
+ examples_per_page=4,
712
+ inputs=image_path_p2p,
713
+ label="Person Examples ①",
714
+ )
715
+ gr.Examples(
716
+ examples=[
717
+ os.path.join(root_path, "person", "women", _)
718
+ for _ in os.listdir(os.path.join(root_path, "person", "women"))
719
+ ],
720
+ examples_per_page=4,
721
+ inputs=image_path_p2p,
722
+ label="Person Examples ②",
723
+ )
724
+ gr.Markdown(
725
+ '<span style="color: #808080; font-size: small;">*Person examples come from the demos of <a href="https://huggingface.co/spaces/levihsu/OOTDiffusion">OOTDiffusion</a> and <a href="https://www.outfitanyone.org">OutfitAnyone</a>. </span>'
726
+ )
727
+ with gr.Column():
728
+ gr.Examples(
729
+ examples=[
730
+ os.path.join(root_path, "condition", "upper", _)
731
+ for _ in os.listdir(os.path.join(root_path, "condition", "upper"))
732
+ ],
733
+ examples_per_page=4,
734
+ inputs=cloth_image_p2p,
735
+ label="Condition Upper Examples",
736
+ )
737
+ gr.Examples(
738
+ examples=[
739
+ os.path.join(root_path, "condition", "overall", _)
740
+ for _ in os.listdir(os.path.join(root_path, "condition", "overall"))
741
+ ],
742
+ examples_per_page=4,
743
+ inputs=cloth_image_p2p,
744
+ label="Condition Overall Examples",
745
+ )
746
+ condition_person_exm = gr.Examples(
747
+ examples=[
748
+ os.path.join(root_path, "condition", "person", _)
749
+ for _ in os.listdir(os.path.join(root_path, "condition", "person"))
750
+ ],
751
+ examples_per_page=4,
752
+ inputs=cloth_image_p2p,
753
+ label="Condition Reference Person Examples",
754
+ )
755
+ gr.Markdown(
756
+ '<span style="color: #808080; font-size: small;">*Condition examples come from the Internet. </span>'
757
+ )
758
+
759
+ image_path_p2p.change(
760
+ person_example_fn, inputs=image_path_p2p, outputs=person_image_p2p
761
+ )
762
+
763
+ submit_p2p.click(
764
+ submit_function_p2p,
765
+ [
766
+ person_image_p2p,
767
+ cloth_image_p2p,
768
+ num_inference_steps_p2p,
769
+ guidance_scale_p2p,
770
+ seed_p2p],
771
+ result_image_p2p,
772
+ )
773
+
774
+ demo.queue().launch(share=True, show_error=True)
775
+
776
+
777
+ if __name__ == "__main__":
778
+ app_gradio()