|
"""Script to generate text from a trained model using HuggingFace wrappers.""" |
|
|
|
import argparse |
|
import json |
|
import builtins as __builtin__ |
|
import torch |
|
|
|
import sys, os |
|
current_working_directory = os.getcwd() |
|
sys.path.append(f"{current_working_directory}") |
|
from composer.utils import dist, get_device |
|
from open_lm.utils.transformers.hf_model import OpenLMforCausalLM |
|
from open_lm.utils.transformers.hf_config import OpenLMConfig |
|
from open_lm.utils.llm_foundry_wrapper import SimpleComposerOpenLMCausalLM |
|
from open_lm.model import create_params |
|
from open_lm.params import add_model_args |
|
from transformers import GPTNeoXTokenizerFast, LlamaTokenizerFast |
|
|
|
import os |
|
|
|
|
|
|
|
builtin_print = __builtin__.print |
|
|
|
|
|
@torch.inference_mode() |
|
def run_model(open_lm: OpenLMforCausalLM, tokenizer, args): |
|
dist.initialize_dist(get_device(None), timeout=600) |
|
input_text_loads = json.loads(args.input_text) |
|
input = tokenizer(input_text_loads['instruction'] + input_text_loads['input']) |
|
input = {k: torch.tensor(v).unsqueeze(0).cuda() for k, v in input.items()} |
|
composer_model = SimpleComposerOpenLMCausalLM(open_lm, tokenizer) |
|
composer_model = composer_model.cuda() |
|
|
|
generate_args = { |
|
"do_sample": args.temperature > 0, |
|
"pad_token_id": 50282, |
|
"max_new_tokens": args.max_gen_len, |
|
"use_cache": args.use_cache, |
|
"num_beams": args.num_beams, |
|
|
|
} |
|
|
|
if args.temperature > 0: |
|
generate_args["temperature"] = args.temperature |
|
generate_args["top_p"] = args.top_p |
|
output = composer_model.generate( |
|
input["input_ids"], |
|
**generate_args, |
|
eos_token_id=[0], |
|
) |
|
output = tokenizer.decode(output[0][len(input["input_ids"][0]): -1].cpu().numpy()) |
|
print("-" * 50) |
|
print("\t\t Model output:") |
|
print("-" * 50) |
|
print(output) |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--checkpoint") |
|
parser.add_argument("--model", type=str, default="open_lm_1b", help="Name of the model to use") |
|
|
|
parser.add_argument("--input-text", required=True) |
|
parser.add_argument("--max-gen-len", default=200, type=int) |
|
parser.add_argument("--temperature", default=0.8, type=float) |
|
parser.add_argument("--top-p", default=0.95, type=float) |
|
parser.add_argument("--use-cache", default=False, action="store_true") |
|
parser.add_argument("--tokenizer", default="EleutherAI/gpt-neox-20b", type=str) |
|
parser.add_argument("--num-beams", default=1, type=int) |
|
|
|
add_model_args(parser) |
|
args = parser.parse_args() |
|
print("Loading model into the right classes...") |
|
open_lm = OpenLMforCausalLM(OpenLMConfig(create_params(args))) |
|
|
|
if "gpt-neox-20b" in args.tokenizer: |
|
tokenizer = GPTNeoXTokenizerFast.from_pretrained("EleutherAI/gpt-neox-20b") |
|
elif "llama" in args.tokenizer: |
|
tokenizer = LlamaTokenizerFast.from_pretrained(args.tokenizer) |
|
else: |
|
raise ValueError(f"Unknown tokenizer {args.tokenizer}") |
|
if args.checkpoint is not None: |
|
print("Loading checkpoint from disk...") |
|
checkpoint = torch.load(args.checkpoint) |
|
state_dict = checkpoint["state_dict"] |
|
state_dict = {x.replace("module.", ""): y for x, y in state_dict.items()} |
|
open_lm.model.load_state_dict(state_dict) |
|
open_lm.model.eval() |
|
|
|
run_model(open_lm, tokenizer, args) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|