Spaces:
Runtime error
Runtime error
import argparse | |
import logging | |
from pathlib import Path | |
from typing import Dict, List, Optional, Union | |
import sys | |
import torch | |
import yaml | |
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score | |
from torch.utils.data import DataLoader | |
from src import metrics, commons | |
from src.models import models | |
from src.datasets.base_dataset import SimpleAudioFakeDataset | |
from src.datasets.in_the_wild_dataset import InTheWildDataset | |
from src.datasets.folder_dataset import FolderDataset, FileDataset | |
def get_dataset( | |
datasets_paths: List[Union[Path, str]], | |
amount_to_use: Optional[int], | |
) -> SimpleAudioFakeDataset: | |
data_val = FolderDataset( | |
path=datasets_paths[0] | |
) | |
return data_val | |
def get_dataset_file( | |
datasets_path, | |
amount_to_use: Optional[int], | |
) -> SimpleAudioFakeDataset: | |
data_val = FileDataset( | |
path=datasets_path | |
) | |
return data_val | |
def evaluate_nn( | |
model_paths: List[Path], | |
datasets_paths: List[Union[Path, str]], | |
model_config: Dict, | |
device: str, | |
amount_to_use: Optional[int] = None, | |
batch_size: int = 8, | |
): | |
logging.info("Loading data...") | |
model_name, model_parameters = model_config["name"], model_config["parameters"] | |
# Load model architecture | |
model = models.get_model( | |
model_name=model_name, | |
config=model_parameters, | |
device=device, | |
) | |
# If provided weights, apply corresponding ones (from an appropriate fold) | |
if len(model_paths): | |
state_dict = torch.load(model_paths, map_location=device) | |
model.load_state_dict(state_dict) | |
model = model.to(device) | |
data_val = get_dataset( | |
datasets_paths=datasets_paths, | |
amount_to_use=amount_to_use, | |
) | |
logging.info( | |
f"Testing '{model_name}' model, weights path: '{model_paths}', on {len(data_val)} audio files." | |
) | |
test_loader = DataLoader( | |
data_val, | |
batch_size=batch_size, | |
shuffle=True, | |
drop_last=False, | |
num_workers=3, | |
) | |
batches_number = len(data_val) // batch_size | |
num_correct = 0.0 | |
num_total = 0.0 | |
y_pred = torch.Tensor([]).to(device) | |
y = torch.Tensor([]).to(device) | |
y_pred_label = torch.Tensor([]).to(device) | |
preds = [] | |
for i, (batch_x, _, batch_y, metadata) in enumerate(test_loader): | |
model.eval() | |
_, path, _, _ = metadata | |
if i % 10 == 0: | |
print(f"Batch [{i}/{batches_number}]") | |
with torch.no_grad(): | |
batch_x = batch_x.to(device) | |
batch_y = batch_y.to(device) | |
num_total += batch_x.size(0) | |
batch_pred = model(batch_x).squeeze(1) | |
batch_pred = torch.sigmoid(batch_pred) | |
batch_pred_label = (batch_pred + 0.5).int() | |
num_correct += (batch_pred_label == batch_y.int()).sum(dim=0).item() | |
y_pred = torch.concat([y_pred, batch_pred], dim=0) | |
y_pred_label = torch.concat([y_pred_label, batch_pred_label], dim=0) | |
y = torch.concat([y, batch_y], dim=0) | |
for i in range(len(y_pred_label)): | |
label = 'Fake' if y_pred_label[i] == 0 else 'Real' | |
print(f'{path[i]}') | |
print(f' Prediction: : {label}') | |
print(f' Probability: {y_pred[i]})') | |
preds.append((label, y_pred[i].detach().cpu().item())) | |
return preds | |
eval_accuracy = (num_correct / num_total) * 100 | |
precision, recall, f1_score, support = precision_recall_fscore_support( | |
y.cpu().numpy(), y_pred_label.cpu().numpy(), average="binary", beta=1.0 | |
) | |
auc_score = roc_auc_score(y_true=y.cpu().numpy(), y_score=y_pred.cpu().numpy()) | |
# For EER flip values, following original evaluation implementation | |
y_for_eer = 1 - y | |
thresh, eer, fpr, tpr = metrics.calculate_eer( | |
y=y_for_eer.cpu().numpy(), | |
y_score=y_pred.cpu().numpy(), | |
) | |
eer_label = f"eval/eer" | |
accuracy_label = f"eval/accuracy" | |
precision_label = f"eval/precision" | |
recall_label = f"eval/recall" | |
f1_label = f"eval/f1_score" | |
auc_label = f"eval/auc" | |
logging.info( | |
f"{eer_label}: {eer:.4f}, {accuracy_label}: {eval_accuracy:.4f}, {precision_label}: {precision:.4f}, {recall_label}: {recall:.4f}, {f1_label}: {f1_score:.4f}, {auc_label}: {auc_score:.4f}" | |
) | |
def load_model(config, device): | |
model_config = config['model'] | |
model_name, model_parameters = model_config["name"], model_config["parameters"] | |
model_paths = config["checkpoint"].get("path", []) | |
# Load model architecture | |
model = models.get_model( | |
model_name=model_name, | |
config=model_parameters, | |
device=device, | |
) | |
# If provided weights, apply corresponding ones (from an appropriate fold) | |
if len(model_paths): | |
state_dict = torch.load(model_paths, map_location=device) | |
model.load_state_dict(state_dict) | |
model = model.to(device) | |
return model | |
def inference( | |
model, | |
datasets_path, | |
device: str, | |
amount_to_use: Optional[int] = None, | |
batch_size: int = 8, | |
): | |
logging.info("Loading data...") | |
data_val = get_dataset_file( | |
datasets_path=datasets_path, | |
amount_to_use=amount_to_use, | |
) | |
test_loader = DataLoader( | |
data_val, | |
batch_size=batch_size, | |
shuffle=True, | |
drop_last=False, | |
num_workers=3, | |
) | |
batches_number = len(data_val) // batch_size | |
num_correct = 0.0 | |
num_total = 0.0 | |
y_pred = torch.Tensor([]).to(device) | |
y = torch.Tensor([]).to(device) | |
y_pred_label = torch.Tensor([]).to(device) | |
preds = [] | |
for i, (batch_x, _, batch_y, metadata) in enumerate(test_loader): | |
model.eval() | |
_, path, _, _ = metadata | |
if i % 10 == 0: | |
print(f"Batch [{i}/{batches_number}]") | |
with torch.no_grad(): | |
batch_x = batch_x.to(device) | |
batch_y = batch_y.to(device) | |
num_total += batch_x.size(0) | |
batch_pred = model(batch_x).squeeze(1) | |
batch_pred = torch.sigmoid(batch_pred) | |
batch_pred_label = (batch_pred + 0.5).int() | |
num_correct += (batch_pred_label == batch_y.int()).sum(dim=0).item() | |
y_pred = torch.concat([y_pred, batch_pred], dim=0) | |
y_pred_label = torch.concat([y_pred_label, batch_pred_label], dim=0) | |
y = torch.concat([y, batch_y], dim=0) | |
for i in range(len(y_pred_label)): | |
label = 'Fake' if y_pred_label[i] == 0 else 'Real' | |
print(f'{path[i]}') | |
print(f' Prediction: : {label}') | |
print(f' Probability: {y_pred[i]})') | |
preds.append((label, y_pred[i].detach().cpu().item())) | |
return preds | |
eval_accuracy = (num_correct / num_total) * 100 | |
precision, recall, f1_score, support = precision_recall_fscore_support( | |
y.cpu().numpy(), y_pred_label.cpu().numpy(), average="binary", beta=1.0 | |
) | |
auc_score = roc_auc_score(y_true=y.cpu().numpy(), y_score=y_pred.cpu().numpy()) | |
# For EER flip values, following original evaluation implementation | |
y_for_eer = 1 - y | |
thresh, eer, fpr, tpr = metrics.calculate_eer( | |
y=y_for_eer.cpu().numpy(), | |
y_score=y_pred.cpu().numpy(), | |
) | |
eer_label = f"eval/eer" | |
accuracy_label = f"eval/accuracy" | |
precision_label = f"eval/precision" | |
recall_label = f"eval/recall" | |
f1_label = f"eval/f1_score" | |
auc_label = f"eval/auc" | |
logging.info( | |
f"{eer_label}: {eer:.4f}, {accuracy_label}: {eval_accuracy:.4f}, {precision_label}: {precision:.4f}, {recall_label}: {recall:.4f}, {f1_label}: {f1_score:.4f}, {auc_label}: {auc_score:.4f}" | |
) | |
def main(args): | |
LOGGER = logging.getLogger() | |
LOGGER.setLevel(logging.INFO) | |
ch = logging.StreamHandler() | |
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") | |
ch.setFormatter(formatter) | |
LOGGER.addHandler(ch) | |
logging.basicConfig(stream=sys.stdout, level=logging.INFO) | |
if not args.cpu and torch.cuda.is_available(): | |
device = "cuda" | |
else: | |
device = "cpu" | |
with open(args.config, "r") as f: | |
config = yaml.safe_load(f) | |
seed = config["data"].get("seed", 42) | |
# fix all seeds - this should not actually change anything | |
commons.set_seed(seed) | |
evaluate_nn( | |
model_paths=config["checkpoint"].get("path", []), | |
datasets_paths=[ | |
args.folder_path, | |
], | |
model_config=config["model"], | |
amount_to_use=args.amount, | |
device=device, | |
) | |
def parse_args(): | |
parser = argparse.ArgumentParser() | |
# If assigned as None, then it won't be taken into account | |
FOLDER_DATASET_PATH = "sample_files" | |
parser.add_argument( | |
"--folder_path", type=str, default=FOLDER_DATASET_PATH | |
) | |
default_model_config = "config.yaml" | |
parser.add_argument( | |
"--config", | |
help="Model config file path (default: config.yaml)", | |
type=str, | |
default=default_model_config, | |
) | |
default_amount = None | |
parser.add_argument( | |
"--amount", | |
"-a", | |
help=f"Amount of files to load from each directory (default: {default_amount} - use all).", | |
type=int, | |
default=default_amount, | |
) | |
parser.add_argument("--cpu", "-c", help="Force using cpu", action="store_true") | |
return parser.parse_args() | |
if __name__ == "__main__": | |
main(parse_args()) | |