multimodalart's picture
Upload folder using huggingface_hub
0366b8b verified
raw
history blame
No virus
4.42 kB
model:
pretrained_checkpoint: checkpoints/dynamicrafter_512_v1/model.ckpt
base_learning_rate: 1.0e-05
scale_lr: False
target: lvdm.models.ddpm3d.LatentVisualDiffusion
params:
rescale_betas_zero_snr: True
parameterization: "v"
linear_start: 0.00085
linear_end: 0.012
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: video
cond_stage_key: caption
cond_stage_trainable: False
image_proj_model_trainable: True
conditioning_key: hybrid
image_size: [40, 64]
channels: 4
scale_by_std: False
scale_factor: 0.18215
use_ema: False
uncond_prob: 0.05
uncond_type: 'empty_seq'
rand_cond_frame: true
use_dynamic_rescale: true
base_scale: 0.7
fps_condition_type: 'fps'
perframe_ae: True
unet_config:
target: lvdm.modules.networks.openaimodel3d.UNetModel
params:
in_channels: 8
out_channels: 4
model_channels: 320
attention_resolutions:
- 4
- 2
- 1
num_res_blocks: 2
channel_mult:
- 1
- 2
- 4
- 4
dropout: 0.1
num_head_channels: 64
transformer_depth: 1
context_dim: 1024
use_linear: true
use_checkpoint: True
temporal_conv: True
temporal_attention: True
temporal_selfatt_only: true
use_relative_position: false
use_causal_attention: False
temporal_length: 16
addition_attention: true
image_cross_attention: true
default_fs: 10
fs_condition: true
first_stage_config:
target: lvdm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: True
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: lvdm.modules.encoders.condition.FrozenOpenCLIPEmbedder
params:
freeze: true
layer: "penultimate"
img_cond_stage_config:
target: lvdm.modules.encoders.condition.FrozenOpenCLIPImageEmbedderV2
params:
freeze: true
image_proj_stage_config:
target: lvdm.modules.encoders.resampler.Resampler
params:
dim: 1024
depth: 4
dim_head: 64
heads: 12
num_queries: 16
embedding_dim: 1280
output_dim: 1024
ff_mult: 4
video_length: 16
data:
target: utils_data.DataModuleFromConfig
params:
batch_size: 2
num_workers: 12
wrap: false
train:
target: lvdm.data.webvid.WebVid
params:
data_dir: <WebVid10M DATA>
meta_path: <.csv FILE>
video_length: 16
frame_stride: 6
load_raw_resolution: true
resolution: [320, 512]
spatial_transform: resize_center_crop
random_fs: true ## if true, we uniformly sample fs with max_fs=frame_stride (above)
lightning:
precision: 16
# strategy: deepspeed_stage_2
trainer:
benchmark: True
accumulate_grad_batches: 2
max_steps: 100000
# logger
log_every_n_steps: 50
# val
val_check_interval: 0.5
gradient_clip_algorithm: 'norm'
gradient_clip_val: 0.5
callbacks:
model_checkpoint:
target: pytorch_lightning.callbacks.ModelCheckpoint
params:
every_n_train_steps: 9000 #1000
filename: "{epoch}-{step}"
save_weights_only: True
metrics_over_trainsteps_checkpoint:
target: pytorch_lightning.callbacks.ModelCheckpoint
params:
filename: '{epoch}-{step}'
save_weights_only: True
every_n_train_steps: 10000 #20000 # 3s/step*2w=
batch_logger:
target: callbacks.ImageLogger
params:
batch_frequency: 500
to_local: False
max_images: 8
log_images_kwargs:
ddim_steps: 50
unconditional_guidance_scale: 7.5
timestep_spacing: uniform_trailing
guidance_rescale: 0.7