slslslrhfem commited on
Commit
be8ccdd
·
1 Parent(s): 216b804

commit for ICASSP 2026 github repository

Browse files
Files changed (2) hide show
  1. inference.py +6 -23
  2. model.py +8 -218
inference.py CHANGED
@@ -18,22 +18,6 @@ from preprocess import get_segments_from_wav, find_optimal_segment_length
18
 
19
 
20
 
21
- def highpass_filter(y, sr, cutoff=1000, order=5):
22
- if isinstance(sr, np.ndarray):
23
- sr = np.mean(sr)
24
- if not isinstance(sr, (int, float)):
25
- raise ValueError(f"sr must be a number, but got {type(sr)}: {sr}")
26
-
27
- nyquist = 0.5 * sr
28
- if cutoff <= 0 or cutoff >= nyquist:
29
- cutoff = max(10, min(cutoff, nyquist - 1))
30
-
31
- normal_cutoff = cutoff / nyquist
32
- b, a = signal.butter(order, normal_cutoff, btype='high', analog=False)
33
- y_filtered = signal.lfilter(b, a, y)
34
- return y_filtered
35
-
36
-
37
  def load_audio(audio_path: str, sr: int = 24000) -> Tuple[torch.Tensor, torch.Tensor]:
38
  """
39
  오디오 파일을 불러와 세그먼트로 분할합니다.
@@ -216,11 +200,11 @@ def inference(audio_path):
216
  segments = segments.to('cuda').to(torch.float32)
217
  padding_mask = padding_mask.to('cuda').unsqueeze(0)
218
  logits,embedding = backbone_model(segments.squeeze(1))
219
- test_dataset = FakeMusicCapsDataset([audio_path], [0], target_duration=10.0)
220
- test_data, test_target = test_dataset[0]
221
- test_data = test_data.to('cuda').to(torch.float32)
222
- test_target = test_target.to('cuda')
223
- output, _ = backbone_model(test_data.unsqueeze(0))
224
 
225
 
226
 
@@ -230,7 +214,6 @@ def inference(audio_path):
230
  input_dim=input_dim,
231
  #emb_model=backbone_model
232
  is_emb = True,
233
- #mode = 'both'
234
  )
235
 
236
 
@@ -247,5 +230,5 @@ def inference(audio_path):
247
  return results
248
 
249
  if __name__ == "__main__":
250
- main()
251
 
 
18
 
19
 
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  def load_audio(audio_path: str, sr: int = 24000) -> Tuple[torch.Tensor, torch.Tensor]:
22
  """
23
  오디오 파일을 불러와 세그먼트로 분할합니다.
 
200
  segments = segments.to('cuda').to(torch.float32)
201
  padding_mask = padding_mask.to('cuda').unsqueeze(0)
202
  logits,embedding = backbone_model(segments.squeeze(1))
203
+ # test_dataset = FakeMusicCapsDataset([audio_path], [0], target_duration=10.0)
204
+ # test_data, test_target = test_dataset[0]
205
+ # test_data = test_data.to('cuda').to(torch.float32)
206
+ # test_target = test_target.to('cuda')
207
+ # output, _ = backbone_model(test_data.unsqueeze(0))
208
 
209
 
210
 
 
214
  input_dim=input_dim,
215
  #emb_model=backbone_model
216
  is_emb = True,
 
217
  )
218
 
219
 
 
230
  return results
231
 
232
  if __name__ == "__main__":
233
+ inference("some path")
234
 
model.py CHANGED
@@ -36,22 +36,12 @@ class MusicAudioClassifier(pl.LightningModule):
36
  hidden_dim=hidden_dim,
37
  num_classes=num_classes
38
  )
39
- elif backbone == 'guided_segment_transformer':
40
- self.model = GuidedSegmentTransformer(
41
- input_dim=input_dim,
42
- hidden_dim=hidden_dim,
43
- num_classes=num_classes
44
- )
45
- elif backbone == 'ultra_segment_processor':
46
- self.model = UltraModernSegmentProcessor(
47
- input_dim=input_dim,
48
- hidden_dim=hidden_dim,
49
- num_classes=num_classes
50
- )
51
- self.emb_model = emb_model
52
- self.learning_rate = learning_rate
53
- self.is_emb = is_emb
54
- self.num_classes = num_classes
55
 
56
  def _process_audio_batch(self, x: torch.Tensor) -> torch.Tensor:
57
  B, S = x.shape[:2] # [B, S, C, M, T] or [B, S, C, T] for wav, [B, S, 1?, embsize] for emb
@@ -529,6 +519,7 @@ class MultiScaleAdaptivePooler(nn.Module):
529
  Args:
530
  x: (batch, seq_len, hidden_dim) - sequence features
531
  padding_mask: (batch, seq_len) - padding mask
 
