Spaces:
Sleeping
Sleeping
""" | |
Copyright (c) 2022, salesforce.com, inc. | |
All rights reserved. | |
SPDX-License-Identifier: BSD-3-Clause | |
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | |
""" | |
import json | |
import os | |
import logging | |
import numpy as np | |
import torch | |
from lavis.common.dist_utils import main_process | |
from lavis.common.registry import registry | |
from lavis.tasks.base_task import BaseTask | |
class MultimodalClassificationTask(BaseTask): | |
def __init__(self): | |
super().__init__() | |
def valid_step(self, model, samples): | |
results = [] | |
outputs = model.predict(samples) | |
predictions = outputs["predictions"] | |
targets = outputs["targets"] | |
predictions = predictions.max(1)[1].cpu().numpy() | |
targets = targets.cpu().numpy() | |
indices = samples[self.inst_id_key] | |
for pred, tgt, index in zip(predictions, targets, indices): | |
if isinstance(index, torch.Tensor): | |
index = index.item() | |
results.append( | |
{ | |
self.inst_id_key: index, | |
"prediction": pred.item(), | |
"target": tgt.item(), | |
} | |
) | |
return results | |
def after_evaluation(self, val_result, split_name, epoch, **kwargs): | |
eval_result_file = self.save_result( | |
result=val_result, | |
result_dir=registry.get_path("result_dir"), | |
filename="{}_epoch{}".format(split_name, epoch), | |
remove_duplicate=self.inst_id_key, | |
) | |
metrics = self._report_metrics( | |
eval_result_file=eval_result_file, split_name=split_name | |
) | |
return metrics | |
def _report_metrics(self, eval_result_file, split_name): | |
results = json.load(open(eval_result_file)) | |
predictions = np.array([res["prediction"] for res in results]) | |
targets = np.array([res["target"] for res in results]) | |
accuracy = (targets == predictions).sum() / targets.shape[0] | |
metrics = {"agg_metrics": accuracy, "acc": accuracy} | |
log_stats = {split_name: {k: v for k, v in metrics.items()}} | |
with open( | |
os.path.join(registry.get_path("output_dir"), "evaluate.txt"), "a" | |
) as f: | |
f.write(json.dumps(log_stats) + "\n") | |
logging.info(metrics) | |
return metrics | |