import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F from torch.nn import Linear, Identity, Module def default(v, d): return v if exists(v) else d def exists(v): return v is not None def heinsen_associative_scan_log(log_coeffs, log_values): a_star = log_coeffs.cumsum(dim=1) log_h0_plus_b_star = (log_values - a_star).logcumsumexp(dim=1) log_h = a_star + log_h0_plus_b_star return log_h.exp() def log_g(x): return torch.where(x >= 0, (F.relu(x) + 0.5).log(), -F.softplus(-x)) class MinGRU(Module): def __init__(self, dim, expansion_factor=1.): super().__init__() dim_inner = int(dim * expansion_factor) # Combined transformation for hidden state and gate self.to_hidden = Linear(dim, dim_inner, bias=False) self.to_gate = Linear(dim,dim_inner,bias=False) # Output projection (Identity if no expansion) self.to_out = Linear(dim_inner, dim, bias=False) if expansion_factor != 1. else Identity() def forward(self, x, prev_hidden=None, return_next_prev_hidden=False): # Split combined transformation into hidden and gate components hidden= self.to_hidden(x) gate = self.to_gate(x) # Convert to log space for numerical stability log_coeffs = -F.softplus(gate) # log(1 - σ(gate)) log_z = -F.softplus(-gate) # log(σ(gate)) log_tilde_h = log_g(hidden) # log(g(hidden)) log_values = log_z + log_tilde_h # log(z * h_tilde) # Handle previous hidden state if it exists if exists(prev_hidden): log_values = torch.cat((log_g(prev_hidden), log_values), dim=1) log_coeffs = F.pad(log_coeffs, (0, 0, 1, 0)) # Apply parallel scan in log space out = heinsen_associative_scan_log(log_coeffs, log_values) out = out[:, -x.shape[1]:] # Keep only the relevant sequence length # Store last hidden state for potential return next_prev_hidden = out[:, -1:] # Apply output projection out = self.to_out(out) if not return_next_prev_hidden: return out return out, next_prev_hidden class FeedForward(nn.Module): def __init__(self, dim, mult=4): super().__init__() self.dim_inner = int(dim * mult) self.net = nn.Sequential( nn.Linear(dim, self.dim_inner), nn.GELU(), nn.Linear(self.dim_inner, dim) ) def forward(self, x): return self.net(x) class CausalDepthWiseConv1d(nn.Module): def __init__(self, dim, kernel_size): super().__init__() self.kernel_size = kernel_size self.net = nn.Sequential( nn.Conv1d(dim, dim, kernel_size = kernel_size, groups = dim), nn.Conv1d(dim, dim, kernel_size = 1) ) def forward(self, x): x = x.transpose(1, 2) # b n d -> b d n x = F.pad(x, (self.kernel_size - 1, 0), value = 0.) x = self.net(x) return x.transpose(1, 2) # b d n -> b n d class RMSNorm(nn.Module): def __init__(self, dim): super().__init__() self.scale = dim ** 0.5 self.gamma = nn.Parameter(torch.zeros(dim)) def forward(self, x): return F.normalize(x, dim=-1) * self.scale * (self.gamma + 1) class MinGRU_Layers(nn.Module): def __init__(self, dim, num_tokens): super().__init__() self.emb = nn.Embedding(num_tokens, dim) self.casual_depth = CausalDepthWiseConv1d(dim=dim,kernel_size=3) self.rms_norm = RMSNorm(dim) self.gru = MinGRU(dim) self.ff = FeedForward(dim) self.norm = RMSNorm(dim) self.to_logits = nn.Linear(dim, num_tokens, bias=False) def forward(self, inputs, labels=None, is_first_layer=True, prev_hiddens=None): if is_first_layer: x = self.emb(inputs) else: x = self.emb(inputs.argmax(dim=-1)) if exists(prev_hiddens): x = x[:, -1:] next_prev_hiddens = [] prev_hiddens = iter(default(prev_hiddens, [])) x = self.rms_norm(x) prev_hidden = next(prev_hiddens, None) min_gru_out, next_hidden = self.gru(x, prev_hidden, return_next_prev_hidden=True) x = min_gru_out + x next_prev_hiddens.append(next_hidden) x = self.ff(x) + x logits = self.to_logits(self.norm(x)) if labels is not None: loss = F.cross_entropy(logits.transpose(1, 2), labels) else: loss = None return loss, logits, next_prev_hiddens class MinGRU_LM(nn.Module): def __init__(self, dim, num_tokens, num_layers): super().__init__() self.layers = nn.ModuleList([MinGRU_Layers(dim, num_tokens) for _ in range(num_layers)]) def forward(self, inputs, labels): total_loss = 0 hidden_states = [None] * len(self.layers) current_input = inputs for i, layer in enumerate(self.layers): loss, logits, next_hiddens = layer( inputs=current_input, labels=labels, is_first_layer=(i == 0), prev_hiddens=hidden_states[i] ) if loss is not None: total_loss += loss current_input = logits # Use the logits as input for the next layer hidden_states[i] = next_hiddens return total_loss / len(self.layers), logits