fix mixed precision loading with recent transformers versions
#39
by
jupyterjazz
- opened
- modeling_xlm_roberta.py +1 -0
modeling_xlm_roberta.py
CHANGED
@@ -404,6 +404,7 @@ class XLMRobertaPreTrainedModel(PreTrainedModel):
|
|
404 |
config_class = XLMRobertaFlashConfig
|
405 |
base_model_prefix = "roberta"
|
406 |
supports_gradient_checkpointing = True
|
|
|
407 |
|
408 |
def _set_gradient_checkpointing(self, module, value=False):
|
409 |
if isinstance(module, XLMRobertaEncoder):
|
|
|
404 |
config_class = XLMRobertaFlashConfig
|
405 |
base_model_prefix = "roberta"
|
406 |
supports_gradient_checkpointing = True
|
407 |
+
_supports_param_buffer_assignment = False
|
408 |
|
409 |
def _set_gradient_checkpointing(self, module, value=False):
|
410 |
if isinstance(module, XLMRobertaEncoder):
|