Markus28 commited on
Commit
ca5f516
1 Parent(s): eec6c0e

fix: cast mask to bool

Browse files
Files changed (1) hide show
  1. modeling_bert.py +1 -1
modeling_bert.py CHANGED
@@ -184,7 +184,7 @@ class BertEncoder(nn.Module):
184
  """
185
  if key_padding_mask is None or not self.use_flash_attn:
186
  mixer_kwargs = (
187
- {"key_padding_mask": key_padding_mask} if key_padding_mask is not None else None
188
  )
189
  for layer in self.layers:
190
  hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
 
184
  """
185
  if key_padding_mask is None or not self.use_flash_attn:
186
  mixer_kwargs = (
187
+ {"key_padding_mask": key_padding_mask.bool()} if key_padding_mask is not None else None
188
  )
189
  for layer in self.layers:
190
  hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)