Curious about the methodology of finetuning

#15
by zwd125 - opened

Thank you for the great work!
Still a little bit curious about the methodology. May I ask how this was done, to finetune the weight while keeping distribution of the latent unmoved? Was the encoder and the decoder finetuned separately or not? Thanks very much.

Yeah, separate, something like this (the fixed encoder outputs are never sent to the fixed decoder):

def compute_loss(ims, ref_vae, fix_vae, train_encoder=True, train_decoder=True):
    ref_latents, _ref_features = run_model_and_capture_features(ref_vae.encoder, ims)
    loss = 0
    if train_encoder:
        fix_latents, fix_features = run_model_and_capture_features(fix_vae.encoder, ims)
        loss = loss + F.l1_loss(fix_latents, ref_latents) + 0.0001 * F.relu(th.abs(fix_features) - 100).mean()
    if train_decoder:
        ref_ims, _ref_features = run_model_and_capture_features(ref_vae.decoder, ref_latents)
        fix_ims, fix_features = run_model_and_capture_features(fix_vae.decoder, ref_latents)
        loss = loss + F.l1_loss(fix_ims, ref_ims) + 0.0001 * F.relu(th.abs(fix_features) - 100).mean()
    return loss

I trained both enc/dec in the same training run initially, but later switched to individual training of enc / dec at a larger batch size.

Got it. Thank you very much for explanation!

As you say "I watched activation-map magnitudes + output deltas on a test image and manually rebalanced the match-original-output and make-activation-maps-smaller losses occasionally.":
Is "fix_features" the concatenation of all intermediate layers' outputs? And did you rebalance two loss items by adjusting the scale 0.0001 above?
Thank you very much!

hey @madebyollin thanks for your great great work.

Could you further explain how you "scaling down weights and biases within the network"? besides the F.relu(th.abs(fix_features) - 100).mean() (please correct me if anything wrong.)

@xyzhang626 I was only training on a small amount of data (~10k images or something), so to make sure I wasn't breaking anything, I froze quant_conv and post_quant_conv as well as all of the Conv2d / Linear weight matrices, then gave each weight matrix a single trainable scale parameter (initialized to 1). So the finetuning process only used 37802 trainable decoder parameters (and a similarly small number in the encoder).

This comment has been hidden

@xyzhang626 I was only training on a small amount of data (~10k images or something), so to make sure I wasn't breaking anything, I froze quant_conv and post_quant_conv as well as all of the Conv2d / Linear weight matrices, then gave each weight matrix a single trainable scale parameter (initialized to 1). So the finetuning process only used 37802 trainable decoder parameters (and a similarly small number in the encoder).

Thanks for your explanation. Really helpful!

@xyzhang626 I was only training on a small amount of data (~10k images or something), so to make sure I wasn't breaking anything, I froze quant_conv and post_quant_conv as well as all of the Conv2d / Linear weight matrices, then gave each weight matrix a single trainable scale parameter (initialized to 1). So the finetuning process only used 37802 trainable decoder parameters (and a similarly small number in the encoder).

@madebyollin thanks for your great great work.
I trained a decoder from scratch with frozen encoder. But the activations too large, can't run in fp16. Luckily found your solution. i want some details, firstly, one conv2d(weight matrices, [Cout, Cin, kw,kh]) with a single trainable scale? why "37802 trainable decoder parameters"? and is groupnorm weight and bias trainable? Secondly, " I was only training on a small amount of data (~10k images or something)", whats the epochs or lr you used

@Escapist

firstly, one conv2d(weight matrices, [Cout, Cin, kw,kh]) with a single trainable scale? why "37802 trainable decoder parameters"? and is groupnorm weight and bias trainable?

Yeah, single scale per conv weight matrix. The norms are all trainable.

