Spaces:
Paused
Paused
import argparse | |
from transformers import AutoConfig, AutoModel, AutoTokenizer | |
import torch | |
import os | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--pt-checkpoint", type=str, default=None, help="The checkpoint path") | |
parser.add_argument("--model", type=str, default=None, help="main model weights") | |
parser.add_argument("--tokenizer", type=str, default=None, help="main model weights") | |
parser.add_argument("--pt-pre-seq-len", type=int, default=128, help="The pre-seq-len used in p-tuning") | |
parser.add_argument("--device", type=str, default="cuda") | |
parser.add_argument("--max-new-tokens", type=int, default=128) | |
args = parser.parse_args() | |
if args.tokenizer is None: | |
args.tokenizer = args.model | |
if args.pt_checkpoint: | |
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, trust_remote_code=True) | |
config = AutoConfig.from_pretrained(args.model, trust_remote_code=True, pre_seq_len=args.pt_pre_seq_len) | |
model = AutoModel.from_pretrained(args.model, config=config, trust_remote_code=True).cuda() | |
prefix_state_dict = torch.load(os.path.join(args.pt_checkpoint, "pytorch_model.bin")) | |
new_prefix_state_dict = {} | |
for k, v in prefix_state_dict.items(): | |
if k.startswith("transformer.prefix_encoder."): | |
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v | |
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) | |
else: | |
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, trust_remote_code=True) | |
model = AutoModel.from_pretrained(args.model, trust_remote_code=True) | |
model = model.to(args.device) | |
while True: | |
prompt = input("Prompt:") | |
inputs = tokenizer(prompt, return_tensors="pt") | |
inputs = inputs.to(args.device) | |
response = model.generate(input_ids=inputs["input_ids"], max_length=inputs["input_ids"].shape[-1] + args.max_new_tokens) | |
response = response[0, inputs["input_ids"].shape[-1]:] | |
print("Response:", tokenizer.decode(response, skip_special_tokens=True)) |