Spaces:
Running
on
Zero
Running
on
Zero
import os | |
os.system("pip uninstall torchvision -y") | |
os.system("pip install torchvision --force-reinstall --no-cache-dir") | |
import torch | |
from diffusers import AutoPipelineForText2Image | |
import gradio as gr | |
from PIL import Image | |
import spaces | |
pipe = AutoPipelineForText2Image.from_pretrained( | |
"ostris/Flex.2-preview", | |
custom_pipeline="pipeline.py", | |
torch_dtype=torch.bfloat16, | |
).to("cuda") | |
def generate_image( | |
prompt: str, | |
inpaint_img: Image.Image, | |
inpaint_mask: Image.Image, | |
control_img: Image.Image, | |
height: int, | |
width: int, | |
guidance_scale: float, | |
num_inference_steps: int, | |
seed: int, | |
control_strength: float, | |
control_stop: float, | |
): | |
gen = torch.Generator(device="cuda").manual_seed(seed) | |
inp_img = inpaint_img.convert("RGB") | |
inp_mask = inpaint_mask.convert("RGB") | |
ctrl_img = control_img.convert("RGB") | |
result = pipe( | |
prompt=prompt, | |
inpaint_image=inp_img, | |
inpaint_mask=inp_mask, | |
control_image=ctrl_img, | |
control_strength=control_strength, | |
control_stop=control_stop, | |
height=height, | |
width=width, | |
guidance_scale=guidance_scale, | |
num_inference_steps=num_inference_steps, | |
generator=gen, | |
) | |
return result.images[0] | |
with gr.Blocks(title="Flex.2-preview Image Generator") as demo: | |
gr.Markdown("# Flex.2-preview Text→Image Generator") | |
prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt...", lines=2) | |
with gr.Row(): | |
inpaint_img = gr.Image(type="pil", label="Inpaint Image") | |
inpaint_mask = gr.Image(type="pil", label="Inpaint Mask") | |
control_img = gr.Image(type="pil", label="Control Image") | |
with gr.Row(): | |
height = gr.Slider(64, 2048, value=512, step=64, label="Height") | |
width = gr.Slider(64, 2048, value=512, step=64, label="Width") | |
with gr.Row(): | |
guidance_scale = gr.Slider(0.0, 20.0, value=3.5, step=0.1, label="Guidance Scale") | |
num_inference_steps = gr.Slider(1, 100, value=50, step=1, label="Inference Steps") | |
seed = gr.Number(value=42, precision=0, label="Random Seed") | |
control_strength = gr.Slider(0.0, 1.0, value=0.0, step=0.01, label="Control Strength") | |
control_stop = gr.Slider(0.0, 1.0, value=1.0, step=0.01, label="Control Stop") | |
generate_btn = gr.Button("Generate") | |
output = gr.Image(type="pil", label="Generated Image") | |
generate_btn.click( | |
fn=generate_image, | |
inputs=[ | |
prompt, | |
inpaint_img, | |
inpaint_mask, | |
control_img, | |
height, | |
width, | |
guidance_scale, | |
num_inference_steps, | |
seed, | |
control_strength, | |
control_stop, | |
], | |
outputs=[output], | |
) | |
if __name__ == "__main__": | |
demo.launch(share=True) | |