Seonghyeon Go commited on
Commit
0ede85b
·
1 Parent(s): c3c908f

initial commit for AIGM

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ checkpoints/*.ckpt filter=lfs diff=lfs merge=lfs -text
ISMIR_2025/MERT/__pycache__/networks.cpython-312.pyc CHANGED
Binary files a/ISMIR_2025/MERT/__pycache__/networks.cpython-312.pyc and b/ISMIR_2025/MERT/__pycache__/networks.cpython-312.pyc differ
 
app.py CHANGED
@@ -9,6 +9,7 @@ def detect_ai_audio(audio_file):
9
  Detect whether the uploaded audio file was generated by AI
10
  """
11
  result = inference(audio_file)
 
12
 
13
  # Format result with better styling
14
  if "AI" in str(result).upper() or "artificial" in str(result).lower():
@@ -167,7 +168,7 @@ demo = gr.Interface(
167
  """,
168
  examples=[
169
  ["example-ncs-light it up(human).mp3"],
170
- ["example-Fading Memories(suno v3.5).wav"]
171
  ],
172
  css=custom_css,
173
  theme=gr.themes.Soft(
 
9
  Detect whether the uploaded audio file was generated by AI
10
  """
11
  result = inference(audio_file)
12
+ print(result)
13
 
14
  # Format result with better styling
15
  if "AI" in str(result).upper() or "artificial" in str(result).lower():
 
168
  """,
169
  examples=[
170
  ["example-ncs-light it up(human).mp3"],
171
+ ["example-Strumming Heartbeats(suno v4).mp3"]
172
  ],
173
  css=custom_css,
174
  theme=gr.themes.Soft(
dataset_f.py CHANGED
@@ -4,13 +4,9 @@ import torch
4
  import torchaudio
5
  import librosa
6
  import numpy as np
7
- from sklearn.model_selection import train_test_split
8
  from torch.utils.data import Dataset
9
- from imblearn.over_sampling import RandomOverSampler
10
- from transformers import Wav2Vec2Processor
11
  import torch
12
  import torchaudio
13
- from torch.nn.utils.rnn import pad_sequence
14
  from transformers import Wav2Vec2FeatureExtractor
15
  import scipy.signal as signal
16
  import scipy.signal
 
4
  import torchaudio
5
  import librosa
6
  import numpy as np
 
7
  from torch.utils.data import Dataset
 
 
8
  import torch
9
  import torchaudio
 
10
  from transformers import Wav2Vec2FeatureExtractor
11
  import scipy.signal as signal
12
  import scipy.signal
inference.py CHANGED
@@ -12,7 +12,7 @@ import torchaudio
12
  import scipy.signal as signal
13
  from typing import Dict, List
14
  from dataset_f import FakeMusicCapsDataset
15
-
16
  from preprocess import get_segments_from_wav, find_optimal_segment_length
17
 
18
 
@@ -149,7 +149,7 @@ def run_inference(model, audio_segments: torch.Tensor, padding_mask: torch.Tenso
149
  # 데이터를 half 타입으로 변환
150
  if padding_mask.dim() == 1:
151
  padding_mask = padding_mask.unsqueeze(0) # [48] -> [1, 48]
152
- audio_segments = audio_segments.to(device).half()
153
 
154
  mask = padding_mask.to(device)
155
 
@@ -189,14 +189,14 @@ def scaled_sigmoid(x, scale_factor=0.2, linear_property=0.3):
189
  def get_model(model_type, device):
190
  """Load the specified model."""
191
  if model_type == "MERT":
192
- from ISMIR_2025.MERT.networks import CCV
193
  #from model import MusicAudioClassifier
194
-
195
- model = CCV(embed_dim=768, num_heads=8, num_layers=6, num_classes=2, freeze_feature_extractor=True).to(device)
196
  #model = MusicAudioClassifier(input_dim=768, is_emb=True, mode = 'both', share_parameter = False).to(device)
197
- ckpt_file = 'mert_finetune_10.pth'
198
- model.load_state_dict(torch.load(ckpt_file, map_location=device))
 
 
199
  embed_dim = 768
 
200
  elif model_type == "pure_MERT":
201
  from ISMIR_2025.MERT.networks import MERTFeatureExtractor
202
  model = MERTFeatureExtractor().to(device)
@@ -211,33 +211,22 @@ def get_model(model_type, device):
211
 
212
 
213
  def inference(audio_path):
214
- parser = argparse.ArgumentParser(description="Music classifier inference")
215
- parser.add_argument("--model_type", type=str, required=True, choices=["MERT", "AudioCNN"], help="Type of model")
216
- parser.add_argument("--checkpoint_path", type=str, required=True, help="Path to model checkpoint")
217
- parser.add_argument("--output_path", type=str, default=None, help="Path to save results (default: print to console)")
218
- parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to run inference on")
219
- args = parser.parse_args()
220
- audio_path = "The Chainsmokers & Coldplay - Something Just Like This (Lyric).mp3"
221
-
222
-
223
- # Note: Model loading would be handled by your code
224
- print(f"Loading model of type {args.model_type} from {args.checkpoint_path}")
225
-
226
  backbone_model, input_dim = get_model('MERT', 'cuda')
227
  segments, padding_mask = load_audio(audio_path, sr=24000)
228
- segments = segments.to(args.device).to(torch.float32)
229
- padding_mask = padding_mask.to(args.device).unsqueeze(0)
230
  logits,embedding = backbone_model(segments.squeeze(1))
231
  test_dataset = FakeMusicCapsDataset([audio_path], [0], target_duration=10.0)
232
  test_data, test_target = test_dataset[0]
233
- test_data = test_data.to(args.device).to(torch.float32)
234
- test_target = test_target.to(args.device)
235
  output, _ = backbone_model(test_data.unsqueeze(0))
 
236
 
237
 
238
  # 모델 로드 부분 추가
239
  model = MusicAudioClassifier.load_from_checkpoint(
240
- args.checkpoint_path,
241
  input_dim=input_dim,
242
  #emb_model=backbone_model
243
  is_emb = True,
@@ -248,16 +237,13 @@ def inference(audio_path):
248
  # Run inference
249
  print(f"Segments shape: {segments.shape}")
250
  print("Running inference...")
251
- results = run_inference(model, embedding, padding_mask, device=args.device)
252
 
253
  # 결과 출력
254
  print(f"Results: {results}")
 
255
 
256
- # 결과 저장
257
- if args.output_path:
258
- with open(args.output_path, 'w') as f:
259
- json.dump(results, f, indent=4)
260
- print(f"Results saved to {args.output_path}")
261
 
262
  return results
263
 
 
12
  import scipy.signal as signal
13
  from typing import Dict, List
14
  from dataset_f import FakeMusicCapsDataset
15
+ from networks import MERT_AudioCNN
16
  from preprocess import get_segments_from_wav, find_optimal_segment_length
17
 
18
 
 
149
  # 데이터를 half 타입으로 변환
150
  if padding_mask.dim() == 1:
151
  padding_mask = padding_mask.unsqueeze(0) # [48] -> [1, 48]
152
+ audio_segments = audio_segments.to(device)
153
 
154
  mask = padding_mask.to(device)
155
 
 
189
  def get_model(model_type, device):
190
  """Load the specified model."""
191
  if model_type == "MERT":
 
192
  #from model import MusicAudioClassifier
 
 
193
  #model = MusicAudioClassifier(input_dim=768, is_emb=True, mode = 'both', share_parameter = False).to(device)
194
+ ckpt_file = 'checkpoints/step=007000-val_loss=0.1831-val_acc=0.9278.ckpt'#'mert_finetune_10.pth'
195
+ model = MERT_AudioCNN.load_from_checkpoint(ckpt_file).to(device)
196
+ model.eval()
197
+ # model.load_state_dict(torch.load(ckpt_file, map_location=device))
198
  embed_dim = 768
199
+
200
  elif model_type == "pure_MERT":
201
  from ISMIR_2025.MERT.networks import MERTFeatureExtractor
202
  model = MERTFeatureExtractor().to(device)
 
211
 
212
 
213
  def inference(audio_path):
 
 
 
 
 
 
 
 
 
 
 
 
214
  backbone_model, input_dim = get_model('MERT', 'cuda')
215
  segments, padding_mask = load_audio(audio_path, sr=24000)
216
+ segments = segments.to('cuda').to(torch.float32)
217
+ padding_mask = padding_mask.to('cuda').unsqueeze(0)
218
  logits,embedding = backbone_model(segments.squeeze(1))
219
  test_dataset = FakeMusicCapsDataset([audio_path], [0], target_duration=10.0)
220
  test_data, test_target = test_dataset[0]
221
+ test_data = test_data.to('cuda').to(torch.float32)
222
+ test_target = test_target.to('cuda')
223
  output, _ = backbone_model(test_data.unsqueeze(0))
224
+
225
 
226
 
227
  # 모델 로드 부분 추가
228
  model = MusicAudioClassifier.load_from_checkpoint(
229
+ checkpoint_path = 'checkpoints/EmbeddingModel_MERT_768-epoch=0073-val_loss=0.1058-val_acc=0.9585-val_f1=0.9366-val_precision=0.9936-val_recall=0.8857.ckpt',
230
  input_dim=input_dim,
231
  #emb_model=backbone_model
232
  is_emb = True,
 
237
  # Run inference
238
  print(f"Segments shape: {segments.shape}")
239
  print("Running inference...")
240
+ results = run_inference(model, embedding, padding_mask, 'cuda')
241
 
242
  # 결과 출력
243
  print(f"Results: {results}")
244
+ asdf
245
 
246
+
 
 
 
 
247
 
248
  return results
249
 
model.py ADDED
@@ -0,0 +1,1042 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import pytorch_lightning as pl
6
+ from torch.utils.data import Dataset, DataLoader
7
+ from typing import List, Tuple, Optional
8
+ import numpy as np
9
+ from pathlib import Path
10
+ import math
11
+ # from deepspeed.ops.adam import FusedAdam # 호환성 문제로 비활성화
12
+
13
+
14
+ class MusicAudioClassifier(pl.LightningModule):
15
+ def __init__(self,
16
+ input_dim: int,
17
+ hidden_dim: int = 256,
18
+ learning_rate: float = 1e-4,
19
+ emb_model: Optional[nn.Module] = None,
20
+ is_emb: bool = False,
21
+ backbone: str = 'segment_transformer',
22
+ num_classes: int = 2):
23
+ super().__init__()
24
+ self.save_hyperparameters()
25
+
26
+ if backbone == 'segment_transformer':
27
+ self.model = SegmentTransformer(
28
+ input_dim=input_dim,
29
+ hidden_dim=hidden_dim,
30
+ num_classes=num_classes,
31
+ mode = 'both'
32
+ )
33
+ elif backbone == 'fusion_segment_transformer':
34
+ self.model = FusionSegmentTransformer(
35
+ input_dim=input_dim,
36
+ hidden_dim=hidden_dim,
37
+ num_classes=num_classes
38
+ )
39
+ elif backbone == 'guided_segment_transformer':
40
+ self.model = GuidedSegmentTransformer(
41
+ input_dim=input_dim,
42
+ hidden_dim=hidden_dim,
43
+ num_classes=num_classes
44
+ )
45
+ elif backbone == 'ultra_segment_processor':
46
+ self.model = UltraModernSegmentProcessor(
47
+ input_dim=input_dim,
48
+ hidden_dim=hidden_dim,
49
+ num_classes=num_classes
50
+ )
51
+ self.emb_model = emb_model
52
+ self.learning_rate = learning_rate
53
+ self.is_emb = is_emb
54
+ self.num_classes = num_classes
55
+
56
+ def _process_audio_batch(self, x: torch.Tensor) -> torch.Tensor:
57
+ B, S = x.shape[:2] # [B, S, C, M, T] or [B, S, C, T] for wav, [B, S, 1?, embsize] for emb
58
+ x = x.view(B*S, *x.shape[2:]) # [B*S, C, M, T]
59
+ if self.is_emb == False:
60
+ _, embeddings = self.emb_model(x) # [B*S, emb_dim]
61
+ else:
62
+ embeddings = x
63
+ if embeddings.dim() == 3:
64
+ pooled_features = embeddings.mean(dim=1) # transformer
65
+ else:
66
+ pooled_features = embeddings # CCV..? no need to pooling
67
+ return pooled_features.view(B, S, -1) # [B, S, emb_dim]
68
+
69
+ def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
70
+ x = self._process_audio_batch(x) # 이걸 freeze하고 쓰는게 사실상 윗버전임
71
+ x = x.half()
72
+ return self.model(x, mask)
73
+
74
+ def _compute_loss_and_probs(self, y_hat: torch.Tensor, y: torch.Tensor):
75
+ """Compute loss and probabilities based on number of classes"""
76
+ if y_hat.size(0) == 1:
77
+ y_hat_flat = y_hat.flatten()
78
+ y_flat = y.flatten()
79
+ else:
80
+ y_hat_flat = y_hat.squeeze() if self.num_classes == 2 else y_hat
81
+ y_flat = y
82
+
83
+ if self.num_classes == 2:
84
+ loss = F.binary_cross_entropy_with_logits(y_hat_flat, y_flat.float())
85
+ probs = torch.sigmoid(y_hat_flat)
86
+ preds = (probs > 0.5).long()
87
+ else:
88
+ loss = F.cross_entropy(y_hat_flat, y_flat.long())
89
+ probs = F.softmax(y_hat_flat, dim=-1)
90
+ preds = torch.argmax(y_hat_flat, dim=-1)
91
+
92
+ return loss, probs, preds, y_flat.long()
93
+
94
+ def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
95
+ x, y, mask = batch
96
+ x = x.half()
97
+ y_hat = self(x, mask)
98
+
99
+ loss, probs, preds, y_true = self._compute_loss_and_probs(y_hat, y)
100
+
101
+ # 간단한 배치 손실만 로깅 (step 수준)
102
+ self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
103
+
104
+ # 전체 에폭에 대한 메트릭 계산을 위해 예측과 실제값 저장
105
+ if self.num_classes == 2:
106
+ self.training_step_outputs.append({'preds': probs, 'targets': y_true, 'binary_preds': preds})
107
+ else:
108
+ self.training_step_outputs.append({'probs': probs, 'preds': preds, 'targets': y_true})
109
+
110
+ return loss
111
+
112
+ def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], batch_idx: int) -> None:
113
+ x, y, mask = batch
114
+ x = x.half()
115
+ y_hat = self(x, mask)
116
+
117
+ loss, probs, preds, y_true = self._compute_loss_and_probs(y_hat, y)
118
+
119
+ # 간단한 배치 손실만 로깅 (step 수준)
120
+ self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
121
+
122
+ # 전체 에폭에 대한 메트릭 계산을 위해 예측과 실제값 저장
123
+ if self.num_classes == 2:
124
+ self.validation_step_outputs.append({'preds': probs, 'targets': y_true, 'binary_preds': preds})
125
+ else:
126
+ self.validation_step_outputs.append({'probs': probs, 'preds': preds, 'targets': y_true})
127
+
128
+ def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], batch_idx: int) -> None:
129
+ x, y, mask = batch
130
+ x = x.half()
131
+ y_hat = self(x, mask)
132
+
133
+ loss, probs, preds, y_true = self._compute_loss_and_probs(y_hat, y)
134
+
135
+ # 간단한 배치 손실만 로깅 (step 수준)
136
+ self.log('test_loss', loss, on_epoch=True, prog_bar=True)
137
+
138
+ # 전체 에폭에 대한 메트릭 계산을 위해 예측과 실제값 저장
139
+ if self.num_classes == 2:
140
+ self.test_step_outputs.append({'preds': probs, 'targets': y_true, 'binary_preds': preds})
141
+ else:
142
+ self.test_step_outputs.append({'probs': probs, 'preds': preds, 'targets': y_true})
143
+
144
+ def on_train_epoch_start(self):
145
+ # 에폭 시작 시 결과 저장용 리스트 초기화
146
+ self.training_step_outputs = []
147
+
148
+ def on_validation_epoch_start(self):
149
+ # 에폭 시작 시 결과 저장용 리스트 초기화
150
+ self.validation_step_outputs = []
151
+
152
+ def on_test_epoch_start(self):
153
+ # 에폭 시작 시 결과 저장용 리스트 초기화
154
+ self.test_step_outputs = []
155
+
156
+ def _compute_binary_metrics(self, outputs, prefix):
157
+ """Binary classification metrics computation"""
158
+ all_preds = torch.cat([x['preds'] for x in outputs])
159
+ all_targets = torch.cat([x['targets'] for x in outputs])
160
+ binary_preds = torch.cat([x['binary_preds'] for x in outputs])
161
+
162
+ # 정확도 계산
163
+ acc = (binary_preds == all_targets).float().mean()
164
+
165
+ # 혼동 행렬 요소 계산
166
+ tp = torch.sum((binary_preds == 1) & (all_targets == 1)).float()
167
+ fp = torch.sum((binary_preds == 1) & (all_targets == 0)).float()
168
+ tn = torch.sum((binary_preds == 0) & (all_targets == 0)).float()
169
+ fn = torch.sum((binary_preds == 0) & (all_targets == 1)).float()
170
+
171
+ # 메트릭 계산
172
+ precision = tp / (tp + fp) if (tp + fp) > 0 else torch.tensor(0.0).to(tp.device)
173
+ recall = tp / (tp + fn) if (tp + fn) > 0 else torch.tensor(0.0).to(tp.device)
174
+ f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else torch.tensor(0.0).to(tp.device)
175
+ specificity = tn / (tn + fp) if (tn + fp) > 0 else torch.tensor(0.0).to(tn.device)
176
+
177
+ # 로깅
178
+ self.log(f'{prefix}_acc', acc, on_epoch=True, prog_bar=True, sync_dist=True)
179
+ self.log(f'{prefix}_precision', precision, on_epoch=True, sync_dist=True)
180
+ self.log(f'{prefix}_recall', recall, on_epoch=True, sync_dist=True)
181
+ self.log(f'{prefix}_f1', f1, on_epoch=True, prog_bar=True, sync_dist=True)
182
+ self.log(f'{prefix}_specificity', specificity, on_epoch=True, sync_dist=True)
183
+
184
+ if prefix in ['val', 'test']:
185
+ # ROC-AUC 계산 (간단한 근사)
186
+ sorted_indices = torch.argsort(all_preds, descending=True)
187
+ sorted_targets = all_targets[sorted_indices]
188
+
189
+ n_pos = torch.sum(all_targets)
190
+ n_neg = len(all_targets) - n_pos
191
+
192
+ if n_pos > 0 and n_neg > 0:
193
+ tpr_curve = torch.cumsum(sorted_targets, dim=0) / n_pos
194
+ fpr_curve = torch.cumsum(1 - sorted_targets, dim=0) / n_neg
195
+
196
+ width = fpr_curve[1:] - fpr_curve[:-1]
197
+ height = (tpr_curve[1:] + tpr_curve[:-1]) / 2
198
+ auc_approx = torch.sum(width * height)
199
+
200
+ self.log(f'{prefix}_auc', auc_approx, on_epoch=True, sync_dist=True)
201
+
202
+ if prefix == 'test':
203
+ balanced_acc = (recall + specificity) / 2
204
+ self.log('test_balanced_acc', balanced_acc, on_epoch=True)
205
+
206
+ def _compute_multiclass_metrics(self, outputs, prefix):
207
+ """Multi-class classification metrics computation"""
208
+ all_probs = torch.cat([x['probs'] for x in outputs])
209
+ all_preds = torch.cat([x['preds'] for x in outputs])
210
+ all_targets = torch.cat([x['targets'] for x in outputs])
211
+
212
+ # 전체 정확도
213
+ acc = (all_preds == all_targets).float().mean()
214
+ self.log(f'{prefix}_acc', acc, on_epoch=True, prog_bar=True, sync_dist=True)
215
+
216
+ # 클래스별 메트릭 계산
217
+ for class_idx in range(self.num_classes):
218
+ # 각 클래스에 대한 이진 분류 메트릭
219
+ class_targets = (all_targets == class_idx).long()
220
+ class_preds = (all_preds == class_idx).long()
221
+
222
+ tp = torch.sum((class_preds == 1) & (class_targets == 1)).float()
223
+ fp = torch.sum((class_preds == 1) & (class_targets == 0)).float()
224
+ tn = torch.sum((class_preds == 0) & (class_targets == 0)).float()
225
+ fn = torch.sum((class_preds == 0) & (class_targets == 1)).float()
226
+
227
+ precision = tp / (tp + fp) if (tp + fp) > 0 else torch.tensor(0.0).to(tp.device)
228
+ recall = tp / (tp + fn) if (tp + fn) > 0 else torch.tensor(0.0).to(tp.device)
229
+ f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else torch.tensor(0.0).to(tp.device)
230
+
231
+ self.log(f'{prefix}_class_{class_idx}_precision', precision, on_epoch=True)
232
+ self.log(f'{prefix}_class_{class_idx}_recall', recall, on_epoch=True)
233
+ self.log(f'{prefix}_class_{class_idx}_f1', f1, on_epoch=True)
234
+
235
+ # 매크로 평균 F1 스코어
236
+ class_f1_scores = []
237
+ for class_idx in range(self.num_classes):
238
+ class_targets = (all_targets == class_idx).long()
239
+ class_preds = (all_preds == class_idx).long()
240
+
241
+ tp = torch.sum((class_preds == 1) & (class_targets == 1)).float()
242
+ fp = torch.sum((class_preds == 1) & (class_targets == 0)).float()
243
+ fn = torch.sum((class_preds == 0) & (class_targets == 1)).float()
244
+
245
+ precision = tp / (tp + fp) if (tp + fp) > 0 else torch.tensor(0.0).to(tp.device)
246
+ recall = tp / (tp + fn) if (tp + fn) > 0 else torch.tensor(0.0).to(tp.device)
247
+ f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else torch.tensor(0.0).to(tp.device)
248
+
249
+ class_f1_scores.append(f1)
250
+
251
+ macro_f1 = torch.stack(class_f1_scores).mean()
252
+ self.log(f'{prefix}_macro_f1', macro_f1, on_epoch=True, prog_bar=True, sync_dist=True)
253
+
254
+ def on_train_epoch_end(self):
255
+ if not hasattr(self, 'training_step_outputs') or not self.training_step_outputs:
256
+ return
257
+
258
+ if self.num_classes == 2:
259
+ self._compute_binary_metrics(self.training_step_outputs, 'train')
260
+ else:
261
+ self._compute_multiclass_metrics(self.training_step_outputs, 'train')
262
+
263
+ def on_validation_epoch_end(self):
264
+ if not hasattr(self, 'validation_step_outputs') or not self.validation_step_outputs:
265
+ return
266
+
267
+ if self.num_classes == 2:
268
+ self._compute_binary_metrics(self.validation_step_outputs, 'val')
269
+ else:
270
+ self._compute_multiclass_metrics(self.validation_step_outputs, 'val')
271
+
272
+ def on_test_epoch_end(self):
273
+ if not hasattr(self, 'test_step_outputs') or not self.test_step_outputs:
274
+ return
275
+
276
+ if self.num_classes == 2:
277
+ self._compute_binary_metrics(self.test_step_outputs, 'test')
278
+ else:
279
+ self._compute_multiclass_metrics(self.test_step_outputs, 'test')
280
+
281
+ def configure_optimizers(self):
282
+ # FusedAdam 대신 일반 AdamW 사용 (GLIBC 호환성 문제 해결)
283
+ optimizer = torch.optim.AdamW(
284
+ self.parameters(),
285
+ lr=self.learning_rate,
286
+ weight_decay=0.01
287
+ )
288
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
289
+ optimizer,
290
+ T_max=100, # Adjust based on your training epochs
291
+ eta_min=1e-6
292
+ )
293
+
294
+ return {
295
+ 'optimizer': optimizer,
296
+ 'lr_scheduler': scheduler,
297
+ 'monitor': 'val_loss',
298
+ }
299
+
300
+
301
+ def pad_sequence_with_mask(batch: List[Tuple[torch.Tensor, int]]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
302
+ """Collate function for DataLoader that creates padded sequences and attention masks with fixed length (48)."""
303
+ embeddings, labels = zip(*batch)
304
+ fixed_len = 48 # 고정 길이
305
+
306
+ batch_size = len(embeddings)
307
+ feat_dim = embeddings[0].shape[-1]
308
+
309
+ padded = torch.zeros((batch_size, fixed_len, feat_dim)) # 고정 길이로 패딩된 텐서
310
+ mask = torch.ones((batch_size, fixed_len), dtype=torch.bool) # True는 padding을 의미
311
+
312
+ for i, emb in enumerate(embeddings):
313
+ length = emb.shape[0]
314
+
315
+ # 길이가 고정 길이보다 길면 자르고, 짧으면 패딩
316
+ if length > fixed_len:
317
+ padded[i, :] = emb[:fixed_len] # fixed_len보다 긴 부분을 잘라서 채운다.
318
+ mask[i, :] = False
319
+ else:
320
+ padded[i, :length] = emb # 실제 데이터 길이에 맞게 채운다.
321
+ mask[i, :length] = False # 패딩이 아닌 부분은 False로 설정
322
+
323
+ return padded, torch.tensor(labels), mask
324
+
325
+
326
+ class SegmentTransformer(nn.Module):
327
+ def __init__(self,
328
+ input_dim: int,
329
+ hidden_dim: int = 256,
330
+ num_heads: int = 8,
331
+ num_layers: int = 4,
332
+ dropout: float = 0.1,
333
+ max_sequence_length: int = 1000,
334
+ mode: str = 'both',
335
+ share_parameter: bool = False,
336
+ num_classes: int = 2):
337
+ super().__init__()
338
+
339
+ # Original sequence processing
340
+ self.input_projection = nn.Linear(input_dim, hidden_dim)
341
+ self.mode = mode
342
+ self.share_parameter = share_parameter
343
+ self.num_classes = num_classes
344
+
345
+ # Positional encoding
346
+ position = torch.arange(max_sequence_length).unsqueeze(1)
347
+ div_term = torch.exp(torch.arange(0, hidden_dim, 2) * (-np.log(10000.0) / hidden_dim))
348
+ pos_encoding = torch.zeros(max_sequence_length, hidden_dim)
349
+ pos_encoding[:, 0::2] = torch.sin(position * div_term)
350
+ pos_encoding[:, 1::2] = torch.cos(position * div_term)
351
+ self.register_buffer('pos_encoding', pos_encoding)
352
+
353
+ # Transformer for original sequence
354
+ encoder_layer = nn.TransformerEncoderLayer(
355
+ d_model=hidden_dim,
356
+ nhead=num_heads,
357
+ dim_feedforward=hidden_dim * 4,
358
+ dropout=dropout,
359
+ batch_first=True
360
+ )
361
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
362
+ self.sim_transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
363
+
364
+ # Self-similarity stream processing
365
+ self.similarity_projection = nn.Sequential(
366
+ nn.Conv1d(1, hidden_dim // 2, kernel_size=3, padding=1),
367
+ nn.ReLU(),
368
+ nn.Conv1d(hidden_dim // 2, hidden_dim, kernel_size=3, padding=1),
369
+ nn.ReLU(),
370
+ nn.Dropout(dropout)
371
+ )
372
+
373
+ # Transformer for similarity stream
374
+ self.similarity_transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
375
+
376
+ # Final classification head
377
+ self.classification_head_dim = hidden_dim * 2 if mode == 'both' else hidden_dim
378
+
379
+ # Output dimension based on number of classes
380
+ output_dim = 1 if num_classes == 2 else num_classes
381
+
382
+ self.classification_head = nn.Sequential(
383
+ nn.Linear(self.classification_head_dim, hidden_dim),
384
+ nn.LayerNorm(hidden_dim),
385
+ nn.ReLU(),
386
+ nn.Dropout(dropout),
387
+ nn.Linear(hidden_dim, hidden_dim // 2),
388
+ nn.LayerNorm(hidden_dim // 2),
389
+ nn.ReLU(),
390
+ nn.Dropout(dropout),
391
+ nn.Linear(hidden_dim // 2, output_dim)
392
+ )
393
+
394
+ def forward(self, x: torch.Tensor, padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
395
+ batch_size, seq_len, _ = x.shape
396
+
397
+ # 1. Process original sequence
398
+ x = x.half()
399
+ x1 = self.input_projection(x)
400
+ x1 = x1 + self.pos_encoding[:seq_len].unsqueeze(0)
401
+ x1 = self.transformer(x1, src_key_padding_mask=padding_mask) # padding_mask 사용
402
+
403
+ # 2. Calculate and process self-similarity
404
+ x_expanded = x.unsqueeze(2)
405
+ x_transposed = x.unsqueeze(1)
406
+ distances = torch.mean((x_expanded - x_transposed) ** 2, dim=-1)
407
+ similarity_matrix = torch.exp(-distances) # (batch_size, seq_len, seq_len)
408
+
409
+ # 자기 유사도 마스크 생성 및 적용 (각 시점에 대한 마스크 개별 적용)
410
+ if padding_mask is not None:
411
+ similarity_mask = padding_mask.unsqueeze(1) | padding_mask.unsqueeze(2) # (batch_size, seq_len, seq_len)
412
+ similarity_matrix = similarity_matrix.masked_fill(similarity_mask, 0.0)
413
+
414
+ # Process similarity matrix row by row using Conv1d
415
+ x2 = similarity_matrix.unsqueeze(1) # (batch_size, 1, seq_len, seq_len)
416
+ x2 = x2.view(batch_size * seq_len, 1, seq_len) # Reshape for Conv1d
417
+ x2 = self.similarity_projection(x2) # (batch_size * seq_len, hidden_dim, seq_len)
418
+ x2 = x2.mean(dim=2) # Pool across sequence dimension
419
+ x2 = x2.view(batch_size, seq_len, -1) # Reshape back
420
+
421
+ x2 = x2 + self.pos_encoding[:seq_len].unsqueeze(0)
422
+ if self.share_parameter:
423
+ x2 = self.transformer(x2, src_key_padding_mask=padding_mask)
424
+ else:
425
+ x2 = self.sim_transformer(x2, src_key_padding_mask=padding_mask) # padding_mask 사용
426
+
427
+ # 3. Global average pooling for both streams
428
+ if padding_mask is not None:
429
+ mask_expanded = (~padding_mask).float().unsqueeze(-1)
430
+ x1 = (x1 * mask_expanded).sum(dim=1) / mask_expanded.sum(dim=1)
431
+ x2 = (x2 * mask_expanded).sum(dim=1) / mask_expanded.sum(dim=1)
432
+ else:
433
+ x1 = x1.mean(dim=1)
434
+ x2 = x2.mean(dim=1)
435
+
436
+ # 4. Combine both streams and classify
437
+ if self.mode == 'only_emb':
438
+ x = x1
439
+ elif self.mode == 'only_structure':
440
+ x = x2
441
+ elif self.mode == 'both':
442
+ x = torch.cat([x1, x2], dim=-1)
443
+ x= x.half()
444
+ return self.classification_head(x)
445
+
446
+
447
+ class PairwiseGuidedTransformer(nn.Module):
448
+ """Pairwise similarity matrix를 활용한 범용 transformer layer
449
+
450
+ Vision: patch간 유사도, NLP: token간 유사도, Audio: segment간 유사도 등에 활용 가능
451
+ """
452
+ def __init__(self, d_model: int, num_heads: int = 8):
453
+ super().__init__()
454
+ self.d_model = d_model
455
+ self.num_heads = num_heads
456
+
457
+ # Standard Q, K projections
458
+ self.q_proj = nn.Linear(d_model, d_model)
459
+ self.k_proj = nn.Linear(d_model, d_model)
460
+
461
+ # Pairwise-guided V projection
462
+ self.v_proj = nn.Linear(d_model, d_model)
463
+
464
+ self.output_proj = nn.Linear(d_model, d_model)
465
+ self.norm = nn.LayerNorm(d_model)
466
+
467
+ def forward(self, x, pairwise_matrix, padding_mask=None):
468
+ """
469
+ Args:
470
+ x: (batch, seq_len, d_model) - sequence embeddings
471
+ pairwise_matrix: (batch, seq_len, seq_len) - pairwise similarity/distance matrix
472
+ padding_mask: (batch, seq_len) - padding mask
473
+ """
474
+ batch_size, seq_len, d_model = x.shape
475
+
476
+ # Standard Q, K, V
477
+ Q = self.q_proj(x)
478
+ K = self.k_proj(x)
479
+ V = self.v_proj(x)
480
+
481
+ # Reshape for multi-head
482
+ Q = Q.view(batch_size, seq_len, self.num_heads, -1).transpose(1, 2)
483
+ K = K.view(batch_size, seq_len, self.num_heads, -1).transpose(1, 2)
484
+ V = V.view(batch_size, seq_len, self.num_heads, -1).transpose(1, 2)
485
+
486
+ # Standard attention scores
487
+ scores = torch.matmul(Q, K.transpose(-2, -1)) / (d_model ** 0.5)
488
+
489
+ # ✅ Combine with pairwise matrix
490
+ #pairwise_expanded = pairwise_matrix.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
491
+ enhanced_scores = scores# + pairwise_expanded 이거 빼고 하기로 했죠?
492
+
493
+ # Apply padding mask
494
+ if padding_mask is not None:
495
+ mask_4d = padding_mask.unsqueeze(1).unsqueeze(1).expand(-1, self.num_heads, seq_len, -1)
496
+ enhanced_scores = enhanced_scores.masked_fill(mask_4d, float('-inf'))
497
+
498
+ # Softmax and apply to V
499
+ attn_weights = F.softmax(enhanced_scores, dim=-1)
500
+ attended = torch.matmul(attn_weights, V)
501
+
502
+ # Reshape and project
503
+ attended = attended.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
504
+ output = self.output_proj(attended)
505
+
506
+ return self.norm(x + output)
507
+
508
+
509
+ class MultiScaleAdaptivePooler(nn.Module):
510
+ """Multi-scale adaptive pooling - 다양한 도메인에서 활용 가능"""
511
+
512
+ def __init__(self, hidden_dim: int, num_heads: int = 8):
513
+ super().__init__()
514
+
515
+ # Attention-based pooling
516
+ self.attention_pool = nn.MultiheadAttention(
517
+ hidden_dim, num_heads=num_heads, batch_first=True
518
+ )
519
+ self.query_token = nn.Parameter(torch.randn(1, 1, hidden_dim))
520
+
521
+ # Complementary pooling strategies
522
+ self.max_pool_proj = nn.Linear(hidden_dim, hidden_dim)
523
+
524
+ self.fusion = nn.Linear(hidden_dim * 3, hidden_dim)
525
+
526
+
527
+ def forward(self, x, padding_mask=None):
528
+ """
529
+ Args:
530
+ x: (batch, seq_len, hidden_dim) - sequence features
531
+ padding_mask: (batch, seq_len) - padding mask
532
+ """
533
+ batch_size = x.size(0)
534
+
535
+ # 1. Global average pooling
536
+ if padding_mask is not None:
537
+ mask_expanded = (~padding_mask).float().unsqueeze(-1)
538
+ global_avg = (x * mask_expanded).sum(dim=1) / mask_expanded.sum(dim=1)
539
+ else:
540
+ global_avg = x.mean(dim=1)
541
+
542
+ # # 2. Global max pooling
543
+ # if padding_mask is not None:
544
+ # x_masked = x.clone()
545
+ # x_masked[padding_mask] = float('-inf')
546
+ # global_max = x_masked.max(dim=1)[0]
547
+ # else:
548
+ # global_max = x.max(dim=1)[0]
549
+
550
+ # global_max = self.max_pool_proj(global_max)
551
+
552
+ # # 3. Attention-based pooling
553
+ # query = self.query_token.expand(batch_size, -1, -1)
554
+ # attn_pooled, _ = self.attention_pool(
555
+ # query, x, x,
556
+ # key_padding_mask=padding_mask
557
+ # )
558
+ # attn_pooled = attn_pooled.squeeze(1)
559
+
560
+ # # 4. Fuse all pooling results
561
+ # #combined = torch.cat([global_avg, global_max, attn_pooled], dim=-1)
562
+ # #output = self.fusion(combined)
563
+ output = global_avg
564
+ return output
565
+
566
+
567
+ class GuidedSegmentTransformer(nn.Module):
568
+ def __init__(self,
569
+ input_dim: int,
570
+ hidden_dim: int = 256,
571
+ num_heads: int = 8,
572
+ num_layers: int = 4,
573
+ dropout: float = 0.1,
574
+ max_sequence_length: int = 1000,
575
+ mode: str = 'only_emb',
576
+ share_parameter: bool = False,
577
+ num_classes: int = 2):
578
+ super().__init__()
579
+
580
+ # Original sequence processing
581
+ self.input_projection = nn.Linear(input_dim, hidden_dim)
582
+ self.mode = mode
583
+ self.share_parameter = share_parameter
584
+ self.num_classes = num_classes
585
+
586
+ # Positional encoding
587
+ position = torch.arange(max_sequence_length).unsqueeze(1)
588
+ div_term = torch.exp(torch.arange(0, hidden_dim, 2) * (-np.log(10000.0) / hidden_dim))
589
+ pos_encoding = torch.zeros(max_sequence_length, hidden_dim)
590
+ pos_encoding[:, 0::2] = torch.sin(position * div_term)
591
+ pos_encoding[:, 1::2] = torch.cos(position * div_term)
592
+ self.register_buffer('pos_encoding', pos_encoding)
593
+
594
+ # ✅ Pairwise-guided transformer layers (범용적 이름)
595
+ self.pairwise_guided_layers = nn.ModuleList([
596
+ PairwiseGuidedTransformer(hidden_dim, num_heads)
597
+ for _ in range(num_layers)
598
+ ])
599
+
600
+ # Pairwise matrix processing (기존 similarity processing)
601
+ self.pairwise_projection = nn.Sequential(
602
+ nn.Conv1d(1, hidden_dim // 2, kernel_size=3, padding=1),
603
+ nn.ReLU(),
604
+ nn.Conv1d(hidden_dim // 2, hidden_dim, kernel_size=3, padding=1),
605
+ nn.ReLU(),
606
+ nn.Dropout(dropout)
607
+ )
608
+
609
+ # ✅ Multi-scale adaptive pooling (범용적 이름)
610
+ self.adaptive_pooler = MultiScaleAdaptivePooler(hidden_dim, num_heads)
611
+
612
+ # Final classification head
613
+ self.classification_head_dim = hidden_dim * 2 if mode == 'both' else hidden_dim
614
+ output_dim = 1 if num_classes == 2 else num_classes
615
+
616
+ self.classification_head = nn.Sequential(
617
+ nn.Linear(self.classification_head_dim, hidden_dim),
618
+ nn.LayerNorm(hidden_dim),
619
+ nn.ReLU(),
620
+ nn.Dropout(dropout),
621
+ nn.Linear(hidden_dim, hidden_dim // 2),
622
+ nn.LayerNorm(hidden_dim // 2),
623
+ nn.ReLU(),
624
+ nn.Dropout(dropout),
625
+ nn.Linear(hidden_dim // 2, output_dim)
626
+ )
627
+
628
+ def forward(self, x: torch.Tensor, padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
629
+ batch_size, seq_len, _ = x.shape
630
+
631
+ # 1. Process sequence
632
+ x1 = self.input_projection(x)
633
+ x1 = x1 + self.pos_encoding[:seq_len].unsqueeze(0)
634
+
635
+ # 2. Calculate pairwise matrix (can be similarity, distance, correlation, etc.)
636
+ x_expanded = x.unsqueeze(2)
637
+ x_transposed = x.unsqueeze(1)
638
+ distances = torch.mean((x_expanded - x_transposed) ** 2, dim=-1)
639
+ pairwise_matrix = torch.exp(-distances) # Convert distance to similarity
640
+
641
+ # Apply padding mask to pairwise matrix
642
+ if padding_mask is not None:
643
+ pairwise_mask = padding_mask.unsqueeze(1) | padding_mask.unsqueeze(2)
644
+ pairwise_matrix = pairwise_matrix.masked_fill(pairwise_mask, 0.0)
645
+
646
+ # ✅ Pairwise-guided processing
647
+ for layer in self.pairwise_guided_layers:
648
+ x1 = layer(x1, pairwise_matrix, padding_mask)
649
+
650
+ # 3. Process pairwise matrix as separate stream (optional)
651
+ if self.mode in ['only_structure', 'both']:
652
+ x2 = pairwise_matrix.unsqueeze(1)
653
+ x2 = x2.view(batch_size * seq_len, 1, seq_len)
654
+ x2 = self.pairwise_projection(x2)
655
+ x2 = x2.mean(dim=2)
656
+ x2 = x2.view(batch_size, seq_len, -1)
657
+ x2 = x2 + self.pos_encoding[:seq_len].unsqueeze(0)
658
+
659
+ # ✅ Multi-scale adaptive pooling
660
+ if self.mode == 'only_emb':
661
+ x = self.adaptive_pooler(x1, padding_mask)
662
+ elif self.mode == 'only_structure':
663
+ x = self.adaptive_pooler(x2, padding_mask)
664
+ elif self.mode == 'both':
665
+ x1_pooled = self.adaptive_pooler(x1, padding_mask)
666
+ x2_pooled = self.adaptive_pooler(x2, padding_mask)
667
+ x = torch.cat([x1_pooled, x2_pooled], dim=-1)
668
+
669
+ x = x
670
+ return self.classification_head(x)
671
+
672
+
673
+ class CrossModalFusionLayer(nn.Module):
674
+ """Structure와 Embedding 정보를 점진적으로 융합"""
675
+
676
+ def __init__(self, d_model: int, num_heads: int = 8):
677
+ super().__init__()
678
+
679
+ # Cross-attention: embedding이 structure를 query하고, structure가 embedding을 query
680
+ self.emb_to_struct_attn = nn.MultiheadAttention(d_model, num_heads, batch_first=True)
681
+ self.struct_to_emb_attn = nn.MultiheadAttention(d_model, num_heads, batch_first=True)
682
+
683
+ # Fusion gate (어느 정보를 얼마나 믿을지)
684
+ self.fusion_gate = nn.Sequential(
685
+ nn.Linear(d_model * 2, d_model),
686
+ nn.Sigmoid()
687
+ )
688
+
689
+ self.norm1 = nn.LayerNorm(d_model)
690
+ self.norm2 = nn.LayerNorm(d_model)
691
+
692
+ def forward(self, emb_features, struct_features, padding_mask=None):
693
+ """
694
+ emb_features: (batch, seq_len, d_model) - 메인 embedding 정보
695
+ struct_features: (batch, seq_len, d_model) - structure 정보
696
+ """
697
+
698
+ # 1. Embedding이 Structure 정보를 참조
699
+ emb_enhanced, _ = self.emb_to_struct_attn(
700
+ emb_features, struct_features, struct_features,
701
+ key_padding_mask=padding_mask
702
+ )
703
+ emb_enhanced = self.norm1(emb_features + emb_enhanced)
704
+
705
+ # 2. Structure가 Embedding 정보를 참조
706
+ struct_enhanced, _ = self.struct_to_emb_attn(
707
+ struct_features, emb_features, emb_features,
708
+ key_padding_mask=padding_mask
709
+ )
710
+ struct_enhanced = self.norm2(struct_features + struct_enhanced)
711
+
712
+ # 3. Adaptive fusion (둘 중 어느 것을 더 믿을지 학습)
713
+ combined = torch.cat([emb_enhanced, struct_enhanced], dim=-1)
714
+ gate_weight = self.fusion_gate(combined) # (batch, seq_len, d_model)
715
+
716
+ # Gated combination
717
+ fused = gate_weight * emb_enhanced + (1 - gate_weight) * struct_enhanced
718
+
719
+ return fused
720
+
721
+
722
+ class FusionSegmentTransformer(nn.Module):
723
+ def __init__(self,
724
+ input_dim: int,
725
+ hidden_dim: int = 256,
726
+ num_heads: int = 8,
727
+ num_layers: int = 4,
728
+ dropout: float = 0.1,
729
+ max_sequence_length: int = 1000,
730
+ mode: str = 'both', # 기본값을 both로
731
+ share_parameter: bool = False,
732
+ num_classes: int = 2):
733
+ super().__init__()
734
+
735
+ self.input_projection = nn.Linear(input_dim, hidden_dim)
736
+ self.mode = mode
737
+ self.num_classes = num_classes
738
+
739
+ # Positional encoding
740
+ position = torch.arange(max_sequence_length).unsqueeze(1)
741
+ div_term = torch.exp(torch.arange(0, hidden_dim, 2) * (-np.log(10000.0) / hidden_dim))
742
+ pos_encoding = torch.zeros(max_sequence_length, hidden_dim)
743
+ pos_encoding[:, 0::2] = torch.sin(position * div_term)
744
+ pos_encoding[:, 1::2] = torch.cos(position * div_term)
745
+ self.register_buffer('pos_encoding', pos_encoding)
746
+
747
+ # ✅ Embedding stream: Pairwise-guided transformer
748
+ self.embedding_layers = nn.ModuleList([
749
+ PairwiseGuidedTransformer(hidden_dim, num_heads)
750
+ for _ in range(num_layers)
751
+ ])
752
+
753
+ # ✅ Structure stream: Pairwise matrix processing
754
+ self.pairwise_projection = nn.Sequential(
755
+ nn.Conv1d(1, hidden_dim // 2, kernel_size=3, padding=1),
756
+ nn.ReLU(),
757
+ nn.Conv1d(hidden_dim // 2, hidden_dim, kernel_size=3, padding=1),
758
+ nn.ReLU(),
759
+ nn.Dropout(dropout)
760
+ )
761
+
762
+ # Structure transformer layers
763
+ self.structure_layers = nn.ModuleList([
764
+ nn.TransformerEncoderLayer(
765
+ d_model=hidden_dim,
766
+ nhead=num_heads,
767
+ dim_feedforward=hidden_dim * 4,
768
+ dropout=dropout,
769
+ batch_first=True
770
+ ) for _ in range(num_layers // 2) # 절반만 사용
771
+ ])
772
+
773
+ # ✅ Cross-modal fusion layers (핵심!)
774
+ self.fusion_layers = nn.ModuleList([
775
+ CrossModalFusionLayer(hidden_dim, num_heads)
776
+ for _ in range(1) # fusion은 하나만 써야 gate가 유의미해질듯
777
+ ])
778
+
779
+ # Adaptive pooling
780
+ self.adaptive_pooler = MultiScaleAdaptivePooler(hidden_dim, num_heads)
781
+
782
+ # Final classification head (이제 단일 차원)
783
+ output_dim = 1 if num_classes == 2 else num_classes
784
+
785
+ self.classification_head = nn.Sequential(
786
+ nn.Linear(hidden_dim, hidden_dim), # 더 이상 concat 안함
787
+ nn.LayerNorm(hidden_dim),
788
+ nn.ReLU(),
789
+ nn.Dropout(dropout),
790
+ nn.Linear(hidden_dim, hidden_dim // 2),
791
+ nn.LayerNorm(hidden_dim // 2),
792
+ nn.ReLU(),
793
+ nn.Dropout(dropout),
794
+ nn.Linear(hidden_dim // 2, output_dim)
795
+ )
796
+
797
+ def forward(self, x: torch.Tensor, padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
798
+ batch_size, seq_len, _ = x.shape
799
+
800
+ # 1. Initialize both streams
801
+ x_emb = self.input_projection(x)
802
+ x_emb = x_emb + self.pos_encoding[:seq_len].unsqueeze(0)
803
+
804
+ # 2. Calculate pairwise matrix
805
+ x_expanded = x.unsqueeze(2)
806
+ x_transposed = x.unsqueeze(1)
807
+ distances = torch.mean((x_expanded - x_transposed) ** 2, dim=-1)
808
+ pairwise_matrix = torch.exp(-distances)
809
+
810
+ if padding_mask is not None:
811
+ pairwise_mask = padding_mask.unsqueeze(1) | padding_mask.unsqueeze(2)
812
+ pairwise_matrix = pairwise_matrix.masked_fill(pairwise_mask, 0.0)
813
+
814
+ # 3. Process structure stream
815
+ x_struct = pairwise_matrix.unsqueeze(1)
816
+ x_struct = x_struct.view(batch_size * seq_len, 1, seq_len)
817
+ x_struct = self.pairwise_projection(x_struct)
818
+ x_struct = x_struct.mean(dim=2)
819
+ x_struct = x_struct.view(batch_size, seq_len, -1)
820
+ x_struct = x_struct + self.pos_encoding[:seq_len].unsqueeze(0)
821
+
822
+ for struct_layer in self.structure_layers:
823
+ x_struct = struct_layer(x_struct, src_key_padding_mask=padding_mask)
824
+
825
+ # 4. Process embedding stream (with pairwise guidance)
826
+ for emb_layer in self.embedding_layers:
827
+ x_emb = emb_layer(x_emb, pairwise_matrix, padding_mask)
828
+
829
+ # ✅ 5. Progressive Cross-modal Fusion (핵심!)
830
+ fused = x_emb # 시작은 embedding에서
831
+ for fusion_layer in self.fusion_layers:
832
+ fused = fusion_layer(fused, x_struct, padding_mask)
833
+ # 이제 fused는 embedding + structure 정보를 모두 포함
834
+
835
+ # 6. Final pooling and classification
836
+ pooled = self.adaptive_pooler(fused, padding_mask)
837
+
838
+ pooled = pooled.half()
839
+ return self.classification_head(pooled)
840
+
841
+ import torch
842
+ import torch.nn as nn
843
+ import torch.nn.functional as F
844
+ import numpy as np
845
+ from typing import Optional
846
+ import math
847
+
848
+ class RMSNorm(nn.Module):
849
+ """RMS Normalization - 안정적"""
850
+ def __init__(self, dim: int, eps: float = 1e-6):
851
+ super().__init__()
852
+ self.eps = eps
853
+ self.weight = nn.Parameter(torch.ones(dim))
854
+
855
+ def forward(self, x):
856
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
857
+
858
+ class SwiGLU(nn.Module):
859
+ """SwiGLU Activation - 단순 버전"""
860
+ def __init__(self, dim: int):
861
+ super().__init__()
862
+ self.w1 = nn.Linear(dim, dim * 2, bias=False)
863
+ self.w2 = nn.Linear(dim, dim, bias=False)
864
+
865
+ def forward(self, x):
866
+ return self.w2(F.silu(self.w1(x)[:, :, :x.size(-1)])) # 차원 맞춤
867
+
868
+ class GroupedQueryAttention(nn.Module):
869
+ """단순한 GQA - 에러 방지"""
870
+ def __init__(self, d_model: int, num_heads: int = 8):
871
+ super().__init__()
872
+ assert d_model % num_heads == 0
873
+
874
+ self.d_model = d_model
875
+ self.num_heads = num_heads
876
+ self.head_dim = d_model // num_heads
877
+
878
+ # 모든 projection을 동일한 차원으로
879
+ self.q_proj = nn.Linear(d_model, d_model, bias=False)
880
+ self.k_proj = nn.Linear(d_model, d_model, bias=False)
881
+ self.v_proj = nn.Linear(d_model, d_model, bias=False)
882
+ self.o_proj = nn.Linear(d_model, d_model, bias=False)
883
+
884
+ self.scale = 1.0 / math.sqrt(self.head_dim)
885
+
886
+ def forward(self, x, pairwise_matrix=None, padding_mask=None):
887
+ B, L, D = x.shape
888
+
889
+ Q = self.q_proj(x).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
890
+ K = self.k_proj(x).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
891
+ V = self.v_proj(x).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
892
+
893
+ scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale
894
+
895
+ if pairwise_matrix is not None:
896
+ scores = scores + pairwise_matrix.unsqueeze(1)
897
+
898
+ if padding_mask is not None:
899
+ mask_4d = padding_mask.unsqueeze(1).unsqueeze(1).expand(-1, self.num_heads, L, -1)
900
+ scores = scores.masked_fill(mask_4d, float('-inf'))
901
+
902
+ attn_weights = F.softmax(scores, dim=-1)
903
+ attn_output = torch.matmul(attn_weights, V)
904
+
905
+ attn_output = attn_output.transpose(1, 2).contiguous().view(B, L, D)
906
+ return self.o_proj(attn_output)
907
+
908
+ class SimpleModernLayer(nn.Module):
909
+ """단순하고 안전한 모던 레이어"""
910
+ def __init__(self, d_model: int, num_heads: int = 8):
911
+ super().__init__()
912
+
913
+ # RMSNorm
914
+ self.norm1 = RMSNorm(d_model)
915
+ self.norm2 = RMSNorm(d_model)
916
+
917
+ # Attention
918
+ self.attention = GroupedQueryAttention(d_model, num_heads)
919
+
920
+ # Feed forward
921
+ self.ffn = SwiGLU(d_model)
922
+
923
+ def forward(self, x, pairwise_matrix=None, padding_mask=None):
924
+ # Attention with residual
925
+ normed_x = self.norm1(x)
926
+ attn_out = self.attention(normed_x, pairwise_matrix, padding_mask)
927
+ x = x + attn_out
928
+
929
+ # FFN with residual
930
+ normed_x2 = self.norm2(x)
931
+ ffn_out = self.ffn(normed_x2)
932
+ x = x + ffn_out
933
+
934
+ return x
935
+
936
+ class SimpleQuantumPooling(nn.Module):
937
+ """단순한 어텐션 풀링"""
938
+ def __init__(self, d_model: int):
939
+ super().__init__()
940
+
941
+ # 3가지 풀링 방법
942
+ self.attention_pool = nn.MultiheadAttention(d_model, 8, batch_first=True)
943
+ self.query_token = nn.Parameter(torch.randn(1, 1, d_model) * 0.02)
944
+
945
+ # 결합
946
+ self.final_proj = nn.Linear(d_model * 3, d_model, bias=False)
947
+
948
+ def forward(self, x, padding_mask=None):
949
+ batch_size = x.size(0)
950
+
951
+ # 1. Average pooling
952
+ if padding_mask is not None:
953
+ mask_expanded = (~padding_mask).float().unsqueeze(-1)
954
+ avg_pooled = (x * mask_expanded).sum(dim=1) / mask_expanded.sum(dim=1)
955
+ else:
956
+ avg_pooled = x.mean(dim=1)
957
+
958
+ # 2. Max pooling
959
+ if padding_mask is not None:
960
+ x_masked = x.clone()
961
+ x_masked[padding_mask] = float('-inf')
962
+ max_pooled = x_masked.max(dim=1)[0]
963
+ else:
964
+ max_pooled = x.max(dim=1)[0]
965
+
966
+ # 3. Attention pooling
967
+ query = self.query_token.expand(batch_size, -1, -1)
968
+ attn_pooled, _ = self.attention_pool(
969
+ query, x, x, key_padding_mask=padding_mask
970
+ )
971
+ attn_pooled = attn_pooled.squeeze(1)
972
+
973
+ # 결합
974
+ combined = torch.cat([avg_pooled, max_pooled, attn_pooled], dim=-1).half()
975
+ return self.final_proj(combined)
976
+
977
+ class UltraModernSegmentProcessor(nn.Module):
978
+ """에러 없는 단순 버전 ✅"""
979
+ def __init__(self,
980
+ input_dim: int,
981
+ hidden_dim: int = 512,
982
+ num_heads: int = 8,
983
+ num_layers: int = 6,
984
+ dropout: float = 0.1,
985
+ max_sequence_length: int = 1000,
986
+ num_classes: int = 2):
987
+ super().__init__()
988
+
989
+ assert hidden_dim % num_heads == 0
990
+
991
+ self.hidden_dim = hidden_dim
992
+ self.input_projection = nn.Linear(input_dim, hidden_dim, bias=False)
993
+
994
+ # 모던 레이어들
995
+ self.layers = nn.ModuleList([
996
+ SimpleModernLayer(hidden_dim, num_heads)
997
+ for _ in range(num_layers)
998
+ ])
999
+
1000
+ # 단순 풀링
1001
+ self.pooler = SimpleQuantumPooling(hidden_dim)
1002
+
1003
+ # 분류 헤드
1004
+ output_dim = 1 if num_classes == 2 else num_classes
1005
+
1006
+ self.classifier = nn.Sequential(
1007
+ nn.Linear(hidden_dim, hidden_dim // 2, bias=False),
1008
+ RMSNorm(hidden_dim // 2),
1009
+ nn.SiLU(),
1010
+ nn.Dropout(dropout),
1011
+ nn.Linear(hidden_dim // 2, hidden_dim // 4, bias=False),
1012
+ RMSNorm(hidden_dim // 4),
1013
+ nn.SiLU(),
1014
+ nn.Dropout(dropout),
1015
+ nn.Linear(hidden_dim // 4, output_dim, bias=False)
1016
+ )
1017
+
1018
+ def forward(self, x: torch.Tensor, padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
1019
+ # Input projection
1020
+ x_emb = self.input_projection(x)
1021
+
1022
+ # Pairwise matrix 계산
1023
+ x_expanded = x.unsqueeze(2)
1024
+ x_transposed = x.unsqueeze(1)
1025
+
1026
+ # 유클리드 거리만 사용 (단순하게)
1027
+ distances = torch.mean((x_expanded - x_transposed) ** 2, dim=-1)
1028
+ pairwise_matrix = torch.exp(-distances)
1029
+
1030
+ if padding_mask is not None:
1031
+ pairwise_mask = padding_mask.unsqueeze(1) | padding_mask.unsqueeze(2)
1032
+ pairwise_matrix = pairwise_matrix.masked_fill(pairwise_mask, 0.0)
1033
+
1034
+ # 레이어들 통과
1035
+ for layer in self.layers:
1036
+ x_emb = layer(x_emb, pairwise_matrix, padding_mask)
1037
+
1038
+ # 풀링
1039
+ pooled = self.pooler(x_emb, padding_mask)
1040
+
1041
+ # 분류
1042
+ return self.classifier(pooled)
networks.py ADDED
@@ -0,0 +1,560 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import pytorch_lightning as pl
5
+ from transformers import AutoModel, AutoConfig
6
+ from transformers import Wav2Vec2Model, Wav2Vec2Processor, Data2VecAudioModel
7
+ import torchmetrics
8
+
9
+ class cnnblock(nn.Module):
10
+ def __init__(self, embed_dim=512):
11
+ super(cnnblock, self).__init__()
12
+ self.conv_block = nn.Sequential(
13
+ nn.Conv2d(1, 16, kernel_size=3, padding=1),
14
+ nn.ReLU(),
15
+ nn.MaxPool2d(2),
16
+ nn.Conv2d(16, 32, kernel_size=3, padding=1),
17
+ nn.ReLU(),
18
+ nn.MaxPool2d(2),
19
+ nn.AdaptiveAvgPool2d((4, 4))
20
+ )
21
+ self.projection = nn.Linear(32 * 4 * 4, embed_dim)
22
+
23
+ def forward(self, x):
24
+ x = self.conv_block(x)
25
+ B, C, H, W = x.shape
26
+ x = x.view(B, -1)
27
+ x = self.projection(x)
28
+ return x
29
+
30
+ class CrossAttention(nn.Module):
31
+ def __init__(self, embed_dim, num_heads):
32
+ super(CrossAttention, self).__init__()
33
+ self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
34
+ self.layer_norm1 = nn.LayerNorm(embed_dim)
35
+ self.layer_norm2 = nn.LayerNorm(embed_dim)
36
+ self.feed_forward = nn.Sequential(
37
+ nn.Linear(embed_dim, embed_dim * 4),
38
+ nn.ReLU(),
39
+ nn.Linear(embed_dim * 4, embed_dim)
40
+ )
41
+
42
+ def forward(self, x, cross_input):
43
+ attn_output, _ = self.multihead_attn(query=x, key=cross_input, value=cross_input)
44
+ x = self.layer_norm1(x + attn_output)
45
+ ff_output = self.feed_forward(x)
46
+ x = self.layer_norm2(x + ff_output)
47
+ return x
48
+
49
+
50
+ class CrossAttn_Transformer(nn.Module):
51
+ def __init__(self, embed_dim=512, num_heads=8, num_layers=6, num_classes=2):
52
+ super(CrossAttn_Transformer, self).__init__()
53
+
54
+ self.cross_attention_layers = nn.ModuleList([
55
+ CrossAttention(embed_dim, num_heads) for _ in range(num_layers)
56
+ ])
57
+
58
+ encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads)
59
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
60
+
61
+ self.classifier = nn.Sequential(
62
+ nn.LayerNorm(embed_dim),
63
+ nn.Linear(embed_dim, num_classes)
64
+ )
65
+
66
+ def forward(self, x, cross_attention_input):
67
+ self.attention_maps = []
68
+ for layer in self.cross_attention_layers:
69
+ x = layer(x, cross_attention_input)
70
+
71
+ x = x.permute(1, 0, 2)
72
+ x = self.transformer(x)
73
+ x = x.mean(dim=0)
74
+ x = self.classifier(x)
75
+ return x
76
+
77
+ class MERT(nn.Module):
78
+ def __init__(self, freeze_feature_extractor=True):
79
+ super(MERT, self).__init__()
80
+ config = AutoConfig.from_pretrained("m-a-p/MERT-v1-95M", trust_remote_code=True)
81
+ if not hasattr(config, "conv_pos_batch_norm"):
82
+ setattr(config, "conv_pos_batch_norm", False)
83
+ self.mert = AutoModel.from_pretrained("m-a-p/MERT-v1-95M", config=config, trust_remote_code=True)
84
+
85
+ if freeze_feature_extractor:
86
+ self.freeze()
87
+
88
+ def forward(self, input_values):
89
+ with torch.no_grad():
90
+ outputs = self.mert(input_values, output_hidden_states=True)
91
+ hidden_states = torch.stack(outputs.hidden_states)
92
+ hidden_states = hidden_states.detach().clone().requires_grad_(True)
93
+ time_reduced = hidden_states.mean(dim=2)
94
+ time_reduced = time_reduced.permute(1, 0, 2)
95
+ return time_reduced
96
+
97
+ def freeze(self):
98
+ for param in self.mert.parameters():
99
+ param.requires_grad = False
100
+
101
+ def unfreeze(self):
102
+ for param in self.mert.parameters():
103
+ param.requires_grad = True
104
+
105
+
106
+ class MERT_AudioCNN(pl.LightningModule):
107
+ def __init__(self, embed_dim=768, num_heads=8, num_layers=6, num_classes=2,
108
+ freeze_feature_extractor=False, learning_rate=2e-5, weight_decay=0.01):
109
+ super(MERT_AudioCNN, self).__init__()
110
+ self.save_hyperparameters()
111
+ self.feature_extractor = MERT(freeze_feature_extractor=freeze_feature_extractor)
112
+ self.cross_attention_layers = nn.ModuleList([
113
+ CrossAttention(embed_dim, num_heads) for _ in range(num_layers)
114
+ ])
115
+ encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, batch_first=True)
116
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
117
+ self.classifier = nn.Sequential(
118
+ nn.LayerNorm(embed_dim),
119
+ nn.Linear(embed_dim, 256),
120
+ nn.BatchNorm1d(256),
121
+ nn.ReLU(),
122
+ nn.Dropout(0.3),
123
+ nn.Linear(256, num_classes)
124
+ )
125
+
126
+ # Metrics
127
+ self.train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)
128
+ self.val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)
129
+ self.test_acc = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)
130
+
131
+ self.train_f1 = torchmetrics.F1Score(task="multiclass", num_classes=num_classes)
132
+ self.val_f1 = torchmetrics.F1Score(task="multiclass", num_classes=num_classes)
133
+ self.test_f1 = torchmetrics.F1Score(task="multiclass", num_classes=num_classes)
134
+
135
+ self.learning_rate = learning_rate
136
+ self.weight_decay = weight_decay
137
+
138
+ def forward(self, input_values):
139
+ features = self.feature_extractor(input_values)
140
+ for layer in self.cross_attention_layers:
141
+ features = layer(features, features)
142
+
143
+ features = features.mean(dim=1).unsqueeze(1)
144
+ encoded = self.transformer(features)
145
+ encoded = encoded.mean(dim=1)
146
+ output = self.classifier(encoded)
147
+ return output, encoded
148
+
149
+ def training_step(self, batch, batch_idx):
150
+ x, y = batch
151
+ logits = self(x)
152
+ loss = F.cross_entropy(logits, y)
153
+
154
+ preds = torch.argmax(logits, dim=1)
155
+ self.train_acc(preds, y)
156
+ self.train_f1(preds, y)
157
+
158
+ self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
159
+ self.log('train_acc', self.train_acc, on_step=False, on_epoch=True, prog_bar=True)
160
+ self.log('train_f1', self.train_f1, on_step=False, on_epoch=True, prog_bar=True)
161
+
162
+ return loss
163
+
164
+ def validation_step(self, batch, batch_idx):
165
+ x, y = batch
166
+ logits = self(x)
167
+ loss = F.cross_entropy(logits, y)
168
+
169
+ preds = torch.argmax(logits, dim=1)
170
+ self.val_acc(preds, y)
171
+ self.val_f1(preds, y)
172
+
173
+ self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
174
+ self.log('val_acc', self.val_acc, on_step=False, on_epoch=True, prog_bar=True)
175
+ self.log('val_f1', self.val_f1, on_step=False, on_epoch=True, prog_bar=True)
176
+
177
+ return loss
178
+
179
+ def test_step(self, batch, batch_idx):
180
+ x, y = batch
181
+ logits = self(x)
182
+ loss = F.cross_entropy(logits, y)
183
+
184
+ preds = torch.argmax(logits, dim=1)
185
+ self.test_acc(preds, y)
186
+ self.test_f1(preds, y)
187
+
188
+ self.log('test_loss', loss, on_step=False, on_epoch=True)
189
+ self.log('test_acc', self.test_acc, on_step=False, on_epoch=True)
190
+ self.log('test_f1', self.test_f1, on_step=False, on_epoch=True)
191
+
192
+ return loss
193
+
194
+ def configure_optimizers(self):
195
+ optimizer = torch.optim.AdamW(
196
+ self.parameters(),
197
+ lr=self.learning_rate,
198
+ weight_decay=self.weight_decay
199
+ )
200
+
201
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
202
+ optimizer,
203
+ mode='min',
204
+ factor=0.5,
205
+ patience=2,
206
+ verbose=True
207
+ )
208
+
209
+ return {
210
+ "optimizer": optimizer,
211
+ "lr_scheduler": {
212
+ "scheduler": scheduler,
213
+ "monitor": "val_loss",
214
+ "interval": "epoch",
215
+ "frequency": 1
216
+ }
217
+ }
218
+
219
+ def unfreeze_feature_extractor(self):
220
+ self.feature_extractor.unfreeze()
221
+
222
+
223
+ class Wav2vec_AudioCNN(pl.LightningModule):
224
+ def __init__(self, model_name="facebook/wav2vec2-base", embed_dim=512, num_heads=8,
225
+ num_layers=6, num_classes=2, freeze_feature_extractor=True,
226
+ learning_rate=2e-5, weight_decay=0.01):
227
+ super(Wav2vec_AudioCNN, self).__init__()
228
+ self.save_hyperparameters()
229
+
230
+ self.processor = Wav2Vec2Processor.from_pretrained(model_name)
231
+ self.feature_extractor = Wav2Vec2Model.from_pretrained(model_name)
232
+ if freeze_feature_extractor:
233
+ self.feature_extractor.freeze_feature_encoder()
234
+
235
+ self.projection = nn.Linear(self.feature_extractor.config.hidden_size, embed_dim)
236
+ self.decoder = CrossAttn_Transformer(embed_dim=embed_dim, num_heads=num_heads,
237
+ num_layers=num_layers, num_classes=num_classes)
238
+
239
+ # Metrics
240
+ self.train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)
241
+ self.val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)
242
+ self.test_acc = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)
243
+
244
+ self.train_f1 = torchmetrics.F1Score(task="multiclass", num_classes=num_classes)
245
+ self.val_f1 = torchmetrics.F1Score(task="multiclass", num_classes=num_classes)
246
+ self.test_f1 = torchmetrics.F1Score(task="multiclass", num_classes=num_classes)
247
+
248
+ self.learning_rate = learning_rate
249
+ self.weight_decay = weight_decay
250
+
251
+ def forward(self, x, cross_attention_input=None):
252
+ x = x.squeeze(1)
253
+
254
+ # Wav2Vec2 Feature Extraction
255
+ features = self.feature_extractor(x).last_hidden_state
256
+ features = self.projection(features)
257
+
258
+ if cross_attention_input is None:
259
+ cross_attention_input = features
260
+
261
+ x = self.decoder(features, cross_attention_input)
262
+
263
+ return x
264
+
265
+ def training_step(self, batch, batch_idx):
266
+ x, y = batch
267
+ logits = self(x)
268
+ loss = F.cross_entropy(logits, y)
269
+
270
+ preds = torch.argmax(logits, dim=1)
271
+ self.train_acc(preds, y)
272
+ self.train_f1(preds, y)
273
+
274
+ self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
275
+ self.log('train_acc', self.train_acc, on_step=False, on_epoch=True, prog_bar=True)
276
+ self.log('train_f1', self.train_f1, on_step=False, on_epoch=True, prog_bar=True)
277
+
278
+ return loss
279
+
280
+ def validation_step(self, batch, batch_idx):
281
+ x, y = batch
282
+ logits = self(x)
283
+ loss = F.cross_entropy(logits, y)
284
+
285
+ preds = torch.argmax(logits, dim=1)
286
+ self.val_acc(preds, y)
287
+ self.val_f1(preds, y)
288
+
289
+ self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
290
+ self.log('val_acc', self.val_acc, on_step=False, on_epoch=True, prog_bar=True)
291
+ self.log('val_f1', self.val_f1, on_step=False, on_epoch=True, prog_bar=True)
292
+
293
+ return loss
294
+
295
+ def test_step(self, batch, batch_idx):
296
+ x, y = batch
297
+ logits = self(x)
298
+ loss = F.cross_entropy(logits, y)
299
+
300
+ preds = torch.argmax(logits, dim=1)
301
+ self.test_acc(preds, y)
302
+ self.test_f1(preds, y)
303
+
304
+ self.log('test_loss', loss, on_step=False, on_epoch=True)
305
+ self.log('test_acc', self.test_acc, on_step=False, on_epoch=True)
306
+ self.log('test_f1', self.test_f1, on_step=False, on_epoch=True)
307
+
308
+ return loss
309
+
310
+ def configure_optimizers(self):
311
+ optimizer = torch.optim.AdamW(
312
+ self.parameters(),
313
+ lr=self.learning_rate,
314
+ weight_decay=self.weight_decay
315
+ )
316
+
317
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
318
+ optimizer,
319
+ mode='min',
320
+ factor=0.5,
321
+ patience=2,
322
+ verbose=True
323
+ )
324
+
325
+ return {
326
+ "optimizer": optimizer,
327
+ "lr_scheduler": {
328
+ "scheduler": scheduler,
329
+ "monitor": "val_loss",
330
+ "interval": "epoch",
331
+ "frequency": 1
332
+ }
333
+ }
334
+
335
+ class Music2vec_AudioCNN(pl.LightningModule):
336
+ def __init__(self, embed_dim=512, num_heads=8, num_layers=6, num_classes=2,
337
+ learning_rate=2e-5, weight_decay=0.01):
338
+ super(Music2vec_AudioCNN, self).__init__()
339
+ self.save_hyperparameters()
340
+
341
+ self.feature_extractor = Music2vec(freeze_feature_extractor=True)
342
+ self.projection = nn.Linear(self.feature_extractor.music2vec.config.hidden_size, embed_dim)
343
+ self.decoder = CrossAttn_Transformer(embed_dim=embed_dim, num_heads=num_heads,
344
+ num_layers=num_layers, num_classes=num_classes)
345
+
346
+ # Metrics
347
+ self.train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)
348
+ self.val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)
349
+ self.test_acc = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)
350
+
351
+ self.train_f1 = torchmetrics.F1Score(task="multiclass", num_classes=num_classes)
352
+ self.val_f1 = torchmetrics.F1Score(task="multiclass", num_classes=num_classes)
353
+ self.test_f1 = torchmetrics.F1Score(task="multiclass", num_classes=num_classes)
354
+
355
+ self.learning_rate = learning_rate
356
+ self.weight_decay = weight_decay
357
+
358
+ def forward(self, x, cross_attention_input=None):
359
+ x = x.squeeze(1)
360
+ features = self.feature_extractor(x)
361
+ features = self.projection(features)
362
+
363
+ if cross_attention_input is None:
364
+ cross_attention_input = features
365
+
366
+ x = self.decoder(features.unsqueeze(1), cross_attention_input.unsqueeze(1))
367
+ return x
368
+
369
+ def training_step(self, batch, batch_idx):
370
+ x, y = batch
371
+ logits = self(x)
372
+ loss = F.cross_entropy(logits, y)
373
+
374
+ preds = torch.argmax(logits, dim=1)
375
+ self.train_acc(preds, y)
376
+ self.train_f1(preds, y)
377
+
378
+ self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
379
+ self.log('train_acc', self.train_acc, on_step=False, on_epoch=True, prog_bar=True)
380
+ self.log('train_f1', self.train_f1, on_step=False, on_epoch=True, prog_bar=True)
381
+
382
+ return loss
383
+
384
+ def validation_step(self, batch, batch_idx):
385
+ x, y = batch
386
+ logits = self(x)
387
+ loss = F.cross_entropy(logits, y)
388
+
389
+ preds = torch.argmax(logits, dim=1)
390
+ self.val_acc(preds, y)
391
+ self.val_f1(preds, y)
392
+
393
+ self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
394
+ self.log('val_acc', self.val_acc, on_step=False, on_epoch=True, prog_bar=True)
395
+ self.log('val_f1', self.val_f1, on_step=False, on_epoch=True, prog_bar=True)
396
+
397
+ return loss
398
+
399
+ def test_step(self, batch, batch_idx):
400
+ x, y = batch
401
+ logits = self(x)
402
+ loss = F.cross_entropy(logits, y)
403
+
404
+ preds = torch.argmax(logits, dim=1)
405
+ self.test_acc(preds, y)
406
+ self.test_f1(preds, y)
407
+
408
+ self.log('test_loss', loss, on_step=False, on_epoch=True)
409
+ self.log('test_acc', self.test_acc, on_step=False, on_epoch=True)
410
+ self.log('test_f1', self.test_f1, on_step=False, on_epoch=True)
411
+
412
+ return loss
413
+
414
+ def configure_optimizers(self):
415
+ optimizer = torch.optim.AdamW(
416
+ self.parameters(),
417
+ lr=self.learning_rate,
418
+ weight_decay=self.weight_decay
419
+ )
420
+
421
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
422
+ optimizer,
423
+ mode='min',
424
+ factor=0.5,
425
+ patience=2,
426
+ verbose=True
427
+ )
428
+
429
+ return {
430
+ "optimizer": optimizer,
431
+ "lr_scheduler": {
432
+ "scheduler": scheduler,
433
+ "monitor": "val_loss",
434
+ "interval": "epoch",
435
+ "frequency": 1
436
+ }
437
+ }
438
+
439
+ class AudioCNN(pl.LightningModule):
440
+ def __init__(self, embed_dim=512, num_heads=8, num_layers=6, num_classes=2,
441
+ learning_rate=2e-5, weight_decay=0.01):
442
+ super(AudioCNN, self).__init__()
443
+ self.save_hyperparameters()
444
+
445
+ self.encoder = cnnblock(embed_dim=embed_dim)
446
+ self.decoder = CrossAttn_Transformer(embed_dim=embed_dim, num_heads=num_heads,
447
+ num_layers=num_layers, num_classes=num_classes)
448
+
449
+ # Metrics
450
+ self.train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)
451
+ self.val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)
452
+ self.test_acc = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)
453
+
454
+ self.train_f1 = torchmetrics.F1Score(task="multiclass", num_classes=num_classes)
455
+ self.val_f1 = torchmetrics.F1Score(task="multiclass", num_classes=num_classes)
456
+ self.test_f1 = torchmetrics.F1Score(task="multiclass", num_classes=num_classes)
457
+
458
+ self.learning_rate = learning_rate
459
+ self.weight_decay = weight_decay
460
+
461
+ def forward(self, x, cross_attention_input=None):
462
+ x = self.encoder(x)
463
+ x = x.unsqueeze(1)
464
+ if cross_attention_input is None:
465
+ cross_attention_input = x
466
+ x = self.decoder(x, cross_attention_input)
467
+ return x
468
+
469
+ def training_step(self, batch, batch_idx):
470
+ x, y = batch
471
+ logits = self(x)
472
+ loss = F.cross_entropy(logits, y)
473
+
474
+ preds = torch.argmax(logits, dim=1)
475
+ self.train_acc(preds, y)
476
+ self.train_f1(preds, y)
477
+
478
+ self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
479
+ self.log('train_acc', self.train_acc, on_step=False, on_epoch=True, prog_bar=True)
480
+ self.log('train_f1', self.train_f1, on_step=False, on_epoch=True, prog_bar=True)
481
+
482
+ return loss
483
+
484
+ def validation_step(self, batch, batch_idx):
485
+ x, y = batch
486
+ logits = self(x)
487
+ loss = F.cross_entropy(logits, y)
488
+
489
+ preds = torch.argmax(logits, dim=1)
490
+ self.val_acc(preds, y)
491
+ self.val_f1(preds, y)
492
+
493
+ self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
494
+ self.log('val_acc', self.val_acc, on_step=False, on_epoch=True, prog_bar=True)
495
+ self.log('val_f1', self.val_f1, on_step=False, on_epoch=True, prog_bar=True)
496
+
497
+ return loss
498
+
499
+ def test_step(self, batch, batch_idx):
500
+ x, y = batch
501
+ logits = self(x)
502
+ loss = F.cross_entropy(logits, y)
503
+
504
+ preds = torch.argmax(logits, dim=1)
505
+ self.test_acc(preds, y)
506
+ self.test_f1(preds, y)
507
+
508
+ self.log('test_loss', loss, on_step=False, on_epoch=True)
509
+ self.log('test_acc', self.test_acc, on_step=False, on_epoch=True)
510
+ self.log('test_f1', self.test_f1, on_step=False, on_epoch=True)
511
+
512
+ return loss
513
+
514
+ def configure_optimizers(self):
515
+ optimizer = torch.optim.AdamW(
516
+ self.parameters(),
517
+ lr=self.learning_rate,
518
+ weight_decay=self.weight_decay
519
+ )
520
+
521
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
522
+ optimizer,
523
+ mode='min',
524
+ factor=0.5,
525
+ patience=2,
526
+ verbose=True
527
+ )
528
+
529
+ return {
530
+ "optimizer": optimizer,
531
+ "lr_scheduler": {
532
+ "scheduler": scheduler,
533
+ "monitor": "val_loss",
534
+ "interval": "epoch",
535
+ "frequency": 1
536
+ }
537
+ }
538
+
539
+
540
+ # 필요한 보조 클래스들
541
+ class Music2vec(nn.Module):
542
+ def __init__(self, freeze_feature_extractor=True):
543
+ super(Music2vec, self).__init__()
544
+ self.processor = Wav2Vec2Processor.from_pretrained("facebook/data2vec-audio-base-960h")
545
+ self.music2vec = Data2VecAudioModel.from_pretrained("m-a-p/music2vec-v1")
546
+
547
+ if freeze_feature_extractor:
548
+ for param in self.music2vec.parameters():
549
+ param.requires_grad = False
550
+ self.conv1d = nn.Conv1d(in_channels=13, out_channels=1, kernel_size=1)
551
+
552
+ def forward(self, input_values):
553
+ input_values = input_values.squeeze(1)
554
+ with torch.no_grad():
555
+ outputs = self.music2vec(input_values, output_hidden_states=True)
556
+ hidden_states = torch.stack(outputs.hidden_states)
557
+ time_reduced = hidden_states.mean(dim=2)
558
+ time_reduced = time_reduced.permute(1, 0, 2)
559
+ weighted_avg = self.conv1d(time_reduced).squeeze(1)
560
+ return weighted_avg