File size: 1,773 Bytes
894c4b4
1109e5f
652d88f
894c4b4
 
e992815
0b755b6
620ce47
894c4b4
 
e1b962a
894c4b4
 
 
1109e5f
652d88f
1109e5f
652d88f
894c4b4
 
 
dbd4d1b
669da77
 
ad7bcbf
4719e45
894c4b4
 
 
 
 
e1b962a
 
 
 
4e10b3e
e1b962a
 
894c4b4
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from lm_eval import tasks, evaluator, utils
from lm_eval.tasks import initialize_tasks, include_task_folder

from src.backend.manage_requests import EvalRequest

from src.backend.tasks.xsum.task import XSum
from src.backend.tasks.cnndm.task import CNNDM
from src.backend.tasks.selfcheckgpt.task import SelfCheckGpt


def run_evaluation(eval_request: EvalRequest, task_names, num_fewshot, batch_size, device, use_cache=None, limit=None, max_nb_samples=100) -> dict:
    if limit:
        print("WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.")

    include_task_folder("src/backend/tasks/")
    initialize_tasks('INFO')

    task_names = utils.pattern_match(task_names, tasks.ALL_TASKS)

    print(f"Selected Tasks: {task_names}")

    results = evaluator.simple_evaluate(model="hf-auto",  # "hf-causal-experimental",  # "hf-causal"
                                        model_args=eval_request.get_model_args(),
                                        tasks=task_names, num_fewshot=num_fewshot,
                                        batch_size=batch_size, device=device, use_cache=use_cache,
                                        limit=limit, write_out=True)

    results["config"]["model_dtype"] = eval_request.precision
    results["config"]["model_name"] = eval_request.model
    results["config"]["model_sha"] = eval_request.revision

    if max_nb_samples is not None:
        if 'samples' in results:
            samples = results['samples']
            for task_name in samples.keys():
                if len(samples[task_name]) > max_nb_samples:
                    results['samples'][task_name] = results['samples'][task_name][:max_nb_samples]

    print(evaluator.make_table(results))

    return results