Spaces:
Runtime error
Runtime error
import streamlit as st | |
from diffusers import UNet2DConditionModel, DiffusionPipeline, LCMScheduler | |
import torch | |
from PIL import Image | |
# Function to generate and display image | |
def generate_and_display_image(prompt): | |
# Initialize the UNet model | |
unet = UNet2DConditionModel.from_pretrained("path/to/fine-tuned/weight", torch_dtype=torch.float16, variant="fp16") | |
# Initialize the diffusion pipeline | |
pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", unet=unet, torch_dtype=torch.float16, variant="fp16") | |
pipeline.safety_checker = None | |
pipeline.requires_safety_checker = False | |
# Set the loaded scheduler in the pipeline | |
pipeline.scheduler = LCMScheduler.from_config(pipeline.scheduler.config) | |
pipeline.to("cuda") | |
# Set the number of inference steps | |
inference_steps = 4 | |
# Generate image | |
image = pipeline(prompt, num_inference_steps=inference_steps, guidance_scale=2).images[0] | |
image = image.resize((512, 512)) | |
# Display the generated image | |
st.image(image, caption="Generated Image", use_column_width=True) | |
# Main function | |
def main(): | |
st.title("Image Generation with Diffusion Models") | |
# Input prompt | |
prompt = st.text_input("Enter your prompt") | |
# Button to generate and display image | |
if st.button("Generate Image"): | |
if prompt: | |
generate_and_display_image(prompt) | |
else: | |
st.warning("Please provide a prompt.") | |
if __name__ == "__main__": | |
main() | |