Commit
•
d2df67e
1
Parent(s):
6becb1d
fix k_mask
Browse files- modeling_dcpythia.py +8 -6
modeling_dcpythia.py
CHANGED
@@ -120,7 +120,7 @@ class DCPythia(PreTrainedModel):
|
|
120 |
def generate(self, input_ids, num_tokens_to_generate=10, compiled_decode_one_token=None):
|
121 |
batch_size, seq_length = input_ids.shape
|
122 |
input_pos = torch.arange(seq_length, device=self.device)
|
123 |
-
generated_ids = torch.zeros(batch_size, seq_length + num_tokens_to_generate
|
124 |
generated_ids[:, :seq_length] = input_ids.to(self.device).to(torch.int)
|
125 |
logits = self.forward(input_ids, input_pos=input_pos,return_tensor=True)
|
126 |
_next_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
|
@@ -159,12 +159,14 @@ class DCPythia(PreTrainedModel):
|
|
159 |
for i, layer in enumerate(self.layers):
|
160 |
if self.is_training or self.window_size is None :
|
161 |
layer_mask = mask
|
|
|
162 |
elif self.window_size is not None:
|
163 |
layer_mask = mask[:,:,1] if layer.attention.window_size is None else mask[:,:,0]
|
|
|
164 |
if self.use_gradient_checkpointing:
|
165 |
x = checkpoint(layer, x, input_pos, freqs_cis, layer_mask)
|
166 |
else:
|
167 |
-
x = layer(x, input_pos, freqs_cis, layer_mask)
|
168 |
x = self.norm(x)
|
169 |
logits = self.output(x)
|
170 |
if return_tensor:
|
@@ -183,8 +185,8 @@ class DCPythiaBlock(nn.Module):
|
|
183 |
self.attention_norm = nn.LayerNorm(config.dim, eps=config.norm_eps)
|
184 |
self.use_parallel_residual = config.use_parallel_residual
|
185 |
|
186 |
-
def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor:
|
187 |
-
h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos, fast_infer=True)
|
188 |
if self.use_parallel_residual:
|
189 |
out = h + self.feed_forward(self.ffn_norm(x))
|
190 |
else:
|
@@ -424,7 +426,7 @@ class DCMHAttention(nn.Module):
|
|
424 |
y = probs @ v
|
425 |
return y
|
426 |
|
427 |
-
def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None, fast_infer=True) -> Tensor:
|
428 |
bsz, seqlen, _ = x.shape
|
429 |
|
430 |
kv_size = self.n_local_heads * self.head_dim
|
@@ -501,7 +503,7 @@ class DCMHAttention(nn.Module):
|
|
501 |
y[:,:,start:stop] = _o
|
502 |
else: # inference
|
503 |
if seqlen == 1: # one-token generation
|
504 |
-
k_mask = mask if self.window_size is None else
|
505 |
if fast_infer:
|
506 |
y = self._generate_fast(x, input_pos, q, k, v, k_mask)
|
507 |
else:
|
|
|
120 |
def generate(self, input_ids, num_tokens_to_generate=10, compiled_decode_one_token=None):
|
121 |
batch_size, seq_length = input_ids.shape
|
122 |
input_pos = torch.arange(seq_length, device=self.device)
|
123 |
+
generated_ids = torch.zeros(batch_size, seq_length + num_tokens_to_generate, dtype=torch.int, device=self.device)
|
124 |
generated_ids[:, :seq_length] = input_ids.to(self.device).to(torch.int)
|
125 |
logits = self.forward(input_ids, input_pos=input_pos,return_tensor=True)
|
126 |
_next_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
|
|
|
159 |
for i, layer in enumerate(self.layers):
|
160 |
if self.is_training or self.window_size is None :
|
161 |
layer_mask = mask
|
162 |
+
gen_mask = None
|
163 |
elif self.window_size is not None:
|
164 |
layer_mask = mask[:,:,1] if layer.attention.window_size is None else mask[:,:,0]
|
165 |
+
gen_mask = mask[:,:,1] if layer.attention.window_size is not None else None
|
166 |
if self.use_gradient_checkpointing:
|
167 |
x = checkpoint(layer, x, input_pos, freqs_cis, layer_mask)
|
168 |
else:
|
169 |
+
x = layer(x, input_pos, freqs_cis, layer_mask, gen_mask=gen_mask)
|
170 |
x = self.norm(x)
|
171 |
logits = self.output(x)
|
172 |
if return_tensor:
|
|
|
185 |
self.attention_norm = nn.LayerNorm(config.dim, eps=config.norm_eps)
|
186 |
self.use_parallel_residual = config.use_parallel_residual
|
187 |
|
188 |
+
def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor, gen_mask=None) -> Tensor:
|
189 |
+
h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos, fast_infer=True, gen_mask=gen_mask)
|
190 |
if self.use_parallel_residual:
|
191 |
out = h + self.feed_forward(self.ffn_norm(x))
|
192 |
else:
|
|
|
426 |
y = probs @ v
|
427 |
return y
|
428 |
|
429 |
+
def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None, fast_infer=True, gen_mask=None) -> Tensor:
|
430 |
bsz, seqlen, _ = x.shape
|
431 |
|
432 |
kv_size = self.n_local_heads * self.head_dim
|
|
|
503 |
y[:,:,start:stop] = _o
|
504 |
else: # inference
|
505 |
if seqlen == 1: # one-token generation
|
506 |
+
k_mask = mask if self.window_size is None else gen_mask[:, :, :,:self.kv_cache.seq_length
|
507 |
if fast_infer:
|
508 |
y = self._generate_fast(x, input_pos, q, k, v, k_mask)
|
509 |
else:
|