|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("loading libraries") |
|
import os |
|
import datetime |
|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
from torch.utils.data import Dataset, DataLoader |
|
import torch.optim as optim |
|
from torch.optim.lr_scheduler import ReduceLROnPlateau |
|
import math |
|
import inspect |
|
|
|
print("done loading libraries") |
|
|
|
print("Hardcoding Memorized_Speech = Gettysburg Address") |
|
Memorized_Speech = """ |
|
Four score and seven years ago our fathers brought forth on this continent, a new nation, conceived in Liberty, and dedicated to the proposition that all men are created equal. |
|
|
|
Now we are engaged in a great civil war, testing whether that nation, or any nation so conceived and so dedicated, can long endure. We are met on a great battle-field of that war. We have come to dedicate a portion of that field, as a final resting place for those who here gave their lives that that nation might live. It is altogether fitting and proper that we should do this. |
|
|
|
But, in a larger sense, we can not dedicate - we can not consecrate - we can not hallow-this ground. The brave men, living and dead, who struggled here, have consecrated it, far above our poor power to add or detract. The world will little note, nor long remember what we say here, but it can never forget what they did here. It is for us the living, rather, to be dedicated here to the unfinished work which they who fought here have thus far so nobly advanced. It is rather for us to be here dedicated to the great task remaining before us - that from these honored dead we take increased devotion to that cause for which they gave the last full measure of devotion - that we here highly resolve that these dead shall not have died in vain - that this nation, under God, shall have a new birth of freedom - and that government of the people, by the people, for the people, shall not perish from the earth. |
|
""" |
|
print(f'Length of Memorized_Speech = {len(Memorized_Speech)} characters, as follows:') |
|
print(Memorized_Speech) |
|
|
|
|
|
|
|
|
|
hyperparameters = { |
|
"vocab_size": 152, |
|
"special_tokens": ["<FreetheLLM>", "<cr>", "<pad>"], |
|
"n_embd": 512, |
|
"n_layer": 4, |
|
"n_head": 16, |
|
"n_inner": 4 * 512, |
|
"max_sequence_len": 264, |
|
"epochs": 200, |
|
"learning_rate": 1e-3, |
|
"batch_size": 1, |
|
"dropout": 0.2 |
|
} |
|
|
|
min_training_input_seq_len = 32 |
|
Early_stopping_loss = 0.003 |
|
|
|
Per_token_loss_threshold = 0.5 |
|
|
|
|
|
def print_with_line(message): |
|
frame = inspect.currentframe().f_back |
|
line_number = frame.f_lineno |
|
print(f"{message} at script line {line_number}") |
|
|
|
|
|
class Tokenizer: |
|
def __init__(self, text, special_tokens, vocab_size_hyperparameter): |
|
self.special_tokens = special_tokens |
|
self.cr_token = special_tokens[1] |
|
|
|
self.punctuation_list = ['.', ',', '/', '\\', '[', ']', '<', '?', '>', '-'] |
|
estimated_vocab_size = vocab_size_hyperparameter |
|
|
|
|
|
text = self.separate_punctuation(text) |
|
|
|
in_text_words = [] |
|
in_text_punctuations = [] |
|
for candidate in text.split(): |
|
cleaned_words = ''.join(c for c in candidate if c not in self.punctuation_list) |
|
if cleaned_words: |
|
in_text_words.append(cleaned_words.lower()) |
|
for char in candidate: |
|
if char in self.punctuation_list: |
|
in_text_punctuations.append(char) |
|
|
|
|
|
in_text_words = list(set(in_text_words)) |
|
in_text_words.sort() |
|
in_text_punctuations = list(set(in_text_punctuations)) |
|
in_text_punctuations.sort() |
|
|
|
self.vocab = self.special_tokens + in_text_punctuations + in_text_words |
|
self.vocab_size = len(self.vocab) |
|
|
|
if self.vocab_size != estimated_vocab_size: |
|
print(f"Warning: Calculated vocab_size ({self.vocab_size}) differs from estimated size ({estimated_vocab_size}).") |
|
|
|
self.word_to_index = {word: i for i, word in enumerate(self.vocab)} |
|
self.index_to_word = {i: word for i, word in enumerate(self.vocab)} |
|
|
|
def separate_punctuation(self, text): |
|
|
|
for char in self.punctuation_list: |
|
text = text.replace(char, f' {char} ') |
|
|
|
text = text.replace('\r', f' {self.cr_token} ') |
|
|
|
return text |
|
|
|
|
|
|
|
def tokenize(self, text): |
|
|
|
text = self.separate_punctuation(text) |
|
words = text.lower().split() |
|
token_ids = [] |
|
for word in words: |
|
if word in self.word_to_index: |
|
token_ids.append(self.word_to_index[word]) |
|
else: |
|
|
|
token_ids.append(self.word_to_index[self.special_tokens[-1]]) |
|
return token_ids |
|
|
|
def detokenize(self, tokens): |
|
return " ".join([self.index_to_word[token] for token in tokens if token in self.index_to_word]) |
|
|
|
|
|
class CausalSelfAttention(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
assert config["n_embd"] % config["n_head"] == 0 |
|
|
|
self.c_attn = nn.Linear(config["n_embd"], 3 * config["n_embd"]) |
|
|
|
self.c_proj = nn.Linear(config["n_embd"], config["n_embd"]) |
|
|
|
self.attn_dropout = nn.Dropout(0.1) |
|
self.resid_dropout = nn.Dropout(0.1) |
|
self.n_head = config["n_head"] |
|
self.n_embd = config["n_embd"] |
|
self.register_buffer("bias", torch.tril(torch.ones(config["max_sequence_len"], config["max_sequence_len"])) |
|
.view(1, 1, config["max_sequence_len"], config["max_sequence_len"])) |
|
|
|
def forward(self, x): |
|
B, T, C = x.size() |
|
|
|
|
|
q, k ,v = self.c_attn(x).split(self.n_embd, dim=2) |
|
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) |
|
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) |
|
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) |
|
|
|
|
|
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) |
|
att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) |
|
att = torch.softmax(att, dim=-1) |
|
att = self.attn_dropout(att) |
|
y = att @ v |
|
y = y.transpose(1, 2).contiguous().view(B, T, C) |
|
|
|
|
|
y = self.resid_dropout(self.c_proj(y)) |
|
return y |
|
|
|
class Block(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.ln_1 = nn.LayerNorm(config["n_embd"]) |
|
self.attn = CausalSelfAttention(config) |
|
self.ln_2 = nn.LayerNorm(config["n_embd"]) |
|
self.mlp = nn.Sequential( |
|
nn.Linear(config["n_embd"], config["n_inner"]), |
|
nn.GELU(), |
|
nn.Linear(config["n_inner"], config["n_embd"]), |
|
nn.Dropout(0.1), |
|
) |
|
|
|
def forward(self, x): |
|
x = x + self.attn(self.ln_1(x)) |
|
x = x + self.mlp(self.ln_2(x)) |
|
return x |
|
|
|
class ToyGPT2(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.config = config |
|
self.token_embedding_table = nn.Embedding(config["vocab_size"], config["n_embd"]) |
|
self.position_embedding_table = nn.Embedding(config["max_sequence_len"], config["n_embd"]) |
|
self.blocks = nn.Sequential(*[Block(config) for _ in range(config["n_layer"])]) |
|
self.ln_f = nn.LayerNorm(config["n_embd"]) |
|
self.lm_head = nn.Linear(config["n_embd"], config["vocab_size"]) |
|
|
|
|
|
self.apply(self._init_weights) |
|
|
|
|
|
self.lm_head.weight = self.token_embedding_table.weight |
|
|
|
def _init_weights(self, module): |
|
|
|
|
|
if isinstance(module, nn.Linear) and module.bias is not None: |
|
|
|
torch.nn.init.zeros_(module.bias) |
|
elif isinstance(module, nn.Embedding): |
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
|
|
|
|
def forward(self, idx, targets=None): |
|
B, T = idx.shape |
|
|
|
tok_emb = self.token_embedding_table(idx) |
|
pos_emb = self.position_embedding_table(torch.arange(T, device=idx.device)) |
|
x = tok_emb + pos_emb |
|
x = self.blocks(x) |
|
x = self.ln_f(x) |
|
logits = self.lm_head(x) |
|
|
|
if targets is None: |
|
loss = None |
|
else: |
|
B, T, C = logits.shape |
|
logits = logits.view(B*T, C) |
|
targets = targets.view(B*T) |
|
loss = nn.functional.cross_entropy(logits, targets) |
|
|
|
return logits, loss |
|
|
|
def generate(self, input_ids, max_new_tokens, temperature=1.0): |
|
self.eval() |
|
with torch.no_grad(): |
|
for _ in range(max_new_tokens): |
|
|
|
input_ids_truncated = input_ids[:, -self.config["max_sequence_len"]:] |
|
|
|
|
|
logits, _ = self(input_ids_truncated) |
|
|
|
|
|
logits = logits[:, -1, :] / temperature |
|
|
|
|
|
probs = torch.softmax(logits, dim=-1) |
|
|
|
|
|
next_token = torch.multinomial(probs, num_samples=1) |
|
|
|
|
|
|
|
input_ids = torch.cat((input_ids, next_token), dim=1) |
|
|
|
self.train() |
|
return input_ids |
|
|
|
|
|
class Dataset(Dataset): |
|
def __init__(self, data, tokenizer, seq_len): |
|
self.tokenizer = tokenizer |
|
self.seq_len = seq_len |
|
|
|
print_with_line("# Tokenize the entire data") |
|
self.tokens = self.tokenizer.tokenize(data) |
|
print(f"DEBUG: Total tokens: {len(self.tokens)} in Dataset(") |
|
|
|
|
|
self.token_counts = self._calculate_token_counts() |
|
|
|
|
|
self.data = [] |
|
for i in range(0, len(self.tokens) - seq_len - 1, seq_len): |
|
input_seq = self.tokens[i:i + seq_len] |
|
target_seq = self.tokens[i + 1:i + seq_len + 1] |
|
self.data.append((torch.tensor(input_seq), torch.tensor(target_seq))) |
|
|
|
print(f"DEBUG: Number of data samples created in class Dataset(Dataset): {len(self.data)}") |
|
|
|
|
|
print_with_line("# Print token-vocabulary information:") |
|
self.print_vocabulary_info() |
|
|
|
def _calculate_token_counts(self): |
|
|
|
counts = {} |
|
for token in self.tokens: |
|
if token in counts: |
|
counts[token] += 1 |
|
print(f"token {token} count has been incremented to {counts[token]}") |
|
else: |
|
counts[token] = 1 |
|
return counts |
|
|
|
def print_vocabulary_info(self): |
|
print_with_line("# Print token-vocabulary information:") |
|
for token_id in range(self.tokenizer.vocab_size): |
|
token = self.tokenizer.index_to_word[token_id] |
|
count = self.token_counts.get(token_id, 0) |
|
|
|
print(f" Token {token_id}:'{token}' \t\t occurs {count} times in the dataset") |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, idx): |
|
return self.data[idx] |
|
|
|
|
|
|
|
|
|
class Trainer: |
|
def __init__(self, model, tokenizer, train_loader, hyperparameters, device): |
|
self.model = model |
|
self.tokenizer = tokenizer |
|
self.train_loader = train_loader |
|
self.hyperparameters = hyperparameters |
|
self.Per_token_loss_threshold = Per_token_loss_threshold |
|
self.Early_stopping_loss = Early_stopping_loss |
|
self.device = device |
|
|
|
|
|
self.optimizer = optim.AdamW(self.model.parameters(), lr=hyperparameters["learning_rate"]) |
|
self.scheduler = ReduceLROnPlateau(self.optimizer, mode='min', factor=0.99, patience=10) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train(self): |
|
self.model.train() |
|
for epoch in range(self.hyperparameters["epochs"]): |
|
total_loss = 0 |
|
for batch_idx, (input_seq, target_seq) in enumerate(self.train_loader): |
|
input_seq = input_seq.to(self.device) |
|
target_seq = target_seq.to(self.device) |
|
|
|
self.optimizer.zero_grad() |
|
logits, loss = self.model(input_seq, targets=target_seq) |
|
|
|
|
|
""" |
|
# Per-token loss calculation (using cross-entropy as an example) |
|
loss_fn = torch.nn.CrossEntropyLoss(reduction='none') # 'none' to get per-token loss |
|
per_token_loss = loss_fn(logits.view(-1, logits.size(-1)), target_seq.view(-1)) |
|
per_token_loss = per_token_loss.view(target_seq.size()) # Reshape to match target_seq shape |
|
|
|
# Move error reporting INSIDE the batch loop |
|
if loss.item() < 0.01: # Check loss for current batch |
|
print("Tokens with significant errors (per-token loss > threshold): [feature not working]") |
|
for i in range(target_seq.size(0)): # Iterate over elements in the batch |
|
for token_idx in range(target_seq.size(1)): |
|
if per_token_loss[i, token_idx] > self.Per_token_loss_threshold: |
|
target_token_id = target_seq[i, token_idx].item() |
|
target_word = self.tokenizer.index_to_word[target_token_id] |
|
print(f" Batch item {i}, Token {token_idx}: Word '{target_word}' (ID: {target_token_id}), Loss: {per_token_loss[i, token_idx].item():.4f}") |
|
""" |
|
|
|
|
|
loss.backward() |
|
self.optimizer.step() |
|
total_loss += loss.item() |
|
|
|
average_loss = total_loss / len(self.train_loader) |
|
print(f"Epoch {epoch+1}/{self.hyperparameters['epochs']}, Loss: {average_loss:.4f}") |
|
if loss < 0.01: |
|
print(" LOSS IS BELOW 0.01") |
|
if loss < 0.001: |
|
print(" LOSS IS BELOW 0.001") |
|
|
|
self.scheduler.step(average_loss) |
|
|
|
|
|
current_lr = self.optimizer.param_groups[0]['lr'] |
|
last_lr = self.scheduler.get_last_lr()[0] |
|
if current_lr != last_lr: |
|
print(f"Learning rate reduced to {last_lr:.6f}") |
|
|
|
if(epoch%100 ==0): |
|
current_lr = self.optimizer.param_groups[0]['lr'] |
|
print(f"Epoch {epoch + 1}: Current learning rate: {current_lr:.6f}") |
|
|
|
|
|
self.save_checkpoint(f"model_checkpoint_epoch_{epoch + 1}.pth", epoch, average_loss) |
|
|
|
|
|
if average_loss < self.Early_stopping_loss: |
|
print(f"Early stopping: Average loss {average_loss:.4f} is below the threshold ({self.Early_stopping_loss}).") |
|
self.save_checkpoint(f"model_checkpoint_early_stop.pth", epoch, average_loss) |
|
break |
|
|
|
|
|
def save_checkpoint(self, path, epoch, average_loss): |
|
|
|
script_filename = os.path.basename(__file__) |
|
|
|
|
|
current_datetime = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") |
|
|
|
|
|
base_filename, extension = os.path.splitext(path) |
|
new_filename = f"{base_filename}_{script_filename}_{current_datetime}{extension}" |
|
|
|
torch.save({ |
|
'epoch': epoch, |
|
'model_state_dict': self.model.state_dict(), |
|
'optimizer_state_dict': self.optimizer.state_dict(), |
|
'loss': average_loss, |
|
'hyperparameters': self.hyperparameters |
|
}, new_filename) |
|
|
|
|
|
|
|
def main(): |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
print(f"Using device: {device}") |
|
|
|
print_with_line("# Initialize tokenizer") |
|
|
|
tokenizer = Tokenizer(Memorized_Speech, hyperparameters["special_tokens"], hyperparameters["vocab_size"]) |
|
print(f"Vocabulary Size: {tokenizer.vocab_size}") |
|
|
|
print_with_line("# Prepare dataset") |
|
|
|
dataset = Dataset(Memorized_Speech, tokenizer, min_training_input_seq_len) |
|
train_loader = DataLoader(dataset, batch_size=hyperparameters["batch_size"]) |
|
|
|
print_with_line("# Initialize model") |
|
print(f"HyperParamters = {hyperparameters}") |
|
model = ToyGPT2(hyperparameters).to(device) |
|
|
|
print_with_line("# Initialize trainer") |
|
trainer = Trainer(model, tokenizer, train_loader, hyperparameters, device) |
|
|
|
print_with_line("# Train the model") |
|
trainer.train() |
|
|
|
print("") |
|
print_with_line("# --- Inference Examples ---") |
|
model.eval() |
|
|
|
|
|
print_with_line("# Example 1: Recite the Gettysburg Address") |
|
start_text = "four score" |
|
start_tokens = torch.tensor(tokenizer.tokenize(start_text)).unsqueeze(0).to(device) |
|
print("Prompt:", start_text) |
|
generated_tokens = model.generate(start_tokens, max_new_tokens=len(dataset.tokens)-len(start_tokens), temperature=1.0) |
|
generated_text = tokenizer.detokenize(generated_tokens.squeeze().tolist()) |
|
print("\nResponse:\n", generated_text) |
|
|
|
print("") |
|
|
|
print_with_line("# Example 2: Free text generation after encountering <FreetheLLM>") |
|
|
|
start_text = "we here highly resolve that these dead shall not have died in vain and that this nation under god shall have a new " |
|
special_token = tokenizer.special_tokens[0] |
|
start_text += special_token |
|
print("Prompt:", start_text) |
|
|
|
start_tokens = torch.tensor(tokenizer.tokenize(start_text)).unsqueeze(0).to(device) |
|
|
|
generated_tokens = model.generate(start_tokens, max_new_tokens=100, temperature=1.0) |
|
generated_text = tokenizer.detokenize(generated_tokens.squeeze().tolist()) |
|
print("\nFreestyle Generation:\n", generated_text) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |