import gradio as gr import torch from diffusers import DiffusionPipeline import gc from pipeline import Flex2Pipeline # Global variable to store the pipeline pipe = None def load_model(model_id="ostris/Flex.2-preview", device="cuda"): """Load and cache the model to avoid reloading for each inference""" global pipe if pipe is None: print(f"Loading {model_id}...") try: # Load the model components directly using DiffusionPipeline # This avoids trying to use custom_pipeline which is causing issues components = DiffusionPipeline.from_pretrained( model_id, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32 ).components # Create our custom Flex2Pipeline with the components pipe = Flex2Pipeline( scheduler=components["scheduler"], vae=components["vae"], text_encoder=components["text_encoder"], tokenizer=components["tokenizer"], text_encoder_2=components["text_encoder_2"], tokenizer_2=components["tokenizer_2"], transformer=components["transformer"], ) # Move to device pipe = pipe.to(device) print("Model loaded successfully!") except Exception as e: print(f"Error loading model: {e}") return None # Enable TF32 precision if available if torch.cuda.is_available(): torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True return pipe def clear_gpu_memory(): """Clear GPU memory""" global pipe if pipe is not None: del pipe pipe = None gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.ipc_collect() return "GPU memory cleared" def generate_image( prompt, prompt_2=None, inpaint_image=None, inpaint_mask=None, control_image=None, control_strength=1.0, control_stop=1.0, height=1024, width=1024, num_inference_steps=28, guidance_scale=3.5, seed=-1, progress=gr.Progress() ): """Generate image using Flex2Pipeline""" global pipe # Load model if not already loaded pipe = load_model() if pipe is None: return None, "Error: Failed to load the model. Please check logs." # Prepare generator for deterministic output generator = None if seed != -1: generator = torch.Generator("cuda" if torch.cuda.is_available() else "cpu").manual_seed(seed) else: # Generate a random seed seed = torch.randint(0, 2**32 - 1, (1,)).item() generator = torch.Generator("cuda" if torch.cuda.is_available() else "cpu").manual_seed(seed) # Create callback for progress updates def callback_on_step_end(pipe, i, t, callback_kwargs): progress((i + 1) / pipe._num_timesteps) return callback_kwargs try: # Run the pipeline output = pipe( prompt=prompt, prompt_2=prompt_2 if prompt_2 and prompt_2.strip() else None, inpaint_image=inpaint_image, inpaint_mask=inpaint_mask, control_image=control_image, control_strength=float(control_strength), control_stop=float(control_stop), height=height, width=width, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator, callback_on_step_end=callback_on_step_end, ) # Return the generated image and success message return output.images[0], f"Successfully generated image with seed: {seed}" except Exception as e: error_message = f"Error during image generation: {str(e)}" print(error_message) return None, error_message # Create Gradio Interface with gr.Blocks() as demo: gr.Markdown("# Flex.2 Image Generator") with gr.Row(): with gr.Column(scale=1): # Input parameters prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here...", lines=3) prompt_2 = gr.Textbox(label="Secondary Prompt (Optional)", placeholder="Optional secondary prompt...", lines=2) with gr.Accordion("Image Settings", open=True): with gr.Row(): height = gr.Slider(minimum=256, maximum=1536, value=1024, step=64, label="Height") width = gr.Slider(minimum=256, maximum=1536, value=1024, step=64, label="Width") with gr.Row(): num_inference_steps = gr.Slider(minimum=1, maximum=100, value=28, step=1, label="Inference Steps") guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, value=3.5, step=0.1, label="Guidance Scale") seed = gr.Number(label="Seed (-1 for random)", value=-1) with gr.Accordion("Control Settings", open=False): control_image = gr.Image(label="Control Image (Optional)", type="pil") with gr.Row(): control_strength = gr.Slider(minimum=0.0, maximum=2.0, value=1.0, step=0.05, label="Control Strength") control_stop = gr.Slider(minimum=0.0, maximum=1.0, value=1.0, step=0.05, label="Control Stop") with gr.Accordion("Inpainting Settings", open=False): inpaint_image = gr.Image(label="Initial Image for Inpainting", type="pil") inpaint_mask = gr.Image(label="Mask Image (White areas will be inpainted)", type="pil") # Generate button generate_button = gr.Button("Generate Image", variant="primary") # Clear GPU memory button clear_button = gr.Button("Clear GPU Memory") # Status message status_message = gr.Textbox(label="Status", interactive=False) with gr.Column(scale=1): # Output image output_image = gr.Image(label="Generated Image") # Connect buttons to functions generate_button.click( fn=generate_image, inputs=[ prompt, prompt_2, inpaint_image, inpaint_mask, control_image, control_strength, control_stop, height, width, num_inference_steps, guidance_scale, seed ], outputs=[output_image, status_message] ) clear_button.click(fn=clear_gpu_memory, outputs=status_message) # Examples gr.Examples( [ ["A beautiful landscape with mountains and a lake", None, None, None, None, 1.0, 1.0, 1024, 1024, 28, 3.5, 42], ["A cyberpunk cityscape at night with neon lights", "High quality, detailed", None, None, None, 1.0, 1.0, 1024, 1024, 28, 7.0, 1234], ], fn=generate_image, inputs=[ prompt, prompt_2, inpaint_image, inpaint_mask, control_image, control_strength, control_stop, height, width, num_inference_steps, guidance_scale, seed ], outputs=[output_image, status_message], ) # Launch the app with queue enabled if __name__ == "__main__": demo.queue(concurrency_count=1).launch(share=False)