kevinwang676's picture
Upload folder using huggingface_hub
4721aa1
import argparse
from transformers import AutoConfig, AutoModel, AutoTokenizer
import torch
import os
parser = argparse.ArgumentParser()
parser.add_argument("--pt-checkpoint", type=str, default=None, help="The checkpoint path")
parser.add_argument("--model", type=str, default=None, help="main model weights")
parser.add_argument("--tokenizer", type=str, default=None, help="main model weights")
parser.add_argument("--pt-pre-seq-len", type=int, default=128, help="The pre-seq-len used in p-tuning")
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--max-new-tokens", type=int, default=128)
args = parser.parse_args()
if args.tokenizer is None:
args.tokenizer = args.model
if args.pt_checkpoint:
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, trust_remote_code=True)
config = AutoConfig.from_pretrained(args.model, trust_remote_code=True, pre_seq_len=args.pt_pre_seq_len)
model = AutoModel.from_pretrained(args.model, config=config, trust_remote_code=True).cuda()
prefix_state_dict = torch.load(os.path.join(args.pt_checkpoint, "pytorch_model.bin"))
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():
if k.startswith("transformer.prefix_encoder."):
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
else:
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, trust_remote_code=True)
model = AutoModel.from_pretrained(args.model, trust_remote_code=True)
model = model.to(args.device)
while True:
prompt = input("Prompt:")
inputs = tokenizer(prompt, return_tensors="pt")
inputs = inputs.to(args.device)
response = model.generate(input_ids=inputs["input_ids"], max_length=inputs["input_ids"].shape[-1] + args.max_new_tokens)
response = response[0, inputs["input_ids"].shape[-1]:]
print("Response:", tokenizer.decode(response, skip_special_tokens=True))