Trainable params by layer
Module                                   Trainable Parameters
-------------------------------------------------------------
vae                                      37802           
vae.decoder                              37802           
vae.decoder.conv_in                      513             
vae.decoder.mid                          9224            
vae.decoder.mid.block_1                  3074            
vae.decoder.mid.block_1.norm1            1024            
vae.decoder.mid.block_1.conv1            513             
vae.decoder.mid.block_1.norm2            1024            
vae.decoder.mid.block_1.conv2            513             
vae.decoder.mid.attn_1                   3076            
vae.decoder.mid.attn_1.norm              1024            
vae.decoder.mid.attn_1.q                 513             
vae.decoder.mid.attn_1.k                 513             
vae.decoder.mid.attn_1.v                 513             
vae.decoder.mid.attn_1.proj_out          513             
vae.decoder.mid.block_2                  3074            
vae.decoder.mid.block_2.norm1            1024            
vae.decoder.mid.block_2.conv1            513             
vae.decoder.mid.block_2.norm2            1024            
vae.decoder.mid.block_2.conv2            513             
vae.decoder.up                           27805           
vae.decoder.up.0                         2695            
vae.decoder.up.0.block                   2695            
vae.decoder.up.0.block.0                 1155            
vae.decoder.up.0.block.0.norm1           512             
vae.decoder.up.0.block.0.conv1           129             
vae.decoder.up.0.block.0.norm2           256             
vae.decoder.up.0.block.0.conv2           129             
vae.decoder.up.0.block.0.nin_shortcut    129             
vae.decoder.up.0.block.1                 770             
vae.decoder.up.0.block.1.norm1           256             
vae.decoder.up.0.block.1.conv1           129             
vae.decoder.up.0.block.1.norm2           256             
vae.decoder.up.0.block.1.conv2           129             
vae.decoder.up.0.block.2                 770             
vae.decoder.up.0.block.2.norm1           256             
vae.decoder.up.0.block.2.conv1           129             
vae.decoder.up.0.block.2.norm2           256             
vae.decoder.up.0.block.2.conv2           129             
vae.decoder.up.1                         5640            
vae.decoder.up.1.block                   5383            
vae.decoder.up.1.block.0                 2307            
vae.decoder.up.1.block.0.norm1           1024            
vae.decoder.up.1.block.0.conv1           257             
vae.decoder.up.1.block.0.norm2           512             
vae.decoder.up.1.block.0.conv2           257             
vae.decoder.up.1.block.0.nin_shortcut    257             
vae.decoder.up.1.block.1                 1538            
vae.decoder.up.1.block.1.norm1           512             
vae.decoder.up.1.block.1.conv1           257             
vae.decoder.up.1.block.1.norm2           512             
vae.decoder.up.1.block.1.conv2           257             
vae.decoder.up.1.block.2                 1538            
vae.decoder.up.1.block.2.norm1           512             
vae.decoder.up.1.block.2.conv1           257             
vae.decoder.up.1.block.2.norm2           512             
vae.decoder.up.1.block.2.conv2           257             
vae.decoder.up.1.upsample                257             
vae.decoder.up.1.upsample.conv           257             
vae.decoder.up.2                         9735            
vae.decoder.up.2.block                   9222            
vae.decoder.up.2.block.0                 3074            
vae.decoder.up.2.block.0.norm1           1024            
vae.decoder.up.2.block.0.conv1           513             
vae.decoder.up.2.block.0.norm2           1024            
vae.decoder.up.2.block.0.conv2           513             
vae.decoder.up.2.block.1                 3074            
vae.decoder.up.2.block.1.norm1           1024            
vae.decoder.up.2.block.1.conv1           513             
vae.decoder.up.2.block.1.norm2           1024            
vae.decoder.up.2.block.1.conv2           513             
vae.decoder.up.2.block.2                 3074            
vae.decoder.up.2.block.2.norm1           1024            
vae.decoder.up.2.block.2.conv1           513             
vae.decoder.up.2.block.2.norm2           1024            
vae.decoder.up.2.block.2.conv2           513             
vae.decoder.up.2.upsample                513             
vae.decoder.up.2.upsample.conv           513             
vae.decoder.up.3                         9735            
vae.decoder.up.3.block                   9222            
vae.decoder.up.3.block.0                 3074            
vae.decoder.up.3.block.0.norm1           1024            
vae.decoder.up.3.block.0.conv1           513             
vae.decoder.up.3.block.0.norm2           1024            
vae.decoder.up.3.block.0.conv2           513             
vae.decoder.up.3.block.1                 3074            
vae.decoder.up.3.block.1.norm1           1024            
vae.decoder.up.3.block.1.conv1           513             
vae.decoder.up.3.block.1.norm2           1024            
vae.decoder.up.3.block.1.conv2           513             
vae.decoder.up.3.block.2                 3074            
vae.decoder.up.3.block.2.norm1           1024            
vae.decoder.up.3.block.2.conv1           513             
vae.decoder.up.3.block.2.norm2           1024            
vae.decoder.up.3.block.2.conv2           513             
vae.decoder.up.3.upsample                513             
vae.decoder.up.3.upsample.conv           513             
vae.decoder.norm_out                     256             
vae.decoder.conv_out                     4     

You can verify which parameters were changed, and by how much, by comparing the original and fine-tuned weights using a script like this one.

whats the epochs or lr you used

lr 3e-4, not sure about epochs (training was split across a couple separate finetuning runs with different settings - maybe like 100 epochs total?)

@Escapist

firstly, one conv2d(weight matrices, [Cout, Cin, kw,kh]) with a single trainable scale? why "37802 trainable decoder parameters"? and is groupnorm weight and bias trainable?

Yeah, single scale per conv weight matrix. The norms are all trainable.

