|
""" |
|
Sampling script for the nano-coder model. |
|
This script loads a trained nano-coder model and generates Python code completions. |
|
""" |
|
|
|
import os |
|
import pickle |
|
import torch |
|
import torch.nn.functional as F |
|
from model import GPTConfig, GPT |
|
|
|
|
|
out_dir = 'out-nano-coder' |
|
start = "def fibonacci(n):\n " |
|
num_samples = 5 |
|
max_new_tokens = 500 |
|
temperature = 0.8 |
|
top_k = 200 |
|
seed = 1337 |
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' |
|
|
|
|
|
def load_model(): |
|
"""Load the trained nano-coder model.""" |
|
|
|
ckpt_path = os.path.join(out_dir, 'ckpt.pt') |
|
if not os.path.exists(ckpt_path): |
|
raise FileNotFoundError(f"Checkpoint not found at {ckpt_path}. Please train the model first.") |
|
|
|
checkpoint = torch.load(ckpt_path, map_location=device) |
|
gptconf = GPTConfig(**checkpoint['model_args']) |
|
model = GPT(gptconf) |
|
state_dict = checkpoint['model'] |
|
unwanted_prefix = '_orig_mod.' |
|
for k,v in list(state_dict.items()): |
|
if k.startswith(unwanted_prefix): |
|
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) |
|
model.load_state_dict(state_dict) |
|
model.eval() |
|
model.to(device) |
|
|
|
return model, checkpoint |
|
|
|
def load_vocab(): |
|
"""Load the vocabulary from the dataset.""" |
|
data_dir = os.path.join('data', 'python-codes-25k') |
|
meta_path = os.path.join(data_dir, 'meta.pkl') |
|
|
|
if not os.path.exists(meta_path): |
|
raise FileNotFoundError(f"Vocabulary not found at {meta_path}. Please run prepare_code_dataset.py first.") |
|
|
|
with open(meta_path, 'rb') as f: |
|
meta = pickle.load(f) |
|
|
|
return meta['stoi'], meta['itos'] |
|
|
|
def encode(text, stoi): |
|
"""Encode text to token ids.""" |
|
return [stoi[c] for c in text] |
|
|
|
def decode(ids, itos): |
|
"""Decode token ids to text.""" |
|
return ''.join([itos[i] for i in ids]) |
|
|
|
def generate_code(model, stoi, itos, start_text, max_new_tokens, temperature, top_k): |
|
"""Generate code completion.""" |
|
|
|
start_ids = encode(start_text, stoi) |
|
x = torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...] |
|
|
|
|
|
with torch.no_grad(): |
|
with torch.amp.autocast(device_type='cuda' if device == 'cuda' else 'cpu', dtype=torch.bfloat16 if dtype == 'bfloat16' else torch.float16): |
|
y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k) |
|
completion = decode(y[0].tolist(), itos) |
|
|
|
return completion |
|
|
|
def main(): |
|
print("Loading nano-coder model...") |
|
model, checkpoint = load_model() |
|
stoi, itos = load_vocab() |
|
|
|
print(f"Model loaded successfully!") |
|
print(f"Vocabulary size: {len(stoi)}") |
|
print(f"Model parameters: {model.get_num_params()/1e6:.2f}M") |
|
print(f"Context length: {model.config.block_size}") |
|
print(f"Generating {num_samples} samples...") |
|
print(f"Start text: {repr(start)}") |
|
print("-" * 80) |
|
|
|
|
|
torch.manual_seed(seed) |
|
|
|
|
|
for i in range(num_samples): |
|
print(f"\n--- Sample {i+1} ---") |
|
completion = generate_code(model, stoi, itos, start, max_new_tokens, temperature, top_k) |
|
print(completion) |
|
print("-" * 80) |
|
|
|
if __name__ == '__main__': |
|
main() |