PyTorch
ssl-aasist
custom_code
ash56's picture
Add files using upload-large-folder tool
d28af7f verified
raw
history blame
2.03 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import glob
import numpy as np
from . import metric as metric_path
from . import predictor as predictor_path
class Evaluator(object):
"""
perform evaluation on a single (downstream) task.
make this both offline and online.
TODO(huxu) saving evaluation results.
"""
def __init__(self, config, eval_dataloader=None):
if config.metric is None:
raise ValueError("config.metric is", config.metric)
metric_cls = getattr(metric_path, config.metric)
self.metric = metric_cls(config)
if config.predictor is None:
raise ValueError("config.predictor is", config.predictor)
predictor_cls = getattr(predictor_path, config.predictor)
self.predictor = predictor_cls(config)
self.eval_dataloader = eval_dataloader
def __call__(self):
try:
print(self.predictor.pred_dir)
for pred_file in glob.glob(
self.predictor.pred_dir + "/*_merged.npy"):
outputs = np.load(pred_file)
results = self.metric.compute_metrics(outputs)
self.metric.print_computed_metrics(results)
outputs = np.load(os.path.join(
self.predictor.pred_dir, "merged.npy"))
results = self.metric.compute_metrics(outputs)
return {"results": results, "metric": self.metric}
except FileNotFoundError:
print("\n[missing]", self.predictor.pred_dir)
return {}
def evaluate(self, model, eval_dataloader=None, output_file="merged"):
if eval_dataloader is None:
eval_dataloader = self.eval_dataloader
outputs = self.predictor.predict_loop(
model, eval_dataloader, output_file)
results = self.metric.compute_metrics(**outputs)
return results