mingru-stories / mingru_lm.py
damerajee's picture
Update mingru_lm.py
fa23921 verified
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.nn import Linear, Identity, Module
def default(v, d):
return v if exists(v) else d
def exists(v):
return v is not None
def heinsen_associative_scan_log(log_coeffs, log_values):
a_star = log_coeffs.cumsum(dim=1)
log_h0_plus_b_star = (log_values - a_star).logcumsumexp(dim=1)
log_h = a_star + log_h0_plus_b_star
return log_h.exp()
def log_g(x):
return torch.where(x >= 0, (F.relu(x) + 0.5).log(), -F.softplus(-x))
class MinGRU(Module):
def __init__(self, dim, expansion_factor=1.):
super().__init__()
dim_inner = int(dim * expansion_factor)
# Combined transformation for hidden state and gate
self.to_hidden = Linear(dim, dim_inner, bias=False)
self.to_gate = Linear(dim,dim_inner,bias=False)
# Output projection (Identity if no expansion)
self.to_out = Linear(dim_inner, dim, bias=False) if expansion_factor != 1. else Identity()
def forward(self, x, prev_hidden=None, return_next_prev_hidden=False):
# Split combined transformation into hidden and gate components
hidden= self.to_hidden(x)
gate = self.to_gate(x)
# Convert to log space for numerical stability
log_coeffs = -F.softplus(gate) # log(1 - σ(gate))
log_z = -F.softplus(-gate) # log(σ(gate))
log_tilde_h = log_g(hidden) # log(g(hidden))
log_values = log_z + log_tilde_h # log(z * h_tilde)
# Handle previous hidden state if it exists
if exists(prev_hidden):
log_values = torch.cat((log_g(prev_hidden), log_values), dim=1)
log_coeffs = F.pad(log_coeffs, (0, 0, 1, 0))
# Apply parallel scan in log space
out = heinsen_associative_scan_log(log_coeffs, log_values)
out = out[:, -x.shape[1]:] # Keep only the relevant sequence length
# Store last hidden state for potential return
next_prev_hidden = out[:, -1:]
# Apply output projection
out = self.to_out(out)
if not return_next_prev_hidden:
return out
return out, next_prev_hidden
class FeedForward(nn.Module):
def __init__(self, dim, mult=4):
super().__init__()
self.dim_inner = int(dim * mult)
self.net = nn.Sequential(
nn.Linear(dim, self.dim_inner),
nn.GELU(),
nn.Linear(self.dim_inner, dim)
)
def forward(self, x):
return self.net(x)
class CausalDepthWiseConv1d(nn.Module):
def __init__(self, dim, kernel_size):
super().__init__()
self.kernel_size = kernel_size
self.net = nn.Sequential(
nn.Conv1d(dim, dim, kernel_size = kernel_size, groups = dim),
nn.Conv1d(dim, dim, kernel_size = 1)
)
def forward(self, x):
x = x.transpose(1, 2) # b n d -> b d n
x = F.pad(x, (self.kernel_size - 1, 0), value = 0.)
x = self.net(x)
return x.transpose(1, 2) # b d n -> b n d
class RMSNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.scale = dim ** 0.5
self.gamma = nn.Parameter(torch.zeros(dim))
def forward(self, x):
return F.normalize(x, dim=-1) * self.scale * (self.gamma + 1)
class MinGRU_Layers(nn.Module):
def __init__(self, dim, num_tokens):
super().__init__()
self.emb = nn.Embedding(num_tokens, dim)
self.casual_depth = CausalDepthWiseConv1d(dim=dim,kernel_size=3)
self.rms_norm = RMSNorm(dim)
self.gru = MinGRU(dim)
self.ff = FeedForward(dim)
self.norm = RMSNorm(dim)
self.to_logits = nn.Linear(dim, num_tokens, bias=False)
def forward(self, inputs, labels=None, is_first_layer=True, prev_hiddens=None):
if is_first_layer:
x = self.emb(inputs)
else:
x = self.emb(inputs.argmax(dim=-1))
if exists(prev_hiddens):
x = x[:, -1:]
next_prev_hiddens = []
prev_hiddens = iter(default(prev_hiddens, []))
x = self.rms_norm(x)
prev_hidden = next(prev_hiddens, None)
min_gru_out, next_hidden = self.gru(x, prev_hidden, return_next_prev_hidden=True)
x = min_gru_out + x
next_prev_hiddens.append(next_hidden)
x = self.ff(x) + x
logits = self.to_logits(self.norm(x))
if labels is not None:
loss = F.cross_entropy(logits.transpose(1, 2), labels)
else:
loss = None
return loss, logits, next_prev_hiddens
class MinGRU_LM(nn.Module):
def __init__(self, dim, num_tokens, num_layers):
super().__init__()
self.layers = nn.ModuleList([MinGRU_Layers(dim, num_tokens) for _ in range(num_layers)])
def forward(self, inputs, labels):
total_loss = 0
hidden_states = [None] * len(self.layers)
current_input = inputs
for i, layer in enumerate(self.layers):
loss, logits, next_hiddens = layer(
inputs=current_input,
labels=labels,
is_first_layer=(i == 0),
prev_hiddens=hidden_states[i]
)
if loss is not None:
total_loss += loss
current_input = logits # Use the logits as input for the next layer
hidden_states[i] = next_hiddens
return total_loss / len(self.layers), logits