Update model.py
Browse files
model.py
CHANGED
@@ -355,7 +355,7 @@ class StripedHyena(nn.Module):
|
|
355 |
self.gradient_checkpointing = False
|
356 |
self._gradient_checkpointing_func = None
|
357 |
|
358 |
-
def forward(self,
|
359 |
L = x.shape[1]
|
360 |
x = self.embedding_layer.embed(x)
|
361 |
if inference_params_dict is not None:
|
@@ -370,7 +370,7 @@ class StripedHyena(nn.Module):
|
|
370 |
x = self.unembed.unembed(x)
|
371 |
return x, inference_params_dict_out
|
372 |
|
373 |
-
def stateful_forward(self,
|
374 |
for block_idx, block in enumerate(self.blocks):
|
375 |
block_name = "mha" if block_idx in self.config.attn_layer_idxs else "hyena"
|
376 |
inference_params = inference_params_dict[block_name]
|
@@ -378,7 +378,7 @@ class StripedHyena(nn.Module):
|
|
378 |
|
379 |
return x, inference_params_dict
|
380 |
|
381 |
-
def stateless_forward(self,
|
382 |
if type(padding_mask) == torch.Tensor:
|
383 |
x = x * padding_mask[..., None]
|
384 |
|
|
|
355 |
self.gradient_checkpointing = False
|
356 |
self._gradient_checkpointing_func = None
|
357 |
|
358 |
+
def forward(self, x, inference_params_dict=None, padding_mask=None):
|
359 |
L = x.shape[1]
|
360 |
x = self.embedding_layer.embed(x)
|
361 |
if inference_params_dict is not None:
|
|
|
370 |
x = self.unembed.unembed(x)
|
371 |
return x, inference_params_dict_out
|
372 |
|
373 |
+
def stateful_forward(self, x, inference_params_dict=None):
|
374 |
for block_idx, block in enumerate(self.blocks):
|
375 |
block_name = "mha" if block_idx in self.config.attn_layer_idxs else "hyena"
|
376 |
inference_params = inference_params_dict[block_name]
|
|
|
378 |
|
379 |
return x, inference_params_dict
|
380 |
|
381 |
+
def stateless_forward(self, x, padding_mask=None):
|
382 |
if type(padding_mask) == torch.Tensor:
|
383 |
x = x * padding_mask[..., None]
|
384 |
|