qianyuchen commited on
Commit
dcb92ed
1 Parent(s): 52c7ed6

Update resampler.py

Browse files

update for lora finetuning

Files changed (1) hide show
  1. resampler.py +664 -7
resampler.py CHANGED
@@ -19,6 +19,21 @@ from torch.nn.init import trunc_normal_
19
  from torchvision import transforms
20
  from torchvision.transforms import InterpolationMode
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  def get_abs_pos(abs_pos, tgt_size):
23
  # abs_pos: L, C
24
  # tgt_size: (H, W)
@@ -117,24 +132,20 @@ class Resampler(nn.Module):
117
  self.pos_embed = nn.Parameter(
118
  torch.from_numpy(get_2d_sincos_pos_embed(embed_dim, grid_size)).float()
119
  ).requires_grad_(False)
120
-
121
  self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
122
- trunc_normal_(self.query, std=.02)
123
 
124
  if kv_dim is not None and kv_dim != embed_dim:
125
  self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False)
126
  else:
127
  self.kv_proj = nn.Identity()
128
 
129
- self.attn = nn.MultiheadAttention(embed_dim, num_heads)
130
  self.ln_q = norm_layer(embed_dim)
131
  self.ln_kv = norm_layer(embed_dim)
132
 
133
  self.ln_post = norm_layer(embed_dim)
134
  self.proj = nn.Parameter((embed_dim ** -0.5) * torch.randn(embed_dim, embed_dim))
135
 
136
- self.apply(self._init_weights)
137
-
138
  def _init_weights(self, m):
139
  if isinstance(m, nn.Linear):
140
  trunc_normal_(m.weight, std=.02)
@@ -149,22 +160,668 @@ class Resampler(nn.Module):
149
  pos_embed = torch.Tensor(get_2d_sincos_pos_embed(self.embed_dim, tgt_size)).float().to(device=x.device, dtype=x.dtype)
150
  else:
151
  pos_embed = get_abs_pos(self.pos_embed, tgt_size)
152
-
153
  x = self.kv_proj(x)
154
  x = self.ln_kv(x).permute(1, 0, 2)
155
 
156
  N = x.shape[1]
157
  q = self.ln_q(self.query)
 
158
  out = self.attn(
159
  self._repeat(q, N) + self.pos_embed.unsqueeze(1),
160
  x + pos_embed.unsqueeze(1),
161
  x,
162
  attn_mask=attn_mask)[0]
163
  x = out.permute(1, 0, 2)
164
-
165
  x = self.ln_post(x)
166
  x = x @ self.proj
167
  return x
168
 
169
  def _repeat(self, query, N: int):
170
  return query.unsqueeze(1).repeat(1, N, 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  from torchvision import transforms
20
  from torchvision.transforms import InterpolationMode
21
 
22
+ from functools import partial
23
+ import numpy as np
24
+ import warnings
25
+ from typing import Optional, Tuple
26
+ import torch
27
+ from torch import nn
28
+ from torch import Tensor
29
+ import deepspeed
30
+ import torch.nn.functional as F
31
+ from torch.nn.functional import *
32
+ from torch.nn.modules.activation import *
33
+ from torch.nn.init import trunc_normal_
34
+ from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
35
+ from transformers import PreTrainedModel
36
+ from transformers.integrations import is_deepspeed_zero3_enabled
37
  def get_abs_pos(abs_pos, tgt_size):
38
  # abs_pos: L, C
39
  # tgt_size: (H, W)
 
132
  self.pos_embed = nn.Parameter(
133
  torch.from_numpy(get_2d_sincos_pos_embed(embed_dim, grid_size)).float()
134
  ).requires_grad_(False)
 
135
  self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
 
136
 
137
  if kv_dim is not None and kv_dim != embed_dim:
138
  self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False)
139
  else:
140
  self.kv_proj = nn.Identity()
141
 
142
+ self.attn = MultiheadAttention(embed_dim, num_heads)
143
  self.ln_q = norm_layer(embed_dim)
144
  self.ln_kv = norm_layer(embed_dim)
145
 
146
  self.ln_post = norm_layer(embed_dim)
147
  self.proj = nn.Parameter((embed_dim ** -0.5) * torch.randn(embed_dim, embed_dim))
