Fix compatibility with transformers 5.0

#8

Fix three breaking changes introduced by transformers>=5.0:

  1. Remove is_torch_fx_available import — This function was removed from transformers.utils.import_utils in v5.0. The torch.fx wrapping of _prepare_4d_causal_attention_mask is no longer needed.

  2. Fix _init_rope rope_scaling key access — transformers 5.0 renamed rope_scaling["type"] to rope_scaling["rope_type"] and may auto-populate rope_scaling even when the original config has it as null. Use .get() with fallback for both key names, default factor to 1.0, and fall back to default DeepseekRotaryEmbedding for unknown scaling types instead of raising.

  3. Add **kwargs to forward() signatures — transformers 5.0 passes additional keyword arguments (e.g., loss_mask) to model forward methods. Without **kwargs, these cause TypeError.

Ref: https://github.com/huggingface/transformers/issues/44561

I meet the same question. Hope them would fix it.

Ready to merge
This branch is ready to get merged automatically.

Sign up or log in to comment