import argparse from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM import torch import os import json from tqdm import tqdm import shortuuid import ray from fastchat.conversation import get_default_conv_template, compute_skip_echo_len from fastchat.utils import disable_torch_init def run_eval(model_path, model_id, question_file, answer_file, num_gpus): # split question file into num_gpus files ques_jsons = [] with open(os.path.expanduser(question_file), "r") as ques_file: for line in ques_file: ques_jsons.append(line) chunk_size = len(ques_jsons) // num_gpus ans_handles = [] for i in range(0, len(ques_jsons), chunk_size): ans_handles.append( get_model_answers.remote( model_path, model_id, ques_jsons[i : i + chunk_size] ) ) ans_jsons = [] for ans_handle in ans_handles: ans_jsons.extend(ray.get(ans_handle)) with open(os.path.expanduser(answer_file), "w") as ans_file: for line in ans_jsons: ans_file.write(json.dumps(line) + "\n") @ray.remote(num_gpus=1) @torch.inference_mode() def get_model_answers(model_path, model_id, question_jsons): disable_torch_init() model_path = os.path.expanduser(model_path) tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) model = AutoModelForCausalLM.from_pretrained( model_path, torch_dtype=torch.float16 ).cuda() ans_jsons = [] for i, line in enumerate(tqdm(question_jsons)): ques_json = json.loads(line) idx = ques_json["question_id"] qs = ques_json["text"] conv = get_default_conv_template(model_id).copy() conv.append_message(conv.roles[0], qs) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() inputs = tokenizer([prompt]) output_ids = model.generate( torch.as_tensor(inputs.input_ids).cuda(), do_sample=True, temperature=0.7, max_new_tokens=1024, ) outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] skip_echo_len = compute_skip_echo_len(model_id, conv, prompt) outputs = outputs[skip_echo_len:].strip() ans_id = shortuuid.uuid() ans_jsons.append( { "question_id": idx, "text": outputs, "answer_id": ans_id, "model_id": model_id, "metadata": {}, } ) return ans_jsons if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model-path", type=str, required=True) parser.add_argument("--model-id", type=str, required=True) parser.add_argument("--question-file", type=str, required=True) parser.add_argument("--answer-file", type=str, default="answer.jsonl") parser.add_argument("--num-gpus", type=int, default=1) args = parser.parse_args() ray.init() run_eval( args.model_path, args.model_id, args.question_file, args.answer_file, args.num_gpus, )