| """Script to generate text from a trained model using HuggingFace wrappers.""" |
|
|
| import argparse |
| import json |
| import builtins as __builtin__ |
| import torch |
| import os |
| import pandas as pd |
| import multiprocessing as mp |
|
|
| import sys |
| sys.path.append("/home/ubuntu/model_sft/zr/open_lm") |
| from composer.utils import dist, get_device |
| from open_lm.utils.transformers.hf_model import OpenLMforCausalLM |
| from open_lm.utils.transformers.hf_config import OpenLMConfig |
| from open_lm.utils.llm_foundry_wrapper import SimpleComposerOpenLMCausalLM |
| from open_lm.model import create_params |
| from open_lm.params import add_model_args |
| from transformers import GPTNeoXTokenizerFast, LlamaTokenizerFast |
|
|
| builtin_print = __builtin__.print |
|
|
| def load_model_and_tokenizer(args): |
| """ |
| 加载模型和分词器 |
| """ |
| open_lm = OpenLMforCausalLM(OpenLMConfig(create_params(args))) |
|
|
| if "gpt-neox-20b" in args.tokenizer: |
| tokenizer = GPTNeoXTokenizerFast.from_pretrained("EleutherAI/gpt-neox-20b") |
| elif "llama" in args.tokenizer: |
| tokenizer = LlamaTokenizerFast.from_pretrained(args.tokenizer) |
| else: |
| raise ValueError(f"Unknown tokenizer {args.tokenizer}") |
|
|
| if args.checkpoint is not None: |
| print("Loading checkpoint from disk...") |
| checkpoint = torch.load(args.checkpoint) |
| state_dict = checkpoint["state_dict"] |
| state_dict = {x.replace("module.", ""): y for x, y in state_dict.items()} |
| open_lm.model.load_state_dict(state_dict) |
|
|
| open_lm.model.eval() |
|
|
| return open_lm, tokenizer |
|
|
| @torch.inference_mode() |
| def run_model(open_lm: OpenLMforCausalLM, tokenizer, input_text, args): |
| dist.initialize_dist(get_device(None), timeout=600) |
| input_text_loads = json.loads(input_text) |
| input = tokenizer(input_text_loads['instruction'] + " " + input_text_loads['input']) |
| input = {k: torch.tensor(v).unsqueeze(0).cuda() for k, v in input.items()} |
| composer_model = SimpleComposerOpenLMCausalLM(open_lm, tokenizer) |
| composer_model = composer_model.cuda() |
|
|
| generate_args = { |
| "do_sample": args.temperature > 0, |
| "pad_token_id": 50282, |
| "max_new_tokens": args.max_gen_len, |
| "use_cache": args.use_cache, |
| "num_beams": args.num_beams, |
| } |
|
|
| if args.temperature > 0: |
| generate_args["temperature"] = args.temperature |
| generate_args["top_p"] = args.top_p |
|
|
| output = composer_model.generate( |
| input["input_ids"], |
| **generate_args, |
| eos_token_id=[0], |
| ) |
| len_input = len(input["input_ids"][0]) |
| output_text = tokenizer.decode(output[0][len_input:].cpu().numpy()) |
| return { |
| "instruction": input_text_loads['instruction'], |
| "input": input_text_loads['input'], |
| "output": output_text |
| } |
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--checkpoint") |
| parser.add_argument("--model", type=str, default="open_lm_1b", help="Name of the model to use") |
| parser.add_argument("--input-file", required=True, help="Input JSONL file path") |
| parser.add_argument("--output-file", default="output.xlsx", help="Output Excel file path") |
| parser.add_argument("--max-gen-len", default=200, type=int) |
| parser.add_argument("--temperature", default=0.0, type=float) |
| parser.add_argument("--top-p", default=0.95, type=float) |
| parser.add_argument("--use-cache", default=False, action="store_true") |
| parser.add_argument("--tokenizer", default="EleutherAI/gpt-neox-20b", type=str) |
| parser.add_argument("--num-beams", default=1, type=int) |
| parser.add_argument("--num-workers", default=4, type=int) |
| add_model_args(parser) |
| args = parser.parse_args() |
|
|
| open_lm, tokenizer = load_model_and_tokenizer(args) |
|
|
| with open(args.input_file, 'r') as f: |
| input_texts = [line.strip() for line in f] |
|
|
| with mp.Pool(processes=args.num_workers) as pool: |
| results = pool.starmap(run_model, [(open_lm, tokenizer, input_text, args) for input_text in input_texts]) |
|
|
| df = pd.DataFrame(results) |
| df.to_excel(args.output_file, index=False) |
| print(f"Results saved to {args.output_file}") |
|
|
| if __name__ == "__main__": |
| main() |
|
|