|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from datasets import load_dataset |
|
from torch.utils.data import Dataset, DataLoader |
|
from transformers import GPT2Tokenizer |
|
import math |
|
from einops import einsum |
|
from tqdm import tqdm |
|
from einops.layers.torch import Rearrange |
|
|
|
import os |
|
import torch.distributed as dist |
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
from torch.utils.data.distributed import DistributedSampler |
|
|
|
def exists(v): |
|
return v is not None |
|
|
|
def default(v, d): |
|
return v if exists(v) else d |
|
|
|
class RMSNorm(nn.Module): |
|
def __init__(self, dim): |
|
super().__init__() |
|
self.scale = dim ** 0.5 |
|
self.gamma = nn.Parameter(torch.ones(dim)) |
|
|
|
def forward(self, x): |
|
return F.normalize(x, dim=-1) * self.scale * self.gamma |
|
|
|
class ProductKeyMemory(nn.Module): |
|
def __init__(self, dim, num_keys): |
|
super().__init__() |
|
self.dim = dim |
|
self.num_keys = num_keys |
|
self.keys = nn.Parameter(torch.randn(num_keys, dim // 2)) |
|
|
|
def forward(self, query): |
|
query = query.view(query.shape[0], 2, -1) |
|
dots = torch.einsum('bkd,nd->bkn', query, self.keys) |
|
return dots.view(query.shape[0], -1) |
|
|
|
class PEER(nn.Module): |
|
def __init__( |
|
self, |
|
dim, |
|
*, |
|
heads=8, |
|
num_experts=1_000_000, |
|
num_experts_per_head=16, |
|
activation=nn.GELU, |
|
dim_key=None, |
|
product_key_topk=None, |
|
separate_embed_per_head=False, |
|
pre_rmsnorm=False, |
|
dropout=0. |
|
): |
|
super().__init__() |
|
|
|
self.norm = RMSNorm(dim) if pre_rmsnorm else nn.Identity() |
|
|
|
self.heads = heads |
|
self.separate_embed_per_head = separate_embed_per_head |
|
self.num_experts = num_experts |
|
|
|
num_expert_sets = heads if separate_embed_per_head else 1 |
|
|
|
self.weight_down_embed = nn.Embedding(num_experts * num_expert_sets, dim) |
|
self.weight_up_embed = nn.Embedding(num_experts * num_expert_sets, dim) |
|
|
|
self.activation = activation() |
|
|
|
assert (num_experts ** 0.5).is_integer(), '`num_experts` needs to be a square' |
|
assert (dim % 2) == 0, 'feature dimension should be divisible by 2' |
|
|
|
dim_key = default(dim_key, dim // 2) |
|
self.num_keys = int(num_experts ** 0.5) |
|
|
|
self.to_queries = nn.Sequential( |
|
nn.Linear(dim, dim_key * heads * 2, bias=False), |
|
Rearrange('b n (p h d) -> p b n h d', p=2, h=heads) |
|
) |
|
|
|
self.product_key_topk = default(product_key_topk, num_experts_per_head) |
|
self.num_experts_per_head = num_experts_per_head |
|
|
|
self.keys = nn.Parameter(torch.randn(heads, self.num_keys, 2, dim_key)) |
|
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
def forward(self, x): |
|
x = self.norm(x) |
|
|
|
queries = self.to_queries(x) |
|
|
|
sim = einsum(queries, self.keys, 'p b n h d, h k p d -> p b n h k') |
|
|
|
(scores_x, scores_y), (indices_x, indices_y) = [s.topk(self.product_key_topk, dim=-1) for s in sim] |
|
|
|
all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2) |
|
all_indices = indices_x.unsqueeze(-1) * self.num_keys + indices_y.unsqueeze(-2) |
|
|
|
all_scores = all_scores.view(*all_scores.shape[:-2], -1) |
|
all_indices = all_indices.view(*all_indices.shape[:-2], -1) |
|
|
|
scores, pk_indices = all_scores.topk(self.num_experts_per_head, dim=-1) |
|
indices = all_indices.gather(-1, pk_indices) |
|
|
|
if self.separate_embed_per_head: |
|
head_expert_offsets = torch.arange(self.heads, device=x.device) * self.num_experts |
|
indices = indices + head_expert_offsets.view(1, 1, -1, 1) |
|
|
|
weights_down = self.weight_down_embed(pk_indices) |
|
weights_up = self.weight_up_embed(pk_indices) |
|
|
|
x = einsum(x, weights_down, 'b n d, b n h k d -> b n h k') |
|
|
|
x = self.activation(x) |
|
x = self.dropout(x) |
|
|
|
x = x * F.softmax(scores, dim=-1) |
|
|
|
x = einsum(x, weights_up, 'b n h k, b n h k d -> b n d') |
|
|
|
return x |
|
|
|
class TransformerBlock(nn.Module): |
|
def __init__(self, dim, num_heads, num_experts, num_experts_per_head, dropout=0.1): |
|
super(TransformerBlock, self).__init__() |
|
|
|
self.attention = nn.MultiheadAttention(dim, num_heads) |
|
self.norm1 = nn.LayerNorm(dim) |
|
self.norm2 = nn.LayerNorm(dim) |
|
|
|
self.peer1 = PEER(dim, heads=num_heads, num_experts=num_experts, num_experts_per_head=num_experts_per_head) |
|
self.peer2 = PEER(dim, heads=num_heads, num_experts=num_experts, num_experts_per_head=num_experts_per_head) |
|
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
def forward(self, x): |
|
attn_output, _ = self.attention(x, x, x) |
|
x = x + self.dropout(attn_output) |
|
x = self.norm1(x) |
|
|
|
peer_output1 = self.peer1(x) |
|
peer_output2 = self.peer2(F.gelu(peer_output1)) |
|
x = x + self.dropout(peer_output2) |
|
x = self.norm2(x) |
|
|
|
return x |
|
|
|
class PEERLanguageModel(nn.Module): |
|
def __init__(self, vocab_size, dim, num_layers, num_heads, num_experts, top_k): |
|
super().__init__() |
|
self.token_embedding = nn.Embedding(vocab_size, dim) |
|
self.position_embedding = nn.Embedding(512, dim) |
|
self.layers = nn.ModuleList([TransformerBlock(dim, num_heads, num_experts, top_k) for _ in range(num_layers)]) |
|
self.layer_norm = nn.LayerNorm(dim) |
|
self.lm_head = nn.Linear(dim, vocab_size, bias=False) |
|
|
|
def forward(self, x): |
|
b, s = x.shape |
|
positions = torch.arange(s, device=x.device).unsqueeze(0).expand(b, s) |
|
|
|
x = self.token_embedding(x) + self.position_embedding(positions) |
|
|
|
for layer in self.layers: |
|
x = layer(x) |
|
|
|
x = self.layer_norm(x) |
|
logits = self.lm_head(x) |
|
return logits |
|
|
|
class PileDataset(Dataset): |
|
def __init__(self, file_path, tokenizer, split='train', max_length=512): |
|
self.tokenizer = tokenizer |
|
self.max_length = max_length |
|
|
|
self.data = load_dataset(file_path, "wikitext-103-raw-v1", split=split) |
|
self.data = self.data.filter(lambda x: len(x['text']) > 0) |
|
if split == "train": |
|
self.data = self.data.select(range(0,300000)) |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, idx): |
|
text = self.data[idx]['text'] |
|
encoding = self.tokenizer(text, max_length=self.max_length, truncation=True, padding='max_length', return_tensors='pt') |
|
return encoding['input_ids'].squeeze(), encoding['attention_mask'].squeeze() |
|
|
|
|
|
def train(model, train_loader, optimizer, device): |
|
model.train() |
|
total_loss = 0 |
|
for batch in tqdm(train_loader, disable=not torch.distributed.get_rank() == 0): |
|
input_ids, attention_mask = batch |
|
input_ids, attention_mask = input_ids.to(device), attention_mask.to(device) |
|
|
|
optimizer.zero_grad() |
|
|
|
|
|
targets = input_ids[:, 1:].contiguous() |
|
input_ids = input_ids[:, :-1].contiguous() |
|
attention_mask = attention_mask[:, :-1].contiguous() |
|
|
|
outputs = model(input_ids) |
|
|
|
|
|
outputs = outputs.view(-1, outputs.size(-1)) |
|
targets = targets.view(-1) |
|
|
|
|
|
loss = F.cross_entropy(outputs, targets, ignore_index=0) |
|
|
|
loss.backward() |
|
optimizer.step() |
|
|
|
total_loss += loss.item() |
|
|
|
return total_loss / len(train_loader) |
|
|
|
def validate(model, val_loader, device): |
|
model.eval() |
|
total_loss = 0 |
|
with torch.no_grad(): |
|
for batch in tqdm(val_loader): |
|
input_ids, attention_mask = batch |
|
input_ids, attention_mask = input_ids.to(device), attention_mask.to(device) |
|
|
|
outputs = model(input_ids) |
|
loss = F.cross_entropy(outputs.view(-1, outputs.size(-1)), input_ids.view(-1), ignore_index=0) |
|
|
|
total_loss += loss.item() |
|
|
|
return total_loss / len(val_loader) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
dist.init_process_group(backend='nccl') |
|
local_rank = int(os.environ["LOCAL_RANK"]) |
|
torch.cuda.set_device(local_rank) |
|
device = torch.device("cuda", local_rank) |
|
|
|
|
|
vocab_size = 50257 |
|
dim = 256 |
|
num_layers = 8 |
|
num_heads = 8 |
|
num_experts = 512 * 512 |
|
top_k = 16 |
|
batch_size = 6 |
|
num_epochs = 10 |
|
learning_rate = 1e-4 |
|
|
|
|
|
tokenizer = GPT2Tokenizer.from_pretrained('gpt2') |
|
tokenizer.pad_token = tokenizer.eos_token |
|
model = PEERLanguageModel(vocab_size, dim, num_layers, num_heads, num_experts, top_k).to(device) |
|
|
|
|
|
model = DDP(model, device_ids=[local_rank], output_device=local_rank) |
|
|
|
|
|
train_dataset = PileDataset('Salesforce/wikitext', tokenizer, split='train') |
|
val_dataset = PileDataset('Salesforce/wikitext', tokenizer, split='validation') |
|
|
|
|
|
train_sampler = DistributedSampler(train_dataset) |
|
train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler) |
|
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) |
|
|
|
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) |
|
|
|
if local_rank == 0: |
|
print("Number of parameters:", sum(p.numel() for p in model.parameters())) |
|
|
|
|
|
best_val_loss = float('inf') |
|
for epoch in range(num_epochs): |
|
train_sampler.set_epoch(epoch) |
|
if local_rank == 0: |
|
print(f"Epoch Training {epoch+1}/{num_epochs}") |
|
train_loss = train(model, train_loader, optimizer, device) |
|
if local_rank == 0: |
|
print(f"Epoch Validation {epoch+1}/{num_epochs}") |
|
val_loss = validate(model, val_loader, device) |
|
print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}") |
|
|
|
|
|
if val_loss < best_val_loss: |
|
best_val_loss = val_loss |
|
torch.save(model.state_dict(), 'best_peer_language_model.pth') |
|
|
|
|
|
if local_rank == 0: |
|
torch.save(model.state_dict(), 'final_peer_language_model.pth') |
|
|
|
|
|
dist.destroy_process_group() |
|
|