cheeseman182's picture
More infrance steps to 4
97fd923 verified
import gradio as gr
import torch
from diffusers import AutoPipelineForText2Image
import time
import numpy as np # The only library we need for this fix
# --- 1. Load the SD-Turbo Model (Optimized for CPU) ---
# (No changes here)
print("Loading the SD-Turbo model for CPU...")
pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sd-turbo")
device = "cpu"
pipe = pipe.to(device)
print("Model loaded successfully on CPU!")
# --- 2. Pre-warm the Model ---
# (No changes here)
print("Pre-warming the pipeline...")
_ = pipe(prompt="A photo of a cat", width=512, height=512, num_inference_steps=1).images[0]
print("Pipeline is warmed up and ready!")
# --- 3. The NumPy Array Solution ---
def generate_and_return_numpy(prompt, seed, width, height):
"""
Generates an image and returns it as a raw NumPy array. This is the most
stable method to avoid Gradio/Windows bugs, though a UI delay will exist.
"""
start_time = time.time()
try:
width = int(width)
height = int(height)
generator = torch.Generator(device=pipe.device).manual_seed(int(seed))
# The model generates the PIL Image
pil_image = pipe(
prompt=prompt,
width=width,
height=height,
num_inference_steps=4,
guidance_scale=0.0,
generator=generator,
).images[0]
# --- THE FIX: Convert to NumPy Array ---
numpy_array = np.array(pil_image)
# --- END OF FIX ---
end_time = time.time()
# This time will be the FAST backend time. The UI will take longer.
generation_time = f"Backend generation time: {end_time - start_time:.2f} seconds"
# We return the raw array. Gradio will handle the slow encoding now.
return numpy_array, generation_time, None
except Exception as e:
print(f"An error occurred: {e}")
return None, "Generation failed", str(e)
# --- 4. Create the Gradio Interface ---
# The UI code is identical. gr.Image can handle NumPy arrays.
with gr.Blocks(theme=gr.themes.Soft()) as app:
gr.Markdown(
"""
# πŸ’― Stable CPU Generator (NumPy Version) πŸ’―
### This is the most robust version to prevent crashes on Windows.
"""
)
with gr.Row():
with gr.Column(scale=3):
prompt_input = gr.Textbox(
label="Prompt", placeholder="A wizard casting a spell", lines=3,
)
with gr.Row():
width_slider = gr.Slider(
label="Width", minimum=256, maximum=768, value=512, step=64,
)
height_slider = gr.Slider(
label="Height", minimum=256, maximum=768, value=512, step=64,
)
seed_input = gr.Number(label="Seed", value=100)
generate_button = gr.Button("Generate", variant="primary")
with gr.Column(scale=2):
image_output = gr.Image(label="Generated Image", show_label=False)
info_output = gr.Textbox(label="Status", show_label=False, interactive=False)
error_output = gr.Textbox(label="Error", visible=False)
generate_button.click(
fn=generate_and_return_numpy,
inputs=[prompt_input, seed_input, width_slider, height_slider],
outputs=[image_output, info_output, error_output],
)
# --- 5. Launch the App ---
app.launch(share=True)