damerajee commited on
Commit
b14db6d
1 Parent(s): c232079

Create mingru_lm.py

Browse files
Files changed (1) hide show
  1. mingru_lm.py +161 -0
mingru_lm.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ import torch.nn.functional as F
5
+ from torch.nn import Linear, Identity, Module
6
+
7
+
8
+ def default(v, d):
9
+ return v if exists(v) else d
10
+
11
+ def exists(v):
12
+ return v is not None
13
+
14
+ def heinsen_associative_scan_log(log_coeffs, log_values):
15
+ a_star = log_coeffs.cumsum(dim=1)
16
+ log_h0_plus_b_star = (log_values - a_star).logcumsumexp(dim=1)
17
+ log_h = a_star + log_h0_plus_b_star
18
+ return log_h.exp()
19
+
20
+ def log_g(x):
21
+ return torch.where(x >= 0, (F.relu(x) + 0.5).log(), -F.softplus(-x))
22
+
23
+ class MinGRU(Module):
24
+ def __init__(self, dim, expansion_factor=1.):
25
+ super().__init__()
26
+ dim_inner = int(dim * expansion_factor)
27
+ # Combined transformation for hidden state and gate
28
+ self.to_hidden = Linear(dim, dim_inner, bias=False)
29
+ self.to_gate = Linear(dim,dim_inner,bias=False)
30
+ # Output projection (Identity if no expansion)
31
+ self.to_out = Linear(dim_inner, dim, bias=False) if expansion_factor != 1. else Identity()
32
+
33
+ def forward(self, x, prev_hidden=None, return_next_prev_hidden=False):
34
+ # Split combined transformation into hidden and gate components
35
+ hidden= self.to_hidden(x)
36
+ gate = self.to_gate(x)
37
+ # Convert to log space for numerical stability
38
+ log_coeffs = -F.softplus(gate) # log(1 - σ(gate))
39
+ log_z = -F.softplus(-gate) # log(σ(gate))
40
+ log_tilde_h = log_g(hidden) # log(g(hidden))
41
+ log_values = log_z + log_tilde_h # log(z * h_tilde)
42
+
43
+ # Handle previous hidden state if it exists
44
+ if exists(prev_hidden):
45
+ log_values = torch.cat((log_g(prev_hidden), log_values), dim=1)
46
+ log_coeffs = F.pad(log_coeffs, (0, 0, 1, 0))
47
+
48
+ # Apply parallel scan in log space
49
+ out = heinsen_associative_scan_log(log_coeffs, log_values)
50
+ out = out[:, -x.shape[1]:] # Keep only the relevant sequence length
51
+
52
+ # Store last hidden state for potential return
53
+ next_prev_hidden = out[:, -1:]
54
+
55
+ # Apply output projection
56
+ out = self.to_out(out)
57
+
58
+ if not return_next_prev_hidden:
59
+ return out
60
+ return out, next_prev_hidden
61
+
62
+ if __name__ == "__main__":
63
+ x = torch.rand(2,256,512)
64
+ model = MinGRU(dim=512)
65
+ out , next_prev_hidden = model(x,return_next_prev_hidden=True)
66
+
67
+
68
+ print("out",out[0,0,:3])
69
+ print("next_prev_hidden",next_prev_hidden[0,0,:3])
70
+ print("out shape",out.shape)
71
+ print("X shape",x.shape)
72
+ assert x.shape == out.shape
73
+
74
+
75
+ class FeedForward(nn.Module):
76
+ def __init__(self, dim, mult=4):
77
+ super().__init__()
78
+ self.dim_inner = int(dim * mult)
79
+ self.net = nn.Sequential(
80
+ nn.Linear(dim, self.dim_inner),
81
+ nn.GELU(),
82
+ nn.Linear(self.dim_inner, dim)
83
+ )
84
+
85
+ def forward(self, x):
86
+ return self.net(x)
87
+
88
+ class RMSNorm(nn.Module):
89
+ def __init__(self, dim):
90
+ super().__init__()
91
+ self.scale = dim ** 0.5
92
+ self.gamma = nn.Parameter(torch.zeros(dim))
93
+
94
+ def forward(self, x):
95
+ return F.normalize(x, dim=-1) * self.scale * (self.gamma + 1)
96
+
97
+ class MinGRU_Layers(nn.Module):
98
+ def __init__(self, dim, num_tokens):
99
+ super().__init__()
100
+ self.emb = nn.Embedding(num_tokens, dim)
101
+ self.rms_norm = RMSNorm(dim)
102
+ self.gru = MinGRU(dim)
103
+ self.ff = FeedForward(dim)
104
+
105
+ self.norm = RMSNorm(dim)
106
+ self.to_logits = nn.Linear(dim, num_tokens, bias=False)
107
+
108
+ def forward(self, inputs, labels=None, is_first_layer=True, prev_hiddens=None):
109
+ if is_first_layer:
110
+ x = self.emb(inputs)
111
+ else:
112
+ x = self.emb(inputs.argmax(dim=-1))
113
+
114
+ if exists(prev_hiddens):
115
+ x = x[:, -1:]
116
+
117
+ next_prev_hiddens = []
118
+ prev_hiddens = iter(default(prev_hiddens, []))
119
+
120
+ x = self.rms_norm(x)
121
+ prev_hidden = next(prev_hiddens, None)
122
+
123
+ min_gru_out, next_hidden = self.gru(x, prev_hidden, return_next_prev_hidden=True)
124
+
125
+ x = min_gru_out + x
126
+ next_prev_hiddens.append(next_hidden)
127
+ x = self.ff(x) + x
128
+ logits = self.to_logits(self.norm(x))
129
+
130
+ if labels is not None:
131
+ loss = F.cross_entropy(logits.transpose(1, 2), labels)
132
+ else:
133
+ loss = None
134
+
135
+ return loss, logits, next_prev_hiddens
136
+
137
+ class MinGRU_LM(nn.Module):
138
+ def __init__(self, dim, num_tokens, num_layers):
139
+ super().__init__()
140
+ self.layers = nn.ModuleList([MinGRU_Layers(dim, num_tokens) for _ in range(num_layers)])
141
+
142
+ def forward(self, inputs, labels):
143
+ total_loss = 0
144
+ hidden_states = [None] * len(self.layers)
145
+ current_input = inputs
146
+
147
+ for i, layer in enumerate(self.layers):
148
+ loss, logits, next_hiddens = layer(
149
+ inputs=current_input,
150
+ labels=labels,
151
+ is_first_layer=(i == 0),
152
+ prev_hiddens=hidden_states[i]
153
+ )
154
+
155
+ if loss is not None:
156
+ total_loss += loss
157
+
158
+ current_input = logits # Use the logits as input for the next layer
159
+ hidden_states[i] = next_hiddens
160
+
161
+ return total_loss / len(self.layers), logits