|
import torch |
|
import os |
|
|
|
auth_token = os.getenv("HF_TOKEN") |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
torch_dtype = torch.float16 if device == "cuda" else None |
|
|
|
from diffusers import StableDiffusionPipeline |
|
|
|
model_id = "CompVis/stable-diffusion-v1-4" |
|
pipe = StableDiffusionPipeline.from_pretrained( |
|
model_id, auth_token=auth_token, revision="fp16", torch_dtype=torch_dtype |
|
).to(device) |
|
|
|
def predict(prompt): |
|
return pipe(prompt).images[0] |
|
|
|
import gradio as gr |
|
|
|
gradio_ui = gr.Interface( |
|
fn=predict, |
|
title="Stable Diffusion Demo", |
|
description="Enter a description of an image you'd like to generate!", |
|
inputs=[ |
|
gr.Textbox(lines=2, label="Paste some text here"), |
|
], |
|
outputs=["image"], |
|
examples=[["a photograph of an astronaut riding a horse"]], |
|
) |
|
|
|
gradio_ui.launch() |