Copyright 2024 DeepMind Technologies Limited.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

# Fine-tuning the 2B Griffin model with Flax

In this tutorial you will learn how to fine-tune the 2B Griffin model for a simple translation task.

## Setup

In [1]:
!git clone https://github.com/google-deepmind/recurrentgemma.git

Cloning into 'recurrentgemma'...
remote: Enumerating objects: 52, done.[K
remote: Counting objects: 100% (49/49), done.[K
remote: Compressing objects: 100% (47/47), done.[K
remote: Total 52 (delta 16), reused 5 (delta 2), pack-reused 3[K
Receiving objects: 100% (52/52), 74.57 KiB | 1.01 MiB/s, done.
Resolving deltas: 100% (16/16), done.


In [7]:
# @title Installation
! pip install 'git+https://github.com/google-deepmind/recurrentgemma.git#egg=recurrentgemma[jax]'
! pip install tensorflow-cpu  # Might require a session restart
! pip install --user kaggle
! pip install datasets

[33mDEPRECATION: git+https://github.com/google-deepmind/recurrentgemma.git#egg=recurrentgemma[jax] contains an egg fragment with a non-PEP 508 name pip 25.0 will enforce this behaviour change. A possible replacement is to use the req @ url syntax, and remove the egg fragment. Discussion can be found at https://github.com/pypa/pip/issues/11617[0m[33m
[0mCollecting recurrentgemma[jax]
  Cloning https://github.com/google-deepmind/recurrentgemma.git to /private/var/folders/jx/gld2clwj7sd_q8hd2m6hztcr0000gn/T/pip-install-2c9hrit5/recurrentgemma_54f0084d6e164dc38004db09c24dfacb
  Running command git clone --filter=blob:none --quiet https://github.com/google-deepmind/recurrentgemma.git /private/var/folders/jx/gld2clwj7sd_q8hd2m6hztcr0000gn/T/pip-install-2c9hrit5/recurrentgemma_54f0084d6e164dc38004db09c24dfacb
  Resolved https://github.com/google-deepmind/recurrentgemma.git to commit 0f5ca57442f17c7309c70b0228fd8e5505cbdaa1
  Installing build dependencies ... [?25ldone
[?25h  Getting req

In [10]:
# @title Python imports
import pathlib
from typing import Any, Mapping, Iterator
import enum
import functools

# We import JAX and some related packages.
import chex
import jax
import jax.numpy as jnp
import optax



# Finally, we import Recurrentgemma.
import sentencepiece as spm
from recurrentgemma import jax as recurrentgemma

# We will use tensorflow to handle the dataset
import tensorflow as tf
import tensorflow_datasets as tfds

ModuleNotFoundError: No module named 'tensorflow'

### Downloading the checkpoint

To use Griffin's checkpoints, you'll need a Kaggle account and API key. Here's how to get them:

1. Visit https://www.kaggle.com/ and create an account.
2. Go to your account settings, then the 'API' section.
3. Click 'Create new token' to download your key.

You will also need to acknowledge the Terms and Conditions of the RecrurrentGemma models on https://www.kaggle.com/models/google/recurrentgemma/ in order to be able to download the model weights and the tokenizer.

Then run the cell below.

If everything went well, you should see:
```
Kaggle credentials set.
Kaggle credentials successfully validated.
```

Now select and download the checkpoint you want to try. The 2b model can fit in memory for fine-tuning.

Need to visit the kaggle page and agree to their term.

In [11]:
!git clone https://huggingface.co/yingbei/recurrentg-2b-it


fatal: destination path 'recurrentg-2b-it' already exists and is not an empty directory.


  pid, fd = os.forkpty()


In [13]:
VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
weights_dir = pathlib.Path("./recurrentg-2b-it")
ckpt_path = weights_dir / VARIANT
vocab_path = weights_dir / 'tokenizer.model'

## Step 1: prepare the dataset



In [None]:
from datasets import load_dataset
code_sharegpt = load_dataset("sanjay920/code74k-sharegpt")

In [None]:
code_sharegpt["train"][0]["conversations"]

In [None]:
import json
chat_prefix = "<start_of_turn>"
chat_suffix = "<end_of_turn>"
user_role = "user\n"
preprocessed_code_sharegpt_data = []
for itor in code_sharegpt["train"]:
  c = itor["conversations"]
  c = json.loads(c)
  assert c[-1]["from"] == "gpt"
  assert c[0]["from"] == "human"
  assert len(c) == 2
  input = chat_prefix + user_role + c[0]["value"] + chat_suffix
  output = c[1]["value"]
  preprocessed_code_sharegpt_data.append({"input": input, "output": output})

print(json.dumps(preprocessed_code_sharegpt_data[0], indent=4))
print(len(preprocessed_code_sharegpt_data))


In [None]:

def load_custom_data(data):
    # convert list of dicts to tfds dataset format
    def preprocess(item):
        # Convert your item here, e.g., tokenize text
        return {
            'src': item['input'],  # Assume these are already preprocessed
            'dst': item['output'],
        }

    # Create a Dataset from the list of dictionaries
    ds = tf.data.Dataset.from_generator(lambda: (preprocess(item) for item in data),
                                        output_types={'src': tf.string, 'dst': tf.string})

    # Further dataset operations (batching, padding, etc.) go here
    # For example, to batch:
    # ds = ds.batch(2)

    return ds

### Tokenizer

Let's start by loading our vocabulary base tokenizer, which we'll construct using the [SentencePiece](https://github.com/google/sentencepiece) library.

In [None]:
vocab = spm.SentencePieceProcessor()
vocab.Load(str(vocab_path))

Let's customize `SentencePieceProcessor` for our English-to-French translation task. Since we're fine-tuning the English-only Griffin 2B model, we need a few adjustments:

- **Input Prefix**: Adding a common prefix to each input signals the translation task. For example we could go with a prompt like `Translate this into French: [INPUT_SENTENCE]`.

- **Translation Start suffix**: We add a suffix at the end of each prompt tells the model exactly when to begin the translation process. A new line should do the job.

- **LM Tokens**: Griffin models expect a *beginning of sequence* token at the beginning of each sequence. Similarly, we need to add an *end of sequence* token at the end of each training example.

In [None]:
class GriffinTokenizer:
  """Custom wrapper around a SentencePieceProcessor for tensorflow."""

  def __init__(self, spm_processor: spm.SentencePieceProcessor):
    self._spm_processor = spm_processor

  @property
  def pad_id(self) -> int:
    """Fast access to the pad id."""
    return self._spm_processor.pad_id()

  def tokenize(
      self,
      example: str | bytes,
      prefix: str = '',
      suffix: str = '',
      add_eos: bool = True,
  ) -> jax.Array:
    """
    Tokenization function.

    Args:
      example: input string to tokenize.
      prefix:  prefix to add to the input string.
      suffix:  suffix to add to the input string.
      add_eos: if True, add an end of sentence token at the end of the output
               sequence.
    Returns:
      Tokens corresponding to the input string.
    """
    int_list = [self._spm_processor.bos_id()]
    int_list.extend(self._spm_processor.EncodeAsIds(prefix + example + suffix))
    if add_eos:
      int_list.append(self._spm_processor.eos_id())

    return jnp.array(int_list, dtype=jnp.int32)

  def tokenize_tf_op(
      self,
      str_tensor: tf.Tensor,
      prefix: str = '',
      suffix: str = '',
      add_eos: bool = True,
  ) -> tf.Tensor:
    """Tensforflow operator for the tokenize function."""
    encoded = tf.numpy_function(
        self.tokenize,
        [str_tensor, prefix, suffix, add_eos],
        tf.int32)
    encoded.set_shape([None])
    return encoded

  def to_string(self, tokens: jax.Array) -> str:
    """Convert an array of tokens to a string."""
    return self._spm_processor.EncodeIds(tokens.tolist())

Now let's try our custom tokenizer on the MTNT dataset

In [None]:
def tokenize_source(tokenizer, example: tf.Tensor):
  return tokenizer.tokenize_tf_op(
      example,
      prefix='',
      suffix='\n<start_of_turn>model\n',
      add_eos=False
  )
def tokenize_destination(tokenizer, example: tf.Tensor):
  return tokenizer.tokenize_tf_op(example, add_eos=True)

tokenizer = GriffinTokenizer(vocab)
# ds = tfds.load("mtnt/en-fr",split="train")

# ds = ds.take(2)
# for d in ds:
#   print(d)

ds = load_custom_data(preprocessed_code_sharegpt_data[:2])
print(ds)
ds = ds.map(lambda x: {
    'input': tokenize_source(tokenizer, x['src']),
    'output': tokenize_destination(tokenizer, x['dst'])
  })
ds = ds.as_numpy_iterator()
for idx, example in enumerate(ds):
  print(f'Example {idx}:')
  for key, val in example.items():
    print(f'{key}: {val}')
  print()

### Data loader

We can now wrap everything a build our data loader.

In [None]:
@chex.dataclass(frozen=True)
class TrainingInput:
  # Input tokens given to the model
  input_tokens: jax.Array

  # A mask that determines which tokens contribute to the target loss
  # calculation.
  target_mask: jax.Array

class DatasetSplit(enum.Enum):
  TRAIN = 'train'
  VALIDATION = 'valid'


class MyDatasetBuilder:
  """Data loader for the MTNT dataset."""

  N_ITEMS = {DatasetSplit.TRAIN: 2000, DatasetSplit.VALIDATION: 100}

  BUFFER_SIZE_SHUFFLE = 1000
  TRANSLATION_PREFIX = ''
  TRANSLATION_SUFFIX = '\n<start_of_turn>model\n'

  def __init__(self,
               tokenizer : GriffinTokenizer,
               max_seq_len: int):
    """Constructor.

    Args:
      tokenizer: Gemma tokenizer to use.
      max_seq_len: size of each sequence in a given batch.
    """
    self._tokenizer = tokenizer
    self._base_data = {
        DatasetSplit.TRAIN: load_custom_data(preprocessed_code_sharegpt_data[:2000]),
        DatasetSplit.VALIDATION: load_custom_data(preprocessed_code_sharegpt_data[-100:]),
    }
    self._max_seq_len = max_seq_len

  def _tokenize_source(self, example: tf.Tensor):
    """Tokenization function for the source."""
    return self._tokenizer.tokenize_tf_op(
        example, prefix=self.TRANSLATION_PREFIX, suffix=self.TRANSLATION_SUFFIX,
        add_eos=False
    )

  def _tokenize_destination(self, example: tf.Tensor):
    """Tokenization function for the French translation."""
    return self._tokenizer.tokenize_tf_op(example, add_eos=True)

  def _pad_up_to_max_len(self,
                         input_tensor: tf.Tensor,
                         pad_value: int | bool,
                         ) -> tf.Tensor:
    """Pad the given tensor up to sequence length of a batch."""
    seq_len = tf.shape(input_tensor)[0]
    to_pad = tf.maximum(self._max_seq_len - seq_len, 0)
    return tf.pad(
        input_tensor, [[0, to_pad]], mode='CONSTANT', constant_values=pad_value,
    )

  def _to_training_input(
      self,
      src_tokens: jax.Array,
      dst_tokens: jax.Array,
  ) -> TrainingInput:
    """Build a training input from a tuple of source and destination tokens."""

    # The input sequence fed to the model is simply the concatenation of the
    # source and the destination.
    tokens = tf.concat([src_tokens, dst_tokens], axis=0)

    # We want to prevent the model from updating based on the source (input)
    # tokens. To achieve this, we add a target mask to each input.
    q_mask = tf.zeros_like(src_tokens, dtype=tf.bool)
    a_mask = tf.ones_like(dst_tokens, dtype=tf.bool)
    mask = tf.concat([q_mask, a_mask], axis=0)

    # If the output tokens sequence is smaller than the target sequence size,
    # then we pad it with pad tokens.
    tokens = self._pad_up_to_max_len(tokens, self._tokenizer.pad_id)

    # We don't want to perform the backward on the pad tokens.
    mask = self._pad_up_to_max_len(mask, False)

    return TrainingInput(input_tokens=tokens, target_mask=mask)


  def get_train_dataset(self, batch_size: int, num_epochs: int):
    """Build the training dataset."""

    # Tokenize each sample
    ds = self._base_data[DatasetSplit.TRAIN].map(
        lambda x : (self._tokenize_source(x['src']),
                    self._tokenize_destination(x['dst']))
    )
    print(ds)

    # Convert them to training inputs
    ds = ds.map(lambda x, y: self._to_training_input(x, y))

    # Remove the samples which are too long
    ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)

    # Shuffle the dataset
    ds = ds.shuffle(buffer_size=self.BUFFER_SIZE_SHUFFLE)

    # Repeat if necessary
    ds = ds.repeat(num_epochs)

    # Build batches
    ds = ds.batch(batch_size, drop_remainder=True)
    return ds

  def get_validation_dataset(self, batch_size: int):
    """Build the validation dataset."""

    # Same as the training dataset, but no shuffling and no repetition
    ds = self._base_data[DatasetSplit.VALIDATION].map(
        lambda x : (self._tokenize_source(x['src']),
                    self._tokenize_destination(x['dst']))
    )
    ds = ds.map(lambda x, y: self._to_training_input(x, y))
    ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)
    ds = ds.batch(batch_size, drop_remainder=True)
    return ds

# backup dataset class

In [None]:
class MTNTDatasetBuilder:
  """Data loader for the MTNT dataset."""

  N_ITEMS = {DatasetSplit.TRAIN: 35_692, DatasetSplit.VALIDATION: 811}

  BUFFER_SIZE_SHUFFLE = 10_000
  TRANSLATION_PREFIX = 'Translate this into French:\n'
  TRANSLATION_SUFFIX = '\n'

  def __init__(self,
               tokenizer : GriffinTokenizer,
               max_seq_len: int):
    """Constructor.

    Args:
      tokenizer: Gemma tokenizer to use.
      max_seq_len: size of each sequence in a given batch.
    """
    self._tokenizer = tokenizer
    self._base_data = {
        DatasetSplit.TRAIN: tfds.load("mtnt/en-fr",split="train"),
        DatasetSplit.VALIDATION: tfds.load("mtnt/en-fr",split="valid"),
    }
    self._max_seq_len = max_seq_len

  def _tokenize_source(self, example: tf.Tensor):
    """Tokenization function for the source."""
    return self._tokenizer.tokenize_tf_op(
        example, prefix=self.TRANSLATION_PREFIX, suffix=self.TRANSLATION_SUFFIX,
        add_eos=False
    )

  def _tokenize_destination(self, example: tf.Tensor):
    """Tokenization function for the French translation."""
    return self._tokenizer.tokenize_tf_op(example, add_eos=True)

  def _pad_up_to_max_len(self,
                         input_tensor: tf.Tensor,
                         pad_value: int | bool,
                         ) -> tf.Tensor:
    """Pad the given tensor up to sequence length of a batch."""
    seq_len = tf.shape(input_tensor)[0]
    to_pad = tf.maximum(self._max_seq_len - seq_len, 0)
    return tf.pad(
        input_tensor, [[0, to_pad]], mode='CONSTANT', constant_values=pad_value,
    )

  def _to_training_input(
      self,
      src_tokens: jax.Array,
      dst_tokens: jax.Array,
  ) -> TrainingInput:
    """Build a training input from a tuple of source and destination tokens."""

    # The input sequence fed to the model is simply the concatenation of the
    # source and the destination.
    tokens = tf.concat([src_tokens, dst_tokens], axis=0)

    # We want to prevent the model from updating based on the source (input)
    # tokens. To achieve this, we add a target mask to each input.
    q_mask = tf.zeros_like(src_tokens, dtype=tf.bool)
    a_mask = tf.ones_like(dst_tokens, dtype=tf.bool)
    mask = tf.concat([q_mask, a_mask], axis=0)

    # If the output tokens sequence is smaller than the target sequence size,
    # then we pad it with pad tokens.
    tokens = self._pad_up_to_max_len(tokens, self._tokenizer.pad_id)

    # We don't want to perform the backward on the pad tokens.
    mask = self._pad_up_to_max_len(mask, False)

    return TrainingInput(input_tokens=tokens, target_mask=mask)


  def get_train_dataset(self, batch_size: int, num_epochs: int):
    """Build the training dataset."""

    # Tokenize each sample
    ds = self._base_data[DatasetSplit.TRAIN].map(
        lambda x : (self._tokenize_source(x['src']),
                    self._tokenize_destination(x['dst']))
    )

    # Convert them to training inputs
    ds = ds.map(lambda x, y: self._to_training_input(x, y))

    # Remove the samples which are too long
    ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)

    # Shuffle the dataset
    ds = ds.shuffle(buffer_size=self.BUFFER_SIZE_SHUFFLE)

    # Repeat if necessary
    ds = ds.repeat(num_epochs)

    # Build batches
    ds = ds.batch(batch_size, drop_remainder=True)
    return ds

  def get_validation_dataset(self, batch_size: int):
    """Build the validation dataset."""

    # Same as the training dataset, but no shuffling and no repetition
    ds = self._base_data[DatasetSplit.VALIDATION].map(
        lambda x : (self._tokenize_source(x['src']),
                    self._tokenize_destination(x['dst']))
    )
    ds = ds.map(lambda x, y: self._to_training_input(x, y))
    ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)
    ds = ds.batch(batch_size, drop_remainder=True)
    return ds

# Try

Let's give it a try.

In [None]:
dataset_builder = MyDatasetBuilder(tokenizer, max_seq_len=4000)
ds = dataset_builder.get_train_dataset(3, 1)
ds = ds.take(2)
ds = ds.as_numpy_iterator()
for idx, example in enumerate(ds):
  print(f'Example {idx}:')
  for key, val in example.items():
    print(f'{key}: {val}')
  print()

## Fine tuning Griffin

### Getting started

First let's load the model. Use the `griffin_lib.GriffinConfig.from_flax_params_or_variables` function to automatically load the correct configuration from a checkpoint.

In [None]:
# Load parameters
params =  recurrentgemma.load_parameters(ckpt_path, "single_device")
config = recurrentgemma.GriffinConfig.from_flax_params_or_variables(params)
model = recurrentgemma.Griffin(config)

Can our model translate French ? Well let's try it out !

In [None]:
sampler = recurrentgemma.Sampler(model=model, vocab=vocab, params=params)

In [None]:
output = sampler(
  ["Develop a Python code snippet that generates an abbreviated version of a given full name.\nname = 'John Smith'"],
  # number of steps performed when generating
  total_generation_steps=300,
)
print(output.text[0])

As expected, it didn't work. Let's see if we can get better results by fine-tuning.

### Model forward and loss function

The `Griffin` class inherits from [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/guides/flax_fundamentals/flax_basics.html). It offers two essential methods:

- `init`: Initializes the model's parameters.

- `apply`: Executes the model's `__call__` function using a given set of parameters.

Since are working with pre-trained weights, we won't use the `init` function.

With it we can now build the `forward_function` which performs the forward pass and loss computation.

In [None]:
def forward_and_loss_fn(
    params,
    *,
    model: recurrentgemma.Griffin,
    input_tokens: jax.Array,            # Shape [B, L]
    input_mask: jax.Array,              # Shape [B, L]
    positions: jax.Array,               # Shape [B, L]
) -> jax.Array:
  """Forward pass and loss function.

  Args:
    params: model's input parameters.
    model: Griffin model to call.
    input_tokens: input tokens sequence, shape [B, L].
    input_mask: tokens to ignore when computing the loss, shape [B, L].
    positions: relative position of each token, shape [B, L].

  Returns:
    Softmax cross-entropy loss for the next-token prediction task.
  """
  batch_size = input_tokens.shape[0]
  # Foward pass on the input data.
  # No attention cache is needed here.
  # Exclude the last step as it does not appear in the targets.
  logits, _ = model.apply(
        {"params": params},
        tokens=input_tokens[:, :-1],
        segment_pos=positions[:, :-1],
        cache=None,
    )

  # Similarly, the first token cannot be predicteds.
  target_tokens = input_tokens[:, 1:]
  target_mask = input_mask[:, 1:]

  # Convert the target labels into one-hot encoded vectors.
  one_hot = jax.nn.one_hot(target_tokens, logits.shape[-1])

  # Don't update on unwanted tokens.
  one_hot = one_hot * target_mask.astype(one_hot.dtype)[...,None]

  # Normalisation factor.
  norm_factor = batch_size * (jnp.sum(target_mask) + 1e-8)

  # Return the nll loss.
  return -jnp.sum(jax.nn.log_softmax(logits) * one_hot) / norm_factor

We can now build the train_step function which performs the backward pass and updates the model's parameters accordingly.

In [None]:
Params = Mapping[str, Any]

def get_positions(example: jax.Array, pad_id : int) -> jax.Array:
  """Builds the position vector from the given tokens."""
  pad_mask = example != pad_id
  positions = jnp.cumsum(pad_mask, axis=-1)
  # Subtract one for all positions from the first valid one as they are
  # 0-indexed
  positions = positions - (positions >= 1)
  return positions

@functools.partial(
    jax.jit,
    static_argnames=['model', 'optimizer'],
    donate_argnames=['params', 'opt_state'],
)
def train_step(
    model: recurrentgemma.Griffin,
    params: Params,
    optimizer: optax.GradientTransformation,
    opt_state: optax.OptState,
    pad_id: int,
    example: TrainingInput,
) -> tuple[jax.Array, Params, optax.OptState]:
  """Train step.

  Args:
    model: Griffin model.
    params: model's input parameters.
    optimizer: optax optimizer to use.
    opt_state: input optimizer's state.
    pad_id: id of the pad token.
    example: input batch.

  Returns:
    Training loss, updated parameters, updated optimizer state.
  """

  positions = get_positions(example.input_tokens, pad_id)

  # Forward and backward passes
  train_loss, grads = jax.value_and_grad(forward_and_loss_fn)(
      params,
      model=model,
      input_tokens=example.input_tokens,
      input_mask=example.target_mask,
      positions=positions,
  )
  # Update the parameters
  updates, opt_state = optimizer.update(grads, opt_state, params)
  params = optax.apply_updates(params, updates)

  return train_loss, params, opt_state

Similarly, we build a `validation_step` function without backward pass.

In [None]:
@functools.partial(jax.jit, static_argnames=['model'])
def validation_step(
    model: recurrentgemma.Griffin,
    params: Params,
    pad_id: int,
    example: TrainingInput,
) -> jax.Array:
  return forward_and_loss_fn(
      params,
      model=model,
      input_tokens=example.input_tokens,
      input_mask=example.target_mask,
      positions=get_positions(example.input_tokens, pad_id),
  )

And now the training loop itself.

In [None]:
def train_loop(
    model: recurrentgemma.Griffin,
    params: Params,
    optimizer: optax.GradientTransformation,
    train_ds: Iterator[TrainingInput],
    validation_ds: Iterator[TrainingInput],
    num_steps: int | None = None,
    eval_every_n: int = 20,
):
  opt_state = jax.jit(optimizer.init)(params)

  step_counter = 0
  avg_loss=0

  # A first round of validation loss
  n_steps_eval = 0
  eval_loss = 0
  for val_example in validation_ds.as_numpy_iterator():
    eval_loss += validation_step(
        model, params, dataset_builder._tokenizer.pad_id, val_example
    )
    n_steps_eval += 1
  print(f"Start, validation loss: {eval_loss/n_steps_eval}")

  for train_example in train_ds:
    train_loss, params, opt_state = train_step(
        model=model,
        params=params,
        optimizer=optimizer,
        opt_state=opt_state,
        pad_id=dataset_builder._tokenizer.pad_id,
        example=train_example,
    )

    step_counter += 1
    avg_loss += train_loss
    if step_counter % eval_every_n == 0:
      eval_loss = 0

      n_steps_eval = 0
      val_iterator = validation_ds.as_numpy_iterator()
      for val_example in val_iterator:
        eval_loss += validation_step(
            model,
            params,
            dataset_builder._tokenizer.pad_id,
            val_example,
        )
        n_steps_eval +=1
      avg_loss /= eval_every_n
      eval_loss /= n_steps_eval
      print(f"STEP {step_counter} training loss: {avg_loss} - eval loss: {eval_loss}")
      avg_loss=0
    if num_steps is not None and step_counter > num_steps:
      break
  return params

Here you have to choose an optimizer. For devices with smaller memory (like the T4 GPU) we suggest to use SGD as it has a much lower memory footprint. To achieve best finetuning performance we suggest to try Adam-W. We have provided optimal hyper parameters for each optimizer for the particular task in this notebook for the '2b-it' checkpoint.

In [None]:
def griffin_weight_decay_mask(params_like: optax.Params) -> Any:
  # Don't put weight decay on the RGLRU, the embeddings and any biases
  def enable_weight_decay(path: list[Any], _: Any) -> bool:
    # Parameters in the LRU and embedder
    path = [dict_key.key for dict_key in path]
    if 'rg_lru' in path or 'embedder' in path:
      return False
    # All biases and scales
    if path[-1] in ('b', 'scale'):
      return False
    return True

  return jax.tree_util.tree_map_with_path(enable_weight_decay, params_like)

optimizer_choice = "adamw" #@param ["sgd", "adamw"]

if optimizer_choice == "sgd":
  optimizer = optax.sgd(learning_rate=1e-3)
  num_steps = 300
elif optimizer_choice == "adamw":
  optimizer = optax.adamw(
        learning_rate=1e-4,
        b2=0.96,
        eps=1e-8,
        weight_decay=0.1,
        mask=griffin_weight_decay_mask,
    )
  num_steps = 100
  pass
else:
  raise ValueError(f"Unknown optimizer: {optimizer_choice}")

Finally we prepare the training and validation datasets

In [None]:
# Small seq size so that everything fits in memory
num_epochs = 1 #@param {type: "integer"}
batch_size = 1 #@param {type: "integer"}
sequence_length = 4000 #@param {type: "integer"}

# Make the dataset builder
tokenizer = GriffinTokenizer(vocab)
dataset_builder= MTNTDatasetBuilder(tokenizer, sequence_length + 1)

# Build the training dataset
train_ds = dataset_builder.get_train_dataset(
    batch_size=batch_size,
    num_epochs=num_epochs,
).as_numpy_iterator()

# Build the validation dataset, with a limited number of samples for this demo
validation_ds = dataset_builder.get_validation_dataset(
    batch_size=batch_size,
).take(50)

We can now fine-tune our model on a limited number of steps.

In [None]:
trained_params = train_loop(
    model=model,
    params=params,
    optimizer=optimizer,
    train_ds=train_ds,
    validation_ds=validation_ds,
    num_steps=num_steps,
)

Both the training loss and the validation's are going down. But is it working ?

Let's try again with our previous example. To ensure our input matches the training format, remember to use the prefix 'Translate this into French:\n'  and a newline character at the end. This signals the model to begin translation.

In [None]:
sampler.params = trained_params
output = sampler(
    ["Translate this into French:\nHello, my name is Morgane.\n"],
    total_generation_steps=30,
)
print(output.text[0])