File size: 1,659 Bytes
a85f909
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import json
import mlxu
from EasyLM.serving import LMClient


FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
    input_file='',
    output_file='',
    prefix_field='prefix',
    text_field='text',
    until_field='until',
    eval_type='loglikelihood',
    lm_client=LMClient.get_default_config(),
)


def main(argv):
    lm_client = LMClient(FLAGS.lm_client)
    with mlxu.open_file(FLAGS.input_file, 'r') as fin:
        input_data = json.load(fin)

    if FLAGS.eval_type == 'loglikelihood':
        prefix = input_data[FLAGS.prefix_field]
        text = input_data[FLAGS.text_field]
        loglikelihoods, is_greedys = lm_client.loglikelihood(prefix, text)
        output_data = {
            'loglikelihood': loglikelihoods,
            'is_greedy': is_greedys,
        }
    elif FLAGS.eval_type == 'loglikelihood_rolling':
        text = input_data[FLAGS.text_field]
        loglikelihoods, is_greedys = lm_client.loglikelihood_rolling(text)
        output_data = {
            'loglikelihood': loglikelihoods,
            'is_greedy': is_greedys,
        }
    elif FLAGS.eval_type == 'greedy_until':
        prefix = input_data[FLAGS.prefix_field]
        until = input_data[FLAGS.until_field]
        output_data = {'output_text': lm_client.greedy_until(prefix, until)}
    elif FLAGS.eval_type == 'generate':
        prefix = input_data[FLAGS.prefix_field]
        output_data = {'output_text': lm_client.generate(prefix)}
    else:
        raise ValueError(f'Unknown eval_type: {FLAGS.eval_type}')

    with mlxu.open_file(FLAGS.output_file, 'w') as fout:
        json.dump(output_data, fout)


if __name__ == "__main__":
    mlxu.run(main)