pminervini commited on
Commit
05346b7
1 Parent(s): 68d5bd5
Files changed (1) hide show
  1. cli/halueval-cli.py +3 -3
cli/halueval-cli.py CHANGED
@@ -28,7 +28,7 @@ def main():
28
  eval_requests: list[EvalRequest] = get_eval_requests(job_status=status, hf_repo=QUEUE_REPO, local_dir=EVAL_REQUESTS_PATH_BACKEND)
29
  eval_request = [r for r in eval_requests if 'bloom-560m' in r.model][0]
30
 
31
- TASKS_HARNESS = [t.value for t in Tasks if 'halueval_qa' in t.value.benchmark]
32
  # task_names = ['triviaqa']
33
  # TASKS_HARNESS = [task.value for task in Tasks]
34
 
@@ -41,8 +41,8 @@ def main():
41
 
42
  for task in TASKS_HARNESS:
43
  print(f"Selected Tasks: [{task}]")
44
- results = evaluator.simple_evaluate(model="hf", model_args=eval_request.get_model_args(), tasks=[task.benchmark], num_fewshot=0,
45
- batch_size=1, device=DEVICE, use_cache=None, limit=8, write_out=True)
46
  print('AAA', results["results"])
47
 
48
  breakpoint()
 
28
  eval_requests: list[EvalRequest] = get_eval_requests(job_status=status, hf_repo=QUEUE_REPO, local_dir=EVAL_REQUESTS_PATH_BACKEND)
29
  eval_request = [r for r in eval_requests if 'bloom-560m' in r.model][0]
30
 
31
+ TASKS_HARNESS = [t.value for t in Tasks if 'xsum' in t.value.benchmark]
32
  # task_names = ['triviaqa']
33
  # TASKS_HARNESS = [task.value for task in Tasks]
34
 
 
41
 
42
  for task in TASKS_HARNESS:
43
  print(f"Selected Tasks: [{task}]")
44
+ results = evaluator.simple_evaluate(model="hf", model_args=eval_request.get_model_args(), tasks=[task.benchmark], num_fewshot=1,
45
+ batch_size=1, device="mps", use_cache=None, limit=1, write_out=True)
46
  print('AAA', results["results"])
47
 
48
  breakpoint()