File size: 4,141 Bytes
af6e330
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
"""Script to generate text from a trained model using HuggingFace wrappers."""

import argparse
import json
import builtins as __builtin__
import torch
import os
import pandas as pd
import multiprocessing as mp

import sys
sys.path.append("/home/ubuntu/model_sft/zr/open_lm")
from composer.utils import dist, get_device
from open_lm.utils.transformers.hf_model import OpenLMforCausalLM
from open_lm.utils.transformers.hf_config import OpenLMConfig
from open_lm.utils.llm_foundry_wrapper import SimpleComposerOpenLMCausalLM
from open_lm.model import create_params
from open_lm.params import add_model_args
from transformers import GPTNeoXTokenizerFast, LlamaTokenizerFast

builtin_print = __builtin__.print

def load_model_and_tokenizer(args):
    """
    加载模型和分词器
    """
    open_lm = OpenLMforCausalLM(OpenLMConfig(create_params(args)))

    if "gpt-neox-20b" in args.tokenizer:
        tokenizer = GPTNeoXTokenizerFast.from_pretrained("EleutherAI/gpt-neox-20b")
    elif "llama" in args.tokenizer:
        tokenizer = LlamaTokenizerFast.from_pretrained(args.tokenizer)
    else:
        raise ValueError(f"Unknown tokenizer {args.tokenizer}")

    if args.checkpoint is not None:
        print("Loading checkpoint from disk...")
        checkpoint = torch.load(args.checkpoint)
        state_dict = checkpoint["state_dict"]
        state_dict = {x.replace("module.", ""): y for x, y in state_dict.items()}
        open_lm.model.load_state_dict(state_dict)

    open_lm.model.eval()

    return open_lm, tokenizer

@torch.inference_mode()
def run_model(open_lm: OpenLMforCausalLM, tokenizer, input_text, args):
    dist.initialize_dist(get_device(None), timeout=600)
    input_text_loads = json.loads(input_text)
    input = tokenizer(input_text_loads['instruction'] + " " + input_text_loads['input'])
    input = {k: torch.tensor(v).unsqueeze(0).cuda() for k, v in input.items()}
    composer_model = SimpleComposerOpenLMCausalLM(open_lm, tokenizer)
    composer_model = composer_model.cuda()

    generate_args = {
        "do_sample": args.temperature > 0,
        "pad_token_id": 50282,
        "max_new_tokens": args.max_gen_len,
        "use_cache": args.use_cache,
        "num_beams": args.num_beams,
    }

    if args.temperature > 0:
        generate_args["temperature"] = args.temperature
        generate_args["top_p"] = args.top_p

    output = composer_model.generate(
        input["input_ids"],
        **generate_args,
        eos_token_id=[0],
    )
    len_input = len(input["input_ids"][0])
    output_text = tokenizer.decode(output[0][len_input:].cpu().numpy())
    return {
        "instruction": input_text_loads['instruction'],
        "input": input_text_loads['input'],
        "output": output_text
    }

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--checkpoint")
    parser.add_argument("--model", type=str, default="open_lm_1b", help="Name of the model to use")
    parser.add_argument("--input-file", required=True, help="Input JSONL file path")
    parser.add_argument("--output-file", default="output.xlsx", help="Output Excel file path")
    parser.add_argument("--max-gen-len", default=200, type=int)
    parser.add_argument("--temperature", default=0.0, type=float)
    parser.add_argument("--top-p", default=0.95, type=float)
    parser.add_argument("--use-cache", default=False, action="store_true")
    parser.add_argument("--tokenizer", default="EleutherAI/gpt-neox-20b", type=str)
    parser.add_argument("--num-beams", default=1, type=int)
    parser.add_argument("--num-workers", default=4, type=int)
    add_model_args(parser)
    args = parser.parse_args()

    open_lm, tokenizer = load_model_and_tokenizer(args)

    with open(args.input_file, 'r') as f:
        input_texts = [line.strip() for line in f]

    with mp.Pool(processes=args.num_workers) as pool:
        results = pool.starmap(run_model, [(open_lm, tokenizer, input_text, args) for input_text in input_texts])

    df = pd.DataFrame(results)
    df.to_excel(args.output_file, index=False)
    print(f"Results saved to {args.output_file}")

if __name__ == "__main__":
    main()