Spaces:
Runtime error
Runtime error
import torch | |
from torchmetrics.classification import Accuracy, Recall, Precision, MatthewsCorrCoef, AUROC, F1Score, MatthewsCorrCoef | |
from torchmetrics.classification import BinaryAccuracy, BinaryRecall, BinaryAUROC, BinaryF1Score, BinaryPrecision, BinaryMatthewsCorrCoef, BinaryF1Score | |
from torchmetrics.regression import SpearmanCorrCoef, MeanSquaredError | |
from torchmetrics.classification import MultilabelAveragePrecision | |
def count_f1_max(pred, target): | |
""" | |
F1 score with the optimal threshold, Copied from TorchDrug. | |
This function first enumerates all possible thresholds for deciding positive and negative | |
samples, and then pick the threshold with the maximal F1 score. | |
Parameters: | |
pred (Tensor): predictions of shape :math:`(B, N)` | |
target (Tensor): binary targets of shape :math:`(B, N)` | |
""" | |
order = pred.argsort(descending=True, dim=1) | |
target = target.gather(1, order) | |
precision = target.cumsum(1) / torch.ones_like(target).cumsum(1) | |
recall = target.cumsum(1) / (target.sum(1, keepdim=True) + 1e-10) | |
is_start = torch.zeros_like(target).bool() | |
is_start[:, 0] = 1 | |
is_start = torch.scatter(is_start, 1, order, is_start) | |
all_order = pred.flatten().argsort(descending=True) | |
order = ( | |
order | |
+ torch.arange(order.shape[0], device=order.device).unsqueeze(1) | |
* order.shape[1] | |
) | |
order = order.flatten() | |
inv_order = torch.zeros_like(order) | |
inv_order[order] = torch.arange(order.shape[0], device=order.device) | |
is_start = is_start.flatten()[all_order] | |
all_order = inv_order[all_order] | |
precision = precision.flatten() | |
recall = recall.flatten() | |
all_precision = precision[all_order] - torch.where( | |
is_start, torch.zeros_like(precision), precision[all_order - 1] | |
) | |
all_precision = all_precision.cumsum(0) / is_start.cumsum(0) | |
all_recall = recall[all_order] - torch.where( | |
is_start, torch.zeros_like(recall), recall[all_order - 1] | |
) | |
all_recall = all_recall.cumsum(0) / pred.shape[0] | |
all_f1 = 2 * all_precision * all_recall / (all_precision + all_recall + 1e-10) | |
return all_f1.max() | |
class MultilabelF1Max(MultilabelAveragePrecision): | |
def compute(self): | |
return count_f1_max(torch.cat(self.preds), torch.cat(self.target)) | |
def setup_metrics(args): | |
"""Setup metrics based on problem type and specified metrics list.""" | |
metrics_dict = {} | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
for metric_name in args.metrics: | |
if args.problem_type == 'regression': | |
metric_config = _setup_regression_metrics(metric_name, device) | |
elif args.problem_type == 'single_label_classification': | |
if args.num_labels == 2: | |
metric_config = _setup_binary_metrics(metric_name, device) | |
else: | |
metric_config = _setup_multiclass_metrics(metric_name, args.num_labels, device) | |
elif args.problem_type == 'multi_label_classification': | |
metric_config = _setup_multilabel_metrics(metric_name, args.num_labels, device) | |
if metric_config: | |
metrics_dict[metric_name] = metric_config['metric'] | |
# Add loss to metrics if it's the monitor metric | |
if args.monitor == 'loss': | |
metrics_dict['loss'] = 'loss' | |
return metrics_dict | |
def _setup_regression_metrics(metric_name, device): | |
metrics_config = { | |
'spearman_corr': { | |
'metric': SpearmanCorrCoef().to(device), | |
}, | |
'mse': { | |
'metric': MeanSquaredError().to(device), | |
} | |
} | |
return metrics_config.get(metric_name) | |
def _setup_multiclass_metrics(metric_name, num_classes, device): | |
metrics_config = { | |
'accuracy': { | |
'metric': Accuracy(task='multiclass', num_classes=num_classes).to(device), | |
}, | |
'recall': { | |
'metric': Recall(task='multiclass', num_classes=num_classes).to(device), | |
}, | |
'precision': { | |
'metric': Precision(task='multiclass', num_classes=num_classes).to(device), | |
}, | |
'f1': { | |
'metric': F1Score(task='multiclass', num_classes=num_classes).to(device), | |
}, | |
'mcc': { | |
'metric': MatthewsCorrCoef(task='multiclass', num_classes=num_classes).to(device), | |
}, | |
'auroc': { | |
'metric': AUROC(task='multiclass', num_classes=num_classes).to(device), | |
} | |
} | |
return metrics_config.get(metric_name) | |
def _setup_binary_metrics(metric_name, device): | |
metrics_config = { | |
'accuracy': { | |
'metric': BinaryAccuracy().to(device), | |
}, | |
'recall': { | |
'metric': BinaryRecall().to(device), | |
}, | |
'precision': { | |
'metric': BinaryPrecision().to(device), | |
}, | |
'f1': { | |
'metric': BinaryF1Score().to(device), | |
}, | |
'mcc': { | |
'metric': BinaryMatthewsCorrCoef().to(device), | |
}, | |
'auroc': { | |
'metric': BinaryAUROC().to(device), | |
} | |
} | |
return metrics_config.get(metric_name) | |
def _setup_multilabel_metrics(metric_name, num_labels, device): | |
metrics_config = { | |
'f1_max': { | |
'metric': MultilabelF1Max(num_labels=num_labels).to(device), | |
} | |
} | |
return metrics_config.get(metric_name) |