Hilbertmeng commited on
Commit
d2df67e
1 Parent(s): 6becb1d

fix k_mask

Browse files
Files changed (1) hide show
  1. modeling_dcpythia.py +8 -6
modeling_dcpythia.py CHANGED
@@ -120,7 +120,7 @@ class DCPythia(PreTrainedModel):
120
  def generate(self, input_ids, num_tokens_to_generate=10, compiled_decode_one_token=None):
121
  batch_size, seq_length = input_ids.shape
122
  input_pos = torch.arange(seq_length, device=self.device)
123
- generated_ids = torch.zeros(batch_size, seq_length + num_tokens_to_generate + 1, dtype=torch.int, device=self.device)
124
  generated_ids[:, :seq_length] = input_ids.to(self.device).to(torch.int)
125
  logits = self.forward(input_ids, input_pos=input_pos,return_tensor=True)
126
  _next_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
@@ -159,12 +159,14 @@ class DCPythia(PreTrainedModel):
159
  for i, layer in enumerate(self.layers):
160
  if self.is_training or self.window_size is None :
161
  layer_mask = mask
 
162
  elif self.window_size is not None:
163
  layer_mask = mask[:,:,1] if layer.attention.window_size is None else mask[:,:,0]
 
164
  if self.use_gradient_checkpointing:
165
  x = checkpoint(layer, x, input_pos, freqs_cis, layer_mask)
166
  else:
167
- x = layer(x, input_pos, freqs_cis, layer_mask)
168
  x = self.norm(x)
169
  logits = self.output(x)
170
  if return_tensor:
@@ -183,8 +185,8 @@ class DCPythiaBlock(nn.Module):
183
  self.attention_norm = nn.LayerNorm(config.dim, eps=config.norm_eps)
184
  self.use_parallel_residual = config.use_parallel_residual
185
 
186
- def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor:
187
- h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos, fast_infer=True)
188
  if self.use_parallel_residual:
189
  out = h + self.feed_forward(self.ffn_norm(x))
190
  else:
@@ -424,7 +426,7 @@ class DCMHAttention(nn.Module):
424
  y = probs @ v
425
  return y
426
 
427
- def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None, fast_infer=True) -> Tensor:
428
  bsz, seqlen, _ = x.shape
429
 
430
  kv_size = self.n_local_heads * self.head_dim
@@ -501,7 +503,7 @@ class DCMHAttention(nn.Module):
501
  y[:,:,start:stop] = _o
502
  else: # inference
503
  if seqlen == 1: # one-token generation
504
- k_mask = mask if self.window_size is None else mask[:,:,:,:self.kv_cache.seq_length]
505
  if fast_infer:
506
  y = self._generate_fast(x, input_pos, q, k, v, k_mask)
507
  else:
 
120
  def generate(self, input_ids, num_tokens_to_generate=10, compiled_decode_one_token=None):
121
  batch_size, seq_length = input_ids.shape
122
  input_pos = torch.arange(seq_length, device=self.device)
123
+ generated_ids = torch.zeros(batch_size, seq_length + num_tokens_to_generate, dtype=torch.int, device=self.device)
124
  generated_ids[:, :seq_length] = input_ids.to(self.device).to(torch.int)
125
  logits = self.forward(input_ids, input_pos=input_pos,return_tensor=True)
126
  _next_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
 
159
  for i, layer in enumerate(self.layers):
160
  if self.is_training or self.window_size is None :
161
  layer_mask = mask
162
+ gen_mask = None
163
  elif self.window_size is not None:
164
  layer_mask = mask[:,:,1] if layer.attention.window_size is None else mask[:,:,0]
165
+ gen_mask = mask[:,:,1] if layer.attention.window_size is not None else None
166
  if self.use_gradient_checkpointing:
167
  x = checkpoint(layer, x, input_pos, freqs_cis, layer_mask)
168
  else:
169
+ x = layer(x, input_pos, freqs_cis, layer_mask, gen_mask=gen_mask)
170
  x = self.norm(x)
171
  logits = self.output(x)
172
  if return_tensor:
 
185
  self.attention_norm = nn.LayerNorm(config.dim, eps=config.norm_eps)
186
  self.use_parallel_residual = config.use_parallel_residual
187
 
188
+ def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor, gen_mask=None) -> Tensor:
189
+ h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos, fast_infer=True, gen_mask=gen_mask)
190
  if self.use_parallel_residual:
191
  out = h + self.feed_forward(self.ffn_norm(x))
192
  else:
 
426
  y = probs @ v
427
  return y
428
 
429
+ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None, fast_infer=True, gen_mask=None) -> Tensor:
430
  bsz, seqlen, _ = x.shape
431
 
432
  kv_size = self.n_local_heads * self.head_dim
 
503
  y[:,:,start:stop] = _o
504
  else: # inference
505
  if seqlen == 1: # one-token generation
506
+ k_mask = mask if self.window_size is None else gen_mask[:, :, :,:self.kv_cache.seq_length
507
  if fast_infer:
508
  y = self._generate_fast(x, input_pos, q, k, v, k_mask)
509
  else: