File size: 7,436 Bytes
1803c22 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 |
import gradio as gr
import torch
from diffusers import DiffusionPipeline
import gc
from pipeline import Flex2Pipeline
# Global variable to store the pipeline
pipe = None
def load_model(model_id="ostris/Flex.2-preview", device="cuda"):
"""Load and cache the model to avoid reloading for each inference"""
global pipe
if pipe is None:
print(f"Loading {model_id}...")
try:
# Load the model components directly using DiffusionPipeline
# This avoids trying to use custom_pipeline which is causing issues
components = DiffusionPipeline.from_pretrained(
model_id,
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
).components
# Create our custom Flex2Pipeline with the components
pipe = Flex2Pipeline(
scheduler=components["scheduler"],
vae=components["vae"],
text_encoder=components["text_encoder"],
tokenizer=components["tokenizer"],
text_encoder_2=components["text_encoder_2"],
tokenizer_2=components["tokenizer_2"],
transformer=components["transformer"],
)
# Move to device
pipe = pipe.to(device)
print("Model loaded successfully!")
except Exception as e:
print(f"Error loading model: {e}")
return None
# Enable TF32 precision if available
if torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
return pipe
def clear_gpu_memory():
"""Clear GPU memory"""
global pipe
if pipe is not None:
del pipe
pipe = None
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
return "GPU memory cleared"
def generate_image(
prompt,
prompt_2=None,
inpaint_image=None,
inpaint_mask=None,
control_image=None,
control_strength=1.0,
control_stop=1.0,
height=1024,
width=1024,
num_inference_steps=28,
guidance_scale=3.5,
seed=-1,
progress=gr.Progress()
):
"""Generate image using Flex2Pipeline"""
global pipe
# Load model if not already loaded
pipe = load_model()
if pipe is None:
return None, "Error: Failed to load the model. Please check logs."
# Prepare generator for deterministic output
generator = None
if seed != -1:
generator = torch.Generator("cuda" if torch.cuda.is_available() else "cpu").manual_seed(seed)
else:
# Generate a random seed
seed = torch.randint(0, 2**32 - 1, (1,)).item()
generator = torch.Generator("cuda" if torch.cuda.is_available() else "cpu").manual_seed(seed)
# Create callback for progress updates
def callback_on_step_end(pipe, i, t, callback_kwargs):
progress((i + 1) / pipe._num_timesteps)
return callback_kwargs
try:
# Run the pipeline
output = pipe(
prompt=prompt,
prompt_2=prompt_2 if prompt_2 and prompt_2.strip() else None,
inpaint_image=inpaint_image,
inpaint_mask=inpaint_mask,
control_image=control_image,
control_strength=float(control_strength),
control_stop=float(control_stop),
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=generator,
callback_on_step_end=callback_on_step_end,
)
# Return the generated image and success message
return output.images[0], f"Successfully generated image with seed: {seed}"
except Exception as e:
error_message = f"Error during image generation: {str(e)}"
print(error_message)
return None, error_message
# Create Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("# Flex.2 Image Generator")
with gr.Row():
with gr.Column(scale=1):
# Input parameters
prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here...", lines=3)
prompt_2 = gr.Textbox(label="Secondary Prompt (Optional)", placeholder="Optional secondary prompt...", lines=2)
with gr.Accordion("Image Settings", open=True):
with gr.Row():
height = gr.Slider(minimum=256, maximum=1536, value=1024, step=64, label="Height")
width = gr.Slider(minimum=256, maximum=1536, value=1024, step=64, label="Width")
with gr.Row():
num_inference_steps = gr.Slider(minimum=1, maximum=100, value=28, step=1, label="Inference Steps")
guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, value=3.5, step=0.1, label="Guidance Scale")
seed = gr.Number(label="Seed (-1 for random)", value=-1)
with gr.Accordion("Control Settings", open=False):
control_image = gr.Image(label="Control Image (Optional)", type="pil")
with gr.Row():
control_strength = gr.Slider(minimum=0.0, maximum=2.0, value=1.0, step=0.05, label="Control Strength")
control_stop = gr.Slider(minimum=0.0, maximum=1.0, value=1.0, step=0.05, label="Control Stop")
with gr.Accordion("Inpainting Settings", open=False):
inpaint_image = gr.Image(label="Initial Image for Inpainting", type="pil")
inpaint_mask = gr.Image(label="Mask Image (White areas will be inpainted)", type="pil")
# Generate button
generate_button = gr.Button("Generate Image", variant="primary")
# Clear GPU memory button
clear_button = gr.Button("Clear GPU Memory")
# Status message
status_message = gr.Textbox(label="Status", interactive=False)
with gr.Column(scale=1):
# Output image
output_image = gr.Image(label="Generated Image")
# Connect buttons to functions
generate_button.click(
fn=generate_image,
inputs=[
prompt, prompt_2, inpaint_image, inpaint_mask, control_image,
control_strength, control_stop, height, width,
num_inference_steps, guidance_scale, seed
],
outputs=[output_image, status_message]
)
clear_button.click(fn=clear_gpu_memory, outputs=status_message)
# Examples
gr.Examples(
[
["A beautiful landscape with mountains and a lake", None, None, None, None, 1.0, 1.0, 1024, 1024, 28, 3.5, 42],
["A cyberpunk cityscape at night with neon lights", "High quality, detailed", None, None, None, 1.0, 1.0, 1024, 1024, 28, 7.0, 1234],
],
fn=generate_image,
inputs=[
prompt, prompt_2, inpaint_image, inpaint_mask, control_image,
control_strength, control_stop, height, width,
num_inference_steps, guidance_scale, seed
],
outputs=[output_image, status_message],
)
# Launch the app with queue enabled
if __name__ == "__main__":
demo.queue(concurrency_count=1).launch(share=False)
|