# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import os import glob import numpy as np from . import metric as metric_path from . import predictor as predictor_path class Evaluator(object): """ perform evaluation on a single (downstream) task. make this both offline and online. TODO(huxu) saving evaluation results. """ def __init__(self, config, eval_dataloader=None): if config.metric is None: raise ValueError("config.metric is", config.metric) metric_cls = getattr(metric_path, config.metric) self.metric = metric_cls(config) if config.predictor is None: raise ValueError("config.predictor is", config.predictor) predictor_cls = getattr(predictor_path, config.predictor) self.predictor = predictor_cls(config) self.eval_dataloader = eval_dataloader def __call__(self): try: print(self.predictor.pred_dir) for pred_file in glob.glob( self.predictor.pred_dir + "/*_merged.npy"): outputs = np.load(pred_file) results = self.metric.compute_metrics(outputs) self.metric.print_computed_metrics(results) outputs = np.load(os.path.join( self.predictor.pred_dir, "merged.npy")) results = self.metric.compute_metrics(outputs) return {"results": results, "metric": self.metric} except FileNotFoundError: print("\n[missing]", self.predictor.pred_dir) return {} def evaluate(self, model, eval_dataloader=None, output_file="merged"): if eval_dataloader is None: eval_dataloader = self.eval_dataloader outputs = self.predictor.predict_loop( model, eval_dataloader, output_file) results = self.metric.compute_metrics(**outputs) return results