Gazoche commited on
Commit
4469af0
1 Parent(s): b0d1af9
Files changed (2) hide show
  1. app.py +218 -0
  2. requirements.txt +9 -0
app.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import nullcontext
2
+ import torch
3
+ from torch import autocast
4
+ from diffusers import StableDiffusionPipeline
5
+ import gradio as gr
6
+
7
+
8
+ CHECKPOINTS = [
9
+ "epoch-000025",
10
+ #"epoch-000081"
11
+ ]
12
+
13
+ device = "cuda" if torch.cuda.is_available() else "cpu"
14
+ context = autocast if device == "cuda" else nullcontext
15
+ dtype = torch.float16 if device == "cuda" else torch.float32
16
+
17
+ def load_pipe(checkpoint):
18
+
19
+ pipe = StableDiffusionPipeline.from_pretrained("Gazoche/sd-gundam-diffusers", revision=checkpoint, torch_dtype=dtype)
20
+ pipe = pipe.to(device)
21
+
22
+ # Disabling the NSFW filter as it's getting confused by the generated images
23
+ def null_safety(images, **kwargs):
24
+ return images, False
25
+ pipe.safety_checker = null_safety
26
+
27
+ return pipe
28
+
29
+ pipes = {
30
+ checkpoint: load_pipe(checkpoint)
31
+ for checkpoint in CHECKPOINTS
32
+ }
33
+
34
+ def infer(prompt, n_samples, steps, scale, model):
35
+
36
+ checkpoint = "epoch-000025" if model == "normal" else "epoch-000081"
37
+
38
+ in_prompt = ""
39
+ guidance_scale = 0.0
40
+ if prompt is not None:
41
+ in_prompt = prompt
42
+ guidance_scale = scale
43
+
44
+ with context("cuda"):
45
+ images = pipes[checkpoint](
46
+ n_samples * [in_prompt],
47
+ guidance_scale=guidance_scale,
48
+ num_inference_steps=steps
49
+ ).images
50
+
51
+ return images
52
+
53
+
54
+ def infer_random(n_samples, steps, scale, model):
55
+ return infer(None, n_samples, steps, scale, model)
56
+
57
+ css = """
58
+ a {
59
+ color: inherit;
60
+ text-decoration: underline;
61
+ }
62
+ .gradio-container {
63
+ font-family: 'IBM Plex Sans', sans-serif;
64
+ }
65
+ .gr-button {
66
+ color: white;
67
+ border-color: #9d66e5;
68
+ background: #9d66e5;
69
+ }
70
+ input[type='range'] {
71
+ accent-color: #9d66e5;
72
+ }
73
+ .dark input[type='range'] {
74
+ accent-color: #dfdfdf;
75
+ }
76
+ .container {
77
+ max-width: 730px;
78
+ margin: auto;
79
+ padding-top: 1.5rem;
80
+ }
81
+ #gallery {
82
+ min-height: 22rem;
83
+ margin-bottom: 15px;
84
+ margin-left: auto;
85
+ margin-right: auto;
86
+ border-bottom-right-radius: .5rem !important;
87
+ border-bottom-left-radius: .5rem !important;
88
+ }
89
+ #gallery>div>.h-full {
90
+ min-height: 20rem;
91
+ }
92
+ .details:hover {
93
+ text-decoration: underline;
94
+ }
95
+ .gr-button {
96
+ white-space: nowrap;
97
+ }
98
+ .gr-button:focus {
99
+ border-color: rgb(147 197 253 / var(--tw-border-opacity));
100
+ outline: none;
101
+ box-shadow: var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow, 0 0 #0000);
102
+ --tw-border-opacity: 1;
103
+ --tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width) var(--tw-ring-offset-color);
104
+ --tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(3px var(--tw-ring-offset-width)) var(--tw-ring-color);
105
+ --tw-ring-color: rgb(191 219 254 / var(--tw-ring-opacity));
106
+ --tw-ring-opacity: .5;
107
+ }
108
+ #advanced-options {
109
+ margin-bottom: 20px;
110
+ }
111
+ .footer {
112
+ margin-bottom: 45px;
113
+ margin-top: 35px;
114
+ text-align: center;
115
+ border-bottom: 1px solid #e5e5e5;
116
+ }
117
+ .footer>p {
118
+ font-size: .8rem;
119
+ display: inline-block;
120
+ padding: 0 10px;
121
+ transform: translateY(10px);
122
+ background: white;
123
+ }
124
+ .dark .logo{ filter: invert(1); }
125
+ .dark .footer {
126
+ border-color: #303030;
127
+ }
128
+ .dark .footer>p {
129
+ background: #0b0f19;
130
+ }
131
+ .acknowledgments h4{
132
+ margin: 1.25em 0 .25em 0;
133
+ font-weight: bold;
134
+ font-size: 115%;
135
+ }
136
+ """
137
+
138
+ block = gr.Blocks(css=css)
139
+
140
+ with block:
141
+ gr.HTML(
142
+ """
143
+ <div style="text-align: center; max-width: 650px; margin: 0 auto;">
144
+ <div>
145
+ <h1 style="font-weight: 900; font-size: 3rem;">
146
+ Gundam text to image
147
+ </h1>
148
+ </div>
149
+ <p style="margin-bottom: 10px; font-size: 94%">
150
+ From a text description, generate a mecha from the anime franchise Mobile Suit Gundam
151
+ </p>
152
+ <p style="margin-bottom: 10px; font-size: 94%">
153
+ Github: <a href="https://github.com/Askannz/gundam-stable-diffusion">https://github.com/Askannz/gundam-stable-diffusion</a>
154
+ </p>
155
+ <ul>
156
+ <li>More steps generally means less visual noise but fewer details</li>
157
+ <li>Text guidance controls how much the prompt influences the generation</li>
158
+ <li>The overfitted model gives cleaner but less original results</li>
159
+ </ul>
160
+ </div>
161
+ """
162
+ )
163
+ with gr.Group():
164
+
165
+ with gr.Box():
166
+ with gr.Row().style(mobile_collapse=False, equal_height=True):
167
+ text = gr.Textbox(
168
+ label="Enter your prompt",
169
+ show_label=False,
170
+ max_lines=1,
171
+ placeholder="Enter your prompt",
172
+ ).style(
173
+ border=(True, False, True, True),
174
+ rounded=(True, False, False, True),
175
+ container=False,
176
+ )
177
+ btn = gr.Button("Generate from prompt").style(
178
+ margin=False,
179
+ rounded=(False, True, True, False),
180
+ )
181
+
182
+ with gr.Box():
183
+ with gr.Row().style(mobile_collapse=False, equal_height=True):
184
+ btn_rand = gr.Button("Random").style(
185
+ margin=False,
186
+ rounded=(False, True, True, False),
187
+ )
188
+
189
+ gallery = gr.Gallery(
190
+ label="Generated images", show_label=False, elem_id="gallery"
191
+ ).style(grid=[2], height="auto")
192
+
193
+
194
+ with gr.Row(elem_id="advanced-options"):
195
+ samples = gr.Slider(label="Images", minimum=1, maximum=4, value=1, step=1)
196
+ steps = gr.Slider(label="Steps", minimum=5, maximum=50, value=25, step=5)
197
+ scale = gr.Slider(
198
+ label="Text Guidance Scale", minimum=0, maximum=50, value=7.5, step=0.1
199
+ )
200
+
201
+ with gr.Row(elem_id="checkpoint"):
202
+ #model = gr.Radio(label="Model", choices=["normal", "overfitted"], value="normal")
203
+ model = gr.Radio(label="Model", choices=["normal"], value="normal")
204
+
205
+
206
+ text.submit(infer, inputs=[text, samples, steps, scale, model], outputs=gallery)
207
+ btn.click(infer, inputs=[text, samples, steps, scale, model], outputs=gallery)
208
+ btn_rand.click(infer_random, inputs=[samples, steps, scale, model], outputs=gallery)
209
+ gr.HTML(
210
+ """
211
+ <div class="footer">
212
+ <p> Gradio Demo by 🤗 Hugging Face and Gazoche
213
+ </p>
214
+ </div>
215
+ """
216
+ )
217
+
218
+ block.launch()
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ diffusers==0.3.0
2
+ transformers
3
+ scipy
4
+ ftfy
5
+ gradio
6
+ datasets
7
+ fastapi
8
+ uvicorn
9
+ requests