jupyterjazz
commited on
Commit
•
3eceb33
1
Parent(s):
e860caa
Update modeling_xlm_roberta.py
Browse files- modeling_xlm_roberta.py +4 -6
modeling_xlm_roberta.py
CHANGED
@@ -210,12 +210,10 @@ class XLMRobertaEncoder(nn.Module):
|
|
210 |
subset_mask: (batch, seqlen), dtype=torch.bool
|
211 |
"""
|
212 |
if key_padding_mask is None or not self.use_flash_attn:
|
213 |
-
mixer_kwargs =
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
)
|
218 |
-
mixer_kwargs['task_type'] = task_type
|
219 |
for layer in self.layers:
|
220 |
if self._grad_checkpointing:
|
221 |
hidden_states = torch.utils.checkpoint.checkpoint(
|
|
|
210 |
subset_mask: (batch, seqlen), dtype=torch.bool
|
211 |
"""
|
212 |
if key_padding_mask is None or not self.use_flash_attn:
|
213 |
+
mixer_kwargs = {'task_type': task_type}
|
214 |
+
if key_padding_mask is not None:
|
215 |
+
mixer_kwargs['key_padding_mask'] = key_padding_mask.bool()
|
216 |
+
|
|
|
|
|
217 |
for layer in self.layers:
|
218 |
if self._grad_checkpointing:
|
219 |
hidden_states = torch.utils.checkpoint.checkpoint(
|