fix: set fp32 when using cpu bc bf16 is slow

#44
Files changed (1) hide show
  1. configuration_xlm_roberta.py +2 -0
configuration_xlm_roberta.py CHANGED
@@ -126,3 +126,5 @@ class XLMRobertaFlashConfig(PretrainedConfig):
126
  self.torch_dtype = getattr(torch, torch_dtype)
127
  else:
128
  self.torch_dtype = torch_dtype
 
 
 
126
  self.torch_dtype = getattr(torch, torch_dtype)
127
  else:
128
  self.torch_dtype = torch_dtype
129
+ if not self.use_flash_attn or not torch.cuda.is_available():
130
+ self.torch_dtype = torch.float32