""" Sample from the trained model with PyTorch """ import os import pickle from contextlib import nullcontext import torch from model import ModelArgs, Transformer from tokenizer import Tokenizer from tinystories import get_tokenizer_model_path # ----------------------------------------------------------------------------- checkpoint = 'out/ckpt.pt' start = "" # or "<|endoftext|>" or etc. Can also specify a file, use as: "FILE:prompt.txt" num_samples = 1 # number of samples to draw max_new_tokens = 100 # number of tokens generated in each sample temperature = 1.0 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions top_k = 300 # retain only the top_k most likely tokens, clamp others to have 0 probability tokenizer = "" # override the tokenizer model path seed = 1337 device = 'cuda' if torch.cuda.is_available() else 'cpu' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc. #dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16' dtype = "float32" compile = False # use PyTorch 2.0 to compile the model to be faster exec(open('configurator.py').read()) # overrides from command line or config file # ----------------------------------------------------------------------------- torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype) # init from a model saved in a specific directory checkpoint_dict = torch.load(checkpoint, map_location=device) gptconf = ModelArgs(**checkpoint_dict['model_args']) model = Transformer(gptconf) state_dict = checkpoint_dict['model'] unwanted_prefix = '_orig_mod.' for k,v in list(state_dict.items()): if k.startswith(unwanted_prefix): state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) model.load_state_dict(state_dict, strict=False) model.eval() model.to(device) if compile: print("Compiling the model...") model = torch.compile(model) # requires PyTorch 2.0 (optional) # load the tokenizer vocab_source = checkpoint_dict["config"].get("vocab_source", "llama2") vocab_size = gptconf.vocab_size if tokenizer: # a specific tokenizer is provided, use it tokenizer_model = tokenizer else: # let's try to find the tokenizer model automatically. bit gross here... query_vocab_size = 0 if vocab_source == "llama2" else vocab_size tokenizer_model = get_tokenizer_model_path(vocab_size=query_vocab_size) enc = Tokenizer(tokenizer_model=tokenizer_model) # encode the beginning of the prompt if start.startswith('FILE:'): with open(start[5:], 'r', encoding='utf-8') as f: start = f.read() start_ids = enc.encode(start, bos=True, eos=False) x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...]) # run generation with torch.no_grad(): with ctx: for k in range(num_samples): y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k) print(enc.decode(y[0].tolist())) print('---------------')