""" Sample example """ from contextlib import nullcontext import torch from model import GPTConfig, GPT from transformers import GPT2TokenizerFast from safetensors.torch import save_file, load_file start = "\n" num_samples = 10 # number of samples to generate max_new_tokens = 500 # number of tokens generated in each sample temperature = 0.8 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions top_k = 200 # retain only the top_k most likely tokens, clamp others to have 0 probability seed = 1337 device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc. dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16' torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype) # Load model checkpoint #st = load_file("model.safetensors", device=device_type) #load_model(model, ) #safe_tensors = from_file(file_path) #print(st) checkpoint = torch.load("pytorch_model.bin", map_location=device) gptconf = GPTConfig(**checkpoint['model_args']) model = GPT(gptconf) state_dict = checkpoint['model'] unwanted_prefix = '_orig_mod.' for k,v in list(state_dict.items()): if k.startswith(unwanted_prefix): state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) model.load_state_dict(state_dict) model.eval() model.to(device) # Prepare the tokenizer calls tokenizer = GPT2TokenizerFast.from_pretrained('gpt2') encode = lambda s: tokenizer.encode(s) decode = lambda l: tokenizer.decode(l) # Create the initial embeddings start_ids = encode(start) x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...]) # run generation with torch.no_grad(): with ctx: for k in range(num_samples): y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k) print(decode(y[0].tolist())) print('---------------')