Hilbertmeng commited on
Commit
51d254e
1 Parent(s): 9ee8b6b

fix k_mask

Browse files
Files changed (1) hide show
  1. modeling_dcformer.py +8 -6
modeling_dcformer.py CHANGED
@@ -123,7 +123,7 @@ class DCFormer(PreTrainedModel):
123
  def generate(self, input_ids, num_tokens_to_generate=10, compiled_decode_one_token=None):
124
  batch_size, seq_length = input_ids.shape
125
  input_pos = torch.arange(seq_length, device=self.device)
126
- generated_ids = torch.zeros(batch_size, seq_length + num_tokens_to_generate + 1, dtype=torch.int, device=self.device)
127
  generated_ids[:, :seq_length] = input_ids.to(self.device).to(torch.int)
128
  logits = self.forward(input_ids, input_pos=input_pos,return_tensor=True)
129
  _next_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
@@ -162,12 +162,14 @@ class DCFormer(PreTrainedModel):
162
  for i, layer in enumerate(self.layers):
163
  if self.is_training or self.window_size is None :
164
  layer_mask = mask
 
165
  elif self.window_size is not None:
166
  layer_mask = mask[:,:,1] if layer.attention.window_size is None else mask[:,:,0]
 
167
  if self.use_gradient_checkpointing:
168
  x = checkpoint(layer, x, input_pos, freqs_cis, layer_mask)
169
  else:
170
- x = layer(x, input_pos, freqs_cis, layer_mask)
171
  x = self.norm(x)
172
  logits = self.output(x)
173
  if return_tensor:
@@ -185,8 +187,8 @@ class DCFormerBlock(nn.Module):
185
  self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
186
  self.attention_norm = RMSNorm(config.dim, config.norm_eps)
187
 
188
- def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor:
189
- h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos, fast_infer=True)
190
  out = h + self.feed_forward(self.ffn_norm(h))
191
  return out
192
 
@@ -416,7 +418,7 @@ class DCMHAttention(nn.Module):
416
  y = probs @ v
417
  return y
418
 
419
- def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None, fast_infer=True) -> Tensor:
420
  bsz, seqlen, _ = x.shape
421
 
422
  kv_size = self.n_local_heads * self.head_dim
@@ -483,7 +485,7 @@ class DCMHAttention(nn.Module):
483
  y[:,:,start:stop] = _o
484
  else: # inference
485
  if seqlen == 1: # one-token generation
486
- k_mask = mask if self.window_size is None else mask[:,:,:,:self.kv_cache.seq_length]
487
  if fast_infer:
488
  y = self._generate_fast(x, input_pos, q, k, v, k_mask)
489
  else:
 
123
  def generate(self, input_ids, num_tokens_to_generate=10, compiled_decode_one_token=None):
124
  batch_size, seq_length = input_ids.shape
125
  input_pos = torch.arange(seq_length, device=self.device)
126
+ generated_ids = torch.zeros(batch_size, seq_length + num_tokens_to_generate, dtype=torch.int, device=self.device)
127
  generated_ids[:, :seq_length] = input_ids.to(self.device).to(torch.int)
128
  logits = self.forward(input_ids, input_pos=input_pos,return_tensor=True)
129
  _next_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
 
162
  for i, layer in enumerate(self.layers):
163
  if self.is_training or self.window_size is None :
164
  layer_mask = mask
165
+ gen_mask = None
166
  elif self.window_size is not None:
167
  layer_mask = mask[:,:,1] if layer.attention.window_size is None else mask[:,:,0]
168
+ gen_mask = mask[:,:,1] if layer.attention.window_size is not None else None
169
  if self.use_gradient_checkpointing:
170
  x = checkpoint(layer, x, input_pos, freqs_cis, layer_mask)
171
  else:
172
+ x = layer(x, input_pos, freqs_cis, layer_mask, gen_mask=gen_mask)
173
  x = self.norm(x)
174
  logits = self.output(x)
175
  if return_tensor:
 
187
  self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
188
  self.attention_norm = RMSNorm(config.dim, config.norm_eps)
189
 
190
+ def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor, gen_mask=None) -> Tensor:
191
+ h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos, gen_mask=gen_mask, fast_infer=True)
192
  out = h + self.feed_forward(self.ffn_norm(h))
193
  return out
194
 
 
418
  y = probs @ v
419
  return y
420
 
421
+ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None, fast_infer=True, gen_mask=None) -> Tensor:
422
  bsz, seqlen, _ = x.shape
423
 
424
  kv_size = self.n_local_heads * self.head_dim
 
485
  y[:,:,start:stop] = _o
486
  else: # inference
487
  if seqlen == 1: # one-token generation
488
+ k_mask = mask if self.window_size is None else gen_mask[:, :, :,:self.kv_cache.seq_length]
489
  if fast_infer:
490
  y = self._generate_fast(x, input_pos, q, k, v, k_mask)
491
  else: