gce4 / gce4.py
JacobLinCool's picture
Upload folder using huggingface_hub
d69b6f3 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from huggingface_hub import PyTorchModelHubMixin
class ResBlock1D(nn.Module):
"""
Residual Block for extracting rhythmic features from audio spectrograms.
Maintains temporal resolution while increasing receptive field.
"""
def __init__(self, channels, kernel_size=3, dilation=1):
super().__init__()
padding = (kernel_size - 1) * dilation // 2
self.conv1 = nn.Conv1d(
channels, channels, kernel_size, padding=padding, dilation=dilation
)
self.bn1 = nn.BatchNorm1d(channels)
self.conv2 = nn.Conv1d(
channels, channels, kernel_size, padding=padding, dilation=dilation
)
self.bn2 = nn.BatchNorm1d(channels)
def forward(self, x):
res = x
x = F.gelu(self.bn1(self.conv1(x)))
x = self.bn2(self.conv2(x))
return F.gelu(x + res)
class GameChartEvaluator(nn.Module, PyTorchModelHubMixin):
def __init__(self, input_dim=80, d_model=128, n_layers=4):
super().__init__()
# --- Early Fusion ---
# Input is (Batch, 80 * 2, Time)
# We stack Music (80) + Chart (80) = 160 channels
self.input_proj = nn.Conv1d(
input_dim * 2, d_model, kernel_size=3, stride=1, padding=1
)
# --- STRICT TEMPORAL ENCODER ---
# No Pooling (stride=1) to preserve 11ms resolution
# Dilations allow seeing context without losing resolution
self.encoder = nn.Sequential(
ResBlock1D(d_model, kernel_size=3, dilation=1),
ResBlock1D(d_model, kernel_size=3, dilation=2),
ResBlock1D(d_model, kernel_size=3, dilation=4),
ResBlock1D(d_model, kernel_size=3, dilation=8),
# Add more layers if you need wider context (e.g. 16, 32)
)
# --- SCORING HEAD ---
# Simple projection to scalar
self.quality_proj = nn.Linear(d_model, 1)
# Learnable Mixing
self.raw_severity = nn.Parameter(torch.tensor(0.0))
def forward(self, music_mels, chart_mels):
"""
music_mels: (Batch, 80, Time)
chart_mels: (Batch, 80, Time)
"""
# 1. Early Fusion: Concatenate along Channel dimension
# Shape becomes (Batch, 160, Time)
x = torch.cat([music_mels, chart_mels], dim=1)
# 2. Extract Features (Strictly Local + Context)
x = F.gelu(self.input_proj(x))
x = self.encoder(x)
# 3. Predict Score per Frame
# (Batch, Dim, Time) -> (Batch, Time, Dim)
x = x.permute(0, 2, 1)
local_scores = torch.sigmoid(self.quality_proj(x)) # (Batch, Time, 1)
# 4. Error-Sensitive Pooling
avg_score = local_scores.mean(dim=1)
k = max(1, int(local_scores.size(1) * 0.1))
min_vals, _ = torch.topk(local_scores, k, dim=1, largest=False)
worst_score = min_vals.mean(dim=1)
alpha = torch.sigmoid(self.raw_severity)
final_score = (alpha * worst_score) + ((1 - alpha) * avg_score)
return final_score.squeeze(1)
def predict_trace(self, music_mels, chart_mels):
"""
Explainability Method: Returns the second-by-second quality curve.
Returns:
local_scores: (Batch, Time) - The quality score at every timestep.
"""
with torch.no_grad():
# 1. Early Fusion: Concatenate along Channel dimension
# Shape becomes (Batch, 160, Time)
x = torch.cat([music_mels, chart_mels], dim=1)
# 2. Extract Features (Strictly Local + Context)
x = F.gelu(self.input_proj(x))
x = self.encoder(x)
# 3. Predict Score per Frame
# (Batch, Dim, Time) -> (Batch, Time, Dim)
x = x.permute(0, 2, 1)
local_scores = torch.sigmoid(self.quality_proj(x)) # (Batch, Time, 1)
return local_scores.squeeze(2)
if __name__ == "__main__":
# Sanity Check
from torchinfo import summary
model = GameChartEvaluator()
print(
f"Model initialized. Learnable Severity: {torch.sigmoid(model.raw_severity).item():.2f}"
)
# Dummy data (Batch=2, Freq=80, Time=1000)
m = torch.randn(2, 80, 1000)
c = torch.randn(2, 80, 1000)
output = model(m, c)
print(f"Output shape: {output.shape}") # Should be torch.Size([2])
print(f"Scores: {output}")
# Trace check
trace = model.predict_trace(m, c)
print(
f"Trace shape: {trace.shape}"
) # Should be torch.Size([2, 500]) (due to MaxPool1d(2))
summary(model, input_data=[m, c])