Trainable params by layer
Module                                   Trainable Parameters
-------------------------------------------------------------
vae                                      37802           
vae.decoder                              37802           
vae.decoder.conv_in                      513             
vae.decoder.mid                          9224            
vae.decoder.mid.block_1                  3074            
vae.decoder.mid.block_1.norm1            1024            
vae.decoder.mid.block_1.conv1            513             
vae.decoder.mid.block_1.norm2            1024            
vae.decoder.mid.block_1.conv2            513             
vae.decoder.mid.attn_1                   3076            
vae.decoder.mid.attn_1.norm              1024            
vae.decoder.mid.attn_1.q                 513             
vae.decoder.mid.attn_1.k                 513             
vae.decoder.mid.attn_1.v                 513             
vae.decoder.mid.attn_1.proj_out          513             
vae.decoder.mid.block_2                  3074            
vae.decoder.mid.block_2.norm1            1024            
vae.decoder.mid.block_2.conv1            513             
vae.decoder.mid.block_2.norm2            1024            
vae.decoder.mid.block_2.conv2            513             
vae.decoder.up                           27805           
vae.decoder.up.0                         2695            
vae.decoder.up.0.block                   2695            
vae.decoder.up.0.block.0                 1155            
vae.decoder.up.0.block.0.norm1           512             
vae.decoder.up.0.block.0.conv1           129             
vae.decoder.up.0.block.0.norm2           256             
vae.decoder.up.0.block.0.conv2           129             
vae.decoder.up.0.block.0.nin_shortcut    129             
vae.decoder.up.0.block.1                 770             
vae.decoder.up.0.block.1.norm1           256             
vae.decoder.up.0.block.1.conv1           129             
vae.decoder.up.0.block.1.norm2           256             
vae.decoder.up.0.block.1.conv2           129             
vae.decoder.up.0.block.2                 770             
vae.decoder.up.0.block.2.norm1           256             
vae.decoder.up.0.block.2.conv1           129             
vae.decoder.up.0.block.2.norm2           256             
vae.decoder.up.0.block.2.conv2           129             
vae.decoder.up.1                         5640            
vae.decoder.up.1.block                   5383            
vae.decoder.up.1.block.0                 2307            
vae.decoder.up.1.block.0.norm1           1024            
vae.decoder.up.1.block.0.conv1           257             
vae.decoder.up.1.block.0.norm2           512             
vae.decoder.up.1.block.0.conv2           257             
vae.decoder.up.1.block.0.nin_shortcut    257             
vae.decoder.up.1.block.1                 1538            
vae.decoder.up.1.block.1.norm1           512             
vae.decoder.up.1.block.1.conv1           257             
vae.decoder.up.1.block.1.norm2           512             
vae.decoder.up.1.block.1.conv2           257             
vae.decoder.up.1.block.2                 1538            
vae.decoder.up.1.block.2.norm1           512             
vae.decoder.up.1.block.2.conv1           257             
vae.decoder.up.1.block.2.norm2           512             
vae.decoder.up.1.block.2.conv2           257             
vae.decoder.up.1.upsample                257             
vae.decoder.up.1.upsample.conv           257             
vae.decoder.up.2                         9735            
vae.decoder.up.2.block                   9222            
vae.decoder.up.2.block.0                 3074            
vae.decoder.up.2.block.0.norm1           1024            
vae.decoder.up.2.block.0.conv1           513             
vae.decoder.up.2.block.0.norm2           1024            
vae.decoder.up.2.block.0.conv2           513             
vae.decoder.up.2.block.1                 3074            
vae.decoder.up.2.block.1.norm1           1024            
vae.decoder.up.2.block.1.conv1           513             
vae.decoder.up.2.block.1.norm2           1024            
vae.decoder.up.2.block.1.conv2           513             
vae.decoder.up.2.block.2                 3074            
vae.decoder.up.2.block.2.norm1           1024            
vae.decoder.up.2.block.2.conv1           513             
vae.decoder.up.2.block.2.norm2           1024            
vae.decoder.up.2.block.2.conv2           513             
vae.decoder.up.2.upsample                513             
vae.decoder.up.2.upsample.conv           513             
vae.decoder.up.3                         9735            
vae.decoder.up.3.block                   9222            
vae.decoder.up.3.block.0                 3074            
vae.decoder.up.3.block.0.norm1           1024            
vae.decoder.up.3.block.0.conv1           513             
vae.decoder.up.3.block.0.norm2           1024            
vae.decoder.up.3.block.0.conv2           513             
vae.decoder.up.3.block.1                 3074            
vae.decoder.up.3.block.1.norm1           1024            
vae.decoder.up.3.block.1.conv1           513             
vae.decoder.up.3.block.1.norm2           1024            
vae.decoder.up.3.block.1.conv2           513             
vae.decoder.up.3.block.2                 3074            
vae.decoder.up.3.block.2.norm1           1024            
vae.decoder.up.3.block.2.conv1           513             
vae.decoder.up.3.block.2.norm2           1024            
vae.decoder.up.3.block.2.conv2           513             
vae.decoder.up.3.upsample                513             
vae.decoder.up.3.upsample.conv           513             
vae.decoder.norm_out                     256             
vae.decoder.conv_out                     4     

You can verify which parameters were changed, and by how much, by comparing the original and fine-tuned weights using a script like this one.

whats the epochs or lr you used

lr 3e-4, not sure about epochs (training was split across a couple separate finetuning runs with different settings - maybe like 100 epochs total?)

@madebyollin Thanks for your explanation. Really helpful! I have checked "Trainable params by layer". For example, "vae.decoder.mid.block_1.conv1 513" means that conv weight matrix has one trainable scaler, and bias with 512 trainable scaler, total 513? or another explanation,conv weight matrix has Cout_num trainable scaler,and bias with one trainable scaler.

Yeah, 512 trainable bias parameters and 1 trainable scale for the weight matrix

Sign up or log in to comment