Spaces:
Sleeping
Sleeping
import importlib | |
from typing import List | |
import gradio as gr | |
import numpy as np | |
import torch | |
from diffusers import StableDiffusionPipeline | |
from torchmetrics import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure | |
from image_utils import make_grid, numpy_to_pil | |
from metrics_utils import compute_main_metrics, compute_psnr_or_ssim | |
from report_utils import add_psnr_ssim_to_report, prepare_report | |
SEED = 0 | |
WEIGHT_DTYPE = torch.float16 | |
TITLE = "Evaluate Schedulers with StableDiffusionPipeline 🧨" | |
ABSTRACT = """ | |
This Space allows you to quantitatively compare [different noise schedulers](https://huggingface.co/docs/diffusers/using-diffusers/schedulers) with a [`StableDiffusionPipeline`](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/overview). | |
One of the applications of this Space could be to evaluate different schedulers for a certain Stable Diffusion checkpoint for a fixed number of inference steps. | |
""" | |
DESCRIPTION = """ | |
#### Hoes does it work? | |
* The evaluator first sets a seed and then generates the initial noise which is passed as the initial latent to start the image generation process. It is done to ensure fair comparison. | |
* This initial latent is used every time the pipeline is run (with different schedulers). | |
* To quantify the quality of the generated images we use: | |
* [Inception Score](https://en.wikipedia.org/wiki/Inception_score) | |
* [Clip Score](https://arxiv.org/abs/2104.08718) | |
#### Notes | |
* When selecting a model checkpoint, if you select "Other" you will have the option to provide a custom Stable Diffusion checkpoint. | |
* The default scheduler associated with the provided checkpoint is always used for reporting the scores. | |
* Increasing both the number of images per prompt and the number of inference steps could quickly build up the inference queue and thus | |
resulting in slowdowns. | |
""" | |
psnr_fn = PeakSignalNoiseRatio() | |
ssim_fn = StructuralSimilarityIndexMeasure() | |
def initialize_pipeline(checkpoint: str): | |
sd_pipe = StableDiffusionPipeline.from_pretrained( | |
checkpoint, torch_dtype=WEIGHT_DTYPE | |
) | |
sd_pipe = sd_pipe.to("cuda") | |
original_scheduler_config = sd_pipe.scheduler.config | |
return sd_pipe, original_scheduler_config | |
def get_scheduler(scheduler_name: str): | |
schedulers_lib = importlib.import_module("diffusers", package="schedulers") | |
scheduler_abs = getattr(schedulers_lib, scheduler_name) | |
return scheduler_abs | |
def get_latents(num_images_per_prompt: int, seed=SEED): | |
generator = torch.manual_seed(seed) | |
latents = np.random.RandomState(seed).standard_normal( | |
(num_images_per_prompt, 4, 64, 64) | |
) | |
latents = torch.from_numpy(latents).to(device="cuda", dtype=WEIGHT_DTYPE) | |
return latents | |
def run( | |
prompt: str, | |
num_images_per_prompt: int, | |
num_inference_steps: int, | |
checkpoint: str, | |
other_finedtuned_checkpoints: str = None, | |
schedulers_to_test: List[str] = None, | |
ssim: bool = False, | |
psnr: bool = False, | |
progress=gr.Progress(), | |
): | |
progress(0, desc="Starting...") | |
if checkpoint == "Other" and other_finedtuned_checkpoints == "": | |
return "❌ No legit checkpoint provided ❌" | |
elif checkpoint == "Other": | |
checkpoint = other_finedtuned_checkpoints | |
all_images = {} | |
scheduler_images = {} | |
# Set up the pipeline | |
sd_pipeline, original_scheduler_config = initialize_pipeline(checkpoint) | |
sd_pipeline.set_progress_bar_config(disable=True) | |
# Prepare latents to start generation and the prompts. | |
latents = get_latents(num_images_per_prompt) | |
prompts = [prompt] * num_images_per_prompt | |
original_scheduler_name = original_scheduler_config._class_name | |
schedulers_to_test.append(original_scheduler_name) | |
# Start generating the images and computing their scores. | |
for scheduler_name in progress.tqdm(schedulers_to_test): | |
if scheduler_name != original_scheduler_name: | |
scheduler_cls = get_scheduler(scheduler_name) | |
current_scheduler = scheduler_cls.from_config(original_scheduler_config) | |
sd_pipeline.scheduler = current_scheduler | |
cur_scheduler_images = sd_pipeline( | |
prompts, | |
latents=latents, | |
num_inference_steps=num_inference_steps, | |
output_type="numpy", | |
).images | |
all_images.update( | |
{ | |
scheduler_name: { | |
"images": make_grid( | |
numpy_to_pil(cur_scheduler_images), 1, num_images_per_prompt | |
), | |
"scores": compute_main_metrics(cur_scheduler_images, prompts), | |
} | |
} | |
) | |
scheduler_images.update({scheduler_name: cur_scheduler_images}) | |
torch.cuda.empty_cache() | |
# Prepare output report. | |
output_str = "" | |
for scheduler_name in all_images: | |
output_str += prepare_report(scheduler_name, all_images[scheduler_name]) | |
# Append PSNR or SSIM if needed. | |
if len(schedulers_to_test) > 1: | |
ssim_scores = psnr_scores = None | |
if ssim: | |
ssim_scores = compute_psnr_or_ssim( | |
ssim_fn, scheduler_images, original_scheduler_name | |
) | |
if psnr: | |
psnr_scores = compute_psnr_or_ssim( | |
psnr_fn, scheduler_images, original_scheduler_name | |
) | |
if len(schedulers_to_test) > 1: | |
ssim_psnr_str = add_psnr_ssim_to_report( | |
original_scheduler_name, ssim_scores, psnr_scores | |
) | |
if ssim_psnr_str != "": | |
output_str += ssim_psnr_str | |
return output_str | |
with gr.Blocks(title="Scheduler Evaluation") as demo: | |
gr.Markdown(f"## {TITLE}\n\n\n\n{ABSTRACT}") | |
with gr.Row(): | |
with gr.Column(): | |
prompt = gr.Text( | |
max_lines=1, placeholder="a painting of a dog", label="prompt" | |
) | |
num_images_per_prompt = gr.Slider( | |
3, 10, value=3, step=1, label="num_images_per_prompt" | |
) | |
num_inference_steps = gr.Slider( | |
10, 100, value=50, step=1, label="num_inference_steps" | |
) | |
model_ckpt = gr.Dropdown( | |
[ | |
"CompVis/stable-diffusion-v1-4", | |
"runwayml/stable-diffusion-v1-5", | |
"stabilityai/stable-diffusion-2-base", | |
"Other", | |
], | |
value="CompVis/stable-diffusion-v1-4", | |
multiselect=False, | |
interactive=True, | |
label="model_ckpt", | |
) | |
other_finedtuned_checkpoints = gr.Textbox( | |
visible=False, | |
interactive=True, | |
placeholder="valhalla/sd-pokemon-model", | |
label="custom_checkpoint", | |
) | |
model_ckpt.change( | |
lambda x: gr.Dropdown.update(visible=x == "Other"), | |
model_ckpt, | |
other_finedtuned_checkpoints, | |
) | |
schedulers_to_test = gr.Dropdown( | |
[ | |
"EulerDiscreteScheduler", | |
"PNDMScheduler", | |
"LMSDiscreteScheduler", | |
"DPMSolverMultistepScheduler", | |
"DDIMScheduler", | |
], | |
value=["LMSDiscreteScheduler"], | |
multiselect=True, | |
label="schedulers_to_test", | |
) | |
ssim = gr.Checkbox(label="Compute SSIM") | |
psnr = gr.Checkbox(label="Compute PSNR") | |
evaluation_button = gr.Button(value="Submit") | |
with gr.Column(): | |
report = gr.Markdown(label="Evaluation Report").style() | |
evaluation_button.click( | |
run, | |
inputs=[ | |
prompt, | |
num_images_per_prompt, | |
num_inference_steps, | |
model_ckpt, | |
other_finedtuned_checkpoints, | |
schedulers_to_test, | |
ssim, | |
psnr, | |
], | |
outputs=report, | |
) | |
gr.Markdown(f"{DESCRIPTION}") | |
demo.queue().launch(debug=True) | |