| import torch |
| import functools |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import Any, Dict, Optional, Tuple |
|
|
| from src.models.audio_jepa_module import AudioJEPAModule |
| from src.models.components.random_projection_quantizer import RandomProjectionQuantizer |
|
|
|
|
| class RQAJEPAModule(AudioJEPAModule): |
| """ |
| RQA-JEPA Lightning Module. |
| Extends AudioJEPAModule with Random Projection Quantization loss. |
| |
| Args: |
| optimizer (torch.optim.Optimizer): Optimizer configuration. |
| net (Dict[str, Any]): Configuration for sub-modules. |
| warmup_pct (float): Percentage of total steps for warmup. |
| final_lr_ratio (float): Ratio of final learning rate to initial learning rate. |
| ema_decay (float): Initial EMA decay rate. |
| ema_end_decay (float): Final EMA decay rate. |
| ema_anneal_end_step (int): Step at which EMA decay reaches ema_end_decay. |
| spectrogram_adjustment_mode (str): 'pad' or 'truncate' for spectrogram time dimension. |
| jepa_criterion (torch.nn.Module): Loss function for JEPA (defaults to MSELoss). |
| rq_criterion (torch.nn.Module): Loss function for RQ (defaults to CrossEntropyLoss). |
| rq_lambda (float): Weight for JEPA loss (1 - rq_lambda is used for RQ loss). |
| codebook_dim (int): Codebook dimension for RandomProjectionQuantizer. |
| vocab_size (int): Vocabulary size for RandomProjectionQuantizer. |
| rq_input_type (str): 'teacher' or 'spectrogram'. Source for quantization targets. |
| """ |
|
|
| def __init__( |
| self, |
| optimizer: torch.optim.Optimizer, |
| net: Dict[str, Any], |
| warmup_pct: float = 0.1, |
| final_lr_ratio: float = 0.001, |
| ema_decay: float = 0.996, |
| ema_end_decay: float = 1.0, |
| ema_anneal_end_step: Optional[int] = None, |
| spectrogram_adjustment_mode: str = "pad", |
| jepa_criterion: Optional[torch.nn.Module] = None, |
| rq_criterion: Optional[torch.nn.Module] = None, |
| rq_lambda: float = 0.5, |
| codebook_dim: int = 16, |
| vocab_size: int = 8192, |
| rq_input_type: str = "teacher", |
| ): |
| super().__init__( |
| optimizer=optimizer, |
| net=net, |
| warmup_pct=warmup_pct, |
| final_lr_ratio=final_lr_ratio, |
| ema_decay=ema_decay, |
| ema_end_decay=ema_end_decay, |
| ema_anneal_end_step=ema_anneal_end_step, |
| spectrogram_adjustment_mode=spectrogram_adjustment_mode, |
| criterion=jepa_criterion, |
| ) |
| self.save_hyperparameters( |
| logger=False, ignore=["jepa_criterion", "rq_criterion", "net", "optimizer"] |
| ) |
|
|
| self.rq_lambda = rq_lambda |
| |
| if rq_criterion is not None: |
| self.rq_criterion = ( |
| rq_criterion() |
| if isinstance(rq_criterion, (type, functools.partial)) |
| or callable(rq_criterion) |
| and not isinstance(rq_criterion, nn.Module) |
| else rq_criterion |
| ) |
| else: |
| self.rq_criterion = nn.CrossEntropyLoss() |
|
|
| self.rq_input_type = rq_input_type |
| if self.rq_input_type not in ["teacher", "spectrogram"]: |
| raise ValueError( |
| f"rq_input_type must be 'teacher' or 'spectrogram', got {self.rq_input_type}" |
| ) |
|
|
| |
| |
| if self.rq_input_type == "teacher": |
| |
| quantizer_input_dim = net.get("encoder", {}).get("embed_dim", 768) |
| else: |
| |
| |
| patch_size = self.patch_embed.patch_size |
| in_chans = self.patch_embed.in_chans |
| quantizer_input_dim = patch_size[0] * patch_size[1] * in_chans |
|
|
| self.quantizer = RandomProjectionQuantizer( |
| input_dim=quantizer_input_dim, cb_dim=codebook_dim, cb_vocab=vocab_size |
| ) |
| |
| for p in self.quantizer.parameters(): |
| p.requires_grad = False |
|
|
| |
| |
| predictor_config = net.get("predictor", {}) |
| predictor_embed_dim = predictor_config.get("embed_dim", 768) |
| self.rq_proj = nn.Linear(predictor_embed_dim, vocab_size) |
|
|
| def _calculate_combined_loss( |
| self, |
| predictions_raw: torch.Tensor, |
| teacher_targets: torch.Tensor, |
| rq_logits: torch.Tensor, |
| rq_targets: torch.Tensor, |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| """ |
| Calculates both JEPA and RQ losses and combines them. |
| """ |
| |
| |
| predictions_jepa = self.predictor_output_proj( |
| predictions_raw |
| ) |
|
|
| jepa_loss = self.criterion( |
| predictions_jepa, teacher_targets |
| ) |
|
|
| |
| |
| |
| rq_loss = self.rq_criterion( |
| rq_logits.reshape(-1, self.hparams.vocab_size), rq_targets.reshape(-1) |
| ) |
|
|
| |
| total_loss = self.rq_lambda * jepa_loss + (1 - self.rq_lambda) * rq_loss |
|
|
| return total_loss, jepa_loss, rq_loss |
|
|
| def _get_raw_patches(self, spec: torch.Tensor) -> torch.Tensor: |
| """ |
| Extract raw key-value patches from spectrogram. |
| |
| Args: |
| spec (torch.Tensor): Adjusted spectrogram [B, C, F, T]. |
| |
| Returns: |
| torch.Tensor: Flattened patches [B, N, patch_dim] |
| """ |
| patch_size = self.patch_embed.patch_size |
|
|
| |
| |
| patches = F.unfold(spec, kernel_size=patch_size, stride=patch_size) |
| patches = patches.transpose(1, 2) |
|
|
| return patches |
|
|
| def _get_rq_targets_input( |
| self, spec: torch.Tensor, teacher_full: torch.Tensor, mask_indices: torch.Tensor |
| ) -> torch.Tensor: |
| """ |
| Helper to get the input for the RQ quantizer (either teacher embeddings or raw patches). |
| Only returns the targets for the MASKED locations. |
| """ |
| if self.rq_input_type == "teacher": |
| |
| return teacher_full[:, mask_indices, :] |
| else: |
| |
| |
| if spec is None: |
| raise ValueError( |
| "Spectrogram cannot be None when rq_input_type is 'spectrogram'" |
| ) |
| raw_patches = self._get_raw_patches(spec) |
| return raw_patches[:, mask_indices, :] |
|
|
| def training_step(self, batch: Dict[str, Any], batch_idx: int) -> torch.Tensor: |
| waveform = batch["waveform"] |
|
|
| patches, current_grid_size = self._process_audio(waveform) |
| B, N, D = patches.shape |
|
|
| mask = self.mask_generator(1, device=self.device, grid_size=current_grid_size) |
| mask = mask.expand(B, -1) |
|
|
| student_out = self.compute_student(patches, mask, current_grid_size) |
| predictions_raw = self.compute_predictor(student_out, mask, current_grid_size) |
|
|
| self._update_teacher() |
|
|
| with torch.no_grad(): |
| teacher_full = self.teacher(patches, grid_size=current_grid_size) |
|
|
| |
| m = mask[0] |
| mask_indices = torch.nonzero(m).flatten() |
|
|
| |
| teacher_targets = teacher_full[:, mask_indices, :] |
|
|
| |
| with torch.no_grad(): |
| |
| spec = None |
| if self.rq_input_type == "spectrogram": |
| |
| spec = self.spectrogram(waveform) |
| spec = self._adjust_spectrogram(spec) |
|
|
| rq_targets_input = self._get_rq_targets_input( |
| spec, teacher_full, mask_indices |
| ) |
| rq_targets = self.quantizer(rq_targets_input) |
|
|
| |
| rq_logits = self.rq_proj(predictions_raw) |
|
|
| loss, jepa_loss, rq_loss = self._calculate_combined_loss( |
| predictions_raw, teacher_targets, rq_logits, rq_targets |
| ) |
|
|
| self.log( |
| "train/loss", loss, on_step=True, on_epoch=True, prog_bar=True, batch_size=B |
| ) |
| self.log( |
| "train/jepa_loss", jepa_loss, on_step=True, on_epoch=True, batch_size=B |
| ) |
| self.log("train/rq_loss", rq_loss, on_step=True, on_epoch=True, batch_size=B) |
| return loss |
|
|
| def validation_step(self, batch: Dict[str, Any], batch_idx: int) -> torch.Tensor: |
| waveform = batch["waveform"] |
|
|
| patches, current_grid_size = self._process_audio(waveform) |
| B, N, D = patches.shape |
|
|
| mask = self.mask_generator(1, device=self.device, grid_size=current_grid_size) |
| mask = mask.expand(B, -1) |
|
|
| student_out = self.compute_student(patches, mask, current_grid_size) |
| predictions_raw = self.compute_predictor(student_out, mask, current_grid_size) |
|
|
| with torch.no_grad(): |
| teacher_full = self.teacher(patches, grid_size=current_grid_size) |
|
|
| |
| m = mask[0] |
| mask_indices = torch.nonzero(m).flatten() |
|
|
| |
| teacher_targets = teacher_full[ |
| :, mask_indices, : |
| ] |
|
|
| |
| spec = None |
| if self.rq_input_type == "spectrogram": |
| spec = self.spectrogram(waveform) |
| spec = self._adjust_spectrogram(spec) |
|
|
| rq_targets_input = self._get_rq_targets_input( |
| spec, teacher_full, mask_indices |
| ) |
| rq_targets = self.quantizer(rq_targets_input) |
|
|
| |
| rq_logits = self.rq_proj(predictions_raw) |
|
|
| loss, jepa_loss, rq_loss = self._calculate_combined_loss( |
| predictions_raw, teacher_targets, rq_logits, rq_targets |
| ) |
|
|
| self.log( |
| "val/loss", loss, on_step=False, on_epoch=True, prog_bar=True, batch_size=B |
| ) |
| self.log("val/jepa_loss", jepa_loss, on_step=False, on_epoch=True, batch_size=B) |
| self.log("val/rq_loss", rq_loss, on_step=False, on_epoch=True, batch_size=B) |
| return loss |
|
|