Llama3-ChatQA-2-70B / code /evaluate_cqa_vllm_chatqa2.py
root
update README and add code
5c83af4
raw
history blame
3.54 kB
## The following code is adapted from
## https://docs.mystic.ai/docs/llama-2-with-vllm-7b-13b-multi-gpu-70b
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from arguments import get_args
from dataset_conv import get_chatqa2_input, preprocess
from tqdm import tqdm
import torch
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ['VLLM_NCCL_SO_PATH'] = '/usr/local/lib/python3.8/dist-packages/nvidia/nccl/lib/libnccl.so.2'
def get_prompt_list(args):
## get tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)
data_list = preprocess(args.sample_input_file, inference_only=True, retrieved_neighbours=args.use_retrieved_neighbours)
print("number of total data_list:", len(data_list))
if args.start_idx != -1 and args.end_idx != -1:
print("getting data from %d to %d" % (args.start_idx, args.end_idx))
data_list = data_list[args.start_idx:args.end_idx]
print("number of test samples in the dataset:", len(data_list))
prompt_list = get_chatqa2_input(data_list, args.eval_dataset, tokenizer, num_ctx=args.num_ctx, max_output_len=args.max_tokens, max_seq_length=args.max_seq_length)
return prompt_list
def main():
args = get_args()
## bos token for llama-3
bos_token = "<|begin_of_text|>"
## get model_path
model_path = args.model_folder
## get prompt_list
prompt_list = get_prompt_list(args)
output_path = os.path.join(model_path, "outputs")
if not os.path.exists(output_path):
os.mkdir(output_path)
## get output_datapath
if args.start_idx != -1 and args.end_idx != -1:
if args.use_retrieved_neighbours:
output_datapath = os.path.join(output_path, "%s_output_%dto%d_ctx%d.txt" % (args.eval_dataset, args.start_idx, args.end_idx, args.num_ctx))
else:
output_datapath = os.path.join(output_path, "%s_output_%dto%d.txt" % (args.eval_dataset, args.start_idx, args.end_idx))
else:
if args.use_retrieved_neighbours:
output_datapath = os.path.join(output_path, "%s_output_ctx%d.txt" % (args.eval_dataset, args.num_ctx))
else:
output_datapath = os.path.join(output_path, "%s_output.txt" % (args.eval_dataset))
## run inference
sampling_params = SamplingParams(temperature=0, top_k=1, max_tokens=args.max_tokens)
## This changes the GPU support to 8
model_vllm = LLM(model_path, tensor_parallel_size=8, dtype=torch.bfloat16)
print(model_vllm)
output_list = []
for prompt in tqdm(prompt_list):
prompt = bos_token + prompt
output = model_vllm.generate([prompt], sampling_params)[0]
generated_text = output.outputs[0].text
generated_text = generated_text.strip().replace("\n", " ")
## for llama3
if "<|eot_id|>" in generated_text:
idx = generated_text.index("<|eot_id|>")
generated_text = generated_text[:idx]
if "<|end_of_text|>" in generated_text:
idx = generated_text.index("<|end_of_text|>")
generated_text = generated_text[:idx]
print("="*80)
print("prompt:", prompt)
print("-"*80)
print("generated_text:", generated_text)
print("="*80)
output_list.append(generated_text)
print("writing to %s" % output_datapath)
with open(output_datapath, "w", encoding="utf-8") as f:
for output in output_list:
f.write(output + "\n")
if __name__ == "__main__":
main()