import streamlit as st from io import BytesIO from typing import Literal from diffusers import StableDiffusionPipeline import torch import time seed = 42 generator = torch.manual_seed(seed) NUM_ITERS_TO_RUN = 2 NUM_INFERENCE_STEPS = 20 NUM_IMAGES_PER_PROMPT = 1 def text2image( prompt: str, repo_id: Literal[ "dreamlike-art/dreamlike-photoreal-2.0", "hakurei/waifu-diffusion", "prompthero/openjourney", "stabilityai/stable-diffusion-2-1", "runwayml/stable-diffusion-v1-5", "nota-ai/bk-sdm-small", "CompVis/stable-diffusion-v1-4", ], ): start = time.time() if torch.cuda.is_available(): print("Using GPU") pipeline = StableDiffusionPipeline.from_pretrained( repo_id, torch_dtype=torch.float16, use_safetensors=True, ).to("cuda") else: print("Using CPU") pipeline = StableDiffusionPipeline.from_pretrained( repo_id, torch_dtype=torch.float32, use_safetensors=True, ) for _ in range(NUM_ITERS_TO_RUN): images = pipeline( prompt, num_inference_steps=NUM_INFERENCE_STEPS, generator=generator, num_images_per_prompt=NUM_IMAGES_PER_PROMPT, ).images end = time.time() return images[0], start, end def app(): st.header("Text-to-image Web App") st.subheader("Powered by Hugging Face") user_input = st.text_area( "Enter your text prompt below and click the button to submit." ) option = st.selectbox( "Select model (in order of processing time)", ( "nota-ai/bk-sdm-small", "CompVis/stable-diffusion-v1-4", "runwayml/stable-diffusion-v1-5", "prompthero/openjourney", "hakurei/waifu-diffusion", "stabilityai/stable-diffusion-2-1", "dreamlike-art/dreamlike-photoreal-2.0", ), ) with st.form("my_form"): submit = st.form_submit_button(label="Submit text prompt") if submit: with st.spinner(text="Generating image ... It may take up to 20 minutes."): im, start, end = text2image(prompt=user_input, repo_id=option) buf = BytesIO() im.save(buf, format="PNG") byte_im = buf.getvalue() hours, rem = divmod(end - start, 3600) minutes, seconds = divmod(rem, 60) st.success( "Processing time: {:0>2}:{:0>2}:{:05.2f}.".format( int(hours), int(minutes), seconds ) ) st.image(im) st.download_button( label="Click here to download", data=byte_im, file_name="generated_image.png", mime="image/png", ) if __name__ == "__main__": app()