Spaces:
Running
on
Zero
Running
on
Zero
slslslrhfem
commited on
Commit
·
be8ccdd
1
Parent(s):
216b804
commit for ICASSP 2026 github repository
Browse files- inference.py +6 -23
- model.py +8 -218
inference.py
CHANGED
|
@@ -18,22 +18,6 @@ from preprocess import get_segments_from_wav, find_optimal_segment_length
|
|
| 18 |
|
| 19 |
|
| 20 |
|
| 21 |
-
def highpass_filter(y, sr, cutoff=1000, order=5):
|
| 22 |
-
if isinstance(sr, np.ndarray):
|
| 23 |
-
sr = np.mean(sr)
|
| 24 |
-
if not isinstance(sr, (int, float)):
|
| 25 |
-
raise ValueError(f"sr must be a number, but got {type(sr)}: {sr}")
|
| 26 |
-
|
| 27 |
-
nyquist = 0.5 * sr
|
| 28 |
-
if cutoff <= 0 or cutoff >= nyquist:
|
| 29 |
-
cutoff = max(10, min(cutoff, nyquist - 1))
|
| 30 |
-
|
| 31 |
-
normal_cutoff = cutoff / nyquist
|
| 32 |
-
b, a = signal.butter(order, normal_cutoff, btype='high', analog=False)
|
| 33 |
-
y_filtered = signal.lfilter(b, a, y)
|
| 34 |
-
return y_filtered
|
| 35 |
-
|
| 36 |
-
|
| 37 |
def load_audio(audio_path: str, sr: int = 24000) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 38 |
"""
|
| 39 |
오디오 파일을 불러와 세그먼트로 분할합니다.
|
|
@@ -216,11 +200,11 @@ def inference(audio_path):
|
|
| 216 |
segments = segments.to('cuda').to(torch.float32)
|
| 217 |
padding_mask = padding_mask.to('cuda').unsqueeze(0)
|
| 218 |
logits,embedding = backbone_model(segments.squeeze(1))
|
| 219 |
-
test_dataset = FakeMusicCapsDataset([audio_path], [0], target_duration=10.0)
|
| 220 |
-
test_data, test_target = test_dataset[0]
|
| 221 |
-
test_data = test_data.to('cuda').to(torch.float32)
|
| 222 |
-
test_target = test_target.to('cuda')
|
| 223 |
-
output, _ = backbone_model(test_data.unsqueeze(0))
|
| 224 |
|
| 225 |
|
| 226 |
|
|
@@ -230,7 +214,6 @@ def inference(audio_path):
|
|
| 230 |
input_dim=input_dim,
|
| 231 |
#emb_model=backbone_model
|
| 232 |
is_emb = True,
|
| 233 |
-
#mode = 'both'
|
| 234 |
)
|
| 235 |
|
| 236 |
|
|
@@ -247,5 +230,5 @@ def inference(audio_path):
|
|
| 247 |
return results
|
| 248 |
|
| 249 |
if __name__ == "__main__":
|
| 250 |
-
|
| 251 |
|
|
|
|
| 18 |
|
| 19 |
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
def load_audio(audio_path: str, sr: int = 24000) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 22 |
"""
|
| 23 |
오디오 파일을 불러와 세그먼트로 분할합니다.
|
|
|
|
| 200 |
segments = segments.to('cuda').to(torch.float32)
|
| 201 |
padding_mask = padding_mask.to('cuda').unsqueeze(0)
|
| 202 |
logits,embedding = backbone_model(segments.squeeze(1))
|
| 203 |
+
# test_dataset = FakeMusicCapsDataset([audio_path], [0], target_duration=10.0)
|
| 204 |
+
# test_data, test_target = test_dataset[0]
|
| 205 |
+
# test_data = test_data.to('cuda').to(torch.float32)
|
| 206 |
+
# test_target = test_target.to('cuda')
|
| 207 |
+
# output, _ = backbone_model(test_data.unsqueeze(0))
|
| 208 |
|
| 209 |
|
| 210 |
|
|
|
|
| 214 |
input_dim=input_dim,
|
| 215 |
#emb_model=backbone_model
|
| 216 |
is_emb = True,
|
|
|
|
| 217 |
)
|
| 218 |
|
| 219 |
|
|
|
|
| 230 |
return results
|
| 231 |
|
| 232 |
if __name__ == "__main__":
|
| 233 |
+
inference("some path")
|
| 234 |
|
model.py
CHANGED
|
@@ -36,22 +36,12 @@ class MusicAudioClassifier(pl.LightningModule):
|
|
| 36 |
hidden_dim=hidden_dim,
|
| 37 |
num_classes=num_classes
|
| 38 |
)
|
| 39 |
-
elif backbone == 'guided_segment_transformer':
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
elif backbone == 'ultra_segment_processor':
|
| 46 |
-
self.model = UltraModernSegmentProcessor(
|
| 47 |
-
input_dim=input_dim,
|
| 48 |
-
hidden_dim=hidden_dim,
|
| 49 |
-
num_classes=num_classes
|
| 50 |
-
)
|
| 51 |
-
self.emb_model = emb_model
|
| 52 |
-
self.learning_rate = learning_rate
|
| 53 |
-
self.is_emb = is_emb
|
| 54 |
-
self.num_classes = num_classes
|
| 55 |
|
| 56 |
def _process_audio_batch(self, x: torch.Tensor) -> torch.Tensor:
|
| 57 |
B, S = x.shape[:2] # [B, S, C, M, T] or [B, S, C, T] for wav, [B, S, 1?, embsize] for emb
|
|
@@ -529,6 +519,7 @@ class MultiScaleAdaptivePooler(nn.Module):
|
|
| 529 |
Args:
|
| 530 |
x: (batch, seq_len, hidden_dim) - sequence features
|
| 531 |
padding_mask: (batch, seq_len) - padding mask
|
|
|
|
| 532 |
"""
|
| 533 |
batch_size = x.size(0)
|
| 534 |
|
|
@@ -838,205 +829,4 @@ class FusionSegmentTransformer(nn.Module):
|
|
| 838 |
pooled = pooled.half()
|
| 839 |
return self.classification_head(pooled)
|
| 840 |
|
| 841 |
-
import torch
|
| 842 |
-
import torch.nn as nn
|
| 843 |
-
import torch.nn.functional as F
|
| 844 |
-
import numpy as np
|
| 845 |
-
from typing import Optional
|
| 846 |
-
import math
|
| 847 |
-
|
| 848 |
-
class RMSNorm(nn.Module):
|
| 849 |
-
"""RMS Normalization - 안정적"""
|
| 850 |
-
def __init__(self, dim: int, eps: float = 1e-6):
|
| 851 |
-
super().__init__()
|
| 852 |
-
self.eps = eps
|
| 853 |
-
self.weight = nn.Parameter(torch.ones(dim))
|
| 854 |
-
|
| 855 |
-
def forward(self, x):
|
| 856 |
-
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
|
| 857 |
-
|
| 858 |
-
class SwiGLU(nn.Module):
|
| 859 |
-
"""SwiGLU Activation - 단순 버전"""
|
| 860 |
-
def __init__(self, dim: int):
|
| 861 |
-
super().__init__()
|
| 862 |
-
self.w1 = nn.Linear(dim, dim * 2, bias=False)
|
| 863 |
-
self.w2 = nn.Linear(dim, dim, bias=False)
|
| 864 |
-
|
| 865 |
-
def forward(self, x):
|
| 866 |
-
return self.w2(F.silu(self.w1(x)[:, :, :x.size(-1)])) # 차원 맞춤
|
| 867 |
-
|
| 868 |
-
class GroupedQueryAttention(nn.Module):
|
| 869 |
-
"""단순한 GQA - 에러 방지"""
|
| 870 |
-
def __init__(self, d_model: int, num_heads: int = 8):
|
| 871 |
-
super().__init__()
|
| 872 |
-
assert d_model % num_heads == 0
|
| 873 |
-
|
| 874 |
-
self.d_model = d_model
|
| 875 |
-
self.num_heads = num_heads
|
| 876 |
-
self.head_dim = d_model // num_heads
|
| 877 |
-
|
| 878 |
-
# 모든 projection을 동일한 차원으로
|
| 879 |
-
self.q_proj = nn.Linear(d_model, d_model, bias=False)
|
| 880 |
-
self.k_proj = nn.Linear(d_model, d_model, bias=False)
|
| 881 |
-
self.v_proj = nn.Linear(d_model, d_model, bias=False)
|
| 882 |
-
self.o_proj = nn.Linear(d_model, d_model, bias=False)
|
| 883 |
-
|
| 884 |
-
self.scale = 1.0 / math.sqrt(self.head_dim)
|
| 885 |
-
|
| 886 |
-
def forward(self, x, pairwise_matrix=None, padding_mask=None):
|
| 887 |
-
B, L, D = x.shape
|
| 888 |
-
|
| 889 |
-
Q = self.q_proj(x).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
|
| 890 |
-
K = self.k_proj(x).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
|
| 891 |
-
V = self.v_proj(x).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
|
| 892 |
-
|
| 893 |
-
scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale
|
| 894 |
-
|
| 895 |
-
if pairwise_matrix is not None:
|
| 896 |
-
scores = scores + pairwise_matrix.unsqueeze(1)
|
| 897 |
-
|
| 898 |
-
if padding_mask is not None:
|
| 899 |
-
mask_4d = padding_mask.unsqueeze(1).unsqueeze(1).expand(-1, self.num_heads, L, -1)
|
| 900 |
-
scores = scores.masked_fill(mask_4d, float('-inf'))
|
| 901 |
-
|
| 902 |
-
attn_weights = F.softmax(scores, dim=-1)
|
| 903 |
-
attn_output = torch.matmul(attn_weights, V)
|
| 904 |
-
|
| 905 |
-
attn_output = attn_output.transpose(1, 2).contiguous().view(B, L, D)
|
| 906 |
-
return self.o_proj(attn_output)
|
| 907 |
-
|
| 908 |
-
class SimpleModernLayer(nn.Module):
|
| 909 |
-
"""단순하고 안전한 모던 레이어"""
|
| 910 |
-
def __init__(self, d_model: int, num_heads: int = 8):
|
| 911 |
-
super().__init__()
|
| 912 |
-
|
| 913 |
-
# RMSNorm
|
| 914 |
-
self.norm1 = RMSNorm(d_model)
|
| 915 |
-
self.norm2 = RMSNorm(d_model)
|
| 916 |
-
|
| 917 |
-
# Attention
|
| 918 |
-
self.attention = GroupedQueryAttention(d_model, num_heads)
|
| 919 |
-
|
| 920 |
-
# Feed forward
|
| 921 |
-
self.ffn = SwiGLU(d_model)
|
| 922 |
-
|
| 923 |
-
def forward(self, x, pairwise_matrix=None, padding_mask=None):
|
| 924 |
-
# Attention with residual
|
| 925 |
-
normed_x = self.norm1(x)
|
| 926 |
-
attn_out = self.attention(normed_x, pairwise_matrix, padding_mask)
|
| 927 |
-
x = x + attn_out
|
| 928 |
-
|
| 929 |
-
# FFN with residual
|
| 930 |
-
normed_x2 = self.norm2(x)
|
| 931 |
-
ffn_out = self.ffn(normed_x2)
|
| 932 |
-
x = x + ffn_out
|
| 933 |
-
|
| 934 |
-
return x
|
| 935 |
-
|
| 936 |
-
class SimpleQuantumPooling(nn.Module):
|
| 937 |
-
"""단순한 어텐션 풀링"""
|
| 938 |
-
def __init__(self, d_model: int):
|
| 939 |
-
super().__init__()
|
| 940 |
-
|
| 941 |
-
# 3가지 풀링 방법
|
| 942 |
-
self.attention_pool = nn.MultiheadAttention(d_model, 8, batch_first=True)
|
| 943 |
-
self.query_token = nn.Parameter(torch.randn(1, 1, d_model) * 0.02)
|
| 944 |
-
|
| 945 |
-
# 결합
|
| 946 |
-
self.final_proj = nn.Linear(d_model * 3, d_model, bias=False)
|
| 947 |
-
|
| 948 |
-
def forward(self, x, padding_mask=None):
|
| 949 |
-
batch_size = x.size(0)
|
| 950 |
-
|
| 951 |
-
# 1. Average pooling
|
| 952 |
-
if padding_mask is not None:
|
| 953 |
-
mask_expanded = (~padding_mask).float().unsqueeze(-1)
|
| 954 |
-
avg_pooled = (x * mask_expanded).sum(dim=1) / mask_expanded.sum(dim=1)
|
| 955 |
-
else:
|
| 956 |
-
avg_pooled = x.mean(dim=1)
|
| 957 |
-
|
| 958 |
-
# 2. Max pooling
|
| 959 |
-
if padding_mask is not None:
|
| 960 |
-
x_masked = x.clone()
|
| 961 |
-
x_masked[padding_mask] = float('-inf')
|
| 962 |
-
max_pooled = x_masked.max(dim=1)[0]
|
| 963 |
-
else:
|
| 964 |
-
max_pooled = x.max(dim=1)[0]
|
| 965 |
-
|
| 966 |
-
# 3. Attention pooling
|
| 967 |
-
query = self.query_token.expand(batch_size, -1, -1)
|
| 968 |
-
attn_pooled, _ = self.attention_pool(
|
| 969 |
-
query, x, x, key_padding_mask=padding_mask
|
| 970 |
-
)
|
| 971 |
-
attn_pooled = attn_pooled.squeeze(1)
|
| 972 |
-
|
| 973 |
-
# 결합
|
| 974 |
-
combined = torch.cat([avg_pooled, max_pooled, attn_pooled], dim=-1).half()
|
| 975 |
-
return self.final_proj(combined)
|
| 976 |
-
|
| 977 |
-
class UltraModernSegmentProcessor(nn.Module):
|
| 978 |
-
"""에러 없는 단순 버전 ✅"""
|
| 979 |
-
def __init__(self,
|
| 980 |
-
input_dim: int,
|
| 981 |
-
hidden_dim: int = 512,
|
| 982 |
-
num_heads: int = 8,
|
| 983 |
-
num_layers: int = 6,
|
| 984 |
-
dropout: float = 0.1,
|
| 985 |
-
max_sequence_length: int = 1000,
|
| 986 |
-
num_classes: int = 2):
|
| 987 |
-
super().__init__()
|
| 988 |
-
|
| 989 |
-
assert hidden_dim % num_heads == 0
|
| 990 |
-
|
| 991 |
-
self.hidden_dim = hidden_dim
|
| 992 |
-
self.input_projection = nn.Linear(input_dim, hidden_dim, bias=False)
|
| 993 |
-
|
| 994 |
-
# 모던 레이어들
|
| 995 |
-
self.layers = nn.ModuleList([
|
| 996 |
-
SimpleModernLayer(hidden_dim, num_heads)
|
| 997 |
-
for _ in range(num_layers)
|
| 998 |
-
])
|
| 999 |
-
|
| 1000 |
-
# 단순 풀링
|
| 1001 |
-
self.pooler = SimpleQuantumPooling(hidden_dim)
|
| 1002 |
-
|
| 1003 |
-
# 분류 헤드
|
| 1004 |
-
output_dim = 1 if num_classes == 2 else num_classes
|
| 1005 |
-
|
| 1006 |
-
self.classifier = nn.Sequential(
|
| 1007 |
-
nn.Linear(hidden_dim, hidden_dim // 2, bias=False),
|
| 1008 |
-
RMSNorm(hidden_dim // 2),
|
| 1009 |
-
nn.SiLU(),
|
| 1010 |
-
nn.Dropout(dropout),
|
| 1011 |
-
nn.Linear(hidden_dim // 2, hidden_dim // 4, bias=False),
|
| 1012 |
-
RMSNorm(hidden_dim // 4),
|
| 1013 |
-
nn.SiLU(),
|
| 1014 |
-
nn.Dropout(dropout),
|
| 1015 |
-
nn.Linear(hidden_dim // 4, output_dim, bias=False)
|
| 1016 |
-
)
|
| 1017 |
-
|
| 1018 |
-
def forward(self, x: torch.Tensor, padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 1019 |
-
# Input projection
|
| 1020 |
-
x_emb = self.input_projection(x)
|
| 1021 |
-
|
| 1022 |
-
# Pairwise matrix 계산
|
| 1023 |
-
x_expanded = x.unsqueeze(2)
|
| 1024 |
-
x_transposed = x.unsqueeze(1)
|
| 1025 |
-
|
| 1026 |
-
# 유클리드 거리만 사용 (단순하게)
|
| 1027 |
-
distances = torch.mean((x_expanded - x_transposed) ** 2, dim=-1)
|
| 1028 |
-
pairwise_matrix = torch.exp(-distances)
|
| 1029 |
-
|
| 1030 |
-
if padding_mask is not None:
|
| 1031 |
-
pairwise_mask = padding_mask.unsqueeze(1) | padding_mask.unsqueeze(2)
|
| 1032 |
-
pairwise_matrix = pairwise_matrix.masked_fill(pairwise_mask, 0.0)
|
| 1033 |
-
|
| 1034 |
-
# 레이어들 통과
|
| 1035 |
-
for layer in self.layers:
|
| 1036 |
-
x_emb = layer(x_emb, pairwise_matrix, padding_mask)
|
| 1037 |
-
|
| 1038 |
-
# 풀링
|
| 1039 |
-
pooled = self.pooler(x_emb, padding_mask)
|
| 1040 |
-
|
| 1041 |
-
# 분류
|
| 1042 |
-
return self.classifier(pooled)
|
|
|
|
| 36 |
hidden_dim=hidden_dim,
|
| 37 |
num_classes=num_classes
|
| 38 |
)
|
| 39 |
+
# elif backbone == 'guided_segment_transformer':
|
| 40 |
+
# self.model = GuidedSegmentTransformer(
|
| 41 |
+
# input_dim=input_dim,
|
| 42 |
+
# hidden_dim=hidden_dim,
|
| 43 |
+
# num_classes=num_classes
|
| 44 |
+
# )
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
def _process_audio_batch(self, x: torch.Tensor) -> torch.Tensor:
|
| 47 |
B, S = x.shape[:2] # [B, S, C, M, T] or [B, S, C, T] for wav, [B, S, 1?, embsize] for emb
|
|
|
|
| 519 |
Args:
|
| 520 |
x: (batch, seq_len, hidden_dim) - sequence features
|
| 521 |
padding_mask: (batch, seq_len) - padding mask
|
| 522 |
+
actually not better than avg pooling haha
|
| 523 |
"""
|
| 524 |
batch_size = x.size(0)
|
| 525 |
|
|
|
|
| 829 |
pooled = pooled.half()
|
| 830 |
return self.classification_head(pooled)
|
| 831 |
|
| 832 |
+
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|