File size: 4,141 Bytes
af6e330 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
"""Script to generate text from a trained model using HuggingFace wrappers."""
import argparse
import json
import builtins as __builtin__
import torch
import os
import pandas as pd
import multiprocessing as mp
import sys
sys.path.append("/home/ubuntu/model_sft/zr/open_lm")
from composer.utils import dist, get_device
from open_lm.utils.transformers.hf_model import OpenLMforCausalLM
from open_lm.utils.transformers.hf_config import OpenLMConfig
from open_lm.utils.llm_foundry_wrapper import SimpleComposerOpenLMCausalLM
from open_lm.model import create_params
from open_lm.params import add_model_args
from transformers import GPTNeoXTokenizerFast, LlamaTokenizerFast
builtin_print = __builtin__.print
def load_model_and_tokenizer(args):
"""
加载模型和分词器
"""
open_lm = OpenLMforCausalLM(OpenLMConfig(create_params(args)))
if "gpt-neox-20b" in args.tokenizer:
tokenizer = GPTNeoXTokenizerFast.from_pretrained("EleutherAI/gpt-neox-20b")
elif "llama" in args.tokenizer:
tokenizer = LlamaTokenizerFast.from_pretrained(args.tokenizer)
else:
raise ValueError(f"Unknown tokenizer {args.tokenizer}")
if args.checkpoint is not None:
print("Loading checkpoint from disk...")
checkpoint = torch.load(args.checkpoint)
state_dict = checkpoint["state_dict"]
state_dict = {x.replace("module.", ""): y for x, y in state_dict.items()}
open_lm.model.load_state_dict(state_dict)
open_lm.model.eval()
return open_lm, tokenizer
@torch.inference_mode()
def run_model(open_lm: OpenLMforCausalLM, tokenizer, input_text, args):
dist.initialize_dist(get_device(None), timeout=600)
input_text_loads = json.loads(input_text)
input = tokenizer(input_text_loads['instruction'] + " " + input_text_loads['input'])
input = {k: torch.tensor(v).unsqueeze(0).cuda() for k, v in input.items()}
composer_model = SimpleComposerOpenLMCausalLM(open_lm, tokenizer)
composer_model = composer_model.cuda()
generate_args = {
"do_sample": args.temperature > 0,
"pad_token_id": 50282,
"max_new_tokens": args.max_gen_len,
"use_cache": args.use_cache,
"num_beams": args.num_beams,
}
if args.temperature > 0:
generate_args["temperature"] = args.temperature
generate_args["top_p"] = args.top_p
output = composer_model.generate(
input["input_ids"],
**generate_args,
eos_token_id=[0],
)
len_input = len(input["input_ids"][0])
output_text = tokenizer.decode(output[0][len_input:].cpu().numpy())
return {
"instruction": input_text_loads['instruction'],
"input": input_text_loads['input'],
"output": output_text
}
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint")
parser.add_argument("--model", type=str, default="open_lm_1b", help="Name of the model to use")
parser.add_argument("--input-file", required=True, help="Input JSONL file path")
parser.add_argument("--output-file", default="output.xlsx", help="Output Excel file path")
parser.add_argument("--max-gen-len", default=200, type=int)
parser.add_argument("--temperature", default=0.0, type=float)
parser.add_argument("--top-p", default=0.95, type=float)
parser.add_argument("--use-cache", default=False, action="store_true")
parser.add_argument("--tokenizer", default="EleutherAI/gpt-neox-20b", type=str)
parser.add_argument("--num-beams", default=1, type=int)
parser.add_argument("--num-workers", default=4, type=int)
add_model_args(parser)
args = parser.parse_args()
open_lm, tokenizer = load_model_and_tokenizer(args)
with open(args.input_file, 'r') as f:
input_texts = [line.strip() for line in f]
with mp.Pool(processes=args.num_workers) as pool:
results = pool.starmap(run_model, [(open_lm, tokenizer, input_text, args) for input_text in input_texts])
df = pd.DataFrame(results)
df.to_excel(args.output_file, index=False)
print(f"Results saved to {args.output_file}")
if __name__ == "__main__":
main()
|