asigalov61 commited on
Commit
1744970
1 Parent(s): 1114cba

Delete x_transformer_1_23_2.py

Browse files
Files changed (1) hide show
  1. x_transformer_1_23_2.py +0 -2464
x_transformer_1_23_2.py DELETED
@@ -1,2464 +0,0 @@
1
- #===================================================================================================================
2
- #
3
- # X Trasformer Module
4
- #
5
- # Partial x-transformers code With useful modifications
6
- #
7
- # Version 1.0
8
- #
9
- # Original source code courtesy of lucidrains
10
- # https://github.com/lucidrains/x-transformers
11
- #
12
- # Original source code retrieved on 10/10/2023
13
- #
14
- # Project Los Angeles
15
- # Tegridy Code 2023
16
-
17
- #===================================================================================================================
18
-
19
- # Critical dependencies
20
- #
21
- # !pip install torch
22
- # !pip install einops
23
-
24
- #===================================================================================================================
25
-
26
- from functools import partial
27
- from typing import Optional, Tuple
28
-
29
- import torch
30
- from torch import nn, einsum, Tensor
31
- import torch.nn.functional as F
32
- # from torch.nn.attention import SDPBackend, sdpa_kernel
33
-
34
- from collections import namedtuple
35
- from functools import wraps
36
- from packaging import version
37
- from dataclasses import dataclass
38
-
39
- from einops import rearrange, repeat
40
-
41
- # constants
42
-
43
- EfficientAttentionConfig = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
44
-
45
- @dataclass
46
- class Intermediates:
47
- qk_similarities: Optional[Tensor] = None
48
- pre_softmax_attn: Optional[Tensor] = None
49
- post_softmax_attn: Optional[Tensor] = None
50
- cached_kv: Optional[Tuple[Tensor, Tensor]] = None
51
-
52
- def to_tuple(self):
53
- return (self.qk_similarities, self.pre_softmax_attn, self.post_softmax_attn)
54
-
55
- # helpers
56
-
57
- def exists(val):
58
- return val is not None
59
-
60
- def default(val, d):
61
- return val if exists(val) else d
62
-
63
- def compact(arr):
64
- return [*filter(exists, arr)]
65
-
66
- def once(fn):
67
- called = False
68
- @wraps(fn)
69
- def inner(x):
70
- nonlocal called
71
- if called:
72
- return
73
- called = True
74
- return fn(x)
75
- return inner
76
-
77
- print_once = once(print)
78
-
79
- # functions for creating causal mask
80
- # need a special one for onnx cpu (no support for .triu)
81
-
82
- def create_causal_mask(i, j, device):
83
- return torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1)
84
-
85
- def onnx_create_causal_mask(i, j, device):
86
- r = torch.arange(i, device = device)
87
- causal_mask = rearrange(r, 'i -> i 1') < rearrange(r, 'j -> 1 j')
88
- causal_mask = F.pad(causal_mask, (j - i, 0), value = False)
89
- return causal_mask
90
-
91
- # main class
92
-
93
- class Attend(nn.Module):
94
- def __init__(
95
- self,
96
- *,
97
- dropout = 0.,
98
- causal = False,
99
- heads = None,
100
- talking_heads = False,
101
- sparse_topk = None,
102
- scale = None,
103
- qk_norm = False,
104
- flash = False,
105
- add_zero_kv = False,
106
- onnxable = False
107
- ):
108
- super().__init__()
109
- self.scale = scale
110
- self.qk_norm = qk_norm
111
-
112
- self.causal = causal
113
- self.create_causal_mask = onnx_create_causal_mask if onnxable else create_causal_mask
114
-
115
- self.attn_fn = partial(F.softmax, dtype = torch.float32) if not qk_norm else F.softmax
116
-
117
- self.dropout = dropout
118
- self.attn_dropout = nn.Dropout(dropout)
119
-
120
- # talking heads
121
-
122
- assert not (flash and talking_heads), 'talking heads not compatible with flash attention'
123
-
124
- self.talking_heads = talking_heads
125
- if talking_heads:
126
- self.pre_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False)
127
- self.post_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False)
128
-
129
- # sparse topk
130
-
131
- assert not (flash and sparse_topk), 'sparse topk not compatible with flash attention'
132
- self.sparse_topk = sparse_topk
133
-
134
- # add a key / value token composed of zeros
135
- # in case this helps controlling outliers, proposed by https://www.evanmiller.org/attention-is-off-by-one.html
136
-
137
- self.add_zero_kv = add_zero_kv
138
-
139
- # flash attention
140
-
141
- self.flash = flash
142
- assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
143
-
144
- # determine efficient attention configs for cuda and cpu
145
-
146
- self.cpu_config = EfficientAttentionConfig(True, True, True)
147
- self.cuda_config = None
148
-
149
- if not torch.cuda.is_available() or not flash:
150
- return
151
-
152
- device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
153
-
154
- major, minor = device_properties.major, device_properties.minor
155
-
156
- if (major, minor) == (8, 0):
157
- print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
158
- self.cuda_config = EfficientAttentionConfig(True, False, False)
159
- elif (major, minor) == (9, 0):
160
- print_once('H100 GPU detected, using flash attention')
161
- self.cuda_config = EfficientAttentionConfig(True, False, False)
162
- else:
163
- print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
164
- self.cuda_config = EfficientAttentionConfig(False, True, True)
165
-
166
- def flash_attn(
167
- self,
168
- q, k, v,
169
- mask = None,
170
- attn_bias = None
171
- ):
172
- batch, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
173
-
174
- # Recommended for multi-query single-key-value attention by Tri Dao
175
- # kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64])
176
-
177
- if k.ndim == 3:
178
- k = rearrange(k, 'b ... -> b 1 ...').expand_as(q)
179
-
180
- if v.ndim == 3:
181
- v = rearrange(v, 'b ... -> b 1 ...').expand_as(q)
182
-
183
- # handle scale - by default they scale by dim_head ** -0.5, but need to take care if using cosine sim attention
184
-
185
- if self.qk_norm:
186
- default_scale = q.shape[-1] ** -0.5
187
- q = q * (self.scale / default_scale)
188
-
189
- # Check if mask exists and expand to compatible shape
190
- # The mask is B L, so it would have to be expanded to B H N L
191
-
192
- causal = self.causal
193
-
194
- # in the case of kv caching with one token (q_len == 1), just turn off causal masking
195
- # in speculative decoding, this may go up to 5-6, so right aligned causal mask will be needed there
196
-
197
- if q_len == 1 and causal:
198
- causal = False
199
-
200
- # expand key padding mask
201
-
202
- if exists(mask):
203
- assert mask.ndim == 4
204
- mask = mask.expand(batch, heads, q_len, k_len)
205
-
206
- # handle kv cache - this should be bypassable in updated flash attention 2
207
-
208
- if k_len > q_len and causal:
209
- causal_mask = self.create_causal_mask(q_len, k_len, device = device)
210
- if not exists(mask):
211
- mask = ~causal_mask
212
- else:
213
- mask = mask & ~causal_mask
214
- causal = False
215
-
216
- # manually handle causal mask, if another mask was given
217
-
218
- row_is_entirely_masked = None
219
-
220
- if exists(mask) and causal:
221
- causal_mask = self.create_causal_mask(q_len, k_len, device = device)
222
- mask = mask & ~causal_mask
223
-
224
- # protect against an entire row being masked out
225
-
226
- row_is_entirely_masked = ~mask.any(dim = -1)
227
- mask[..., 0] = mask[..., 0] | row_is_entirely_masked
228
-
229
- causal = False
230
-
231
- # handle alibi positional bias
232
- # convert from bool to float
233
-
234
- if exists(attn_bias):
235
- attn_bias = rearrange(attn_bias, 'h i j -> 1 h i j').expand(batch, heads, -1, -1)
236
-
237
- # if mask given, the mask would already contain the causal mask from above logic
238
- # otherwise, if no mask given but still causal, mask out alibi positional bias to a large negative number
239
-
240
- mask_value = -torch.finfo(q.dtype).max
241
-
242
- if exists(mask):
243
- attn_bias = attn_bias.masked_fill(~mask, mask_value // 2)
244
- elif causal:
245
- causal_mask = self.create_causal_mask(q_len, k_len, device = device)
246
- attn_bias = attn_bias.masked_fill(causal_mask, mask_value // 2)
247
- causal = False
248
-
249
- # scaled_dot_product_attention handles attn_mask either as bool or additive bias
250
- # make it an additive bias here
251
-
252
- mask = attn_bias
253
-
254
- # Check if there is a compatible device for flash attention
255
-
256
- config = self.cuda_config if is_cuda else self.cpu_config
257
-
258
- # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
259
-
260
- # Legacy code...
261
- with torch.backends.cuda.sdp_kernel(enable_math=True, enable_mem_efficient=True):
262
-
263
- # New SDP kernel code...
264
- # with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
265
-
266
- out = F.scaled_dot_product_attention(
267
- q, k, v,
268
- attn_mask = mask,
269
- dropout_p = self.dropout if self.training else 0.,
270
- is_causal = causal
271
- )
272
-
273
- # for a row that is entirely masked out, should zero out the output of that row token
274
-
275
- if exists(row_is_entirely_masked):
276
- out = out.masked_fill(row_is_entirely_masked[..., None], 0.)
277
-
278
- return out, Intermediates()
279
-
280
- def forward(
281
- self,
282
- q, k, v,
283
- mask = None,
284
- attn_bias = None,
285
- prev_attn = None
286
- ):
287
- """
288
- einstein notation
289
- b - batch
290
- h - heads
291
- n, i, j - sequence length (base sequence length, source, target)
292
- d - feature dimension
293
- """
294
-
295
- n, heads, kv_heads, device = q.shape[-2], q.shape[1], k.shape[1], q.device
296
-
297
- scale = default(self.scale, q.shape[-1] ** -0.5)
298
-
299
- causal = self.causal
300
-
301
- # handle kv cached decoding
302
-
303
- if n == 1 and causal:
304
- causal = False
305
-
306
- # handle grouped multi-query attention
307
-
308
- if kv_heads == 1:
309
- k, v = map(lambda t: rearrange(t, 'b 1 n d -> b n d'), (k, v))
310
- elif kv_heads < heads:
311
- k, v = map(lambda t: repeat(t, 'b kvh n d -> b (r kvh) n d', r = heads // kv_heads), (k, v))
312
-
313
- # handle zero kv, as means for allowing network to attend to nothing
314
-
315
- if self.add_zero_kv:
316
- k, v = map(lambda t: F.pad(t, (0, 0, 1, 0), value = 0.), (k, v))
317
-
318
- if exists(mask):
319
- mask = F.pad(mask, (1, 0), value = True)
320
-
321
- if exists(attn_bias):
322
- attn_bias = F.pad(attn_bias, (1, 0), value = 0.)
323
-
324
- if self.flash:
325
- assert not exists(prev_attn), 'residual attention not compatible with flash attention'
326
- return self.flash_attn(q, k, v, mask = mask, attn_bias = attn_bias)
327
-
328
- kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d'
329
-
330
- dots = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k) * scale
331
-
332
- if exists(prev_attn):
333
- dots = dots + prev_attn
334
-
335
- qk_similarities = dots.clone()
336
-
337
- if self.talking_heads:
338
- dots = self.pre_softmax_talking_heads(dots)
339
-
340
- if exists(attn_bias):
341
- dots = dots + attn_bias
342
-
343
- i, j, dtype = *dots.shape[-2:], dots.dtype
344
-
345
- mask_value = -torch.finfo(dots.dtype).max
346
-
347
- if exists(self.sparse_topk) and self.sparse_topk < j:
348
- top_values, _ = dots.topk(self.sparse_topk, dim = -1)
349
- sparse_topk_mask = dots < top_values[..., -1:]
350
- mask = (mask & sparse_topk_mask) if exists(mask) else sparse_topk_mask
351
-
352
- if exists(mask):
353
- dots = dots.masked_fill(~mask, mask_value)
354
-
355
- if causal:
356
- causal_mask = self.create_causal_mask(i, j, device = device)
357
- dots = dots.masked_fill(causal_mask, mask_value)
358
-
359
- pre_softmax_attn = dots.clone()
360
-
361
- attn = self.attn_fn(dots, dim = -1)
362
- attn = attn.type(dtype)
363
-
364
- post_softmax_attn = attn.clone()
365
-
366
- attn = self.attn_dropout(attn)
367
-
368
- if self.talking_heads:
369
- attn = self.post_softmax_talking_heads(attn)
370
-
371
- out = einsum(f'b h i j, {kv_einsum_eq} -> b h i d', attn, v)
372
-
373
- intermediates = Intermediates(
374
- qk_similarities = qk_similarities,
375
- pre_softmax_attn = pre_softmax_attn,
376
- post_softmax_attn = post_softmax_attn
377
- )
378
-
379
- return out, intermediates
380
-
381
- #===================================================================================================================
382
-
383
- from math import ceil, log
384
- from typing import Optional, Union, Tuple, Callable
385
-
386
- import torch
387
- from torch import nn, Tensor
388
- from torch.nn import Module
389
- import torch.nn.functional as F
390
-
391
- from einops import rearrange, pack, unpack
392
-
393
- def exists(val):
394
- return val is not None
395
-
396
- def default(val, d):
397
- return val if exists(val) else d
398
-
399
- def identity(t, *args, **kwargs):
400
- return t
401
-
402
- def cast_tuple(t, length = 1):
403
- return t if isinstance(t, tuple) else (t,) * length
404
-
405
- def eval_decorator(fn):
406
- def inner(self, *args, **kwargs):
407
- was_training = self.training
408
- self.eval()
409
- out = fn(self, *args, **kwargs)
410
- self.train(was_training)
411
- return out
412
- return inner
413
-
414
- # for variable lengthed prefixes
415
-
416
- def align_right(t, lens, pad_id = 0):
417
- batch, seq_len, device, dtype = *t.shape, t.device, t.dtype
418
-
419
- assert lens.ndim == 1 and lens.shape[0] == batch
420
- assert lens.amax() <= seq_len
421
-
422
- pad_lens = seq_len - lens
423
- max_pad_len = pad_lens.amax()
424
-
425
- batch_arange = torch.arange(batch, device = device, dtype = torch.long)[..., None]
426
- prompt_len_arange = torch.arange(seq_len, device = device, dtype = torch.long)
427
-
428
- t = F.pad(t, (max_pad_len, 0), value = 0)
429
- offset = max_pad_len - pad_lens
430
-
431
- aligned = t[batch_arange, prompt_len_arange + offset[..., None]]
432
- return aligned
433
-
434
- # nucleus
435
-
436
- def top_p(logits, thres = 0.9):
437
- sorted_logits, sorted_indices = torch.sort(logits, descending = True)
438
- cum_probs = torch.cumsum(F.softmax(sorted_logits, dim = -1), dim = -1)
439
-
440
- sorted_indices_to_remove = cum_probs > thres
441
- sorted_indices_to_remove = F.pad(sorted_indices_to_remove, (1, -1), value = False)
442
-
443
- sorted_logits[sorted_indices_to_remove] = float('-inf')
444
- return sorted_logits.scatter(1, sorted_indices, sorted_logits)
445
-
446
- # topk
447
-
448
- def top_k(logits, frac_num_tokens = 0.1, k = None):
449
- num_tokens = logits.shape[-1]
450
-
451
- k = default(k, ceil(frac_num_tokens * num_tokens))
452
- k = min(k, num_tokens)
453
-
454
- val, ind = torch.topk(logits, k)
455
- probs = torch.full_like(logits, float('-inf'))
456
- probs.scatter_(1, ind, val)
457
- return probs
458
-
459
- # top_a
460
-
461
- def top_a(logits, min_p_pow = 2.0, min_p_ratio = 0.02):
462
- probs = F.softmax(logits, dim = -1)
463
- max_probs = torch.amax(probs, dim = -1, keepdim = True)
464
- limit = torch.pow(max_probs, min_p_pow) * min_p_ratio
465
- return torch.where(probs < limit, float('-inf'), logits)
466
-
467
- # contrastive decoding function
468
-
469
- def contrastive_decode_fn(
470
- expert_logits,
471
- amateur_logits,
472
- alpha = 0.1,
473
- beta = 0.5
474
- ):
475
- """
476
- Appendix A Algorithm 2
477
- https://arxiv.org/abs/2309.09117
478
- """
479
-
480
- cutoff = log(alpha) + expert_logits.amax(dim = -1, keepdim = True)
481
- diffs = (1 + beta) * expert_logits - beta * amateur_logits
482
- contrastive_decode_logits = diffs.masked_fill(expert_logits < cutoff, -torch.finfo(expert_logits.dtype).max)
483
- return contrastive_decode_logits
484
-
485
- # autoregressive wrapper class
486
-
487
- class AutoregressiveWrapper(Module):
488
- def __init__(
489
- self,
490
- net,
491
- ignore_index = -100,
492
- pad_value = 0,
493
- mask_prob = 0.,
494
- add_attn_z_loss = False
495
- ):
496
- super().__init__()
497
- self.pad_value = pad_value
498
- self.ignore_index = ignore_index
499
-
500
- self.net = net
501
- self.max_seq_len = net.max_seq_len
502
-
503
- # paper shows masking (MLM) in conjunction with autoregressive decoder-only training leads to big improvements https://arxiv.org/abs/2210.13432
504
- assert mask_prob < 1.
505
- self.mask_prob = mask_prob
506
-
507
- # whether to add router z-loss
508
- self.add_attn_z_loss = add_attn_z_loss
509
-
510
- @torch.no_grad()
511
- @eval_decorator
512
- def generate(
513
- self,
514
- prompts,
515
- seq_len,
516
- eos_token = None,
517
- temperature = 1.,
518
- prompt_lens: Optional[Tensor] = None,
519
- filter_logits_fn: Callable = top_k,
520
- restrict_to_max_seq_len = True,
521
- amateur_model: Optional[Union[Module, Tuple[Module]]] = None,
522
- filter_kwargs: dict = dict(),
523
- contrastive_decode_kwargs: Union[dict, Tuple[dict]] = dict(
524
- beta = 0.5,
525
- alpha = 0.1
526
- ),
527
- cache_kv = True,
528
- verbose=True,
529
- return_prime=False,
530
- **kwargs
531
- ):
532
- max_seq_len, device = self.max_seq_len, prompts.device
533
-
534
- prompts, ps = pack([prompts], '* n')
535
-
536
- b, t = prompts.shape
537
-
538
- # handle variable lengthed prompts (prefixes)
539
-
540
- seq_start_pos = None
541
- if exists(prompt_lens):
542
- prompts = align_right(prompts, prompt_lens, pad_id = self.pad_value)
543
- seq_start_pos = t - prompt_lens
544
-
545
- # output from which sampled tokens appended to
546
-
547
- out = prompts
548
-
549
- if verbose:
550
- print("Generating sequence of max length:", seq_len)
551
-
552
- # kv caches
553
-
554
- cache = None
555
-
556
- # if doing contrastive decoding, turn off filter automatically
557
-
558
- if exists(amateur_model):
559
- amateur_model = cast_tuple(amateur_model)
560
- contrastive_decode_kwargs = cast_tuple(contrastive_decode_kwargs)
561
-
562
- assert len(amateur_model) == len(contrastive_decode_kwargs)
563
-
564
- amateur_caches = [None] * len(amateur_model)
565
- filter_logits_fn = identity
566
-
567
- for i, module in enumerate(amateur_model):
568
- if isinstance(module, AutoregressiveWrapper):
569
- amateur_model[i] = module.net
570
-
571
- module.eval()
572
-
573
- # sampling up to seq_len
574
-
575
- for sl in range(seq_len):
576
-
577
- if restrict_to_max_seq_len:
578
- x = out[:, -max_seq_len:]
579
-
580
- if exists(cache):
581
- for inter in cache.attn_intermediates:
582
- inter.cached_kv = [t[..., -(max_seq_len - 1):, :] for t in inter.cached_kv]
583
-
584
- logits, new_cache = self.net(
585
- x,
586
- return_intermediates = True,
587
- cache = cache,
588
- seq_start_pos = seq_start_pos,
589
- **kwargs
590
- )
591
-
592
- if cache_kv and self.net.can_cache_kv:
593
- cache = new_cache
594
-
595
- logits = logits[:, -1]
596
-
597
- # handle contrastive decoding, Li et al.
598
- # https://arxiv.org/abs/2210.15097
599
-
600
- if exists(amateur_model):
601
- for i, (amateur, amateur_cache, amateur_contrastive_decode_kwargs) in enumerate(zip(amateur_model, amateur_caches, contrastive_decode_kwargs)):
602
- amateur_logits, next_amateur_cache = amateur(
603
- x,
604
- return_intermediates = True,
605
- cache = amateur_cache,
606
- seq_start_pos = seq_start_pos,
607
- **kwargs
608
- )
609
-
610
- amateur_logits = amateur_logits[:, -1]
611
-
612
- assert amateur_logits.shape == logits.shape, 'logits dimension are not the same between amateur and expert model'
613
- logits = contrastive_decode_fn(logits, amateur_logits, **amateur_contrastive_decode_kwargs)
614
-
615
- if cache_kv and amateur.can_cache_kv:
616
- amateur_caches[i] = next_amateur_cache
617
-
618
- # filter by top_k, top_p (nucleus), top_a, or custom
619
-
620
- filtered_logits = filter_logits_fn(logits, **filter_kwargs)
621
-
622
- probs = F.softmax(filtered_logits / temperature, dim=-1)
623
-
624
- sample = torch.multinomial(probs, 1)
625
-
626
- out = torch.cat((out, sample), dim=-1)
627
-
628
- if verbose:
629
- if sl % 32 == 0:
630
- print(sl, '/', seq_len)
631
-
632
- if exists(eos_token):
633
- is_eos_tokens = (out == eos_token)
634
-
635
- if is_eos_tokens.any(dim = -1).all():
636
- # mask out everything after the eos tokens
637
- shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
638
- mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1
639
- out = out.masked_fill(mask, self.pad_value)
640
-
641
- if verbose:
642
- print('Model called the end of sequence at:', sl, '/', seq_len)
643
-
644
- break
645
-
646
- if return_prime:
647
- return out[:, :]
648
-
649
- else:
650
- return out[:, t:]
651
-
652
- # out, = unpack(out, ps, '* n')
653
-
654
- # return out
655
-
656
- def compute_accuracy(self, logits, labels):
657
- out = torch.argmax(logits, dim=-1)
658
- out = out.flatten()
659
- labels = labels.flatten()
660
-
661
- mask = (labels != self.ignore_index) # can also be self.pad_value (your choice)
662
- out = out[mask]
663
- labels = labels[mask]
664
-
665
- num_right = (out == labels)
666
- num_right = torch.sum(num_right).type(torch.float32)
667
-
668
- acc = num_right / len(labels)
669
- return acc
670
-
671
- def forward(self, x, **kwargs):
672
- seq, ignore_index, add_attn_z_loss = x.shape[1], self.ignore_index, self.add_attn_z_loss
673
-
674
- inp, target = x[:, :-1], x[:, 1:]
675
- inp = torch.where(inp == ignore_index, self.pad_value, inp)
676
-
677
- if self.mask_prob > 0.:
678
- rand = torch.randn(inp.shape, device = x.device)
679
- rand[:, 0] = -torch.finfo(rand.dtype).max # first token should not be masked out
680
- num_mask = min(int(seq * self.mask_prob), seq - 1)
681
- indices = rand.topk(num_mask, dim = -1).indices
682
- mask = ~torch.zeros_like(inp).scatter(1, indices, 1.).bool()
683
- kwargs.update(self_attn_kv_mask = mask)
684
-
685
- logits, cache = self.net(
686
- inp,
687
- return_intermediates = True,
688
- return_attn_z_loss = add_attn_z_loss,
689
- **kwargs
690
- )
691
-
692
- acc = self.compute_accuracy(logits, target)
693
-
694
- loss = F.cross_entropy(
695
- rearrange(logits, 'b n c -> b c n'),
696
- target,
697
- ignore_index = ignore_index
698
- )
699
-
700
- if add_attn_z_loss:
701
- loss = loss + cache.attn_z_loss
702
-
703
- return loss, acc
704
-
705
- #===============================================================================
706
-
707
- import math
708
- from random import random
709
-
710
- import torch
711
- from torch import nn, einsum, Tensor
712
- import torch.nn.functional as F
713
-
714
- from functools import partial, wraps
715
- from inspect import isfunction
716
- from collections import namedtuple
717
- from dataclasses import dataclass
718
- from typing import List, Callable, Optional
719
-
720
- from einops import rearrange, repeat, reduce, pack, unpack
721
- from einops.layers.torch import Rearrange
722
-
723
- # constants
724
-
725
- DEFAULT_DIM_HEAD = 64
726
-
727
- @dataclass
728
- class LayerIntermediates:
729
- hiddens: Optional[List[Tensor]] = None
730
- attn_intermediates: Optional[List[Intermediates]] = None
731
- layer_hiddens: Optional[List[Tensor]] = None
732
- attn_z_loss: Optional[Tensor] = None
733
- mems: Optional[Tensor] = None
734
-
735
- # helpers
736
-
737
- def exists(val):
738
- return val is not None
739
-
740
- def default(val, d):
741
- if exists(val):
742
- return val
743
- return d() if isfunction(d) else d
744
-
745
- def cast_tuple(val, depth):
746
- return val if isinstance(val, tuple) else (val,) * depth
747
-
748
- def divisible_by(num, den):
749
- return (num % den) == 0
750
-
751
- def maybe(fn):
752
- @wraps(fn)
753
- def inner(x, *args, **kwargs):
754
- if not exists(x):
755
- return x
756
- return fn(x, *args, **kwargs)
757
- return inner
758
-
759
- class always():
760
- def __init__(self, val):
761
- self.val = val
762
- def __call__(self, *args, **kwargs):
763
- return self.val
764
-
765
- class not_equals():
766
- def __init__(self, val):
767
- self.val = val
768
- def __call__(self, x, *args, **kwargs):
769
- return x != self.val
770
-
771
- class equals():
772
- def __init__(self, val):
773
- self.val = val
774
- def __call__(self, x, *args, **kwargs):
775
- return x == self.val
776
-
777
- def Sequential(*modules):
778
- return nn.Sequential(*filter(exists, modules))
779
-
780
- # tensor helpers
781
-
782
- def max_neg_value(tensor):
783
- return -torch.finfo(tensor.dtype).max
784
-
785
- def l2norm(t, groups = 1):
786
- t = rearrange(t, '... (g d) -> ... g d', g = groups)
787
- t = F.normalize(t, p = 2, dim = -1)
788
- return rearrange(t, '... g d -> ... (g d)')
789
-
790
- def pad_at_dim(t, pad, dim = -1, value = 0.):
791
- dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
792
- zeros = ((0, 0) * dims_from_right)
793
- return F.pad(t, (*zeros, *pad), value = value)
794
-
795
- def or_reduce(masks):
796
- head, *body = masks
797
- for rest in body:
798
- head = head | rest
799
- return head
800
-
801
- # auxiliary loss helpers
802
-
803
- def calc_z_loss(
804
- pre_softmax_attns: List[Tensor],
805
- mask = None,
806
- weight = 1.
807
- ):
808
- # the same loss applied to the mixture of experts router logits in https://arxiv.org/abs/2202.08906
809
- # in the paper, in a tiny footnote, they mention using it on attention logits with stabilizing effects
810
- # also used in PaLM as one of the measures
811
-
812
- lse = 0.
813
-
814
- for attn in pre_softmax_attns:
815
- lse = lse + attn.logsumexp(dim = -1)
816
-
817
- loss = torch.square(lse)
818
- loss = reduce(loss, 'b h n -> b n', 'sum')
819
-
820
- if not exists(mask):
821
- return loss.mean() * weight
822
-
823
- loss = loss[mask].sum() / mask.sum().clamp(min = 1e-5)
824
- return loss * weight
825
-
826
- # init helpers
827
-
828
- def init_zero_(layer):
829
- nn.init.constant_(layer.weight, 0.)
830
- if exists(layer.bias):
831
- nn.init.constant_(layer.bias, 0.)
832
-
833
- # keyword argument helpers
834
-
835
- def pick_and_pop(keys, d):
836
- values = list(map(lambda key: d.pop(key), keys))
837
- return dict(zip(keys, values))
838
-
839
- def group_dict_by_key(cond, d):
840
- return_val = [dict(),dict()]
841
- for key in d.keys():
842
- match = bool(cond(key))
843
- ind = int(not match)
844
- return_val[ind][key] = d[key]
845
- return (*return_val,)
846
-
847
- def string_begins_with(prefix, str):
848
- return str.startswith(prefix)
849
-
850
- def group_by_key_prefix(prefix, d):
851
- return group_dict_by_key(partial(string_begins_with, prefix), d)
852
-
853
- def groupby_prefix_and_trim(prefix, d):
854
- kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
855
- kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
856
- return kwargs_without_prefix, kwargs
857
-
858
- # structured dropout, more effective than traditional attention dropouts
859
-
860
- def dropout_seq(seq, mask, dropout):
861
- b, n, *_, device = *seq.shape, seq.device
862
- logits = torch.randn(b, n, device = device)
863
-
864
- if exists(mask):
865
- mask_value = max_neg_value(logits)
866
- logits = logits.masked_fill(~mask, mask_value)
867
-
868
- keep_prob = 1. - dropout
869
- num_keep = max(1, int(keep_prob * n))
870
- keep_indices = logits.topk(num_keep, dim = 1).indices
871
-
872
- batch_indices = torch.arange(b, device = device)
873
- batch_indices = rearrange(batch_indices, 'b -> b 1')
874
-
875
- seq = seq[batch_indices, keep_indices]
876
-
877
- if exists(mask):
878
- seq_counts = mask.sum(dim = -1)
879
- seq_keep_counts = torch.ceil(seq_counts * keep_prob).int()
880
- keep_mask = torch.arange(num_keep, device = device) < rearrange(seq_keep_counts, 'b -> b 1')
881
-
882
- mask = mask[batch_indices, keep_indices] & keep_mask
883
-
884
- return seq, mask
885
-
886
- # activations
887
-
888
- class ReluSquared(nn.Module):
889
- def forward(self, x):
890
- return F.relu(x) ** 2
891
-
892
- # embedding
893
-
894
- class TokenEmbedding(nn.Module):
895
- def __init__(self, dim, num_tokens, l2norm_embed = False):
896
- super().__init__()
897
- self.l2norm_embed = l2norm_embed
898
- self.emb = nn.Embedding(num_tokens, dim)
899
-
900
- def forward(self, x):
901
- token_emb = self.emb(x)
902
- return l2norm(token_emb) if self.l2norm_embed else token_emb
903
-
904
- # positional embeddings
905
-
906
- class AbsolutePositionalEmbedding(nn.Module):
907
- def __init__(self, dim, max_seq_len, l2norm_embed = False):
908
- super().__init__()
909
- self.scale = dim ** -0.5 if not l2norm_embed else 1.
910
- self.max_seq_len = max_seq_len
911
- self.l2norm_embed = l2norm_embed
912
- self.emb = nn.Embedding(max_seq_len, dim)
913
-
914
- def forward(self, x, pos = None, seq_start_pos = None):
915
- seq_len, device = x.shape[1], x.device
916
- assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'
917
-
918
- if not exists(pos):
919
- pos = torch.arange(seq_len, device = device)
920
-
921
- if exists(seq_start_pos):
922
- pos = (pos - seq_start_pos[..., None]).clamp(min = 0)
923
-
924
- pos_emb = self.emb(pos)
925
- pos_emb = pos_emb * self.scale
926
- return l2norm(pos_emb) if self.l2norm_embed else pos_emb
927
-
928
- class ScaledSinusoidalEmbedding(nn.Module):
929
- def __init__(self, dim, theta = 10000):
930
- super().__init__()
931
- assert divisible_by(dim, 2)
932
- self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5)
933
-
934
- half_dim = dim // 2
935
- freq_seq = torch.arange(half_dim).float() / half_dim
936
- inv_freq = theta ** -freq_seq
937
- self.register_buffer('inv_freq', inv_freq, persistent = False)
938
-
939
- def forward(self, x, pos = None, seq_start_pos = None):
940
- seq_len, device = x.shape[1], x.device
941
-
942
- if not exists(pos):
943
- pos = torch.arange(seq_len, device = device)
944
-
945
- if exists(seq_start_pos):
946
- pos = pos - seq_start_pos[..., None]
947
-
948
- emb = einsum('i, j -> i j', pos, self.inv_freq)
949
- emb = torch.cat((emb.sin(), emb.cos()), dim = -1)
950
- return emb * self.scale
951
-
952
- class RelativePositionBias(nn.Module):
953
- def __init__(self, scale, causal = False, num_buckets = 32, max_distance = 128, heads = 8):
954
- super().__init__()
955
- self.scale = scale
956
- self.causal = causal
957
- self.num_buckets = num_buckets
958
- self.max_distance = max_distance
959
- self.relative_attention_bias = nn.Embedding(num_buckets, heads)
960
-
961
- @staticmethod
962
- def _relative_position_bucket(relative_position, causal = True, num_buckets = 32, max_distance = 128):
963
- ret = 0
964
- n = -relative_position
965
- if not causal:
966
- num_buckets //= 2
967
- ret += (n < 0).long() * num_buckets
968
- n = torch.abs(n)
969
- else:
970
- n = torch.max(n, torch.zeros_like(n))
971
-
972
- max_exact = num_buckets // 2
973
- is_small = n < max_exact
974
-
975
- val_if_large = max_exact + (
976
- torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
977
- ).long()
978
- val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
979
-
980
- ret += torch.where(is_small, n, val_if_large)
981
- return ret
982
-
983
- @property
984
- def device(self):
985
- return next(self.parameters()).device
986
-
987
- def forward(self, i, j):
988
- device = self.device
989
- q_pos = torch.arange(j - i, j, dtype = torch.long, device = device)
990
- k_pos = torch.arange(j, dtype = torch.long, device = device)
991
- rel_pos = k_pos[None, :] - q_pos[:, None]
992
- rp_bucket = self._relative_position_bucket(rel_pos, causal = self.causal, num_buckets = self.num_buckets, max_distance = self.max_distance)
993
- values = self.relative_attention_bias(rp_bucket)
994
- bias = rearrange(values, 'i j h -> h i j')
995
- return bias * self.scale
996
-
997
- class DynamicPositionBias(nn.Module):
998
- def __init__(self, dim, *, heads, depth, log_distance = False, norm = False):
999
- super().__init__()
1000
- assert depth >= 1, 'depth for dynamic position bias MLP must be greater or equal to 1'
1001
- self.log_distance = log_distance
1002
-
1003
- self.mlp = nn.ModuleList([])
1004
-
1005
- self.mlp.append(Sequential(
1006
- nn.Linear(1, dim),
1007
- nn.LayerNorm(dim) if norm else None,
1008
- nn.SiLU()
1009
- ))
1010
-
1011
- for _ in range(depth - 1):
1012
- self.mlp.append(Sequential(
1013
- nn.Linear(dim, dim),
1014
- nn.LayerNorm(dim) if norm else None,
1015
- nn.SiLU()
1016
- ))
1017
-
1018
- self.mlp.append(nn.Linear(dim, heads))
1019
-
1020
- @property
1021
- def device(self):
1022
- return next(self.parameters()).device
1023
-
1024
- def forward(self, i, j):
1025
- assert i == j
1026
- n, device = j, self.device
1027
-
1028
- # get the (n x n) matrix of distances
1029
- seq_arange = torch.arange(n, device = device)
1030
- context_arange = torch.arange(n, device = device)
1031
- indices = rearrange(seq_arange, 'i -> i 1') - rearrange(context_arange, 'j -> 1 j')
1032
- indices += (n - 1)
1033
-
1034
- # input to continuous positions MLP
1035
- pos = torch.arange(-n + 1, n, device = device).float()
1036
- pos = rearrange(pos, '... -> ... 1')
1037
-
1038
- if self.log_distance:
1039
- pos = torch.sign(pos) * torch.log(pos.abs() + 1) # log of distance is sign(rel_pos) * log(abs(rel_pos) + 1)
1040
-
1041
- for layer in self.mlp:
1042
- pos = layer(pos)
1043
-
1044
- # get position biases
1045
- bias = pos[indices]
1046
- bias = rearrange(bias, 'i j h -> h i j')
1047
- return bias
1048
-
1049
- class AlibiPositionalBias(nn.Module):
1050
- def __init__(self, heads, total_heads, **kwargs):
1051
- super().__init__()
1052
- self.heads = heads
1053
- self.total_heads = total_heads
1054
-
1055
- slopes = Tensor(self._get_slopes(heads))
1056
- slopes = rearrange(slopes, 'h -> h 1 1')
1057
- self.register_buffer('slopes', slopes, persistent = False)
1058
- self.register_buffer('bias', None, persistent = False)
1059
-
1060
- def get_bias(self, i, j, device):
1061
- i_arange = torch.arange(j - i, j, device = device)
1062
- j_arange = torch.arange(j, device = device)
1063
- bias = -torch.abs(rearrange(j_arange, 'j -> 1 1 j') - rearrange(i_arange, 'i -> 1 i 1'))
1064
- return bias
1065
-
1066
- @staticmethod
1067
- def _get_slopes(heads):
1068
- def get_slopes_power_of_2(n):
1069
- start = (2**(-2**-(math.log2(n)-3)))
1070
- ratio = start
1071
- return [start*ratio**i for i in range(n)]
1072
-
1073
- if math.log2(heads).is_integer():
1074
- return get_slopes_power_of_2(heads)
1075
-
1076
- closest_power_of_2 = 2 ** math.floor(math.log2(heads))
1077
- return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][:heads-closest_power_of_2]
1078
-
1079
- @property
1080
- def device(self):
1081
- return next(self.buffers()).device
1082
-
1083
- def forward(self, i, j):
1084
- h, device = self.total_heads, self.device
1085
-
1086
- if exists(self.bias) and self.bias.shape[-1] >= j and self.bias.shape[-2] >= i:
1087
- return self.bias[..., -i:, -j:]
1088
-
1089
- bias = self.get_bias(i, j, device)
1090
- bias = bias * self.slopes
1091
-
1092
- num_heads_unalibied = h - bias.shape[0]
1093
- bias = pad_at_dim(bias, (0, num_heads_unalibied), dim = 0)
1094
- self.register_buffer('bias', bias, persistent = False)
1095
-
1096
- return self.bias
1097
-
1098
- class RotaryEmbedding(nn.Module):
1099
- def __init__(
1100
- self,
1101
- dim,
1102
- use_xpos = False,
1103
- scale_base = 512,
1104
- interpolation_factor = 1.,
1105
- base = 10000,
1106
- base_rescale_factor = 1.
1107
- ):
1108
- super().__init__()
1109
- # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
1110
- # has some connection to NTK literature
1111
- # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
1112
- base *= base_rescale_factor ** (dim / (dim - 2))
1113
-
1114
- inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
1115
- self.register_buffer('inv_freq', inv_freq)
1116
-
1117
- assert interpolation_factor >= 1.
1118
- self.interpolation_factor = interpolation_factor
1119
-
1120
- if not use_xpos:
1121
- self.register_buffer('scale', None)
1122
- return
1123
-
1124
- scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
1125
-
1126
- self.scale_base = scale_base
1127
- self.register_buffer('scale', scale)
1128
-
1129
- def forward(self, seq_len):
1130
- device = self.inv_freq.device
1131
- t = torch.arange(seq_len, device = device).type_as(self.inv_freq)
1132
-
1133
- t = t / self.interpolation_factor
1134
-
1135
- freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
1136
- freqs = torch.cat((freqs, freqs), dim = -1)
1137
-
1138
- if not exists(self.scale):
1139
- return freqs, 1.
1140
-
1141
- power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base
1142
- scale = self.scale ** rearrange(power, 'n -> n 1')
1143
- scale = torch.cat((scale, scale), dim = -1)
1144
-
1145
- return freqs, scale
1146
-
1147
-
1148
- def rotate_half(x):
1149
- x = rearrange(x, '... (j d) -> ... j d', j = 2)
1150
- x1, x2 = x.unbind(dim = -2)
1151
- return torch.cat((-x2, x1), dim = -1)
1152
-
1153
- def apply_rotary_pos_emb(t, freqs, scale = 1):
1154
- rot_dim, seq_len = freqs.shape[-1], t.shape[-2]
1155
- freqs = freqs[-seq_len:, :]
1156
-
1157
- if t.ndim == 4 and freqs.ndim == 3:
1158
- freqs = rearrange(freqs, 'b n d -> b 1 n d')
1159
-
1160
- # partial rotary embeddings, Wang et al. GPT-J
1161
- t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:]
1162
- t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
1163
- return torch.cat((t, t_unrotated), dim = -1)
1164
-
1165
- # norms
1166
-
1167
- class Scale(nn.Module):
1168
- def __init__(self, value, fn):
1169
- super().__init__()
1170
- self.value = value
1171
- self.fn = fn
1172
-
1173
- def forward(self, x, **kwargs):
1174
- out = self.fn(x, **kwargs)
1175
- scale_fn = lambda t: t * self.value
1176
-
1177
- if not isinstance(out, tuple):
1178
- return scale_fn(out)
1179
-
1180
- return (scale_fn(out[0]), *out[1:])
1181
-
1182
- class ScaleNorm(nn.Module):
1183
- def __init__(self, dim, eps = 1e-5):
1184
- super().__init__()
1185
- self.eps = eps
1186
- self.g = nn.Parameter(torch.ones(1) * (dim ** -0.5))
1187
-
1188
- def forward(self, x):
1189
- norm = torch.norm(x, dim = -1, keepdim = True)
1190
- return x / norm.clamp(min = self.eps) * self.g
1191
-
1192
- class RMSNorm(nn.Module):
1193
- def __init__(self, dim):
1194
- super().__init__()
1195
- self.scale = dim ** 0.5
1196
- self.g = nn.Parameter(torch.ones(dim))
1197
-
1198
- def forward(self, x):
1199
- return F.normalize(x, dim = -1) * self.scale * self.g
1200
-
1201
- class SimpleRMSNorm(nn.Module):
1202
- def __init__(self, dim):
1203
- super().__init__()
1204
- self.scale = dim ** 0.5
1205
-
1206
- def forward(self, x):
1207
- return F.normalize(x, dim = -1) * self.scale
1208
-
1209
- # residual and residual gates
1210
-
1211
- class Residual(nn.Module):
1212
- def __init__(self, dim, scale_residual = False, scale_residual_constant = 1.):
1213
- super().__init__()
1214
- self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
1215
- self.scale_residual_constant = scale_residual_constant
1216
-
1217
- def forward(self, x, residual):
1218
- if exists(self.residual_scale):
1219
- residual = residual * self.residual_scale
1220
-
1221
- if self.scale_residual_constant != 1:
1222
- residual = residual * self.scale_residual_constant
1223
-
1224
- return x + residual
1225
-
1226
- class GRUGating(nn.Module):
1227
- def __init__(self, dim, scale_residual = False, **kwargs):
1228
- super().__init__()
1229
- self.gru = nn.GRUCell(dim, dim)
1230
- self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
1231
-
1232
- def forward(self, x, residual):
1233
- if exists(self.residual_scale):
1234
- residual = residual * self.residual_scale
1235
-
1236
- gated_output = self.gru(
1237
- rearrange(x, 'b n d -> (b n) d'),
1238
- rearrange(residual, 'b n d -> (b n) d')
1239
- )
1240
-
1241
- return gated_output.reshape_as(x)
1242
-
1243
- # token shifting
1244
-
1245
- def shift(t, amount, mask = None):
1246
- if amount == 0:
1247
- return t
1248
- else:
1249
- amount = min(amount, t.shape[1])
1250
-
1251
- if exists(mask):
1252
- t = t.masked_fill(~mask[..., None], 0.)
1253
-
1254
- return pad_at_dim(t, (amount, -amount), dim = - 2, value = 0.)
1255
-
1256
- class ShiftTokens(nn.Module):
1257
- def __init__(self, shifts, fn):
1258
- super().__init__()
1259
- self.fn = fn
1260
- self.shifts = tuple(shifts)
1261
-
1262
- def forward(self, x, **kwargs):
1263
- mask = kwargs.get('mask', None)
1264
- shifts = self.shifts
1265
- segments = len(shifts)
1266
- feats_per_shift = x.shape[-1] // segments
1267
- splitted = x.split(feats_per_shift, dim = -1)
1268
- segments_to_shift, rest = splitted[:segments], splitted[segments:]
1269
- segments_to_shift = list(map(lambda args: shift(*args, mask = mask), zip(segments_to_shift, shifts)))
1270
- x = torch.cat((*segments_to_shift, *rest), dim = -1)
1271
- return self.fn(x, **kwargs)
1272
-
1273
- # feedforward
1274
-
1275
- class GLU(nn.Module):
1276
- def __init__(
1277
- self,
1278
- dim_in,
1279
- dim_out,
1280
- activation: Callable,
1281
- mult_bias = False
1282
- ):
1283
- super().__init__()
1284
- self.act = activation
1285
- self.proj = nn.Linear(dim_in, dim_out * 2)
1286
- self.mult_bias = nn.Parameter(torch.ones(dim_out)) if mult_bias else 1.
1287
-
1288
- def forward(self, x):
1289
- x, gate = self.proj(x).chunk(2, dim = -1)
1290
- return x * self.act(gate) * self.mult_bias
1291
-
1292
- class FeedForward(nn.Module):
1293
- def __init__(
1294
- self,
1295
- dim,
1296
- dim_out = None,
1297
- mult = 4,
1298
- glu = False,
1299
- glu_mult_bias = False,
1300
- swish = False,
1301
- relu_squared = False,
1302
- post_act_ln = False,
1303
- dropout = 0.,
1304
- no_bias = False,
1305
- zero_init_output = False
1306
- ):
1307
- super().__init__()
1308
- inner_dim = int(dim * mult)
1309
- dim_out = default(dim_out, dim)
1310
-
1311
- if relu_squared:
1312
- activation = ReluSquared()
1313
- elif swish:
1314
- activation = nn.SiLU()
1315
- else:
1316
- activation = nn.GELU()
1317
-
1318
- if glu:
1319
- project_in = GLU(dim, inner_dim, activation, mult_bias = glu_mult_bias)
1320
- else:
1321
- project_in = nn.Sequential(
1322
- nn.Linear(dim, inner_dim, bias = not no_bias),
1323
- activation
1324
- )
1325
-
1326
- self.ff = Sequential(
1327
- project_in,
1328
- nn.LayerNorm(inner_dim) if post_act_ln else None,
1329
- nn.Dropout(dropout),
1330
- nn.Linear(inner_dim, dim_out, bias = not no_bias)
1331
- )
1332
-
1333
- # init last linear layer to 0
1334
- if zero_init_output:
1335
- init_zero_(self.ff[-1])
1336
-
1337
- def forward(self, x):
1338
- return self.ff(x)
1339
-
1340
- # attention. it is all we need
1341
-
1342
- class Attention(nn.Module):
1343
- def __init__(
1344
- self,
1345
- dim,
1346
- dim_head = DEFAULT_DIM_HEAD,
1347
- heads = 8,
1348
- causal = False,
1349
- flash = False,
1350
- talking_heads = False,
1351
- head_scale = False,
1352
- sparse_topk = None,
1353
- num_mem_kv = 0,
1354
- dropout = 0.,
1355
- on_attn = False,
1356
- gate_value_heads = False,
1357
- gate_values = False,
1358
- zero_init_output = False,
1359
- max_attend_past = None,
1360
- qk_norm = False,
1361
- qk_norm_groups = 1,
1362
- qk_norm_scale = 10,
1363
- qk_norm_dim_scale = False,
1364
- one_kv_head = False,
1365
- kv_heads = None,
1366
- shared_kv = False,
1367
- value_dim_head = None,
1368
- tensor_product = False, # https://arxiv.org/abs/2208.06061
1369
- add_zero_kv = False, # same as add_zero_attn in pytorch
1370
- rotary_embed_values = False,
1371
- onnxable = False
1372
- ):
1373
- super().__init__()
1374
- self.scale = dim_head ** -0.5
1375
-
1376
- self.heads = heads
1377
- self.causal = causal
1378
- self.max_attend_past = max_attend_past
1379
-
1380
- assert not (exists(kv_heads) and one_kv_head), 'either attn_one_kv_head is set to True (in which case kv_heads is set to 1), or attn_kv_heads is set, but not both'
1381
-
1382
- value_dim_head = default(value_dim_head, dim_head)
1383
- kv_heads = default(kv_heads, heads)
1384
-
1385
- kv_heads = 1 if one_kv_head else kv_heads
1386
- assert divisible_by(heads, kv_heads)
1387
-
1388
- self.kv_heads = kv_heads
1389
-
1390
- q_dim = dim_head * heads
1391
- k_dim = dim_head * kv_heads
1392
- v_dim = value_dim_head * kv_heads
1393
- out_dim = value_dim_head * heads
1394
-
1395
- self.to_q = nn.Linear(dim, q_dim, bias = False)
1396
- self.to_k = nn.Linear(dim, k_dim, bias = False)
1397
-
1398
- # shared key / values, for further memory savings during inference
1399
- assert not (shared_kv and value_dim_head != dim_head), 'key and value head dimensions must be equal for shared key / values'
1400
- self.to_v = nn.Linear(dim, v_dim, bias = False) if not shared_kv else None
1401
-
1402
- # relations projection from tp-attention
1403
- self.to_r = nn.Linear(dim, v_dim, bias = False) if tensor_product else None
1404
-
1405
- # add GLU gating for aggregated values, from alphafold2
1406
- self.to_v_gate = None
1407
- if gate_values:
1408
- self.to_v_gate = nn.Linear(dim, out_dim)
1409
- nn.init.constant_(self.to_v_gate.weight, 0)
1410
- nn.init.constant_(self.to_v_gate.bias, 10)
1411
-
1412
- # add per head gating of the output values, from 'Attend to nothing' paper
1413
- self.to_v_head_gate = None
1414
- if gate_value_heads:
1415
- self.to_v_head_gate = nn.Linear(dim, heads)
1416
- nn.init.constant_(self.to_v_head_gate.weight, 0)
1417
- nn.init.constant_(self.to_v_head_gate.bias, 10)
1418
-
1419
- # cosine sim attention
1420
- self.qk_norm = qk_norm
1421
- self.qk_norm_groups = qk_norm_groups
1422
- self.qk_norm_scale = qk_norm_scale
1423
-
1424
- # whether to use the rmsnorm (equivalent to cosine sim attention when scale is equal to 1) - https://arxiv.org/abs/2302.05442
1425
- self.qk_norm_dim_scale = qk_norm_dim_scale
1426
-
1427
- self.qk_norm_q_scale = self.qk_norm_k_scale = 1
1428
- if qk_norm and qk_norm_dim_scale:
1429
- self.qk_norm_q_scale = nn.Parameter(torch.ones(heads, 1, dim_head))
1430
- self.qk_norm_k_scale = nn.Parameter(torch.ones(heads, 1, dim_head))
1431
-
1432
- assert (not qk_norm) or divisible_by(dim_head, qk_norm_groups), 'dimension per attention head must be divisible by the qk norm groups'
1433
- assert not (qk_norm and (dim_head // qk_norm_groups) <= 2), 'the group dimension may be too small (2 was too small in my tests, but 4 still works, surprisingly)'
1434
-
1435
- # attend class - includes core attention algorithm + talking heads
1436
-
1437
- self.attend = Attend(
1438
- heads = heads,
1439
- causal = causal,
1440
- talking_heads = talking_heads,
1441
- dropout = dropout,
1442
- sparse_topk = sparse_topk,
1443
- qk_norm = qk_norm,
1444
- scale = qk_norm_scale if qk_norm else self.scale,
1445
- add_zero_kv = add_zero_kv,
1446
- flash = flash,
1447
- onnxable = onnxable
1448
- )
1449
-
1450
- # head scaling
1451
- self.head_scale = head_scale
1452
- if head_scale:
1453
- self.head_scale_params = nn.Parameter(torch.ones(1, heads, 1, 1))
1454
-
1455
- # explicit topk sparse attention
1456
- self.sparse_topk = sparse_topk
1457
-
1458
- # add memory key / values
1459
- self.num_mem_kv = num_mem_kv
1460
- if num_mem_kv > 0:
1461
- self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
1462
- self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
1463
-
1464
- # attention on attention
1465
- self.attn_on_attn = on_attn
1466
- self.to_out = nn.Sequential(nn.Linear(out_dim, dim * 2, bias = False), nn.GLU()) if on_attn else nn.Linear(out_dim, dim, bias = False)
1467
-
1468
- # whether to rotate positions into values, for absolute positions in addition to relative
1469
- self.rotary_embed_values = rotary_embed_values
1470
-
1471
- # init output projection 0
1472
- if zero_init_output:
1473
- init_zero_(self.to_out)
1474
-
1475
- def forward(
1476
- self,
1477
- x,
1478
- context = None,
1479
- mask = None,
1480
- context_mask = None,
1481
- attn_mask = None,
1482
- rel_pos = None,
1483
- rotary_pos_emb = None,
1484
- prev_attn = None,
1485
- mem = None,
1486
- return_intermediates = False,
1487
- cache: Optional[Intermediates] = None,
1488
- ):
1489
- b, n, _, h, kv_h, head_scale, device, has_context = *x.shape, self.heads, self.kv_heads, self.head_scale, x.device, exists(context)
1490
- kv_input = default(context, x)
1491
-
1492
- q_input = x
1493
- k_input = kv_input
1494
- v_input = kv_input
1495
- r_input = x
1496
-
1497
- if exists(mem):
1498
- k_input, mem_packed_shape = pack([mem, k_input], 'b * d')
1499
- v_input, _ = pack([mem, v_input], 'b * d')
1500
-
1501
- q = self.to_q(q_input)
1502
- k = self.to_k(k_input)
1503
- v = self.to_v(v_input) if exists(self.to_v) else k
1504
- r = self.to_r(r_input) if exists(self.to_r) else None
1505
-
1506
- q = rearrange(q, 'b n (h d) -> b h n d', h = h)
1507
-
1508
- k, v, r = map(lambda t: maybe(rearrange)(t, 'b n (h d) -> b h n d', h = kv_h), (k, v, r))
1509
-
1510
- if exists(cache) and not has_context:
1511
- ck, cv = cache.cached_kv
1512
-
1513
- if exists(mem):
1514
- mk, k = unpack(k, mem_packed_shape, 'b h * d')
1515
- mv, v = unpack(v, mem_packed_shape, 'b h * d')
1516
-
1517
- k = torch.cat((ck, k), dim = -2)
1518
- v = torch.cat((cv, v), dim = -2)
1519
-
1520
- if exists(mem):
1521
- k = torch.cat((mk, k), dim = -2)
1522
- v = torch.cat((mv, v), dim = -2)
1523
-
1524
- if return_intermediates:
1525
- mem_len = mem.shape[-2] if exists(mem) else 0
1526
- cached_kv = (k[..., mem_len:, :], v[..., mem_len:, :])
1527
-
1528
- if self.qk_norm:
1529
- qk_l2norm = partial(l2norm, groups = self.qk_norm_groups)
1530
- q, k = map(qk_l2norm, (q, k))
1531
- scale = self.qk_norm_scale
1532
-
1533
- q = q * self.qk_norm_q_scale
1534
- k = k * self.qk_norm_k_scale
1535
-
1536
- if exists(rotary_pos_emb) and not has_context:
1537
- freqs, xpos_scale = rotary_pos_emb
1538
- q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if exists(xpos_scale) else (1., 1.)
1539
-
1540
- q = apply_rotary_pos_emb(q, freqs, q_xpos_scale)
1541
- k = apply_rotary_pos_emb(k, freqs, k_xpos_scale)
1542
-
1543
- if self.rotary_embed_values:
1544
- v = apply_rotary_pos_emb(v, freqs, k_xpos_scale)
1545
-
1546
- input_mask = context_mask
1547
-
1548
- if not exists(input_mask) and not has_context:
1549
- input_mask = mask
1550
-
1551
- if self.num_mem_kv > 0:
1552
- mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b = b), (self.mem_k, self.mem_v))
1553
-
1554
- if self.qk_norm:
1555
- mem_k = l2norm(mem_k)
1556
- mem_k = mem_k * self.qk_norm_k_scale
1557
-
1558
- k = torch.cat((mem_k, k), dim = -2)
1559
- v = torch.cat((mem_v, v), dim = -2)
1560
-
1561
- if exists(input_mask):
1562
- input_mask = pad_at_dim(input_mask, (self.num_mem_kv, 0), dim = -1, value = True)
1563
-
1564
- i, j = map(lambda t: t.shape[-2], (q, k))
1565
-
1566
- # determine masking
1567
-
1568
- mask_value = max_neg_value(q)
1569
- masks = []
1570
- final_attn_mask = None
1571
-
1572
- if exists(input_mask):
1573
- input_mask = rearrange(input_mask, 'b j -> b 1 1 j')
1574
- masks.append(~input_mask)
1575
-
1576
- if exists(attn_mask):
1577
- assert 2 <= attn_mask.ndim <= 4, 'attention mask must have greater than 2 dimensions but less than or equal to 4'
1578
- if attn_mask.ndim == 2:
1579
- attn_mask = rearrange(attn_mask, 'i j -> 1 1 i j')
1580
- elif attn_mask.ndim == 3:
1581
- attn_mask = rearrange(attn_mask, 'h i j -> 1 h i j')
1582
- masks.append(~attn_mask)
1583
-
1584
- if exists(self.max_attend_past):
1585
- range_q = torch.arange(j - i, j, device = device)
1586
- range_k = torch.arange(j, device = device)
1587
- dist = rearrange(range_q, 'i -> 1 1 i 1') - rearrange(range_k, 'j -> 1 1 1 j')
1588
- max_attend_past_mask = dist > self.max_attend_past
1589
- masks.append(max_attend_past_mask)
1590
-
1591
- if len(masks) > 0:
1592
- final_attn_mask = ~or_reduce(masks)
1593
-
1594
- # prepare relative positional bias, if needed
1595
-
1596
- attn_bias = None
1597
- if exists(rel_pos):
1598
- attn_bias = rel_pos(i, j)
1599
-
1600
- # attention is all we need
1601
-
1602
- out, intermediates = self.attend(
1603
- q, k, v,
1604
- mask = final_attn_mask,
1605
- attn_bias = attn_bias,
1606
- prev_attn = prev_attn
1607
- )
1608
-
1609
- # https://arxiv.org/abs/2208.06061 proposes to add a residual for better gradients
1610
-
1611
- if exists(r):
1612
- out = out * r + out
1613
-
1614
- # normformer scaling of heads
1615
-
1616
- if head_scale:
1617
- out = out * self.head_scale_params
1618
-
1619
- # per head gating, from https://arxiv.org/abs/2306.12929
1620
-
1621
- if exists(self.to_v_head_gate):
1622
- head_gate = self.to_v_head_gate(x)
1623
- out = out * rearrange(head_gate, 'b n h -> b h n 1').sigmoid()
1624
-
1625
- # merge heads
1626
-
1627
- out = rearrange(out, 'b h n d -> b n (h d)')
1628
-
1629
- # alphafold2 styled gating of the values
1630
-
1631
- if exists(self.to_v_gate):
1632
- gates = self.to_v_gate(x)
1633
- out = out * gates.sigmoid()
1634
-
1635
- # combine the heads
1636
-
1637
- out = self.to_out(out)
1638
-
1639
- if exists(mask):
1640
- mask = rearrange(mask, 'b n -> b n 1')
1641
- out = out.masked_fill(~mask, 0.)
1642
-
1643
- if not return_intermediates:
1644
- return out
1645
-
1646
- intermediates.cached_kv = cached_kv
1647
-
1648
- return out, intermediates
1649
-
1650
- class AttentionLayers(nn.Module):
1651
- def __init__(
1652
- self,
1653
- dim,
1654
- depth,
1655
- heads = 8,
1656
- causal = False,
1657
- cross_attend = False,
1658
- only_cross = False,
1659
- use_scalenorm = False,
1660
- use_rmsnorm = False,
1661
- use_simple_rmsnorm = False,
1662
- alibi_pos_bias = False,
1663
- alibi_num_heads = None,
1664
- rel_pos_bias = False,
1665
- rel_pos_num_buckets = 32,
1666
- rel_pos_max_distance = 128,
1667
- dynamic_pos_bias = False,
1668
- dynamic_pos_bias_log_distance = False,
1669
- dynamic_pos_bias_mlp_depth = 2,
1670
- dynamic_pos_bias_norm = False,
1671
- rotary_pos_emb = False,
1672
- rotary_emb_dim = None,
1673
- rotary_xpos = False,
1674
- rotary_interpolation_factor = 1.,
1675
- rotary_xpos_scale_base = 512,
1676
- rotary_base_rescale_factor = 1.,
1677
- custom_layers = None,
1678
- sandwich_coef = None,
1679
- par_ratio = None,
1680
- weight_tie_layers = False, # Albert - https://arxiv.org/abs/1909.11942
1681
- layers_execute_order = None, # generalizes weight tying, can do arbitrary layer execution orders
1682
- residual_attn = False,
1683
- cross_residual_attn = False,
1684
- macaron = False,
1685
- pre_norm = True,
1686
- pre_norm_has_final_norm = True,
1687
- gate_residual = False,
1688
- scale_residual = False,
1689
- scale_residual_constant = 1.,
1690
- shift_tokens = 0,
1691
- sandwich_norm = False,
1692
- resi_dual = False,
1693
- resi_dual_scale = 1.,
1694
- zero_init_branch_output = False,
1695
- layer_dropout = 0.,
1696
- cross_attn_tokens_dropout = 0.,
1697
- **kwargs
1698
- ):
1699
- super().__init__()
1700
- rotary_pos_emb = rotary_pos_emb or rotary_xpos
1701
-
1702
- ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)
1703
- attn_kwargs, kwargs = groupby_prefix_and_trim('attn_', kwargs)
1704
-
1705
- dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
1706
-
1707
- self.dim = dim
1708
- self.depth = depth
1709
- self.causal = causal
1710
- self.layers = nn.ModuleList([])
1711
-
1712
- self.has_pos_emb = rel_pos_bias or rotary_pos_emb
1713
-
1714
- rotary_emb_dim = max(default(rotary_emb_dim, dim_head // 2), 32)
1715
-
1716
- assert not (rotary_xpos and not causal), 'rotary xpos is not compatible with bidirectional attention'
1717
- self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim, use_xpos = rotary_xpos, scale_base = rotary_xpos_scale_base, interpolation_factor = rotary_interpolation_factor, base_rescale_factor = rotary_base_rescale_factor) if rotary_pos_emb else None
1718
-
1719
- assert not (alibi_pos_bias and rel_pos_bias), 'you can only choose Alibi positional bias or T5 relative positional bias, not both'
1720
- assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
1721
-
1722
- # relative positional bias
1723
-
1724
- flash_attn = attn_kwargs.get('flash', False)
1725
- assert (int(rel_pos_bias) + int(dynamic_pos_bias) + int(alibi_pos_bias)) <= 1, 'you can only choose up to one of t5, alibi, or dynamic positional bias'
1726
-
1727
- self.rel_pos = None
1728
- if rel_pos_bias:
1729
- assert not flash_attn, 'flash attention not compatible with t5 relative positional bias'
1730
- self.rel_pos = RelativePositionBias(scale = dim_head ** 0.5, causal = causal, heads = heads, num_buckets = rel_pos_num_buckets, max_distance = rel_pos_max_distance)
1731
- elif dynamic_pos_bias:
1732
- assert not flash_attn, 'flash attention not compatible with dynamic positional bias'
1733
- self.rel_pos = DynamicPositionBias(dim = dim // 4, heads = heads, log_distance = dynamic_pos_bias_log_distance, depth = dynamic_pos_bias_mlp_depth, norm = dynamic_pos_bias_norm)
1734
- elif alibi_pos_bias:
1735
- alibi_num_heads = default(alibi_num_heads, heads)
1736
- assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads'
1737
- self.rel_pos = AlibiPositionalBias(heads = alibi_num_heads, total_heads = heads)
1738
-
1739
- assert (int(sandwich_norm) + int(resi_dual)) <= 1, 'either sandwich norm or resiDual is selected, but not both'
1740
- assert not (not pre_norm and sandwich_norm), 'sandwich norm cannot be used when not using prenorm'
1741
-
1742
- if resi_dual:
1743
- pre_norm = False
1744
-
1745
- self.pre_norm = pre_norm
1746
- self.sandwich_norm = sandwich_norm
1747
-
1748
- self.resi_dual = resi_dual
1749
- assert 0 < resi_dual_scale <= 1., 'resiDual prenorm residual must be scaled by a factor greater than 0 and less than or equal to 1.'
1750
- self.resi_dual_scale = resi_dual_scale
1751
-
1752
- self.residual_attn = residual_attn
1753
- self.cross_residual_attn = cross_residual_attn
1754
- assert not (flash_attn and (residual_attn or cross_residual_attn)), 'flash attention is not compatible with residual attention'
1755
-
1756
- self.cross_attend = cross_attend
1757
-
1758
- assert (int(use_scalenorm) + int(use_rmsnorm) + int(use_simple_rmsnorm)) <= 1, 'you can only use either scalenorm, rmsnorm, or simple rmsnorm'
1759
-
1760
- if use_scalenorm:
1761
- norm_class = ScaleNorm
1762
- elif use_rmsnorm:
1763
- norm_class = RMSNorm
1764
- elif use_simple_rmsnorm:
1765
- norm_class = SimpleRMSNorm
1766
- else:
1767
- norm_class = nn.LayerNorm
1768
-
1769
- norm_fn = partial(norm_class, dim)
1770
-
1771
- if cross_attend and not only_cross:
1772
- default_block = ('a', 'c', 'f')
1773
- elif cross_attend and only_cross:
1774
- default_block = ('c', 'f')
1775
- else:
1776
- default_block = ('a', 'f')
1777
-
1778
- if macaron:
1779
- default_block = ('f',) + default_block
1780
-
1781
- # zero init
1782
-
1783
- if zero_init_branch_output:
1784
- attn_kwargs = {**attn_kwargs, 'zero_init_output': True}
1785
- ff_kwargs = {**ff_kwargs, 'zero_init_output': True}
1786
-
1787
- # setup weight tying, which is a special case of `layer_execute_order`
1788
-
1789
- assert not (weight_tie_layers and any([*map(exists, (custom_layers, par_ratio, sandwich_coef))]))
1790
-
1791
- if weight_tie_layers:
1792
- assert not exists(layers_execute_order)
1793
- layers_execute_order = tuple(range(len(default_block))) * depth
1794
- depth = 1
1795
-
1796
- # calculate layer block order
1797
-
1798
- if exists(custom_layers):
1799
- layer_types = custom_layers
1800
- elif exists(par_ratio):
1801
- par_depth = depth * len(default_block)
1802
- assert 1 < par_ratio <= par_depth, 'par ratio out of range'
1803
- default_block = tuple(filter(not_equals('f'), default_block))
1804
- par_attn = par_depth // par_ratio
1805
- depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper
1806
- par_width = (depth_cut + depth_cut // par_attn) // par_attn
1807
- assert len(default_block) <= par_width, 'default block is too large for par_ratio'
1808
- par_block = default_block + ('f',) * (par_width - len(default_block))
1809
- par_head = par_block * par_attn
1810
- layer_types = par_head + ('f',) * (par_depth - len(par_head))
1811
- elif exists(sandwich_coef):
1812
- assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'
1813
- layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef
1814
- else:
1815
- layer_types = default_block * depth
1816
-
1817
- self.layer_types = layer_types
1818
- self.layers_execute_order = default(layers_execute_order, tuple(range(len(layer_types))))
1819
-
1820
- assert all([i < len(self.layer_types) for i in self.layers_execute_order])
1821
-
1822
- self.num_attn_layers = len(list(filter(equals('a'), layer_types)))
1823
-
1824
- # stochastic depth
1825
-
1826
- self.layer_dropouts = cast_tuple(layer_dropout, len(layer_types))
1827
-
1828
- # structured dropout for cross attending
1829
-
1830
- self.cross_attn_tokens_dropout = cross_attn_tokens_dropout
1831
-
1832
- # calculate token shifting
1833
-
1834
- shift_tokens = cast_tuple(shift_tokens, len(layer_types))
1835
-
1836
- # whether it has post norm
1837
-
1838
- self.final_norm = norm_fn() if pre_norm or resi_dual else nn.Identity()
1839
-
1840
- # iterate and construct layers
1841
-
1842
- for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)):
1843
- is_last_layer = ind == (len(self.layer_types) - 1)
1844
-
1845
- if layer_type == 'a':
1846
- layer = Attention(dim, heads = heads, causal = causal, **attn_kwargs)
1847
- elif layer_type == 'c':
1848
- layer = Attention(dim, heads = heads, **attn_kwargs)
1849
- elif layer_type == 'f':
1850
- layer = FeedForward(dim, **ff_kwargs)
1851
- layer = layer if not macaron else Scale(0.5, layer)
1852
- else:
1853
- raise Exception(f'invalid layer type {layer_type}')
1854
-
1855
- if layer_shift_tokens > 0:
1856
- shift_range_upper = layer_shift_tokens + 1
1857
- shift_range_lower = -layer_shift_tokens if not causal else 0
1858
- layer = ShiftTokens(range(shift_range_lower, shift_range_upper), layer)
1859
-
1860
- residual_fn = GRUGating if gate_residual else Residual
1861
- residual = residual_fn(dim, scale_residual = scale_residual, scale_residual_constant = scale_residual_constant)
1862
-
1863
- pre_branch_norm = norm_fn() if pre_norm else None
1864
- post_branch_norm = norm_fn() if sandwich_norm else None
1865
- post_main_norm = norm_fn() if not pre_norm else None
1866
-
1867
- norms = nn.ModuleList([
1868
- pre_branch_norm,
1869
- post_branch_norm,
1870
- post_main_norm
1871
- ])
1872
-
1873
- self.layers.append(nn.ModuleList([
1874
- norms,
1875
- layer,
1876
- residual
1877
- ]))
1878
-
1879
- def forward(
1880
- self,
1881
- x,
1882
- context = None,
1883
- mask = None,
1884
- context_mask = None,
1885
- attn_mask = None,
1886
- self_attn_kv_mask = None,
1887
- mems = None,
1888
- seq_start_pos: Optional[Tensor] = None,
1889
- cache: Optional[LayerIntermediates] = None,
1890
- cache_age = 1,
1891
- return_hiddens = False
1892
- ):
1893
- assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross_attend is set to True'
1894
-
1895
- # initialize accums
1896
-
1897
- hiddens = []
1898
- layer_hiddens = []
1899
- intermediates = []
1900
-
1901
- prev_attn = None
1902
- prev_cross_attn = None
1903
-
1904
- mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
1905
-
1906
- # handle left padded sequences
1907
-
1908
- if exists(seq_start_pos):
1909
- seq_arange = torch.arange(x.shape[-2], device = x.device, dtype = torch.long)
1910
- left_pad_mask = seq_arange >= seq_start_pos[..., None]
1911
-
1912
- if exists(self_attn_kv_mask):
1913
- self_attn_kv_mask = self_attn_kv_mask & left_pad_mask
1914
- else:
1915
- self_attn_kv_mask = left_pad_mask
1916
-
1917
- # rotary positions
1918
-
1919
- rotary_pos_emb = None
1920
-
1921
- if exists(self.rotary_pos_emb):
1922
- max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + x.shape[1], mems)))
1923
- rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length)
1924
-
1925
- # assume cached key / values
1926
-
1927
- attn_cache = []
1928
-
1929
- if exists(cache):
1930
- assert not self.training and self.causal and not any([*map(exists, (mask, attn_mask))])
1931
-
1932
- if cache_age > 0:
1933
- x = x[:, -cache_age:] # for spec decoding, may be greater than 1
1934
-
1935
- attn_cache = cache.attn_intermediates
1936
-
1937
- iter_attn_cache = iter(attn_cache)
1938
-
1939
- # outer residual - for resiDual paper
1940
-
1941
- outer_residual = x * self.resi_dual_scale
1942
-
1943
- # get layers to be executed
1944
-
1945
- layer_variables = (
1946
- self.layer_types,
1947
- self.layers,
1948
- self.layer_dropouts
1949
- )
1950
-
1951
- layer_variables = tuple(tuple(layer_variable[i] for i in self.layers_execute_order) for layer_variable in layer_variables)
1952
-
1953
- # go through the attention and feedforward layers
1954
-
1955
- for ind, (layer_type, (norm, block, residual_fn), layer_dropout) in enumerate(zip(*layer_variables)):
1956
- is_last = ind == (len(self.layers) - 1)
1957
-
1958
- if self.training and layer_dropout > 0. and random() < layer_dropout:
1959
- continue
1960
-
1961
- if layer_type == 'a':
1962
- if return_hiddens:
1963
- hiddens.append(x)
1964
- layer_mem = mems.pop(0) if mems else None
1965
-
1966
- if layer_type == 'c':
1967
- if self.training and self.cross_attn_tokens_dropout > 0.:
1968
- context, context_mask = dropout_seq(context, context_mask, self.cross_attn_tokens_dropout)
1969
-
1970
- inner_residual = x
1971
-
1972
- if return_hiddens:
1973
- layer_hiddens.append(x)
1974
-
1975
- pre_norm, post_branch_norm, post_main_norm = norm
1976
-
1977
- if exists(pre_norm):
1978
- x = pre_norm(x)
1979
-
1980
- if layer_type == 'a':
1981
- out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, return_intermediates = True)
1982
- elif layer_type == 'c':
1983
- out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), return_intermediates = True)
1984
- elif layer_type == 'f':
1985
- out = block(x)
1986
-
1987
- if self.resi_dual:
1988
- outer_residual = outer_residual + out * self.resi_dual_scale
1989
-
1990
- if exists(post_branch_norm):
1991
- out = post_branch_norm(out)
1992
-
1993
- x = residual_fn(out, inner_residual)
1994
-
1995
- if layer_type in ('a', 'c') and return_hiddens:
1996
- intermediates.append(inter)
1997
-
1998
- if layer_type == 'a' and self.residual_attn:
1999
- prev_attn = inter.pre_softmax_attn
2000
- elif layer_type == 'c' and self.cross_residual_attn:
2001
- prev_cross_attn = inter.pre_softmax_attn
2002
-
2003
- if exists(post_main_norm):
2004
- x = post_main_norm(x)
2005
-
2006
- if return_hiddens:
2007
- layer_hiddens.append(x)
2008
-
2009
- if self.resi_dual:
2010
- x = x + self.final_norm(outer_residual)
2011
- else:
2012
- x = self.final_norm(x)
2013
-
2014
- if not return_hiddens:
2015
- return x
2016
-
2017
- intermediates = LayerIntermediates(
2018
- hiddens = hiddens,
2019
- attn_intermediates = intermediates,
2020
- layer_hiddens = layer_hiddens
2021
- )
2022
-
2023
- return x, intermediates
2024
-
2025
- class Encoder(AttentionLayers):
2026
- def __init__(self, **kwargs):
2027
- assert 'causal' not in kwargs, 'cannot set causality on encoder'
2028
- super().__init__(causal = False, **kwargs)
2029
-
2030
- class Decoder(AttentionLayers):
2031
- def __init__(self, **kwargs):
2032
- assert 'causal' not in kwargs, 'cannot set causality on decoder'
2033
- super().__init__(causal = True, **kwargs)
2034
-
2035
- class CrossAttender(AttentionLayers):
2036
- def __init__(self, **kwargs):
2037
- super().__init__(cross_attend = True, only_cross = True, **kwargs)
2038
-
2039
- class ViTransformerWrapper(nn.Module):
2040
- def __init__(
2041
- self,
2042
- *,
2043
- image_size,
2044
- patch_size,
2045
- attn_layers,
2046
- channels = 3,
2047
- num_classes = None,
2048
- post_emb_norm = False,
2049
- num_register_tokens = 0,
2050
- emb_dropout = 0.
2051
- ):
2052
- super().__init__()
2053
- assert isinstance(attn_layers, Encoder), 'attention layers must be an Encoder'
2054
- assert divisible_by(image_size, patch_size), 'image dimensions must be divisible by the patch size'
2055
- dim = attn_layers.dim
2056
- num_patches = (image_size // patch_size) ** 2
2057
- patch_dim = channels * patch_size ** 2
2058
-
2059
- self.patch_size = patch_size
2060
-
2061
- self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim))
2062
-
2063
- has_register_tokens = num_register_tokens > 0
2064
- self.has_register_tokens = has_register_tokens
2065
-
2066
- if has_register_tokens:
2067
- self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim))
2068
-
2069
- self.patch_to_embedding = nn.Sequential(
2070
- nn.LayerNorm(patch_dim),
2071
- nn.Linear(patch_dim, dim),
2072
- nn.LayerNorm(dim)
2073
- )
2074
-
2075
- self.post_emb_norm = nn.LayerNorm(dim) if post_emb_norm else nn.Identity()
2076
- self.dropout = nn.Dropout(emb_dropout)
2077
-
2078
- self.attn_layers = attn_layers
2079
-
2080
- self.mlp_head = nn.Linear(dim, num_classes) if exists(num_classes) else nn.Identity()
2081
-
2082
- def forward(
2083
- self,
2084
- img,
2085
- return_embeddings = False
2086
- ):
2087
- b, p = img.shape[0], self.patch_size
2088
-
2089
- x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
2090
- x = self.patch_to_embedding(x)
2091
- n = x.shape[1]
2092
-
2093
- x = x + self.pos_embedding[:, :n]
2094
-
2095
- x = self.post_emb_norm(x)
2096
- x = self.dropout(x)
2097
-
2098
- if self.has_register_tokens:
2099
- r = repeat(self.register_tokens, 'n d -> b n d', b = b)
2100
- x, ps = pack((x, r), 'b * d')
2101
-
2102
- x = self.attn_layers(x)
2103
-
2104
- if self.has_register_tokens:
2105
- x, _ = unpack(x, ps, 'b * d')
2106
-
2107
- if not exists(self.mlp_head) or return_embeddings:
2108
- return x
2109
-
2110
- x = x.mean(dim = -2)
2111
- return self.mlp_head(x)
2112
-
2113
- class TransformerWrapper(nn.Module):
2114
- def __init__(
2115
- self,
2116
- *,
2117
- num_tokens,
2118
- max_seq_len,
2119
- attn_layers,
2120
- emb_dim = None,
2121
- max_mem_len = 0,
2122
- shift_mem_down = 0,
2123
- emb_dropout = 0.,
2124
- post_emb_norm = False,
2125
- num_memory_tokens = None,
2126
- memory_tokens_interspersed_every = None,
2127
- tie_embedding = False,
2128
- logits_dim = None,
2129
- use_abs_pos_emb = True,
2130
- scaled_sinu_pos_emb = False,
2131
- l2norm_embed = False,
2132
- emb_frac_gradient = 1., # GLM-130B and Cogview successfully used this, set at 0.1
2133
- attn_z_loss_weight = 1e-4,
2134
- ):
2135
- super().__init__()
2136
- assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
2137
-
2138
- dim = attn_layers.dim
2139
- emb_dim = default(emb_dim, dim)
2140
- self.emb_dim = emb_dim
2141
- self.num_tokens = num_tokens
2142
-
2143
- self.max_seq_len = max_seq_len
2144
- self.max_mem_len = max_mem_len
2145
- self.shift_mem_down = shift_mem_down
2146
-
2147
- self.l2norm_embed = l2norm_embed
2148
- self.token_emb = TokenEmbedding(emb_dim, num_tokens, l2norm_embed = l2norm_embed)
2149
-
2150
- if not (use_abs_pos_emb and not attn_layers.has_pos_emb):
2151
- self.pos_emb = always(0)
2152
- elif scaled_sinu_pos_emb:
2153
- self.pos_emb = ScaledSinusoidalEmbedding(emb_dim)
2154
- else:
2155
- self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len, l2norm_embed = l2norm_embed)
2156
-
2157
- self.emb_frac_gradient = emb_frac_gradient # fraction of the gradient that should go to the embedding, https://arxiv.org/abs/2105.13290
2158
-
2159
- self.post_emb_norm = nn.LayerNorm(emb_dim) if post_emb_norm else nn.Identity()
2160
- self.emb_dropout = nn.Dropout(emb_dropout)
2161
-
2162
- self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
2163
- self.attn_layers = attn_layers
2164
-
2165
- self.init_()
2166
-
2167
- logits_dim = default(logits_dim, num_tokens)
2168
- self.to_logits = nn.Linear(dim, logits_dim) if not tie_embedding else lambda t: t @ self.token_emb.emb.weight.t()
2169
-
2170
- # memory tokens (like [cls]) from Memory Transformers paper
2171
-
2172
- num_memory_tokens = default(num_memory_tokens, 0)
2173
- self.num_memory_tokens = num_memory_tokens
2174
- if num_memory_tokens > 0:
2175
- self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
2176
-
2177
- self.memory_tokens_interspersed_every = memory_tokens_interspersed_every
2178
-
2179
- # whether can do cached kv decoding
2180
-
2181
- self.can_cache_kv = self.num_memory_tokens == 0
2182
-
2183
- def init_(self):
2184
- if self.l2norm_embed:
2185
- nn.init.normal_(self.token_emb.emb.weight, std = 1e-5)
2186
- if not isinstance(self.pos_emb, always):
2187
- nn.init.normal_(self.pos_emb.emb.weight, std = 1e-5)
2188
- return
2189
-
2190
- nn.init.kaiming_normal_(self.token_emb.emb.weight)
2191
-
2192
- def forward(
2193
- self,
2194
- x,
2195
- return_embeddings = False,
2196
- return_logits_and_embeddings = False,
2197
- return_intermediates = False,
2198
- mask = None,
2199
- return_mems = False,
2200
- return_attn = False,
2201
- mems = None,
2202
- pos = None,
2203
- prepend_embeds = None,
2204
- sum_embeds = None,
2205
- return_attn_z_loss = False,
2206
- attn_z_loss_weight = 1e-4,
2207
- seq_start_pos = None,
2208
- cache: Optional[LayerIntermediates] = None,
2209
- **kwargs
2210
- ):
2211
- b, n, device, num_mems, has_memory_tokens, emb_frac_gradient = *x.shape, x.device, self.num_memory_tokens, self.num_memory_tokens > 0, self.emb_frac_gradient
2212
- return_hiddens = return_mems | return_attn | return_intermediates | return_attn_z_loss
2213
-
2214
- # absolute positional embedding
2215
-
2216
- external_pos_emb = exists(pos) and pos.dtype != torch.long
2217
- pos_emb = self.pos_emb(x, pos = pos, seq_start_pos = seq_start_pos) if not external_pos_emb else pos
2218
- x = self.token_emb(x) + pos_emb
2219
-
2220
- # for summing embeddings passed externally - needs this for self-conditioning in non-autoregressive training
2221
-
2222
- if exists(sum_embeds):
2223
- x = x + sum_embeds
2224
-
2225
- # post embedding norm, purportedly leads to greater stabilization
2226
-
2227
- x = self.post_emb_norm(x)
2228
-
2229
- # whether to append embeds, as in PaLI, for image embeddings
2230
-
2231
- if exists(prepend_embeds):
2232
- prepend_seq, prepend_dim = prepend_embeds.shape[1:]
2233
- assert prepend_dim == x.shape[-1], 'prepended embeddings need to have same dimensions as text model dimensions'
2234
-
2235
- x = torch.cat((prepend_embeds, x), dim = -2)
2236
-
2237
- # whether to reduce the gradient going to the embedding, from cogview paper, corroborated by GLM-130B model
2238
-
2239
- if emb_frac_gradient < 1:
2240
- assert emb_frac_gradient > 0
2241
- x = x * emb_frac_gradient + x.detach() * (1 - emb_frac_gradient)
2242
-
2243
- # embedding dropout
2244
-
2245
- x = self.emb_dropout(x)
2246
-
2247
- x = self.project_emb(x)
2248
-
2249
- if has_memory_tokens:
2250
- mem_every = self.memory_tokens_interspersed_every
2251
-
2252
- if exists(mem_every):
2253
- assert mem_every > 0
2254
- assert isinstance(self.attn_layers, Decoder), 'only for decoder'
2255
- next_seq_len = math.ceil(n / mem_every) * mem_every
2256
-
2257
- x = pad_at_dim(x, (0, next_seq_len - n), dim = -2, value = 0.)
2258
- x = rearrange(x, 'b (n m) d -> (b n) m d', m = mem_every)
2259
-
2260
- mem = repeat(self.memory_tokens, 'n d -> b n d', b = x.shape[0])
2261
- x, mem_packed_shape = pack((mem, x), 'b * d')
2262
-
2263
- # auto-handle masking after appending memory tokens
2264
- if not exists(mem_every) and exists(mask):
2265
- mask = pad_at_dim(mask, (num_mems, 0), dim = -1, value = True)
2266
-
2267
- if exists(mem_every):
2268
- x = rearrange(x, '(b n) m d -> b (n m) d', b = b)
2269
-
2270
- if self.shift_mem_down and exists(mems):
2271
- mems_l, mems_r = mems[:self.shift_mem_down], mems[self.shift_mem_down:]
2272
- mems = [*mems_r, *mems_l]
2273
-
2274
- x, intermediates = self.attn_layers(x, mask = mask, mems = mems, cache = cache, return_hiddens = True, seq_start_pos = seq_start_pos, **kwargs)
2275
-
2276
- if has_memory_tokens:
2277
- if exists(mem_every):
2278
- x = rearrange(x, 'b (n m) d -> (b n) m d', m = (mem_every + num_mems))
2279
-
2280
- mem, x = unpack(x, mem_packed_shape, 'b * d')
2281
-
2282
- if exists(mem_every):
2283
- x = rearrange(x, '(b n) m d -> b (n m) d', b = b)
2284
-
2285
- x = x[:, :n]
2286
-
2287
- if return_logits_and_embeddings:
2288
- out = (self.to_logits(x), x)
2289
- elif return_embeddings:
2290
- out = x
2291
- else:
2292
- out = self.to_logits(x)
2293
-
2294
- if return_attn_z_loss:
2295
- pre_softmax_attns = list(map(lambda t: t.pre_softmax_attn, intermediates.attn_intermediates))
2296
- intermediates.attn_z_loss = calc_z_loss(pre_softmax_attns, weight = attn_z_loss_weight)
2297
- return_intermediates = True
2298
-
2299
- if return_mems:
2300
- hiddens = intermediates.hiddens
2301
- new_mems = list(map(lambda pair: torch.cat(pair, dim = -2), zip(mems, hiddens))) if exists(mems) else hiddens
2302
- new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems))
2303
-
2304
- if not return_intermediates:
2305
- return out, new_mems
2306
-
2307
- intermediates.mems = new_mems
2308
-
2309
- if return_intermediates:
2310
- return out, intermediates
2311
-
2312
- if return_attn:
2313
- attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
2314
- return out, attn_maps
2315
-
2316
- return out
2317
-
2318
- class ContinuousTransformerWrapper(nn.Module):
2319
- def __init__(
2320
- self,
2321
- *,
2322
- max_seq_len,
2323
- attn_layers,
2324
- dim_in = None,
2325
- dim_out = None,
2326
- emb_dim = None,
2327
- max_mem_len = 0,
2328
- post_emb_norm = False,
2329
- emb_dropout = 0.,
2330
- use_abs_pos_emb = True,
2331
- scaled_sinu_pos_emb = False
2332
- ):
2333
- super().__init__()
2334
- assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
2335
-
2336
- dim = attn_layers.dim
2337
-
2338
- self.max_seq_len = max_seq_len
2339
-
2340
- self.max_mem_len = max_mem_len
2341
-
2342
- if not (use_abs_pos_emb and not attn_layers.has_pos_emb):
2343
- self.pos_emb = always(0)
2344
- elif scaled_sinu_pos_emb:
2345
- self.pos_emb = ScaledSinusoidalEmbedding(dim)
2346
- else:
2347
- self.pos_emb = AbsolutePositionalEmbedding(dim, max_seq_len)
2348
-
2349
- self.post_emb_norm = nn.LayerNorm(dim) if post_emb_norm else nn.Identity()
2350
- self.emb_dropout = nn.Dropout(emb_dropout)
2351
-
2352
- self.project_in = nn.Linear(dim_in, dim) if exists(dim_in) else nn.Identity()
2353
-
2354
- self.attn_layers = attn_layers
2355
-
2356
- self.project_out = nn.Linear(dim, dim_out) if exists(dim_out) else nn.Identity()
2357
-
2358
- def forward(
2359
- self,
2360
- x,
2361
- return_embeddings = False,
2362
- return_intermediates = False,
2363
- return_mems = False,
2364
- mask = None,
2365
- return_attn = False,
2366
- mems = None,
2367
- pos = None,
2368
- prepend_embeds = None,
2369
- **kwargs
2370
- ):
2371
- x = self.project_in(x)
2372
- x = x + self.pos_emb(x, pos = pos)
2373
-
2374
- x = self.post_emb_norm(x)
2375
-
2376
- # whether to append embeds, as in PaLI, for image embeddings
2377
-
2378
- if exists(prepend_embeds):
2379
- _, prepend_dim = prepend_embeds.shape[1:]
2380
- assert prepend_dim == x.shape[-1], 'prepended embeddings need to have same dimensions as model dimensions'
2381
-
2382
- x = torch.cat((prepend_embeds, x), dim = -2)
2383
-
2384
- x = self.emb_dropout(x)
2385
-
2386
- x, intermediates = self.attn_layers(x, mask = mask, mems = mems, return_hiddens = True, **kwargs)
2387
-
2388
- out = self.project_out(x) if not return_embeddings else x
2389
-
2390
- if return_intermediates:
2391
- return out, intermediates
2392
-
2393
- if return_mems:
2394
- hiddens = intermediates.hiddens
2395
- new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), hiddens))
2396
- return out, new_mems
2397
-
2398
- if return_attn:
2399
- attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
2400
- return out, attn_maps
2401
-
2402
- return out
2403
-
2404
- class XTransformer(nn.Module):
2405
- def __init__(
2406
- self,
2407
- *,
2408
- dim,
2409
- tie_token_emb = False,
2410
- ignore_index = -100,
2411
- pad_value = 0,
2412
- cross_attn_tokens_dropout = 0.,
2413
- **kwargs
2414
- ):
2415
- super().__init__()
2416
- enc_kwargs, kwargs = groupby_prefix_and_trim('enc_', kwargs)
2417
- dec_kwargs, kwargs = groupby_prefix_and_trim('dec_', kwargs)
2418
-
2419
- assert 'dim' not in enc_kwargs and 'dim' not in dec_kwargs, 'dimension of either encoder or decoder must be set with `dim` keyword'
2420
- enc_transformer_kwargs = pick_and_pop(['num_tokens', 'max_seq_len'], enc_kwargs)
2421
- enc_transformer_kwargs['emb_dropout'] = enc_kwargs.pop('emb_dropout', 0)
2422
- enc_transformer_kwargs['num_memory_tokens'] = enc_kwargs.pop('num_memory_tokens', None)
2423
- enc_transformer_kwargs['scaled_sinu_pos_emb'] = enc_kwargs.pop('scaled_sinu_pos_emb', False)
2424
- enc_transformer_kwargs['use_abs_pos_emb'] = enc_kwargs.pop('use_abs_pos_emb', True)
2425
-
2426
- dec_transformer_kwargs = pick_and_pop(['num_tokens', 'max_seq_len'], dec_kwargs)
2427
- dec_transformer_kwargs['emb_dropout'] = dec_kwargs.pop('emb_dropout', 0)
2428
- dec_transformer_kwargs['scaled_sinu_pos_emb'] = dec_kwargs.pop('scaled_sinu_pos_emb', False)
2429
- dec_transformer_kwargs['use_abs_pos_emb'] = dec_kwargs.pop('use_abs_pos_emb', True)
2430
-
2431
- self.cross_attn_tokens_dropout = cross_attn_tokens_dropout # how many tokens from the encoder to dropout when cross attending from decoder - seen in a couple papers, including Perceiver AR - this will also be very effective regularization when cross attending to very long memories
2432
-
2433
- self.encoder = TransformerWrapper(
2434
- **enc_transformer_kwargs,
2435
- attn_layers = Encoder(dim = dim, **enc_kwargs)
2436
- )
2437
-
2438
- self.decoder = TransformerWrapper(
2439
- **dec_transformer_kwargs,
2440
- attn_layers = Decoder(dim = dim, cross_attend = True, **dec_kwargs)
2441
- )
2442
-
2443
- if tie_token_emb:
2444
- self.decoder.token_emb = self.encoder.token_emb
2445
-
2446
- self.decoder = AutoregressiveWrapper(self.decoder, ignore_index=ignore_index, pad_value=pad_value)
2447
-
2448
- @torch.no_grad()
2449
- def generate(self, seq_in, seq_out_start, seq_len, mask = None, attn_mask = None, **kwargs):
2450
- encodings = self.encoder(seq_in, mask = mask, attn_mask = attn_mask, return_embeddings = True)
2451
- return self.decoder.generate(seq_out_start, seq_len, context = encodings, context_mask = mask, **kwargs)
2452
-
2453
- def forward(self, src, tgt, mask = None, attn_mask = None, src_prepend_embeds = None):
2454
-
2455
- if exists(src_prepend_embeds) and exists(mask):
2456
- mask = pad_at_dim(mask, (src_prepend_embeds.shape[-2], 0), dim = -1, value = True)
2457
-
2458
- enc = self.encoder(src, mask = mask, attn_mask = attn_mask, prepend_embeds = src_prepend_embeds, return_embeddings = True)
2459
-
2460
- if self.training and self.cross_attn_tokens_dropout > 0:
2461
- enc, mask = dropout_seq(enc, mask, self.cross_attn_tokens_dropout)
2462
-
2463
- out = self.decoder(tgt, context = enc, context_mask = mask)
2464
- return out