SinclairSchneider commited on
Commit
dfa4056
1 Parent(s): 52aabf1

Update modeling_dbrx.py

Browse files

gradient_checkpointing fixed

Files changed (1) hide show
  1. modeling_dbrx.py +7 -7
modeling_dbrx.py CHANGED
@@ -1093,13 +1093,13 @@ class DbrxModel(DbrxPreTrainedModel):
1093
  block_outputs = self._gradient_checkpointing_func(
1094
  block.__call__,
1095
  hidden_states,
1096
- attention_mask=causal_mask,
1097
- position_ids=position_ids,
1098
- past_key_values=past_key_values,
1099
- output_attentions=output_attentions,
1100
- output_router_logits=output_router_logits,
1101
- use_cache=use_cache,
1102
- cache_position=cache_position,
1103
  )
1104
  else:
1105
  block_outputs = block(
 
1093
  block_outputs = self._gradient_checkpointing_func(
1094
  block.__call__,
1095
  hidden_states,
1096
+ causal_mask,
1097
+ position_ids,
1098
+ past_key_values,
1099
+ output_attentions,
1100
+ output_router_logits,
1101
+ use_cache,
1102
+ cache_position,
1103
  )
1104
  else:
1105
  block_outputs = block(