g-ronimo commited on
Commit
8e2bb3f
1 Parent(s): b1caf2e

Upload tryon_inference.py

Browse files
Files changed (1) hide show
  1. tryon_inference.py +124 -0
tryon_inference.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ from diffusers.utils import load_image, check_min_version
4
+ from diffusers import FluxPriorReduxPipeline, FluxFillPipeline
5
+ from diffusers import FluxTransformer2DModel
6
+ import numpy as np
7
+ from torchvision import transforms
8
+
9
+ def run_inference(
10
+ image_path,
11
+ mask_path,
12
+ garment_path,
13
+ size=(576, 768),
14
+ num_steps=50,
15
+ guidance_scale=30,
16
+ seed=42,
17
+ pipe=None
18
+ ):
19
+ # Build pipeline
20
+ if pipe is None:
21
+ transformer = FluxTransformer2DModel.from_pretrained(
22
+ "xiaozaa/catvton-flux-alpha",
23
+ torch_dtype=torch.bfloat16
24
+ )
25
+ pipe = FluxFillPipeline.from_pretrained(
26
+ "black-forest-labs/FLUX.1-dev",
27
+ transformer=transformer,
28
+ torch_dtype=torch.bfloat16
29
+ ).to("cuda")
30
+ else:
31
+ pipe.to("cuda")
32
+
33
+ pipe.transformer.to(torch.bfloat16)
34
+
35
+ # Add transform
36
+ transform = transforms.Compose([
37
+ transforms.ToTensor(),
38
+ transforms.Normalize([0.5], [0.5]) # For RGB images
39
+ ])
40
+ mask_transform = transforms.Compose([
41
+ transforms.ToTensor()
42
+ ])
43
+
44
+ # Load and process images
45
+ print("image_path", image_path)
46
+ image = load_image(image_path).convert("RGB").resize(size)
47
+ mask = load_image(mask_path).convert("RGB").resize(size)
48
+ garment = load_image(garment_path).convert("RGB").resize(size)
49
+
50
+ # Transform images using the new preprocessing
51
+ image_tensor = transform(image)
52
+ mask_tensor = mask_transform(mask)[:1] # Take only first channel
53
+ garment_tensor = transform(garment)
54
+
55
+ # Create concatenated images
56
+ inpaint_image = torch.cat([garment_tensor, image_tensor], dim=2) # Concatenate along width
57
+ garment_mask = torch.zeros_like(mask_tensor)
58
+ extended_mask = torch.cat([garment_mask, mask_tensor], dim=2)
59
+
60
+ prompt = f"The pair of images highlights a clothing and its styling on a model, high resolution, 4K, 8K; " \
61
+ f"[IMAGE1] Detailed product shot of a clothing" \
62
+ f"[IMAGE2] The same cloth is worn by a model in a lifestyle setting."
63
+
64
+ generator = torch.Generator(device="cuda").manual_seed(seed)
65
+
66
+ result = pipe(
67
+ height=size[1],
68
+ width=size[0] * 2,
69
+ image=inpaint_image,
70
+ mask_image=extended_mask,
71
+ num_inference_steps=num_steps,
72
+ generator=generator,
73
+ max_sequence_length=512,
74
+ guidance_scale=guidance_scale,
75
+ prompt=prompt,
76
+ ).images[0]
77
+
78
+ # Split and save results
79
+ width = size[0]
80
+ garment_result = result.crop((0, 0, width, size[1]))
81
+ tryon_result = result.crop((width, 0, width * 2, size[1]))
82
+
83
+
84
+ return garment_result, tryon_result
85
+
86
+ def main():
87
+ parser = argparse.ArgumentParser(description='Run FLUX virtual try-on inference')
88
+ parser.add_argument('--image', required=True, help='Path to the model image')
89
+ parser.add_argument('--mask', required=True, help='Path to the agnostic mask')
90
+ parser.add_argument('--garment', required=True, help='Path to the garment image')
91
+ parser.add_argument('--output-garment', default='flux_inpaint_garment.png', help='Output path for garment result')
92
+ parser.add_argument('--output-tryon', default='flux_inpaint_tryon.png', help='Output path for try-on result')
93
+ parser.add_argument('--steps', type=int, default=50, help='Number of inference steps')
94
+ parser.add_argument('--guidance-scale', type=float, default=30, help='Guidance scale')
95
+ parser.add_argument('--seed', type=int, default=0, help='Random seed')
96
+ parser.add_argument('--width', type=int, default=768, help='Width')
97
+ parser.add_argument('--height', type=int, default=576, help='Height')
98
+
99
+ args = parser.parse_args()
100
+
101
+ check_min_version("0.30.2")
102
+
103
+ garment_result, tryon_result = run_inference(
104
+ image_path=args.image,
105
+ mask_path=args.mask,
106
+ garment_path=args.garment,
107
+ output_garment_path=args.output_garment,
108
+ output_tryon_path=args.output_tryon,
109
+ num_steps=args.steps,
110
+ guidance_scale=args.guidance_scale,
111
+ seed=args.seed,
112
+ size=(args.width, args.height)
113
+ )
114
+ output_garment_path=args.output_garment,
115
+ output_tryon_path=args.output_tryon,
116
+
117
+ if output_garment_path is not None:
118
+ garment_result.save(output_garment_path)
119
+ tryon_result.save(output_tryon_path)
120
+
121
+ print("Successfully saved garment and try-on images")
122
+
123
+ if __name__ == "__main__":
124
+ main()