|
import random |
|
|
|
from os import path |
|
from argparse import ArgumentParser |
|
|
|
import torch |
|
|
|
from torch.cuda import is_available as cuda_is_available |
|
|
|
from model import LightGPT, LightGPTInstruct |
|
from data import SmolTalk |
|
|
|
import tiktoken |
|
|
|
|
|
def main(): |
|
parser = ArgumentParser( |
|
description="Use a greedy search strategy to generate candidate sequences.", |
|
) |
|
|
|
parser.add_argument( |
|
"--checkpoint_path", default="./checkpoints/checkpoint.pt", type=str |
|
) |
|
parser.add_argument("--lora_path", default=None, type=str) |
|
parser.add_argument("--max_tokens", default=100, type=int) |
|
parser.add_argument("--context_length", default=1024, type=int) |
|
parser.add_argument("--num_candidates", default=3, type=int) |
|
parser.add_argument("--beam_width", default=16, type=int) |
|
parser.add_argument("--device", default="cuda", type=str) |
|
parser.add_argument("--seed", default=None, type=int) |
|
|
|
args = parser.parse_args() |
|
|
|
if "cuda" in args.device and not cuda_is_available(): |
|
raise RuntimeError("Cuda is not available.") |
|
|
|
torch.set_float32_matmul_precision("high") |
|
|
|
if args.seed: |
|
torch.manual_seed(args.seed) |
|
random.seed(args.seed) |
|
|
|
checkpoint = torch.load( |
|
args.checkpoint_path, map_location=args.device, weights_only=True |
|
) |
|
|
|
tokenizer = tiktoken.get_encoding(checkpoint["token_encoding"]) |
|
|
|
model = LightGPT(**checkpoint["model_args"]) |
|
|
|
model = torch.compile(model) |
|
|
|
model.load_state_dict(checkpoint["model"]) |
|
|
|
print("Model checkpoint loaded") |
|
|
|
if args.lora_path: |
|
checkpoint = torch.load( |
|
args.lora_path, map_location=args.device, weights_only=True |
|
) |
|
|
|
model = LightGPTInstruct(model, **checkpoint["lora_args"]) |
|
|
|
model = torch.compile(model) |
|
|
|
model.load_state_dict(checkpoint["lora"], strict=False) |
|
|
|
model.merge_lora_parameters() |
|
|
|
print("LoRA checkpoint loaded") |
|
|
|
model.to(args.device) |
|
|
|
model.eval() |
|
|
|
while True: |
|
prompt = input("Enter a prompt: ") |
|
|
|
if args.lora_path: |
|
prompt = SmolTalk.PROMPT_TEMPLATE.format(role="user", message=prompt) |
|
|
|
prompt = tokenizer.encode_ordinary(prompt) |
|
|
|
prompt = torch.tensor(prompt, dtype=torch.int64, device=args.device) |
|
|
|
candidates = model.beam_search( |
|
prompt, |
|
args.max_tokens, |
|
args.context_length, |
|
args.num_candidates, |
|
args.beam_width, |
|
) |
|
|
|
for i, candidate in enumerate(candidates, start=1): |
|
print(f"Sequence #{i}") |
|
|
|
out = tokenizer.decode(candidate.tokens.tolist()).strip() |
|
|
|
print(out, end="\n\n") |
|
|
|
print("\n") |
|
|
|
if "y" not in input("Go again? (yes|no): ").lower(): |
|
break |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|