matiusX's picture
Upload llama.py
589d5e7 verified
raw
history blame contribute delete
No virus
12.4 kB
import pandas as pd
import numpy as np
import os
import requests
import torch
from torch import nn
from torch.nn import functional as F
import sentencepiece as spm
import random
from collections import OrderedDict
from matplotlib import pyplot as plt
import time
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
VOCAB_SIZE = 130
BATCH_SIZE = 32
CONTEXT_WINDOW = 16
EPOCHS = 1000
DIM = 128
LOG_INTERVAL = 10
HEADS = 8
LAYERS = 4
url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
response = requests.get(url)
if response.status_code == 200:
tinyshakespeare = response.text
else:
print(response.status_code)
tinyshakespeare_list = tinyshakespeare.split("\n")
tinyshakespeare_list = [i for i in tinyshakespeare_list if i != ""]
spm.SentencePieceTrainer.Train(
sentence_iterator = iter(tinyshakespeare_list),
model_prefix = "tinyshakespeare_model",
vocab_size = VOCAB_SIZE,
character_coverage = 1.0,
model_type = "bpe",
pad_id = 0,
unk_id = 1,
bos_id = 2,
eos_id = 3,
)
sp = spm.SentencePieceProcessor(model_file = "tinyshakespeare_model.model")
dataset_tensor = torch.tensor(sp.Encode(tinyshakespeare))
def get_batch_train(dataset, batch_size, context_window):
train_data = dataset[:int(.7 * len(dataset))]
ix = torch.randint(0, train_data.size(0) - context_window - 1, (batch_size,))
x = torch.stack([train_data[i:i+context_window] for i in ix]).long()
y = torch.stack([train_data[i+1:i+context_window+1] for i in ix]).long()
return x, y
def get_batch_val(dataset, batch_size, context_window):
val_data = dataset[int(.7 * len(dataset)): int(.85 * len(dataset))]
ix = torch.randint(0, val_data.size(0) - context_window - 1, (batch_size,))
x = torch.stack([val_data[i:i+context_window] for i in ix]).long()
y = torch.stack([val_data[i+1:i+context_window+1] for i in ix]).long()
return x, y
def get_batch_test(dataset, batch_size, context_window):
test_data = dataset[int(.85 * len(dataset)): len(dataset)]
ix = torch.randint(0, test_data.size(0) - context_window - 1, (batch_size,))
x = torch.stack([test_data[i:i+context_window] for i in ix]).long()
y = torch.stack([test_data[i+1:i+context_window+1] for i in ix]).long()
return x, y
@torch.no_grad()
def calculate_loss(model):
model.eval()
train_losses = []
val_losses = []
for i in range(EPOCHS):
#train evaluation
x_train, y_train = get_batch_train(dataset_tensor, BATCH_SIZE, CONTEXT_WINDOW)
_, train_loss = model(x_train, y_train)
train_losses.append(train_loss.item())
#val evaluation
x_val, y_val = get_batch_val(dataset_tensor, BATCH_SIZE, CONTEXT_WINDOW)
_, val_loss = model(x_val, y_val)
val_losses.append(val_loss.item())
losses_dict = {"train": np.mean(train_losses), "val": np.mean(val_losses)}
return losses_dict
@torch.no_grad()
def calculate_accuracy(model):
model.eval()
correct_predictions = 0
total_predictions = 0
for i in range(EPOCHS):
# Get a batch of validation data
x_val, y_val = get_batch_val(dataset_tensor, BATCH_SIZE, CONTEXT_WINDOW)
# Get model predictions
logits = model(x_val)
# Convert predictions to class labels
predicted_labels = torch.argmax(logits, dim=-1)
# Compare with true labels
correct_predictions += (predicted_labels == y_val).sum().item()
total_predictions += y_val.numel()
accuracy = correct_predictions / total_predictions
return accuracy
@torch.no_grad()
def calculate_perplexity(model):
model.eval()
val_losses = []
for i in range(EPOCHS):
# Get a batch of validation data
x_val, y_val = get_batch_val(dataset_tensor, BATCH_SIZE, CONTEXT_WINDOW)
# Get model predictions and loss
_, val_loss = model(x_val, y_val)
val_losses.append(val_loss.item())
# Calculate the mean validation loss
mean_val_loss = np.mean(val_losses)
# Perplexity is the exponential of the cross-entropy loss
perplexity = np.exp(mean_val_loss)
return perplexity
def train(model, optimizer, checkpoint_path="/checkpoints"):
losses = []
accs = []
perps = []
for epoch in range(EPOCHS):
optimizer.zero_grad()
x_train, y_train = get_batch_train(dataset_tensor, BATCH_SIZE, CONTEXT_WINDOW)
logits, loss = model(x_train, y_train)
loss.backward()
optimizer.step()
if epoch % LOG_INTERVAL == 0:
current_loss = calculate_loss(model)
current_accuracy = calculate_accuracy(model)
current_perplexity = calculate_perplexity(model)
losses.append(current_loss)
accs.append(current_accuracy)
perps.append(current_perplexity)
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': current_loss,
'accuracy': current_accuracy,
'perplexity': current_perplexity
}, f"{checkpoint_path}/checkpoint_epoch_{epoch}.pth")
print(f"Epoch {epoch}: Loss - {current_loss['val']}, Accuracy - {current_accuracy}, Perplexity - {current_perplexity}")
print("validation Loss: ", losses[-1]['val'])
print("validation Accuracy: ", accs[-1])
print("validation Perplexity: ", perps[-1])
return pd.DataFrame(losses).plot()
class RMSNorm(torch.nn.Module):
def __init__(self, layer_shape, eps=1e-8, bias=False):
super(RMSNorm, self).__init__()
self.register_parameter("scale", torch.nn.Parameter(torch.ones(layer_shape)))
def forward(self, x):
return self.scale[:x.shape[1], :].unsqueeze(0) * ((torch.linalg.norm(x, dim=(1,2)) * x[0].numel() ** -.5).unsqueeze(-1).unsqueeze(-1))
def get_rotary_matrix(context_window, embedding_dim):
R = torch.zeros((context_window, embedding_dim, embedding_dim), requires_grad=False)
for position in range(context_window):
for i in range(embedding_dim//2):
theta = 10000. ** (-2.*(i - 1) / embedding_dim)
m_theta = position * theta
R[position, 2*i,2*i] = np.cos(m_theta)
R[position, 2*i,2*i+1] = - np.sin(m_theta)
R[position, 2*i+1,2*i] = np.sin(m_theta)
R[position, 2*i+1,2*i+1] = np.cos(m_theta)
return R
class RoPEAttentionHead(nn.Module):
def __init__(self):
super().__init__()
self.w_q = nn.Linear(DIM, DIM, bias=False)
self.w_k = nn.Linear(DIM, DIM, bias=False)
self.w_v = nn.Linear(DIM, DIM, bias=False)
self.R = get_rotary_matrix(CONTEXT_WINDOW, DIM)
def get_rotary_matrix(context_window, embedding_dim):
R = torch.zeros((context_window, embedding_dim, embedding_dim), requires_grad=False)
for position in range(context_window):
for i in range(embedding_dim//2):
theta = 10000. ** (-2.*(i - 1) / embedding_dim)
m_theta = position * theta
R[position, 2*i,2*i] = np.cos(m_theta)
R[position, 2*i,2*i+1] = - np.sin(m_theta)
R[position, 2*i+1,2*i] = np.sin(m_theta)
R[position, 2*i+1,2*i+1] = np.cos(m_theta)
return R
def forward(self, x, return_attn_weights=False):
b,m,d = x.shape
q = self.w_q(x)
k = self.w_k(x)
v = self.w_v(x)
q_rotated = (torch.bmm(q.transpose(0,1), self.R[:m])).transpose(0,1)
k_rotated = (torch.bmm(k.transpose(0,1), self.R[:m])).transpose(0,1)
activations = F.scaled_dot_product_attention(
q_rotated,k_rotated,v,dropout_p =.1
)
if return_attn_weights:
attn_weights = torch.bmm(q_rotated, k_rotated.transpose(1,2)) / np.sqrt(d)
attn_weights = F.softmax(attn_weights, dim=-1)
return activations, attn_weights
return activations
class RoPEAttentionHead(nn.Module):
def __init__(self):
super().__init__()
self.w_q = nn.Linear(DIM, DIM, bias=False)
self.w_k = nn.Linear(DIM, DIM, bias=False)
self.w_v = nn.Linear(DIM, DIM, bias=False)
self.R = get_rotary_matrix(CONTEXT_WINDOW, DIM)
def get_rotary_matrix(context_window, embedding_dim):
R = torch.zeros((context_window, embedding_dim, embedding_dim), requires_grad=False)
for position in range(context_window):
for i in range(embedding_dim//2):
theta = 10000. ** (-2.*(i - 1) / embedding_dim)
m_theta = position * theta
R[position, 2*i,2*i] = np.cos(m_theta)
R[position, 2*i,2*i+1] = - np.sin(m_theta)
R[position, 2*i+1,2*i] = np.sin(m_theta)
R[position, 2*i+1,2*i+1] = np.cos(m_theta)
return R
def forward(self, x, return_attn_weights=False):
b,m,d = x.shape
q = self.w_q(x)
k = self.w_k(x)
v = self.w_v(x)
q_rotated = (torch.bmm(q.transpose(0,1), self.R[:m])).transpose(0,1)
k_rotated = (torch.bmm(k.transpose(0,1), self.R[:m])).transpose(0,1)
activations = F.scaled_dot_product_attention(
q_rotated,k_rotated,v,dropout_p =.1, is_causal=True
)
if return_attn_weights:
attn_mask = torch.tril(torch.ones((m,m)), diagonal=0)
attn_weights = torch.bmm(q_rotated, k_rotated.transpose(1,2)) / np.sqrt(d) + attn_mask
attn_weights = F.softmax(attn_weights, dim=-1)
return activations, attn_weights
return activations
class RoPEMultiheadAttention(nn.Module):
def __init__(self):
super().__init__()
self.heads = nn.ModuleList([
RoPEAttentionHead() for _ in range(HEADS)
])
self.linear = nn.Linear(HEADS * DIM, DIM)
self.dropout = nn.Dropout(.1)
def forward(self, x):
heads = [h(x) for h in self.heads]
x = torch.cat(heads, dim=-1)
x = self.linear(x)
x = self.dropout(x)
return x
class SwiGLU(nn.Module):
def __init__(self, size):
super().__init__()
self.linear_gate = nn.Linear(size, size)
self.linear = nn.Linear(size, size)
self.beta = torch.randn(1, requires_grad=True)
self.beta = nn.Parameter(torch.ones(1))
self.register_parameter("beta", self.beta)
def forward(self, x):
swish_gate = self.linear_gate(x) * torch.sigmoid(self.beta * self.linear_gate(x))
out = swish_gate * self.linear(x)
return out
class LlamaBlock(nn.Module):
def __init__(self):
super().__init__()
self.rms = RMSNorm((CONTEXT_WINDOW, DIM))
self.attention = RoPEMultiheadAttention()
self.feedforward = nn.Sequential(
nn.Linear(DIM, DIM),
SwiGLU(DIM),
)
def forward(self, x):
x = self.rms(x) #RMS NORMALIZATION
x = x + self.attention(x) #Self attention
x = self.rms(x) #RMS NORMALIZATION
x = x + self.feedforward(x) #Feed Foward: SwiGlu
return x
class Llama(nn.Module):
def __init__(self):
super().__init__()
self.embeddings = nn.Embedding(VOCAB_SIZE, DIM)
self.llama_blocks = nn.Sequential(
OrderedDict([(f"llama_{i}", LlamaBlock()) for i in range(LAYERS)])
)
self.ffn = nn.Sequential(
nn.Linear(DIM, DIM),
SwiGLU(DIM),
nn.Linear(DIM, VOCAB_SIZE),
)
print("model params:", sum([m.numel() for m in self.parameters()]))
def forward(self, idx, targets=None):
x = self.embeddings(idx)
x = self.llama_blocks(x)
logits = self.ffn(x)
if targets is None:
return logits
else:
loss = F.cross_entropy(logits.view(-1, VOCAB_SIZE), targets.view(-1))
return logits, loss
llama = Llama()
optimizer = torch.optim.Adam(llama.parameters())
train(llama, optimizer)