Prompt2Picture / app.py
gautamraj8044's picture
Upload 2 files
f58e466 verified
import streamlit as st
import torch
from diffusers import StableDiffusionPipeline
from transformers import pipeline, set_seed
from PIL import Image
# TTI Class Definition
class TTI:
device = "cuda" if torch.cuda.is_available() else "cpu"
seed = 42
generator = torch.Generator(device).manual_seed(seed)
image_gen_steps = 35
image_gen_size = (400, 400)
image_gen_guidence_scale = 9
image_gen_model_id = "stabilityai/stable-diffusion-2"
prompt_gen_model_id = "gpt2"
# Load Stable Diffusion Model
@st.cache_resource
def load_image_gen_model():
model = StableDiffusionPipeline.from_pretrained(
TTI.image_gen_model_id,
torch_dtype=torch.float16,
revision="fp16"
)
return model.to(TTI.device)
image_gen_model = load_image_gen_model()
# Function to Generate Images
def generate_image(prompt, model):
image = model(
prompt,
num_inference_steps=TTI.image_gen_steps,
generator=TTI.generator,
guidance_scale=TTI.image_gen_guidence_scale
).images[0]
# Resize the image to the specified size
image = image.resize(TTI.image_gen_size, Image.ANTIALIAS)
return image
# Streamlit UI
st.title("Text-to-Image Generator")
st.write("Generate images from text prompts using Stable Diffusion.")
# User Input: Prompt
prompt = st.text_input("Enter a text prompt", value="A monkey on a tree")
# User Input: Inference Steps
image_gen_steps = st.slider(
"Number of inference steps (Higher = Better quality but slower)",
min_value=10,
max_value=100,
value=TTI.image_gen_steps,
step=5
)
# User Input: Guidance Scale
guidance_scale = st.slider(
"Guidance scale (Higher = Closer to prompt, but less creative)",
min_value=1.0,
max_value=20.0,
value=float(TTI.image_gen_guidence_scale), # Convert the value to float
step=0.5
)
# User Input: Image Size
image_width = st.number_input("Image Width", min_value=64, max_value=1024, value=TTI.image_gen_size[0], step=64)
image_height = st.number_input("Image Height", min_value=64, max_value=1024, value=TTI.image_gen_size[1], step=64)
# Generate Image Button
if st.button("Generate Image"):
TTI.image_gen_steps = image_gen_steps
TTI.image_gen_guidence_scale = guidance_scale
TTI.image_gen_size = (image_width, image_height)
with st.spinner("Generating image..."):
image = generate_image(prompt, image_gen_model)
st.image(image, caption=f"Generated Image for Prompt: '{prompt}'", use_column_width=True)
st.write("Adjust parameters to customize the image generation!")