|
import json |
|
import mlxu |
|
from EasyLM.serving import LMClient |
|
|
|
|
|
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default( |
|
input_file='', |
|
output_file='', |
|
prefix_field='prefix', |
|
text_field='text', |
|
until_field='until', |
|
eval_type='loglikelihood', |
|
lm_client=LMClient.get_default_config(), |
|
) |
|
|
|
|
|
def main(argv): |
|
lm_client = LMClient(FLAGS.lm_client) |
|
with mlxu.open_file(FLAGS.input_file, 'r') as fin: |
|
input_data = json.load(fin) |
|
|
|
if FLAGS.eval_type == 'loglikelihood': |
|
prefix = input_data[FLAGS.prefix_field] |
|
text = input_data[FLAGS.text_field] |
|
loglikelihoods, is_greedys = lm_client.loglikelihood(prefix, text) |
|
output_data = { |
|
'loglikelihood': loglikelihoods, |
|
'is_greedy': is_greedys, |
|
} |
|
elif FLAGS.eval_type == 'loglikelihood_rolling': |
|
text = input_data[FLAGS.text_field] |
|
loglikelihoods, is_greedys = lm_client.loglikelihood_rolling(text) |
|
output_data = { |
|
'loglikelihood': loglikelihoods, |
|
'is_greedy': is_greedys, |
|
} |
|
elif FLAGS.eval_type == 'greedy_until': |
|
prefix = input_data[FLAGS.prefix_field] |
|
until = input_data[FLAGS.until_field] |
|
output_data = {'output_text': lm_client.greedy_until(prefix, until)} |
|
elif FLAGS.eval_type == 'generate': |
|
prefix = input_data[FLAGS.prefix_field] |
|
output_data = {'output_text': lm_client.generate(prefix)} |
|
else: |
|
raise ValueError(f'Unknown eval_type: {FLAGS.eval_type}') |
|
|
|
with mlxu.open_file(FLAGS.output_file, 'w') as fout: |
|
json.dump(output_data, fout) |
|
|
|
|
|
if __name__ == "__main__": |
|
mlxu.run(main) |
|
|