Hilbertmeng commited on
Commit
052c75b
1 Parent(s): d2df67e
Files changed (1) hide show
  1. modeling_dcpythia.py +1 -1
modeling_dcpythia.py CHANGED
@@ -503,7 +503,7 @@ class DCMHAttention(nn.Module):
503
  y[:,:,start:stop] = _o
504
  else: # inference
505
  if seqlen == 1: # one-token generation
506
- k_mask = mask if self.window_size is None else gen_mask[:, :, :,:self.kv_cache.seq_length
507
  if fast_infer:
508
  y = self._generate_fast(x, input_pos, q, k, v, k_mask)
509
  else:
 
503
  y[:,:,start:stop] = _o
504
  else: # inference
505
  if seqlen == 1: # one-token generation
506
+ k_mask = mask if self.window_size is None else gen_mask[:, :, :,:self.kv_cache.seq_length]
507
  if fast_infer:
508
  y = self._generate_fast(x, input_pos, q, k, v, k_mask)
509
  else: