Text Generation
Transformers
PyTorch
mosaic_gpt
custom_code
anas-awadalla commited on
Commit
f0a13e4
1 Parent(s): 38fdb61

turn attention_mask to bool in forward pass

Browse files
Files changed (1) hide show
  1. mosaic_gpt.py +1 -0
mosaic_gpt.py CHANGED
@@ -247,6 +247,7 @@ class MosaicGPT(PreTrainedModel):
247
  use_cache: Optional[bool] = None):
248
  return_dict = return_dict if return_dict is not None else self.config.return_dict
249
  use_cache = use_cache if use_cache is not None else self.config.use_cache
 
250
 
251
  # These args are passed in by keyword in huggingface's generate function
252
  # https://github.com/huggingface/transformers/blob/68287689f2f0d8b7063c400230b3766987abf18d/src/transformers/generation/utils.py#L2201-L2206
 
247
  use_cache: Optional[bool] = None):
248
  return_dict = return_dict if return_dict is not None else self.config.return_dict
249
  use_cache = use_cache if use_cache is not None else self.config.use_cache
250
+ attention_mask = attention_mask.bool()
251
 
252
  # These args are passed in by keyword in huggingface's generate function
253
  # https://github.com/huggingface/transformers/blob/68287689f2f0d8b7063c400230b3766987abf18d/src/transformers/generation/utils.py#L2201-L2206