import json import mlxu from EasyLM.serving import LMClient FLAGS, FLAGS_DEF = mlxu.define_flags_with_default( input_file='', output_file='', prefix_field='prefix', text_field='text', until_field='until', eval_type='loglikelihood', lm_client=LMClient.get_default_config(), ) def main(argv): lm_client = LMClient(FLAGS.lm_client) with mlxu.open_file(FLAGS.input_file, 'r') as fin: input_data = json.load(fin) if FLAGS.eval_type == 'loglikelihood': prefix = input_data[FLAGS.prefix_field] text = input_data[FLAGS.text_field] loglikelihoods, is_greedys = lm_client.loglikelihood(prefix, text) output_data = { 'loglikelihood': loglikelihoods, 'is_greedy': is_greedys, } elif FLAGS.eval_type == 'loglikelihood_rolling': text = input_data[FLAGS.text_field] loglikelihoods, is_greedys = lm_client.loglikelihood_rolling(text) output_data = { 'loglikelihood': loglikelihoods, 'is_greedy': is_greedys, } elif FLAGS.eval_type == 'greedy_until': prefix = input_data[FLAGS.prefix_field] until = input_data[FLAGS.until_field] output_data = {'output_text': lm_client.greedy_until(prefix, until)} elif FLAGS.eval_type == 'generate': prefix = input_data[FLAGS.prefix_field] output_data = {'output_text': lm_client.generate(prefix)} else: raise ValueError(f'Unknown eval_type: {FLAGS.eval_type}') with mlxu.open_file(FLAGS.output_file, 'w') as fout: json.dump(output_data, fout) if __name__ == "__main__": mlxu.run(main)