Anyone else currently experimenting with fine-tuning Jamba?

#21
by Severian - opened

Just curious to see what others experience is like so far. I'm finding the model is training very quickly and a tendency to overfit if you don't find the right combination of hyperparameters (I went through dozens and dozens of iterations to see loss differences). I think overall this might be a good thing in that Jamba can learn much quicker but it will take a different approach in your set up and process. I haven't landed on good rule-of-thumb for how to know where to start when training across different datasets other than just experimenting and intuition.

If someone has some more insight or can set my wrong perspective/understanding/approach straight, please share and enlighten! I am genuinely curious and am not bothered by being in the wrong so don't hesitate to call me out or share your thoughts : )

Also currently on it.
I agree it's not so easy so far to find the right balance for lr. Their example is way to high but conversely axolotl has defaulted on a value way too low. Currently trying in the 2e-4 - 5e-4 zone.

Would say given the capacity of this model, the most interesting dataset could be for long document summarization and other long context related tasks.

Thanks for your input! Seems we've landed on similar conclusions and settings. Here are the most current best settings I am getting with Jamba over the Open-Hermes dataset:

lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["embed_tokens", "x_proj", "in_proj", "out_proj"],
    lora_dropout=0.2,  
    task_type="CAUSAL_LM",
    bias="none"
)

trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,
    dataset_text_field="text",
    max_seq_length=max_seq_length,
    tokenizer=tokenizer,
    args=TrainingArguments(
        num_train_epochs=1,
        lr_scheduler_type='linear',
        learning_rate=2e-5,
        per_device_train_batch_size=2,
        gradient_accumulation_steps=8,
        gradient_checkpointing=True,
        warmup_steps=10,  
        weight_decay=0.2,  
        fp16=not torch.cuda.is_bf16_supported(),
        bf16=torch.cuda.is_bf16_supported(),
        logging_steps=1,  
        save_steps=100, 
        output_dir="outputs",
        optim="paged_adamw_8bit",
        seed=42,
    ),
)

Following

Sign up or log in to comment