Spaces:
Runtime error
Runtime error
import argparse | |
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM | |
import torch | |
import os | |
import json | |
from tqdm import tqdm | |
import shortuuid | |
import ray | |
from fastchat.conversation import get_default_conv_template, compute_skip_echo_len | |
from fastchat.utils import disable_torch_init | |
def run_eval(model_path, model_id, question_file, answer_file, num_gpus): | |
# split question file into num_gpus files | |
ques_jsons = [] | |
with open(os.path.expanduser(question_file), "r") as ques_file: | |
for line in ques_file: | |
ques_jsons.append(line) | |
chunk_size = len(ques_jsons) // num_gpus | |
ans_handles = [] | |
for i in range(0, len(ques_jsons), chunk_size): | |
ans_handles.append( | |
get_model_answers.remote( | |
model_path, model_id, ques_jsons[i : i + chunk_size] | |
) | |
) | |
ans_jsons = [] | |
for ans_handle in ans_handles: | |
ans_jsons.extend(ray.get(ans_handle)) | |
with open(os.path.expanduser(answer_file), "w") as ans_file: | |
for line in ans_jsons: | |
ans_file.write(json.dumps(line) + "\n") | |
def get_model_answers(model_path, model_id, question_jsons): | |
disable_torch_init() | |
model_path = os.path.expanduser(model_path) | |
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_path, torch_dtype=torch.float16 | |
).cuda() | |
ans_jsons = [] | |
for i, line in enumerate(tqdm(question_jsons)): | |
ques_json = json.loads(line) | |
idx = ques_json["question_id"] | |
qs = ques_json["text"] | |
conv = get_default_conv_template(model_id).copy() | |
conv.append_message(conv.roles[0], qs) | |
conv.append_message(conv.roles[1], None) | |
prompt = conv.get_prompt() | |
inputs = tokenizer([prompt]) | |
output_ids = model.generate( | |
torch.as_tensor(inputs.input_ids).cuda(), | |
do_sample=True, | |
temperature=0.7, | |
max_new_tokens=1024, | |
) | |
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] | |
skip_echo_len = compute_skip_echo_len(model_id, conv, prompt) | |
outputs = outputs[skip_echo_len:].strip() | |
ans_id = shortuuid.uuid() | |
ans_jsons.append( | |
{ | |
"question_id": idx, | |
"text": outputs, | |
"answer_id": ans_id, | |
"model_id": model_id, | |
"metadata": {}, | |
} | |
) | |
return ans_jsons | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--model-path", type=str, required=True) | |
parser.add_argument("--model-id", type=str, required=True) | |
parser.add_argument("--question-file", type=str, required=True) | |
parser.add_argument("--answer-file", type=str, default="answer.jsonl") | |
parser.add_argument("--num-gpus", type=int, default=1) | |
args = parser.parse_args() | |
ray.init() | |
run_eval( | |
args.model_path, | |
args.model_id, | |
args.question_file, | |
args.answer_file, | |
args.num_gpus, | |
) | |