V3D / configs /example_training /autoencoder /kl-f4 /imagenet-kl_f8_8chn.yaml
heheyas
init
cfb7702
raw
history blame
No virus
2.52 kB
model:
base_learning_rate: 4.5e-6
target: sgm.models.autoencoder.AutoencodingEngine
params:
input_key: jpg
monitor: val/loss/rec
disc_start_iter: 0
encoder_config:
target: sgm.modules.diffusionmodules.model.Encoder
params:
attn_type: vanilla-xformers
double_z: true
z_channels: 8
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
decoder_config:
target: sgm.modules.diffusionmodules.model.Decoder
params: ${model.params.encoder_config.params}
regularizer_config:
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
loss_config:
target: sgm.modules.autoencoding.losses.GeneralLPIPSWithDiscriminator
params:
perceptual_weight: 0.25
disc_start: 20001
disc_weight: 0.5
learn_logvar: True
regularization_weights:
kl_loss: 1.0
data:
target: sgm.data.dataset.StableDataModuleFromConfig
params:
train:
datapipeline:
urls:
- DATA-PATH
pipeline_config:
shardshuffle: 10000
sample_shuffle: 10000
decoders:
- pil
postprocessors:
- target: sdata.mappers.TorchVisionImageTransforms
params:
key: jpg
transforms:
- target: torchvision.transforms.Resize
params:
size: 256
interpolation: 3
- target: torchvision.transforms.ToTensor
- target: sdata.mappers.Rescaler
- target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
params:
h_key: height
w_key: width
loader:
batch_size: 8
num_workers: 4
lightning:
strategy:
target: pytorch_lightning.strategies.DDPStrategy
params:
find_unused_parameters: True
modelcheckpoint:
params:
every_n_train_steps: 5000
callbacks:
metrics_over_trainsteps_checkpoint:
params:
every_n_train_steps: 50000
image_logger:
target: main.ImageLogger
params:
enable_autocast: False
batch_frequency: 1000
max_images: 8
increase_log_steps: True
trainer:
devices: 0,
limit_val_batches: 50
benchmark: True
accumulate_grad_batches: 1
val_check_interval: 10000