Spaces:
Runtime error
Runtime error
mrneuralnet
commited on
Commit
•
3fb4562
1
Parent(s):
f067c08
Initial commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +3 -33
- .gitignore +160 -0
- LICENSE +21 -0
- app.py +108 -0
- config.yaml +16 -0
- configs/finetuning/whisper_frontend_mesonet.yaml +16 -0
- configs/training/lcnn.yaml +14 -0
- configs/training/mesonet.yaml +15 -0
- configs/training/rawnet3.yaml +13 -0
- configs/training/specrnet.yaml +14 -0
- configs/training/whisper_frontend_lcnn.yaml +16 -0
- configs/training/whisper_frontend_lcnn_mfcc.yaml +15 -0
- configs/training/whisper_frontend_mesonet.yaml +16 -0
- configs/training/whisper_frontend_mesonet_mfcc.yaml +17 -0
- configs/training/whisper_frontend_specrnet.yaml +15 -0
- configs/training/whisper_frontend_specrnet_mfcc.yaml +16 -0
- configs/training/whisper_lcnn.yaml +15 -0
- configs/training/whisper_mesonet.yaml +16 -0
- configs/training/whisper_specrnet.yaml +15 -0
- download_whisper.py +29 -0
- evaluate_models.py +316 -0
- install.sh +6 -0
- mesonet_whisper_mfcc_finetuned.pth +3 -0
- sample_files/[FAKE] - jokowi - cupid [vocals].mp3 +3 -0
- sample_files/[REAL] - Obama at Rutgers: 'Ignorance Is Not a Virtue'_[cut_49sec].mp3 +3 -0
- sample_files/[REAL] - Obama's speech to the class of 2020 in 2 minutes | The Washington Post.wav +3 -0
- sample_files/[[FAKE] - y2mate.com - DeepFake AI generated synthetic video of Barack Obama.mp3 +3 -0
- src/__init__.py +3 -0
- src/commons.py +22 -0
- src/datasets/__init__.py +0 -0
- src/datasets/asvspoof_dataset.py +155 -0
- src/datasets/base_dataset.py +180 -0
- src/datasets/deepfake_asvspoof_dataset.py +86 -0
- src/datasets/detection_dataset.py +125 -0
- src/datasets/fakeavceleb_dataset.py +94 -0
- src/datasets/folder_dataset.py +75 -0
- src/datasets/in_the_wild_dataset.py +62 -0
- src/datasets/wavefake_dataset.py +85 -0
- src/frontends.py +72 -0
- src/metrics.py +15 -0
- src/models/__init__.py +0 -0
- src/models/assets/mel_filters.npz +0 -0
- src/models/assets/tiny_enc.en.pt +3 -0
- src/models/lcnn.py +247 -0
- src/models/meso_net.py +146 -0
- src/models/models.py +73 -0
- src/models/rawnet3.py +323 -0
- src/models/specrnet.py +226 -0
- src/models/whisper_lcnn.py +89 -0
- src/models/whisper_main.py +323 -0
.gitattributes
CHANGED
@@ -1,35 +1,5 @@
|
|
1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
23 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.wav filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.mp3 filter=lfs diff=lfs merge=lfs -text
|
5 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.gitignore
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/#use-with-ide
|
110 |
+
.pdm.toml
|
111 |
+
|
112 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
113 |
+
__pypackages__/
|
114 |
+
|
115 |
+
# Celery stuff
|
116 |
+
celerybeat-schedule
|
117 |
+
celerybeat.pid
|
118 |
+
|
119 |
+
# SageMath parsed files
|
120 |
+
*.sage.py
|
121 |
+
|
122 |
+
# Environments
|
123 |
+
.env
|
124 |
+
.venv
|
125 |
+
env/
|
126 |
+
venv/
|
127 |
+
ENV/
|
128 |
+
env.bak/
|
129 |
+
venv.bak/
|
130 |
+
|
131 |
+
# Spyder project settings
|
132 |
+
.spyderproject
|
133 |
+
.spyproject
|
134 |
+
|
135 |
+
# Rope project settings
|
136 |
+
.ropeproject
|
137 |
+
|
138 |
+
# mkdocs documentation
|
139 |
+
/site
|
140 |
+
|
141 |
+
# mypy
|
142 |
+
.mypy_cache/
|
143 |
+
.dmypy.json
|
144 |
+
dmypy.json
|
145 |
+
|
146 |
+
# Pyre type checker
|
147 |
+
.pyre/
|
148 |
+
|
149 |
+
# pytype static type analyzer
|
150 |
+
.pytype/
|
151 |
+
|
152 |
+
# Cython debug symbols
|
153 |
+
cython_debug/
|
154 |
+
|
155 |
+
# PyCharm
|
156 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
157 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
158 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
159 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
160 |
+
#.idea/
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) Piotr Kawa, Marcin Plata, Michał Czuba, Piotr Szymański, Piotr Syga
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
app.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import json
|
3 |
+
import os, shutil
|
4 |
+
import re
|
5 |
+
import time
|
6 |
+
import uuid
|
7 |
+
|
8 |
+
import cv2
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import streamlit as st
|
12 |
+
from pydub import AudioSegment
|
13 |
+
import torch
|
14 |
+
import yaml
|
15 |
+
# from extract_video import extract_method_single_video
|
16 |
+
|
17 |
+
from utils import st_file_selector, img2base64
|
18 |
+
from evaluate_models import inference, load_model
|
19 |
+
from src import commons
|
20 |
+
|
21 |
+
import os
|
22 |
+
|
23 |
+
DEBUG = True
|
24 |
+
|
25 |
+
def main():
|
26 |
+
st.markdown("###")
|
27 |
+
uploaded_file = st.file_uploader('Upload an audio file', type=['wav', 'mp3'], accept_multiple_files=False)
|
28 |
+
|
29 |
+
with st.spinner(f'Loading samples...'):
|
30 |
+
while not os.path.isdir("sample_files"):
|
31 |
+
time.sleep(1)
|
32 |
+
st.markdown("### or")
|
33 |
+
selected_file = st_file_selector(st, path='sample_files', key = 'selected', label = 'Choose a sample image/video')
|
34 |
+
|
35 |
+
if uploaded_file:
|
36 |
+
random_id = uuid.uuid1()
|
37 |
+
ext = uploaded_file.name.split('.')[-1]
|
38 |
+
|
39 |
+
base_folder = "temps"
|
40 |
+
filename = "{}.{}".format(random_id, ext)
|
41 |
+
file_type = uploaded_file.type.split("/")[0]
|
42 |
+
filepath = f"{base_folder}/{filename}"
|
43 |
+
|
44 |
+
uploaded_file_length = len(uploaded_file.getvalue())
|
45 |
+
if uploaded_file_length > 0:
|
46 |
+
with open(filepath, 'wb') as f:
|
47 |
+
f.write(uploaded_file.read())
|
48 |
+
st.audio(uploaded_file, format=ext)
|
49 |
+
elif selected_file:
|
50 |
+
base_folder = "sample_files"
|
51 |
+
file_type = selected_file.split(".")[-1]
|
52 |
+
filename = selected_file.split("/")[-1]
|
53 |
+
filepath = f"{base_folder}/{selected_file}"
|
54 |
+
|
55 |
+
st.write('file_type', file_type)
|
56 |
+
with open(filepath, 'rb') as f:
|
57 |
+
audio_bytes = f.read()
|
58 |
+
st.audio(audio_bytes, format=file_type)
|
59 |
+
else:
|
60 |
+
return
|
61 |
+
|
62 |
+
|
63 |
+
|
64 |
+
|
65 |
+
with st.spinner(f'Analyzing {file_type}...'):
|
66 |
+
|
67 |
+
|
68 |
+
seed = config["data"].get("seed", 42)
|
69 |
+
# fix all seeds - this should not actually change anything
|
70 |
+
commons.set_seed(seed)
|
71 |
+
|
72 |
+
result = inference(
|
73 |
+
model,
|
74 |
+
datasets_path=filepath,
|
75 |
+
device=device,
|
76 |
+
)
|
77 |
+
result = result[0]
|
78 |
+
|
79 |
+
if 'Real' == result[0]:
|
80 |
+
st.success(f'Audio is real! \nprob:{result[1]}', icon="✅")
|
81 |
+
else:
|
82 |
+
st.error(f'Audio is fake! \nprob:{result[1]}', icon="🚨")
|
83 |
+
|
84 |
+
st.divider()
|
85 |
+
st.write('## Response JSON')
|
86 |
+
st.write(result)
|
87 |
+
|
88 |
+
|
89 |
+
def setup():
|
90 |
+
if not os.path.isdir("temps"):
|
91 |
+
os.makedirs("temps")
|
92 |
+
|
93 |
+
|
94 |
+
|
95 |
+
if __name__ == "__main__":
|
96 |
+
if torch.cuda.is_available():
|
97 |
+
device = "cuda"
|
98 |
+
else:
|
99 |
+
device = "cpu"
|
100 |
+
|
101 |
+
with open('config.yaml', "r") as f:
|
102 |
+
config = yaml.safe_load(f)
|
103 |
+
|
104 |
+
model = load_model(config, device)
|
105 |
+
|
106 |
+
st.title("Face Fake Detection")
|
107 |
+
setup()
|
108 |
+
main()
|
config.yaml
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data:
|
2 |
+
seed: 42
|
3 |
+
|
4 |
+
checkpoint:
|
5 |
+
path: C:\Users\manfr\Projects\deepfake-whisper-features\mesonet_whisper_mfcc_finetuned.pth
|
6 |
+
|
7 |
+
model:
|
8 |
+
name: whisper_frontend_mesonet
|
9 |
+
optimizer:
|
10 |
+
lr: 1.0e-06
|
11 |
+
weight_decay: 0.0001
|
12 |
+
parameters:
|
13 |
+
fc1_dim: 1024
|
14 |
+
freeze_encoder: false
|
15 |
+
frontend_algorithm: ["mfcc"]
|
16 |
+
input_channels: 2
|
configs/finetuning/whisper_frontend_mesonet.yaml
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data:
|
2 |
+
seed: 42
|
3 |
+
|
4 |
+
checkpoint:
|
5 |
+
path: "trained_models/whisper_frontend_mesonet/ckpt.pth"
|
6 |
+
|
7 |
+
model:
|
8 |
+
name: "whisper_frontend_mesonet"
|
9 |
+
parameters:
|
10 |
+
freeze_encoder: false
|
11 |
+
input_channels: 2
|
12 |
+
fc1_dim: 1024
|
13 |
+
frontend_algorithm: ["lfcc"]
|
14 |
+
optimizer:
|
15 |
+
lr: 1.0e-06
|
16 |
+
weight_decay: 0.0001
|
configs/training/lcnn.yaml
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data:
|
2 |
+
seed: 42
|
3 |
+
|
4 |
+
checkpoint:
|
5 |
+
path: ""
|
6 |
+
|
7 |
+
model:
|
8 |
+
name: "lcnn"
|
9 |
+
parameters:
|
10 |
+
input_channels: 1
|
11 |
+
frontend_algorithm: ["mfcc"]
|
12 |
+
optimizer:
|
13 |
+
lr: 0.0001
|
14 |
+
weight_decay: 0.0001
|
configs/training/mesonet.yaml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data:
|
2 |
+
seed: 42
|
3 |
+
|
4 |
+
checkpoint:
|
5 |
+
path: ""
|
6 |
+
|
7 |
+
model:
|
8 |
+
name: "mesonet"
|
9 |
+
parameters:
|
10 |
+
input_channels: 1
|
11 |
+
fc1_dim: 1024
|
12 |
+
frontend_algorithm: ["lfcc"]
|
13 |
+
optimizer:
|
14 |
+
lr: 0.0001
|
15 |
+
weight_decay: 0.0001
|
configs/training/rawnet3.yaml
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data:
|
2 |
+
seed: 42
|
3 |
+
|
4 |
+
checkpoint:
|
5 |
+
path: ""
|
6 |
+
|
7 |
+
model:
|
8 |
+
name: "rawnet3"
|
9 |
+
parameters: {}
|
10 |
+
optimizer:
|
11 |
+
lr: 0.001
|
12 |
+
weight_decay: 0.00005 # 5e-5
|
13 |
+
|
configs/training/specrnet.yaml
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data:
|
2 |
+
seed: 42
|
3 |
+
|
4 |
+
checkpoint:
|
5 |
+
path: ""
|
6 |
+
|
7 |
+
model:
|
8 |
+
name: "specrnet"
|
9 |
+
parameters:
|
10 |
+
input_channels: 1
|
11 |
+
frontend_algorithm: ["lfcc"]
|
12 |
+
optimizer:
|
13 |
+
lr: 0.0001
|
14 |
+
weight_decay: 0.0001
|
configs/training/whisper_frontend_lcnn.yaml
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data:
|
2 |
+
seed: 42
|
3 |
+
|
4 |
+
checkpoint:
|
5 |
+
path: ""
|
6 |
+
|
7 |
+
model:
|
8 |
+
name: "whisper_frontend_lcnn"
|
9 |
+
parameters:
|
10 |
+
freeze_encoder: True
|
11 |
+
input_channels: 2
|
12 |
+
frontend_algorithm: ["lfcc"]
|
13 |
+
optimizer:
|
14 |
+
lr: 0.0001
|
15 |
+
weight_decay: 0.0001
|
16 |
+
|
configs/training/whisper_frontend_lcnn_mfcc.yaml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data:
|
2 |
+
seed: 42
|
3 |
+
|
4 |
+
checkpoint:
|
5 |
+
path: ""
|
6 |
+
|
7 |
+
model:
|
8 |
+
name: "whisper_frontend_lcnn"
|
9 |
+
parameters:
|
10 |
+
freeze_encoder: True
|
11 |
+
input_channels: 2
|
12 |
+
frontend_algorithm: ["mfcc"]
|
13 |
+
optimizer:
|
14 |
+
lr: 0.0001
|
15 |
+
weight_decay: 0.0001
|
configs/training/whisper_frontend_mesonet.yaml
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data:
|
2 |
+
seed: 42
|
3 |
+
|
4 |
+
checkpoint:
|
5 |
+
path: ""
|
6 |
+
|
7 |
+
model:
|
8 |
+
name: "whisper_frontend_mesonet"
|
9 |
+
parameters:
|
10 |
+
freeze_encoder: True
|
11 |
+
input_channels: 2
|
12 |
+
fc1_dim: 1024
|
13 |
+
frontend_algorithm: ["lfcc"]
|
14 |
+
optimizer:
|
15 |
+
lr: 0.0001
|
16 |
+
weight_decay: 0.0001
|
configs/training/whisper_frontend_mesonet_mfcc.yaml
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data:
|
2 |
+
seed: 42
|
3 |
+
|
4 |
+
checkpoint:
|
5 |
+
path: ""
|
6 |
+
|
7 |
+
model:
|
8 |
+
name: "whisper_frontend_mesonet"
|
9 |
+
parameters:
|
10 |
+
freeze_encoder: True
|
11 |
+
input_channels: 2
|
12 |
+
fc1_dim: 1024
|
13 |
+
frontend_algorithm: ["mfcc"]
|
14 |
+
optimizer:
|
15 |
+
lr: 0.0001
|
16 |
+
weight_decay: 0.0001
|
17 |
+
|
configs/training/whisper_frontend_specrnet.yaml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data:
|
2 |
+
seed: 42
|
3 |
+
|
4 |
+
checkpoint:
|
5 |
+
path: ""
|
6 |
+
|
7 |
+
model:
|
8 |
+
name: "whisper_frontend_specrnet"
|
9 |
+
parameters:
|
10 |
+
freeze_encoder: True
|
11 |
+
input_channels: 2
|
12 |
+
frontend_algorithm: ["lfcc"]
|
13 |
+
optimizer:
|
14 |
+
lr: 0.0001
|
15 |
+
weight_decay: 0.0001
|
configs/training/whisper_frontend_specrnet_mfcc.yaml
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data:
|
2 |
+
seed: 42
|
3 |
+
|
4 |
+
checkpoint:
|
5 |
+
path: ""
|
6 |
+
|
7 |
+
model:
|
8 |
+
name: "whisper_frontend_specrnet"
|
9 |
+
parameters:
|
10 |
+
freeze_encoder: True
|
11 |
+
input_channels: 2
|
12 |
+
frontend_algorithm: ["mfcc"]
|
13 |
+
optimizer:
|
14 |
+
lr: 0.0001
|
15 |
+
weight_decay: 0.0001
|
16 |
+
|
configs/training/whisper_lcnn.yaml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data:
|
2 |
+
seed: 42
|
3 |
+
|
4 |
+
checkpoint:
|
5 |
+
path: ""
|
6 |
+
|
7 |
+
model:
|
8 |
+
name: "whisper_lcnn"
|
9 |
+
parameters:
|
10 |
+
freeze_encoder: True
|
11 |
+
input_channels: 1
|
12 |
+
frontend_algorithm: ["lfcc"]
|
13 |
+
optimizer:
|
14 |
+
lr: 0.0001
|
15 |
+
weight_decay: 0.0001
|
configs/training/whisper_mesonet.yaml
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data:
|
2 |
+
seed: 42
|
3 |
+
|
4 |
+
checkpoint:
|
5 |
+
path: ""
|
6 |
+
|
7 |
+
model:
|
8 |
+
name: "whisper_mesonet"
|
9 |
+
parameters:
|
10 |
+
freeze_encoder: True
|
11 |
+
input_channels: 1
|
12 |
+
fc1_dim: 1024
|
13 |
+
frontend_algorithm: []
|
14 |
+
optimizer:
|
15 |
+
lr: 0.0001
|
16 |
+
weight_decay: 0.0001
|
configs/training/whisper_specrnet.yaml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data:
|
2 |
+
seed: 42
|
3 |
+
|
4 |
+
checkpoint:
|
5 |
+
path: ""
|
6 |
+
|
7 |
+
model:
|
8 |
+
name: "whisper_specrnet"
|
9 |
+
parameters:
|
10 |
+
freeze_encoder: True
|
11 |
+
input_channels: 1
|
12 |
+
frontend_algorithm: ["lfcc"]
|
13 |
+
optimizer:
|
14 |
+
lr: 0.0001
|
15 |
+
weight_decay: 0.0001
|
download_whisper.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pip install git+https://github.com/openai/whisper.git
|
2 |
+
from collections import OrderedDict
|
3 |
+
import whisper
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from src.commons import WHISPER_MODEL_WEIGHTS_PATH
|
7 |
+
|
8 |
+
def download_whisper():
|
9 |
+
model = whisper.load_model("tiny.en")
|
10 |
+
return model
|
11 |
+
|
12 |
+
|
13 |
+
def extract_and_save_encoder(model):
|
14 |
+
model_ckpt = OrderedDict()
|
15 |
+
|
16 |
+
model_ckpt['model_state_dict'] = OrderedDict()
|
17 |
+
|
18 |
+
for key, value in model.encoder.state_dict().items():
|
19 |
+
model_ckpt['model_state_dict'][f'encoder.{key}'] = value
|
20 |
+
|
21 |
+
model_ckpt['dims'] = model.dims
|
22 |
+
torch.save(model_ckpt, WHISPER_MODEL_WEIGHTS_PATH)
|
23 |
+
|
24 |
+
|
25 |
+
if __name__ == "__main__":
|
26 |
+
model = download_whisper()
|
27 |
+
print("Downloaded Whisper model!")
|
28 |
+
extract_and_save_encoder(model)
|
29 |
+
print(f"Saved encoder at '{WHISPER_MODEL_WEIGHTS_PATH}'")
|
evaluate_models.py
ADDED
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import logging
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Dict, List, Optional, Union
|
5 |
+
import sys
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import yaml
|
9 |
+
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score
|
10 |
+
from torch.utils.data import DataLoader
|
11 |
+
|
12 |
+
from src import metrics, commons
|
13 |
+
from src.models import models
|
14 |
+
from src.datasets.base_dataset import SimpleAudioFakeDataset
|
15 |
+
from src.datasets.in_the_wild_dataset import InTheWildDataset
|
16 |
+
from src.datasets.folder_dataset import FolderDataset, FileDataset
|
17 |
+
|
18 |
+
|
19 |
+
def get_dataset(
|
20 |
+
datasets_paths: List[Union[Path, str]],
|
21 |
+
amount_to_use: Optional[int],
|
22 |
+
) -> SimpleAudioFakeDataset:
|
23 |
+
data_val = FolderDataset(
|
24 |
+
path=datasets_paths[0]
|
25 |
+
)
|
26 |
+
return data_val
|
27 |
+
|
28 |
+
def get_dataset_file(
|
29 |
+
datasets_path,
|
30 |
+
amount_to_use: Optional[int],
|
31 |
+
) -> SimpleAudioFakeDataset:
|
32 |
+
data_val = FileDataset(
|
33 |
+
path=datasets_path
|
34 |
+
)
|
35 |
+
return data_val
|
36 |
+
|
37 |
+
|
38 |
+
def evaluate_nn(
|
39 |
+
model_paths: List[Path],
|
40 |
+
datasets_paths: List[Union[Path, str]],
|
41 |
+
model_config: Dict,
|
42 |
+
device: str,
|
43 |
+
amount_to_use: Optional[int] = None,
|
44 |
+
batch_size: int = 8,
|
45 |
+
):
|
46 |
+
logging.info("Loading data...")
|
47 |
+
model_name, model_parameters = model_config["name"], model_config["parameters"]
|
48 |
+
|
49 |
+
# Load model architecture
|
50 |
+
model = models.get_model(
|
51 |
+
model_name=model_name,
|
52 |
+
config=model_parameters,
|
53 |
+
device=device,
|
54 |
+
)
|
55 |
+
# If provided weights, apply corresponding ones (from an appropriate fold)
|
56 |
+
if len(model_paths):
|
57 |
+
state_dict = torch.load(model_paths, map_location=device)
|
58 |
+
model.load_state_dict(state_dict)
|
59 |
+
model = model.to(device)
|
60 |
+
|
61 |
+
data_val = get_dataset(
|
62 |
+
datasets_paths=datasets_paths,
|
63 |
+
amount_to_use=amount_to_use,
|
64 |
+
)
|
65 |
+
|
66 |
+
logging.info(
|
67 |
+
f"Testing '{model_name}' model, weights path: '{model_paths}', on {len(data_val)} audio files."
|
68 |
+
)
|
69 |
+
test_loader = DataLoader(
|
70 |
+
data_val,
|
71 |
+
batch_size=batch_size,
|
72 |
+
shuffle=True,
|
73 |
+
drop_last=False,
|
74 |
+
num_workers=3,
|
75 |
+
)
|
76 |
+
|
77 |
+
batches_number = len(data_val) // batch_size
|
78 |
+
num_correct = 0.0
|
79 |
+
num_total = 0.0
|
80 |
+
|
81 |
+
y_pred = torch.Tensor([]).to(device)
|
82 |
+
y = torch.Tensor([]).to(device)
|
83 |
+
y_pred_label = torch.Tensor([]).to(device)
|
84 |
+
|
85 |
+
preds = []
|
86 |
+
|
87 |
+
for i, (batch_x, _, batch_y, metadata) in enumerate(test_loader):
|
88 |
+
model.eval()
|
89 |
+
_, path, _, _ = metadata
|
90 |
+
if i % 10 == 0:
|
91 |
+
print(f"Batch [{i}/{batches_number}]")
|
92 |
+
|
93 |
+
with torch.no_grad():
|
94 |
+
batch_x = batch_x.to(device)
|
95 |
+
batch_y = batch_y.to(device)
|
96 |
+
num_total += batch_x.size(0)
|
97 |
+
|
98 |
+
batch_pred = model(batch_x).squeeze(1)
|
99 |
+
batch_pred = torch.sigmoid(batch_pred)
|
100 |
+
batch_pred_label = (batch_pred + 0.5).int()
|
101 |
+
|
102 |
+
num_correct += (batch_pred_label == batch_y.int()).sum(dim=0).item()
|
103 |
+
|
104 |
+
y_pred = torch.concat([y_pred, batch_pred], dim=0)
|
105 |
+
y_pred_label = torch.concat([y_pred_label, batch_pred_label], dim=0)
|
106 |
+
y = torch.concat([y, batch_y], dim=0)
|
107 |
+
|
108 |
+
for i in range(len(y_pred_label)):
|
109 |
+
label = 'Fake' if y_pred_label[i] == 0 else 'Real'
|
110 |
+
print(f'{path[i]}')
|
111 |
+
print(f' Prediction: : {label}')
|
112 |
+
print(f' Probability: {y_pred[i]})')
|
113 |
+
preds.append((label, y_pred[i].detach().cpu().item()))
|
114 |
+
|
115 |
+
return preds
|
116 |
+
|
117 |
+
eval_accuracy = (num_correct / num_total) * 100
|
118 |
+
|
119 |
+
precision, recall, f1_score, support = precision_recall_fscore_support(
|
120 |
+
y.cpu().numpy(), y_pred_label.cpu().numpy(), average="binary", beta=1.0
|
121 |
+
)
|
122 |
+
auc_score = roc_auc_score(y_true=y.cpu().numpy(), y_score=y_pred.cpu().numpy())
|
123 |
+
|
124 |
+
# For EER flip values, following original evaluation implementation
|
125 |
+
y_for_eer = 1 - y
|
126 |
+
|
127 |
+
thresh, eer, fpr, tpr = metrics.calculate_eer(
|
128 |
+
y=y_for_eer.cpu().numpy(),
|
129 |
+
y_score=y_pred.cpu().numpy(),
|
130 |
+
)
|
131 |
+
|
132 |
+
eer_label = f"eval/eer"
|
133 |
+
accuracy_label = f"eval/accuracy"
|
134 |
+
precision_label = f"eval/precision"
|
135 |
+
recall_label = f"eval/recall"
|
136 |
+
f1_label = f"eval/f1_score"
|
137 |
+
auc_label = f"eval/auc"
|
138 |
+
|
139 |
+
logging.info(
|
140 |
+
f"{eer_label}: {eer:.4f}, {accuracy_label}: {eval_accuracy:.4f}, {precision_label}: {precision:.4f}, {recall_label}: {recall:.4f}, {f1_label}: {f1_score:.4f}, {auc_label}: {auc_score:.4f}"
|
141 |
+
)
|
142 |
+
|
143 |
+
def load_model(config, device):
|
144 |
+
model_config = config['model']
|
145 |
+
model_name, model_parameters = model_config["name"], model_config["parameters"]
|
146 |
+
model_paths = config["checkpoint"].get("path", [])
|
147 |
+
# Load model architecture
|
148 |
+
model = models.get_model(
|
149 |
+
model_name=model_name,
|
150 |
+
config=model_parameters,
|
151 |
+
device=device,
|
152 |
+
)
|
153 |
+
# If provided weights, apply corresponding ones (from an appropriate fold)
|
154 |
+
if len(model_paths):
|
155 |
+
state_dict = torch.load(model_paths, map_location=device)
|
156 |
+
model.load_state_dict(state_dict)
|
157 |
+
model = model.to(device)
|
158 |
+
return model
|
159 |
+
|
160 |
+
def inference(
|
161 |
+
model,
|
162 |
+
datasets_path,
|
163 |
+
device: str,
|
164 |
+
amount_to_use: Optional[int] = None,
|
165 |
+
batch_size: int = 8,
|
166 |
+
):
|
167 |
+
logging.info("Loading data...")
|
168 |
+
|
169 |
+
|
170 |
+
data_val = get_dataset_file(
|
171 |
+
datasets_path=datasets_path,
|
172 |
+
amount_to_use=amount_to_use,
|
173 |
+
)
|
174 |
+
|
175 |
+
test_loader = DataLoader(
|
176 |
+
data_val,
|
177 |
+
batch_size=batch_size,
|
178 |
+
shuffle=True,
|
179 |
+
drop_last=False,
|
180 |
+
num_workers=3,
|
181 |
+
)
|
182 |
+
|
183 |
+
batches_number = len(data_val) // batch_size
|
184 |
+
num_correct = 0.0
|
185 |
+
num_total = 0.0
|
186 |
+
|
187 |
+
y_pred = torch.Tensor([]).to(device)
|
188 |
+
y = torch.Tensor([]).to(device)
|
189 |
+
y_pred_label = torch.Tensor([]).to(device)
|
190 |
+
|
191 |
+
preds = []
|
192 |
+
|
193 |
+
for i, (batch_x, _, batch_y, metadata) in enumerate(test_loader):
|
194 |
+
model.eval()
|
195 |
+
_, path, _, _ = metadata
|
196 |
+
if i % 10 == 0:
|
197 |
+
print(f"Batch [{i}/{batches_number}]")
|
198 |
+
|
199 |
+
with torch.no_grad():
|
200 |
+
batch_x = batch_x.to(device)
|
201 |
+
batch_y = batch_y.to(device)
|
202 |
+
num_total += batch_x.size(0)
|
203 |
+
|
204 |
+
batch_pred = model(batch_x).squeeze(1)
|
205 |
+
batch_pred = torch.sigmoid(batch_pred)
|
206 |
+
batch_pred_label = (batch_pred + 0.5).int()
|
207 |
+
|
208 |
+
num_correct += (batch_pred_label == batch_y.int()).sum(dim=0).item()
|
209 |
+
|
210 |
+
y_pred = torch.concat([y_pred, batch_pred], dim=0)
|
211 |
+
y_pred_label = torch.concat([y_pred_label, batch_pred_label], dim=0)
|
212 |
+
y = torch.concat([y, batch_y], dim=0)
|
213 |
+
|
214 |
+
for i in range(len(y_pred_label)):
|
215 |
+
label = 'Fake' if y_pred_label[i] == 0 else 'Real'
|
216 |
+
print(f'{path[i]}')
|
217 |
+
print(f' Prediction: : {label}')
|
218 |
+
print(f' Probability: {y_pred[i]})')
|
219 |
+
preds.append((label, y_pred[i].detach().cpu().item()))
|
220 |
+
|
221 |
+
return preds
|
222 |
+
|
223 |
+
eval_accuracy = (num_correct / num_total) * 100
|
224 |
+
|
225 |
+
precision, recall, f1_score, support = precision_recall_fscore_support(
|
226 |
+
y.cpu().numpy(), y_pred_label.cpu().numpy(), average="binary", beta=1.0
|
227 |
+
)
|
228 |
+
auc_score = roc_auc_score(y_true=y.cpu().numpy(), y_score=y_pred.cpu().numpy())
|
229 |
+
|
230 |
+
# For EER flip values, following original evaluation implementation
|
231 |
+
y_for_eer = 1 - y
|
232 |
+
|
233 |
+
thresh, eer, fpr, tpr = metrics.calculate_eer(
|
234 |
+
y=y_for_eer.cpu().numpy(),
|
235 |
+
y_score=y_pred.cpu().numpy(),
|
236 |
+
)
|
237 |
+
|
238 |
+
eer_label = f"eval/eer"
|
239 |
+
accuracy_label = f"eval/accuracy"
|
240 |
+
precision_label = f"eval/precision"
|
241 |
+
recall_label = f"eval/recall"
|
242 |
+
f1_label = f"eval/f1_score"
|
243 |
+
auc_label = f"eval/auc"
|
244 |
+
|
245 |
+
logging.info(
|
246 |
+
f"{eer_label}: {eer:.4f}, {accuracy_label}: {eval_accuracy:.4f}, {precision_label}: {precision:.4f}, {recall_label}: {recall:.4f}, {f1_label}: {f1_score:.4f}, {auc_label}: {auc_score:.4f}"
|
247 |
+
)
|
248 |
+
|
249 |
+
|
250 |
+
def main(args):
|
251 |
+
LOGGER = logging.getLogger()
|
252 |
+
LOGGER.setLevel(logging.INFO)
|
253 |
+
|
254 |
+
ch = logging.StreamHandler()
|
255 |
+
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
256 |
+
ch.setFormatter(formatter)
|
257 |
+
LOGGER.addHandler(ch)
|
258 |
+
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
259 |
+
|
260 |
+
if not args.cpu and torch.cuda.is_available():
|
261 |
+
device = "cuda"
|
262 |
+
else:
|
263 |
+
device = "cpu"
|
264 |
+
|
265 |
+
with open(args.config, "r") as f:
|
266 |
+
config = yaml.safe_load(f)
|
267 |
+
|
268 |
+
seed = config["data"].get("seed", 42)
|
269 |
+
# fix all seeds - this should not actually change anything
|
270 |
+
commons.set_seed(seed)
|
271 |
+
|
272 |
+
evaluate_nn(
|
273 |
+
model_paths=config["checkpoint"].get("path", []),
|
274 |
+
datasets_paths=[
|
275 |
+
args.folder_path,
|
276 |
+
],
|
277 |
+
model_config=config["model"],
|
278 |
+
amount_to_use=args.amount,
|
279 |
+
device=device,
|
280 |
+
)
|
281 |
+
|
282 |
+
|
283 |
+
def parse_args():
|
284 |
+
parser = argparse.ArgumentParser()
|
285 |
+
|
286 |
+
# If assigned as None, then it won't be taken into account
|
287 |
+
FOLDER_DATASET_PATH = "sample_files"
|
288 |
+
|
289 |
+
parser.add_argument(
|
290 |
+
"--folder_path", type=str, default=FOLDER_DATASET_PATH
|
291 |
+
)
|
292 |
+
|
293 |
+
default_model_config = "config.yaml"
|
294 |
+
parser.add_argument(
|
295 |
+
"--config",
|
296 |
+
help="Model config file path (default: config.yaml)",
|
297 |
+
type=str,
|
298 |
+
default=default_model_config,
|
299 |
+
)
|
300 |
+
|
301 |
+
default_amount = None
|
302 |
+
parser.add_argument(
|
303 |
+
"--amount",
|
304 |
+
"-a",
|
305 |
+
help=f"Amount of files to load from each directory (default: {default_amount} - use all).",
|
306 |
+
type=int,
|
307 |
+
default=default_amount,
|
308 |
+
)
|
309 |
+
|
310 |
+
parser.add_argument("--cpu", "-c", help="Force using cpu", action="store_true")
|
311 |
+
|
312 |
+
return parser.parse_args()
|
313 |
+
|
314 |
+
|
315 |
+
if __name__ == "__main__":
|
316 |
+
main(parse_args())
|
install.sh
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch -y
|
2 |
+
|
3 |
+
pip install asteroid-filterbanks==0.4.0
|
4 |
+
pip install librosa==0.9.2
|
5 |
+
pip install git+https://github.com/openai/whisper.git@7858aa9c08d98f75575035ecd6481f462d66ca27
|
6 |
+
pip install pandas==2.0.2
|
mesonet_whisper_mfcc_finetuned.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a34a00d0961303274e1cf7a2dc2b6e9f9d568ff0416300be1aaee1c2e2ceee12
|
3 |
+
size 32983925
|
sample_files/[FAKE] - jokowi - cupid [vocals].mp3
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1ce8dce41de4f44908c57deea26d4efe5a74f9a37700a76a94ac065e862304c0
|
3 |
+
size 775449
|
sample_files/[REAL] - Obama at Rutgers: 'Ignorance Is Not a Virtue'_[cut_49sec].mp3
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6694c6d329f8a372896808c1f7c1e487eec65e5ad2fb3d244d80729b211ac0c4
|
3 |
+
size 1950720
|
sample_files/[REAL] - Obama's speech to the class of 2020 in 2 minutes | The Washington Post.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3f045f2b80fdc136c63bfc5897dd9a3d34a3b60dba886cae297d4425db30d5d9
|
3 |
+
size 27507540
|
sample_files/[[FAKE] - y2mate.com - DeepFake AI generated synthetic video of Barack Obama.mp3
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fd7100f013cb23ae4af3e00594330838dcae39dc86d669ef8fd215a6a6d88f53
|
3 |
+
size 273900
|
src/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
logging.getLogger(__name__).addHandler(logging.NullHandler())
|
src/commons.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Utility file for src toolkit."""
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
|
8 |
+
WHISPER_MODEL_WEIGHTS_PATH = "src/models/assets/tiny_enc.en.pt"
|
9 |
+
|
10 |
+
|
11 |
+
def set_seed(seed: int):
|
12 |
+
"""Fix PRNG seed for reproducable experiments.
|
13 |
+
"""
|
14 |
+
random.seed(seed)
|
15 |
+
np.random.seed(seed)
|
16 |
+
torch.manual_seed(seed)
|
17 |
+
if torch.cuda.is_available():
|
18 |
+
torch.cuda.manual_seed(seed)
|
19 |
+
torch.cuda.manual_seed_all(seed)
|
20 |
+
torch.backends.cudnn.deterministic = True
|
21 |
+
torch.backends.cudnn.benchmark = False
|
22 |
+
os.environ["PYTHONHASHSEED"] = str(seed)
|
src/datasets/__init__.py
ADDED
File without changes
|
src/datasets/asvspoof_dataset.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
import pandas as pd
|
4 |
+
if __name__ == "__main__":
|
5 |
+
import sys
|
6 |
+
sys.path.append(str(Path(__file__).parent.parent.parent.absolute()))
|
7 |
+
|
8 |
+
from src.datasets.base_dataset import SimpleAudioFakeDataset
|
9 |
+
|
10 |
+
ASVSPOOF_SPLIT = {
|
11 |
+
"train": ['A01', 'A07', 'A08', 'A02', 'A09', 'A10', 'A03', 'A04', 'A05', 'A06', 'A11', 'A12', 'A13', 'A14', 'A15', 'A16', 'A17', 'A18', 'A19'],
|
12 |
+
"test": ['A01', 'A07', 'A08', 'A02', 'A09', 'A10', 'A03', 'A04', 'A05', 'A06', 'A11', 'A12', 'A13', 'A14', 'A15', 'A16', 'A17', 'A18', 'A19'],
|
13 |
+
"val": ['A01', 'A07', 'A08', 'A02', 'A09', 'A10', 'A03', 'A04', 'A05', 'A06', 'A11', 'A12', 'A13', 'A14', 'A15', 'A16', 'A17', 'A18', 'A19'],
|
14 |
+
"partition_ratio": [0.7, 0.15],
|
15 |
+
"seed": 45,
|
16 |
+
}
|
17 |
+
|
18 |
+
|
19 |
+
class ASVSpoofDataset(SimpleAudioFakeDataset):
|
20 |
+
|
21 |
+
protocol_folder_name = "ASVspoof2019_LA_cm_protocols"
|
22 |
+
subset_dir_prefix = "ASVspoof2019_LA_"
|
23 |
+
subsets = ("train", "dev", "eval")
|
24 |
+
|
25 |
+
def __init__(self, path, subset="train", transform=None):
|
26 |
+
super().__init__(subset, transform)
|
27 |
+
self.path = path
|
28 |
+
|
29 |
+
self.allowed_attacks = ASVSPOOF_SPLIT[subset]
|
30 |
+
self.partition_ratio = ASVSPOOF_SPLIT["partition_ratio"]
|
31 |
+
self.seed = ASVSPOOF_SPLIT["seed"]
|
32 |
+
|
33 |
+
self.samples = pd.DataFrame()
|
34 |
+
|
35 |
+
for subset in self.subsets:
|
36 |
+
subset_dir = Path(self.path) / f"{self.subset_dir_prefix}{subset}"
|
37 |
+
subset_protocol_path = self.get_protocol_path(subset)
|
38 |
+
subset_samples = self.read_protocol(subset_dir, subset_protocol_path)
|
39 |
+
|
40 |
+
self.samples = pd.concat([self.samples, subset_samples])
|
41 |
+
|
42 |
+
self.transform = transform
|
43 |
+
|
44 |
+
def get_protocol_path(self, subset):
|
45 |
+
paths = list((Path(self.path) / self.protocol_folder_name).glob("*.txt"))
|
46 |
+
for path in paths:
|
47 |
+
if subset in Path(path).stem:
|
48 |
+
return path
|
49 |
+
|
50 |
+
def read_protocol(self, subset_dir, protocol_path):
|
51 |
+
samples = {
|
52 |
+
"user_id": [],
|
53 |
+
"sample_name": [],
|
54 |
+
"attack_type": [],
|
55 |
+
"label": [],
|
56 |
+
"path": []
|
57 |
+
}
|
58 |
+
|
59 |
+
real_samples = []
|
60 |
+
fake_samples = []
|
61 |
+
with open(protocol_path, "r") as file:
|
62 |
+
for line in file:
|
63 |
+
attack_type = line.strip().split(" ")[3]
|
64 |
+
|
65 |
+
if attack_type == "-":
|
66 |
+
real_samples.append(line)
|
67 |
+
elif attack_type in self.allowed_attacks:
|
68 |
+
fake_samples.append(line)
|
69 |
+
|
70 |
+
if attack_type not in self.allowed_attacks:
|
71 |
+
continue
|
72 |
+
|
73 |
+
fake_samples = self.split_samples(fake_samples)
|
74 |
+
for line in fake_samples:
|
75 |
+
samples = self.add_line_to_samples(samples, line, subset_dir)
|
76 |
+
|
77 |
+
real_samples = self.split_samples(real_samples)
|
78 |
+
for line in real_samples:
|
79 |
+
samples = self.add_line_to_samples(samples, line, subset_dir)
|
80 |
+
|
81 |
+
return pd.DataFrame(samples)
|
82 |
+
|
83 |
+
@staticmethod
|
84 |
+
def add_line_to_samples(samples, line, subset_dir):
|
85 |
+
user_id, sample_name, _, attack_type, label = line.strip().split(" ")
|
86 |
+
samples["user_id"].append(user_id)
|
87 |
+
samples["sample_name"].append(sample_name)
|
88 |
+
samples["attack_type"].append(attack_type)
|
89 |
+
samples["label"].append(label)
|
90 |
+
|
91 |
+
assert (subset_dir / "flac" / f"{sample_name}.flac").exists()
|
92 |
+
samples["path"].append(subset_dir / "flac" / f"{sample_name}.flac")
|
93 |
+
|
94 |
+
return samples
|
95 |
+
|
96 |
+
class ASVSpoof2019DatasetOriginal(ASVSpoofDataset):
|
97 |
+
|
98 |
+
subsets = {"train": "train", "test": "dev", "val": "eval"}
|
99 |
+
|
100 |
+
protocol_folder_name = "ASVspoof2019_LA_cm_protocols"
|
101 |
+
subset_dir_prefix = "ASVspoof2019_LA_"
|
102 |
+
subset_dirs_attacks = {
|
103 |
+
"train": ["A01", "A02", "A03", "A04", "A05", "A06"],
|
104 |
+
"dev": ["A01", "A02", "A03", "A04", "A05", "A06"],
|
105 |
+
"eval": [
|
106 |
+
"A07", "A08", "A09", "A10", "A11", "A12", "A13", "A14", "A15",
|
107 |
+
"A16", "A17", "A18", "A19"
|
108 |
+
]
|
109 |
+
}
|
110 |
+
|
111 |
+
|
112 |
+
def __init__(self, path, fold_subset="train"):
|
113 |
+
"""
|
114 |
+
Initialise object. Skip __init__ of ASVSpoofDataset doe to different
|
115 |
+
logic, but follow SimpleAudioFakeDataset constructor.
|
116 |
+
"""
|
117 |
+
super(ASVSpoofDataset, self).__init__(float('inf'), fold_subset)
|
118 |
+
self.path = path
|
119 |
+
subset = self.subsets[fold_subset]
|
120 |
+
self.allowed_attacks = self.subset_dirs_attacks[subset]
|
121 |
+
subset_dir = Path(self.path) / f"{self.subset_dir_prefix}{subset}"
|
122 |
+
subset_protocol_path = self.get_protocol_path(subset)
|
123 |
+
self.samples = self.read_protocol(subset_dir, subset_protocol_path)
|
124 |
+
|
125 |
+
def read_protocol(self, subset_dir, protocol_path):
|
126 |
+
samples = {
|
127 |
+
"user_id": [],
|
128 |
+
"sample_name": [],
|
129 |
+
"attack_type": [],
|
130 |
+
"label": [],
|
131 |
+
"path": []
|
132 |
+
}
|
133 |
+
|
134 |
+
real_samples = []
|
135 |
+
fake_samples = []
|
136 |
+
|
137 |
+
with open(protocol_path, "r") as file:
|
138 |
+
for line in file:
|
139 |
+
attack_type = line.strip().split(" ")[3]
|
140 |
+
if attack_type == "-":
|
141 |
+
real_samples.append(line)
|
142 |
+
elif attack_type in self.allowed_attacks:
|
143 |
+
fake_samples.append(line)
|
144 |
+
else:
|
145 |
+
raise ValueError(
|
146 |
+
"Tried to load attack that shouldn't be here!"
|
147 |
+
)
|
148 |
+
|
149 |
+
for line in fake_samples:
|
150 |
+
samples = self.add_line_to_samples(samples, line, subset_dir)
|
151 |
+
for line in real_samples:
|
152 |
+
samples = self.add_line_to_samples(samples, line, subset_dir)
|
153 |
+
|
154 |
+
return pd.DataFrame(samples)
|
155 |
+
|
src/datasets/base_dataset.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Base dataset classes."""
|
2 |
+
import logging
|
3 |
+
import math
|
4 |
+
import random
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import pandas as pd
|
8 |
+
import torch
|
9 |
+
import torchaudio
|
10 |
+
from torch.utils.data import Dataset
|
11 |
+
from torch.utils.data.dataset import T_co
|
12 |
+
|
13 |
+
|
14 |
+
LOGGER = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
SAMPLING_RATE = 16_000
|
17 |
+
APPLY_NORMALIZATION = True
|
18 |
+
APPLY_TRIMMING = True
|
19 |
+
APPLY_PADDING = True
|
20 |
+
FRAMES_NUMBER = 480_000 # <- originally 64_600
|
21 |
+
|
22 |
+
|
23 |
+
SOX_SILENCE = [
|
24 |
+
# trim all silence that is longer than 0.2s and louder than 1% volume (relative to the file)
|
25 |
+
# from beginning and middle/end
|
26 |
+
["silence", "1", "0.2", "1%", "-1", "0.2", "1%"],
|
27 |
+
]
|
28 |
+
|
29 |
+
|
30 |
+
class SimpleAudioFakeDataset(Dataset):
|
31 |
+
def __init__(
|
32 |
+
self,
|
33 |
+
subset,
|
34 |
+
transform=None,
|
35 |
+
return_label: bool = True,
|
36 |
+
return_meta: bool = True,
|
37 |
+
):
|
38 |
+
self.transform = transform
|
39 |
+
|
40 |
+
self.subset = subset
|
41 |
+
self.allowed_attacks = None
|
42 |
+
self.partition_ratio = None
|
43 |
+
self.seed = None
|
44 |
+
self.return_label = return_label
|
45 |
+
self.return_meta = return_meta
|
46 |
+
|
47 |
+
def split_samples(self, samples_list):
|
48 |
+
if isinstance(samples_list, pd.DataFrame):
|
49 |
+
samples_list = samples_list.sort_values(by=list(samples_list.columns))
|
50 |
+
samples_list = samples_list.sample(frac=1, random_state=self.seed)
|
51 |
+
else:
|
52 |
+
samples_list = sorted(samples_list)
|
53 |
+
random.seed(self.seed)
|
54 |
+
random.shuffle(samples_list)
|
55 |
+
|
56 |
+
p, s = self.partition_ratio
|
57 |
+
subsets = np.split(
|
58 |
+
samples_list, [int(p * len(samples_list)), int((p + s) * len(samples_list))]
|
59 |
+
)
|
60 |
+
return dict(zip(["train", "test", "val"], subsets))[self.subset]
|
61 |
+
|
62 |
+
def df2tuples(self):
|
63 |
+
tuple_samples = []
|
64 |
+
for i, elem in self.samples.iterrows():
|
65 |
+
tuple_samples.append(
|
66 |
+
(str(elem["path"]), elem["label"], elem["attack_type"])
|
67 |
+
)
|
68 |
+
|
69 |
+
self.samples = tuple_samples
|
70 |
+
|
71 |
+
|
72 |
+
return self.samples
|
73 |
+
|
74 |
+
def __getitem__(self, index) -> T_co:
|
75 |
+
if isinstance(self.samples, pd.DataFrame):
|
76 |
+
sample = self.samples.iloc[index]
|
77 |
+
|
78 |
+
path = str(sample["path"])
|
79 |
+
label = sample["label"]
|
80 |
+
attack_type = sample["attack_type"]
|
81 |
+
if type(attack_type) != str and math.isnan(attack_type):
|
82 |
+
attack_type = "N/A"
|
83 |
+
else:
|
84 |
+
path, label, attack_type = self.samples[index]
|
85 |
+
|
86 |
+
waveform, sample_rate = torchaudio.load(path, normalize=APPLY_NORMALIZATION)
|
87 |
+
import librosa
|
88 |
+
# waveform, sample_rate = librosa.load(path, sr=SAMPLING_RATE)
|
89 |
+
# waveform = torch.tensor(waveform)
|
90 |
+
print('waveform', waveform)
|
91 |
+
real_sec_length = len(waveform[0]) / sample_rate
|
92 |
+
|
93 |
+
waveform, sample_rate = apply_preprocessing(waveform, sample_rate)
|
94 |
+
|
95 |
+
return_data = [waveform, sample_rate]
|
96 |
+
if self.return_label:
|
97 |
+
label = 1 if label == "bonafide" else 0
|
98 |
+
return_data.append(label)
|
99 |
+
|
100 |
+
if self.return_meta:
|
101 |
+
return_data.append(
|
102 |
+
(
|
103 |
+
attack_type,
|
104 |
+
path,
|
105 |
+
self.subset,
|
106 |
+
real_sec_length,
|
107 |
+
)
|
108 |
+
)
|
109 |
+
return return_data
|
110 |
+
|
111 |
+
def __len__(self):
|
112 |
+
return len(self.samples)
|
113 |
+
|
114 |
+
|
115 |
+
def apply_preprocessing(
|
116 |
+
waveform,
|
117 |
+
sample_rate,
|
118 |
+
):
|
119 |
+
if sample_rate != SAMPLING_RATE and SAMPLING_RATE != -1:
|
120 |
+
waveform, sample_rate = resample_wave(waveform, sample_rate, SAMPLING_RATE)
|
121 |
+
|
122 |
+
# Stereo to mono
|
123 |
+
if waveform.dim() > 1 and waveform.shape[0] > 1:
|
124 |
+
waveform = waveform[:1, ...]
|
125 |
+
|
126 |
+
# Trim too long utterances...
|
127 |
+
if APPLY_TRIMMING:
|
128 |
+
waveform, sample_rate = apply_trim(waveform, sample_rate)
|
129 |
+
|
130 |
+
# ... or pad too short ones.
|
131 |
+
if APPLY_PADDING:
|
132 |
+
waveform = apply_pad(waveform, FRAMES_NUMBER)
|
133 |
+
|
134 |
+
return waveform, sample_rate
|
135 |
+
|
136 |
+
|
137 |
+
def resample_wave(waveform, sample_rate, target_sample_rate):
|
138 |
+
# waveform, sample_rate = torchaudio.sox_effects.apply_effects_tensor(
|
139 |
+
# waveform, sample_rate, [["rate", f"{target_sample_rate}"]]
|
140 |
+
# )
|
141 |
+
waveform = torchaudio.functional.resample(waveform, orig_freq=sample_rate, new_freq=target_sample_rate)
|
142 |
+
return waveform, target_sample_rate
|
143 |
+
|
144 |
+
|
145 |
+
def resample_file(path, target_sample_rate, normalize=True):
|
146 |
+
waveform, sample_rate = torchaudio.sox_effects.apply_effects_file(
|
147 |
+
path, [["rate", f"{target_sample_rate}"]], normalize=normalize
|
148 |
+
)
|
149 |
+
|
150 |
+
return waveform, sample_rate
|
151 |
+
|
152 |
+
|
153 |
+
def apply_trim(waveform, sample_rate):
|
154 |
+
# (
|
155 |
+
# waveform_trimmed,
|
156 |
+
# sample_rate_trimmed,
|
157 |
+
# ) = torchaudio.sox_effects.apply_effects_tensor(waveform, sample_rate, SOX_SILENCE)
|
158 |
+
|
159 |
+
["silence", "1", "0.2", "1%", "-1", "0.2", "1%"],
|
160 |
+
waveform_trimmed = torchaudio.functional.vad(waveform, sample_rate=sample_rate)
|
161 |
+
|
162 |
+
if waveform_trimmed.size()[1] > 0:
|
163 |
+
waveform = waveform_trimmed
|
164 |
+
|
165 |
+
return waveform, sample_rate
|
166 |
+
|
167 |
+
|
168 |
+
def apply_pad(waveform, cut):
|
169 |
+
"""Pad wave by repeating signal until `cut` length is achieved."""
|
170 |
+
waveform = waveform.squeeze(0)
|
171 |
+
waveform_len = waveform.shape[0]
|
172 |
+
|
173 |
+
if waveform_len >= cut:
|
174 |
+
return waveform[:cut]
|
175 |
+
|
176 |
+
# need to pad
|
177 |
+
num_repeats = int(cut / waveform_len) + 1
|
178 |
+
padded_waveform = torch.tile(waveform, (1, num_repeats))[:, :cut][0]
|
179 |
+
|
180 |
+
return padded_waveform
|
src/datasets/deepfake_asvspoof_dataset.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
import pandas as pd
|
5 |
+
|
6 |
+
from src.datasets.base_dataset import SimpleAudioFakeDataset
|
7 |
+
|
8 |
+
DF_ASVSPOOF_SPLIT = {
|
9 |
+
"partition_ratio": [0.7, 0.15],
|
10 |
+
"seed": 45
|
11 |
+
}
|
12 |
+
|
13 |
+
LOGGER = logging.getLogger()
|
14 |
+
|
15 |
+
class DeepFakeASVSpoofDataset(SimpleAudioFakeDataset):
|
16 |
+
|
17 |
+
protocol_file_name = "keys/CM/trial_metadata.txt"
|
18 |
+
subset_dir_prefix = "ASVspoof2021_DF_eval"
|
19 |
+
subset_parts = ("part00", "part01", "part02", "part03")
|
20 |
+
|
21 |
+
def __init__(self, path, subset="train", transform=None):
|
22 |
+
super().__init__(subset, transform)
|
23 |
+
self.path = path
|
24 |
+
|
25 |
+
self.partition_ratio = DF_ASVSPOOF_SPLIT["partition_ratio"]
|
26 |
+
self.seed = DF_ASVSPOOF_SPLIT["seed"]
|
27 |
+
|
28 |
+
self.flac_paths = self.get_file_references()
|
29 |
+
self.samples = self.read_protocol()
|
30 |
+
|
31 |
+
self.transform = transform
|
32 |
+
LOGGER.info(f"Spoof: {len(self.samples[self.samples['label'] == 'spoof'])}")
|
33 |
+
LOGGER.info(f"Original: {len(self.samples[self.samples['label'] == 'bonafide'])}")
|
34 |
+
|
35 |
+
def get_file_references(self):
|
36 |
+
flac_paths = {}
|
37 |
+
for part in self.subset_parts:
|
38 |
+
path = Path(self.path) / f"{self.subset_dir_prefix}_{part}" / self.subset_dir_prefix / "flac"
|
39 |
+
flac_list = list(path.glob("*.flac"))
|
40 |
+
|
41 |
+
for path in flac_list:
|
42 |
+
flac_paths[path.stem] = path
|
43 |
+
|
44 |
+
return flac_paths
|
45 |
+
|
46 |
+
def read_protocol(self):
|
47 |
+
samples = {
|
48 |
+
"sample_name": [],
|
49 |
+
"label": [],
|
50 |
+
"path": [],
|
51 |
+
"attack_type": [],
|
52 |
+
}
|
53 |
+
|
54 |
+
real_samples = []
|
55 |
+
fake_samples = []
|
56 |
+
with open(Path(self.path) / self.protocol_file_name, "r") as file:
|
57 |
+
for line in file:
|
58 |
+
label = line.strip().split(" ")[5]
|
59 |
+
|
60 |
+
if label == "bonafide":
|
61 |
+
real_samples.append(line)
|
62 |
+
elif label == "spoof":
|
63 |
+
fake_samples.append(line)
|
64 |
+
|
65 |
+
fake_samples = self.split_samples(fake_samples)
|
66 |
+
for line in fake_samples:
|
67 |
+
samples = self.add_line_to_samples(samples, line)
|
68 |
+
|
69 |
+
real_samples = self.split_samples(real_samples)
|
70 |
+
for line in real_samples:
|
71 |
+
samples = self.add_line_to_samples(samples, line)
|
72 |
+
|
73 |
+
return pd.DataFrame(samples)
|
74 |
+
|
75 |
+
def add_line_to_samples(self, samples, line):
|
76 |
+
_, sample_name, _, _, _, label, _, _ = line.strip().split(" ")
|
77 |
+
samples["sample_name"].append(sample_name)
|
78 |
+
samples["label"].append(label)
|
79 |
+
samples["attack_type"].append(label)
|
80 |
+
|
81 |
+
sample_path = self.flac_paths[sample_name]
|
82 |
+
assert sample_path.exists()
|
83 |
+
samples["path"].append(sample_path)
|
84 |
+
|
85 |
+
return samples
|
86 |
+
|
src/datasets/detection_dataset.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from typing import List, Optional
|
3 |
+
|
4 |
+
|
5 |
+
import pandas as pd
|
6 |
+
|
7 |
+
from src.datasets.base_dataset import SimpleAudioFakeDataset
|
8 |
+
from src.datasets.deepfake_asvspoof_dataset import DeepFakeASVSpoofDataset
|
9 |
+
from src.datasets.fakeavceleb_dataset import FakeAVCelebDataset
|
10 |
+
from src.datasets.wavefake_dataset import WaveFakeDataset
|
11 |
+
from src.datasets.asvspoof_dataset import ASVSpoof2019DatasetOriginal
|
12 |
+
|
13 |
+
|
14 |
+
LOGGER = logging.getLogger()
|
15 |
+
|
16 |
+
|
17 |
+
class DetectionDataset(SimpleAudioFakeDataset):
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
asvspoof_path=None,
|
21 |
+
wavefake_path=None,
|
22 |
+
fakeavceleb_path=None,
|
23 |
+
asvspoof2019_path=None,
|
24 |
+
subset: str = "val",
|
25 |
+
transform=None,
|
26 |
+
oversample: bool = True,
|
27 |
+
undersample: bool = False,
|
28 |
+
return_label: bool = True,
|
29 |
+
reduced_number: Optional[int] = None,
|
30 |
+
return_meta: bool = False,
|
31 |
+
):
|
32 |
+
super().__init__(
|
33 |
+
subset=subset,
|
34 |
+
transform=transform,
|
35 |
+
return_label=return_label,
|
36 |
+
return_meta=return_meta,
|
37 |
+
)
|
38 |
+
datasets = self._init_datasets(
|
39 |
+
asvspoof_path=asvspoof_path,
|
40 |
+
wavefake_path=wavefake_path,
|
41 |
+
fakeavceleb_path=fakeavceleb_path,
|
42 |
+
asvspoof2019_path=asvspoof2019_path,
|
43 |
+
subset=subset,
|
44 |
+
)
|
45 |
+
self.samples = pd.concat([ds.samples for ds in datasets], ignore_index=True)
|
46 |
+
|
47 |
+
if oversample:
|
48 |
+
self.oversample_dataset()
|
49 |
+
elif undersample:
|
50 |
+
self.undersample_dataset()
|
51 |
+
|
52 |
+
if reduced_number:
|
53 |
+
LOGGER.info(f"Using reduced number of samples - {reduced_number}!")
|
54 |
+
self.samples = self.samples.sample(
|
55 |
+
min(len(self.samples), reduced_number),
|
56 |
+
random_state=42,
|
57 |
+
)
|
58 |
+
|
59 |
+
def _init_datasets(
|
60 |
+
self,
|
61 |
+
asvspoof_path: Optional[str],
|
62 |
+
wavefake_path: Optional[str],
|
63 |
+
fakeavceleb_path: Optional[str],
|
64 |
+
asvspoof2019_path: Optional[str],
|
65 |
+
subset: str,
|
66 |
+
) -> List[SimpleAudioFakeDataset]:
|
67 |
+
datasets = []
|
68 |
+
|
69 |
+
if asvspoof_path is not None:
|
70 |
+
asvspoof_dataset = DeepFakeASVSpoofDataset(asvspoof_path, subset=subset)
|
71 |
+
datasets.append(asvspoof_dataset)
|
72 |
+
|
73 |
+
if wavefake_path is not None:
|
74 |
+
wavefake_dataset = WaveFakeDataset(wavefake_path, subset=subset)
|
75 |
+
datasets.append(wavefake_dataset)
|
76 |
+
|
77 |
+
if fakeavceleb_path is not None:
|
78 |
+
fakeavceleb_dataset = FakeAVCelebDataset(fakeavceleb_path, subset=subset)
|
79 |
+
datasets.append(fakeavceleb_dataset)
|
80 |
+
|
81 |
+
if asvspoof2019_path is not None:
|
82 |
+
la_dataset = ASVSpoof2019DatasetOriginal(
|
83 |
+
asvspoof2019_path, fold_subset=subset
|
84 |
+
)
|
85 |
+
datasets.append(la_dataset)
|
86 |
+
|
87 |
+
return datasets
|
88 |
+
|
89 |
+
def oversample_dataset(self):
|
90 |
+
samples = self.samples.groupby(by=["label"])
|
91 |
+
bona_length = len(samples.groups["bonafide"])
|
92 |
+
spoof_length = len(samples.groups["spoof"])
|
93 |
+
|
94 |
+
diff_length = spoof_length - bona_length
|
95 |
+
|
96 |
+
if diff_length < 0:
|
97 |
+
raise NotImplementedError
|
98 |
+
|
99 |
+
if diff_length > 0:
|
100 |
+
bonafide = samples.get_group("bonafide").sample(diff_length, replace=True)
|
101 |
+
self.samples = pd.concat([self.samples, bonafide], ignore_index=True)
|
102 |
+
|
103 |
+
def undersample_dataset(self):
|
104 |
+
samples = self.samples.groupby(by=["label"])
|
105 |
+
bona_length = len(samples.groups["bonafide"])
|
106 |
+
spoof_length = len(samples.groups["spoof"])
|
107 |
+
|
108 |
+
if spoof_length < bona_length:
|
109 |
+
raise NotImplementedError
|
110 |
+
|
111 |
+
if spoof_length > bona_length:
|
112 |
+
spoofs = samples.get_group("spoof").sample(bona_length, replace=True)
|
113 |
+
self.samples = pd.concat(
|
114 |
+
[samples.get_group("bonafide"), spoofs], ignore_index=True
|
115 |
+
)
|
116 |
+
|
117 |
+
def get_bonafide_only(self):
|
118 |
+
samples = self.samples.groupby(by=["label"])
|
119 |
+
self.samples = samples.get_group("bonafide")
|
120 |
+
return self.samples
|
121 |
+
|
122 |
+
def get_spoof_only(self):
|
123 |
+
samples = self.samples.groupby(by=["label"])
|
124 |
+
self.samples = samples.get_group("spoof")
|
125 |
+
return self.samples
|
src/datasets/fakeavceleb_dataset.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
import pandas as pd
|
4 |
+
|
5 |
+
from src.datasets.base_dataset import SimpleAudioFakeDataset
|
6 |
+
|
7 |
+
FAKEAVCELEB_SPLIT = {
|
8 |
+
"train": ['faceswap-wav2lip', 'fsgan-wav2lip', 'wav2lip', 'rtvc'],
|
9 |
+
"test": ['faceswap-wav2lip', 'fsgan-wav2lip', 'wav2lip', 'rtvc'],
|
10 |
+
"val": ['faceswap-wav2lip', 'fsgan-wav2lip', 'wav2lip', 'rtvc'],
|
11 |
+
"partition_ratio": [0.7, 0.15],
|
12 |
+
"seed": 45
|
13 |
+
}
|
14 |
+
|
15 |
+
|
16 |
+
class FakeAVCelebDataset(SimpleAudioFakeDataset):
|
17 |
+
|
18 |
+
audio_folder = "FakeAVCeleb-audio"
|
19 |
+
audio_extension = ".mp3"
|
20 |
+
metadata_file = Path(audio_folder) / "meta_data.csv"
|
21 |
+
subsets = ("train", "dev", "eval")
|
22 |
+
|
23 |
+
def __init__(self, path, subset="train", transform=None):
|
24 |
+
super().__init__(subset, transform)
|
25 |
+
self.path = path
|
26 |
+
|
27 |
+
self.subset = subset
|
28 |
+
self.allowed_attacks = FAKEAVCELEB_SPLIT[subset]
|
29 |
+
self.partition_ratio = FAKEAVCELEB_SPLIT["partition_ratio"]
|
30 |
+
self.seed = FAKEAVCELEB_SPLIT["seed"]
|
31 |
+
|
32 |
+
self.metadata = self.get_metadata()
|
33 |
+
|
34 |
+
self.samples = pd.concat([self.get_fake_samples(), self.get_real_samples()], ignore_index=True)
|
35 |
+
|
36 |
+
def get_metadata(self):
|
37 |
+
md = pd.read_csv(Path(self.path) / self.metadata_file)
|
38 |
+
md["audio_type"] = md["type"].apply(lambda x: x.split("-")[-1])
|
39 |
+
return md
|
40 |
+
|
41 |
+
def get_fake_samples(self):
|
42 |
+
samples = {
|
43 |
+
"user_id": [],
|
44 |
+
"sample_name": [],
|
45 |
+
"attack_type": [],
|
46 |
+
"label": [],
|
47 |
+
"path": []
|
48 |
+
}
|
49 |
+
|
50 |
+
for attack_name in self.allowed_attacks:
|
51 |
+
fake_samples = self.metadata[
|
52 |
+
(self.metadata["method"] == attack_name) & (self.metadata["audio_type"] == "FakeAudio")
|
53 |
+
]
|
54 |
+
|
55 |
+
samples_list = fake_samples.iterrows()
|
56 |
+
samples_list = self.split_samples(samples_list)
|
57 |
+
|
58 |
+
for _, sample in samples_list:
|
59 |
+
samples["user_id"].append(sample["source"])
|
60 |
+
samples["sample_name"].append(Path(sample["filename"]).stem)
|
61 |
+
samples["attack_type"].append(sample["method"])
|
62 |
+
samples["label"].append("spoof")
|
63 |
+
samples["path"].append(self.get_file_path(sample))
|
64 |
+
|
65 |
+
return pd.DataFrame(samples)
|
66 |
+
|
67 |
+
def get_real_samples(self):
|
68 |
+
samples = {
|
69 |
+
"user_id": [],
|
70 |
+
"sample_name": [],
|
71 |
+
"attack_type": [],
|
72 |
+
"label": [],
|
73 |
+
"path": []
|
74 |
+
}
|
75 |
+
|
76 |
+
samples_list = self.metadata[
|
77 |
+
(self.metadata["method"] == "real") & (self.metadata["audio_type"] == "RealAudio")
|
78 |
+
]
|
79 |
+
|
80 |
+
samples_list = self.split_samples(samples_list)
|
81 |
+
|
82 |
+
for index, sample in samples_list.iterrows():
|
83 |
+
samples["user_id"].append(sample["source"])
|
84 |
+
samples["sample_name"].append(Path(sample["filename"]).stem)
|
85 |
+
samples["attack_type"].append("-")
|
86 |
+
samples["label"].append("bonafide")
|
87 |
+
samples["path"].append(self.get_file_path(sample))
|
88 |
+
|
89 |
+
return pd.DataFrame(samples)
|
90 |
+
|
91 |
+
def get_file_path(self, sample):
|
92 |
+
path = "/".join([self.audio_folder, *sample["path"].split("/")[1:]])
|
93 |
+
return Path(self.path) / path / Path(sample["filename"]).with_suffix(self.audio_extension)
|
94 |
+
|
src/datasets/folder_dataset.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import pandas as pd
|
3 |
+
import os
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
from src.datasets.base_dataset import SimpleAudioFakeDataset
|
7 |
+
|
8 |
+
|
9 |
+
class FolderDataset(SimpleAudioFakeDataset):
|
10 |
+
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
path,
|
14 |
+
subset="test",
|
15 |
+
transform=None,
|
16 |
+
):
|
17 |
+
super().__init__(subset=subset, transform=transform)
|
18 |
+
self.path = path
|
19 |
+
self.samples = self.read_samples()
|
20 |
+
|
21 |
+
|
22 |
+
def read_samples(self):
|
23 |
+
path = Path(self.path)
|
24 |
+
print('ori path', path)
|
25 |
+
print('list', os.listdir(path))
|
26 |
+
|
27 |
+
samples = []
|
28 |
+
for filepath in os.listdir(path):
|
29 |
+
samples.append({
|
30 |
+
'path': path / filepath,
|
31 |
+
'label': '',
|
32 |
+
'attack_type': '',
|
33 |
+
})
|
34 |
+
|
35 |
+
samples = pd.DataFrame(samples)
|
36 |
+
print('samples', samples)
|
37 |
+
return samples
|
38 |
+
|
39 |
+
|
40 |
+
class FileDataset(SimpleAudioFakeDataset):
|
41 |
+
|
42 |
+
def __init__(
|
43 |
+
self,
|
44 |
+
path,
|
45 |
+
subset="test",
|
46 |
+
transform=None,
|
47 |
+
):
|
48 |
+
super().__init__(subset=subset, transform=transform)
|
49 |
+
self.path = path
|
50 |
+
self.samples = self.read_samples()
|
51 |
+
|
52 |
+
|
53 |
+
def read_samples(self):
|
54 |
+
path = Path(self.path)
|
55 |
+
|
56 |
+
samples = [{'path': path, 'label': '', 'attack_type':''}]
|
57 |
+
|
58 |
+
samples = pd.DataFrame(samples)
|
59 |
+
print('samples', samples)
|
60 |
+
return samples
|
61 |
+
|
62 |
+
|
63 |
+
if __name__ == "__main__":
|
64 |
+
dataset = InTheWildDataset(
|
65 |
+
path="../datasets/release_in_the_wild",
|
66 |
+
subset="val",
|
67 |
+
seed=242,
|
68 |
+
split_strategy="per_speaker"
|
69 |
+
)
|
70 |
+
|
71 |
+
print(len(dataset))
|
72 |
+
print(len(dataset.samples["user_id"].unique()))
|
73 |
+
print(dataset.samples["user_id"].unique())
|
74 |
+
|
75 |
+
print(dataset[0])
|
src/datasets/in_the_wild_dataset.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import pandas as pd
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
from src.datasets.base_dataset import SimpleAudioFakeDataset
|
6 |
+
|
7 |
+
|
8 |
+
class InTheWildDataset(SimpleAudioFakeDataset):
|
9 |
+
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
path,
|
13 |
+
subset="train",
|
14 |
+
transform=None,
|
15 |
+
seed=None,
|
16 |
+
partition_ratio=(0.7, 0.15),
|
17 |
+
split_strategy="random"
|
18 |
+
):
|
19 |
+
super().__init__(subset=subset, transform=transform)
|
20 |
+
self.path = path
|
21 |
+
self.read_samples()
|
22 |
+
self.partition_ratio = partition_ratio
|
23 |
+
self.seed = seed
|
24 |
+
|
25 |
+
|
26 |
+
def read_samples(self):
|
27 |
+
path = Path(self.path)
|
28 |
+
meta_path = path / "meta.csv"
|
29 |
+
|
30 |
+
self.samples = pd.read_csv(meta_path)
|
31 |
+
self.samples["path"] = self.samples["file"].apply(lambda n: str(path / n))
|
32 |
+
self.samples["file"] = self.samples["file"].apply(lambda n: Path(n).stem)
|
33 |
+
self.samples["label"] = self.samples["label"].map({"bona-fide": "bonafide", "spoof": "spoof"})
|
34 |
+
self.samples["attack_type"] = self.samples["label"].map({"bonafide": "-", "spoof": "X"})
|
35 |
+
self.samples.rename(columns={'file': 'sample_name', 'speaker': 'user_id'}, inplace=True)
|
36 |
+
|
37 |
+
|
38 |
+
def split_samples_per_speaker(self, samples):
|
39 |
+
speaker_list = pd.Series(samples["user_id"].unique())
|
40 |
+
speaker_list = speaker_list.sort_values()
|
41 |
+
speaker_list = speaker_list.sample(frac=1, random_state=self.seed)
|
42 |
+
speaker_list = list(speaker_list)
|
43 |
+
|
44 |
+
p, s = self.partition_ratio
|
45 |
+
subsets = np.split(speaker_list, [int(p * len(speaker_list)), int((p + s) * len(speaker_list))])
|
46 |
+
speaker_subset = dict(zip(['train', 'test', 'val'], subsets))[self.subset]
|
47 |
+
return self.samples[self.samples["user_id"].isin(speaker_subset)]
|
48 |
+
|
49 |
+
|
50 |
+
if __name__ == "__main__":
|
51 |
+
dataset = InTheWildDataset(
|
52 |
+
path="../datasets/release_in_the_wild",
|
53 |
+
subset="val",
|
54 |
+
seed=242,
|
55 |
+
split_strategy="per_speaker"
|
56 |
+
)
|
57 |
+
|
58 |
+
print(len(dataset))
|
59 |
+
print(len(dataset.samples["user_id"].unique()))
|
60 |
+
print(dataset.samples["user_id"].unique())
|
61 |
+
|
62 |
+
print(dataset[0])
|
src/datasets/wavefake_dataset.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
import pandas as pd
|
4 |
+
|
5 |
+
from src.datasets.base_dataset import SimpleAudioFakeDataset
|
6 |
+
|
7 |
+
WAVEFAKE_SPLIT = {
|
8 |
+
"train": ['multi_band_melgan', 'melgan_large', 'parallel_wavegan', 'waveglow', 'full_band_melgan', 'melgan', 'hifiGAN'],
|
9 |
+
"test": ['multi_band_melgan', 'melgan_large', 'parallel_wavegan', 'waveglow', 'full_band_melgan', 'melgan', 'hifiGAN'],
|
10 |
+
"val": ['multi_band_melgan', 'melgan_large', 'parallel_wavegan', 'waveglow', 'full_band_melgan', 'melgan', 'hifiGAN'],
|
11 |
+
"partition_ratio": [0.7, 0.15],
|
12 |
+
"seed": 45
|
13 |
+
}
|
14 |
+
|
15 |
+
|
16 |
+
class WaveFakeDataset(SimpleAudioFakeDataset):
|
17 |
+
|
18 |
+
fake_data_path = "generated_audio"
|
19 |
+
jsut_real_data_path = "real_audio/jsut_ver1.1/basic5000/wav"
|
20 |
+
ljspeech_real_data_path = "real_audio/LJSpeech-1.1/wavs"
|
21 |
+
|
22 |
+
def __init__(self, path, subset="train", transform=None):
|
23 |
+
super().__init__(subset, transform)
|
24 |
+
self.path = Path(path)
|
25 |
+
|
26 |
+
self.fold_subset = subset
|
27 |
+
self.allowed_attacks = WAVEFAKE_SPLIT[subset]
|
28 |
+
self.partition_ratio = WAVEFAKE_SPLIT["partition_ratio"]
|
29 |
+
self.seed = WAVEFAKE_SPLIT["seed"]
|
30 |
+
|
31 |
+
self.samples = pd.concat([self.get_fake_samples(), self.get_real_samples()], ignore_index=True)
|
32 |
+
|
33 |
+
def get_fake_samples(self):
|
34 |
+
samples = {
|
35 |
+
"user_id": [],
|
36 |
+
"sample_name": [],
|
37 |
+
"attack_type": [],
|
38 |
+
"label": [],
|
39 |
+
"path": []
|
40 |
+
}
|
41 |
+
|
42 |
+
samples_list = list((self.path / self.fake_data_path).glob("*/*.wav"))
|
43 |
+
samples_list = self.filter_samples_by_attack(samples_list)
|
44 |
+
samples_list = self.split_samples(samples_list)
|
45 |
+
|
46 |
+
for sample in samples_list:
|
47 |
+
samples["user_id"].append(None)
|
48 |
+
samples["sample_name"].append("_".join(sample.stem.split("_")[:-1]))
|
49 |
+
samples["attack_type"].append(self.get_attack_from_path(sample))
|
50 |
+
samples["label"].append("spoof")
|
51 |
+
samples["path"].append(sample)
|
52 |
+
|
53 |
+
return pd.DataFrame(samples)
|
54 |
+
|
55 |
+
def filter_samples_by_attack(self, samples_list):
|
56 |
+
return [s for s in samples_list if self.get_attack_from_path(s) in self.allowed_attacks]
|
57 |
+
|
58 |
+
def get_real_samples(self):
|
59 |
+
samples = {
|
60 |
+
"user_id": [],
|
61 |
+
"sample_name": [],
|
62 |
+
"attack_type": [],
|
63 |
+
"label": [],
|
64 |
+
"path": []
|
65 |
+
}
|
66 |
+
|
67 |
+
samples_list = list((self.path / self.jsut_real_data_path).glob("*.wav"))
|
68 |
+
samples_list += list((self.path / self.ljspeech_real_data_path).glob("*.wav"))
|
69 |
+
samples_list = self.split_samples(samples_list)
|
70 |
+
|
71 |
+
for sample in samples_list:
|
72 |
+
samples["user_id"].append(None)
|
73 |
+
samples["sample_name"].append(sample.stem)
|
74 |
+
samples["attack_type"].append("-")
|
75 |
+
samples["label"].append("bonafide")
|
76 |
+
samples["path"].append(sample)
|
77 |
+
|
78 |
+
return pd.DataFrame(samples)
|
79 |
+
|
80 |
+
@staticmethod
|
81 |
+
def get_attack_from_path(path):
|
82 |
+
folder_name = path.parents[0].relative_to(path.parents[1])
|
83 |
+
return str(folder_name).split("_", maxsplit=1)[-1]
|
84 |
+
|
85 |
+
|
src/frontends.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Union, Callable
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torchaudio
|
5 |
+
|
6 |
+
SAMPLING_RATE = 16_000
|
7 |
+
win_length = 400 # int((25 / 1_000) * SAMPLING_RATE)
|
8 |
+
hop_length = 160 # int((10 / 1_000) * SAMPLING_RATE)
|
9 |
+
|
10 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
11 |
+
|
12 |
+
MFCC_FN = torchaudio.transforms.MFCC(
|
13 |
+
sample_rate=SAMPLING_RATE,
|
14 |
+
n_mfcc=128,
|
15 |
+
melkwargs={
|
16 |
+
"n_fft": 512,
|
17 |
+
"win_length": win_length,
|
18 |
+
"hop_length": hop_length,
|
19 |
+
},
|
20 |
+
).to(device)
|
21 |
+
|
22 |
+
|
23 |
+
LFCC_FN = torchaudio.transforms.LFCC(
|
24 |
+
sample_rate=SAMPLING_RATE,
|
25 |
+
n_lfcc=128,
|
26 |
+
speckwargs={
|
27 |
+
"n_fft": 512,
|
28 |
+
"win_length": win_length,
|
29 |
+
"hop_length": hop_length,
|
30 |
+
},
|
31 |
+
).to(device)
|
32 |
+
|
33 |
+
MEL_SCALE_FN = torchaudio.transforms.MelScale(
|
34 |
+
n_mels=80,
|
35 |
+
n_stft=257,
|
36 |
+
sample_rate=SAMPLING_RATE,
|
37 |
+
).to(device)
|
38 |
+
|
39 |
+
delta_fn = torchaudio.transforms.ComputeDeltas(
|
40 |
+
win_length=400,
|
41 |
+
mode="replicate",
|
42 |
+
)
|
43 |
+
|
44 |
+
|
45 |
+
def get_frontend(
|
46 |
+
frontends: List[str],
|
47 |
+
) -> Union[torchaudio.transforms.MFCC, torchaudio.transforms.LFCC, Callable,]:
|
48 |
+
if "mfcc" in frontends:
|
49 |
+
return prepare_mfcc_double_delta
|
50 |
+
elif "lfcc" in frontends:
|
51 |
+
return prepare_lfcc_double_delta
|
52 |
+
raise ValueError(f"{frontends} frontend is not supported!")
|
53 |
+
|
54 |
+
|
55 |
+
def prepare_lfcc_double_delta(input):
|
56 |
+
if input.ndim < 4:
|
57 |
+
input = input.unsqueeze(1) # (bs, 1, n_lfcc, frames)
|
58 |
+
x = LFCC_FN(input)
|
59 |
+
delta = delta_fn(x)
|
60 |
+
double_delta = delta_fn(delta)
|
61 |
+
x = torch.cat((x, delta, double_delta), 2) # -> [bs, 1, 128 * 3, 1500]
|
62 |
+
return x[:, :, :, :3000] # (bs, n, n_lfcc * 3, frames)
|
63 |
+
|
64 |
+
|
65 |
+
def prepare_mfcc_double_delta(input):
|
66 |
+
if input.ndim < 4:
|
67 |
+
input = input.unsqueeze(1) # (bs, 1, n_lfcc, frames)
|
68 |
+
x = MFCC_FN(input)
|
69 |
+
delta = delta_fn(x)
|
70 |
+
double_delta = delta_fn(delta)
|
71 |
+
x = torch.cat((x, delta, double_delta), 2) # -> [bs, 1, 128 * 3, 1500]
|
72 |
+
return x[:, :, :, :3000] # (bs, n, n_lfcc * 3, frames)
|
src/metrics.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
from scipy.interpolate import interp1d
|
5 |
+
from scipy.optimize import brentq
|
6 |
+
from sklearn.metrics import roc_curve
|
7 |
+
from sklearn.metrics import roc_curve
|
8 |
+
|
9 |
+
|
10 |
+
def calculate_eer(y, y_score) -> Tuple[float, float, np.ndarray, np.ndarray]:
|
11 |
+
fpr, tpr, thresholds = roc_curve(y, -y_score)
|
12 |
+
|
13 |
+
eer = brentq(lambda x: 1.0 - x - interp1d(fpr, tpr)(x), 0.0, 1.0)
|
14 |
+
thresh = interp1d(fpr, thresholds)(eer)
|
15 |
+
return thresh, eer, fpr, tpr
|
src/models/__init__.py
ADDED
File without changes
|
src/models/assets/mel_filters.npz
ADDED
Binary file (2.05 kB). View file
|
|
src/models/assets/tiny_enc.en.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:206cca585e8ee06b813f958f72c548aebd489f125ef8949ad437f9fcc86e8cda
|
3 |
+
size 32853468
|
src/models/lcnn.py
ADDED
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This code is modified version of LCNN baseline
|
3 |
+
from ASVSpoof2021 challenge - https://github.com/asvspoof-challenge/2021/blob/main/LA/Baseline-LFCC-LCNN/project/baseline_LA/model.py
|
4 |
+
"""
|
5 |
+
import sys
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as torch_nn
|
9 |
+
|
10 |
+
from src import frontends
|
11 |
+
|
12 |
+
|
13 |
+
NUM_COEFFICIENTS = 384
|
14 |
+
|
15 |
+
|
16 |
+
# For blstm
|
17 |
+
class BLSTMLayer(torch_nn.Module):
|
18 |
+
""" Wrapper over dilated conv1D
|
19 |
+
Input tensor: (batchsize=1, length, dim_in)
|
20 |
+
Output tensor: (batchsize=1, length, dim_out)
|
21 |
+
We want to keep the length the same
|
22 |
+
"""
|
23 |
+
def __init__(self, input_dim, output_dim):
|
24 |
+
super().__init__()
|
25 |
+
if output_dim % 2 != 0:
|
26 |
+
print("Output_dim of BLSTMLayer is {:d}".format(output_dim))
|
27 |
+
print("BLSTMLayer expects a layer size of even number")
|
28 |
+
sys.exit(1)
|
29 |
+
# bi-directional LSTM
|
30 |
+
self.l_blstm = torch_nn.LSTM(
|
31 |
+
input_dim,
|
32 |
+
output_dim // 2,
|
33 |
+
bidirectional=True
|
34 |
+
)
|
35 |
+
def forward(self, x):
|
36 |
+
# permute to (length, batchsize=1, dim)
|
37 |
+
blstm_data, _ = self.l_blstm(x.permute(1, 0, 2))
|
38 |
+
# permute it backt to (batchsize=1, length, dim)
|
39 |
+
return blstm_data.permute(1, 0, 2)
|
40 |
+
|
41 |
+
|
42 |
+
class MaxFeatureMap2D(torch_nn.Module):
|
43 |
+
""" Max feature map (along 2D)
|
44 |
+
|
45 |
+
MaxFeatureMap2D(max_dim=1)
|
46 |
+
|
47 |
+
l_conv2d = MaxFeatureMap2D(1)
|
48 |
+
data_in = torch.rand([1, 4, 5, 5])
|
49 |
+
data_out = l_conv2d(data_in)
|
50 |
+
|
51 |
+
|
52 |
+
Input:
|
53 |
+
------
|
54 |
+
data_in: tensor of shape (batch, channel, ...)
|
55 |
+
|
56 |
+
Output:
|
57 |
+
-------
|
58 |
+
data_out: tensor of shape (batch, channel//2, ...)
|
59 |
+
|
60 |
+
Note
|
61 |
+
----
|
62 |
+
By default, Max-feature-map is on channel dimension,
|
63 |
+
and maxout is used on (channel ...)
|
64 |
+
"""
|
65 |
+
def __init__(self, max_dim = 1):
|
66 |
+
super().__init__()
|
67 |
+
self.max_dim = max_dim
|
68 |
+
|
69 |
+
def forward(self, inputs):
|
70 |
+
# suppose inputs (batchsize, channel, length, dim)
|
71 |
+
|
72 |
+
shape = list(inputs.size())
|
73 |
+
|
74 |
+
if self.max_dim >= len(shape):
|
75 |
+
print("MaxFeatureMap: maximize on %d dim" % (self.max_dim))
|
76 |
+
print("But input has %d dimensions" % (len(shape)))
|
77 |
+
sys.exit(1)
|
78 |
+
if shape[self.max_dim] // 2 * 2 != shape[self.max_dim]:
|
79 |
+
print("MaxFeatureMap: maximize on %d dim" % (self.max_dim))
|
80 |
+
print("But this dimension has an odd number of data")
|
81 |
+
sys.exit(1)
|
82 |
+
shape[self.max_dim] = shape[self.max_dim]//2
|
83 |
+
shape.insert(self.max_dim, 2)
|
84 |
+
|
85 |
+
# view to (batchsize, 2, channel//2, ...)
|
86 |
+
# maximize on the 2nd dim
|
87 |
+
m, i = inputs.view(*shape).max(self.max_dim)
|
88 |
+
return m
|
89 |
+
|
90 |
+
|
91 |
+
##############
|
92 |
+
## FOR MODEL
|
93 |
+
##############
|
94 |
+
|
95 |
+
class LCNN(torch_nn.Module):
|
96 |
+
""" Model definition
|
97 |
+
"""
|
98 |
+
def __init__(self, **kwargs):
|
99 |
+
super().__init__()
|
100 |
+
input_channels = kwargs.get("input_channels", 1)
|
101 |
+
num_coefficients = kwargs.get("num_coefficients", NUM_COEFFICIENTS)
|
102 |
+
|
103 |
+
# Working sampling rate
|
104 |
+
self.num_coefficients = num_coefficients
|
105 |
+
|
106 |
+
# dimension of embedding vectors
|
107 |
+
# here, the embedding is just the activation before sigmoid()
|
108 |
+
self.v_emd_dim = 1
|
109 |
+
|
110 |
+
# it can handle models with multiple front-end configuration
|
111 |
+
# by default, only a single front-end
|
112 |
+
|
113 |
+
self.m_transform = torch_nn.Sequential(
|
114 |
+
torch_nn.Conv2d(input_channels, 64, (5, 5), 1, padding=(2, 2)),
|
115 |
+
MaxFeatureMap2D(),
|
116 |
+
torch.nn.MaxPool2d((2, 2), (2, 2)),
|
117 |
+
|
118 |
+
torch_nn.Conv2d(32, 64, (1, 1), 1, padding=(0, 0)),
|
119 |
+
MaxFeatureMap2D(),
|
120 |
+
torch_nn.BatchNorm2d(32, affine=False),
|
121 |
+
torch_nn.Conv2d(32, 96, (3, 3), 1, padding=(1, 1)),
|
122 |
+
MaxFeatureMap2D(),
|
123 |
+
|
124 |
+
torch.nn.MaxPool2d((2, 2), (2, 2)),
|
125 |
+
torch_nn.BatchNorm2d(48, affine=False),
|
126 |
+
|
127 |
+
torch_nn.Conv2d(48, 96, (1, 1), 1, padding=(0, 0)),
|
128 |
+
MaxFeatureMap2D(),
|
129 |
+
torch_nn.BatchNorm2d(48, affine=False),
|
130 |
+
torch_nn.Conv2d(48, 128, (3, 3), 1, padding=(1, 1)),
|
131 |
+
MaxFeatureMap2D(),
|
132 |
+
|
133 |
+
torch.nn.MaxPool2d((2, 2), (2, 2)),
|
134 |
+
|
135 |
+
torch_nn.Conv2d(64, 128, (1, 1), 1, padding=(0, 0)),
|
136 |
+
MaxFeatureMap2D(),
|
137 |
+
torch_nn.BatchNorm2d(64, affine=False),
|
138 |
+
torch_nn.Conv2d(64, 64, (3, 3), 1, padding=(1, 1)),
|
139 |
+
MaxFeatureMap2D(),
|
140 |
+
torch_nn.BatchNorm2d(32, affine=False),
|
141 |
+
|
142 |
+
torch_nn.Conv2d(32, 64, (1, 1), 1, padding=(0, 0)),
|
143 |
+
MaxFeatureMap2D(),
|
144 |
+
torch_nn.BatchNorm2d(32, affine=False),
|
145 |
+
torch_nn.Conv2d(32, 64, (3, 3), 1, padding=(1, 1)),
|
146 |
+
MaxFeatureMap2D(),
|
147 |
+
torch_nn.MaxPool2d((2, 2), (2, 2)),
|
148 |
+
|
149 |
+
torch_nn.Dropout(0.7)
|
150 |
+
)
|
151 |
+
|
152 |
+
self.m_before_pooling = torch_nn.Sequential(
|
153 |
+
BLSTMLayer((self.num_coefficients//16) * 32, (self.num_coefficients//16) * 32),
|
154 |
+
BLSTMLayer((self.num_coefficients//16) * 32, (self.num_coefficients//16) * 32)
|
155 |
+
)
|
156 |
+
|
157 |
+
self.m_output_act = torch_nn.Linear((self.num_coefficients // 16) * 32, self.v_emd_dim)
|
158 |
+
|
159 |
+
def _compute_embedding(self, x):
|
160 |
+
""" definition of forward method
|
161 |
+
Assume x (batchsize, length, dim)
|
162 |
+
Output x (batchsize * number_filter, output_dim)
|
163 |
+
"""
|
164 |
+
# resample if necessary
|
165 |
+
# x = self.m_resampler(x.squeeze(-1)).unsqueeze(-1)
|
166 |
+
|
167 |
+
# number of sub models
|
168 |
+
batch_size = x.shape[0]
|
169 |
+
|
170 |
+
# buffer to store output scores from sub-models
|
171 |
+
output_emb = torch.zeros(
|
172 |
+
[batch_size, self.v_emd_dim],
|
173 |
+
device=x.device,
|
174 |
+
dtype=x.dtype
|
175 |
+
)
|
176 |
+
|
177 |
+
# compute scores for each sub-models
|
178 |
+
idx = 0
|
179 |
+
|
180 |
+
# compute scores
|
181 |
+
# 1. unsqueeze to (batch, 1, frame_length, fft_bin)
|
182 |
+
# 2. compute hidden features
|
183 |
+
x = x.permute(0,1,3,2)
|
184 |
+
hidden_features = self.m_transform(x)
|
185 |
+
|
186 |
+
# 3. (batch, channel, frame//N, feat_dim//N) ->
|
187 |
+
# (batch, frame//N, channel * feat_dim//N)
|
188 |
+
# where N is caused by conv with stride
|
189 |
+
hidden_features = hidden_features.permute(0, 2, 1, 3).contiguous()
|
190 |
+
frame_num = hidden_features.shape[1]
|
191 |
+
|
192 |
+
hidden_features = hidden_features.view(batch_size, frame_num, -1)
|
193 |
+
# 4. pooling
|
194 |
+
# 4. pass through LSTM then summingc
|
195 |
+
hidden_features_lstm = self.m_before_pooling(hidden_features)
|
196 |
+
|
197 |
+
# 5. pass through the output layer
|
198 |
+
tmp_emb = self.m_output_act((hidden_features_lstm + hidden_features).mean(1))
|
199 |
+
output_emb[idx * batch_size : (idx+1) * batch_size] = tmp_emb
|
200 |
+
|
201 |
+
return output_emb
|
202 |
+
|
203 |
+
def _compute_score(self, feature_vec):
|
204 |
+
# feature_vec is [batch * submodel, 1]
|
205 |
+
return torch.sigmoid(feature_vec).squeeze(1)
|
206 |
+
|
207 |
+
def forward(self, x):
|
208 |
+
feature_vec = self._compute_embedding(x)
|
209 |
+
return feature_vec
|
210 |
+
|
211 |
+
|
212 |
+
|
213 |
+
class FrontendLCNN(LCNN):
|
214 |
+
""" Model definition
|
215 |
+
"""
|
216 |
+
def __init__(self, device: str = "cuda", **kwargs):
|
217 |
+
super().__init__(**kwargs)
|
218 |
+
|
219 |
+
self.device = device
|
220 |
+
|
221 |
+
frontend_name = kwargs.get("frontend_algorithm", [])
|
222 |
+
self.frontend = frontends.get_frontend(frontend_name)
|
223 |
+
print(f"Using {frontend_name} frontend")
|
224 |
+
|
225 |
+
def _compute_frontend(self, x):
|
226 |
+
frontend = self.frontend(x)
|
227 |
+
if frontend.ndim < 4:
|
228 |
+
return frontend.unsqueeze(1) # (bs, 1, n_lfcc, frames)
|
229 |
+
return frontend # (bs, n, n_lfcc, frames)
|
230 |
+
|
231 |
+
def forward(self, x):
|
232 |
+
x = self._compute_frontend(x)
|
233 |
+
feature_vec = self._compute_embedding(x)
|
234 |
+
|
235 |
+
return feature_vec
|
236 |
+
|
237 |
+
|
238 |
+
if __name__ == "__main__":
|
239 |
+
|
240 |
+
device = "cuda"
|
241 |
+
print("Definition of model")
|
242 |
+
model = FrontendLCNN(input_channels=2, num_coefficients=80, device=device, frontend_algorithm=["mel_spec"])
|
243 |
+
model = model.to(device)
|
244 |
+
batch_size = 12
|
245 |
+
mock_input = torch.rand((batch_size, 64_600,), device=device)
|
246 |
+
output = model(mock_input)
|
247 |
+
print(output.shape)
|
src/models/meso_net.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This code is modified version of MesoNet DeepFake detection solution
|
3 |
+
from FakeAVCeleb repository - https://github.com/DASH-Lab/FakeAVCeleb/blob/main/models/MesoNet.py.
|
4 |
+
"""
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
from src import frontends
|
9 |
+
|
10 |
+
|
11 |
+
class MesoInception4(nn.Module):
|
12 |
+
"""
|
13 |
+
Pytorch Implemention of MesoInception4
|
14 |
+
Author: Honggu Liu
|
15 |
+
Date: July 7, 2019
|
16 |
+
"""
|
17 |
+
def __init__(self, num_classes=1, **kwargs):
|
18 |
+
super().__init__()
|
19 |
+
|
20 |
+
self.fc1_dim = kwargs.get("fc1_dim", 1024)
|
21 |
+
input_channels = kwargs.get("input_channels", 3)
|
22 |
+
self.num_classes = num_classes
|
23 |
+
|
24 |
+
#InceptionLayer1
|
25 |
+
self.Incption1_conv1 = nn.Conv2d(input_channels, 1, 1, padding=0, bias=False)
|
26 |
+
self.Incption1_conv2_1 = nn.Conv2d(input_channels, 4, 1, padding=0, bias=False)
|
27 |
+
self.Incption1_conv2_2 = nn.Conv2d(4, 4, 3, padding=1, bias=False)
|
28 |
+
self.Incption1_conv3_1 = nn.Conv2d(input_channels, 4, 1, padding=0, bias=False)
|
29 |
+
self.Incption1_conv3_2 = nn.Conv2d(4, 4, 3, padding=2, dilation=2, bias=False)
|
30 |
+
self.Incption1_conv4_1 = nn.Conv2d(input_channels, 2, 1, padding=0, bias=False)
|
31 |
+
self.Incption1_conv4_2 = nn.Conv2d(2, 2, 3, padding=3, dilation=3, bias=False)
|
32 |
+
self.Incption1_bn = nn.BatchNorm2d(11)
|
33 |
+
|
34 |
+
|
35 |
+
#InceptionLayer2
|
36 |
+
self.Incption2_conv1 = nn.Conv2d(11, 2, 1, padding=0, bias=False)
|
37 |
+
self.Incption2_conv2_1 = nn.Conv2d(11, 4, 1, padding=0, bias=False)
|
38 |
+
self.Incption2_conv2_2 = nn.Conv2d(4, 4, 3, padding=1, bias=False)
|
39 |
+
self.Incption2_conv3_1 = nn.Conv2d(11, 4, 1, padding=0, bias=False)
|
40 |
+
self.Incption2_conv3_2 = nn.Conv2d(4, 4, 3, padding=2, dilation=2, bias=False)
|
41 |
+
self.Incption2_conv4_1 = nn.Conv2d(11, 2, 1, padding=0, bias=False)
|
42 |
+
self.Incption2_conv4_2 = nn.Conv2d(2, 2, 3, padding=3, dilation=3, bias=False)
|
43 |
+
self.Incption2_bn = nn.BatchNorm2d(12)
|
44 |
+
|
45 |
+
#Normal Layer
|
46 |
+
self.conv1 = nn.Conv2d(12, 16, 5, padding=2, bias=False)
|
47 |
+
self.relu = nn.ReLU(inplace=True)
|
48 |
+
self.leakyrelu = nn.LeakyReLU(0.1)
|
49 |
+
self.bn1 = nn.BatchNorm2d(16)
|
50 |
+
self.maxpooling1 = nn.MaxPool2d(kernel_size=(2, 2))
|
51 |
+
|
52 |
+
self.conv2 = nn.Conv2d(16, 16, 5, padding=2, bias=False)
|
53 |
+
self.maxpooling2 = nn.MaxPool2d(kernel_size=(4, 4))
|
54 |
+
|
55 |
+
self.dropout = nn.Dropout2d(0.5)
|
56 |
+
self.fc1 = nn.Linear(self.fc1_dim, 16)
|
57 |
+
self.fc2 = nn.Linear(16, num_classes)
|
58 |
+
|
59 |
+
|
60 |
+
#InceptionLayer
|
61 |
+
def InceptionLayer1(self, input):
|
62 |
+
x1 = self.Incption1_conv1(input)
|
63 |
+
x2 = self.Incption1_conv2_1(input)
|
64 |
+
x2 = self.Incption1_conv2_2(x2)
|
65 |
+
x3 = self.Incption1_conv3_1(input)
|
66 |
+
x3 = self.Incption1_conv3_2(x3)
|
67 |
+
x4 = self.Incption1_conv4_1(input)
|
68 |
+
x4 = self.Incption1_conv4_2(x4)
|
69 |
+
y = torch.cat((x1, x2, x3, x4), 1)
|
70 |
+
y = self.Incption1_bn(y)
|
71 |
+
y = self.maxpooling1(y)
|
72 |
+
|
73 |
+
return y
|
74 |
+
|
75 |
+
def InceptionLayer2(self, input):
|
76 |
+
x1 = self.Incption2_conv1(input)
|
77 |
+
x2 = self.Incption2_conv2_1(input)
|
78 |
+
x2 = self.Incption2_conv2_2(x2)
|
79 |
+
x3 = self.Incption2_conv3_1(input)
|
80 |
+
x3 = self.Incption2_conv3_2(x3)
|
81 |
+
x4 = self.Incption2_conv4_1(input)
|
82 |
+
x4 = self.Incption2_conv4_2(x4)
|
83 |
+
y = torch.cat((x1, x2, x3, x4), 1)
|
84 |
+
y = self.Incption2_bn(y)
|
85 |
+
y = self.maxpooling1(y)
|
86 |
+
|
87 |
+
return y
|
88 |
+
|
89 |
+
def forward(self, input):
|
90 |
+
x = self._compute_embedding(input)
|
91 |
+
return x
|
92 |
+
|
93 |
+
def _compute_embedding(self, input):
|
94 |
+
x = self.InceptionLayer1(input) #(Batch, 11, 128, 128)
|
95 |
+
x = self.InceptionLayer2(x) #(Batch, 12, 64, 64)
|
96 |
+
|
97 |
+
x = self.conv1(x) #(Batch, 16, 64 ,64)
|
98 |
+
x = self.relu(x)
|
99 |
+
x = self.bn1(x)
|
100 |
+
x = self.maxpooling1(x) #(Batch, 16, 32, 32)
|
101 |
+
|
102 |
+
x = self.conv2(x) #(Batch, 16, 32, 32)
|
103 |
+
x = self.relu(x)
|
104 |
+
x = self.bn1(x)
|
105 |
+
x = self.maxpooling2(x) #(Batch, 16, 8, 8)
|
106 |
+
|
107 |
+
x = x.view(x.size(0), -1) #(Batch, 16*8*8)
|
108 |
+
x = self.dropout(x)
|
109 |
+
|
110 |
+
x = nn.AdaptiveAvgPool1d(self.fc1_dim)(x)
|
111 |
+
x = self.fc1(x) #(Batch, 16) ### <-- o tu
|
112 |
+
x = self.leakyrelu(x)
|
113 |
+
x = self.dropout(x)
|
114 |
+
x = self.fc2(x)
|
115 |
+
return x
|
116 |
+
|
117 |
+
|
118 |
+
class FrontendMesoInception4(MesoInception4):
|
119 |
+
|
120 |
+
def __init__(self, **kwargs):
|
121 |
+
super().__init__(**kwargs)
|
122 |
+
|
123 |
+
self.device = kwargs['device']
|
124 |
+
|
125 |
+
frontend_name = kwargs.get("frontend_algorithm", [])
|
126 |
+
self.frontend = frontends.get_frontend(frontend_name)
|
127 |
+
print(f"Using {frontend_name} frontend")
|
128 |
+
|
129 |
+
def forward(self, x):
|
130 |
+
x = self.frontend(x)
|
131 |
+
x = self._compute_embedding(x)
|
132 |
+
return x
|
133 |
+
|
134 |
+
|
135 |
+
if __name__ == "__main__":
|
136 |
+
model = FrontendMesoInception4(
|
137 |
+
input_channels=2,
|
138 |
+
fc1_dim=1024,
|
139 |
+
device='cuda',
|
140 |
+
frontend_algorithm="lfcc"
|
141 |
+
)
|
142 |
+
|
143 |
+
def count_parameters(model) -> int:
|
144 |
+
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
145 |
+
return pytorch_total_params
|
146 |
+
print(count_parameters(model))
|
src/models/models.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict
|
2 |
+
|
3 |
+
from src.models import (
|
4 |
+
lcnn,
|
5 |
+
specrnet,
|
6 |
+
whisper_specrnet,
|
7 |
+
rawnet3,
|
8 |
+
whisper_lcnn,
|
9 |
+
meso_net,
|
10 |
+
whisper_meso_net
|
11 |
+
)
|
12 |
+
|
13 |
+
|
14 |
+
def get_model(model_name: str, config: Dict, device: str):
|
15 |
+
if model_name == "rawnet3":
|
16 |
+
return rawnet3.prepare_model()
|
17 |
+
elif model_name == "lcnn":
|
18 |
+
return lcnn.FrontendLCNN(device=device, **config)
|
19 |
+
elif model_name == "specrnet":
|
20 |
+
return specrnet.FrontendSpecRNet(
|
21 |
+
device=device,
|
22 |
+
**config,
|
23 |
+
)
|
24 |
+
elif model_name == "mesonet":
|
25 |
+
return meso_net.FrontendMesoInception4(
|
26 |
+
input_channels=config.get("input_channels", 1),
|
27 |
+
fc1_dim=config.get("fc1_dim", 1024),
|
28 |
+
frontend_algorithm=config.get("frontend_algorithm", "lfcc"),
|
29 |
+
device=device,
|
30 |
+
)
|
31 |
+
elif model_name == "whisper_lcnn":
|
32 |
+
return whisper_lcnn.WhisperLCNN(
|
33 |
+
input_channels=config.get("input_channels", 1),
|
34 |
+
freeze_encoder=config.get("freeze_encoder", False),
|
35 |
+
device=device,
|
36 |
+
)
|
37 |
+
elif model_name == "whisper_specrnet":
|
38 |
+
return whisper_specrnet.WhisperSpecRNet(
|
39 |
+
input_channels=config.get("input_channels", 1),
|
40 |
+
freeze_encoder=config.get("freeze_encoder", False),
|
41 |
+
device=device,
|
42 |
+
)
|
43 |
+
elif model_name == "whisper_mesonet":
|
44 |
+
return whisper_meso_net.WhisperMesoNet(
|
45 |
+
input_channels=config.get("input_channels", 1),
|
46 |
+
freeze_encoder=config.get("freeze_encoder", True),
|
47 |
+
fc1_dim=config.get("fc1_dim", 1024),
|
48 |
+
device=device,
|
49 |
+
)
|
50 |
+
elif model_name == "whisper_frontend_lcnn":
|
51 |
+
return whisper_lcnn.WhisperMultiFrontLCNN(
|
52 |
+
input_channels=config.get("input_channels", 2),
|
53 |
+
freeze_encoder=config.get("freeze_encoder", False),
|
54 |
+
frontend_algorithm=config.get("frontend_algorithm", "lfcc"),
|
55 |
+
device=device,
|
56 |
+
)
|
57 |
+
elif model_name == "whisper_frontend_specrnet":
|
58 |
+
return whisper_specrnet.WhisperMultiFrontSpecRNet(
|
59 |
+
input_channels=config.get("input_channels", 2),
|
60 |
+
freeze_encoder=config.get("freeze_encoder", False),
|
61 |
+
frontend_algorithm=config.get("frontend_algorithm", "lfcc"),
|
62 |
+
device=device,
|
63 |
+
)
|
64 |
+
elif model_name == "whisper_frontend_mesonet":
|
65 |
+
return whisper_meso_net.WhisperMultiFrontMesoNet(
|
66 |
+
input_channels=config.get("input_channels", 2),
|
67 |
+
fc1_dim=config.get("fc1_dim", 1024),
|
68 |
+
freeze_encoder=config.get("freeze_encoder", True),
|
69 |
+
frontend_algorithm=config.get("frontend_algorithm", "lfcc"),
|
70 |
+
device=device,
|
71 |
+
)
|
72 |
+
else:
|
73 |
+
raise ValueError(f"Model '{model_name}' not supported")
|
src/models/rawnet3.py
ADDED
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This file contains implementation of RawNet3 architecture.
|
3 |
+
The original implementation can be found here: https://github.com/Jungjee/RawNet/tree/master/python/RawNet3
|
4 |
+
"""
|
5 |
+
import math
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from asteroid_filterbanks import Encoder, ParamSincFB # pip install asteroid_filterbanks
|
11 |
+
|
12 |
+
|
13 |
+
class RawNet3(nn.Module):
|
14 |
+
def __init__(self, block, model_scale, context, summed, C=1024, **kwargs):
|
15 |
+
super().__init__()
|
16 |
+
|
17 |
+
nOut = kwargs["nOut"]
|
18 |
+
|
19 |
+
self.context = context
|
20 |
+
self.encoder_type = kwargs["encoder_type"]
|
21 |
+
self.log_sinc = kwargs["log_sinc"]
|
22 |
+
self.norm_sinc = kwargs["norm_sinc"]
|
23 |
+
self.out_bn = kwargs["out_bn"]
|
24 |
+
self.summed = summed
|
25 |
+
|
26 |
+
self.preprocess = nn.Sequential(
|
27 |
+
PreEmphasis(), nn.InstanceNorm1d(1, eps=1e-4, affine=True)
|
28 |
+
)
|
29 |
+
self.conv1 = Encoder(
|
30 |
+
ParamSincFB(
|
31 |
+
C // 4,
|
32 |
+
251,
|
33 |
+
stride=kwargs["sinc_stride"],
|
34 |
+
)
|
35 |
+
)
|
36 |
+
self.relu = nn.ReLU()
|
37 |
+
self.bn1 = nn.BatchNorm1d(C // 4)
|
38 |
+
|
39 |
+
self.layer1 = block(
|
40 |
+
C // 4, C, kernel_size=3, dilation=2, scale=model_scale, pool=5
|
41 |
+
)
|
42 |
+
self.layer2 = block(
|
43 |
+
C, C, kernel_size=3, dilation=3, scale=model_scale, pool=3
|
44 |
+
)
|
45 |
+
self.layer3 = block(C, C, kernel_size=3, dilation=4, scale=model_scale)
|
46 |
+
self.layer4 = nn.Conv1d(3 * C, 1536, kernel_size=1)
|
47 |
+
|
48 |
+
if self.context:
|
49 |
+
attn_input = 1536 * 3
|
50 |
+
else:
|
51 |
+
attn_input = 1536
|
52 |
+
print("self.encoder_type", self.encoder_type)
|
53 |
+
if self.encoder_type == "ECA":
|
54 |
+
attn_output = 1536
|
55 |
+
elif self.encoder_type == "ASP":
|
56 |
+
attn_output = 1
|
57 |
+
else:
|
58 |
+
raise ValueError("Undefined encoder")
|
59 |
+
|
60 |
+
self.attention = nn.Sequential(
|
61 |
+
nn.Conv1d(attn_input, 128, kernel_size=1),
|
62 |
+
nn.ReLU(),
|
63 |
+
nn.BatchNorm1d(128),
|
64 |
+
nn.Conv1d(128, attn_output, kernel_size=1),
|
65 |
+
nn.Softmax(dim=2),
|
66 |
+
)
|
67 |
+
|
68 |
+
self.bn5 = nn.BatchNorm1d(3072)
|
69 |
+
|
70 |
+
self.fc6 = nn.Linear(3072, nOut)
|
71 |
+
self.bn6 = nn.BatchNorm1d(nOut)
|
72 |
+
|
73 |
+
self.mp3 = nn.MaxPool1d(3)
|
74 |
+
|
75 |
+
def forward(self, x):
|
76 |
+
"""
|
77 |
+
:param x: input mini-batch (bs, samp)
|
78 |
+
"""
|
79 |
+
|
80 |
+
with torch.cuda.amp.autocast(enabled=False):
|
81 |
+
x = self.preprocess(x)
|
82 |
+
x = torch.abs(self.conv1(x))
|
83 |
+
if self.log_sinc:
|
84 |
+
x = torch.log(x + 1e-6)
|
85 |
+
if self.norm_sinc == "mean":
|
86 |
+
x = x - torch.mean(x, dim=-1, keepdim=True)
|
87 |
+
elif self.norm_sinc == "mean_std":
|
88 |
+
m = torch.mean(x, dim=-1, keepdim=True)
|
89 |
+
s = torch.std(x, dim=-1, keepdim=True)
|
90 |
+
s[s < 0.001] = 0.001
|
91 |
+
x = (x - m) / s
|
92 |
+
|
93 |
+
if self.summed:
|
94 |
+
x1 = self.layer1(x)
|
95 |
+
x2 = self.layer2(x1)
|
96 |
+
x3 = self.layer3(self.mp3(x1) + x2)
|
97 |
+
else:
|
98 |
+
x1 = self.layer1(x)
|
99 |
+
x2 = self.layer2(x1)
|
100 |
+
x3 = self.layer3(x2)
|
101 |
+
|
102 |
+
x = self.layer4(torch.cat((self.mp3(x1), x2, x3), dim=1))
|
103 |
+
x = self.relu(x)
|
104 |
+
|
105 |
+
t = x.size()[-1]
|
106 |
+
|
107 |
+
if self.context:
|
108 |
+
global_x = torch.cat(
|
109 |
+
(
|
110 |
+
x,
|
111 |
+
torch.mean(x, dim=2, keepdim=True).repeat(1, 1, t),
|
112 |
+
torch.sqrt(
|
113 |
+
torch.var(x, dim=2, keepdim=True).clamp(
|
114 |
+
min=1e-4, max=1e4
|
115 |
+
)
|
116 |
+
).repeat(1, 1, t),
|
117 |
+
),
|
118 |
+
dim=1,
|
119 |
+
)
|
120 |
+
else:
|
121 |
+
global_x = x
|
122 |
+
|
123 |
+
w = self.attention(global_x)
|
124 |
+
|
125 |
+
mu = torch.sum(x * w, dim=2)
|
126 |
+
sg = torch.sqrt(
|
127 |
+
(torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-4, max=1e4)
|
128 |
+
)
|
129 |
+
|
130 |
+
x = torch.cat((mu, sg), 1)
|
131 |
+
|
132 |
+
x = self.bn5(x)
|
133 |
+
|
134 |
+
x = self.fc6(x)
|
135 |
+
|
136 |
+
if self.out_bn:
|
137 |
+
x = self.bn6(x)
|
138 |
+
|
139 |
+
return x
|
140 |
+
|
141 |
+
|
142 |
+
class PreEmphasis(torch.nn.Module):
|
143 |
+
def __init__(self, coef: float = 0.97) -> None:
|
144 |
+
super().__init__()
|
145 |
+
self.coef = coef
|
146 |
+
# make kernel
|
147 |
+
# In pytorch, the convolution operation uses cross-correlation. So, filter is flipped.
|
148 |
+
self.register_buffer(
|
149 |
+
"flipped_filter",
|
150 |
+
torch.FloatTensor([-self.coef, 1.0]).unsqueeze(0).unsqueeze(0),
|
151 |
+
)
|
152 |
+
|
153 |
+
def forward(self, input: torch.tensor) -> torch.tensor:
|
154 |
+
assert (
|
155 |
+
len(input.size()) == 2
|
156 |
+
), "The number of dimensions of input tensor must be 2!"
|
157 |
+
# reflect padding to match lengths of in/out
|
158 |
+
input = input.unsqueeze(1)
|
159 |
+
input = F.pad(input, (1, 0), "reflect")
|
160 |
+
return F.conv1d(input, self.flipped_filter)
|
161 |
+
|
162 |
+
|
163 |
+
class AFMS(nn.Module):
|
164 |
+
"""
|
165 |
+
Alpha-Feature map scaling, added to the output of each residual block[1,2].
|
166 |
+
|
167 |
+
Reference:
|
168 |
+
[1] RawNet2 : https://www.isca-speech.org/archive/Interspeech_2020/pdfs/1011.pdf
|
169 |
+
[2] AMFS : https://www.koreascience.or.kr/article/JAKO202029757857763.page
|
170 |
+
"""
|
171 |
+
|
172 |
+
def __init__(self, nb_dim: int) -> None:
|
173 |
+
super().__init__()
|
174 |
+
self.alpha = nn.Parameter(torch.ones((nb_dim, 1)))
|
175 |
+
self.fc = nn.Linear(nb_dim, nb_dim)
|
176 |
+
self.sig = nn.Sigmoid()
|
177 |
+
|
178 |
+
def forward(self, x):
|
179 |
+
y = F.adaptive_avg_pool1d(x, 1).view(x.size(0), -1)
|
180 |
+
y = self.sig(self.fc(y)).view(x.size(0), x.size(1), -1)
|
181 |
+
|
182 |
+
x = x + self.alpha
|
183 |
+
x = x * y
|
184 |
+
return x
|
185 |
+
|
186 |
+
|
187 |
+
class Bottle2neck(nn.Module):
|
188 |
+
def __init__(
|
189 |
+
self,
|
190 |
+
inplanes,
|
191 |
+
planes,
|
192 |
+
kernel_size=None,
|
193 |
+
dilation=None,
|
194 |
+
scale=4,
|
195 |
+
pool=False,
|
196 |
+
):
|
197 |
+
|
198 |
+
super().__init__()
|
199 |
+
|
200 |
+
width = int(math.floor(planes / scale))
|
201 |
+
|
202 |
+
self.conv1 = nn.Conv1d(inplanes, width * scale, kernel_size=1)
|
203 |
+
self.bn1 = nn.BatchNorm1d(width * scale)
|
204 |
+
|
205 |
+
self.nums = scale - 1
|
206 |
+
|
207 |
+
convs = []
|
208 |
+
bns = []
|
209 |
+
|
210 |
+
num_pad = math.floor(kernel_size / 2) * dilation
|
211 |
+
|
212 |
+
for i in range(self.nums):
|
213 |
+
convs.append(
|
214 |
+
nn.Conv1d(
|
215 |
+
width,
|
216 |
+
width,
|
217 |
+
kernel_size=kernel_size,
|
218 |
+
dilation=dilation,
|
219 |
+
padding=num_pad,
|
220 |
+
)
|
221 |
+
)
|
222 |
+
bns.append(nn.BatchNorm1d(width))
|
223 |
+
|
224 |
+
self.convs = nn.ModuleList(convs)
|
225 |
+
self.bns = nn.ModuleList(bns)
|
226 |
+
|
227 |
+
self.conv3 = nn.Conv1d(width * scale, planes, kernel_size=1)
|
228 |
+
self.bn3 = nn.BatchNorm1d(planes)
|
229 |
+
|
230 |
+
self.relu = nn.ReLU()
|
231 |
+
|
232 |
+
self.width = width
|
233 |
+
|
234 |
+
self.mp = nn.MaxPool1d(pool) if pool else False
|
235 |
+
self.afms = AFMS(planes)
|
236 |
+
|
237 |
+
if inplanes != planes: # if change in number of filters
|
238 |
+
self.residual = nn.Sequential(
|
239 |
+
nn.Conv1d(inplanes, planes, kernel_size=1, stride=1, bias=False)
|
240 |
+
)
|
241 |
+
else:
|
242 |
+
self.residual = nn.Identity()
|
243 |
+
|
244 |
+
def forward(self, x):
|
245 |
+
residual = self.residual(x)
|
246 |
+
|
247 |
+
out = self.conv1(x)
|
248 |
+
out = self.relu(out)
|
249 |
+
out = self.bn1(out)
|
250 |
+
|
251 |
+
spx = torch.split(out, self.width, 1)
|
252 |
+
for i in range(self.nums):
|
253 |
+
if i == 0:
|
254 |
+
sp = spx[i]
|
255 |
+
else:
|
256 |
+
sp = sp + spx[i]
|
257 |
+
sp = self.convs[i](sp)
|
258 |
+
sp = self.relu(sp)
|
259 |
+
sp = self.bns[i](sp)
|
260 |
+
if i == 0:
|
261 |
+
out = sp
|
262 |
+
else:
|
263 |
+
out = torch.cat((out, sp), 1)
|
264 |
+
|
265 |
+
out = torch.cat((out, spx[self.nums]), 1)
|
266 |
+
|
267 |
+
out = self.conv3(out)
|
268 |
+
out = self.relu(out)
|
269 |
+
out = self.bn3(out)
|
270 |
+
|
271 |
+
out += residual
|
272 |
+
if self.mp:
|
273 |
+
out = self.mp(out)
|
274 |
+
out = self.afms(out)
|
275 |
+
|
276 |
+
return out
|
277 |
+
|
278 |
+
|
279 |
+
def prepare_model():
|
280 |
+
model = RawNet3(
|
281 |
+
Bottle2neck,
|
282 |
+
model_scale=8,
|
283 |
+
context=True,
|
284 |
+
summed=True,
|
285 |
+
encoder_type="ECA",
|
286 |
+
nOut=1, # number of slices
|
287 |
+
out_bn=False,
|
288 |
+
sinc_stride=10,
|
289 |
+
log_sinc=True,
|
290 |
+
norm_sinc="mean",
|
291 |
+
grad_mult=1,
|
292 |
+
)
|
293 |
+
return model
|
294 |
+
|
295 |
+
|
296 |
+
if __name__ == "__main__":
|
297 |
+
model = RawNet3(
|
298 |
+
Bottle2neck,
|
299 |
+
model_scale=8,
|
300 |
+
context=True,
|
301 |
+
summed=True,
|
302 |
+
encoder_type="ECA",
|
303 |
+
nOut=1, # number of slices
|
304 |
+
out_bn=False,
|
305 |
+
sinc_stride=10,
|
306 |
+
log_sinc=True,
|
307 |
+
norm_sinc="mean",
|
308 |
+
grad_mult=1,
|
309 |
+
)
|
310 |
+
gpu = False
|
311 |
+
|
312 |
+
model.eval()
|
313 |
+
print("RawNet3 initialised & weights loaded!")
|
314 |
+
|
315 |
+
if torch.cuda.is_available():
|
316 |
+
print("Cuda available, conducting inference on GPU")
|
317 |
+
model = model.to("cuda")
|
318 |
+
gpu = True
|
319 |
+
|
320 |
+
audios = torch.rand(32, 64_600)
|
321 |
+
|
322 |
+
out = model(audios)
|
323 |
+
print(out.shape)
|
src/models/specrnet.py
ADDED
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This file contains implementation of SpecRNet architecture.
|
3 |
+
We base our codebase on the implementation of RawNet2 by Hemlata Tak (tak@eurecom.fr).
|
4 |
+
It is available here: https://github.com/asvspoof-challenge/2021/blob/main/LA/Baseline-RawNet2/model.py
|
5 |
+
"""
|
6 |
+
from typing import Dict
|
7 |
+
|
8 |
+
import torch.nn as nn
|
9 |
+
|
10 |
+
from src import frontends
|
11 |
+
|
12 |
+
|
13 |
+
def get_config(input_channels: int) -> Dict:
|
14 |
+
return {
|
15 |
+
"filts": [input_channels, [input_channels, 20], [20, 64], [64, 64]],
|
16 |
+
"nb_fc_node": 64,
|
17 |
+
"gru_node": 64,
|
18 |
+
"nb_gru_layer": 2,
|
19 |
+
"nb_classes": 1,
|
20 |
+
}
|
21 |
+
|
22 |
+
|
23 |
+
class Residual_block2D(nn.Module):
|
24 |
+
def __init__(self, nb_filts, first=False):
|
25 |
+
super().__init__()
|
26 |
+
self.first = first
|
27 |
+
|
28 |
+
if not self.first:
|
29 |
+
self.bn1 = nn.BatchNorm2d(num_features=nb_filts[0])
|
30 |
+
|
31 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.3)
|
32 |
+
|
33 |
+
self.conv1 = nn.Conv2d(
|
34 |
+
in_channels=nb_filts[0],
|
35 |
+
out_channels=nb_filts[1],
|
36 |
+
kernel_size=3,
|
37 |
+
padding=1,
|
38 |
+
stride=1,
|
39 |
+
)
|
40 |
+
|
41 |
+
self.bn2 = nn.BatchNorm2d(num_features=nb_filts[1])
|
42 |
+
self.conv2 = nn.Conv2d(
|
43 |
+
in_channels=nb_filts[1],
|
44 |
+
out_channels=nb_filts[1],
|
45 |
+
padding=1,
|
46 |
+
kernel_size=3,
|
47 |
+
stride=1,
|
48 |
+
)
|
49 |
+
|
50 |
+
if nb_filts[0] != nb_filts[1]:
|
51 |
+
self.downsample = True
|
52 |
+
self.conv_downsample = nn.Conv2d(
|
53 |
+
in_channels=nb_filts[0],
|
54 |
+
out_channels=nb_filts[1],
|
55 |
+
padding=0,
|
56 |
+
kernel_size=1,
|
57 |
+
stride=1,
|
58 |
+
)
|
59 |
+
|
60 |
+
else:
|
61 |
+
self.downsample = False
|
62 |
+
self.mp = nn.MaxPool2d(2)
|
63 |
+
|
64 |
+
def forward(self, x):
|
65 |
+
identity = x
|
66 |
+
if not self.first:
|
67 |
+
out = self.bn1(x)
|
68 |
+
out = self.lrelu(out)
|
69 |
+
else:
|
70 |
+
out = x
|
71 |
+
|
72 |
+
out = self.conv1(x)
|
73 |
+
out = self.bn2(out)
|
74 |
+
out = self.lrelu(out)
|
75 |
+
out = self.conv2(out)
|
76 |
+
|
77 |
+
if self.downsample:
|
78 |
+
identity = self.conv_downsample(identity)
|
79 |
+
|
80 |
+
out += identity
|
81 |
+
out = self.mp(out)
|
82 |
+
return out
|
83 |
+
|
84 |
+
|
85 |
+
class SpecRNet(nn.Module):
|
86 |
+
def __init__(self, input_channels, **kwargs):
|
87 |
+
super().__init__()
|
88 |
+
config = get_config(input_channels=input_channels)
|
89 |
+
|
90 |
+
self.device = kwargs.get("device", "cuda")
|
91 |
+
|
92 |
+
self.first_bn = nn.BatchNorm2d(num_features=config["filts"][0])
|
93 |
+
self.selu = nn.SELU(inplace=True)
|
94 |
+
self.block0 = nn.Sequential(
|
95 |
+
Residual_block2D(nb_filts=config["filts"][1], first=True)
|
96 |
+
)
|
97 |
+
self.block2 = nn.Sequential(Residual_block2D(nb_filts=config["filts"][2]))
|
98 |
+
config["filts"][2][0] = config["filts"][2][1]
|
99 |
+
self.block4 = nn.Sequential(Residual_block2D(nb_filts=config["filts"][2]))
|
100 |
+
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
101 |
+
|
102 |
+
self.fc_attention0 = self._make_attention_fc(
|
103 |
+
in_features=config["filts"][1][-1], l_out_features=config["filts"][1][-1]
|
104 |
+
)
|
105 |
+
self.fc_attention2 = self._make_attention_fc(
|
106 |
+
in_features=config["filts"][2][-1], l_out_features=config["filts"][2][-1]
|
107 |
+
)
|
108 |
+
self.fc_attention4 = self._make_attention_fc(
|
109 |
+
in_features=config["filts"][2][-1], l_out_features=config["filts"][2][-1]
|
110 |
+
)
|
111 |
+
|
112 |
+
self.bn_before_gru = nn.BatchNorm2d(num_features=config["filts"][2][-1])
|
113 |
+
self.gru = nn.GRU(
|
114 |
+
input_size=config["filts"][2][-1],
|
115 |
+
hidden_size=config["gru_node"],
|
116 |
+
num_layers=config["nb_gru_layer"],
|
117 |
+
batch_first=True,
|
118 |
+
bidirectional=True,
|
119 |
+
)
|
120 |
+
|
121 |
+
self.fc1_gru = nn.Linear(
|
122 |
+
in_features=config["gru_node"] * 2, out_features=config["nb_fc_node"] * 2
|
123 |
+
)
|
124 |
+
|
125 |
+
self.fc2_gru = nn.Linear(
|
126 |
+
in_features=config["nb_fc_node"] * 2,
|
127 |
+
out_features=config["nb_classes"],
|
128 |
+
bias=True,
|
129 |
+
)
|
130 |
+
|
131 |
+
self.sig = nn.Sigmoid()
|
132 |
+
|
133 |
+
def _compute_embedding(self, x):
|
134 |
+
x = self.first_bn(x)
|
135 |
+
x = self.selu(x)
|
136 |
+
|
137 |
+
x0 = self.block0(x)
|
138 |
+
y0 = self.avgpool(x0).view(x0.size(0), -1)
|
139 |
+
y0 = self.fc_attention0(y0)
|
140 |
+
y0 = self.sig(y0).view(y0.size(0), y0.size(1), -1)
|
141 |
+
y0 = y0.unsqueeze(-1)
|
142 |
+
x = x0 * y0 + y0
|
143 |
+
|
144 |
+
x = nn.MaxPool2d(2)(x)
|
145 |
+
|
146 |
+
x2 = self.block2(x)
|
147 |
+
y2 = self.avgpool(x2).view(x2.size(0), -1)
|
148 |
+
y2 = self.fc_attention2(y2)
|
149 |
+
y2 = self.sig(y2).view(y2.size(0), y2.size(1), -1)
|
150 |
+
y2 = y2.unsqueeze(-1)
|
151 |
+
x = x2 * y2 + y2
|
152 |
+
|
153 |
+
x = nn.MaxPool2d(2)(x)
|
154 |
+
|
155 |
+
x4 = self.block4(x)
|
156 |
+
y4 = self.avgpool(x4).view(x4.size(0), -1)
|
157 |
+
y4 = self.fc_attention4(y4)
|
158 |
+
y4 = self.sig(y4).view(y4.size(0), y4.size(1), -1)
|
159 |
+
y4 = y4.unsqueeze(-1)
|
160 |
+
x = x4 * y4 + y4
|
161 |
+
|
162 |
+
x = nn.MaxPool2d(2)(x)
|
163 |
+
|
164 |
+
x = self.bn_before_gru(x)
|
165 |
+
x = self.selu(x)
|
166 |
+
x = nn.AdaptiveAvgPool2d((1, None))(x)
|
167 |
+
x = x.squeeze(-2)
|
168 |
+
x = x.permute(0, 2, 1)
|
169 |
+
self.gru.flatten_parameters()
|
170 |
+
x, _ = self.gru(x)
|
171 |
+
x = x[:, -1, :]
|
172 |
+
x = self.fc1_gru(x)
|
173 |
+
x = self.fc2_gru(x)
|
174 |
+
return x
|
175 |
+
|
176 |
+
def forward(self, x):
|
177 |
+
x = self._compute_embedding(x)
|
178 |
+
return x
|
179 |
+
|
180 |
+
def _make_attention_fc(self, in_features, l_out_features):
|
181 |
+
l_fc = []
|
182 |
+
l_fc.append(nn.Linear(in_features=in_features, out_features=l_out_features))
|
183 |
+
return nn.Sequential(*l_fc)
|
184 |
+
|
185 |
+
|
186 |
+
class FrontendSpecRNet(SpecRNet):
|
187 |
+
def __init__(self, input_channels, **kwargs):
|
188 |
+
super().__init__(input_channels, **kwargs)
|
189 |
+
|
190 |
+
self.device = kwargs['device']
|
191 |
+
|
192 |
+
frontend_name = kwargs.get("frontend_algorithm", [])
|
193 |
+
self.frontend = frontends.get_frontend(frontend_name)
|
194 |
+
print(f"Using {frontend_name} frontend")
|
195 |
+
|
196 |
+
def _compute_frontend(self, x):
|
197 |
+
frontend = self.frontend(x)
|
198 |
+
if frontend.ndim < 4:
|
199 |
+
return frontend.unsqueeze(1) # (bs, 1, n_lfcc, frames)
|
200 |
+
return frontend # (bs, n, n_lfcc, frames)
|
201 |
+
|
202 |
+
def forward(self, x):
|
203 |
+
x = self._compute_frontend(x)
|
204 |
+
x = self._compute_embedding(x)
|
205 |
+
return x
|
206 |
+
|
207 |
+
|
208 |
+
if __name__ == "__main__":
|
209 |
+
print("Definition of model")
|
210 |
+
device = "cuda"
|
211 |
+
|
212 |
+
input_channels = 1
|
213 |
+
config = {
|
214 |
+
"filts": [input_channels, [input_channels, 20], [20, 64], [64, 64]],
|
215 |
+
"nb_fc_node": 64,
|
216 |
+
"gru_node": 64,
|
217 |
+
"nb_gru_layer": 2,
|
218 |
+
"nb_classes": 1,
|
219 |
+
}
|
220 |
+
|
221 |
+
def count_parameters(model) -> int:
|
222 |
+
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
223 |
+
return pytorch_total_params
|
224 |
+
model = FrontendSpecRNet(input_channels=1, device=device, frontend_algorithm=["lfcc"])
|
225 |
+
model = model.to(device)
|
226 |
+
print(count_parameters(model))
|
src/models/whisper_lcnn.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from src.models.whisper_main import ModelDimensions, Whisper, log_mel_spectrogram
|
4 |
+
from src.models.lcnn import LCNN
|
5 |
+
from src import frontends
|
6 |
+
from src.commons import WHISPER_MODEL_WEIGHTS_PATH
|
7 |
+
|
8 |
+
|
9 |
+
class WhisperLCNN(LCNN):
|
10 |
+
|
11 |
+
def __init__(self, input_channels, freeze_encoder, **kwargs):
|
12 |
+
super().__init__(input_channels=input_channels, **kwargs)
|
13 |
+
|
14 |
+
self.device = kwargs['device']
|
15 |
+
checkpoint = torch.load(WHISPER_MODEL_WEIGHTS_PATH)
|
16 |
+
dims = ModelDimensions(**checkpoint["dims"].__dict__)
|
17 |
+
model = Whisper(dims)
|
18 |
+
model = model.to(self.device)
|
19 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
20 |
+
self.whisper_model = model
|
21 |
+
if freeze_encoder:
|
22 |
+
for param in self.whisper_model.parameters():
|
23 |
+
param.requires_grad = False
|
24 |
+
|
25 |
+
def compute_whisper_features(self, x):
|
26 |
+
specs = []
|
27 |
+
for sample in x:
|
28 |
+
specs.append(log_mel_spectrogram(sample))
|
29 |
+
x = torch.stack(specs)
|
30 |
+
x = self.whisper_model(x)
|
31 |
+
|
32 |
+
x = x.permute(0, 2, 1) # (bs, frames, 3 x n_lfcc)
|
33 |
+
x = x.unsqueeze(1) # (bs, 1, frames, 3 x n_lfcc)
|
34 |
+
x = x.repeat(
|
35 |
+
(1, 1, 1, 2)
|
36 |
+
) # (bs, 1, frames, 3 x n_lfcc) -> (bs, 1, frames, 3000)
|
37 |
+
return x
|
38 |
+
|
39 |
+
def forward(self, x):
|
40 |
+
# we assume that the data is correct (i.e. 30s)
|
41 |
+
x = self.compute_whisper_features(x)
|
42 |
+
out = self._compute_embedding(x)
|
43 |
+
return out
|
44 |
+
|
45 |
+
|
46 |
+
class WhisperMultiFrontLCNN(WhisperLCNN):
|
47 |
+
|
48 |
+
def __init__(self, input_channels, freeze_encoder, **kwargs):
|
49 |
+
super().__init__(input_channels=input_channels, freeze_encoder=freeze_encoder, **kwargs)
|
50 |
+
|
51 |
+
self.frontend = frontends.get_frontend(kwargs['frontend_algorithm'])
|
52 |
+
print(f"Using {self.frontend} frontend!")
|
53 |
+
|
54 |
+
def forward(self, x):
|
55 |
+
# Frontend computation
|
56 |
+
frontend_x = self.frontend(x)
|
57 |
+
x = self.compute_whisper_features(x)
|
58 |
+
|
59 |
+
x = torch.cat([x, frontend_x], 1)
|
60 |
+
out = self._compute_embedding(x)
|
61 |
+
return out
|
62 |
+
|
63 |
+
|
64 |
+
if __name__ == "__main__":
|
65 |
+
import numpy as np
|
66 |
+
|
67 |
+
input_channels = 1
|
68 |
+
device = "cpu"
|
69 |
+
classifier = WhisperLCNN(
|
70 |
+
input_channels=input_channels,
|
71 |
+
freeze_encoder=True,
|
72 |
+
device=device,
|
73 |
+
)
|
74 |
+
|
75 |
+
input_channels = 2
|
76 |
+
classifier_2 = WhisperMultiFrontLCNN(
|
77 |
+
input_channels=input_channels,
|
78 |
+
freeze_encoder=True,
|
79 |
+
device=device,
|
80 |
+
frontend_algorithm="lfcc"
|
81 |
+
)
|
82 |
+
x = np.random.rand(2, 30 * 16_000).astype(np.float32)
|
83 |
+
x = torch.from_numpy(x)
|
84 |
+
|
85 |
+
out = classifier(x)
|
86 |
+
print(out.shape)
|
87 |
+
|
88 |
+
out = classifier_2(x)
|
89 |
+
print(out.shape)
|
src/models/whisper_main.py
ADDED
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Based on https://github.com/openai/whisper/blob/main/whisper/model.py
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from functools import lru_cache
|
4 |
+
import os
|
5 |
+
from typing import Iterable, Optional, Union
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from torch import Tensor
|
11 |
+
from torch import nn
|
12 |
+
|
13 |
+
|
14 |
+
def exact_div(x, y):
|
15 |
+
assert x % y == 0
|
16 |
+
return x // y
|
17 |
+
|
18 |
+
|
19 |
+
# hard-coded audio hyperparameters
|
20 |
+
SAMPLE_RATE = 16000
|
21 |
+
N_FFT = 400
|
22 |
+
N_MELS = 80
|
23 |
+
HOP_LENGTH = 160
|
24 |
+
CHUNK_LENGTH = 30
|
25 |
+
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk
|
26 |
+
N_FRAMES = exact_div(
|
27 |
+
N_SAMPLES, HOP_LENGTH
|
28 |
+
) # 3000: number of frames in a mel spectrogram input
|
29 |
+
|
30 |
+
|
31 |
+
def pad_or_trim(
|
32 |
+
array: Union[torch.Tensor, np.ndarray],
|
33 |
+
length: int = N_SAMPLES,
|
34 |
+
*,
|
35 |
+
axis: int = -1,
|
36 |
+
) -> torch.Tensor:
|
37 |
+
"""
|
38 |
+
Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
|
39 |
+
"""
|
40 |
+
if not torch.is_tensor(array):
|
41 |
+
array = torch.from_numpy(array)
|
42 |
+
|
43 |
+
if array.shape[axis] > length:
|
44 |
+
array = array.index_select(
|
45 |
+
dim=axis, index=torch.arange(length, device=array.device)
|
46 |
+
)
|
47 |
+
|
48 |
+
if array.shape[axis] < length:
|
49 |
+
# pad multiple times
|
50 |
+
num_repeats = int(length / array.shape[axis]) + 1
|
51 |
+
array = torch.tile(array, (1, num_repeats))[:, :length]
|
52 |
+
return array
|
53 |
+
|
54 |
+
|
55 |
+
@lru_cache(maxsize=None)
|
56 |
+
def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
|
57 |
+
"""
|
58 |
+
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
|
59 |
+
Allows decoupling librosa dependency; saved using:
|
60 |
+
|
61 |
+
np.savez_compressed(
|
62 |
+
"mel_filters.npz",
|
63 |
+
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
|
64 |
+
)
|
65 |
+
"""
|
66 |
+
assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
|
67 |
+
with np.load(
|
68 |
+
os.path.join(os.path.dirname(__file__), "assets/mel_filters.npz")
|
69 |
+
) as f:
|
70 |
+
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
|
71 |
+
|
72 |
+
|
73 |
+
def log_mel_spectrogram(audio: torch.Tensor, n_mels: int = N_MELS):
|
74 |
+
"""
|
75 |
+
Compute the log-Mel spectrogram of
|
76 |
+
|
77 |
+
Parameters
|
78 |
+
----------
|
79 |
+
audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
|
80 |
+
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
|
81 |
+
|
82 |
+
n_mels: int
|
83 |
+
The number of Mel-frequency filters, only 80 is supported
|
84 |
+
|
85 |
+
Returns
|
86 |
+
-------
|
87 |
+
torch.Tensor, shape = (80, n_frames)
|
88 |
+
A Tensor that contains the Mel spectrogram
|
89 |
+
"""
|
90 |
+
window = torch.hann_window(N_FFT).to(audio.device)
|
91 |
+
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
|
92 |
+
magnitudes = stft[:, :-1].abs() ** 2
|
93 |
+
|
94 |
+
filters = mel_filters(audio.device, n_mels)
|
95 |
+
mel_spec = filters @ magnitudes
|
96 |
+
|
97 |
+
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
98 |
+
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
99 |
+
log_spec = (log_spec + 4.0) / 4.0
|
100 |
+
return log_spec
|
101 |
+
|
102 |
+
|
103 |
+
@dataclass
|
104 |
+
class ModelDimensions:
|
105 |
+
n_mels: int
|
106 |
+
n_audio_ctx: int
|
107 |
+
n_audio_state: int
|
108 |
+
n_audio_head: int
|
109 |
+
n_audio_layer: int
|
110 |
+
n_vocab: int
|
111 |
+
n_text_ctx: int
|
112 |
+
n_text_state: int
|
113 |
+
n_text_head: int
|
114 |
+
n_text_layer: int
|
115 |
+
|
116 |
+
|
117 |
+
class LayerNorm(nn.LayerNorm):
|
118 |
+
def forward(self, x: Tensor) -> Tensor:
|
119 |
+
return super().forward(x.float()).type(x.dtype)
|
120 |
+
|
121 |
+
|
122 |
+
class Linear(nn.Linear):
|
123 |
+
def forward(self, x: Tensor) -> Tensor:
|
124 |
+
return F.linear(
|
125 |
+
x,
|
126 |
+
self.weight.to(x.dtype),
|
127 |
+
None if self.bias is None else self.bias.to(x.dtype),
|
128 |
+
)
|
129 |
+
|
130 |
+
|
131 |
+
class Conv1d(nn.Conv1d):
|
132 |
+
def _conv_forward(
|
133 |
+
self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
|
134 |
+
) -> Tensor:
|
135 |
+
return super()._conv_forward(
|
136 |
+
x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
|
137 |
+
)
|
138 |
+
|
139 |
+
|
140 |
+
def sinusoids(length, channels, max_timescale=10_000):
|
141 |
+
"""Returns sinusoids for positional embedding"""
|
142 |
+
assert channels % 2 == 0
|
143 |
+
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
|
144 |
+
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
|
145 |
+
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
|
146 |
+
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
|
147 |
+
|
148 |
+
|
149 |
+
class MultiHeadAttention(nn.Module):
|
150 |
+
def __init__(self, n_state: int, n_head: int):
|
151 |
+
super().__init__()
|
152 |
+
self.n_head = n_head
|
153 |
+
self.query = Linear(n_state, n_state)
|
154 |
+
self.key = Linear(n_state, n_state, bias=False)
|
155 |
+
self.value = Linear(n_state, n_state)
|
156 |
+
self.out = Linear(n_state, n_state)
|
157 |
+
|
158 |
+
def forward(
|
159 |
+
self,
|
160 |
+
x: Tensor,
|
161 |
+
xa: Optional[Tensor] = None,
|
162 |
+
mask: Optional[Tensor] = None,
|
163 |
+
kv_cache: Optional[dict] = None,
|
164 |
+
):
|
165 |
+
q = self.query(x)
|
166 |
+
|
167 |
+
if kv_cache is None or xa is None or self.key not in kv_cache:
|
168 |
+
# hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
|
169 |
+
# otherwise, perform key/value projections for self- or cross-attention as usual.
|
170 |
+
k = self.key(x if xa is None else xa)
|
171 |
+
v = self.value(x if xa is None else xa)
|
172 |
+
else:
|
173 |
+
# for cross-attention, calculate keys and values once and reuse in subsequent calls.
|
174 |
+
k = kv_cache[self.key]
|
175 |
+
v = kv_cache[self.value]
|
176 |
+
|
177 |
+
wv = self.qkv_attention(q, k, v, mask)
|
178 |
+
return self.out(wv)
|
179 |
+
|
180 |
+
def qkv_attention(
|
181 |
+
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
|
182 |
+
):
|
183 |
+
n_batch, n_ctx, n_state = q.shape
|
184 |
+
scale = (n_state // self.n_head) ** -0.25
|
185 |
+
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
|
186 |
+
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
|
187 |
+
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
188 |
+
|
189 |
+
qk = q @ k
|
190 |
+
if mask is not None:
|
191 |
+
qk = qk + mask[:n_ctx, :n_ctx]
|
192 |
+
|
193 |
+
w = F.softmax(qk.float(), dim=-1).to(q.dtype)
|
194 |
+
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
|
195 |
+
|
196 |
+
|
197 |
+
class ResidualAttentionBlock(nn.Module):
|
198 |
+
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
|
199 |
+
super().__init__()
|
200 |
+
|
201 |
+
self.attn = MultiHeadAttention(n_state, n_head)
|
202 |
+
self.attn_ln = LayerNorm(n_state)
|
203 |
+
|
204 |
+
self.cross_attn = (
|
205 |
+
MultiHeadAttention(n_state, n_head) if cross_attention else None
|
206 |
+
)
|
207 |
+
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
|
208 |
+
|
209 |
+
n_mlp = n_state * 4
|
210 |
+
self.mlp = nn.Sequential(
|
211 |
+
Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
|
212 |
+
)
|
213 |
+
self.mlp_ln = LayerNorm(n_state)
|
214 |
+
|
215 |
+
def forward(
|
216 |
+
self,
|
217 |
+
x: Tensor,
|
218 |
+
xa: Optional[Tensor] = None,
|
219 |
+
mask: Optional[Tensor] = None,
|
220 |
+
kv_cache: Optional[dict] = None,
|
221 |
+
):
|
222 |
+
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)
|
223 |
+
if self.cross_attn:
|
224 |
+
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)
|
225 |
+
x = x + self.mlp(self.mlp_ln(x))
|
226 |
+
return x
|
227 |
+
|
228 |
+
|
229 |
+
class AudioEncoder(nn.Module):
|
230 |
+
def __init__(
|
231 |
+
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
|
232 |
+
):
|
233 |
+
super().__init__()
|
234 |
+
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
|
235 |
+
self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
|
236 |
+
self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
|
237 |
+
|
238 |
+
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
|
239 |
+
[ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
|
240 |
+
)
|
241 |
+
self.ln_post = LayerNorm(n_state)
|
242 |
+
|
243 |
+
def forward(self, x: Tensor):
|
244 |
+
"""
|
245 |
+
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
|
246 |
+
the mel spectrogram of the audio
|
247 |
+
"""
|
248 |
+
x = F.gelu(self.conv1(x))
|
249 |
+
x = F.gelu(self.conv2(x))
|
250 |
+
x = x.permute(0, 2, 1)
|
251 |
+
|
252 |
+
assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
|
253 |
+
x = (x + self.positional_embedding).to(x.dtype)
|
254 |
+
for block in self.blocks:
|
255 |
+
x = block(x)
|
256 |
+
|
257 |
+
x = self.ln_post(x)
|
258 |
+
return x
|
259 |
+
|
260 |
+
|
261 |
+
class TextDecoder(nn.Module):
|
262 |
+
def __init__(
|
263 |
+
self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
|
264 |
+
):
|
265 |
+
super().__init__()
|
266 |
+
|
267 |
+
self.token_embedding = nn.Embedding(n_vocab, n_state)
|
268 |
+
self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
|
269 |
+
|
270 |
+
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
|
271 |
+
[
|
272 |
+
ResidualAttentionBlock(n_state, n_head, cross_attention=True)
|
273 |
+
for _ in range(n_layer)
|
274 |
+
]
|
275 |
+
)
|
276 |
+
self.ln = LayerNorm(n_state)
|
277 |
+
|
278 |
+
mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
|
279 |
+
self.register_buffer("mask", mask, persistent=False)
|
280 |
+
|
281 |
+
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
|
282 |
+
"""
|
283 |
+
x : torch.LongTensor, shape = (batch_size, <= n_ctx)
|
284 |
+
the text tokens
|
285 |
+
xa : torch.Tensor, shape = (batch_size, n_mels, n_audio_ctx)
|
286 |
+
the encoded audio features to be attended on
|
287 |
+
"""
|
288 |
+
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
|
289 |
+
x = (
|
290 |
+
self.token_embedding(x)
|
291 |
+
+ self.positional_embedding[offset : offset + x.shape[-1]]
|
292 |
+
)
|
293 |
+
x = x.to(xa.dtype)
|
294 |
+
|
295 |
+
for block in self.blocks:
|
296 |
+
x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
|
297 |
+
|
298 |
+
x = self.ln(x)
|
299 |
+
logits = (
|
300 |
+
x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
|
301 |
+
).float()
|
302 |
+
|
303 |
+
return logits
|
304 |
+
|
305 |
+
|
306 |
+
class Whisper(nn.Module):
|
307 |
+
def __init__(self, dims: ModelDimensions):
|
308 |
+
super().__init__()
|
309 |
+
self.dims = dims
|
310 |
+
self.encoder = AudioEncoder(
|
311 |
+
self.dims.n_mels,
|
312 |
+
self.dims.n_audio_ctx,
|
313 |
+
self.dims.n_audio_state,
|
314 |
+
self.dims.n_audio_head,
|
315 |
+
self.dims.n_audio_layer,
|
316 |
+
)
|
317 |
+
|
318 |
+
def forward(self, mel: torch.Tensor):
|
319 |
+
return self.encoder(mel)
|
320 |
+
|
321 |
+
@property
|
322 |
+
def device(self):
|
323 |
+
return next(self.parameters()).device
|