SuCicada commited on
Commit
8ad30ba
1 Parent(s): 98e3f08
Files changed (1) hide show
  1. app.py +66 -4
app.py CHANGED
@@ -1,7 +1,69 @@
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
1
+ # pip install transformers gradio scipy ftfy "ipywidgets>=7,<8" datasets diffusers
2
+
3
  import gradio as gr
4
+ import torch
5
+ from torch import autocast
6
+ from diffusers import StableDiffusionPipeline
7
+
8
+ model_id = "hakurei/waifu-diffusion"
9
+ device = "cpu"
10
+
11
+ # pipe = StableDiffusionPipeline.from_pretrained(model_id,
12
+ # resume_download=True, # 模型文件断点续传
13
+ # torch_dtype=torch.float16,
14
+ # revision='fp16')
15
+ # pipe = pipe.to(device)
16
+
17
+ block = gr.Blocks(css=".container { max-width: 800px; margin: auto; }")
18
+
19
+ num_samples = 2
20
+
21
+
22
+ def infer(prompt):
23
+ with autocast("cuda"):
24
+ images = pipe([prompt] * num_samples,
25
+ hight=111,
26
+ width=100,
27
+
28
+ guidance_scale=7.5)["sample"]
29
+
30
+ return images
31
+
32
+
33
+ with block as demo:
34
+ gr.Markdown("<h1><center>Waifu Diffusion</center></h1>")
35
+ gr.Markdown(
36
+ "waifu-diffusion is a latent text-to-image diffusion model that has been conditioned on high-quality anime images through fine-tuning."
37
+ )
38
+ with gr.Group():
39
+ with gr.Box():
40
+ with gr.Row().style(mobile_collapse=False, equal_height=True):
41
+ text = gr.Textbox(
42
+ label="Enter your prompt", show_label=False, max_lines=1
43
+ ).style(
44
+ border=(True, False, True, True),
45
+ rounded=(True, False, False, True),
46
+ container=False,
47
+ )
48
+ btn = gr.Button("Run").style(
49
+ margin=False,
50
+ rounded=(False, True, True, False),
51
+ )
52
+
53
+ gallery = gr.Gallery(label="Generated images", show_label=False).style(
54
+ grid=[2], height="auto"
55
+ )
56
+ text.submit(infer, inputs=[text], outputs=gallery)
57
+ btn.click(infer, inputs=[text], outputs=gallery)
58
 
59
+ gr.Markdown(
60
+ """___
61
+ <p style='text-align: center
62
+ '>
63
+ Created by https://huggingface.co/hakurei
64
+ <br/>
65
+ </p>"""
66
+ )
67
 
68
+ demo.launch(debug=True,
69
+ server_port=7860)