JAX/Flax Implementation

#13
by lemon-mint - opened

DeepMind's Gemma implementation does not seem to have been updated in accordance with the new release.

Are there any plans to release the JAX/Flax implementation and model?

lemon-mint changed discussion title from JAX/Flax implementation to JAX/Flax Implementation
Google org

There is! Our focus was on getting the weights out properly. For my own curiosity why are you interested in flax/jax in particular?

For my own curiosity why are you interested in flax/jax in particular?

I think using TPU is the most cost-effective way to full fine-tune the 27B model.

Additionally, the JAX/Flax implementation is good to use as a reference implementation. Last time, in Gemma 1, DeepMind's implementation was the only one without bugs.

There is! Our focus was on getting the weights out properly. For my own curiosity why are you interested in flax/jax in particular?

@canyon289 This would be very convenient. I want to integrate with our JORA library (Jax centered LLM PEFT finetuning). I believe the only differences from Gemma 1/1.1 are

  • Logit softcaps,
  • Sliding Window Attention, and
  • query normalization

Plus, the weights in Flax format (i.e. orbax.checkpoint)

Google org

Thank you both for the answers. There's a couple of other changes such as GQA! Regardless its still being worked on, it should be out soonish. My apologies for the delay

JORA looks interesting! I'd suggest adding a link to the paper in the readme.

Sign up or log in to comment