File size: 9,394 Bytes
449d6db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1bfde71
449d6db
 
 
 
e813159
 
 
 
 
 
 
1bfde71
 
 
449d6db
1bfde71
 
449d6db
1bfde71
 
 
 
449d6db
1bfde71
 
 
 
 
449d6db
1bfde71
 
449d6db
1bfde71
 
 
449d6db
1bfde71
 
449d6db
1bfde71
 
 
449d6db
1bfde71
 
449d6db
1bfde71
 
 
449d6db
1bfde71
 
 
449d6db
 
 
0a2963b
1bfde71
449d6db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1bfde71
449d6db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1bfde71
449d6db
 
 
 
 
 
babe222
 
 
 
 
 
 
 
449d6db
 
 
 
d4e8caa
449d6db
 
 
cb011c7
7f4bd00
fca368f
f44ab65
449d6db
e53c7ab
ba36c3d
 
 
 
541c781
1bfde71
 
 
 
 
 
 
 
541c781
 
1bfde71
449d6db
541c781
 
 
449d6db
 
 
1bfde71
449d6db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb011c7
449d6db
 
 
931b456
 
 
 
 
 
 
 
 
 
 
449d6db
 
 
 
1bfde71
 
449d6db
931b456
 
 
 
 
449d6db
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
# PyTorch 2.8 (temporary hack)
import os
os.system('pip install --upgrade --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu126 "torch<2.9" spaces')

# Actual demo code
import gradio as gr
import numpy as np
import spaces
import torch
import random
from PIL import Image, ImageOps

from diffusers import FluxKontextPipeline
from diffusers.utils import load_image

# from optimization import optimize_pipeline_

MAX_SEED = np.iinfo(np.int32).max

pipe = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16).to("cuda")
pipe.load_lora_weights("ovi054/virtual-tryon-kontext-lora")
pipe.fuse_lora()
# optimize_pipeline_(pipe, image=Image.new("RGB", (512, 512)), prompt='prompt')


import os

EXAMPLES_DIR = "examples"
BASE_EXAMPLES = [os.path.join(EXAMPLES_DIR, "base", f) for f in sorted(os.listdir(os.path.join(EXAMPLES_DIR, "base")))]
FACE_EXAMPLES = [os.path.join(EXAMPLES_DIR, "face", f) for f in sorted(os.listdir(os.path.join(EXAMPLES_DIR, "face")))]


# def add_overlay(base_img, overlay_img, margin=20):
#     """
#     Pastes an overlay image onto the top-right corner of a base image.

#     The overlay is resized to be 1/5th of the width of the base image,
#     maintaining its aspect ratio.

#     Args:
#         base_img (PIL.Image.Image): The main image.
#         overlay_img (PIL.Image.Image): The image to place on top.
#         margin (int, optional): The pixel margin from the top and right edges. Defaults to 20.

#     Returns:
#         PIL.Image.Image: The combined image.
#     """
#     if base_img is None or overlay_img is None:
#         return base_img
    
#     base = base_img.convert("RGBA")
#     overlay = overlay_img.convert("RGBA")
    
#     # --- MODIFICATION ---
#     # Calculate the target width to be 1/5th of the base image's width
#     target_width = base.width // 5
    
#     # Keep aspect ratio, resize overlay to the newly calculated target width
#     w, h = overlay.size
    
#     # Add a check to prevent division by zero if the overlay image has no width
#     if w == 0:
#         return base
        
#     new_height = int(h * (target_width / w))
#     overlay = overlay.resize((target_width, new_height), Image.LANCZOS)

#     # Position: top-right corner with a margin
#     x = base.width - overlay.width - margin
#     y = margin

#     # Paste the resized overlay onto the base image using its alpha channel for transparency
#     base.paste(overlay, (x, y), overlay)
#     return base



