lmzjms's picture
Upload 35 files
15ac91d
raw
history blame
1.11 kB
from sklearn import metrics
from pytorch_utils import forward
class Evaluator(object):
def __init__(self, model):
"""Evaluator.
Args:
model: object
"""
self.model = model
def evaluate(self, data_loader):
"""Forward evaluation data and calculate statistics.
Args:
data_loader: object
Returns:
statistics: dict,
{'average_precision': (classes_num,), 'auc': (classes_num,)}
"""
# Forward
output_dict = forward(
model=self.model,
generator=data_loader,
return_target=True)
clipwise_output = output_dict['clipwise_output'] # (audios_num, classes_num)
target = output_dict['target'] # (audios_num, classes_num)
average_precision = metrics.average_precision_score(
target, clipwise_output, average=None)
auc = metrics.roc_auc_score(target, clipwise_output, average=None)
statistics = {'average_precision': average_precision, 'auc': auc}
return statistics