sayakpaul HF staff commited on
Commit
4eea26f
1 Parent(s): 3e190bd

add: initial files.

Browse files
Files changed (3) hide show
  1. README.md +2 -2
  2. app.py +205 -0
  3. requirements.txt +5 -0
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
- title: Evaluate Sd Schedulers
3
- emoji: 👁
4
  colorFrom: gray
5
  colorTo: indigo
6
  sdk: gradio
 
1
  ---
2
+ title: Evaluate StableDiffusionPipeline with Different Schedulers
3
+ emoji:
4
  colorFrom: gray
5
  colorTo: indigo
6
  sdk: gradio
app.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ from functools import partial
3
+ from typing import List
4
+
5
+ import gradio as gr
6
+ import numpy as np
7
+ import torch
8
+ from diffusers import StableDiffusionPipeline
9
+ from PIL import Image
10
+ from torchmetrics.functional.multimodal import clip_score
11
+ from torchmetrics.image.inception import InceptionScore
12
+
13
+ SEED = 0
14
+ WEIGHT_DTYPE = torch.float16
15
+
16
+
17
+ inception_score_fn = InceptionScore(normalize=True)
18
+ torch.manual_seed(SEED)
19
+ clip_score_fn = partial(clip_score, model_name_or_path="openai/clip-vit-base-patch16")
20
+
21
+
22
+ def make_grid(images, rows, cols):
23
+ w, h = images[0].size
24
+ grid = Image.new("RGB", size=(cols * w, rows * h))
25
+ for i, image in enumerate(images):
26
+ grid.paste(image, box=(i % cols * w, i // cols * h))
27
+ return grid
28
+
29
+
30
+ # Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_utils.py#L814
31
+ def numpy_to_pil(images):
32
+ """
33
+ Convert a numpy image or a batch of images to a PIL image.
34
+ """
35
+ if images.ndim == 3:
36
+ images = images[None, ...]
37
+ images = (images * 255).round().astype("uint8")
38
+ if images.shape[-1] == 1:
39
+ # special case for grayscale (single channel) images
40
+ pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
41
+ else:
42
+ pil_images = [Image.fromarray(image) for image in images]
43
+
44
+ return pil_images
45
+
46
+
47
+ def prepare_report(scheduler_name: str, results: dict):
48
+ image_grid = results["images"]
49
+ scores = results["scores"]
50
+ img_str = ""
51
+
52
+ image_name = f"{scheduler_name}_images.png"
53
+ image_grid.save(image_name)
54
+ img_str = f"![img_grid_{scheduler_name}](./{image_name})\n"
55
+
56
+ report_str = f"""
57
+ \n\n## {scheduler_name}
58
+
59
+ ### Sample images
60
+
61
+ {img_str}
62
+
63
+ ### Scores
64
+
65
+ {scores}
66
+ \n\n
67
+ """
68
+
69
+ return report_str
70
+
71
+
72
+ def initialize_pipeline(checkpoint: str):
73
+ sd_pipe = StableDiffusionPipeline.from_pretrained(
74
+ checkpoint, torch_dtype=WEIGHT_DTYPE
75
+ )
76
+ sd_pipe = sd_pipe.to("cuda")
77
+ original_scheduler_config = sd_pipe.scheduler.config
78
+ return sd_pipe, original_scheduler_config
79
+
80
+
81
+ def get_scheduler(scheduler_name):
82
+ schedulers_lib = importlib.import_module("diffusers", package="schedulers")
83
+ scheduler_abs = getattr(schedulers_lib, scheduler_name)
84
+
85
+ return scheduler_abs
86
+
87
+
88
+ def get_latents(num_images_per_prompt: int, seed=SEED):
89
+ generator = torch.manual_seed(seed)
90
+ latents = np.random.RandomState(seed).standard_normal(
91
+ (num_images_per_prompt, 4, 64, 64)
92
+ )
93
+ latents = torch.from_numpy(latents).to(device="cuda", dtype=WEIGHT_DTYPE)
94
+ return latents
95
+
96
+
97
+ def compute_metrics(images: np.ndarray, prompts: List[str]):
98
+ inception_score_fn.update(torch.from_numpy(images).permute(0, 3, 1, 2))
99
+ inception_score = inception_score_fn.compute()
100
+
101
+ images_int = (images * 255).astype("uint8")
102
+ clip_score = clip_score_fn(
103
+ torch.from_numpy(images_int).permute(0, 3, 1, 2), prompts
104
+ ).detach()
105
+ return {
106
+ "inception_score (⬆️)": {
107
+ "mean": round(float(inception_score[0]), 4),
108
+ "std": round(float(inception_score[1]), 4),
109
+ },
110
+ "clip_score (⬆️)": round(float(clip_score), 4),
111
+ }
112
+
113
+
114
+ def run(
115
+ prompt: str,
116
+ num_images_per_prompt: int,
117
+ num_inference_steps: int,
118
+ checkpoint: str,
119
+ schedulers_to_test: List[str],
120
+ ):
121
+ all_images = {}
122
+
123
+ sd_pipeline, original_scheduler_config = initialize_pipeline(checkpoint)
124
+ latents = get_latents(num_images_per_prompt)
125
+ prompts = [prompt] * num_images_per_prompt
126
+
127
+ images = sd_pipeline(
128
+ prompts,
129
+ latents=latents,
130
+ num_inference_steps=num_inference_steps,
131
+ output_type="numpy",
132
+ ).images
133
+ original_scheduler_name = original_scheduler_config._class_name
134
+
135
+ all_images.update(
136
+ {
137
+ original_scheduler_name: {
138
+ "images": make_grid(numpy_to_pil(images), 1, num_images_per_prompt),
139
+ "scores": compute_metrics(images, prompts),
140
+ }
141
+ }
142
+ )
143
+ print("First scheduler complete.")
144
+
145
+ for scheduler_name in schedulers_to_test:
146
+ if scheduler_name == original_scheduler_name:
147
+ continue
148
+ scheduler_cls = get_scheduler(scheduler_name)
149
+ current_scheduler = scheduler_cls.from_config(original_scheduler_config)
150
+ sd_pipeline.scheduler = current_scheduler
151
+
152
+ cur_scheduler_images = sd_pipeline(
153
+ prompts, num_inference_steps=num_inference_steps, output_type="numpy"
154
+ ).images
155
+ all_images.update(
156
+ {
157
+ scheduler_name: {
158
+ "images": make_grid(
159
+ numpy_to_pil(cur_scheduler_images), 1, num_images_per_prompt
160
+ ),
161
+ "scores": compute_metrics(cur_scheduler_images, prompts),
162
+ }
163
+ }
164
+ )
165
+ print(f"{scheduler_name} complete.")
166
+
167
+ output_str = ""
168
+ for scheduler_name in all_images:
169
+ print(f"scheduler_name: {scheduler_name}")
170
+ output_str += prepare_report(scheduler_name, all_images[scheduler_name])
171
+ print(output_str)
172
+ return output_str
173
+
174
+
175
+ demo = gr.Interface(
176
+ run,
177
+ inputs=[
178
+ gr.Text(max_lines=1, placeholder="a painting of a dog"),
179
+ gr.Slider(3, 10, value=3),
180
+ gr.Slider(10, 100, value=50),
181
+ gr.Dropdown(
182
+ [
183
+ "CompVis/stable-diffusion-v1-4",
184
+ "runwayml/stable-diffusion-v1-5",
185
+ "stabilityai/stable-diffusion-2-base",
186
+ ],
187
+ value="CompVis/stable-diffusion-v1-4",
188
+ multiselect=False,
189
+ ),
190
+ gr.Dropdown(
191
+ [
192
+ "EulerDiscreteScheduler",
193
+ "PNDMScheduler",
194
+ "LMSDiscreteScheduler",
195
+ "DPMSolverMultistepScheduler",
196
+ "DDIMScheduler",
197
+ ],
198
+ value=["LMSDiscreteScheduler"],
199
+ multiselect=True,
200
+ ),
201
+ ],
202
+ outputs=[gr.Markdown().style()],
203
+ allow_flagging=False,
204
+ )
205
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torchmetrics[image]
2
+ transformers
3
+ diffusers
4
+ accelerate
5
+ numpy