ccdv commited on
Commit
93a4f16
1 Parent(s): d766a2f
Files changed (1) hide show
  1. modeling_lsg_distilbert.py +1 -1
modeling_lsg_distilbert.py CHANGED
@@ -766,7 +766,7 @@ class LSGTransformer(Transformer):
766
  return_dict: Optional[bool] = None,
767
  ) -> Union[BaseModelOutput, Tuple[torch.Tensor, ...]]: # docstyle-ignore
768
 
769
- attn_mask = attn_mask.float()
770
  mask_value = 0
771
  n, t = attn_mask.size()
772
 
 
766
  return_dict: Optional[bool] = None,
767
  ) -> Union[BaseModelOutput, Tuple[torch.Tensor, ...]]: # docstyle-ignore
768
 
769
+ attn_mask = attn_mask.to(dtype=x.dtype)
770
  mask_value = 0
771
  n, t = attn_mask.size()
772