mrneuralnet commited on
Commit
3fb4562
1 Parent(s): f067c08

Initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +3 -33
  2. .gitignore +160 -0
  3. LICENSE +21 -0
  4. app.py +108 -0
  5. config.yaml +16 -0
  6. configs/finetuning/whisper_frontend_mesonet.yaml +16 -0
  7. configs/training/lcnn.yaml +14 -0
  8. configs/training/mesonet.yaml +15 -0
  9. configs/training/rawnet3.yaml +13 -0
  10. configs/training/specrnet.yaml +14 -0
  11. configs/training/whisper_frontend_lcnn.yaml +16 -0
  12. configs/training/whisper_frontend_lcnn_mfcc.yaml +15 -0
  13. configs/training/whisper_frontend_mesonet.yaml +16 -0
  14. configs/training/whisper_frontend_mesonet_mfcc.yaml +17 -0
  15. configs/training/whisper_frontend_specrnet.yaml +15 -0
  16. configs/training/whisper_frontend_specrnet_mfcc.yaml +16 -0
  17. configs/training/whisper_lcnn.yaml +15 -0
  18. configs/training/whisper_mesonet.yaml +16 -0
  19. configs/training/whisper_specrnet.yaml +15 -0
  20. download_whisper.py +29 -0
  21. evaluate_models.py +316 -0
  22. install.sh +6 -0
  23. mesonet_whisper_mfcc_finetuned.pth +3 -0
  24. sample_files/[FAKE] - jokowi - cupid [vocals].mp3 +3 -0
  25. sample_files/[REAL] - Obama at Rutgers: 'Ignorance Is Not a Virtue'_[cut_49sec].mp3 +3 -0
  26. sample_files/[REAL] - Obama's speech to the class of 2020 in 2 minutes | The Washington Post.wav +3 -0
  27. sample_files/[[FAKE] - y2mate.com - DeepFake AI generated synthetic video of Barack Obama.mp3 +3 -0
  28. src/__init__.py +3 -0
  29. src/commons.py +22 -0
  30. src/datasets/__init__.py +0 -0
  31. src/datasets/asvspoof_dataset.py +155 -0
  32. src/datasets/base_dataset.py +180 -0
  33. src/datasets/deepfake_asvspoof_dataset.py +86 -0
  34. src/datasets/detection_dataset.py +125 -0
  35. src/datasets/fakeavceleb_dataset.py +94 -0
  36. src/datasets/folder_dataset.py +75 -0
  37. src/datasets/in_the_wild_dataset.py +62 -0
  38. src/datasets/wavefake_dataset.py +85 -0
  39. src/frontends.py +72 -0
  40. src/metrics.py +15 -0
  41. src/models/__init__.py +0 -0
  42. src/models/assets/mel_filters.npz +0 -0
  43. src/models/assets/tiny_enc.en.pt +3 -0
  44. src/models/lcnn.py +247 -0
  45. src/models/meso_net.py +146 -0
  46. src/models/models.py +73 -0
  47. src/models/rawnet3.py +323 -0
  48. src/models/specrnet.py +226 -0
  49. src/models/whisper_lcnn.py +89 -0
  50. src/models/whisper_main.py +323 -0
.gitattributes CHANGED
@@ -1,35 +1,5 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
  *.pt filter=lfs diff=lfs merge=lfs -text
 
 
 
23
  *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  *.pt filter=lfs diff=lfs merge=lfs -text
2
+ *.jpg filter=lfs diff=lfs merge=lfs -text
3
+ *.wav filter=lfs diff=lfs merge=lfs -text
4
+ *.mp3 filter=lfs diff=lfs merge=lfs -text
5
  *.pth filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) Piotr Kawa, Marcin Plata, Michał Czuba, Piotr Szymański, Piotr Syga
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
app.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import json
3
+ import os, shutil
4
+ import re
5
+ import time
6
+ import uuid
7
+
8
+ import cv2
9
+
10
+ import numpy as np
11
+ import streamlit as st
12
+ from pydub import AudioSegment
13
+ import torch
14
+ import yaml
15
+ # from extract_video import extract_method_single_video
16
+
17
+ from utils import st_file_selector, img2base64
18
+ from evaluate_models import inference, load_model
19
+ from src import commons
20
+
21
+ import os
22
+
23
+ DEBUG = True
24
+
25
+ def main():
26
+ st.markdown("###")
27
+ uploaded_file = st.file_uploader('Upload an audio file', type=['wav', 'mp3'], accept_multiple_files=False)
28
+
29
+ with st.spinner(f'Loading samples...'):
30
+ while not os.path.isdir("sample_files"):
31
+ time.sleep(1)
32
+ st.markdown("### or")
33
+ selected_file = st_file_selector(st, path='sample_files', key = 'selected', label = 'Choose a sample image/video')
34
+
35
+ if uploaded_file:
36
+ random_id = uuid.uuid1()
37
+ ext = uploaded_file.name.split('.')[-1]
38
+
39
+ base_folder = "temps"
40
+ filename = "{}.{}".format(random_id, ext)
41
+ file_type = uploaded_file.type.split("/")[0]
42
+ filepath = f"{base_folder}/{filename}"
43
+
44
+ uploaded_file_length = len(uploaded_file.getvalue())
45
+ if uploaded_file_length > 0:
46
+ with open(filepath, 'wb') as f:
47
+ f.write(uploaded_file.read())
48
+ st.audio(uploaded_file, format=ext)
49
+ elif selected_file:
50
+ base_folder = "sample_files"
51
+ file_type = selected_file.split(".")[-1]
52
+ filename = selected_file.split("/")[-1]
53
+ filepath = f"{base_folder}/{selected_file}"
54
+
55
+ st.write('file_type', file_type)
56
+ with open(filepath, 'rb') as f:
57
+ audio_bytes = f.read()
58
+ st.audio(audio_bytes, format=file_type)
59
+ else:
60
+ return
61
+
62
+
63
+
64
+
65
+ with st.spinner(f'Analyzing {file_type}...'):
66
+
67
+
68
+ seed = config["data"].get("seed", 42)
69
+ # fix all seeds - this should not actually change anything
70
+ commons.set_seed(seed)
71
+
72
+ result = inference(
73
+ model,
74
+ datasets_path=filepath,
75
+ device=device,
76
+ )
77
+ result = result[0]
78
+
79
+ if 'Real' == result[0]:
80
+ st.success(f'Audio is real! \nprob:{result[1]}', icon="✅")
81
+ else:
82
+ st.error(f'Audio is fake! \nprob:{result[1]}', icon="🚨")
83
+
84
+ st.divider()
85
+ st.write('## Response JSON')
86
+ st.write(result)
87
+
88
+
89
+ def setup():
90
+ if not os.path.isdir("temps"):
91
+ os.makedirs("temps")
92
+
93
+
94
+
95
+ if __name__ == "__main__":
96
+ if torch.cuda.is_available():
97
+ device = "cuda"
98
+ else:
99
+ device = "cpu"
100
+
101
+ with open('config.yaml', "r") as f:
102
+ config = yaml.safe_load(f)
103
+
104
+ model = load_model(config, device)
105
+
106
+ st.title("Face Fake Detection")
107
+ setup()
108
+ main()
config.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ seed: 42
3
+
4
+ checkpoint:
5
+ path: C:\Users\manfr\Projects\deepfake-whisper-features\mesonet_whisper_mfcc_finetuned.pth
6
+
7
+ model:
8
+ name: whisper_frontend_mesonet
9
+ optimizer:
10
+ lr: 1.0e-06
11
+ weight_decay: 0.0001
12
+ parameters:
13
+ fc1_dim: 1024
14
+ freeze_encoder: false
15
+ frontend_algorithm: ["mfcc"]
16
+ input_channels: 2
configs/finetuning/whisper_frontend_mesonet.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ seed: 42
3
+
4
+ checkpoint:
5
+ path: "trained_models/whisper_frontend_mesonet/ckpt.pth"
6
+
7
+ model:
8
+ name: "whisper_frontend_mesonet"
9
+ parameters:
10
+ freeze_encoder: false
11
+ input_channels: 2
12
+ fc1_dim: 1024
13
+ frontend_algorithm: ["lfcc"]
14
+ optimizer:
15
+ lr: 1.0e-06
16
+ weight_decay: 0.0001
configs/training/lcnn.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ seed: 42
3
+
4
+ checkpoint:
5
+ path: ""
6
+
7
+ model:
8
+ name: "lcnn"
9
+ parameters:
10
+ input_channels: 1
11
+ frontend_algorithm: ["mfcc"]
12
+ optimizer:
13
+ lr: 0.0001
14
+ weight_decay: 0.0001
configs/training/mesonet.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ seed: 42
3
+
4
+ checkpoint:
5
+ path: ""
6
+
7
+ model:
8
+ name: "mesonet"
9
+ parameters:
10
+ input_channels: 1
11
+ fc1_dim: 1024
12
+ frontend_algorithm: ["lfcc"]
13
+ optimizer:
14
+ lr: 0.0001
15
+ weight_decay: 0.0001
configs/training/rawnet3.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ seed: 42
3
+
4
+ checkpoint:
5
+ path: ""
6
+
7
+ model:
8
+ name: "rawnet3"
9
+ parameters: {}
10
+ optimizer:
11
+ lr: 0.001
12
+ weight_decay: 0.00005 # 5e-5
13
+
configs/training/specrnet.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ seed: 42
3
+
4
+ checkpoint:
5
+ path: ""
6
+
7
+ model:
8
+ name: "specrnet"
9
+ parameters:
10
+ input_channels: 1
11
+ frontend_algorithm: ["lfcc"]
12
+ optimizer:
13
+ lr: 0.0001
14
+ weight_decay: 0.0001
configs/training/whisper_frontend_lcnn.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ seed: 42
3
+
4
+ checkpoint:
5
+ path: ""
6
+
7
+ model:
8
+ name: "whisper_frontend_lcnn"
9
+ parameters:
10
+ freeze_encoder: True
11
+ input_channels: 2
12
+ frontend_algorithm: ["lfcc"]
13
+ optimizer:
14
+ lr: 0.0001
15
+ weight_decay: 0.0001
16
+
configs/training/whisper_frontend_lcnn_mfcc.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ seed: 42
3
+
4
+ checkpoint:
5
+ path: ""
6
+
7
+ model:
8
+ name: "whisper_frontend_lcnn"
9
+ parameters:
10
+ freeze_encoder: True
11
+ input_channels: 2
12
+ frontend_algorithm: ["mfcc"]
13
+ optimizer:
14
+ lr: 0.0001
15
+ weight_decay: 0.0001
configs/training/whisper_frontend_mesonet.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ seed: 42
3
+
4
+ checkpoint:
5
+ path: ""
6
+
7
+ model:
8
+ name: "whisper_frontend_mesonet"
9
+ parameters:
10
+ freeze_encoder: True
11
+ input_channels: 2
12
+ fc1_dim: 1024
13
+ frontend_algorithm: ["lfcc"]
14
+ optimizer:
15
+ lr: 0.0001
16
+ weight_decay: 0.0001
configs/training/whisper_frontend_mesonet_mfcc.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ seed: 42
3
+
4
+ checkpoint:
5
+ path: ""
6
+
7
+ model:
8
+ name: "whisper_frontend_mesonet"
9
+ parameters:
10
+ freeze_encoder: True
11
+ input_channels: 2
12
+ fc1_dim: 1024
13
+ frontend_algorithm: ["mfcc"]
14
+ optimizer:
15
+ lr: 0.0001
16
+ weight_decay: 0.0001
17
+
configs/training/whisper_frontend_specrnet.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ seed: 42
3
+
4
+ checkpoint:
5
+ path: ""
6
+
7
+ model:
8
+ name: "whisper_frontend_specrnet"
9
+ parameters:
10
+ freeze_encoder: True
11
+ input_channels: 2
12
+ frontend_algorithm: ["lfcc"]
13
+ optimizer:
14
+ lr: 0.0001
15
+ weight_decay: 0.0001
configs/training/whisper_frontend_specrnet_mfcc.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ seed: 42
3
+
4
+ checkpoint:
5
+ path: ""
6
+
7
+ model:
8
+ name: "whisper_frontend_specrnet"
9
+ parameters:
10
+ freeze_encoder: True
11
+ input_channels: 2
12
+ frontend_algorithm: ["mfcc"]
13
+ optimizer:
14
+ lr: 0.0001
15
+ weight_decay: 0.0001
16
+
configs/training/whisper_lcnn.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ seed: 42
3
+
4
+ checkpoint:
5
+ path: ""
6
+
7
+ model:
8
+ name: "whisper_lcnn"
9
+ parameters:
10
+ freeze_encoder: True
11
+ input_channels: 1
12
+ frontend_algorithm: ["lfcc"]
13
+ optimizer:
14
+ lr: 0.0001
15
+ weight_decay: 0.0001
configs/training/whisper_mesonet.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ seed: 42
3
+
4
+ checkpoint:
5
+ path: ""
6
+
7
+ model:
8
+ name: "whisper_mesonet"
9
+ parameters:
10
+ freeze_encoder: True
11
+ input_channels: 1
12
+ fc1_dim: 1024
13
+ frontend_algorithm: []
14
+ optimizer:
15
+ lr: 0.0001
16
+ weight_decay: 0.0001
configs/training/whisper_specrnet.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ seed: 42
3
+
4
+ checkpoint:
5
+ path: ""
6
+
7
+ model:
8
+ name: "whisper_specrnet"
9
+ parameters:
10
+ freeze_encoder: True
11
+ input_channels: 1
12
+ frontend_algorithm: ["lfcc"]
13
+ optimizer:
14
+ lr: 0.0001
15
+ weight_decay: 0.0001
download_whisper.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pip install git+https://github.com/openai/whisper.git
2
+ from collections import OrderedDict
3
+ import whisper
4
+ import torch
5
+
6
+ from src.commons import WHISPER_MODEL_WEIGHTS_PATH
7
+
8
+ def download_whisper():
9
+ model = whisper.load_model("tiny.en")
10
+ return model
11
+
12
+
13
+ def extract_and_save_encoder(model):
14
+ model_ckpt = OrderedDict()
15
+
16
+ model_ckpt['model_state_dict'] = OrderedDict()
17
+
18
+ for key, value in model.encoder.state_dict().items():
19
+ model_ckpt['model_state_dict'][f'encoder.{key}'] = value
20
+
21
+ model_ckpt['dims'] = model.dims
22
+ torch.save(model_ckpt, WHISPER_MODEL_WEIGHTS_PATH)
23
+
24
+
25
+ if __name__ == "__main__":
26
+ model = download_whisper()
27
+ print("Downloaded Whisper model!")
28
+ extract_and_save_encoder(model)
29
+ print(f"Saved encoder at '{WHISPER_MODEL_WEIGHTS_PATH}'")
evaluate_models.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ from pathlib import Path
4
+ from typing import Dict, List, Optional, Union
5
+ import sys
6
+
7
+ import torch
8
+ import yaml
9
+ from sklearn.metrics import precision_recall_fscore_support, roc_auc_score
10
+ from torch.utils.data import DataLoader
11
+
12
+ from src import metrics, commons
13
+ from src.models import models
14
+ from src.datasets.base_dataset import SimpleAudioFakeDataset
15
+ from src.datasets.in_the_wild_dataset import InTheWildDataset
16
+ from src.datasets.folder_dataset import FolderDataset, FileDataset
17
+
18
+
19
+ def get_dataset(
20
+ datasets_paths: List[Union[Path, str]],
21
+ amount_to_use: Optional[int],
22
+ ) -> SimpleAudioFakeDataset:
23
+ data_val = FolderDataset(
24
+ path=datasets_paths[0]
25
+ )
26
+ return data_val
27
+
28
+ def get_dataset_file(
29
+ datasets_path,
30
+ amount_to_use: Optional[int],
31
+ ) -> SimpleAudioFakeDataset:
32
+ data_val = FileDataset(
33
+ path=datasets_path
34
+ )
35
+ return data_val
36
+
37
+
38
+ def evaluate_nn(
39
+ model_paths: List[Path],
40
+ datasets_paths: List[Union[Path, str]],
41
+ model_config: Dict,
42
+ device: str,
43
+ amount_to_use: Optional[int] = None,
44
+ batch_size: int = 8,
45
+ ):
46
+ logging.info("Loading data...")
47
+ model_name, model_parameters = model_config["name"], model_config["parameters"]
48
+
49
+ # Load model architecture
50
+ model = models.get_model(
51
+ model_name=model_name,
52
+ config=model_parameters,
53
+ device=device,
54
+ )
55
+ # If provided weights, apply corresponding ones (from an appropriate fold)
56
+ if len(model_paths):
57
+ state_dict = torch.load(model_paths, map_location=device)
58
+ model.load_state_dict(state_dict)
59
+ model = model.to(device)
60
+
61
+ data_val = get_dataset(
62
+ datasets_paths=datasets_paths,
63
+ amount_to_use=amount_to_use,
64
+ )
65
+
66
+ logging.info(
67
+ f"Testing '{model_name}' model, weights path: '{model_paths}', on {len(data_val)} audio files."
68
+ )
69
+ test_loader = DataLoader(
70
+ data_val,
71
+ batch_size=batch_size,
72
+ shuffle=True,
73
+ drop_last=False,
74
+ num_workers=3,
75
+ )
76
+
77
+ batches_number = len(data_val) // batch_size
78
+ num_correct = 0.0
79
+ num_total = 0.0
80
+
81
+ y_pred = torch.Tensor([]).to(device)
82
+ y = torch.Tensor([]).to(device)
83
+ y_pred_label = torch.Tensor([]).to(device)
84
+
85
+ preds = []
86
+
87
+ for i, (batch_x, _, batch_y, metadata) in enumerate(test_loader):
88
+ model.eval()
89
+ _, path, _, _ = metadata
90
+ if i % 10 == 0:
91
+ print(f"Batch [{i}/{batches_number}]")
92
+
93
+ with torch.no_grad():
94
+ batch_x = batch_x.to(device)
95
+ batch_y = batch_y.to(device)
96
+ num_total += batch_x.size(0)
97
+
98
+ batch_pred = model(batch_x).squeeze(1)
99
+ batch_pred = torch.sigmoid(batch_pred)
100
+ batch_pred_label = (batch_pred + 0.5).int()
101
+
102
+ num_correct += (batch_pred_label == batch_y.int()).sum(dim=0).item()
103
+
104
+ y_pred = torch.concat([y_pred, batch_pred], dim=0)
105
+ y_pred_label = torch.concat([y_pred_label, batch_pred_label], dim=0)
106
+ y = torch.concat([y, batch_y], dim=0)
107
+
108
+ for i in range(len(y_pred_label)):
109
+ label = 'Fake' if y_pred_label[i] == 0 else 'Real'
110
+ print(f'{path[i]}')
111
+ print(f' Prediction: : {label}')
112
+ print(f' Probability: {y_pred[i]})')
113
+ preds.append((label, y_pred[i].detach().cpu().item()))
114
+
115
+ return preds
116
+
117
+ eval_accuracy = (num_correct / num_total) * 100
118
+
119
+ precision, recall, f1_score, support = precision_recall_fscore_support(
120
+ y.cpu().numpy(), y_pred_label.cpu().numpy(), average="binary", beta=1.0
121
+ )
122
+ auc_score = roc_auc_score(y_true=y.cpu().numpy(), y_score=y_pred.cpu().numpy())
123
+
124
+ # For EER flip values, following original evaluation implementation
125
+ y_for_eer = 1 - y
126
+
127
+ thresh, eer, fpr, tpr = metrics.calculate_eer(
128
+ y=y_for_eer.cpu().numpy(),
129
+ y_score=y_pred.cpu().numpy(),
130
+ )
131
+
132
+ eer_label = f"eval/eer"
133
+ accuracy_label = f"eval/accuracy"
134
+ precision_label = f"eval/precision"
135
+ recall_label = f"eval/recall"
136
+ f1_label = f"eval/f1_score"
137
+ auc_label = f"eval/auc"
138
+
139
+ logging.info(
140
+ f"{eer_label}: {eer:.4f}, {accuracy_label}: {eval_accuracy:.4f}, {precision_label}: {precision:.4f}, {recall_label}: {recall:.4f}, {f1_label}: {f1_score:.4f}, {auc_label}: {auc_score:.4f}"
141
+ )
142
+
143
+ def load_model(config, device):
144
+ model_config = config['model']
145
+ model_name, model_parameters = model_config["name"], model_config["parameters"]
146
+ model_paths = config["checkpoint"].get("path", [])
147
+ # Load model architecture
148
+ model = models.get_model(
149
+ model_name=model_name,
150
+ config=model_parameters,
151
+ device=device,
152
+ )
153
+ # If provided weights, apply corresponding ones (from an appropriate fold)
154
+ if len(model_paths):
155
+ state_dict = torch.load(model_paths, map_location=device)
156
+ model.load_state_dict(state_dict)
157
+ model = model.to(device)
158
+ return model
159
+
160
+ def inference(
161
+ model,
162
+ datasets_path,
163
+ device: str,
164
+ amount_to_use: Optional[int] = None,
165
+ batch_size: int = 8,
166
+ ):
167
+ logging.info("Loading data...")
168
+
169
+
170
+ data_val = get_dataset_file(
171
+ datasets_path=datasets_path,
172
+ amount_to_use=amount_to_use,
173
+ )
174
+
175
+ test_loader = DataLoader(
176
+ data_val,
177
+ batch_size=batch_size,
178
+ shuffle=True,
179
+ drop_last=False,
180
+ num_workers=3,
181
+ )
182
+
183
+ batches_number = len(data_val) // batch_size
184
+ num_correct = 0.0
185
+ num_total = 0.0
186
+
187
+ y_pred = torch.Tensor([]).to(device)
188
+ y = torch.Tensor([]).to(device)
189
+ y_pred_label = torch.Tensor([]).to(device)
190
+
191
+ preds = []
192
+
193
+ for i, (batch_x, _, batch_y, metadata) in enumerate(test_loader):
194
+ model.eval()
195
+ _, path, _, _ = metadata
196
+ if i % 10 == 0:
197
+ print(f"Batch [{i}/{batches_number}]")
198
+
199
+ with torch.no_grad():
200
+ batch_x = batch_x.to(device)
201
+ batch_y = batch_y.to(device)
202
+ num_total += batch_x.size(0)
203
+
204
+ batch_pred = model(batch_x).squeeze(1)
205
+ batch_pred = torch.sigmoid(batch_pred)
206
+ batch_pred_label = (batch_pred + 0.5).int()
207
+
208
+ num_correct += (batch_pred_label == batch_y.int()).sum(dim=0).item()
209
+
210
+ y_pred = torch.concat([y_pred, batch_pred], dim=0)
211
+ y_pred_label = torch.concat([y_pred_label, batch_pred_label], dim=0)
212
+ y = torch.concat([y, batch_y], dim=0)
213
+
214
+ for i in range(len(y_pred_label)):
215
+ label = 'Fake' if y_pred_label[i] == 0 else 'Real'
216
+ print(f'{path[i]}')
217
+ print(f' Prediction: : {label}')
218
+ print(f' Probability: {y_pred[i]})')
219
+ preds.append((label, y_pred[i].detach().cpu().item()))
220
+
221
+ return preds
222
+
223
+ eval_accuracy = (num_correct / num_total) * 100
224
+
225
+ precision, recall, f1_score, support = precision_recall_fscore_support(
226
+ y.cpu().numpy(), y_pred_label.cpu().numpy(), average="binary", beta=1.0
227
+ )
228
+ auc_score = roc_auc_score(y_true=y.cpu().numpy(), y_score=y_pred.cpu().numpy())
229
+
230
+ # For EER flip values, following original evaluation implementation
231
+ y_for_eer = 1 - y
232
+
233
+ thresh, eer, fpr, tpr = metrics.calculate_eer(
234
+ y=y_for_eer.cpu().numpy(),
235
+ y_score=y_pred.cpu().numpy(),
236
+ )
237
+
238
+ eer_label = f"eval/eer"
239
+ accuracy_label = f"eval/accuracy"
240
+ precision_label = f"eval/precision"
241
+ recall_label = f"eval/recall"
242
+ f1_label = f"eval/f1_score"
243
+ auc_label = f"eval/auc"
244
+
245
+ logging.info(
246
+ f"{eer_label}: {eer:.4f}, {accuracy_label}: {eval_accuracy:.4f}, {precision_label}: {precision:.4f}, {recall_label}: {recall:.4f}, {f1_label}: {f1_score:.4f}, {auc_label}: {auc_score:.4f}"
247
+ )
248
+
249
+
250
+ def main(args):
251
+ LOGGER = logging.getLogger()
252
+ LOGGER.setLevel(logging.INFO)
253
+
254
+ ch = logging.StreamHandler()
255
+ formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
256
+ ch.setFormatter(formatter)
257
+ LOGGER.addHandler(ch)
258
+ logging.basicConfig(stream=sys.stdout, level=logging.INFO)
259
+
260
+ if not args.cpu and torch.cuda.is_available():
261
+ device = "cuda"
262
+ else:
263
+ device = "cpu"
264
+
265
+ with open(args.config, "r") as f:
266
+ config = yaml.safe_load(f)
267
+
268
+ seed = config["data"].get("seed", 42)
269
+ # fix all seeds - this should not actually change anything
270
+ commons.set_seed(seed)
271
+
272
+ evaluate_nn(
273
+ model_paths=config["checkpoint"].get("path", []),
274
+ datasets_paths=[
275
+ args.folder_path,
276
+ ],
277
+ model_config=config["model"],
278
+ amount_to_use=args.amount,
279
+ device=device,
280
+ )
281
+
282
+
283
+ def parse_args():
284
+ parser = argparse.ArgumentParser()
285
+
286
+ # If assigned as None, then it won't be taken into account
287
+ FOLDER_DATASET_PATH = "sample_files"
288
+
289
+ parser.add_argument(
290
+ "--folder_path", type=str, default=FOLDER_DATASET_PATH
291
+ )
292
+
293
+ default_model_config = "config.yaml"
294
+ parser.add_argument(
295
+ "--config",
296
+ help="Model config file path (default: config.yaml)",
297
+ type=str,
298
+ default=default_model_config,
299
+ )
300
+
301
+ default_amount = None
302
+ parser.add_argument(
303
+ "--amount",
304
+ "-a",
305
+ help=f"Amount of files to load from each directory (default: {default_amount} - use all).",
306
+ type=int,
307
+ default=default_amount,
308
+ )
309
+
310
+ parser.add_argument("--cpu", "-c", help="Force using cpu", action="store_true")
311
+
312
+ return parser.parse_args()
313
+
314
+
315
+ if __name__ == "__main__":
316
+ main(parse_args())
install.sh ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch -y
2
+
3
+ pip install asteroid-filterbanks==0.4.0
4
+ pip install librosa==0.9.2
5
+ pip install git+https://github.com/openai/whisper.git@7858aa9c08d98f75575035ecd6481f462d66ca27
6
+ pip install pandas==2.0.2
mesonet_whisper_mfcc_finetuned.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a34a00d0961303274e1cf7a2dc2b6e9f9d568ff0416300be1aaee1c2e2ceee12
3
+ size 32983925
sample_files/[FAKE] - jokowi - cupid [vocals].mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1ce8dce41de4f44908c57deea26d4efe5a74f9a37700a76a94ac065e862304c0
3
+ size 775449
sample_files/[REAL] - Obama at Rutgers: 'Ignorance Is Not a Virtue'_[cut_49sec].mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6694c6d329f8a372896808c1f7c1e487eec65e5ad2fb3d244d80729b211ac0c4
3
+ size 1950720
sample_files/[REAL] - Obama's speech to the class of 2020 in 2 minutes | The Washington Post.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3f045f2b80fdc136c63bfc5897dd9a3d34a3b60dba886cae297d4425db30d5d9
3
+ size 27507540
sample_files/[[FAKE] - y2mate.com - DeepFake AI generated synthetic video of Barack Obama.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fd7100f013cb23ae4af3e00594330838dcae39dc86d669ef8fd215a6a6d88f53
3
+ size 273900
src/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ import logging
2
+
3
+ logging.getLogger(__name__).addHandler(logging.NullHandler())
src/commons.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility file for src toolkit."""
2
+ import os
3
+ import random
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+ WHISPER_MODEL_WEIGHTS_PATH = "src/models/assets/tiny_enc.en.pt"
9
+
10
+
11
+ def set_seed(seed: int):
12
+ """Fix PRNG seed for reproducable experiments.
13
+ """
14
+ random.seed(seed)
15
+ np.random.seed(seed)
16
+ torch.manual_seed(seed)
17
+ if torch.cuda.is_available():
18
+ torch.cuda.manual_seed(seed)
19
+ torch.cuda.manual_seed_all(seed)
20
+ torch.backends.cudnn.deterministic = True
21
+ torch.backends.cudnn.benchmark = False
22
+ os.environ["PYTHONHASHSEED"] = str(seed)
src/datasets/__init__.py ADDED
File without changes
src/datasets/asvspoof_dataset.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import pandas as pd
4
+ if __name__ == "__main__":
5
+ import sys
6
+ sys.path.append(str(Path(__file__).parent.parent.parent.absolute()))
7
+
8
+ from src.datasets.base_dataset import SimpleAudioFakeDataset
9
+
10
+ ASVSPOOF_SPLIT = {
11
+ "train": ['A01', 'A07', 'A08', 'A02', 'A09', 'A10', 'A03', 'A04', 'A05', 'A06', 'A11', 'A12', 'A13', 'A14', 'A15', 'A16', 'A17', 'A18', 'A19'],
12
+ "test": ['A01', 'A07', 'A08', 'A02', 'A09', 'A10', 'A03', 'A04', 'A05', 'A06', 'A11', 'A12', 'A13', 'A14', 'A15', 'A16', 'A17', 'A18', 'A19'],
13
+ "val": ['A01', 'A07', 'A08', 'A02', 'A09', 'A10', 'A03', 'A04', 'A05', 'A06', 'A11', 'A12', 'A13', 'A14', 'A15', 'A16', 'A17', 'A18', 'A19'],
14
+ "partition_ratio": [0.7, 0.15],
15
+ "seed": 45,
16
+ }
17
+
18
+
19
+ class ASVSpoofDataset(SimpleAudioFakeDataset):
20
+
21
+ protocol_folder_name = "ASVspoof2019_LA_cm_protocols"
22
+ subset_dir_prefix = "ASVspoof2019_LA_"
23
+ subsets = ("train", "dev", "eval")
24
+
25
+ def __init__(self, path, subset="train", transform=None):
26
+ super().__init__(subset, transform)
27
+ self.path = path
28
+
29
+ self.allowed_attacks = ASVSPOOF_SPLIT[subset]
30
+ self.partition_ratio = ASVSPOOF_SPLIT["partition_ratio"]
31
+ self.seed = ASVSPOOF_SPLIT["seed"]
32
+
33
+ self.samples = pd.DataFrame()
34
+
35
+ for subset in self.subsets:
36
+ subset_dir = Path(self.path) / f"{self.subset_dir_prefix}{subset}"
37
+ subset_protocol_path = self.get_protocol_path(subset)
38
+ subset_samples = self.read_protocol(subset_dir, subset_protocol_path)
39
+
40
+ self.samples = pd.concat([self.samples, subset_samples])
41
+
42
+ self.transform = transform
43
+
44
+ def get_protocol_path(self, subset):
45
+ paths = list((Path(self.path) / self.protocol_folder_name).glob("*.txt"))
46
+ for path in paths:
47
+ if subset in Path(path).stem:
48
+ return path
49
+
50
+ def read_protocol(self, subset_dir, protocol_path):
51
+ samples = {
52
+ "user_id": [],
53
+ "sample_name": [],
54
+ "attack_type": [],
55
+ "label": [],
56
+ "path": []
57
+ }
58
+
59
+ real_samples = []
60
+ fake_samples = []
61
+ with open(protocol_path, "r") as file:
62
+ for line in file:
63
+ attack_type = line.strip().split(" ")[3]
64
+
65
+ if attack_type == "-":
66
+ real_samples.append(line)
67
+ elif attack_type in self.allowed_attacks:
68
+ fake_samples.append(line)
69
+
70
+ if attack_type not in self.allowed_attacks:
71
+ continue
72
+
73
+ fake_samples = self.split_samples(fake_samples)
74
+ for line in fake_samples:
75
+ samples = self.add_line_to_samples(samples, line, subset_dir)
76
+
77
+ real_samples = self.split_samples(real_samples)
78
+ for line in real_samples:
79
+ samples = self.add_line_to_samples(samples, line, subset_dir)
80
+
81
+ return pd.DataFrame(samples)
82
+
83
+ @staticmethod
84
+ def add_line_to_samples(samples, line, subset_dir):
85
+ user_id, sample_name, _, attack_type, label = line.strip().split(" ")
86
+ samples["user_id"].append(user_id)
87
+ samples["sample_name"].append(sample_name)
88
+ samples["attack_type"].append(attack_type)
89
+ samples["label"].append(label)
90
+
91
+ assert (subset_dir / "flac" / f"{sample_name}.flac").exists()
92
+ samples["path"].append(subset_dir / "flac" / f"{sample_name}.flac")
93
+
94
+ return samples
95
+
96
+ class ASVSpoof2019DatasetOriginal(ASVSpoofDataset):
97
+
98
+ subsets = {"train": "train", "test": "dev", "val": "eval"}
99
+
100
+ protocol_folder_name = "ASVspoof2019_LA_cm_protocols"
101
+ subset_dir_prefix = "ASVspoof2019_LA_"
102
+ subset_dirs_attacks = {
103
+ "train": ["A01", "A02", "A03", "A04", "A05", "A06"],
104
+ "dev": ["A01", "A02", "A03", "A04", "A05", "A06"],
105
+ "eval": [
106
+ "A07", "A08", "A09", "A10", "A11", "A12", "A13", "A14", "A15",
107
+ "A16", "A17", "A18", "A19"
108
+ ]
109
+ }
110
+
111
+
112
+ def __init__(self, path, fold_subset="train"):
113
+ """
114
+ Initialise object. Skip __init__ of ASVSpoofDataset doe to different
115
+ logic, but follow SimpleAudioFakeDataset constructor.
116
+ """
117
+ super(ASVSpoofDataset, self).__init__(float('inf'), fold_subset)
118
+ self.path = path
119
+ subset = self.subsets[fold_subset]
120
+ self.allowed_attacks = self.subset_dirs_attacks[subset]
121
+ subset_dir = Path(self.path) / f"{self.subset_dir_prefix}{subset}"
122
+ subset_protocol_path = self.get_protocol_path(subset)
123
+ self.samples = self.read_protocol(subset_dir, subset_protocol_path)
124
+
125
+ def read_protocol(self, subset_dir, protocol_path):
126
+ samples = {
127
+ "user_id": [],
128
+ "sample_name": [],
129
+ "attack_type": [],
130
+ "label": [],
131
+ "path": []
132
+ }
133
+
134
+ real_samples = []
135
+ fake_samples = []
136
+
137
+ with open(protocol_path, "r") as file:
138
+ for line in file:
139
+ attack_type = line.strip().split(" ")[3]
140
+ if attack_type == "-":
141
+ real_samples.append(line)
142
+ elif attack_type in self.allowed_attacks:
143
+ fake_samples.append(line)
144
+ else:
145
+ raise ValueError(
146
+ "Tried to load attack that shouldn't be here!"
147
+ )
148
+
149
+ for line in fake_samples:
150
+ samples = self.add_line_to_samples(samples, line, subset_dir)
151
+ for line in real_samples:
152
+ samples = self.add_line_to_samples(samples, line, subset_dir)
153
+
154
+ return pd.DataFrame(samples)
155
+
src/datasets/base_dataset.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Base dataset classes."""
2
+ import logging
3
+ import math
4
+ import random
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+ import torch
9
+ import torchaudio
10
+ from torch.utils.data import Dataset
11
+ from torch.utils.data.dataset import T_co
12
+
13
+
14
+ LOGGER = logging.getLogger(__name__)
15
+
16
+ SAMPLING_RATE = 16_000
17
+ APPLY_NORMALIZATION = True
18
+ APPLY_TRIMMING = True
19
+ APPLY_PADDING = True
20
+ FRAMES_NUMBER = 480_000 # <- originally 64_600
21
+
22
+
23
+ SOX_SILENCE = [
24
+ # trim all silence that is longer than 0.2s and louder than 1% volume (relative to the file)
25
+ # from beginning and middle/end
26
+ ["silence", "1", "0.2", "1%", "-1", "0.2", "1%"],
27
+ ]
28
+
29
+
30
+ class SimpleAudioFakeDataset(Dataset):
31
+ def __init__(
32
+ self,
33
+ subset,
34
+ transform=None,
35
+ return_label: bool = True,
36
+ return_meta: bool = True,
37
+ ):
38
+ self.transform = transform
39
+
40
+ self.subset = subset
41
+ self.allowed_attacks = None
42
+ self.partition_ratio = None
43
+ self.seed = None
44
+ self.return_label = return_label
45
+ self.return_meta = return_meta
46
+
47
+ def split_samples(self, samples_list):
48
+ if isinstance(samples_list, pd.DataFrame):
49
+ samples_list = samples_list.sort_values(by=list(samples_list.columns))
50
+ samples_list = samples_list.sample(frac=1, random_state=self.seed)
51
+ else:
52
+ samples_list = sorted(samples_list)
53
+ random.seed(self.seed)
54
+ random.shuffle(samples_list)
55
+
56
+ p, s = self.partition_ratio
57
+ subsets = np.split(
58
+ samples_list, [int(p * len(samples_list)), int((p + s) * len(samples_list))]
59
+ )
60
+ return dict(zip(["train", "test", "val"], subsets))[self.subset]
61
+
62
+ def df2tuples(self):
63
+ tuple_samples = []
64
+ for i, elem in self.samples.iterrows():
65
+ tuple_samples.append(
66
+ (str(elem["path"]), elem["label"], elem["attack_type"])
67
+ )
68
+
69
+ self.samples = tuple_samples
70
+
71
+
72
+ return self.samples
73
+
74
+ def __getitem__(self, index) -> T_co:
75
+ if isinstance(self.samples, pd.DataFrame):
76
+ sample = self.samples.iloc[index]
77
+
78
+ path = str(sample["path"])
79
+ label = sample["label"]
80
+ attack_type = sample["attack_type"]
81
+ if type(attack_type) != str and math.isnan(attack_type):
82
+ attack_type = "N/A"
83
+ else:
84
+ path, label, attack_type = self.samples[index]
85
+
86
+ waveform, sample_rate = torchaudio.load(path, normalize=APPLY_NORMALIZATION)
87
+ import librosa
88
+ # waveform, sample_rate = librosa.load(path, sr=SAMPLING_RATE)
89
+ # waveform = torch.tensor(waveform)
90
+ print('waveform', waveform)
91
+ real_sec_length = len(waveform[0]) / sample_rate
92
+
93
+ waveform, sample_rate = apply_preprocessing(waveform, sample_rate)
94
+
95
+ return_data = [waveform, sample_rate]
96
+ if self.return_label:
97
+ label = 1 if label == "bonafide" else 0
98
+ return_data.append(label)
99
+
100
+ if self.return_meta:
101
+ return_data.append(
102
+ (
103
+ attack_type,
104
+ path,
105
+ self.subset,
106
+ real_sec_length,
107
+ )
108
+ )
109
+ return return_data
110
+
111
+ def __len__(self):
112
+ return len(self.samples)
113
+
114
+
115
+ def apply_preprocessing(
116
+ waveform,
117
+ sample_rate,
118
+ ):
119
+ if sample_rate != SAMPLING_RATE and SAMPLING_RATE != -1:
120
+ waveform, sample_rate = resample_wave(waveform, sample_rate, SAMPLING_RATE)
121
+
122
+ # Stereo to mono
123
+ if waveform.dim() > 1 and waveform.shape[0] > 1:
124
+ waveform = waveform[:1, ...]
125
+
126
+ # Trim too long utterances...
127
+ if APPLY_TRIMMING:
128
+ waveform, sample_rate = apply_trim(waveform, sample_rate)
129
+
130
+ # ... or pad too short ones.
131
+ if APPLY_PADDING:
132
+ waveform = apply_pad(waveform, FRAMES_NUMBER)
133
+
134
+ return waveform, sample_rate
135
+
136
+
137
+ def resample_wave(waveform, sample_rate, target_sample_rate):
138
+ # waveform, sample_rate = torchaudio.sox_effects.apply_effects_tensor(
139
+ # waveform, sample_rate, [["rate", f"{target_sample_rate}"]]
140
+ # )
141
+ waveform = torchaudio.functional.resample(waveform, orig_freq=sample_rate, new_freq=target_sample_rate)
142
+ return waveform, target_sample_rate
143
+
144
+
145
+ def resample_file(path, target_sample_rate, normalize=True):
146
+ waveform, sample_rate = torchaudio.sox_effects.apply_effects_file(
147
+ path, [["rate", f"{target_sample_rate}"]], normalize=normalize
148
+ )
149
+
150
+ return waveform, sample_rate
151
+
152
+
153
+ def apply_trim(waveform, sample_rate):
154
+ # (
155
+ # waveform_trimmed,
156
+ # sample_rate_trimmed,
157
+ # ) = torchaudio.sox_effects.apply_effects_tensor(waveform, sample_rate, SOX_SILENCE)
158
+
159
+ ["silence", "1", "0.2", "1%", "-1", "0.2", "1%"],
160
+ waveform_trimmed = torchaudio.functional.vad(waveform, sample_rate=sample_rate)
161
+
162
+ if waveform_trimmed.size()[1] > 0:
163
+ waveform = waveform_trimmed
164
+
165
+ return waveform, sample_rate
166
+
167
+
168
+ def apply_pad(waveform, cut):
169
+ """Pad wave by repeating signal until `cut` length is achieved."""
170
+ waveform = waveform.squeeze(0)
171
+ waveform_len = waveform.shape[0]
172
+
173
+ if waveform_len >= cut:
174
+ return waveform[:cut]
175
+
176
+ # need to pad
177
+ num_repeats = int(cut / waveform_len) + 1
178
+ padded_waveform = torch.tile(waveform, (1, num_repeats))[:, :cut][0]
179
+
180
+ return padded_waveform
src/datasets/deepfake_asvspoof_dataset.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from pathlib import Path
3
+
4
+ import pandas as pd
5
+
6
+ from src.datasets.base_dataset import SimpleAudioFakeDataset
7
+
8
+ DF_ASVSPOOF_SPLIT = {
9
+ "partition_ratio": [0.7, 0.15],
10
+ "seed": 45
11
+ }
12
+
13
+ LOGGER = logging.getLogger()
14
+
15
+ class DeepFakeASVSpoofDataset(SimpleAudioFakeDataset):
16
+
17
+ protocol_file_name = "keys/CM/trial_metadata.txt"
18
+ subset_dir_prefix = "ASVspoof2021_DF_eval"
19
+ subset_parts = ("part00", "part01", "part02", "part03")
20
+
21
+ def __init__(self, path, subset="train", transform=None):
22
+ super().__init__(subset, transform)
23
+ self.path = path
24
+
25
+ self.partition_ratio = DF_ASVSPOOF_SPLIT["partition_ratio"]
26
+ self.seed = DF_ASVSPOOF_SPLIT["seed"]
27
+
28
+ self.flac_paths = self.get_file_references()
29
+ self.samples = self.read_protocol()
30
+
31
+ self.transform = transform
32
+ LOGGER.info(f"Spoof: {len(self.samples[self.samples['label'] == 'spoof'])}")
33
+ LOGGER.info(f"Original: {len(self.samples[self.samples['label'] == 'bonafide'])}")
34
+
35
+ def get_file_references(self):
36
+ flac_paths = {}
37
+ for part in self.subset_parts:
38
+ path = Path(self.path) / f"{self.subset_dir_prefix}_{part}" / self.subset_dir_prefix / "flac"
39
+ flac_list = list(path.glob("*.flac"))
40
+
41
+ for path in flac_list:
42
+ flac_paths[path.stem] = path
43
+
44
+ return flac_paths
45
+
46
+ def read_protocol(self):
47
+ samples = {
48
+ "sample_name": [],
49
+ "label": [],
50
+ "path": [],
51
+ "attack_type": [],
52
+ }
53
+
54
+ real_samples = []
55
+ fake_samples = []
56
+ with open(Path(self.path) / self.protocol_file_name, "r") as file:
57
+ for line in file:
58
+ label = line.strip().split(" ")[5]
59
+
60
+ if label == "bonafide":
61
+ real_samples.append(line)
62
+ elif label == "spoof":
63
+ fake_samples.append(line)
64
+
65
+ fake_samples = self.split_samples(fake_samples)
66
+ for line in fake_samples:
67
+ samples = self.add_line_to_samples(samples, line)
68
+
69
+ real_samples = self.split_samples(real_samples)
70
+ for line in real_samples:
71
+ samples = self.add_line_to_samples(samples, line)
72
+
73
+ return pd.DataFrame(samples)
74
+
75
+ def add_line_to_samples(self, samples, line):
76
+ _, sample_name, _, _, _, label, _, _ = line.strip().split(" ")
77
+ samples["sample_name"].append(sample_name)
78
+ samples["label"].append(label)
79
+ samples["attack_type"].append(label)
80
+
81
+ sample_path = self.flac_paths[sample_name]
82
+ assert sample_path.exists()
83
+ samples["path"].append(sample_path)
84
+
85
+ return samples
86
+
src/datasets/detection_dataset.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import List, Optional
3
+
4
+
5
+ import pandas as pd
6
+
7
+ from src.datasets.base_dataset import SimpleAudioFakeDataset
8
+ from src.datasets.deepfake_asvspoof_dataset import DeepFakeASVSpoofDataset
9
+ from src.datasets.fakeavceleb_dataset import FakeAVCelebDataset
10
+ from src.datasets.wavefake_dataset import WaveFakeDataset
11
+ from src.datasets.asvspoof_dataset import ASVSpoof2019DatasetOriginal
12
+
13
+
14
+ LOGGER = logging.getLogger()
15
+
16
+
17
+ class DetectionDataset(SimpleAudioFakeDataset):
18
+ def __init__(
19
+ self,
20
+ asvspoof_path=None,
21
+ wavefake_path=None,
22
+ fakeavceleb_path=None,
23
+ asvspoof2019_path=None,
24
+ subset: str = "val",
25
+ transform=None,
26
+ oversample: bool = True,
27
+ undersample: bool = False,
28
+ return_label: bool = True,
29
+ reduced_number: Optional[int] = None,
30
+ return_meta: bool = False,
31
+ ):
32
+ super().__init__(
33
+ subset=subset,
34
+ transform=transform,
35
+ return_label=return_label,
36
+ return_meta=return_meta,
37
+ )
38
+ datasets = self._init_datasets(
39
+ asvspoof_path=asvspoof_path,
40
+ wavefake_path=wavefake_path,
41
+ fakeavceleb_path=fakeavceleb_path,
42
+ asvspoof2019_path=asvspoof2019_path,
43
+ subset=subset,
44
+ )
45
+ self.samples = pd.concat([ds.samples for ds in datasets], ignore_index=True)
46
+
47
+ if oversample:
48
+ self.oversample_dataset()
49
+ elif undersample:
50
+ self.undersample_dataset()
51
+
52
+ if reduced_number:
53
+ LOGGER.info(f"Using reduced number of samples - {reduced_number}!")
54
+ self.samples = self.samples.sample(
55
+ min(len(self.samples), reduced_number),
56
+ random_state=42,
57
+ )
58
+
59
+ def _init_datasets(
60
+ self,
61
+ asvspoof_path: Optional[str],
62
+ wavefake_path: Optional[str],
63
+ fakeavceleb_path: Optional[str],
64
+ asvspoof2019_path: Optional[str],
65
+ subset: str,
66
+ ) -> List[SimpleAudioFakeDataset]:
67
+ datasets = []
68
+
69
+ if asvspoof_path is not None:
70
+ asvspoof_dataset = DeepFakeASVSpoofDataset(asvspoof_path, subset=subset)
71
+ datasets.append(asvspoof_dataset)
72
+
73
+ if wavefake_path is not None:
74
+ wavefake_dataset = WaveFakeDataset(wavefake_path, subset=subset)
75
+ datasets.append(wavefake_dataset)
76
+
77
+ if fakeavceleb_path is not None:
78
+ fakeavceleb_dataset = FakeAVCelebDataset(fakeavceleb_path, subset=subset)
79
+ datasets.append(fakeavceleb_dataset)
80
+
81
+ if asvspoof2019_path is not None:
82
+ la_dataset = ASVSpoof2019DatasetOriginal(
83
+ asvspoof2019_path, fold_subset=subset
84
+ )
85
+ datasets.append(la_dataset)
86
+
87
+ return datasets
88
+
89
+ def oversample_dataset(self):
90
+ samples = self.samples.groupby(by=["label"])
91
+ bona_length = len(samples.groups["bonafide"])
92
+ spoof_length = len(samples.groups["spoof"])
93
+
94
+ diff_length = spoof_length - bona_length
95
+
96
+ if diff_length < 0:
97
+ raise NotImplementedError
98
+
99
+ if diff_length > 0:
100
+ bonafide = samples.get_group("bonafide").sample(diff_length, replace=True)
101
+ self.samples = pd.concat([self.samples, bonafide], ignore_index=True)
102
+
103
+ def undersample_dataset(self):
104
+ samples = self.samples.groupby(by=["label"])
105
+ bona_length = len(samples.groups["bonafide"])
106
+ spoof_length = len(samples.groups["spoof"])
107
+
108
+ if spoof_length < bona_length:
109
+ raise NotImplementedError
110
+
111
+ if spoof_length > bona_length:
112
+ spoofs = samples.get_group("spoof").sample(bona_length, replace=True)
113
+ self.samples = pd.concat(
114
+ [samples.get_group("bonafide"), spoofs], ignore_index=True
115
+ )
116
+
117
+ def get_bonafide_only(self):
118
+ samples = self.samples.groupby(by=["label"])
119
+ self.samples = samples.get_group("bonafide")
120
+ return self.samples
121
+
122
+ def get_spoof_only(self):
123
+ samples = self.samples.groupby(by=["label"])
124
+ self.samples = samples.get_group("spoof")
125
+ return self.samples
src/datasets/fakeavceleb_dataset.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import pandas as pd
4
+
5
+ from src.datasets.base_dataset import SimpleAudioFakeDataset
6
+
7
+ FAKEAVCELEB_SPLIT = {
8
+ "train": ['faceswap-wav2lip', 'fsgan-wav2lip', 'wav2lip', 'rtvc'],
9
+ "test": ['faceswap-wav2lip', 'fsgan-wav2lip', 'wav2lip', 'rtvc'],
10
+ "val": ['faceswap-wav2lip', 'fsgan-wav2lip', 'wav2lip', 'rtvc'],
11
+ "partition_ratio": [0.7, 0.15],
12
+ "seed": 45
13
+ }
14
+
15
+
16
+ class FakeAVCelebDataset(SimpleAudioFakeDataset):
17
+
18
+ audio_folder = "FakeAVCeleb-audio"
19
+ audio_extension = ".mp3"
20
+ metadata_file = Path(audio_folder) / "meta_data.csv"
21
+ subsets = ("train", "dev", "eval")
22
+
23
+ def __init__(self, path, subset="train", transform=None):
24
+ super().__init__(subset, transform)
25
+ self.path = path
26
+
27
+ self.subset = subset
28
+ self.allowed_attacks = FAKEAVCELEB_SPLIT[subset]
29
+ self.partition_ratio = FAKEAVCELEB_SPLIT["partition_ratio"]
30
+ self.seed = FAKEAVCELEB_SPLIT["seed"]
31
+
32
+ self.metadata = self.get_metadata()
33
+
34
+ self.samples = pd.concat([self.get_fake_samples(), self.get_real_samples()], ignore_index=True)
35
+
36
+ def get_metadata(self):
37
+ md = pd.read_csv(Path(self.path) / self.metadata_file)
38
+ md["audio_type"] = md["type"].apply(lambda x: x.split("-")[-1])
39
+ return md
40
+
41
+ def get_fake_samples(self):
42
+ samples = {
43
+ "user_id": [],
44
+ "sample_name": [],
45
+ "attack_type": [],
46
+ "label": [],
47
+ "path": []
48
+ }
49
+
50
+ for attack_name in self.allowed_attacks:
51
+ fake_samples = self.metadata[
52
+ (self.metadata["method"] == attack_name) & (self.metadata["audio_type"] == "FakeAudio")
53
+ ]
54
+
55
+ samples_list = fake_samples.iterrows()
56
+ samples_list = self.split_samples(samples_list)
57
+
58
+ for _, sample in samples_list:
59
+ samples["user_id"].append(sample["source"])
60
+ samples["sample_name"].append(Path(sample["filename"]).stem)
61
+ samples["attack_type"].append(sample["method"])
62
+ samples["label"].append("spoof")
63
+ samples["path"].append(self.get_file_path(sample))
64
+
65
+ return pd.DataFrame(samples)
66
+
67
+ def get_real_samples(self):
68
+ samples = {
69
+ "user_id": [],
70
+ "sample_name": [],
71
+ "attack_type": [],
72
+ "label": [],
73
+ "path": []
74
+ }
75
+
76
+ samples_list = self.metadata[
77
+ (self.metadata["method"] == "real") & (self.metadata["audio_type"] == "RealAudio")
78
+ ]
79
+
80
+ samples_list = self.split_samples(samples_list)
81
+
82
+ for index, sample in samples_list.iterrows():
83
+ samples["user_id"].append(sample["source"])
84
+ samples["sample_name"].append(Path(sample["filename"]).stem)
85
+ samples["attack_type"].append("-")
86
+ samples["label"].append("bonafide")
87
+ samples["path"].append(self.get_file_path(sample))
88
+
89
+ return pd.DataFrame(samples)
90
+
91
+ def get_file_path(self, sample):
92
+ path = "/".join([self.audio_folder, *sample["path"].split("/")[1:]])
93
+ return Path(self.path) / path / Path(sample["filename"]).with_suffix(self.audio_extension)
94
+
src/datasets/folder_dataset.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ import os
4
+ from pathlib import Path
5
+
6
+ from src.datasets.base_dataset import SimpleAudioFakeDataset
7
+
8
+
9
+ class FolderDataset(SimpleAudioFakeDataset):
10
+
11
+ def __init__(
12
+ self,
13
+ path,
14
+ subset="test",
15
+ transform=None,
16
+ ):
17
+ super().__init__(subset=subset, transform=transform)
18
+ self.path = path
19
+ self.samples = self.read_samples()
20
+
21
+
22
+ def read_samples(self):
23
+ path = Path(self.path)
24
+ print('ori path', path)
25
+ print('list', os.listdir(path))
26
+
27
+ samples = []
28
+ for filepath in os.listdir(path):
29
+ samples.append({
30
+ 'path': path / filepath,
31
+ 'label': '',
32
+ 'attack_type': '',
33
+ })
34
+
35
+ samples = pd.DataFrame(samples)
36
+ print('samples', samples)
37
+ return samples
38
+
39
+
40
+ class FileDataset(SimpleAudioFakeDataset):
41
+
42
+ def __init__(
43
+ self,
44
+ path,
45
+ subset="test",
46
+ transform=None,
47
+ ):
48
+ super().__init__(subset=subset, transform=transform)
49
+ self.path = path
50
+ self.samples = self.read_samples()
51
+
52
+
53
+ def read_samples(self):
54
+ path = Path(self.path)
55
+
56
+ samples = [{'path': path, 'label': '', 'attack_type':''}]
57
+
58
+ samples = pd.DataFrame(samples)
59
+ print('samples', samples)
60
+ return samples
61
+
62
+
63
+ if __name__ == "__main__":
64
+ dataset = InTheWildDataset(
65
+ path="../datasets/release_in_the_wild",
66
+ subset="val",
67
+ seed=242,
68
+ split_strategy="per_speaker"
69
+ )
70
+
71
+ print(len(dataset))
72
+ print(len(dataset.samples["user_id"].unique()))
73
+ print(dataset.samples["user_id"].unique())
74
+
75
+ print(dataset[0])
src/datasets/in_the_wild_dataset.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ from pathlib import Path
4
+
5
+ from src.datasets.base_dataset import SimpleAudioFakeDataset
6
+
7
+
8
+ class InTheWildDataset(SimpleAudioFakeDataset):
9
+
10
+ def __init__(
11
+ self,
12
+ path,
13
+ subset="train",
14
+ transform=None,
15
+ seed=None,
16
+ partition_ratio=(0.7, 0.15),
17
+ split_strategy="random"
18
+ ):
19
+ super().__init__(subset=subset, transform=transform)
20
+ self.path = path
21
+ self.read_samples()
22
+ self.partition_ratio = partition_ratio
23
+ self.seed = seed
24
+
25
+
26
+ def read_samples(self):
27
+ path = Path(self.path)
28
+ meta_path = path / "meta.csv"
29
+
30
+ self.samples = pd.read_csv(meta_path)
31
+ self.samples["path"] = self.samples["file"].apply(lambda n: str(path / n))
32
+ self.samples["file"] = self.samples["file"].apply(lambda n: Path(n).stem)
33
+ self.samples["label"] = self.samples["label"].map({"bona-fide": "bonafide", "spoof": "spoof"})
34
+ self.samples["attack_type"] = self.samples["label"].map({"bonafide": "-", "spoof": "X"})
35
+ self.samples.rename(columns={'file': 'sample_name', 'speaker': 'user_id'}, inplace=True)
36
+
37
+
38
+ def split_samples_per_speaker(self, samples):
39
+ speaker_list = pd.Series(samples["user_id"].unique())
40
+ speaker_list = speaker_list.sort_values()
41
+ speaker_list = speaker_list.sample(frac=1, random_state=self.seed)
42
+ speaker_list = list(speaker_list)
43
+
44
+ p, s = self.partition_ratio
45
+ subsets = np.split(speaker_list, [int(p * len(speaker_list)), int((p + s) * len(speaker_list))])
46
+ speaker_subset = dict(zip(['train', 'test', 'val'], subsets))[self.subset]
47
+ return self.samples[self.samples["user_id"].isin(speaker_subset)]
48
+
49
+
50
+ if __name__ == "__main__":
51
+ dataset = InTheWildDataset(
52
+ path="../datasets/release_in_the_wild",
53
+ subset="val",
54
+ seed=242,
55
+ split_strategy="per_speaker"
56
+ )
57
+
58
+ print(len(dataset))
59
+ print(len(dataset.samples["user_id"].unique()))
60
+ print(dataset.samples["user_id"].unique())
61
+
62
+ print(dataset[0])
src/datasets/wavefake_dataset.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import pandas as pd
4
+
5
+ from src.datasets.base_dataset import SimpleAudioFakeDataset
6
+
7
+ WAVEFAKE_SPLIT = {
8
+ "train": ['multi_band_melgan', 'melgan_large', 'parallel_wavegan', 'waveglow', 'full_band_melgan', 'melgan', 'hifiGAN'],
9
+ "test": ['multi_band_melgan', 'melgan_large', 'parallel_wavegan', 'waveglow', 'full_band_melgan', 'melgan', 'hifiGAN'],
10
+ "val": ['multi_band_melgan', 'melgan_large', 'parallel_wavegan', 'waveglow', 'full_band_melgan', 'melgan', 'hifiGAN'],
11
+ "partition_ratio": [0.7, 0.15],
12
+ "seed": 45
13
+ }
14
+
15
+
16
+ class WaveFakeDataset(SimpleAudioFakeDataset):
17
+
18
+ fake_data_path = "generated_audio"
19
+ jsut_real_data_path = "real_audio/jsut_ver1.1/basic5000/wav"
20
+ ljspeech_real_data_path = "real_audio/LJSpeech-1.1/wavs"
21
+
22
+ def __init__(self, path, subset="train", transform=None):
23
+ super().__init__(subset, transform)
24
+ self.path = Path(path)
25
+
26
+ self.fold_subset = subset
27
+ self.allowed_attacks = WAVEFAKE_SPLIT[subset]
28
+ self.partition_ratio = WAVEFAKE_SPLIT["partition_ratio"]
29
+ self.seed = WAVEFAKE_SPLIT["seed"]
30
+
31
+ self.samples = pd.concat([self.get_fake_samples(), self.get_real_samples()], ignore_index=True)
32
+
33
+ def get_fake_samples(self):
34
+ samples = {
35
+ "user_id": [],
36
+ "sample_name": [],
37
+ "attack_type": [],
38
+ "label": [],
39
+ "path": []
40
+ }
41
+
42
+ samples_list = list((self.path / self.fake_data_path).glob("*/*.wav"))
43
+ samples_list = self.filter_samples_by_attack(samples_list)
44
+ samples_list = self.split_samples(samples_list)
45
+
46
+ for sample in samples_list:
47
+ samples["user_id"].append(None)
48
+ samples["sample_name"].append("_".join(sample.stem.split("_")[:-1]))
49
+ samples["attack_type"].append(self.get_attack_from_path(sample))
50
+ samples["label"].append("spoof")
51
+ samples["path"].append(sample)
52
+
53
+ return pd.DataFrame(samples)
54
+
55
+ def filter_samples_by_attack(self, samples_list):
56
+ return [s for s in samples_list if self.get_attack_from_path(s) in self.allowed_attacks]
57
+
58
+ def get_real_samples(self):
59
+ samples = {
60
+ "user_id": [],
61
+ "sample_name": [],
62
+ "attack_type": [],
63
+ "label": [],
64
+ "path": []
65
+ }
66
+
67
+ samples_list = list((self.path / self.jsut_real_data_path).glob("*.wav"))
68
+ samples_list += list((self.path / self.ljspeech_real_data_path).glob("*.wav"))
69
+ samples_list = self.split_samples(samples_list)
70
+
71
+ for sample in samples_list:
72
+ samples["user_id"].append(None)
73
+ samples["sample_name"].append(sample.stem)
74
+ samples["attack_type"].append("-")
75
+ samples["label"].append("bonafide")
76
+ samples["path"].append(sample)
77
+
78
+ return pd.DataFrame(samples)
79
+
80
+ @staticmethod
81
+ def get_attack_from_path(path):
82
+ folder_name = path.parents[0].relative_to(path.parents[1])
83
+ return str(folder_name).split("_", maxsplit=1)[-1]
84
+
85
+
src/frontends.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Union, Callable
2
+
3
+ import torch
4
+ import torchaudio
5
+
6
+ SAMPLING_RATE = 16_000
7
+ win_length = 400 # int((25 / 1_000) * SAMPLING_RATE)
8
+ hop_length = 160 # int((10 / 1_000) * SAMPLING_RATE)
9
+
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
11
+
12
+ MFCC_FN = torchaudio.transforms.MFCC(
13
+ sample_rate=SAMPLING_RATE,
14
+ n_mfcc=128,
15
+ melkwargs={
16
+ "n_fft": 512,
17
+ "win_length": win_length,
18
+ "hop_length": hop_length,
19
+ },
20
+ ).to(device)
21
+
22
+
23
+ LFCC_FN = torchaudio.transforms.LFCC(
24
+ sample_rate=SAMPLING_RATE,
25
+ n_lfcc=128,
26
+ speckwargs={
27
+ "n_fft": 512,
28
+ "win_length": win_length,
29
+ "hop_length": hop_length,
30
+ },
31
+ ).to(device)
32
+
33
+ MEL_SCALE_FN = torchaudio.transforms.MelScale(
34
+ n_mels=80,
35
+ n_stft=257,
36
+ sample_rate=SAMPLING_RATE,
37
+ ).to(device)
38
+
39
+ delta_fn = torchaudio.transforms.ComputeDeltas(
40
+ win_length=400,
41
+ mode="replicate",
42
+ )
43
+
44
+
45
+ def get_frontend(
46
+ frontends: List[str],
47
+ ) -> Union[torchaudio.transforms.MFCC, torchaudio.transforms.LFCC, Callable,]:
48
+ if "mfcc" in frontends:
49
+ return prepare_mfcc_double_delta
50
+ elif "lfcc" in frontends:
51
+ return prepare_lfcc_double_delta
52
+ raise ValueError(f"{frontends} frontend is not supported!")
53
+
54
+
55
+ def prepare_lfcc_double_delta(input):
56
+ if input.ndim < 4:
57
+ input = input.unsqueeze(1) # (bs, 1, n_lfcc, frames)
58
+ x = LFCC_FN(input)
59
+ delta = delta_fn(x)
60
+ double_delta = delta_fn(delta)
61
+ x = torch.cat((x, delta, double_delta), 2) # -> [bs, 1, 128 * 3, 1500]
62
+ return x[:, :, :, :3000] # (bs, n, n_lfcc * 3, frames)
63
+
64
+
65
+ def prepare_mfcc_double_delta(input):
66
+ if input.ndim < 4:
67
+ input = input.unsqueeze(1) # (bs, 1, n_lfcc, frames)
68
+ x = MFCC_FN(input)
69
+ delta = delta_fn(x)
70
+ double_delta = delta_fn(delta)
71
+ x = torch.cat((x, delta, double_delta), 2) # -> [bs, 1, 128 * 3, 1500]
72
+ return x[:, :, :, :3000] # (bs, n, n_lfcc * 3, frames)
src/metrics.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ import numpy as np
4
+ from scipy.interpolate import interp1d
5
+ from scipy.optimize import brentq
6
+ from sklearn.metrics import roc_curve
7
+ from sklearn.metrics import roc_curve
8
+
9
+
10
+ def calculate_eer(y, y_score) -> Tuple[float, float, np.ndarray, np.ndarray]:
11
+ fpr, tpr, thresholds = roc_curve(y, -y_score)
12
+
13
+ eer = brentq(lambda x: 1.0 - x - interp1d(fpr, tpr)(x), 0.0, 1.0)
14
+ thresh = interp1d(fpr, thresholds)(eer)
15
+ return thresh, eer, fpr, tpr
src/models/__init__.py ADDED
File without changes
src/models/assets/mel_filters.npz ADDED
Binary file (2.05 kB). View file
 
src/models/assets/tiny_enc.en.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:206cca585e8ee06b813f958f72c548aebd489f125ef8949ad437f9fcc86e8cda
3
+ size 32853468
src/models/lcnn.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This code is modified version of LCNN baseline
3
+ from ASVSpoof2021 challenge - https://github.com/asvspoof-challenge/2021/blob/main/LA/Baseline-LFCC-LCNN/project/baseline_LA/model.py
4
+ """
5
+ import sys
6
+
7
+ import torch
8
+ import torch.nn as torch_nn
9
+
10
+ from src import frontends
11
+
12
+
13
+ NUM_COEFFICIENTS = 384
14
+
15
+
16
+ # For blstm
17
+ class BLSTMLayer(torch_nn.Module):
18
+ """ Wrapper over dilated conv1D
19
+ Input tensor: (batchsize=1, length, dim_in)
20
+ Output tensor: (batchsize=1, length, dim_out)
21
+ We want to keep the length the same
22
+ """
23
+ def __init__(self, input_dim, output_dim):
24
+ super().__init__()
25
+ if output_dim % 2 != 0:
26
+ print("Output_dim of BLSTMLayer is {:d}".format(output_dim))
27
+ print("BLSTMLayer expects a layer size of even number")
28
+ sys.exit(1)
29
+ # bi-directional LSTM
30
+ self.l_blstm = torch_nn.LSTM(
31
+ input_dim,
32
+ output_dim // 2,
33
+ bidirectional=True
34
+ )
35
+ def forward(self, x):
36
+ # permute to (length, batchsize=1, dim)
37
+ blstm_data, _ = self.l_blstm(x.permute(1, 0, 2))
38
+ # permute it backt to (batchsize=1, length, dim)
39
+ return blstm_data.permute(1, 0, 2)
40
+
41
+
42
+ class MaxFeatureMap2D(torch_nn.Module):
43
+ """ Max feature map (along 2D)
44
+
45
+ MaxFeatureMap2D(max_dim=1)
46
+
47
+ l_conv2d = MaxFeatureMap2D(1)
48
+ data_in = torch.rand([1, 4, 5, 5])
49
+ data_out = l_conv2d(data_in)
50
+
51
+
52
+ Input:
53
+ ------
54
+ data_in: tensor of shape (batch, channel, ...)
55
+
56
+ Output:
57
+ -------
58
+ data_out: tensor of shape (batch, channel//2, ...)
59
+
60
+ Note
61
+ ----
62
+ By default, Max-feature-map is on channel dimension,
63
+ and maxout is used on (channel ...)
64
+ """
65
+ def __init__(self, max_dim = 1):
66
+ super().__init__()
67
+ self.max_dim = max_dim
68
+
69
+ def forward(self, inputs):
70
+ # suppose inputs (batchsize, channel, length, dim)
71
+
72
+ shape = list(inputs.size())
73
+
74
+ if self.max_dim >= len(shape):
75
+ print("MaxFeatureMap: maximize on %d dim" % (self.max_dim))
76
+ print("But input has %d dimensions" % (len(shape)))
77
+ sys.exit(1)
78
+ if shape[self.max_dim] // 2 * 2 != shape[self.max_dim]:
79
+ print("MaxFeatureMap: maximize on %d dim" % (self.max_dim))
80
+ print("But this dimension has an odd number of data")
81
+ sys.exit(1)
82
+ shape[self.max_dim] = shape[self.max_dim]//2
83
+ shape.insert(self.max_dim, 2)
84
+
85
+ # view to (batchsize, 2, channel//2, ...)
86
+ # maximize on the 2nd dim
87
+ m, i = inputs.view(*shape).max(self.max_dim)
88
+ return m
89
+
90
+
91
+ ##############
92
+ ## FOR MODEL
93
+ ##############
94
+
95
+ class LCNN(torch_nn.Module):
96
+ """ Model definition
97
+ """
98
+ def __init__(self, **kwargs):
99
+ super().__init__()
100
+ input_channels = kwargs.get("input_channels", 1)
101
+ num_coefficients = kwargs.get("num_coefficients", NUM_COEFFICIENTS)
102
+
103
+ # Working sampling rate
104
+ self.num_coefficients = num_coefficients
105
+
106
+ # dimension of embedding vectors
107
+ # here, the embedding is just the activation before sigmoid()
108
+ self.v_emd_dim = 1
109
+
110
+ # it can handle models with multiple front-end configuration
111
+ # by default, only a single front-end
112
+
113
+ self.m_transform = torch_nn.Sequential(
114
+ torch_nn.Conv2d(input_channels, 64, (5, 5), 1, padding=(2, 2)),
115
+ MaxFeatureMap2D(),
116
+ torch.nn.MaxPool2d((2, 2), (2, 2)),
117
+
118
+ torch_nn.Conv2d(32, 64, (1, 1), 1, padding=(0, 0)),
119
+ MaxFeatureMap2D(),
120
+ torch_nn.BatchNorm2d(32, affine=False),
121
+ torch_nn.Conv2d(32, 96, (3, 3), 1, padding=(1, 1)),
122
+ MaxFeatureMap2D(),
123
+
124
+ torch.nn.MaxPool2d((2, 2), (2, 2)),
125
+ torch_nn.BatchNorm2d(48, affine=False),
126
+
127
+ torch_nn.Conv2d(48, 96, (1, 1), 1, padding=(0, 0)),
128
+ MaxFeatureMap2D(),
129
+ torch_nn.BatchNorm2d(48, affine=False),
130
+ torch_nn.Conv2d(48, 128, (3, 3), 1, padding=(1, 1)),
131
+ MaxFeatureMap2D(),
132
+
133
+ torch.nn.MaxPool2d((2, 2), (2, 2)),
134
+
135
+ torch_nn.Conv2d(64, 128, (1, 1), 1, padding=(0, 0)),
136
+ MaxFeatureMap2D(),
137
+ torch_nn.BatchNorm2d(64, affine=False),
138
+ torch_nn.Conv2d(64, 64, (3, 3), 1, padding=(1, 1)),
139
+ MaxFeatureMap2D(),
140
+ torch_nn.BatchNorm2d(32, affine=False),
141
+
142
+ torch_nn.Conv2d(32, 64, (1, 1), 1, padding=(0, 0)),
143
+ MaxFeatureMap2D(),
144
+ torch_nn.BatchNorm2d(32, affine=False),
145
+ torch_nn.Conv2d(32, 64, (3, 3), 1, padding=(1, 1)),
146
+ MaxFeatureMap2D(),
147
+ torch_nn.MaxPool2d((2, 2), (2, 2)),
148
+
149
+ torch_nn.Dropout(0.7)
150
+ )
151
+
152
+ self.m_before_pooling = torch_nn.Sequential(
153
+ BLSTMLayer((self.num_coefficients//16) * 32, (self.num_coefficients//16) * 32),
154
+ BLSTMLayer((self.num_coefficients//16) * 32, (self.num_coefficients//16) * 32)
155
+ )
156
+
157
+ self.m_output_act = torch_nn.Linear((self.num_coefficients // 16) * 32, self.v_emd_dim)
158
+
159
+ def _compute_embedding(self, x):
160
+ """ definition of forward method
161
+ Assume x (batchsize, length, dim)
162
+ Output x (batchsize * number_filter, output_dim)
163
+ """
164
+ # resample if necessary
165
+ # x = self.m_resampler(x.squeeze(-1)).unsqueeze(-1)
166
+
167
+ # number of sub models
168
+ batch_size = x.shape[0]
169
+
170
+ # buffer to store output scores from sub-models
171
+ output_emb = torch.zeros(
172
+ [batch_size, self.v_emd_dim],
173
+ device=x.device,
174
+ dtype=x.dtype
175
+ )
176
+
177
+ # compute scores for each sub-models
178
+ idx = 0
179
+
180
+ # compute scores
181
+ # 1. unsqueeze to (batch, 1, frame_length, fft_bin)
182
+ # 2. compute hidden features
183
+ x = x.permute(0,1,3,2)
184
+ hidden_features = self.m_transform(x)
185
+
186
+ # 3. (batch, channel, frame//N, feat_dim//N) ->
187
+ # (batch, frame//N, channel * feat_dim//N)
188
+ # where N is caused by conv with stride
189
+ hidden_features = hidden_features.permute(0, 2, 1, 3).contiguous()
190
+ frame_num = hidden_features.shape[1]
191
+
192
+ hidden_features = hidden_features.view(batch_size, frame_num, -1)
193
+ # 4. pooling
194
+ # 4. pass through LSTM then summingc
195
+ hidden_features_lstm = self.m_before_pooling(hidden_features)
196
+
197
+ # 5. pass through the output layer
198
+ tmp_emb = self.m_output_act((hidden_features_lstm + hidden_features).mean(1))
199
+ output_emb[idx * batch_size : (idx+1) * batch_size] = tmp_emb
200
+
201
+ return output_emb
202
+
203
+ def _compute_score(self, feature_vec):
204
+ # feature_vec is [batch * submodel, 1]
205
+ return torch.sigmoid(feature_vec).squeeze(1)
206
+
207
+ def forward(self, x):
208
+ feature_vec = self._compute_embedding(x)
209
+ return feature_vec
210
+
211
+
212
+
213
+ class FrontendLCNN(LCNN):
214
+ """ Model definition
215
+ """
216
+ def __init__(self, device: str = "cuda", **kwargs):
217
+ super().__init__(**kwargs)
218
+
219
+ self.device = device
220
+
221
+ frontend_name = kwargs.get("frontend_algorithm", [])
222
+ self.frontend = frontends.get_frontend(frontend_name)
223
+ print(f"Using {frontend_name} frontend")
224
+
225
+ def _compute_frontend(self, x):
226
+ frontend = self.frontend(x)
227
+ if frontend.ndim < 4:
228
+ return frontend.unsqueeze(1) # (bs, 1, n_lfcc, frames)
229
+ return frontend # (bs, n, n_lfcc, frames)
230
+
231
+ def forward(self, x):
232
+ x = self._compute_frontend(x)
233
+ feature_vec = self._compute_embedding(x)
234
+
235
+ return feature_vec
236
+
237
+
238
+ if __name__ == "__main__":
239
+
240
+ device = "cuda"
241
+ print("Definition of model")
242
+ model = FrontendLCNN(input_channels=2, num_coefficients=80, device=device, frontend_algorithm=["mel_spec"])
243
+ model = model.to(device)
244
+ batch_size = 12
245
+ mock_input = torch.rand((batch_size, 64_600,), device=device)
246
+ output = model(mock_input)
247
+ print(output.shape)
src/models/meso_net.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This code is modified version of MesoNet DeepFake detection solution
3
+ from FakeAVCeleb repository - https://github.com/DASH-Lab/FakeAVCeleb/blob/main/models/MesoNet.py.
4
+ """
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from src import frontends
9
+
10
+
11
+ class MesoInception4(nn.Module):
12
+ """
13
+ Pytorch Implemention of MesoInception4
14
+ Author: Honggu Liu
15
+ Date: July 7, 2019
16
+ """
17
+ def __init__(self, num_classes=1, **kwargs):
18
+ super().__init__()
19
+
20
+ self.fc1_dim = kwargs.get("fc1_dim", 1024)
21
+ input_channels = kwargs.get("input_channels", 3)
22
+ self.num_classes = num_classes
23
+
24
+ #InceptionLayer1
25
+ self.Incption1_conv1 = nn.Conv2d(input_channels, 1, 1, padding=0, bias=False)
26
+ self.Incption1_conv2_1 = nn.Conv2d(input_channels, 4, 1, padding=0, bias=False)
27
+ self.Incption1_conv2_2 = nn.Conv2d(4, 4, 3, padding=1, bias=False)
28
+ self.Incption1_conv3_1 = nn.Conv2d(input_channels, 4, 1, padding=0, bias=False)
29
+ self.Incption1_conv3_2 = nn.Conv2d(4, 4, 3, padding=2, dilation=2, bias=False)
30
+ self.Incption1_conv4_1 = nn.Conv2d(input_channels, 2, 1, padding=0, bias=False)
31
+ self.Incption1_conv4_2 = nn.Conv2d(2, 2, 3, padding=3, dilation=3, bias=False)
32
+ self.Incption1_bn = nn.BatchNorm2d(11)
33
+
34
+
35
+ #InceptionLayer2
36
+ self.Incption2_conv1 = nn.Conv2d(11, 2, 1, padding=0, bias=False)
37
+ self.Incption2_conv2_1 = nn.Conv2d(11, 4, 1, padding=0, bias=False)
38
+ self.Incption2_conv2_2 = nn.Conv2d(4, 4, 3, padding=1, bias=False)
39
+ self.Incption2_conv3_1 = nn.Conv2d(11, 4, 1, padding=0, bias=False)
40
+ self.Incption2_conv3_2 = nn.Conv2d(4, 4, 3, padding=2, dilation=2, bias=False)
41
+ self.Incption2_conv4_1 = nn.Conv2d(11, 2, 1, padding=0, bias=False)
42
+ self.Incption2_conv4_2 = nn.Conv2d(2, 2, 3, padding=3, dilation=3, bias=False)
43
+ self.Incption2_bn = nn.BatchNorm2d(12)
44
+
45
+ #Normal Layer
46
+ self.conv1 = nn.Conv2d(12, 16, 5, padding=2, bias=False)
47
+ self.relu = nn.ReLU(inplace=True)
48
+ self.leakyrelu = nn.LeakyReLU(0.1)
49
+ self.bn1 = nn.BatchNorm2d(16)
50
+ self.maxpooling1 = nn.MaxPool2d(kernel_size=(2, 2))
51
+
52
+ self.conv2 = nn.Conv2d(16, 16, 5, padding=2, bias=False)
53
+ self.maxpooling2 = nn.MaxPool2d(kernel_size=(4, 4))
54
+
55
+ self.dropout = nn.Dropout2d(0.5)
56
+ self.fc1 = nn.Linear(self.fc1_dim, 16)
57
+ self.fc2 = nn.Linear(16, num_classes)
58
+
59
+
60
+ #InceptionLayer
61
+ def InceptionLayer1(self, input):
62
+ x1 = self.Incption1_conv1(input)
63
+ x2 = self.Incption1_conv2_1(input)
64
+ x2 = self.Incption1_conv2_2(x2)
65
+ x3 = self.Incption1_conv3_1(input)
66
+ x3 = self.Incption1_conv3_2(x3)
67
+ x4 = self.Incption1_conv4_1(input)
68
+ x4 = self.Incption1_conv4_2(x4)
69
+ y = torch.cat((x1, x2, x3, x4), 1)
70
+ y = self.Incption1_bn(y)
71
+ y = self.maxpooling1(y)
72
+
73
+ return y
74
+
75
+ def InceptionLayer2(self, input):
76
+ x1 = self.Incption2_conv1(input)
77
+ x2 = self.Incption2_conv2_1(input)
78
+ x2 = self.Incption2_conv2_2(x2)
79
+ x3 = self.Incption2_conv3_1(input)
80
+ x3 = self.Incption2_conv3_2(x3)
81
+ x4 = self.Incption2_conv4_1(input)
82
+ x4 = self.Incption2_conv4_2(x4)
83
+ y = torch.cat((x1, x2, x3, x4), 1)
84
+ y = self.Incption2_bn(y)
85
+ y = self.maxpooling1(y)
86
+
87
+ return y
88
+
89
+ def forward(self, input):
90
+ x = self._compute_embedding(input)
91
+ return x
92
+
93
+ def _compute_embedding(self, input):
94
+ x = self.InceptionLayer1(input) #(Batch, 11, 128, 128)
95
+ x = self.InceptionLayer2(x) #(Batch, 12, 64, 64)
96
+
97
+ x = self.conv1(x) #(Batch, 16, 64 ,64)
98
+ x = self.relu(x)
99
+ x = self.bn1(x)
100
+ x = self.maxpooling1(x) #(Batch, 16, 32, 32)
101
+
102
+ x = self.conv2(x) #(Batch, 16, 32, 32)
103
+ x = self.relu(x)
104
+ x = self.bn1(x)
105
+ x = self.maxpooling2(x) #(Batch, 16, 8, 8)
106
+
107
+ x = x.view(x.size(0), -1) #(Batch, 16*8*8)
108
+ x = self.dropout(x)
109
+
110
+ x = nn.AdaptiveAvgPool1d(self.fc1_dim)(x)
111
+ x = self.fc1(x) #(Batch, 16) ### <-- o tu
112
+ x = self.leakyrelu(x)
113
+ x = self.dropout(x)
114
+ x = self.fc2(x)
115
+ return x
116
+
117
+
118
+ class FrontendMesoInception4(MesoInception4):
119
+
120
+ def __init__(self, **kwargs):
121
+ super().__init__(**kwargs)
122
+
123
+ self.device = kwargs['device']
124
+
125
+ frontend_name = kwargs.get("frontend_algorithm", [])
126
+ self.frontend = frontends.get_frontend(frontend_name)
127
+ print(f"Using {frontend_name} frontend")
128
+
129
+ def forward(self, x):
130
+ x = self.frontend(x)
131
+ x = self._compute_embedding(x)
132
+ return x
133
+
134
+
135
+ if __name__ == "__main__":
136
+ model = FrontendMesoInception4(
137
+ input_channels=2,
138
+ fc1_dim=1024,
139
+ device='cuda',
140
+ frontend_algorithm="lfcc"
141
+ )
142
+
143
+ def count_parameters(model) -> int:
144
+ pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
145
+ return pytorch_total_params
146
+ print(count_parameters(model))
src/models/models.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+ from src.models import (
4
+ lcnn,
5
+ specrnet,
6
+ whisper_specrnet,
7
+ rawnet3,
8
+ whisper_lcnn,
9
+ meso_net,
10
+ whisper_meso_net
11
+ )
12
+
13
+
14
+ def get_model(model_name: str, config: Dict, device: str):
15
+ if model_name == "rawnet3":
16
+ return rawnet3.prepare_model()
17
+ elif model_name == "lcnn":
18
+ return lcnn.FrontendLCNN(device=device, **config)
19
+ elif model_name == "specrnet":
20
+ return specrnet.FrontendSpecRNet(
21
+ device=device,
22
+ **config,
23
+ )
24
+ elif model_name == "mesonet":
25
+ return meso_net.FrontendMesoInception4(
26
+ input_channels=config.get("input_channels", 1),
27
+ fc1_dim=config.get("fc1_dim", 1024),
28
+ frontend_algorithm=config.get("frontend_algorithm", "lfcc"),
29
+ device=device,
30
+ )
31
+ elif model_name == "whisper_lcnn":
32
+ return whisper_lcnn.WhisperLCNN(
33
+ input_channels=config.get("input_channels", 1),
34
+ freeze_encoder=config.get("freeze_encoder", False),
35
+ device=device,
36
+ )
37
+ elif model_name == "whisper_specrnet":
38
+ return whisper_specrnet.WhisperSpecRNet(
39
+ input_channels=config.get("input_channels", 1),
40
+ freeze_encoder=config.get("freeze_encoder", False),
41
+ device=device,
42
+ )
43
+ elif model_name == "whisper_mesonet":
44
+ return whisper_meso_net.WhisperMesoNet(
45
+ input_channels=config.get("input_channels", 1),
46
+ freeze_encoder=config.get("freeze_encoder", True),
47
+ fc1_dim=config.get("fc1_dim", 1024),
48
+ device=device,
49
+ )
50
+ elif model_name == "whisper_frontend_lcnn":
51
+ return whisper_lcnn.WhisperMultiFrontLCNN(
52
+ input_channels=config.get("input_channels", 2),
53
+ freeze_encoder=config.get("freeze_encoder", False),
54
+ frontend_algorithm=config.get("frontend_algorithm", "lfcc"),
55
+ device=device,
56
+ )
57
+ elif model_name == "whisper_frontend_specrnet":
58
+ return whisper_specrnet.WhisperMultiFrontSpecRNet(
59
+ input_channels=config.get("input_channels", 2),
60
+ freeze_encoder=config.get("freeze_encoder", False),
61
+ frontend_algorithm=config.get("frontend_algorithm", "lfcc"),
62
+ device=device,
63
+ )
64
+ elif model_name == "whisper_frontend_mesonet":
65
+ return whisper_meso_net.WhisperMultiFrontMesoNet(
66
+ input_channels=config.get("input_channels", 2),
67
+ fc1_dim=config.get("fc1_dim", 1024),
68
+ freeze_encoder=config.get("freeze_encoder", True),
69
+ frontend_algorithm=config.get("frontend_algorithm", "lfcc"),
70
+ device=device,
71
+ )
72
+ else:
73
+ raise ValueError(f"Model '{model_name}' not supported")
src/models/rawnet3.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file contains implementation of RawNet3 architecture.
3
+ The original implementation can be found here: https://github.com/Jungjee/RawNet/tree/master/python/RawNet3
4
+ """
5
+ import math
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from asteroid_filterbanks import Encoder, ParamSincFB # pip install asteroid_filterbanks
11
+
12
+
13
+ class RawNet3(nn.Module):
14
+ def __init__(self, block, model_scale, context, summed, C=1024, **kwargs):
15
+ super().__init__()
16
+
17
+ nOut = kwargs["nOut"]
18
+
19
+ self.context = context
20
+ self.encoder_type = kwargs["encoder_type"]
21
+ self.log_sinc = kwargs["log_sinc"]
22
+ self.norm_sinc = kwargs["norm_sinc"]
23
+ self.out_bn = kwargs["out_bn"]
24
+ self.summed = summed
25
+
26
+ self.preprocess = nn.Sequential(
27
+ PreEmphasis(), nn.InstanceNorm1d(1, eps=1e-4, affine=True)
28
+ )
29
+ self.conv1 = Encoder(
30
+ ParamSincFB(
31
+ C // 4,
32
+ 251,
33
+ stride=kwargs["sinc_stride"],
34
+ )
35
+ )
36
+ self.relu = nn.ReLU()
37
+ self.bn1 = nn.BatchNorm1d(C // 4)
38
+
39
+ self.layer1 = block(
40
+ C // 4, C, kernel_size=3, dilation=2, scale=model_scale, pool=5
41
+ )
42
+ self.layer2 = block(
43
+ C, C, kernel_size=3, dilation=3, scale=model_scale, pool=3
44
+ )
45
+ self.layer3 = block(C, C, kernel_size=3, dilation=4, scale=model_scale)
46
+ self.layer4 = nn.Conv1d(3 * C, 1536, kernel_size=1)
47
+
48
+ if self.context:
49
+ attn_input = 1536 * 3
50
+ else:
51
+ attn_input = 1536
52
+ print("self.encoder_type", self.encoder_type)
53
+ if self.encoder_type == "ECA":
54
+ attn_output = 1536
55
+ elif self.encoder_type == "ASP":
56
+ attn_output = 1
57
+ else:
58
+ raise ValueError("Undefined encoder")
59
+
60
+ self.attention = nn.Sequential(
61
+ nn.Conv1d(attn_input, 128, kernel_size=1),
62
+ nn.ReLU(),
63
+ nn.BatchNorm1d(128),
64
+ nn.Conv1d(128, attn_output, kernel_size=1),
65
+ nn.Softmax(dim=2),
66
+ )
67
+
68
+ self.bn5 = nn.BatchNorm1d(3072)
69
+
70
+ self.fc6 = nn.Linear(3072, nOut)
71
+ self.bn6 = nn.BatchNorm1d(nOut)
72
+
73
+ self.mp3 = nn.MaxPool1d(3)
74
+
75
+ def forward(self, x):
76
+ """
77
+ :param x: input mini-batch (bs, samp)
78
+ """
79
+
80
+ with torch.cuda.amp.autocast(enabled=False):
81
+ x = self.preprocess(x)
82
+ x = torch.abs(self.conv1(x))
83
+ if self.log_sinc:
84
+ x = torch.log(x + 1e-6)
85
+ if self.norm_sinc == "mean":
86
+ x = x - torch.mean(x, dim=-1, keepdim=True)
87
+ elif self.norm_sinc == "mean_std":
88
+ m = torch.mean(x, dim=-1, keepdim=True)
89
+ s = torch.std(x, dim=-1, keepdim=True)
90
+ s[s < 0.001] = 0.001
91
+ x = (x - m) / s
92
+
93
+ if self.summed:
94
+ x1 = self.layer1(x)
95
+ x2 = self.layer2(x1)
96
+ x3 = self.layer3(self.mp3(x1) + x2)
97
+ else:
98
+ x1 = self.layer1(x)
99
+ x2 = self.layer2(x1)
100
+ x3 = self.layer3(x2)
101
+
102
+ x = self.layer4(torch.cat((self.mp3(x1), x2, x3), dim=1))
103
+ x = self.relu(x)
104
+
105
+ t = x.size()[-1]
106
+
107
+ if self.context:
108
+ global_x = torch.cat(
109
+ (
110
+ x,
111
+ torch.mean(x, dim=2, keepdim=True).repeat(1, 1, t),
112
+ torch.sqrt(
113
+ torch.var(x, dim=2, keepdim=True).clamp(
114
+ min=1e-4, max=1e4
115
+ )
116
+ ).repeat(1, 1, t),
117
+ ),
118
+ dim=1,
119
+ )
120
+ else:
121
+ global_x = x
122
+
123
+ w = self.attention(global_x)
124
+
125
+ mu = torch.sum(x * w, dim=2)
126
+ sg = torch.sqrt(
127
+ (torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-4, max=1e4)
128
+ )
129
+
130
+ x = torch.cat((mu, sg), 1)
131
+
132
+ x = self.bn5(x)
133
+
134
+ x = self.fc6(x)
135
+
136
+ if self.out_bn:
137
+ x = self.bn6(x)
138
+
139
+ return x
140
+
141
+
142
+ class PreEmphasis(torch.nn.Module):
143
+ def __init__(self, coef: float = 0.97) -> None:
144
+ super().__init__()
145
+ self.coef = coef
146
+ # make kernel
147
+ # In pytorch, the convolution operation uses cross-correlation. So, filter is flipped.
148
+ self.register_buffer(
149
+ "flipped_filter",
150
+ torch.FloatTensor([-self.coef, 1.0]).unsqueeze(0).unsqueeze(0),
151
+ )
152
+
153
+ def forward(self, input: torch.tensor) -> torch.tensor:
154
+ assert (
155
+ len(input.size()) == 2
156
+ ), "The number of dimensions of input tensor must be 2!"
157
+ # reflect padding to match lengths of in/out
158
+ input = input.unsqueeze(1)
159
+ input = F.pad(input, (1, 0), "reflect")
160
+ return F.conv1d(input, self.flipped_filter)
161
+
162
+
163
+ class AFMS(nn.Module):
164
+ """
165
+ Alpha-Feature map scaling, added to the output of each residual block[1,2].
166
+
167
+ Reference:
168
+ [1] RawNet2 : https://www.isca-speech.org/archive/Interspeech_2020/pdfs/1011.pdf
169
+ [2] AMFS : https://www.koreascience.or.kr/article/JAKO202029757857763.page
170
+ """
171
+
172
+ def __init__(self, nb_dim: int) -> None:
173
+ super().__init__()
174
+ self.alpha = nn.Parameter(torch.ones((nb_dim, 1)))
175
+ self.fc = nn.Linear(nb_dim, nb_dim)
176
+ self.sig = nn.Sigmoid()
177
+
178
+ def forward(self, x):
179
+ y = F.adaptive_avg_pool1d(x, 1).view(x.size(0), -1)
180
+ y = self.sig(self.fc(y)).view(x.size(0), x.size(1), -1)
181
+
182
+ x = x + self.alpha
183
+ x = x * y
184
+ return x
185
+
186
+
187
+ class Bottle2neck(nn.Module):
188
+ def __init__(
189
+ self,
190
+ inplanes,
191
+ planes,
192
+ kernel_size=None,
193
+ dilation=None,
194
+ scale=4,
195
+ pool=False,
196
+ ):
197
+
198
+ super().__init__()
199
+
200
+ width = int(math.floor(planes / scale))
201
+
202
+ self.conv1 = nn.Conv1d(inplanes, width * scale, kernel_size=1)
203
+ self.bn1 = nn.BatchNorm1d(width * scale)
204
+
205
+ self.nums = scale - 1
206
+
207
+ convs = []
208
+ bns = []
209
+
210
+ num_pad = math.floor(kernel_size / 2) * dilation
211
+
212
+ for i in range(self.nums):
213
+ convs.append(
214
+ nn.Conv1d(
215
+ width,
216
+ width,
217
+ kernel_size=kernel_size,
218
+ dilation=dilation,
219
+ padding=num_pad,
220
+ )
221
+ )
222
+ bns.append(nn.BatchNorm1d(width))
223
+
224
+ self.convs = nn.ModuleList(convs)
225
+ self.bns = nn.ModuleList(bns)
226
+
227
+ self.conv3 = nn.Conv1d(width * scale, planes, kernel_size=1)
228
+ self.bn3 = nn.BatchNorm1d(planes)
229
+
230
+ self.relu = nn.ReLU()
231
+
232
+ self.width = width
233
+
234
+ self.mp = nn.MaxPool1d(pool) if pool else False
235
+ self.afms = AFMS(planes)
236
+
237
+ if inplanes != planes: # if change in number of filters
238
+ self.residual = nn.Sequential(
239
+ nn.Conv1d(inplanes, planes, kernel_size=1, stride=1, bias=False)
240
+ )
241
+ else:
242
+ self.residual = nn.Identity()
243
+
244
+ def forward(self, x):
245
+ residual = self.residual(x)
246
+
247
+ out = self.conv1(x)
248
+ out = self.relu(out)
249
+ out = self.bn1(out)
250
+
251
+ spx = torch.split(out, self.width, 1)
252
+ for i in range(self.nums):
253
+ if i == 0:
254
+ sp = spx[i]
255
+ else:
256
+ sp = sp + spx[i]
257
+ sp = self.convs[i](sp)
258
+ sp = self.relu(sp)
259
+ sp = self.bns[i](sp)
260
+ if i == 0:
261
+ out = sp
262
+ else:
263
+ out = torch.cat((out, sp), 1)
264
+
265
+ out = torch.cat((out, spx[self.nums]), 1)
266
+
267
+ out = self.conv3(out)
268
+ out = self.relu(out)
269
+ out = self.bn3(out)
270
+
271
+ out += residual
272
+ if self.mp:
273
+ out = self.mp(out)
274
+ out = self.afms(out)
275
+
276
+ return out
277
+
278
+
279
+ def prepare_model():
280
+ model = RawNet3(
281
+ Bottle2neck,
282
+ model_scale=8,
283
+ context=True,
284
+ summed=True,
285
+ encoder_type="ECA",
286
+ nOut=1, # number of slices
287
+ out_bn=False,
288
+ sinc_stride=10,
289
+ log_sinc=True,
290
+ norm_sinc="mean",
291
+ grad_mult=1,
292
+ )
293
+ return model
294
+
295
+
296
+ if __name__ == "__main__":
297
+ model = RawNet3(
298
+ Bottle2neck,
299
+ model_scale=8,
300
+ context=True,
301
+ summed=True,
302
+ encoder_type="ECA",
303
+ nOut=1, # number of slices
304
+ out_bn=False,
305
+ sinc_stride=10,
306
+ log_sinc=True,
307
+ norm_sinc="mean",
308
+ grad_mult=1,
309
+ )
310
+ gpu = False
311
+
312
+ model.eval()
313
+ print("RawNet3 initialised & weights loaded!")
314
+
315
+ if torch.cuda.is_available():
316
+ print("Cuda available, conducting inference on GPU")
317
+ model = model.to("cuda")
318
+ gpu = True
319
+
320
+ audios = torch.rand(32, 64_600)
321
+
322
+ out = model(audios)
323
+ print(out.shape)
src/models/specrnet.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file contains implementation of SpecRNet architecture.
3
+ We base our codebase on the implementation of RawNet2 by Hemlata Tak (tak@eurecom.fr).
4
+ It is available here: https://github.com/asvspoof-challenge/2021/blob/main/LA/Baseline-RawNet2/model.py
5
+ """
6
+ from typing import Dict
7
+
8
+ import torch.nn as nn
9
+
10
+ from src import frontends
11
+
12
+
13
+ def get_config(input_channels: int) -> Dict:
14
+ return {
15
+ "filts": [input_channels, [input_channels, 20], [20, 64], [64, 64]],
16
+ "nb_fc_node": 64,
17
+ "gru_node": 64,
18
+ "nb_gru_layer": 2,
19
+ "nb_classes": 1,
20
+ }
21
+
22
+
23
+ class Residual_block2D(nn.Module):
24
+ def __init__(self, nb_filts, first=False):
25
+ super().__init__()
26
+ self.first = first
27
+
28
+ if not self.first:
29
+ self.bn1 = nn.BatchNorm2d(num_features=nb_filts[0])
30
+
31
+ self.lrelu = nn.LeakyReLU(negative_slope=0.3)
32
+
33
+ self.conv1 = nn.Conv2d(
34
+ in_channels=nb_filts[0],
35
+ out_channels=nb_filts[1],
36
+ kernel_size=3,
37
+ padding=1,
38
+ stride=1,
39
+ )
40
+
41
+ self.bn2 = nn.BatchNorm2d(num_features=nb_filts[1])
42
+ self.conv2 = nn.Conv2d(
43
+ in_channels=nb_filts[1],
44
+ out_channels=nb_filts[1],
45
+ padding=1,
46
+ kernel_size=3,
47
+ stride=1,
48
+ )
49
+
50
+ if nb_filts[0] != nb_filts[1]:
51
+ self.downsample = True
52
+ self.conv_downsample = nn.Conv2d(
53
+ in_channels=nb_filts[0],
54
+ out_channels=nb_filts[1],
55
+ padding=0,
56
+ kernel_size=1,
57
+ stride=1,
58
+ )
59
+
60
+ else:
61
+ self.downsample = False
62
+ self.mp = nn.MaxPool2d(2)
63
+
64
+ def forward(self, x):
65
+ identity = x
66
+ if not self.first:
67
+ out = self.bn1(x)
68
+ out = self.lrelu(out)
69
+ else:
70
+ out = x
71
+
72
+ out = self.conv1(x)
73
+ out = self.bn2(out)
74
+ out = self.lrelu(out)
75
+ out = self.conv2(out)
76
+
77
+ if self.downsample:
78
+ identity = self.conv_downsample(identity)
79
+
80
+ out += identity
81
+ out = self.mp(out)
82
+ return out
83
+
84
+
85
+ class SpecRNet(nn.Module):
86
+ def __init__(self, input_channels, **kwargs):
87
+ super().__init__()
88
+ config = get_config(input_channels=input_channels)
89
+
90
+ self.device = kwargs.get("device", "cuda")
91
+
92
+ self.first_bn = nn.BatchNorm2d(num_features=config["filts"][0])
93
+ self.selu = nn.SELU(inplace=True)
94
+ self.block0 = nn.Sequential(
95
+ Residual_block2D(nb_filts=config["filts"][1], first=True)
96
+ )
97
+ self.block2 = nn.Sequential(Residual_block2D(nb_filts=config["filts"][2]))
98
+ config["filts"][2][0] = config["filts"][2][1]
99
+ self.block4 = nn.Sequential(Residual_block2D(nb_filts=config["filts"][2]))
100
+ self.avgpool = nn.AdaptiveAvgPool2d(1)
101
+
102
+ self.fc_attention0 = self._make_attention_fc(
103
+ in_features=config["filts"][1][-1], l_out_features=config["filts"][1][-1]
104
+ )
105
+ self.fc_attention2 = self._make_attention_fc(
106
+ in_features=config["filts"][2][-1], l_out_features=config["filts"][2][-1]
107
+ )
108
+ self.fc_attention4 = self._make_attention_fc(
109
+ in_features=config["filts"][2][-1], l_out_features=config["filts"][2][-1]
110
+ )
111
+
112
+ self.bn_before_gru = nn.BatchNorm2d(num_features=config["filts"][2][-1])
113
+ self.gru = nn.GRU(
114
+ input_size=config["filts"][2][-1],
115
+ hidden_size=config["gru_node"],
116
+ num_layers=config["nb_gru_layer"],
117
+ batch_first=True,
118
+ bidirectional=True,
119
+ )
120
+
121
+ self.fc1_gru = nn.Linear(
122
+ in_features=config["gru_node"] * 2, out_features=config["nb_fc_node"] * 2
123
+ )
124
+
125
+ self.fc2_gru = nn.Linear(
126
+ in_features=config["nb_fc_node"] * 2,
127
+ out_features=config["nb_classes"],
128
+ bias=True,
129
+ )
130
+
131
+ self.sig = nn.Sigmoid()
132
+
133
+ def _compute_embedding(self, x):
134
+ x = self.first_bn(x)
135
+ x = self.selu(x)
136
+
137
+ x0 = self.block0(x)
138
+ y0 = self.avgpool(x0).view(x0.size(0), -1)
139
+ y0 = self.fc_attention0(y0)
140
+ y0 = self.sig(y0).view(y0.size(0), y0.size(1), -1)
141
+ y0 = y0.unsqueeze(-1)
142
+ x = x0 * y0 + y0
143
+
144
+ x = nn.MaxPool2d(2)(x)
145
+
146
+ x2 = self.block2(x)
147
+ y2 = self.avgpool(x2).view(x2.size(0), -1)
148
+ y2 = self.fc_attention2(y2)
149
+ y2 = self.sig(y2).view(y2.size(0), y2.size(1), -1)
150
+ y2 = y2.unsqueeze(-1)
151
+ x = x2 * y2 + y2
152
+
153
+ x = nn.MaxPool2d(2)(x)
154
+
155
+ x4 = self.block4(x)
156
+ y4 = self.avgpool(x4).view(x4.size(0), -1)
157
+ y4 = self.fc_attention4(y4)
158
+ y4 = self.sig(y4).view(y4.size(0), y4.size(1), -1)
159
+ y4 = y4.unsqueeze(-1)
160
+ x = x4 * y4 + y4
161
+
162
+ x = nn.MaxPool2d(2)(x)
163
+
164
+ x = self.bn_before_gru(x)
165
+ x = self.selu(x)
166
+ x = nn.AdaptiveAvgPool2d((1, None))(x)
167
+ x = x.squeeze(-2)
168
+ x = x.permute(0, 2, 1)
169
+ self.gru.flatten_parameters()
170
+ x, _ = self.gru(x)
171
+ x = x[:, -1, :]
172
+ x = self.fc1_gru(x)
173
+ x = self.fc2_gru(x)
174
+ return x
175
+
176
+ def forward(self, x):
177
+ x = self._compute_embedding(x)
178
+ return x
179
+
180
+ def _make_attention_fc(self, in_features, l_out_features):
181
+ l_fc = []
182
+ l_fc.append(nn.Linear(in_features=in_features, out_features=l_out_features))
183
+ return nn.Sequential(*l_fc)
184
+
185
+
186
+ class FrontendSpecRNet(SpecRNet):
187
+ def __init__(self, input_channels, **kwargs):
188
+ super().__init__(input_channels, **kwargs)
189
+
190
+ self.device = kwargs['device']
191
+
192
+ frontend_name = kwargs.get("frontend_algorithm", [])
193
+ self.frontend = frontends.get_frontend(frontend_name)
194
+ print(f"Using {frontend_name} frontend")
195
+
196
+ def _compute_frontend(self, x):
197
+ frontend = self.frontend(x)
198
+ if frontend.ndim < 4:
199
+ return frontend.unsqueeze(1) # (bs, 1, n_lfcc, frames)
200
+ return frontend # (bs, n, n_lfcc, frames)
201
+
202
+ def forward(self, x):
203
+ x = self._compute_frontend(x)
204
+ x = self._compute_embedding(x)
205
+ return x
206
+
207
+
208
+ if __name__ == "__main__":
209
+ print("Definition of model")
210
+ device = "cuda"
211
+
212
+ input_channels = 1
213
+ config = {
214
+ "filts": [input_channels, [input_channels, 20], [20, 64], [64, 64]],
215
+ "nb_fc_node": 64,
216
+ "gru_node": 64,
217
+ "nb_gru_layer": 2,
218
+ "nb_classes": 1,
219
+ }
220
+
221
+ def count_parameters(model) -> int:
222
+ pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
223
+ return pytorch_total_params
224
+ model = FrontendSpecRNet(input_channels=1, device=device, frontend_algorithm=["lfcc"])
225
+ model = model.to(device)
226
+ print(count_parameters(model))
src/models/whisper_lcnn.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from src.models.whisper_main import ModelDimensions, Whisper, log_mel_spectrogram
4
+ from src.models.lcnn import LCNN
5
+ from src import frontends
6
+ from src.commons import WHISPER_MODEL_WEIGHTS_PATH
7
+
8
+
9
+ class WhisperLCNN(LCNN):
10
+
11
+ def __init__(self, input_channels, freeze_encoder, **kwargs):
12
+ super().__init__(input_channels=input_channels, **kwargs)
13
+
14
+ self.device = kwargs['device']
15
+ checkpoint = torch.load(WHISPER_MODEL_WEIGHTS_PATH)
16
+ dims = ModelDimensions(**checkpoint["dims"].__dict__)
17
+ model = Whisper(dims)
18
+ model = model.to(self.device)
19
+ model.load_state_dict(checkpoint["model_state_dict"])
20
+ self.whisper_model = model
21
+ if freeze_encoder:
22
+ for param in self.whisper_model.parameters():
23
+ param.requires_grad = False
24
+
25
+ def compute_whisper_features(self, x):
26
+ specs = []
27
+ for sample in x:
28
+ specs.append(log_mel_spectrogram(sample))
29
+ x = torch.stack(specs)
30
+ x = self.whisper_model(x)
31
+
32
+ x = x.permute(0, 2, 1) # (bs, frames, 3 x n_lfcc)
33
+ x = x.unsqueeze(1) # (bs, 1, frames, 3 x n_lfcc)
34
+ x = x.repeat(
35
+ (1, 1, 1, 2)
36
+ ) # (bs, 1, frames, 3 x n_lfcc) -> (bs, 1, frames, 3000)
37
+ return x
38
+
39
+ def forward(self, x):
40
+ # we assume that the data is correct (i.e. 30s)
41
+ x = self.compute_whisper_features(x)
42
+ out = self._compute_embedding(x)
43
+ return out
44
+
45
+
46
+ class WhisperMultiFrontLCNN(WhisperLCNN):
47
+
48
+ def __init__(self, input_channels, freeze_encoder, **kwargs):
49
+ super().__init__(input_channels=input_channels, freeze_encoder=freeze_encoder, **kwargs)
50
+
51
+ self.frontend = frontends.get_frontend(kwargs['frontend_algorithm'])
52
+ print(f"Using {self.frontend} frontend!")
53
+
54
+ def forward(self, x):
55
+ # Frontend computation
56
+ frontend_x = self.frontend(x)
57
+ x = self.compute_whisper_features(x)
58
+
59
+ x = torch.cat([x, frontend_x], 1)
60
+ out = self._compute_embedding(x)
61
+ return out
62
+
63
+
64
+ if __name__ == "__main__":
65
+ import numpy as np
66
+
67
+ input_channels = 1
68
+ device = "cpu"
69
+ classifier = WhisperLCNN(
70
+ input_channels=input_channels,
71
+ freeze_encoder=True,
72
+ device=device,
73
+ )
74
+
75
+ input_channels = 2
76
+ classifier_2 = WhisperMultiFrontLCNN(
77
+ input_channels=input_channels,
78
+ freeze_encoder=True,
79
+ device=device,
80
+ frontend_algorithm="lfcc"
81
+ )
82
+ x = np.random.rand(2, 30 * 16_000).astype(np.float32)
83
+ x = torch.from_numpy(x)
84
+
85
+ out = classifier(x)
86
+ print(out.shape)
87
+
88
+ out = classifier_2(x)
89
+ print(out.shape)
src/models/whisper_main.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Based on https://github.com/openai/whisper/blob/main/whisper/model.py
2
+ from dataclasses import dataclass
3
+ from functools import lru_cache
4
+ import os
5
+ from typing import Iterable, Optional, Union
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from torch import Tensor
11
+ from torch import nn
12
+
13
+
14
+ def exact_div(x, y):
15
+ assert x % y == 0
16
+ return x // y
17
+
18
+
19
+ # hard-coded audio hyperparameters
20
+ SAMPLE_RATE = 16000
21
+ N_FFT = 400
22
+ N_MELS = 80
23
+ HOP_LENGTH = 160
24
+ CHUNK_LENGTH = 30
25
+ N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk
26
+ N_FRAMES = exact_div(
27
+ N_SAMPLES, HOP_LENGTH
28
+ ) # 3000: number of frames in a mel spectrogram input
29
+
30
+
31
+ def pad_or_trim(
32
+ array: Union[torch.Tensor, np.ndarray],
33
+ length: int = N_SAMPLES,
34
+ *,
35
+ axis: int = -1,
36
+ ) -> torch.Tensor:
37
+ """
38
+ Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
39
+ """
40
+ if not torch.is_tensor(array):
41
+ array = torch.from_numpy(array)
42
+
43
+ if array.shape[axis] > length:
44
+ array = array.index_select(
45
+ dim=axis, index=torch.arange(length, device=array.device)
46
+ )
47
+
48
+ if array.shape[axis] < length:
49
+ # pad multiple times
50
+ num_repeats = int(length / array.shape[axis]) + 1
51
+ array = torch.tile(array, (1, num_repeats))[:, :length]
52
+ return array
53
+
54
+
55
+ @lru_cache(maxsize=None)
56
+ def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
57
+ """
58
+ load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
59
+ Allows decoupling librosa dependency; saved using:
60
+
61
+ np.savez_compressed(
62
+ "mel_filters.npz",
63
+ mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
64
+ )
65
+ """
66
+ assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
67
+ with np.load(
68
+ os.path.join(os.path.dirname(__file__), "assets/mel_filters.npz")
69
+ ) as f:
70
+ return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
71
+
72
+
73
+ def log_mel_spectrogram(audio: torch.Tensor, n_mels: int = N_MELS):
74
+ """
75
+ Compute the log-Mel spectrogram of
76
+
77
+ Parameters
78
+ ----------
79
+ audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
80
+ The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
81
+
82
+ n_mels: int
83
+ The number of Mel-frequency filters, only 80 is supported
84
+
85
+ Returns
86
+ -------
87
+ torch.Tensor, shape = (80, n_frames)
88
+ A Tensor that contains the Mel spectrogram
89
+ """
90
+ window = torch.hann_window(N_FFT).to(audio.device)
91
+ stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
92
+ magnitudes = stft[:, :-1].abs() ** 2
93
+
94
+ filters = mel_filters(audio.device, n_mels)
95
+ mel_spec = filters @ magnitudes
96
+
97
+ log_spec = torch.clamp(mel_spec, min=1e-10).log10()
98
+ log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
99
+ log_spec = (log_spec + 4.0) / 4.0
100
+ return log_spec
101
+
102
+
103
+ @dataclass
104
+ class ModelDimensions:
105
+ n_mels: int
106
+ n_audio_ctx: int
107
+ n_audio_state: int
108
+ n_audio_head: int
109
+ n_audio_layer: int
110
+ n_vocab: int
111
+ n_text_ctx: int
112
+ n_text_state: int
113
+ n_text_head: int
114
+ n_text_layer: int
115
+
116
+
117
+ class LayerNorm(nn.LayerNorm):
118
+ def forward(self, x: Tensor) -> Tensor:
119
+ return super().forward(x.float()).type(x.dtype)
120
+
121
+
122
+ class Linear(nn.Linear):
123
+ def forward(self, x: Tensor) -> Tensor:
124
+ return F.linear(
125
+ x,
126
+ self.weight.to(x.dtype),
127
+ None if self.bias is None else self.bias.to(x.dtype),
128
+ )
129
+
130
+
131
+ class Conv1d(nn.Conv1d):
132
+ def _conv_forward(
133
+ self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
134
+ ) -> Tensor:
135
+ return super()._conv_forward(
136
+ x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
137
+ )
138
+
139
+
140
+ def sinusoids(length, channels, max_timescale=10_000):
141
+ """Returns sinusoids for positional embedding"""
142
+ assert channels % 2 == 0
143
+ log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
144
+ inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
145
+ scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
146
+ return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
147
+
148
+
149
+ class MultiHeadAttention(nn.Module):
150
+ def __init__(self, n_state: int, n_head: int):
151
+ super().__init__()
152
+ self.n_head = n_head
153
+ self.query = Linear(n_state, n_state)
154
+ self.key = Linear(n_state, n_state, bias=False)
155
+ self.value = Linear(n_state, n_state)
156
+ self.out = Linear(n_state, n_state)
157
+
158
+ def forward(
159
+ self,
160
+ x: Tensor,
161
+ xa: Optional[Tensor] = None,
162
+ mask: Optional[Tensor] = None,
163
+ kv_cache: Optional[dict] = None,
164
+ ):
165
+ q = self.query(x)
166
+
167
+ if kv_cache is None or xa is None or self.key not in kv_cache:
168
+ # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
169
+ # otherwise, perform key/value projections for self- or cross-attention as usual.
170
+ k = self.key(x if xa is None else xa)
171
+ v = self.value(x if xa is None else xa)
172
+ else:
173
+ # for cross-attention, calculate keys and values once and reuse in subsequent calls.
174
+ k = kv_cache[self.key]
175
+ v = kv_cache[self.value]
176
+
177
+ wv = self.qkv_attention(q, k, v, mask)
178
+ return self.out(wv)
179
+
180
+ def qkv_attention(
181
+ self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
182
+ ):
183
+ n_batch, n_ctx, n_state = q.shape
184
+ scale = (n_state // self.n_head) ** -0.25
185
+ q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
186
+ k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
187
+ v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
188
+
189
+ qk = q @ k
190
+ if mask is not None:
191
+ qk = qk + mask[:n_ctx, :n_ctx]
192
+
193
+ w = F.softmax(qk.float(), dim=-1).to(q.dtype)
194
+ return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
195
+
196
+
197
+ class ResidualAttentionBlock(nn.Module):
198
+ def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
199
+ super().__init__()
200
+
201
+ self.attn = MultiHeadAttention(n_state, n_head)
202
+ self.attn_ln = LayerNorm(n_state)
203
+
204
+ self.cross_attn = (
205
+ MultiHeadAttention(n_state, n_head) if cross_attention else None
206
+ )
207
+ self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
208
+
209
+ n_mlp = n_state * 4
210
+ self.mlp = nn.Sequential(
211
+ Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
212
+ )
213
+ self.mlp_ln = LayerNorm(n_state)
214
+
215
+ def forward(
216
+ self,
217
+ x: Tensor,
218
+ xa: Optional[Tensor] = None,
219
+ mask: Optional[Tensor] = None,
220
+ kv_cache: Optional[dict] = None,
221
+ ):
222
+ x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)
223
+ if self.cross_attn:
224
+ x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)
225
+ x = x + self.mlp(self.mlp_ln(x))
226
+ return x
227
+
228
+
229
+ class AudioEncoder(nn.Module):
230
+ def __init__(
231
+ self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
232
+ ):
233
+ super().__init__()
234
+ self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
235
+ self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
236
+ self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
237
+
238
+ self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
239
+ [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
240
+ )
241
+ self.ln_post = LayerNorm(n_state)
242
+
243
+ def forward(self, x: Tensor):
244
+ """
245
+ x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
246
+ the mel spectrogram of the audio
247
+ """
248
+ x = F.gelu(self.conv1(x))
249
+ x = F.gelu(self.conv2(x))
250
+ x = x.permute(0, 2, 1)
251
+
252
+ assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
253
+ x = (x + self.positional_embedding).to(x.dtype)
254
+ for block in self.blocks:
255
+ x = block(x)
256
+
257
+ x = self.ln_post(x)
258
+ return x
259
+
260
+
261
+ class TextDecoder(nn.Module):
262
+ def __init__(
263
+ self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
264
+ ):
265
+ super().__init__()
266
+
267
+ self.token_embedding = nn.Embedding(n_vocab, n_state)
268
+ self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
269
+
270
+ self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
271
+ [
272
+ ResidualAttentionBlock(n_state, n_head, cross_attention=True)
273
+ for _ in range(n_layer)
274
+ ]
275
+ )
276
+ self.ln = LayerNorm(n_state)
277
+
278
+ mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
279
+ self.register_buffer("mask", mask, persistent=False)
280
+
281
+ def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
282
+ """
283
+ x : torch.LongTensor, shape = (batch_size, <= n_ctx)
284
+ the text tokens
285
+ xa : torch.Tensor, shape = (batch_size, n_mels, n_audio_ctx)
286
+ the encoded audio features to be attended on
287
+ """
288
+ offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
289
+ x = (
290
+ self.token_embedding(x)
291
+ + self.positional_embedding[offset : offset + x.shape[-1]]
292
+ )
293
+ x = x.to(xa.dtype)
294
+
295
+ for block in self.blocks:
296
+ x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
297
+
298
+ x = self.ln(x)
299
+ logits = (
300
+ x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
301
+ ).float()
302
+
303
+ return logits
304
+
305
+
306
+ class Whisper(nn.Module):
307
+ def __init__(self, dims: ModelDimensions):
308
+ super().__init__()
309
+ self.dims = dims
310
+ self.encoder = AudioEncoder(
311
+ self.dims.n_mels,
312
+ self.dims.n_audio_ctx,
313
+ self.dims.n_audio_state,
314
+ self.dims.n_audio_head,
315
+ self.dims.n_audio_layer,
316
+ )
317
+
318
+ def forward(self, mel: torch.Tensor):
319
+ return self.encoder(mel)
320
+
321
+ @property
322
+ def device(self):
323
+ return next(self.parameters()).device