BucketOfFish commited on
Commit
3649bbb
1 Parent(s): 4f25dda

Got output to match Phi2 exactly

Browse files
Files changed (1) hide show
  1. attention.py +25 -5
attention.py CHANGED
@@ -370,11 +370,29 @@ class MHA(nn.Module):
370
  dtype=kv.dtype,
371
  device=kv.device,
372
  )
 
 
 
 
 
 
 
 
 
 
 
 
373
  kv_cache.kv_block_map[block_n][
374
- kv_cache.batch_size_offset: kv_cache.batch_size_offset + kv.shape[0],
375
- kv_cache.seqlen_offset: kv_cache.seqlen_offset + kv.shape[1],
376
  ...
377
  ] = kv
 
 
 
 
 
 
378
 
379
  def _forward_cross_attn(
380
  self,
@@ -396,8 +414,9 @@ class MHA(nn.Module):
396
  ],
397
  dim=2,
398
  )
399
- self._update_kv_cache(kv, kv_cache, self.block_n)
400
- causal = False # turning off causal mask for cross attention
 
401
 
402
  if self.using_flash_attn and unpad_input and pad_input: # not touching flash attention code
403
  batch_size, seqlen_q = q.shape[0], q.shape[1]
@@ -528,4 +547,5 @@ class ParallelAttentionBlock(nn.Module):
528
  key_padding_mask=key_padding_mask,
529
  )
530
  mlp_outputs = self.mlp(x)
531
- return self.dropout(attn_outputs + mlp_outputs) + residual
 
 
370
  dtype=kv.dtype,
371
  device=kv.device,
372
  )
373
+
374
+ batch_start = kv_cache.batch_size_offset
375
+ batch_end = batch_start + kv.shape[0]
376
+ sequence_start = kv_cache.seqlen_offset
377
+ sequence_end = sequence_start + kv.shape[1]
378
+
379
+ # TODO: figure out why they're doing this
380
+ if sequence_end >= kv_cache.max_seqlen:
381
+ kv_cache.kv_block_map[block_n] = torch.concatenate(
382
+ (kv_cache.kv_block_map[block_n], kv),
383
+ dim=1,
384
+ )
385
  kv_cache.kv_block_map[block_n][
386
+ batch_start:batch_end,
387
+ sequence_start:sequence_end,
388
  ...
389
  ] = kv
390
+ kv = kv_cache.kv_block_map[block_n][
391
+ batch_start:batch_end,
392
+ :sequence_end,
393
+ ...
394
+ ]
395
+ return kv
396
 
397
  def _forward_cross_attn(
398
  self,
 
414
  ],
415
  dim=2,
416
  )
417
+ kv = self._update_kv_cache(kv, kv_cache, self.block_n)
418
+
419
+ causal = (kv_cache.seqlen_offset == 0)
420
 
421
  if self.using_flash_attn and unpad_input and pad_input: # not touching flash attention code
422
  batch_size, seqlen_q = q.shape[0], q.shape[1]
 
547
  key_padding_mask=key_padding_mask,
548
  )
549
  mlp_outputs = self.mlp(x)
550
+ x = self.dropout(attn_outputs + mlp_outputs) + residual
551
+ return x