fix: set fp32 when using cpu bc bf16 is slow
#44
by
jupyterjazz
- opened
configuration_xlm_roberta.py
CHANGED
@@ -126,3 +126,5 @@ class XLMRobertaFlashConfig(PretrainedConfig):
|
|
126 |
self.torch_dtype = getattr(torch, torch_dtype)
|
127 |
else:
|
128 |
self.torch_dtype = torch_dtype
|
|
|
|
|
|
126 |
self.torch_dtype = getattr(torch, torch_dtype)
|
127 |
else:
|
128 |
self.torch_dtype = torch_dtype
|
129 |
+
if not self.use_flash_attn or not torch.cuda.is_available():
|
130 |
+
self.torch_dtype = torch.float32
|