File size: 2,591 Bytes
ceca114
3be882c
652d88f
894c4b4
 
e992815
fd7beec
 
0b755b6
fd7beec
 
6dcc9f8
894c4b4
69021cc
894c4b4
e1b962a
894c4b4
 
 
3be882c
 
 
8e3d8c1
 
3be882c
 
1109e5f
46d6ca0
3be882c
46d6ca0
3be882c
894c4b4
 
18d23cd
07f212a
 
669da77
6524ea0
 
 
 
 
 
 
3be882c
 
894c4b4
 
 
 
 
e1b962a
 
 
 
4e10b3e
e1b962a
 
b204b32
894c4b4
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from lm_eval import evaluator
from lm_eval.tasks import TaskManager

from src.backend.manage_requests import EvalRequest

from src.backend.tasks.xsum.task import XSum
from src.backend.tasks.xsum.task_v2 import XSumv2

from src.backend.tasks.cnndm.task import CNNDM
from src.backend.tasks.cnndm.task_v2 import CNNDMv2

from src.backend.tasks.selfcheckgpt.task import SelfCheckGPT

from src.backend.huggingface_generate_until import HFLMwithChatTemplate

def run_evaluation(eval_request: EvalRequest, task_names, num_fewshot, batch_size, device, use_cache=None, limit=None, max_nb_samples=100) -> dict:
    if limit:
        print("WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.")

    # include_task_folder("src/backend/tasks/")
    # initialize_tasks('INFO')

    print(f"Allocating task manager for: {task_names}")

    task_manager = TaskManager(include_path="./src/backend/tasks/")
    # task_manager.initialize_tasks('INFO')

    print(f"Considered Tasks: {task_names}")
    # print(f"Allowed Tasks: {tasks.ALL_TASKS}")

    # task_names = utils.pattern_match(task_names, tasks.ALL_TASKS)

    print(f"Selected Tasks: {task_names}")
    print(f"Eval Request: {eval_request.get_model_args()}")
    # hf-chat is implemented to use apply_chat_template
    results = evaluator.simple_evaluate(model="hf-auto",  # "hf-causal-experimental",  # "hf-causal", hf-chat
                                        model_args=eval_request.get_model_args(),
                                        tasks=task_names,
                                        num_fewshot=num_fewshot,
                                        batch_size=batch_size,
                                        max_batch_size=8,
                                        device=device,
                                        use_cache=use_cache,
                                        limit=limit,
                                        write_out=True,
                                        task_manager=task_manager)

    results["config"]["model_dtype"] = eval_request.precision
    results["config"]["model_name"] = eval_request.model
    results["config"]["model_sha"] = eval_request.revision

    if max_nb_samples is not None:
        if 'samples' in results:
            samples = results['samples']
            for task_name in samples.keys():
                if len(samples[task_name]) > max_nb_samples:
                    results['samples'][task_name] = results['samples'][task_name][:max_nb_samples]

    # print(evaluator.make_table(results))

    return results