sayakpaul HF staff commited on
Commit
ca4cd1a
1 Parent(s): e0896ab

add files.

Browse files
Files changed (3) hide show
  1. app.py +261 -102
  2. constants.py +0 -123
  3. requirements.txt +5 -0
app.py CHANGED
@@ -1,103 +1,262 @@
1
- from constants import css
2
-
3
- import gradio as gr
4
-
5
-
6
- block = gr.Blocks(css=css)
7
-
8
- with block:
9
- gr.HTML(
10
- """
11
- <div style="text-align: center; margin: 0 auto;">
12
- <div
13
- style="
14
- display: inline-flex;
15
- align-items: center;
16
- gap: 0.8rem;
17
- font-size: 1.75rem;
18
- "
19
- >
20
- <svg
21
- width="0.65em"
22
- height="0.65em"
23
- viewBox="0 0 115 115"
24
- fill="none"
25
- xmlns="http://www.w3.org/2000/svg"
26
- >
27
- <rect width="23" height="23" fill="white"></rect>
28
- <rect y="69" width="23" height="23" fill="white"></rect>
29
- <rect x="23" width="23" height="23" fill="#AEAEAE"></rect>
30
- <rect x="23" y="69" width="23" height="23" fill="#AEAEAE"></rect>
31
- <rect x="46" width="23" height="23" fill="white"></rect>
32
- <rect x="46" y="69" width="23" height="23" fill="white"></rect>
33
- <rect x="69" width="23" height="23" fill="black"></rect>
34
- <rect x="69" y="69" width="23" height="23" fill="black"></rect>
35
- <rect x="92" width="23" height="23" fill="#D9D9D9"></rect>
36
- <rect x="92" y="69" width="23" height="23" fill="#AEAEAE"></rect>
37
- <rect x="115" y="46" width="23" height="23" fill="white"></rect>
38
- <rect x="115" y="115" width="23" height="23" fill="white"></rect>
39
- <rect x="115" y="69" width="23" height="23" fill="#D9D9D9"></rect>
40
- <rect x="92" y="46" width="23" height="23" fill="#AEAEAE"></rect>
41
- <rect x="92" y="115" width="23" height="23" fill="#AEAEAE"></rect>
42
- <rect x="92" y="69" width="23" height="23" fill="white"></rect>
43
- <rect x="69" y="46" width="23" height="23" fill="white"></rect>
44
- <rect x="69" y="115" width="23" height="23" fill="white"></rect>
45
- <rect x="69" y="69" width="23" height="23" fill="#D9D9D9"></rect>
46
- <rect x="46" y="46" width="23" height="23" fill="black"></rect>
47
- <rect x="46" y="115" width="23" height="23" fill="black"></rect>
48
- <rect x="46" y="69" width="23" height="23" fill="black"></rect>
49
- <rect x="23" y="46" width="23" height="23" fill="#D9D9D9"></rect>
50
- <rect x="23" y="115" width="23" height="23" fill="#AEAEAE"></rect>
51
- <rect x="23" y="69" width="23" height="23" fill="black"></rect>
52
- </svg>
53
- <h1 style="font-weight: 900; margin-bottom: 7px;margin-top:5px">
54
- Stable Diffusion 2.1 Demo
55
- </h1>
56
- </div>
57
- <p style="margin-bottom: 10px; font-size: 94%; line-height: 23px;">
58
- Stable Diffusion 2.1 is the latest text-to-image model from StabilityAI. <a style="text-decoration: underline;" href="https://huggingface.co/spaces/stabilityai/stable-diffusion-1">Access Stable Diffusion 1 Space here</a><br>For faster generation and API
59
- access you can try
60
- <a
61
- href="http://beta.dreamstudio.ai/"
62
- style="text-decoration: underline;"
63
- target="_blank"
64
- >DreamStudio Beta</a
65
- >.</a>
66
- </p>
67
- </div>
68
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  )
70
- with gr.Group():
71
- with gr.Box():
72
- with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True):
73
- with gr.Column():
74
- text = gr.Textbox(
75
- label="Enter your prompt",
76
- show_label=False,
77
- max_lines=1,
78
- placeholder="Enter your prompt",
79
- elem_id="prompt-text-input",
80
- ).style(
81
- border=(True, False, True, True),
82
- rounded=(True, False, False, True),
83
- container=False,
84
- )
85
- negative = gr.Textbox(
86
- label="Enter your negative prompt",
87
- show_label=False,
88
- max_lines=1,
89
- placeholder="Enter a negative prompt",
90
- elem_id="negative-prompt-text-input",
91
- ).style(
92
- border=(True, False, True, True),
93
- rounded=(True, False, False, True),
94
- container=False,
95
- )
96
- btn = gr.Button("Generate image").style(
97
- margin=False,
98
- rounded=(False, True, True, False),
99
- full_width=False,
100
- )
101
-
102
-
103
- block.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ from functools import partial
3
+ from typing import List
4
+
5
+ import gradio as gr
6
+ import numpy as np
7
+ import torch
8
+ from diffusers import StableDiffusionPipeline
9
+ from PIL import Image
10
+ from torchmetrics.functional.multimodal import clip_score
11
+ from torchmetrics.image.inception import InceptionScore
12
+
13
+ SEED = 0
14
+ WEIGHT_DTYPE = torch.float16
15
+
16
+ TITLE = "Evaluate Schedulers with StableDiffusionPipeline 🧨"
17
+ DESCRIPTION = """
18
+ This Space allows you to quantitatively compare [different noise schedulers](https://huggingface.co/docs/diffusers/using-diffusers/schedulers) with a [`StableDiffusionPipeline`](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/overview).
19
+
20
+ One of the applications of this Space could be to evaluate different schedulers for a certain Stable Diffusion checkpoint for a fixed number of inference steps.
21
+
22
+ Here's how it works:
23
+
24
+ * The evaluator first sets a seed and then generates the initial noise which is passed as the initial latent to start the image generation process. It is done to ensure fair comparison.
25
+ * This initial latent is used every time the pipeline is run (with different schedulers).
26
+ * To quantify the quality of the generated images we use:
27
+ * [Inception Score](https://en.wikipedia.org/wiki/Inception_score)
28
+ * [Clip Score](https://arxiv.org/abs/2104.08718)
29
+
30
+ **Notes**:
31
+
32
+ * The default scheduler associated with the provided checkpoint is always used for reporting the scores.
33
+ * Increasing both the number of images per prompt and the number of inference steps could quickly build up the inference queue and thus
34
+ resulting in slowdowns.
35
+ """
36
+
37
+
38
+ inception_score_fn = InceptionScore(normalize=True)
39
+ torch.manual_seed(SEED)
40
+ clip_score_fn = partial(clip_score, model_name_or_path="openai/clip-vit-base-patch16")
41
+
42
+
43
+ def make_grid(images, rows, cols):
44
+ w, h = images[0].size
45
+ grid = Image.new("RGB", size=(cols * w, rows * h))
46
+ for i, image in enumerate(images):
47
+ grid.paste(image, box=(i % cols * w, i // cols * h))
48
+ return grid
49
+
50
+
51
+ # Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_utils.py#L814
52
+ def numpy_to_pil(images):
53
+ """
54
+ Convert a numpy image or a batch of images to a PIL image.
55
+ """
56
+ if images.ndim == 3:
57
+ images = images[None, ...]
58
+ images = (images * 255).round().astype("uint8")
59
+ if images.shape[-1] == 1:
60
+ # special case for grayscale (single channel) images
61
+ pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
62
+ else:
63
+ pil_images = [Image.fromarray(image) for image in images]
64
+
65
+ return pil_images
66
+
67
+
68
+ def prepare_report(scheduler_name: str, results: dict):
69
+ image_grid = results["images"]
70
+ scores = results["scores"]
71
+ img_str = ""
72
+
73
+ image_name = f"{scheduler_name}_images.png"
74
+ image_grid.save(image_name)
75
+ img_str = img_str = f"![img_grid_{scheduler_name}](/file=./{image_name})\n"
76
+
77
+ report_str = f"""
78
+ \n\n## {scheduler_name}
79
+
80
+ ### Sample images
81
+
82
+ {img_str}
83
+
84
+ ### Scores
85
+
86
+ {scores}
87
+ \n\n
88
+ """
89
+
90
+ return report_str
91
+
92
+
93
+ def initialize_pipeline(checkpoint: str):
94
+ sd_pipe = StableDiffusionPipeline.from_pretrained(
95
+ checkpoint, torch_dtype=WEIGHT_DTYPE
96
+ )
97
+ sd_pipe = sd_pipe.to("cuda")
98
+ original_scheduler_config = sd_pipe.scheduler.config
99
+ return sd_pipe, original_scheduler_config
100
+
101
+
102
+ def get_scheduler(scheduler_name: str):
103
+ schedulers_lib = importlib.import_module("diffusers", package="schedulers")
104
+ scheduler_abs = getattr(schedulers_lib, scheduler_name)
105
+
106
+ return scheduler_abs
107
+
108
+
109
+ def get_latents(num_images_per_prompt: int, seed=SEED):
110
+ generator = torch.manual_seed(seed)
111
+ latents = np.random.RandomState(seed).standard_normal(
112
+ (num_images_per_prompt, 4, 64, 64)
113
  )
114
+ latents = torch.from_numpy(latents).to(device="cuda", dtype=WEIGHT_DTYPE)
115
+ return latents
116
+
117
+
118
+ def compute_metrics(images: np.ndarray, prompts: List[str]):
119
+ inception_score_fn.update(torch.from_numpy(images).permute(0, 3, 1, 2))
120
+ inception_score = inception_score_fn.compute()
121
+
122
+ images_int = (images * 255).astype("uint8")
123
+ clip_score = clip_score_fn(
124
+ torch.from_numpy(images_int).permute(0, 3, 1, 2), prompts
125
+ ).detach()
126
+ return {
127
+ "inception_score (⬆️)": {
128
+ "mean": round(float(inception_score[0]), 4),
129
+ "std": round(float(inception_score[1]), 4),
130
+ },
131
+ "clip_score (⬆️)": round(float(clip_score), 4),
132
+ }
133
+
134
+
135
+ def run(
136
+ prompt: str,
137
+ num_images_per_prompt: int,
138
+ num_inference_steps: int,
139
+ checkpoint: str,
140
+ schedulers_to_test: List[str],
141
+ ):
142
+ all_images = {}
143
+
144
+ sd_pipeline, original_scheduler_config = initialize_pipeline(checkpoint)
145
+ latents = get_latents(num_images_per_prompt)
146
+ prompts = [prompt] * num_images_per_prompt
147
+
148
+ images = sd_pipeline(
149
+ prompts,
150
+ latents=latents,
151
+ num_inference_steps=num_inference_steps,
152
+ output_type="numpy",
153
+ ).images
154
+ original_scheduler_name = original_scheduler_config._class_name
155
+
156
+ all_images.update(
157
+ {
158
+ original_scheduler_name: {
159
+ "images": make_grid(numpy_to_pil(images), 1, num_images_per_prompt),
160
+ "scores": compute_metrics(images, prompts),
161
+ }
162
+ }
163
+ )
164
+ # print("First scheduler complete.")
165
+
166
+ for scheduler_name in schedulers_to_test:
167
+ if scheduler_name == original_scheduler_name:
168
+ continue
169
+ scheduler_cls = get_scheduler(scheduler_name)
170
+ current_scheduler = scheduler_cls.from_config(original_scheduler_config)
171
+ sd_pipeline.scheduler = current_scheduler
172
+
173
+ cur_scheduler_images = sd_pipeline(
174
+ prompts, num_inference_steps=num_inference_steps, output_type="numpy"
175
+ ).images
176
+ all_images.update(
177
+ {
178
+ scheduler_name: {
179
+ "images": make_grid(
180
+ numpy_to_pil(cur_scheduler_images), 1, num_images_per_prompt
181
+ ),
182
+ "scores": compute_metrics(cur_scheduler_images, prompts),
183
+ }
184
+ }
185
+ )
186
+ # print(f"{scheduler_name} complete.")
187
+
188
+ output_str = ""
189
+ for scheduler_name in all_images:
190
+ # print(f"scheduler_name: {scheduler_name}")
191
+ output_str += prepare_report(scheduler_name, all_images[scheduler_name])
192
+ # print(output_str)
193
+ return output_str
194
+
195
+
196
+ demo = gr.Interface(
197
+ run,
198
+ inputs=[
199
+ gr.Text(max_lines=1, placeholder="a painting of a dog"),
200
+ gr.Slider(3, 10, value=3, step=1),
201
+ gr.Slider(10, 100, value=50, step=1),
202
+ gr.Dropdown(
203
+ [
204
+ "CompVis/stable-diffusion-v1-4",
205
+ "runwayml/stable-diffusion-v1-5",
206
+ "stabilityai/stable-diffusion-2-base",
207
+ ],
208
+ value="CompVis/stable-diffusion-v1-4",
209
+ multiselect=False,
210
+ interactive=True,
211
+ ),
212
+ gr.Dropdown(
213
+ [
214
+ "EulerDiscreteScheduler",
215
+ "PNDMScheduler",
216
+ "LMSDiscreteScheduler",
217
+ "DPMSolverMultistepScheduler",
218
+ "DDIMScheduler",
219
+ ],
220
+ value=["LMSDiscreteScheduler"],
221
+ multiselect=True,
222
+ ),
223
+ ],
224
+ outputs=[gr.Markdown().style()],
225
+ title=TITLE,
226
+ description=DESCRIPTION,
227
+ allow_flagging=False,
228
+ )
229
+ demo.launch()
230
+
231
+ with gr.Blocks() as demo:
232
+ with gr.Row():
233
+ with gr.Column():
234
+ prompt = gr.Text(max_lines=1, placeholder="a painting of a dog")
235
+ num_images_per_prompt = gr.Slider(3, 10, value=3, step=1)
236
+ num_inference_steps = gr.Slider(10, 100, value=50, step=1)
237
+ model_ckpt = gr.Dropdown(
238
+ [
239
+ "CompVis/stable-diffusion-v1-4",
240
+ "runwayml/stable-diffusion-v1-5",
241
+ "stabilityai/stable-diffusion-2-base",
242
+ "Other"
243
+ ],
244
+ value="CompVis/stable-diffusion-v1-4",
245
+ multiselect=False,
246
+ interactive=True,
247
+ )
248
+ other_finedtuned_checkpoints = gr.Textbox(visible=False)
249
+ model_ckpt.change(lambda x: gr.Dropdown.update(visible=x=="Other"), model_ckpt, other_finedtuned_checkpoints)
250
+ schedulers_to_test = gr.Dropdown(
251
+ [
252
+ "EulerDiscreteScheduler",
253
+ "PNDMScheduler",
254
+ "LMSDiscreteScheduler",
255
+ "DPMSolverMultistepScheduler",
256
+ "DDIMScheduler",
257
+ ],
258
+ value=["LMSDiscreteScheduler"],
259
+ multiselect=True,
260
+ )
261
+
262
+ demo.launch()
constants.py DELETED
@@ -1,123 +0,0 @@
1
- css = """
2
- .gradio-container {
3
- font-family: 'IBM Plex Sans', sans-serif;
4
- }
5
- .gr-button {
6
- color: white;
7
- border-color: black;
8
- background: black;
9
- }
10
- input[type='range'] {
11
- accent-color: black;
12
- }
13
- .dark input[type='range'] {
14
- accent-color: #dfdfdf;
15
- }
16
- .container {
17
- max-width: 730px;
18
- margin: auto;
19
- padding-top: 1.5rem;
20
- }
21
- #gallery {
22
- min-height: 22rem;
23
- margin-bottom: 15px;
24
- margin-left: auto;
25
- margin-right: auto;
26
- border-bottom-right-radius: .5rem !important;
27
- border-bottom-left-radius: .5rem !important;
28
- }
29
- #gallery>div>.h-full {
30
- min-height: 20rem;
31
- }
32
- .details:hover {
33
- text-decoration: underline;
34
- }
35
- .gr-button {
36
- white-space: nowrap;
37
- }
38
- .gr-button:focus {
39
- border-color: rgb(147 197 253 / var(--tw-border-opacity));
40
- outline: none;
41
- box-shadow: var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow, 0 0 #0000);
42
- --tw-border-opacity: 1;
43
- --tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width) var(--tw-ring-offset-color);
44
- --tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(3px var(--tw-ring-offset-width)) var(--tw-ring-color);
45
- --tw-ring-color: rgb(191 219 254 / var(--tw-ring-opacity));
46
- --tw-ring-opacity: .5;
47
- }
48
- #advanced-btn {
49
- font-size: .7rem !important;
50
- line-height: 19px;
51
- margin-top: 12px;
52
- margin-bottom: 12px;
53
- padding: 2px 8px;
54
- border-radius: 14px !important;
55
- }
56
- #advanced-options {
57
- display: none;
58
- margin-bottom: 20px;
59
- }
60
- .footer {
61
- margin-bottom: 45px;
62
- margin-top: 35px;
63
- text-align: center;
64
- border-bottom: 1px solid #e5e5e5;
65
- }
66
- .footer>p {
67
- font-size: .8rem;
68
- display: inline-block;
69
- padding: 0 10px;
70
- transform: translateY(10px);
71
- background: white;
72
- }
73
- .dark .footer {
74
- border-color: #303030;
75
- }
76
- .dark .footer>p {
77
- background: #0b0f19;
78
- }
79
- .acknowledgments h4{
80
- margin: 1.25em 0 .25em 0;
81
- font-weight: bold;
82
- font-size: 115%;
83
- }
84
- .animate-spin {
85
- animation: spin 1s linear infinite;
86
- }
87
- @keyframes spin {
88
- from {
89
- transform: rotate(0deg);
90
- }
91
- to {
92
- transform: rotate(360deg);
93
- }
94
- }
95
- #share-btn-container {
96
- display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem;
97
- margin-top: 10px;
98
- margin-left: auto;
99
- }
100
- #share-btn {
101
- all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important;right:0;
102
- }
103
- #share-btn * {
104
- all: unset;
105
- }
106
- #share-btn-container div:nth-child(-n+2){
107
- width: auto !important;
108
- min-height: 0px !important;
109
- }
110
- #share-btn-container .wrap {
111
- display: none !important;
112
- }
113
-
114
- .gr-form{
115
- flex: 1 1 50%; border-top-right-radius: 0; border-bottom-right-radius: 0;
116
- }
117
- #prompt-container{
118
- gap: 0;
119
- }
120
- #prompt-text-input, #negative-prompt-text-input{padding: .45rem 0.625rem}
121
- #component-16{border-top-width: 1px!important;margin-top: 1em}
122
- .image_duplication{position: absolute; width: 100px; left: 50px}
123
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torchmetrics[image]
2
+ transformers
3
+ diffusers
4
+ accelerate
5
+ numpy