Spaces:
Running
Running
# pip install transformers gradio scipy ftfy "ipywidgets>=7,<8" datasets diffusers | |
import random | |
import gradio as gr | |
import torch | |
from torch import autocast | |
from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline | |
model_id = "hakurei/waifu-diffusion" | |
# torch.backends.cudnn.benchmark = True | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
def __def_helper(): | |
StableDiffusionPipeline.__call__() | |
pipe = StableDiffusionPipeline.from_pretrained(model_id, | |
resume_download=True, # 模型文件断点续传 | |
torch_dtype=torch.float16, | |
revision='fp16') | |
pipe = pipe.to(device) | |
def infer(prompt, width, height, nums, steps, guidance_scale, seed): | |
print(prompt) | |
print(width, height, nums, steps, guidance_scale, seed) | |
if prompt is not None and prompt != "": | |
if seed is None or seed == '' or seed == -1: | |
seed = int(random.randrange(4294967294)) | |
with autocast("cuda"): | |
generator = torch.Generator("cuda").manual_seed(seed) | |
images = pipe([prompt] * nums, | |
height=height, | |
width=width, | |
num_inference_steps=steps, | |
generator=generator, | |
guidance_scale=guidance_scale | |
)["sample"] | |
return images | |
description = """ | |
prompt 素材:[https://lexica.art](https://lexica.art) \n | |
seed:为空会使用随机seed | |
""" | |
# with block as demo: | |
def run(): | |
_app = gr.Interface( | |
fn=infer, | |
title="Waifu Diffusion", | |
description=description, | |
inputs=[ | |
gr.Textbox(label="输入 prompt"), | |
gr.Slider(512, 1024, 512, step=64, label="width"), | |
gr.Slider(512, 1024, 512, step=64, label="height"), | |
gr.Slider(1, 4, 1, step=1, label="Number of Images"), | |
gr.Slider(10, 150, step=1, value=50, | |
label="num_inference_steps:\n" | |
"去噪步骤的数量。更多的去噪步骤通常会导致更高质量的图像,但会降低推理速度。"), | |
gr.Slider(0, 20, 7.5, step=0.5, | |
label="guidance_scale:\n" + | |
"较高的引导比例鼓励生成与文本“提示”密切相关的图像,通常以降低图像质量为代价"), | |
gr.Textbox(label="随机 seed", | |
placeholder="Random Seed", | |
lines=1), | |
], | |
outputs=[ | |
gr.Gallery(label="Generated images") | |
]) | |
return _app | |
app = run() | |
app.launch(debug=True) | |