@spaces.GPU(duration=45)
def infer(input_image_upload, prompt="wear it", seed=42, randomize_seed=False, guidance_scale=2.5, steps=28, progress=gr.Progress(track_tqdm=True)):
    """
    Perform image editing using the FLUX.1 Kontext pipeline.
    
    This function takes an input image and a text prompt to generate a modified version
    of the image based on the provided instructions. It uses the FLUX.1 Kontext model
    for contextual image editing tasks.
    
    Args:
        input_image (dict or PIL.Image.Image): The input from the gr.Paint component.
        input_image_upload (PIL.Image.Image): The input from the gr.Image upload component.
        overlay_image (PIL.Image.Image): The face photo to overlay.
        prompt (str): Text description of the desired edit to apply to the image.
        seed (int, optional): Random seed for reproducible generation.
        randomize_seed (bool, optional): If True, generates a random seed.
        guidance_scale (float, optional): Controls how closely the model follows the prompt.
        steps (int, optional): Controls how many steps to run the diffusion model for.
        progress (gr.Progress, optional): Gradio progress tracker.
    
    Returns:
        tuple: A 4-tuple containing the result image, the processed input image, the seed, and a gr.Button update.
    """
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)

    # --- CORRECTED LOGIC STARTS HERE ---
    
    # 1. Prioritize the uploaded image. If it exists, it becomes our main 'input_image'.
    if input_image_upload is not None:
        processed_input_image = input_image_upload
        
    else:
        # Fallback in case the input is neither from upload nor a valid canvas dict.
        processed_input_image = None
        
    # --- CORRECTED LOGIC ENDS HERE ---
    
    # From this point on, 'processed_input_image' is either a PIL Image or None.
    if processed_input_image is not None:
            
        processed_input_image = processed_input_image.convert("RGB")
        image = pipe(
            image=processed_input_image, 
            prompt=prompt,
            guidance_scale=guidance_scale,
            width = processed_input_image.size[0],
            height = processed_input_image.size[1],
            num_inference_steps=steps,
            generator=torch.Generator().manual_seed(seed),
        ).images[0]
    else:
        # Handle the text-to-image case where no input image was provided.
        image = pipe(
            prompt=prompt,
            guidance_scale=guidance_scale,
            num_inference_steps=steps,
            generator=torch.Generator().manual_seed(seed),
        ).images[0]
        
    return image, seed, gr.Button(visible=False)
    
@spaces.GPU
def infer_example(input_image, prompt):
    image, seed, _ = infer(input_image, prompt)
    return image, seed

# css="""
# #col-container {
#     margin: 0 auto;
#     max-width: 960px;
# }
# """

css=""

with gr.Blocks(css=css) as demo:
    
    with gr.Column(elem_id="col-container"):
        gr.Markdown(f"""# FLUX.1 Kontext [dev] + [Virtual Try-On LoRA](https://huggingface.co/ovi054/virtual-tryon-kontext-lora)
        """)
        with gr.Row():
            with gr.Column():
                gr.Markdown("""Step 1. Select/Upload the combined model and garment image ⬇️<br>
Place the garment onto the model image as an overlay using [this tool](https://v0-image-editor-app-eight.vercel.app/).
""")
                # input_image = gr.Image(label="Upload Image", type="pil")
                with gr.Row():
                    input_image_upload = gr.Image(label="Upload Image", type="pil")
                gr.Examples(
                    examples=[[img] for img in BASE_EXAMPLES],
                    inputs=[input_image_upload],
                )

            # with gr.Column():
            #     gr.Markdown("Step 2.  Select/Upload a face photo ⬇️")
            #     with gr.Row():
            #         overlay_image = gr.Image(label="Upload face photo", type="pil")
            #     gr.Examples(
            #         examples=[[img] for img in FACE_EXAMPLES],
            #         inputs=[overlay_image],
            #     )
                    
            with gr.Column():
                gr.Markdown("Step 2.  Press “Run” to get results ⬇️")
                with gr.Row():
                    run_button = gr.Button("Run")
                with gr.Accordion("Advanced Settings", open=False):

                    prompt = gr.Text(
                        label="Prompt",
                        max_lines=1,
                        value = "wear it",
                        placeholder="Enter your prompt for editing (e.g., 'Remove glasses', 'Add a hat')",
                        container=False,
                    )
                    
                    seed = gr.Slider(
                        label="Seed",
                        minimum=0,
                        maximum=MAX_SEED,
                        step=1,
                        value=0,
                    )
                    
                    randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
                    
                    guidance_scale = gr.Slider(
                        label="Guidance Scale",
                        minimum=1,
                        maximum=10,
                        step=0.1,
                        value=2.5,
                    )       
                    
                    steps = gr.Slider(
                        label="Steps",
                        minimum=1,
                        maximum=30,
                        value=28,
                        step=1
                    )
                result = gr.Image(label="Result", show_label=False, interactive=False)
                result_input = gr.Image(label="Result", visible=False, show_label=False, interactive=False)
                reuse_button = gr.Button("Reuse this image", visible=False)
        
            
        # examples = gr.Examples(
        #     examples=[
        #         ["flowers.png", "turn the flowers into sunflowers"],
        #         ["monster.png", "make this monster ride a skateboard on the beach"],
        #         ["cat.png", "make this cat happy"]
        #     ],
        #     inputs=[input_image_upload, prompt],
        #     outputs=[result, seed],
        #     fn=infer_example,
        #     cache_examples="lazy"
        # )
            
    gr.on(
        triggers=[run_button.click, prompt.submit],
        fn = infer,
        inputs = [input_image_upload, prompt, seed, randomize_seed, guidance_scale, steps],
        outputs = [result, seed, reuse_button]
    )
    # reuse_button.click(
    #     fn = lambda image: image,
    #     inputs = [result],
    #     outputs = [input_image]
    # )

demo.launch(mcp_server=True)