|
import torch |
|
|
|
|
|
from builtin_architecture import make_model |
|
import os |
|
import sys |
|
import time |
|
from dataset import dataset, get_train_dataset |
|
import torch.nn.functional as F |
|
|
|
EXPERIMENT_DIRECTORY = "runs/code-decoder-v10-vanilla-smaller-batchfirst" |
|
|
|
device = "mps" if torch.backends.mps.is_available() else "cpu" |
|
|
|
device = "cpu" |
|
|
|
|
|
net = make_model() |
|
net.to(device) |
|
|
|
net.load_state_dict( |
|
torch.load(os.path.join(EXPERIMENT_DIRECTORY, "ckpt", "best.pt"), weights_only=True) |
|
) |
|
|
|
|
|
for name, param in net.named_parameters(): |
|
if torch.isnan(param).any(): |
|
print(f"NaN found in {name}") |
|
for name, param in net.named_parameters(): |
|
if param.grad is not None and torch.isnan(param.grad).any(): |
|
print(f"NaN found in gradients of {name}") |
|
|
|
|
|
pad_token_id = 0 |
|
sep_token_id = None |
|
|
|
input_text = input("Prompt: ") |
|
max_length = 100 |
|
|
|
|
|
input_ids = torch.tensor(dataset.manager.encode(input_text), dtype=int) |
|
print(input_ids.shape) |
|
attention_mask = dataset.manager.attention_mask(input_ids.squeeze(0)).to(device) |
|
|
|
|
|
generated_text = dataset.manager.decode(input_ids) |
|
|
|
print(generated_text) |
|
generated_text = "" |
|
input_ids = torch.randint(199, (1, 1), dtype=torch.long).to(device) |
|
|
|
net.eval() |
|
temp = 1.0 |
|
|
|
for _ in range(max_length): |
|
with torch.no_grad(): |
|
output = net(input_ids) |
|
logits = F.log_softmax(output[-1], dim=-1) |
|
word_weights = logits.div(temp).cpu() |
|
|
|
|
|
top_k = 10 |
|
vocab_size = word_weights.size(0) |
|
top_k = min(top_k, vocab_size) |
|
|
|
top_probs, top_indices = torch.topk(word_weights, k=top_k) |
|
|
|
|
|
if top_probs.size(0) == 1: |
|
word_idx = top_indices[0] |
|
else: |
|
sampled_idx = torch.multinomial(top_probs, 1).item() |
|
word_idx = top_indices[sampled_idx] |
|
|
|
|
|
print(word_idx) |
|
predicted_token = dataset.manager.decode(word_idx.item()) |
|
print(predicted_token, end=" ") |
|
generated_text += predicted_token |
|
|
|
print("Word Weights:", word_weights) |
|
print("Top Probabilities:", top_probs) |
|
print("Top Indices:", top_indices) |
|
|
|
|
|
word_tensor = torch.tensor([[word_idx]], dtype=torch.long).to(device) |
|
input_ids = torch.cat([input_ids, word_tensor], dim=1) |
|
|
|
print("\nGenerated text:", generated_text) |
|
with open("output.txt", "w+") as f: |
|
f.write(generated_text) |
|
|