from functools import partial import gradio as gr from PIL import Image from inference import generate_image def process_coord_click(image_idx: int, evt: gr.SelectData) -> Image.Image: """ Process the click event on the coordinate selector """ x, y = evt.index[0], evt.index[1] x, y = x / 400, y / 400 print(f"Clicked at coordinates: ({x:.3f}, {y:.3f})") return generate_image(image_idx, x, y) def process_image_select(evt: gr.SelectData, idx: int) -> tuple[int, str]: """ Process the reference image selection Returns the selected image index and corresponding heatmap """ return idx, f"imgs/heatmap_{idx}.png" with gr.Blocks() as demo: gr.Markdown( """ # Interactive Image Generation Click on a reference image to select it, then click on the coordinate selector to generate a new image. """ ) with gr.Row(): # Left column: Interactive reference images with gr.Column(scale=1): # State variable to track selected image index selected_idx = gr.State(value=0) # Two separate Image components for reference images with gr.Column(): image_0 = gr.Image( value="imgs/pattern_0.png", label="Task 1", show_label=False, interactive=True, height=300, width=450, ) image_1 = gr.Image( value="imgs/pattern_1.png", label="Task 2", show_label=False, interactive=True, height=300, width=450, ) # Right column: Coordinate selector and output image with gr.Column(scale=1): # Coordinate selector with dynamic background coord_selector = gr.Image( value="imgs/heatmap_0.png", # Initial background label="Click to select (x, y) coordinates", show_label=True, interactive=True, height=400, width=400, ) # Generated image output output_image = gr.Image(label="Generated Output", height=400, width=400) # Handle image selection for each reference image image_0.select(partial(process_image_select, idx=0), outputs=[selected_idx, coord_selector]) image_1.select(partial(process_image_select, idx=1), outputs=[selected_idx, coord_selector]) # Handle coordinate selection coord_selector.select( process_coord_click, inputs=[selected_idx], outputs=output_image, trigger_mode="multiple" ) demo.launch()