Distilled Google Gemma-2-2b-it

image/png

Model Description

This model is a distilled version of Google's Gemma-2-2b-it, created through knowledge distillation from the larger Gemma-2-9b-it model. The distillation process was performed using arcee-ai DistilKit, focusing on preserving the capabilities of the larger model in a more compact form.

Key Features

  • Base Model: Google Gemma-2-2b-it
  • Teacher Model: Google Gemma-2-9b-it
  • Distillation Tool: arcee-ai DistilKit
  • Training Data: Subset of mlabonne/Tome dataset (30,000 rows)
  • Distillation Method: Logit-based distillation

Distillation Process

The distillation process involved transferring knowledge from the larger Gemma-2-9b-it model to the smaller Gemma-2-2b-it model. This was achieved using arcee-ai DistilKit, which offers several key features:

  1. Logit-based Distillation: This method ensures that the student model (Gemma-2-2b-it) learns to mimic the output distribution of the teacher model (Gemma-2-9b-it).

  2. Architectural Consistency: Both the teacher and student models share the same architecture, allowing for direct logit-based distillation.

Dataset

The model was trained on a subset of the mlabonne/Tome dataset, utilizing 30,000 rows due to computational constraints. This dataset was chosen for its quality and relevance to the target tasks of the model.

Model Limitations

While this distilled model retains much of the capability of its larger counterpart, users should be aware of potential limitations:

  • Slightly reduced performance compared to the original Gemma-2-9b-it model
  • Limited to the scope of tasks covered in the training data
  • May not perform as well on highly specialized or domain-specific tasks

Usage

Below we share some code snippets on how to get quickly started with running the model. First, install the Transformers library with:

pip install -U transformers

Then, copy the snippet from the section that is relevant for your usecase.

Running with the pipeline API

import torch
from transformers import pipeline

pipe = pipeline(
    "text-generation",
    model="Syed-Hasan-8503/Gemma-2-2b-it-distilled",
    model_kwargs={"torch_dtype": torch.bfloat16},
    device="cuda",  # replace with "mps" to run on a Mac device
)

messages = [
    {"role": "user", "content": "Who are you? Please, answer in pirate-speak."},
]

outputs = pipe(messages, max_new_tokens=256)
assistant_response = outputs[0]["generated_text"][-1]["content"].strip()
print(assistant_response)
# Ahoy, matey! I be Gemma, a digital scallywag, a language-slingin' parrot of the digital seas. I be here to help ye with yer wordy woes, answer yer questions, and spin ye yarns of the digital world.  So, what be yer pleasure, eh? 🦜
Downloads last month
11
Safetensors
Model size
2.61B params
Tensor type
BF16
·
Inference API
Unable to determine this model's library. Check the docs .

Dataset used to train Syed-Hasan-8503/Gemma-2-2b-it-distilled