| | """Train streaming motion generation model (MotionStreamer) with llama blocks, Two-Forward strategy and QK-Norm, using the motion latents encoded by the Causal TAE (trained in the first stage).""" |
| |
|
| | import os |
| | import torch |
| | import numpy as np |
| | import random |
| | from torch.utils.tensorboard import SummaryWriter |
| | import json |
| | from accelerate import Accelerator |
| | from models.llama_model import LLaMAHF, LLaMAHFConfig |
| | import options.option_transformer as option_trans |
| | import utils.utils_model as utils_model |
| | import warnings |
| | from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR |
| | warnings.filterwarnings('ignore') |
| |
|
| | os.environ["TOKENIZERS_PARALLELISM"] = "false" |
| |
|
| | |
| | args = option_trans.get_args_parser() |
| | torch.manual_seed(args.seed) |
| |
|
| | def unwrap(m): |
| | return m.module if hasattr(m, 'module') else m |
| |
|
| | |
| | class WarmupCosineDecayScheduler: |
| | def __init__(self, optimizer, warmup_iters, total_iters, min_lr=0): |
| | self.optimizer = optimizer |
| | self.warmup_iters = warmup_iters |
| | self.total_iters = total_iters |
| | self.min_lr = min_lr |
| | self.warmup_scheduler = LambdaLR(optimizer, lr_lambda=self.warmup_lambda) |
| | self.cosine_scheduler = CosineAnnealingLR(optimizer, T_max=total_iters - warmup_iters, eta_min=min_lr) |
| |
|
| | def warmup_lambda(self, current_iter): |
| | if current_iter < self.warmup_iters: |
| | return float(current_iter) / float(max(1, self.warmup_iters)) |
| | return 1.0 |
| |
|
| | def step(self, current_iter): |
| | if current_iter < self.warmup_iters: |
| | self.warmup_scheduler.step() |
| | else: |
| | self.cosine_scheduler.step() |
| |
|
| | def state_dict(self): |
| | return {'warmup_iters': self.warmup_iters, 'total_iters': self.total_iters, 'min_lr': self.min_lr} |
| |
|
| | def load_state_dict(self, state_dict): |
| | self.warmup_iters = state_dict['warmup_iters'] |
| | self.total_iters = state_dict['total_iters'] |
| | self.min_lr = state_dict['min_lr'] |
| |
|
| | args.out_dir = os.path.join(args.out_dir, f'{args.exp_name}') |
| | os.makedirs(args.out_dir, exist_ok=True) |
| |
|
| | |
| | accelerator = Accelerator() |
| | comp_device = accelerator.device |
| |
|
| | |
| | logger = utils_model.get_logger(args.out_dir) |
| | writer = SummaryWriter(args.out_dir) |
| | logger.info(json.dumps(vars(args), indent=4, sort_keys=True)) |
| |
|
| | |
| | from humanml3d_272 import dataset_TM_train_motionstreamer |
| | train_loader = dataset_TM_train_motionstreamer.DATALoader( |
| | args.dataname, args.batch_size, unit_length=2**args.down_t, latent_dir=args.latent_dir |
| | ) |
| |
|
| | |
| | from sentence_transformers import SentenceTransformer |
| | t5_model = SentenceTransformer("sentence-t5-xl", device=comp_device) |
| | t5_model.half() |
| | t5_model.eval() |
| | for p in t5_model.parameters(): |
| | p.requires_grad = False |
| |
|
| | |
| | config = LLaMAHFConfig.from_name('Normal_size') |
| | |
| | |
| |
|
| | trans_encoder = LLaMAHF( |
| | config=config, |
| | num_diffusion_head_layers=args.num_diffusion_head_layers, |
| | input_token_dim=args.latent_dim, |
| | device=comp_device, |
| | |
| | |
| | ) |
| |
|
| | if args.resume_trans is not None: |
| | print('loading transformer checkpoint from {}'.format(args.resume_trans)) |
| | ckpt = torch.load(args.resume_trans, map_location='cpu') |
| | new_ckpt_trans = {} |
| | for key in ckpt['trans'].keys(): |
| | new_key = '.'.join(key.split('.')[1:]) if key.split('.')[0]=='module' else key |
| | new_ckpt_trans[new_key] = ckpt['trans'][key] |
| | trans_encoder.load_state_dict(new_ckpt_trans, strict=True) |
| |
|
| | trans_encoder.train() |
| | trans_encoder.to(comp_device) |
| |
|
| | |
| | optimizer = utils_model.initial_optim(args.decay_option, args.lr, args.weight_decay, trans_encoder, args.optimizer) |
| | scheduler = WarmupCosineDecayScheduler(optimizer, args.total_iter//10, args.total_iter) |
| |
|
| | t5_model, trans_encoder, optimizer, train_loader = accelerator.prepare( |
| | t5_model, trans_encoder, optimizer, train_loader |
| | ) |
| | base = accelerator.unwrap_model(trans_encoder) |
| | train_loader_iter = dataset_TM_train_motionstreamer.cycle(train_loader) |
| |
|
| | args.dit_window = 2 |
| |
|
| | def lengths_to_mask(lengths, max_len): |
| | return torch.arange(max_len, device=lengths.device).expand(len(lengths), max_len) < lengths.unsqueeze(1) |
| |
|
| | import math |
| | def cosine_decay(step, total_steps, start_value=1.0, end_value=0.0): |
| | step = torch.tensor(step, dtype=torch.float32) |
| | total_steps = torch.tensor(total_steps, dtype=torch.float32) |
| | cosine_factor = 0.5 * (1 + torch.cos(torch.pi * step / total_steps)) |
| | return start_value + (end_value - start_value) * cosine_factor |
| |
|
| | def replace_with_pred(latents, pred_xstart, step, total_steps): |
| | decay_factor = cosine_decay(step, total_steps).to(latents.device) |
| | b, l, d = latents.shape |
| | num_replace = int(l * decay_factor) |
| | replace_indices = torch.randperm(l, device=latents.device)[:num_replace] |
| | replace_mask = torch.zeros(b, l, dtype=torch.bool, device=latents.device) |
| | replace_mask[:, replace_indices] = 1 |
| | updated_latents = latents.clone() |
| | updated_latents[replace_mask] = pred_xstart[replace_mask] |
| | return updated_latents |
| |
|
| | |
| | def forward_loss_withmask_2_forward_streaming(latents, trans, m_lens, feat_text, |
| | step, total_steps, A_token_length, K=None): |
| | """ |
| | Two-Forward with a *windowed* Temporal-DiT: |
| | - AR sees full sequence. |
| | - Diffusion head sees only last K positions (causal). |
| | """ |
| | K = K or getattr(args, "dit_window", 2) |
| |
|
| | latents = latents.to(comp_device) |
| | feat_text = feat_text.to(comp_device) |
| | A_token_length = A_token_length.to(comp_device) |
| |
|
| | B, L, D = latents.shape |
| | L_eff = L - 1 |
| | if L_eff <= 0: |
| | raise ValueError("Sequence too short for next-token training.") |
| |
|
| | base.set_prompt(feat_text) |
| |
|
| | |
| | conditions = trans(latents, feature=None) |
| | |
| | z_full = conditions[:, 1:-1, :] |
| | target_full = latents[:, 1:, :] |
| |
|
| | |
| | eff_lens = (m_lens - 1).clamp(min=0) |
| | full_mask = torch.arange(L_eff, device=latents.device).unsqueeze(0).expand(B, L_eff) < eff_lens.unsqueeze(1) |
| | |
| | for b in range(B): |
| | a_excl = max(0, A_token_length[b].item() - 1) |
| | if a_excl > 0: |
| | full_mask[b, :a_excl] = False |
| |
|
| | |
| | W = min(K, L_eff) |
| | tail_start = L_eff - W |
| | z = z_full[:, tail_start:, :] |
| | target = target_full[:, tail_start:, :] |
| | mask = full_mask[:, tail_start:] |
| | mask_flat = mask.reshape(B * W).float() |
| |
|
| | |
| | base.diff_loss.set_sequence_layout(B, W) |
| |
|
| | |
| | with torch.no_grad(): |
| | |
| | loss0, pred_xstart_full = base.diff_loss( |
| | target=target.reshape(B * W, D), |
| | z=z.reshape(B * W, -1), |
| | mask=None |
| | ) |
| | pred_xstart = pred_xstart_full.view(B, W, D) |
| |
|
| | |
| | for b in range(B): |
| | a_excl = max(0, A_token_length[b].item() - 1) |
| | |
| | |
| | cut = max(0, min(W, a_excl - tail_start)) |
| | if cut > 0: |
| | pred_xstart[b, :cut, :] = target[b, :cut, :] |
| |
|
| | |
| | decay_ratio = 0.5 * (1.0 + torch.cos( |
| | torch.pi * torch.tensor(step, dtype=torch.float32, device=latents.device) |
| | / torch.tensor(total_steps, dtype=torch.float32, device=latents.device) |
| | )).item() |
| | k = int(W * decay_ratio) |
| |
|
| | updated_latents = latents.clone() |
| | if k > 0: |
| | replace_idx = torch.randperm(W, device=latents.device)[:k] |
| | |
| | raw_positions = 1 + tail_start + replace_idx |
| | |
| | updated_latents[:, raw_positions, :] = pred_xstart[:, replace_idx, :] |
| |
|
| | |
| | updated_conditions = trans(updated_latents, feature=None) |
| | updated_z_full = updated_conditions[:, 1:-1, :] |
| | updated_z = updated_z_full[:, tail_start:, :] |
| |
|
| | updated_loss, _ = base.diff_loss( |
| | target=target.reshape(B * W, D), |
| | z=updated_z.reshape(B * W, -1), |
| | mask=mask_flat |
| | ) |
| | return updated_loss |
| |
|
| | |
| | nb_iter, avg_loss_cls = 0, 0.0 |
| |
|
| | while nb_iter <= args.total_iter: |
| | batch = next(train_loader_iter) |
| | caption, m_tokens, m_tokens_len, A_token_length = batch |
| | caption = list(caption) |
| | m_tokens, m_tokens_len = m_tokens.to(comp_device), m_tokens_len.to(comp_device) |
| | A_token_length = A_token_length.to(comp_device) |
| |
|
| | |
| | bs = len(caption) |
| | num_masked = int(bs * 0.1) |
| | if num_masked > 0: |
| | for idx in random.sample(range(bs), num_masked): |
| | caption[idx] = '' |
| |
|
| | |
| | feat_text = torch.from_numpy(t5_model.encode(caption)).float().to(comp_device) |
| |
|
| | |
| | input_latent = m_tokens[:, :-1, :] |
| |
|
| | loss_cls = forward_loss_withmask_2_forward_streaming( |
| | latents=input_latent, |
| | trans=trans_encoder, |
| | m_lens=m_tokens_len, |
| | feat_text=feat_text, |
| | step=nb_iter, |
| | total_steps=args.total_iter, |
| | A_token_length=A_token_length, |
| | K=args.dit_window, |
| | ) |
| |
|
| | |
| | optimizer.zero_grad() |
| | accelerator.backward(loss_cls) |
| | optimizer.step() |
| | scheduler.step(nb_iter) |
| |
|
| | avg_loss_cls += loss_cls.item() |
| | nb_iter += 1 |
| |
|
| | |
| | args.print_iter = 100 |
| | if nb_iter % args.print_iter == 0: |
| | if accelerator.is_main_process: |
| | avg_loss_cls = avg_loss_cls / args.print_iter |
| | writer.add_scalar('./Loss/train', avg_loss_cls, nb_iter) |
| | writer.add_scalar('./LR/train', optimizer.param_groups[0]['lr'], nb_iter) |
| | logger.info(f"Train. Iter {nb_iter} : Loss. {avg_loss_cls:.5f}") |
| | avg_loss_cls = 0.0 |
| |
|
| | |
| | args.save_iter = 10000 |
| | if nb_iter % args.save_iter == 0: |
| | if accelerator.is_main_process: |
| | torch.save({'trans': unwrap(trans_encoder).state_dict()}, |
| | os.path.join(args.out_dir, f'latest.pth')) |
| |
|
| | accelerator.wait_for_everyone() |
| |
|