STAR / utils /torch_utilities.py
Yixuan Li
first commit
4853fdc
import logging
import math
from typing import Callable
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
logger = logging.Logger(__file__)
def remove_key_prefix_factory(prefix: str = "module."):
def func(
model_dict: dict[str, torch.Tensor], state_dict: dict[str,
torch.Tensor]
) -> dict[str, torch.Tensor]:
state_dict = {
key[len(prefix):]: value
for key, value in state_dict.items() if key.startswith(prefix)
}
return state_dict
return func
def merge_matched_keys(
model_dict: dict[str, torch.Tensor], state_dict: dict[str, torch.Tensor]
) -> dict[str, torch.Tensor]:
"""
Args:
model_dict:
The state dict of the current model, which is going to load pretrained parameters
state_dict:
A dictionary of parameters from a pre-trained model.
Returns:
dict[str, torch.Tensor]:
The updated state dict, where parameters with matched keys and shape are
updated with values in `state_dict`.
"""
pretrained_dict = {}
mismatch_keys = []
for key, value in state_dict.items():
if key in model_dict and model_dict[key].shape == value.shape:
pretrained_dict[key] = value
else:
mismatch_keys.append(key)
logger.info(
f"Loading pre-trained model, with mismatched keys {mismatch_keys}"
)
model_dict.update(pretrained_dict)
return model_dict
def load_pretrained_model(
model: nn.Module,
ckpt_or_state_dict: str | Path | dict[str, torch.Tensor],
state_dict_process_fn: Callable = merge_matched_keys
) -> None:
state_dict = ckpt_or_state_dict
if not isinstance(state_dict, dict):
state_dict = torch.load(ckpt_or_state_dict, "cpu")
model_dict = model.state_dict()
state_dict = state_dict_process_fn(model_dict, state_dict)
model.load_state_dict(state_dict)
def create_mask_from_length(
lengths: torch.Tensor, max_length: int | None = None
):
if max_length is None:
max_length = max(lengths)
idxs = torch.arange(max_length).reshape(1, -1) # (1, max_length)
mask = idxs.to(lengths.device) < lengths.view(-1, 1)
# (1, max_length) < (batch_size, 1) -> (batch_size, max_length)
return mask
def loss_with_mask(
loss: torch.Tensor,
mask: torch.Tensor,
reduce: bool = True
) -> torch.Tensor:
"""
Apply a mask to the loss tensor and optionally reduce it.
Args:
loss: Tensor of shape (b, t, ...) representing the loss values.
mask: Tensor of shape (b, t) where 1 indicates valid positions and 0 indicates masked positions.
reduce: If True, return a single scalar value; otherwise, return a tensor of shape (b,).
Returns:
torch.Tensor: A scalar if reduce is True, otherwise a tensor of shape (b,).
"""
expanded_mask = mask[(..., ) + (None, ) * (loss.ndim - mask.ndim)]
expanded_mask = expanded_mask.expand_as(loss)
masked_loss = loss * expanded_mask
sum_dims = tuple(range(1, loss.ndim))
loss_sum = masked_loss.sum(dim=sum_dims)
mask_sum = expanded_mask.sum(dim=sum_dims)
loss = loss_sum / mask_sum
if reduce:
return loss.mean()
else:
return loss
def convert_pad_shape(pad_shape: list[list[int]]):
l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist]
return pad_shape
def create_alignment_path(duration: torch.Tensor, mask: torch.Tensor):
device = duration.device
b, t_x, t_y = mask.shape
cum_duration = torch.cumsum(duration, 1)
cum_duration_flat = cum_duration.view(b * t_x)
path = create_mask_from_length(cum_duration_flat, t_y).float()
path = path.view(b, t_x, t_y)
# take the diff on the `t_x` axis
path = path - torch.nn.functional.pad(
path, convert_pad_shape([[0, 0], [1, 0], [0, 0]])
)[:, :-1]
path = path * mask
return path
def trim_or_pad_length(x: torch.Tensor, target_length: int, length_dim: int):
"""
Adjusts the size of the specified dimension of tensor x to match `target_length`.
Args:
x:
Input tensor.
target_length:
Desired size of the specified dimension.
length_dim:
The dimension to modify.
Returns:
torch.Tensor: The adjusted tensor.
"""
current_length = x.shape[length_dim]
if current_length > target_length:
# Truncate the tensor
slices = [slice(None)] * x.ndim
slices[length_dim] = slice(0, target_length)
return x[tuple(slices)]
elif current_length < target_length:
# Pad the tensor with zeros
pad_shape = list(x.shape)
pad_length = target_length - current_length
pad_shape[length_dim] = pad_length # Shape for left padding
padding = torch.zeros(pad_shape, dtype=x.dtype, device=x.device)
return torch.cat([x, padding], dim=length_dim)
return x
def concat_non_padding(
seq1: torch.Tensor, mask1: torch.BoolTensor, seq2: torch.Tensor,
mask2: torch.BoolTensor
):
"""
Args
seq1 : Tensor (B, L1, E)
First sequence.
mask1 : BoolTensor (B, L1)
True for valid tokens in seq1, False for padding.
seq2 : Tensor (B, L2, E)
Second sequence.
mask2 : BoolTensor (B, L2)
True for valid tokens in seq2, False for padding.
Returns
concat_seq : Tensor (B, L1+L2, E)
Both sequences concatenated; valid tokens are left-aligned,
padding on the right is 0.
concat_mask: BoolTensor (B, L1+L2)
Mask for the concatenated sequence.
perm : LongTensor (B, L1+L2)
Permutation that maps **original indices → new indices**.
Needed for restoring the original sequences.
"""
mask1, mask2 = mask1.bool(), mask2.bool()
B, L1, E = seq1.shape
L2 = seq2.size(1)
L = L1 + L2
seq_cat = torch.cat([seq1, seq2], dim=1) # (B, L, E)
mask_cat = torch.cat([mask1, mask2], dim=1) # (B, L)
# ----- Key step: stable sort so that all valid tokens move to the left -----
# Padding positions get +L, guaranteeing the largest “score” → sorted to the end.
positions = torch.arange(L, device=seq_cat.device).unsqueeze(0) # (1, L)
sort_score = positions + (~mask_cat) * L
perm = sort_score.argsort(dim=1, stable=True) # (B, L)
# Build concatenated sequence & mask
gather_idx = perm.unsqueeze(-1).expand(-1, -1, E) # (B, L, E)
concat_seq = seq_cat.gather(1, gather_idx)
concat_mask = mask_cat.gather(1, perm)
# Explicitly zero out the right-hand padding region for safety
concat_seq = concat_seq * concat_mask.unsqueeze(-1)
return concat_seq, concat_mask, perm
def restore_from_concat(
concat_seq: torch.Tensor, mask1: torch.BoolTensor, mask2: torch.BoolTensor,
perm: torch.LongTensor
):
"""
Restore (seq1, seq2) from the concatenated sequence produced by
`concat_non_padding`, using the returned permutation `perm`.
Fully vectorised — no Python loops.
"""
mask1, mask2 = mask1.bool(), mask2.bool()
B, L1 = mask1.shape
L2 = mask2.size(1)
E = concat_seq.size(-1)
# Inverse permutation: maps **new_idx → old_idx**
inv_perm = torch.empty_like(perm)
inv_perm.scatter_(
1, perm,
torch.arange(L1 + L2, device=perm.device).unsqueeze(0).expand(B, -1)
)
# Bring tokens back to their original order
gather_idx = inv_perm.unsqueeze(-1).expand(-1, -1, E)
seq_cat_rec = concat_seq.gather(1, gather_idx) # (B, L1+L2, E)
# Split back into the two sequences and mask out padding positions
seq1_restore, seq2_restore = seq_cat_rec.split([L1, L2], dim=1)
seq1_restore = seq1_restore * mask1.unsqueeze(-1)
seq2_restore = seq2_restore * mask2.unsqueeze(-1)
return seq1_restore, seq2_restore
def contains_nan(data):
"""check if data contains NaN"""
if isinstance(data, torch.Tensor):
return torch.isnan(data).any().item()
elif isinstance(data, np.ndarray):
return np.isnan(data).any()
elif isinstance(data, float):
return math.isnan(data)
elif isinstance(data, (list, tuple)):
return any(contains_nan(x) for x in data)
elif isinstance(data, dict):
return any(contains_nan(v) for v in data.values())
return False
def check_nan_in_batch(batch):
"""check if batch contains NaN and return nan audio ids"""
assert type(batch)==dict,"batch type error"
nan_audio_ids=[]
audio_ids=batch["audio_id"]
audio_id2content={}
for idx,audio_id in enumerate(audio_ids):
content=[]
for k,v in batch.items():
if k=="audio_id":
continue
content.append(v[idx])
audio_id2content[audio_id]=content
for audio_id,content in audio_id2content.items():
if contains_nan(content):
nan_audio_ids.append(audio_id)
print(f"{audio_id} contains NaN")
return nan_audio_ids