bach-or-bot / src /spectttra /spectttra.py
krislette's picture
Auto-deploy from GitHub: 5ac21603a8274a2350875ec7db1bd58cbf2ee539
75d43d2
import torch
import torch.nn as nn
from pathlib import Path
from .transformer import Transformer
from .tokenizer import STTokenizer
from src.spectttra.feature import FeatureExtractor
class SpecTTTra(nn.Module):
"""
SpecTTTra: A Spectro-Temporal Transformer model for audio representation learning.
This model first tokenizes the input spectrogram into temporal and spectral tokens,
then processes them with a Transformer encoder to capture spectro-temporal dependencies.
"""
def __init__(
self,
input_spec_dim,
input_temp_dim,
embed_dim,
t_clip,
f_clip,
num_heads,
num_layers,
pre_norm=False,
pe_learnable=False,
pos_drop_rate=0.0,
attn_drop_rate=0.0,
proj_drop_rate=0.0,
mlp_ratio=4.0,
):
"""
Initialize the SpecTTTra model.
Args:
input_spec_dim (int): Input spectrogram frequency dimension (F).
input_temp_dim (int): Input spectrogram temporal dimension (T).
embed_dim (int): Embedding dimension for tokens.
t_clip (int): Temporal clip size for tokenization.
f_clip (int): Spectral clip size for tokenization.
num_heads (int): Number of attention heads in the transformer.
num_layers (int): Number of transformer layers.
pre_norm (bool, optional): Whether to apply pre-normalization. Defaults to False.
pe_learnable (bool, optional): If True, use learnable positional embeddings. Defaults to False.
pos_drop_rate (float, optional): Dropout rate for positional embeddings. Defaults to 0.0.
attn_drop_rate (float, optional): Dropout rate for attention. Defaults to 0.0.
proj_drop_rate (float, optional): Dropout rate for projection layers. Defaults to 0.0.
mlp_ratio (float, optional): Expansion ratio for MLP hidden dimension. Defaults to 4.0.
"""
super(SpecTTTra, self).__init__()
self.input_spec_dim = input_spec_dim
self.input_temp_dim = input_temp_dim
self.embed_dim = embed_dim
self.t_clip = t_clip
self.f_clip = f_clip
self.num_heads = num_heads
self.num_layers = num_layers
self.pre_norm = (
pre_norm # Applied after tokenization before transformer (used in CLIP)
)
self.pe_learnable = pe_learnable # Learned positional encoding
self.pos_drop_rate = pos_drop_rate
self.attn_drop_rate = attn_drop_rate
self.proj_drop_rate = proj_drop_rate
self.mlp_ratio = mlp_ratio
# Tokenizer for spectro-temporal features
self.st_tokenizer = STTokenizer(
input_spec_dim,
input_temp_dim,
t_clip,
f_clip,
embed_dim,
pre_norm=pre_norm,
pe_learnable=pe_learnable,
)
# Dropout applied after tokenization
self.pos_drop = nn.Dropout(p=pos_drop_rate)
# Transformer encoder
self.transformer = Transformer(
embed_dim,
num_heads,
num_layers,
attn_drop=self.attn_drop_rate,
proj_drop=self.proj_drop_rate,
mlp_ratio=self.mlp_ratio,
)
def forward(self, x):
"""
Forward pass of SpecTTTra.
Args:
x (torch.Tensor): Input spectrogram of shape
- (B, 1, F, T) if channel dimension exists
- (B, F, T) otherwise
Returns:
torch.Tensor: Transformer-encoded spectro-temporal tokens of shape
(B, T/t + F/f, embed_dim)
"""
# Squeeze the channel dimension if it exists
if x.dim() == 4:
x = x.squeeze(1)
# Spectro-temporal tokenization
spectro_temporal_tokens = self.st_tokenizer(x)
# Positional dropout
spectro_temporal_tokens = self.pos_drop(spectro_temporal_tokens)
# Transformer
output = self.transformer(spectro_temporal_tokens) # shape: (B, T/t + F/f, dim)
return output
def build_spectttra_from_cfg(cfg, device):
"""
Constructs the SpecTTTra model and its associated FeatureExtractor from a given configuration.
Args:
cfg (SimpleNamespace): Configuration object containing model and feature extraction parameters. Expected attributes include:
- cfg.melspec.n_mels: Number of mel frequency bins.
- cfg.model: Model-specific parameters (e.g., embed_dim, t_clip, f_clip, etc.).
device (torch.device): The device on which the model and feature extractor will be allocated (e.g., 'cpu' or 'cuda').
Returns:
tuple:
FeatureExtractor: Initialized feature extraction module moved to the specified device.
SpecTTTra: Constructed SpecTTTra model moved to the specified device.
"""
feat_ext = FeatureExtractor(cfg).to(device)
# The pre-trained model expects specific, fixed input dimensions.
# Hardcoded to ensure the model architecture matches the checkpoint weights exactly.
# The expected number of frames (n_frames) is taken directly from the RuntimeError message.
n_mels = cfg.melspec.n_mels # n_mels should be 128
n_frames = 3744 # n_frames match the checkpoint's expectation
print(f"[INFO] Initializing SpecTTTra with fixed dimensions: n_mels={n_mels}, n_frames={n_frames}")
model_cfg = cfg.model
model = SpecTTTra(
input_spec_dim=n_mels,
input_temp_dim=n_frames,
embed_dim=model_cfg.embed_dim,
t_clip=model_cfg.t_clip,
f_clip=model_cfg.f_clip,
num_heads=model_cfg.num_heads,
num_layers=model_cfg.num_layers,
pre_norm=model_cfg.pre_norm,
pe_learnable=model_cfg.pe_learnable,
pos_drop_rate=model_cfg.pos_drop_rate,
attn_drop_rate=model_cfg.attn_drop_rate,
proj_drop_rate=model_cfg.proj_drop_rate,
mlp_ratio=model_cfg.mlp_ratio,
).to(device)
return feat_ext, model
def load_frozen_spectttra(model, ckpt_path, device):
"""
Loads pretrained SpecTTTra weights from a frozen checkpoint file.
Args:
model (torch.nn.Module): An initialized SpecTTTra model instance to load weights into.
ckpt_path (str or Path): Path to the pretrained model checkpoint file (e.g., 'spectttra_frozen.pth').
device (torch.device): The device to map the loaded weights to (e.g., 'cpu' or 'cuda').
Returns:
model (torch.nn.Module): The SpecTTTra model with loaded pretrained weights, set to evaluation mode.
Raises:
FileNotFoundError: If the specified checkpoint file does not exist at `ckpt_path`.
"""
ckpt_path = Path(ckpt_path)
if not ckpt_path.exists():
raise FileNotFoundError(
f"Pre-trained model not found at {ckpt_path}. "
"Please download 'pytorch_model.bin', rename to 'spectttra_frozen.pth', "
"and place it in the correct directory."
)
print(f"[INFO] Found SpecTTTra checkpoint at {ckpt_path}. Loading weights...")
state = torch.load(ckpt_path, map_location=device)
new_state_dict = {}
for k, v in state.items():
if k.startswith("encoder."):
new_key = k[len("encoder."):]
new_state_dict[new_key] = v
else:
new_state_dict[k] = v
# Now that the shapes match, this should load without a size mismatch error.
missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False)
if missing_keys:
# Might see a few missing keys if your SpecTTTra class is slightly different, but the core should load.
print(f"[WARNING] Missing keys in model: {missing_keys}")
if unexpected_keys:
# Seeing 'classifier' or 'ft_extractor' keys here is NORMAL and SAFE.
print(f"[INFO] Unused keys in checkpoint: {unexpected_keys}")
print("[INFO] Successfully loaded pre-trained SpecTTTra weights.")
model.eval()
return model