KrutikaBM commited on
Commit
c766955
1 Parent(s): 4b7336f

Upload train_tuneavideo.py

Browse files
Files changed (1) hide show
  1. train_tuneavideo.py +367 -0
train_tuneavideo.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import logging
4
+ import inspect
5
+ import math
6
+ import os
7
+ from typing import Dict, Optional, Tuple
8
+ from omegaconf import OmegaConf
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+ import torch.utils.checkpoint
13
+
14
+ import diffusers
15
+ import transformers
16
+ from accelerate import Accelerator
17
+ from accelerate.logging import get_logger
18
+ from accelerate.utils import set_seed
19
+ from diffusers import AutoencoderKL, DDPMScheduler, DDIMScheduler
20
+ from diffusers.optimization import get_scheduler
21
+ from diffusers.utils import check_min_version
22
+ from diffusers.utils.import_utils import is_xformers_available
23
+ from tqdm.auto import tqdm
24
+ from transformers import CLIPTextModel, CLIPTokenizer
25
+
26
+ from tuneavideo.models.unet import UNet3DConditionModel
27
+ from tuneavideo.data.dataset import TuneAVideoDataset
28
+ from tuneavideo.pipelines.pipeline_tuneavideo import TuneAVideoPipeline
29
+ from tuneavideo.util import save_videos_grid, ddim_inversion
30
+ from einops import rearrange
31
+
32
+
33
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
34
+ check_min_version("0.10.0.dev0")
35
+
36
+ logger = get_logger(__name__, log_level="INFO")
37
+
38
+
39
+ def main(
40
+ pretrained_model_path: str,
41
+ output_dir: str,
42
+ train_data: Dict,
43
+ validation_data: Dict,
44
+ validation_steps: int = 100,
45
+ trainable_modules: Tuple[str] = (
46
+ "attn1.to_q",
47
+ "attn2.to_q",
48
+ "attn_temp",
49
+ ),
50
+ train_batch_size: int = 1,
51
+ max_train_steps: int = 500,
52
+ learning_rate: float = 3e-5,
53
+ scale_lr: bool = False,
54
+ lr_scheduler: str = "constant",
55
+ lr_warmup_steps: int = 0,
56
+ adam_beta1: float = 0.9,
57
+ adam_beta2: float = 0.999,
58
+ adam_weight_decay: float = 1e-2,
59
+ adam_epsilon: float = 1e-08,
60
+ max_grad_norm: float = 1.0,
61
+ gradient_accumulation_steps: int = 1,
62
+ gradient_checkpointing: bool = True,
63
+ checkpointing_steps: int = 500,
64
+ resume_from_checkpoint: Optional[str] = None,
65
+ mixed_precision: Optional[str] = "fp16",
66
+ use_8bit_adam: bool = False,
67
+ enable_xformers_memory_efficient_attention: bool = True,
68
+ seed: Optional[int] = None,
69
+ ):
70
+ *_, config = inspect.getargvalues(inspect.currentframe())
71
+
72
+ accelerator = Accelerator(
73
+ gradient_accumulation_steps=gradient_accumulation_steps,
74
+ mixed_precision=mixed_precision,
75
+ )
76
+
77
+ # Make one log on every process with the configuration for debugging.
78
+ logging.basicConfig(
79
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
80
+ datefmt="%m/%d/%Y %H:%M:%S",
81
+ level=logging.INFO,
82
+ )
83
+ logger.info(accelerator.state, main_process_only=False)
84
+ if accelerator.is_local_main_process:
85
+ transformers.utils.logging.set_verbosity_warning()
86
+ diffusers.utils.logging.set_verbosity_info()
87
+ else:
88
+ transformers.utils.logging.set_verbosity_error()
89
+ diffusers.utils.logging.set_verbosity_error()
90
+
91
+ # If passed along, set the training seed now.
92
+ if seed is not None:
93
+ set_seed(seed)
94
+
95
+ # Handle the output folder creation
96
+ if accelerator.is_main_process:
97
+ # now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
98
+ # output_dir = os.path.join(output_dir, now)
99
+ os.makedirs(output_dir, exist_ok=True)
100
+ os.makedirs(f"{output_dir}/samples", exist_ok=True)
101
+ os.makedirs(f"{output_dir}/inv_latents", exist_ok=True)
102
+ OmegaConf.save(config, os.path.join(output_dir, 'config.yaml'))
103
+
104
+ # Load scheduler, tokenizer and models.
105
+ noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler")
106
+ tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
107
+ text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder")
108
+ vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae")
109
+ unet = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet")
110
+
111
+ # Freeze vae and text_encoder
112
+ vae.requires_grad_(False)
113
+ text_encoder.requires_grad_(False)
114
+
115
+ unet.requires_grad_(False)
116
+ for name, module in unet.named_modules():
117
+ if name.endswith(tuple(trainable_modules)):
118
+ for params in module.parameters():
119
+ params.requires_grad = True
120
+
121
+ if enable_xformers_memory_efficient_attention:
122
+ if is_xformers_available():
123
+ unet.enable_xformers_memory_efficient_attention()
124
+ else:
125
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
126
+
127
+ if gradient_checkpointing:
128
+ unet.enable_gradient_checkpointing()
129
+
130
+ if scale_lr:
131
+ learning_rate = (
132
+ learning_rate * gradient_accumulation_steps * train_batch_size * accelerator.num_processes
133
+ )
134
+
135
+ # Initialize the optimizer
136
+ if use_8bit_adam:
137
+ try:
138
+ import bitsandbytes as bnb
139
+ except ImportError:
140
+ raise ImportError(
141
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
142
+ )
143
+
144
+ optimizer_cls = bnb.optim.AdamW8bit
145
+ else:
146
+ optimizer_cls = torch.optim.AdamW
147
+
148
+ optimizer = optimizer_cls(
149
+ unet.parameters(),
150
+ lr=learning_rate,
151
+ betas=(adam_beta1, adam_beta2),
152
+ weight_decay=adam_weight_decay,
153
+ eps=adam_epsilon,
154
+ )
155
+
156
+ # Get the training dataset
157
+ train_dataset = TuneAVideoDataset(**train_data)
158
+
159
+ # Preprocessing the dataset
160
+ train_dataset.prompt_ids = tokenizer(
161
+ train_dataset.prompt, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
162
+ ).input_ids[0]
163
+
164
+ # DataLoaders creation:
165
+ train_dataloader = torch.utils.data.DataLoader(
166
+ train_dataset, batch_size=train_batch_size
167
+ )
168
+
169
+ # Get the validation pipeline
170
+ validation_pipeline = TuneAVideoPipeline(
171
+ vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet,
172
+ scheduler=DDIMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler")
173
+ )
174
+ validation_pipeline.enable_vae_slicing()
175
+ ddim_inv_scheduler = DDIMScheduler.from_pretrained(pretrained_model_path, subfolder='scheduler')
176
+ ddim_inv_scheduler.set_timesteps(validation_data.num_inv_steps)
177
+
178
+ # Scheduler
179
+ lr_scheduler = get_scheduler(
180
+ lr_scheduler,
181
+ optimizer=optimizer,
182
+ num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps,
183
+ num_training_steps=max_train_steps * gradient_accumulation_steps,
184
+ )
185
+
186
+ # Prepare everything with our `accelerator`.
187
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
188
+ unet, optimizer, train_dataloader, lr_scheduler
189
+ )
190
+
191
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
192
+ # as these models are only used for inference, keeping weights in full precision is not required.
193
+ weight_dtype = torch.float32
194
+ if accelerator.mixed_precision == "fp16":
195
+ weight_dtype = torch.float16
196
+ elif accelerator.mixed_precision == "bf16":
197
+ weight_dtype = torch.bfloat16
198
+
199
+ # Move text_encode and vae to gpu and cast to weight_dtype
200
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
201
+ vae.to(accelerator.device, dtype=weight_dtype)
202
+
203
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
204
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
205
+ # Afterwards we recalculate our number of training epochs
206
+ num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)
207
+
208
+ # We need to initialize the trackers we use, and also store our configuration.
209
+ # The trackers initializes automatically on the main process.
210
+ if accelerator.is_main_process:
211
+ accelerator.init_trackers("text2video-fine-tune")
212
+
213
+ # Train!
214
+ total_batch_size = train_batch_size * accelerator.num_processes * gradient_accumulation_steps
215
+
216
+ logger.info("***** Running training *****")
217
+ logger.info(f" Num examples = {len(train_dataset)}")
218
+ logger.info(f" Num Epochs = {num_train_epochs}")
219
+ logger.info(f" Instantaneous batch size per device = {train_batch_size}")
220
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
221
+ logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}")
222
+ logger.info(f" Total optimization steps = {max_train_steps}")
223
+ global_step = 0
224
+ first_epoch = 0
225
+
226
+ # Potentially load in the weights and states from a previous save
227
+ if resume_from_checkpoint:
228
+ if resume_from_checkpoint != "latest":
229
+ path = os.path.basename(resume_from_checkpoint)
230
+ else:
231
+ # Get the most recent checkpoint
232
+ dirs = os.listdir(output_dir)
233
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
234
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
235
+ path = dirs[-1]
236
+ accelerator.print(f"Resuming from checkpoint {path}")
237
+ accelerator.load_state(os.path.join(output_dir, path))
238
+ global_step = int(path.split("-")[1])
239
+
240
+ first_epoch = global_step // num_update_steps_per_epoch
241
+ resume_step = global_step % num_update_steps_per_epoch
242
+
243
+ # Only show the progress bar once on each machine.
244
+ progress_bar = tqdm(range(global_step, max_train_steps), disable=not accelerator.is_local_main_process)
245
+ progress_bar.set_description("Steps")
246
+
247
+ for epoch in range(first_epoch, num_train_epochs):
248
+ unet.train()
249
+ train_loss = 0.0
250
+ for step, batch in enumerate(train_dataloader):
251
+ # Skip steps until we reach the resumed step
252
+ if resume_from_checkpoint and epoch == first_epoch and step < resume_step:
253
+ if step % gradient_accumulation_steps == 0:
254
+ progress_bar.update(1)
255
+ continue
256
+
257
+ with accelerator.accumulate(unet):
258
+ # Convert videos to latent space
259
+ pixel_values = batch["pixel_values"].to(weight_dtype)
260
+ video_length = pixel_values.shape[1]
261
+ pixel_values = rearrange(pixel_values, "b f c h w -> (b f) c h w")
262
+ latents = vae.encode(pixel_values).latent_dist.sample()
263
+ latents = rearrange(latents, "(b f) c h w -> b c f h w", f=video_length)
264
+ latents = latents * 0.18215
265
+
266
+ # Sample noise that we'll add to the latents
267
+ noise = torch.randn_like(latents)
268
+ bsz = latents.shape[0]
269
+ # Sample a random timestep for each video
270
+ timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
271
+ timesteps = timesteps.long()
272
+
273
+ # Add noise to the latents according to the noise magnitude at each timestep
274
+ # (this is the forward diffusion process)
275
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
276
+
277
+ # Get the text embedding for conditioning
278
+ encoder_hidden_states = text_encoder(batch["prompt_ids"])[0]
279
+
280
+ # Get the target for loss depending on the prediction type
281
+ if noise_scheduler.prediction_type == "epsilon":
282
+ target = noise
283
+ elif noise_scheduler.prediction_type == "v_prediction":
284
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
285
+ else:
286
+ raise ValueError(f"Unknown prediction type {noise_scheduler.prediction_type}")
287
+
288
+ # Predict the noise residual and compute loss
289
+ model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
290
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
291
+
292
+ # Gather the losses across all processes for logging (if we use distributed training).
293
+ avg_loss = accelerator.gather(loss.repeat(train_batch_size)).mean()
294
+ train_loss += avg_loss.item() / gradient_accumulation_steps
295
+
296
+ # Backpropagate
297
+ accelerator.backward(loss)
298
+ if accelerator.sync_gradients:
299
+ accelerator.clip_grad_norm_(unet.parameters(), max_grad_norm)
300
+ optimizer.step()
301
+ lr_scheduler.step()
302
+ optimizer.zero_grad()
303
+
304
+ # Checks if the accelerator has performed an optimization step behind the scenes
305
+ if accelerator.sync_gradients:
306
+ progress_bar.update(1)
307
+ global_step += 1
308
+ accelerator.log({"train_loss": train_loss}, step=global_step)
309
+ train_loss = 0.0
310
+
311
+ if global_step % checkpointing_steps == 0:
312
+ if accelerator.is_main_process:
313
+ save_path = os.path.join(output_dir, f"checkpoint-{global_step}")
314
+ accelerator.save_state(save_path)
315
+ logger.info(f"Saved state to {save_path}")
316
+
317
+ if global_step % validation_steps == 0:
318
+ if accelerator.is_main_process:
319
+ samples = []
320
+ generator = torch.Generator(device=latents.device)
321
+ generator.manual_seed(seed)
322
+
323
+ ddim_inv_latent = None
324
+ if validation_data.use_inv_latent:
325
+ inv_latents_path = os.path.join(output_dir, f"inv_latents/ddim_latent-{global_step}.pt")
326
+ ddim_inv_latent = ddim_inversion(
327
+ validation_pipeline, ddim_inv_scheduler, video_latent=latents,
328
+ num_inv_steps=validation_data.num_inv_steps, prompt="")[-1].to(weight_dtype)
329
+ torch.save(ddim_inv_latent, inv_latents_path)
330
+
331
+ for idx, prompt in enumerate(validation_data.prompts):
332
+ sample = validation_pipeline(prompt, generator=generator, latents=ddim_inv_latent,
333
+ **validation_data).videos
334
+ save_videos_grid(sample, f"{output_dir}/samples/sample-{global_step}/{prompt}.gif")
335
+ samples.append(sample)
336
+ samples = torch.concat(samples)
337
+ save_path = f"{output_dir}/samples/sample-{global_step}.gif"
338
+ save_videos_grid(samples, save_path)
339
+ logger.info(f"Saved samples to {save_path}")
340
+
341
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
342
+ progress_bar.set_postfix(**logs)
343
+
344
+ if global_step >= max_train_steps:
345
+ break
346
+
347
+ # Create the pipeline using the trained modules and save it.
348
+ accelerator.wait_for_everyone()
349
+ if accelerator.is_main_process:
350
+ unet = accelerator.unwrap_model(unet)
351
+ pipeline = TuneAVideoPipeline.from_pretrained(
352
+ pretrained_model_path,
353
+ text_encoder=text_encoder,
354
+ vae=vae,
355
+ unet=unet,
356
+ )
357
+ pipeline.save_pretrained(output_dir)
358
+
359
+ accelerator.end_training()
360
+
361
+
362
+ if __name__ == "__main__":
363
+ parser = argparse.ArgumentParser()
364
+ parser.add_argument("--config", type=str, default="./configs/tuneavideo.yaml")
365
+ args = parser.parse_args()
366
+
367
+ main(**OmegaConf.load(args.config))