148
 
 
 
149
  def _init_weights(self, m):
150
  if isinstance(m, nn.Linear):
151
  trunc_normal_(m.weight, std=.02)
 
160
  pos_embed = torch.Tensor(get_2d_sincos_pos_embed(self.embed_dim, tgt_size)).float().to(device=x.device, dtype=x.dtype)
161
  else:
162
  pos_embed = get_abs_pos(self.pos_embed, tgt_size)
163
+
164
  x = self.kv_proj(x)
165
  x = self.ln_kv(x).permute(1, 0, 2)
166
 
167
  N = x.shape[1]
168
  q = self.ln_q(self.query)
169
+
170
  out = self.attn(
171
  self._repeat(q, N) + self.pos_embed.unsqueeze(1),
172
  x + pos_embed.unsqueeze(1),
173
  x,
174
  attn_mask=attn_mask)[0]
175
  x = out.permute(1, 0, 2)
 
176
  x = self.ln_post(x)
177
  x = x @ self.proj
178
  return x
179
 
180
  def _repeat(self, query, N: int):
181
  return query.unsqueeze(1).repeat(1, N, 1)
182
+
183
+
184
+
185
+ class MultiheadAttention(nn.MultiheadAttention):
186
+ def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False,
187
+ add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None):
188
+ super().__init__(embed_dim, num_heads, dropout, bias, add_bias_kv, add_zero_attn, kdim, vdim, batch_first, device, dtype)
189
+
190
+ # rewrite out_proj layer,with nn.Linear
191
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias,)
192
+ print(device)
193
+
194
+ def forward(
195
+ self,
196
+ query: Tensor,
197
+ key: Tensor,
198
+ value: Tensor,
199
+ key_padding_mask: Optional[Tensor] = None,
200
+ need_weights: bool = True,
201
+ attn_mask: Optional[Tensor] = None,
202
+ average_attn_weights: bool = True,
203
+ is_causal : bool = False) -> Tuple[Tensor, Optional[Tensor]]:
204
+ why_not_fast_path = ''
205
+ if ((attn_mask is not None and torch.is_floating_point(attn_mask))
206
+ or (key_padding_mask is not None) and torch.is_floating_point(key_padding_mask)):
207
+ why_not_fast_path = "floating-point masks are not supported for fast path."
208
+
209
+ is_batched = query.dim() == 3
210
+
211
+ key_padding_mask = F._canonical_mask(
212
+ mask=key_padding_mask,
213
+ mask_name="key_padding_mask",
214
+ other_type=F._none_or_dtype(attn_mask),
215
+ other_name="attn_mask",
216
+ target_type=query.dtype
217
+ )
218
+ # _canonical_mask
219
+ attn_mask = F._canonical_mask(
220
+ mask=attn_mask,
221
+ mask_name="attn_mask",
222
+ other_type=None,
223
+ other_name="",
224
+ target_type=query.dtype,
225
+ check_other=False,
226
+ )
227
+
228
+
229
+ if not is_batched:
230
+ why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
231
+ elif query is not key or key is not value:
232
+ # When lifting this restriction, don't forget to either
233
+ # enforce that the dtypes all match or test cases where
234
+ # they don't!
235
+ why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
236
+ elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
237
+ why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
238
+ elif self.in_proj_weight is None:
239
+ why_not_fast_path = "in_proj_weight was None"
240
+ elif query.dtype != self.in_proj_weight.dtype:
241
+ # this case will fail anyway, but at least they'll get a useful error message.
242
+ why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
243
+ elif self.training:
244
+ why_not_fast_path = "training is enabled"
245
+ elif (self.num_heads % 2) != 0:
246
+ why_not_fast_path = "self.num_heads is not even"
247
+ elif not self.batch_first:
248
+ why_not_fast_path = "batch_first was not True"
249
+ elif self.bias_k is not None:
250
+ why_not_fast_path = "self.bias_k was not None"
251
+ elif self.bias_v is not None:
252
+ why_not_fast_path = "self.bias_v was not None"
253
+ elif self.add_zero_attn:
254
+ why_not_fast_path = "add_zero_attn was enabled"
255
+ elif not self._qkv_same_embed_dim:
256
+ why_not_fast_path = "_qkv_same_embed_dim was not True"
257
+ elif query.is_nested and (key_padding_mask is not None or attn_mask is not None):
258
+ why_not_fast_path = "supplying both src_key_padding_mask and src_mask at the same time \
259
+ is not supported with NestedTensor input"
260
+ elif torch.is_autocast_enabled():
261
+ why_not_fast_path = "autocast is enabled"
262
+
263
+ if not why_not_fast_path:
264
+ tensor_args = (
265
+ query,
266
+ key,
267
+ value,
268
+ self.in_proj_weight,
269
+ self.in_proj_bias,
270
+ self.out_proj.weight,
271
+ self.out_proj.bias,
272
+ )
273
+ # We have to use list comprehensions below because TorchScript does not support
274
+ # generator expressions.
275
+ if torch.overrides.has_torch_function(tensor_args):
276
+ why_not_fast_path = "some Tensor argument has_torch_function"
277
+ elif _is_make_fx_tracing():
278
+ why_not_fast_path = "we are running make_fx tracing"
279
+ elif not all(_check_arg_device(x) for x in tensor_args):
280
+ why_not_fast_path = ("some Tensor argument's device is neither one of "
281
+ f"cpu, cuda or {torch.utils.backend_registration._privateuse1_backend_name}")
282
+ elif torch.is_grad_enabled() and any(_arg_requires_grad(x) for x in tensor_args):
283
+ why_not_fast_path = ("grad is enabled and at least one of query or the "
284
+ "input/output projection weights or biases requires_grad")
285
+ if not why_not_fast_path:
286
+ merged_mask, mask_type = self.merge_masks(attn_mask, key_padding_mask, query)
287
+
288
+ if self.in_proj_bias is not None and self.in_proj_weight is not None:
289
+ return torch._native_multi_head_attention(
290
+ query,
291
+ key,
292
+ value,
293
+ self.embed_dim,
294
+ self.num_heads,
295
+ self.in_proj_weight,
296
+ self.in_proj_bias,
297
+ self.out_proj.weight,
298
+ self.out_proj.bias,
299
+ merged_mask,
300
+ need_weights,
301
+ average_attn_weights,
302
+ mask_type)
303
+
304
+ any_nested = query.is_nested or key.is_nested or value.is_nested
305
+ assert not any_nested, ("MultiheadAttention does not support NestedTensor outside of its fast path. " +
306
+ f"The fast path was not hit because {why_not_fast_path}")
307
+
308
+ if self.batch_first and is_batched:
309
+ # make sure that the transpose op does not affect the "is" property
310
+ if key is value:
311
+ if query is key:
312
+ query = key = value = query.transpose(1, 0)
313
+ else:
314
+ query, key = (x.transpose(1, 0) for x in (query, key))
315
+ value = key
316
+ else:
317
+ query, key, value = (x.transpose(1, 0) for x in (query, key, value))
318
+
319
+ if not self._qkv_same_embed_dim:
320
+ attn_output, attn_output_weights = self.multi_head_attention_forward(
321
+ query, key, value, self.embed_dim, self.num_heads,
322
+ self.in_proj_weight, self.in_proj_bias,
323
+ self.bias_k, self.bias_v, self.add_zero_attn,
324
+ self.dropout, self.out_proj.weight, self.out_proj.bias,
325
+ training=self.training,
326
+ key_padding_mask=key_padding_mask, need_weights=need_weights,
327
+ attn_mask=attn_mask,
328
+ use_separate_proj_weight=True,
329
+ q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
330
+ v_proj_weight=self.v_proj_weight,
331
+ average_attn_weights=average_attn_weights,
332
+ is_causal=is_causal)
333
+ else:
334
+ attn_output, attn_output_weights = self.multi_head_attention_forward(
335
+ query, key, value, self.embed_dim, self.num_heads,
336
+ self.in_proj_weight, self.in_proj_bias,
337
+ self.bias_k, self.bias_v, self.add_zero_attn,
338
+ self.dropout, self.out_proj.weight, self.out_proj.bias,
339
+ training=self.training,
340
+ key_padding_mask=key_padding_mask,
341
+ need_weights=need_weights,
342
+ attn_mask=attn_mask,
343
+ average_attn_weights=average_attn_weights,
344
+ is_causal=is_causal)
345
+ if self.batch_first and is_batched:
346
+ return attn_output.transpose(1, 0), attn_output_weights
347
+ else:
348
+ return attn_output, attn_output_weights
349
+
350
+ def multi_head_attention_forward(
351
+ self,
352
+ query: Tensor,
353
+ key: Tensor,
354
+ value: Tensor,
355
+ embed_dim_to_check: int,
356
+ num_heads: int,
357
+ in_proj_weight: Optional[Tensor],
358
+ in_proj_bias: Optional[Tensor],
359
+ bias_k: Optional[Tensor],
360
+ bias_v: Optional[Tensor],
361
+ add_zero_attn: bool,
362
+ dropout_p: float,
363
+ out_proj_weight: Tensor,
364
+ out_proj_bias: Optional[Tensor],
365
+ training: bool = True,
366
+ key_padding_mask: Optional[Tensor] = None,
367
+ need_weights: bool = True,
368
+ attn_mask: Optional[Tensor] = None,
369
+ use_separate_proj_weight: bool = False,
370
+ q_proj_weight: Optional[Tensor] = None,
371
+ k_proj_weight: Optional[Tensor] = None,
372
+ v_proj_weight: Optional[Tensor] = None,
373
+ static_k: Optional[Tensor] = None,
374
+ static_v: Optional[Tensor] = None,
375
+ average_attn_weights: bool = True,
376
+ is_causal: bool = False,
377
+ ) -> Tuple[Tensor, Optional[Tensor]]:
378
+ tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias)
379
+ if has_torch_function(tens_ops):
380
+ return handle_torch_function(
381
+ multi_head_attention_forward,
382
+ tens_ops,
383
+ query,
384
+ key,
385
+ value,
386
+ embed_dim_to_check,
387
+ num_heads,
388
+ in_proj_weight,
389
+ in_proj_bias,
390
+ bias_k,
391
+ bias_v,
392
+ add_zero_attn,
393
+ dropout_p,
394
+ out_proj_weight,
395
+ out_proj_bias,
396
+ training=training,
397
+ key_padding_mask=key_padding_mask,
398
+ need_weights=need_weights,
399
+ attn_mask=attn_mask,
400
+ is_causal=is_causal,
401
+ use_separate_proj_weight=use_separate_proj_weight,
402
+ q_proj_weight=q_proj_weight,
403
+ k_proj_weight=k_proj_weight,
404
+ v_proj_weight=v_proj_weight,
405
+ static_k=static_k,
406
+ static_v=static_v,
407
+ average_attn_weights=average_attn_weights,
408
+ )
409
+
410
+ is_batched = _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads)
411
+
412
+ # For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
413
+ # is batched, run the computation and before returning squeeze the
414
+ # batch dimension so that the output doesn't carry this temporary batch dimension.
415
+ if not is_batched:
416
+ # unsqueeze if the input is unbatched
417
+ query = query.unsqueeze(1)
418
+ key = key.unsqueeze(1)
419
+ value = value.unsqueeze(1)
420
+ if key_padding_mask is not None:
421
+ key_padding_mask = key_padding_mask.unsqueeze(0)
422
+
423
+ # set up shape vars
424
+ tgt_len, bsz, embed_dim = query.shape
425
+ src_len, _, _ = key.shape
426
+
427
+ key_padding_mask = _canonical_mask(
428
+ mask=key_padding_mask,
429
+ mask_name="key_padding_mask",
430
+ other_type=_none_or_dtype(attn_mask),
431
+ other_name="attn_mask",
432
+ target_type=query.dtype
433
+ )
434
+
435
+ if is_causal and attn_mask is None:
436
+ raise RuntimeError(
437
+ "Need attn_mask if specifying the is_causal hint. "
438
+ "You may use the Transformer module method "
439
+ "`generate_square_subsequent_mask` to create this mask."
440
+ )
441
+
442
+ if is_causal and key_padding_mask is None and not need_weights:
443
+ # when we have a kpm or need weights, we need attn_mask
444
+ # Otherwise, we use the is_causal hint go as is_causal
445
+ # indicator to SDPA.
446
+ attn_mask = None
447
+ else:
448
+ attn_mask = _canonical_mask(
449
+ mask=attn_mask,
450
+ mask_name="attn_mask",
451
+ other_type=None,
452
+ other_name="",
453
+ target_type=query.dtype,
454
+ check_other=False,
455
+ )
456
+
457
+ if key_padding_mask is not None:
458
+ # We have the attn_mask, and use that to merge kpm into it.
459
+ # Turn off use of is_causal hint, as the merged mask is no
460
+ # longer causal.
461
+ is_causal = False
462
+
463
+ assert embed_dim == embed_dim_to_check, \
464
+ f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
465
+ if isinstance(embed_dim, torch.Tensor):
466
+ # embed_dim can be a tensor when JIT tracing
467
+ head_dim = embed_dim.div(num_heads, rounding_mode='trunc')
468
+ else:
469
+ head_dim = embed_dim // num_heads
470
+ assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
471
+ if use_separate_proj_weight:
472
+ # allow MHA to have different embedding dimensions when separate projection weights are used
473
+ assert key.shape[:2] == value.shape[:2], \
474
+ f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
475
+ else:
476
+ assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"
477
+
478
+ #
479
+ # compute in-projection
480
+ #
481
+
482
+ if not use_separate_proj_weight:
483
+ assert in_proj_weight is not None, "use_separate_proj_weight is False but in_proj_weight is None"
484
+ q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
485
+ else:
486
+ assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None"
487
+ assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None"
488
+ assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None"
489
+ if in_proj_bias is None:
490
+ b_q = b_k = b_v = None
491
+ else:
492
+ b_q, b_k, b_v = in_proj_bias.chunk(3)
493
+ q, k, v = _in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q, b_k, b_v)
494
+
495
+ # prep attention mask
496
+
497
+ if attn_mask is not None:
498
+ # ensure attn_mask's dim is 3
499
+ if attn_mask.dim() == 2:
500
+ correct_2d_size = (tgt_len, src_len)
501
+ if attn_mask.shape != correct_2d_size:
502
+ raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
503
+ attn_mask = attn_mask.unsqueeze(0)
504
+ elif attn_mask.dim() == 3:
505
+ correct_3d_size = (bsz * num_heads, tgt_len, src_len)
506
+ if attn_mask.shape != correct_3d_size:
507
+ raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
508
+ else:
509
+ raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
510
+
511
+ # add bias along batch dimension (currently second)
512
+ if bias_k is not None and bias_v is not None:
513
+ assert static_k is None, "bias cannot be added to static key."
514
+ assert static_v is None, "bias cannot be added to static value."
515
+ k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
516
+ v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
517
+ if attn_mask is not None:
518
+ attn_mask = pad(attn_mask, (0, 1))
519
+ if key_padding_mask is not None:
520
+ key_padding_mask = pad(key_padding_mask, (0, 1))
521
+ else:
522
+ assert bias_k is None
523
+ assert bias_v is None
524
+
525
+ #
526
+ # reshape q, k, v for multihead attention and make em batch first
527
+ #
528
+ q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
529
+ if static_k is None:
530
+ k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
531
+ else:
532
+ # TODO finish disentangling control flow so we don't do in-projections when statics are passed
533
+ assert static_k.size(0) == bsz * num_heads, \
534
+ f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
535
+ assert static_k.size(2) == head_dim, \
536
+ f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
537
+ k = static_k
538
+ if static_v is None:
539
+ v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
540
+ else:
541
+ # TODO finish disentangling control flow so we don't do in-projections when statics are passed
542
+ assert static_v.size(0) == bsz * num_heads, \
543
+ f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
544
+ assert static_v.size(2) == head_dim, \
545
+ f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
546
+ v = static_v
547
+
548
+ # add zero attention along batch dimension (now first)
549
+ if add_zero_attn:
550
+ zero_attn_shape = (bsz * num_heads, 1, head_dim)
551
+ k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1)
552
+ v = torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1)
553
+ if attn_mask is not None:
554
+ attn_mask = pad(attn_mask, (0, 1))
555
+ if key_padding_mask is not None:
556
+ key_padding_mask = pad(key_padding_mask, (0, 1))
557
+
558
+ # update source sequence length after adjustments
559
+ src_len = k.size(1)
560
+
561
+ # merge key padding and attention masks
562
+ if key_padding_mask is not None:
563
+ assert key_padding_mask.shape == (bsz, src_len), \
564
+ f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
565
+ key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len). \
566
+ expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
567
+ if attn_mask is None:
568
+ attn_mask = key_padding_mask
569
+ else:
570
+ attn_mask = attn_mask + key_padding_mask
571
+
572
+ # adjust dropout probability
573
+ if not training:
574
+ dropout_p = 0.0
575
+
576
+ #
577
+ # (deep breath) calculate attention and out projection
578
+ #
579
+
580
+ if need_weights:
581
+ B, Nt, E = q.shape
582
+ q_scaled = q / math.sqrt(E)
583
+
584
+ assert not (is_causal and attn_mask is None), "FIXME: is_causal not implemented for need_weights"
585
+
586
+ if attn_mask is not None:
587
+ attn_output_weights = torch.baddbmm(attn_mask, q_scaled, k.transpose(-2, -1))
588
+ else:
589
+ attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
590
+ attn_output_weights = softmax(attn_output_weights, dim=-1)
591
+ if dropout_p > 0.0:
592
+ attn_output_weights = dropout(attn_output_weights, p=dropout_p)
593
+
594
+ attn_output = torch.bmm(attn_output_weights, v)
595
+
596
+ attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
597
+ attn_output = self.out_proj(attn_output)
598
+
599
+ attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
600
+
601
+ # optionally average attention weights over heads
602
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
603
+ if average_attn_weights:
604
+ attn_output_weights = attn_output_weights.mean(dim=1)
605
+
606
+ if not is_batched:
607
+ # squeeze the output if input was unbatched
608
+ attn_output = attn_output.squeeze(1)
609
+ attn_output_weights = attn_output_weights.squeeze(0)
610
+ return attn_output, attn_output_weights
611
+ else:
612
+ # attn_mask can be either (L,S) or (N*num_heads, L, S)
613
+ # if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S)
614
+ # in order to match the input for SDPA of (N, num_heads, L, S)
615
+ if attn_mask is not None:
616
+ if attn_mask.size(0) == 1 and attn_mask.dim() == 3:
617
+ attn_mask = attn_mask.unsqueeze(0)
618
+ else:
619
+ attn_mask = attn_mask.view(bsz, num_heads, -1, src_len)
620
+
621
+ q = q.view(bsz, num_heads, tgt_len, head_dim)
622
+ k = k.view(bsz, num_heads, src_len, head_dim)
623
+ v = v.view(bsz, num_heads, src_len, head_dim)
624
+
625
+ attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
626
+ attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
627
+
628
+ attn_output = self.out_proj(attn_output)
629
+
630
+ attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
631
+ if not is_batched:
632
+ # squeeze the output if input was unbatched
633
+ attn_output = attn_output.squeeze(1)
634
+ return attn_output, None
635
+
636
+
637
+ def _mha_shape_check(query: Tensor, key: Tensor, value: Tensor,
638
+ key_padding_mask: Optional[Tensor], attn_mask: Optional[Tensor], num_heads: int):
639
+ # Verifies the expected shape for `query, `key`, `value`, `key_padding_mask` and `attn_mask`
640
+ # and returns if the input is batched or not.
641
+ # Raises an error if `query` is not 2-D (unbatched) or 3-D (batched) tensor.
642
+
643
+ # Shape check.
644
+ if query.dim() == 3:
645
+ # Batched Inputs
646
+ is_batched = True
647
+ assert key.dim() == 3 and value.dim() == 3, \
648
+ ("For batched (3-D) `query`, expected `key` and `value` to be 3-D"
649
+ f" but found {key.dim()}-D and {value.dim()}-D tensors respectively")
650
+ if key_padding_mask is not None:
651
+ assert key_padding_mask.dim() == 2, \
652
+ ("For batched (3-D) `query`, expected `key_padding_mask` to be `None` or 2-D"
653
+ f" but found {key_padding_mask.dim()}-D tensor instead")
654
+ if attn_mask is not None:
655
+ assert attn_mask.dim() in (2, 3), \
656
+ ("For batched (3-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D"
657
+ f" but found {attn_mask.dim()}-D tensor instead")
658
+ elif query.dim() == 2:
659
+ # Unbatched Inputs
660
+ is_batched = False
661
+ assert key.dim() == 2 and value.dim() == 2, \
662
+ ("For unbatched (2-D) `query`, expected `key` and `value` to be 2-D"
663
+ f" but found {key.dim()}-D and {value.dim()}-D tensors respectively")
664
+
665
+ if key_padding_mask is not None:
666
+ assert key_padding_mask.dim() == 1, \
667
+ ("For unbatched (2-D) `query`, expected `key_padding_mask` to be `None` or 1-D"
668
+ f" but found {key_padding_mask.dim()}-D tensor instead")
669
+
670
+ if attn_mask is not None:
671
+ assert attn_mask.dim() in (2, 3), \
672
+ ("For unbatched (2-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D"
673
+ f" but found {attn_mask.dim()}-D tensor instead")
674
+ if attn_mask.dim() == 3:
675
+ expected_shape = (num_heads, query.shape[0], key.shape[0])
676
+ assert attn_mask.shape == expected_shape, \
677
+ (f"Expected `attn_mask` shape to be {expected_shape} but got {attn_mask.shape}")
678
+ else:
679
+ raise AssertionError(
680
+ f"query should be unbatched 2D or batched 3D tensor but received {query.dim()}-D query tensor")
681
+
682
+ return is_batched
683
+
684
+
685
+ def _canonical_mask(
686
+ mask: Optional[Tensor],
687
+ mask_name: str,
688
+ other_type: Optional[DType],
689
+ other_name: str,
690
+ target_type: DType,
691
+ check_other: bool = True,
692
+ ) -> Optional[Tensor]:
693
+
694
+ if mask is not None:
695
+ _mask_dtype = mask.dtype
696
+ _mask_is_float = torch.is_floating_point(mask)
697
+ if _mask_dtype != torch.bool and not _mask_is_float:
698
+ raise AssertionError(
699
+ f"only bool and floating types of {mask_name} are supported")
700
+ if check_other and other_type is not None:
701
+ if _mask_dtype != other_type:
702
+ warnings.warn(
703
+ f"Support for mismatched {mask_name} and {other_name} "
704
+ "is deprecated. Use same type for both instead."
705
+ )
706
+ if not _mask_is_float:
707
+ mask = (
708
+ torch.zeros_like(mask, dtype=target_type)
709
+ .masked_fill_(mask, float("-inf"))
710
+ )
711
+ return mask
712
+
713
+
714
+ def _none_or_dtype(input: Optional[Tensor]) -> Optional[DType]:
715
+ if input is None:
716
+ return None
717
+ elif isinstance(input, torch.Tensor):
718
+ return input.dtype
719
+ raise RuntimeError("input to _none_or_dtype() must be None or torch.Tensor")
720
+
721
+ def _in_projection_packed(
722
+ q: Tensor,
723
+ k: Tensor,
724
+ v: Tensor,
725
+ w: Tensor,
726
+ b: Optional[Tensor] = None,
727
+ ) -> List[Tensor]:
728
+ r"""
729
+ Performs the in-projection step of the attention operation, using packed weights.
730
+ Output is a triple containing projection tensors for query, key and value.
731
+ Args:
732
+ q, k, v: query, key and value tensors to be projected. For self-attention,
733
+ these are typically the same tensor; for encoder-decoder attention,
734
+ k and v are typically the same tensor. (We take advantage of these
735
+ identities for performance if they are present.) Regardless, q, k and v
736
+ must share a common embedding dimension; otherwise their shapes may vary.
737
+ w: projection weights for q, k and v, packed into a single tensor. Weights
738
+ are packed along dimension 0, in q, k, v order.
739
+ b: optional projection biases for q, k and v, packed into a single tensor
740
+ in q, k, v order.
741
+ Shape:
742
+ Inputs:
743
+ - q: :math:`(..., E)` where E is the embedding dimension
744
+ - k: :math:`(..., E)` where E is the embedding dimension
745
+ - v: :math:`(..., E)` where E is the embedding dimension
746
+ - w: :math:`(E * 3, E)` where E is the embedding dimension
747
+ - b: :math:`E * 3` where E is the embedding dimension
748
+ Output:
749
+ - in output list :math:`[q', k', v']`, each output tensor will have the
750
+ same shape as the corresponding input tensor.
751
+ """
752
+ E = q.size(-1)
753
+ if k is v:
754
+ if q is k:
755
+ # self-attention
756
+ proj = linear(q, w, b)
757
+ # reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk()
758
+ proj = proj.unflatten(-1, (3, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
759
+ return proj[0], proj[1], proj[2]
760
+ else:
761
+ # encoder-decoder attention
762
+ w_q, w_kv = w.split([E, E * 2])
763
+ if b is None:
764
+ b_q = b_kv = None
765
+ else:
766
+ b_q, b_kv = b.split([E, E * 2])
767
+ q_proj = linear(q, w_q, b_q)
768
+ kv_proj = linear(k, w_kv, b_kv)
769
+ # reshape to 2, E and not E, 2 is deliberate for better memory coalescing and keeping same order as chunk()
770
+ kv_proj = kv_proj.unflatten(-1, (2, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
771
+ return (q_proj, kv_proj[0], kv_proj[1])
772
+ else:
773
+ w_q, w_k, w_v = w.chunk(3)
774
+ if b is None:
775
+ b_q = b_k = b_v = None
776
+ else:
777
+ b_q, b_k, b_v = b.chunk(3)
778
+ return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
779
+
780
+
781
+ def _in_projection(
782
+ q: Tensor,
783
+ k: Tensor,
784
+ v: Tensor,
785
+ w_q: Tensor,
786
+ w_k: Tensor,
787
+ w_v: Tensor,
788
+ b_q: Optional[Tensor] = None,
789
+ b_k: Optional[Tensor] = None,
790
+ b_v: Optional[Tensor] = None,
791
+ ) -> Tuple[Tensor, Tensor, Tensor]:
792
+ r"""
793
+ Performs the in-projection step of the attention operation. This is simply
794
+ a triple of linear projections, with shape constraints on the weights which
795
+ ensure embedding dimension uniformity in the projected outputs.
796
+ Output is a triple containing projection tensors for query, key and value.
797
+ Args:
798
+ q, k, v: query, key and value tensors to be projected.
799
+ w_q, w_k, w_v: weights for q, k and v, respectively.
800
+ b_q, b_k, b_v: optional biases for q, k and v, respectively.
801
+ Shape:
802
+ Inputs:
803
+ - q: :math:`(Qdims..., Eq)` where Eq is the query embedding dimension and Qdims are any
804
+ number of leading dimensions.
805
+ - k: :math:`(Kdims..., Ek)` where Ek is the key embedding dimension and Kdims are any
806
+ number of leading dimensions.
807
+ - v: :math:`(Vdims..., Ev)` where Ev is the value embedding dimension and Vdims are any
808
+ number of leading dimensions.
809
+ - w_q: :math:`(Eq, Eq)`
810
+ - w_k: :math:`(Eq, Ek)`
811
+ - w_v: :math:`(Eq, Ev)`
812
+ - b_q: :math:`(Eq)`
813
+ - b_k: :math:`(Eq)`
814
+ - b_v: :math:`(Eq)`
815
+ Output: in output triple :math:`(q', k', v')`,
816
+ - q': :math:`[Qdims..., Eq]`
817
+ - k': :math:`[Kdims..., Eq]`
818
+ - v': :math:`[Vdims..., Eq]`
819
+ """
820
+ Eq, Ek, Ev = q.size(-1), k.size(-1), v.size(-1)
821
+ assert w_q.shape == (Eq, Eq), f"expecting query weights shape of {(Eq, Eq)}, but got {w_q.shape}"
822
+ assert w_k.shape == (Eq, Ek), f"expecting key weights shape of {(Eq, Ek)}, but got {w_k.shape}"
823
+ assert w_v.shape == (Eq, Ev), f"expecting value weights shape of {(Eq, Ev)}, but got {w_v.shape}"
824
+ assert b_q is None or b_q.shape == (Eq,), f"expecting query bias shape of {(Eq,)}, but got {b_q.shape}"
825
+ assert b_k is None or b_k.shape == (Eq,), f"expecting key bias shape of {(Eq,)}, but got {b_k.shape}"
826
+ assert b_v is None or b_v.shape == (Eq,), f"expecting value bias shape of {(Eq,)}, but got {b_v.shape}"
827
+ return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)