File size: 7,698 Bytes
291e480
 
 
 
 
 
 
61f0b81
291e480
 
 
 
61f0b81
291e480
 
 
 
 
 
 
 
61f0b81
 
 
 
291e480
 
61f0b81
291e480
 
 
 
61f0b81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291e480
 
 
 
 
61f0b81
 
 
 
291e480
 
61f0b81
 
291e480
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61f0b81
 
 
 
 
 
 
 
 
 
 
 
291e480
 
 
61f0b81
 
291e480
 
 
56d0f44
61f0b81
 
 
 
 
 
 
 
 
 
 
 
 
 
56d0f44
291e480
 
61f0b81
 
56d0f44
 
 
291e480
 
61f0b81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291e480
 
 
 
61f0b81
 
 
 
 
291e480
 
 
 
61f0b81
 
291e480
 
56d0f44
 
 
61f0b81
56d0f44
 
291e480
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import gradio as gr
import asyncio
import fal_client
from dotenv import load_dotenv
import os
from pathlib import Path
import time
import json

load_dotenv()
os.environ["FAL_KEY"] = os.getenv("FAL_API_KEY")

async def generate_paris_images(product_name: str, image1_path: str, image2_path: str, woman_prompt: str, man_prompt: str, girl_name: str, girl_hair_length: str, girl_hair_style: str, girl_hair_color: str, boy_name: str, boy_hair_length: str, boy_hair_style: str, boy_hair_color: str, batch_size: int, progress=gr.Progress()):
    start_time = time.time()
    print("Progress: 5% - Starting Paris image generation...")
    progress(0.05, desc="Starting Paris image generation...")
    
    # Upload all images in parallel
    upload_tasks = [
        fal_client.upload_file_async(str(image1_path)),
        fal_client.upload_file_async(str(image2_path)),
        fal_client.upload_file_async("template/man_pose.png"),
        fal_client.upload_file_async("template/woman_pose.png"),
        fal_client.upload_file_async("template/woman_clip_mask.png"),
        fal_client.upload_file_async("template/man_clip_mask.png")
    ]
    
    [image1_url, image2_url, man_pose_img, woman_pose_img, woman_clip_mask, man_clip_mask] = await asyncio.gather(*upload_tasks)
    
    print("Progress: 40% - Uploaded all images")
    progress(0.4, desc="Uploaded all images")

    # Replace {hair_feature} placeholders with user hair descriptions
    woman_hair_desc = f"{girl_hair_length} {girl_hair_style} {girl_hair_color} hair,"
    print(f"Final woman hair description: {woman_hair_desc}")

    # Handle bald case for man's hair description
    if boy_hair_length == "Bald":
        man_hair_desc = "bald,"
    else:
        man_hair_desc = f"{boy_hair_length} {boy_hair_style} {boy_hair_color} hair,"
    print(f"Final man hair description: {man_hair_desc}")
    
    woman_prompt = woman_prompt.replace("{hair_feature}", woman_hair_desc)
    man_prompt = man_prompt.replace("{hair_feature}", man_hair_desc)

    print(f"Final woman prompt: {woman_prompt}")
    print(f"Final man prompt: {man_prompt}")

    handler = await fal_client.submit_async(
        "comfy/LVE/paris-couple",
        arguments={
            "loadimage_1": image1_url,
            "loadimage_2": image2_url,
            "loadimage_3": woman_pose_img,
            "loadimage_4": woman_clip_mask,
            "loadimage_5": man_clip_mask,
            "loadimage_6": man_pose_img,
            "woman_prompt": woman_prompt,
            "man_prompt": man_prompt,
            "girl_name": girl_name,
            "boy_name": boy_name,
            "batch_size": batch_size
        }
    )

    print("Progress: 60% - Processing images...")
    progress(0.6, desc="Processing images...")
    
    result = await handler.get()
    print(result)

    end_time = time.time()
    processing_time = end_time - start_time
    print(f"Progress: 100% - Generation completed in {processing_time:.2f} seconds")
    progress(1.0, desc=f"Generation completed in {processing_time:.2f} seconds")
    
    # Fix the URL extraction logic
    image_215 = []
    image_818 = []
    if "outputs" in result:
        if "215" in result["outputs"]:
            image_215 = [img["url"] for img in result["outputs"]["215"]["images"]]
        if "818" in result["outputs"]:
            image_818 = [img["url"] for img in result["outputs"]["818"]["images"]]

    print(f"Image 215: {image_215}")
    print(f"Image 818: {image_818}")

    # Return all generated image URLs and processing time
    # Get the first key from outputs dynamically
    return (
        image_215,
        image_818,
        f"Processing time: {processing_time:.2f} seconds"
    )

