Jamba-Hercules / README.md
Severian's picture
Update README.md
dd66d6a verified
|
raw
history blame
2.42 kB
metadata
license: apache-2.0
tags:
  - jamba
datasets:
  - teknium/OpenHermes-2.5
base_model: ai21labs/Jamba-v0.1
pipeline_tag: text-generation

Jamba-Open-Hermes

This is highly experimental and should be viewed as purely testing right now. Jamba has been very hard to train but I wanted to see how it did on one of the best datasets we have access to. I believe in transparent development so all best working iterations, even if they are a bit wonky, will be pushed here.

I've unfortunately gone way over budget and spent a significant amount of money over the past few days trying to figure the best way to fine-tune Jamba. New iterations may be sparse until Jamba is coverted to MLX or I find buried treasure somewhere. If you've downloaded it, feel free to provde any feedback so I can improve on the next training cycle! Thanks for checking it out.

There's been limited testing so no example outputs yet


Training

Open-Hermes-2.0 (Only first 1500 examples): [ 1530/125193 4:46:45 < 386:48:08, 0.09 it/s, Epoch 0.01/1]

Notes:

  • Tried over 30+ combinations of hyperparameters. Below are the best I could land on.

  • Loss hovered around ~5-6 no matter what I tried with the learning rate.

  • Couldn't increase batch size due to Colab limitations, so the answer may lie somewhere in a perfect balance of Lr and Batch Size.

Hyperparameters


lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["embed_tokens", "x_proj", "in_proj", "out_proj"],
    lora_dropout=0.2,  
    task_type="CAUSAL_LM",
    bias="none"
)

trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,
    dataset_text_field="text",
    max_seq_length=max_seq_length,
    tokenizer=tokenizer,
    args=TrainingArguments(
        num_train_epochs=1,
        lr_scheduler_type='linear',
        learning_rate=2e-5,
        per_device_train_batch_size=1,
        gradient_accumulation_steps=8,
        gradient_checkpointing=True,
        warmup_steps=10,  
        weight_decay=0.2,  
        fp16=not torch.cuda.is_bf16_supported(),
        bf16=torch.cuda.is_bf16_supported(),
        logging_steps=1,  
        save_steps=100, 
        output_dir="outputs",
        optim="paged_adamw_8bit",
        seed=42,
    ),
)