| import torch |
| from torch import nn |
| from wm.model.interface import DIT_CLASS_MAP, VAE_CLASS_MAP |
|
|
| class DiffusionForcing_WM(nn.Module): |
| def __init__(self, model_name, model_config): |
| super().__init__() |
| self.model_name = model_name |
| self.model_config = model_config |
| |
| |
| dit_keys = [ |
| 'in_channels', 'patch_size', 'dim', 'num_layers', 'num_heads', |
| 'action_dim', 'action_compress_rate', 'max_frames', |
| 'rope_config', 'action_dropout_prob', 'temporal_causal' |
| ] |
| dit_config = {k: v for k, v in model_config.items() if k in dit_keys} |
| |
| |
| dit_config['temporal_causal'] = True |
| |
| self.model = DIT_CLASS_MAP[model_name](**dit_config) |
| self.vae = VAE_CLASS_MAP[model_config['vae_name']](*model_config.get('vae_config', [])) |
| |
| |
| scheduler_config = model_config.get('scheduler') |
| if isinstance(scheduler_config, str): |
| if scheduler_config == "FlowMatch": |
| from wm.model.diffusion.flow_matching import FlowMatchScheduler |
| self.scheduler = FlowMatchScheduler() |
| else: |
| raise ValueError(f"Unknown scheduler type: {scheduler_config}") |
| else: |
| self.scheduler = scheduler_config |
| |
| |
| self.scheduler.set_timesteps(model_config['training_timesteps'], training=True) |
| |
| |
| def encode_obs(self, o): |
| |
| |
| |
| with torch.no_grad(): |
| |
| o = o * 2.0 - 1.0 |
| |
| |
| if o.shape[-1] == 3: |
| o = o.permute(0, 1, 4, 2, 3).contiguous() |
| elif o.shape[2] == 3: |
| |
| pass |
| |
| latent = self.vae.encode(o) |
| |
| latent = latent.permute(0, 1, 3, 4, 2).contiguous() |
| return latent |
| |
| |
| def training_loss(self, z, a): |
| |
| |
| |
| |
| B, T = z.shape[0], z.shape[1] |
| |
| |
| t_indices = torch.randint(0, self.scheduler.timesteps.shape[0], (B, T), device=z.device) |
| t_values = self.scheduler.timesteps[t_indices] |
| |
| |
| z_t, eps = self.scheduler.add_independent_noise(z, t_values) |
| |
| v_pred = self.model(z_t, t_values, a) |
| v_target = self.scheduler.training_target(z, eps, t_values) |
| |
| |
| weights = self.scheduler.training_weight(t_values) |
| loss = (weights.view(B, T, 1, 1, 1) * (v_pred - v_target)**2).mean() |
| return loss |
| |
| |
| def full_train_loss(self, o_t, a): |
| |
| |
| |
| |
| |
| a = a.clone() |
| a[:, -1, :] = 0 |
| |
| |
| z = self.encode_obs(o_t) |
| |
| |
| loss = self.training_loss(z, a) |
| return loss |
| |
| |
| def generate(self, o_0, a, num_inference_steps=50, noise_level=0.0, mode="autoregressive"): |
| |
| |
| |
| |
| |
| |
| |
| |
| B = o_0.shape[0] |
| T_pixel = a.shape[1] |
| device = o_0.device |
| |
| |
| z_0 = self.encode_obs(o_0.unsqueeze(1)) |
| |
| |
| T_latent = (T_pixel - 1) // 4 + 1 |
| H_prime, W_prime = z_0.shape[2], z_0.shape[3] |
| D = z_0.shape[4] |
| |
| |
| old_training = self.scheduler.training |
| self.scheduler.set_timesteps(num_inference_steps=num_inference_steps, training=False) |
| |
| from tqdm import tqdm |
| |
| if mode == "parallel": |
| |
| z = torch.randn(B, T_latent, H_prime, W_prime, D, device=device) |
| |
| |
| if noise_level > 0: |
| t_val_0 = torch.full((B, 1), noise_level, device=device) |
| z_0_noisy, _ = self.scheduler.add_independent_noise(z_0, t_val_0) |
| z[:, 0] = z_0_noisy.squeeze(1) |
| else: |
| z[:, 0] = z_0.squeeze(1) |
| |
| |
| for i in tqdm(range(len(self.scheduler.timesteps)), desc="Denoising (Parallel)"): |
| t_val = self.scheduler.timesteps[i] |
| t = torch.full((B, T_latent), t_val, device=device) |
| |
| if noise_level > 0: |
| t[:, 0] = torch.where(t_val > noise_level, torch.tensor(noise_level, device=device), t_val) |
| else: |
| t[:, 0] = 0 |
| |
| with torch.no_grad(): |
| v_pred = self.model(z, t, a) |
| z = self.scheduler.step(v_pred, t, z) |
| |
| if noise_level == 0: |
| z[:, 0] = z_0.squeeze(1) |
| |
| elif mode == "autoregressive": |
| |
| z_all = z_0.clone() |
| |
| |
| for t_idx in range(1, T_latent): |
| |
| z_next = torch.randn(B, 1, H_prime, W_prime, D, device=device) |
| z_curr = torch.cat([z_all, z_next], dim=1) |
| |
| |
| for i in range(len(self.scheduler.timesteps)): |
| t_val = self.scheduler.timesteps[i] |
| |
| t_seq = torch.zeros(B, t_idx + 1, device=device) |
| t_seq[:, -1] = t_val |
| |
| |
| L_curr = self.model.action_compress_rate * t_idx + 1 |
| a_curr = a[:, :L_curr] |
| |
| with torch.no_grad(): |
| v_pred = self.model(z_curr, t_seq, a_curr) |
| z_curr = self.scheduler.step(v_pred, t_seq, z_curr) |
| |
| |
| z_curr[:, :-1] = z_all |
| |
| |
| z_all = torch.cat([z_all, z_curr[:, -1:]], dim=1) |
| |
| z = z_all |
| else: |
| raise ValueError(f"Unknown generation mode: {mode}") |
|
|
| |
| if old_training: |
| self.scheduler.set_timesteps(self.model_config['training_timesteps'], training=True) |
| |
| |
| with torch.no_grad(): |
| z_for_vae = z.permute(0, 1, 4, 2, 3).contiguous() |
| video_recon = self.vae.decode_to_pixel(z_for_vae) |
| video_recon = (video_recon + 1.0) / 2.0 |
| video_recon = video_recon.permute(0, 1, 3, 4, 2).contiguous().clamp(0, 1) |
| |
| return video_recon |
|
|