def change_product_preview(product_name):
    # Load prompts from JSON file
    with open('prompt.json', 'r') as f:
        prompts = json.load(f)
    
    # Find the matching prompt data
    prompt_data = next((item for item in prompts if item['title'] == product_name), None)
    
    if prompt_data:
        return (
            f"thumbnail/{product_name}.png",
            prompt_data['woman'],
            prompt_data['man']
        )
    return None, "", ""

with gr.Blocks() as demo:
    with gr.Row():
        product_name = gr.Dropdown(label="Product Name", choices=["Winter", "Classy", "Night Out", "Romantic"], value="Winter")
        product_preview = gr.Image(label="Product Preview", type="filepath", value="thumbnail/Winter.png", height=500, width=500)
    with gr.Row():
        image1_input = gr.Image(label="Upload Woman Image", type="filepath", value="user3-f.jpg")
        image2_input = gr.Image(label="Upload Man Image", type="filepath", value="user3-m.jpg")
    
    with gr.Row():
        with gr.Column():
            woman_prompt = gr.Textbox(
                label="Woman Prompt",
                value="Close-up, portrait photo, a woman, {hair_feature} wearing a cream-colored wool coat, chunky knit scarf, and matching earmuffs, standing on the same snow-dusted cobblestone street, illuminated Eiffel Tower glowing golden in the background, snowflakes sparkling in the warm streetlight glow, looking at camera with gentle smile."
            )
            girl_name = gr.Textbox(
                label="Girl Name",
                value="julie delpy"
            )
            girl_hair_length = gr.Dropdown(label="Girl Hair Length", choices=["Short", "Medium", "Long"], value="Long")
            girl_hair_style = gr.Dropdown(label="Girl Hair Style", choices=["Straight", "Wavy", "Curly"], value="Straight")
            girl_hair_color = gr.Dropdown(label="Girl Hair Color", choices=["Blonde", "Brown", "Black", "Brunette", "Redhead", "Bronde"], value="Bronde")
        with gr.Column():
            man_prompt = gr.Textbox(
                label="Man Prompt",
                value="Close-up, portrait photo, a man, {hair_feature} wearing a dark navy wool peacoat, cashmere scarf, and leather gloves, standing on a snow-dusted cobblestone street, illuminated Eiffel Tower in the background glowing golden against the night sky, gentle snowflakes catching the warm glow of vintage streetlamps, looking at camera with confident expression."
            )
            boy_name = gr.Textbox(
                label="Boy Name",
                value="ethan hawke"
            )
            boy_hair_length = gr.Dropdown(label="Boy Hair Length", choices=["Short", "Medium", "Long", "Bald"], value="Short")
            boy_hair_style = gr.Dropdown(label="Boy Hair Style", choices=["None", "Undercut", "Mullet", "French Crop", "Slicked Back", "Fade", "Buzz Cut"], value="Undercut")
            boy_hair_color = gr.Dropdown(label="Boy Hair Color", choices=["None", "Blonde", "Brown", "Black", "Brunette", "Redhead"], value="Black")
    
    batch_size = gr.Slider(minimum=1, maximum=8, value=4, step=1, label="Batch Size")
    
    generate_btn = gr.Button("Generate")

    with gr.Row():
        image_output = gr.Gallery(label="Generated Image Raw")
        image_output_processed = gr.Gallery(label="Generated Image Final")

    time_output = gr.Textbox(label="Processing Time")
    
    generate_btn.click(
        fn=generate_paris_images,
        inputs=[product_name, image1_input, image2_input, woman_prompt, man_prompt, girl_name, girl_hair_length, girl_hair_style, girl_hair_color, boy_name, boy_hair_length, boy_hair_style, boy_hair_color, batch_size],
        outputs=[image_output, image_output_processed, time_output]
    )

    product_name.change(
        fn=change_product_preview,
        inputs=[product_name],
        outputs=[product_preview, woman_prompt, man_prompt]
    )

if __name__ == "__main__":
    print("Starting Gradio interface...")
    demo.launch()