asigalov61 commited on
Commit
938c807
1 Parent(s): d2983d5

Upload 2 files

Browse files
Files changed (2) hide show
  1. midi_to_colab_audio.py +0 -0
  2. x_transformer_1_23_2.py +2474 -0
midi_to_colab_audio.py ADDED
The diff for this file is too large to render. See raw diff
 
x_transformer_1_23_2.py ADDED
@@ -0,0 +1,2474 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #===================================================================================================================
2
+ #
3
+ # X Trasformer Module
4
+ #
5
+ # Partial x-transformers code With useful modifications
6
+ #
7
+ # Version 1.0
8
+ #
9
+ # Original source code courtesy of lucidrains
10
+ # https://github.com/lucidrains/x-transformers
11
+ #
12
+ # Original source code retrieved on 10/10/2023
13
+ #
14
+ # Project Los Angeles
15
+ # Tegridy Code 2023
16
+
17
+ #===================================================================================================================
18
+
19
+ # Critical dependencies
20
+ #
21
+ # !pip install torch
22
+ # !pip install einops
23
+
24
+ #===================================================================================================================
25
+
26
+ from functools import partial
27
+ from typing import Optional, Tuple
28
+
29
+ import os
30
+ os.environ['USE_FLASH_ATTENTION'] = '1'
31
+
32
+ import torch
33
+ from torch import nn, einsum, Tensor
34
+ import torch.nn.functional as F
35
+
36
+ # Flash attention
37
+ from torch.nn.attention import SDPBackend, sdpa_kernel
38
+ torch.backends.cuda.enable_flash_sdp(True)
39
+
40
+ from collections import namedtuple
41
+ from functools import wraps
42
+ from packaging import version
43
+ from dataclasses import dataclass
44
+
45
+ from einops import rearrange, repeat
46
+
47
+ # constants
48
+
49
+ EfficientAttentionConfig = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
50
+
51
+ @dataclass
52
+ class Intermediates:
53
+ qk_similarities: Optional[Tensor] = None
54
+ pre_softmax_attn: Optional[Tensor] = None
55
+ post_softmax_attn: Optional[Tensor] = None
56
+ cached_kv: Optional[Tuple[Tensor, Tensor]] = None
57
+
58
+ def to_tuple(self):
59
+ return (self.qk_similarities, self.pre_softmax_attn, self.post_softmax_attn)
60
+
61
+ # helpers
62
+
63
+ def exists(val):
64
+ return val is not None
65
+
66
+ def default(val, d):
67
+ return val if exists(val) else d
68
+
69
+ def compact(arr):
70
+ return [*filter(exists, arr)]
71
+
72
+ def once(fn):
73
+ called = False
74
+ @wraps(fn)
75
+ def inner(x):
76
+ nonlocal called
77
+ if called:
78
+ return
79
+ called = True
80
+ return fn(x)
81
+ return inner
82
+
83
+ print_once = once(print)
84
+
85
+ # functions for creating causal mask
86
+ # need a special one for onnx cpu (no support for .triu)
87
+
88
+ def create_causal_mask(i, j, device):
89
+ return torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1)
90
+
91
+ def onnx_create_causal_mask(i, j, device):
92
+ r = torch.arange(i, device = device)
93
+ causal_mask = rearrange(r, 'i -> i 1') < rearrange(r, 'j -> 1 j')
94
+ causal_mask = F.pad(causal_mask, (j - i, 0), value = False)
95
+ return causal_mask
96
+
97
+ # main class
98
+
99
+ class Attend(nn.Module):
100
+ def __init__(
101
+ self,
102
+ *,
103
+ dropout = 0.,
104
+ causal = False,
105
+ heads = None,
106
+ talking_heads = False,
107
+ sparse_topk = None,
108
+ scale = None,
109
+ qk_norm = False,
110
+ flash = False,
111
+ add_zero_kv = False,
112
+ onnxable = False
113
+ ):
114
+ super().__init__()
115
+ self.scale = scale
116
+ self.qk_norm = qk_norm
117
+
118
+ self.causal = causal
119
+ self.create_causal_mask = onnx_create_causal_mask if onnxable else create_causal_mask
120
+
121
+ self.attn_fn = partial(F.softmax, dtype = torch.float32) if not qk_norm else F.softmax
122
+
123
+ self.dropout = dropout
124
+ self.attn_dropout = nn.Dropout(dropout)
125
+
126
+ # talking heads
127
+
128
+ assert not (flash and talking_heads), 'talking heads not compatible with flash attention'
129
+
130
+ self.talking_heads = talking_heads
131
+ if talking_heads:
132
+ self.pre_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False)
133
+ self.post_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False)
134
+
135
+ # sparse topk
136
+
137
+ assert not (flash and sparse_topk), 'sparse topk not compatible with flash attention'
138
+ self.sparse_topk = sparse_topk
139
+
140
+ # add a key / value token composed of zeros
141
+ # in case this helps controlling outliers, proposed by https://www.evanmiller.org/attention-is-off-by-one.html
142
+
143
+ self.add_zero_kv = add_zero_kv
144
+
145
+ # flash attention
146
+
147
+ self.flash = flash
148
+ assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
149
+
150
+ # determine efficient attention configs for cuda and cpu
151
+
152
+ self.cpu_config = EfficientAttentionConfig(True, True, True)
153
+ self.cuda_config = None
154
+
155
+ if not torch.cuda.is_available() or not flash:
156
+ return
157
+
158
+ device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
159
+
160
+ major, minor = device_properties.major, device_properties.minor
161
+
162
+ if (major, minor) == (8, 0):
163
+ print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
164
+ self.cuda_config = EfficientAttentionConfig(True, False, False)
165
+ elif (major, minor) == (9, 0):
166
+ print_once('H100 GPU detected, using flash attention')
167
+ self.cuda_config = EfficientAttentionConfig(True, False, False)
168
+ else:
169
+ print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
170
+ self.cuda_config = EfficientAttentionConfig(False, True, True)
171
+
172
+ def flash_attn(
173
+ self,
174
+ q, k, v,
175
+ mask = None,
176
+ attn_bias = None
177
+ ):
178
+ batch, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
179
+
180
+ # Recommended for multi-query single-key-value attention by Tri Dao
181
+ # kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64])
182
+
183
+ if k.ndim == 3:
184
+ k = rearrange(k, 'b ... -> b 1 ...').expand_as(q)
185
+
186
+ if v.ndim == 3:
187
+ v = rearrange(v, 'b ... -> b 1 ...').expand_as(q)
188
+
189
+ # handle scale - by default they scale by dim_head ** -0.5, but need to take care if using cosine sim attention
190
+
191
+ if self.qk_norm:
192
+ default_scale = q.shape[-1] ** -0.5
193
+ q = q * (self.scale / default_scale)
194
+
195
+ # Check if mask exists and expand to compatible shape
196
+ # The mask is B L, so it would have to be expanded to B H N L
197
+
198
+ causal = self.causal
199
+
200
+ # in the case of kv caching with one token (q_len == 1), just turn off causal masking
201
+ # in speculative decoding, this may go up to 5-6, so right aligned causal mask will be needed there
202
+
203
+ if q_len == 1 and causal:
204
+ causal = False
205
+
206
+ # expand key padding mask
207
+
208
+ if exists(mask):
209
+ assert mask.ndim == 4
210
+ mask = mask.expand(batch, heads, q_len, k_len)
211
+
212
+ # handle kv cache - this should be bypassable in updated flash attention 2
213
+
214
+ if k_len > q_len and causal:
215
+ causal_mask = self.create_causal_mask(q_len, k_len, device = device)
216
+ if not exists(mask):
217
+ mask = ~causal_mask
218
+ else:
219
+ mask = mask & ~causal_mask
220
+ causal = False
221
+
222
+ # manually handle causal mask, if another mask was given
223
+
224
+ row_is_entirely_masked = None
225
+
226
+ if exists(mask) and causal:
227
+ causal_mask = self.create_causal_mask(q_len, k_len, device = device)
228
+ mask = mask & ~causal_mask
229
+
230
+ # protect against an entire row being masked out
231
+
232
+ row_is_entirely_masked = ~mask.any(dim = -1)
233
+ mask[..., 0] = mask[..., 0] | row_is_entirely_masked
234
+
235
+ causal = False
236
+
237
+ # handle alibi positional bias
238
+ # convert from bool to float
239
+
240
+ if exists(attn_bias):
241
+ attn_bias = rearrange(attn_bias, 'h i j -> 1 h i j').expand(batch, heads, -1, -1)
242
+
243
+ # if mask given, the mask would already contain the causal mask from above logic
244
+ # otherwise, if no mask given but still causal, mask out alibi positional bias to a large negative number
245
+
246
+ mask_value = -torch.finfo(q.dtype).max
247
+
248
+ if exists(mask):
249
+ attn_bias = attn_bias.masked_fill(~mask, mask_value // 2)
250
+ elif causal:
251
+ causal_mask = self.create_causal_mask(q_len, k_len, device = device)
252
+ attn_bias = attn_bias.masked_fill(causal_mask, mask_value // 2)
253
+ causal = False
254
+
255
+ # scaled_dot_product_attention handles attn_mask either as bool or additive bias
256
+ # make it an additive bias here
257
+
258
+ mask = attn_bias
259
+
260
+ # Check if there is a compatible device for flash attention
261
+
262
+ config = self.cuda_config if is_cuda else self.cpu_config
263
+
264
+ # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
265
+
266
+ # Legacy code...
267
+ # with torch.backends.cuda.sdp_kernel(enable_math=True, enable_mem_efficient=True):
268
+ # with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]):
269
+
270
+ # PyTorch 2.3-2.4 SDPA backend code...
271
+ with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.FLASH_ATTENTION, SDPBackend.CUDNN_ATTENTION]):
272
+
273
+ # New PyTorch 2.5 SDPA backend code:
274
+ # with sdpa_kernel(SDPBackend.CUDNN_ATTENTION):
275
+
276
+ out = F.scaled_dot_product_attention(
277
+ q, k, v,
278
+ attn_mask = mask,
279
+ dropout_p = self.dropout if self.training else 0.,
280
+ is_causal = causal
281
+ )
282
+
283
+ # for a row that is entirely masked out, should zero out the output of that row token
284
+
285
+ if exists(row_is_entirely_masked):
286
+ out = out.masked_fill(row_is_entirely_masked[..., None], 0.)
287
+
288
+ return out, Intermediates()
289
+
290
+ def forward(
291
+ self,
292
+ q, k, v,
293
+ mask = None,
294
+ attn_bias = None,
295
+ prev_attn = None
296
+ ):
297
+ """
298
+ einstein notation
299
+ b - batch
300
+ h - heads
301
+ n, i, j - sequence length (base sequence length, source, target)
302
+ d - feature dimension
303
+ """
304
+
305
+ n, heads, kv_heads, device = q.shape[-2], q.shape[1], k.shape[1], q.device
306
+
307
+ scale = default(self.scale, q.shape[-1] ** -0.5)
308
+
309
+ causal = self.causal
310
+
311
+ # handle kv cached decoding
312
+
313
+ if n == 1 and causal:
314
+ causal = False
315
+
316
+ # handle grouped multi-query attention
317
+
318
+ if kv_heads == 1:
319
+ k, v = map(lambda t: rearrange(t, 'b 1 n d -> b n d'), (k, v))
320
+ elif kv_heads < heads:
321
+ k, v = map(lambda t: repeat(t, 'b kvh n d -> b (r kvh) n d', r = heads // kv_heads), (k, v))
322
+
323
+ # handle zero kv, as means for allowing network to attend to nothing
324
+
325
+ if self.add_zero_kv:
326
+ k, v = map(lambda t: F.pad(t, (0, 0, 1, 0), value = 0.), (k, v))
327
+
328
+ if exists(mask):
329
+ mask = F.pad(mask, (1, 0), value = True)
330
+
331
+ if exists(attn_bias):
332
+ attn_bias = F.pad(attn_bias, (1, 0), value = 0.)
333
+
334
+ if self.flash:
335
+ assert not exists(prev_attn), 'residual attention not compatible with flash attention'
336
+ return self.flash_attn(q, k, v, mask = mask, attn_bias = attn_bias)
337
+
338
+ kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d'
339
+
340
+ dots = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k) * scale
341
+
342
+ if exists(prev_attn):
343
+ dots = dots + prev_attn
344
+
345
+ qk_similarities = dots.clone()
346
+
347
+ if self.talking_heads:
348
+ dots = self.pre_softmax_talking_heads(dots)
349
+
350
+ if exists(attn_bias):
351
+ dots = dots + attn_bias
352
+
353
+ i, j, dtype = *dots.shape[-2:], dots.dtype
354
+
355
+ mask_value = -torch.finfo(dots.dtype).max
356
+
357
+ if exists(self.sparse_topk) and self.sparse_topk < j:
358
+ top_values, _ = dots.topk(self.sparse_topk, dim = -1)
359
+ sparse_topk_mask = dots < top_values[..., -1:]
360
+ mask = (mask & sparse_topk_mask) if exists(mask) else sparse_topk_mask
361
+
362
+ if exists(mask):
363
+ dots = dots.masked_fill(~mask, mask_value)
364
+
365
+ if causal:
366
+ causal_mask = self.create_causal_mask(i, j, device = device)
367
+ dots = dots.masked_fill(causal_mask, mask_value)
368
+
369
+ pre_softmax_attn = dots.clone()
370
+
371
+ attn = self.attn_fn(dots, dim = -1)
372
+ attn = attn.type(dtype)
373
+
374
+ post_softmax_attn = attn.clone()
375
+
376
+ attn = self.attn_dropout(attn)
377
+
378
+ if self.talking_heads:
379
+ attn = self.post_softmax_talking_heads(attn)
380
+
381
+ out = einsum(f'b h i j, {kv_einsum_eq} -> b h i d', attn, v)
382
+
383
+ intermediates = Intermediates(
384
+ qk_similarities = qk_similarities,
385
+ pre_softmax_attn = pre_softmax_attn,
386
+ post_softmax_attn = post_softmax_attn
387
+ )
388
+
389
+ return out, intermediates
390
+
391
+ #===================================================================================================================
392
+
393
+ from math import ceil, log
394
+ from typing import Optional, Union, Tuple, Callable
395
+
396
+ import torch
397
+ from torch import nn, Tensor
398
+ from torch.nn import Module
399
+ import torch.nn.functional as F
400
+
401
+ from einops import rearrange, pack, unpack
402
+
403
+ def exists(val):
404
+ return val is not None
405
+
406
+ def default(val, d):
407
+ return val if exists(val) else d
408
+
409
+ def identity(t, *args, **kwargs):
410
+ return t
411
+
412
+ def cast_tuple(t, length = 1):
413
+ return t if isinstance(t, tuple) else (t,) * length
414
+
415
+ def eval_decorator(fn):
416
+ def inner(self, *args, **kwargs):
417
+ was_training = self.training
418
+ self.eval()
419
+ out = fn(self, *args, **kwargs)
420
+ self.train(was_training)
421
+ return out
422
+ return inner
423
+
424
+ # for variable lengthed prefixes
425
+
426
+ def align_right(t, lens, pad_id = 0):
427
+ batch, seq_len, device, dtype = *t.shape, t.device, t.dtype
428
+
429
+ assert lens.ndim == 1 and lens.shape[0] == batch
430
+ assert lens.amax() <= seq_len
431
+
432
+ pad_lens = seq_len - lens
433
+ max_pad_len = pad_lens.amax()
434
+
435
+ batch_arange = torch.arange(batch, device = device, dtype = torch.long)[..., None]
436
+ prompt_len_arange = torch.arange(seq_len, device = device, dtype = torch.long)
437
+
438
+ t = F.pad(t, (max_pad_len, 0), value = 0)
439
+ offset = max_pad_len - pad_lens
440
+
441
+ aligned = t[batch_arange, prompt_len_arange + offset[..., None]]
442
+ return aligned
443
+
444
+ # nucleus
445
+
446
+ def top_p(logits, thres = 0.9):
447
+ sorted_logits, sorted_indices = torch.sort(logits, descending = True)
448
+ cum_probs = torch.cumsum(F.softmax(sorted_logits, dim = -1), dim = -1)
449
+
450
+ sorted_indices_to_remove = cum_probs > thres
451
+ sorted_indices_to_remove = F.pad(sorted_indices_to_remove, (1, -1), value = False)
452
+
453
+ sorted_logits[sorted_indices_to_remove] = float('-inf')
454
+ return sorted_logits.scatter(1, sorted_indices, sorted_logits)
455
+
456
+ # topk
457
+
458
+ def top_k(logits, frac_num_tokens = 0.1, k = None):
459
+ num_tokens = logits.shape[-1]
460
+
461
+ k = default(k, ceil(frac_num_tokens * num_tokens))
462
+ k = min(k, num_tokens)
463
+
464
+ val, ind = torch.topk(logits, k)
465
+ probs = torch.full_like(logits, float('-inf'))
466
+ probs.scatter_(1, ind, val)
467
+ return probs
468
+
469
+ # top_a
470
+
471
+ def top_a(logits, min_p_pow = 2.0, min_p_ratio = 0.02):
472
+ probs = F.softmax(logits, dim = -1)
473
+ max_probs = torch.amax(probs, dim = -1, keepdim = True)
474
+ limit = torch.pow(max_probs, min_p_pow) * min_p_ratio
475
+ return torch.where(probs < limit, float('-inf'), logits)
476
+
477
+ # contrastive decoding function
478
+
479
+ def contrastive_decode_fn(
480
+ expert_logits,
481
+ amateur_logits,
482
+ alpha = 0.1,
483
+ beta = 0.5
484
+ ):
485
+ """
486
+ Appendix A Algorithm 2
487
+ https://arxiv.org/abs/2309.09117
488
+ """
489
+
490
+ cutoff = log(alpha) + expert_logits.amax(dim = -1, keepdim = True)
491
+ diffs = (1 + beta) * expert_logits - beta * amateur_logits
492
+ contrastive_decode_logits = diffs.masked_fill(expert_logits < cutoff, -torch.finfo(expert_logits.dtype).max)
493
+ return contrastive_decode_logits
494
+
495
+ # autoregressive wrapper class
496
+
497
+ class AutoregressiveWrapper(Module):
498
+ def __init__(
499
+ self,
500
+ net,
501
+ ignore_index = -100,
502
+ pad_value = 0,
503
+ mask_prob = 0.,
504
+ add_attn_z_loss = False
505
+ ):
506
+ super().__init__()
507
+ self.pad_value = pad_value
508
+ self.ignore_index = ignore_index
509
+
510
+ self.net = net
511
+ self.max_seq_len = net.max_seq_len
512
+
513
+ # paper shows masking (MLM) in conjunction with autoregressive decoder-only training leads to big improvements https://arxiv.org/abs/2210.13432
514
+ assert mask_prob < 1.
515
+ self.mask_prob = mask_prob
516
+
517
+ # whether to add router z-loss
518
+ self.add_attn_z_loss = add_attn_z_loss
519
+
520
+ @torch.inference_mode()
521
+ @eval_decorator
522
+ def generate(
523
+ self,
524
+ prompts,
525
+ seq_len,
526
+ eos_token = None,
527
+ temperature = 1.,
528
+ prompt_lens: Optional[Tensor] = None,
529
+ filter_logits_fn: Callable = top_k,
530
+ restrict_to_max_seq_len = True,
531
+ amateur_model: Optional[Union[Module, Tuple[Module]]] = None,
532
+ filter_kwargs: dict = dict(),
533
+ contrastive_decode_kwargs: Union[dict, Tuple[dict]] = dict(
534
+ beta = 0.5,
535
+ alpha = 0.1
536
+ ),
537
+ cache_kv = True,
538
+ verbose=True,
539
+ return_prime=False,
540
+ **kwargs
541
+ ):
542
+ max_seq_len, device = self.max_seq_len, prompts.device
543
+
544
+ prompts, ps = pack([prompts], '* n')
545
+
546
+ b, t = prompts.shape
547
+
548
+ # handle variable lengthed prompts (prefixes)
549
+
550
+ seq_start_pos = None
551
+ if exists(prompt_lens):
552
+ prompts = align_right(prompts, prompt_lens, pad_id = self.pad_value)
553
+ seq_start_pos = t - prompt_lens
554
+
555
+ # output from which sampled tokens appended to
556
+
557
+ out = prompts
558
+
559
+ if verbose:
560
+ print("Generating sequence of max length:", seq_len)
561
+
562
+ # kv caches
563
+
564
+ cache = None
565
+
566
+ # if doing contrastive decoding, turn off filter automatically
567
+
568
+ if exists(amateur_model):
569
+ amateur_model = cast_tuple(amateur_model)
570
+ contrastive_decode_kwargs = cast_tuple(contrastive_decode_kwargs)
571
+
572
+ assert len(amateur_model) == len(contrastive_decode_kwargs)
573
+
574
+ amateur_caches = [None] * len(amateur_model)
575
+ filter_logits_fn = identity
576
+
577
+ for i, module in enumerate(amateur_model):
578
+ if isinstance(module, AutoregressiveWrapper):
579
+ amateur_model[i] = module.net
580
+
581
+ module.eval()
582
+
583
+ # sampling up to seq_len
584
+
585
+ for sl in range(seq_len):
586
+
587
+ if restrict_to_max_seq_len:
588
+ x = out[:, -max_seq_len:]
589
+
590
+ if exists(cache):
591
+ for inter in cache.attn_intermediates:
592
+ inter.cached_kv = [t[..., -(max_seq_len - 1):, :] for t in inter.cached_kv]
593
+
594
+ logits, new_cache = self.net(
595
+ x,
596
+ return_intermediates = True,
597
+ cache = cache,
598
+ seq_start_pos = seq_start_pos,
599
+ **kwargs
600
+ )
601
+
602
+ if cache_kv and self.net.can_cache_kv:
603
+ cache = new_cache
604
+
605
+ logits = logits[:, -1]
606
+
607
+ # handle contrastive decoding, Li et al.
608
+ # https://arxiv.org/abs/2210.15097
609
+
610
+ if exists(amateur_model):
611
+ for i, (amateur, amateur_cache, amateur_contrastive_decode_kwargs) in enumerate(zip(amateur_model, amateur_caches, contrastive_decode_kwargs)):
612
+ amateur_logits, next_amateur_cache = amateur(
613
+ x,
614
+ return_intermediates = True,
615
+ cache = amateur_cache,
616
+ seq_start_pos = seq_start_pos,
617
+ **kwargs
618
+ )
619
+
620
+ amateur_logits = amateur_logits[:, -1]
621
+
622
+ assert amateur_logits.shape == logits.shape, 'logits dimension are not the same between amateur and expert model'
623
+ logits = contrastive_decode_fn(logits, amateur_logits, **amateur_contrastive_decode_kwargs)
624
+
625
+ if cache_kv and amateur.can_cache_kv:
626
+ amateur_caches[i] = next_amateur_cache
627
+
628
+ # filter by top_k, top_p (nucleus), top_a, or custom
629
+
630
+ filtered_logits = filter_logits_fn(logits, **filter_kwargs)
631
+
632
+ probs = F.softmax(filtered_logits / temperature, dim=-1)
633
+
634
+ sample = torch.multinomial(probs, 1)
635
+
636
+ out = torch.cat((out, sample), dim=-1)
637
+
638
+ if verbose:
639
+ if sl % 32 == 0:
640
+ print(sl, '/', seq_len)
641
+
642
+ if exists(eos_token):
643
+ is_eos_tokens = (out == eos_token)
644
+
645
+ if is_eos_tokens.any(dim = -1).all():
646
+ # mask out everything after the eos tokens
647
+ shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
648
+ mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1
649
+ out = out.masked_fill(mask, self.pad_value)
650
+
651
+ if verbose:
652
+ print('Model called the end of sequence at:', sl, '/', seq_len)
653
+
654
+ break
655
+
656
+ if return_prime:
657
+ return out[:, :]
658
+
659
+ else:
660
+ return out[:, t:]
661
+
662
+ # out, = unpack(out, ps, '* n')
663
+
664
+ # return out
665
+
666
+ def compute_accuracy(self, logits, labels):
667
+ out = torch.argmax(logits, dim=-1)
668
+ out = out.flatten()
669
+ labels = labels.flatten()
670
+
671
+ mask = (labels != self.ignore_index) # can also be self.pad_value (your choice)
672
+ out = out[mask]
673
+ labels = labels[mask]
674
+
675
+ num_right = (out == labels)
676
+ num_right = torch.sum(num_right).type(torch.float32)
677
+
678
+ acc = num_right / len(labels)
679
+ return acc
680
+
681
+ def forward(self, x, **kwargs):
682
+ seq, ignore_index, add_attn_z_loss = x.shape[1], self.ignore_index, self.add_attn_z_loss
683
+
684
+ inp, target = x[:, :-1], x[:, 1:]
685
+ inp = torch.where(inp == ignore_index, self.pad_value, inp)
686
+
687
+ if self.mask_prob > 0.:
688
+ rand = torch.randn(inp.shape, device = x.device)
689
+ rand[:, 0] = -torch.finfo(rand.dtype).max # first token should not be masked out
690
+ num_mask = min(int(seq * self.mask_prob), seq - 1)
691
+ indices = rand.topk(num_mask, dim = -1).indices
692
+ mask = ~torch.zeros_like(inp).scatter(1, indices, 1.).bool()
693
+ kwargs.update(self_attn_kv_mask = mask)
694
+
695
+ logits, cache = self.net(
696
+ inp,
697
+ return_intermediates = True,
698
+ return_attn_z_loss = add_attn_z_loss,
699
+ **kwargs
700
+ )
701
+
702
+ acc = self.compute_accuracy(logits, target)
703
+
704
+ loss = F.cross_entropy(
705
+ rearrange(logits, 'b n c -> b c n'),
706
+ target,
707
+ ignore_index = ignore_index
708
+ )
709
+
710
+ if add_attn_z_loss:
711
+ loss = loss + cache.attn_z_loss
712
+
713
+ return loss, acc
714
+
715
+ #===============================================================================
716
+
717
+ import math
718
+ from random import random
719
+
720
+ import torch
721
+ from torch import nn, einsum, Tensor
722
+ import torch.nn.functional as F
723
+
724
+ from functools import partial, wraps
725
+ from inspect import isfunction
726
+ from collections import namedtuple
727
+ from dataclasses import dataclass
728
+ from typing import List, Callable, Optional
729
+
730
+ from einops import rearrange, repeat, reduce, pack, unpack
731
+ from einops.layers.torch import Rearrange
732
+
733
+ # constants
734
+
735
+ DEFAULT_DIM_HEAD = 64
736
+
737
+ @dataclass
738
+ class LayerIntermediates:
739
+ hiddens: Optional[List[Tensor]] = None
740
+ attn_intermediates: Optional[List[Intermediates]] = None
741
+ layer_hiddens: Optional[List[Tensor]] = None
742
+ attn_z_loss: Optional[Tensor] = None
743
+ mems: Optional[Tensor] = None
744
+
745
+ # helpers
746
+
747
+ def exists(val):
748
+ return val is not None
749
+
750
+ def default(val, d):
751
+ if exists(val):
752
+ return val
753
+ return d() if isfunction(d) else d
754
+
755
+ def cast_tuple(val, depth):
756
+ return val if isinstance(val, tuple) else (val,) * depth
757
+
758
+ def divisible_by(num, den):
759
+ return (num % den) == 0
760
+
761
+ def maybe(fn):
762
+ @wraps(fn)
763
+ def inner(x, *args, **kwargs):
764
+ if not exists(x):
765
+ return x
766
+ return fn(x, *args, **kwargs)
767
+ return inner
768
+
769
+ class always():
770
+ def __init__(self, val):
771
+ self.val = val
772
+ def __call__(self, *args, **kwargs):
773
+ return self.val
774
+
775
+ class not_equals():
776
+ def __init__(self, val):
777
+ self.val = val
778
+ def __call__(self, x, *args, **kwargs):
779
+ return x != self.val
780
+
781
+ class equals():
782
+ def __init__(self, val):
783
+ self.val = val
784
+ def __call__(self, x, *args, **kwargs):
785
+ return x == self.val
786
+
787
+ def Sequential(*modules):
788
+ return nn.Sequential(*filter(exists, modules))
789
+
790
+ # tensor helpers
791
+
792
+ def max_neg_value(tensor):
793
+ return -torch.finfo(tensor.dtype).max
794
+
795
+ def l2norm(t, groups = 1):
796
+ t = rearrange(t, '... (g d) -> ... g d', g = groups)
797
+ t = F.normalize(t, p = 2, dim = -1)
798
+ return rearrange(t, '... g d -> ... (g d)')
799
+
800
+ def pad_at_dim(t, pad, dim = -1, value = 0.):
801
+ dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
802
+ zeros = ((0, 0) * dims_from_right)
803
+ return F.pad(t, (*zeros, *pad), value = value)
804
+
805
+ def or_reduce(masks):
806
+ head, *body = masks
807
+ for rest in body:
808
+ head = head | rest
809
+ return head
810
+
811
+ # auxiliary loss helpers
812
+
813
+ def calc_z_loss(
814
+ pre_softmax_attns: List[Tensor],
815
+ mask = None,
816
+ weight = 1.
817
+ ):
818
+ # the same loss applied to the mixture of experts router logits in https://arxiv.org/abs/2202.08906
819
+ # in the paper, in a tiny footnote, they mention using it on attention logits with stabilizing effects
820
+ # also used in PaLM as one of the measures
821
+
822
+ lse = 0.
823
+
824
+ for attn in pre_softmax_attns:
825
+ lse = lse + attn.logsumexp(dim = -1)
826
+
827
+ loss = torch.square(lse)
828
+ loss = reduce(loss, 'b h n -> b n', 'sum')
829
+
830
+ if not exists(mask):
831
+ return loss.mean() * weight
832
+
833
+ loss = loss[mask].sum() / mask.sum().clamp(min = 1e-5)
834
+ return loss * weight
835
+
836
+ # init helpers
837
+
838
+ def init_zero_(layer):
839
+ nn.init.constant_(layer.weight, 0.)
840
+ if exists(layer.bias):
841
+ nn.init.constant_(layer.bias, 0.)
842
+
843
+ # keyword argument helpers
844
+
845
+ def pick_and_pop(keys, d):
846
+ values = list(map(lambda key: d.pop(key), keys))
847
+ return dict(zip(keys, values))
848
+
849
+ def group_dict_by_key(cond, d):
850
+ return_val = [dict(),dict()]
851
+ for key in d.keys():
852
+ match = bool(cond(key))
853
+ ind = int(not match)
854
+ return_val[ind][key] = d[key]
855
+ return (*return_val,)
856
+
857
+ def string_begins_with(prefix, str):
858
+ return str.startswith(prefix)
859
+
860
+ def group_by_key_prefix(prefix, d):
861
+ return group_dict_by_key(partial(string_begins_with, prefix), d)
862
+
863
+ def groupby_prefix_and_trim(prefix, d):
864
+ kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
865
+ kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
866
+ return kwargs_without_prefix, kwargs
867
+
868
+ # structured dropout, more effective than traditional attention dropouts
869
+
870
+ def dropout_seq(seq, mask, dropout):
871
+ b, n, *_, device = *seq.shape, seq.device
872
+ logits = torch.randn(b, n, device = device)
873
+
874
+ if exists(mask):
875
+ mask_value = max_neg_value(logits)
876
+ logits = logits.masked_fill(~mask, mask_value)
877
+
878
+ keep_prob = 1. - dropout
879
+ num_keep = max(1, int(keep_prob * n))
880
+ keep_indices = logits.topk(num_keep, dim = 1).indices
881
+
882
+ batch_indices = torch.arange(b, device = device)
883
+ batch_indices = rearrange(batch_indices, 'b -> b 1')
884
+
885
+ seq = seq[batch_indices, keep_indices]
886
+
887
+ if exists(mask):
888
+ seq_counts = mask.sum(dim = -1)
889
+ seq_keep_counts = torch.ceil(seq_counts * keep_prob).int()
890
+ keep_mask = torch.arange(num_keep, device = device) < rearrange(seq_keep_counts, 'b -> b 1')
891
+
892
+ mask = mask[batch_indices, keep_indices] & keep_mask
893
+
894
+ return seq, mask
895
+
896
+ # activations
897
+
898
+ class ReluSquared(nn.Module):
899
+ def forward(self, x):
900
+ return F.relu(x) ** 2
901
+
902
+ # embedding
903
+
904
+ class TokenEmbedding(nn.Module):
905
+ def __init__(self, dim, num_tokens, l2norm_embed = False):
906
+ super().__init__()
907
+ self.l2norm_embed = l2norm_embed
908
+ self.emb = nn.Embedding(num_tokens, dim)
909
+
910
+ def forward(self, x):
911
+ token_emb = self.emb(x)
912
+ return l2norm(token_emb) if self.l2norm_embed else token_emb
913
+
914
+ # positional embeddings
915
+
916
+ class AbsolutePositionalEmbedding(nn.Module):
917
+ def __init__(self, dim, max_seq_len, l2norm_embed = False):
918
+ super().__init__()
919
+ self.scale = dim ** -0.5 if not l2norm_embed else 1.
920
+ self.max_seq_len = max_seq_len
921
+ self.l2norm_embed = l2norm_embed
922
+ self.emb = nn.Embedding(max_seq_len, dim)
923
+
924
+ def forward(self, x, pos = None, seq_start_pos = None):
925
+ seq_len, device = x.shape[1], x.device
926
+ assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'
927
+
928
+ if not exists(pos):
929
+ pos = torch.arange(seq_len, device = device)
930
+
931
+ if exists(seq_start_pos):
932
+ pos = (pos - seq_start_pos[..., None]).clamp(min = 0)
933
+
934
+ pos_emb = self.emb(pos)
935
+ pos_emb = pos_emb * self.scale
936
+ return l2norm(pos_emb) if self.l2norm_embed else pos_emb
937
+
938
+ class ScaledSinusoidalEmbedding(nn.Module):
939
+ def __init__(self, dim, theta = 10000):
940
+ super().__init__()
941
+ assert divisible_by(dim, 2)
942
+ self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5)
943
+
944
+ half_dim = dim // 2
945
+ freq_seq = torch.arange(half_dim).float() / half_dim
946
+ inv_freq = theta ** -freq_seq
947
+ self.register_buffer('inv_freq', inv_freq, persistent = False)
948
+
949
+ def forward(self, x, pos = None, seq_start_pos = None):
950
+ seq_len, device = x.shape[1], x.device
951
+
952
+ if not exists(pos):
953
+ pos = torch.arange(seq_len, device = device)
954
+
955
+ if exists(seq_start_pos):
956
+ pos = pos - seq_start_pos[..., None]
957
+
958
+ emb = einsum('i, j -> i j', pos, self.inv_freq)
959
+ emb = torch.cat((emb.sin(), emb.cos()), dim = -1)
960
+ return emb * self.scale
961
+
962
+ class RelativePositionBias(nn.Module):
963
+ def __init__(self, scale, causal = False, num_buckets = 32, max_distance = 128, heads = 8):
964
+ super().__init__()
965
+ self.scale = scale
966
+ self.causal = causal
967
+ self.num_buckets = num_buckets
968
+ self.max_distance = max_distance
969
+ self.relative_attention_bias = nn.Embedding(num_buckets, heads)
970
+
971
+ @staticmethod
972
+ def _relative_position_bucket(relative_position, causal = True, num_buckets = 32, max_distance = 128):
973
+ ret = 0
974
+ n = -relative_position
975
+ if not causal:
976
+ num_buckets //= 2
977
+ ret += (n < 0).long() * num_buckets
978
+ n = torch.abs(n)
979
+ else:
980
+ n = torch.max(n, torch.zeros_like(n))
981
+
982
+ max_exact = num_buckets // 2
983
+ is_small = n < max_exact
984
+
985
+ val_if_large = max_exact + (
986
+ torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
987
+ ).long()
988
+ val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
989
+
990
+ ret += torch.where(is_small, n, val_if_large)
991
+ return ret
992
+
993
+ @property
994
+ def device(self):
995
+ return next(self.parameters()).device
996
+
997
+ def forward(self, i, j):
998
+ device = self.device
999
+ q_pos = torch.arange(j - i, j, dtype = torch.long, device = device)
1000
+ k_pos = torch.arange(j, dtype = torch.long, device = device)
1001
+ rel_pos = k_pos[None, :] - q_pos[:, None]
1002
+ rp_bucket = self._relative_position_bucket(rel_pos, causal = self.causal, num_buckets = self.num_buckets, max_distance = self.max_distance)
1003
+ values = self.relative_attention_bias(rp_bucket)
1004
+ bias = rearrange(values, 'i j h -> h i j')
1005
+ return bias * self.scale
1006
+
1007
+ class DynamicPositionBias(nn.Module):
1008
+ def __init__(self, dim, *, heads, depth, log_distance = False, norm = False):
1009
+ super().__init__()
1010
+ assert depth >= 1, 'depth for dynamic position bias MLP must be greater or equal to 1'
1011
+ self.log_distance = log_distance
1012
+
1013
+ self.mlp = nn.ModuleList([])
1014
+
1015
+ self.mlp.append(Sequential(
1016
+ nn.Linear(1, dim),
1017
+ nn.LayerNorm(dim) if norm else None,
1018
+ nn.SiLU()
1019
+ ))
1020
+
1021
+ for _ in range(depth - 1):
1022
+ self.mlp.append(Sequential(
1023
+ nn.Linear(dim, dim),
1024
+ nn.LayerNorm(dim) if norm else None,
1025
+ nn.SiLU()
1026
+ ))
1027
+
1028
+ self.mlp.append(nn.Linear(dim, heads))
1029
+
1030
+ @property
1031
+ def device(self):
1032
+ return next(self.parameters()).device
1033
+
1034
+ def forward(self, i, j):
1035
+ assert i == j
1036
+ n, device = j, self.device
1037
+
1038
+ # get the (n x n) matrix of distances
1039
+ seq_arange = torch.arange(n, device = device)
1040
+ context_arange = torch.arange(n, device = device)
1041
+ indices = rearrange(seq_arange, 'i -> i 1') - rearrange(context_arange, 'j -> 1 j')
1042
+ indices += (n - 1)
1043
+
1044
+ # input to continuous positions MLP
1045
+ pos = torch.arange(-n + 1, n, device = device).float()
1046
+ pos = rearrange(pos, '... -> ... 1')
1047
+
1048
+ if self.log_distance:
1049
+ pos = torch.sign(pos) * torch.log(pos.abs() + 1) # log of distance is sign(rel_pos) * log(abs(rel_pos) + 1)
1050
+
1051
+ for layer in self.mlp:
1052
+ pos = layer(pos)
1053
+
1054
+ # get position biases
1055
+ bias = pos[indices]
1056
+ bias = rearrange(bias, 'i j h -> h i j')
1057
+ return bias
1058
+
1059
+ class AlibiPositionalBias(nn.Module):
1060
+ def __init__(self, heads, total_heads, **kwargs):
1061
+ super().__init__()
1062
+ self.heads = heads
1063
+ self.total_heads = total_heads
1064
+
1065
+ slopes = Tensor(self._get_slopes(heads))
1066
+ slopes = rearrange(slopes, 'h -> h 1 1')
1067
+ self.register_buffer('slopes', slopes, persistent = False)
1068
+ self.register_buffer('bias', None, persistent = False)
1069
+
1070
+ def get_bias(self, i, j, device):
1071
+ i_arange = torch.arange(j - i, j, device = device)
1072
+ j_arange = torch.arange(j, device = device)
1073
+ bias = -torch.abs(rearrange(j_arange, 'j -> 1 1 j') - rearrange(i_arange, 'i -> 1 i 1'))
1074
+ return bias
1075
+
1076
+ @staticmethod
1077
+ def _get_slopes(heads):
1078
+ def get_slopes_power_of_2(n):
1079
+ start = (2**(-2**-(math.log2(n)-3)))
1080
+ ratio = start
1081
+ return [start*ratio**i for i in range(n)]
1082
+
1083
+ if math.log2(heads).is_integer():
1084
+ return get_slopes_power_of_2(heads)
1085
+
1086
+ closest_power_of_2 = 2 ** math.floor(math.log2(heads))
1087
+ return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][:heads-closest_power_of_2]
1088
+
1089
+ @property
1090
+ def device(self):
1091
+ return next(self.buffers()).device
1092
+
1093
+ def forward(self, i, j):
1094
+ h, device = self.total_heads, self.device
1095
+
1096
+ if exists(self.bias) and self.bias.shape[-1] >= j and self.bias.shape[-2] >= i:
1097
+ return self.bias[..., -i:, -j:]
1098
+
1099
+ bias = self.get_bias(i, j, device)
1100
+ bias = bias * self.slopes
1101
+
1102
+ num_heads_unalibied = h - bias.shape[0]
1103
+ bias = pad_at_dim(bias, (0, num_heads_unalibied), dim = 0)
1104
+ self.register_buffer('bias', bias, persistent = False)
1105
+
1106
+ return self.bias
1107
+
1108
+ class RotaryEmbedding(nn.Module):
1109
+ def __init__(
1110
+ self,
1111
+ dim,
1112
+ use_xpos = False,
1113
+ scale_base = 512,
1114
+ interpolation_factor = 1.,
1115
+ base = 10000,
1116
+ base_rescale_factor = 1.
1117
+ ):
1118
+ super().__init__()
1119
+ # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
1120
+ # has some connection to NTK literature
1121
+ # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
1122
+ base *= base_rescale_factor ** (dim / (dim - 2))
1123
+
1124
+ inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
1125
+ self.register_buffer('inv_freq', inv_freq)
1126
+
1127
+ assert interpolation_factor >= 1.
1128
+ self.interpolation_factor = interpolation_factor
1129
+
1130
+ if not use_xpos:
1131
+ self.register_buffer('scale', None)
1132
+ return
1133
+
1134
+ scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
1135
+
1136
+ self.scale_base = scale_base
1137
+ self.register_buffer('scale', scale)
1138
+
1139
+ def forward(self, seq_len):
1140
+ device = self.inv_freq.device
1141
+ t = torch.arange(seq_len, device = device).type_as(self.inv_freq)
1142
+
1143
+ t = t / self.interpolation_factor
1144
+
1145
+ freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
1146
+ freqs = torch.cat((freqs, freqs), dim = -1)
1147
+
1148
+ if not exists(self.scale):
1149
+ return freqs, 1.
1150
+
1151
+ power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base
1152
+ scale = self.scale ** rearrange(power, 'n -> n 1')
1153
+ scale = torch.cat((scale, scale), dim = -1)
1154
+
1155
+ return freqs, scale
1156
+
1157
+
1158
+ def rotate_half(x):
1159
+ x = rearrange(x, '... (j d) -> ... j d', j = 2)
1160
+ x1, x2 = x.unbind(dim = -2)
1161
+ return torch.cat((-x2, x1), dim = -1)
1162
+
1163
+ def apply_rotary_pos_emb(t, freqs, scale = 1):
1164
+ rot_dim, seq_len = freqs.shape[-1], t.shape[-2]
1165
+ freqs = freqs[-seq_len:, :]
1166
+
1167
+ if t.ndim == 4 and freqs.ndim == 3:
1168
+ freqs = rearrange(freqs, 'b n d -> b 1 n d')
1169
+
1170
+ # partial rotary embeddings, Wang et al. GPT-J
1171
+ t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:]
1172
+ t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
1173
+ return torch.cat((t, t_unrotated), dim = -1)
1174
+
1175
+ # norms
1176
+
1177
+ class Scale(nn.Module):
1178
+ def __init__(self, value, fn):
1179
+ super().__init__()
1180
+ self.value = value
1181
+ self.fn = fn
1182
+
1183
+ def forward(self, x, **kwargs):
1184
+ out = self.fn(x, **kwargs)
1185
+ scale_fn = lambda t: t * self.value
1186
+
1187
+ if not isinstance(out, tuple):
1188
+ return scale_fn(out)
1189
+
1190
+ return (scale_fn(out[0]), *out[1:])
1191
+
1192
+ class ScaleNorm(nn.Module):
1193
+ def __init__(self, dim, eps = 1e-5):
1194
+ super().__init__()
1195
+ self.eps = eps
1196
+ self.g = nn.Parameter(torch.ones(1) * (dim ** -0.5))
1197
+
1198
+ def forward(self, x):
1199
+ norm = torch.norm(x, dim = -1, keepdim = True)
1200
+ return x / norm.clamp(min = self.eps) * self.g
1201
+
1202
+ class RMSNorm(nn.Module):
1203
+ def __init__(self, dim):
1204
+ super().__init__()
1205
+ self.scale = dim ** 0.5
1206
+ self.g = nn.Parameter(torch.ones(dim))
1207
+
1208
+ def forward(self, x):
1209
+ return F.normalize(x, dim = -1) * self.scale * self.g
1210
+
1211
+ class SimpleRMSNorm(nn.Module):
1212
+ def __init__(self, dim):
1213
+ super().__init__()
1214
+ self.scale = dim ** 0.5
1215
+
1216
+ def forward(self, x):
1217
+ return F.normalize(x, dim = -1) * self.scale
1218
+
1219
+ # residual and residual gates
1220
+
1221
+ class Residual(nn.Module):
1222
+ def __init__(self, dim, scale_residual = False, scale_residual_constant = 1.):
1223
+ super().__init__()
1224
+ self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
1225
+ self.scale_residual_constant = scale_residual_constant
1226
+
1227
+ def forward(self, x, residual):
1228
+ if exists(self.residual_scale):
1229
+ residual = residual * self.residual_scale
1230
+
1231
+ if self.scale_residual_constant != 1:
1232
+ residual = residual * self.scale_residual_constant
1233
+
1234
+ return x + residual
1235
+
1236
+ class GRUGating(nn.Module):
1237
+ def __init__(self, dim, scale_residual = False, **kwargs):
1238
+ super().__init__()
1239
+ self.gru = nn.GRUCell(dim, dim)
1240
+ self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
1241
+
1242
+ def forward(self, x, residual):
1243
+ if exists(self.residual_scale):
1244
+ residual = residual * self.residual_scale
1245
+
1246
+ gated_output = self.gru(
1247
+ rearrange(x, 'b n d -> (b n) d'),
1248
+ rearrange(residual, 'b n d -> (b n) d')
1249
+ )
1250
+
1251
+ return gated_output.reshape_as(x)
1252
+
1253
+ # token shifting
1254
+
1255
+ def shift(t, amount, mask = None):
1256
+ if amount == 0:
1257
+ return t
1258
+ else:
1259
+ amount = min(amount, t.shape[1])
1260
+
1261
+ if exists(mask):
1262
+ t = t.masked_fill(~mask[..., None], 0.)
1263
+
1264
+ return pad_at_dim(t, (amount, -amount), dim = - 2, value = 0.)
1265
+
1266
+ class ShiftTokens(nn.Module):
1267
+ def __init__(self, shifts, fn):
1268
+ super().__init__()
1269
+ self.fn = fn
1270
+ self.shifts = tuple(shifts)
1271
+
1272
+ def forward(self, x, **kwargs):
1273
+ mask = kwargs.get('mask', None)
1274
+ shifts = self.shifts
1275
+ segments = len(shifts)
1276
+ feats_per_shift = x.shape[-1] // segments
1277
+ splitted = x.split(feats_per_shift, dim = -1)
1278
+ segments_to_shift, rest = splitted[:segments], splitted[segments:]
1279
+ segments_to_shift = list(map(lambda args: shift(*args, mask = mask), zip(segments_to_shift, shifts)))
1280
+ x = torch.cat((*segments_to_shift, *rest), dim = -1)
1281
+ return self.fn(x, **kwargs)
1282
+
1283
+ # feedforward
1284
+
1285
+ class GLU(nn.Module):
1286
+ def __init__(
1287
+ self,
1288
+ dim_in,
1289
+ dim_out,
1290
+ activation: Callable,
1291
+ mult_bias = False
1292
+ ):
1293
+ super().__init__()
1294
+ self.act = activation
1295
+ self.proj = nn.Linear(dim_in, dim_out * 2)
1296
+ self.mult_bias = nn.Parameter(torch.ones(dim_out)) if mult_bias else 1.
1297
+
1298
+ def forward(self, x):
1299
+ x, gate = self.proj(x).chunk(2, dim = -1)
1300
+ return x * self.act(gate) * self.mult_bias
1301
+
1302
+ class FeedForward(nn.Module):
1303
+ def __init__(
1304
+ self,
1305
+ dim,
1306
+ dim_out = None,
1307
+ mult = 4,
1308
+ glu = False,
1309
+ glu_mult_bias = False,
1310
+ swish = False,
1311
+ relu_squared = False,
1312
+ post_act_ln = False,
1313
+ dropout = 0.,
1314
+ no_bias = False,
1315
+ zero_init_output = False
1316
+ ):
1317
+ super().__init__()
1318
+ inner_dim = int(dim * mult)
1319
+ dim_out = default(dim_out, dim)
1320
+
1321
+ if relu_squared:
1322
+ activation = ReluSquared()
1323
+ elif swish:
1324
+ activation = nn.SiLU()
1325
+ else:
1326
+ activation = nn.GELU()
1327
+
1328
+ if glu:
1329
+ project_in = GLU(dim, inner_dim, activation, mult_bias = glu_mult_bias)
1330
+ else:
1331
+ project_in = nn.Sequential(
1332
+ nn.Linear(dim, inner_dim, bias = not no_bias),
1333
+ activation
1334
+ )
1335
+
1336
+ self.ff = Sequential(
1337
+ project_in,
1338
+ nn.LayerNorm(inner_dim) if post_act_ln else None,
1339
+ nn.Dropout(dropout),
1340
+ nn.Linear(inner_dim, dim_out, bias = not no_bias)
1341
+ )
1342
+
1343
+ # init last linear layer to 0
1344
+ if zero_init_output:
1345
+ init_zero_(self.ff[-1])
1346
+
1347
+ def forward(self, x):
1348
+ return self.ff(x)
1349
+
1350
+ # attention. it is all we need
1351
+
1352
+ class Attention(nn.Module):
1353
+ def __init__(
1354
+ self,
1355
+ dim,
1356
+ dim_head = DEFAULT_DIM_HEAD,
1357
+ heads = 8,
1358
+ causal = False,
1359
+ flash = False,
1360
+ talking_heads = False,
1361
+ head_scale = False,
1362
+ sparse_topk = None,
1363
+ num_mem_kv = 0,
1364
+ dropout = 0.,
1365
+ on_attn = False,
1366
+ gate_value_heads = False,
1367
+ gate_values = False,
1368
+ zero_init_output = False,
1369
+ max_attend_past = None,
1370
+ qk_norm = False,
1371
+ qk_norm_groups = 1,
1372
+ qk_norm_scale = 10,
1373
+ qk_norm_dim_scale = False,
1374
+ one_kv_head = False,
1375
+ kv_heads = None,
1376
+ shared_kv = False,
1377
+ value_dim_head = None,
1378
+ tensor_product = False, # https://arxiv.org/abs/2208.06061
1379
+ add_zero_kv = False, # same as add_zero_attn in pytorch
1380
+ rotary_embed_values = False,
1381
+ onnxable = False
1382
+ ):
1383
+ super().__init__()
1384
+ self.scale = dim_head ** -0.5
1385
+
1386
+ self.heads = heads
1387
+ self.causal = causal
1388
+ self.max_attend_past = max_attend_past
1389
+
1390
+ assert not (exists(kv_heads) and one_kv_head), 'either attn_one_kv_head is set to True (in which case kv_heads is set to 1), or attn_kv_heads is set, but not both'
1391
+
1392
+ value_dim_head = default(value_dim_head, dim_head)
1393
+ kv_heads = default(kv_heads, heads)
1394
+
1395
+ kv_heads = 1 if one_kv_head else kv_heads
1396
+ assert divisible_by(heads, kv_heads)
1397
+
1398
+ self.kv_heads = kv_heads
1399
+
1400
+ q_dim = dim_head * heads
1401
+ k_dim = dim_head * kv_heads
1402
+ v_dim = value_dim_head * kv_heads
1403
+ out_dim = value_dim_head * heads
1404
+
1405
+ self.to_q = nn.Linear(dim, q_dim, bias = False)
1406
+ self.to_k = nn.Linear(dim, k_dim, bias = False)
1407
+
1408
+ # shared key / values, for further memory savings during inference
1409
+ assert not (shared_kv and value_dim_head != dim_head), 'key and value head dimensions must be equal for shared key / values'
1410
+ self.to_v = nn.Linear(dim, v_dim, bias = False) if not shared_kv else None
1411
+
1412
+ # relations projection from tp-attention
1413
+ self.to_r = nn.Linear(dim, v_dim, bias = False) if tensor_product else None
1414
+
1415
+ # add GLU gating for aggregated values, from alphafold2
1416
+ self.to_v_gate = None
1417
+ if gate_values:
1418
+ self.to_v_gate = nn.Linear(dim, out_dim)
1419
+ nn.init.constant_(self.to_v_gate.weight, 0)
1420
+ nn.init.constant_(self.to_v_gate.bias, 10)
1421
+
1422
+ # add per head gating of the output values, from 'Attend to nothing' paper
1423
+ self.to_v_head_gate = None
1424
+ if gate_value_heads:
1425
+ self.to_v_head_gate = nn.Linear(dim, heads)
1426
+ nn.init.constant_(self.to_v_head_gate.weight, 0)
1427
+ nn.init.constant_(self.to_v_head_gate.bias, 10)
1428
+
1429
+ # cosine sim attention
1430
+ self.qk_norm = qk_norm
1431
+ self.qk_norm_groups = qk_norm_groups
1432
+ self.qk_norm_scale = qk_norm_scale
1433
+
1434
+ # whether to use the rmsnorm (equivalent to cosine sim attention when scale is equal to 1) - https://arxiv.org/abs/2302.05442
1435
+ self.qk_norm_dim_scale = qk_norm_dim_scale
1436
+
1437
+ self.qk_norm_q_scale = self.qk_norm_k_scale = 1
1438
+ if qk_norm and qk_norm_dim_scale:
1439
+ self.qk_norm_q_scale = nn.Parameter(torch.ones(heads, 1, dim_head))
1440
+ self.qk_norm_k_scale = nn.Parameter(torch.ones(heads, 1, dim_head))
1441
+
1442
+ assert (not qk_norm) or divisible_by(dim_head, qk_norm_groups), 'dimension per attention head must be divisible by the qk norm groups'
1443
+ assert not (qk_norm and (dim_head // qk_norm_groups) <= 2), 'the group dimension may be too small (2 was too small in my tests, but 4 still works, surprisingly)'
1444
+
1445
+ # attend class - includes core attention algorithm + talking heads
1446
+
1447
+ self.attend = Attend(
1448
+ heads = heads,
1449
+ causal = causal,
1450
+ talking_heads = talking_heads,
1451
+ dropout = dropout,
1452
+ sparse_topk = sparse_topk,
1453
+ qk_norm = qk_norm,
1454
+ scale = qk_norm_scale if qk_norm else self.scale,
1455
+ add_zero_kv = add_zero_kv,
1456
+ flash = flash,
1457
+ onnxable = onnxable
1458
+ )
1459
+
1460
+ # head scaling
1461
+ self.head_scale = head_scale
1462
+ if head_scale:
1463
+ self.head_scale_params = nn.Parameter(torch.ones(1, heads, 1, 1))
1464
+
1465
+ # explicit topk sparse attention
1466
+ self.sparse_topk = sparse_topk
1467
+
1468
+ # add memory key / values
1469
+ self.num_mem_kv = num_mem_kv
1470
+ if num_mem_kv > 0:
1471
+ self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
1472
+ self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
1473
+
1474
+ # attention on attention
1475
+ self.attn_on_attn = on_attn
1476
+ self.to_out = nn.Sequential(nn.Linear(out_dim, dim * 2, bias = False), nn.GLU()) if on_attn else nn.Linear(out_dim, dim, bias = False)
1477
+
1478
+ # whether to rotate positions into values, for absolute positions in addition to relative
1479
+ self.rotary_embed_values = rotary_embed_values
1480
+
1481
+ # init output projection 0
1482
+ if zero_init_output:
1483
+ init_zero_(self.to_out)
1484
+
1485
+ def forward(
1486
+ self,
1487
+ x,
1488
+ context = None,
1489
+ mask = None,
1490
+ context_mask = None,
1491
+ attn_mask = None,
1492
+ rel_pos = None,
1493
+ rotary_pos_emb = None,
1494
+ prev_attn = None,
1495
+ mem = None,
1496
+ return_intermediates = False,
1497
+ cache: Optional[Intermediates] = None,
1498
+ ):
1499
+ b, n, _, h, kv_h, head_scale, device, has_context = *x.shape, self.heads, self.kv_heads, self.head_scale, x.device, exists(context)
1500
+ kv_input = default(context, x)
1501
+
1502
+ q_input = x
1503
+ k_input = kv_input
1504
+ v_input = kv_input
1505
+ r_input = x
1506
+
1507
+ if exists(mem):
1508
+ k_input, mem_packed_shape = pack([mem, k_input], 'b * d')
1509
+ v_input, _ = pack([mem, v_input], 'b * d')
1510
+
1511
+ q = self.to_q(q_input)
1512
+ k = self.to_k(k_input)
1513
+ v = self.to_v(v_input) if exists(self.to_v) else k
1514
+ r = self.to_r(r_input) if exists(self.to_r) else None
1515
+
1516
+ q = rearrange(q, 'b n (h d) -> b h n d', h = h)
1517
+
1518
+ k, v, r = map(lambda t: maybe(rearrange)(t, 'b n (h d) -> b h n d', h = kv_h), (k, v, r))
1519
+
1520
+ if exists(cache) and not has_context:
1521
+ ck, cv = cache.cached_kv
1522
+
1523
+ if exists(mem):
1524
+ mk, k = unpack(k, mem_packed_shape, 'b h * d')
1525
+ mv, v = unpack(v, mem_packed_shape, 'b h * d')
1526
+
1527
+ k = torch.cat((ck, k), dim = -2)
1528
+ v = torch.cat((cv, v), dim = -2)
1529
+
1530
+ if exists(mem):
1531
+ k = torch.cat((mk, k), dim = -2)
1532
+ v = torch.cat((mv, v), dim = -2)
1533
+
1534
+ if return_intermediates:
1535
+ mem_len = mem.shape[-2] if exists(mem) else 0
1536
+ cached_kv = (k[..., mem_len:, :], v[..., mem_len:, :])
1537
+
1538
+ if self.qk_norm:
1539
+ qk_l2norm = partial(l2norm, groups = self.qk_norm_groups)
1540
+ q, k = map(qk_l2norm, (q, k))
1541
+ scale = self.qk_norm_scale
1542
+
1543
+ q = q * self.qk_norm_q_scale
1544
+ k = k * self.qk_norm_k_scale
1545
+
1546
+ if exists(rotary_pos_emb) and not has_context:
1547
+ freqs, xpos_scale = rotary_pos_emb
1548
+ q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if exists(xpos_scale) else (1., 1.)
1549
+
1550
+ q = apply_rotary_pos_emb(q, freqs, q_xpos_scale)
1551
+ k = apply_rotary_pos_emb(k, freqs, k_xpos_scale)
1552
+
1553
+ if self.rotary_embed_values:
1554
+ v = apply_rotary_pos_emb(v, freqs, k_xpos_scale)
1555
+
1556
+ input_mask = context_mask
1557
+
1558
+ if not exists(input_mask) and not has_context:
1559
+ input_mask = mask
1560
+
1561
+ if self.num_mem_kv > 0:
1562
+ mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b = b), (self.mem_k, self.mem_v))
1563
+
1564
+ if self.qk_norm:
1565
+ mem_k = l2norm(mem_k)
1566
+ mem_k = mem_k * self.qk_norm_k_scale
1567
+
1568
+ k = torch.cat((mem_k, k), dim = -2)
1569
+ v = torch.cat((mem_v, v), dim = -2)
1570
+
1571
+ if exists(input_mask):
1572
+ input_mask = pad_at_dim(input_mask, (self.num_mem_kv, 0), dim = -1, value = True)
1573
+
1574
+ i, j = map(lambda t: t.shape[-2], (q, k))
1575
+
1576
+ # determine masking
1577
+
1578
+ mask_value = max_neg_value(q)
1579
+ masks = []
1580
+ final_attn_mask = None
1581
+
1582
+ if exists(input_mask):
1583
+ input_mask = rearrange(input_mask, 'b j -> b 1 1 j')
1584
+ masks.append(~input_mask)
1585
+
1586
+ if exists(attn_mask):
1587
+ assert 2 <= attn_mask.ndim <= 4, 'attention mask must have greater than 2 dimensions but less than or equal to 4'
1588
+ if attn_mask.ndim == 2:
1589
+ attn_mask = rearrange(attn_mask, 'i j -> 1 1 i j')
1590
+ elif attn_mask.ndim == 3:
1591
+ attn_mask = rearrange(attn_mask, 'h i j -> 1 h i j')
1592
+ masks.append(~attn_mask)
1593
+
1594
+ if exists(self.max_attend_past):
1595
+ range_q = torch.arange(j - i, j, device = device)
1596
+ range_k = torch.arange(j, device = device)
1597
+ dist = rearrange(range_q, 'i -> 1 1 i 1') - rearrange(range_k, 'j -> 1 1 1 j')
1598
+ max_attend_past_mask = dist > self.max_attend_past
1599
+ masks.append(max_attend_past_mask)
1600
+
1601
+ if len(masks) > 0:
1602
+ final_attn_mask = ~or_reduce(masks)
1603
+
1604
+ # prepare relative positional bias, if needed
1605
+
1606
+ attn_bias = None
1607
+ if exists(rel_pos):
1608
+ attn_bias = rel_pos(i, j)
1609
+
1610
+ # attention is all we need
1611
+
1612
+ out, intermediates = self.attend(
1613
+ q, k, v,
1614
+ mask = final_attn_mask,
1615
+ attn_bias = attn_bias,
1616
+ prev_attn = prev_attn
1617
+ )
1618
+
1619
+ # https://arxiv.org/abs/2208.06061 proposes to add a residual for better gradients
1620
+
1621
+ if exists(r):
1622
+ out = out * r + out
1623
+
1624
+ # normformer scaling of heads
1625
+
1626
+ if head_scale:
1627
+ out = out * self.head_scale_params
1628
+
1629
+ # per head gating, from https://arxiv.org/abs/2306.12929
1630
+
1631
+ if exists(self.to_v_head_gate):
1632
+ head_gate = self.to_v_head_gate(x)
1633
+ out = out * rearrange(head_gate, 'b n h -> b h n 1').sigmoid()
1634
+
1635
+ # merge heads
1636
+
1637
+ out = rearrange(out, 'b h n d -> b n (h d)')
1638
+
1639
+ # alphafold2 styled gating of the values
1640
+
1641
+ if exists(self.to_v_gate):
1642
+ gates = self.to_v_gate(x)
1643
+ out = out * gates.sigmoid()
1644
+
1645
+ # combine the heads
1646
+
1647
+ out = self.to_out(out)
1648
+
1649
+ if exists(mask):
1650
+ mask = rearrange(mask, 'b n -> b n 1')
1651
+ out = out.masked_fill(~mask, 0.)
1652
+
1653
+ if not return_intermediates:
1654
+ return out
1655
+
1656
+ intermediates.cached_kv = cached_kv
1657
+
1658
+ return out, intermediates
1659
+
1660
+ class AttentionLayers(nn.Module):
1661
+ def __init__(
1662
+ self,
1663
+ dim,
1664
+ depth,
1665
+ heads = 8,
1666
+ causal = False,
1667
+ cross_attend = False,
1668
+ only_cross = False,
1669
+ use_scalenorm = False,
1670
+ use_rmsnorm = False,
1671
+ use_simple_rmsnorm = False,
1672
+ alibi_pos_bias = False,
1673
+ alibi_num_heads = None,
1674
+ rel_pos_bias = False,
1675
+ rel_pos_num_buckets = 32,
1676
+ rel_pos_max_distance = 128,
1677
+ dynamic_pos_bias = False,
1678
+ dynamic_pos_bias_log_distance = False,
1679
+ dynamic_pos_bias_mlp_depth = 2,
1680
+ dynamic_pos_bias_norm = False,
1681
+ rotary_pos_emb = False,
1682
+ rotary_emb_dim = None,
1683
+ rotary_xpos = False,
1684
+ rotary_interpolation_factor = 1.,
1685
+ rotary_xpos_scale_base = 512,
1686
+ rotary_base_rescale_factor = 1.,
1687
+ custom_layers = None,
1688
+ sandwich_coef = None,
1689
+ par_ratio = None,
1690
+ weight_tie_layers = False, # Albert - https://arxiv.org/abs/1909.11942
1691
+ layers_execute_order = None, # generalizes weight tying, can do arbitrary layer execution orders
1692
+ residual_attn = False,
1693
+ cross_residual_attn = False,
1694
+ macaron = False,
1695
+ pre_norm = True,
1696
+ pre_norm_has_final_norm = True,
1697
+ gate_residual = False,
1698
+ scale_residual = False,
1699
+ scale_residual_constant = 1.,
1700
+ shift_tokens = 0,
1701
+ sandwich_norm = False,
1702
+ resi_dual = False,
1703
+ resi_dual_scale = 1.,
1704
+ zero_init_branch_output = False,
1705
+ layer_dropout = 0.,
1706
+ cross_attn_tokens_dropout = 0.,
1707
+ **kwargs
1708
+ ):
1709
+ super().__init__()
1710
+ rotary_pos_emb = rotary_pos_emb or rotary_xpos
1711
+
1712
+ ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)
1713
+ attn_kwargs, kwargs = groupby_prefix_and_trim('attn_', kwargs)
1714
+
1715
+ dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
1716
+
1717
+ self.dim = dim
1718
+ self.depth = depth
1719
+ self.causal = causal
1720
+ self.layers = nn.ModuleList([])
1721
+
1722
+ self.has_pos_emb = rel_pos_bias or rotary_pos_emb
1723
+
1724
+ rotary_emb_dim = max(default(rotary_emb_dim, dim_head // 2), 32)
1725
+
1726
+ assert not (rotary_xpos and not causal), 'rotary xpos is not compatible with bidirectional attention'
1727
+ self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim, use_xpos = rotary_xpos, scale_base = rotary_xpos_scale_base, interpolation_factor = rotary_interpolation_factor, base_rescale_factor = rotary_base_rescale_factor) if rotary_pos_emb else None
1728
+
1729
+ assert not (alibi_pos_bias and rel_pos_bias), 'you can only choose Alibi positional bias or T5 relative positional bias, not both'
1730
+ assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
1731
+
1732
+ # relative positional bias
1733
+
1734
+ flash_attn = attn_kwargs.get('flash', False)
1735
+ assert (int(rel_pos_bias) + int(dynamic_pos_bias) + int(alibi_pos_bias)) <= 1, 'you can only choose up to one of t5, alibi, or dynamic positional bias'
1736
+
1737
+ self.rel_pos = None
1738
+ if rel_pos_bias:
1739
+ assert not flash_attn, 'flash attention not compatible with t5 relative positional bias'
1740
+ self.rel_pos = RelativePositionBias(scale = dim_head ** 0.5, causal = causal, heads = heads, num_buckets = rel_pos_num_buckets, max_distance = rel_pos_max_distance)
1741
+ elif dynamic_pos_bias:
1742
+ assert not flash_attn, 'flash attention not compatible with dynamic positional bias'
1743
+ self.rel_pos = DynamicPositionBias(dim = dim // 4, heads = heads, log_distance = dynamic_pos_bias_log_distance, depth = dynamic_pos_bias_mlp_depth, norm = dynamic_pos_bias_norm)
1744
+ elif alibi_pos_bias:
1745
+ alibi_num_heads = default(alibi_num_heads, heads)
1746
+ assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads'
1747
+ self.rel_pos = AlibiPositionalBias(heads = alibi_num_heads, total_heads = heads)
1748
+
1749
+ assert (int(sandwich_norm) + int(resi_dual)) <= 1, 'either sandwich norm or resiDual is selected, but not both'
1750
+ assert not (not pre_norm and sandwich_norm), 'sandwich norm cannot be used when not using prenorm'
1751
+
1752
+ if resi_dual:
1753
+ pre_norm = False
1754
+
1755
+ self.pre_norm = pre_norm
1756
+ self.sandwich_norm = sandwich_norm
1757
+
1758
+ self.resi_dual = resi_dual
1759
+ assert 0 < resi_dual_scale <= 1., 'resiDual prenorm residual must be scaled by a factor greater than 0 and less than or equal to 1.'
1760
+ self.resi_dual_scale = resi_dual_scale
1761
+
1762
+ self.residual_attn = residual_attn
1763
+ self.cross_residual_attn = cross_residual_attn
1764
+ assert not (flash_attn and (residual_attn or cross_residual_attn)), 'flash attention is not compatible with residual attention'
1765
+
1766
+ self.cross_attend = cross_attend
1767
+
1768
+ assert (int(use_scalenorm) + int(use_rmsnorm) + int(use_simple_rmsnorm)) <= 1, 'you can only use either scalenorm, rmsnorm, or simple rmsnorm'
1769
+
1770
+ if use_scalenorm:
1771
+ norm_class = ScaleNorm
1772
+ elif use_rmsnorm:
1773
+ norm_class = RMSNorm
1774
+ elif use_simple_rmsnorm:
1775
+ norm_class = SimpleRMSNorm
1776
+ else:
1777
+ norm_class = nn.LayerNorm
1778
+
1779
+ norm_fn = partial(norm_class, dim)
1780
+
1781
+ if cross_attend and not only_cross:
1782
+ default_block = ('a', 'c', 'f')
1783
+ elif cross_attend and only_cross:
1784
+ default_block = ('c', 'f')
1785
+ else:
1786
+ default_block = ('a', 'f')
1787
+
1788
+ if macaron:
1789
+ default_block = ('f',) + default_block
1790
+
1791
+ # zero init
1792
+
1793
+ if zero_init_branch_output:
1794
+ attn_kwargs = {**attn_kwargs, 'zero_init_output': True}
1795
+ ff_kwargs = {**ff_kwargs, 'zero_init_output': True}
1796
+
1797
+ # setup weight tying, which is a special case of `layer_execute_order`
1798
+
1799
+ assert not (weight_tie_layers and any([*map(exists, (custom_layers, par_ratio, sandwich_coef))]))
1800
+
1801
+ if weight_tie_layers:
1802
+ assert not exists(layers_execute_order)
1803
+ layers_execute_order = tuple(range(len(default_block))) * depth
1804
+ depth = 1
1805
+
1806
+ # calculate layer block order
1807
+
1808
+ if exists(custom_layers):
1809
+ layer_types = custom_layers
1810
+ elif exists(par_ratio):
1811
+ par_depth = depth * len(default_block)
1812
+ assert 1 < par_ratio <= par_depth, 'par ratio out of range'
1813
+ default_block = tuple(filter(not_equals('f'), default_block))
1814
+ par_attn = par_depth // par_ratio
1815
+ depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper
1816
+ par_width = (depth_cut + depth_cut // par_attn) // par_attn
1817
+ assert len(default_block) <= par_width, 'default block is too large for par_ratio'
1818
+ par_block = default_block + ('f',) * (par_width - len(default_block))
1819
+ par_head = par_block * par_attn
1820
+ layer_types = par_head + ('f',) * (par_depth - len(par_head))
1821
+ elif exists(sandwich_coef):
1822
+ assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'
1823
+ layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef
1824
+ else:
1825
+ layer_types = default_block * depth
1826
+
1827
+ self.layer_types = layer_types
1828
+ self.layers_execute_order = default(layers_execute_order, tuple(range(len(layer_types))))
1829
+
1830
+ assert all([i < len(self.layer_types) for i in self.layers_execute_order])
1831
+
1832
+ self.num_attn_layers = len(list(filter(equals('a'), layer_types)))
1833
+
1834
+ # stochastic depth
1835
+
1836
+ self.layer_dropouts = cast_tuple(layer_dropout, len(layer_types))
1837
+
1838
+ # structured dropout for cross attending
1839
+
1840
+ self.cross_attn_tokens_dropout = cross_attn_tokens_dropout
1841
+
1842
+ # calculate token shifting
1843
+
1844
+ shift_tokens = cast_tuple(shift_tokens, len(layer_types))
1845
+
1846
+ # whether it has post norm
1847
+
1848
+ self.final_norm = norm_fn() if pre_norm or resi_dual else nn.Identity()
1849
+
1850
+ # iterate and construct layers
1851
+
1852
+ for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)):
1853
+ is_last_layer = ind == (len(self.layer_types) - 1)
1854
+
1855
+ if layer_type == 'a':
1856
+ layer = Attention(dim, heads = heads, causal = causal, **attn_kwargs)
1857
+ elif layer_type == 'c':
1858
+ layer = Attention(dim, heads = heads, **attn_kwargs)
1859
+ elif layer_type == 'f':
1860
+ layer = FeedForward(dim, **ff_kwargs)
1861
+ layer = layer if not macaron else Scale(0.5, layer)
1862
+ else:
1863
+ raise Exception(f'invalid layer type {layer_type}')
1864
+
1865
+ if layer_shift_tokens > 0:
1866
+ shift_range_upper = layer_shift_tokens + 1
1867
+ shift_range_lower = -layer_shift_tokens if not causal else 0
1868
+ layer = ShiftTokens(range(shift_range_lower, shift_range_upper), layer)
1869
+
1870
+ residual_fn = GRUGating if gate_residual else Residual
1871
+ residual = residual_fn(dim, scale_residual = scale_residual, scale_residual_constant = scale_residual_constant)
1872
+
1873
+ pre_branch_norm = norm_fn() if pre_norm else None
1874
+ post_branch_norm = norm_fn() if sandwich_norm else None
1875
+ post_main_norm = norm_fn() if not pre_norm else None
1876
+
1877
+ norms = nn.ModuleList([
1878
+ pre_branch_norm,
1879
+ post_branch_norm,
1880
+ post_main_norm
1881
+ ])
1882
+
1883
+ self.layers.append(nn.ModuleList([
1884
+ norms,
1885
+ layer,
1886
+ residual
1887
+ ]))
1888
+
1889
+ def forward(
1890
+ self,
1891
+ x,
1892
+ context = None,
1893
+ mask = None,
1894
+ context_mask = None,
1895
+ attn_mask = None,
1896
+ self_attn_kv_mask = None,
1897
+ mems = None,
1898
+ seq_start_pos: Optional[Tensor] = None,
1899
+ cache: Optional[LayerIntermediates] = None,
1900
+ cache_age = 1,
1901
+ return_hiddens = False
1902
+ ):
1903
+ assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross_attend is set to True'
1904
+
1905
+ # initialize accums
1906
+
1907
+ hiddens = []
1908
+ layer_hiddens = []
1909
+ intermediates = []
1910
+
1911
+ prev_attn = None
1912
+ prev_cross_attn = None
1913
+
1914
+ mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
1915
+
1916
+ # handle left padded sequences
1917
+
1918
+ if exists(seq_start_pos):
1919
+ seq_arange = torch.arange(x.shape[-2], device = x.device, dtype = torch.long)
1920
+ left_pad_mask = seq_arange >= seq_start_pos[..., None]
1921
+
1922
+ if exists(self_attn_kv_mask):
1923
+ self_attn_kv_mask = self_attn_kv_mask & left_pad_mask
1924
+ else:
1925
+ self_attn_kv_mask = left_pad_mask
1926
+
1927
+ # rotary positions
1928
+
1929
+ rotary_pos_emb = None
1930
+
1931
+ if exists(self.rotary_pos_emb):
1932
+ max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + x.shape[1], mems)))
1933
+ rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length)
1934
+
1935
+ # assume cached key / values
1936
+
1937
+ attn_cache = []
1938
+
1939
+ if exists(cache):
1940
+ assert not self.training and self.causal and not any([*map(exists, (mask, attn_mask))])
1941
+
1942
+ if cache_age > 0:
1943
+ x = x[:, -cache_age:] # for spec decoding, may be greater than 1
1944
+
1945
+ attn_cache = cache.attn_intermediates
1946
+
1947
+ iter_attn_cache = iter(attn_cache)
1948
+
1949
+ # outer residual - for resiDual paper
1950
+
1951
+ outer_residual = x * self.resi_dual_scale
1952
+
1953
+ # get layers to be executed
1954
+
1955
+ layer_variables = (
1956
+ self.layer_types,
1957
+ self.layers,
1958
+ self.layer_dropouts
1959
+ )
1960
+
1961
+ layer_variables = tuple(tuple(layer_variable[i] for i in self.layers_execute_order) for layer_variable in layer_variables)
1962
+
1963
+ # go through the attention and feedforward layers
1964
+
1965
+ for ind, (layer_type, (norm, block, residual_fn), layer_dropout) in enumerate(zip(*layer_variables)):
1966
+ is_last = ind == (len(self.layers) - 1)
1967
+
1968
+ if self.training and layer_dropout > 0. and random() < layer_dropout:
1969
+ continue
1970
+
1971
+ if layer_type == 'a':
1972
+ if return_hiddens:
1973
+ hiddens.append(x)
1974
+ layer_mem = mems.pop(0) if mems else None
1975
+
1976
+ if layer_type == 'c':
1977
+ if self.training and self.cross_attn_tokens_dropout > 0.:
1978
+ context, context_mask = dropout_seq(context, context_mask, self.cross_attn_tokens_dropout)
1979
+
1980
+ inner_residual = x
1981
+
1982
+ if return_hiddens:
1983
+ layer_hiddens.append(x)
1984
+
1985
+ pre_norm, post_branch_norm, post_main_norm = norm
1986
+
1987
+ if exists(pre_norm):
1988
+ x = pre_norm(x)
1989
+
1990
+ if layer_type == 'a':
1991
+ out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, return_intermediates = True)
1992
+ elif layer_type == 'c':
1993
+ out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), return_intermediates = True)
1994
+ elif layer_type == 'f':
1995
+ out = block(x)
1996
+
1997
+ if self.resi_dual:
1998
+ outer_residual = outer_residual + out * self.resi_dual_scale
1999
+
2000
+ if exists(post_branch_norm):
2001
+ out = post_branch_norm(out)
2002
+
2003
+ x = residual_fn(out, inner_residual)
2004
+
2005
+ if layer_type in ('a', 'c') and return_hiddens:
2006
+ intermediates.append(inter)
2007
+
2008
+ if layer_type == 'a' and self.residual_attn:
2009
+ prev_attn = inter.pre_softmax_attn
2010
+ elif layer_type == 'c' and self.cross_residual_attn:
2011
+ prev_cross_attn = inter.pre_softmax_attn
2012
+
2013
+ if exists(post_main_norm):
2014
+ x = post_main_norm(x)
2015
+
2016
+ if return_hiddens:
2017
+ layer_hiddens.append(x)
2018
+
2019
+ if self.resi_dual:
2020
+ x = x + self.final_norm(outer_residual)
2021
+ else:
2022
+ x = self.final_norm(x)
2023
+
2024
+ if not return_hiddens:
2025
+ return x
2026
+
2027
+ intermediates = LayerIntermediates(
2028
+ hiddens = hiddens,
2029
+ attn_intermediates = intermediates,
2030
+ layer_hiddens = layer_hiddens
2031
+ )
2032
+
2033
+ return x, intermediates
2034
+
2035
+ class Encoder(AttentionLayers):
2036
+ def __init__(self, **kwargs):
2037
+ assert 'causal' not in kwargs, 'cannot set causality on encoder'
2038
+ super().__init__(causal = False, **kwargs)
2039
+
2040
+ class Decoder(AttentionLayers):
2041
+ def __init__(self, **kwargs):
2042
+ assert 'causal' not in kwargs, 'cannot set causality on decoder'
2043
+ super().__init__(causal = True, **kwargs)
2044
+
2045
+ class CrossAttender(AttentionLayers):
2046
+ def __init__(self, **kwargs):
2047
+ super().__init__(cross_attend = True, only_cross = True, **kwargs)
2048
+
2049
+ class ViTransformerWrapper(nn.Module):
2050
+ def __init__(
2051
+ self,
2052
+ *,
2053
+ image_size,
2054
+ patch_size,
2055
+ attn_layers,
2056
+ channels = 3,
2057
+ num_classes = None,
2058
+ post_emb_norm = False,
2059
+ num_register_tokens = 0,
2060
+ emb_dropout = 0.
2061
+ ):
2062
+ super().__init__()
2063
+ assert isinstance(attn_layers, Encoder), 'attention layers must be an Encoder'
2064
+ assert divisible_by(image_size, patch_size), 'image dimensions must be divisible by the patch size'
2065
+ dim = attn_layers.dim
2066
+ num_patches = (image_size // patch_size) ** 2
2067
+ patch_dim = channels * patch_size ** 2
2068
+
2069
+ self.patch_size = patch_size
2070
+
2071
+ self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim))
2072
+
2073
+ has_register_tokens = num_register_tokens > 0
2074
+ self.has_register_tokens = has_register_tokens
2075
+
2076
+ if has_register_tokens:
2077
+ self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim))
2078
+
2079
+ self.patch_to_embedding = nn.Sequential(
2080
+ nn.LayerNorm(patch_dim),
2081
+ nn.Linear(patch_dim, dim),
2082
+ nn.LayerNorm(dim)
2083
+ )
2084
+
2085
+ self.post_emb_norm = nn.LayerNorm(dim) if post_emb_norm else nn.Identity()
2086
+ self.dropout = nn.Dropout(emb_dropout)
2087
+
2088
+ self.attn_layers = attn_layers
2089
+
2090
+ self.mlp_head = nn.Linear(dim, num_classes) if exists(num_classes) else nn.Identity()
2091
+
2092
+ def forward(
2093
+ self,
2094
+ img,
2095
+ return_embeddings = False
2096
+ ):
2097
+ b, p = img.shape[0], self.patch_size
2098
+
2099
+ x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
2100
+ x = self.patch_to_embedding(x)
2101
+ n = x.shape[1]
2102
+
2103
+ x = x + self.pos_embedding[:, :n]
2104
+
2105
+ x = self.post_emb_norm(x)
2106
+ x = self.dropout(x)
2107
+
2108
+ if self.has_register_tokens:
2109
+ r = repeat(self.register_tokens, 'n d -> b n d', b = b)
2110
+ x, ps = pack((x, r), 'b * d')
2111
+
2112
+ x = self.attn_layers(x)
2113
+
2114
+ if self.has_register_tokens:
2115
+ x, _ = unpack(x, ps, 'b * d')
2116
+
2117
+ if not exists(self.mlp_head) or return_embeddings:
2118
+ return x
2119
+
2120
+ x = x.mean(dim = -2)
2121
+ return self.mlp_head(x)
2122
+
2123
+ class TransformerWrapper(nn.Module):
2124
+ def __init__(
2125
+ self,
2126
+ *,
2127
+ num_tokens,
2128
+ max_seq_len,
2129
+ attn_layers,
2130
+ emb_dim = None,
2131
+ max_mem_len = 0,
2132
+ shift_mem_down = 0,
2133
+ emb_dropout = 0.,
2134
+ post_emb_norm = False,
2135
+ num_memory_tokens = None,
2136
+ memory_tokens_interspersed_every = None,
2137
+ tie_embedding = False,
2138
+ logits_dim = None,
2139
+ use_abs_pos_emb = True,
2140
+ scaled_sinu_pos_emb = False,
2141
+ l2norm_embed = False,
2142
+ emb_frac_gradient = 1., # GLM-130B and Cogview successfully used this, set at 0.1
2143
+ attn_z_loss_weight = 1e-4,
2144
+ ):
2145
+ super().__init__()
2146
+ assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
2147
+
2148
+ dim = attn_layers.dim
2149
+ emb_dim = default(emb_dim, dim)
2150
+ self.emb_dim = emb_dim
2151
+ self.num_tokens = num_tokens
2152
+
2153
+ self.max_seq_len = max_seq_len
2154
+ self.max_mem_len = max_mem_len
2155
+ self.shift_mem_down = shift_mem_down
2156
+
2157
+ self.l2norm_embed = l2norm_embed
2158
+ self.token_emb = TokenEmbedding(emb_dim, num_tokens, l2norm_embed = l2norm_embed)
2159
+
2160
+ if not (use_abs_pos_emb and not attn_layers.has_pos_emb):
2161
+ self.pos_emb = always(0)
2162
+ elif scaled_sinu_pos_emb:
2163
+ self.pos_emb = ScaledSinusoidalEmbedding(emb_dim)
2164
+ else:
2165
+ self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len, l2norm_embed = l2norm_embed)
2166
+
2167
+ self.emb_frac_gradient = emb_frac_gradient # fraction of the gradient that should go to the embedding, https://arxiv.org/abs/2105.13290
2168
+
2169
+ self.post_emb_norm = nn.LayerNorm(emb_dim) if post_emb_norm else nn.Identity()
2170
+ self.emb_dropout = nn.Dropout(emb_dropout)
2171
+
2172
+ self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
2173
+ self.attn_layers = attn_layers
2174
+
2175
+ self.init_()
2176
+
2177
+ logits_dim = default(logits_dim, num_tokens)
2178
+ self.to_logits = nn.Linear(dim, logits_dim) if not tie_embedding else lambda t: t @ self.token_emb.emb.weight.t()
2179
+
2180
+ # memory tokens (like [cls]) from Memory Transformers paper
2181
+
2182
+ num_memory_tokens = default(num_memory_tokens, 0)
2183
+ self.num_memory_tokens = num_memory_tokens
2184
+ if num_memory_tokens > 0:
2185
+ self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
2186
+
2187
+ self.memory_tokens_interspersed_every = memory_tokens_interspersed_every
2188
+
2189
+ # whether can do cached kv decoding
2190
+
2191
+ self.can_cache_kv = self.num_memory_tokens == 0
2192
+
2193
+ def init_(self):
2194
+ if self.l2norm_embed:
2195
+ nn.init.normal_(self.token_emb.emb.weight, std = 1e-5)
2196
+ if not isinstance(self.pos_emb, always):
2197
+ nn.init.normal_(self.pos_emb.emb.weight, std = 1e-5)
2198
+ return
2199
+
2200
+ nn.init.kaiming_normal_(self.token_emb.emb.weight)
2201
+
2202
+ def forward(
2203
+ self,
2204
+ x,
2205
+ return_embeddings = False,
2206
+ return_logits_and_embeddings = False,
2207
+ return_intermediates = False,
2208
+ mask = None,
2209
+ return_mems = False,
2210
+ return_attn = False,
2211
+ mems = None,
2212
+ pos = None,
2213
+ prepend_embeds = None,
2214
+ sum_embeds = None,
2215
+ return_attn_z_loss = False,
2216
+ attn_z_loss_weight = 1e-4,
2217
+ seq_start_pos = None,
2218
+ cache: Optional[LayerIntermediates] = None,
2219
+ **kwargs
2220
+ ):
2221
+ b, n, device, num_mems, has_memory_tokens, emb_frac_gradient = *x.shape, x.device, self.num_memory_tokens, self.num_memory_tokens > 0, self.emb_frac_gradient
2222
+ return_hiddens = return_mems | return_attn | return_intermediates | return_attn_z_loss
2223
+
2224
+ # absolute positional embedding
2225
+
2226
+ external_pos_emb = exists(pos) and pos.dtype != torch.long
2227
+ pos_emb = self.pos_emb(x, pos = pos, seq_start_pos = seq_start_pos) if not external_pos_emb else pos
2228
+ x = self.token_emb(x) + pos_emb
2229
+
2230
+ # for summing embeddings passed externally - needs this for self-conditioning in non-autoregressive training
2231
+
2232
+ if exists(sum_embeds):
2233
+ x = x + sum_embeds
2234
+
2235
+ # post embedding norm, purportedly leads to greater stabilization
2236
+
2237
+ x = self.post_emb_norm(x)
2238
+
2239
+ # whether to append embeds, as in PaLI, for image embeddings
2240
+
2241
+ if exists(prepend_embeds):
2242
+ prepend_seq, prepend_dim = prepend_embeds.shape[1:]
2243
+ assert prepend_dim == x.shape[-1], 'prepended embeddings need to have same dimensions as text model dimensions'
2244
+
2245
+ x = torch.cat((prepend_embeds, x), dim = -2)
2246
+
2247
+ # whether to reduce the gradient going to the embedding, from cogview paper, corroborated by GLM-130B model
2248
+
2249
+ if emb_frac_gradient < 1:
2250
+ assert emb_frac_gradient > 0
2251
+ x = x * emb_frac_gradient + x.detach() * (1 - emb_frac_gradient)
2252
+
2253
+ # embedding dropout
2254
+
2255
+ x = self.emb_dropout(x)
2256
+
2257
+ x = self.project_emb(x)
2258
+
2259
+ if has_memory_tokens:
2260
+ mem_every = self.memory_tokens_interspersed_every
2261
+
2262
+ if exists(mem_every):
2263
+ assert mem_every > 0
2264
+ assert isinstance(self.attn_layers, Decoder), 'only for decoder'
2265
+ next_seq_len = math.ceil(n / mem_every) * mem_every
2266
+
2267
+ x = pad_at_dim(x, (0, next_seq_len - n), dim = -2, value = 0.)
2268
+ x = rearrange(x, 'b (n m) d -> (b n) m d', m = mem_every)
2269
+
2270
+ mem = repeat(self.memory_tokens, 'n d -> b n d', b = x.shape[0])
2271
+ x, mem_packed_shape = pack((mem, x), 'b * d')
2272
+
2273
+ # auto-handle masking after appending memory tokens
2274
+ if not exists(mem_every) and exists(mask):
2275
+ mask = pad_at_dim(mask, (num_mems, 0), dim = -1, value = True)
2276
+
2277
+ if exists(mem_every):
2278
+ x = rearrange(x, '(b n) m d -> b (n m) d', b = b)
2279
+
2280
+ if self.shift_mem_down and exists(mems):
2281
+ mems_l, mems_r = mems[:self.shift_mem_down], mems[self.shift_mem_down:]
2282
+ mems = [*mems_r, *mems_l]
2283
+
2284
+ x, intermediates = self.attn_layers(x, mask = mask, mems = mems, cache = cache, return_hiddens = True, seq_start_pos = seq_start_pos, **kwargs)
2285
+
2286
+ if has_memory_tokens:
2287
+ if exists(mem_every):
2288
+ x = rearrange(x, 'b (n m) d -> (b n) m d', m = (mem_every + num_mems))
2289
+
2290
+ mem, x = unpack(x, mem_packed_shape, 'b * d')
2291
+
2292
+ if exists(mem_every):
2293
+ x = rearrange(x, '(b n) m d -> b (n m) d', b = b)
2294
+
2295
+ x = x[:, :n]
2296
+
2297
+ if return_logits_and_embeddings:
2298
+ out = (self.to_logits(x), x)
2299
+ elif return_embeddings:
2300
+ out = x
2301
+ else:
2302
+ out = self.to_logits(x)
2303
+
2304
+ if return_attn_z_loss:
2305
+ pre_softmax_attns = list(map(lambda t: t.pre_softmax_attn, intermediates.attn_intermediates))
2306
+ intermediates.attn_z_loss = calc_z_loss(pre_softmax_attns, weight = attn_z_loss_weight)
2307
+ return_intermediates = True
2308
+
2309
+ if return_mems:
2310
+ hiddens = intermediates.hiddens
2311
+ new_mems = list(map(lambda pair: torch.cat(pair, dim = -2), zip(mems, hiddens))) if exists(mems) else hiddens
2312
+ new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems))
2313
+
2314
+ if not return_intermediates:
2315
+ return out, new_mems
2316
+
2317
+ intermediates.mems = new_mems
2318
+
2319
+ if return_intermediates:
2320
+ return out, intermediates
2321
+
2322
+ if return_attn:
2323
+ attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
2324
+ return out, attn_maps
2325
+
2326
+ return out
2327
+
2328
+ class ContinuousTransformerWrapper(nn.Module):
2329
+ def __init__(
2330
+ self,
2331
+ *,
2332
+ max_seq_len,
2333
+ attn_layers,
2334
+ dim_in = None,
2335
+ dim_out = None,
2336
+ emb_dim = None,
2337
+ max_mem_len = 0,
2338
+ post_emb_norm = False,
2339
+ emb_dropout = 0.,
2340
+ use_abs_pos_emb = True,
2341
+ scaled_sinu_pos_emb = False
2342
+ ):
2343
+ super().__init__()
2344
+ assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
2345
+
2346
+ dim = attn_layers.dim
2347
+
2348
+ self.max_seq_len = max_seq_len
2349
+
2350
+ self.max_mem_len = max_mem_len
2351
+
2352
+ if not (use_abs_pos_emb and not attn_layers.has_pos_emb):
2353
+ self.pos_emb = always(0)
2354
+ elif scaled_sinu_pos_emb:
2355
+ self.pos_emb = ScaledSinusoidalEmbedding(dim)
2356
+ else:
2357
+ self.pos_emb = AbsolutePositionalEmbedding(dim, max_seq_len)
2358
+
2359
+ self.post_emb_norm = nn.LayerNorm(dim) if post_emb_norm else nn.Identity()
2360
+ self.emb_dropout = nn.Dropout(emb_dropout)
2361
+
2362
+ self.project_in = nn.Linear(dim_in, dim) if exists(dim_in) else nn.Identity()
2363
+
2364
+ self.attn_layers = attn_layers
2365
+
2366
+ self.project_out = nn.Linear(dim, dim_out) if exists(dim_out) else nn.Identity()
2367
+
2368
+ def forward(
2369
+ self,
2370
+ x,
2371
+ return_embeddings = False,
2372
+ return_intermediates = False,
2373
+ return_mems = False,
2374
+ mask = None,
2375
+ return_attn = False,
2376
+ mems = None,
2377
+ pos = None,
2378
+ prepend_embeds = None,
2379
+ **kwargs
2380
+ ):
2381
+ x = self.project_in(x)
2382
+ x = x + self.pos_emb(x, pos = pos)
2383
+
2384
+ x = self.post_emb_norm(x)
2385
+
2386
+ # whether to append embeds, as in PaLI, for image embeddings
2387
+
2388
+ if exists(prepend_embeds):
2389
+ _, prepend_dim = prepend_embeds.shape[1:]
2390
+ assert prepend_dim == x.shape[-1], 'prepended embeddings need to have same dimensions as model dimensions'
2391
+
2392
+ x = torch.cat((prepend_embeds, x), dim = -2)
2393
+
2394
+ x = self.emb_dropout(x)
2395
+
2396
+ x, intermediates = self.attn_layers(x, mask = mask, mems = mems, return_hiddens = True, **kwargs)
2397
+
2398
+ out = self.project_out(x) if not return_embeddings else x
2399
+
2400
+ if return_intermediates:
2401
+ return out, intermediates
2402
+
2403
+ if return_mems:
2404
+ hiddens = intermediates.hiddens
2405
+ new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), hiddens))
2406
+ return out, new_mems
2407
+
2408
+ if return_attn:
2409
+ attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
2410
+ return out, attn_maps
2411
+
2412
+ return out
2413
+
2414
+ class XTransformer(nn.Module):
2415
+ def __init__(
2416
+ self,
2417
+ *,
2418
+ dim,
2419
+ tie_token_emb = False,
2420
+ ignore_index = -100,
2421
+ pad_value = 0,
2422
+ cross_attn_tokens_dropout = 0.,
2423
+ **kwargs
2424
+ ):
2425
+ super().__init__()
2426
+ enc_kwargs, kwargs = groupby_prefix_and_trim('enc_', kwargs)
2427
+ dec_kwargs, kwargs = groupby_prefix_and_trim('dec_', kwargs)
2428
+
2429
+ assert 'dim' not in enc_kwargs and 'dim' not in dec_kwargs, 'dimension of either encoder or decoder must be set with `dim` keyword'
2430
+ enc_transformer_kwargs = pick_and_pop(['num_tokens', 'max_seq_len'], enc_kwargs)
2431
+ enc_transformer_kwargs['emb_dropout'] = enc_kwargs.pop('emb_dropout', 0)
2432
+ enc_transformer_kwargs['num_memory_tokens'] = enc_kwargs.pop('num_memory_tokens', None)
2433
+ enc_transformer_kwargs['scaled_sinu_pos_emb'] = enc_kwargs.pop('scaled_sinu_pos_emb', False)
2434
+ enc_transformer_kwargs['use_abs_pos_emb'] = enc_kwargs.pop('use_abs_pos_emb', True)
2435
+
2436
+ dec_transformer_kwargs = pick_and_pop(['num_tokens', 'max_seq_len'], dec_kwargs)
2437
+ dec_transformer_kwargs['emb_dropout'] = dec_kwargs.pop('emb_dropout', 0)
2438
+ dec_transformer_kwargs['scaled_sinu_pos_emb'] = dec_kwargs.pop('scaled_sinu_pos_emb', False)
2439
+ dec_transformer_kwargs['use_abs_pos_emb'] = dec_kwargs.pop('use_abs_pos_emb', True)
2440
+
2441
+ self.cross_attn_tokens_dropout = cross_attn_tokens_dropout # how many tokens from the encoder to dropout when cross attending from decoder - seen in a couple papers, including Perceiver AR - this will also be very effective regularization when cross attending to very long memories
2442
+
2443
+ self.encoder = TransformerWrapper(
2444
+ **enc_transformer_kwargs,
2445
+ attn_layers = Encoder(dim = dim, **enc_kwargs)
2446
+ )
2447
+
2448
+ self.decoder = TransformerWrapper(
2449
+ **dec_transformer_kwargs,
2450
+ attn_layers = Decoder(dim = dim, cross_attend = True, **dec_kwargs)
2451
+ )
2452
+
2453
+ if tie_token_emb:
2454
+ self.decoder.token_emb = self.encoder.token_emb
2455
+
2456
+ self.decoder = AutoregressiveWrapper(self.decoder, ignore_index=ignore_index, pad_value=pad_value)
2457
+
2458
+ @torch.no_grad()
2459
+ def generate(self, seq_in, seq_out_start, seq_len, mask = None, attn_mask = None, **kwargs):
2460
+ encodings = self.encoder(seq_in, mask = mask, attn_mask = attn_mask, return_embeddings = True)
2461
+ return self.decoder.generate(seq_out_start, seq_len, context = encodings, context_mask = mask, **kwargs)
2462
+
2463
+ def forward(self, src, tgt, mask = None, attn_mask = None, src_prepend_embeds = None):
2464
+
2465
+ if exists(src_prepend_embeds) and exists(mask):
2466
+ mask = pad_at_dim(mask, (src_prepend_embeds.shape[-2], 0), dim = -1, value = True)
2467
+
2468
+ enc = self.encoder(src, mask = mask, attn_mask = attn_mask, prepend_embeds = src_prepend_embeds, return_embeddings = True)
2469
+
2470
+ if self.training and self.cross_attn_tokens_dropout > 0:
2471
+ enc, mask = dropout_seq(enc, mask, self.cross_attn_tokens_dropout)
2472
+
2473
+ out = self.decoder(tgt, context = enc, context_mask = mask)
2474
+ return out