TinyStories_Transformer / better_transformer.py
Kc-12's picture
Upload better_transformer.py
31cb92e
import os
import json
import random
import time
import streamlit as st
import re
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer
MODEL_FILE = r'bt_8_LAYERs_100_DATA_PCT_768_EMBD_DIM_epoch_10.pt' ##place model file in same directory as app.py
torch.set_default_device(torch.device("cuda"))
# Better Transformer Class –––––––––––––––––––––––––––––––––––––––––––––––
class MLP(nn.Module):
def __init__(self, n_embd, dropout=0.1):
super().__init__()
self.net = nn.Sequential(
nn.Linear(n_embd, 4 * n_embd),
nn.GELU(), # replaced ReLU
nn.Dropout(p=dropout),
nn.Linear(4 * n_embd, n_embd),
)
def forward(self, x):
return self.net(x)
class MultiHeadAttention(nn.Module):
def __init__(self, n_embd, n_head, seq_length, dropout=0.1):
super().__init__()
self.n_embd = n_embd
self.n_head = n_head
self.head_dim = n_embd // n_head # Dimension of each head's key, query, and value
assert self.head_dim * n_head == self.n_embd, "n_embd must be divisible by n_head"
self.seq_length = seq_length
self.drop = nn.Dropout(p=dropout)
self.query = nn.Linear(n_embd, n_embd, bias=False)
self.key = nn.Linear(n_embd, n_embd, bias=False)
self.value = nn.Linear(n_embd, n_embd, bias=False)
self.out = nn.Linear(n_embd, n_embd, bias=False) # multi-head combining weight matrix
def split_heads(self, x):
B, S, D = x.size()
# split dimension into n_head * head_dim, then transpose the sequence length w/ n_head
# output: [B, n_head, S, head_dim]
return x.view(B, S, self.n_head, self.head_dim).transpose(1, 2)
def combine_heads(self, x):
# use permute or transpose to reverse
# taking a view earlier may produce a non-contiguous tensor, so we convert back because view needs a contiguous input
B, _, S, head_dim = x.size() # _ is n_head which we will merge
# output: [B, S, n_embd]
return x.transpose(1, 2).contiguous().view(B, S, self.n_embd)
def scaled_dot_product(self, q, k, v, dropout, mask=None):
# q,k,v are [B, n_head, S, head_dim]
# the key transpose sets up batch multiplication s.t. wei = [B, n_head, S, S]
wei = q @ k.transpose(-2,-1) / np.sqrt(self.head_dim)
# mask is [B, 1, S, S], so simply broadcasted across each head and works as expected
if mask is not None:
wei = wei.masked_fill(mask, float('-inf'))
wei = dropout(F.softmax(wei, dim=-1))
out = wei @ v
return out
def forward(self, x, mask=None):
# x: (B, S, n_embd)
# Step 1 and 2: Project full query, key, value, then split via reshaping
q = self.split_heads(self.query(x))
k = self.split_heads(self.key(x))
v = self.split_heads(self.value(x))
# Step 3: Compute scaled dot-product attention with causal mask
# not done. should use generate_mask
attn = self.scaled_dot_product(q, k, v, self.drop, mask)
# Step 4 and 5: Concatenate attention scores, return projected output matrix
out = self.out(self.combine_heads(attn)) # (B, S, n_embd)
return out
class Block(nn.Module):
def __init__(self, n_embd, n_head, seq_length, dropout=0.1):
super().__init__()
self.sa = MultiHeadAttention(n_embd, n_head, seq_length, dropout)
self.mlp = MLP(n_embd, dropout)
self.ln1 = nn.LayerNorm(n_embd)
self.ln2 = nn.LayerNorm(n_embd)
# experimentally, apply layer norm before attention/MLP
self.drop = nn.Dropout(p=dropout)
def forward(self, x, mask):
# residual connection (stream)
x = x + self.drop(self.sa(self.ln1(x), mask))
x = x + self.drop(self.mlp(self.ln2(x)))
return x
class PositionalEncoding(nn.Module):
"""
Formula taken from the original Transformer paper:
PE(pos, 2i (even)) = sin(pos/(10000^{2i/d_model}))
PE(pos, 2i+1 (odd)) = cos(pos/(10000^{2i/d_model}))
See reference for more details:
https://kikaben.com/transformers-positional-encoding/
"""
def __init__(self, d_model, max_len):
# just set d_model = n_embd and max_len = seq_len
super().__init__()
position = torch.arange(max_len).unsqueeze(1) # [max_len, 1]
divisor = torch.exp(torch.arange(0, d_model, 2) * (-np.log(10000.0) / d_model)) # [d_model / 2, half for each of sin and cos]
pe = torch.zeros(max_len, d_model)
pe[:, 0::2] = torch.sin(position * divisor) # 0 for second dim or :?
pe[:, 1::2] = torch.cos(position * divisor)
self.register_buffer('pe', pe) # result: self.pe = [max_len, d_model], mapping each token index to a vector of length d_model as desired
def forward(self, x):
# x = torch.arange(seq_length) has shape [seq_length], so x.size(0) extracts it, then we index self.pe for the first seq_length mappings
# note we do not add the positional embeddings to x itself yet, we simply return them
# output = (seq_length, d_model=n_embd)
return self.pe[:x.size(0)]
class BetterTransformer(nn.Module):
def __init__(self, vocab_size, seq_length, n_embd, n_head, n_layer, pad_idx, eos_token_id, device, dropout=0.1):
super().__init__()
self.token_embedding = nn.Embedding(vocab_size, n_embd, padding_idx=pad_idx)
# we need to make sure the embedding ignores the padding token right?
self.position_embedding = PositionalEncoding(n_embd, seq_length)
self.blocks = nn.Sequential(*[Block(n_embd,
n_head,
seq_length,
dropout) for _ in range(n_layer)])
self.lm_head = nn.Linear(n_embd, vocab_size)
self.drop = nn.Dropout(dropout)
self.seq_length = seq_length
self.pad_idx = pad_idx
self.eos_token_id = eos_token_id
self.device = device
self.init_params()
# optional weight initialization (e.g. Xavier uniform)
def init_params(self, default_initialization=False):
if not default_initialization:
for name, p in self.named_parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def get_causal_mask(self, x):
"""
Generates causal mask for decoding
"""
seq_len = x.size(-1) # x = (batch_size x seq_len)
attn_shape = (1, seq_len, seq_len)
subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8') # k = 1 shifts the diagonal, so that the main diagonal gets 0's
return (torch.from_numpy(subsequent_mask) == 0).to(self.device) # (1, seq_len x seq_len)
# True along main diagonal + below, False elsewhere
def get_pad_mask(self, x, pad_idx):
"""
Generates padding mask
"""
return (x != pad_idx).unsqueeze(1).unsqueeze(-2).to(self.device)
# (batch_size x 1 x 1 x seq_len)
def forward(self, x, targets=None):
# should alr be int64 tokens but explicit cast in case
x = x.to(torch.int64)
B, S = x.shape
# get mask
mask = self.get_pad_mask(x, self.pad_idx) & self.get_causal_mask(x).to(self.device)
# mask = (batch_size x 1 x seq_len x seq_len)
tok_emb = self.token_embedding(x)
pos_emb = self.position_embedding(torch.arange(S))
x = self.drop(tok_emb + pos_emb)
# (B, S, n_embd)
for block in self.blocks:
x = block(x, ~mask) # (batch_size, seq_length, n_embd)
# negate mask to fill originally False values with -inf later
logits = self.lm_head(x) # (batch_size, seq_length, vocab_size)
# this code assumes teacher forcingβ€”β€”for each text of seq length S we have S autoregressive predictions,
# thus we have B*S logits and B*S targets
if targets is None:
loss = None
else:
B, S, C = logits.shape
logits = logits.view(B*S, C)
targets = targets.view(B*S)
loss = F.cross_entropy(logits, targets, ignore_index=self.pad_idx)
# we need to make sure loss ignores the padding token right?
# this helps it avoid wasting compute on learning PAD -> PAD, etc.
return logits, loss
def generate(self, input_ids, method='multinomial',
max_new_tokens=1000, temp=None,
num_beams=None, p_nucleus=None, k=None):
# TODO: see Huggingface's .generate() function
# https://huggingface.co/transformers/v3.4.0/_modules/transformers/generation_utils.html
if method == 'temperature':
assert (temp is not None) and (0 < temp) and (temp <= 1)
# if method == 'num_beams':
# assert isinstance(num_beams, int) and (num_beams) > 0 and (num_beams) < 100
if method == 'top-k':
assert isinstance(k, int) and (k > 0)
# input_ids begins as (batch_size, seq_length)
for _ in range(max_new_tokens):
if method in ['multinomial', 'temperature', 'greedy', 'nucleus', 'top-k']:
# i) Truncate to the most recent `max length` tokens
text_cond = input_ids[:, -self.seq_length:]
# ii) Retrieve predictions
logits, loss = self(text_cond) # no loss because no targets ofc
# model output: (batch_size, seq_length, vocab_size)
# iii) Find last token logits of each
logits = logits[:, -1, :] # (batch_size, vocab_size)
# aside: if temperature sampling, divide logits by temp before applying softmax
if method == 'temperature':
logits = logits / temp
# iv) Take softmax along each
probs = F.softmax(logits, dim=-1)
# v) Sample next token depending on method
if method == 'greedy':
next_idx = probs.argmax(dim=-1).unsqueeze(-1)
elif method in ['multinomial', 'temperature', 'nucleus', 'top-k']:
if method == 'nucleus':
assert p_nucleus is not None and (0 < p_nucleus) and (p_nucleus <= 1)
sorted_probs, sorted_idx = probs.sort(dim=-1, descending=True)
prob_cumsum = sorted_probs.cumsum(dim=-1)
idx_remove = prob_cumsum > p_nucleus
# shift one right to ensure the first token is above the threshold
idx_remove[..., 1:] = idx_remove[..., :-1].clone()
idx_remove[..., 0] = False
# retrieve original indices by reverse-sorting
remove_mask = idx_remove.gather(dim=-1,
index=sorted_idx.argsort(dim=-1))
# ^ specifically, we do this by first argsorting the indices which were returned from argsort. this is crazy y'all
# you can show that this returns indices that when used to subset a sorted array, returns the original array in unsorted order
# https://stackoverflow.com/questions/52127723/pytorch-better-way-to-get-back-original-tensor-order-after-torch-sort
# torch.gather is how we apply a multi-dimensional index
# https://stackoverflow.com/questions/50999977/what-does-the-gather-function-do-in-pytorch-in-layman-terms
probs[remove_mask] = 0
if method == 'top-k':
remove_mask = probs < torch.topk(probs, k).values[..., -1, None] # the topk returns (B, 1), leaving only the
# kth largest probs (i.e. the cutoff value for each). Then mask is same size as probs (B, vocab_size)
probs[remove_mask] = 0
# Sample probabilistically via scores
next_idx = torch.multinomial(probs, num_samples=1) # (batch_size, 1)
# vi) Autoregressively append to input_text
input_ids = torch.cat((input_ids, next_idx), dim=-1)
# end prematurely if <EOS> generated
if next_idx == self.eos_token_id:
break
# now input_text = (batch_size, seq_length + 1)
return input_ids
# END OF Better Transformer Class –––––––––––––––––––––––––––––––––––––––––––––––
def set_seed(seed = 42):
random.seed(seed)
np.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
torch.manual_seed(seed)
#torch.cuda.manual_seed(seed)
# torch.cuda.manual_seed_all(seed) # if multi-GPU
torch.backends.cudnn.deterministic=True # only applies to CUDA convolution operations
torch.backends.cudnn.benchmark = False
# usually CuDNN has heuristics as to which algorithm to pick. cudnn.benchmark benchmarks several algorithms and picks the fastest
# often helpful if your input shapes are fixed and not changing a lot during training
# however, this means it may pick a different algorithm even when the deterministic flag is set.
# As such it is good practice to turn off cudnn.benchmark when turning on cudnn.deterministic
def load_tokenizer(device):
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B")
if tokenizer.pad_token is None:
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
EMPTY_TOKENS = torch.full((1,1), tokenizer.bos_token_id, dtype=torch.long).to(device)
return tokenizer, EMPTY_TOKENS
def load_big_model(tokenizer, device):
## Model architecture
N_HEAD = 16
N_LAYER = 8
N_EMBD = 768
VOCAB_SIZE = 50258
SEQ_LENGTH = 384
model = BetterTransformer(VOCAB_SIZE, SEQ_LENGTH, N_EMBD, N_HEAD, N_LAYER, tokenizer.pad_token_id, tokenizer.eos_token_id, device=device)
model.init_params()
path = MODEL_FILE
model.load_state_dict(torch.load(path, map_location=device)["model_state_dict"])
return model
def generate(model, tokenizer, device, method=None, k=None,
p_nucleus=None, temp=None, max_new_tokens=None, cond="", deterministic=None):
"""
Wrapper for generating text using the specified model. Generates unconditionally if cond=None.
Inputs:
-model: Decoder model to be used for text generation
-tokenizer: Compatible tokenizer
-device: Device of model (CPU/CUDA)
-method (str): Decoding method for text generation ('multinomial', 'temperature', 'greedy', 'nucleus', or 'top-k')
-k (int): Positive integer for top-k logits to sample if top-k decoding
-p_nucleus (float/int): Cumulative probability cutoff if nucleus/top-p decoding
-temp (float/int): Temperature if temperature decoding
-max_new_tokens (int): Maximum number of tokens to generate
-cond (str=None): If provided, will serve as conditional prompt for text generation
-deterministic (int): If deterministic, uses the specified seed for model generation
Returns:
-res (str): Generated text string
"""
assert method in ['multinomial', 'temperature', 'greedy', 'nucleus', 'top-k'], \
"method must be 'multinomial', 'temperature', 'greedy', 'nucleus', or 'top-k'"
#if method == 'temperature':
# assert (temp is not None) and isinstance(temp, (int, float)) and (0 < temp) and (temp <= 1), \
# "temp must be defined as a number between (0, 1]"
#if method == 'nucleus':
# assert (p_nucleus is not None) and isinstance(p_nucleus, (int, float)) and (0 < p_nucleus) and (p_nucleus <= 1), \
# "p_nucleus must be defined as a number between (0, 1]"
## if method == 'num_beams':
## assert isinstance(num_beams, int) and (num_beams) > 0 and (num_beams) < 100
#if method == 'top-k':
# assert (k is not None) and isinstance(k, int) and (k > 0) and (k < SEQ_LENGTH), \
# "k must be defined as an integer greater than 0 and less than the model sequence length"
#if max_new_tokens is None:
# print('No max_new_tokens provided, using a default value of 250\n')
# max_new_tokens = 250
#assert (max_new_tokens is not None) and isinstance(max_new_tokens, int) and (max_new_tokens) > 0 and (max_new_tokens) <= 1000, \
#"max_new_tokens must be an integer between (0, 1000]"
if deterministic is not None:
set_seed(deterministic)
if cond != "":
cond_tokens = tokenizer(cond).input_ids
gen_tokens = model.generate(torch.tensor(cond_tokens).unsqueeze(0).long().to(device),
method=method, k=k, p_nucleus=p_nucleus, temp=temp,
max_new_tokens=max_new_tokens)[0]
# Insert delimiter to indicate where prompt ends
gen_prep = torch.zeros(len(gen_tokens)+2).long() # make space for two more tokens for delimiter
gen_prep -= 1
gen_prep[:len(cond_tokens)] = gen_tokens[:len(cond_tokens)]
gen_prep[-(len(gen_tokens)-len(cond_tokens)):] = gen_tokens[-(len(gen_tokens)-len(cond_tokens)):]
gen_prep[gen_prep == -1] = torch.tensor(tokenizer.encode(' || ')) # insert tokens for || in between
res = tokenizer.decode(gen_prep)
res = re.sub(re.escape(tokenizer.bos_token), '', res, count=1) ## Remove end token
else:
empty_tokens = torch.full((1,1), tokenizer.bos_token_id, dtype=torch.long).to(device)
res = tokenizer.batch_decode(model.generate(empty_tokens,
method=method, k=k,
p_nucleus=p_nucleus, temp=temp,
max_new_tokens=max_new_tokens))[0]
res = re.sub(re.escape(tokenizer.bos_token), '', res, count=2) ## Remove start and end tokens
# Clean up Unicode character issues
# 'Ò€œ' then 'Ò€' = opening and closing double quotes
# 'Ò€ℒ' = apostrophe
res = re.sub(r'Ò€œ', '"', res)
res = re.sub(r'Ò€ℒ', "'", res)
res = re.sub(r'Ò€', '"', res)
res = res + " <|endoftext|>" ## better end token
return res