import logging import math from typing import Callable from pathlib import Path import numpy as np import torch import torch.nn as nn logger = logging.Logger(__file__) def remove_key_prefix_factory(prefix: str = "module."): def func( model_dict: dict[str, torch.Tensor], state_dict: dict[str, torch.Tensor] ) -> dict[str, torch.Tensor]: state_dict = { key[len(prefix):]: value for key, value in state_dict.items() if key.startswith(prefix) } return state_dict return func def merge_matched_keys( model_dict: dict[str, torch.Tensor], state_dict: dict[str, torch.Tensor] ) -> dict[str, torch.Tensor]: """ Args: model_dict: The state dict of the current model, which is going to load pretrained parameters state_dict: A dictionary of parameters from a pre-trained model. Returns: dict[str, torch.Tensor]: The updated state dict, where parameters with matched keys and shape are updated with values in `state_dict`. """ pretrained_dict = {} mismatch_keys = [] for key, value in state_dict.items(): if key in model_dict and model_dict[key].shape == value.shape: pretrained_dict[key] = value else: mismatch_keys.append(key) logger.info( f"Loading pre-trained model, with mismatched keys {mismatch_keys}" ) model_dict.update(pretrained_dict) return model_dict def load_pretrained_model( model: nn.Module, ckpt_or_state_dict: str | Path | dict[str, torch.Tensor], state_dict_process_fn: Callable = merge_matched_keys ) -> None: state_dict = ckpt_or_state_dict if not isinstance(state_dict, dict): state_dict = torch.load(ckpt_or_state_dict, "cpu") model_dict = model.state_dict() state_dict = state_dict_process_fn(model_dict, state_dict) model.load_state_dict(state_dict) def create_mask_from_length( lengths: torch.Tensor, max_length: int | None = None ): if max_length is None: max_length = max(lengths) idxs = torch.arange(max_length).reshape(1, -1) # (1, max_length) mask = idxs.to(lengths.device) < lengths.view(-1, 1) # (1, max_length) < (batch_size, 1) -> (batch_size, max_length) return mask def loss_with_mask( loss: torch.Tensor, mask: torch.Tensor, reduce: bool = True ) -> torch.Tensor: """ Apply a mask to the loss tensor and optionally reduce it. Args: loss: Tensor of shape (b, t, ...) representing the loss values. mask: Tensor of shape (b, t) where 1 indicates valid positions and 0 indicates masked positions. reduce: If True, return a single scalar value; otherwise, return a tensor of shape (b,). Returns: torch.Tensor: A scalar if reduce is True, otherwise a tensor of shape (b,). """ expanded_mask = mask[(..., ) + (None, ) * (loss.ndim - mask.ndim)] expanded_mask = expanded_mask.expand_as(loss) masked_loss = loss * expanded_mask sum_dims = tuple(range(1, loss.ndim)) loss_sum = masked_loss.sum(dim=sum_dims) mask_sum = expanded_mask.sum(dim=sum_dims) loss = loss_sum / mask_sum if reduce: return loss.mean() else: return loss def convert_pad_shape(pad_shape: list[list[int]]): l = pad_shape[::-1] pad_shape = [item for sublist in l for item in sublist] return pad_shape def create_alignment_path(duration: torch.Tensor, mask: torch.Tensor): device = duration.device b, t_x, t_y = mask.shape cum_duration = torch.cumsum(duration, 1) cum_duration_flat = cum_duration.view(b * t_x) path = create_mask_from_length(cum_duration_flat, t_y).float() path = path.view(b, t_x, t_y) # take the diff on the `t_x` axis path = path - torch.nn.functional.pad( path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]) )[:, :-1] path = path * mask return path def trim_or_pad_length(x: torch.Tensor, target_length: int, length_dim: int): """ Adjusts the size of the specified dimension of tensor x to match `target_length`. Args: x: Input tensor. target_length: Desired size of the specified dimension. length_dim: The dimension to modify. Returns: torch.Tensor: The adjusted tensor. """ current_length = x.shape[length_dim] if current_length > target_length: # Truncate the tensor slices = [slice(None)] * x.ndim slices[length_dim] = slice(0, target_length) return x[tuple(slices)] elif current_length < target_length: # Pad the tensor with zeros pad_shape = list(x.shape) pad_length = target_length - current_length pad_shape[length_dim] = pad_length # Shape for left padding padding = torch.zeros(pad_shape, dtype=x.dtype, device=x.device) return torch.cat([x, padding], dim=length_dim) return x def concat_non_padding( seq1: torch.Tensor, mask1: torch.BoolTensor, seq2: torch.Tensor, mask2: torch.BoolTensor ): """ Args seq1 : Tensor (B, L1, E) First sequence. mask1 : BoolTensor (B, L1) True for valid tokens in seq1, False for padding. seq2 : Tensor (B, L2, E) Second sequence. mask2 : BoolTensor (B, L2) True for valid tokens in seq2, False for padding. Returns concat_seq : Tensor (B, L1+L2, E) Both sequences concatenated; valid tokens are left-aligned, padding on the right is 0. concat_mask: BoolTensor (B, L1+L2) Mask for the concatenated sequence. perm : LongTensor (B, L1+L2) Permutation that maps **original indices → new indices**. Needed for restoring the original sequences. """ mask1, mask2 = mask1.bool(), mask2.bool() B, L1, E = seq1.shape L2 = seq2.size(1) L = L1 + L2 seq_cat = torch.cat([seq1, seq2], dim=1) # (B, L, E) mask_cat = torch.cat([mask1, mask2], dim=1) # (B, L) # ----- Key step: stable sort so that all valid tokens move to the left ----- # Padding positions get +L, guaranteeing the largest “score” → sorted to the end. positions = torch.arange(L, device=seq_cat.device).unsqueeze(0) # (1, L) sort_score = positions + (~mask_cat) * L perm = sort_score.argsort(dim=1, stable=True) # (B, L) # Build concatenated sequence & mask gather_idx = perm.unsqueeze(-1).expand(-1, -1, E) # (B, L, E) concat_seq = seq_cat.gather(1, gather_idx) concat_mask = mask_cat.gather(1, perm) # Explicitly zero out the right-hand padding region for safety concat_seq = concat_seq * concat_mask.unsqueeze(-1) return concat_seq, concat_mask, perm def restore_from_concat( concat_seq: torch.Tensor, mask1: torch.BoolTensor, mask2: torch.BoolTensor, perm: torch.LongTensor ): """ Restore (seq1, seq2) from the concatenated sequence produced by `concat_non_padding`, using the returned permutation `perm`. Fully vectorised — no Python loops. """ mask1, mask2 = mask1.bool(), mask2.bool() B, L1 = mask1.shape L2 = mask2.size(1) E = concat_seq.size(-1) # Inverse permutation: maps **new_idx → old_idx** inv_perm = torch.empty_like(perm) inv_perm.scatter_( 1, perm, torch.arange(L1 + L2, device=perm.device).unsqueeze(0).expand(B, -1) ) # Bring tokens back to their original order gather_idx = inv_perm.unsqueeze(-1).expand(-1, -1, E) seq_cat_rec = concat_seq.gather(1, gather_idx) # (B, L1+L2, E) # Split back into the two sequences and mask out padding positions seq1_restore, seq2_restore = seq_cat_rec.split([L1, L2], dim=1) seq1_restore = seq1_restore * mask1.unsqueeze(-1) seq2_restore = seq2_restore * mask2.unsqueeze(-1) return seq1_restore, seq2_restore def contains_nan(data): """check if data contains NaN""" if isinstance(data, torch.Tensor): return torch.isnan(data).any().item() elif isinstance(data, np.ndarray): return np.isnan(data).any() elif isinstance(data, float): return math.isnan(data) elif isinstance(data, (list, tuple)): return any(contains_nan(x) for x in data) elif isinstance(data, dict): return any(contains_nan(v) for v in data.values()) return False def check_nan_in_batch(batch): """check if batch contains NaN and return nan audio ids""" assert type(batch)==dict,"batch type error" nan_audio_ids=[] audio_ids=batch["audio_id"] audio_id2content={} for idx,audio_id in enumerate(audio_ids): content=[] for k,v in batch.items(): if k=="audio_id": continue content.append(v[idx]) audio_id2content[audio_id]=content for audio_id,content in audio_id2content.items(): if contains_nan(content): nan_audio_ids.append(audio_id) print(f"{audio_id} contains NaN") return nan_audio_ids