Can't repro MMLU: sliding window attention implementation seems broken

#11
by dzhulgakov - opened

From running quick evals I'm getting only 58% on MMLU for the base model and 69% on instruct. Compared to our implementation at Fireworks (https://fireworks.ai/models/fireworks/gemma2-9b-it), the numerics diverge at sliding window attention (i.e., after the 2nd, 4th, etc. layer).

Disabling the sliding window (which should be equivalent as MMLU prompts are shorter than the window) brings results back to 71%. E.g.:

+++ transformers/models/gemma2/modeling_gemma2.py
@@ -217,6 +217,7 @@
             base=self.rope_theta,
         )
         self.sliding_window = config.sliding_window if layer_idx % 2 else None
+        self.sliding_window = None

     def forward(
         self,
@@ -611,6 +612,7 @@
         self.post_attention_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

         self.is_sliding = bool(layer_idx % 2)
+        self.is_sliding = False
         self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
         self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
         self.sliding_window = config.sliding_window

I didn't find the exact bug, but worth investigating.

Interesting. I'm having major issues with fine-tuning Gemma2. I'm fine-tuning the 27B -- the base model seems bricked.

Google org

Hello! Are you both using the latest transformers version v4.42.3?

This report predates 4.42.3. Now it works, thanks!

dzhulgakov changed discussion status to closed

Sign up or log in to comment