|
import gradio as gr |
|
import spaces |
|
import torch |
|
from diffusers import DiffusionPipeline |
|
import gc |
|
from pipeline import Flex2Pipeline |
|
|
|
|
|
pipe = None |
|
|
|
@spaces.GPU |
|
def load_model(model_id="ostris/Flex.2-preview", device="cuda"): |
|
"""Load and cache the model to avoid reloading for each inference""" |
|
global pipe |
|
|
|
if pipe is None: |
|
print(f"Loading {model_id}...") |
|
try: |
|
|
|
|
|
components = DiffusionPipeline.from_pretrained( |
|
model_id, |
|
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32 |
|
).components |
|
|
|
|
|
pipe = Flex2Pipeline( |
|
scheduler=components["scheduler"], |
|
vae=components["vae"], |
|
text_encoder=components["text_encoder"], |
|
tokenizer=components["tokenizer"], |
|
text_encoder_2=components["text_encoder_2"], |
|
tokenizer_2=components["tokenizer_2"], |
|
transformer=components["transformer"], |
|
) |
|
|
|
|
|
pipe = pipe.to(device) |
|
print("Model loaded successfully!") |
|
except Exception as e: |
|
print(f"Error loading model: {e}") |
|
return None |
|
|
|
|
|
if torch.cuda.is_available(): |
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
torch.backends.cudnn.allow_tf32 = True |
|
|
|
return pipe |
|
|
|
def clear_gpu_memory(): |
|
"""Clear GPU memory""" |
|
global pipe |
|
if pipe is not None: |
|
del pipe |
|
pipe = None |
|
gc.collect() |
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
torch.cuda.ipc_collect() |
|
return "GPU memory cleared" |
|
|
|
@spaces.GPU |
|
def generate_image( |
|
prompt, |
|
prompt_2=None, |
|
inpaint_image=None, |
|
inpaint_mask=None, |
|
control_image=None, |
|
control_strength=1.0, |
|
control_stop=1.0, |
|
height=1024, |
|
width=1024, |
|
num_inference_steps=28, |
|
guidance_scale=3.5, |
|
seed=-1, |
|
progress=gr.Progress() |
|
): |
|
"""Generate image using Flex2Pipeline""" |
|
global pipe |
|
|
|
|
|
pipe = load_model() |
|
if pipe is None: |
|
return None, "Error: Failed to load the model. Please check logs." |
|
|
|
|
|
generator = None |
|
if seed != -1: |
|
generator = torch.Generator("cuda" if torch.cuda.is_available() else "cpu").manual_seed(seed) |
|
else: |
|
|
|
seed = torch.randint(0, 2**32 - 1, (1,)).item() |
|
generator = torch.Generator("cuda" if torch.cuda.is_available() else "cpu").manual_seed(seed) |
|
|
|
|
|
def callback_on_step_end(pipe, i, t, callback_kwargs): |
|
progress((i + 1) / pipe._num_timesteps) |
|
return callback_kwargs |
|
|
|
try: |
|
|
|
output = pipe( |
|
prompt=prompt, |
|
prompt_2=prompt_2 if prompt_2 and prompt_2.strip() else None, |
|
inpaint_image=inpaint_image, |
|
inpaint_mask=inpaint_mask, |
|
control_image=control_image, |
|
control_strength=float(control_strength), |
|
control_stop=float(control_stop), |
|
height=height, |
|
width=width, |
|
num_inference_steps=num_inference_steps, |
|
guidance_scale=guidance_scale, |
|
generator=generator, |
|
callback_on_step_end=callback_on_step_end, |
|
) |
|
|
|
|
|
return output.images[0], f"Successfully generated image with seed: {seed}" |
|
|
|
except Exception as e: |
|
error_message = f"Error during image generation: {str(e)}" |
|
print(error_message) |
|
return None, error_message |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Flex.2 generator by [JustLab.ai](https://justlab.ai)") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
|
|
prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here...", lines=3) |
|
prompt_2 = gr.Textbox(label="Secondary Prompt (Optional)", placeholder="Optional secondary prompt...", lines=2) |
|
|
|
with gr.Accordion("Image Settings", open=True): |
|
with gr.Row(): |
|
height = gr.Slider(minimum=256, maximum=1536, value=1024, step=64, label="Height") |
|
width = gr.Slider(minimum=256, maximum=1536, value=1024, step=64, label="Width") |
|
with gr.Row(): |
|
num_inference_steps = gr.Slider(minimum=1, maximum=100, value=28, step=1, label="Inference Steps") |
|
guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, value=3.5, step=0.1, label="Guidance Scale") |
|
seed = gr.Number(label="Seed (-1 for random)", value=-1) |
|
|
|
with gr.Accordion("Control Settings (You can use LoRA generated images.)", open=False): |
|
control_image = gr.Image(label="Control Image (Optional)", type="pil") |
|
with gr.Row(): |
|
control_strength = gr.Slider(minimum=0.0, maximum=2.0, value=1.0, step=0.05, label="Control Strength") |
|
control_stop = gr.Slider(minimum=0.0, maximum=1.0, value=1.0, step=0.05, label="Control Stop") |
|
|
|
with gr.Accordion("Inpainting Settings", open=False): |
|
inpaint_image = gr.Image(label="Initial Image for Inpainting", type="pil") |
|
inpaint_mask = gr.Image(label="Mask Image (White areas will be inpainted)", type="pil") |
|
|
|
|
|
generate_button = gr.Button("Generate Image", variant="primary") |
|
|
|
|
|
status_message = gr.Textbox(label="Status", interactive=False) |
|
|
|
with gr.Column(scale=1): |
|
|
|
output_image = gr.Image(label="Generated Image") |
|
|
|
|
|
generate_button.click( |
|
fn=generate_image, |
|
inputs=[ |
|
prompt, prompt_2, inpaint_image, inpaint_mask, control_image, |
|
control_strength, control_stop, height, width, |
|
num_inference_steps, guidance_scale, seed |
|
], |
|
outputs=[output_image, status_message] |
|
) |
|
|
|
|
|
gr.Examples( |
|
[ |
|
["A beautiful landscape with mountains and a lake", None, None, None, None, 1.0, 1.0, 1024, 1024, 28, 3.5, 42], |
|
["A cyberpunk cityscape at night with neon lights", "High quality, detailed", None, None, None, 1.0, 1.0, 1024, 1024, 28, 7.0, 1234], |
|
], |
|
fn=generate_image, |
|
inputs=[ |
|
prompt, prompt_2, inpaint_image, inpaint_mask, control_image, |
|
control_strength, control_stop, height, width, |
|
num_inference_steps, guidance_scale, seed |
|
], |
|
outputs=[output_image, status_message], |
|
) |
|
|
|
|
|
|
|
demo.queue().launch(show_api=False) |
|
|