File size: 2,563 Bytes
9d5b280 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import os
import pytest
import lm_eval.api as api
import lm_eval.evaluator as evaluator
from lm_eval import tasks
@pytest.mark.parametrize(
"limit,model,model_args",
[
(
10,
"hf",
"pretrained=EleutherAI/pythia-160m,dtype=float32,device=cpu",
),
],
)
def test_include_correctness(limit: int, model: str, model_args: str):
task_name = ["arc_easy"]
task_manager = tasks.TaskManager()
task_dict = tasks.get_task_dict(task_name, task_manager)
e1 = evaluator.simple_evaluate(
model=model,
tasks=task_name,
limit=limit,
model_args=model_args,
)
assert e1 is not None
# run with evaluate() and "arc_easy" test config (included from ./testconfigs path)
lm = api.registry.get_model(model).create_from_arg_string(
model_args,
{
"batch_size": None,
"max_batch_size": None,
"device": None,
},
)
task_name = ["arc_easy"]
task_manager = tasks.TaskManager(
include_path=os.path.dirname(os.path.abspath(__file__)) + "/testconfigs",
include_defaults=False,
)
task_dict = tasks.get_task_dict(task_name, task_manager)
e2 = evaluator.evaluate(
lm=lm,
task_dict=task_dict,
limit=limit,
)
assert e2 is not None
# check that caching is working
def r(x):
return x["results"]["arc_easy"]
assert all(
x == y
for x, y in zip([y for _, y in r(e1).items()], [y for _, y in r(e2).items()])
)
# test that setting include_defaults = False works as expected and that include_path works
def test_no_include_defaults():
task_name = ["arc_easy"]
task_manager = tasks.TaskManager(
include_path=os.path.dirname(os.path.abspath(__file__)) + "/testconfigs",
include_defaults=False,
)
# should succeed, because we've included an 'arc_easy' task from this dir
task_dict = tasks.get_task_dict(task_name, task_manager)
# should fail, since ./testconfigs has no arc_challenge task
task_name = ["arc_challenge"]
with pytest.raises(KeyError):
task_dict = tasks.get_task_dict(task_name, task_manager) # noqa: F841
# test that include_path containing a task shadowing another task's name fails
# def test_shadowed_name_fails():
# task_name = ["arc_easy"]
# task_manager = tasks.TaskManager(include_path=os.path.dirname(os.path.abspath(__file__)) + "/testconfigs")
# task_dict = tasks.get_task_dict(task_name, task_manager)
|