File size: 3,544 Bytes
5c83af4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98

## The following code is adapted from
## https://docs.mystic.ai/docs/llama-2-with-vllm-7b-13b-multi-gpu-70b


from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from arguments import get_args
from dataset_conv import get_chatqa2_input, preprocess
from tqdm import tqdm
import torch
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ['VLLM_NCCL_SO_PATH'] = '/usr/local/lib/python3.8/dist-packages/nvidia/nccl/lib/libnccl.so.2'

def get_prompt_list(args):

    ## get tokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)

    
    data_list = preprocess(args.sample_input_file, inference_only=True, retrieved_neighbours=args.use_retrieved_neighbours)
    print("number of total data_list:", len(data_list))
    if args.start_idx != -1 and args.end_idx != -1:
        print("getting data from %d to %d" % (args.start_idx, args.end_idx))
        data_list = data_list[args.start_idx:args.end_idx]
    
    print("number of test samples in the dataset:", len(data_list))
    prompt_list = get_chatqa2_input(data_list, args.eval_dataset, tokenizer, num_ctx=args.num_ctx, max_output_len=args.max_tokens, max_seq_length=args.max_seq_length)

    return prompt_list


def main():
    args = get_args()
    
    ## bos token for llama-3
    bos_token = "<|begin_of_text|>"
    ## get model_path
    model_path = args.model_folder
    
    ## get prompt_list
    prompt_list = get_prompt_list(args)

    output_path = os.path.join(model_path, "outputs")
    if not os.path.exists(output_path):
        os.mkdir(output_path)
    
    ## get output_datapath
    if args.start_idx != -1 and args.end_idx != -1:
        if args.use_retrieved_neighbours:
            output_datapath = os.path.join(output_path, "%s_output_%dto%d_ctx%d.txt" % (args.eval_dataset, args.start_idx, args.end_idx, args.num_ctx))
        else:
            output_datapath = os.path.join(output_path, "%s_output_%dto%d.txt" % (args.eval_dataset, args.start_idx, args.end_idx))
    else:
        if args.use_retrieved_neighbours:
            output_datapath = os.path.join(output_path, "%s_output_ctx%d.txt" % (args.eval_dataset, args.num_ctx))
        else:
            output_datapath = os.path.join(output_path, "%s_output.txt" % (args.eval_dataset))

    ## run inference
    sampling_params = SamplingParams(temperature=0, top_k=1, max_tokens=args.max_tokens)

    ## This changes the GPU support to 8
    model_vllm = LLM(model_path, tensor_parallel_size=8, dtype=torch.bfloat16)
    print(model_vllm)

    output_list = []
    for prompt in tqdm(prompt_list):
        prompt = bos_token + prompt
        output = model_vllm.generate([prompt], sampling_params)[0]
        generated_text = output.outputs[0].text
        generated_text = generated_text.strip().replace("\n", " ")

        ## for llama3
        if "<|eot_id|>" in generated_text:
            idx = generated_text.index("<|eot_id|>")
            generated_text = generated_text[:idx]
        if "<|end_of_text|>" in generated_text:
            idx = generated_text.index("<|end_of_text|>")
            generated_text = generated_text[:idx]

        print("="*80)
        print("prompt:", prompt)
        print("-"*80)
        print("generated_text:", generated_text)
        print("="*80)
        output_list.append(generated_text)

    print("writing to %s" % output_datapath)
    with open(output_datapath, "w", encoding="utf-8") as f:
        for output in output_list:
            f.write(output + "\n")


if __name__ == "__main__":
    main()