Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import gradio as gr | |
from torch import autocast | |
from diffusers import StableDiffusionPipeline | |
# get hf user access token as an environment variable | |
TOKEN_KEY = os.getenv('AUTH_TOKEN') | |
# setup pipeline | |
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16, use_auth_token=TOKEN_KEY) | |
pipe = pipe.to('cuda') | |
# define gradio function | |
def generate(prompt:str, seed:int, guidance:float): | |
generator = torch.Generator("cuda").manual_seed(int(seed)) | |
with autocast("cuda"): | |
image = pipe(prompt=prompt, generator=generator, guidance_scale=guidance, steps=50).images[0] | |
return image | |
# create the gradio UI | |
demo = gr.Interface( | |
fn=generate, | |
inputs=[gr.Textbox(placeholder="castle on a mountain"), gr.Number(value=123456), gr.Slider(0,10)], | |
outputs="image", | |
allow_flagging="never", | |
) | |
# allow queueing or incoming requests, max=3 | |
demo.queue(concurrency_count=3) | |
# launch demo | |
demo.launch() |