|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import math |
|
|
from huggingface_hub import PyTorchModelHubMixin |
|
|
|
|
|
|
|
|
class ResBlock1D(nn.Module): |
|
|
""" |
|
|
Residual Block for extracting rhythmic features from audio spectrograms. |
|
|
Maintains temporal resolution while increasing receptive field. |
|
|
""" |
|
|
|
|
|
def __init__(self, channels, kernel_size=3, dilation=1): |
|
|
super().__init__() |
|
|
padding = (kernel_size - 1) * dilation // 2 |
|
|
self.conv1 = nn.Conv1d( |
|
|
channels, channels, kernel_size, padding=padding, dilation=dilation |
|
|
) |
|
|
self.bn1 = nn.BatchNorm1d(channels) |
|
|
self.conv2 = nn.Conv1d( |
|
|
channels, channels, kernel_size, padding=padding, dilation=dilation |
|
|
) |
|
|
self.bn2 = nn.BatchNorm1d(channels) |
|
|
|
|
|
def forward(self, x): |
|
|
res = x |
|
|
x = F.gelu(self.bn1(self.conv1(x))) |
|
|
x = self.bn2(self.conv2(x)) |
|
|
return F.gelu(x + res) |
|
|
|
|
|
|
|
|
class GameChartEvaluator(nn.Module, PyTorchModelHubMixin): |
|
|
def __init__(self, input_dim=80, d_model=128, n_layers=4): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.input_proj = nn.Conv1d( |
|
|
input_dim * 2, d_model, kernel_size=3, stride=1, padding=1 |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.encoder = nn.Sequential( |
|
|
ResBlock1D(d_model, kernel_size=3, dilation=1), |
|
|
ResBlock1D(d_model, kernel_size=3, dilation=2), |
|
|
ResBlock1D(d_model, kernel_size=3, dilation=4), |
|
|
ResBlock1D(d_model, kernel_size=3, dilation=8), |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
self.quality_proj = nn.Linear(d_model, 1) |
|
|
|
|
|
|
|
|
self.raw_severity = nn.Parameter(torch.tensor(0.0)) |
|
|
|
|
|
def forward(self, music_mels, chart_mels): |
|
|
""" |
|
|
music_mels: (Batch, 80, Time) |
|
|
chart_mels: (Batch, 80, Time) |
|
|
""" |
|
|
|
|
|
|
|
|
x = torch.cat([music_mels, chart_mels], dim=1) |
|
|
|
|
|
|
|
|
x = F.gelu(self.input_proj(x)) |
|
|
x = self.encoder(x) |
|
|
|
|
|
|
|
|
|
|
|
x = x.permute(0, 2, 1) |
|
|
local_scores = torch.sigmoid(self.quality_proj(x)) |
|
|
|
|
|
|
|
|
avg_score = local_scores.mean(dim=1) |
|
|
|
|
|
k = max(1, int(local_scores.size(1) * 0.1)) |
|
|
min_vals, _ = torch.topk(local_scores, k, dim=1, largest=False) |
|
|
worst_score = min_vals.mean(dim=1) |
|
|
|
|
|
alpha = torch.sigmoid(self.raw_severity) |
|
|
final_score = (alpha * worst_score) + ((1 - alpha) * avg_score) |
|
|
|
|
|
return final_score.squeeze(1) |
|
|
|
|
|
def predict_trace(self, music_mels, chart_mels): |
|
|
""" |
|
|
Explainability Method: Returns the second-by-second quality curve. |
|
|
|
|
|
Returns: |
|
|
local_scores: (Batch, Time) - The quality score at every timestep. |
|
|
""" |
|
|
with torch.no_grad(): |
|
|
|
|
|
|
|
|
x = torch.cat([music_mels, chart_mels], dim=1) |
|
|
|
|
|
|
|
|
x = F.gelu(self.input_proj(x)) |
|
|
x = self.encoder(x) |
|
|
|
|
|
|
|
|
|
|
|
x = x.permute(0, 2, 1) |
|
|
local_scores = torch.sigmoid(self.quality_proj(x)) |
|
|
return local_scores.squeeze(2) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
from torchinfo import summary |
|
|
|
|
|
model = GameChartEvaluator() |
|
|
print( |
|
|
f"Model initialized. Learnable Severity: {torch.sigmoid(model.raw_severity).item():.2f}" |
|
|
) |
|
|
|
|
|
|
|
|
m = torch.randn(2, 80, 1000) |
|
|
c = torch.randn(2, 80, 1000) |
|
|
|
|
|
output = model(m, c) |
|
|
print(f"Output shape: {output.shape}") |
|
|
print(f"Scores: {output}") |
|
|
|
|
|
|
|
|
trace = model.predict_trace(m, c) |
|
|
print( |
|
|
f"Trace shape: {trace.shape}" |
|
|
) |
|
|
|
|
|
summary(model, input_data=[m, c]) |
|
|
|