SD 2.1 unet NaNs in output

#5
by ptx0 - opened

Hey Ollin, I'm sorry if this isnt' the best place to begin a discussion on the issue, but I was wondering if you have any idea why SD 2.1-v's unet produces NaNs in some positions for each iteration in Diffusers, Auto, and ComfyUI?

I did some mucking-about and discovered that running the VAE in fp32 doesn't help avoid black outputs there, because each step of iteration is on bad latents with many NaN positions. doing stupid things like setting these positions to zero or 1e-10 doesn't fix anything, it only reveals that the model is returning pure noise.

2.1-base (512px) seems unaffected by the problem, but other v-prediction models (eg. ptx0/terminus-xl-gamma-v2-1) does not exhibit the problem at fp16 precision level, albeit that uses OpenCLIP-G/14 and CLIP-L/14 instead of OpenCLIP-H/14.

Hmm, do you have an example notebook/script reproducing the issue? If it's a model problem (like the SDXL VAE NaNs), I have some helper code for logging activations that could narrow things down.

essentially. manually casting the unet to fp16 for stabilityai/stable-diffusion-2-1 will reproduce the issue:

pipe = DiffusionPipeline.from_pretrained('stabilityai/stable-diffusion-2-1')
pipe.unet = pipe.unet.to(torch.float16)

the vae and text encoder can run at fp16 just fine.

also, we discovered then that the single file ckpt non-ema pruned 2.1-v file works correctly at fp16.

Digging around, it looks like Patrick summarized the sd2.1-768 fp16 issue and mitigations here https://github.com/huggingface/diffusers/issues/1614#issuecomment-1385093065. I haven't figured out how to repro the NaNs in latest diffusers yet (I guess the mitigations are working)

In my quick test, it looks like the overflowing values are probably in the attn1 of the last two up_blocks. So maybe you only need to upcast those? But I still haven't figured out how to disable upcast and repro NaNs :p
image.png

import torch as th
from diffusers import DiffusionPipeline

pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1", torch_dtype=th.float16)
pipe.unet = pipe.unet.to(th.float16)
pipe = pipe.to("cuda")

def summarize_tensor(x):
    color = [color for max_val, color in ((10, "2"), (100, "3"), (1000, "1")) if max_val > x.abs().max().item()][0]
    return f"\033[3{color}m (min {x.min().item():04.4f} / mean {x.mean().item():04.4f}m / max {x.max().item():04.4f}\033[0m)"


class ModelActivationPrinter:
    def __init__(self, module, submodules_to_log):
        self.id_to_name = {
            id(module): str(name) for name, module in module.named_modules()
        }
        self.submodules = submodules_to_log
        self.hooks = []

    def __enter__(self, *args, **kwargs):
        def log_activations(m, m_in, m_out):
            label = self.id_to_name.get(id(m), "(unnamed)") + " output"
            if isinstance(m_out, (tuple, list)):
                m_out = m_out[0]
                label += "[0]"
            print(label.ljust(96) + summarize_tensor(m_out))

        for m in self.submodules:
            self.hooks.append(m.register_forward_hook(log_activations))
        return self

    def __exit__(self, *args, **kwargs):
        for hook in self.hooks:
            hook.remove()

def select_modules(model):
    modules = []
    for m in pipe.unet.modules():
        if hasattr(m, "to_q"):
            modules.append(m.to_q)
        if hasattr(m, "to_k"):
            modules.append(m.to_k)
    return modules

with ModelActivationPrinter(pipe.unet, select_modules(pipe.unet)):
    image = pipe("slice of delicious New York-style berry cheesecake", num_inference_steps=15).images[0]
    display(image)

Sign up or log in to comment