Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
from diffusers import StableDiffusionPipeline | |
from transformers import pipeline, set_seed | |
from PIL import Image | |
# TTI Class Definition | |
class TTI: | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
seed = 42 | |
generator = torch.Generator(device).manual_seed(seed) | |
image_gen_steps = 35 | |
image_gen_size = (400, 400) | |
image_gen_guidence_scale = 9 | |
image_gen_model_id = "stabilityai/stable-diffusion-2" | |
prompt_gen_model_id = "gpt2" | |
# Load Stable Diffusion Model | |
def load_image_gen_model(): | |
model = StableDiffusionPipeline.from_pretrained( | |
TTI.image_gen_model_id, | |
torch_dtype=torch.float16, | |
revision="fp16" | |
) | |
return model.to(TTI.device) | |
image_gen_model = load_image_gen_model() | |
# Function to Generate Images | |
def generate_image(prompt, model): | |
image = model( | |
prompt, | |
num_inference_steps=TTI.image_gen_steps, | |
generator=TTI.generator, | |
guidance_scale=TTI.image_gen_guidence_scale | |
).images[0] | |
# Resize the image to the specified size | |
image = image.resize(TTI.image_gen_size, Image.ANTIALIAS) | |
return image | |
# Streamlit UI | |
st.title("Text-to-Image Generator") | |
st.write("Generate images from text prompts using Stable Diffusion.") | |
# User Input: Prompt | |
prompt = st.text_input("Enter a text prompt", value="A monkey on a tree") | |
# User Input: Inference Steps | |
image_gen_steps = st.slider( | |
"Number of inference steps (Higher = Better quality but slower)", | |
min_value=10, | |
max_value=100, | |
value=TTI.image_gen_steps, | |
step=5 | |
) | |
# User Input: Guidance Scale | |
guidance_scale = st.slider( | |
"Guidance scale (Higher = Closer to prompt, but less creative)", | |
min_value=1.0, | |
max_value=20.0, | |
value=float(TTI.image_gen_guidence_scale), # Convert the value to float | |
step=0.5 | |
) | |
# User Input: Image Size | |
image_width = st.number_input("Image Width", min_value=64, max_value=1024, value=TTI.image_gen_size[0], step=64) | |
image_height = st.number_input("Image Height", min_value=64, max_value=1024, value=TTI.image_gen_size[1], step=64) | |
# Generate Image Button | |
if st.button("Generate Image"): | |
TTI.image_gen_steps = image_gen_steps | |
TTI.image_gen_guidence_scale = guidance_scale | |
TTI.image_gen_size = (image_width, image_height) | |
with st.spinner("Generating image..."): | |
image = generate_image(prompt, image_gen_model) | |
st.image(image, caption=f"Generated Image for Prompt: '{prompt}'", use_column_width=True) | |
st.write("Adjust parameters to customize the image generation!") | |