asigalov61 commited on
Commit
afa4794
·
1 Parent(s): e59e137

Upload x_transformer.py

Browse files
Files changed (1) hide show
  1. x_transformer.py +2001 -0
x_transformer.py ADDED
@@ -0,0 +1,2001 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #===================================================================================================================
2
+
3
+ # X Trasformer Module
4
+ # Partial x-transformers code With useful modifications
5
+ #
6
+ # Version 1.0
7
+ #
8
+ # Original source code courtesy of lucidrains
9
+ # https://github.com/lucidrains/x-transformers
10
+ #
11
+ # Original source code retrieved on 05/10/2023
12
+ #
13
+ # Project Los Angeles
14
+ # Tegridy Code 2023
15
+
16
+ #===================================================================================================================
17
+
18
+ # Critical dependencies
19
+ #
20
+ # !pip install torch
21
+ # !pip install einops
22
+
23
+ #===================================================================================================================
24
+
25
+ from functools import partial
26
+
27
+ import torch
28
+ from torch import nn, einsum, Tensor
29
+ import torch.nn.functional as F
30
+
31
+ from collections import namedtuple
32
+ from functools import wraps
33
+ from packaging import version
34
+ from dataclasses import dataclass
35
+
36
+ from einops import rearrange
37
+
38
+ import math
39
+ from random import random
40
+
41
+ from functools import partial
42
+ from inspect import isfunction
43
+
44
+ from dataclasses import dataclass
45
+ from typing import List
46
+
47
+ from einops import rearrange, repeat, reduce
48
+ from einops.layers.torch import Rearrange
49
+
50
+ from math import ceil
51
+
52
+ from einops import rearrange, pack, unpack
53
+
54
+ #===================================================================================================================
55
+
56
+ # constants
57
+
58
+ EfficientAttentionConfig = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
59
+
60
+ @dataclass
61
+ class Intermediates:
62
+ qk_similarities: Tensor = None
63
+ pre_softmax_attn: Tensor = None
64
+ post_softmax_attn: Tensor = None
65
+
66
+ # helpers
67
+
68
+ def exists(val):
69
+ return val is not None
70
+
71
+ def default(val, d):
72
+ return val if exists(val) else d
73
+
74
+ def once(fn):
75
+ called = False
76
+ @wraps(fn)
77
+ def inner(x):
78
+ nonlocal called
79
+ if called:
80
+ return
81
+ called = True
82
+ return fn(x)
83
+ return inner
84
+
85
+ print_once = once(print)
86
+
87
+ # main class
88
+
89
+ class Attend(nn.Module):
90
+ def __init__(
91
+ self,
92
+ *,
93
+ dropout = 0.,
94
+ causal = False,
95
+ heads = None,
96
+ talking_heads = False,
97
+ scale = None,
98
+ qk_norm = False,
99
+ flash = False,
100
+ ):
101
+ super().__init__()
102
+ self.scale = scale
103
+ self.qk_norm = qk_norm
104
+ self.causal = causal
105
+ self.attn_fn = partial(F.softmax, dtype = torch.float32) if not qk_norm else F.softmax
106
+
107
+ self.dropout = dropout
108
+ self.attn_dropout = nn.Dropout(dropout)
109
+
110
+ # talking heads
111
+
112
+ assert not (flash and talking_heads), 'talking heads not compatible with flash attention'
113
+
114
+ self.talking_heads = talking_heads
115
+ if talking_heads:
116
+ self.pre_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False)
117
+ self.post_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False)
118
+
119
+ # flash attention
120
+
121
+ self.flash = flash
122
+ assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
123
+
124
+ # determine efficient attention configs for cuda and cpu
125
+
126
+ self.cpu_config = EfficientAttentionConfig(True, True, True)
127
+ self.cuda_config = None
128
+
129
+ if not torch.cuda.is_available() or not flash:
130
+ return
131
+
132
+ device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
133
+
134
+ if device_properties.major == 8 and device_properties.minor == 0:
135
+ print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
136
+ self.cuda_config = EfficientAttentionConfig(True, False, False)
137
+ else:
138
+ print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
139
+ self.cuda_config = EfficientAttentionConfig(False, True, True)
140
+
141
+ def flash_attn(
142
+ self,
143
+ q, k, v,
144
+ mask = None,
145
+ attn_bias = None
146
+ ):
147
+ batch, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
148
+
149
+ # Recommended for multi-query single-key-value attention by Tri Dao
150
+ # kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64])
151
+
152
+ if k.ndim == 3:
153
+ k = rearrange(k, 'b ... -> b 1 ...').expand_as(q)
154
+
155
+ if v.ndim == 3:
156
+ v = rearrange(v, 'b ... -> b 1 ...').expand_as(q)
157
+
158
+ # handle scale - by default they scale by dim_head ** -0.5, but need to take care if using cosine sim attention
159
+
160
+ if self.qk_norm:
161
+ default_scale = q.shape[-1] ** -0.5
162
+ q = q * (default_scale / self.scale)
163
+
164
+ # Check if mask exists and expand to compatible shape
165
+ # The mask is B L, so it would have to be expanded to B H N L
166
+
167
+ causal = self.causal
168
+
169
+ if exists(mask):
170
+ assert mask.ndim == 4
171
+ mask = mask.expand(batch, heads, q_len, k_len)
172
+
173
+ # manually handle causal mask, if another mask was given
174
+
175
+ if causal:
176
+ causal_mask = torch.ones((q_len, k_len), dtype = torch.bool, device = device).triu(k_len - q_len + 1)
177
+ mask = mask | causal_mask
178
+ causal = False
179
+
180
+ # handle alibi positional bias
181
+ # convert from bool to float
182
+
183
+ if exists(attn_bias):
184
+ attn_bias = rearrange(attn_bias, 'h i j -> 1 h i j').expand(batch, -1, -1, -1)
185
+
186
+ # if mask given, the mask would already contain the causal mask from above logic
187
+ # otherwise, if no mask given but still causal, mask out alibi positional bias to a large negative number
188
+
189
+ mask_value = -torch.finfo(q.dtype).max
190
+
191
+ if exists(mask):
192
+ attn_bias = attn_bias.masked_fill(mask, mask_value // 2)
193
+ elif causal:
194
+ causal_mask = torch.ones((q_len, k_len), dtype = torch.bool, device = device).triu(k_len - q_len + 1)
195
+ attn_bias = attn_bias.masked_fill(causal_mask, mask_value // 2)
196
+ causal = False
197
+
198
+ # scaled_dot_product_attention handles attn_mask either as bool or additive bias
199
+ # make it an additive bias here
200
+
201
+ mask = attn_bias
202
+
203
+ # Check if there is a compatible device for flash attention
204
+
205
+ config = self.cuda_config if is_cuda else self.cpu_config
206
+
207
+ # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
208
+
209
+ with torch.backends.cuda.sdp_kernel(**config._asdict()):
210
+ out = F.scaled_dot_product_attention(
211
+ q, k, v,
212
+ attn_mask = mask,
213
+ dropout_p = self.dropout if self.training else 0.,
214
+ is_causal = causal
215
+ )
216
+
217
+ return out, Intermediates()
218
+
219
+ def forward(
220
+ self,
221
+ q, k, v,
222
+ mask = None,
223
+ attn_bias = None,
224
+ prev_attn = None
225
+ ):
226
+ """
227
+ einstein notation
228
+ b - batch
229
+ h - heads
230
+ n, i, j - sequence length (base sequence length, source, target)
231
+ d - feature dimension
232
+ """
233
+
234
+ n, device = q.shape[-2], q.device
235
+
236
+ scale = default(self.scale, q.shape[-1] ** -0.5)
237
+
238
+ if self.flash:
239
+ assert not exists(prev_attn), 'residual attention not compatible with flash attention'
240
+ return self.flash_attn(q, k, v, mask = mask, attn_bias = attn_bias)
241
+
242
+ kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d'
243
+
244
+ dots = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k) * scale
245
+
246
+ if exists(prev_attn):
247
+ dots = dots + prev_attn
248
+
249
+ qk_similarities = dots.clone()
250
+
251
+ if self.talking_heads:
252
+ dots = self.pre_softmax_talking_heads(dots)
253
+
254
+ if exists(attn_bias):
255
+ dots = dots + attn_bias
256
+
257
+ dtype = dots.dtype
258
+ pre_softmax_attn = dots.clone()
259
+
260
+ mask_value = -torch.finfo(dots.dtype).max
261
+
262
+ if exists(mask):
263
+ dots = dots.masked_fill(mask, mask_value)
264
+
265
+ if self.causal:
266
+ i, j = dots.shape[-2:]
267
+ causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)
268
+ dots = dots.masked_fill(causal_mask, mask_value)
269
+
270
+ attn = self.attn_fn(dots, dim = -1)
271
+ attn = attn.type(dtype)
272
+
273
+ post_softmax_attn = attn.clone()
274
+
275
+ attn = self.attn_dropout(attn)
276
+
277
+ if self.talking_heads:
278
+ attn = self.post_softmax_talking_heads(attn)
279
+
280
+ out = einsum(f'b h i j, {kv_einsum_eq} -> b h i d', attn, v)
281
+
282
+ intermediates = Intermediates(
283
+ qk_similarities = qk_similarities,
284
+ pre_softmax_attn = pre_softmax_attn,
285
+ post_softmax_attn = post_softmax_attn
286
+ )
287
+
288
+ return out, intermediates
289
+
290
+ #===================================================================================================================
291
+
292
+ # constants
293
+
294
+ DEFAULT_DIM_HEAD = 64
295
+
296
+ @dataclass
297
+ class LayerIntermediates:
298
+ hiddens: List[Tensor] = None,
299
+ attn_intermediates: List[Intermediates] = None
300
+
301
+ # helpers
302
+
303
+ def exists(val):
304
+ return val is not None
305
+
306
+ def default(val, d):
307
+ if exists(val):
308
+ return val
309
+ return d() if isfunction(d) else d
310
+
311
+ def cast_tuple(val, depth):
312
+ return val if isinstance(val, tuple) else (val,) * depth
313
+
314
+ def maybe(fn):
315
+ @wraps(fn)
316
+ def inner(x, *args, **kwargs):
317
+ if not exists(x):
318
+ return x
319
+ return fn(x, *args, **kwargs)
320
+ return inner
321
+
322
+ class always():
323
+ def __init__(self, val):
324
+ self.val = val
325
+ def __call__(self, *args, **kwargs):
326
+ return self.val
327
+
328
+ class not_equals():
329
+ def __init__(self, val):
330
+ self.val = val
331
+ def __call__(self, x, *args, **kwargs):
332
+ return x != self.val
333
+
334
+ class equals():
335
+ def __init__(self, val):
336
+ self.val = val
337
+ def __call__(self, x, *args, **kwargs):
338
+ return x == self.val
339
+
340
+ # tensor helpers
341
+
342
+ def max_neg_value(tensor):
343
+ return -torch.finfo(tensor.dtype).max
344
+
345
+ def l2norm(t, groups = 1):
346
+ t = rearrange(t, '... (g d) -> ... g d', g = groups)
347
+ t = F.normalize(t, p = 2, dim = -1)
348
+ return rearrange(t, '... g d -> ... (g d)')
349
+
350
+ def pad_at_dim(t, pad, dim = -1, value = 0.):
351
+ dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
352
+ zeros = ((0, 0) * dims_from_right)
353
+ return F.pad(t, (*zeros, *pad), value = value)
354
+
355
+ def or_reduce(masks):
356
+ head, *body = masks
357
+ for rest in body:
358
+ head = head | rest
359
+ return head
360
+
361
+ # init helpers
362
+
363
+ def init_zero_(layer):
364
+ nn.init.constant_(layer.weight, 0.)
365
+ if exists(layer.bias):
366
+ nn.init.constant_(layer.bias, 0.)
367
+
368
+ # keyword argument helpers
369
+
370
+ def pick_and_pop(keys, d):
371
+ values = list(map(lambda key: d.pop(key), keys))
372
+ return dict(zip(keys, values))
373
+
374
+ def group_dict_by_key(cond, d):
375
+ return_val = [dict(),dict()]
376
+ for key in d.keys():
377
+ match = bool(cond(key))
378
+ ind = int(not match)
379
+ return_val[ind][key] = d[key]
380
+ return (*return_val,)
381
+
382
+ def string_begins_with(prefix, str):
383
+ return str.startswith(prefix)
384
+
385
+ def group_by_key_prefix(prefix, d):
386
+ return group_dict_by_key(partial(string_begins_with, prefix), d)
387
+
388
+ def groupby_prefix_and_trim(prefix, d):
389
+ kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
390
+ kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
391
+ return kwargs_without_prefix, kwargs
392
+
393
+ # initializations
394
+
395
+ def deepnorm_init(
396
+ transformer,
397
+ beta,
398
+ module_name_match_list = ['.ff.', '.to_v', '.to_out']
399
+ ):
400
+ for name, module in transformer.named_modules():
401
+ if type(module) != nn.Linear:
402
+ continue
403
+
404
+ needs_beta_gain = any(map(lambda substr: substr in name, module_name_match_list))
405
+ gain = beta if needs_beta_gain else 1
406
+ nn.init.xavier_normal_(module.weight.data, gain = gain)
407
+
408
+ if exists(module.bias):
409
+ nn.init.constant_(module.bias.data, 0)
410
+
411
+ # structured dropout, more effective than traditional attention dropouts
412
+
413
+ def dropout_seq(seq, mask, dropout):
414
+ b, n, *_, device = *seq.shape, seq.device
415
+ logits = torch.randn(b, n, device = device)
416
+
417
+ if exists(mask):
418
+ mask_value = max_neg_value(logits)
419
+ logits = logits.masked_fill(~mask, mask_value)
420
+
421
+ keep_prob = 1. - dropout
422
+ num_keep = max(1, int(keep_prob * n))
423
+ keep_indices = logits.topk(num_keep, dim = 1).indices
424
+
425
+ batch_indices = torch.arange(b, device = device)
426
+ batch_indices = rearrange(batch_indices, 'b -> b 1')
427
+
428
+ seq = seq[batch_indices, keep_indices]
429
+
430
+ if exists(mask):
431
+ seq_counts = mask.sum(dim = -1)
432
+ seq_keep_counts = torch.ceil(seq_counts * keep_prob).int()
433
+ keep_mask = torch.arange(num_keep, device = device) < rearrange(seq_keep_counts, 'b -> b 1')
434
+
435
+ mask = mask[batch_indices, keep_indices] & keep_mask
436
+
437
+ return seq, mask
438
+
439
+ # activations
440
+
441
+ class ReluSquared(nn.Module):
442
+ def forward(self, x):
443
+ return F.relu(x) ** 2
444
+
445
+ # embedding
446
+
447
+ class TokenEmbedding(nn.Module):
448
+ def __init__(self, dim, num_tokens, l2norm_embed = False):
449
+ super().__init__()
450
+ self.l2norm_embed = l2norm_embed
451
+ self.emb = nn.Embedding(num_tokens, dim)
452
+
453
+ def forward(self, x):
454
+ token_emb = self.emb(x)
455
+ return l2norm(token_emb) if self.l2norm_embed else token_emb
456
+
457
+ # positional embeddings
458
+
459
+ class AbsolutePositionalEmbedding(nn.Module):
460
+ def __init__(self, dim, max_seq_len, l2norm_embed = False):
461
+ super().__init__()
462
+ self.scale = dim ** -0.5 if not l2norm_embed else 1.
463
+ self.max_seq_len = max_seq_len
464
+ self.l2norm_embed = l2norm_embed
465
+ self.emb = nn.Embedding(max_seq_len, dim)
466
+
467
+ def forward(self, x, pos = None):
468
+ seq_len, device = x.shape[1], x.device
469
+ assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'
470
+
471
+ if not exists(pos):
472
+ pos = torch.arange(seq_len, device = device)
473
+
474
+ pos_emb = self.emb(pos)
475
+ pos_emb = pos_emb * self.scale
476
+ return l2norm(pos_emb) if self.l2norm_embed else pos_emb
477
+
478
+ class ScaledSinusoidalEmbedding(nn.Module):
479
+ def __init__(self, dim, theta = 10000):
480
+ super().__init__()
481
+ assert (dim % 2) == 0
482
+ self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5)
483
+
484
+ half_dim = dim // 2
485
+ freq_seq = torch.arange(half_dim).float() / half_dim
486
+ inv_freq = theta ** -freq_seq
487
+ self.register_buffer('inv_freq', inv_freq, persistent = False)
488
+
489
+ def forward(self, x, pos = None):
490
+ seq_len, device = x.shape[1], x.device
491
+
492
+ if not exists(pos):
493
+ pos = torch.arange(seq_len, device = device)
494
+
495
+ emb = einsum('i, j -> i j', pos, self.inv_freq)
496
+ emb = torch.cat((emb.sin(), emb.cos()), dim = -1)
497
+ return emb * self.scale
498
+
499
+ class RelativePositionBias(nn.Module):
500
+ def __init__(self, scale, causal = False, num_buckets = 32, max_distance = 128, heads = 8):
501
+ super().__init__()
502
+ self.scale = scale
503
+ self.causal = causal
504
+ self.num_buckets = num_buckets
505
+ self.max_distance = max_distance
506
+ self.relative_attention_bias = nn.Embedding(num_buckets, heads)
507
+
508
+ @staticmethod
509
+ def _relative_position_bucket(relative_position, causal = True, num_buckets = 32, max_distance = 128):
510
+ ret = 0
511
+ n = -relative_position
512
+ if not causal:
513
+ num_buckets //= 2
514
+ ret += (n < 0).long() * num_buckets
515
+ n = torch.abs(n)
516
+ else:
517
+ n = torch.max(n, torch.zeros_like(n))
518
+
519
+ max_exact = num_buckets // 2
520
+ is_small = n < max_exact
521
+
522
+ val_if_large = max_exact + (
523
+ torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
524
+ ).long()
525
+ val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
526
+
527
+ ret += torch.where(is_small, n, val_if_large)
528
+ return ret
529
+
530
+ @property
531
+ def device(self):
532
+ return next(self.parameters()).device
533
+
534
+ def forward(self, i, j):
535
+ device = self.device
536
+ q_pos = torch.arange(j - i, j, dtype = torch.long, device = device)
537
+ k_pos = torch.arange(j, dtype = torch.long, device = device)
538
+ rel_pos = k_pos[None, :] - q_pos[:, None]
539
+ rp_bucket = self._relative_position_bucket(rel_pos, causal = self.causal, num_buckets = self.num_buckets, max_distance = self.max_distance)
540
+ values = self.relative_attention_bias(rp_bucket)
541
+ bias = rearrange(values, 'i j h -> h i j')
542
+ return bias * self.scale
543
+
544
+ class DynamicPositionBias(nn.Module):
545
+ def __init__(self, dim, *, heads, depth, log_distance = False, norm = False):
546
+ super().__init__()
547
+ assert depth >= 1, 'depth for dynamic position bias MLP must be greater or equal to 1'
548
+ self.log_distance = log_distance
549
+
550
+ self.mlp = nn.ModuleList([])
551
+
552
+ self.mlp.append(nn.Sequential(
553
+ nn.Linear(1, dim),
554
+ nn.LayerNorm(dim) if norm else nn.Identity(),
555
+ nn.SiLU()
556
+ ))
557
+
558
+ for _ in range(depth - 1):
559
+ self.mlp.append(nn.Sequential(
560
+ nn.Linear(dim, dim),
561
+ nn.LayerNorm(dim) if norm else nn.Identity(),
562
+ nn.SiLU()
563
+ ))
564
+
565
+ self.mlp.append(nn.Linear(dim, heads))
566
+
567
+ @property
568
+ def device(self):
569
+ return next(self.parameters()).device
570
+
571
+ def forward(self, i, j):
572
+ assert i == j
573
+ n, device = j, self.device
574
+
575
+ # get the (n x n) matrix of distances
576
+ seq_arange = torch.arange(n, device = device)
577
+ context_arange = torch.arange(n, device = device)
578
+ indices = rearrange(seq_arange, 'i -> i 1') - rearrange(context_arange, 'j -> 1 j')
579
+ indices += (n - 1)
580
+
581
+ # input to continuous positions MLP
582
+ pos = torch.arange(-n + 1, n, device = device).float()
583
+ pos = rearrange(pos, '... -> ... 1')
584
+
585
+ if self.log_distance:
586
+ pos = torch.sign(pos) * torch.log(pos.abs() + 1) # log of distance is sign(rel_pos) * log(abs(rel_pos) + 1)
587
+
588
+ for layer in self.mlp:
589
+ pos = layer(pos)
590
+
591
+ # get position biases
592
+ bias = pos[indices]
593
+ bias = rearrange(bias, 'i j h -> h i j')
594
+ return bias
595
+
596
+ class AlibiPositionalBias(nn.Module):
597
+ def __init__(self, heads, total_heads, **kwargs):
598
+ super().__init__()
599
+ self.heads = heads
600
+ self.total_heads = total_heads
601
+
602
+ slopes = Tensor(self._get_slopes(heads))
603
+ slopes = rearrange(slopes, 'h -> h 1 1')
604
+ self.register_buffer('slopes', slopes, persistent = False)
605
+ self.register_buffer('bias', None, persistent = False)
606
+
607
+ def get_bias(self, i, j, device):
608
+ i_arange = torch.arange(j - i, j, device = device)
609
+ j_arange = torch.arange(j, device = device)
610
+ bias = -torch.abs(rearrange(j_arange, 'j -> 1 1 j') - rearrange(i_arange, 'i -> 1 i 1'))
611
+ return bias
612
+
613
+ @staticmethod
614
+ def _get_slopes(heads):
615
+ def get_slopes_power_of_2(n):
616
+ start = (2**(-2**-(math.log2(n)-3)))
617
+ ratio = start
618
+ return [start*ratio**i for i in range(n)]
619
+
620
+ if math.log2(heads).is_integer():
621
+ return get_slopes_power_of_2(heads)
622
+
623
+ closest_power_of_2 = 2 ** math.floor(math.log2(heads))
624
+ return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][:heads-closest_power_of_2]
625
+
626
+ @property
627
+ def device(self):
628
+ return next(self.buffers()).device
629
+
630
+ def forward(self, i, j):
631
+ h, device = self.total_heads, self.device
632
+
633
+ if exists(self.bias) and self.bias.shape[-1] >= j:
634
+ return self.bias[..., :i, :j]
635
+
636
+ bias = self.get_bias(i, j, device)
637
+ bias = bias * self.slopes
638
+
639
+ num_heads_unalibied = h - bias.shape[0]
640
+ bias = pad_at_dim(bias, (0, num_heads_unalibied), dim = 0)
641
+ self.register_buffer('bias', bias, persistent = False)
642
+
643
+ return self.bias
644
+
645
+ class LearnedAlibiPositionalBias(AlibiPositionalBias):
646
+ def __init__(self, heads, total_heads):
647
+ super().__init__(heads, total_heads)
648
+ log_slopes = torch.log(self.slopes)
649
+ self.learned_logslopes = nn.Parameter(log_slopes)
650
+
651
+ def forward(self, i, j):
652
+ h, i, j, device = self.heads, self.device
653
+
654
+ def get_slopes(param):
655
+ return pad_at_dim(param.exp(), (0, h - param.shape[0]), dim = -2)
656
+
657
+ if exists(self.bias) and self.bias.shape[-1] >= j:
658
+ bias = self.bias[..., :i, :j]
659
+ else:
660
+ bias = self.get_bias(i, j, device)
661
+ self.register_buffer('bias', bias, persistent = False)
662
+
663
+ slopes = get_slopes(self.learned_logslopes)
664
+ bias = bias * slopes
665
+
666
+ return bias
667
+
668
+ class RotaryEmbedding(nn.Module):
669
+ def __init__(
670
+ self,
671
+ dim,
672
+ use_xpos = False,
673
+ scale_base = 512
674
+ ):
675
+ super().__init__()
676
+ inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
677
+ self.register_buffer('inv_freq', inv_freq)
678
+
679
+ if not use_xpos:
680
+ self.register_buffer('scale', None)
681
+ return
682
+
683
+ scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
684
+
685
+ self.scale_base = scale_base
686
+ self.register_buffer('scale', scale)
687
+
688
+ def forward(self, seq_len, device):
689
+ t = torch.arange(seq_len, device = device).type_as(self.inv_freq)
690
+ freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
691
+ freqs = torch.cat((freqs, freqs), dim = -1)
692
+
693
+ if not exists(self.scale):
694
+ return freqs, 1.
695
+
696
+ power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base
697
+ scale = self.scale ** rearrange(power, 'n -> n 1')
698
+ scale = torch.cat((scale, scale), dim = -1)
699
+
700
+ return freqs, scale
701
+
702
+
703
+ def rotate_half(x):
704
+ x = rearrange(x, '... (j d) -> ... j d', j = 2)
705
+ x1, x2 = x.unbind(dim = -2)
706
+ return torch.cat((-x2, x1), dim = -1)
707
+
708
+ def apply_rotary_pos_emb(t, freqs, scale = 1):
709
+ seq_len = t.shape[-2]
710
+ freqs = freqs[-seq_len:, :]
711
+ return (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
712
+
713
+ # norms
714
+
715
+ class Scale(nn.Module):
716
+ def __init__(self, value, fn):
717
+ super().__init__()
718
+ self.value = value
719
+ self.fn = fn
720
+
721
+ def forward(self, x, **kwargs):
722
+ out = self.fn(x, **kwargs)
723
+ scale_fn = lambda t: t * self.value
724
+
725
+ if not isinstance(out, tuple):
726
+ return scale_fn(out)
727
+
728
+ return (scale_fn(out[0]), *out[1:])
729
+
730
+ class ScaleNorm(nn.Module):
731
+ def __init__(self, dim, eps = 1e-5):
732
+ super().__init__()
733
+ self.eps = eps
734
+ self.g = nn.Parameter(torch.ones(1) * (dim ** -0.5))
735
+
736
+ def forward(self, x):
737
+ norm = torch.norm(x, dim = -1, keepdim = True)
738
+ return x / norm.clamp(min = self.eps) * self.g
739
+
740
+ class RMSNorm(nn.Module):
741
+ def __init__(self, dim, eps = 1e-8):
742
+ super().__init__()
743
+ self.scale = dim ** -0.5
744
+ self.eps = eps
745
+ self.g = nn.Parameter(torch.ones(dim))
746
+
747
+ def forward(self, x):
748
+ norm = torch.norm(x, dim = -1, keepdim = True) * self.scale
749
+ return x / norm.clamp(min = self.eps) * self.g
750
+
751
+ # residual and residual gates
752
+
753
+ class Residual(nn.Module):
754
+ def __init__(self, dim, scale_residual = False, scale_residual_constant = 1.):
755
+ super().__init__()
756
+ self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
757
+ self.scale_residual_constant = scale_residual_constant
758
+
759
+ def forward(self, x, residual):
760
+ if exists(self.residual_scale):
761
+ residual = residual * self.residual_scale
762
+
763
+ if self.scale_residual_constant != 1:
764
+ residual = residual * self.scale_residual_constant
765
+
766
+ return x + residual
767
+
768
+ class GRUGating(nn.Module):
769
+ def __init__(self, dim, scale_residual = False, **kwargs):
770
+ super().__init__()
771
+ self.gru = nn.GRUCell(dim, dim)
772
+ self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
773
+
774
+ def forward(self, x, residual):
775
+ if exists(self.residual_scale):
776
+ residual = residual * self.residual_scale
777
+
778
+ gated_output = self.gru(
779
+ rearrange(x, 'b n d -> (b n) d'),
780
+ rearrange(residual, 'b n d -> (b n) d')
781
+ )
782
+
783
+ return gated_output.reshape_as(x)
784
+
785
+ # token shifting
786
+
787
+ def shift(t, amount, mask = None):
788
+ if amount == 0:
789
+ return t
790
+ else:
791
+ amount = min(amount, t.shape[1])
792
+
793
+ if exists(mask):
794
+ t = t.masked_fill(~mask[..., None], 0.)
795
+
796
+ return pad_at_dim(t, (amount, -amount), dim = - 2, value = 0.)
797
+
798
+ class ShiftTokens(nn.Module):
799
+ def __init__(self, shifts, fn):
800
+ super().__init__()
801
+ self.fn = fn
802
+ self.shifts = tuple(shifts)
803
+
804
+ def forward(self, x, **kwargs):
805
+ mask = kwargs.get('mask', None)
806
+ shifts = self.shifts
807
+ segments = len(shifts)
808
+ feats_per_shift = x.shape[-1] // segments
809
+ splitted = x.split(feats_per_shift, dim = -1)
810
+ segments_to_shift, rest = splitted[:segments], splitted[segments:]
811
+ segments_to_shift = list(map(lambda args: shift(*args, mask = mask), zip(segments_to_shift, shifts)))
812
+ x = torch.cat((*segments_to_shift, *rest), dim = -1)
813
+ return self.fn(x, **kwargs)
814
+
815
+ # feedforward
816
+
817
+ class GLU(nn.Module):
818
+ def __init__(self, dim_in, dim_out, activation):
819
+ super().__init__()
820
+ self.act = activation
821
+ self.proj = nn.Linear(dim_in, dim_out * 2)
822
+
823
+ def forward(self, x):
824
+ x, gate = self.proj(x).chunk(2, dim = -1)
825
+ return x * self.act(gate)
826
+
827
+ class FeedForward(nn.Module):
828
+ def __init__(
829
+ self,
830
+ dim,
831
+ dim_out = None,
832
+ mult = 4,
833
+ glu = False,
834
+ swish = False,
835
+ relu_squared = False,
836
+ post_act_ln = False,
837
+ dropout = 0.,
838
+ no_bias = False,
839
+ zero_init_output = False
840
+ ):
841
+ super().__init__()
842
+ inner_dim = int(dim * mult)
843
+ dim_out = default(dim_out, dim)
844
+
845
+ if relu_squared:
846
+ activation = ReluSquared()
847
+ elif swish:
848
+ activation = nn.SiLU()
849
+ else:
850
+ activation = nn.GELU()
851
+
852
+ project_in = nn.Sequential(
853
+ nn.Linear(dim, inner_dim, bias = not no_bias),
854
+ activation
855
+ ) if not glu else GLU(dim, inner_dim, activation)
856
+
857
+ self.ff = nn.Sequential(
858
+ project_in,
859
+ nn.LayerNorm(inner_dim) if post_act_ln else nn.Identity(),
860
+ nn.Dropout(dropout),
861
+ nn.Linear(inner_dim, dim_out, bias = not no_bias)
862
+ )
863
+
864
+ # init last linear layer to 0
865
+ if zero_init_output:
866
+ init_zero_(self.ff[-1])
867
+
868
+ def forward(self, x):
869
+ return self.ff(x)
870
+
871
+ # attention. it is all we need
872
+
873
+ class Attention(nn.Module):
874
+ def __init__(
875
+ self,
876
+ dim,
877
+ dim_head = DEFAULT_DIM_HEAD,
878
+ heads = 8,
879
+ causal = False,
880
+ flash = False,
881
+ talking_heads = False,
882
+ head_scale = False,
883
+ sparse_topk = None,
884
+ num_mem_kv = 0,
885
+ dropout = 0.,
886
+ on_attn = False,
887
+ gate_values = False,
888
+ zero_init_output = False,
889
+ max_attend_past = None,
890
+ qk_norm = False,
891
+ qk_norm_groups = 1,
892
+ qk_norm_scale = 10,
893
+ qk_norm_dim_scale = False,
894
+ one_kv_head = False,
895
+ shared_kv = False,
896
+ value_dim_head = None,
897
+ tensor_product = False # https://arxiv.org/abs/2208.06061
898
+ ):
899
+ super().__init__()
900
+ self.scale = dim_head ** -0.5
901
+
902
+ self.heads = heads
903
+ self.causal = causal
904
+ self.max_attend_past = max_attend_past
905
+
906
+ value_dim_head = default(value_dim_head, dim_head)
907
+ q_dim = k_dim = dim_head * heads
908
+ v_dim = out_dim = value_dim_head * heads
909
+
910
+ self.one_kv_head = one_kv_head
911
+ if one_kv_head:
912
+ k_dim = dim_head
913
+ v_dim = value_dim_head
914
+ out_dim = v_dim * heads
915
+
916
+ self.to_q = nn.Linear(dim, q_dim, bias = False)
917
+ self.to_k = nn.Linear(dim, k_dim, bias = False)
918
+
919
+ # shared key / values, for further memory savings during inference
920
+ assert not (shared_kv and value_dim_head != dim_head), 'key and value head dimensions must be equal for shared key / values'
921
+ self.to_v = nn.Linear(dim, v_dim, bias = False) if not shared_kv else None
922
+
923
+ # relations projection from tp-attention
924
+ self.to_r = nn.Linear(dim, v_dim, bias = False) if tensor_product else None
925
+
926
+ # add GLU gating for aggregated values, from alphafold2
927
+ self.to_v_gate = None
928
+ if gate_values:
929
+ self.to_v_gate = nn.Linear(dim, out_dim)
930
+ nn.init.constant_(self.to_v_gate.weight, 0)
931
+ nn.init.constant_(self.to_v_gate.bias, 1)
932
+
933
+ # cosine sim attention
934
+ self.qk_norm = qk_norm
935
+ self.qk_norm_groups = qk_norm_groups
936
+ self.qk_norm_scale = qk_norm_scale
937
+
938
+ # whether to use the rmsnorm (equivalent to cosine sim attention when scale is equal to 1) - https://arxiv.org/abs/2302.05442
939
+ self.qk_norm_dim_scale = qk_norm_dim_scale
940
+
941
+ self.qk_norm_q_scale = self.qk_norm_k_scale = 1
942
+ if qk_norm and qk_norm_dim_scale:
943
+ self.qk_norm_q_scale = nn.Parameter(torch.ones(dim_head))
944
+ self.qk_norm_k_scale = nn.Parameter(torch.ones(dim_head))
945
+
946
+ assert (not qk_norm) or (dim_head % qk_norm_groups) == 0, 'dimension per attention head must be divisible by the qk norm groups'
947
+ assert not (qk_norm and (dim_head // qk_norm_groups) <= 2), 'the group dimension may be too small (2 was too small in my tests, but 4 still works, surprisingly)'
948
+
949
+ # attend class - includes core attention algorithm + talking heads
950
+
951
+ self.attend = Attend(
952
+ heads = heads,
953
+ causal = causal,
954
+ talking_heads = talking_heads,
955
+ dropout = dropout,
956
+ qk_norm = qk_norm,
957
+ scale = qk_norm_scale if qk_norm else self.scale,
958
+ flash = flash
959
+ )
960
+
961
+ # head scaling
962
+ self.head_scale = head_scale
963
+ if head_scale:
964
+ self.head_scale_params = nn.Parameter(torch.ones(1, heads, 1, 1))
965
+
966
+ # explicit topk sparse attention
967
+ self.sparse_topk = sparse_topk
968
+
969
+ # add memory key / values
970
+ self.num_mem_kv = num_mem_kv
971
+ if num_mem_kv > 0:
972
+ self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
973
+ self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
974
+
975
+ # attention on attention
976
+ self.attn_on_attn = on_attn
977
+ self.to_out = nn.Sequential(nn.Linear(out_dim, dim * 2, bias = False), nn.GLU()) if on_attn else nn.Linear(out_dim, dim, bias = False)
978
+
979
+ # init output projection 0
980
+ if zero_init_output:
981
+ init_zero_(self.to_out)
982
+
983
+ def forward(
984
+ self,
985
+ x,
986
+ context = None,
987
+ mask = None,
988
+ context_mask = None,
989
+ attn_mask = None,
990
+ rel_pos = None,
991
+ rotary_pos_emb = None,
992
+ prev_attn = None,
993
+ mem = None
994
+ ):
995
+ b, n, _, h, head_scale, device, has_context = *x.shape, self.heads, self.head_scale, x.device, exists(context)
996
+ kv_input = default(context, x)
997
+
998
+ q_input = x
999
+ k_input = kv_input
1000
+ v_input = kv_input
1001
+ r_input = x
1002
+
1003
+ if exists(mem):
1004
+ k_input = torch.cat((mem, k_input), dim = -2)
1005
+ v_input = torch.cat((mem, v_input), dim = -2)
1006
+
1007
+ q = self.to_q(q_input)
1008
+ k = self.to_k(k_input)
1009
+ v = self.to_v(v_input) if exists(self.to_v) else k
1010
+ r = self.to_r(r_input) if exists(self.to_r) else None
1011
+
1012
+ q = rearrange(q, 'b n (h d) -> b h n d', h = h)
1013
+
1014
+ if not self.one_kv_head:
1015
+ k, v, r = map(lambda t: maybe(rearrange)(t, 'b n (h d) -> b h n d', h = h), (k, v, r))
1016
+
1017
+ if self.qk_norm:
1018
+ qk_l2norm = partial(l2norm, groups = self.qk_norm_groups)
1019
+ q, k = map(qk_l2norm, (q, k))
1020
+ scale = self.qk_norm_scale
1021
+
1022
+ q = q * self.qk_norm_q_scale
1023
+ k = k * self.qk_norm_k_scale
1024
+
1025
+ if exists(rotary_pos_emb) and not has_context:
1026
+ freqs, xpos_scale = rotary_pos_emb
1027
+ l = freqs.shape[-1]
1028
+
1029
+ q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if exists(xpos_scale) else (1., 1.)
1030
+ (ql, qr), (kl, kr), (vl, vr) = map(lambda t: (t[..., :l], t[..., l:]), (q, k, v))
1031
+
1032
+ ql, kl, vl = map(lambda arg: apply_rotary_pos_emb(arg[0], freqs, arg[1]), ((ql, q_xpos_scale), (kl, k_xpos_scale), (vl, k_xpos_scale)))
1033
+ q, k, v = map(lambda t: torch.cat(t, dim = -1), ((ql, qr), (kl, kr), (vl, vr)))
1034
+
1035
+ input_mask = default(context_mask, mask)
1036
+
1037
+ if self.num_mem_kv > 0:
1038
+ mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b = b), (self.mem_k, self.mem_v))
1039
+
1040
+ if self.qk_norm:
1041
+ mem_k = l2norm(mem_k)
1042
+ mem_k = mem_k * self.qk_norm_k_scale
1043
+
1044
+ k = torch.cat((mem_k, k), dim = -2)
1045
+ v = torch.cat((mem_v, v), dim = -2)
1046
+
1047
+ if exists(input_mask):
1048
+ input_mask = pad_at_dim(input_mask, (self.num_mem_kv, 0), dim = -1, value = True)
1049
+
1050
+
1051
+ i, j = map(lambda t: t.shape[-2], (q, k))
1052
+
1053
+ # determine masking
1054
+
1055
+ mask_value = max_neg_value(q)
1056
+ masks = []
1057
+ final_attn_mask = None
1058
+
1059
+ if exists(input_mask):
1060
+ input_mask = rearrange(input_mask, 'b j -> b 1 1 j')
1061
+ masks.append(~input_mask)
1062
+
1063
+ if exists(attn_mask):
1064
+ assert 2 <= attn_mask.ndim <= 4, 'attention mask must have greater than 2 dimensions but less than or equal to 4'
1065
+ if attn_mask.ndim == 2:
1066
+ attn_mask = rearrange(attn_mask, 'i j -> 1 1 i j')
1067
+ elif attn_mask.ndim == 3:
1068
+ attn_mask = rearrange(attn_mask, 'h i j -> 1 h i j')
1069
+ masks.append(~attn_mask)
1070
+
1071
+ if exists(self.max_attend_past):
1072
+ range_q = torch.arange(j - i, j, device = device)
1073
+ range_k = torch.arange(j, device = device)
1074
+ dist = rearrange(range_q, 'i -> 1 1 i 1') - rearrange(range_k, 'j -> 1 1 1 j')
1075
+ max_attend_past_mask = dist > self.max_attend_past
1076
+ masks.append(max_attend_past_mask)
1077
+
1078
+ if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]:
1079
+ top, _ = dots.topk(self.sparse_topk, dim = -1)
1080
+ vk = rearrange(top[..., -1], '... -> ... 1')
1081
+ sparse_topk_mask = dots < vk
1082
+ masks.append(sparse_topk_mask)
1083
+
1084
+ if len(masks) > 0:
1085
+ final_attn_mask = or_reduce(masks)
1086
+
1087
+ # prepare relative positional bias, if needed
1088
+
1089
+ attn_bias = None
1090
+ if exists(rel_pos):
1091
+ attn_bias = rel_pos(i, j)
1092
+
1093
+ # attention is all we need
1094
+
1095
+ out, intermediates = self.attend(
1096
+ q, k, v,
1097
+ mask = final_attn_mask,
1098
+ attn_bias = attn_bias,
1099
+ prev_attn = prev_attn
1100
+ )
1101
+
1102
+ # https://arxiv.org/abs/2208.06061 proposes to add a residual for better gradients
1103
+
1104
+ if exists(r):
1105
+ out = out * r + out
1106
+
1107
+ # normformer scaling of heads
1108
+
1109
+ if head_scale:
1110
+ out = out * self.head_scale_params
1111
+
1112
+ # merge heads
1113
+
1114
+ out = rearrange(out, 'b h n d -> b n (h d)')
1115
+
1116
+ # alphafold2 styled gating of the values
1117
+
1118
+ if exists(self.to_v_gate):
1119
+ gates = self.to_v_gate(x)
1120
+ out = out * gates.sigmoid()
1121
+
1122
+ # combine the heads
1123
+
1124
+ out = self.to_out(out)
1125
+
1126
+ if exists(mask):
1127
+ mask = rearrange(mask, 'b n -> b n 1')
1128
+ out = out.masked_fill(~mask, 0.)
1129
+
1130
+ return out, intermediates
1131
+
1132
+ class AttentionLayers(nn.Module):
1133
+ def __init__(
1134
+ self,
1135
+ dim,
1136
+ depth,
1137
+ heads = 8,
1138
+ causal = False,
1139
+ cross_attend = False,
1140
+ only_cross = False,
1141
+ use_scalenorm = False,
1142
+ use_rmsnorm = False,
1143
+ alibi_pos_bias = False,
1144
+ alibi_num_heads = None,
1145
+ alibi_learned = False,
1146
+ rel_pos_bias = False,
1147
+ rel_pos_num_buckets = 32,
1148
+ rel_pos_max_distance = 128,
1149
+ dynamic_pos_bias = False,
1150
+ dynamic_pos_bias_log_distance = False,
1151
+ dynamic_pos_bias_mlp_depth = 2,
1152
+ dynamic_pos_bias_norm = False,
1153
+ rotary_pos_emb = False,
1154
+ rotary_emb_dim = None,
1155
+ rotary_xpos = False,
1156
+ rotary_xpos_scale_base = 512,
1157
+ custom_layers = None,
1158
+ sandwich_coef = None,
1159
+ par_ratio = None,
1160
+ residual_attn = False,
1161
+ cross_residual_attn = False,
1162
+ macaron = False,
1163
+ pre_norm = True,
1164
+ gate_residual = False,
1165
+ scale_residual = False,
1166
+ scale_residual_constant = 1.,
1167
+ deepnorm = False,
1168
+ shift_tokens = 0,
1169
+ sandwich_norm = False,
1170
+ resi_dual = False,
1171
+ zero_init_branch_output = False,
1172
+ layer_dropout = 0.,
1173
+ cross_attn_tokens_dropout = 0.,
1174
+ **kwargs
1175
+ ):
1176
+ super().__init__()
1177
+ rotary_pos_emb = rotary_pos_emb or rotary_xpos
1178
+
1179
+ ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)
1180
+ attn_kwargs, kwargs = groupby_prefix_and_trim('attn_', kwargs)
1181
+
1182
+ dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
1183
+
1184
+ self.dim = dim
1185
+ self.depth = depth
1186
+ self.layers = nn.ModuleList([])
1187
+
1188
+ self.has_pos_emb = rel_pos_bias or rotary_pos_emb
1189
+
1190
+ rotary_emb_dim = max(default(rotary_emb_dim, dim_head // 2), 32)
1191
+
1192
+ assert not (rotary_xpos and not causal), 'rotary xpos is not compatible with bidirectional attention'
1193
+ self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim, use_xpos = rotary_xpos, scale_base = rotary_xpos_scale_base) if rotary_pos_emb else None
1194
+
1195
+ assert not (alibi_pos_bias and rel_pos_bias), 'you can only choose Alibi positional bias or T5 relative positional bias, not both'
1196
+ assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
1197
+
1198
+ # relative positional bias
1199
+
1200
+ flash_attn = attn_kwargs.get('flash', False)
1201
+ assert (int(rel_pos_bias) + int(dynamic_pos_bias) + int(alibi_pos_bias)) <= 1, 'you can only choose up to one of t5, alibi, or dynamic positional bias'
1202
+
1203
+ self.rel_pos = None
1204
+ if rel_pos_bias:
1205
+ assert not flash_attn, 'flash attention not compatible with t5 relative positional bias'
1206
+ self.rel_pos = RelativePositionBias(scale = dim_head ** 0.5, causal = causal, heads = heads, num_buckets = rel_pos_num_buckets, max_distance = rel_pos_max_distance)
1207
+ elif dynamic_pos_bias:
1208
+ assert not flash_attn, 'flash attention not compatible with dynamic positional bias'
1209
+ self.rel_pos = DynamicPositionBias(dim = dim // 4, heads = heads, log_distance = dynamic_pos_bias_log_distance, depth = dynamic_pos_bias_mlp_depth, norm = dynamic_pos_bias_norm)
1210
+ elif alibi_pos_bias:
1211
+ alibi_num_heads = default(alibi_num_heads, heads)
1212
+ assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads'
1213
+ alibi_pos_klass = LearnedAlibiPositionalBias if alibi_learned else AlibiPositionalBias
1214
+ self.rel_pos = alibi_pos_klass(heads = alibi_num_heads, total_heads = heads)
1215
+
1216
+ # determine deepnorm and residual scale
1217
+
1218
+ if deepnorm:
1219
+ assert scale_residual_constant == 1, 'scale residual constant is being overridden by deep norm settings'
1220
+ pre_norm = sandwich_norm = resi_dual = False
1221
+ scale_residual = True
1222
+ scale_residual_constant = (2 * depth) ** 0.25
1223
+
1224
+ assert (int(sandwich_norm) + int(resi_dual)) <= 1, 'either sandwich norm or resiDual is selected, but not both'
1225
+ assert not (not pre_norm and sandwich_norm), 'sandwich norm cannot be used when not using prenorm'
1226
+ assert not (not pre_norm and resi_dual), 'resiDualcannot be used when not using prenorm'
1227
+ self.pre_norm = pre_norm
1228
+ self.sandwich_norm = sandwich_norm
1229
+ self.resi_dual = resi_dual
1230
+
1231
+ self.residual_attn = residual_attn
1232
+ self.cross_residual_attn = cross_residual_attn
1233
+ self.cross_attend = cross_attend
1234
+
1235
+ norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm
1236
+ norm_class = RMSNorm if use_rmsnorm else norm_class
1237
+ norm_fn = partial(norm_class, dim)
1238
+
1239
+ if cross_attend and not only_cross:
1240
+ default_block = ('a', 'c', 'f')
1241
+ elif cross_attend and only_cross:
1242
+ default_block = ('c', 'f')
1243
+ else:
1244
+ default_block = ('a', 'f')
1245
+
1246
+ if macaron:
1247
+ default_block = ('f',) + default_block
1248
+
1249
+ # zero init
1250
+
1251
+ if zero_init_branch_output:
1252
+ attn_kwargs = {**attn_kwargs, 'zero_init_output': True}
1253
+ ff_kwargs = {**ff_kwargs, 'zero_init_output': True}
1254
+
1255
+ # calculate layer block order
1256
+
1257
+ if exists(custom_layers):
1258
+ layer_types = custom_layers
1259
+ elif exists(par_ratio):
1260
+ par_depth = depth * len(default_block)
1261
+ assert 1 < par_ratio <= par_depth, 'par ratio out of range'
1262
+ default_block = tuple(filter(not_equals('f'), default_block))
1263
+ par_attn = par_depth // par_ratio
1264
+ depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper
1265
+ par_width = (depth_cut + depth_cut // par_attn) // par_attn
1266
+ assert len(default_block) <= par_width, 'default block is too large for par_ratio'
1267
+ par_block = default_block + ('f',) * (par_width - len(default_block))
1268
+ par_head = par_block * par_attn
1269
+ layer_types = par_head + ('f',) * (par_depth - len(par_head))
1270
+ elif exists(sandwich_coef):
1271
+ assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'
1272
+ layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef
1273
+ else:
1274
+ layer_types = default_block * depth
1275
+
1276
+ self.layer_types = layer_types
1277
+ self.num_attn_layers = len(list(filter(equals('a'), layer_types)))
1278
+
1279
+ # stochastic depth
1280
+
1281
+ self.layer_dropouts = cast_tuple(layer_dropout, len(layer_types))
1282
+
1283
+ # structured dropout for cross attending
1284
+
1285
+ self.cross_attn_tokens_dropout = cross_attn_tokens_dropout
1286
+
1287
+ # calculate token shifting
1288
+
1289
+ shift_tokens = cast_tuple(shift_tokens, len(layer_types))
1290
+
1291
+ # iterate and construct layers
1292
+
1293
+ for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)):
1294
+ is_last_layer = ind == (len(self.layer_types) - 1)
1295
+
1296
+ if layer_type == 'a':
1297
+ layer = Attention(dim, heads = heads, causal = causal, **attn_kwargs)
1298
+ elif layer_type == 'c':
1299
+ layer = Attention(dim, heads = heads, **attn_kwargs)
1300
+ elif layer_type == 'f':
1301
+ layer = FeedForward(dim, **ff_kwargs)
1302
+ layer = layer if not macaron else Scale(0.5, layer)
1303
+ else:
1304
+ raise Exception(f'invalid layer type {layer_type}')
1305
+
1306
+ if layer_shift_tokens > 0:
1307
+ shift_range_upper = layer_shift_tokens + 1
1308
+ shift_range_lower = -layer_shift_tokens if not causal else 0
1309
+ layer = ShiftTokens(range(shift_range_lower, shift_range_upper), layer)
1310
+
1311
+ residual_fn = GRUGating if gate_residual else Residual
1312
+ residual = residual_fn(dim, scale_residual = scale_residual, scale_residual_constant = scale_residual_constant)
1313
+
1314
+ pre_branch_norm = norm_fn() if pre_norm else None
1315
+ post_branch_norm = norm_fn() if sandwich_norm else None
1316
+ post_main_norm = norm_fn() if (resi_dual or not pre_norm) and not is_last_layer else None
1317
+
1318
+ norms = nn.ModuleList([
1319
+ pre_branch_norm,
1320
+ post_branch_norm,
1321
+ post_main_norm
1322
+ ])
1323
+
1324
+ self.layers.append(nn.ModuleList([
1325
+ norms,
1326
+ layer,
1327
+ residual
1328
+ ]))
1329
+
1330
+ if deepnorm:
1331
+ init_gain = (8 * depth) ** -0.25
1332
+ deepnorm_init(self, init_gain)
1333
+
1334
+ def forward(
1335
+ self,
1336
+ x,
1337
+ context = None,
1338
+ mask = None,
1339
+ context_mask = None,
1340
+ attn_mask = None,
1341
+ self_attn_context_mask = None,
1342
+ mems = None,
1343
+ return_hiddens = False
1344
+ ):
1345
+ assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross_attend is set to True'
1346
+
1347
+ hiddens = []
1348
+ intermediates = []
1349
+ prev_attn = None
1350
+ prev_cross_attn = None
1351
+
1352
+ mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
1353
+
1354
+ rotary_pos_emb = None
1355
+ if exists(self.rotary_pos_emb):
1356
+ max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + x.shape[1], mems)))
1357
+ rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device)
1358
+
1359
+ outer_residual = x
1360
+
1361
+ for ind, (layer_type, (norm, block, residual_fn), layer_dropout) in enumerate(zip(self.layer_types, self.layers, self.layer_dropouts)):
1362
+ is_last = ind == (len(self.layers) - 1)
1363
+
1364
+ if self.training and layer_dropout > 0. and random() < layer_dropout:
1365
+ continue
1366
+
1367
+ if layer_type == 'a':
1368
+ if return_hiddens:
1369
+ hiddens.append(x)
1370
+ layer_mem = mems.pop(0) if mems else None
1371
+
1372
+ if layer_type == 'c':
1373
+ if self.training and self.cross_attn_tokens_dropout > 0.:
1374
+ context, context_mask = dropout_seq(context, context_mask, self.cross_attn_tokens_dropout)
1375
+
1376
+ inner_residual = x
1377
+
1378
+ pre_norm, post_branch_norm, post_main_norm = norm
1379
+
1380
+ if exists(pre_norm) and not self.resi_dual:
1381
+ x = pre_norm(x)
1382
+
1383
+ if layer_type == 'a':
1384
+ out, inter = block(x, mask = mask, context_mask = self_attn_context_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, mem = layer_mem)
1385
+ elif layer_type == 'c':
1386
+ out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn)
1387
+ elif layer_type == 'f':
1388
+ out = block(x)
1389
+
1390
+ if self.resi_dual:
1391
+ outer_residual = residual_fn(out, outer_residual)
1392
+
1393
+ if exists(post_branch_norm):
1394
+ out = post_branch_norm(out)
1395
+
1396
+ x = residual_fn(out, inner_residual)
1397
+
1398
+ if layer_type in ('a', 'c') and return_hiddens:
1399
+ intermediates.append(inter)
1400
+
1401
+ if layer_type == 'a' and self.residual_attn:
1402
+ prev_attn = inter.pre_softmax_attn
1403
+ elif layer_type == 'c' and self.cross_residual_attn:
1404
+ prev_cross_attn = inter.pre_softmax_attn
1405
+
1406
+ if exists(post_main_norm):
1407
+ x = post_main_norm(x)
1408
+
1409
+ if self.resi_dual:
1410
+ x = x + pre_norm(outer_residual)
1411
+
1412
+ if return_hiddens:
1413
+ intermediates = LayerIntermediates(
1414
+ hiddens = hiddens,
1415
+ attn_intermediates = intermediates
1416
+ )
1417
+
1418
+ return x, intermediates
1419
+
1420
+ return x
1421
+
1422
+ class Encoder(AttentionLayers):
1423
+ def __init__(self, **kwargs):
1424
+ assert 'causal' not in kwargs, 'cannot set causality on encoder'
1425
+ super().__init__(causal = False, **kwargs)
1426
+
1427
+ class Decoder(AttentionLayers):
1428
+ def __init__(self, **kwargs):
1429
+ assert 'causal' not in kwargs, 'cannot set causality on decoder'
1430
+ super().__init__(causal = True, **kwargs)
1431
+
1432
+ class CrossAttender(AttentionLayers):
1433
+ def __init__(self, **kwargs):
1434
+ super().__init__(cross_attend = True, only_cross = True, **kwargs)
1435
+
1436
+ class ViTransformerWrapper(nn.Module):
1437
+ def __init__(
1438
+ self,
1439
+ *,
1440
+ image_size,
1441
+ patch_size,
1442
+ attn_layers,
1443
+ channels = 3,
1444
+ num_classes = None,
1445
+ dropout = 0.,
1446
+ post_emb_norm = False,
1447
+ emb_dropout = 0.
1448
+ ):
1449
+ super().__init__()
1450
+ assert isinstance(attn_layers, Encoder), 'attention layers must be an Encoder'
1451
+ assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
1452
+ dim = attn_layers.dim
1453
+ num_patches = (image_size // patch_size) ** 2
1454
+ patch_dim = channels * patch_size ** 2
1455
+
1456
+ self.patch_size = patch_size
1457
+
1458
+ self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
1459
+
1460
+ self.patch_to_embedding = nn.Sequential(
1461
+ nn.LayerNorm(patch_dim),
1462
+ nn.Linear(patch_dim, dim),
1463
+ nn.LayerNorm(dim)
1464
+ )
1465
+
1466
+ self.post_emb_norm = nn.LayerNorm(dim) if post_emb_norm else nn.Identity()
1467
+ self.dropout = nn.Dropout(emb_dropout)
1468
+
1469
+ self.attn_layers = attn_layers
1470
+ self.norm = nn.LayerNorm(dim)
1471
+ self.mlp_head = nn.Linear(dim, num_classes) if exists(num_classes) else nn.Identity()
1472
+
1473
+ def forward(
1474
+ self,
1475
+ img,
1476
+ return_embeddings = False
1477
+ ):
1478
+ p = self.patch_size
1479
+
1480
+ x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
1481
+ x = self.patch_to_embedding(x)
1482
+ n = x.shape[1]
1483
+
1484
+ x = x + self.pos_embedding[:, :n]
1485
+
1486
+ x = self.post_emb_norm(x)
1487
+ x = self.dropout(x)
1488
+
1489
+ x = self.attn_layers(x)
1490
+ x = self.norm(x)
1491
+
1492
+ if not exists(self.mlp_head) or return_embeddings:
1493
+ return x
1494
+
1495
+ x = x.mean(dim = -2)
1496
+ return self.mlp_head(x)
1497
+
1498
+ class TransformerWrapper(nn.Module):
1499
+ def __init__(
1500
+ self,
1501
+ *,
1502
+ num_tokens,
1503
+ max_seq_len,
1504
+ attn_layers,
1505
+ emb_dim = None,
1506
+ max_mem_len = 0.,
1507
+ shift_mem_down = 0,
1508
+ emb_dropout = 0.,
1509
+ post_emb_norm = False,
1510
+ num_memory_tokens = None,
1511
+ tie_embedding = False,
1512
+ logits_dim = None,
1513
+ use_abs_pos_emb = True,
1514
+ scaled_sinu_pos_emb = False,
1515
+ l2norm_embed = False,
1516
+ emb_frac_gradient = 1. # GLM-130B and Cogview successfully used this, set at 0.1
1517
+ ):
1518
+ super().__init__()
1519
+ assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
1520
+
1521
+ dim = attn_layers.dim
1522
+ emb_dim = default(emb_dim, dim)
1523
+ self.emb_dim = emb_dim
1524
+ self.num_tokens = num_tokens
1525
+ self.token_pad = num_tokens
1526
+
1527
+ self.max_seq_len = max_seq_len
1528
+ self.max_mem_len = max_mem_len
1529
+ self.shift_mem_down = shift_mem_down
1530
+
1531
+ self.l2norm_embed = l2norm_embed
1532
+ self.token_emb = TokenEmbedding(emb_dim, num_tokens, l2norm_embed = l2norm_embed)
1533
+
1534
+ if not (use_abs_pos_emb and not attn_layers.has_pos_emb):
1535
+ self.pos_emb = always(0)
1536
+ elif scaled_sinu_pos_emb:
1537
+ self.pos_emb = ScaledSinusoidalEmbedding(emb_dim)
1538
+ else:
1539
+ self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len, l2norm_embed = l2norm_embed)
1540
+
1541
+ self.emb_frac_gradient = emb_frac_gradient # fraction of the gradient that should go to the embedding, https://arxiv.org/abs/2105.13290
1542
+
1543
+ self.post_emb_norm = nn.LayerNorm(emb_dim) if post_emb_norm else nn.Identity()
1544
+ self.emb_dropout = nn.Dropout(emb_dropout)
1545
+
1546
+ self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
1547
+ self.attn_layers = attn_layers
1548
+ self.norm = nn.LayerNorm(dim)
1549
+
1550
+ self.init_()
1551
+
1552
+ logits_dim = default(logits_dim, num_tokens)
1553
+ self.to_logits = nn.Linear(dim, logits_dim) if not tie_embedding else lambda t: t @ self.token_emb.weight.t()
1554
+
1555
+ # memory tokens (like [cls]) from Memory Transformers paper
1556
+ num_memory_tokens = default(num_memory_tokens, 0)
1557
+ self.num_memory_tokens = num_memory_tokens
1558
+ if num_memory_tokens > 0:
1559
+ self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
1560
+
1561
+ def init_(self):
1562
+ if self.l2norm_embed:
1563
+ nn.init.normal_(self.token_emb.emb.weight, std = 1e-5)
1564
+ if not isinstance(self.pos_emb, always):
1565
+ nn.init.normal_(self.pos_emb.emb.weight, std = 1e-5)
1566
+ return
1567
+
1568
+ nn.init.kaiming_normal_(self.token_emb.emb.weight)
1569
+
1570
+ def forward(
1571
+ self,
1572
+ x,
1573
+ return_embeddings = False,
1574
+ return_logits_and_embeddings = False,
1575
+ return_intermediates = False,
1576
+ mask = None,
1577
+ return_mems = False,
1578
+ return_attn = False,
1579
+ mems = None,
1580
+ pos = None,
1581
+ prepend_embeds = None,
1582
+ sum_embeds = None,
1583
+ **kwargs
1584
+ ):
1585
+ b, n, device, num_mem, emb_frac_gradient = *x.shape, x.device, self.num_memory_tokens, self.emb_frac_gradient
1586
+ return_hiddens = return_mems | return_attn
1587
+
1588
+ # absolute positional embedding
1589
+
1590
+ external_pos_emb = exists(pos) and pos.dtype != torch.long
1591
+ pos_emb = self.pos_emb(x, pos = pos) if not external_pos_emb else pos
1592
+ x = self.token_emb(x) + pos_emb
1593
+
1594
+ # for summing embeddings passed externally - needs this for self-conditioning in non-autoregressive training
1595
+
1596
+ if exists(sum_embeds):
1597
+ x = x + sum_embeds
1598
+
1599
+ # post embedding norm, purportedly leads to greater stabilization
1600
+
1601
+ x = self.post_emb_norm(x)
1602
+
1603
+ # whether to append embeds, as in PaLI, for image embeddings
1604
+
1605
+ if exists(prepend_embeds):
1606
+ prepend_seq, prepend_dim = prepend_embeds.shape[1:]
1607
+ assert prepend_dim == x.shape[-1], 'prepended embeddings need to have same dimensions as text model dimensions'
1608
+
1609
+ x = torch.cat((prepend_embeds, x), dim = -2)
1610
+
1611
+ # whether to reduce the gradient going to the embedding, from cogview paper, corroborated by GLM-130B model
1612
+
1613
+ if emb_frac_gradient < 1:
1614
+ assert emb_frac_gradient > 0
1615
+ x = x * emb_frac_gradient + x.detach() * (1 - emb_frac_gradient)
1616
+
1617
+ # embedding dropout
1618
+
1619
+ x = self.emb_dropout(x)
1620
+
1621
+ x = self.project_emb(x)
1622
+
1623
+ if num_mem > 0:
1624
+ mem = repeat(self.memory_tokens, 'n d -> b n d', b = b)
1625
+ x = torch.cat((mem, x), dim = 1)
1626
+
1627
+ # auto-handle masking after appending memory tokens
1628
+ if exists(mask):
1629
+ mask = pad_at_dim(mask, (num_mem, 0), dim = -1, value = True)
1630
+
1631
+ if self.shift_mem_down and exists(mems):
1632
+ mems_l, mems_r = mems[:self.shift_mem_down], mems[self.shift_mem_down:]
1633
+ mems = [*mems_r, *mems_l]
1634
+
1635
+ if return_hiddens:
1636
+ x, intermediates = self.attn_layers(x, mask = mask, mems = mems, return_hiddens = True, **kwargs)
1637
+ else:
1638
+ x = self.attn_layers(x, mask = mask, mems = mems, **kwargs)
1639
+
1640
+ x = self.norm(x)
1641
+
1642
+ mem, x = x[:, :num_mem], x[:, num_mem:]
1643
+
1644
+ if return_logits_and_embeddings:
1645
+ out = (self.to_logits(x), x)
1646
+ elif return_embeddings:
1647
+ out = x
1648
+ else:
1649
+ out = self.to_logits(x)
1650
+
1651
+ if return_intermediates:
1652
+ return out, intermediates
1653
+
1654
+ if return_mems:
1655
+ hiddens = intermediates.hiddens
1656
+ new_mems = list(map(lambda pair: torch.cat(pair, dim = -2), zip(mems, hiddens))) if exists(mems) else hiddens
1657
+ new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems))
1658
+ return out, new_mems
1659
+
1660
+ if return_attn:
1661
+ attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
1662
+ return out, attn_maps
1663
+
1664
+ return out
1665
+
1666
+ class ContinuousTransformerWrapper(nn.Module):
1667
+ def __init__(
1668
+ self,
1669
+ *,
1670
+ max_seq_len,
1671
+ attn_layers,
1672
+ dim_in = None,
1673
+ dim_out = None,
1674
+ emb_dim = None,
1675
+ post_emb_norm = False,
1676
+ emb_dropout = 0.,
1677
+ use_abs_pos_emb = True,
1678
+ scaled_sinu_pos_emb = False
1679
+ ):
1680
+ super().__init__()
1681
+ assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
1682
+
1683
+ dim = attn_layers.dim
1684
+
1685
+ self.max_seq_len = max_seq_len
1686
+
1687
+ if not (use_abs_pos_emb and not attn_layers.has_pos_emb):
1688
+ self.pos_emb = always(0)
1689
+ elif scaled_sinu_pos_emb:
1690
+ self.pos_emb = ScaledSinusoidalEmbedding(dim)
1691
+ else:
1692
+ self.pos_emb = AbsolutePositionalEmbedding(dim, max_seq_len)
1693
+
1694
+ self.post_emb_norm = nn.LayerNorm(dim) if post_emb_norm else nn.Identity()
1695
+ self.emb_dropout = nn.Dropout(emb_dropout)
1696
+
1697
+ self.project_in = nn.Linear(dim_in, dim) if exists(dim_in) else nn.Identity()
1698
+
1699
+ self.attn_layers = attn_layers
1700
+ self.norm = nn.LayerNorm(dim)
1701
+
1702
+ self.project_out = nn.Linear(dim, dim_out) if exists(dim_out) else nn.Identity()
1703
+
1704
+ def forward(
1705
+ self,
1706
+ x,
1707
+ return_embeddings = False,
1708
+ return_intermediates = False,
1709
+ mask = None,
1710
+ return_attn = False,
1711
+ mems = None,
1712
+ pos = None,
1713
+ prepend_embeds = None,
1714
+ **kwargs
1715
+ ):
1716
+ x = self.project_in(x)
1717
+ x = x + self.pos_emb(x, pos = pos)
1718
+
1719
+ x = self.post_emb_norm(x)
1720
+
1721
+ # whether to append embeds, as in PaLI, for image embeddings
1722
+
1723
+ if exists(prepend_embeds):
1724
+ _, prepend_dim = prepend_embeds.shape[1:]
1725
+ assert prepend_dim == x.shape[-1], 'prepended embeddings need to have same dimensions as model dimensions'
1726
+
1727
+ x = torch.cat((prepend_embeds, x), dim = -2)
1728
+
1729
+ x = self.emb_dropout(x)
1730
+
1731
+ x, intermediates = self.attn_layers(x, mask = mask, mems = mems, return_hiddens = True, **kwargs)
1732
+ x = self.norm(x)
1733
+
1734
+ out = self.project_out(x) if not return_embeddings else x
1735
+
1736
+ if return_intermediates:
1737
+ return out, intermediates
1738
+
1739
+ if return_attn:
1740
+ attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
1741
+ return out, attn_maps
1742
+
1743
+ return out
1744
+
1745
+ class XTransformer(nn.Module):
1746
+ def __init__(
1747
+ self,
1748
+ *,
1749
+ dim,
1750
+ tie_token_emb = False,
1751
+ ignore_index = -100,
1752
+ pad_value = 0,
1753
+ deepnorm = False,
1754
+ cross_attn_tokens_dropout = 0.,
1755
+ **kwargs
1756
+ ):
1757
+ super().__init__()
1758
+ enc_kwargs, kwargs = groupby_prefix_and_trim('enc_', kwargs)
1759
+ dec_kwargs, kwargs = groupby_prefix_and_trim('dec_', kwargs)
1760
+
1761
+ assert 'dim' not in enc_kwargs and 'dim' not in dec_kwargs, 'dimension of either encoder or decoder must be set with `dim` keyword'
1762
+ enc_transformer_kwargs = pick_and_pop(['num_tokens', 'max_seq_len'], enc_kwargs)
1763
+ enc_transformer_kwargs['emb_dropout'] = enc_kwargs.pop('emb_dropout', 0)
1764
+ enc_transformer_kwargs['num_memory_tokens'] = enc_kwargs.pop('num_memory_tokens', None)
1765
+ enc_transformer_kwargs['scaled_sinu_pos_emb'] = enc_kwargs.pop('scaled_sinu_pos_emb', False)
1766
+ enc_transformer_kwargs['use_abs_pos_emb'] = enc_kwargs.pop('use_abs_pos_emb', True)
1767
+
1768
+ dec_transformer_kwargs = pick_and_pop(['num_tokens', 'max_seq_len'], dec_kwargs)
1769
+ dec_transformer_kwargs['emb_dropout'] = dec_kwargs.pop('emb_dropout', 0)
1770
+ dec_transformer_kwargs['scaled_sinu_pos_emb'] = dec_kwargs.pop('scaled_sinu_pos_emb', False)
1771
+ dec_transformer_kwargs['use_abs_pos_emb'] = dec_kwargs.pop('use_abs_pos_emb', True)
1772
+
1773
+ self.cross_attn_tokens_dropout = cross_attn_tokens_dropout # how many tokens from the encoder to dropout when cross attending from decoder - seen in a couple papers, including Perceiver AR - this will also be very effective regularization when cross attending to very long memories
1774
+
1775
+ if deepnorm:
1776
+ enc_kwargs['scale_residual'] = True
1777
+ dec_kwargs['scale_residual'] = True
1778
+
1779
+ enc_depth = enc_kwargs['depth']
1780
+ dec_depth = dec_kwargs['depth']
1781
+
1782
+ enc_kwargs['scale_residual_constant'] = 0.81 * ((enc_depth ** 4) * dec_depth) ** .0625
1783
+ dec_kwargs['scale_residual_constant'] = (3 * dec_depth) ** 0.25
1784
+
1785
+ self.encoder = TransformerWrapper(
1786
+ **enc_transformer_kwargs,
1787
+ attn_layers = Encoder(dim = dim, **enc_kwargs)
1788
+ )
1789
+
1790
+ self.decoder = TransformerWrapper(
1791
+ **dec_transformer_kwargs,
1792
+ attn_layers = Decoder(dim = dim, cross_attend = True, **dec_kwargs)
1793
+ )
1794
+
1795
+ if deepnorm:
1796
+ deepnorm_init(self.encoder, 0.87 * ((enc_depth ** 4) * dec_depth) ** -0.0625)
1797
+ deepnorm_init(self.decoder, (12 * dec_depth) ** -0.25)
1798
+
1799
+ if tie_token_emb:
1800
+ self.decoder.token_emb = self.encoder.token_emb
1801
+
1802
+ self.decoder = AutoregressiveWrapper(self.decoder, ignore_index=ignore_index, pad_value=pad_value)
1803
+
1804
+ @torch.no_grad()
1805
+ def generate(self, seq_in, seq_out_start, seq_len, mask = None, attn_mask = None, **kwargs):
1806
+ encodings = self.encoder(seq_in, mask = mask, attn_mask = attn_mask, return_embeddings = True)
1807
+ return self.decoder.generate(seq_out_start, seq_len, context = encodings, context_mask = mask, **kwargs)
1808
+
1809
+ def forward(self, src, tgt, mask = None, attn_mask = None, src_prepend_embeds = None):
1810
+
1811
+ if exists(src_prepend_embeds) and exists(mask):
1812
+ mask = pad_at_dim(mask, (src_prepend_embeds.shape[-2], 0), dim = -1, value = True)
1813
+
1814
+ enc = self.encoder(src, mask = mask, attn_mask = attn_mask, prepend_embeds = src_prepend_embeds, return_embeddings = True)
1815
+
1816
+ if self.training and self.cross_attn_tokens_dropout > 0:
1817
+ enc, mask = dropout_seq(enc, mask, self.cross_attn_tokens_dropout)
1818
+
1819
+ out = self.decoder(tgt, context = enc, context_mask = mask)
1820
+ return out
1821
+
1822
+ #===================================================================================================================
1823
+
1824
+ def exists(val):
1825
+ return val is not None
1826
+
1827
+ def eval_decorator(fn):
1828
+ def inner(self, *args, **kwargs):
1829
+ was_training = self.training
1830
+ self.eval()
1831
+ out = fn(self, *args, **kwargs)
1832
+ self.train(was_training)
1833
+ return out
1834
+ return inner
1835
+
1836
+ # nucleus
1837
+
1838
+ def top_p(logits, thres = 0.9):
1839
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
1840
+ cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
1841
+
1842
+ sorted_indices_to_remove = cum_probs > (1 - thres)
1843
+ sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
1844
+ sorted_indices_to_remove[:, 0] = 0
1845
+
1846
+ sorted_logits[sorted_indices_to_remove] = float('-inf')
1847
+ return sorted_logits.scatter(1, sorted_indices, sorted_logits)
1848
+
1849
+ # topk
1850
+
1851
+ def top_k(logits, thres = 0.9):
1852
+ k = ceil((1 - thres) * logits.shape[-1])
1853
+ val, ind = torch.topk(logits, k)
1854
+ probs = torch.full_like(logits, float('-inf'))
1855
+ probs.scatter_(1, ind, val)
1856
+ return probs
1857
+
1858
+ # top_a
1859
+
1860
+ def top_a(logits, min_p_pow=2.0, min_p_ratio=0.02):
1861
+ probs = F.softmax(logits, dim=-1)
1862
+ limit = torch.pow(torch.max(probs), min_p_pow) * min_p_ratio
1863
+ logits[probs < limit] = float('-inf')
1864
+ logits[probs >= limit] = 1
1865
+ return logits
1866
+
1867
+ # autoregressive wrapper class
1868
+
1869
+ class AutoregressiveWrapper(nn.Module):
1870
+ def __init__(
1871
+ self,
1872
+ net,
1873
+ ignore_index = -100,
1874
+ pad_value = 0,
1875
+ mask_prob = 0.
1876
+ ):
1877
+ super().__init__()
1878
+ self.pad_value = pad_value
1879
+ self.ignore_index = ignore_index
1880
+
1881
+ self.net = net
1882
+ self.max_seq_len = net.max_seq_len
1883
+
1884
+ # paper shows masking (MLM) in conjunction with autoregressive decoder-only training leads to big improvements https://arxiv.org/abs/2210.13432
1885
+ assert mask_prob < 1.
1886
+ self.mask_prob = mask_prob
1887
+
1888
+ @torch.no_grad()
1889
+ @eval_decorator
1890
+ def generate(
1891
+ self,
1892
+ start_tokens,
1893
+ seq_len,
1894
+ eos_token = None,
1895
+ temperature = 1.,
1896
+ filter_logits_fn = top_k,
1897
+ filter_thres = 0.9,
1898
+ min_p_pow = 2.0,
1899
+ min_p_ratio = 0.02,
1900
+ verbose=True,
1901
+ return_prime=False,
1902
+ **kwargs
1903
+ ):
1904
+ device = start_tokens.device
1905
+ num_dims = start_tokens.ndim
1906
+
1907
+ start_tokens, ps = pack([start_tokens], '* n')
1908
+
1909
+ b, t = start_tokens.shape
1910
+
1911
+ out = start_tokens
1912
+
1913
+ if verbose:
1914
+ print("Generating sequence of max length:", seq_len)
1915
+
1916
+ for s in range(seq_len):
1917
+ x = out[:, -self.max_seq_len:]
1918
+
1919
+ logits = self.net(x, **kwargs)[:, -1]
1920
+
1921
+ if filter_logits_fn in {top_k, top_p}:
1922
+ filtered_logits = filter_logits_fn(logits, thres = filter_thres)
1923
+ probs = F.softmax(filtered_logits / temperature, dim=-1)
1924
+
1925
+ elif filter_logits_fn is top_a:
1926
+ filtered_logits = filter_logits_fn(logits, min_p_pow = min_p_pow, min_p_ratio= min_p_ratio)
1927
+ probs = F.softmax(filtered_logits / temperature, dim=-1)
1928
+
1929
+ sample = torch.multinomial(probs, 1)
1930
+
1931
+ out = torch.cat((out, sample), dim=-1)
1932
+
1933
+ if verbose:
1934
+ if s % 32 == 0:
1935
+ print(s, '/', seq_len)
1936
+
1937
+ if exists(eos_token):
1938
+ is_eos_tokens = (out == eos_token)
1939
+
1940
+ if is_eos_tokens.any(dim = -1).all():
1941
+ # mask out everything after the eos tokens
1942
+ shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
1943
+ mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1
1944
+ out = out.masked_fill(mask, self.pad_value)
1945
+
1946
+ if verbose:
1947
+ print('Model called the end of sequence at:', s, '/', seq_len)
1948
+
1949
+ break
1950
+
1951
+ if return_prime:
1952
+ return out[:, :]
1953
+
1954
+ else:
1955
+ return out[:, t:]
1956
+
1957
+ out, = unpack(out, ps, '* n')
1958
+
1959
+ return out
1960
+
1961
+ def compute_accuracy(self, logits, labels):
1962
+ out = torch.argmax(logits, dim=-1)
1963
+ out = out.flatten()
1964
+ labels = labels.flatten()
1965
+
1966
+ mask = (labels != 999999) # dummy pad value / supposed to be self.token_pad / will fix later
1967
+ out = out[mask]
1968
+ labels = labels[mask]
1969
+
1970
+ num_right = (out == labels)
1971
+ num_right = torch.sum(num_right).type(torch.float32)
1972
+
1973
+ acc = num_right / len(labels)
1974
+ return acc
1975
+
1976
+ def forward(self, x, labels = None, **kwargs):
1977
+ seq, ignore_index = x.shape[1], self.ignore_index
1978
+
1979
+ inp, target = x[:, :-1], x[:, 1:]
1980
+
1981
+ if self.mask_prob > 0.:
1982
+ rand = torch.randn(inp.shape, device = x.device)
1983
+ rand[:, 0] = -torch.finfo(rand.dtype).max # first token should not be masked out
1984
+ num_mask = min(int(seq * self.mask_prob), seq - 1)
1985
+ indices = rand.topk(num_mask, dim = -1).indices
1986
+ mask = ~torch.zeros_like(inp).scatter(1, indices, 1.).bool()
1987
+ kwargs.update(self_attn_context_mask = mask)
1988
+
1989
+ logits = self.net(inp, **kwargs)
1990
+
1991
+ acc = self.compute_accuracy(logits, target)
1992
+
1993
+ loss = F.cross_entropy(
1994
+ rearrange(logits, 'b n c -> b c n'),
1995
+ target,
1996
+ ignore_index = ignore_index
1997
+ )
1998
+
1999
+ return loss, acc
2000
+
2001
+ #===================================================================================================================