Fine-Tune Gemma with ChatML and Transformer Reinforcement Learning

#80
by Ateeqq - opened
  • Recent Release: Google introduced Gemma just last month.
  • Open-Source LLMs: Gemma is a family of open-source large language models.
  • Multiple Variants: It comes in two sizes: 7B parameter for powerful devices and 2B parameter for more resource-constrained environments.
  • Deployment Options: Each size offers base and fine-tuned versions for various use cases.

Introduction: Gemma Open Models

Gemma Open Models represent a family of advanced, lightweight models crafted using cutting-edge research and technology similar to that employed in the creation of the Gemini models. Developed by Google DeepMind in conjunction with various teams across Google, Gemma draws inspiration from Gemini, with its name derived from the Latin word "gemma," signifying "precious stone."

However, following its initial release, it became apparent that Gemma posed challenges when fine-tuning using the ChatML format, widely adopted by the open-source community in projects like OpenHermes or Dolphin. This blog post to guide you through the process of fine-tuning Gemma using ChatML and TRL.

What is CHatML?

ChatML, a new way to structure information for large language models, simplifies conversations between humans and AI by differentiating between user input, AI responses, and system messages. While designed for chat, it can also be used for other LLM tasks and is envisioned as a standard for LLM interaction, making it easier for developers to build applications.

What is Hugging Face TRL?

Hugging Face TRL, or Transformer Reinforcement Learning, is a library designed to train transformer-based language models using reinforcement learning techniques. In simpler terms, it allows you to fine-tune these powerful language models by rewarding them for desired behaviors. TRL integrates with the popular Transformers library, so you can leverage a wide range of pre-trained models.

Check: "Fine-Tune LLMs with Hugging Face," specifically tailored to the fine-tuning of Gemma 7B.

We'll leverage Hugging Face TRL, Transformers, and datasets to achieve our objectives.

Setting Up the Development Environment:

Our initial step involves installing Hugging Face Libraries and PyTorch, including trl, transformers, and datasets. If you're unfamiliar with trl, it's a novel library built on top of transformers and datasets, streamlining the fine-tuning process and aligning open LLMs.

! pip install transformers==4.38.2 datasets==2.16.1 evaluate==0.4.1 bitsandbytes==0.42.0 trl==0.7.11 peft==0.8.2 accelerate==0.26.1

If you're using a GPU with Ampere architecture or newer, such as NVIDIA A10G or RTX 4090/3090, you can leverage Flash Attention. Flash Attention is a method that optimizes attention computation, significantly boosting speed and reducing memory usage, potentially tripling training acceleration. You can learn more about Flash Attention here.

import torch
assert torch.cuda.get_device_capability()[0] >= 8, 'Hardware not supported for Flash Attention'

# Install flash-attn
!pip install ninja packaging
!MAX_JOBS=4 pip install flash-attn --no-build-isolation --upgrade

Please note that installing flash attention from source may take some time, ranging from 10 to 45 minutes.

Additionally, you'll need to log in to your Hugging Face account to access Gemma. Before using Gemma, you must agree to the terms of use by visiting the Gemma page following the gate mechanism.

from huggingface_hub import login

login(
  token="", # ADD YOUR TOKEN HERE
  add_to_git_credential=True
)

Creating and Preparing the Dataset:

For this blog post, we won't delve into dataset creation. We'll utilize the Databricks Dolly dataset, already formatted as messages, enabling us to fine-tune our model in a conversational format.

from datasets import load_dataset

# Load Dolly Dataset
dataset = load_dataset("philschmid/dolly-15k-oai-style", split="train")

print(dataset[3]["messages"])

Fine-Tuning LLM using trl and the SFTTrainer:

We will use the SFTTrainer from trl to fine-tune our model. The SFTTrainer simplifies the process of supervising fine-tuning open LLMs. It's a subclass of the Trainer from the transformers library, offering additional quality of life features such as:

  • Dataset formatting, including conversational and instruction format
  • Training on completions only, ignoring prompts
  • Packing datasets for more efficient training
  • PEFT (parameter-efficient fine-tuning) support, including Q-LoRA
  • Preparing the model and tokenizer for conversational fine-tuning (e.g., adding special tokens)

For example, Gemma, which comes with a vocabulary of ~250,000 tokens, requires inputs to start with a <bos> token. We'll prepare our model and tokenizer accordingly:

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

# Hugging Face model id
model_id = "google/gemma-7b"
tokenizer_id = "philschmid/gemma-tokenizer-chatml"

# BitsAndBytesConfig int-4 config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
)

# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16,
    quantization_config=bnb_config
)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
tokenizer.padding_side = 'right' # to prevent warnings

The SFTTrainer also supports a native integration with PEFT, making it super easy to efficiently tune LLMs using methods like QLoRA. We only need to create our LoraConfig and provide it to the trainer:

