Gemma Fine-tuning won't work with any other method except SFT

#35
by erfanzar - opened

I have a library named EasyDeL and I have re-implemented Gemma for that with some other options like flash-attention, ring-attention, blockwise_ffn, and ...
but there's a problem the training won't do anything the loss will start from 8 and won't go any lower than 4.23 no matter which model you try I have already tried all of the Gemma models, here's the example of training and fine-tuning model with EasyDeL (this code Is not for DPOTrainer but the same will happens for DPOTrainer loss average of ~50 and model don't learn any)

Installation dependencies

You Need EasyDeL from head

pip install git+https://github.com/erfanzar/EasyDeL.git -q -U
pip install jax[tpu]==0.4.22 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html -q

Fine-tuning Code

from EasyDel import (
    AutoEasyDelModelForCausalLM,
    TrainArguments,
    CausalLanguageModelTrainer,
    EasyDelOptimizers,
    EasyDelSchedulers,
    EasyDelGradientCheckPointers,
    EasyDelState,
    EasyDeLXRapTureConfig,
    get_modules_by_type,
    easystate_to_huggingface_model
)
from datasets import load_dataset

from flax.core import FrozenDict
from transformers import AutoTokenizer
from jax import numpy as jnp
import jax
from transformers import GemmaForCausalLM as ModuleTorch


def main(use_lora=False):
    pretrained_model_name_or_path = "google/gemma-2b-it"

    model, params = AutoEasyDelModelForCausalLM.from_pretrained(
        pretrained_model_name_or_path,
        device=jax.devices('cpu')[0],
        input_shape=(1, 1),
        device_map="auto",
        sharding_axis_dims=(1, 1, 1, -1)
    )

    config = model.config

    model_parameters = FrozenDict({"params": params})

    dtype = jnp.bfloat16
    config.add_basic_configurations(
        attn_mechanism="normal",
        block_b=1,
        block_q=128,
        block_k=128,
        block_k_major=128,
    )
    tokenizer = AutoTokenizer.from_pretrained(
        pretrained_model_name_or_path,
        trust_remote_code=True
    )

    max_length = 4096

    configs_to_initialize_model_class = {
        'config': config,
        'dtype': dtype,
        'param_dtype': dtype,
        'input_shape': (1, max_length)
    }

    if tokenizer.pad_token == None:
        tokenizer.pad_token = tokenizer.eos_token

    rapture_config = EasyDeLXRapTureConfig(
        model_parameters,
        lora_dim=64,
        fully_fine_tune_parameters=["embed_tokens"],
        lora_fine_tune_parameters=["q_proj", "v_proj", "k_proj", "o_proj"],
        verbose=True
    ) if use_lora else None

    dataset = load_dataset(
        "erfanzar/Zeus-v0.1-Llama",
        split="train",
    )

    def gemma_prompt(x):
        return x.replace(
            "[/INST]", "<end_of_turn>\n<start_of_turn>model\n").replace(
            "</s><s>[INST]", "<end_of_turn>\n").replace(
            "<s>[INST] <<SYS>>\n", "<start_of_turn>system\n").replace(
            "<s>[INST]", "<start_of_turn>user\n").replace(
            "<</SYS>>\n", "<end_of_turn>\n").replace(
            "<end_of_turn>\n\n", "<end_of_turn>\n"
        )

    def tokenization_process(data_chunk) -> dict:
        return tokenizer(
            gemma_prompt(data_chunk["prompt"]),
            add_special_tokens=False,
            max_length=max_length,
            padding="max_length"
        )

    dataset = dataset.map(
        tokenization_process,
        num_proc=18,
        remove_columns=dataset.column_names
    )

    train_args = TrainArguments(

        model_class=get_modules_by_type(config.model_type)[1],
        configs_to_initialize_model_class=configs_to_initialize_model_class,
        custom_rule=config.get_partition_rules(True),

        model_name="Jupyter",

        num_train_epochs=2,
        learning_rate=5e-5,
        learning_rate_end=7e-6,
        warmup_steps=200,
        optimizer=EasyDelOptimizers.ADAMW,
        scheduler=EasyDelSchedulers.LINEAR,
        weight_decay=0.02,
        total_batch_size=64,
        max_sequence_length=max_length,
        gradient_checkpointing=EasyDelGradientCheckPointers.NOTHING_SAVEABLE,
        sharding_array=(1, 1, 1, -1),
        use_pjit_attention_force=False,
        gradient_accumulation_steps=1,

        init_input_shape=(1, max_length),

        dtype=dtype,
        param_dtype=dtype,

        step_start_point=0,

        training_time="7H",
        rapture_config=rapture_config,
        wandb_entity=None
    )

    trainer = CausalLanguageModelTrainer(
        train_args,
        dataset.shuffle().shuffle().shuffle(),
        checkpoint_path=None
    )

    model_parameters = model_parameters if not use_lora else None

    output = trainer.train(
        model_parameters=model_parameters,
        state=None
    )

    with jax.default_device(jax.devices("cpu")[0]):
        model = easystate_to_huggingface_model(
            state=EasyDelState.load_state(
                output.checkpoint_path
            ),
            base_huggingface_module=ModuleTorch,
            config=config
        )

    model = model.half()
    model.push_to_hub("Gemma-2B-Fine-tuned")
    tokenizer.push_to_hub("Gemma-2B-Fine-tuned")


if __name__ == "__main__":
    main()
Google org

I haven't seen this code before; do you track any other metrics at the start of training that might indicate what's wrong?

Just to check, are you finetuning the pretrained checkpoints?

yes I'm fine-tuning the pre-trained model and this is my library EasyDeL and I track some metrics like TPU/GPU/CPU usage mean (loss/accuracy) loss, and accuracy, perplexity, trained_tokens, learning rate

these are gemma-7b-it charts which I have tried with higher lr but in lower learning rates exactly the same would happen

Aditional information

  • Model Generate Text fine
    • Used Code
from EasyDel import JAXServer, JAXServerConfig, EasyServe
from fjformer import get_dtype
from EasyDel.serve.prompters import GemmaPrompter, Llama2Prompter, OpenChatPrompter, Qwen2Prompter
from EasyDel.serve.prompters.base_prompter import BasePrompter
from jax import numpy as jnp, lax
import jax
from typing import List, Union, Optional

max_sequence_length = 8192
max_compile_tokens = 256
max_new_tokens_ratio = 25

dtype = "bf16"

prompter_type = "gemma"

sharding_axis_dims = (1, 1, 1, -1)
pretrained_model_name_or_path = "google/gemma-7b-it"
attn_mechanism = "normal"
scan_mlp_chunk_size = max_compile_tokens
use_scan_mlp = True
scan_ring_attention = True
block_k = 128
block_q = 128
use_sharded_kv_caching = False

server_config = JAXServerConfig(
    max_sequence_length=max_sequence_length,
    max_compile_tokens=max_compile_tokens,
    max_new_tokens=max_compile_tokens * max_new_tokens_ratio,
    dtype=dtype,
    pre_compile=False
)

prompters = {
    "gemma": GemmaPrompter(),
    "llama": Llama2Prompter(),
    "openchat": OpenChatPrompter(),
    "qwen2": Qwen2Prompter()
}

prompter: BasePrompter = prompters[prompter_type]

class JAXServerC(JAXServer):
    @staticmethod
    def format_chat(history: List[List[str]], prompt: str, system: Union[str, None]) -> str:
        return prompter.format_message(
            history=history,
            prompt=prompt,
            system_message=system,
            prefix=None
        )

    @staticmethod
    def format_instruct(system: str, instruction: str) -> str:
        return prompter.format_message(
            prefix=None,
            system_message=system,
            prompt=instruction,
            history=[]
        )

server = JAXServerC.from_torch_pretrained(
    server_config=server_config,
    pretrained_model_name_or_path=pretrained_model_name_or_path,
    device=jax.devices('cpu')[0],
    dtype=get_dtype(dtype=dtype),
    param_dtype=get_dtype(dtype=dtype),
    precision=jax.lax.Precision("fastest"),
    sharding_axis_dims=sharding_axis_dims,
    sharding_axis_names=("dp", "fsdp", "tp", "sp"),
    input_shape=(1, server_config.max_sequence_length),
    model_config_kwargs=dict(
        fully_sharded_data_parallel=True,
        attn_mechanism=attn_mechanism,
        scan_mlp_chunk_size=max_compile_tokens,
        use_scan_mlp=use_scan_mlp,
        scan_ring_attention=scan_ring_attention,
        block_k=block_k,
        block_q=block_q,
        use_sharded_kv_caching=use_sharded_kv_caching
    )
)

history = []
while True:
    user_prompt = input("> ")
    model_prompt = server.format_chat(history, user_prompt, None)

    past_response_length = 0
    
    for response, used_tokens in server.sample(
        model_prompt,
        greedy=False
    ):
        print(response[past_response_length:], end="")
        past_response_length = len(response)
    
    history.append([user_prompt, response])
  • Trainer Loops (DPO, CLM) are both working fine.

Screenshot from 2024-02-24 13-37-22.png
Screenshot from 2024-02-24 13-37-30.png
Screenshot from 2024-02-24 13-37-46.png

Sign up or log in to comment