File size: 540 Bytes
4d4dd90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch

from ..utils.tools import get_class
from .eval_pipeline import EvalPipeline


def get_benchmark(benchmark):
    return get_class(f"{__name__}.{benchmark}", EvalPipeline)


@torch.no_grad()
def run_benchmark(benchmark, eval_conf, experiment_dir, model=None):
    """This overwrites existing benchmarks"""
    experiment_dir.mkdir(exist_ok=True, parents=True)
    bm = get_benchmark(benchmark)

    pipeline = bm(eval_conf)
    return pipeline.run(
        experiment_dir, model=model, overwrite=True, overwrite_eval=True
    )