from peft import LoraConfig

# LoRA config based on QLoRA paper & Sebastian Raschka experiment
peft_config = LoraConfig(
        lora_alpha=8,
        lora_dropout=0.05,
        r=6,
        bias="none",
        target_modules="all-linear",
        task_type="CAUSAL_LM",
)

Before starting the training, we define the hyperparameters (TrainingArguments):

from transformers import TrainingArguments

args = TrainingArguments(
    output_dir="gemma-7b-dolly-chatml", # directory to save and repository id
    num_train_epochs=3,                     # number of training epochs
    per_device_train_batch_size=2,          # batch size per device during training
    gradient_accumulation_steps=2,          # number of steps before performing a backward/update pass
    gradient_checkpointing=True,            # use gradient checkpointing to save memory
    optim="adamw_torch_fused",              # use fused adamw optimizer
    logging_steps=10,                       # log every 10 steps
    save_strategy="epoch",                  # save checkpoint every epoch
    bf16=True,                              # use bfloat16 precision
    tf32=True,                              # use tf32 precision
    learning_rate=2e-4,                     # learning rate, based on QLoRA paper
    max_grad_norm=0.3,                      # max gradient norm based on QLoRA paper
    warmup_ratio=0.03,                      # warmup ratio based on QLoRA paper
    lr_scheduler_type="constant",           # use constant learning rate scheduler
    push_to_hub=False,                       # push model to hub
    report_to="tensorboard",                # report metrics to tensorboard
)

With these building blocks, we create our SFTTrainer to start training the model:

from trl import SFTTrainer

max_seq_length = 1512 # max sequence length for model and packing of the dataset

trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=dataset,
    peft_config=peft_config,
    max_seq_length=max_seq_length,
    tokenizer=tokenizer,
    packing=True,
    dataset_kwargs={
        "add_special_tokens": False, # We template with special tokens
        "append_concat_token": False, # No need to add additional separator token
    }
)

# Start training
trainer.train()

# Save model
trainer.save_model()

The training with Flash Attention for 3 epochs with a dataset of 15k samples took approximately 4 hours and 14 minutes on a g5.2xlarge instance, with a cost of around $5.3.

Testing Model and Running Inference:

After training, we evaluate and test our model. We load various samples from the original dataset and evaluate the model's performance using a simple loop and accuracy as our metric.

# Free memory
del model
del trainer
torch.cuda.empty_cache()

# Load adapted model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(peft_model_id)
model = AutoPeftModelForCausalLM.from_pretrained(peft_model_id, device_map="auto", torch_dtype=torch.float16)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)

# Get token id for end of conversation
eos_token = tokenizer("",add_special_tokens=False)["input_ids"][0]

Lets test some prompt samples and see how the model performs.

prompts = [
    "What is Fine Tuning? Explain why thats the case and if it was different in the past?",
    "Write a Python function to calculate the factorial of a number.",
]

def test_inference(prompt):
    prompt = pipe.tokenizer.apply_chat_template([{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True)
    outputs = pipe(prompt, max_new_tokens=1024, do_sample=True, temperature=0.7, top_k=50, top_p=0.95, eos_token_id=eos_token)
    return outputs[0]['generated_text'][len(prompt):].strip()

# Test inference
for prompt in prompts:
    print(f"    Prompt:\n{prompt}")
    print(f"    Response:\n{test_inference(prompt)}")
    print("-"*50)

This concludes the setup, fine-tuning, and evaluation process for the Gemma AI model using ChatML and Hugging Face TRL.

CHECK: The easiest way to Fine Tune Gemma here: https://exnrt.com/blog/ai/finetune-gemma-with-huggingface-transformers/

Additional Information:

What is Fine Tuning?

Fine-tuning refers to the process of taking a pre-trained machine learning model and further training it on a specific task or dataset to adapt it to a new task or improve its performance on a related task.

In fine-tuning, the pre-trained model's parameters are adjusted by continuing the training process with new data, typically using a smaller learning rate than during the original training phase. This allows the model to retain the general knowledge learned during pre-training while also learning task-specific features from the new data.

Fine-tuning is commonly used in transfer learning, where models trained on large datasets for generic tasks (such as image classification or natural language understanding) are adapted to perform specific tasks with smaller, task-specific datasets. This approach can save significant time and computational resources compared to training a new model from scratch. Fine-tuning is widely used in various fields, including computer vision, natural language processing, and speech recognition.

It would be nice if you reference my blog: https://www.philschmid.de/fine-tune-llms-in-2024-with-trl. It seems like you just copied the content.

Ateeqq changed discussion title from Fine-Tune Gemma with ChatML and Hugging Face TRL to Fine-Tune Gemma with ChatML and Transformer Reinforcement Learning

Sign up or log in to comment