Text-to-Image / app.py
sdafd's picture
Update app.py
f51843b verified
import torch
from diffusers import FluxPipeline
import gradio as gr
import threading
import os
os.environ["OMP_NUM_THREADS"] = str(os.cpu_count())
torch.set_num_threads(os.cpu_count())
# Initialize Flux pipeline
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.bfloat16
)
pipe.enable_model_cpu_offload()
stop_event = threading.Event()
def generate_images(
prompt,
height,
width,
guidance_scale,
num_inference_steps,
max_sequence_length,
seed,
randomize_seed
):
stop_event.clear()
results = []
for i in range(3):
if stop_event.is_set():
return [None] * 3
# Handle seed randomization
if randomize_seed:
current_seed = torch.randint(0, 2**32 - 1, (1,)).item()
else:
current_seed = seed + i
generator = torch.Generator(device="cpu").manual_seed(current_seed)
# Generate image with current parameters
image = pipe(
prompt=prompt,
height=int(height),
width=int(width),
guidance_scale=guidance_scale,
num_inference_steps=int(num_inference_steps),
max_sequence_length=int(max_sequence_length),
generator=generator
).images[0]
results.append(image)
return results
def stop_generation():
stop_event.set()
return [None] * 3
with gr.Blocks() as interface:
gr.Markdown("""
### FLUX Image Generation
Adjust parameters below to control the image generation process
""")
with gr.Row():
text_input = gr.Textbox(
label="Prompt",
placeholder="Describe what you want to generate...",
scale=3
)
with gr.Accordion("Generation Parameters", open=False):
with gr.Row():
height = gr.Number(
label="Height",
value=1024,
minimum=512,
maximum=4096,
step=64,
precision=0
)
width = gr.Number(
label="Width",
value=1024,
minimum=512,
maximum=4096,
step=64,
precision=0
)
guidance_scale = gr.Slider(
label="Guidance Scale",
minimum=0.0,
maximum=20.0,
value=7.0,
step=0.5
)
num_inference_steps = gr.Slider(
label="Inference Steps",
minimum=10,
maximum=150,
value=50,
step=1
)
max_sequence_length = gr.Dropdown(
label="Max Sequence Length",
choices=[512, 768, 1024],
value=512
)
with gr.Row():
seed = gr.Number(
label="Seed",
value=42,
precision=0
)
randomize_seed = gr.Checkbox(
label="Randomize Seed",
value=True
)
with gr.Row():
generate_btn = gr.Button("Generate", variant="primary")
stop_btn = gr.Button("Stop Generation")
with gr.Row():
output1 = gr.Image(label="Output 1", type="pil")
output2 = gr.Image(label="Output 2", type="pil")
output3 = gr.Image(label="Output 3", type="pil")
generate_btn.click(
generate_images,
inputs=[
text_input,
height,
width,
guidance_scale,
num_inference_steps,
max_sequence_length,
seed,
randomize_seed
],
outputs=[output1, output2, output3]
)
stop_btn.click(
stop_generation,
inputs=[],
outputs=[output1, output2, output3]
)
interface.launch()