File size: 4,422 Bytes
3b3fc8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
model:
  pretrained_checkpoint: checkpoints/dynamicrafter_512_v1/model.ckpt
  base_learning_rate: 1.0e-05
  scale_lr: False
  target: lvdm.models.ddpm3d.LatentVisualDiffusion
  params:
    rescale_betas_zero_snr: True
    parameterization: "v"
    linear_start: 0.00085
    linear_end: 0.012
    num_timesteps_cond: 1
    log_every_t: 200
    timesteps: 1000
    first_stage_key: video
    cond_stage_key: caption
    cond_stage_trainable: False
    image_proj_model_trainable: True
    conditioning_key: hybrid
    image_size: [40, 64]
    channels: 4
    scale_by_std: False
    scale_factor: 0.18215
    use_ema: False
    uncond_prob: 0.05
    uncond_type: 'empty_seq'
    rand_cond_frame: true
    use_dynamic_rescale: true
    base_scale: 0.7
    fps_condition_type: 'fps'
    perframe_ae: True

    unet_config:
      target: lvdm.modules.networks.openaimodel3d.UNetModel
      params:
        in_channels: 8
        out_channels: 4
        model_channels: 320
        attention_resolutions:
        - 4
        - 2
        - 1
        num_res_blocks: 2
        channel_mult:
        - 1
        - 2
        - 4
        - 4
        dropout: 0.1
        num_head_channels: 64
        transformer_depth: 1
        context_dim: 1024
        use_linear: true
        use_checkpoint: True
        temporal_conv: True
        temporal_attention: True
        temporal_selfatt_only: true
        use_relative_position: false
        use_causal_attention: False
        temporal_length: 16
        addition_attention: true
        image_cross_attention: true
        default_fs: 10
        fs_condition: true

    first_stage_config:
      target: lvdm.models.autoencoder.AutoencoderKL
      params:
        embed_dim: 4
        monitor: val/rec_loss
        ddconfig:
          double_z: True
          z_channels: 4
          resolution: 256
          in_channels: 3
          out_ch: 3
          ch: 128
          ch_mult:
          - 1
          - 2
          - 4
          - 4
          num_res_blocks: 2
          attn_resolutions: []
          dropout: 0.0
        lossconfig:
          target: torch.nn.Identity

    cond_stage_config:
      target: lvdm.modules.encoders.condition.FrozenOpenCLIPEmbedder
      params:
        freeze: true
        layer: "penultimate"

    img_cond_stage_config:
      target: lvdm.modules.encoders.condition.FrozenOpenCLIPImageEmbedderV2
      params:
        freeze: true
    
    image_proj_stage_config:
      target: lvdm.modules.encoders.resampler.Resampler
      params:
        dim: 1024
        depth: 4
        dim_head: 64
        heads: 12
        num_queries: 16
        embedding_dim: 1280
        output_dim: 1024
        ff_mult: 4
        video_length: 16

data:
  target: utils_data.DataModuleFromConfig
  params:
    batch_size: 2
    num_workers: 12
    wrap: false
    train:
      target: lvdm.data.webvid.WebVid
      params:
        data_dir: <WebVid10M DATA>
        meta_path: <.csv FILE>
        video_length: 16
        frame_stride: 6
        load_raw_resolution: true
        resolution: [320, 512]
        spatial_transform: resize_center_crop
        random_fs: true  ## if true, we uniformly sample fs with max_fs=frame_stride (above)

lightning:
  precision: 16
  # strategy: deepspeed_stage_2
  trainer:
    benchmark: True
    accumulate_grad_batches: 2
    max_steps: 100000
    # logger
    log_every_n_steps: 50
    # val
    val_check_interval: 0.5
    gradient_clip_algorithm: 'norm'
    gradient_clip_val: 0.5
  callbacks:
    model_checkpoint:
      target: pytorch_lightning.callbacks.ModelCheckpoint
      params:
        every_n_train_steps: 9000 #1000
        filename: "{epoch}-{step}"
        save_weights_only: True
    metrics_over_trainsteps_checkpoint:
      target: pytorch_lightning.callbacks.ModelCheckpoint
      params:
        filename: '{epoch}-{step}'
        save_weights_only: True
        every_n_train_steps: 10000 #20000 # 3s/step*2w=
    batch_logger:
      target: callbacks.ImageLogger
      params:
        batch_frequency: 500
        to_local: False
        max_images: 8
        log_images_kwargs:
          ddim_steps: 50
          unconditional_guidance_scale: 7.5
          timestep_spacing: uniform_trailing
          guidance_rescale: 0.7