imagen2 / app_hf.py
Abe
clean deploy HF
1803c22
import gradio as gr
import spaces
import torch
from diffusers import DiffusionPipeline
import gc
from pipeline import Flex2Pipeline
# Global variable to store the pipeline
pipe = None
@spaces.GPU
def load_model(model_id="ostris/Flex.2-preview", device="cuda"):
"""Load and cache the model to avoid reloading for each inference"""
global pipe
if pipe is None:
print(f"Loading {model_id}...")
try:
# Load the model components directly using DiffusionPipeline
# This avoids trying to use custom_pipeline which is causing issues
components = DiffusionPipeline.from_pretrained(
model_id,
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
).components
# Create our custom Flex2Pipeline with the components
pipe = Flex2Pipeline(
scheduler=components["scheduler"],
vae=components["vae"],
text_encoder=components["text_encoder"],
tokenizer=components["tokenizer"],
text_encoder_2=components["text_encoder_2"],
tokenizer_2=components["tokenizer_2"],
transformer=components["transformer"],
)
# Move to device
pipe = pipe.to(device)
print("Model loaded successfully!")
except Exception as e:
print(f"Error loading model: {e}")
return None
# Enable TF32 precision if available
if torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
return pipe
def clear_gpu_memory():
"""Clear GPU memory"""
global pipe
if pipe is not None:
del pipe
pipe = None
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
return "GPU memory cleared"
@spaces.GPU
def generate_image(
prompt,
prompt_2=None,
inpaint_image=None,
inpaint_mask=None,
control_image=None,
control_strength=1.0,
control_stop=1.0,
height=1024,
width=1024,
num_inference_steps=28,
guidance_scale=3.5,
seed=-1,
progress=gr.Progress()
):
"""Generate image using Flex2Pipeline"""
global pipe
# Load model if not already loaded
pipe = load_model()
if pipe is None:
return None, "Error: Failed to load the model. Please check logs."
# Prepare generator for deterministic output
generator = None
if seed != -1:
generator = torch.Generator("cuda" if torch.cuda.is_available() else "cpu").manual_seed(seed)
else:
# Generate a random seed
seed = torch.randint(0, 2**32 - 1, (1,)).item()
generator = torch.Generator("cuda" if torch.cuda.is_available() else "cpu").manual_seed(seed)
# Create callback for progress updates
def callback_on_step_end(pipe, i, t, callback_kwargs):
progress((i + 1) / pipe._num_timesteps)
return callback_kwargs
try:
# Run the pipeline
output = pipe(
prompt=prompt,
prompt_2=prompt_2 if prompt_2 and prompt_2.strip() else None,
inpaint_image=inpaint_image,
inpaint_mask=inpaint_mask,
control_image=control_image,
control_strength=float(control_strength),
control_stop=float(control_stop),
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=generator,
callback_on_step_end=callback_on_step_end,
)
# Return the generated image and success message
return output.images[0], f"Successfully generated image with seed: {seed}"
except Exception as e:
error_message = f"Error during image generation: {str(e)}"
print(error_message)
return None, error_message
# Create Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("# Flex.2 generator by [JustLab.ai](https://justlab.ai)")
with gr.Row():
with gr.Column(scale=1):
# Input parameters
prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here...", lines=3)
prompt_2 = gr.Textbox(label="Secondary Prompt (Optional)", placeholder="Optional secondary prompt...", lines=2)
with gr.Accordion("Image Settings", open=True):
with gr.Row():
height = gr.Slider(minimum=256, maximum=1536, value=1024, step=64, label="Height")
width = gr.Slider(minimum=256, maximum=1536, value=1024, step=64, label="Width")
with gr.Row():
num_inference_steps = gr.Slider(minimum=1, maximum=100, value=28, step=1, label="Inference Steps")
guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, value=3.5, step=0.1, label="Guidance Scale")
seed = gr.Number(label="Seed (-1 for random)", value=-1)
with gr.Accordion("Control Settings (You can use LoRA generated images.)", open=False):
control_image = gr.Image(label="Control Image (Optional)", type="pil")
with gr.Row():
control_strength = gr.Slider(minimum=0.0, maximum=2.0, value=1.0, step=0.05, label="Control Strength")
control_stop = gr.Slider(minimum=0.0, maximum=1.0, value=1.0, step=0.05, label="Control Stop")
with gr.Accordion("Inpainting Settings", open=False):
inpaint_image = gr.Image(label="Initial Image for Inpainting", type="pil")
inpaint_mask = gr.Image(label="Mask Image (White areas will be inpainted)", type="pil")
# Generate button
generate_button = gr.Button("Generate Image", variant="primary")
# Status message
status_message = gr.Textbox(label="Status", interactive=False)
with gr.Column(scale=1):
# Output image
output_image = gr.Image(label="Generated Image")
# Connect buttons to functions
generate_button.click(
fn=generate_image,
inputs=[
prompt, prompt_2, inpaint_image, inpaint_mask, control_image,
control_strength, control_stop, height, width,
num_inference_steps, guidance_scale, seed
],
outputs=[output_image, status_message]
)
# Examples
gr.Examples(
[
["A beautiful landscape with mountains and a lake", None, None, None, None, 1.0, 1.0, 1024, 1024, 28, 3.5, 42],
["A cyberpunk cityscape at night with neon lights", "High quality, detailed", None, None, None, 1.0, 1.0, 1024, 1024, 28, 7.0, 1234],
],
fn=generate_image,
inputs=[
prompt, prompt_2, inpaint_image, inpaint_mask, control_image,
control_strength, control_stop, height, width,
num_inference_steps, guidance_scale, seed
],
outputs=[output_image, status_message],
)
# Configure for HF Spaces
# Disable API endpoints
demo.queue().launch(show_api=False)