|
|
|
|
|
|
|
|
|
import os |
|
import glob |
|
import numpy as np |
|
|
|
from . import metric as metric_path |
|
from . import predictor as predictor_path |
|
|
|
|
|
class Evaluator(object): |
|
""" |
|
perform evaluation on a single (downstream) task. |
|
make this both offline and online. |
|
TODO(huxu) saving evaluation results. |
|
""" |
|
|
|
def __init__(self, config, eval_dataloader=None): |
|
if config.metric is None: |
|
raise ValueError("config.metric is", config.metric) |
|
metric_cls = getattr(metric_path, config.metric) |
|
self.metric = metric_cls(config) |
|
if config.predictor is None: |
|
raise ValueError("config.predictor is", config.predictor) |
|
predictor_cls = getattr(predictor_path, config.predictor) |
|
self.predictor = predictor_cls(config) |
|
self.eval_dataloader = eval_dataloader |
|
|
|
def __call__(self): |
|
try: |
|
print(self.predictor.pred_dir) |
|
for pred_file in glob.glob( |
|
self.predictor.pred_dir + "/*_merged.npy"): |
|
outputs = np.load(pred_file) |
|
results = self.metric.compute_metrics(outputs) |
|
self.metric.print_computed_metrics(results) |
|
|
|
outputs = np.load(os.path.join( |
|
self.predictor.pred_dir, "merged.npy")) |
|
results = self.metric.compute_metrics(outputs) |
|
return {"results": results, "metric": self.metric} |
|
except FileNotFoundError: |
|
print("\n[missing]", self.predictor.pred_dir) |
|
return {} |
|
|
|
def evaluate(self, model, eval_dataloader=None, output_file="merged"): |
|
if eval_dataloader is None: |
|
eval_dataloader = self.eval_dataloader |
|
outputs = self.predictor.predict_loop( |
|
model, eval_dataloader, output_file) |
|
results = self.metric.compute_metrics(**outputs) |
|
return results |
|
|