tc5-exp / tc5 /model.py
JacobLinCool's picture
Implement TaikoConformer7 model, loss function, preprocessing, and training pipeline
812b01c
import torch
import torch.nn as nn
from torchaudio.models import Conformer
from huggingface_hub import PyTorchModelHubMixin
from .config import (
N_MELS,
CNN_CH,
N_HEADS,
D_MODEL,
FF_DIM,
N_LAYERS,
DROPOUT,
DEPTHWISE_CONV_KERNEL_SIZE,
HIDDEN_DIM,
DEVICE,
)
class TaikoConformer5(nn.Module, PyTorchModelHubMixin):
def __init__(self):
super().__init__()
# 1) CNN frontend: frequency-only pooling
self.cnn = nn.Sequential(
nn.Conv2d(1, CNN_CH, 3, stride=(2, 1), padding=1),
nn.BatchNorm2d(CNN_CH),
nn.GELU(),
nn.Dropout2d(DROPOUT),
nn.Conv2d(CNN_CH, CNN_CH, 3, stride=(2, 1), padding=1),
nn.BatchNorm2d(CNN_CH),
nn.GELU(),
nn.Dropout2d(DROPOUT),
)
feat_dim = CNN_CH * (N_MELS // 4)
# 2) Linear projection to model dimension
self.proj = nn.Linear(feat_dim, D_MODEL)
# 3) FiLM conditioning for notes_per_second
self.film = nn.Linear(1, 2 * D_MODEL)
# 4) Conformer encoder
self.encoder = Conformer(
input_dim=D_MODEL,
num_heads=N_HEADS,
ffn_dim=FF_DIM,
num_layers=N_LAYERS,
depthwise_conv_kernel_size=DEPTHWISE_CONV_KERNEL_SIZE,
dropout=DROPOUT,
use_group_norm=False,
convolution_first=False,
)
# 5) Presence regressor head
self.presence_regressor = nn.Sequential(
nn.Dropout(DROPOUT),
nn.Linear(D_MODEL, HIDDEN_DIM),
nn.GELU(),
nn.Dropout(DROPOUT),
nn.Linear(HIDDEN_DIM, 3), # Don, Ka, DrumRoll energy
nn.Sigmoid(), # Output between 0 and 1
)
# 6) Initialize weights
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
elif isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
def forward(
self, mel: torch.Tensor, lengths: torch.Tensor, notes_per_second: torch.Tensor
):
"""
Args:
mel: (B, 1, N_MELS, T_mel)
lengths: (B,) lengths after CNN
notes_per_second: (B,) stream of control values
Returns:
Dict with:
'presence': (B, T_cnn_out, 4)
'lengths': lengths
"""
# CNN frontend
x = self.cnn(mel) # (B, C, F, T)
B, C, F, T = x.size()
x = x.permute(0, 3, 1, 2).reshape(B, T, C * F)
# Project to model dimension
x = self.proj(x) # (B, T, D_MODEL)
# FiLM conditioning
nps = notes_per_second.unsqueeze(-1) # (B, 1)
gamma_beta = self.film(nps) # (B, 2*D_MODEL)
gamma, beta = gamma_beta.chunk(2, dim=-1)
x = gamma.unsqueeze(1) * x + beta.unsqueeze(1)
# Conformer encoder
x, _ = self.encoder(x, lengths=lengths)
# Presence prediction
presence = self.presence_regressor(x)
return {"presence": presence, "lengths": lengths}
if __name__ == "__main__":
model = TaikoConformer5().to(device=DEVICE)
print(model)
for name, param in model.named_parameters():
if param.requires_grad:
print(f"{name}: {param.numel():,}")
params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {params / 1e6:.2f}M")
batch_size = 4
mel_time_steps = 1024
input_mel = torch.randn(batch_size, 1, N_MELS, mel_time_steps).to(DEVICE)
conformer_lengths = torch.tensor(
[mel_time_steps] * batch_size, dtype=torch.long
).to(DEVICE)
notes_per_second_input = torch.tensor([10.0] * batch_size, dtype=torch.float32).to(
DEVICE
)
output = model(input_mel, conformer_lengths, notes_per_second_input)
print("Output shapes:")
for key, value in output.items():
print(f"{key}: {value.shape}")