Spaces:
Runtime error
Runtime error
import importlib | |
from functools import partial | |
from typing import List | |
import gradio as gr | |
import numpy as np | |
import torch | |
from diffusers import StableDiffusionPipeline | |
from PIL import Image | |
from torchmetrics.functional.multimodal import clip_score | |
from torchmetrics.image.inception import InceptionScore | |
SEED = 0 | |
WEIGHT_DTYPE = torch.float16 | |
TITLE = "Evaluate Schedulers with StableDiffusionPipeline 🧨" | |
DESCRIPTION = """ | |
This Space allows you to quantitatively compare [different noise schedulers](https://huggingface.co/docs/diffusers/using-diffusers/schedulers) with a [`StableDiffusionPipeline`](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/overview). | |
One of the applications of this Space could be to evaluate different schedulers for a certain Stable Diffusion checkpoint for a fixed number of inference steps. | |
Here's how it works: | |
* The evaluator first sets a seed and then generates the initial noise which is passed as the initial latent to start the image generation process. It is done to ensure fair comparison. | |
* This initial latent is used every time the pipeline is run (with different schedulers). | |
* To quantify the quality of the generated images we use: | |
* [Inception Score](https://en.wikipedia.org/wiki/Inception_score) | |
* [Clip Score](https://arxiv.org/abs/2104.08718) | |
**Notes**: | |
* The default scheduler associated with the provided checkpoint is always used for reporting the scores. | |
* Increasing both the number of images per prompt and the number of inference steps could quickly build up the inference queue and thus | |
resulting in slowdowns. | |
""" | |
inception_score_fn = InceptionScore(normalize=True) | |
torch.manual_seed(SEED) | |
clip_score_fn = partial(clip_score, model_name_or_path="openai/clip-vit-base-patch16") | |
def make_grid(images, rows, cols): | |
w, h = images[0].size | |
grid = Image.new("RGB", size=(cols * w, rows * h)) | |
for i, image in enumerate(images): | |
grid.paste(image, box=(i % cols * w, i // cols * h)) | |
return grid | |
# Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_utils.py#L814 | |
def numpy_to_pil(images): | |
""" | |
Convert a numpy image or a batch of images to a PIL image. | |
""" | |
if images.ndim == 3: | |
images = images[None, ...] | |
images = (images * 255).round().astype("uint8") | |
if images.shape[-1] == 1: | |
# special case for grayscale (single channel) images | |
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] | |
else: | |
pil_images = [Image.fromarray(image) for image in images] | |
return pil_images | |
def prepare_report(scheduler_name: str, results: dict): | |
image_grid = results["images"] | |
scores = results["scores"] | |
img_str = "" | |
image_name = f"{scheduler_name}_images.png" | |
image_grid.save(image_name) | |
img_str = img_str = f"![img_grid_{scheduler_name}](/file=./{image_name})\n" | |
report_str = f""" | |
\n\n## {scheduler_name} | |
### Sample images | |
{img_str} | |
### Scores | |
{scores} | |
\n\n | |
""" | |
return report_str | |
def initialize_pipeline(checkpoint: str): | |
sd_pipe = StableDiffusionPipeline.from_pretrained( | |
checkpoint, torch_dtype=WEIGHT_DTYPE | |
) | |
sd_pipe = sd_pipe.to("cuda") | |
original_scheduler_config = sd_pipe.scheduler.config | |
return sd_pipe, original_scheduler_config | |
def get_scheduler(scheduler_name: str): | |
schedulers_lib = importlib.import_module("diffusers", package="schedulers") | |
scheduler_abs = getattr(schedulers_lib, scheduler_name) | |
return scheduler_abs | |
def get_latents(num_images_per_prompt: int, seed=SEED): | |
generator = torch.manual_seed(seed) | |
latents = np.random.RandomState(seed).standard_normal( | |
(num_images_per_prompt, 4, 64, 64) | |
) | |
latents = torch.from_numpy(latents).to(device="cuda", dtype=WEIGHT_DTYPE) | |
return latents | |
def compute_metrics(images: np.ndarray, prompts: List[str]): | |
inception_score_fn.update(torch.from_numpy(images).permute(0, 3, 1, 2)) | |
inception_score = inception_score_fn.compute() | |
images_int = (images * 255).astype("uint8") | |
clip_score = clip_score_fn( | |
torch.from_numpy(images_int).permute(0, 3, 1, 2), prompts | |
).detach() | |
return { | |
"inception_score (⬆️)": { | |
"mean": round(float(inception_score[0]), 4), | |
"std": round(float(inception_score[1]), 4), | |
}, | |
"clip_score (⬆️)": round(float(clip_score), 4), | |
} | |
def run( | |
prompt: str, | |
num_images_per_prompt: int, | |
num_inference_steps: int, | |
checkpoint: str, | |
schedulers_to_test: List[str], | |
): | |
all_images = {} | |
sd_pipeline, original_scheduler_config = initialize_pipeline(checkpoint) | |
latents = get_latents(num_images_per_prompt) | |
prompts = [prompt] * num_images_per_prompt | |
images = sd_pipeline( | |
prompts, | |
latents=latents, | |
num_inference_steps=num_inference_steps, | |
output_type="numpy", | |
).images | |
original_scheduler_name = original_scheduler_config._class_name | |
all_images.update( | |
{ | |
original_scheduler_name: { | |
"images": make_grid(numpy_to_pil(images), 1, num_images_per_prompt), | |
"scores": compute_metrics(images, prompts), | |
} | |
} | |
) | |
# print("First scheduler complete.") | |
for scheduler_name in schedulers_to_test: | |
if scheduler_name == original_scheduler_name: | |
continue | |
scheduler_cls = get_scheduler(scheduler_name) | |
current_scheduler = scheduler_cls.from_config(original_scheduler_config) | |
sd_pipeline.scheduler = current_scheduler | |
cur_scheduler_images = sd_pipeline( | |
prompts, num_inference_steps=num_inference_steps, output_type="numpy" | |
).images | |
all_images.update( | |
{ | |
scheduler_name: { | |
"images": make_grid( | |
numpy_to_pil(cur_scheduler_images), 1, num_images_per_prompt | |
), | |
"scores": compute_metrics(cur_scheduler_images, prompts), | |
} | |
} | |
) | |
# print(f"{scheduler_name} complete.") | |
output_str = "" | |
for scheduler_name in all_images: | |
# print(f"scheduler_name: {scheduler_name}") | |
output_str += prepare_report(scheduler_name, all_images[scheduler_name]) | |
# print(output_str) | |
return output_str | |
with gr.Blocks() as demo: | |
gr.HTML(f"<div align='center'{TITLE}</div>") | |
with gr.Row(): | |
with gr.Column(): | |
prompt = gr.Text(max_lines=1, placeholder="a painting of a dog") | |
num_images_per_prompt = gr.Slider(3, 10, value=3, step=1) | |
num_inference_steps = gr.Slider(10, 100, value=50, step=1) | |
model_ckpt = gr.Dropdown( | |
[ | |
"CompVis/stable-diffusion-v1-4", | |
"runwayml/stable-diffusion-v1-5", | |
"stabilityai/stable-diffusion-2-base", | |
"Other" | |
], | |
value="CompVis/stable-diffusion-v1-4", | |
multiselect=False, | |
interactive=True, | |
) | |
other_finedtuned_checkpoints = gr.Text(visible=False, placeholder="valhalla/sd-pokemon-model") | |
model_ckpt.change(lambda x: gr.Dropdown.update(visible=x=="Other"), model_ckpt, other_finedtuned_checkpoints) | |
schedulers_to_test = gr.Dropdown( | |
[ | |
"EulerDiscreteScheduler", | |
"PNDMScheduler", | |
"LMSDiscreteScheduler", | |
"DPMSolverMultistepScheduler", | |
"DDIMScheduler", | |
], | |
value=["LMSDiscreteScheduler"], | |
multiselect=True, | |
) | |
evaluation_button = gr.Button(value="Submit") | |
with gr.Column(): | |
report = gr.Markdown(label="Evaluation Report") | |
demo.launch() |