File size: 1,907 Bytes
5a63fc6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# This script runs lm_eval_harness evaluations against a served language model.
# Typically, you need to run a language model server first, e.g.:
#    python -m EasyLM.models.gptj.gptj_serve ...

import dataclasses
import pprint
from functools import partial
import os
from tqdm import tqdm, trange
import numpy as np
import mlxu

from flax.traverse_util import flatten_dict
from lm_eval import evaluator, tasks
from lm_eval.base import LM

from EasyLM.serving import LMClient


FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
    tasks='wsc,piqa,winogrande,openbookqa,logiqa',
    shots=0,
    limit=0,
    write_out=False,
    lm_client=LMClient.get_default_config(),
    logger=mlxu.WandBLogger.get_default_config(),
)


class LMEvalHarnessInterface(LM):

    def __init__(self, lm_client):
        self.lm_client = lm_client

    def greedy_until(self, inputs):
        prefix, until = zip(*inputs)
        return self.lm_client.greedy_until(prefix, until)

    def loglikelihood_rolling(self, inputs):
        loglikelihood, is_greedy = self.lm_client.loglikelihood_rolling(inputs)
        return list(zip(loglikelihood, is_greedy))

    def loglikelihood(self, inputs):
        prefix, text = zip(*inputs)
        loglikelihood, is_greedy = self.lm_client.loglikelihood(prefix, text)
        return list(zip(loglikelihood, is_greedy))


def main(argv):
    logger = mlxu.WandBLogger(
        config=FLAGS.logger, variant=mlxu.get_user_flags(FLAGS, FLAGS_DEF)
    )
    model = LMEvalHarnessInterface(LMClient(FLAGS.lm_client))
    task_list = FLAGS.tasks.split(',')
    results = evaluator.evaluate(
        model, tasks.get_task_dict(task_list), False, FLAGS.shots,
        limit=None if FLAGS.limit <= 0 else FLAGS.limit,
        write_out=FLAGS.write_out,
    )
    logger.log(flatten_dict(results['results'], sep='/'))
    pprint.pprint(results)


if __name__ == "__main__":
    mlxu.run(main)