P-FAD / evaluate_models.py
mrneuralnet's picture
Initial commit
3fb4562
raw
history blame
No virus
9.51 kB
import argparse
import logging
from pathlib import Path
from typing import Dict, List, Optional, Union
import sys
import torch
import yaml
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score
from torch.utils.data import DataLoader
from src import metrics, commons
from src.models import models
from src.datasets.base_dataset import SimpleAudioFakeDataset
from src.datasets.in_the_wild_dataset import InTheWildDataset
from src.datasets.folder_dataset import FolderDataset, FileDataset
def get_dataset(
datasets_paths: List[Union[Path, str]],
amount_to_use: Optional[int],
) -> SimpleAudioFakeDataset:
data_val = FolderDataset(
path=datasets_paths[0]
)
return data_val
def get_dataset_file(
datasets_path,
amount_to_use: Optional[int],
) -> SimpleAudioFakeDataset:
data_val = FileDataset(
path=datasets_path
)
return data_val
def evaluate_nn(
model_paths: List[Path],
datasets_paths: List[Union[Path, str]],
model_config: Dict,
device: str,
amount_to_use: Optional[int] = None,
batch_size: int = 8,
):
logging.info("Loading data...")
model_name, model_parameters = model_config["name"], model_config["parameters"]
# Load model architecture
model = models.get_model(
model_name=model_name,
config=model_parameters,
device=device,
)
# If provided weights, apply corresponding ones (from an appropriate fold)
if len(model_paths):
state_dict = torch.load(model_paths, map_location=device)
model.load_state_dict(state_dict)
model = model.to(device)
data_val = get_dataset(
datasets_paths=datasets_paths,
amount_to_use=amount_to_use,
)
logging.info(
f"Testing '{model_name}' model, weights path: '{model_paths}', on {len(data_val)} audio files."
)
test_loader = DataLoader(
data_val,
batch_size=batch_size,
shuffle=True,
drop_last=False,
num_workers=3,
)
batches_number = len(data_val) // batch_size
num_correct = 0.0
num_total = 0.0
y_pred = torch.Tensor([]).to(device)
y = torch.Tensor([]).to(device)
y_pred_label = torch.Tensor([]).to(device)
preds = []
for i, (batch_x, _, batch_y, metadata) in enumerate(test_loader):
model.eval()
_, path, _, _ = metadata
if i % 10 == 0:
print(f"Batch [{i}/{batches_number}]")
with torch.no_grad():
batch_x = batch_x.to(device)
batch_y = batch_y.to(device)
num_total += batch_x.size(0)
batch_pred = model(batch_x).squeeze(1)
batch_pred = torch.sigmoid(batch_pred)
batch_pred_label = (batch_pred + 0.5).int()
num_correct += (batch_pred_label == batch_y.int()).sum(dim=0).item()
y_pred = torch.concat([y_pred, batch_pred], dim=0)
y_pred_label = torch.concat([y_pred_label, batch_pred_label], dim=0)
y = torch.concat([y, batch_y], dim=0)
for i in range(len(y_pred_label)):
label = 'Fake' if y_pred_label[i] == 0 else 'Real'
print(f'{path[i]}')
print(f' Prediction: : {label}')
print(f' Probability: {y_pred[i]})')
preds.append((label, y_pred[i].detach().cpu().item()))
return preds
eval_accuracy = (num_correct / num_total) * 100
precision, recall, f1_score, support = precision_recall_fscore_support(
y.cpu().numpy(), y_pred_label.cpu().numpy(), average="binary", beta=1.0
)
auc_score = roc_auc_score(y_true=y.cpu().numpy(), y_score=y_pred.cpu().numpy())
# For EER flip values, following original evaluation implementation
y_for_eer = 1 - y
thresh, eer, fpr, tpr = metrics.calculate_eer(
y=y_for_eer.cpu().numpy(),
y_score=y_pred.cpu().numpy(),
)
eer_label = f"eval/eer"
accuracy_label = f"eval/accuracy"
precision_label = f"eval/precision"
recall_label = f"eval/recall"
f1_label = f"eval/f1_score"
auc_label = f"eval/auc"
logging.info(
f"{eer_label}: {eer:.4f}, {accuracy_label}: {eval_accuracy:.4f}, {precision_label}: {precision:.4f}, {recall_label}: {recall:.4f}, {f1_label}: {f1_score:.4f}, {auc_label}: {auc_score:.4f}"
)
def load_model(config, device):
model_config = config['model']
model_name, model_parameters = model_config["name"], model_config["parameters"]
model_paths = config["checkpoint"].get("path", [])
# Load model architecture
model = models.get_model(
model_name=model_name,
config=model_parameters,
device=device,
)
# If provided weights, apply corresponding ones (from an appropriate fold)
if len(model_paths):
state_dict = torch.load(model_paths, map_location=device)
model.load_state_dict(state_dict)
model = model.to(device)
return model
def inference(
model,
datasets_path,
device: str,
amount_to_use: Optional[int] = None,
batch_size: int = 8,
):
logging.info("Loading data...")
data_val = get_dataset_file(
datasets_path=datasets_path,
amount_to_use=amount_to_use,
)
test_loader = DataLoader(
data_val,
batch_size=batch_size,
shuffle=True,
drop_last=False,
num_workers=3,
)
batches_number = len(data_val) // batch_size
num_correct = 0.0
num_total = 0.0
y_pred = torch.Tensor([]).to(device)
y = torch.Tensor([]).to(device)
y_pred_label = torch.Tensor([]).to(device)
preds = []
for i, (batch_x, _, batch_y, metadata) in enumerate(test_loader):
model.eval()
_, path, _, _ = metadata
if i % 10 == 0:
print(f"Batch [{i}/{batches_number}]")
with torch.no_grad():
batch_x = batch_x.to(device)
batch_y = batch_y.to(device)
num_total += batch_x.size(0)
batch_pred = model(batch_x).squeeze(1)
batch_pred = torch.sigmoid(batch_pred)
batch_pred_label = (batch_pred + 0.5).int()
num_correct += (batch_pred_label == batch_y.int()).sum(dim=0).item()
y_pred = torch.concat([y_pred, batch_pred], dim=0)
y_pred_label = torch.concat([y_pred_label, batch_pred_label], dim=0)
y = torch.concat([y, batch_y], dim=0)
for i in range(len(y_pred_label)):
label = 'Fake' if y_pred_label[i] == 0 else 'Real'
print(f'{path[i]}')
print(f' Prediction: : {label}')
print(f' Probability: {y_pred[i]})')
preds.append((label, y_pred[i].detach().cpu().item()))
return preds
eval_accuracy = (num_correct / num_total) * 100
precision, recall, f1_score, support = precision_recall_fscore_support(
y.cpu().numpy(), y_pred_label.cpu().numpy(), average="binary", beta=1.0
)
auc_score = roc_auc_score(y_true=y.cpu().numpy(), y_score=y_pred.cpu().numpy())
# For EER flip values, following original evaluation implementation
y_for_eer = 1 - y
thresh, eer, fpr, tpr = metrics.calculate_eer(
y=y_for_eer.cpu().numpy(),
y_score=y_pred.cpu().numpy(),
)
eer_label = f"eval/eer"
accuracy_label = f"eval/accuracy"
precision_label = f"eval/precision"
recall_label = f"eval/recall"
f1_label = f"eval/f1_score"
auc_label = f"eval/auc"
logging.info(
f"{eer_label}: {eer:.4f}, {accuracy_label}: {eval_accuracy:.4f}, {precision_label}: {precision:.4f}, {recall_label}: {recall:.4f}, {f1_label}: {f1_score:.4f}, {auc_label}: {auc_score:.4f}"
)
def main(args):
LOGGER = logging.getLogger()
LOGGER.setLevel(logging.INFO)
ch = logging.StreamHandler()
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
ch.setFormatter(formatter)
LOGGER.addHandler(ch)
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
if not args.cpu and torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
with open(args.config, "r") as f:
config = yaml.safe_load(f)
seed = config["data"].get("seed", 42)
# fix all seeds - this should not actually change anything
commons.set_seed(seed)
evaluate_nn(
model_paths=config["checkpoint"].get("path", []),
datasets_paths=[
args.folder_path,
],
model_config=config["model"],
amount_to_use=args.amount,
device=device,
)
def parse_args():
parser = argparse.ArgumentParser()
# If assigned as None, then it won't be taken into account
FOLDER_DATASET_PATH = "sample_files"
parser.add_argument(
"--folder_path", type=str, default=FOLDER_DATASET_PATH
)
default_model_config = "config.yaml"
parser.add_argument(
"--config",
help="Model config file path (default: config.yaml)",
type=str,
default=default_model_config,
)
default_amount = None
parser.add_argument(
"--amount",
"-a",
help=f"Amount of files to load from each directory (default: {default_amount} - use all).",
type=int,
default=default_amount,
)
parser.add_argument("--cpu", "-c", help="Force using cpu", action="store_true")
return parser.parse_args()
if __name__ == "__main__":
main(parse_args())