| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from torch.utils.data import DataLoader |
| | from transformers import LongT5ForConditionalGeneration, T5ForConditionalGeneration, T5Tokenizer |
| | from accelerate import Accelerator |
| | from accelerate.utils import set_seed |
| | from concurrent.futures import ThreadPoolExecutor |
| | import numpy as np |
| | from pathlib import Path |
| | import yaml |
| | from tqdm import tqdm |
| | from typing import Dict, List, Tuple, Optional |
| | import argparse |
| | import os |
| | import re |
| | import warnings |
| | from collections import defaultdict |
| | import time |
| | from datetime import datetime |
| | import sys |
| | import matplotlib |
| | matplotlib.use("Agg") |
| | import matplotlib.pyplot as plt |
| |
|
| | |
| | SCRIPT_DIR = Path(__file__).resolve().parent |
| | WAVEGEN_ROOT = SCRIPT_DIR.parent |
| | if str(WAVEGEN_ROOT) not in sys.path: |
| | sys.path.insert(0, str(WAVEGEN_ROOT)) |
| |
|
| | |
| | warnings.filterwarnings("ignore", message="Passing a tuple of `past_key_values` is deprecated") |
| |
|
| | from data.movi_dataset import create_dataloader |
| | from utils.save_generation_results import save_generation_results |
| |
|
| |
|
| | class Text2WaveModel(nn.Module): |
| | """Text to Superquadric Wave Parameters Model""" |
| | |
| | def __init__( |
| | self, |
| | model_name: str = "google/long-t5-tglobal-base", |
| | max_objects: int = 10, |
| | num_frames: int = 24, |
| | max_history_frames: int = 3, |
| | random_history_sampling: bool = True, |
| | decoder_noise_std: float = 0.0, |
| | ): |
| | super().__init__() |
| | |
| | self.max_objects = max_objects |
| | self.num_frames = num_frames |
| | self.max_history_frames = max_history_frames |
| | self.random_history_sampling = random_history_sampling |
| | self.decoder_noise_std = float(decoder_noise_std) |
| | |
| | self.object_param_dim = 15 |
| | |
| | |
| | self.model_name = model_name |
| | self.is_longt5 = "long-t5" in model_name.lower() |
| | self.tokenizer = T5Tokenizer.from_pretrained(model_name) |
| | if self.is_longt5: |
| | self.t5_model = LongT5ForConditionalGeneration.from_pretrained(model_name) |
| | else: |
| | self.t5_model = T5ForConditionalGeneration.from_pretrained(model_name) |
| | |
| | |
| | if self.tokenizer.vocab_size != self.t5_model.config.vocab_size: |
| | self.t5_model.resize_token_embeddings(self.tokenizer.vocab_size) |
| | |
| | |
| | self.hidden_size = self.t5_model.config.d_model |
| | |
| | |
| | |
| | self.object_proj = nn.Linear(self.hidden_size, max_objects * self.object_param_dim) |
| | |
| | |
| | self.world_proj = nn.Linear(self.hidden_size, 8) |
| | |
| | |
| | self.physics_proj = nn.Linear(self.hidden_size, max_objects * 3) |
| | |
| | |
| | self.time_embed = nn.Linear(1, self.hidden_size) |
| |
|
| | |
| | history_feature_dim = max_history_frames * (max_objects * self.object_param_dim + 8) + max_objects * 3 |
| | self.history_feature_dim = history_feature_dim |
| | self.history_proj = nn.Linear(history_feature_dim, self.hidden_size) |
| |
|
| | |
| | |
| | self._init_weights() |
| | |
| | def _init_weights(self): |
| | """Initialize weights for stability""" |
| | |
| | for module in [self.object_proj, self.world_proj, self.physics_proj]: |
| | nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| | nn.init.zeros_(module.bias) |
| | |
| | |
| | nn.init.normal_(self.time_embed.weight, mean=0.0, std=0.02) |
| | nn.init.zeros_(self.time_embed.bias) |
| |
|
| | |
| | nn.init.normal_(self.history_proj.weight, mean=0.0, std=0.02) |
| | nn.init.zeros_(self.history_proj.bias) |
| | |
| | def _initialize_history_state( |
| | self, |
| | history_frames: Optional[Dict[str, torch.Tensor]], |
| | batch_size: int, |
| | device: torch.device, |
| | ) -> Tuple[List[Dict[str, torch.Tensor]], torch.Tensor]: |
| | """Prepare history buffer and physics state for autoregressive decoding.""" |
| | history_buffer: List[Dict[str, torch.Tensor]] = [] |
| |
|
| | physics_state = torch.zeros( |
| | batch_size, |
| | self.max_objects, |
| | 3, |
| | device=device, |
| | dtype=torch.float32, |
| | ) |
| |
|
| | if history_frames is not None: |
| | objects_hist = history_frames.get('objects') |
| | world_hist = history_frames.get('world') |
| | physics_hist = history_frames.get('physics') |
| |
|
| | if physics_hist is not None: |
| | physics_state = physics_hist.to(device=device, dtype=torch.float32) |
| |
|
| | if objects_hist is not None and world_hist is not None: |
| | history_len = objects_hist.shape[1] |
| | for idx in range(history_len): |
| | history_buffer.append({ |
| | 'objects': objects_hist[:, idx, :, :self.object_param_dim].to(device=device, dtype=torch.float32), |
| | 'world': world_hist[:, idx, :8].to(device=device, dtype=torch.float32), |
| | }) |
| |
|
| | if len(history_buffer) == 0: |
| | history_buffer.append({ |
| | 'objects': torch.zeros(batch_size, self.max_objects, self.object_param_dim, device=device), |
| | 'world': torch.zeros(batch_size, 8, device=device), |
| | }) |
| |
|
| | history_buffer = history_buffer[-self.max_history_frames:] |
| |
|
| | return history_buffer, physics_state |
| |
|
| | def sample_decoder_noise(self, batch_size: int, device: torch.device) -> Optional[torch.Tensor]: |
| | """Sample decoder noise embedding when noise std > 0.""" |
| | if self.decoder_noise_std <= 0: |
| | return None |
| | noise = torch.randn(batch_size, self.hidden_size, device=device) |
| | return noise * self.decoder_noise_std |
| |
|
| | def _build_history_embedding( |
| | self, |
| | history_buffer: List[Dict[str, torch.Tensor]], |
| | physics_state: torch.Tensor, |
| | use_frames: int, |
| | ) -> torch.Tensor: |
| | """Convert most recent history frames into conditioning embedding.""" |
| | batch_size = physics_state.shape[0] |
| | device = physics_state.device |
| |
|
| | frame_dim = self.max_objects * self.object_param_dim + 8 |
| | history_tensor = torch.zeros( |
| | batch_size, |
| | self.max_history_frames * frame_dim, |
| | device=device, |
| | ) |
| |
|
| | use_frames = min(use_frames, self.max_history_frames) |
| | recent_frames = history_buffer[-use_frames:] if use_frames > 0 else [] |
| | for slot, frame in enumerate(recent_frames): |
| | offset = slot * frame_dim |
| | obj_flat = frame['objects'].reshape(batch_size, -1) |
| | world_feat = frame['world'] |
| | history_tensor[:, offset:offset + obj_flat.shape[1]] = obj_flat |
| | history_tensor[:, offset + obj_flat.shape[1]:offset + frame_dim] = world_feat |
| |
|
| | physics_flat = physics_state.reshape(batch_size, -1) |
| | history_features = torch.cat([history_tensor, physics_flat], dim=-1) |
| | return self.history_proj(history_features) |
| |
|
| | def forward( |
| | self, |
| | input_text: List[str], |
| | target_frames: torch.Tensor, |
| | history_frames: Optional[Dict[str, torch.Tensor]] = None, |
| | relative_times: torch.Tensor = None, |
| | static_object_params: Optional[torch.Tensor] = None, |
| | noise: Optional[torch.Tensor] = None, |
| | ): |
| | """ |
| | Forward pass for text to wave parameter generation |
| | |
| | Args: |
| | input_text: List of text descriptions |
| | target_frames: Target frame indices to predict |
| | history_frames: Optional history frames for conditioning |
| | relative_times: Relative time positions [-1, 1] for each target frame |
| | """ |
| | batch_size = len(input_text) |
| | num_target_frames = target_frames.shape[1] |
| | |
| | |
| | |
| | formatted_text = [f"translate to wave: {text}" for text in input_text] |
| | |
| | |
| | text_inputs = self.tokenizer( |
| | formatted_text, |
| | padding=True, |
| | truncation=True, |
| | max_length=512, |
| | return_tensors="pt" |
| | ).to(target_frames.device) |
| | |
| |
|
| | |
| | try: |
| | |
| | |
| | decoder_start_token_id = self.t5_model.config.pad_token_id |
| | decoder_input_ids = torch.full( |
| | (batch_size, 1), |
| | decoder_start_token_id, |
| | dtype=torch.long, |
| | device=text_inputs.input_ids.device |
| | ) |
| |
|
| | |
| | outputs = self.t5_model( |
| | input_ids=text_inputs.input_ids, |
| | attention_mask=text_inputs.attention_mask, |
| | decoder_input_ids=decoder_input_ids, |
| | return_dict=True, |
| | output_hidden_states=True |
| | ) |
| |
|
| | encoder_outputs = outputs.encoder_last_hidden_state |
| | except Exception as e: |
| | if 'log_message' in globals(): |
| | log_message(f"ERROR in encoder: {e}") |
| | else: |
| | print(f"ERROR in encoder: {e}") |
| | raise |
| |
|
| | |
| | history_buffer, physics_state = self._initialize_history_state( |
| | history_frames, |
| | batch_size, |
| | target_frames.device, |
| | ) |
| |
|
| | if static_object_params is not None: |
| | static_object_params = static_object_params.to( |
| | device=target_frames.device, |
| | dtype=torch.float32, |
| | ) |
| |
|
| | if noise is not None: |
| | noise = noise.to(device=encoder_outputs.device, dtype=encoder_outputs.dtype) |
| |
|
| | outputs = [] |
| |
|
| | for f in range(num_target_frames): |
| | if self.random_history_sampling: |
| | max_available = min(len(history_buffer), self.max_history_frames) |
| | if max_available > 0: |
| | use_history = int(torch.randint( |
| | low=0, |
| | high=max_available + 1, |
| | size=(1,), |
| | device=encoder_outputs.device, |
| | ).item()) |
| | else: |
| | use_history = 0 |
| | else: |
| | use_history = min(len(history_buffer), self.max_history_frames) |
| |
|
| | if relative_times is not None: |
| | time_input = relative_times[:, f:f+1].unsqueeze(-1) |
| | time_embed = self.time_embed(time_input).squeeze(1) |
| | else: |
| | time_embed = torch.zeros( |
| | batch_size, |
| | self.hidden_size, |
| | device=encoder_outputs.device, |
| | ) |
| |
|
| | history_embed = self._build_history_embedding(history_buffer, physics_state, use_history) |
| | decoder_embed = time_embed + history_embed |
| | if noise is not None: |
| | decoder_embed = decoder_embed + noise |
| |
|
| | decoder_output = self.t5_model.decoder( |
| | inputs_embeds=decoder_embed.unsqueeze(1), |
| | encoder_hidden_states=encoder_outputs, |
| | encoder_attention_mask=text_inputs.attention_mask, |
| | ) |
| |
|
| | hidden = decoder_output.last_hidden_state[:, 0] |
| |
|
| | object_params = self.object_proj(hidden).view(batch_size, self.max_objects, self.object_param_dim) |
| | if static_object_params is not None: |
| | |
| | static_slice = static_object_params[:, :, :6] |
| | if static_slice.shape[-1] < 6: |
| | pad_width = 6 - static_slice.shape[-1] |
| | pad = torch.zeros(*static_slice.shape[:-1], pad_width, device=object_params.device) |
| | static_slice = torch.cat([static_slice, pad], dim=-1) |
| | object_params = object_params.clone() |
| | object_params[:, :, :6] = static_slice |
| | world_params = self.world_proj(hidden) |
| | physics_params = self.physics_proj(hidden).view(batch_size, self.max_objects, 3) |
| |
|
| | outputs.append({ |
| | 'objects': object_params, |
| | 'world': world_params, |
| | 'physics': physics_params, |
| | }) |
| |
|
| | history_buffer.append({ |
| | 'objects': object_params, |
| | 'world': world_params, |
| | }) |
| | if len(history_buffer) > self.max_history_frames: |
| | history_buffer = history_buffer[-self.max_history_frames:] |
| |
|
| | physics_state = physics_params |
| |
|
| | return outputs |
| |
|
| |
|
| | class BidirectionalTrainer: |
| | """Trainer for bidirectional prediction from middle frame""" |
| | |
| | def __init__( |
| | self, |
| | model: Text2WaveModel, |
| | config: Dict, |
| | accelerator: Accelerator, |
| | ): |
| | self.model = model |
| | self.config = config |
| | self.accelerator = accelerator |
| | base_model = accelerator.unwrap_model(model) if hasattr(accelerator, "unwrap_model") else model |
| | self.object_param_dim = getattr(base_model, "object_param_dim", 12) |
| | self.freeze_static_params = bool(config['training'].get('freeze_static_from_anchor', True)) |
| | self.base_model = base_model |
| | self.sample_attempts = int(config['training'].get('multi_sample_attempts', 1)) |
| | self.sample_attempts = max(1, self.sample_attempts) |
| | |
| | |
| | self.world_loss_fn = nn.MSELoss() |
| | self.physics_loss_fn = nn.MSELoss() |
| | |
| | |
| | loss_weights_config = config.get('loss', {}).get('weights', {}) |
| | self.loss_weights = { |
| | 'wave_loss(superquadric)': loss_weights_config.get('wave_loss', 1.0), |
| | 'wave_contrastive_loss': loss_weights_config.get('wave_contrastive_loss', 2.0), |
| | 'world_info_loss(camera,scale,time)': loss_weights_config.get('world_info_loss', 0.5), |
| | 'controllable_info_loss(mass,friction,restitution)': loss_weights_config.get('controllable_info_loss', 0.1), |
| | 'pla_loss': loss_weights_config.get('pla_loss', 3.0), |
| | } |
| |
|
| | physics_cfg = config.get('physics', {}) |
| | self.gravity = float(physics_cfg.get('gravity', 9.81)) |
| | self.collision_buffer = float(physics_cfg.get('collision_buffer', 1.05)) |
| |
|
| | |
| | self.frame_rate = float(config['training'].get('frame_rate', 8.0)) |
| | self.frame_rate = max(self.frame_rate, 1e-6) |
| |
|
| | presence_cfg = config.get('loss', {}).get('wave_presence', {}) |
| | self.wave_count_weight = float(presence_cfg.get('count_weight', 0.2)) |
| | self.wave_presence_threshold = float(presence_cfg.get('scale_threshold', 0.1)) |
| | self.wave_presence_temperature = float(presence_cfg.get('temperature', 0.1)) |
| | contrastive_cfg = config.get('loss', {}).get('wave_contrastive', {}) |
| | self.wave_contrastive_temperature = float(contrastive_cfg.get('temperature', 0.2)) |
| |
|
| | |
| | self.velocity_slice = slice(max(self.object_param_dim - 3, 0), self.object_param_dim) |
| |
|
| | def compute_loss( |
| | self, |
| | predictions: List[Dict], |
| | targets: Dict[str, torch.Tensor], |
| | frame_indices: List[int], |
| | ) -> Dict[str, torch.Tensor]: |
| | """Compute losses for predicted frames""" |
| | losses = { |
| | 'wave_loss(superquadric)': 0.0, |
| | 'wave_contrastive_loss': 0.0, |
| | 'world_info_loss(camera,scale,time)': 0.0, |
| | 'controllable_info_loss(mass,friction,restitution)': 0.0, |
| | 'pla_loss': 0.0, |
| | 'wave_count_mse': 0.0, |
| | 'total': 0.0, |
| | } |
| |
|
| | pla_entries = [] |
| | pred_summaries: List[torch.Tensor] = [] |
| | target_summaries: List[torch.Tensor] = [] |
| |
|
| | for i, (pred, frame_idx) in enumerate(zip(predictions, frame_indices)): |
| | |
| | target_objects = targets['objects'][:, frame_idx] |
| | if target_objects.shape[-1] < self.object_param_dim: |
| | pad_width = self.object_param_dim - target_objects.shape[-1] |
| | pad = target_objects.new_zeros(*target_objects.shape[:-1], pad_width) |
| | target_objects = torch.cat([target_objects, pad], dim=-1) |
| | pred_objects = pred['objects'] |
| |
|
| | |
| | exists_mask = target_objects[:, :, 0] > 0.5 |
| |
|
| | target_core = target_objects[:, :, :self.object_param_dim] |
| |
|
| | |
| | object_loss = self._wave_reconstruction_loss(pred_objects, target_core, exists_mask) |
| | losses['wave_loss(superquadric)'] += object_loss |
| |
|
| | |
| | target_presence = target_objects[:, :, 0].float() |
| | pred_scale_norm = torch.linalg.norm(pred_objects[:, :, 3:6], dim=-1) |
| | presence_input = (pred_scale_norm - self.wave_presence_threshold) / max(self.wave_presence_temperature, 1e-6) |
| | pred_presence = torch.sigmoid(presence_input) |
| | pred_count = pred_presence.sum(dim=-1) |
| | target_count = target_presence.sum(dim=-1) |
| | count_mse = F.mse_loss(pred_count, target_count) |
| | losses['wave_count_mse'] += count_mse |
| | losses['wave_loss(superquadric)'] += self.wave_count_weight * count_mse |
| |
|
| | pla_entries.append({ |
| | 'frame_idx': frame_idx, |
| | 'pred_objects': pred_objects, |
| | 'exists_mask': exists_mask, |
| | }) |
| |
|
| | |
| | mask = exists_mask.float().unsqueeze(-1) |
| | |
| | denom = mask.sum(dim=1).clamp_min(1.0) |
| | pred_summary = (pred_objects * mask).sum(dim=1) / denom |
| | target_summary = (target_core * mask).sum(dim=1) / denom |
| | pred_summaries.append(pred_summary) |
| | target_summaries.append(target_summary) |
| |
|
| | |
| | target_world = targets['world'][:, frame_idx] |
| | pred_world = pred['world'] |
| |
|
| | |
| | world_loss = self.world_loss_fn( |
| | pred_world, |
| | target_world[:, :8] |
| | ) |
| | losses['world_info_loss(camera,scale,time)'] += world_loss |
| | |
| | |
| | if i == 0: |
| | target_physics = targets['physics'] |
| | pred_physics = pred['physics'] |
| | |
| | physics_loss = self.physics_loss_fn( |
| | pred_physics[exists_mask], |
| | target_physics[exists_mask] |
| | ) |
| | losses['controllable_info_loss(mass,friction,restitution)'] = physics_loss |
| |
|
| | |
| | num_frames = len(predictions) |
| | losses['wave_loss(superquadric)'] /= num_frames |
| | losses['world_info_loss(camera,scale,time)'] /= num_frames |
| | losses['wave_count_mse'] /= num_frames |
| |
|
| | |
| | total_frames = targets['objects'].shape[1] |
| | middle_idx = total_frames // 2 |
| | anchor_objects = targets['objects'][:, middle_idx] |
| | anchor_exists = anchor_objects[:, :, 0] > 0.5 |
| | pla_entries.append({ |
| | 'frame_idx': middle_idx, |
| | 'pred_objects': anchor_objects[:, :, :self.object_param_dim].detach(), |
| | 'exists_mask': anchor_exists, |
| | }) |
| |
|
| | |
| | pla_loss = self._compute_pla_regularizer(pla_entries) |
| | losses['pla_loss'] = pla_loss |
| |
|
| | |
| | if pred_summaries: |
| | pred_stack = torch.stack(pred_summaries, dim=0).mean(dim=0) |
| | target_stack = torch.stack(target_summaries, dim=0).mean(dim=0) |
| | losses['wave_contrastive_loss'] = self._contrastive_clip_loss(pred_stack, target_stack) |
| | else: |
| | device = targets['objects'].device |
| | losses['wave_contrastive_loss'] = torch.zeros((), device=device) |
| |
|
| | |
| | for key, weight in self.loss_weights.items(): |
| | if key in losses: |
| | losses['total'] += weight * losses[key] |
| |
|
| | return losses |
| |
|
| | def _wave_reconstruction_loss( |
| | self, |
| | pred_objects: torch.Tensor, |
| | target_objects: torch.Tensor, |
| | exists_mask: torch.Tensor, |
| | ) -> torch.Tensor: |
| | """Velocity-aware reconstruction loss combining position L1 and velocity L1.""" |
| | device = pred_objects.device |
| | dtype = pred_objects.dtype |
| | if not exists_mask.any(): |
| | return torch.zeros((), device=device, dtype=dtype) |
| |
|
| | pred_active = pred_objects[exists_mask] |
| | target_active = target_objects[exists_mask] |
| |
|
| | base_l1 = F.l1_loss(pred_active, target_active, reduction='mean') |
| |
|
| | if self.velocity_slice.start >= self.velocity_slice.stop: |
| | velocity_l1 = torch.zeros((), device=device, dtype=dtype) |
| | else: |
| | pred_velocity = pred_active[..., self.velocity_slice] |
| | target_velocity = target_active[..., self.velocity_slice] |
| | velocity_l1 = F.l1_loss(pred_velocity, target_velocity, reduction='mean') |
| |
|
| | return 0.5 * base_l1 + 0.5 * velocity_l1 |
| |
|
| | def _contrastive_clip_loss( |
| | self, |
| | pred_summary: torch.Tensor, |
| | target_summary: torch.Tensor, |
| | ) -> torch.Tensor: |
| | """InfoNCE-style contrastive loss between predicted and target clip summaries.""" |
| | device = pred_summary.device |
| | dtype = pred_summary.dtype |
| | batch = pred_summary.size(0) |
| | if batch <= 1: |
| | return torch.zeros((), device=device, dtype=dtype) |
| |
|
| | dim = min(pred_summary.size(-1), target_summary.size(-1)) |
| | if dim == 0: |
| | return torch.zeros((), device=device, dtype=dtype) |
| | if pred_summary.size(-1) != dim: |
| | pred_summary = pred_summary[..., :dim] |
| | if target_summary.size(-1) != dim: |
| | target_summary = target_summary[..., :dim] |
| |
|
| | temperature = max(self.wave_contrastive_temperature, 1e-6) |
| | pred_norm = F.normalize(pred_summary, dim=-1) |
| | target_norm = F.normalize(target_summary, dim=-1) |
| | dim_post = min(pred_norm.size(-1), target_norm.size(-1)) |
| | if dim_post == 0: |
| | return torch.zeros((), device=device, dtype=dtype) |
| | if pred_norm.size(-1) != dim_post: |
| | pred_norm = pred_norm[..., :dim_post] |
| | if target_norm.size(-1) != dim_post: |
| | target_norm = target_norm[..., :dim_post] |
| | logits = pred_norm @ target_norm.transpose(0, 1) |
| | logits = logits / temperature |
| |
|
| | labels = torch.arange(batch, device=device) |
| | loss_forward = F.cross_entropy(logits, labels) |
| | loss_backward = F.cross_entropy(logits.transpose(0, 1), labels) |
| |
|
| | return 0.5 * (loss_forward + loss_backward) |
| |
|
| | def _compute_pla_regularizer(self, entries: List[Dict[str, torch.Tensor]]) -> torch.Tensor: |
| | """Encourage rigid-body consistency, free-fall dynamics, and collision plausibility.""" |
| | model_device = next(self.model.parameters()).device |
| | if not entries: |
| | return torch.tensor(0.0, device=model_device) |
| |
|
| | |
| | sorted_entries = sorted(entries, key=lambda x: x['frame_idx']) |
| |
|
| | device = sorted_entries[0]['pred_objects'].device |
| | dtype = sorted_entries[0]['pred_objects'].dtype |
| |
|
| | preds = torch.stack([item['pred_objects'] for item in sorted_entries], dim=0) |
| | exists = torch.stack([item['exists_mask'].float() for item in sorted_entries], dim=0) |
| |
|
| | frame_count, batch_size, max_objects, _ = preds.shape |
| |
|
| | if frame_count <= 1: |
| | return torch.tensor(0.0, device=device, dtype=dtype) |
| |
|
| | exists_expanded = exists.unsqueeze(-1) |
| | exists_total = exists_expanded.sum() |
| | if exists_total.item() == 0: |
| | return torch.tensor(0.0, device=device, dtype=dtype) |
| |
|
| | |
| | shape_params = preds[..., 1:3] |
| | scale_params = preds[..., 3:6] |
| |
|
| | shape_mean = (shape_params * exists_expanded).sum(dim=0) / exists_expanded.sum(dim=0).clamp_min(1.0) |
| | scale_mean = (scale_params * exists_expanded).sum(dim=0) / exists_expanded.sum(dim=0).clamp_min(1.0) |
| |
|
| | shape_loss = ((shape_params - shape_mean) ** 2 * exists_expanded).sum() / exists_expanded.sum().clamp_min(1.0) |
| | scale_loss = ((scale_params - scale_mean) ** 2 * exists_expanded).sum() / exists_expanded.sum().clamp_min(1.0) |
| |
|
| | |
| | freefall_loss = torch.tensor(0.0, device=device, dtype=dtype) |
| | rotation_loss = torch.tensor(0.0, device=device, dtype=dtype) |
| | collision_penalty = torch.tensor(0.0, device=device, dtype=dtype) |
| | velocity_loss = torch.tensor(0.0, device=device, dtype=dtype) |
| |
|
| | positions = preds[..., 6:9] |
| |
|
| | if frame_count >= 3: |
| | radii = torch.linalg.norm(preds[..., 3:6], dim=-1) |
| |
|
| | accel = positions[2:] - 2 * positions[1:-1] + positions[:-2] |
| |
|
| | exists_triplet = exists[1:-1] * exists[:-2] * exists[2:] |
| | exists_triplet_expanded = exists_triplet.unsqueeze(-1) |
| |
|
| | |
| | center_positions = positions[1:-1].reshape(-1, max_objects, 3) |
| | center_exists = exists[1:-1].reshape(-1, max_objects) |
| | center_radii = radii[1:-1].reshape(-1, max_objects) |
| |
|
| | if center_positions.numel() > 0: |
| | dist = torch.cdist(center_positions, center_positions, p=2) |
| | radius_sum = (center_radii.unsqueeze(-1) + center_radii.unsqueeze(-2)) * self.collision_buffer |
| | exists_pair = center_exists.unsqueeze(-1) * center_exists.unsqueeze(-2) |
| |
|
| | eye = torch.eye(max_objects, device=device).unsqueeze(0) |
| | non_diag = (1 - eye) |
| |
|
| | penetration = torch.relu((radius_sum - dist) * non_diag) * exists_pair |
| | collision_penalty = penetration.pow(2).sum() / (non_diag * exists_pair).sum().clamp_min(1.0) |
| |
|
| | contact_any = (penetration > 0).any(dim=-1).view(frame_count - 2, batch_size, max_objects) |
| | else: |
| | contact_any = torch.zeros(frame_count - 2, batch_size, max_objects, device=device, dtype=torch.bool) |
| |
|
| | contact_mask = contact_any.float() |
| |
|
| | gravity_vec = torch.tensor([0.0, 0.0, -self.gravity], device=device, dtype=dtype).view(1, 1, 1, 3) |
| | residual = accel + gravity_vec |
| |
|
| | freefall_mask = exists_triplet_expanded * (1.0 - contact_mask.unsqueeze(-1)) |
| | valid_count = freefall_mask.sum().clamp_min(1.0) |
| | freefall_loss = (residual.pow(2) * freefall_mask).sum() / valid_count |
| |
|
| | rotations = preds[..., 9:12] |
| | rot_sin = torch.sin(rotations) |
| | rot_cos = torch.cos(rotations) |
| | rot_features = torch.cat([rot_sin, rot_cos], dim=-1) |
| | rot_acc = rot_features[2:] - 2 * rot_features[1:-1] + rot_features[:-2] |
| |
|
| | rot_mask = exists_triplet_expanded * (1.0 - contact_mask.unsqueeze(-1)) |
| | rot_valid = rot_mask.sum().clamp_min(1.0) |
| | rotation_loss = (rot_acc.pow(2) * rot_mask).sum() / rot_valid |
| |
|
| | if frame_count >= 2: |
| | velocities = preds[..., 12:15] |
| | diff = (positions[1:] - positions[:-1]) * self.frame_rate |
| | exists_pair = exists[1:] * exists[:-1] |
| | diff_expanded = exists_pair.unsqueeze(-1) |
| |
|
| | velocity_residual = (velocities[1:] - diff).pow(2) * diff_expanded |
| | valid_velocity = diff_expanded.sum() |
| | velocity_loss = velocity_residual.sum() |
| |
|
| | first_pair = (exists[0] * exists[1]).unsqueeze(-1) |
| | velocity_loss += ((velocities[0] - diff[0]) ** 2 * first_pair).sum() |
| | valid_velocity += first_pair.sum() |
| |
|
| | velocity_loss = velocity_loss / valid_velocity.clamp_min(1.0) |
| |
|
| | pla_loss = ( |
| | shape_loss |
| | + scale_loss |
| | + freefall_loss |
| | + rotation_loss |
| | + collision_penalty |
| | + velocity_loss |
| | ) |
| | return pla_loss |
| |
|
| | def _select_anchor_frame(self, num_frames: int) -> int: |
| | """Determine which frame should serve as the initial anchor.""" |
| | cfg = self.config['training'].get('initial_frame', {}) |
| | strategy = cfg.get('strategy', 'middle') |
| |
|
| | if strategy == 'random': |
| | base_idx = int(torch.randint(low=0, high=num_frames, size=(1,), device=torch.device('cpu')).item()) |
| | elif strategy == 'fixed': |
| | base_idx = int(cfg.get('index', num_frames // 2)) |
| | else: |
| | base_idx = num_frames // 2 |
| |
|
| | offset = int(cfg.get('offset', 0)) |
| | anchor_idx = base_idx + offset |
| | anchor_idx = max(0, min(num_frames - 1, anchor_idx)) |
| | return anchor_idx |
| |
|
| | def _generate_full_sequence( |
| | self, |
| | text: List[str], |
| | objects: torch.Tensor, |
| | world: torch.Tensor, |
| | physics: torch.Tensor, |
| | teacher_prob: float, |
| | anchor_idx: Optional[int] = None, |
| | use_noise: bool = False, |
| | ) -> Tuple[List[Dict[str, torch.Tensor]], List[int], float]: |
| | """Generate a full sequence of predictions given an anchor frame.""" |
| | batch_size, num_frames = objects.shape[:2] |
| | if anchor_idx is None: |
| | anchor_idx = self._select_anchor_frame(num_frames) |
| |
|
| | static_object_params = None |
| | if self.freeze_static_params: |
| | anchor_static = objects[:, anchor_idx, :, :6] |
| | static_object_params = anchor_static |
| |
|
| | if teacher_prob > 0.0: |
| | teacher_mask = (torch.rand(batch_size, device=objects.device) < teacher_prob).float() |
| | else: |
| | teacher_mask = torch.zeros(batch_size, device=objects.device, dtype=torch.float32) |
| |
|
| | def sample_noise(): |
| | return self.base_model.sample_decoder_noise(batch_size, objects.device) if use_noise else None |
| |
|
| | half_span = max(num_frames - 1, 1) / 2.0 |
| | inference_time = 0.0 |
| | predictions_by_idx: Dict[int, Dict[str, torch.Tensor]] = {} |
| |
|
| | anchor_rel_times = torch.zeros( |
| | (batch_size, 1), dtype=torch.float32, device=objects.device |
| | ) |
| | anchor_targets = torch.full( |
| | (batch_size, 1), anchor_idx, dtype=torch.long, device=objects.device |
| | ) |
| |
|
| | start = time.time() |
| | anchor_preds = self.model( |
| | input_text=text, |
| | target_frames=anchor_targets, |
| | history_frames=None, |
| | relative_times=anchor_rel_times, |
| | static_object_params=static_object_params, |
| | noise=sample_noise(), |
| | ) |
| | inference_time += time.time() - start |
| | anchor_pred = anchor_preds[0] |
| | predictions_by_idx[anchor_idx] = anchor_pred |
| |
|
| | anchor_gt_objects = objects[:, anchor_idx, :, :self.object_param_dim] |
| | if anchor_gt_objects.shape[-1] < self.object_param_dim: |
| | pad_width = self.object_param_dim - anchor_gt_objects.shape[-1] |
| | pad = anchor_gt_objects.new_zeros(*anchor_gt_objects.shape[:-1], pad_width) |
| | anchor_gt_objects = torch.cat([anchor_gt_objects, pad], dim=-1) |
| | anchor_gt_world = world[:, anchor_idx, :8] |
| | anchor_pred_objects = anchor_pred['objects'] |
| | if static_object_params is not None: |
| | anchor_pred_objects[:, :, :6] = static_object_params[:, :, :6] |
| | anchor_pred_world = anchor_pred['world'] |
| |
|
| | teacher_mask_objs = teacher_mask.view(batch_size, 1, 1) |
| | teacher_mask_world = teacher_mask.view(batch_size, 1) |
| |
|
| | blended_objects = anchor_pred_objects * (1.0 - teacher_mask_objs) + anchor_gt_objects * teacher_mask_objs |
| | blended_world = anchor_pred_world * (1.0 - teacher_mask_world) + anchor_gt_world * teacher_mask_world |
| |
|
| | history_objects = blended_objects.unsqueeze(1) |
| | history_world = blended_world.unsqueeze(1) |
| | history_physics = physics.clone() |
| |
|
| | def make_history_seed(): |
| | return { |
| | 'objects': history_objects.clone(), |
| | 'world': history_world.clone(), |
| | 'physics': history_physics.clone(), |
| | } |
| |
|
| | backward_indices = list(range(anchor_idx - 1, -1, -1)) |
| | forward_indices = list(range(anchor_idx + 1, num_frames)) |
| |
|
| | def run_direction(target_indices: List[int]): |
| | nonlocal inference_time |
| | if not target_indices: |
| | return |
| |
|
| | rel_times = torch.tensor( |
| | [(idx - anchor_idx) / half_span for idx in target_indices], |
| | dtype=torch.float32, |
| | device=objects.device, |
| | ).unsqueeze(0).repeat(batch_size, 1) |
| |
|
| | target_tensor = torch.tensor( |
| | target_indices, |
| | dtype=torch.long, |
| | device=objects.device, |
| | ).unsqueeze(0).repeat(batch_size, 1) |
| |
|
| | history_frames = make_history_seed() |
| |
|
| | start_time = time.time() |
| | preds = self.model( |
| | input_text=text, |
| | target_frames=target_tensor, |
| | history_frames=history_frames, |
| | relative_times=rel_times, |
| | static_object_params=static_object_params, |
| | noise=sample_noise(), |
| | ) |
| | inference_time += time.time() - start_time |
| |
|
| | for idx, pred in zip(target_indices, preds): |
| | if static_object_params is not None: |
| | pred['objects'][:, :, :6] = static_object_params[:, :, :6] |
| | predictions_by_idx[idx] = pred |
| |
|
| | run_direction(backward_indices) |
| | run_direction(forward_indices) |
| |
|
| | ordered_indices = list(range(num_frames)) |
| | predictions = [predictions_by_idx[idx] for idx in ordered_indices] |
| | return predictions, ordered_indices, inference_time |
| |
|
| | def _compute_losses( |
| | self, |
| | batch: Dict[str, torch.Tensor], |
| | ) -> Tuple[Dict[str, torch.Tensor], float, int]: |
| | """Shared logic for computing losses and metadata.""" |
| | text = batch['text'] |
| | objects = batch['objects'] |
| | world = batch['world'] |
| | physics = batch['physics'] |
| |
|
| | batch_size, num_frames = objects.shape[:2] |
| | anchor_idx = self._select_anchor_frame(num_frames) |
| | teacher_prob = float(self.config['training'].get('initial_teacher_forcing_prob', 0.5)) |
| |
|
| | targets = { |
| | 'objects': objects, |
| | 'world': world, |
| | 'physics': physics, |
| | } |
| |
|
| | attempts = self.sample_attempts if self.model.training else 1 |
| | use_noise = attempts > 1 |
| | best_losses: Optional[Dict[str, torch.Tensor]] = None |
| | best_predictions: Optional[List[Dict[str, torch.Tensor]]] = None |
| | best_frame_indices: Optional[List[int]] = None |
| | best_inference_time: float = 0.0 |
| | best_total_value: Optional[float] = None |
| |
|
| | for attempt in range(attempts): |
| | predictions, frame_indices, inference_time = self._generate_full_sequence( |
| | text=text, |
| | objects=objects, |
| | world=world, |
| | physics=physics, |
| | teacher_prob=teacher_prob, |
| | anchor_idx=anchor_idx, |
| | use_noise=use_noise, |
| | ) |
| |
|
| | losses = self.compute_loss(predictions, targets, frame_indices) |
| | total_value = float(losses['total'].detach()) |
| | if best_total_value is None or total_value < best_total_value: |
| | if best_losses is not None: |
| | del best_losses |
| | if best_predictions is not None: |
| | del best_predictions |
| | best_total_value = total_value |
| | best_losses = losses |
| | best_predictions = predictions |
| | best_frame_indices = frame_indices |
| | best_inference_time = inference_time |
| | else: |
| | del losses |
| | del predictions |
| | if torch.cuda.is_available(): |
| | torch.cuda.empty_cache() |
| |
|
| | assert best_losses is not None and best_predictions is not None and best_frame_indices is not None |
| | num_predicted_frames = len(best_predictions) |
| | frames_per_second = num_predicted_frames / best_inference_time if best_inference_time > 0 else 0.0 |
| |
|
| | return best_losses, frames_per_second, num_predicted_frames |
| | |
| | def train_step( |
| | self, |
| | batch: Dict[str, torch.Tensor], |
| | step: int, |
| | ) -> Dict[str, float]: |
| | """Single training step with bidirectional prediction""" |
| | self.model.train() |
| |
|
| | losses, frames_per_second, num_predicted_frames = self._compute_losses(batch) |
| |
|
| | self.accelerator.backward(losses['total']) |
| |
|
| | loss_dict = {k: v.item() if torch.is_tensor(v) else float(v) for k, v in losses.items()} |
| | loss_dict['inference_fps'] = frames_per_second |
| | loss_dict['frames_predicted'] = num_predicted_frames |
| |
|
| | return loss_dict |
| |
|
| | def evaluate_batch(self, batch: Dict[str, torch.Tensor]) -> Dict[str, float]: |
| | """Compute losses without gradient updates.""" |
| | was_training = self.model.training |
| | self.model.eval() |
| | with torch.no_grad(): |
| | losses, frames_per_second, num_predicted_frames = self._compute_losses(batch) |
| | if was_training: |
| | self.model.train() |
| |
|
| | loss_dict = {k: v.item() if torch.is_tensor(v) else float(v) for k, v in losses.items()} |
| | loss_dict['inference_fps'] = frames_per_second |
| | loss_dict['frames_predicted'] = num_predicted_frames |
| | return loss_dict |
| |
|
| |
|
| |
|
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument('--train_config', type=str, default='configs/default.yaml', |
| | help='Training configuration file') |
| | parser.add_argument('--data_root', type=str, |
| | default='../data/movi_a_128x128', |
| | help='Root directory of MOVi dataset') |
| | parser.add_argument('--output_dir', type=str, default='core_space', |
| | help='Directory to save checkpoints and generation results') |
| | parser.add_argument('--resume_step', type=int, default=None, |
| | help='Resume training from specific step') |
| | args = parser.parse_args() |
| | |
| | |
| | with open(args.train_config, 'r') as f: |
| | config = yaml.safe_load(f) |
| | |
| | |
| | from accelerate import DistributedDataParallelKwargs |
| | |
| | ddp_kwargs = DistributedDataParallelKwargs( |
| | find_unused_parameters=True, |
| | broadcast_buffers=False |
| | ) |
| | |
| | |
| | |
| | accelerator = Accelerator( |
| | gradient_accumulation_steps=1, |
| | kwargs_handlers=[ddp_kwargs] |
| | ) |
| | |
| | |
| | set_seed(42) |
| | |
| | |
| | model_name = config.get('text2wave_model', {}).get('model_name', "google/t5-v1_1-small") |
| | model = Text2WaveModel( |
| | model_name=model_name, |
| | max_objects=10, |
| | num_frames=24, |
| | max_history_frames=config['training']['max_history_frames'], |
| | random_history_sampling=config['training'].get('random_history_sampling', True), |
| | decoder_noise_std=config['training'].get('decoder_noise_std', 0.0), |
| | ) |
| | |
| | |
| | optimizer = torch.optim.AdamW( |
| | model.parameters(), |
| | lr=config['training']['learning_rate'], |
| | weight_decay=0.01, |
| | ) |
| | |
| | |
| | train_dataloader = create_dataloader( |
| | data_root=args.data_root, |
| | split='train', |
| | batch_size=config['training']['batch_size'], |
| | num_workers=config['data']['num_workers'], |
| | shuffle=True, |
| | max_samples=config['data'].get('max_sequences', -1), |
| | ) |
| | |
| | val_dataloader = create_dataloader( |
| | data_root=args.data_root, |
| | split='validation', |
| | batch_size=config['training']['batch_size'], |
| | num_workers=config['data']['num_workers'], |
| | shuffle=False, |
| | max_samples=10, |
| | ) |
| | |
| | |
| | model, optimizer, train_dataloader, val_dataloader = accelerator.prepare( |
| | model, optimizer, train_dataloader, val_dataloader |
| | ) |
| | |
| | checkpoint_dir = Path("checkpoints_text2wave") |
| | if accelerator.is_main_process: |
| | checkpoint_dir.mkdir(parents=True, exist_ok=True) |
| |
|
| | log_file_path = checkpoint_dir / "training_log.txt" |
| |
|
| | def log_message(message: str): |
| | """Log to stdout and append to training_log.txt from main process.""" |
| | if not accelerator.is_main_process: |
| | return |
| | timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") |
| | formatted = f"{timestamp} {message}" |
| | print(formatted) |
| | try: |
| | with open(log_file_path, 'a') as fp: |
| | fp.write(formatted + "\n") |
| | except Exception: |
| | pass |
| |
|
| | best_metrics_path = checkpoint_dir / "best_metrics.json" |
| | if best_metrics_path.exists(): |
| | try: |
| | best_metrics_path.unlink() |
| | except OSError as exc: |
| | log_message(f"Warning: failed to remove legacy best_metrics.json due to {exc}") |
| |
|
| | best_train_loss = float('inf') |
| | best_val_loss = float('inf') |
| |
|
| | evaluation_cfg = config['training'].get('evaluation', {}) |
| | eval_max_batches = evaluation_cfg.get('max_batches', 5) |
| |
|
| | training_stats_path = checkpoint_dir / "training_stats.npz" |
| | loaded_step_history: Optional[List[int]] = None |
| | loaded_loss_history: Dict[str, List[float]] = {} |
| | if training_stats_path.exists(): |
| | try: |
| | stats = np.load(training_stats_path, allow_pickle=True) |
| | best_train_loss = float(stats.get('best_train_loss', best_train_loss)) |
| | best_val_loss = float(stats.get('best_val_loss', best_val_loss)) |
| | if 'step_history' in stats: |
| | loaded_step_history = stats['step_history'].tolist() |
| | if 'loss_history_keys' in stats and 'loss_history_values' in stats: |
| | keys = stats['loss_history_keys'].tolist() |
| | values = stats['loss_history_values'].tolist() |
| | for key, value in zip(keys, values): |
| | loaded_loss_history[str(key)] = list(np.asarray(value, dtype=float)) |
| | except Exception as exc: |
| | log_message(f"Warning: failed to load training_stats.npz due to {exc}") |
| |
|
| | executor = ThreadPoolExecutor(max_workers=1) |
| | pending_futures: List = [] |
| |
|
| | def cleanup_futures(): |
| | pending_futures[:] = [f for f in pending_futures if not f.done()] |
| |
|
| | def submit_task(fn, *args, **kwargs): |
| | cleanup_futures() |
| | future = executor.submit(fn, *args, **kwargs) |
| | pending_futures.append(future) |
| | return future |
| |
|
| | def recursive_to_cpu(obj): |
| | if isinstance(obj, torch.Tensor): |
| | return obj.detach().cpu() |
| | if isinstance(obj, dict): |
| | return {k: recursive_to_cpu(v) for k, v in obj.items()} |
| | if isinstance(obj, list): |
| | return [recursive_to_cpu(v) for v in obj] |
| | if isinstance(obj, tuple): |
| | return tuple(recursive_to_cpu(v) for v in obj) |
| | return obj |
| |
|
| | def save_checkpoint_async(path: Path, payload: Dict): |
| | def _task(): |
| | torch.save(payload, path) |
| | submit_task(_task) |
| |
|
| | def save_generation_async(predictions: List[Dict], targets: Dict[str, torch.Tensor], texts: List[str], step: int, save_config: Dict, metadata: Dict, batch_data: Dict, data_root: str, data_split: str): |
| | def _task(): |
| | save_generation_results( |
| | predictions=predictions, |
| | targets=targets, |
| | texts=texts, |
| | step=step, |
| | output_dir=args.output_dir, |
| | save_config=save_config, |
| | metadata=metadata, |
| | batch_data=batch_data, |
| | data_root=data_root, |
| | data_split=data_split |
| | ) |
| | submit_task(_task) |
| |
|
| | def compute_validation_loss(max_batches: Optional[int]) -> Optional[float]: |
| | limit = -1 if max_batches is None else max_batches |
| | if limit == 0: |
| | return None |
| | total = 0.0 |
| | count = 0 |
| | for batch_idx, val_batch in enumerate(val_dataloader): |
| | val_losses = trainer.evaluate_batch(val_batch) |
| | total += val_losses['total'] |
| | count += 1 |
| | if limit > 0 and (batch_idx + 1) >= limit: |
| | break |
| | if count == 0: |
| | return None |
| | return total / count |
| |
|
| | |
| | trainer = BidirectionalTrainer(model, config, accelerator) |
| | |
| | |
| | max_steps = config['training']['max_steps'] |
| | |
| | |
| | if accelerator.is_main_process: |
| | steps_per_epoch = len(train_dataloader) |
| | total_epochs = max_steps / steps_per_epoch |
| | log_message("=" * 60) |
| | log_message("Dataset Information:") |
| | log_message(f"- Training samples: {len(train_dataloader.dataset) if hasattr(train_dataloader, 'dataset') else 'N/A'}") |
| | log_message(f"- Batch size: {config['training']['batch_size']}") |
| | log_message(f"- Steps per epoch (full dataset): {steps_per_epoch}") |
| | log_message(f"- Total training steps: {max_steps}") |
| | log_message(f"- Will traverse dataset: {total_epochs:.2f} times") |
| | log_message("=" * 60) |
| | |
| | |
| | start_step = 0 |
| | resumed_from = None |
| | if args.resume_step is not None: |
| | checkpoint_path = checkpoint_dir / f"step{args.resume_step}.pt" |
| | if checkpoint_path.exists(): |
| | log_message(f"Resuming from checkpoint step {args.resume_step}") |
| | checkpoint = torch.load(checkpoint_path, map_location='cpu') |
| | accelerator.unwrap_model(model).load_state_dict(checkpoint['model_state_dict']) |
| | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
| | start_step = checkpoint.get('step', args.resume_step) |
| | resumed_from = checkpoint_path |
| | else: |
| | log_message(f"Warning: Checkpoint for step {args.resume_step} not found, starting from scratch") |
| | else: |
| | latest_checkpoint_path = checkpoint_dir / "latest.pt" |
| | if latest_checkpoint_path.exists(): |
| | try: |
| | log_message("Resuming from latest checkpoint") |
| | checkpoint = torch.load(latest_checkpoint_path, map_location='cpu') |
| | accelerator.unwrap_model(model).load_state_dict(checkpoint['model_state_dict']) |
| | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
| | start_step = checkpoint.get('step', 0) |
| | resumed_from = latest_checkpoint_path |
| | except Exception as exc: |
| | log_message(f"Warning: failed to load latest checkpoint due to {exc}; attempting best checkpoint") |
| | try: |
| | corrupt_path = latest_checkpoint_path.with_suffix(latest_checkpoint_path.suffix + ".corrupt") |
| | latest_checkpoint_path.rename(corrupt_path) |
| | log_message(f"Renamed corrupt latest checkpoint to {corrupt_path.name}") |
| | except Exception as rename_exc: |
| | log_message(f"Warning: could not rename corrupt latest checkpoint: {rename_exc}") |
| | if resumed_from is None: |
| | best_checkpoint_path = checkpoint_dir / "best.pt" |
| | if best_checkpoint_path.exists(): |
| | try: |
| | log_message("Resuming from best checkpoint") |
| | checkpoint = torch.load(best_checkpoint_path, map_location='cpu') |
| | accelerator.unwrap_model(model).load_state_dict(checkpoint['model_state_dict']) |
| | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
| | start_step = checkpoint.get('step', 0) |
| | resumed_from = best_checkpoint_path |
| | except Exception as exc: |
| | log_message(f"Warning: failed to load best checkpoint due to {exc}; starting from scratch") |
| |
|
| | |
| | log_dir = checkpoint_dir |
| | loss_history = defaultdict(list) |
| | step_history: List[int] = [] |
| | if loaded_step_history: |
| | step_history.extend(int(s) for s in loaded_step_history) |
| | if loaded_loss_history: |
| | for key, values in loaded_loss_history.items(): |
| | loss_history[key].extend(values) |
| | last_plot_time = time.time() |
| | plot_path = log_dir / "losses.png" |
| |
|
| | def save_training_stats(): |
| | if not accelerator.is_main_process: |
| | return |
| | keys = sorted(loss_history.keys()) |
| | loss_arrays = [np.array(loss_history[k], dtype=np.float32) for k in keys] |
| | np.savez( |
| | training_stats_path, |
| | best_train_loss=best_train_loss, |
| | best_val_loss=best_val_loss, |
| | step_history=np.array(step_history, dtype=np.int64), |
| | loss_history_keys=np.array(keys, dtype=object), |
| | loss_history_values=np.array(loss_arrays, dtype=object), |
| | ) |
| |
|
| | def update_loss_plot(): |
| | if not accelerator.is_main_process or not step_history: |
| | return |
| | x_values = np.array(step_history, dtype=np.int64) |
| | keys = [k for k, v in sorted(loss_history.items()) if v] |
| | if not keys: |
| | return |
| |
|
| | def align_series(series: List[float]) -> np.ndarray: |
| | y_vals = np.array(series, dtype=np.float32) |
| | if len(y_vals) > len(x_values): |
| | y_vals = y_vals[-len(x_values):] |
| | elif len(y_vals) < len(x_values): |
| | pad = np.full(len(x_values) - len(y_vals), np.nan, dtype=np.float32) |
| | y_vals = np.concatenate([pad, y_vals]) |
| | return y_vals |
| |
|
| | fig_height = 3 * (len(keys) + 1) |
| | fig, axes = plt.subplots(len(keys) + 1, 1, figsize=(10, fig_height), sharex=True) |
| | if not isinstance(axes, np.ndarray): |
| | axes = np.array([axes]) |
| |
|
| | cmap = plt.get_cmap('tab10', len(keys)) |
| |
|
| | aggregated_ax = axes[0] |
| | aggregated_ax.set_title("Training Losses (all)") |
| | aggregated_ax.set_ylabel("Loss") |
| | aggregated_ax.grid(True, alpha=0.3) |
| |
|
| | for idx, key in enumerate(keys): |
| | y_aligned = align_series(loss_history[key]) |
| | if np.all(np.isnan(y_aligned)): |
| | continue |
| | color = cmap(idx % cmap.N) |
| | aggregated_ax.plot(x_values, y_aligned, label=key, color=color) |
| | ax = axes[idx + 1] |
| | ax.plot(x_values, y_aligned, color=color) |
| | ax.set_ylabel(key) |
| | ax.grid(True, alpha=0.3) |
| |
|
| | axes[-1].set_xlabel("Step") |
| | aggregated_ax.legend() |
| | fig.tight_layout() |
| | fig.savefig(plot_path) |
| | plt.close(fig) |
| | save_training_stats() |
| |
|
| | if accelerator.is_main_process and step_history: |
| | update_loss_plot() |
| |
|
| | |
| | global_step = start_step |
| | |
| | with tqdm(total=max_steps, initial=start_step, disable=not accelerator.is_local_main_process, position=0, leave=True) as pbar: |
| | while global_step < max_steps: |
| | for batch in train_dataloader: |
| | |
| | losses = trainer.train_step(batch, global_step) |
| | |
| | |
| | if accelerator.is_local_main_process: |
| | pbar.update(1) |
| | |
| | display_losses = losses.copy() |
| | display_losses['fps'] = losses['inference_fps'] |
| | pbar.set_postfix(display_losses) |
| | |
| | |
| | loss_str = f"Step {global_step}: " |
| | for k, v in losses.items(): |
| | if k not in ['inference_fps', 'frames_predicted']: |
| | loss_str += f"{k}={v:.4f} " |
| | loss_str += f"| {losses['frames_predicted']} frames @ {losses['inference_fps']:.1f} fps (training speed, inference faster)" |
| | tqdm.write(loss_str) |
| | |
| | |
| | if accelerator.is_main_process: |
| | step_history.append(global_step) |
| | for k, v in losses.items(): |
| | if k in ['inference_fps', 'frames_predicted']: |
| | continue |
| | loss_history[k].append(v) |
| | current_time = time.time() |
| | if current_time - last_plot_time >= 10: |
| | update_loss_plot() |
| | last_plot_time = current_time |
| |
|
| | |
| | |
| | save_condition = (global_step == 5) or (global_step > 0 and global_step % config['training']['save_generation']['save_interval'] == 0) |
| | if save_condition: |
| | if accelerator.is_main_process: |
| | generation_save_dir = Path(args.output_dir) |
| | generation_save_dir.mkdir(parents=True, exist_ok=True) |
| |
|
| | current_train_loss = losses['total'] |
| | val_loss = compute_validation_loss(eval_max_batches) |
| |
|
| | model_state = recursive_to_cpu(accelerator.get_state_dict(model)) |
| | optimizer_state = recursive_to_cpu(optimizer.state_dict()) |
| | payload = { |
| | 'step': global_step, |
| | 'model_state_dict': model_state, |
| | 'optimizer_state_dict': optimizer_state, |
| | 'config': config, |
| | } |
| |
|
| | latest_checkpoint_path = checkpoint_dir / "latest.pt" |
| | save_checkpoint_async(latest_checkpoint_path, dict(payload)) |
| | save_training_stats() |
| |
|
| | is_new_best = False |
| | if val_loss is not None: |
| | if val_loss < best_val_loss: |
| | best_val_loss = val_loss |
| | best_train_loss = min(best_train_loss, current_train_loss) |
| | is_new_best = True |
| | else: |
| | if current_train_loss < best_train_loss: |
| | best_train_loss = current_train_loss |
| | is_new_best = True |
| |
|
| | if is_new_best: |
| | best_checkpoint_path = checkpoint_dir / "best.pt" |
| | save_checkpoint_async(best_checkpoint_path, dict(payload)) |
| | save_training_stats() |
| | if val_loss is not None: |
| | log_message(f"New best checkpoint at step {global_step}: train_loss={current_train_loss:.6f}, val_loss={val_loss:.6f}") |
| | else: |
| | log_message(f"New best checkpoint at step {global_step}: train_loss={current_train_loss:.6f}") |
| |
|
| | if config['training']['save_generation']['enabled']: |
| | with torch.no_grad(): |
| | val_batch = next(iter(val_dataloader)) |
| | texts = val_batch['text'][:5] |
| | val_objects = val_batch['objects'][:5] |
| | val_world = val_batch['world'][:5] |
| | val_physics = val_batch.get('physics') |
| | if val_physics is not None: |
| | val_physics = val_physics[:5] |
| | else: |
| | val_physics = torch.zeros_like(val_objects[:, 0, :, :3]) |
| | val_device = val_objects.device |
| | val_batch_size, val_num_frames = val_objects.shape[:2] |
| | anchor_idx = trainer._select_anchor_frame(val_num_frames) |
| | predictions, generated_indices, _ = trainer._generate_full_sequence( |
| | text=texts, |
| | objects=val_objects, |
| | world=val_world, |
| | physics=val_physics, |
| | teacher_prob=0.0, |
| | anchor_idx=anchor_idx, |
| | ) |
| |
|
| | val_objects_cpu = val_objects.detach().cpu() |
| | val_world_cpu = val_world.detach().cpu() |
| | val_physics_cpu = val_physics.detach().cpu() |
| | val_batch_cpu = recursive_to_cpu(val_batch) |
| | predictions_cpu = [{ |
| | 'objects': pred['objects'].detach().cpu(), |
| | 'world': pred['world'].detach().cpu(), |
| | 'physics': pred['physics'].detach().cpu(), |
| | } for pred in predictions] |
| | targets_cpu = { |
| | 'objects': val_objects_cpu, |
| | 'world': val_world_cpu, |
| | 'physics': val_physics_cpu, |
| | } |
| | metadata = { |
| | 'sequence_names': val_batch.get('sequence_names', None)[:5] if 'sequence_names' in val_batch else None, |
| | 'generated_indices': generated_indices, |
| | } |
| | save_generation_async( |
| | predictions=predictions_cpu, |
| | targets=targets_cpu, |
| | texts=list(texts), |
| | step=global_step, |
| | save_config=config['training']['save_generation'], |
| | metadata=metadata, |
| | batch_data=val_batch_cpu, |
| | data_root=args.data_root, |
| | data_split='validation' |
| | ) |
| | else: |
| | msg = f"No improvement at step {global_step}: train_loss={current_train_loss:.6f}" |
| | if val_loss is not None: |
| | msg += f", val_loss={val_loss:.6f}" |
| | log_message(msg) |
| |
|
| | |
| | if accelerator.sync_gradients: |
| | clip_val = config['training'].get('gradient_clip_val', 1.0) |
| | accelerator.clip_grad_norm_(model.parameters(), max_norm=clip_val) |
| | |
| | optimizer.step() |
| | optimizer.zero_grad() |
| | |
| | global_step += 1 |
| | |
| | if global_step >= max_steps: |
| | break |
| | |
| | |
| | if accelerator.is_main_process: |
| | update_loss_plot() |
| |
|
| | |
| | executor.shutdown(wait=True) |
| |
|
| | |
| | if accelerator.is_main_process: |
| | checkpoint_dir.mkdir(parents=True, exist_ok=True) |
| | final_checkpoint_path = checkpoint_dir / f"step{global_step}_final.pt" |
| | |
| | torch.save({ |
| | 'step': global_step, |
| | 'model_state_dict': model.state_dict(), |
| | 'optimizer_state_dict': optimizer.state_dict(), |
| | 'config': config, |
| | }, final_checkpoint_path) |
| | |
| | |
| | best_path = checkpoint_dir / "best.pt" |
| | if best_path.exists() or best_path.is_symlink(): |
| | best_path.unlink() |
| | best_path.symlink_to(final_checkpoint_path.name) |
| |
|
| | log_message(f"Saved final checkpoint: {final_checkpoint_path}") |
| | |
| | |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|