jbetker commited on
Commit
73e9929
1 Parent(s): 33e4bc7

new autoregressive check-in

Browse files
api_new_autoregressive.py CHANGED
@@ -134,8 +134,8 @@ class TextToSpeech:
134
  self.tokenizer = VoiceBpeTokenizer()
135
  download_models()
136
 
137
- self.autoregressive = AutoregressiveCodegen(512, 12).cpu().eval()
138
- self.autoregressive.load_state_dict(torch.load('D:\\dlas\\experiments\\train_autoregressive_codegen\\models\\23000_codegen_ema.pth'))
139
 
140
  self.clip = VoiceCLIP(dim_text=512, dim_speech=512, dim_latent=512, num_text_tokens=256, text_enc_depth=12,
141
  text_seq_len=350, text_heads=8,
 
134
  self.tokenizer = VoiceBpeTokenizer()
135
  download_models()
136
 
137
+ self.autoregressive = AutoregressiveCodegen(1024, 16).cpu().eval()
138
+ self.autoregressive.load_state_dict(torch.load('X:\\dlas\\experiments\\train_autoregressive_codegen\\models\\11000_codegen_ema.pth'))
139
 
140
  self.clip = VoiceCLIP(dim_text=512, dim_speech=512, dim_latent=512, num_text_tokens=256, text_enc_depth=12,
141
  text_seq_len=350, text_heads=8,
models/new_autoregressive.py CHANGED
@@ -1,11 +1,9 @@
1
- import functools
2
-
3
  import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
  from transformers import GPT2PreTrainedModel, GPT2Config
 
7
  from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
8
- from x_transformers import TransformerWrapper, Encoder, Decoder
9
 
10
  from models.arch_util import AttentionBlock
11
 
@@ -87,8 +85,8 @@ class InferenceModel(GPT2PreTrainedModel):
87
  assert labels is None # Training not supported by this inference model.
88
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
89
 
90
- hidden_states = self.transformer.decoder(input_ids, context=self.context, return_embeddings=True)
91
- logits = self.transformer.decoder.transformer.to_logits(hidden_states)
92
 
93
  if not return_dict:
94
  return (logits, )
@@ -157,54 +155,22 @@ class ConditioningEncoder(nn.Module):
157
  return h.mean(dim=2)
158
 
159
 
160
- class CheckpointedLayer(nn.Module):
161
- """
162
- Wraps a module. When forward() is called, passes kwargs that require_grad through torch.checkpoint() and bypasses
163
- checkpoint for all other args.
164
- """
165
- def __init__(self, wrap):
166
- super().__init__()
167
- self.wrap = wrap
168
-
169
- def forward(self, x, *args, **kwargs):
170
- for k, v in kwargs.items():
171
- assert not (isinstance(v, torch.Tensor) and v.requires_grad) # This would screw up checkpointing.
172
- partial = functools.partial(self.wrap, **kwargs)
173
- return torch.utils.checkpoint.checkpoint(partial, x, *args)
174
-
175
-
176
- class CheckpointedXTransformerWrapper(nn.Module):
177
- """
178
- Wraps a TransformerWrapper and applies CheckpointedLayer to each layer.
179
- """
180
- def __init__(self, checkpoint=True, **xtransformer_kwargs):
181
- super().__init__()
182
- self.transformer = TransformerWrapper(**xtransformer_kwargs)
183
-
184
- if not checkpoint:
185
- return
186
- for i in range(len(self.transformer.attn_layers.layers)):
187
- n, b, r = self.transformer.attn_layers.layers[i]
188
- self.transformer.attn_layers.layers[i] = nn.ModuleList([n, CheckpointedLayer(b), r])
189
-
190
- def forward(self, x, **kwargs):
191
- return self.transformer(x, **kwargs)
192
-
193
-
194
  class AutoregressiveCodegen(nn.Module):
195
- def __init__(self, model_dim, depth, num_text_tokens=256, num_mel_tokens=8194, max_text_tokens=4000,
196
- max_mel_tokens=4000, dropout=.1):
197
  super().__init__()
 
198
 
199
  self.START_TOKEN=8192
200
  self.STOP_TOKEN=8193
201
- self.max_mel_tokens = max_mel_tokens
202
- self.minicoder = ConditioningEncoder(80, model_dim, do_checkpointing=False)
203
- self.encoder = CheckpointedXTransformerWrapper(
 
204
  num_tokens=num_text_tokens,
205
- max_seq_len=max_text_tokens,
 
206
  attn_layers = Encoder(
207
- depth=depth//2,
208
  heads=model_dim//64,
209
  dim=model_dim,
210
  attn_dropout=dropout,
@@ -213,11 +179,14 @@ class AutoregressiveCodegen(nn.Module):
213
  ff_glu=True,
214
  ff_mult=1,
215
  rotary_pos_emb=True,
216
- rel_pos_bias=True,
217
  ))
218
- self.decoder = CheckpointedXTransformerWrapper(
 
 
219
  num_tokens=num_mel_tokens,
220
- max_seq_len=max_mel_tokens,
 
221
  attn_layers=Decoder(
222
  depth=depth,
223
  heads=model_dim//64,
@@ -228,18 +197,21 @@ class AutoregressiveCodegen(nn.Module):
228
  ff_glu=True,
229
  ff_mult=1,
230
  rotary_pos_emb=True,
231
- rel_pos_bias=True,
232
  cross_attend=True,
 
233
  ))
234
 
235
  def get_grad_norm_parameter_groups(self):
236
  return {
237
  'encoder': list(self.encoder.parameters()),
238
  'decoder': list(self.decoder.parameters()),
239
- 'minicoder': list(self.minicoder.parameters()),
240
  }
241
 
242
  def forward(self, text_codes, conditioning_signal, mel_codes, wav_lengths, return_loss=True):
 
 
 
243
  # Format mel_codes with a stop token on the end.
244
  mel_lengths = wav_lengths // 1024 + 1
245
  for b in range(mel_codes.shape[0]):
@@ -251,43 +223,51 @@ class AutoregressiveCodegen(nn.Module):
251
  conditioning_signal = conditioning_signal.unsqueeze(1)
252
  cond_embs = []
253
  for i in range(conditioning_signal.shape[1]):
254
- cond_embs.append(self.minicoder(conditioning_signal[:, i]))
255
  cond_emb = torch.stack(cond_embs, dim=1).mean(dim=1, keepdim=True)
256
- enc_text = self.encoder(text_codes, return_embeddings=True)
257
- context = torch.cat([cond_emb, enc_text], dim=1)
 
 
 
 
258
 
259
  # Execute the decoder
260
  dec_inputs = F.pad(mel_codes, (1,0), value=self.START_TOKEN)[:, :-1]
261
- dec = self.decoder(dec_inputs, context=context)
262
  if not return_loss:
263
  return dec
264
  loss_mel = F.cross_entropy(dec.permute(0,2,1), mel_codes)
265
  return loss_mel
266
 
267
- def generate(self, conditioning_signal, text_codes, **hf_generate_kwargs):
268
- if not hasattr(self, 'inference_model'):
269
- self.inference_model = InferenceModel(self)
270
-
271
  if len(conditioning_signal.shape) != 4:
272
  conditioning_signal = conditioning_signal.unsqueeze(1)
273
  cond_embs = []
274
  for i in range(conditioning_signal.shape[1]):
275
- cond_embs.append(self.minicoder(conditioning_signal[:, i]))
276
  cond_emb = torch.stack(cond_embs, dim=1).mean(dim=1, keepdim=True)
277
- enc_text = self.encoder(text_codes, return_embeddings=True)
278
- context = torch.cat([cond_emb, enc_text], dim=1)
279
- self.inference_model.store_context(context)
280
-
281
- gen = self.inference_model.generate(bos_token_id=self.START_TOKEN, pad_token_id=self.STOP_TOKEN, eos_token_id=self.STOP_TOKEN,
282
- max_length=250, output_attentions=False, return_dict_in_generate=True,
 
 
 
 
283
  **hf_generate_kwargs)
284
  return gen.sequences
285
 
286
 
287
  if __name__ == '__main__':
288
- codegen = AutoregressiveCodegen(1024, 20)
289
- codegen.generate(torch.randn((1,80,120)), torch.randint(0,256,(1,200)))
 
290
  codegen(torch.randint(0,256, (2,200)),
291
  torch.randn(2,80,120),
292
  torch.randint(0,8192, (2,350)),
293
- torch.tensor([192,350]))
 
 
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
  from transformers import GPT2PreTrainedModel, GPT2Config
5
+ from models.xtransformers import TransformerWrapper, Encoder, Decoder
6
  from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
 
7
 
8
  from models.arch_util import AttentionBlock
9
 
 
85
  assert labels is None # Training not supported by this inference model.
86
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
87
 
88
+ hidden_states = self.transformer.decoder(input_ids, full_context=self.context, return_embeddings=True)
89
+ logits = self.transformer.decoder.to_logits(hidden_states)
90
 
91
  if not return_dict:
92
  return (logits, )
 
155
  return h.mean(dim=2)
156
 
157
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  class AutoregressiveCodegen(nn.Module):
159
+ def __init__(self, model_dim, depth, num_text_tokens=256, num_mel_tokens=8194, dropout=.1):
 
160
  super().__init__()
161
+ assert depth >= 8 # This is the minimum bound to support the context interleaving that happens later.
162
 
163
  self.START_TOKEN=8192
164
  self.STOP_TOKEN=8193
165
+ self.max_text_token_id = num_text_tokens
166
+ self.max_mel_token_id = num_mel_tokens
167
+ self.mel_embedding = ConditioningEncoder(80, model_dim, do_checkpointing=False)
168
+ self.encoder = TransformerWrapper(
169
  num_tokens=num_text_tokens,
170
+ use_pos_emb=False,
171
+ max_seq_len=-1,
172
  attn_layers = Encoder(
173
+ depth=depth,
174
  heads=model_dim//64,
175
  dim=model_dim,
176
  attn_dropout=dropout,
 
179
  ff_glu=True,
180
  ff_mult=1,
181
  rotary_pos_emb=True,
182
+ attn_rel_pos_bias=True,
183
  ))
184
+ self.encoder.norm = nn.Identity() # This layer and the next are unused.
185
+ self.encoder.to_logits = nn.Identity()
186
+ self.decoder = TransformerWrapper(
187
  num_tokens=num_mel_tokens,
188
+ use_pos_emb=False,
189
+ max_seq_len=-1,
190
  attn_layers=Decoder(
191
  depth=depth,
192
  heads=model_dim//64,
 
197
  ff_glu=True,
198
  ff_mult=1,
199
  rotary_pos_emb=True,
 
200
  cross_attend=True,
201
+ attn_rel_pos_bias=True,
202
  ))
203
 
204
  def get_grad_norm_parameter_groups(self):
205
  return {
206
  'encoder': list(self.encoder.parameters()),
207
  'decoder': list(self.decoder.parameters()),
208
+ 'minicoder': list(self.mel_embedding.parameters()),
209
  }
210
 
211
  def forward(self, text_codes, conditioning_signal, mel_codes, wav_lengths, return_loss=True):
212
+ assert text_codes.max() < self.max_text_token_id and text_codes.min() >= 0, f'Invalid text code encountered: {text_codes.max()}, {text_codes.min()}'
213
+ assert mel_codes.max() < self.max_mel_token_id and mel_codes.min() >= 0, f'Invalid mel code encountered: {mel_codes.max()}, {mel_codes.min()}'
214
+
215
  # Format mel_codes with a stop token on the end.
216
  mel_lengths = wav_lengths // 1024 + 1
217
  for b in range(mel_codes.shape[0]):
 
223
  conditioning_signal = conditioning_signal.unsqueeze(1)
224
  cond_embs = []
225
  for i in range(conditioning_signal.shape[1]):
226
+ cond_embs.append(self.mel_embedding(conditioning_signal[:, i]))
227
  cond_emb = torch.stack(cond_embs, dim=1).mean(dim=1, keepdim=True)
228
+ _, enc_text = self.encoder(text_codes, return_hiddens=True)
229
+ # Interleave cond_emb into the first few contexts.
230
+ full_context = enc_text
231
+ full_context[1] = cond_emb
232
+ full_context[3] = cond_emb
233
+ full_context[6] = cond_emb
234
 
235
  # Execute the decoder
236
  dec_inputs = F.pad(mel_codes, (1,0), value=self.START_TOKEN)[:, :-1]
237
+ dec = self.decoder(dec_inputs, full_context=full_context)
238
  if not return_loss:
239
  return dec
240
  loss_mel = F.cross_entropy(dec.permute(0,2,1), mel_codes)
241
  return loss_mel
242
 
243
+ def generate(self, conditioning_signal, text_codes, max_tokens=256, **hf_generate_kwargs):
244
+ inference_model = InferenceModel(self)
245
+ # Build the context
 
246
  if len(conditioning_signal.shape) != 4:
247
  conditioning_signal = conditioning_signal.unsqueeze(1)
248
  cond_embs = []
249
  for i in range(conditioning_signal.shape[1]):
250
+ cond_embs.append(self.mel_embedding(conditioning_signal[:, i]))
251
  cond_emb = torch.stack(cond_embs, dim=1).mean(dim=1, keepdim=True)
252
+ _, enc_text = self.encoder(text_codes, return_hiddens=True)
253
+ # Interleave cond_emb into the first few contexts.
254
+ full_context = enc_text
255
+ full_context[1] = cond_emb
256
+ full_context[3] = cond_emb
257
+ full_context[6] = cond_emb
258
+ inference_model.store_context(full_context)
259
+
260
+ gen = inference_model.generate(bos_token_id=self.START_TOKEN, pad_token_id=self.STOP_TOKEN, eos_token_id=self.STOP_TOKEN,
261
+ max_length=max_tokens, output_attentions=False, return_dict_in_generate=True,
262
  **hf_generate_kwargs)
263
  return gen.sequences
264
 
265
 
266
  if __name__ == '__main__':
267
+ codegen = AutoregressiveCodegen(256, 10)
268
+ torch.save(codegen.state_dict(), 'sample.pth')
269
+ #codegen.generate(torch.randn((1,80,120)), torch.randint(0,256,(1,200)))
270
  codegen(torch.randint(0,256, (2,200)),
271
  torch.randn(2,80,120),
272
  torch.randint(0,8192, (2,350)),
273
+ torch.tensor([192,350]))
models/xtransformers.py ADDED
@@ -0,0 +1,1259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import math
3
+ import torch
4
+ from torch import nn, einsum
5
+ import torch.nn.functional as F
6
+ from functools import partial
7
+ from inspect import isfunction
8
+ from collections import namedtuple
9
+
10
+ from einops import rearrange, repeat, reduce
11
+ from einops.layers.torch import Rearrange
12
+
13
+ from entmax import entmax15
14
+ from torch.utils.checkpoint import checkpoint
15
+
16
+ from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
17
+
18
+ DEFAULT_DIM_HEAD = 64
19
+
20
+ Intermediates = namedtuple('Intermediates', [
21
+ 'pre_softmax_attn',
22
+ 'post_softmax_attn'
23
+ ])
24
+
25
+ LayerIntermediates = namedtuple('Intermediates', [
26
+ 'hiddens',
27
+ 'attn_intermediates'
28
+ ])
29
+
30
+
31
+ # helpers
32
+
33
+ def exists(val):
34
+ return val is not None
35
+
36
+
37
+ def default(val, d):
38
+ if exists(val):
39
+ return val
40
+ return d() if isfunction(d) else d
41
+
42
+
43
+ def cast_tuple(val, depth):
44
+ return val if isinstance(val, tuple) else (val,) * depth
45
+
46
+
47
+ class always():
48
+ def __init__(self, val):
49
+ self.val = val
50
+
51
+ def __call__(self, *args, **kwargs):
52
+ return self.val
53
+
54
+
55
+ class not_equals():
56
+ def __init__(self, val):
57
+ self.val = val
58
+
59
+ def __call__(self, x, *args, **kwargs):
60
+ return x != self.val
61
+
62
+
63
+ class equals():
64
+ def __init__(self, val):
65
+ self.val = val
66
+
67
+ def __call__(self, x, *args, **kwargs):
68
+ return x == self.val
69
+
70
+
71
+ def max_neg_value(tensor):
72
+ return -torch.finfo(tensor.dtype).max
73
+
74
+
75
+ def l2norm(t):
76
+ return F.normalize(t, p=2, dim=-1)
77
+
78
+
79
+ # init helpers
80
+
81
+ def init_zero_(layer):
82
+ nn.init.constant_(layer.weight, 0.)
83
+ if exists(layer.bias):
84
+ nn.init.constant_(layer.bias, 0.)
85
+
86
+
87
+ # keyword argument helpers
88
+
89
+ def pick_and_pop(keys, d):
90
+ values = list(map(lambda key: d.pop(key), keys))
91
+ return dict(zip(keys, values))
92
+
93
+
94
+ def group_dict_by_key(cond, d):
95
+ return_val = [dict(), dict()]
96
+ for key in d.keys():
97
+ match = bool(cond(key))
98
+ ind = int(not match)
99
+ return_val[ind][key] = d[key]
100
+ return (*return_val,)
101
+
102
+
103
+ def string_begins_with(prefix, str):
104
+ return str.startswith(prefix)
105
+
106
+
107
+ def group_by_key_prefix(prefix, d):
108
+ return group_dict_by_key(partial(string_begins_with, prefix), d)
109
+
110
+
111
+ def groupby_prefix_and_trim(prefix, d):
112
+ kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
113
+ kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
114
+ return kwargs_without_prefix, kwargs
115
+
116
+
117
+ # activations
118
+
119
+ class ReluSquared(nn.Module):
120
+ def forward(self, x):
121
+ return F.relu(x) ** 2
122
+
123
+
124
+ # positional embeddings
125
+
126
+ class AbsolutePositionalEmbedding(nn.Module):
127
+ def __init__(self, dim, max_seq_len):
128
+ super().__init__()
129
+ self.scale = dim ** -0.5
130
+ self.emb = nn.Embedding(max_seq_len, dim)
131
+
132
+ def forward(self, x):
133
+ n = torch.arange(x.shape[1], device=x.device)
134
+ pos_emb = self.emb(n)
135
+ pos_emb = rearrange(pos_emb, 'n d -> () n d')
136
+ return pos_emb * self.scale
137
+
138
+
139
+ class FixedPositionalEmbedding(nn.Module):
140
+ def __init__(self, dim):
141
+ super().__init__()
142
+ inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
143
+ self.register_buffer('inv_freq', inv_freq)
144
+
145
+ def forward(self, x, seq_dim=1, offset=0):
146
+ t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset
147
+ sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq)
148
+ emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
149
+ return rearrange(emb, 'n d -> () n d')
150
+
151
+
152
+ class RelativePositionBias(nn.Module):
153
+ def __init__(self, scale, causal=False, num_buckets=32, max_distance=128, heads=8):
154
+ super().__init__()
155
+ self.scale = scale
156
+ self.causal = causal
157
+ self.num_buckets = num_buckets
158
+ self.max_distance = max_distance
159
+ self.relative_attention_bias = nn.Embedding(num_buckets, heads)
160
+
161
+ @staticmethod
162
+ def _relative_position_bucket(relative_position, causal=True, num_buckets=32, max_distance=128):
163
+ ret = 0
164
+ n = -relative_position
165
+ if not causal:
166
+ num_buckets //= 2
167
+ ret += (n < 0).long() * num_buckets
168
+ n = torch.abs(n)
169
+ else:
170
+ n = torch.max(n, torch.zeros_like(n))
171
+
172
+ max_exact = num_buckets // 2
173
+ is_small = n < max_exact
174
+
175
+ val_if_large = max_exact + (
176
+ torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
177
+ ).long()
178
+ val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
179
+
180
+ ret += torch.where(is_small, n, val_if_large)
181
+ return ret
182
+
183
+ def forward(self, qk_dots):
184
+ i, j, device = *qk_dots.shape[-2:], qk_dots.device
185
+ q_pos = torch.arange(i, dtype=torch.long, device=device)
186
+ k_pos = torch.arange(j, dtype=torch.long, device=device)
187
+ rel_pos = k_pos[None, :] - q_pos[:, None]
188
+ rp_bucket = self._relative_position_bucket(rel_pos, causal=self.causal, num_buckets=self.num_buckets,
189
+ max_distance=self.max_distance)
190
+ values = self.relative_attention_bias(rp_bucket)
191
+ bias = rearrange(values, 'i j h -> () h i j')
192
+ return qk_dots + (bias * self.scale)
193
+
194
+
195
+ class AlibiPositionalBias(nn.Module):
196
+ def __init__(self, heads, **kwargs):
197
+ super().__init__()
198
+ self.heads = heads
199
+ slopes = torch.Tensor(self._get_slopes(heads))
200
+ slopes = rearrange(slopes, 'h -> () h () ()')
201
+ self.register_buffer('slopes', slopes, persistent=False)
202
+ self.register_buffer('bias', None, persistent=False)
203
+
204
+ @staticmethod
205
+ def _get_slopes(heads):
206
+ def get_slopes_power_of_2(n):
207
+ start = (2 ** (-2 ** -(math.log2(n) - 3)))
208
+ ratio = start
209
+ return [start * ratio ** i for i in range(n)]
210
+
211
+ if math.log2(heads).is_integer():
212
+ return get_slopes_power_of_2(heads)
213
+
214
+ closest_power_of_2 = 2 ** math.floor(math.log2(heads))
215
+ return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][
216
+ :heads - closest_power_of_2]
217
+
218
+ def forward(self, qk_dots):
219
+ h, i, j, device = *qk_dots.shape[-3:], qk_dots.device
220
+
221
+ if exists(self.bias) and self.bias.shape[-1] >= j:
222
+ return qk_dots + self.bias[..., :j]
223
+
224
+ bias = torch.arange(j, device=device)
225
+ bias = rearrange(bias, 'j -> () () () j')
226
+ bias = bias * self.slopes
227
+
228
+ num_heads_unalibied = h - bias.shape[1]
229
+ bias = F.pad(bias, (0, 0, 0, 0, 0, num_heads_unalibied))
230
+
231
+ self.register_buffer('bias', bias, persistent=False)
232
+ return qk_dots + self.bias
233
+
234
+
235
+ class LearnedAlibiPositionalBias(AlibiPositionalBias):
236
+ def __init__(self, heads, bidirectional=False):
237
+ super().__init__(heads)
238
+ los_slopes = torch.log(self.slopes)
239
+ self.learned_logslopes = nn.Parameter(los_slopes)
240
+
241
+ self.bidirectional = bidirectional
242
+ if self.bidirectional:
243
+ self.learned_logslopes_future = nn.Parameter(los_slopes)
244
+
245
+ def forward(self, qk_dots):
246
+ h, i, j, device = *qk_dots.shape[-3:], qk_dots.device
247
+
248
+ def get_slopes(param):
249
+ return F.pad(param.exp(), (0, 0, 0, 0, 0, h - param.shape[1]))
250
+
251
+ if exists(self.bias) and self.bias.shape[-1] >= j:
252
+ bias = self.bias[..., :i, :j]
253
+ else:
254
+ i_arange = torch.arange(i, device=device)
255
+ j_arange = torch.arange(j, device=device)
256
+ bias = rearrange(j_arange, 'j -> 1 1 1 j') - rearrange(i_arange, 'i -> 1 1 i 1')
257
+ self.register_buffer('bias', bias, persistent=False)
258
+
259
+ if self.bidirectional:
260
+ past_slopes = get_slopes(self.learned_logslopes)
261
+ future_slopes = get_slopes(self.learned_logslopes_future)
262
+ bias = torch.tril(bias * past_slopes) + torch.triu(bias * future_slopes)
263
+ else:
264
+ slopes = get_slopes(self.learned_logslopes)
265
+ bias = bias * slopes
266
+
267
+ return qk_dots + bias
268
+
269
+
270
+ class RotaryEmbedding(nn.Module):
271
+ def __init__(self, dim):
272
+ super().__init__()
273
+ inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
274
+ self.register_buffer('inv_freq', inv_freq)
275
+
276
+ def forward(self, max_seq_len, device):
277
+ t = torch.arange(max_seq_len, device=device).type_as(self.inv_freq)
278
+ freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
279
+ emb = torch.cat((freqs, freqs), dim=-1)
280
+ return rearrange(emb, 'n d -> () () n d')
281
+
282
+
283
+ def rotate_half(x):
284
+ x = rearrange(x, '... (j d) -> ... j d', j=2)
285
+ x1, x2 = x.unbind(dim=-2)
286
+ return torch.cat((-x2, x1), dim=-1)
287
+
288
+
289
+ def apply_rotary_pos_emb(t, freqs):
290
+ seq_len = t.shape[-2]
291
+ freqs = freqs[:, :, -seq_len:]
292
+ return (t * freqs.cos()) + (rotate_half(t) * freqs.sin())
293
+
294
+
295
+ # norms
296
+
297
+ class Scale(nn.Module):
298
+ def __init__(self, value, fn):
299
+ super().__init__()
300
+ self.value = value
301
+ self.fn = fn
302
+
303
+ def forward(self, x, **kwargs):
304
+ out = self.fn(x, **kwargs)
305
+ scale_fn = lambda t: t * self.value
306
+
307
+ if not isinstance(out, tuple):
308
+ return scale_fn(out)
309
+
310
+ return (scale_fn(out[0]), *out[1:])
311
+
312
+
313
+ class Rezero(nn.Module):
314
+ def __init__(self, fn):
315
+ super().__init__()
316
+ self.fn = fn
317
+ self.g = nn.Parameter(torch.zeros(1))
318
+
319
+ def forward(self, x, **kwargs):
320
+ out = self.fn(x, **kwargs)
321
+ rezero_fn = lambda t: t * self.g
322
+
323
+ if not isinstance(out, tuple):
324
+ return rezero_fn(out)
325
+
326
+ return (rezero_fn(out[0]), *out[1:])
327
+
328
+
329
+ class ScaleNorm(nn.Module):
330
+ def __init__(self, dim, eps=1e-5):
331
+ super().__init__()
332
+ self.scale = dim ** -0.5
333
+ self.eps = eps
334
+ self.g = nn.Parameter(torch.ones(1))
335
+
336
+ def forward(self, x):
337
+ norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
338
+ return x / norm.clamp(min=self.eps) * self.g
339
+
340
+
341
+ class RMSNorm(nn.Module):
342
+ def __init__(self, dim, eps=1e-8):
343
+ super().__init__()
344
+ self.scale = dim ** -0.5
345
+ self.eps = eps
346
+ self.g = nn.Parameter(torch.ones(dim))
347
+
348
+ def forward(self, x):
349
+ norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
350
+ return x / norm.clamp(min=self.eps) * self.g
351
+
352
+
353
+ class RMSScaleShiftNorm(nn.Module):
354
+ def __init__(self, dim, eps=1e-8):
355
+ super().__init__()
356
+ self.scale = dim ** -0.5
357
+ self.eps = eps
358
+ self.g = nn.Parameter(torch.ones(dim))
359
+ self.scale_shift_process = nn.Linear(dim * 2, dim * 2)
360
+
361
+ def forward(self, x, norm_scale_shift_inp):
362
+ norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
363
+ norm = x / norm.clamp(min=self.eps) * self.g
364
+
365
+ ss_emb = self.scale_shift_process(norm_scale_shift_inp)
366
+ scale, shift = torch.chunk(ss_emb, 2, dim=1)
367
+ h = norm * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
368
+ return h
369
+
370
+
371
+ # residual and residual gates
372
+
373
+ class Residual(nn.Module):
374
+ def __init__(self, dim, scale_residual=False):
375
+ super().__init__()
376
+ self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
377
+
378
+ def forward(self, x, residual):
379
+ if exists(self.residual_scale):
380
+ residual = residual * self.residual_scale
381
+
382
+ return x + residual
383
+
384
+
385
+ class GRUGating(nn.Module):
386
+ def __init__(self, dim, scale_residual=False):
387
+ super().__init__()
388
+ self.gru = nn.GRUCell(dim, dim)
389
+ self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
390
+
391
+ def forward(self, x, residual):
392
+ if exists(self.residual_scale):
393
+ residual = residual * self.residual_scale
394
+
395
+ gated_output = self.gru(
396
+ rearrange(x, 'b n d -> (b n) d'),
397
+ rearrange(residual, 'b n d -> (b n) d')
398
+ )
399
+
400
+ return gated_output.reshape_as(x)
401
+
402
+
403
+ # token shifting
404
+
405
+ def shift(t, amount, mask=None):
406
+ if amount == 0:
407
+ return t
408
+
409
+ if exists(mask):
410
+ t = t.masked_fill(~mask[..., None], 0.)
411
+
412
+ return F.pad(t, (0, 0, amount, -amount), value=0.)
413
+
414
+
415
+ class ShiftTokens(nn.Module):
416
+ def __init__(self, shifts, fn):
417
+ super().__init__()
418
+ self.fn = fn
419
+ self.shifts = tuple(shifts)
420
+
421
+ def forward(self, x, **kwargs):
422
+ mask = kwargs.get('mask', None)
423
+ shifts = self.shifts
424
+ segments = len(shifts)
425
+ feats_per_shift = x.shape[-1] // segments
426
+ splitted = x.split(feats_per_shift, dim=-1)
427
+ segments_to_shift, rest = splitted[:segments], splitted[segments:]
428
+ segments_to_shift = list(map(lambda args: shift(*args, mask=mask), zip(segments_to_shift, shifts)))
429
+ x = torch.cat((*segments_to_shift, *rest), dim=-1)
430
+ return self.fn(x, **kwargs)
431
+
432
+
433
+ # feedforward
434
+
435
+ class GLU(nn.Module):
436
+ def __init__(self, dim_in, dim_out, activation):
437
+ super().__init__()
438
+ self.act = activation
439
+ self.proj = nn.Linear(dim_in, dim_out * 2)
440
+
441
+ def forward(self, x):
442
+ x, gate = self.proj(x).chunk(2, dim=-1)
443
+ return x * self.act(gate)
444
+
445
+
446
+ class FeedForward(nn.Module):
447
+ def __init__(
448
+ self,
449
+ dim,
450
+ dim_out=None,
451
+ mult=4,
452
+ glu=False,
453
+ relu_squared=False,
454
+ post_act_ln=False,
455
+ dropout=0.,
456
+ zero_init_output=False
457
+ ):
458
+ super().__init__()
459
+ inner_dim = int(dim * mult)
460
+ dim_out = default(dim_out, dim)
461
+ activation = ReluSquared() if relu_squared else nn.GELU()
462
+
463
+ project_in = nn.Sequential(
464
+ nn.Linear(dim, inner_dim),
465
+ activation
466
+ ) if not glu else GLU(dim, inner_dim, activation)
467
+
468
+ self.net = nn.Sequential(
469
+ project_in,
470
+ nn.LayerNorm(inner_dim) if post_act_ln else nn.Identity(),
471
+ nn.Dropout(dropout),
472
+ nn.Linear(inner_dim, dim_out)
473
+ )
474
+
475
+ # init last linear layer to 0
476
+ if zero_init_output:
477
+ init_zero_(self.net[-1])
478
+
479
+ def forward(self, x):
480
+ return self.net(x)
481
+
482
+
483
+ # attention.
484
+
485
+ class Attention(nn.Module):
486
+ def __init__(
487
+ self,
488
+ dim,
489
+ dim_head=DEFAULT_DIM_HEAD,
490
+ heads=8,
491
+ causal=False,
492
+ talking_heads=False,
493
+ head_scale=False,
494
+ collab_heads=False,
495
+ collab_compression=.3,
496
+ sparse_topk=None,
497
+ use_entmax15=False,
498
+ num_mem_kv=0,
499
+ dropout=0.,
500
+ on_attn=False,
501
+ gate_values=False,
502
+ zero_init_output=False,
503
+ max_attend_past=None,
504
+ qk_norm=False,
505
+ scale_init_value=None,
506
+ rel_pos_bias=False,
507
+ rel_pos_num_buckets=32,
508
+ rel_pos_max_distance=128,
509
+ ):
510
+ super().__init__()
511
+ self.scale = dim_head ** -0.5
512
+
513
+ self.heads = heads
514
+ self.causal = causal
515
+ self.max_attend_past = max_attend_past
516
+
517
+ qk_dim = v_dim = dim_head * heads
518
+
519
+ # collaborative heads
520
+ self.collab_heads = collab_heads
521
+ if self.collab_heads:
522
+ qk_dim = int(collab_compression * qk_dim)
523
+ self.collab_mixing = nn.Parameter(torch.randn(heads, qk_dim))
524
+
525
+ self.to_q = nn.Linear(dim, qk_dim, bias=False)
526
+ self.to_k = nn.Linear(dim, qk_dim, bias=False)
527
+ self.to_v = nn.Linear(dim, v_dim, bias=False)
528
+
529
+ self.dropout = nn.Dropout(dropout)
530
+
531
+ # add GLU gating for aggregated values, from alphafold2
532
+ self.to_v_gate = None
533
+ if gate_values:
534
+ self.to_v_gate = nn.Linear(dim, v_dim)
535
+ nn.init.constant_(self.to_v_gate.weight, 0)
536
+ nn.init.constant_(self.to_v_gate.bias, 1)
537
+
538
+ # cosine sim attention
539
+ self.qk_norm = qk_norm
540
+ if qk_norm:
541
+ scale_init_value = default(scale_init_value,
542
+ -3) # if not provided, initialize as though it were sequence length of 1024
543
+ self.scale = nn.Parameter(torch.ones(1, heads, 1, 1) * scale_init_value)
544
+
545
+ # talking heads
546
+ self.talking_heads = talking_heads
547
+ if talking_heads:
548
+ self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads))
549
+ self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads))
550
+
551
+ # head scaling
552
+ self.head_scale = head_scale
553
+ if head_scale:
554
+ self.head_scale_params = nn.Parameter(torch.ones(1, heads, 1, 1))
555
+
556
+ # explicit topk sparse attention
557
+ self.sparse_topk = sparse_topk
558
+
559
+ # entmax
560
+ self.attn_fn = entmax15 if use_entmax15 else F.softmax
561
+
562
+ # add memory key / values
563
+ self.num_mem_kv = num_mem_kv
564
+ if num_mem_kv > 0:
565
+ self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
566
+ self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
567
+
568
+ # attention on attention
569
+ self.attn_on_attn = on_attn
570
+ self.to_out = nn.Sequential(nn.Linear(v_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(v_dim, dim)
571
+
572
+ self.rel_pos_bias = rel_pos_bias
573
+ if rel_pos_bias:
574
+ assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
575
+ self.rel_pos = RelativePositionBias(scale=dim_head ** 0.5, causal=causal, heads=heads,
576
+ num_buckets=rel_pos_num_buckets, max_distance=rel_pos_max_distance)
577
+
578
+ # init output projection 0
579
+ if zero_init_output:
580
+ init_zero_(self.to_out)
581
+
582
+ def forward(
583
+ self,
584
+ x,
585
+ context=None,
586
+ mask=None,
587
+ context_mask=None,
588
+ attn_mask=None,
589
+ sinusoidal_emb=None,
590
+ rotary_pos_emb=None,
591
+ prev_attn=None,
592
+ mem=None
593
+ ):
594
+ b, n, _, h, talking_heads, collab_heads, head_scale, scale, device, has_context = *x.shape, self.heads, self.talking_heads, self.collab_heads, self.head_scale, self.scale, x.device, exists(
595
+ context)
596
+ kv_input = default(context, x)
597
+
598
+ q_input = x
599
+ k_input = kv_input
600
+ v_input = kv_input
601
+
602
+ if exists(mem):
603
+ k_input = torch.cat((mem, k_input), dim=-2)
604
+ v_input = torch.cat((mem, v_input), dim=-2)
605
+
606
+ if exists(sinusoidal_emb):
607
+ # in shortformer, the query would start at a position offset depending on the past cached memory
608
+ offset = k_input.shape[-2] - q_input.shape[-2]
609
+ q_input = q_input + sinusoidal_emb(q_input, offset=offset)
610
+ k_input = k_input + sinusoidal_emb(k_input)
611
+
612
+ q = self.to_q(q_input)
613
+ k = self.to_k(k_input)
614
+ v = self.to_v(v_input)
615
+
616
+ if not collab_heads:
617
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
618
+ else:
619
+ q = einsum('b i d, h d -> b h i d', q, self.collab_mixing)
620
+ k = rearrange(k, 'b n d -> b () n d')
621
+ v = rearrange(v, 'b n (h d) -> b h n d', h=h)
622
+
623
+ if exists(rotary_pos_emb) and not has_context:
624
+ l = rotary_pos_emb.shape[-1]
625
+ (ql, qr), (kl, kr), (vl, vr) = map(lambda t: (t[..., :l], t[..., l:]), (q, k, v))
626
+ ql, kl, vl = map(lambda t: apply_rotary_pos_emb(t, rotary_pos_emb), (ql, kl, vl))
627
+ q, k, v = map(lambda t: torch.cat(t, dim=-1), ((ql, qr), (kl, kr), (vl, vr)))
628
+
629
+ input_mask = None
630
+ if any(map(exists, (mask, context_mask))):
631
+ q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool())
632
+ k_mask = q_mask if not exists(context) else context_mask
633
+ k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool())
634
+ q_mask = rearrange(q_mask, 'b i -> b () i ()')
635
+ k_mask = rearrange(k_mask, 'b j -> b () () j')
636
+ input_mask = q_mask * k_mask
637
+
638
+ if self.num_mem_kv > 0:
639
+ mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v))
640
+ k = torch.cat((mem_k, k), dim=-2)
641
+ v = torch.cat((mem_v, v), dim=-2)
642
+ if exists(input_mask):
643
+ input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True)
644
+
645
+ if collab_heads:
646
+ k = k.expand(-1, h, -1, -1)
647
+
648
+ if self.qk_norm:
649
+ q, k = map(l2norm, (q, k))
650
+ scale = 1 / (self.scale.exp().clamp(min=1e-2))
651
+
652
+ dots = einsum('b h i d, b h j d -> b h i j', q, k) * scale
653
+ mask_value = max_neg_value(dots)
654
+
655
+ if exists(prev_attn):
656
+ dots = dots + prev_attn
657
+
658
+ pre_softmax_attn = dots.clone()
659
+
660
+ if talking_heads:
661
+ dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous()
662
+
663
+ if self.rel_pos_bias:
664
+ dots = self.rel_pos(dots)
665
+
666
+ if exists(input_mask):
667
+ dots.masked_fill_(~input_mask, mask_value)
668
+ del input_mask
669
+
670
+ if exists(attn_mask):
671
+ assert 2 <= attn_mask.ndim <= 4, 'attention mask must have greater than 2 dimensions but less than or equal to 4'
672
+ if attn_mask.ndim == 2:
673
+ attn_mask = rearrange(attn_mask, 'i j -> () () i j')
674
+ elif attn_mask.ndim == 3:
675
+ attn_mask = rearrange(attn_mask, 'h i j -> () h i j')
676
+ dots.masked_fill_(~attn_mask, mask_value)
677
+
678
+ if exists(self.max_attend_past):
679
+ i, j = dots.shape[-2:]
680
+ range_q = torch.arange(j - i, j, device=device)
681
+ range_k = torch.arange(j, device=device)
682
+ dist = rearrange(range_q, 'i -> () () i ()') - rearrange(range_k, 'j -> () () () j')
683
+ mask = dist > self.max_attend_past
684
+ dots.masked_fill_(mask, mask_value)
685
+ del mask
686
+
687
+ if self.causal:
688
+ i, j = dots.shape[-2:]
689
+ r = torch.arange(i, device=device)
690
+ mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j')
691
+ mask = F.pad(mask, (j - i, 0), value=False)
692
+ dots.masked_fill_(mask, mask_value)
693
+ del mask
694
+
695
+ if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]:
696
+ top, _ = dots.topk(self.sparse_topk, dim=-1)
697
+ vk = top[..., -1].unsqueeze(-1).expand_as(dots)
698
+ mask = dots < vk
699
+ dots.masked_fill_(mask, mask_value)
700
+ del mask
701
+
702
+ attn = self.attn_fn(dots, dim=-1)
703
+ post_softmax_attn = attn.clone()
704
+
705
+ attn = self.dropout(attn)
706
+
707
+ if talking_heads:
708
+ attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous()
709
+
710
+ out = einsum('b h i j, b h j d -> b h i d', attn, v)
711
+
712
+ if head_scale:
713
+ out = out * self.head_scale_params
714
+
715
+ out = rearrange(out, 'b h n d -> b n (h d)')
716
+
717
+ if exists(self.to_v_gate):
718
+ gates = self.to_v_gate(x)
719
+ out = out * gates.sigmoid()
720
+
721
+ intermediates = Intermediates(
722
+ pre_softmax_attn=pre_softmax_attn,
723
+ post_softmax_attn=post_softmax_attn
724
+ )
725
+
726
+ return self.to_out(out), intermediates
727
+
728
+
729
+ class AttentionLayers(nn.Module):
730
+ def __init__(
731
+ self,
732
+ dim,
733
+ depth,
734
+ heads=8,
735
+ causal=False,
736
+ cross_attend=False,
737
+ only_cross=False,
738
+ use_scalenorm=False,
739
+ use_rms_scaleshift_norm=False,
740
+ use_rmsnorm=False,
741
+ use_rezero=False,
742
+ alibi_pos_bias=False,
743
+ alibi_num_heads=None,
744
+ alibi_learned=False,
745
+ position_infused_attn=False,
746
+ rotary_pos_emb=False,
747
+ rotary_emb_dim=None,
748
+ custom_layers=None,
749
+ sandwich_coef=None,
750
+ par_ratio=None,
751
+ residual_attn=False,
752
+ cross_residual_attn=False,
753
+ macaron=False,
754
+ pre_norm=True,
755
+ gate_residual=False,
756
+ scale_residual=False,
757
+ shift_tokens=0,
758
+ sandwich_norm=False,
759
+ use_qk_norm_attn=False,
760
+ qk_norm_attn_seq_len=None,
761
+ zero_init_branch_output=False,
762
+ **kwargs
763
+ ):
764
+ super().__init__()
765
+ ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)
766
+ attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs)
767
+
768
+ dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
769
+
770
+ self.dim = dim
771
+ self.depth = depth
772
+ self.layers = nn.ModuleList([])
773
+
774
+ rel_pos_bias = 'rel_pos_bias' in attn_kwargs
775
+ self.has_pos_emb = position_infused_attn or rel_pos_bias or rotary_pos_emb
776
+ self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None
777
+
778
+ rotary_emb_dim = max(default(rotary_emb_dim, dim_head // 2), 32)
779
+ self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim) if rotary_pos_emb else None
780
+
781
+ assert not (
782
+ alibi_pos_bias and rel_pos_bias), 'you can only choose Alibi positional bias or T5 relative positional bias, not both'
783
+
784
+ if alibi_pos_bias:
785
+ alibi_num_heads = default(alibi_num_heads, heads)
786
+ assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads'
787
+ alibi_pos_klass = LearnedAlibiPositionalBias if alibi_learned or not causal else AlibiPositionalBias
788
+ self.rel_pos = alibi_pos_klass(heads=alibi_num_heads, bidirectional=not causal)
789
+ else:
790
+ self.rel_pos = None
791
+
792
+ assert not (not pre_norm and sandwich_norm), 'sandwich norm cannot be used when not using prenorm'
793
+ self.pre_norm = pre_norm
794
+ self.sandwich_norm = sandwich_norm
795
+
796
+ self.residual_attn = residual_attn
797
+ self.cross_residual_attn = cross_residual_attn
798
+ self.cross_attend = cross_attend
799
+
800
+ norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm
801
+ norm_class = RMSNorm if use_rmsnorm else norm_class
802
+ norm_class = RMSScaleShiftNorm if use_rms_scaleshift_norm else norm_class
803
+ norm_fn = partial(norm_class, dim)
804
+
805
+ norm_fn = nn.Identity if use_rezero else norm_fn
806
+ branch_fn = Rezero if use_rezero else None
807
+
808
+ if cross_attend and not only_cross:
809
+ default_block = ('a', 'c', 'f')
810
+ elif cross_attend and only_cross:
811
+ default_block = ('c', 'f')
812
+ else:
813
+ default_block = ('a', 'f')
814
+
815
+ if macaron:
816
+ default_block = ('f',) + default_block
817
+
818
+ # qk normalization
819
+
820
+ if use_qk_norm_attn:
821
+ attn_scale_init_value = -math.log(math.log2(qk_norm_attn_seq_len ** 2 - qk_norm_attn_seq_len)) if exists(
822
+ qk_norm_attn_seq_len) else None
823
+ attn_kwargs = {**attn_kwargs, 'qk_norm': True, 'scale_init_value': attn_scale_init_value}
824
+
825
+ # zero init
826
+
827
+ if zero_init_branch_output:
828
+ attn_kwargs = {**attn_kwargs, 'zero_init_output': True}
829
+ ff_kwargs = {**ff_kwargs, 'zero_init_output': True}
830
+
831
+ # calculate layer block order
832
+
833
+ if exists(custom_layers):
834
+ layer_types = custom_layers
835
+ elif exists(par_ratio):
836
+ par_depth = depth * len(default_block)
837
+ assert 1 < par_ratio <= par_depth, 'par ratio out of range'
838
+ default_block = tuple(filter(not_equals('f'), default_block))
839
+ par_attn = par_depth // par_ratio
840
+ depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper
841
+ par_width = (depth_cut + depth_cut // par_attn) // par_attn
842
+ assert len(default_block) <= par_width, 'default block is too large for par_ratio'
843
+ par_block = default_block + ('f',) * (par_width - len(default_block))
844
+ par_head = par_block * par_attn
845
+ layer_types = par_head + ('f',) * (par_depth - len(par_head))
846
+ elif exists(sandwich_coef):
847
+ assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'
848
+ layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef
849
+ else:
850
+ layer_types = default_block * depth
851
+
852
+ self.layer_types = layer_types
853
+ self.num_attn_layers = len(list(filter(equals('a'), layer_types)))
854
+
855
+ # calculate token shifting
856
+
857
+ shift_tokens = cast_tuple(shift_tokens, len(layer_types))
858
+
859
+ # iterate and construct layers
860
+
861
+ for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)):
862
+ is_last_layer = ind == (len(self.layer_types) - 1)
863
+
864
+ if layer_type == 'a':
865
+ layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs)
866
+ elif layer_type == 'c':
867
+ layer = Attention(dim, heads=heads, **attn_kwargs)
868
+ elif layer_type == 'f':
869
+ layer = FeedForward(dim, **ff_kwargs)
870
+ layer = layer if not macaron else Scale(0.5, layer)
871
+ else:
872
+ raise Exception(f'invalid layer type {layer_type}')
873
+
874
+ if layer_shift_tokens > 0:
875
+ shift_range_upper = layer_shift_tokens + 1
876
+ shift_range_lower = -layer_shift_tokens if not causal else 0
877
+ layer = ShiftTokens(range(shift_range_lower, shift_range_upper), layer)
878
+
879
+ if exists(branch_fn):
880
+ layer = branch_fn(layer)
881
+
882
+ residual_fn = GRUGating if gate_residual else Residual
883
+ residual = residual_fn(dim, scale_residual=scale_residual)
884
+
885
+ layer_uses_qk_norm = use_qk_norm_attn and layer_type in ('a', 'c')
886
+
887
+ pre_branch_norm = norm_fn() if pre_norm and not layer_uses_qk_norm else None
888
+ post_branch_norm = norm_fn() if sandwich_norm or layer_uses_qk_norm else None
889
+ post_main_norm = norm_fn() if not pre_norm and not is_last_layer else None
890
+
891
+ norms = nn.ModuleList([
892
+ pre_branch_norm,
893
+ post_branch_norm,
894
+ post_main_norm
895
+ ])
896
+
897
+ self.layers.append(nn.ModuleList([
898
+ norms,
899
+ layer,
900
+ residual
901
+ ]))
902
+
903
+ def forward(
904
+ self,
905
+ x,
906
+ context=None,
907
+ full_context=None, # for passing a list of hidden states from an encoder
908
+ mask=None,
909
+ context_mask=None,
910
+ attn_mask=None,
911
+ mems=None,
912
+ return_hiddens=False,
913
+ norm_scale_shift_inp=None,
914
+ ):
915
+
916
+ assert not (self.cross_attend ^ (exists(context) or exists(
917
+ full_context))), 'context must be passed in if cross_attend is set to True'
918
+ assert context is None or full_context is None, 'only one of full_context or context can be provided'
919
+
920
+ hiddens = []
921
+ intermediates = []
922
+ prev_attn = None
923
+ prev_cross_attn = None
924
+
925
+ mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
926
+ norm_args = {}
927
+ if exists(norm_scale_shift_inp):
928
+ norm_args['norm_scale_shift_inp'] = norm_scale_shift_inp
929
+
930
+ rotary_pos_emb = None
931
+ if exists(self.rotary_pos_emb):
932
+ max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + x.shape[1], mems)))
933
+ rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device)
934
+
935
+ cross_attn_count = 0
936
+ for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
937
+ if layer_type == 'a':
938
+ layer_mem = mems.pop(0) if mems else None
939
+
940
+ residual = x
941
+
942
+ pre_branch_norm, post_branch_norm, post_main_norm = norm
943
+
944
+ if exists(pre_branch_norm):
945
+ x = pre_branch_norm(x, **norm_args)
946
+
947
+ if layer_type == 'a':
948
+ out, inter = checkpoint(block, x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb,
949
+ prev_attn, layer_mem)
950
+ elif layer_type == 'c':
951
+ if exists(full_context):
952
+ out, inter = checkpoint(block, x, full_context[cross_attn_count], mask, context_mask, None, None,
953
+ None, prev_attn)
954
+ else:
955
+ out, inter = checkpoint(block, x, context, mask, context_mask, None, None, None, prev_attn)
956
+ elif layer_type == 'f':
957
+ out = checkpoint(block, x)
958
+
959
+ if exists(post_branch_norm):
960
+ out = post_branch_norm(out, **norm_args)
961
+
962
+ x = residual_fn(out, residual)
963
+
964
+ if layer_type in ('a', 'c'):
965
+ intermediates.append(inter)
966
+
967
+ if layer_type == 'a' and self.residual_attn:
968
+ prev_attn = inter.pre_softmax_attn
969
+ elif layer_type == 'c' and self.cross_residual_attn:
970
+ prev_cross_attn = inter.pre_softmax_attn
971
+
972
+ if exists(post_main_norm):
973
+ x = post_main_norm(x, **norm_args)
974
+
975
+ if layer_type == 'c':
976
+ cross_attn_count += 1
977
+
978
+ if layer_type == 'f':
979
+ hiddens.append(x)
980
+
981
+ if return_hiddens:
982
+ intermediates = LayerIntermediates(
983
+ hiddens=hiddens,
984
+ attn_intermediates=intermediates
985
+ )
986
+
987
+ return x, intermediates
988
+
989
+ return x
990
+
991
+
992
+ class Encoder(AttentionLayers):
993
+ def __init__(self, **kwargs):
994
+ assert 'causal' not in kwargs, 'cannot set causality on encoder'
995
+ super().__init__(causal=False, **kwargs)
996
+
997
+
998
+ class Decoder(AttentionLayers):
999
+ def __init__(self, **kwargs):
1000
+ assert 'causal' not in kwargs, 'cannot set causality on decoder'
1001
+ super().__init__(causal=True, **kwargs)
1002
+
1003
+
1004
+ class CrossAttender(AttentionLayers):
1005
+ def __init__(self, **kwargs):
1006
+ super().__init__(cross_attend=True, only_cross=True, **kwargs)
1007
+
1008
+
1009
+ class ViTransformerWrapper(nn.Module):
1010
+ def __init__(
1011
+ self,
1012
+ *,
1013
+ image_size,
1014
+ patch_size,
1015
+ attn_layers,
1016
+ num_classes=None,
1017
+ dropout=0.,
1018
+ emb_dropout=0.
1019
+ ):
1020
+ super().__init__()
1021
+ assert isinstance(attn_layers, Encoder), 'attention layers must be an Encoder'
1022
+ assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
1023
+ dim = attn_layers.dim
1024
+ num_patches = (image_size // patch_size) ** 2
1025
+ patch_dim = 3 * patch_size ** 2
1026
+
1027
+ self.patch_size = patch_size
1028
+
1029
+ self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
1030
+ self.patch_to_embedding = nn.Linear(patch_dim, dim)
1031
+ self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
1032
+ self.dropout = nn.Dropout(emb_dropout)
1033
+
1034
+ self.attn_layers = attn_layers
1035
+ self.norm = nn.LayerNorm(dim)
1036
+ self.mlp_head = FeedForward(dim, dim_out=num_classes, dropout=dropout) if exists(num_classes) else None
1037
+
1038
+ def forward(
1039
+ self,
1040
+ img,
1041
+ return_embeddings=False
1042
+ ):
1043
+ p = self.patch_size
1044
+
1045
+ x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
1046
+ x = self.patch_to_embedding(x)
1047
+ b, n, _ = x.shape
1048
+
1049
+ cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
1050
+ x = torch.cat((cls_tokens, x), dim=1)
1051
+ x = x + self.pos_embedding[:, :(n + 1)]
1052
+ x = self.dropout(x)
1053
+
1054
+ x = self.attn_layers(x)
1055
+ x = self.norm(x)
1056
+
1057
+ if not exists(self.mlp_head) or return_embeddings:
1058
+ return x
1059
+
1060
+ return self.mlp_head(x[:, 0])
1061
+
1062
+
1063
+ class TransformerWrapper(nn.Module):
1064
+ def __init__(
1065
+ self,
1066
+ *,
1067
+ num_tokens,
1068
+ max_seq_len,
1069
+ attn_layers,
1070
+ emb_dim=None,
1071
+ max_mem_len=0.,
1072
+ shift_mem_down=0,
1073
+ emb_dropout=0.,
1074
+ num_memory_tokens=None,
1075
+ tie_embedding=False,
1076
+ use_pos_emb=True
1077
+ ):
1078
+ super().__init__()
1079
+ assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
1080
+
1081
+ dim = attn_layers.dim
1082
+ emb_dim = default(emb_dim, dim)
1083
+
1084
+ self.max_seq_len = max_seq_len
1085
+ self.max_mem_len = max_mem_len
1086
+ self.shift_mem_down = shift_mem_down
1087
+
1088
+ self.token_emb = nn.Embedding(num_tokens, emb_dim)
1089
+ self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (
1090
+ use_pos_emb and not attn_layers.has_pos_emb) else always(0)
1091
+ self.emb_dropout = nn.Dropout(emb_dropout)
1092
+
1093
+ self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
1094
+ self.attn_layers = attn_layers
1095
+ self.norm = nn.LayerNorm(dim)
1096
+
1097
+ self.init_()
1098
+
1099
+ self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t()
1100
+
1101
+ # memory tokens (like [cls]) from Memory Transformers paper
1102
+ num_memory_tokens = default(num_memory_tokens, 0)
1103
+ self.num_memory_tokens = num_memory_tokens
1104
+ if num_memory_tokens > 0:
1105
+ self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
1106
+
1107
+ def init_(self):
1108
+ nn.init.kaiming_normal_(self.token_emb.weight)
1109
+
1110
+ def forward(
1111
+ self,
1112
+ x,
1113
+ return_embeddings=False,
1114
+ mask=None,
1115
+ return_hiddens=False,
1116
+ return_attn=False,
1117
+ mems=None,
1118
+ **kwargs
1119
+ ):
1120
+ b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
1121
+ x = self.token_emb(x)
1122
+ x = x + self.pos_emb(x)
1123
+ x = self.emb_dropout(x)
1124
+
1125
+ x = self.project_emb(x)
1126
+
1127
+ if num_mem > 0:
1128
+ mem = repeat(self.memory_tokens, 'n d -> b n d', b=b)
1129
+ x = torch.cat((mem, x), dim=1)
1130
+
1131
+ # auto-handle masking after appending memory tokens
1132
+ if exists(mask):
1133
+ mask = F.pad(mask, (num_mem, 0), value=True)
1134
+
1135
+ if self.shift_mem_down and exists(mems):
1136
+ mems_l, mems_r = mems[:self.shift_mem_down], mems[self.shift_mem_down:]
1137
+ mems = [*mems_r, *mems_l]
1138
+
1139
+ x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs)
1140
+ x = self.norm(x)
1141
+
1142
+ mem, x = x[:, :num_mem], x[:, num_mem:]
1143
+
1144
+ out = self.to_logits(x) if not return_embeddings else x
1145
+
1146
+ if return_hiddens:
1147
+ hiddens = intermediates.hiddens
1148
+ return out, hiddens
1149
+
1150
+ if return_attn:
1151
+ attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
1152
+ return out, attn_maps
1153
+
1154
+ return out
1155
+
1156
+
1157
+ class ContinuousTransformerWrapper(nn.Module):
1158
+ def __init__(
1159
+ self,
1160
+ *,
1161
+ max_seq_len,
1162
+ attn_layers,
1163
+ dim_in=None,
1164
+ dim_out=None,
1165
+ emb_dim=None,
1166
+ emb_dropout=0.,
1167
+ use_pos_emb=True
1168
+ ):
1169
+ super().__init__()
1170
+ assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
1171
+
1172
+ dim = attn_layers.dim
1173
+
1174
+ self.max_seq_len = max_seq_len
1175
+
1176
+ self.pos_emb = AbsolutePositionalEmbedding(dim, max_seq_len) if (
1177
+ use_pos_emb and not attn_layers.has_pos_emb) else always(0)
1178
+ self.emb_dropout = nn.Dropout(emb_dropout)
1179
+
1180
+ self.project_in = nn.Linear(dim_in, dim) if exists(dim_in) else nn.Identity()
1181
+
1182
+ self.attn_layers = attn_layers
1183
+ self.norm = nn.LayerNorm(dim)
1184
+
1185
+ self.project_out = nn.Linear(dim, dim_out) if exists(dim_out) else nn.Identity()
1186
+
1187
+ def forward(
1188
+ self,
1189
+ x,
1190
+ return_embeddings=False,
1191
+ mask=None,
1192
+ return_attn=False,
1193
+ mems=None,
1194
+ **kwargs
1195
+ ):
1196
+ b, n, _, device = *x.shape, x.device
1197
+
1198
+ x = self.project_in(x)
1199
+ x = x + self.pos_emb(x)
1200
+ x = self.emb_dropout(x)
1201
+
1202
+ x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs)
1203
+ x = self.norm(x)
1204
+
1205
+ out = self.project_out(x) if not return_embeddings else x
1206
+
1207
+ if return_attn:
1208
+ attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
1209
+ return out, attn_maps
1210
+
1211
+ return out
1212
+
1213
+
1214
+ class XTransformer(nn.Module):
1215
+ def __init__(
1216
+ self,
1217
+ *,
1218
+ dim,
1219
+ tie_token_emb=False,
1220
+ **kwargs
1221
+ ):
1222
+ super().__init__()
1223
+ enc_kwargs, kwargs = groupby_prefix_and_trim('enc_', kwargs)
1224
+ dec_kwargs, kwargs = groupby_prefix_and_trim('dec_', kwargs)
1225
+
1226
+ assert 'dim' not in enc_kwargs and 'dim' not in dec_kwargs, 'dimension of either encoder or decoder must be set with `dim` keyword'
1227
+ enc_transformer_kwargs = pick_and_pop(['num_tokens', 'max_seq_len'], enc_kwargs)
1228
+ enc_transformer_kwargs['emb_dropout'] = enc_kwargs.pop('emb_dropout', 0)
1229
+ enc_transformer_kwargs['num_memory_tokens'] = enc_kwargs.pop('num_memory_tokens', None)
1230
+ enc_transformer_kwargs['use_pos_emb'] = enc_kwargs.pop('use_pos_emb', True)
1231
+
1232
+ dec_transformer_kwargs = pick_and_pop(['num_tokens', 'max_seq_len'], dec_kwargs)
1233
+ dec_transformer_kwargs['emb_dropout'] = dec_kwargs.pop('emb_dropout', 0)
1234
+ dec_transformer_kwargs['use_pos_emb'] = dec_kwargs.pop('use_pos_emb', True)
1235
+
1236
+ self.encoder = TransformerWrapper(
1237
+ **enc_transformer_kwargs,
1238
+ attn_layers=Encoder(dim=dim, **enc_kwargs)
1239
+ )
1240
+
1241
+ self.decoder = TransformerWrapper(
1242
+ **dec_transformer_kwargs,
1243
+ attn_layers=Decoder(dim=dim, cross_attend=True, **dec_kwargs)
1244
+ )
1245
+
1246
+ if tie_token_emb:
1247
+ self.decoder.token_emb = self.encoder.token_emb
1248
+
1249
+ self.decoder = AutoregressiveWrapper(self.decoder)
1250
+
1251
+ @torch.no_grad()
1252
+ def generate(self, seq_in, seq_out_start, seq_len, src_mask=None, src_attn_mask=None, **kwargs):
1253
+ encodings = self.encoder(seq_in, mask=src_mask, attn_mask=src_attn_mask, return_embeddings=True)
1254
+ return self.decoder.generate(seq_out_start, seq_len, context=encodings, context_mask=src_mask, **kwargs)
1255
+
1256
+ def forward(self, src, tgt, src_mask=None, tgt_mask=None, src_attn_mask=None):
1257
+ enc = self.encoder(src, mask=src_mask, attn_mask=src_attn_mask, return_embeddings=True)
1258
+ out = self.decoder(tgt, context=enc, mask=tgt_mask, context_mask=src_mask)
1259
+ return out