How is this achieved?

#1
by gkorepanov - opened

Hi, can you share the details of how you have fixed the VAE in fp16?

Thank you, this is a nice work!

gkorepanov changed discussion status to closed

@madebyollin would you be willing to share the code (however nice or not it might be) of how you did it? I'm interested in learning of how something like this is done.

@Kubuxu No code, sorry, too messy (+too much of it changed during training).

Some notes on fine-tuning process:

  • I mostly trained in bfloat16to avoid OOM

  • I watched activation-map magnitudes + output deltas on a test image and manually rebalanced the match-original-output and make-activation-maps-smaller losses occasionally. image.png image.png

  • I froze the weight matrices and only fine-tuned biases / normalization layers / a single scaler for each weight matrix (screenshot for decoder).
    image.png

Some speculation on what might have happened to the original VAE:

  • I think Stability zero-initialized final convs in UNet resblocks, but not VAE resblocks, which I think leads to variance increasing with VAE depth (per FixUp / ReZero papers)
  • I think Stability initialized up / down convs with too-big weights (default PyTorch initialization assumes a ReLU-like nonlinearity afterwards, but these convs have no nonlinearity), which I think will increase variance after each of these convs (per He initialization paper)
  • If you compare the original / fixed weights, resblock final convs and up / down convs are mostly what shrunk, which seems like weak evidence in favor of these weights being too large initially. image.png

That is all speculation though - I don't thoroughly understand the issue yet. I just threw some code together and happened to get it working :)

Sign up or log in to comment