532
  """
533
  batch_size = x.size(0)
534
 
@@ -838,205 +829,4 @@ class FusionSegmentTransformer(nn.Module):
838
  pooled = pooled.half()
839
  return self.classification_head(pooled)
840
 
841
- import torch
842
- import torch.nn as nn
843
- import torch.nn.functional as F
844
- import numpy as np
845
- from typing import Optional
846
- import math
847
-
848
- class RMSNorm(nn.Module):
849
- """RMS Normalization - 안정적"""
850
- def __init__(self, dim: int, eps: float = 1e-6):
851
- super().__init__()
852
- self.eps = eps
853
- self.weight = nn.Parameter(torch.ones(dim))
854
-
855
- def forward(self, x):
856
- return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
857
-
858
- class SwiGLU(nn.Module):
859
- """SwiGLU Activation - 단순 버전"""
860
- def __init__(self, dim: int):
861
- super().__init__()
862
- self.w1 = nn.Linear(dim, dim * 2, bias=False)
863
- self.w2 = nn.Linear(dim, dim, bias=False)
864
-
865
- def forward(self, x):
866
- return self.w2(F.silu(self.w1(x)[:, :, :x.size(-1)])) # 차원 맞춤
867
-
868
- class GroupedQueryAttention(nn.Module):
869
- """단순한 GQA - 에러 방지"""
870
- def __init__(self, d_model: int, num_heads: int = 8):
871
- super().__init__()
872
- assert d_model % num_heads == 0
873
-
874
- self.d_model = d_model
875
- self.num_heads = num_heads
876
- self.head_dim = d_model // num_heads
877
-
878
- # 모든 projection을 동일한 차원으로
879
- self.q_proj = nn.Linear(d_model, d_model, bias=False)
880
- self.k_proj = nn.Linear(d_model, d_model, bias=False)
881
- self.v_proj = nn.Linear(d_model, d_model, bias=False)
882
- self.o_proj = nn.Linear(d_model, d_model, bias=False)
883
-
884
- self.scale = 1.0 / math.sqrt(self.head_dim)
885
-
886
- def forward(self, x, pairwise_matrix=None, padding_mask=None):
887
- B, L, D = x.shape
888
-
889
- Q = self.q_proj(x).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
890
- K = self.k_proj(x).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
891
- V = self.v_proj(x).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
892
-
893
- scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale
894
-
895
- if pairwise_matrix is not None:
896
- scores = scores + pairwise_matrix.unsqueeze(1)
897
-
898
- if padding_mask is not None:
899
- mask_4d = padding_mask.unsqueeze(1).unsqueeze(1).expand(-1, self.num_heads, L, -1)
900
- scores = scores.masked_fill(mask_4d, float('-inf'))
901
-
902
- attn_weights = F.softmax(scores, dim=-1)
903
- attn_output = torch.matmul(attn_weights, V)
904
-
905
- attn_output = attn_output.transpose(1, 2).contiguous().view(B, L, D)
906
- return self.o_proj(attn_output)
907
-
908
- class SimpleModernLayer(nn.Module):
909
- """단순하고 안전한 모던 레이어"""
910
- def __init__(self, d_model: int, num_heads: int = 8):
911
- super().__init__()
912
-
913
- # RMSNorm
914
- self.norm1 = RMSNorm(d_model)
915
- self.norm2 = RMSNorm(d_model)
916
-
917
- # Attention
918
- self.attention = GroupedQueryAttention(d_model, num_heads)
919
-
920
- # Feed forward
921
- self.ffn = SwiGLU(d_model)
922
-
923
- def forward(self, x, pairwise_matrix=None, padding_mask=None):
924
- # Attention with residual
925
- normed_x = self.norm1(x)
926
- attn_out = self.attention(normed_x, pairwise_matrix, padding_mask)
927
- x = x + attn_out
928
-
929
- # FFN with residual
930
- normed_x2 = self.norm2(x)
931
- ffn_out = self.ffn(normed_x2)
932
- x = x + ffn_out
933
-
934
- return x
935
-
936
- class SimpleQuantumPooling(nn.Module):
937
- """단순한 어텐션 풀링"""
938
- def __init__(self, d_model: int):
939
- super().__init__()
940
-
941
- # 3가지 풀링 방법
942
- self.attention_pool = nn.MultiheadAttention(d_model, 8, batch_first=True)
943
- self.query_token = nn.Parameter(torch.randn(1, 1, d_model) * 0.02)
944
-
945
- # 결합
946
- self.final_proj = nn.Linear(d_model * 3, d_model, bias=False)
947
-
948
- def forward(self, x, padding_mask=None):
949
- batch_size = x.size(0)
950
-
951
- # 1. Average pooling
952
- if padding_mask is not None:
953
- mask_expanded = (~padding_mask).float().unsqueeze(-1)
954
- avg_pooled = (x * mask_expanded).sum(dim=1) / mask_expanded.sum(dim=1)
955
- else:
956
- avg_pooled = x.mean(dim=1)
957
-
958
- # 2. Max pooling
959
- if padding_mask is not None:
960
- x_masked = x.clone()
961
- x_masked[padding_mask] = float('-inf')
962
- max_pooled = x_masked.max(dim=1)[0]
963
- else:
964
- max_pooled = x.max(dim=1)[0]
965
-
966
- # 3. Attention pooling
967
- query = self.query_token.expand(batch_size, -1, -1)
968
- attn_pooled, _ = self.attention_pool(
969
- query, x, x, key_padding_mask=padding_mask
970
- )
971
- attn_pooled = attn_pooled.squeeze(1)
972
-
973
- # 결합
974
- combined = torch.cat([avg_pooled, max_pooled, attn_pooled], dim=-1).half()
975
- return self.final_proj(combined)
976
-
977
- class UltraModernSegmentProcessor(nn.Module):
978
- """에러 없는 단순 버전 ✅"""
979
- def __init__(self,
980
- input_dim: int,
981
- hidden_dim: int = 512,
982
- num_heads: int = 8,
983
- num_layers: int = 6,
984
- dropout: float = 0.1,
985
- max_sequence_length: int = 1000,
986
- num_classes: int = 2):
987
- super().__init__()
988
-
989
- assert hidden_dim % num_heads == 0
990
-
991
- self.hidden_dim = hidden_dim
992
- self.input_projection = nn.Linear(input_dim, hidden_dim, bias=False)
993
-
994
- # 모던 레이어들
995
- self.layers = nn.ModuleList([
996
- SimpleModernLayer(hidden_dim, num_heads)
997
- for _ in range(num_layers)
998
- ])
999
-
1000
- # 단순 풀링
1001
- self.pooler = SimpleQuantumPooling(hidden_dim)
1002
-
1003
- # 분류 헤드
1004
- output_dim = 1 if num_classes == 2 else num_classes
1005
-
1006
- self.classifier = nn.Sequential(
1007
- nn.Linear(hidden_dim, hidden_dim // 2, bias=False),
1008
- RMSNorm(hidden_dim // 2),
1009
- nn.SiLU(),
1010
- nn.Dropout(dropout),
1011
- nn.Linear(hidden_dim // 2, hidden_dim // 4, bias=False),
1012
- RMSNorm(hidden_dim // 4),
1013
- nn.SiLU(),
1014
- nn.Dropout(dropout),
1015
- nn.Linear(hidden_dim // 4, output_dim, bias=False)
1016
- )
1017
-
1018
- def forward(self, x: torch.Tensor, padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
1019
- # Input projection
1020
- x_emb = self.input_projection(x)
1021
-
1022
- # Pairwise matrix 계산
1023
- x_expanded = x.unsqueeze(2)
1024
- x_transposed = x.unsqueeze(1)
1025
-
1026
- # 유클리드 거리만 사용 (단순하게)
1027
- distances = torch.mean((x_expanded - x_transposed) ** 2, dim=-1)
1028
- pairwise_matrix = torch.exp(-distances)
1029
-
1030
- if padding_mask is not None:
1031
- pairwise_mask = padding_mask.unsqueeze(1) | padding_mask.unsqueeze(2)
1032
- pairwise_matrix = pairwise_matrix.masked_fill(pairwise_mask, 0.0)
1033
-
1034
- # 레이어들 통과
1035
- for layer in self.layers:
1036
- x_emb = layer(x_emb, pairwise_matrix, padding_mask)
1037
-
1038
- # 풀링
1039
- pooled = self.pooler(x_emb, padding_mask)
1040
-
1041
- # 분류
1042
- return self.classifier(pooled)
 
36
  hidden_dim=hidden_dim,
37
  num_classes=num_classes
38
  )
39
+ # elif backbone == 'guided_segment_transformer':
40
+ # self.model = GuidedSegmentTransformer(
41
+ # input_dim=input_dim,
42
+ # hidden_dim=hidden_dim,
43
+ # num_classes=num_classes
44
+ # )
 
 
 
 
 
 
 
 
 
 
45
 
46
  def _process_audio_batch(self, x: torch.Tensor) -> torch.Tensor:
47
  B, S = x.shape[:2] # [B, S, C, M, T] or [B, S, C, T] for wav, [B, S, 1?, embsize] for emb
 
519
  Args:
520
  x: (batch, seq_len, hidden_dim) - sequence features
521
  padding_mask: (batch, seq_len) - padding mask
522
+ actually not better than avg pooling haha
523
  """
524
  batch_size = x.size(0)
525
 
 
829
  pooled = pooled.half()
830
  return self.classification_head(pooled)
831
 
832
+ import torch