diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..09d5fd7e15066e91a2e4e6bf06842a151bc2a21b 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,35 +1,5 @@ -*.7z filter=lfs diff=lfs merge=lfs -text -*.arrow filter=lfs diff=lfs merge=lfs -text -*.bin filter=lfs diff=lfs merge=lfs -text -*.bz2 filter=lfs diff=lfs merge=lfs -text -*.ckpt filter=lfs diff=lfs merge=lfs -text -*.ftz filter=lfs diff=lfs merge=lfs -text -*.gz filter=lfs diff=lfs merge=lfs -text -*.h5 filter=lfs diff=lfs merge=lfs -text -*.joblib filter=lfs diff=lfs merge=lfs -text -*.lfs.* filter=lfs diff=lfs merge=lfs -text -*.mlmodel filter=lfs diff=lfs merge=lfs -text -*.model filter=lfs diff=lfs merge=lfs -text -*.msgpack filter=lfs diff=lfs merge=lfs -text -*.npy filter=lfs diff=lfs merge=lfs -text -*.npz filter=lfs diff=lfs merge=lfs -text -*.onnx filter=lfs diff=lfs merge=lfs -text -*.ot filter=lfs diff=lfs merge=lfs -text -*.parquet filter=lfs diff=lfs merge=lfs -text -*.pb filter=lfs diff=lfs merge=lfs -text -*.pickle filter=lfs diff=lfs merge=lfs -text -*.pkl filter=lfs diff=lfs merge=lfs -text *.pt filter=lfs diff=lfs merge=lfs -text +*.jpg filter=lfs diff=lfs merge=lfs -text +*.wav filter=lfs diff=lfs merge=lfs -text +*.mp3 filter=lfs diff=lfs merge=lfs -text *.pth filter=lfs diff=lfs merge=lfs -text -*.rar filter=lfs diff=lfs merge=lfs -text -*.safetensors filter=lfs diff=lfs merge=lfs -text -saved_model/**/* filter=lfs diff=lfs merge=lfs -text -*.tar.* filter=lfs diff=lfs merge=lfs -text -*.tar filter=lfs diff=lfs merge=lfs -text -*.tflite filter=lfs diff=lfs merge=lfs -text -*.tgz filter=lfs diff=lfs merge=lfs -text -*.wasm filter=lfs diff=lfs merge=lfs -text -*.xz filter=lfs diff=lfs merge=lfs -text -*.zip filter=lfs diff=lfs merge=lfs -text -*.zst filter=lfs diff=lfs merge=lfs -text -*tfevents* filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..68bc17f9ff2104a9d7b6777058bb4c343ca72609 --- /dev/null +++ b/.gitignore @@ -0,0 +1,160 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..cc541dce72d1e0f961ec29cd124d51330572ff81 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) Piotr Kawa, Marcin Plata, Michał Czuba, Piotr Szymański, Piotr Syga + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..d1a508d301d4ca0bce29fc9a9589898d1eac5e68 --- /dev/null +++ b/app.py @@ -0,0 +1,108 @@ +import base64 +import json +import os, shutil +import re +import time +import uuid + +import cv2 + +import numpy as np +import streamlit as st +from pydub import AudioSegment +import torch +import yaml +# from extract_video import extract_method_single_video + +from utils import st_file_selector, img2base64 +from evaluate_models import inference, load_model +from src import commons + +import os + +DEBUG = True + +def main(): + st.markdown("###") + uploaded_file = st.file_uploader('Upload an audio file', type=['wav', 'mp3'], accept_multiple_files=False) + + with st.spinner(f'Loading samples...'): + while not os.path.isdir("sample_files"): + time.sleep(1) + st.markdown("### or") + selected_file = st_file_selector(st, path='sample_files', key = 'selected', label = 'Choose a sample image/video') + + if uploaded_file: + random_id = uuid.uuid1() + ext = uploaded_file.name.split('.')[-1] + + base_folder = "temps" + filename = "{}.{}".format(random_id, ext) + file_type = uploaded_file.type.split("/")[0] + filepath = f"{base_folder}/{filename}" + + uploaded_file_length = len(uploaded_file.getvalue()) + if uploaded_file_length > 0: + with open(filepath, 'wb') as f: + f.write(uploaded_file.read()) + st.audio(uploaded_file, format=ext) + elif selected_file: + base_folder = "sample_files" + file_type = selected_file.split(".")[-1] + filename = selected_file.split("/")[-1] + filepath = f"{base_folder}/{selected_file}" + + st.write('file_type', file_type) + with open(filepath, 'rb') as f: + audio_bytes = f.read() + st.audio(audio_bytes, format=file_type) + else: + return + + + + + with st.spinner(f'Analyzing {file_type}...'): + + + seed = config["data"].get("seed", 42) + # fix all seeds - this should not actually change anything + commons.set_seed(seed) + + result = inference( + model, + datasets_path=filepath, + device=device, + ) + result = result[0] + + if 'Real' == result[0]: + st.success(f'Audio is real! \nprob:{result[1]}', icon="✅") + else: + st.error(f'Audio is fake! \nprob:{result[1]}', icon="🚨") + + st.divider() + st.write('## Response JSON') + st.write(result) + + +def setup(): + if not os.path.isdir("temps"): + os.makedirs("temps") + + + +if __name__ == "__main__": + if torch.cuda.is_available(): + device = "cuda" + else: + device = "cpu" + + with open('config.yaml', "r") as f: + config = yaml.safe_load(f) + + model = load_model(config, device) + + st.title("Face Fake Detection") + setup() + main() \ No newline at end of file diff --git a/config.yaml b/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f05b51991c384ef5ee5fcd6fc2e6744394c51cc3 --- /dev/null +++ b/config.yaml @@ -0,0 +1,16 @@ +data: + seed: 42 + +checkpoint: + path: C:\Users\manfr\Projects\deepfake-whisper-features\mesonet_whisper_mfcc_finetuned.pth + +model: + name: whisper_frontend_mesonet + optimizer: + lr: 1.0e-06 + weight_decay: 0.0001 + parameters: + fc1_dim: 1024 + freeze_encoder: false + frontend_algorithm: ["mfcc"] + input_channels: 2 diff --git a/configs/finetuning/whisper_frontend_mesonet.yaml b/configs/finetuning/whisper_frontend_mesonet.yaml new file mode 100644 index 0000000000000000000000000000000000000000..12fd2c41d6f47b6ce2e3ee5b98c3d4b68dfe6a02 --- /dev/null +++ b/configs/finetuning/whisper_frontend_mesonet.yaml @@ -0,0 +1,16 @@ +data: + seed: 42 + +checkpoint: + path: "trained_models/whisper_frontend_mesonet/ckpt.pth" + +model: + name: "whisper_frontend_mesonet" + parameters: + freeze_encoder: false + input_channels: 2 + fc1_dim: 1024 + frontend_algorithm: ["lfcc"] + optimizer: + lr: 1.0e-06 + weight_decay: 0.0001 \ No newline at end of file diff --git a/configs/training/lcnn.yaml b/configs/training/lcnn.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7e02fad540d4f84b5750e960d52880356aca3e0b --- /dev/null +++ b/configs/training/lcnn.yaml @@ -0,0 +1,14 @@ +data: + seed: 42 + +checkpoint: + path: "" + +model: + name: "lcnn" + parameters: + input_channels: 1 + frontend_algorithm: ["mfcc"] + optimizer: + lr: 0.0001 + weight_decay: 0.0001 diff --git a/configs/training/mesonet.yaml b/configs/training/mesonet.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f49d09fcd6e9d79097c2f63f8148be029ce8cde3 --- /dev/null +++ b/configs/training/mesonet.yaml @@ -0,0 +1,15 @@ +data: + seed: 42 + +checkpoint: + path: "" + +model: + name: "mesonet" + parameters: + input_channels: 1 + fc1_dim: 1024 + frontend_algorithm: ["lfcc"] + optimizer: + lr: 0.0001 + weight_decay: 0.0001 diff --git a/configs/training/rawnet3.yaml b/configs/training/rawnet3.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0f1b60a872f3a705058dbabcf9540c58b90e5320 --- /dev/null +++ b/configs/training/rawnet3.yaml @@ -0,0 +1,13 @@ +data: + seed: 42 + +checkpoint: + path: "" + +model: + name: "rawnet3" + parameters: {} + optimizer: + lr: 0.001 + weight_decay: 0.00005 # 5e-5 + diff --git a/configs/training/specrnet.yaml b/configs/training/specrnet.yaml new file mode 100644 index 0000000000000000000000000000000000000000..64e23e848de15c30d14a149237d59d15f0f8355a --- /dev/null +++ b/configs/training/specrnet.yaml @@ -0,0 +1,14 @@ +data: + seed: 42 + +checkpoint: + path: "" + +model: + name: "specrnet" + parameters: + input_channels: 1 + frontend_algorithm: ["lfcc"] + optimizer: + lr: 0.0001 + weight_decay: 0.0001 diff --git a/configs/training/whisper_frontend_lcnn.yaml b/configs/training/whisper_frontend_lcnn.yaml new file mode 100644 index 0000000000000000000000000000000000000000..608c0bef8754e42534097769818b111a549bb461 --- /dev/null +++ b/configs/training/whisper_frontend_lcnn.yaml @@ -0,0 +1,16 @@ +data: + seed: 42 + +checkpoint: + path: "" + +model: + name: "whisper_frontend_lcnn" + parameters: + freeze_encoder: True + input_channels: 2 + frontend_algorithm: ["lfcc"] + optimizer: + lr: 0.0001 + weight_decay: 0.0001 + diff --git a/configs/training/whisper_frontend_lcnn_mfcc.yaml b/configs/training/whisper_frontend_lcnn_mfcc.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b7e1a90e3f23157a9ecb4b0065f7010686894996 --- /dev/null +++ b/configs/training/whisper_frontend_lcnn_mfcc.yaml @@ -0,0 +1,15 @@ +data: + seed: 42 + +checkpoint: + path: "" + +model: + name: "whisper_frontend_lcnn" + parameters: + freeze_encoder: True + input_channels: 2 + frontend_algorithm: ["mfcc"] + optimizer: + lr: 0.0001 + weight_decay: 0.0001 diff --git a/configs/training/whisper_frontend_mesonet.yaml b/configs/training/whisper_frontend_mesonet.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2b4bcda552ba0ebc20daf5204fa281a183375199 --- /dev/null +++ b/configs/training/whisper_frontend_mesonet.yaml @@ -0,0 +1,16 @@ +data: + seed: 42 + +checkpoint: + path: "" + +model: + name: "whisper_frontend_mesonet" + parameters: + freeze_encoder: True + input_channels: 2 + fc1_dim: 1024 + frontend_algorithm: ["lfcc"] + optimizer: + lr: 0.0001 + weight_decay: 0.0001 diff --git a/configs/training/whisper_frontend_mesonet_mfcc.yaml b/configs/training/whisper_frontend_mesonet_mfcc.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c0b3328a25c8fdfbfd7ad69dac32f109aaa35533 --- /dev/null +++ b/configs/training/whisper_frontend_mesonet_mfcc.yaml @@ -0,0 +1,17 @@ +data: + seed: 42 + +checkpoint: + path: "" + +model: + name: "whisper_frontend_mesonet" + parameters: + freeze_encoder: True + input_channels: 2 + fc1_dim: 1024 + frontend_algorithm: ["mfcc"] + optimizer: + lr: 0.0001 + weight_decay: 0.0001 + diff --git a/configs/training/whisper_frontend_specrnet.yaml b/configs/training/whisper_frontend_specrnet.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9827f53f34033cee23ad72bd9ca3add1dab3dae6 --- /dev/null +++ b/configs/training/whisper_frontend_specrnet.yaml @@ -0,0 +1,15 @@ +data: + seed: 42 + +checkpoint: + path: "" + +model: + name: "whisper_frontend_specrnet" + parameters: + freeze_encoder: True + input_channels: 2 + frontend_algorithm: ["lfcc"] + optimizer: + lr: 0.0001 + weight_decay: 0.0001 diff --git a/configs/training/whisper_frontend_specrnet_mfcc.yaml b/configs/training/whisper_frontend_specrnet_mfcc.yaml new file mode 100644 index 0000000000000000000000000000000000000000..93fc30a7ca4a53c94d8a15a956247c52e294ce7c --- /dev/null +++ b/configs/training/whisper_frontend_specrnet_mfcc.yaml @@ -0,0 +1,16 @@ +data: + seed: 42 + +checkpoint: + path: "" + +model: + name: "whisper_frontend_specrnet" + parameters: + freeze_encoder: True + input_channels: 2 + frontend_algorithm: ["mfcc"] + optimizer: + lr: 0.0001 + weight_decay: 0.0001 + diff --git a/configs/training/whisper_lcnn.yaml b/configs/training/whisper_lcnn.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5dfeaeee2934512b85df228540cc57cbde4a4486 --- /dev/null +++ b/configs/training/whisper_lcnn.yaml @@ -0,0 +1,15 @@ +data: + seed: 42 + +checkpoint: + path: "" + +model: + name: "whisper_lcnn" + parameters: + freeze_encoder: True + input_channels: 1 + frontend_algorithm: ["lfcc"] + optimizer: + lr: 0.0001 + weight_decay: 0.0001 diff --git a/configs/training/whisper_mesonet.yaml b/configs/training/whisper_mesonet.yaml new file mode 100644 index 0000000000000000000000000000000000000000..175977b74f2757d2b1e568c4d4af96940281f49e --- /dev/null +++ b/configs/training/whisper_mesonet.yaml @@ -0,0 +1,16 @@ +data: + seed: 42 + +checkpoint: + path: "" + +model: + name: "whisper_mesonet" + parameters: + freeze_encoder: True + input_channels: 1 + fc1_dim: 1024 + frontend_algorithm: [] + optimizer: + lr: 0.0001 + weight_decay: 0.0001 diff --git a/configs/training/whisper_specrnet.yaml b/configs/training/whisper_specrnet.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6676d071973ab580602b6d31fa38cab78ea0034a --- /dev/null +++ b/configs/training/whisper_specrnet.yaml @@ -0,0 +1,15 @@ +data: + seed: 42 + +checkpoint: + path: "" + +model: + name: "whisper_specrnet" + parameters: + freeze_encoder: True + input_channels: 1 + frontend_algorithm: ["lfcc"] + optimizer: + lr: 0.0001 + weight_decay: 0.0001 diff --git a/download_whisper.py b/download_whisper.py new file mode 100644 index 0000000000000000000000000000000000000000..ad377c82135e0bbd413eda89174414dc9b541e0a --- /dev/null +++ b/download_whisper.py @@ -0,0 +1,29 @@ +# pip install git+https://github.com/openai/whisper.git +from collections import OrderedDict +import whisper +import torch + +from src.commons import WHISPER_MODEL_WEIGHTS_PATH + +def download_whisper(): + model = whisper.load_model("tiny.en") + return model + + +def extract_and_save_encoder(model): + model_ckpt = OrderedDict() + + model_ckpt['model_state_dict'] = OrderedDict() + + for key, value in model.encoder.state_dict().items(): + model_ckpt['model_state_dict'][f'encoder.{key}'] = value + + model_ckpt['dims'] = model.dims + torch.save(model_ckpt, WHISPER_MODEL_WEIGHTS_PATH) + + +if __name__ == "__main__": + model = download_whisper() + print("Downloaded Whisper model!") + extract_and_save_encoder(model) + print(f"Saved encoder at '{WHISPER_MODEL_WEIGHTS_PATH}'") \ No newline at end of file diff --git a/evaluate_models.py b/evaluate_models.py new file mode 100644 index 0000000000000000000000000000000000000000..055cfe73336ff560eeede06648fcb300884aaee4 --- /dev/null +++ b/evaluate_models.py @@ -0,0 +1,316 @@ +import argparse +import logging +from pathlib import Path +from typing import Dict, List, Optional, Union +import sys + +import torch +import yaml +from sklearn.metrics import precision_recall_fscore_support, roc_auc_score +from torch.utils.data import DataLoader + +from src import metrics, commons +from src.models import models +from src.datasets.base_dataset import SimpleAudioFakeDataset +from src.datasets.in_the_wild_dataset import InTheWildDataset +from src.datasets.folder_dataset import FolderDataset, FileDataset + + +def get_dataset( + datasets_paths: List[Union[Path, str]], + amount_to_use: Optional[int], +) -> SimpleAudioFakeDataset: + data_val = FolderDataset( + path=datasets_paths[0] + ) + return data_val + +def get_dataset_file( + datasets_path, + amount_to_use: Optional[int], +) -> SimpleAudioFakeDataset: + data_val = FileDataset( + path=datasets_path + ) + return data_val + + +def evaluate_nn( + model_paths: List[Path], + datasets_paths: List[Union[Path, str]], + model_config: Dict, + device: str, + amount_to_use: Optional[int] = None, + batch_size: int = 8, +): + logging.info("Loading data...") + model_name, model_parameters = model_config["name"], model_config["parameters"] + + # Load model architecture + model = models.get_model( + model_name=model_name, + config=model_parameters, + device=device, + ) + # If provided weights, apply corresponding ones (from an appropriate fold) + if len(model_paths): + state_dict = torch.load(model_paths, map_location=device) + model.load_state_dict(state_dict) + model = model.to(device) + + data_val = get_dataset( + datasets_paths=datasets_paths, + amount_to_use=amount_to_use, + ) + + logging.info( + f"Testing '{model_name}' model, weights path: '{model_paths}', on {len(data_val)} audio files." + ) + test_loader = DataLoader( + data_val, + batch_size=batch_size, + shuffle=True, + drop_last=False, + num_workers=3, + ) + + batches_number = len(data_val) // batch_size + num_correct = 0.0 + num_total = 0.0 + + y_pred = torch.Tensor([]).to(device) + y = torch.Tensor([]).to(device) + y_pred_label = torch.Tensor([]).to(device) + + preds = [] + + for i, (batch_x, _, batch_y, metadata) in enumerate(test_loader): + model.eval() + _, path, _, _ = metadata + if i % 10 == 0: + print(f"Batch [{i}/{batches_number}]") + + with torch.no_grad(): + batch_x = batch_x.to(device) + batch_y = batch_y.to(device) + num_total += batch_x.size(0) + + batch_pred = model(batch_x).squeeze(1) + batch_pred = torch.sigmoid(batch_pred) + batch_pred_label = (batch_pred + 0.5).int() + + num_correct += (batch_pred_label == batch_y.int()).sum(dim=0).item() + + y_pred = torch.concat([y_pred, batch_pred], dim=0) + y_pred_label = torch.concat([y_pred_label, batch_pred_label], dim=0) + y = torch.concat([y, batch_y], dim=0) + + for i in range(len(y_pred_label)): + label = 'Fake' if y_pred_label[i] == 0 else 'Real' + print(f'{path[i]}') + print(f' Prediction: : {label}') + print(f' Probability: {y_pred[i]})') + preds.append((label, y_pred[i].detach().cpu().item())) + + return preds + + eval_accuracy = (num_correct / num_total) * 100 + + precision, recall, f1_score, support = precision_recall_fscore_support( + y.cpu().numpy(), y_pred_label.cpu().numpy(), average="binary", beta=1.0 + ) + auc_score = roc_auc_score(y_true=y.cpu().numpy(), y_score=y_pred.cpu().numpy()) + + # For EER flip values, following original evaluation implementation + y_for_eer = 1 - y + + thresh, eer, fpr, tpr = metrics.calculate_eer( + y=y_for_eer.cpu().numpy(), + y_score=y_pred.cpu().numpy(), + ) + + eer_label = f"eval/eer" + accuracy_label = f"eval/accuracy" + precision_label = f"eval/precision" + recall_label = f"eval/recall" + f1_label = f"eval/f1_score" + auc_label = f"eval/auc" + + logging.info( + f"{eer_label}: {eer:.4f}, {accuracy_label}: {eval_accuracy:.4f}, {precision_label}: {precision:.4f}, {recall_label}: {recall:.4f}, {f1_label}: {f1_score:.4f}, {auc_label}: {auc_score:.4f}" + ) + +def load_model(config, device): + model_config = config['model'] + model_name, model_parameters = model_config["name"], model_config["parameters"] + model_paths = config["checkpoint"].get("path", []) + # Load model architecture + model = models.get_model( + model_name=model_name, + config=model_parameters, + device=device, + ) + # If provided weights, apply corresponding ones (from an appropriate fold) + if len(model_paths): + state_dict = torch.load(model_paths, map_location=device) + model.load_state_dict(state_dict) + model = model.to(device) + return model + +def inference( + model, + datasets_path, + device: str, + amount_to_use: Optional[int] = None, + batch_size: int = 8, +): + logging.info("Loading data...") + + + data_val = get_dataset_file( + datasets_path=datasets_path, + amount_to_use=amount_to_use, + ) + + test_loader = DataLoader( + data_val, + batch_size=batch_size, + shuffle=True, + drop_last=False, + num_workers=3, + ) + + batches_number = len(data_val) // batch_size + num_correct = 0.0 + num_total = 0.0 + + y_pred = torch.Tensor([]).to(device) + y = torch.Tensor([]).to(device) + y_pred_label = torch.Tensor([]).to(device) + + preds = [] + + for i, (batch_x, _, batch_y, metadata) in enumerate(test_loader): + model.eval() + _, path, _, _ = metadata + if i % 10 == 0: + print(f"Batch [{i}/{batches_number}]") + + with torch.no_grad(): + batch_x = batch_x.to(device) + batch_y = batch_y.to(device) + num_total += batch_x.size(0) + + batch_pred = model(batch_x).squeeze(1) + batch_pred = torch.sigmoid(batch_pred) + batch_pred_label = (batch_pred + 0.5).int() + + num_correct += (batch_pred_label == batch_y.int()).sum(dim=0).item() + + y_pred = torch.concat([y_pred, batch_pred], dim=0) + y_pred_label = torch.concat([y_pred_label, batch_pred_label], dim=0) + y = torch.concat([y, batch_y], dim=0) + + for i in range(len(y_pred_label)): + label = 'Fake' if y_pred_label[i] == 0 else 'Real' + print(f'{path[i]}') + print(f' Prediction: : {label}') + print(f' Probability: {y_pred[i]})') + preds.append((label, y_pred[i].detach().cpu().item())) + + return preds + + eval_accuracy = (num_correct / num_total) * 100 + + precision, recall, f1_score, support = precision_recall_fscore_support( + y.cpu().numpy(), y_pred_label.cpu().numpy(), average="binary", beta=1.0 + ) + auc_score = roc_auc_score(y_true=y.cpu().numpy(), y_score=y_pred.cpu().numpy()) + + # For EER flip values, following original evaluation implementation + y_for_eer = 1 - y + + thresh, eer, fpr, tpr = metrics.calculate_eer( + y=y_for_eer.cpu().numpy(), + y_score=y_pred.cpu().numpy(), + ) + + eer_label = f"eval/eer" + accuracy_label = f"eval/accuracy" + precision_label = f"eval/precision" + recall_label = f"eval/recall" + f1_label = f"eval/f1_score" + auc_label = f"eval/auc" + + logging.info( + f"{eer_label}: {eer:.4f}, {accuracy_label}: {eval_accuracy:.4f}, {precision_label}: {precision:.4f}, {recall_label}: {recall:.4f}, {f1_label}: {f1_score:.4f}, {auc_label}: {auc_score:.4f}" + ) + + +def main(args): + LOGGER = logging.getLogger() + LOGGER.setLevel(logging.INFO) + + ch = logging.StreamHandler() + formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") + ch.setFormatter(formatter) + LOGGER.addHandler(ch) + logging.basicConfig(stream=sys.stdout, level=logging.INFO) + + if not args.cpu and torch.cuda.is_available(): + device = "cuda" + else: + device = "cpu" + + with open(args.config, "r") as f: + config = yaml.safe_load(f) + + seed = config["data"].get("seed", 42) + # fix all seeds - this should not actually change anything + commons.set_seed(seed) + + evaluate_nn( + model_paths=config["checkpoint"].get("path", []), + datasets_paths=[ + args.folder_path, + ], + model_config=config["model"], + amount_to_use=args.amount, + device=device, + ) + + +def parse_args(): + parser = argparse.ArgumentParser() + + # If assigned as None, then it won't be taken into account + FOLDER_DATASET_PATH = "sample_files" + + parser.add_argument( + "--folder_path", type=str, default=FOLDER_DATASET_PATH + ) + + default_model_config = "config.yaml" + parser.add_argument( + "--config", + help="Model config file path (default: config.yaml)", + type=str, + default=default_model_config, + ) + + default_amount = None + parser.add_argument( + "--amount", + "-a", + help=f"Amount of files to load from each directory (default: {default_amount} - use all).", + type=int, + default=default_amount, + ) + + parser.add_argument("--cpu", "-c", help="Force using cpu", action="store_true") + + return parser.parse_args() + + +if __name__ == "__main__": + main(parse_args()) diff --git a/install.sh b/install.sh new file mode 100644 index 0000000000000000000000000000000000000000..7f4598c4ee79d2951f437b8fc48934a9d6dee27f --- /dev/null +++ b/install.sh @@ -0,0 +1,6 @@ +conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch -y + +pip install asteroid-filterbanks==0.4.0 +pip install librosa==0.9.2 +pip install git+https://github.com/openai/whisper.git@7858aa9c08d98f75575035ecd6481f462d66ca27 +pip install pandas==2.0.2 diff --git a/mesonet_whisper_mfcc_finetuned.pth b/mesonet_whisper_mfcc_finetuned.pth new file mode 100644 index 0000000000000000000000000000000000000000..95709ca6df14548467d159c98d1c1e25b202ae5d --- /dev/null +++ b/mesonet_whisper_mfcc_finetuned.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a34a00d0961303274e1cf7a2dc2b6e9f9d568ff0416300be1aaee1c2e2ceee12 +size 32983925 diff --git a/sample_files/[FAKE] - jokowi - cupid [vocals].mp3 b/sample_files/[FAKE] - jokowi - cupid [vocals].mp3 new file mode 100644 index 0000000000000000000000000000000000000000..a97996505377712f7450d59a104768fa691ce2be --- /dev/null +++ b/sample_files/[FAKE] - jokowi - cupid [vocals].mp3 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1ce8dce41de4f44908c57deea26d4efe5a74f9a37700a76a94ac065e862304c0 +size 775449 diff --git "a/sample_files/[REAL] - Obama at Rutgers\357\274\232 'Ignorance Is Not a Virtue'_[cut_49sec].mp3" "b/sample_files/[REAL] - Obama at Rutgers\357\274\232 'Ignorance Is Not a Virtue'_[cut_49sec].mp3" new file mode 100644 index 0000000000000000000000000000000000000000..a93bb07a9d87f913811f653419d04d690b6fc2d4 --- /dev/null +++ "b/sample_files/[REAL] - Obama at Rutgers\357\274\232 'Ignorance Is Not a Virtue'_[cut_49sec].mp3" @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6694c6d329f8a372896808c1f7c1e487eec65e5ad2fb3d244d80729b211ac0c4 +size 1950720 diff --git "a/sample_files/[REAL] - Obama's speech to the class of 2020 in 2 minutes \357\275\234 The Washington Post.wav" "b/sample_files/[REAL] - Obama's speech to the class of 2020 in 2 minutes \357\275\234 The Washington Post.wav" new file mode 100644 index 0000000000000000000000000000000000000000..48cf46a28c817b6a0ba45de25299733b1f6ff03e --- /dev/null +++ "b/sample_files/[REAL] - Obama's speech to the class of 2020 in 2 minutes \357\275\234 The Washington Post.wav" @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3f045f2b80fdc136c63bfc5897dd9a3d34a3b60dba886cae297d4425db30d5d9 +size 27507540 diff --git a/sample_files/[[FAKE] - y2mate.com - DeepFake AI generated synthetic video of Barack Obama.mp3 b/sample_files/[[FAKE] - y2mate.com - DeepFake AI generated synthetic video of Barack Obama.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..5271abaf4b45f16bdc45a150a1b83428ce6509b5 --- /dev/null +++ b/sample_files/[[FAKE] - y2mate.com - DeepFake AI generated synthetic video of Barack Obama.mp3 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fd7100f013cb23ae4af3e00594330838dcae39dc86d669ef8fd215a6a6d88f53 +size 273900 diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..967b94b78280fb1ecad3249a78b490832597e2e5 --- /dev/null +++ b/src/__init__.py @@ -0,0 +1,3 @@ +import logging + +logging.getLogger(__name__).addHandler(logging.NullHandler()) diff --git a/src/commons.py b/src/commons.py new file mode 100644 index 0000000000000000000000000000000000000000..b9ed59920d3584d5111ed1f261c9ec5006bd5b44 --- /dev/null +++ b/src/commons.py @@ -0,0 +1,22 @@ +"""Utility file for src toolkit.""" +import os +import random + +import numpy as np +import torch + +WHISPER_MODEL_WEIGHTS_PATH = "src/models/assets/tiny_enc.en.pt" + + +def set_seed(seed: int): + """Fix PRNG seed for reproducable experiments. + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + os.environ["PYTHONHASHSEED"] = str(seed) diff --git a/src/datasets/__init__.py b/src/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/datasets/asvspoof_dataset.py b/src/datasets/asvspoof_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..6db0e8f05b05de8cadf1c33b120028fedc9cfa99 --- /dev/null +++ b/src/datasets/asvspoof_dataset.py @@ -0,0 +1,155 @@ +from pathlib import Path + +import pandas as pd +if __name__ == "__main__": + import sys + sys.path.append(str(Path(__file__).parent.parent.parent.absolute())) + +from src.datasets.base_dataset import SimpleAudioFakeDataset + +ASVSPOOF_SPLIT = { + "train": ['A01', 'A07', 'A08', 'A02', 'A09', 'A10', 'A03', 'A04', 'A05', 'A06', 'A11', 'A12', 'A13', 'A14', 'A15', 'A16', 'A17', 'A18', 'A19'], + "test": ['A01', 'A07', 'A08', 'A02', 'A09', 'A10', 'A03', 'A04', 'A05', 'A06', 'A11', 'A12', 'A13', 'A14', 'A15', 'A16', 'A17', 'A18', 'A19'], + "val": ['A01', 'A07', 'A08', 'A02', 'A09', 'A10', 'A03', 'A04', 'A05', 'A06', 'A11', 'A12', 'A13', 'A14', 'A15', 'A16', 'A17', 'A18', 'A19'], + "partition_ratio": [0.7, 0.15], + "seed": 45, +} + + +class ASVSpoofDataset(SimpleAudioFakeDataset): + + protocol_folder_name = "ASVspoof2019_LA_cm_protocols" + subset_dir_prefix = "ASVspoof2019_LA_" + subsets = ("train", "dev", "eval") + + def __init__(self, path, subset="train", transform=None): + super().__init__(subset, transform) + self.path = path + + self.allowed_attacks = ASVSPOOF_SPLIT[subset] + self.partition_ratio = ASVSPOOF_SPLIT["partition_ratio"] + self.seed = ASVSPOOF_SPLIT["seed"] + + self.samples = pd.DataFrame() + + for subset in self.subsets: + subset_dir = Path(self.path) / f"{self.subset_dir_prefix}{subset}" + subset_protocol_path = self.get_protocol_path(subset) + subset_samples = self.read_protocol(subset_dir, subset_protocol_path) + + self.samples = pd.concat([self.samples, subset_samples]) + + self.transform = transform + + def get_protocol_path(self, subset): + paths = list((Path(self.path) / self.protocol_folder_name).glob("*.txt")) + for path in paths: + if subset in Path(path).stem: + return path + + def read_protocol(self, subset_dir, protocol_path): + samples = { + "user_id": [], + "sample_name": [], + "attack_type": [], + "label": [], + "path": [] + } + + real_samples = [] + fake_samples = [] + with open(protocol_path, "r") as file: + for line in file: + attack_type = line.strip().split(" ")[3] + + if attack_type == "-": + real_samples.append(line) + elif attack_type in self.allowed_attacks: + fake_samples.append(line) + + if attack_type not in self.allowed_attacks: + continue + + fake_samples = self.split_samples(fake_samples) + for line in fake_samples: + samples = self.add_line_to_samples(samples, line, subset_dir) + + real_samples = self.split_samples(real_samples) + for line in real_samples: + samples = self.add_line_to_samples(samples, line, subset_dir) + + return pd.DataFrame(samples) + + @staticmethod + def add_line_to_samples(samples, line, subset_dir): + user_id, sample_name, _, attack_type, label = line.strip().split(" ") + samples["user_id"].append(user_id) + samples["sample_name"].append(sample_name) + samples["attack_type"].append(attack_type) + samples["label"].append(label) + + assert (subset_dir / "flac" / f"{sample_name}.flac").exists() + samples["path"].append(subset_dir / "flac" / f"{sample_name}.flac") + + return samples + +class ASVSpoof2019DatasetOriginal(ASVSpoofDataset): + + subsets = {"train": "train", "test": "dev", "val": "eval"} + + protocol_folder_name = "ASVspoof2019_LA_cm_protocols" + subset_dir_prefix = "ASVspoof2019_LA_" + subset_dirs_attacks = { + "train": ["A01", "A02", "A03", "A04", "A05", "A06"], + "dev": ["A01", "A02", "A03", "A04", "A05", "A06"], + "eval": [ + "A07", "A08", "A09", "A10", "A11", "A12", "A13", "A14", "A15", + "A16", "A17", "A18", "A19" + ] + } + + + def __init__(self, path, fold_subset="train"): + """ + Initialise object. Skip __init__ of ASVSpoofDataset doe to different + logic, but follow SimpleAudioFakeDataset constructor. + """ + super(ASVSpoofDataset, self).__init__(float('inf'), fold_subset) + self.path = path + subset = self.subsets[fold_subset] + self.allowed_attacks = self.subset_dirs_attacks[subset] + subset_dir = Path(self.path) / f"{self.subset_dir_prefix}{subset}" + subset_protocol_path = self.get_protocol_path(subset) + self.samples = self.read_protocol(subset_dir, subset_protocol_path) + + def read_protocol(self, subset_dir, protocol_path): + samples = { + "user_id": [], + "sample_name": [], + "attack_type": [], + "label": [], + "path": [] + } + + real_samples = [] + fake_samples = [] + + with open(protocol_path, "r") as file: + for line in file: + attack_type = line.strip().split(" ")[3] + if attack_type == "-": + real_samples.append(line) + elif attack_type in self.allowed_attacks: + fake_samples.append(line) + else: + raise ValueError( + "Tried to load attack that shouldn't be here!" + ) + + for line in fake_samples: + samples = self.add_line_to_samples(samples, line, subset_dir) + for line in real_samples: + samples = self.add_line_to_samples(samples, line, subset_dir) + + return pd.DataFrame(samples) + diff --git a/src/datasets/base_dataset.py b/src/datasets/base_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..04451d22721896409668eb4875f2e6c8b1088899 --- /dev/null +++ b/src/datasets/base_dataset.py @@ -0,0 +1,180 @@ +"""Base dataset classes.""" +import logging +import math +import random + +import numpy as np +import pandas as pd +import torch +import torchaudio +from torch.utils.data import Dataset +from torch.utils.data.dataset import T_co + + +LOGGER = logging.getLogger(__name__) + +SAMPLING_RATE = 16_000 +APPLY_NORMALIZATION = True +APPLY_TRIMMING = True +APPLY_PADDING = True +FRAMES_NUMBER = 480_000 # <- originally 64_600 + + +SOX_SILENCE = [ + # trim all silence that is longer than 0.2s and louder than 1% volume (relative to the file) + # from beginning and middle/end + ["silence", "1", "0.2", "1%", "-1", "0.2", "1%"], +] + + +class SimpleAudioFakeDataset(Dataset): + def __init__( + self, + subset, + transform=None, + return_label: bool = True, + return_meta: bool = True, + ): + self.transform = transform + + self.subset = subset + self.allowed_attacks = None + self.partition_ratio = None + self.seed = None + self.return_label = return_label + self.return_meta = return_meta + + def split_samples(self, samples_list): + if isinstance(samples_list, pd.DataFrame): + samples_list = samples_list.sort_values(by=list(samples_list.columns)) + samples_list = samples_list.sample(frac=1, random_state=self.seed) + else: + samples_list = sorted(samples_list) + random.seed(self.seed) + random.shuffle(samples_list) + + p, s = self.partition_ratio + subsets = np.split( + samples_list, [int(p * len(samples_list)), int((p + s) * len(samples_list))] + ) + return dict(zip(["train", "test", "val"], subsets))[self.subset] + + def df2tuples(self): + tuple_samples = [] + for i, elem in self.samples.iterrows(): + tuple_samples.append( + (str(elem["path"]), elem["label"], elem["attack_type"]) + ) + + self.samples = tuple_samples + + + return self.samples + + def __getitem__(self, index) -> T_co: + if isinstance(self.samples, pd.DataFrame): + sample = self.samples.iloc[index] + + path = str(sample["path"]) + label = sample["label"] + attack_type = sample["attack_type"] + if type(attack_type) != str and math.isnan(attack_type): + attack_type = "N/A" + else: + path, label, attack_type = self.samples[index] + + waveform, sample_rate = torchaudio.load(path, normalize=APPLY_NORMALIZATION) + import librosa + # waveform, sample_rate = librosa.load(path, sr=SAMPLING_RATE) + # waveform = torch.tensor(waveform) + print('waveform', waveform) + real_sec_length = len(waveform[0]) / sample_rate + + waveform, sample_rate = apply_preprocessing(waveform, sample_rate) + + return_data = [waveform, sample_rate] + if self.return_label: + label = 1 if label == "bonafide" else 0 + return_data.append(label) + + if self.return_meta: + return_data.append( + ( + attack_type, + path, + self.subset, + real_sec_length, + ) + ) + return return_data + + def __len__(self): + return len(self.samples) + + +def apply_preprocessing( + waveform, + sample_rate, +): + if sample_rate != SAMPLING_RATE and SAMPLING_RATE != -1: + waveform, sample_rate = resample_wave(waveform, sample_rate, SAMPLING_RATE) + + # Stereo to mono + if waveform.dim() > 1 and waveform.shape[0] > 1: + waveform = waveform[:1, ...] + + # Trim too long utterances... + if APPLY_TRIMMING: + waveform, sample_rate = apply_trim(waveform, sample_rate) + + # ... or pad too short ones. + if APPLY_PADDING: + waveform = apply_pad(waveform, FRAMES_NUMBER) + + return waveform, sample_rate + + +def resample_wave(waveform, sample_rate, target_sample_rate): + # waveform, sample_rate = torchaudio.sox_effects.apply_effects_tensor( + # waveform, sample_rate, [["rate", f"{target_sample_rate}"]] + # ) + waveform = torchaudio.functional.resample(waveform, orig_freq=sample_rate, new_freq=target_sample_rate) + return waveform, target_sample_rate + + +def resample_file(path, target_sample_rate, normalize=True): + waveform, sample_rate = torchaudio.sox_effects.apply_effects_file( + path, [["rate", f"{target_sample_rate}"]], normalize=normalize + ) + + return waveform, sample_rate + + +def apply_trim(waveform, sample_rate): + # ( + # waveform_trimmed, + # sample_rate_trimmed, + # ) = torchaudio.sox_effects.apply_effects_tensor(waveform, sample_rate, SOX_SILENCE) + + ["silence", "1", "0.2", "1%", "-1", "0.2", "1%"], + waveform_trimmed = torchaudio.functional.vad(waveform, sample_rate=sample_rate) + + if waveform_trimmed.size()[1] > 0: + waveform = waveform_trimmed + + return waveform, sample_rate + + +def apply_pad(waveform, cut): + """Pad wave by repeating signal until `cut` length is achieved.""" + waveform = waveform.squeeze(0) + waveform_len = waveform.shape[0] + + if waveform_len >= cut: + return waveform[:cut] + + # need to pad + num_repeats = int(cut / waveform_len) + 1 + padded_waveform = torch.tile(waveform, (1, num_repeats))[:, :cut][0] + + return padded_waveform diff --git a/src/datasets/deepfake_asvspoof_dataset.py b/src/datasets/deepfake_asvspoof_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..c4f02c0542c05f9a822668dc76b90f75acbb6014 --- /dev/null +++ b/src/datasets/deepfake_asvspoof_dataset.py @@ -0,0 +1,86 @@ +import logging +from pathlib import Path + +import pandas as pd + +from src.datasets.base_dataset import SimpleAudioFakeDataset + +DF_ASVSPOOF_SPLIT = { + "partition_ratio": [0.7, 0.15], + "seed": 45 +} + +LOGGER = logging.getLogger() + +class DeepFakeASVSpoofDataset(SimpleAudioFakeDataset): + + protocol_file_name = "keys/CM/trial_metadata.txt" + subset_dir_prefix = "ASVspoof2021_DF_eval" + subset_parts = ("part00", "part01", "part02", "part03") + + def __init__(self, path, subset="train", transform=None): + super().__init__(subset, transform) + self.path = path + + self.partition_ratio = DF_ASVSPOOF_SPLIT["partition_ratio"] + self.seed = DF_ASVSPOOF_SPLIT["seed"] + + self.flac_paths = self.get_file_references() + self.samples = self.read_protocol() + + self.transform = transform + LOGGER.info(f"Spoof: {len(self.samples[self.samples['label'] == 'spoof'])}") + LOGGER.info(f"Original: {len(self.samples[self.samples['label'] == 'bonafide'])}") + + def get_file_references(self): + flac_paths = {} + for part in self.subset_parts: + path = Path(self.path) / f"{self.subset_dir_prefix}_{part}" / self.subset_dir_prefix / "flac" + flac_list = list(path.glob("*.flac")) + + for path in flac_list: + flac_paths[path.stem] = path + + return flac_paths + + def read_protocol(self): + samples = { + "sample_name": [], + "label": [], + "path": [], + "attack_type": [], + } + + real_samples = [] + fake_samples = [] + with open(Path(self.path) / self.protocol_file_name, "r") as file: + for line in file: + label = line.strip().split(" ")[5] + + if label == "bonafide": + real_samples.append(line) + elif label == "spoof": + fake_samples.append(line) + + fake_samples = self.split_samples(fake_samples) + for line in fake_samples: + samples = self.add_line_to_samples(samples, line) + + real_samples = self.split_samples(real_samples) + for line in real_samples: + samples = self.add_line_to_samples(samples, line) + + return pd.DataFrame(samples) + + def add_line_to_samples(self, samples, line): + _, sample_name, _, _, _, label, _, _ = line.strip().split(" ") + samples["sample_name"].append(sample_name) + samples["label"].append(label) + samples["attack_type"].append(label) + + sample_path = self.flac_paths[sample_name] + assert sample_path.exists() + samples["path"].append(sample_path) + + return samples + diff --git a/src/datasets/detection_dataset.py b/src/datasets/detection_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..9f148477adab01a8ec6778a2a98e8465b9555973 --- /dev/null +++ b/src/datasets/detection_dataset.py @@ -0,0 +1,125 @@ +import logging +from typing import List, Optional + + +import pandas as pd + +from src.datasets.base_dataset import SimpleAudioFakeDataset +from src.datasets.deepfake_asvspoof_dataset import DeepFakeASVSpoofDataset +from src.datasets.fakeavceleb_dataset import FakeAVCelebDataset +from src.datasets.wavefake_dataset import WaveFakeDataset +from src.datasets.asvspoof_dataset import ASVSpoof2019DatasetOriginal + + +LOGGER = logging.getLogger() + + +class DetectionDataset(SimpleAudioFakeDataset): + def __init__( + self, + asvspoof_path=None, + wavefake_path=None, + fakeavceleb_path=None, + asvspoof2019_path=None, + subset: str = "val", + transform=None, + oversample: bool = True, + undersample: bool = False, + return_label: bool = True, + reduced_number: Optional[int] = None, + return_meta: bool = False, + ): + super().__init__( + subset=subset, + transform=transform, + return_label=return_label, + return_meta=return_meta, + ) + datasets = self._init_datasets( + asvspoof_path=asvspoof_path, + wavefake_path=wavefake_path, + fakeavceleb_path=fakeavceleb_path, + asvspoof2019_path=asvspoof2019_path, + subset=subset, + ) + self.samples = pd.concat([ds.samples for ds in datasets], ignore_index=True) + + if oversample: + self.oversample_dataset() + elif undersample: + self.undersample_dataset() + + if reduced_number: + LOGGER.info(f"Using reduced number of samples - {reduced_number}!") + self.samples = self.samples.sample( + min(len(self.samples), reduced_number), + random_state=42, + ) + + def _init_datasets( + self, + asvspoof_path: Optional[str], + wavefake_path: Optional[str], + fakeavceleb_path: Optional[str], + asvspoof2019_path: Optional[str], + subset: str, + ) -> List[SimpleAudioFakeDataset]: + datasets = [] + + if asvspoof_path is not None: + asvspoof_dataset = DeepFakeASVSpoofDataset(asvspoof_path, subset=subset) + datasets.append(asvspoof_dataset) + + if wavefake_path is not None: + wavefake_dataset = WaveFakeDataset(wavefake_path, subset=subset) + datasets.append(wavefake_dataset) + + if fakeavceleb_path is not None: + fakeavceleb_dataset = FakeAVCelebDataset(fakeavceleb_path, subset=subset) + datasets.append(fakeavceleb_dataset) + + if asvspoof2019_path is not None: + la_dataset = ASVSpoof2019DatasetOriginal( + asvspoof2019_path, fold_subset=subset + ) + datasets.append(la_dataset) + + return datasets + + def oversample_dataset(self): + samples = self.samples.groupby(by=["label"]) + bona_length = len(samples.groups["bonafide"]) + spoof_length = len(samples.groups["spoof"]) + + diff_length = spoof_length - bona_length + + if diff_length < 0: + raise NotImplementedError + + if diff_length > 0: + bonafide = samples.get_group("bonafide").sample(diff_length, replace=True) + self.samples = pd.concat([self.samples, bonafide], ignore_index=True) + + def undersample_dataset(self): + samples = self.samples.groupby(by=["label"]) + bona_length = len(samples.groups["bonafide"]) + spoof_length = len(samples.groups["spoof"]) + + if spoof_length < bona_length: + raise NotImplementedError + + if spoof_length > bona_length: + spoofs = samples.get_group("spoof").sample(bona_length, replace=True) + self.samples = pd.concat( + [samples.get_group("bonafide"), spoofs], ignore_index=True + ) + + def get_bonafide_only(self): + samples = self.samples.groupby(by=["label"]) + self.samples = samples.get_group("bonafide") + return self.samples + + def get_spoof_only(self): + samples = self.samples.groupby(by=["label"]) + self.samples = samples.get_group("spoof") + return self.samples diff --git a/src/datasets/fakeavceleb_dataset.py b/src/datasets/fakeavceleb_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..e1f6a001c24c32b1cae1691d77c1777154ce3b2f --- /dev/null +++ b/src/datasets/fakeavceleb_dataset.py @@ -0,0 +1,94 @@ +from pathlib import Path + +import pandas as pd + +from src.datasets.base_dataset import SimpleAudioFakeDataset + +FAKEAVCELEB_SPLIT = { + "train": ['faceswap-wav2lip', 'fsgan-wav2lip', 'wav2lip', 'rtvc'], + "test": ['faceswap-wav2lip', 'fsgan-wav2lip', 'wav2lip', 'rtvc'], + "val": ['faceswap-wav2lip', 'fsgan-wav2lip', 'wav2lip', 'rtvc'], + "partition_ratio": [0.7, 0.15], + "seed": 45 +} + + +class FakeAVCelebDataset(SimpleAudioFakeDataset): + + audio_folder = "FakeAVCeleb-audio" + audio_extension = ".mp3" + metadata_file = Path(audio_folder) / "meta_data.csv" + subsets = ("train", "dev", "eval") + + def __init__(self, path, subset="train", transform=None): + super().__init__(subset, transform) + self.path = path + + self.subset = subset + self.allowed_attacks = FAKEAVCELEB_SPLIT[subset] + self.partition_ratio = FAKEAVCELEB_SPLIT["partition_ratio"] + self.seed = FAKEAVCELEB_SPLIT["seed"] + + self.metadata = self.get_metadata() + + self.samples = pd.concat([self.get_fake_samples(), self.get_real_samples()], ignore_index=True) + + def get_metadata(self): + md = pd.read_csv(Path(self.path) / self.metadata_file) + md["audio_type"] = md["type"].apply(lambda x: x.split("-")[-1]) + return md + + def get_fake_samples(self): + samples = { + "user_id": [], + "sample_name": [], + "attack_type": [], + "label": [], + "path": [] + } + + for attack_name in self.allowed_attacks: + fake_samples = self.metadata[ + (self.metadata["method"] == attack_name) & (self.metadata["audio_type"] == "FakeAudio") + ] + + samples_list = fake_samples.iterrows() + samples_list = self.split_samples(samples_list) + + for _, sample in samples_list: + samples["user_id"].append(sample["source"]) + samples["sample_name"].append(Path(sample["filename"]).stem) + samples["attack_type"].append(sample["method"]) + samples["label"].append("spoof") + samples["path"].append(self.get_file_path(sample)) + + return pd.DataFrame(samples) + + def get_real_samples(self): + samples = { + "user_id": [], + "sample_name": [], + "attack_type": [], + "label": [], + "path": [] + } + + samples_list = self.metadata[ + (self.metadata["method"] == "real") & (self.metadata["audio_type"] == "RealAudio") + ] + + samples_list = self.split_samples(samples_list) + + for index, sample in samples_list.iterrows(): + samples["user_id"].append(sample["source"]) + samples["sample_name"].append(Path(sample["filename"]).stem) + samples["attack_type"].append("-") + samples["label"].append("bonafide") + samples["path"].append(self.get_file_path(sample)) + + return pd.DataFrame(samples) + + def get_file_path(self, sample): + path = "/".join([self.audio_folder, *sample["path"].split("/")[1:]]) + return Path(self.path) / path / Path(sample["filename"]).with_suffix(self.audio_extension) + diff --git a/src/datasets/folder_dataset.py b/src/datasets/folder_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..431571e30cf62805f5fdd8284df19221f6ee5aa2 --- /dev/null +++ b/src/datasets/folder_dataset.py @@ -0,0 +1,75 @@ +import numpy as np +import pandas as pd +import os +from pathlib import Path + +from src.datasets.base_dataset import SimpleAudioFakeDataset + + +class FolderDataset(SimpleAudioFakeDataset): + + def __init__( + self, + path, + subset="test", + transform=None, + ): + super().__init__(subset=subset, transform=transform) + self.path = path + self.samples = self.read_samples() + + + def read_samples(self): + path = Path(self.path) + print('ori path', path) + print('list', os.listdir(path)) + + samples = [] + for filepath in os.listdir(path): + samples.append({ + 'path': path / filepath, + 'label': '', + 'attack_type': '', + }) + + samples = pd.DataFrame(samples) + print('samples', samples) + return samples + + +class FileDataset(SimpleAudioFakeDataset): + + def __init__( + self, + path, + subset="test", + transform=None, + ): + super().__init__(subset=subset, transform=transform) + self.path = path + self.samples = self.read_samples() + + + def read_samples(self): + path = Path(self.path) + + samples = [{'path': path, 'label': '', 'attack_type':''}] + + samples = pd.DataFrame(samples) + print('samples', samples) + return samples + + +if __name__ == "__main__": + dataset = InTheWildDataset( + path="../datasets/release_in_the_wild", + subset="val", + seed=242, + split_strategy="per_speaker" + ) + + print(len(dataset)) + print(len(dataset.samples["user_id"].unique())) + print(dataset.samples["user_id"].unique()) + + print(dataset[0]) diff --git a/src/datasets/in_the_wild_dataset.py b/src/datasets/in_the_wild_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..3151ab37992f672a2d3bd7e881ccfa16587d136f --- /dev/null +++ b/src/datasets/in_the_wild_dataset.py @@ -0,0 +1,62 @@ +import numpy as np +import pandas as pd +from pathlib import Path + +from src.datasets.base_dataset import SimpleAudioFakeDataset + + +class InTheWildDataset(SimpleAudioFakeDataset): + + def __init__( + self, + path, + subset="train", + transform=None, + seed=None, + partition_ratio=(0.7, 0.15), + split_strategy="random" + ): + super().__init__(subset=subset, transform=transform) + self.path = path + self.read_samples() + self.partition_ratio = partition_ratio + self.seed = seed + + + def read_samples(self): + path = Path(self.path) + meta_path = path / "meta.csv" + + self.samples = pd.read_csv(meta_path) + self.samples["path"] = self.samples["file"].apply(lambda n: str(path / n)) + self.samples["file"] = self.samples["file"].apply(lambda n: Path(n).stem) + self.samples["label"] = self.samples["label"].map({"bona-fide": "bonafide", "spoof": "spoof"}) + self.samples["attack_type"] = self.samples["label"].map({"bonafide": "-", "spoof": "X"}) + self.samples.rename(columns={'file': 'sample_name', 'speaker': 'user_id'}, inplace=True) + + + def split_samples_per_speaker(self, samples): + speaker_list = pd.Series(samples["user_id"].unique()) + speaker_list = speaker_list.sort_values() + speaker_list = speaker_list.sample(frac=1, random_state=self.seed) + speaker_list = list(speaker_list) + + p, s = self.partition_ratio + subsets = np.split(speaker_list, [int(p * len(speaker_list)), int((p + s) * len(speaker_list))]) + speaker_subset = dict(zip(['train', 'test', 'val'], subsets))[self.subset] + return self.samples[self.samples["user_id"].isin(speaker_subset)] + + +if __name__ == "__main__": + dataset = InTheWildDataset( + path="../datasets/release_in_the_wild", + subset="val", + seed=242, + split_strategy="per_speaker" + ) + + print(len(dataset)) + print(len(dataset.samples["user_id"].unique())) + print(dataset.samples["user_id"].unique()) + + print(dataset[0]) diff --git a/src/datasets/wavefake_dataset.py b/src/datasets/wavefake_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..c9c3a706833b4f5daefebc9d872b7c8e292aed8f --- /dev/null +++ b/src/datasets/wavefake_dataset.py @@ -0,0 +1,85 @@ +from pathlib import Path + +import pandas as pd + +from src.datasets.base_dataset import SimpleAudioFakeDataset + +WAVEFAKE_SPLIT = { + "train": ['multi_band_melgan', 'melgan_large', 'parallel_wavegan', 'waveglow', 'full_band_melgan', 'melgan', 'hifiGAN'], + "test": ['multi_band_melgan', 'melgan_large', 'parallel_wavegan', 'waveglow', 'full_band_melgan', 'melgan', 'hifiGAN'], + "val": ['multi_band_melgan', 'melgan_large', 'parallel_wavegan', 'waveglow', 'full_band_melgan', 'melgan', 'hifiGAN'], + "partition_ratio": [0.7, 0.15], + "seed": 45 +} + + +class WaveFakeDataset(SimpleAudioFakeDataset): + + fake_data_path = "generated_audio" + jsut_real_data_path = "real_audio/jsut_ver1.1/basic5000/wav" + ljspeech_real_data_path = "real_audio/LJSpeech-1.1/wavs" + + def __init__(self, path, subset="train", transform=None): + super().__init__(subset, transform) + self.path = Path(path) + + self.fold_subset = subset + self.allowed_attacks = WAVEFAKE_SPLIT[subset] + self.partition_ratio = WAVEFAKE_SPLIT["partition_ratio"] + self.seed = WAVEFAKE_SPLIT["seed"] + + self.samples = pd.concat([self.get_fake_samples(), self.get_real_samples()], ignore_index=True) + + def get_fake_samples(self): + samples = { + "user_id": [], + "sample_name": [], + "attack_type": [], + "label": [], + "path": [] + } + + samples_list = list((self.path / self.fake_data_path).glob("*/*.wav")) + samples_list = self.filter_samples_by_attack(samples_list) + samples_list = self.split_samples(samples_list) + + for sample in samples_list: + samples["user_id"].append(None) + samples["sample_name"].append("_".join(sample.stem.split("_")[:-1])) + samples["attack_type"].append(self.get_attack_from_path(sample)) + samples["label"].append("spoof") + samples["path"].append(sample) + + return pd.DataFrame(samples) + + def filter_samples_by_attack(self, samples_list): + return [s for s in samples_list if self.get_attack_from_path(s) in self.allowed_attacks] + + def get_real_samples(self): + samples = { + "user_id": [], + "sample_name": [], + "attack_type": [], + "label": [], + "path": [] + } + + samples_list = list((self.path / self.jsut_real_data_path).glob("*.wav")) + samples_list += list((self.path / self.ljspeech_real_data_path).glob("*.wav")) + samples_list = self.split_samples(samples_list) + + for sample in samples_list: + samples["user_id"].append(None) + samples["sample_name"].append(sample.stem) + samples["attack_type"].append("-") + samples["label"].append("bonafide") + samples["path"].append(sample) + + return pd.DataFrame(samples) + + @staticmethod + def get_attack_from_path(path): + folder_name = path.parents[0].relative_to(path.parents[1]) + return str(folder_name).split("_", maxsplit=1)[-1] + + diff --git a/src/frontends.py b/src/frontends.py new file mode 100644 index 0000000000000000000000000000000000000000..95258ec2804190124503a5a406f5846161236308 --- /dev/null +++ b/src/frontends.py @@ -0,0 +1,72 @@ +from typing import List, Union, Callable + +import torch +import torchaudio + +SAMPLING_RATE = 16_000 +win_length = 400 # int((25 / 1_000) * SAMPLING_RATE) +hop_length = 160 # int((10 / 1_000) * SAMPLING_RATE) + +device = "cuda" if torch.cuda.is_available() else "cpu" + +MFCC_FN = torchaudio.transforms.MFCC( + sample_rate=SAMPLING_RATE, + n_mfcc=128, + melkwargs={ + "n_fft": 512, + "win_length": win_length, + "hop_length": hop_length, + }, +).to(device) + + +LFCC_FN = torchaudio.transforms.LFCC( + sample_rate=SAMPLING_RATE, + n_lfcc=128, + speckwargs={ + "n_fft": 512, + "win_length": win_length, + "hop_length": hop_length, + }, +).to(device) + +MEL_SCALE_FN = torchaudio.transforms.MelScale( + n_mels=80, + n_stft=257, + sample_rate=SAMPLING_RATE, +).to(device) + +delta_fn = torchaudio.transforms.ComputeDeltas( + win_length=400, + mode="replicate", +) + + +def get_frontend( + frontends: List[str], +) -> Union[torchaudio.transforms.MFCC, torchaudio.transforms.LFCC, Callable,]: + if "mfcc" in frontends: + return prepare_mfcc_double_delta + elif "lfcc" in frontends: + return prepare_lfcc_double_delta + raise ValueError(f"{frontends} frontend is not supported!") + + +def prepare_lfcc_double_delta(input): + if input.ndim < 4: + input = input.unsqueeze(1) # (bs, 1, n_lfcc, frames) + x = LFCC_FN(input) + delta = delta_fn(x) + double_delta = delta_fn(delta) + x = torch.cat((x, delta, double_delta), 2) # -> [bs, 1, 128 * 3, 1500] + return x[:, :, :, :3000] # (bs, n, n_lfcc * 3, frames) + + +def prepare_mfcc_double_delta(input): + if input.ndim < 4: + input = input.unsqueeze(1) # (bs, 1, n_lfcc, frames) + x = MFCC_FN(input) + delta = delta_fn(x) + double_delta = delta_fn(delta) + x = torch.cat((x, delta, double_delta), 2) # -> [bs, 1, 128 * 3, 1500] + return x[:, :, :, :3000] # (bs, n, n_lfcc * 3, frames) diff --git a/src/metrics.py b/src/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..3c299ee4809762ca446fae1d0e97352fa5ad3df7 --- /dev/null +++ b/src/metrics.py @@ -0,0 +1,15 @@ +from typing import Tuple + +import numpy as np +from scipy.interpolate import interp1d +from scipy.optimize import brentq +from sklearn.metrics import roc_curve +from sklearn.metrics import roc_curve + + +def calculate_eer(y, y_score) -> Tuple[float, float, np.ndarray, np.ndarray]: + fpr, tpr, thresholds = roc_curve(y, -y_score) + + eer = brentq(lambda x: 1.0 - x - interp1d(fpr, tpr)(x), 0.0, 1.0) + thresh = interp1d(fpr, thresholds)(eer) + return thresh, eer, fpr, tpr diff --git a/src/models/__init__.py b/src/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/models/assets/mel_filters.npz b/src/models/assets/mel_filters.npz new file mode 100644 index 0000000000000000000000000000000000000000..1a7839244dfb6b1cc02e4f3cfe12e4817a073bc7 Binary files /dev/null and b/src/models/assets/mel_filters.npz differ diff --git a/src/models/assets/tiny_enc.en.pt b/src/models/assets/tiny_enc.en.pt new file mode 100644 index 0000000000000000000000000000000000000000..411c61c740e1a38dc1ed145c878776c956c26879 --- /dev/null +++ b/src/models/assets/tiny_enc.en.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:206cca585e8ee06b813f958f72c548aebd489f125ef8949ad437f9fcc86e8cda +size 32853468 diff --git a/src/models/lcnn.py b/src/models/lcnn.py new file mode 100644 index 0000000000000000000000000000000000000000..01c8aa7ef4950177b57f6c4a15d513008462af0d --- /dev/null +++ b/src/models/lcnn.py @@ -0,0 +1,247 @@ +""" +This code is modified version of LCNN baseline +from ASVSpoof2021 challenge - https://github.com/asvspoof-challenge/2021/blob/main/LA/Baseline-LFCC-LCNN/project/baseline_LA/model.py +""" +import sys + +import torch +import torch.nn as torch_nn + +from src import frontends + + +NUM_COEFFICIENTS = 384 + + +# For blstm +class BLSTMLayer(torch_nn.Module): + """ Wrapper over dilated conv1D + Input tensor: (batchsize=1, length, dim_in) + Output tensor: (batchsize=1, length, dim_out) + We want to keep the length the same + """ + def __init__(self, input_dim, output_dim): + super().__init__() + if output_dim % 2 != 0: + print("Output_dim of BLSTMLayer is {:d}".format(output_dim)) + print("BLSTMLayer expects a layer size of even number") + sys.exit(1) + # bi-directional LSTM + self.l_blstm = torch_nn.LSTM( + input_dim, + output_dim // 2, + bidirectional=True + ) + def forward(self, x): + # permute to (length, batchsize=1, dim) + blstm_data, _ = self.l_blstm(x.permute(1, 0, 2)) + # permute it backt to (batchsize=1, length, dim) + return blstm_data.permute(1, 0, 2) + + +class MaxFeatureMap2D(torch_nn.Module): + """ Max feature map (along 2D) + + MaxFeatureMap2D(max_dim=1) + + l_conv2d = MaxFeatureMap2D(1) + data_in = torch.rand([1, 4, 5, 5]) + data_out = l_conv2d(data_in) + + + Input: + ------ + data_in: tensor of shape (batch, channel, ...) + + Output: + ------- + data_out: tensor of shape (batch, channel//2, ...) + + Note + ---- + By default, Max-feature-map is on channel dimension, + and maxout is used on (channel ...) + """ + def __init__(self, max_dim = 1): + super().__init__() + self.max_dim = max_dim + + def forward(self, inputs): + # suppose inputs (batchsize, channel, length, dim) + + shape = list(inputs.size()) + + if self.max_dim >= len(shape): + print("MaxFeatureMap: maximize on %d dim" % (self.max_dim)) + print("But input has %d dimensions" % (len(shape))) + sys.exit(1) + if shape[self.max_dim] // 2 * 2 != shape[self.max_dim]: + print("MaxFeatureMap: maximize on %d dim" % (self.max_dim)) + print("But this dimension has an odd number of data") + sys.exit(1) + shape[self.max_dim] = shape[self.max_dim]//2 + shape.insert(self.max_dim, 2) + + # view to (batchsize, 2, channel//2, ...) + # maximize on the 2nd dim + m, i = inputs.view(*shape).max(self.max_dim) + return m + + +############## +## FOR MODEL +############## + +class LCNN(torch_nn.Module): + """ Model definition + """ + def __init__(self, **kwargs): + super().__init__() + input_channels = kwargs.get("input_channels", 1) + num_coefficients = kwargs.get("num_coefficients", NUM_COEFFICIENTS) + + # Working sampling rate + self.num_coefficients = num_coefficients + + # dimension of embedding vectors + # here, the embedding is just the activation before sigmoid() + self.v_emd_dim = 1 + + # it can handle models with multiple front-end configuration + # by default, only a single front-end + + self.m_transform = torch_nn.Sequential( + torch_nn.Conv2d(input_channels, 64, (5, 5), 1, padding=(2, 2)), + MaxFeatureMap2D(), + torch.nn.MaxPool2d((2, 2), (2, 2)), + + torch_nn.Conv2d(32, 64, (1, 1), 1, padding=(0, 0)), + MaxFeatureMap2D(), + torch_nn.BatchNorm2d(32, affine=False), + torch_nn.Conv2d(32, 96, (3, 3), 1, padding=(1, 1)), + MaxFeatureMap2D(), + + torch.nn.MaxPool2d((2, 2), (2, 2)), + torch_nn.BatchNorm2d(48, affine=False), + + torch_nn.Conv2d(48, 96, (1, 1), 1, padding=(0, 0)), + MaxFeatureMap2D(), + torch_nn.BatchNorm2d(48, affine=False), + torch_nn.Conv2d(48, 128, (3, 3), 1, padding=(1, 1)), + MaxFeatureMap2D(), + + torch.nn.MaxPool2d((2, 2), (2, 2)), + + torch_nn.Conv2d(64, 128, (1, 1), 1, padding=(0, 0)), + MaxFeatureMap2D(), + torch_nn.BatchNorm2d(64, affine=False), + torch_nn.Conv2d(64, 64, (3, 3), 1, padding=(1, 1)), + MaxFeatureMap2D(), + torch_nn.BatchNorm2d(32, affine=False), + + torch_nn.Conv2d(32, 64, (1, 1), 1, padding=(0, 0)), + MaxFeatureMap2D(), + torch_nn.BatchNorm2d(32, affine=False), + torch_nn.Conv2d(32, 64, (3, 3), 1, padding=(1, 1)), + MaxFeatureMap2D(), + torch_nn.MaxPool2d((2, 2), (2, 2)), + + torch_nn.Dropout(0.7) + ) + + self.m_before_pooling = torch_nn.Sequential( + BLSTMLayer((self.num_coefficients//16) * 32, (self.num_coefficients//16) * 32), + BLSTMLayer((self.num_coefficients//16) * 32, (self.num_coefficients//16) * 32) + ) + + self.m_output_act = torch_nn.Linear((self.num_coefficients // 16) * 32, self.v_emd_dim) + + def _compute_embedding(self, x): + """ definition of forward method + Assume x (batchsize, length, dim) + Output x (batchsize * number_filter, output_dim) + """ + # resample if necessary + # x = self.m_resampler(x.squeeze(-1)).unsqueeze(-1) + + # number of sub models + batch_size = x.shape[0] + + # buffer to store output scores from sub-models + output_emb = torch.zeros( + [batch_size, self.v_emd_dim], + device=x.device, + dtype=x.dtype + ) + + # compute scores for each sub-models + idx = 0 + + # compute scores + # 1. unsqueeze to (batch, 1, frame_length, fft_bin) + # 2. compute hidden features + x = x.permute(0,1,3,2) + hidden_features = self.m_transform(x) + + # 3. (batch, channel, frame//N, feat_dim//N) -> + # (batch, frame//N, channel * feat_dim//N) + # where N is caused by conv with stride + hidden_features = hidden_features.permute(0, 2, 1, 3).contiguous() + frame_num = hidden_features.shape[1] + + hidden_features = hidden_features.view(batch_size, frame_num, -1) + # 4. pooling + # 4. pass through LSTM then summingc + hidden_features_lstm = self.m_before_pooling(hidden_features) + + # 5. pass through the output layer + tmp_emb = self.m_output_act((hidden_features_lstm + hidden_features).mean(1)) + output_emb[idx * batch_size : (idx+1) * batch_size] = tmp_emb + + return output_emb + + def _compute_score(self, feature_vec): + # feature_vec is [batch * submodel, 1] + return torch.sigmoid(feature_vec).squeeze(1) + + def forward(self, x): + feature_vec = self._compute_embedding(x) + return feature_vec + + + +class FrontendLCNN(LCNN): + """ Model definition + """ + def __init__(self, device: str = "cuda", **kwargs): + super().__init__(**kwargs) + + self.device = device + + frontend_name = kwargs.get("frontend_algorithm", []) + self.frontend = frontends.get_frontend(frontend_name) + print(f"Using {frontend_name} frontend") + + def _compute_frontend(self, x): + frontend = self.frontend(x) + if frontend.ndim < 4: + return frontend.unsqueeze(1) # (bs, 1, n_lfcc, frames) + return frontend # (bs, n, n_lfcc, frames) + + def forward(self, x): + x = self._compute_frontend(x) + feature_vec = self._compute_embedding(x) + + return feature_vec + + +if __name__ == "__main__": + + device = "cuda" + print("Definition of model") + model = FrontendLCNN(input_channels=2, num_coefficients=80, device=device, frontend_algorithm=["mel_spec"]) + model = model.to(device) + batch_size = 12 + mock_input = torch.rand((batch_size, 64_600,), device=device) + output = model(mock_input) + print(output.shape) diff --git a/src/models/meso_net.py b/src/models/meso_net.py new file mode 100644 index 0000000000000000000000000000000000000000..bc9035dded40a2e520da1a190ffa822127085e00 --- /dev/null +++ b/src/models/meso_net.py @@ -0,0 +1,146 @@ +""" +This code is modified version of MesoNet DeepFake detection solution +from FakeAVCeleb repository - https://github.com/DASH-Lab/FakeAVCeleb/blob/main/models/MesoNet.py. +""" +import torch +import torch.nn as nn + +from src import frontends + + +class MesoInception4(nn.Module): + """ + Pytorch Implemention of MesoInception4 + Author: Honggu Liu + Date: July 7, 2019 + """ + def __init__(self, num_classes=1, **kwargs): + super().__init__() + + self.fc1_dim = kwargs.get("fc1_dim", 1024) + input_channels = kwargs.get("input_channels", 3) + self.num_classes = num_classes + + #InceptionLayer1 + self.Incption1_conv1 = nn.Conv2d(input_channels, 1, 1, padding=0, bias=False) + self.Incption1_conv2_1 = nn.Conv2d(input_channels, 4, 1, padding=0, bias=False) + self.Incption1_conv2_2 = nn.Conv2d(4, 4, 3, padding=1, bias=False) + self.Incption1_conv3_1 = nn.Conv2d(input_channels, 4, 1, padding=0, bias=False) + self.Incption1_conv3_2 = nn.Conv2d(4, 4, 3, padding=2, dilation=2, bias=False) + self.Incption1_conv4_1 = nn.Conv2d(input_channels, 2, 1, padding=0, bias=False) + self.Incption1_conv4_2 = nn.Conv2d(2, 2, 3, padding=3, dilation=3, bias=False) + self.Incption1_bn = nn.BatchNorm2d(11) + + + #InceptionLayer2 + self.Incption2_conv1 = nn.Conv2d(11, 2, 1, padding=0, bias=False) + self.Incption2_conv2_1 = nn.Conv2d(11, 4, 1, padding=0, bias=False) + self.Incption2_conv2_2 = nn.Conv2d(4, 4, 3, padding=1, bias=False) + self.Incption2_conv3_1 = nn.Conv2d(11, 4, 1, padding=0, bias=False) + self.Incption2_conv3_2 = nn.Conv2d(4, 4, 3, padding=2, dilation=2, bias=False) + self.Incption2_conv4_1 = nn.Conv2d(11, 2, 1, padding=0, bias=False) + self.Incption2_conv4_2 = nn.Conv2d(2, 2, 3, padding=3, dilation=3, bias=False) + self.Incption2_bn = nn.BatchNorm2d(12) + + #Normal Layer + self.conv1 = nn.Conv2d(12, 16, 5, padding=2, bias=False) + self.relu = nn.ReLU(inplace=True) + self.leakyrelu = nn.LeakyReLU(0.1) + self.bn1 = nn.BatchNorm2d(16) + self.maxpooling1 = nn.MaxPool2d(kernel_size=(2, 2)) + + self.conv2 = nn.Conv2d(16, 16, 5, padding=2, bias=False) + self.maxpooling2 = nn.MaxPool2d(kernel_size=(4, 4)) + + self.dropout = nn.Dropout2d(0.5) + self.fc1 = nn.Linear(self.fc1_dim, 16) + self.fc2 = nn.Linear(16, num_classes) + + + #InceptionLayer + def InceptionLayer1(self, input): + x1 = self.Incption1_conv1(input) + x2 = self.Incption1_conv2_1(input) + x2 = self.Incption1_conv2_2(x2) + x3 = self.Incption1_conv3_1(input) + x3 = self.Incption1_conv3_2(x3) + x4 = self.Incption1_conv4_1(input) + x4 = self.Incption1_conv4_2(x4) + y = torch.cat((x1, x2, x3, x4), 1) + y = self.Incption1_bn(y) + y = self.maxpooling1(y) + + return y + + def InceptionLayer2(self, input): + x1 = self.Incption2_conv1(input) + x2 = self.Incption2_conv2_1(input) + x2 = self.Incption2_conv2_2(x2) + x3 = self.Incption2_conv3_1(input) + x3 = self.Incption2_conv3_2(x3) + x4 = self.Incption2_conv4_1(input) + x4 = self.Incption2_conv4_2(x4) + y = torch.cat((x1, x2, x3, x4), 1) + y = self.Incption2_bn(y) + y = self.maxpooling1(y) + + return y + + def forward(self, input): + x = self._compute_embedding(input) + return x + + def _compute_embedding(self, input): + x = self.InceptionLayer1(input) #(Batch, 11, 128, 128) + x = self.InceptionLayer2(x) #(Batch, 12, 64, 64) + + x = self.conv1(x) #(Batch, 16, 64 ,64) + x = self.relu(x) + x = self.bn1(x) + x = self.maxpooling1(x) #(Batch, 16, 32, 32) + + x = self.conv2(x) #(Batch, 16, 32, 32) + x = self.relu(x) + x = self.bn1(x) + x = self.maxpooling2(x) #(Batch, 16, 8, 8) + + x = x.view(x.size(0), -1) #(Batch, 16*8*8) + x = self.dropout(x) + + x = nn.AdaptiveAvgPool1d(self.fc1_dim)(x) + x = self.fc1(x) #(Batch, 16) ### <-- o tu + x = self.leakyrelu(x) + x = self.dropout(x) + x = self.fc2(x) + return x + + +class FrontendMesoInception4(MesoInception4): + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + self.device = kwargs['device'] + + frontend_name = kwargs.get("frontend_algorithm", []) + self.frontend = frontends.get_frontend(frontend_name) + print(f"Using {frontend_name} frontend") + + def forward(self, x): + x = self.frontend(x) + x = self._compute_embedding(x) + return x + + +if __name__ == "__main__": + model = FrontendMesoInception4( + input_channels=2, + fc1_dim=1024, + device='cuda', + frontend_algorithm="lfcc" + ) + + def count_parameters(model) -> int: + pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + return pytorch_total_params + print(count_parameters(model)) \ No newline at end of file diff --git a/src/models/models.py b/src/models/models.py new file mode 100644 index 0000000000000000000000000000000000000000..bd1a43d5c7984e7bdebe182873cae71215e9bd7c --- /dev/null +++ b/src/models/models.py @@ -0,0 +1,73 @@ +from typing import Dict + +from src.models import ( + lcnn, + specrnet, + whisper_specrnet, + rawnet3, + whisper_lcnn, + meso_net, + whisper_meso_net +) + + +def get_model(model_name: str, config: Dict, device: str): + if model_name == "rawnet3": + return rawnet3.prepare_model() + elif model_name == "lcnn": + return lcnn.FrontendLCNN(device=device, **config) + elif model_name == "specrnet": + return specrnet.FrontendSpecRNet( + device=device, + **config, + ) + elif model_name == "mesonet": + return meso_net.FrontendMesoInception4( + input_channels=config.get("input_channels", 1), + fc1_dim=config.get("fc1_dim", 1024), + frontend_algorithm=config.get("frontend_algorithm", "lfcc"), + device=device, + ) + elif model_name == "whisper_lcnn": + return whisper_lcnn.WhisperLCNN( + input_channels=config.get("input_channels", 1), + freeze_encoder=config.get("freeze_encoder", False), + device=device, + ) + elif model_name == "whisper_specrnet": + return whisper_specrnet.WhisperSpecRNet( + input_channels=config.get("input_channels", 1), + freeze_encoder=config.get("freeze_encoder", False), + device=device, + ) + elif model_name == "whisper_mesonet": + return whisper_meso_net.WhisperMesoNet( + input_channels=config.get("input_channels", 1), + freeze_encoder=config.get("freeze_encoder", True), + fc1_dim=config.get("fc1_dim", 1024), + device=device, + ) + elif model_name == "whisper_frontend_lcnn": + return whisper_lcnn.WhisperMultiFrontLCNN( + input_channels=config.get("input_channels", 2), + freeze_encoder=config.get("freeze_encoder", False), + frontend_algorithm=config.get("frontend_algorithm", "lfcc"), + device=device, + ) + elif model_name == "whisper_frontend_specrnet": + return whisper_specrnet.WhisperMultiFrontSpecRNet( + input_channels=config.get("input_channels", 2), + freeze_encoder=config.get("freeze_encoder", False), + frontend_algorithm=config.get("frontend_algorithm", "lfcc"), + device=device, + ) + elif model_name == "whisper_frontend_mesonet": + return whisper_meso_net.WhisperMultiFrontMesoNet( + input_channels=config.get("input_channels", 2), + fc1_dim=config.get("fc1_dim", 1024), + freeze_encoder=config.get("freeze_encoder", True), + frontend_algorithm=config.get("frontend_algorithm", "lfcc"), + device=device, + ) + else: + raise ValueError(f"Model '{model_name}' not supported") diff --git a/src/models/rawnet3.py b/src/models/rawnet3.py new file mode 100644 index 0000000000000000000000000000000000000000..0ddd68c1ba1f7370e352d69258bd1747510fcd91 --- /dev/null +++ b/src/models/rawnet3.py @@ -0,0 +1,323 @@ +""" +This file contains implementation of RawNet3 architecture. +The original implementation can be found here: https://github.com/Jungjee/RawNet/tree/master/python/RawNet3 +""" +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from asteroid_filterbanks import Encoder, ParamSincFB # pip install asteroid_filterbanks + + +class RawNet3(nn.Module): + def __init__(self, block, model_scale, context, summed, C=1024, **kwargs): + super().__init__() + + nOut = kwargs["nOut"] + + self.context = context + self.encoder_type = kwargs["encoder_type"] + self.log_sinc = kwargs["log_sinc"] + self.norm_sinc = kwargs["norm_sinc"] + self.out_bn = kwargs["out_bn"] + self.summed = summed + + self.preprocess = nn.Sequential( + PreEmphasis(), nn.InstanceNorm1d(1, eps=1e-4, affine=True) + ) + self.conv1 = Encoder( + ParamSincFB( + C // 4, + 251, + stride=kwargs["sinc_stride"], + ) + ) + self.relu = nn.ReLU() + self.bn1 = nn.BatchNorm1d(C // 4) + + self.layer1 = block( + C // 4, C, kernel_size=3, dilation=2, scale=model_scale, pool=5 + ) + self.layer2 = block( + C, C, kernel_size=3, dilation=3, scale=model_scale, pool=3 + ) + self.layer3 = block(C, C, kernel_size=3, dilation=4, scale=model_scale) + self.layer4 = nn.Conv1d(3 * C, 1536, kernel_size=1) + + if self.context: + attn_input = 1536 * 3 + else: + attn_input = 1536 + print("self.encoder_type", self.encoder_type) + if self.encoder_type == "ECA": + attn_output = 1536 + elif self.encoder_type == "ASP": + attn_output = 1 + else: + raise ValueError("Undefined encoder") + + self.attention = nn.Sequential( + nn.Conv1d(attn_input, 128, kernel_size=1), + nn.ReLU(), + nn.BatchNorm1d(128), + nn.Conv1d(128, attn_output, kernel_size=1), + nn.Softmax(dim=2), + ) + + self.bn5 = nn.BatchNorm1d(3072) + + self.fc6 = nn.Linear(3072, nOut) + self.bn6 = nn.BatchNorm1d(nOut) + + self.mp3 = nn.MaxPool1d(3) + + def forward(self, x): + """ + :param x: input mini-batch (bs, samp) + """ + + with torch.cuda.amp.autocast(enabled=False): + x = self.preprocess(x) + x = torch.abs(self.conv1(x)) + if self.log_sinc: + x = torch.log(x + 1e-6) + if self.norm_sinc == "mean": + x = x - torch.mean(x, dim=-1, keepdim=True) + elif self.norm_sinc == "mean_std": + m = torch.mean(x, dim=-1, keepdim=True) + s = torch.std(x, dim=-1, keepdim=True) + s[s < 0.001] = 0.001 + x = (x - m) / s + + if self.summed: + x1 = self.layer1(x) + x2 = self.layer2(x1) + x3 = self.layer3(self.mp3(x1) + x2) + else: + x1 = self.layer1(x) + x2 = self.layer2(x1) + x3 = self.layer3(x2) + + x = self.layer4(torch.cat((self.mp3(x1), x2, x3), dim=1)) + x = self.relu(x) + + t = x.size()[-1] + + if self.context: + global_x = torch.cat( + ( + x, + torch.mean(x, dim=2, keepdim=True).repeat(1, 1, t), + torch.sqrt( + torch.var(x, dim=2, keepdim=True).clamp( + min=1e-4, max=1e4 + ) + ).repeat(1, 1, t), + ), + dim=1, + ) + else: + global_x = x + + w = self.attention(global_x) + + mu = torch.sum(x * w, dim=2) + sg = torch.sqrt( + (torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-4, max=1e4) + ) + + x = torch.cat((mu, sg), 1) + + x = self.bn5(x) + + x = self.fc6(x) + + if self.out_bn: + x = self.bn6(x) + + return x + + +class PreEmphasis(torch.nn.Module): + def __init__(self, coef: float = 0.97) -> None: + super().__init__() + self.coef = coef + # make kernel + # In pytorch, the convolution operation uses cross-correlation. So, filter is flipped. + self.register_buffer( + "flipped_filter", + torch.FloatTensor([-self.coef, 1.0]).unsqueeze(0).unsqueeze(0), + ) + + def forward(self, input: torch.tensor) -> torch.tensor: + assert ( + len(input.size()) == 2 + ), "The number of dimensions of input tensor must be 2!" + # reflect padding to match lengths of in/out + input = input.unsqueeze(1) + input = F.pad(input, (1, 0), "reflect") + return F.conv1d(input, self.flipped_filter) + + +class AFMS(nn.Module): + """ + Alpha-Feature map scaling, added to the output of each residual block[1,2]. + + Reference: + [1] RawNet2 : https://www.isca-speech.org/archive/Interspeech_2020/pdfs/1011.pdf + [2] AMFS : https://www.koreascience.or.kr/article/JAKO202029757857763.page + """ + + def __init__(self, nb_dim: int) -> None: + super().__init__() + self.alpha = nn.Parameter(torch.ones((nb_dim, 1))) + self.fc = nn.Linear(nb_dim, nb_dim) + self.sig = nn.Sigmoid() + + def forward(self, x): + y = F.adaptive_avg_pool1d(x, 1).view(x.size(0), -1) + y = self.sig(self.fc(y)).view(x.size(0), x.size(1), -1) + + x = x + self.alpha + x = x * y + return x + + +class Bottle2neck(nn.Module): + def __init__( + self, + inplanes, + planes, + kernel_size=None, + dilation=None, + scale=4, + pool=False, + ): + + super().__init__() + + width = int(math.floor(planes / scale)) + + self.conv1 = nn.Conv1d(inplanes, width * scale, kernel_size=1) + self.bn1 = nn.BatchNorm1d(width * scale) + + self.nums = scale - 1 + + convs = [] + bns = [] + + num_pad = math.floor(kernel_size / 2) * dilation + + for i in range(self.nums): + convs.append( + nn.Conv1d( + width, + width, + kernel_size=kernel_size, + dilation=dilation, + padding=num_pad, + ) + ) + bns.append(nn.BatchNorm1d(width)) + + self.convs = nn.ModuleList(convs) + self.bns = nn.ModuleList(bns) + + self.conv3 = nn.Conv1d(width * scale, planes, kernel_size=1) + self.bn3 = nn.BatchNorm1d(planes) + + self.relu = nn.ReLU() + + self.width = width + + self.mp = nn.MaxPool1d(pool) if pool else False + self.afms = AFMS(planes) + + if inplanes != planes: # if change in number of filters + self.residual = nn.Sequential( + nn.Conv1d(inplanes, planes, kernel_size=1, stride=1, bias=False) + ) + else: + self.residual = nn.Identity() + + def forward(self, x): + residual = self.residual(x) + + out = self.conv1(x) + out = self.relu(out) + out = self.bn1(out) + + spx = torch.split(out, self.width, 1) + for i in range(self.nums): + if i == 0: + sp = spx[i] + else: + sp = sp + spx[i] + sp = self.convs[i](sp) + sp = self.relu(sp) + sp = self.bns[i](sp) + if i == 0: + out = sp + else: + out = torch.cat((out, sp), 1) + + out = torch.cat((out, spx[self.nums]), 1) + + out = self.conv3(out) + out = self.relu(out) + out = self.bn3(out) + + out += residual + if self.mp: + out = self.mp(out) + out = self.afms(out) + + return out + + +def prepare_model(): + model = RawNet3( + Bottle2neck, + model_scale=8, + context=True, + summed=True, + encoder_type="ECA", + nOut=1, # number of slices + out_bn=False, + sinc_stride=10, + log_sinc=True, + norm_sinc="mean", + grad_mult=1, + ) + return model + + +if __name__ == "__main__": + model = RawNet3( + Bottle2neck, + model_scale=8, + context=True, + summed=True, + encoder_type="ECA", + nOut=1, # number of slices + out_bn=False, + sinc_stride=10, + log_sinc=True, + norm_sinc="mean", + grad_mult=1, + ) + gpu = False + + model.eval() + print("RawNet3 initialised & weights loaded!") + + if torch.cuda.is_available(): + print("Cuda available, conducting inference on GPU") + model = model.to("cuda") + gpu = True + + audios = torch.rand(32, 64_600) + + out = model(audios) + print(out.shape) diff --git a/src/models/specrnet.py b/src/models/specrnet.py new file mode 100644 index 0000000000000000000000000000000000000000..98bdeadb3069dfc976e7c5e7e812a0537c9d75d1 --- /dev/null +++ b/src/models/specrnet.py @@ -0,0 +1,226 @@ +""" +This file contains implementation of SpecRNet architecture. +We base our codebase on the implementation of RawNet2 by Hemlata Tak (tak@eurecom.fr). +It is available here: https://github.com/asvspoof-challenge/2021/blob/main/LA/Baseline-RawNet2/model.py +""" +from typing import Dict + +import torch.nn as nn + +from src import frontends + + +def get_config(input_channels: int) -> Dict: + return { + "filts": [input_channels, [input_channels, 20], [20, 64], [64, 64]], + "nb_fc_node": 64, + "gru_node": 64, + "nb_gru_layer": 2, + "nb_classes": 1, + } + + +class Residual_block2D(nn.Module): + def __init__(self, nb_filts, first=False): + super().__init__() + self.first = first + + if not self.first: + self.bn1 = nn.BatchNorm2d(num_features=nb_filts[0]) + + self.lrelu = nn.LeakyReLU(negative_slope=0.3) + + self.conv1 = nn.Conv2d( + in_channels=nb_filts[0], + out_channels=nb_filts[1], + kernel_size=3, + padding=1, + stride=1, + ) + + self.bn2 = nn.BatchNorm2d(num_features=nb_filts[1]) + self.conv2 = nn.Conv2d( + in_channels=nb_filts[1], + out_channels=nb_filts[1], + padding=1, + kernel_size=3, + stride=1, + ) + + if nb_filts[0] != nb_filts[1]: + self.downsample = True + self.conv_downsample = nn.Conv2d( + in_channels=nb_filts[0], + out_channels=nb_filts[1], + padding=0, + kernel_size=1, + stride=1, + ) + + else: + self.downsample = False + self.mp = nn.MaxPool2d(2) + + def forward(self, x): + identity = x + if not self.first: + out = self.bn1(x) + out = self.lrelu(out) + else: + out = x + + out = self.conv1(x) + out = self.bn2(out) + out = self.lrelu(out) + out = self.conv2(out) + + if self.downsample: + identity = self.conv_downsample(identity) + + out += identity + out = self.mp(out) + return out + + +class SpecRNet(nn.Module): + def __init__(self, input_channels, **kwargs): + super().__init__() + config = get_config(input_channels=input_channels) + + self.device = kwargs.get("device", "cuda") + + self.first_bn = nn.BatchNorm2d(num_features=config["filts"][0]) + self.selu = nn.SELU(inplace=True) + self.block0 = nn.Sequential( + Residual_block2D(nb_filts=config["filts"][1], first=True) + ) + self.block2 = nn.Sequential(Residual_block2D(nb_filts=config["filts"][2])) + config["filts"][2][0] = config["filts"][2][1] + self.block4 = nn.Sequential(Residual_block2D(nb_filts=config["filts"][2])) + self.avgpool = nn.AdaptiveAvgPool2d(1) + + self.fc_attention0 = self._make_attention_fc( + in_features=config["filts"][1][-1], l_out_features=config["filts"][1][-1] + ) + self.fc_attention2 = self._make_attention_fc( + in_features=config["filts"][2][-1], l_out_features=config["filts"][2][-1] + ) + self.fc_attention4 = self._make_attention_fc( + in_features=config["filts"][2][-1], l_out_features=config["filts"][2][-1] + ) + + self.bn_before_gru = nn.BatchNorm2d(num_features=config["filts"][2][-1]) + self.gru = nn.GRU( + input_size=config["filts"][2][-1], + hidden_size=config["gru_node"], + num_layers=config["nb_gru_layer"], + batch_first=True, + bidirectional=True, + ) + + self.fc1_gru = nn.Linear( + in_features=config["gru_node"] * 2, out_features=config["nb_fc_node"] * 2 + ) + + self.fc2_gru = nn.Linear( + in_features=config["nb_fc_node"] * 2, + out_features=config["nb_classes"], + bias=True, + ) + + self.sig = nn.Sigmoid() + + def _compute_embedding(self, x): + x = self.first_bn(x) + x = self.selu(x) + + x0 = self.block0(x) + y0 = self.avgpool(x0).view(x0.size(0), -1) + y0 = self.fc_attention0(y0) + y0 = self.sig(y0).view(y0.size(0), y0.size(1), -1) + y0 = y0.unsqueeze(-1) + x = x0 * y0 + y0 + + x = nn.MaxPool2d(2)(x) + + x2 = self.block2(x) + y2 = self.avgpool(x2).view(x2.size(0), -1) + y2 = self.fc_attention2(y2) + y2 = self.sig(y2).view(y2.size(0), y2.size(1), -1) + y2 = y2.unsqueeze(-1) + x = x2 * y2 + y2 + + x = nn.MaxPool2d(2)(x) + + x4 = self.block4(x) + y4 = self.avgpool(x4).view(x4.size(0), -1) + y4 = self.fc_attention4(y4) + y4 = self.sig(y4).view(y4.size(0), y4.size(1), -1) + y4 = y4.unsqueeze(-1) + x = x4 * y4 + y4 + + x = nn.MaxPool2d(2)(x) + + x = self.bn_before_gru(x) + x = self.selu(x) + x = nn.AdaptiveAvgPool2d((1, None))(x) + x = x.squeeze(-2) + x = x.permute(0, 2, 1) + self.gru.flatten_parameters() + x, _ = self.gru(x) + x = x[:, -1, :] + x = self.fc1_gru(x) + x = self.fc2_gru(x) + return x + + def forward(self, x): + x = self._compute_embedding(x) + return x + + def _make_attention_fc(self, in_features, l_out_features): + l_fc = [] + l_fc.append(nn.Linear(in_features=in_features, out_features=l_out_features)) + return nn.Sequential(*l_fc) + + +class FrontendSpecRNet(SpecRNet): + def __init__(self, input_channels, **kwargs): + super().__init__(input_channels, **kwargs) + + self.device = kwargs['device'] + + frontend_name = kwargs.get("frontend_algorithm", []) + self.frontend = frontends.get_frontend(frontend_name) + print(f"Using {frontend_name} frontend") + + def _compute_frontend(self, x): + frontend = self.frontend(x) + if frontend.ndim < 4: + return frontend.unsqueeze(1) # (bs, 1, n_lfcc, frames) + return frontend # (bs, n, n_lfcc, frames) + + def forward(self, x): + x = self._compute_frontend(x) + x = self._compute_embedding(x) + return x + + +if __name__ == "__main__": + print("Definition of model") + device = "cuda" + + input_channels = 1 + config = { + "filts": [input_channels, [input_channels, 20], [20, 64], [64, 64]], + "nb_fc_node": 64, + "gru_node": 64, + "nb_gru_layer": 2, + "nb_classes": 1, + } + + def count_parameters(model) -> int: + pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + return pytorch_total_params + model = FrontendSpecRNet(input_channels=1, device=device, frontend_algorithm=["lfcc"]) + model = model.to(device) + print(count_parameters(model)) diff --git a/src/models/whisper_lcnn.py b/src/models/whisper_lcnn.py new file mode 100644 index 0000000000000000000000000000000000000000..8752cee28a7502de4e6a0ccff0fd9e77fe121f53 --- /dev/null +++ b/src/models/whisper_lcnn.py @@ -0,0 +1,89 @@ +import torch + +from src.models.whisper_main import ModelDimensions, Whisper, log_mel_spectrogram +from src.models.lcnn import LCNN +from src import frontends +from src.commons import WHISPER_MODEL_WEIGHTS_PATH + + +class WhisperLCNN(LCNN): + + def __init__(self, input_channels, freeze_encoder, **kwargs): + super().__init__(input_channels=input_channels, **kwargs) + + self.device = kwargs['device'] + checkpoint = torch.load(WHISPER_MODEL_WEIGHTS_PATH) + dims = ModelDimensions(**checkpoint["dims"].__dict__) + model = Whisper(dims) + model = model.to(self.device) + model.load_state_dict(checkpoint["model_state_dict"]) + self.whisper_model = model + if freeze_encoder: + for param in self.whisper_model.parameters(): + param.requires_grad = False + + def compute_whisper_features(self, x): + specs = [] + for sample in x: + specs.append(log_mel_spectrogram(sample)) + x = torch.stack(specs) + x = self.whisper_model(x) + + x = x.permute(0, 2, 1) # (bs, frames, 3 x n_lfcc) + x = x.unsqueeze(1) # (bs, 1, frames, 3 x n_lfcc) + x = x.repeat( + (1, 1, 1, 2) + ) # (bs, 1, frames, 3 x n_lfcc) -> (bs, 1, frames, 3000) + return x + + def forward(self, x): + # we assume that the data is correct (i.e. 30s) + x = self.compute_whisper_features(x) + out = self._compute_embedding(x) + return out + + +class WhisperMultiFrontLCNN(WhisperLCNN): + + def __init__(self, input_channels, freeze_encoder, **kwargs): + super().__init__(input_channels=input_channels, freeze_encoder=freeze_encoder, **kwargs) + + self.frontend = frontends.get_frontend(kwargs['frontend_algorithm']) + print(f"Using {self.frontend} frontend!") + + def forward(self, x): + # Frontend computation + frontend_x = self.frontend(x) + x = self.compute_whisper_features(x) + + x = torch.cat([x, frontend_x], 1) + out = self._compute_embedding(x) + return out + + +if __name__ == "__main__": + import numpy as np + + input_channels = 1 + device = "cpu" + classifier = WhisperLCNN( + input_channels=input_channels, + freeze_encoder=True, + device=device, + ) + + input_channels = 2 + classifier_2 = WhisperMultiFrontLCNN( + input_channels=input_channels, + freeze_encoder=True, + device=device, + frontend_algorithm="lfcc" + ) + x = np.random.rand(2, 30 * 16_000).astype(np.float32) + x = torch.from_numpy(x) + + out = classifier(x) + print(out.shape) + + out = classifier_2(x) + print(out.shape) diff --git a/src/models/whisper_main.py b/src/models/whisper_main.py new file mode 100644 index 0000000000000000000000000000000000000000..9d7ba5312f064956385ce42ccba9792a5126f94a --- /dev/null +++ b/src/models/whisper_main.py @@ -0,0 +1,323 @@ +# Based on https://github.com/openai/whisper/blob/main/whisper/model.py +from dataclasses import dataclass +from functools import lru_cache +import os +from typing import Iterable, Optional, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import Tensor +from torch import nn + + +def exact_div(x, y): + assert x % y == 0 + return x // y + + +# hard-coded audio hyperparameters +SAMPLE_RATE = 16000 +N_FFT = 400 +N_MELS = 80 +HOP_LENGTH = 160 +CHUNK_LENGTH = 30 +N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk +N_FRAMES = exact_div( + N_SAMPLES, HOP_LENGTH +) # 3000: number of frames in a mel spectrogram input + + +def pad_or_trim( + array: Union[torch.Tensor, np.ndarray], + length: int = N_SAMPLES, + *, + axis: int = -1, +) -> torch.Tensor: + """ + Pad or trim the audio array to N_SAMPLES, as expected by the encoder. + """ + if not torch.is_tensor(array): + array = torch.from_numpy(array) + + if array.shape[axis] > length: + array = array.index_select( + dim=axis, index=torch.arange(length, device=array.device) + ) + + if array.shape[axis] < length: + # pad multiple times + num_repeats = int(length / array.shape[axis]) + 1 + array = torch.tile(array, (1, num_repeats))[:, :length] + return array + + +@lru_cache(maxsize=None) +def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor: + """ + load the mel filterbank matrix for projecting STFT into a Mel spectrogram. + Allows decoupling librosa dependency; saved using: + + np.savez_compressed( + "mel_filters.npz", + mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80), + ) + """ + assert n_mels == 80, f"Unsupported n_mels: {n_mels}" + with np.load( + os.path.join(os.path.dirname(__file__), "assets/mel_filters.npz") + ) as f: + return torch.from_numpy(f[f"mel_{n_mels}"]).to(device) + + +def log_mel_spectrogram(audio: torch.Tensor, n_mels: int = N_MELS): + """ + Compute the log-Mel spectrogram of + + Parameters + ---------- + audio: Union[str, np.ndarray, torch.Tensor], shape = (*) + The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz + + n_mels: int + The number of Mel-frequency filters, only 80 is supported + + Returns + ------- + torch.Tensor, shape = (80, n_frames) + A Tensor that contains the Mel spectrogram + """ + window = torch.hann_window(N_FFT).to(audio.device) + stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True) + magnitudes = stft[:, :-1].abs() ** 2 + + filters = mel_filters(audio.device, n_mels) + mel_spec = filters @ magnitudes + + log_spec = torch.clamp(mel_spec, min=1e-10).log10() + log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) + log_spec = (log_spec + 4.0) / 4.0 + return log_spec + + +@dataclass +class ModelDimensions: + n_mels: int + n_audio_ctx: int + n_audio_state: int + n_audio_head: int + n_audio_layer: int + n_vocab: int + n_text_ctx: int + n_text_state: int + n_text_head: int + n_text_layer: int + + +class LayerNorm(nn.LayerNorm): + def forward(self, x: Tensor) -> Tensor: + return super().forward(x.float()).type(x.dtype) + + +class Linear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + return F.linear( + x, + self.weight.to(x.dtype), + None if self.bias is None else self.bias.to(x.dtype), + ) + + +class Conv1d(nn.Conv1d): + def _conv_forward( + self, x: Tensor, weight: Tensor, bias: Optional[Tensor] + ) -> Tensor: + return super()._conv_forward( + x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype) + ) + + +def sinusoids(length, channels, max_timescale=10_000): + """Returns sinusoids for positional embedding""" + assert channels % 2 == 0 + log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) + inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)) + scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] + return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) + + +class MultiHeadAttention(nn.Module): + def __init__(self, n_state: int, n_head: int): + super().__init__() + self.n_head = n_head + self.query = Linear(n_state, n_state) + self.key = Linear(n_state, n_state, bias=False) + self.value = Linear(n_state, n_state) + self.out = Linear(n_state, n_state) + + def forward( + self, + x: Tensor, + xa: Optional[Tensor] = None, + mask: Optional[Tensor] = None, + kv_cache: Optional[dict] = None, + ): + q = self.query(x) + + if kv_cache is None or xa is None or self.key not in kv_cache: + # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors; + # otherwise, perform key/value projections for self- or cross-attention as usual. + k = self.key(x if xa is None else xa) + v = self.value(x if xa is None else xa) + else: + # for cross-attention, calculate keys and values once and reuse in subsequent calls. + k = kv_cache[self.key] + v = kv_cache[self.value] + + wv = self.qkv_attention(q, k, v, mask) + return self.out(wv) + + def qkv_attention( + self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None + ): + n_batch, n_ctx, n_state = q.shape + scale = (n_state // self.n_head) ** -0.25 + q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale + k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale + v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) + + qk = q @ k + if mask is not None: + qk = qk + mask[:n_ctx, :n_ctx] + + w = F.softmax(qk.float(), dim=-1).to(q.dtype) + return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2) + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, n_state: int, n_head: int, cross_attention: bool = False): + super().__init__() + + self.attn = MultiHeadAttention(n_state, n_head) + self.attn_ln = LayerNorm(n_state) + + self.cross_attn = ( + MultiHeadAttention(n_state, n_head) if cross_attention else None + ) + self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None + + n_mlp = n_state * 4 + self.mlp = nn.Sequential( + Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state) + ) + self.mlp_ln = LayerNorm(n_state) + + def forward( + self, + x: Tensor, + xa: Optional[Tensor] = None, + mask: Optional[Tensor] = None, + kv_cache: Optional[dict] = None, + ): + x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache) + if self.cross_attn: + x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache) + x = x + self.mlp(self.mlp_ln(x)) + return x + + +class AudioEncoder(nn.Module): + def __init__( + self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int + ): + super().__init__() + self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1) + self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1) + self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state)) + + self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList( + [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)] + ) + self.ln_post = LayerNorm(n_state) + + def forward(self, x: Tensor): + """ + x : torch.Tensor, shape = (batch_size, n_mels, n_ctx) + the mel spectrogram of the audio + """ + x = F.gelu(self.conv1(x)) + x = F.gelu(self.conv2(x)) + x = x.permute(0, 2, 1) + + assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape" + x = (x + self.positional_embedding).to(x.dtype) + for block in self.blocks: + x = block(x) + + x = self.ln_post(x) + return x + + +class TextDecoder(nn.Module): + def __init__( + self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int + ): + super().__init__() + + self.token_embedding = nn.Embedding(n_vocab, n_state) + self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state)) + + self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList( + [ + ResidualAttentionBlock(n_state, n_head, cross_attention=True) + for _ in range(n_layer) + ] + ) + self.ln = LayerNorm(n_state) + + mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1) + self.register_buffer("mask", mask, persistent=False) + + def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None): + """ + x : torch.LongTensor, shape = (batch_size, <= n_ctx) + the text tokens + xa : torch.Tensor, shape = (batch_size, n_mels, n_audio_ctx) + the encoded audio features to be attended on + """ + offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 + x = ( + self.token_embedding(x) + + self.positional_embedding[offset : offset + x.shape[-1]] + ) + x = x.to(xa.dtype) + + for block in self.blocks: + x = block(x, xa, mask=self.mask, kv_cache=kv_cache) + + x = self.ln(x) + logits = ( + x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1) + ).float() + + return logits + + +class Whisper(nn.Module): + def __init__(self, dims: ModelDimensions): + super().__init__() + self.dims = dims + self.encoder = AudioEncoder( + self.dims.n_mels, + self.dims.n_audio_ctx, + self.dims.n_audio_state, + self.dims.n_audio_head, + self.dims.n_audio_layer, + ) + + def forward(self, mel: torch.Tensor): + return self.encoder(mel) + + @property + def device(self): + return next(self.parameters()).device diff --git a/src/models/whisper_meso_net.py b/src/models/whisper_meso_net.py new file mode 100644 index 0000000000000000000000000000000000000000..9b1fc689dd85ada2246f0cdfd732eedf22c389c8 --- /dev/null +++ b/src/models/whisper_meso_net.py @@ -0,0 +1,88 @@ +import torch +from src import frontends + +from src.models.whisper_main import ModelDimensions, Whisper, log_mel_spectrogram +from src.models.meso_net import MesoInception4 +from src.commons import WHISPER_MODEL_WEIGHTS_PATH + + +class WhisperMesoNet(MesoInception4): + def __init__(self, freeze_encoder, **kwargs): + super().__init__(**kwargs) + + self.device = kwargs['device'] + checkpoint = torch.load(WHISPER_MODEL_WEIGHTS_PATH) + dims = ModelDimensions(**checkpoint["dims"].__dict__) + model = Whisper(dims) + model = model.to(self.device) + model.load_state_dict(checkpoint["model_state_dict"]) + self.whisper_model = model + if freeze_encoder: + for param in self.whisper_model.parameters(): + param.requires_grad = False + + def compute_whisper_features(self, x): + specs = [] + for sample in x: + specs.append(log_mel_spectrogram(sample)) + x = torch.stack(specs) + x = self.whisper_model(x) + + x = x.permute(0, 2, 1) # (bs, frames, 3 x n_lfcc) + x = x.unsqueeze(1) # (bs, 1, frames, 3 x n_lfcc) + x = x.repeat( + (1, 1, 1, 2) + ) # (bs, 1, frames, 3 x n_lfcc) -> (bs, 1, frames, 3000) + return x + + def forward(self, x): + # we assume that the data is correct (i.e. 30s) + x = self.compute_whisper_features(x) + out = self._compute_embedding(x) + return out + + +class WhisperMultiFrontMesoNet(WhisperMesoNet): + def __init__(self, freeze_encoder, **kwargs): + super().__init__(freeze_encoder=freeze_encoder, **kwargs) + self.frontend = frontends.get_frontend(kwargs['frontend_algorithm']) + print(f"Using {self.frontend} frontend!") + + def forward(self, x): + # Frontend computation + frontend_x = self.frontend(x) + x = self.compute_whisper_features(x) + + x = torch.cat([x, frontend_x], 1) + out = self._compute_embedding(x) + return out + + +if __name__ == "__main__": + import numpy as np + + input_channels = 1 + device = "cpu" + classifier = WhisperMesoNet( + input_channels=input_channels, + freeze_encoder=True, + fc1_dim=1024, + device=device, + ) + + input_channels = 2 + classifier_2 = WhisperMultiFrontMesoNet( + input_channels=input_channels, + freeze_encoder=True, + fc1_dim=1024, + device=device, + frontend_algorithm="lfcc" + ) + x = np.random.rand(2, 30 * 16_000).astype(np.float32) + x = torch.from_numpy(x) + + out = classifier(x) + print(out.shape) + + out = classifier_2(x) + print(out.shape) \ No newline at end of file diff --git a/src/models/whisper_specrnet.py b/src/models/whisper_specrnet.py new file mode 100644 index 0000000000000000000000000000000000000000..e9974d499b2f61eed1a68173afac2fba2ed9a813 --- /dev/null +++ b/src/models/whisper_specrnet.py @@ -0,0 +1,97 @@ +import numpy as np +import torch + +from src import frontends +from src.models.whisper_main import ModelDimensions, Whisper, log_mel_spectrogram +from src.models.specrnet import SpecRNet +from src.commons import WHISPER_MODEL_WEIGHTS_PATH + + +class WhisperSpecRNet(SpecRNet): + def __init__(self, input_channels, freeze_encoder, **kwargs): + super().__init__(input_channels=input_channels, **kwargs) + + self.device = kwargs["device"] + checkpoint = torch.load(WHISPER_MODEL_WEIGHTS_PATH) + dims = ModelDimensions(**checkpoint["dims"].__dict__) + model = Whisper(dims) + model = model.to(self.device) + model.load_state_dict(checkpoint["model_state_dict"]) + self.whisper_model = model + if freeze_encoder: + for param in self.whisper_model.parameters(): + param.requires_grad = False + + def compute_whisper_features(self, x): + specs = [] + for sample in x: + specs.append(log_mel_spectrogram(sample)) + x = torch.stack(specs) + x = self.whisper_model(x) + + x = x.permute(0, 2, 1) # (bs, frames, 3 x n_lfcc) + x = x.unsqueeze(1) # (bs, 1, frames, 3 x n_lfcc) + x = x.repeat( + (1, 1, 1, 2) + ) # (bs, 1, frames, 3 x n_lfcc) -> (bs, 1, frames, 3000) + return x + + def forward(self, x): + # we assume that the data is correct (i.e. 30s) + x = self.compute_whisper_features(x) + out = self._compute_embedding(x) + return out + + +class WhisperMultiFrontSpecRNet(WhisperSpecRNet): + def __init__(self, input_channels, freeze_encoder, **kwargs): + super().__init__( + input_channels=input_channels, + freeze_encoder=freeze_encoder, + **kwargs, + ) + self.frontend = frontends.get_frontend(kwargs["frontend_algorithm"]) + print(f"Using {self.frontend} frontend!") + + def forward(self, x): + # Frontend computation + frontend_x = self.frontend(x) + x = self.compute_whisper_features(x) + + x = torch.cat([x, frontend_x], 1) + out = self._compute_embedding(x) + return out + + +if __name__ == "__main__": + import numpy as np + + input_channels = 1 + config = { + "filts": [input_channels, [input_channels, 20], [20, 64], [64, 64]], + "nb_fc_node": 64, + "gru_node": 64, + "nb_gru_layer": 2, + "nb_classes": 1, + } + device = "cpu" + classifier = WhisperSpecRNet( + input_channels, + freeze_encoder=False, + device=device, + ) + input_channels = 2 + classifier_2 = WhisperMultiFrontSpecRNet( + input_channels, + freeze_encoder=False, + device=device, + frontend_algorithm="lfcc" + ) + x = np.random.rand(2, 30 * 16_000).astype(np.float32) + x = torch.from_numpy(x) + + out = classifier(x) + print(out.shape) + + out = classifier_2(x) + print(out.shape) \ No newline at end of file diff --git a/src/trainer.py b/src/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..ba139f7cde31e922b7df6786a21b40ed78a5c119 --- /dev/null +++ b/src/trainer.py @@ -0,0 +1,173 @@ +"""A generic training wrapper.""" +from copy import deepcopy +import logging +from typing import Callable, List, Optional + +import torch +from torch.utils.data import DataLoader + + +LOGGER = logging.getLogger(__name__) + + +class Trainer: + def __init__( + self, + epochs: int = 20, + batch_size: int = 32, + device: str = "cpu", + optimizer_fn: Callable = torch.optim.Adam, + optimizer_kwargs: dict = {"lr": 1e-3}, + use_scheduler: bool = False, + ) -> None: + self.epochs = epochs + self.batch_size = batch_size + self.device = device + self.optimizer_fn = optimizer_fn + self.optimizer_kwargs = optimizer_kwargs + self.epoch_test_losses: List[float] = [] + self.use_scheduler = use_scheduler + + +def forward_and_loss(model, criterion, batch_x, batch_y, **kwargs): + batch_out = model(batch_x) + batch_loss = criterion(batch_out, batch_y) + return batch_out, batch_loss + + +class GDTrainer(Trainer): + def train( + self, + dataset: torch.utils.data.Dataset, + model: torch.nn.Module, + test_len: Optional[float] = None, + test_dataset: Optional[torch.utils.data.Dataset] = None, + ): + if test_dataset is not None: + train = dataset + test = test_dataset + else: + test_len = int(len(dataset) * test_len) + train_len = len(dataset) - test_len + lengths = [train_len, test_len] + train, test = torch.utils.data.random_split(dataset, lengths) + + train_loader = DataLoader( + train, + batch_size=self.batch_size, + shuffle=True, + drop_last=True, + num_workers=6, + ) + test_loader = DataLoader( + test, + batch_size=self.batch_size, + shuffle=True, + drop_last=True, + num_workers=6, + ) + + criterion = torch.nn.BCEWithLogitsLoss() + optim = self.optimizer_fn(model.parameters(), **self.optimizer_kwargs) + + best_model = None + best_acc = 0 + + LOGGER.info(f"Starting training for {self.epochs} epochs!") + + forward_and_loss_fn = forward_and_loss + + if self.use_scheduler: + batches_per_epoch = len(train_loader) * 2 # every 2nd epoch + scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( + optimizer=optim, + T_0=batches_per_epoch, + T_mult=1, + eta_min=5e-6, + # verbose=True, + ) + use_cuda = self.device != "cpu" + + for epoch in range(self.epochs): + LOGGER.info(f"Epoch num: {epoch}") + + running_loss = 0 + num_correct = 0.0 + num_total = 0.0 + model.train() + + for i, (batch_x, _, batch_y) in enumerate(train_loader): + batch_size = batch_x.size(0) + num_total += batch_size + batch_x = batch_x.to(self.device) + + batch_y = batch_y.unsqueeze(1).type(torch.float32).to(self.device) + + batch_out, batch_loss = forward_and_loss_fn( + model, criterion, batch_x, batch_y, use_cuda=use_cuda + ) + batch_pred = (torch.sigmoid(batch_out) + 0.5).int() + num_correct += (batch_pred == batch_y.int()).sum(dim=0).item() + + running_loss += batch_loss.item() * batch_size + + if i % 100 == 0: + LOGGER.info( + f"[{epoch:04d}][{i:05d}]: {running_loss / num_total} {num_correct/num_total*100}" + ) + + optim.zero_grad() + batch_loss.backward() + optim.step() + if self.use_scheduler: + scheduler.step() + + running_loss /= num_total + train_accuracy = (num_correct / num_total) * 100 + + LOGGER.info( + f"Epoch [{epoch+1}/{self.epochs}]: train/loss: {running_loss}, train/accuracy: {train_accuracy}" + ) + + test_running_loss = 0.0 + num_correct = 0.0 + num_total = 0.0 + model.eval() + eer_val = 0 + + for batch_x, _, batch_y in test_loader: + batch_size = batch_x.size(0) + num_total += batch_size + batch_x = batch_x.to(self.device) + + with torch.no_grad(): + batch_pred = model(batch_x) + + batch_y = batch_y.unsqueeze(1).type(torch.float32).to(self.device) + batch_loss = criterion(batch_pred, batch_y) + + test_running_loss += batch_loss.item() * batch_size + + batch_pred = torch.sigmoid(batch_pred) + batch_pred_label = (batch_pred + 0.5).int() + num_correct += (batch_pred_label == batch_y.int()).sum(dim=0).item() + + if num_total == 0: + num_total = 1 + + test_running_loss /= num_total + test_acc = 100 * (num_correct / num_total) + LOGGER.info( + f"Epoch [{epoch+1}/{self.epochs}]: test/loss: {test_running_loss}, test/accuracy: {test_acc}, test/eer: {eer_val}" + ) + + if best_model is None or test_acc > best_acc: + best_acc = test_acc + best_model = deepcopy(model.state_dict()) + + LOGGER.info( + f"[{epoch:04d}]: {running_loss} - train acc: {train_accuracy} - test_acc: {test_acc}" + ) + + model.load_state_dict(best_model) + return model diff --git a/temps/53cc0ee8-4227-11ee-8cbf-00a554bbe3d7.mp3 b/temps/53cc0ee8-4227-11ee-8cbf-00a554bbe3d7.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..06773a940bb6a986ffd67046feeceb1432f4b737 --- /dev/null +++ b/temps/53cc0ee8-4227-11ee-8cbf-00a554bbe3d7.mp3 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2b9cb546cabedbb92126529d496a39ac1c2547ec6ea64d8f48bdea61597c1b93 +size 729600 diff --git a/temps/58631253-422a-11ee-a8ec-00a554bbe3d7.mp3 b/temps/58631253-422a-11ee-a8ec-00a554bbe3d7.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..c48e19384009f24e372d85b62078fc1d68234535 --- /dev/null +++ b/temps/58631253-422a-11ee-a8ec-00a554bbe3d7.mp3 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ce14d946d66b259b3700dce2a0457ffafb42c6227ea9cf1ce9c8509901fa495a +size 223208 diff --git a/temps/7252a08b-422a-11ee-a8ec-00a554bbe3d7.mp3 b/temps/7252a08b-422a-11ee-a8ec-00a554bbe3d7.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..c48e19384009f24e372d85b62078fc1d68234535 --- /dev/null +++ b/temps/7252a08b-422a-11ee-a8ec-00a554bbe3d7.mp3 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ce14d946d66b259b3700dce2a0457ffafb42c6227ea9cf1ce9c8509901fa495a +size 223208 diff --git a/temps/82cc9d15-422a-11ee-a8ec-00a554bbe3d7.mp3 b/temps/82cc9d15-422a-11ee-a8ec-00a554bbe3d7.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..c48e19384009f24e372d85b62078fc1d68234535 --- /dev/null +++ b/temps/82cc9d15-422a-11ee-a8ec-00a554bbe3d7.mp3 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ce14d946d66b259b3700dce2a0457ffafb42c6227ea9cf1ce9c8509901fa495a +size 223208 diff --git a/temps/8c943835-422a-11ee-a8ec-00a554bbe3d7.mp3 b/temps/8c943835-422a-11ee-a8ec-00a554bbe3d7.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..c48e19384009f24e372d85b62078fc1d68234535 --- /dev/null +++ b/temps/8c943835-422a-11ee-a8ec-00a554bbe3d7.mp3 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ce14d946d66b259b3700dce2a0457ffafb42c6227ea9cf1ce9c8509901fa495a +size 223208 diff --git a/temps/99e6d45c-422a-11ee-a8ec-00a554bbe3d7.mp3 b/temps/99e6d45c-422a-11ee-a8ec-00a554bbe3d7.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..c48e19384009f24e372d85b62078fc1d68234535 --- /dev/null +++ b/temps/99e6d45c-422a-11ee-a8ec-00a554bbe3d7.mp3 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ce14d946d66b259b3700dce2a0457ffafb42c6227ea9cf1ce9c8509901fa495a +size 223208 diff --git a/temps/a4cc39ba-422a-11ee-a8ec-00a554bbe3d7.mp3 b/temps/a4cc39ba-422a-11ee-a8ec-00a554bbe3d7.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..c48e19384009f24e372d85b62078fc1d68234535 --- /dev/null +++ b/temps/a4cc39ba-422a-11ee-a8ec-00a554bbe3d7.mp3 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ce14d946d66b259b3700dce2a0457ffafb42c6227ea9cf1ce9c8509901fa495a +size 223208 diff --git a/temps/b1501022-4227-11ee-92e8-00a554bbe3d7.mp3 b/temps/b1501022-4227-11ee-92e8-00a554bbe3d7.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..06773a940bb6a986ffd67046feeceb1432f4b737 --- /dev/null +++ b/temps/b1501022-4227-11ee-92e8-00a554bbe3d7.mp3 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2b9cb546cabedbb92126529d496a39ac1c2547ec6ea64d8f48bdea61597c1b93 +size 729600 diff --git a/temps/c60c7f4b-422a-11ee-a8ec-00a554bbe3d7.mp3 b/temps/c60c7f4b-422a-11ee-a8ec-00a554bbe3d7.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..c48e19384009f24e372d85b62078fc1d68234535 --- /dev/null +++ b/temps/c60c7f4b-422a-11ee-a8ec-00a554bbe3d7.mp3 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ce14d946d66b259b3700dce2a0457ffafb42c6227ea9cf1ce9c8509901fa495a +size 223208 diff --git a/temps/c7b4085b-4227-11ee-ad3e-00a554bbe3d7.mp3 b/temps/c7b4085b-4227-11ee-ad3e-00a554bbe3d7.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..06773a940bb6a986ffd67046feeceb1432f4b737 --- /dev/null +++ b/temps/c7b4085b-4227-11ee-ad3e-00a554bbe3d7.mp3 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2b9cb546cabedbb92126529d496a39ac1c2547ec6ea64d8f48bdea61597c1b93 +size 729600 diff --git a/temps/ca6fcdff-4229-11ee-a8ec-00a554bbe3d7.mp3 b/temps/ca6fcdff-4229-11ee-a8ec-00a554bbe3d7.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..c48e19384009f24e372d85b62078fc1d68234535 --- /dev/null +++ b/temps/ca6fcdff-4229-11ee-a8ec-00a554bbe3d7.mp3 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ce14d946d66b259b3700dce2a0457ffafb42c6227ea9cf1ce9c8509901fa495a +size 223208 diff --git a/temps/d00f5c8d-4227-11ee-bd92-00a554bbe3d7.wav b/temps/d00f5c8d-4227-11ee-bd92-00a554bbe3d7.wav new file mode 100644 index 0000000000000000000000000000000000000000..55763d445ee0a9e2e1a16c85ef2a940696d6d57e --- /dev/null +++ b/temps/d00f5c8d-4227-11ee-bd92-00a554bbe3d7.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:81717fd4d5fa7ac70b026ee873e768cc5bcc86279ad81cc296bc22c22102c68a +size 7397458 diff --git a/train_and_test.py b/train_and_test.py new file mode 100644 index 0000000000000000000000000000000000000000..59a10bc254ce3c9e7d90f2133e72bc4cc7a80302 --- /dev/null +++ b/train_and_test.py @@ -0,0 +1,147 @@ +import argparse +import logging +from pathlib import Path + +import torch +import yaml + +import train_models +import evaluate_models +from src.commons import set_seed + + +LOGGER = logging.getLogger() +LOGGER.setLevel(logging.INFO) + +ch = logging.StreamHandler() +formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") +ch.setFormatter(formatter) +LOGGER.addHandler(ch) + + +def parse_args(): + parser = argparse.ArgumentParser() + + ASVSPOOF_DATASET_PATH = "../datasets/ASVspoof2021/DF" + IN_THE_WILD_DATASET_PATH = "../datasets/release_in_the_wild" + + parser.add_argument( + "--asv_path", + type=str, + default=ASVSPOOF_DATASET_PATH, + help="Path to ASVspoof2021 dataset directory", + ) + parser.add_argument( + "--in_the_wild_path", + type=str, + default=IN_THE_WILD_DATASET_PATH, + help="Path to In The Wild dataset directory", + ) + default_model_config = "config.yaml" + parser.add_argument( + "--config", + help="Model config file path (default: config.yaml)", + type=str, + default=default_model_config, + ) + + default_train_amount = None + parser.add_argument( + "--train_amount", + "-a", + help=f"Amount of files to load for training.", + type=int, + default=default_train_amount, + ) + + default_valid_amount = None + parser.add_argument( + "--valid_amount", + "-va", + help=f"Amount of files to load for testing.", + type=int, + default=default_valid_amount, + ) + + default_test_amount = None + parser.add_argument( + "--test_amount", + "-ta", + help=f"Amount of files to load for testing.", + type=int, + default=default_test_amount, + ) + + default_batch_size = 8 + parser.add_argument( + "--batch_size", + "-b", + help=f"Batch size (default: {default_batch_size}).", + type=int, + default=default_batch_size, + ) + + default_epochs = 10 # it was 5 originally + parser.add_argument( + "--epochs", + "-e", + help=f"Epochs (default: {default_epochs}).", + type=int, + default=default_epochs, + ) + + default_model_dir = "trained_models" + parser.add_argument( + "--ckpt", + help=f"Checkpoint directory (default: {default_model_dir}).", + type=str, + default=default_model_dir, + ) + + parser.add_argument("--cpu", "-c", help="Force using cpu?", action="store_true") + + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + # TRAIN MODEL + + with open(args.config, "r") as f: + config = yaml.safe_load(f) + + seed = config["data"].get("seed", 42) + # fix all seeds + set_seed(seed) + + if not args.cpu and torch.cuda.is_available(): + device = "cuda" + else: + device = "cpu" + + model_dir = Path(args.ckpt) + model_dir.mkdir(parents=True, exist_ok=True) + + evaluation_config_path, model_path = train_models.train_nn( + datasets_paths=[ + args.asv_path, + ], + device=device, + amount_to_use=(args.train_amount, args.valid_amount), + batch_size=args.batch_size, + epochs=args.epochs, + model_dir=model_dir, + config=config, + ) + + with open(evaluation_config_path, "r") as f: + config = yaml.safe_load(f) + + evaluate_models.evaluate_nn( + model_paths=config["checkpoint"].get("path", []), + batch_size=args.batch_size, + datasets_paths=[args.in_the_wild_path], + model_config=config["model"], + amount_to_use=args.test_amount, + device=device, + ) diff --git a/train_models.py b/train_models.py new file mode 100644 index 0000000000000000000000000000000000000000..5debba5eeacbf20bd0b5e6e174f2381473891e1d --- /dev/null +++ b/train_models.py @@ -0,0 +1,235 @@ +import argparse +import logging +import sys +import time +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Union + +import torch +import yaml + +from src.datasets.detection_dataset import DetectionDataset +from src.models import models +from src.trainer import GDTrainer +from src.commons import set_seed + + +def save_model( + model: torch.nn.Module, + model_dir: Union[Path, str], + name: str, +) -> None: + full_model_dir = Path(f"{model_dir}/{name}") + full_model_dir.mkdir(parents=True, exist_ok=True) + torch.save(model.state_dict(), f"{full_model_dir}/ckpt.pth") + + +def get_datasets( + datasets_paths: List[Union[Path, str]], + amount_to_use: Tuple[Optional[int], Optional[int]], +) -> Tuple[DetectionDataset, DetectionDataset]: + data_train = DetectionDataset( + asvspoof_path=datasets_paths[0], + subset="train", + reduced_number=amount_to_use[0], + oversample=True, + ) + data_test = DetectionDataset( + asvspoof_path=datasets_paths[0], + subset="test", + reduced_number=amount_to_use[1], + oversample=True, + ) + + return data_train, data_test + + +def train_nn( + datasets_paths: List[Union[Path, str]], + batch_size: int, + epochs: int, + device: str, + config: Dict, + model_dir: Optional[Path] = None, + amount_to_use: Tuple[Optional[int], Optional[int]] = (None, None), + config_save_path: str = "configs", +) -> Tuple[str, str]: + logging.info("Loading data...") + model_config = config["model"] + model_name, model_parameters = model_config["name"], model_config["parameters"] + optimizer_config = model_config["optimizer"] + + timestamp = time.time() + checkpoint_path = "" + + data_train, data_test = get_datasets( + datasets_paths=datasets_paths, + amount_to_use=amount_to_use, + ) + + current_model = models.get_model( + model_name=model_name, + config=model_parameters, + device=device, + ) + + # If provided weights, apply corresponding ones (from an appropriate fold) + model_path = config["checkpoint"]["path"] + if model_path: + current_model.load_state_dict(torch.load(model_path)) + logging.info( + f"Finetuning '{model_name}' model, weights path: '{model_path}', on {len(data_train)} audio files." + ) + if config["model"]["parameters"].get("freeze_encoder"): + for param in current_model.whisper_model.parameters(): + param.requires_grad = False + else: + logging.info(f"Training '{model_name}' model on {len(data_train)} audio files.") + current_model = current_model.to(device) + + use_scheduler = "rawnet3" in model_name.lower() + + current_model = GDTrainer( + device=device, + batch_size=batch_size, + epochs=epochs, + optimizer_kwargs=optimizer_config, + use_scheduler=use_scheduler, + ).train( + dataset=data_train, + model=current_model, + test_dataset=data_test, + ) + + if model_dir is not None: + save_name = f"model__{model_name}__{timestamp}" + save_model( + model=current_model, + model_dir=model_dir, + name=save_name, + ) + checkpoint_path = str(model_dir.resolve() / save_name / "ckpt.pth") + + # Save config for testing + if model_dir is not None: + config["checkpoint"] = {"path": checkpoint_path} + config_name = f"model__{model_name}__{timestamp}.yaml" + config_save_path = str(Path(config_save_path) / config_name) + with open(config_save_path, "w") as f: + yaml.dump(config, f) + logging.info("Test config saved at location '{}'!".format(config_save_path)) + return config_save_path, checkpoint_path + + +def main(args): + LOGGER = logging.getLogger() + LOGGER.setLevel(logging.INFO) + + ch = logging.StreamHandler() + formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") + ch.setFormatter(formatter) + LOGGER.addHandler(ch) + logging.basicConfig(stream=sys.stdout, level=logging.INFO) + + with open(args.config, "r") as f: + config = yaml.safe_load(f) + + seed = config["data"].get("seed", 42) + # fix all seeds + set_seed(seed) + + if not args.cpu and torch.cuda.is_available(): + device = "cuda" + else: + device = "cpu" + + model_dir = Path(args.ckpt) + model_dir.mkdir(parents=True, exist_ok=True) + + train_nn( + datasets_paths=[ + args.asv_path, + args.wavefake_path, + args.celeb_path, + args.asv19_path, + ], + device=device, + amount_to_use=(args.train_amount, args.test_amount), + batch_size=args.batch_size, + epochs=args.epochs, + model_dir=model_dir, + config=config, + ) + + +def parse_args(): + parser = argparse.ArgumentParser() + + ASVSPOOF_DATASET_PATH = "../datasets/ASVspoof2021/DF" + + parser.add_argument( + "--asv_path", + type=str, + default=ASVSPOOF_DATASET_PATH, + help="Path to ASVspoof2021 dataset directory", + ) + + default_model_config = "config.yaml" + parser.add_argument( + "--config", + help="Model config file path (default: config.yaml)", + type=str, + default=default_model_config, + ) + + default_train_amount = None + parser.add_argument( + "--train_amount", + "-a", + help=f"Amount of files to load for training.", + type=int, + default=default_train_amount, + ) + + default_test_amount = None + parser.add_argument( + "--test_amount", + "-ta", + help=f"Amount of files to load for testing.", + type=int, + default=default_test_amount, + ) + + default_batch_size = 8 + parser.add_argument( + "--batch_size", + "-b", + help=f"Batch size (default: {default_batch_size}).", + type=int, + default=default_batch_size, + ) + + default_epochs = 10 + parser.add_argument( + "--epochs", + "-e", + help=f"Epochs (default: {default_epochs}).", + type=int, + default=default_epochs, + ) + + default_model_dir = "trained_models" + parser.add_argument( + "--ckpt", + help=f"Checkpoint directory (default: {default_model_dir}).", + type=str, + default=default_model_dir, + ) + + parser.add_argument("--cpu", "-c", help="Force using cpu?", action="store_true") + + return parser.parse_args() + + +if __name__ == "__main__": + main(parse_args()) diff --git a/utils.py b/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e322954902db3068b7bff7f44b9e67e13dda5e9f --- /dev/null +++ b/utils.py @@ -0,0 +1,101 @@ +import torch +import numpy as np +import cv2 +import tempfile, base64 + + +def readb64(uri): + encoded_data = uri.split(',')[-1] + nparr = np.frombuffer(base64.b64decode(encoded_data), np.uint8) + img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) + return img + +def img2base64(img, extension="jpg"): + _, img_encoded = cv2.imencode(f".{extension}", img) + img_base64 = base64.b64encode(img_encoded) + img_base64 = img_base64.decode('utf-8') + return img_base64 + +def binary2video(video_binary): + # byte_arr = BytesIO() + # byte_arr.write(video_binary) + + temp_ = tempfile.NamedTemporaryFile(suffix='.mp4') + # decoded_string = base64.b64decode(video_binary) + + temp_.write(video_binary) + video_capture = cv2.VideoCapture(temp_.name) + ret, frame = video_capture.read() + return video_capture + +def extract_frames(data_path, interval=30, max_frames=50): + """Method to extract frames""" + cap = cv2.VideoCapture(data_path) + frame_num = 0 + frames = list() + + while cap.isOpened(): + success, image = cap.read() + if not success: + break + # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + # image = torch.tensor(image) - torch.tensor([104, 117, 123]) + if frame_num % interval == 0: + frames.append(image) + frame_num += 1 + if len(frames) > max_frames: + break + cap.release() + # if len(frames) > max_frames: + # samples = np.random.choice( + # np.arange(0, len(frames)), size=max_frames, replace=False) + # return [frames[_] for _ in samples] + return frames + +"""FilePicker for streamlit. +Still doesn't seem to be a good solution for a way to select files to process from the server Streamlit is running on. +Here's a pretty functional solution. +Usage: +``` +import streamlit as st +from filepicker import st_file_selector +tif_file = st_file_selector(st, key = 'tif', label = 'Choose tif file') +``` +""" + +import os +import streamlit as st + +def update_dir(key): + choice = st.session_state[key] + if os.path.isdir(os.path.join(st.session_state[key+'curr_dir'], choice)): + st.session_state[key+'curr_dir'] = os.path.normpath(os.path.join(st.session_state[key+'curr_dir'], choice)) + files = sorted(os.listdir(st.session_state[key+'curr_dir'])) + if "images" in files: + files.remove("images") + st.session_state[key+'files'] = files + +def st_file_selector(st_placeholder, path='.', label='Select a file/folder', key = 'selected'): + if key+'curr_dir' not in st.session_state: + base_path = '.' if path is None or path == '' else path + base_path = base_path if os.path.isdir(base_path) else os.path.dirname(base_path) + base_path = '.' if base_path is None or base_path == '' else base_path + + files = sorted(os.listdir(base_path)) + files.insert(0, 'Choose a file...') + if "images" in files: + files.remove("images") + st.session_state[key+'files'] = files + st.session_state[key+'curr_dir'] = base_path + else: + base_path = st.session_state[key+'curr_dir'] + + selected_file = st_placeholder.selectbox(label=label, + options=st.session_state[key+'files'], + key=key, + on_change = lambda: update_dir(key)) + + if selected_file == "Choose a file...": + return None + + return selected_file \ No newline at end of file