V-Prediction Loss Weighting Test

Notice

This repository contains personal experimental records. No guarantees are made regarding accuracy or reproducibility.
These models are for verification purposes only and are not intended for general use.

Overview

This repository is a test project comparing different loss weighting schemes for Stable Diffusion v-prediction training.

In this test, we focus on two loss weighting schemes that performed particularly well in the previous test:

  1. edm2_weight (edm2_new)
  2. normal_weight

The goal of this test is to evaluate their performance in a practical fine-tuning scenario, specifically:

  • Learning new concepts
  • Strengthening existing concepts

This allows us to assess their effectiveness in common additional training tasks.

Environment

  • sd-scripts dev branch
    • Commit hash: [6adb69b] + Modified

Test Cases

This repository includes test models using the following weighting schemes:

  1. FT_test_normal

    • Baseline model using standard weighting
  2. FT_test_edm2

    • EDM2 loss weighting scheme
    • Adaptive weighting to optimize learning at each timestep
    • Implementation by madman404
    • Note: The initial implementation contained errors, which have been corrected in the updated version below(FT_test_edm2_fixed).See bottom for details.
  3. FT_test_edm2_fixed

    • Corrected implementation of EDM2 loss weighting

    • Modified MLP hyperparameters:

      • Learning rate: 1.5e-2
      • Optimizer: Adafactor
      • Learning rate scheduler: Similar to inverse square root decay
    • Added debugging tools in sdxl_train.py:

      • TensorBoard logging for loss_scaled and MLP learning rate (MLP_lr)
      • Step-by-step saving of loss_weights_MLP = 1 / torch.exp(adaptive_loss_weights)

Training Parameters

  • Base model: noobaiXLNAIXL_vPred06
  • Dataset:
    • 50k samples from danbooru2023
    • 4k samples for verification (training targets, num_repeats=2)
  • Epochs: 4
  • U-Net only
  • Learning rate: 3.5e-6
  • Effective batch size: 16
  • Optimizer: Adafactor (stochastic rounding)

Each model uses sdxl_train.py in each model directory
(sdxl_train.py and lossweightMLP.py for edm2_weight).

For detailed parameters, please refer to the .toml files in each model directory.

Dataset Information

The dataset used for testing consists of:

  • 50,000 images extracted from danbooru2023
  • 4,000 carefully selected images for verification (training targets, num_repeats=2)
    • Includes various styles and pure black/white images (~20 each)

Learning Targets

  • For the list of training targets, refer to training_targets.txt.
  • For the list of art styles, you can also refer to wildcard_artstyle.txt.

Tag Format

The training follows the tag format from Kohaku-XL-Epsilon:
<1girl/1boy/1other/...>, <character>, <series>, <artists>, <general tags>, <quality tags>, <year tags>, <meta tags>, <rating tags>

Additional Notes on FT_test_edm2_fixed

The previous implementation of FT_test_edm2 contained an error that caused instability during training. Specifically, due to an implementation issue, the loss weights failed to converge properly. This issue has been corrected in the updated FT_test_edm2_fixed model. The corrected version includes changes to MLP hyperparameters and additional debugging tools to ensure proper functionality.

  • before implementation(FT_test_edm2) image/gif
  • after implementation(FT_test_edm2_fixed) image/gif

Acknowledgments

I would like to extend my gratitude to everyone in the ArtiWaifu Discord community for their invaluable support in testing, implementation, and debugging. Special thanks to a and madman404 for providing the implementations that made this project possible.


This model card was written with the assistance of Claude 3.5 Sonnet.

Downloads last month
8
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Model tree for kawaimasa/test_FT_loss_weight