wh1tet3a commited on
Commit
ed8c770
·
0 Parent(s):

add model

Browse files
Files changed (5) hide show
  1. .gitattributes +35 -0
  2. README.md +190 -0
  3. config.json +1 -0
  4. model.py +797 -0
  5. model.safetensors +3 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: pytorch
3
+ tags:
4
+ - audio
5
+ - spoofing-detection
6
+ - anti-spoofing
7
+ - wav2vec2
8
+ - aasist
9
+ license: apache-2.0
10
+ pipeline_tag: audio-classification
11
+ model-index:
12
+ - name: spectra_aasist
13
+ results:
14
+ - task:
15
+ type: Speech Antispoofing
16
+ dataset:
17
+ name: ASVspoof19_LA
18
+ type: ASVspoof19_LA
19
+ metrics:
20
+ - name: Equal Error Rate
21
+ type: Equal Error Rate
22
+ value: 0.159
23
+ - task:
24
+ type: Speech Antispoofing
25
+ dataset:
26
+ name: ASVspoof21_LA
27
+ type: ASVspoof21_LA
28
+ metrics:
29
+ - name: Equal Error Rate
30
+ type: Equal Error Rate
31
+ value: 5.164
32
+ - task:
33
+ type: Speech Antispoofing
34
+ dataset:
35
+ name: ASVspoof21_DF
36
+ type: ASVspoof21_DF
37
+ metrics:
38
+ - name: Equal Error Rate
39
+ type: Equal Error Rate
40
+ value: 2.568
41
+ - task:
42
+ type: Speech Antispoofing
43
+ dataset:
44
+ name: ASVspoof5
45
+ type: ASVspoof5
46
+ metrics:
47
+ - name: Equal Error Rate
48
+ type: Equal Error Rate
49
+ value: 14.056
50
+ - task:
51
+ type: Speech Antispoofing
52
+ dataset:
53
+ name: ADD2022
54
+ type: ADD2022
55
+ metrics:
56
+ - name: Equal Error Rate
57
+ type: Equal Error Rate
58
+ value: 15.205
59
+ - task:
60
+ type: Speech Antispoofing
61
+ dataset:
62
+ name: In-the-Wild
63
+ type: In-the-Wild
64
+ metrics:
65
+ - name: Equal Error Rate
66
+ type: Equal Error Rate
67
+ value: 1.461
68
+ ---
69
+
70
+ ## Model Card: Spectra-0 (anti-spoofing / bonafide vs spoof)
71
+
72
+ `Spectra-AASIST` is a model for **speech spoofing detection** (binary classification: `bonafide` vs `spoof`) from **raw audio waveforms**. Architecture: SSL encoder (`Wav2Vec2`) → MLP projection → `AASIST` 2-class classifier.
73
+
74
+ - **Input**: waveform \(float32\), shape `(batch, num_samples)` (typically 16 kHz).
75
+ - **Output**: logits of shape `(batch, 2)`, where **index 0 = spoof**, **index 1 = bonafide**.
76
+
77
+ On first run, the model will automatically download the SSL encoder `facebook/wav2vec2-xls-r-300m` via `transformers`.
78
+
79
+ ## Evaluation Results
80
+
81
+ | Model | ASVspoof19 LA | ASVspoof21 LA | ASVspoof21 DF | ASVspoof5 | ADD2022 | In-the-Wild |
82
+ |-----------|--------|--------|--------|--------|--------|--------|
83
+ | [Res2TCNGuard](https://github.com/mtuciru/Res2TCNGuard) | 7.487 | 19.130 | 19.883 | 37.620 | 49.538 | 49.246 |
84
+ | [AASIST3](https://huggingface.co/MTUCI/AASIST3) | 27.585 | 37.407 | 33.099 | 41.001 | 47.192 | 39.626 |
85
+ | [XSLS](https://github.com/QiShanZhang/SLSforASVspoof-2021-DF) | 0.231 | 7.714 | 4.220 | 17.688 | 33.951 | 7.453 |
86
+ | [TCM-ADD](https://github.com/ductuantruong/tcm_add) | **0.152** | 6.655 | 3.444 | 19.505 | 35.252 | 7.767 |
87
+ | [DF Arena 1B](https://huggingface.co/Speech-Arena-2025/DF_Arena_1B_V_1) | 43.793 | 40.137 | 42.994 | 35.333 | 42.139 | 17.598 |
88
+ | **Spectra-AASIST** | 0.159 | **5.164** | **2.568** | **14.056** | **15.205** | **1.461** |
89
+
90
+ ## Quickstart
91
+
92
+ ### Clone from Hugging Face
93
+
94
+ This repository is hosted on Hugging Face Hub: `https://huggingface.co/MTUCI/spectra_aasist`.
95
+
96
+ ```bash
97
+ git lfs install
98
+ git clone https://huggingface.co/MTUCI/spectra_aasist
99
+ cd spectra_aasist
100
+ ```
101
+
102
+ ### Install dependencies
103
+
104
+ ```bash
105
+ pip install -U torch torchaudio transformers huggingface_hub safetensors soundfile
106
+ ```
107
+
108
+ ### Single-file inference (example preprocessing)
109
+
110
+ ```python
111
+ import random
112
+ import torch
113
+ import torchaudio
114
+ import soundfile as sf
115
+
116
+ from model import spectra_aasist
117
+
118
+
119
+ def pad_random(x: torch.Tensor, max_len: int = 64600) -> torch.Tensor:
120
+ # x: (num_samples,) or (1, num_samples)
121
+ if x.ndim > 1:
122
+ x = x.squeeze()
123
+ x_len = x.shape[0]
124
+ if x_len >= max_len:
125
+ start = random.randint(0, x_len - max_len)
126
+ return x[start:start + max_len]
127
+ num_repeats = int(max_len / x_len) + 1
128
+ return x.repeat(num_repeats)[:max_len]
129
+
130
+
131
+ def load_audio_mono(path: str) -> torch.Tensor:
132
+ audio, sr = sf.read(path, dtype="float32")
133
+ audio = torch.from_numpy(audio)
134
+ if audio.ndim > 1:
135
+ # (num_samples, channels) -> mono
136
+ audio = audio.mean(dim=1)
137
+ if sr != 16000:
138
+ audio = torchaudio.functional.resample(audio, sr, 16000)
139
+ return audio
140
+
141
+
142
+ device = "cuda" if torch.cuda.is_available() else "cpu"
143
+ model = spectra_aasist.from_pretrained(pretrained_model_name_or_path=".").eval().to(device)
144
+
145
+ audio = load_audio_mono("path/to/audio.wav")
146
+ audio = torchaudio.functional.preemphasis(audio.unsqueeze(0)) # (1, T)
147
+ audio = pad_random(audio.squeeze(0), 64600).unsqueeze(0) # (1, 64600)
148
+
149
+ with torch.inference_mode():
150
+ logits = model(audio.to(device)) # (1, 2)
151
+ score_spoof = logits[0, 0].item()
152
+ score_bonafide = logits[0, 1].item()
153
+
154
+ print({"score_bonafide": score_bonafide, "score_spoof": score_spoof})
155
+ ```
156
+
157
+ ## Threshold-based classification (and how to tune it)
158
+
159
+ In `model.py`, the `SpectraAASIST` class provides `classify()` with a **default threshold** chosen as an “optimal” value for the original setting:
160
+
161
+ - **Default threshold**: `-1.0625009` (it thresholds `logit_bonafide = logits[:, 1]`)
162
+ - **Note**: this threshold **may not be optimal** on a different dataset/domain. It’s recommended to tune the threshold on your dataset using **EER** (Equal Error Rate) or a target FAR/FRR.
163
+
164
+ Example:
165
+
166
+ ```python
167
+ with torch.inference_mode():
168
+ pred = model.classify(audio.to(device), threshold=-1.0625009) # 1=bonafide, 0=spoof
169
+ ```
170
+
171
+ ### Tuning the threshold via EER (typical workflow)
172
+
173
+ 1) Run the model on a labeled set and collect scores for both classes.
174
+
175
+ 2) Compute EER and the threshold
176
+
177
+ ## Limitations and notes
178
+
179
+ - This is a **pre-release** model.
180
+ - Significantly stronger models are planned for **Q3–Q4 2026** — stay tuned.
181
+
182
+ ## License
183
+
184
+ MIT (see the `license` field in the model repo header).
185
+
186
+ ## Contacts
187
+
188
+ TG channel: https://t.me/korallll_ai
189
+ email: k.n.borodin@mtuci.ru
190
+ website: https://lab260.ru/
config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {}
model.py ADDED
@@ -0,0 +1,797 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import Wav2Vec2Model
4
+ import torch.nn.functional as F
5
+ from huggingface_hub import PyTorchModelHubMixin
6
+
7
+
8
+ class Wav2Vec2Encoder(nn.Module):
9
+ """SSL encoder based on Hugging Face's Wav2Vec2 model."""
10
+
11
+ def __init__(self,
12
+ model_name_or_path: str = "facebook/wav2vec2-base-960h",
13
+ ssl_out_dim: int = 1024,
14
+ use_ssl_n_layers: int = None,
15
+ freeze_ssl_n_layers: int = 0,
16
+ output_attentions: bool = False,
17
+ output_hidden_states: bool = False,
18
+ normalize_waveform: bool = True):
19
+ """Initialize the Wav2Vec2 encoder.
20
+
21
+ Args:
22
+ model_name_or_path: HuggingFace model name or path to local model.
23
+ ssl_out_dim: Output dimension of the Wav2Vec2 encoder.
24
+ use_ssl_n_layers: Number of Wav2Vec2 layers to use. If None, use all layers.
25
+ freeze_ssl_n_layers: Number of Wav2Vec2 layers to freeze during training.
26
+ output_attentions: Whether to output attentions.
27
+ output_hidden_states: Whether to output hidden states.
28
+ normalize_waveform: Whether to normalize the waveform input.
29
+ """
30
+ super().__init__()
31
+
32
+ self.model_name_or_path = model_name_or_path
33
+ self.ssl_out_dim = ssl_out_dim
34
+ self.use_ssl_n_layers = use_ssl_n_layers
35
+ self.freeze_ssl_n_layers = freeze_ssl_n_layers
36
+ self.output_attentions = output_attentions
37
+ self.output_hidden_states = output_hidden_states
38
+ self.normalize_waveform = normalize_waveform
39
+
40
+ # Load Wav2Vec2 model
41
+ self.model = Wav2Vec2Model.from_pretrained(
42
+ model_name_or_path,
43
+ gradient_checkpointing=False)
44
+ self.model.config.apply_spec_augment = False
45
+ self.model.masked_spec_embed = None
46
+
47
+ # Handle layer freezing
48
+ if freeze_ssl_n_layers > 0:
49
+ self._freeze_layers(freeze_ssl_n_layers)
50
+
51
+ def _freeze_layers(self, n_layers):
52
+ """Freeze the first n_layers layers of the Wav2Vec2 encoder.
53
+
54
+ Args:
55
+ n_layers: Number of layers to freeze.
56
+ """
57
+ # Freeze feature extractor
58
+ if n_layers > 0:
59
+ for param in self.model.feature_extractor.parameters():
60
+ param.requires_grad = False
61
+
62
+ # Freeze encoder layers
63
+ encoder_layers = self.model.encoder.layers
64
+ total_layers = len(encoder_layers)
65
+ layers_to_freeze = min(n_layers - 1, total_layers) # -1 because feature_extractor counts as one layer
66
+
67
+ if layers_to_freeze > 0:
68
+ for i in range(layers_to_freeze):
69
+ for param in encoder_layers[i].parameters():
70
+ param.requires_grad = False
71
+
72
+ def forward(self, x):
73
+ """Forward pass through the Wav2Vec2 encoder.
74
+
75
+ Args:
76
+ x: Input tensor of shape (batch_size, sequence_length, channels)
77
+
78
+ Returns:
79
+ Extracted features of shape (batch_size, sequence_length, ssl_out_dim)
80
+ """
81
+ # Handle shape: convert (batch_size, sequence_length, channels) to (batch_size, sequence_length)
82
+ if x.ndim == 3:
83
+ x = x.squeeze(-1) # Remove channel dimension if present
84
+
85
+ # Normalize input if specified
86
+ if self.normalize_waveform:
87
+ x = x / (torch.max(torch.abs(x), dim=1, keepdim=True)[0] + 1e-8)
88
+
89
+ # Wav2Vec2 forward pass
90
+ outputs = self.model(
91
+ x,
92
+ output_attentions=self.output_attentions,
93
+ output_hidden_states=self.output_hidden_states,
94
+ return_dict=True
95
+ )
96
+
97
+ # Extract last hidden state
98
+ last_hidden_state = outputs.last_hidden_state
99
+
100
+ # Optionally use only a subset of layers (if use_ssl_n_layers is set and output_hidden_states is True)
101
+ if self.use_ssl_n_layers is not None and self.output_hidden_states and outputs.hidden_states is not None:
102
+ # Use the last N hidden states and concatenate or average them
103
+ selected = outputs.hidden_states[-self.use_ssl_n_layers:]
104
+ last_hidden_state = torch.mean(torch.stack(selected, dim=0), dim=0)
105
+ del outputs
106
+
107
+ return last_hidden_state
108
+
109
+
110
+ class MLPBridge(nn.Module):
111
+ """MLP bridge between SSL encoder and AASIST model."""
112
+
113
+ def __init__(self, input_dim: int, output_dim: int, hidden_dim: int = None,
114
+ dropout: float = 0.1, activation: str = nn.ReLU, n_layers: int = 1):
115
+ """Initialize the MLP bridge.
116
+
117
+ Args:
118
+ input_dim: The input dimension from the SSL encoder.
119
+ output_dim: The output dimension for the AASIST model.
120
+ hidden_dim: Hidden dimension size. If None, use the average of input and output dims.
121
+ dropout: Dropout probability to apply between layers.
122
+ activation: Activation function to use
123
+ n_layers: Number of MLP layers (repeats of Linear+Activation+Dropout blocks).
124
+ """
125
+ super().__init__()
126
+
127
+ if hidden_dim is None:
128
+ hidden_dim = (input_dim + output_dim) // 2
129
+
130
+ self.input_dim = input_dim
131
+ self.output_dim = output_dim
132
+ self.hidden_dim = hidden_dim
133
+ self.n_layers = n_layers
134
+
135
+ assert hasattr(activation, 'forward') and callable(getattr(activation, 'forward', None)), "Activation class must have a callable forward() method."
136
+ act_fn = activation
137
+
138
+ layers = []
139
+ for i in range(n_layers):
140
+ in_dim = input_dim if i == 0 else hidden_dim
141
+ out_dim = hidden_dim
142
+ layers.append(nn.Linear(in_dim, out_dim))
143
+ layers.append(act_fn)
144
+ layers.append(nn.Dropout(dropout) if dropout > 0 else nn.Identity())
145
+ # Final output layer
146
+ layers.append(nn.Linear(hidden_dim, output_dim))
147
+ layers.append(nn.Dropout(dropout) if dropout > 0 else nn.Identity())
148
+
149
+ self.mlp = nn.Sequential(*layers)
150
+
151
+ def forward(self, x):
152
+ """Forward pass through the bridge.
153
+
154
+ Args:
155
+ x: The input tensor from the SSL encoder.
156
+
157
+ Returns:
158
+ The transformed tensor for the AASIST model.
159
+ """
160
+ return self.mlp(x)
161
+
162
+
163
+ class HtrgGraphAttentionLayer(nn.Module):
164
+ def __init__(self, in_dim, out_dim, size, layer="KANLinear", **kwargs):
165
+ super().__init__()
166
+ if layer == "Linear":
167
+ self.proj_type1 = nn.Linear(in_dim, in_dim)
168
+ self.proj_type2 = nn.Linear(in_dim, in_dim)
169
+ self.att_proj = nn.Linear(in_dim, out_dim)
170
+ self.att_projM = nn.Linear(in_dim, out_dim)
171
+ self.proj_with_att = nn.Linear(in_dim, out_dim)
172
+ self.proj_without_att = nn.Linear(in_dim, out_dim)
173
+ self.proj_with_attM = nn.Linear(in_dim, out_dim)
174
+ self.proj_without_attM = nn.Linear(in_dim, out_dim)
175
+ else:
176
+ raise ValueError(f"Invalid layer type: {layer}")
177
+ self.att_weight11 = self._init_new_params(out_dim, 1)
178
+ self.att_weight22 = self._init_new_params(out_dim, 1)
179
+ self.att_weight12 = self._init_new_params(out_dim, 1)
180
+ self.att_weightM = self._init_new_params(out_dim, 1)
181
+ self.bn = nn.BatchNorm1d(out_dim)
182
+ self.input_drop = nn.Dropout(p=0.2)
183
+ self.act = nn.SELU(inplace=True)
184
+ self.temp = 1.
185
+ if "temperature" in kwargs:
186
+ self.temp = kwargs["temperature"]
187
+
188
+ def forward(self, x1, x2, master=None):
189
+ '''
190
+ x1 :(#bs, #node, #dim)
191
+ x2 :(#bs, #node, #dim)
192
+ '''
193
+ num_type1 = x1.size(1)
194
+ num_type2 = x2.size(1)
195
+
196
+ x1 = self.proj_type1(x1)
197
+ x2 = self.proj_type2(x2)
198
+
199
+ x = torch.cat([x1, x2], dim=1)
200
+
201
+ if master is None:
202
+ master = torch.mean(x, dim=1, keepdim=True)
203
+
204
+ # apply input dropout
205
+ x = self.input_drop(x)
206
+
207
+ # derive attention map
208
+ att_map = self._derive_att_map(x, num_type1, num_type2)
209
+
210
+ # directional edge for master node
211
+ master = self._update_master(x, master)
212
+
213
+ # projection
214
+ x = self._project(x, att_map)
215
+
216
+ # apply batch norm
217
+ x = self._apply_BN(x)
218
+ # x = self.act(x)
219
+
220
+ x1 = x.narrow(1, 0, num_type1)
221
+ x2 = x.narrow(1, num_type1, num_type2)
222
+
223
+ return x1, x2, master
224
+
225
+ def _update_master(self, x, master):
226
+
227
+ att_map = self._derive_att_map_master(x, master)
228
+ master = self._project_master(x, master, att_map)
229
+
230
+ return master
231
+
232
+ def _pairwise_mul_nodes(self, x):
233
+ '''
234
+ Calculates pairwise multiplication of nodes.
235
+ - for attention map
236
+ x :(#bs, #node, #dim)
237
+ out_shape :(#bs, #node, #node, #dim)
238
+ '''
239
+
240
+ nb_nodes = x.size(1)
241
+ x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1)
242
+ x_mirror = x.transpose(1, 2)
243
+
244
+ return x * x_mirror
245
+
246
+ def _derive_att_map_master(self, x, master):
247
+ '''
248
+ x :(#bs, #node, #dim)
249
+ out_shape :(#bs, #node, #node, 1)
250
+ '''
251
+ att_map = x * master
252
+ att_map = torch.tanh(self.att_projM(att_map))
253
+
254
+ att_map = torch.matmul(att_map, self.att_weightM)
255
+
256
+ # apply temperature
257
+ att_map = att_map / self.temp
258
+
259
+ att_map = F.softmax(att_map, dim=-2)
260
+
261
+ return att_map
262
+
263
+ def _derive_att_map(self, x, num_type1, num_type2):
264
+ '''
265
+ x :(#bs, #node, #dim)
266
+ out_shape :(#bs, #node, #node, 1)
267
+ '''
268
+ att_map = self._pairwise_mul_nodes(x)
269
+ # size: (#bs, #node, #node, #dim_out)
270
+ att_map = torch.tanh(self.att_proj(att_map))
271
+ # size: (#bs, #node, #node, 1)
272
+
273
+ att_board = torch.zeros_like(att_map[:, :, :, 0]).unsqueeze(-1)
274
+
275
+ att_board[:, :num_type1, :num_type1, :] = torch.matmul(
276
+ att_map[:, :num_type1, :num_type1, :], self.att_weight11)
277
+ att_board[:, num_type1:, num_type1:, :] = torch.matmul(
278
+ att_map[:, num_type1:, num_type1:, :], self.att_weight22)
279
+ att_board[:, :num_type1, num_type1:, :] = torch.matmul(
280
+ att_map[:, :num_type1, num_type1:, :], self.att_weight12)
281
+ att_board[:, num_type1:, :num_type1, :] = torch.matmul(
282
+ att_map[:, num_type1:, :num_type1, :], self.att_weight12)
283
+
284
+ att_map = att_board
285
+
286
+ # att_map = torch.matmul(att_map, self.att_weight12)
287
+
288
+ # apply temperature
289
+ att_map = att_map / self.temp
290
+
291
+ att_map = F.softmax(att_map, dim=-2)
292
+
293
+ return att_map
294
+
295
+ def _project(self, x, att_map):
296
+ x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x))
297
+ x2 = self.proj_without_att(x)
298
+
299
+ return x1 + x2
300
+
301
+ def _project_master(self, x, master, att_map):
302
+
303
+ x1 = self.proj_with_attM(torch.matmul(
304
+ att_map.squeeze(-1).unsqueeze(1), x))
305
+ x2 = self.proj_without_attM(master)
306
+
307
+ return x1 + x2
308
+
309
+ def _apply_BN(self, x):
310
+ org_size = x.size()
311
+ x = x.view(-1, org_size[-1])
312
+ x = self.bn(x)
313
+ x = x.view(org_size)
314
+
315
+ return x
316
+
317
+ def _init_new_params(self, *size):
318
+ out = nn.Parameter(torch.FloatTensor(*size))
319
+ nn.init.xavier_normal_(out)
320
+ return out
321
+
322
+
323
+ class GraphPool(nn.Module):
324
+ def __init__(self, k: float, in_dim: int, p, size, layer="KANLinear"):
325
+ super().__init__()
326
+ self.k = k
327
+ self.sigmoid = nn.Sigmoid()
328
+ if layer == "Linear":
329
+ self.proj = nn.Linear(in_dim, 1)
330
+ else:
331
+ raise ValueError(f"Invalid layer type: {layer}")
332
+ self.drop = nn.Dropout(p=p) if p > 0 else nn.Identity()
333
+ self.in_dim = in_dim
334
+
335
+ def forward(self, h):
336
+ Z = self.drop(h)
337
+ weights = self.proj(Z)
338
+ scores = self.sigmoid(weights)
339
+ new_h = self.top_k_graph(scores, h, self.k)
340
+
341
+ return new_h
342
+
343
+ def top_k_graph(self, scores, h, k):
344
+ """
345
+ args
346
+ =====
347
+ scores: attention-based weights (#bs, #node, 1)
348
+ h: graph data (#bs, #node, #dim)
349
+ k: ratio of remaining nodes, (float)
350
+
351
+ returns
352
+ =====
353
+ h: graph pool applied data (#bs, #node', #dim)
354
+ """
355
+ _, n_nodes, n_feat = h.size()
356
+ n_nodes = max(int(n_nodes * k), 1)
357
+ _, idx = torch.topk(scores, n_nodes, dim=1)
358
+ idx = idx.expand(-1, -1, n_feat)
359
+
360
+ h = h * scores
361
+ h = torch.gather(h, 1, idx)
362
+
363
+ return h
364
+
365
+
366
+ class GraphAttentionLayer(nn.Module):
367
+ def __init__(self, in_dim, out_dim, layer="KANLinear", **kwargs):
368
+ super().__init__()
369
+ # attention map
370
+ if layer == "Linear":
371
+ self.att_proj = nn.Linear(in_dim, out_dim)
372
+ self.proj_with_att = nn.Linear(in_dim, out_dim)
373
+ self.proj_without_att = nn.Linear(in_dim, out_dim)
374
+ else:
375
+ raise ValueError(f"Invalid layer type: {layer}")
376
+ self.att_weight = self._init_new_params(out_dim, 1)
377
+
378
+ # batch norm
379
+ self.bn = nn.BatchNorm1d(out_dim)
380
+
381
+ # dropout for inputs
382
+ self.input_drop = nn.Dropout(p=0.2)
383
+
384
+ # activate
385
+ self.act = nn.SELU(inplace=True)
386
+
387
+ # temperature
388
+ self.temp = 1.
389
+ if "temperature" in kwargs:
390
+ self.temp = kwargs["temperature"]
391
+
392
+ def forward(self, x):
393
+ '''
394
+ x :(#bs, #node, #dim)
395
+ '''
396
+ # apply input dropout
397
+ x = self.input_drop(x)
398
+
399
+ # derive attention map
400
+ att_map = self._derive_att_map(x)
401
+
402
+ # projection
403
+ x = self._project(x, att_map)
404
+
405
+ # apply batch norm
406
+ x = self._apply_BN(x)
407
+ x = self.act(x)
408
+ return x
409
+
410
+ def _pairwise_mul_nodes(self, x):
411
+ '''
412
+ Calculates pairwise multiplication of nodes.
413
+ - for attention map
414
+ x :(#bs, #node, #dim)
415
+ out_shape :(#bs, #node, #node, #dim)
416
+ '''
417
+
418
+ nb_nodes = x.size(1)
419
+ x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1)
420
+ x_mirror = x.transpose(1, 2)
421
+
422
+ return x * x_mirror
423
+
424
+ def _derive_att_map(self, x):
425
+ '''
426
+ x :(#bs, #node, #dim)
427
+ out_shape :(#bs, #node, #node, 1)
428
+ '''
429
+ att_map = self._pairwise_mul_nodes(x)
430
+ # size: (#bs, #node, #node, #dim_out)
431
+ att_map = torch.tanh(self.att_proj(att_map))
432
+ # size: (#bs, #node, #node, 1)
433
+ att_map = torch.matmul(att_map, self.att_weight)
434
+
435
+ # apply temperature
436
+ att_map = att_map / self.temp
437
+
438
+ att_map = F.softmax(att_map, dim=-2)
439
+
440
+ return att_map
441
+
442
+ def _project(self, x, att_map):
443
+ x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x))
444
+ x2 = self.proj_without_att(x)
445
+
446
+ return x1 + x2
447
+
448
+ def _apply_BN(self, x):
449
+ org_size = x.size()
450
+ x = x.view(-1, org_size[-1])
451
+ x = self.bn(x)
452
+ x = x.view(org_size)
453
+
454
+ return x
455
+
456
+ def _init_new_params(self, *size):
457
+ out = nn.Parameter(torch.FloatTensor(*size))
458
+ nn.init.xavier_normal_(out)
459
+ return out
460
+
461
+
462
+ class Res2NetBlock(nn.Module):
463
+ def __init__(self, in_channels, out_channels, scale=4, kernel_size=(2, 3), stride=1, padding=(1, 1)):
464
+ super().__init__()
465
+ assert out_channels % scale == 0, "out_channels must be divisible by scale"
466
+ self.scale = scale
467
+ self.width = out_channels // scale
468
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1)
469
+ self.convs = nn.ModuleList([
470
+ nn.Conv2d(self.width, self.width, kernel_size=kernel_size, stride=stride, padding=padding)
471
+ for _ in range(scale)
472
+ ])
473
+ self.bn = nn.BatchNorm2d(out_channels)
474
+ self.selu = nn.SELU(inplace=True)
475
+ self.conv3 = nn.Conv2d(out_channels, out_channels, kernel_size=1)
476
+ self.downsample = None
477
+ if in_channels != out_channels:
478
+ self.downsample = nn.Conv2d(in_channels, out_channels, kernel_size=1)
479
+
480
+ def forward(self, x):
481
+ identity = x
482
+ out = self.conv1(x)
483
+ xs = torch.chunk(out, self.scale, dim=1)
484
+ ys = []
485
+ for s in range(self.scale):
486
+ if s == 0:
487
+ ys.append(self.convs[s](xs[s]))
488
+ else:
489
+ ys.append(self.convs[s](xs[s] + ys[s - 1]))
490
+ out = torch.cat(ys, dim=1)
491
+ out = self.bn(out)
492
+ out = self.selu(out)
493
+ out = self.conv3(out)
494
+ if self.downsample is not None:
495
+ identity = self.downsample(identity)
496
+ out += identity
497
+ return out
498
+
499
+
500
+ class Residual_block(nn.Module):
501
+ def __init__(self, nb_filts, first=False):
502
+ super().__init__()
503
+ self.first = first
504
+
505
+ if not self.first:
506
+ self.bn1 = nn.BatchNorm2d(num_features=nb_filts[0])
507
+ self.conv1 = nn.Conv2d(in_channels=nb_filts[0],
508
+ out_channels=nb_filts[1],
509
+ kernel_size=(2, 3),
510
+ padding=(1, 1),
511
+ stride=1)
512
+ self.selu = nn.SELU(inplace=True)
513
+
514
+ self.bn2 = nn.BatchNorm2d(num_features=nb_filts[1])
515
+ self.conv2 = nn.Conv2d(in_channels=nb_filts[1],
516
+ out_channels=nb_filts[1],
517
+ kernel_size=(2, 3),
518
+ padding=(0, 1),
519
+ stride=1)
520
+
521
+ if nb_filts[0] != nb_filts[1]:
522
+ self.downsample = True
523
+ self.conv_downsample = nn.Conv2d(in_channels=nb_filts[0],
524
+ out_channels=nb_filts[1],
525
+ padding=(0, 1),
526
+ kernel_size=(1, 3),
527
+ stride=1)
528
+
529
+ else:
530
+ self.downsample = False
531
+
532
+ def forward(self, x):
533
+ identity = x
534
+ if not self.first:
535
+ out = self.bn1(x)
536
+ out = self.selu(out)
537
+ else:
538
+ out = x
539
+
540
+ # print('out',out.shape)
541
+ out = self.conv1(out)
542
+
543
+ # print('aft conv1 out',out.shape)
544
+ out = self.bn2(out)
545
+ out = self.selu(out)
546
+ # print('out',out.shape)
547
+ out = self.conv2(out)
548
+ # print('conv2 out',out.shape)
549
+
550
+ if self.downsample:
551
+ identity = self.conv_downsample(identity)
552
+
553
+ out += identity
554
+ # out = self.mp(out)
555
+ return out
556
+
557
+
558
+ class Encoder(nn.Module):
559
+ def __init__(self, filts):
560
+ super().__init__()
561
+
562
+ self.first_bn = nn.BatchNorm2d(num_features=1)
563
+ self.first_bn1 = nn.BatchNorm2d(num_features=64)
564
+
565
+ self.selu = nn.SELU(inplace=True)
566
+ self.enc = nn.Sequential(
567
+ nn.Sequential(Residual_block(nb_filts=filts[1], first=True)),
568
+ nn.Sequential(Residual_block(nb_filts=filts[2])),
569
+ nn.Sequential(Residual_block(nb_filts=filts[3])),
570
+ nn.Sequential(Residual_block(nb_filts=filts[4])),
571
+ nn.Sequential(Residual_block(nb_filts=filts[4])),
572
+ nn.Sequential(Residual_block(nb_filts=filts[4]))
573
+ )
574
+
575
+ def forward(self, x):
576
+
577
+ x = x.transpose(1, 2)
578
+ x = x.unsqueeze(dim=1)
579
+
580
+ x = F.max_pool2d(torch.abs(x), (3, 3))
581
+ x = self.first_bn(x)
582
+ x = self.selu(x)
583
+
584
+ # # get embeddings using encoder
585
+ # # (#bs, #filt, #spec, #seq)
586
+
587
+ x = self.enc(x)
588
+
589
+ x = self.first_bn1(x)
590
+ x = self.selu(x)
591
+
592
+ return x
593
+
594
+
595
+ class HSGALBranch_v1(nn.Module):
596
+ def __init__(self, gat_dims, temperatures, pool_ratios, size=200, layer="KANLinear"):
597
+ super().__init__()
598
+
599
+ self.master = nn.Parameter(torch.randn(1, 1, gat_dims[0]))
600
+ self.HtrgGAT_layer_ST1 = HtrgGraphAttentionLayer(
601
+ gat_dims[0], gat_dims[1], temperature=temperatures[2], size=size, layer=layer
602
+ )
603
+ self.HtrgGAT_layer_ST2 = HtrgGraphAttentionLayer(
604
+ gat_dims[1], gat_dims[1], temperature=temperatures[2], size=size, layer=layer
605
+ )
606
+
607
+ self.pool_hS = GraphPool(pool_ratios[2], gat_dims[1], 0.3, size=size, layer=layer)
608
+ self.pool_hT = GraphPool(pool_ratios[2], gat_dims[1], 0.3, size=size, layer=layer)
609
+
610
+ self.drop_way = nn.Dropout(0.2, inplace=True)
611
+
612
+ def forward(self, out_t, out_s):
613
+ out_T, out_S, master = self.HtrgGAT_layer_ST1(
614
+ out_t, out_s, master=self.master
615
+ )
616
+
617
+ out_S = self.pool_hS(out_S)
618
+ out_T = self.pool_hT(out_T)
619
+
620
+ out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST2(
621
+ out_T, out_S, master=master
622
+ )
623
+ out_T = out_T + out_T_aug
624
+ out_S = out_S + out_S_aug
625
+ master = master + master_aug
626
+
627
+ out_T = self.drop_way(out_T)
628
+ out_S = self.drop_way(out_S)
629
+ master = self.drop_way(master)
630
+
631
+ return out_T, out_S, master
632
+
633
+
634
+ class KANAASIST(nn.Module):
635
+ """KAN-AASIST model with graph attention layers."""
636
+
637
+ def __init__(
638
+ self,
639
+ d_args={
640
+ "architecture": "AASIST",
641
+ "nb_samp": 64600,
642
+ "filts": [512, [1, 32], [32, 32], [32, 64], [64, 64]],
643
+ "gat_dims": [64, 32],
644
+ "pool_ratios": [0.5, 0.5, 0.5, 0.5],
645
+ "temperatures": [2.0, 2.0, 100.0, 100.0]
646
+ },
647
+ encoder=Encoder,
648
+ size=200,
649
+ n_frames=400,
650
+ layer_type="Linear",
651
+ **kwargs
652
+ ):
653
+ super().__init__()
654
+
655
+ layer = layer_type
656
+ self.d_args = d_args
657
+ filts = d_args["filts"]
658
+ gat_dims = d_args["gat_dims"]
659
+ pool_ratios = d_args["pool_ratios"]
660
+ temperatures = d_args["temperatures"]
661
+
662
+ self.drop = nn.Dropout(0.5, inplace=True)
663
+ self.drop_way = nn.Dropout(0.2, inplace=True)
664
+
665
+ self.attention = nn.Sequential(
666
+ nn.Conv2d(64, 128, kernel_size=(1, 1)),
667
+ nn.SELU(inplace=True),
668
+ nn.BatchNorm2d(128),
669
+ nn.Conv2d(128, 64, kernel_size=(1, 1)),
670
+ )
671
+
672
+ self.pos_S = nn.Parameter(torch.randn(1, filts[0] // 3, filts[-1][-1]))
673
+ self.pos_T = nn.Parameter(torch.randn(1, n_frames, filts[0]))
674
+
675
+ self.GAT_layer_S = GraphAttentionLayer(filts[-1][-1],
676
+ gat_dims[0],
677
+ temperature=temperatures[0], size=size, layer=layer)
678
+ self.GAT_layer_T = GraphAttentionLayer(filts[-1][-1],
679
+ gat_dims[0],
680
+ temperature=temperatures[1], size=size, layer=layer)
681
+
682
+ self.branch1 = HSGALBranch_v1(gat_dims, temperatures, pool_ratios, size, layer=layer)
683
+ self.branch2 = HSGALBranch_v1(gat_dims, temperatures, pool_ratios, size, layer=layer)
684
+ self.branch3 = HSGALBranch_v1(gat_dims, temperatures, pool_ratios, size, layer=layer)
685
+ self.branch4 = HSGALBranch_v1(gat_dims, temperatures, pool_ratios, size, layer=layer)
686
+
687
+ self.pool_S = GraphPool(pool_ratios[0], gat_dims[0], 0.3, size=size, layer=layer)
688
+ self.pool_T = GraphPool(pool_ratios[1], gat_dims[0], 0.3, size=size, layer=layer)
689
+
690
+ out_features = 2
691
+ in_features = 5 * gat_dims[1]
692
+ if layer == 'Linear':
693
+ self.out_layer = nn.Linear(in_features, out_features)
694
+ else:
695
+ raise ValueError(f"Invalid layer type: {layer}")
696
+ self.enc = encoder(filts=filts)
697
+
698
+ def forward(self, x, Freq_aug=False):
699
+ """Forward pass through the KAN-AASIST model.
700
+
701
+ Args:
702
+ x: Input tensor of shape (batch_size, seq_len, channels)
703
+ Freq_aug: Whether to use frequency augmentation
704
+
705
+ Returns:
706
+ Model output for binary classification.
707
+ """
708
+ x = x + self.pos_T[:, :x.size(1), :]
709
+ x = self.enc(x)
710
+ # attention block assumes x is (batch, time, feature_dim)
711
+ # Adapt attention block if needed for SSL features
712
+ w = self.attention(x)
713
+ w1 = F.softmax(w, dim=-1)
714
+ m = torch.sum(x * w1, dim=-1)
715
+ e_S = m.transpose(1, 2) + self.pos_S
716
+
717
+ gat_S = self.GAT_layer_S(e_S)
718
+ out_S = self.pool_S(gat_S) # (#bs, #node, #dim)
719
+
720
+ w2 = F.softmax(w, dim=-2)
721
+ m1 = torch.sum(x * w2, dim=-2)
722
+
723
+ e_T = m1.transpose(1, 2)
724
+
725
+ gat_T = self.GAT_layer_T(e_T)
726
+ out_T = self.pool_T(gat_T)
727
+
728
+ out_T1, out_S1, master1 = self.branch1(out_T, out_S)
729
+ out_T2, out_S2, master2 = self.branch2(out_T, out_S)
730
+ out_T3, out_S3, master3 = self.branch3(out_T, out_S)
731
+ out_T4, out_S4, master4 = self.branch4(out_T, out_S)
732
+
733
+ out_T = torch.amax(torch.stack([out_T1, out_T2, out_T3, out_T4]), dim=0)
734
+ out_S = torch.amax(torch.stack([out_S1, out_S2, out_S3, out_S4]), dim=0)
735
+ master = torch.amax(torch.stack([master1, master2, master3, master4]), dim=0)
736
+
737
+ T_max, _ = torch.max(torch.abs(out_T), dim=1)
738
+ T_avg = torch.mean(out_T, dim=1)
739
+
740
+ S_max, _ = torch.max(torch.abs(out_S), dim=1)
741
+ S_avg = torch.mean(out_S, dim=1)
742
+
743
+ last_hidden = torch.cat(
744
+ [T_max, T_avg, S_max, S_avg, master.squeeze(1)], dim=1)
745
+
746
+ last_hidden = self.drop(last_hidden)
747
+ output = self.out_layer(last_hidden)
748
+
749
+ return output
750
+
751
+
752
+ class SpectraAASIST(nn.Module, PyTorchModelHubMixin):
753
+ def __init__(self, **kwargs):
754
+ super().__init__()
755
+ self.ssl_encoder = Wav2Vec2Encoder("facebook/wav2vec2-xls-r-300m",
756
+ 1024,
757
+ None,
758
+ 0,
759
+ False,
760
+ False,
761
+ False)
762
+ self.bridge = MLPBridge(1024,
763
+ 128,
764
+ hidden_dim=128, dropout=0.1, activation=nn.SELU(), n_layers=1)
765
+ self.aasist = KANAASIST(
766
+ d_args={
767
+ "architecture": "AASIST",
768
+ "nb_samp": 64400,
769
+ "filts": [128, [1, 32], [32, 32], [32, 64], [64, 64]],
770
+ "gat_dims": [64, 32],
771
+ "pool_ratios": [0.5, 0.5, 0.5, 0.5],
772
+ "temperatures": [2.0, 2.0, 100.0, 100.0]
773
+ },
774
+ size=200,
775
+ layer_type="Linear"
776
+ )
777
+
778
+ def forward(self, x):
779
+ x = self.ssl_encoder(x)
780
+ x = self.bridge(x)
781
+ x = self.aasist(x)
782
+ return x
783
+
784
+ @torch.inference_mode()
785
+ def classify(self, x, threshold: float = -1.0625009):
786
+ x = self.forward(x)[:, 1]
787
+ x = (x > threshold).float()
788
+ return x.item()
789
+
790
+ spectra_aasist = SpectraAASIST
791
+ if __name__ == "__main__":
792
+ model = SpectraAASIST()
793
+ model.load_state_dict(torch.load("/data/home/maslov/maslov/WEIGHTS/aasist/baseline_v1_linear/weights/model.pt", map_location="cpu"))
794
+ model.eval()
795
+ x = torch.randn(1, 64400)
796
+ print(model(x).shape)
797
+ print(model(x))
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2e2727a7397f78d28b0a2a2b8ee031ff08143b9c431ea7f06fc29a808b0180db
3
+ size 1264151840