|
import pandas as pd |
|
import numpy as np |
|
import os |
|
import requests |
|
import torch |
|
from torch import nn |
|
from torch.nn import functional as F |
|
import sentencepiece as spm |
|
import random |
|
from collections import OrderedDict |
|
from matplotlib import pyplot as plt |
|
import time |
|
|
|
if torch.cuda.is_available(): |
|
device = "cuda" |
|
elif torch.backends.mps.is_available(): |
|
device = "mps" |
|
else: |
|
device = "cpu" |
|
|
|
VOCAB_SIZE = 130 |
|
BATCH_SIZE = 32 |
|
CONTEXT_WINDOW = 16 |
|
EPOCHS = 1000 |
|
DIM = 128 |
|
LOG_INTERVAL = 10 |
|
HEADS = 8 |
|
LAYERS = 4 |
|
|
|
url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt" |
|
response = requests.get(url) |
|
|
|
if response.status_code == 200: |
|
tinyshakespeare = response.text |
|
else: |
|
print(response.status_code) |
|
|
|
tinyshakespeare_list = tinyshakespeare.split("\n") |
|
tinyshakespeare_list = [i for i in tinyshakespeare_list if i != ""] |
|
|
|
spm.SentencePieceTrainer.Train( |
|
sentence_iterator = iter(tinyshakespeare_list), |
|
model_prefix = "tinyshakespeare_model", |
|
vocab_size = VOCAB_SIZE, |
|
character_coverage = 1.0, |
|
model_type = "bpe", |
|
pad_id = 0, |
|
unk_id = 1, |
|
bos_id = 2, |
|
eos_id = 3, |
|
) |
|
|
|
sp = spm.SentencePieceProcessor(model_file = "tinyshakespeare_model.model") |
|
dataset_tensor = torch.tensor(sp.Encode(tinyshakespeare)) |
|
|
|
def get_batch_train(dataset, batch_size, context_window): |
|
train_data = dataset[:int(.7 * len(dataset))] |
|
ix = torch.randint(0, train_data.size(0) - context_window - 1, (batch_size,)) |
|
x = torch.stack([train_data[i:i+context_window] for i in ix]).long() |
|
y = torch.stack([train_data[i+1:i+context_window+1] for i in ix]).long() |
|
return x, y |
|
|
|
|
|
def get_batch_val(dataset, batch_size, context_window): |
|
val_data = dataset[int(.7 * len(dataset)): int(.85 * len(dataset))] |
|
ix = torch.randint(0, val_data.size(0) - context_window - 1, (batch_size,)) |
|
x = torch.stack([val_data[i:i+context_window] for i in ix]).long() |
|
y = torch.stack([val_data[i+1:i+context_window+1] for i in ix]).long() |
|
return x, y |
|
|
|
def get_batch_test(dataset, batch_size, context_window): |
|
test_data = dataset[int(.85 * len(dataset)): len(dataset)] |
|
ix = torch.randint(0, test_data.size(0) - context_window - 1, (batch_size,)) |
|
x = torch.stack([test_data[i:i+context_window] for i in ix]).long() |
|
y = torch.stack([test_data[i+1:i+context_window+1] for i in ix]).long() |
|
return x, y |
|
|
|
@torch.no_grad() |
|
def calculate_loss(model): |
|
model.eval() |
|
train_losses = [] |
|
val_losses = [] |
|
for i in range(EPOCHS): |
|
|
|
x_train, y_train = get_batch_train(dataset_tensor, BATCH_SIZE, CONTEXT_WINDOW) |
|
_, train_loss = model(x_train, y_train) |
|
train_losses.append(train_loss.item()) |
|
|
|
|
|
x_val, y_val = get_batch_val(dataset_tensor, BATCH_SIZE, CONTEXT_WINDOW) |
|
_, val_loss = model(x_val, y_val) |
|
val_losses.append(val_loss.item()) |
|
|
|
losses_dict = {"train": np.mean(train_losses), "val": np.mean(val_losses)} |
|
return losses_dict |
|
|
|
|
|
@torch.no_grad() |
|
def calculate_accuracy(model): |
|
model.eval() |
|
correct_predictions = 0 |
|
total_predictions = 0 |
|
|
|
for i in range(EPOCHS): |
|
|
|
x_val, y_val = get_batch_val(dataset_tensor, BATCH_SIZE, CONTEXT_WINDOW) |
|
|
|
|
|
logits = model(x_val) |
|
|
|
|
|
predicted_labels = torch.argmax(logits, dim=-1) |
|
|
|
|
|
correct_predictions += (predicted_labels == y_val).sum().item() |
|
total_predictions += y_val.numel() |
|
|
|
accuracy = correct_predictions / total_predictions |
|
return accuracy |
|
|
|
@torch.no_grad() |
|
def calculate_perplexity(model): |
|
model.eval() |
|
val_losses = [] |
|
|
|
for i in range(EPOCHS): |
|
|
|
x_val, y_val = get_batch_val(dataset_tensor, BATCH_SIZE, CONTEXT_WINDOW) |
|
|
|
|
|
_, val_loss = model(x_val, y_val) |
|
val_losses.append(val_loss.item()) |
|
|
|
|
|
mean_val_loss = np.mean(val_losses) |
|
|
|
|
|
perplexity = np.exp(mean_val_loss) |
|
return perplexity |
|
|
|
def train(model, optimizer, checkpoint_path="/checkpoints"): |
|
losses = [] |
|
accs = [] |
|
perps = [] |
|
for epoch in range(EPOCHS): |
|
optimizer.zero_grad() |
|
x_train, y_train = get_batch_train(dataset_tensor, BATCH_SIZE, CONTEXT_WINDOW) |
|
logits, loss = model(x_train, y_train) |
|
loss.backward() |
|
optimizer.step() |
|
|
|
if epoch % LOG_INTERVAL == 0: |
|
current_loss = calculate_loss(model) |
|
current_accuracy = calculate_accuracy(model) |
|
current_perplexity = calculate_perplexity(model) |
|
|
|
losses.append(current_loss) |
|
accs.append(current_accuracy) |
|
perps.append(current_perplexity) |
|
|
|
torch.save({ |
|
'epoch': epoch, |
|
'model_state_dict': model.state_dict(), |
|
'optimizer_state_dict': optimizer.state_dict(), |
|
'loss': current_loss, |
|
'accuracy': current_accuracy, |
|
'perplexity': current_perplexity |
|
}, f"{checkpoint_path}/checkpoint_epoch_{epoch}.pth") |
|
|
|
print(f"Epoch {epoch}: Loss - {current_loss['val']}, Accuracy - {current_accuracy}, Perplexity - {current_perplexity}") |
|
|
|
|
|
print("validation Loss: ", losses[-1]['val']) |
|
print("validation Accuracy: ", accs[-1]) |
|
print("validation Perplexity: ", perps[-1]) |
|
return pd.DataFrame(losses).plot() |
|
|
|
class RMSNorm(torch.nn.Module): |
|
def __init__(self, layer_shape, eps=1e-8, bias=False): |
|
super(RMSNorm, self).__init__() |
|
self.register_parameter("scale", torch.nn.Parameter(torch.ones(layer_shape))) |
|
|
|
def forward(self, x): |
|
return self.scale[:x.shape[1], :].unsqueeze(0) * ((torch.linalg.norm(x, dim=(1,2)) * x[0].numel() ** -.5).unsqueeze(-1).unsqueeze(-1)) |
|
|
|
def get_rotary_matrix(context_window, embedding_dim): |
|
R = torch.zeros((context_window, embedding_dim, embedding_dim), requires_grad=False) |
|
for position in range(context_window): |
|
for i in range(embedding_dim//2): |
|
theta = 10000. ** (-2.*(i - 1) / embedding_dim) |
|
m_theta = position * theta |
|
R[position, 2*i,2*i] = np.cos(m_theta) |
|
R[position, 2*i,2*i+1] = - np.sin(m_theta) |
|
R[position, 2*i+1,2*i] = np.sin(m_theta) |
|
R[position, 2*i+1,2*i+1] = np.cos(m_theta) |
|
return R |
|
|
|
|
|
class RoPEAttentionHead(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.w_q = nn.Linear(DIM, DIM, bias=False) |
|
self.w_k = nn.Linear(DIM, DIM, bias=False) |
|
self.w_v = nn.Linear(DIM, DIM, bias=False) |
|
|
|
self.R = get_rotary_matrix(CONTEXT_WINDOW, DIM) |
|
|
|
def get_rotary_matrix(context_window, embedding_dim): |
|
R = torch.zeros((context_window, embedding_dim, embedding_dim), requires_grad=False) |
|
for position in range(context_window): |
|
for i in range(embedding_dim//2): |
|
theta = 10000. ** (-2.*(i - 1) / embedding_dim) |
|
m_theta = position * theta |
|
R[position, 2*i,2*i] = np.cos(m_theta) |
|
R[position, 2*i,2*i+1] = - np.sin(m_theta) |
|
R[position, 2*i+1,2*i] = np.sin(m_theta) |
|
R[position, 2*i+1,2*i+1] = np.cos(m_theta) |
|
return R |
|
|
|
def forward(self, x, return_attn_weights=False): |
|
b,m,d = x.shape |
|
|
|
q = self.w_q(x) |
|
k = self.w_k(x) |
|
v = self.w_v(x) |
|
|
|
q_rotated = (torch.bmm(q.transpose(0,1), self.R[:m])).transpose(0,1) |
|
k_rotated = (torch.bmm(k.transpose(0,1), self.R[:m])).transpose(0,1) |
|
|
|
activations = F.scaled_dot_product_attention( |
|
q_rotated,k_rotated,v,dropout_p =.1 |
|
) |
|
|
|
if return_attn_weights: |
|
attn_weights = torch.bmm(q_rotated, k_rotated.transpose(1,2)) / np.sqrt(d) |
|
attn_weights = F.softmax(attn_weights, dim=-1) |
|
return activations, attn_weights |
|
return activations |
|
|
|
class RoPEAttentionHead(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.w_q = nn.Linear(DIM, DIM, bias=False) |
|
self.w_k = nn.Linear(DIM, DIM, bias=False) |
|
self.w_v = nn.Linear(DIM, DIM, bias=False) |
|
|
|
self.R = get_rotary_matrix(CONTEXT_WINDOW, DIM) |
|
|
|
def get_rotary_matrix(context_window, embedding_dim): |
|
R = torch.zeros((context_window, embedding_dim, embedding_dim), requires_grad=False) |
|
for position in range(context_window): |
|
for i in range(embedding_dim//2): |
|
theta = 10000. ** (-2.*(i - 1) / embedding_dim) |
|
m_theta = position * theta |
|
R[position, 2*i,2*i] = np.cos(m_theta) |
|
R[position, 2*i,2*i+1] = - np.sin(m_theta) |
|
R[position, 2*i+1,2*i] = np.sin(m_theta) |
|
R[position, 2*i+1,2*i+1] = np.cos(m_theta) |
|
return R |
|
|
|
def forward(self, x, return_attn_weights=False): |
|
b,m,d = x.shape |
|
|
|
q = self.w_q(x) |
|
k = self.w_k(x) |
|
v = self.w_v(x) |
|
|
|
q_rotated = (torch.bmm(q.transpose(0,1), self.R[:m])).transpose(0,1) |
|
k_rotated = (torch.bmm(k.transpose(0,1), self.R[:m])).transpose(0,1) |
|
|
|
activations = F.scaled_dot_product_attention( |
|
q_rotated,k_rotated,v,dropout_p =.1, is_causal=True |
|
) |
|
|
|
if return_attn_weights: |
|
attn_mask = torch.tril(torch.ones((m,m)), diagonal=0) |
|
attn_weights = torch.bmm(q_rotated, k_rotated.transpose(1,2)) / np.sqrt(d) + attn_mask |
|
attn_weights = F.softmax(attn_weights, dim=-1) |
|
return activations, attn_weights |
|
return activations |
|
|
|
class RoPEMultiheadAttention(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.heads = nn.ModuleList([ |
|
RoPEAttentionHead() for _ in range(HEADS) |
|
]) |
|
self.linear = nn.Linear(HEADS * DIM, DIM) |
|
self.dropout = nn.Dropout(.1) |
|
|
|
def forward(self, x): |
|
heads = [h(x) for h in self.heads] |
|
x = torch.cat(heads, dim=-1) |
|
x = self.linear(x) |
|
x = self.dropout(x) |
|
return x |
|
|
|
|
|
class SwiGLU(nn.Module): |
|
def __init__(self, size): |
|
super().__init__() |
|
self.linear_gate = nn.Linear(size, size) |
|
self.linear = nn.Linear(size, size) |
|
self.beta = torch.randn(1, requires_grad=True) |
|
|
|
self.beta = nn.Parameter(torch.ones(1)) |
|
self.register_parameter("beta", self.beta) |
|
|
|
def forward(self, x): |
|
swish_gate = self.linear_gate(x) * torch.sigmoid(self.beta * self.linear_gate(x)) |
|
out = swish_gate * self.linear(x) |
|
return out |
|
|
|
|
|
class LlamaBlock(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
self.rms = RMSNorm((CONTEXT_WINDOW, DIM)) |
|
|
|
self.attention = RoPEMultiheadAttention() |
|
self.feedforward = nn.Sequential( |
|
nn.Linear(DIM, DIM), |
|
SwiGLU(DIM), |
|
) |
|
|
|
def forward(self, x): |
|
x = self.rms(x) |
|
x = x + self.attention(x) |
|
|
|
x = self.rms(x) |
|
x = x + self.feedforward(x) |
|
return x |
|
|
|
class Llama(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.embeddings = nn.Embedding(VOCAB_SIZE, DIM) |
|
self.llama_blocks = nn.Sequential( |
|
OrderedDict([(f"llama_{i}", LlamaBlock()) for i in range(LAYERS)]) |
|
) |
|
|
|
self.ffn = nn.Sequential( |
|
nn.Linear(DIM, DIM), |
|
SwiGLU(DIM), |
|
nn.Linear(DIM, VOCAB_SIZE), |
|
) |
|
|
|
print("model params:", sum([m.numel() for m in self.parameters()])) |
|
|
|
def forward(self, idx, targets=None): |
|
x = self.embeddings(idx) |
|
x = self.llama_blocks(x) |
|
logits = self.ffn(x) |
|
|
|
if targets is None: |
|
return logits |
|
else: |
|
loss = F.cross_entropy(logits.view(-1, VOCAB_SIZE), targets.view(-1)) |
|
return logits, loss |
|
|
|
|
|
llama = Llama() |
|
optimizer = torch.optim.Adam(llama.parameters()) |
|
train(llama, optimizer) |