2dogey's picture
Upload folder using huggingface_hub
8918ac7 verified
import torch
from torchmetrics.classification import Accuracy, Recall, Precision, MatthewsCorrCoef, AUROC, F1Score, MatthewsCorrCoef
from torchmetrics.classification import BinaryAccuracy, BinaryRecall, BinaryAUROC, BinaryF1Score, BinaryPrecision, BinaryMatthewsCorrCoef, BinaryF1Score
from torchmetrics.regression import SpearmanCorrCoef, MeanSquaredError
from torchmetrics.classification import MultilabelAveragePrecision
def count_f1_max(pred, target):
"""
F1 score with the optimal threshold, Copied from TorchDrug.
This function first enumerates all possible thresholds for deciding positive and negative
samples, and then pick the threshold with the maximal F1 score.
Parameters:
pred (Tensor): predictions of shape :math:`(B, N)`
target (Tensor): binary targets of shape :math:`(B, N)`
"""
order = pred.argsort(descending=True, dim=1)
target = target.gather(1, order)
precision = target.cumsum(1) / torch.ones_like(target).cumsum(1)
recall = target.cumsum(1) / (target.sum(1, keepdim=True) + 1e-10)
is_start = torch.zeros_like(target).bool()
is_start[:, 0] = 1
is_start = torch.scatter(is_start, 1, order, is_start)
all_order = pred.flatten().argsort(descending=True)
order = (
order
+ torch.arange(order.shape[0], device=order.device).unsqueeze(1)
* order.shape[1]
)
order = order.flatten()
inv_order = torch.zeros_like(order)
inv_order[order] = torch.arange(order.shape[0], device=order.device)
is_start = is_start.flatten()[all_order]
all_order = inv_order[all_order]
precision = precision.flatten()
recall = recall.flatten()
all_precision = precision[all_order] - torch.where(
is_start, torch.zeros_like(precision), precision[all_order - 1]
)
all_precision = all_precision.cumsum(0) / is_start.cumsum(0)
all_recall = recall[all_order] - torch.where(
is_start, torch.zeros_like(recall), recall[all_order - 1]
)
all_recall = all_recall.cumsum(0) / pred.shape[0]
all_f1 = 2 * all_precision * all_recall / (all_precision + all_recall + 1e-10)
return all_f1.max()
class MultilabelF1Max(MultilabelAveragePrecision):
def compute(self):
return count_f1_max(torch.cat(self.preds), torch.cat(self.target))
def setup_metrics(args):
"""Setup metrics based on problem type and specified metrics list."""
metrics_dict = {}
device = "cuda" if torch.cuda.is_available() else "cpu"
for metric_name in args.metrics:
if args.problem_type == 'regression':
metric_config = _setup_regression_metrics(metric_name, device)
elif args.problem_type == 'single_label_classification':
if args.num_labels == 2:
metric_config = _setup_binary_metrics(metric_name, device)
else:
metric_config = _setup_multiclass_metrics(metric_name, args.num_labels, device)
elif args.problem_type == 'multi_label_classification':
metric_config = _setup_multilabel_metrics(metric_name, args.num_labels, device)
if metric_config:
metrics_dict[metric_name] = metric_config['metric']
# Add loss to metrics if it's the monitor metric
if args.monitor == 'loss':
metrics_dict['loss'] = 'loss'
return metrics_dict
def _setup_regression_metrics(metric_name, device):
metrics_config = {
'spearman_corr': {
'metric': SpearmanCorrCoef().to(device),
},
'mse': {
'metric': MeanSquaredError().to(device),
}
}
return metrics_config.get(metric_name)
def _setup_multiclass_metrics(metric_name, num_classes, device):
metrics_config = {
'accuracy': {
'metric': Accuracy(task='multiclass', num_classes=num_classes).to(device),
},
'recall': {
'metric': Recall(task='multiclass', num_classes=num_classes).to(device),
},
'precision': {
'metric': Precision(task='multiclass', num_classes=num_classes).to(device),
},
'f1': {
'metric': F1Score(task='multiclass', num_classes=num_classes).to(device),
},
'mcc': {
'metric': MatthewsCorrCoef(task='multiclass', num_classes=num_classes).to(device),
},
'auroc': {
'metric': AUROC(task='multiclass', num_classes=num_classes).to(device),
}
}
return metrics_config.get(metric_name)
def _setup_binary_metrics(metric_name, device):
metrics_config = {
'accuracy': {
'metric': BinaryAccuracy().to(device),
},
'recall': {
'metric': BinaryRecall().to(device),
},
'precision': {
'metric': BinaryPrecision().to(device),
},
'f1': {
'metric': BinaryF1Score().to(device),
},
'mcc': {
'metric': BinaryMatthewsCorrCoef().to(device),
},
'auroc': {
'metric': BinaryAUROC().to(device),
}
}
return metrics_config.get(metric_name)
def _setup_multilabel_metrics(metric_name, num_labels, device):
metrics_config = {
'f1_max': {
'metric': MultilabelF1Max(num_labels=num_labels).to(device),
}
}
return metrics_config.get(metric_name)