Jamba-Hercules / README.md
Severian's picture
Update README.md
169463c verified
metadata
license: apache-2.0
tags:
  - jamba
datasets:
  - teknium/OpenHermes-2.5
base_model: ai21labs/Jamba-v0.1
pipeline_tag: text-generation

Jamba-Open-Hermes

This is highly experimental and should be viewed as purely testing right now. Jamba has been very hard to train but I wanted to see how it did on one of the best datasets we have access to. I believe in transparent development so all best working iterations, even if they are a bit wonky, will be pushed here.

I've unfortunately gone way over budget and spent a significant amount of money over the past few days trying to figure the best way to fine-tune Jamba. New iterations may be sparse until Jamba is coverted to MLX or I find buried treasure somewhere. If you've downloaded it, feel free to provde any feedback so I can improve on the next training cycle! Thanks for checking it out.

There's been limited testing so no example outputs yet


Training

Open-Hermes-2.0 (Only first 1500 examples): [ 1530/125193 4:46:45 < 386:48:08, 0.09 it/s, Epoch 0.01/1]

1483	5.986700
1484	5.764100
1485	5.887200
1486	5.445200
1487	6.086300
1488	5.718300
1489	5.670300
1490	5.440900
1491	4.945900
1492	6.154700
1493	5.624800
1494	6.868100
1495	5.627100
1496	5.192700
1497	5.826800
1498	5.512200
1499	5.869900
1500	5.852300
1501	5.574800
1502	5.299200
1503	5.631200
1504	5.535600
1505	5.626000
1506	5.093300
1507	5.278000
1508	5.585400
1509	5.318600
1510	5.319200
1511	5.513900
1512	5.375400
1513	5.460600
1514	5.045300
1515	6.013600
1516	5.812300
1517	5.707400
1518	5.109800
1519	5.212900
1520	5.317200
1521	5.935400
1522	5.733900
1523	5.866000
1524	5.675400
1525	5.580800
1526	4.996900
1527	5.666700
1528	4.979900

Hyperparameters


lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["embed_tokens", "x_proj", "in_proj", "out_proj"],
    lora_dropout=0.2,  
    task_type="CAUSAL_LM",
    bias="none"
)

trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,
    dataset_text_field="text",
    max_seq_length=max_seq_length,
    tokenizer=tokenizer,
    args=TrainingArguments(
        num_train_epochs=1,
        lr_scheduler_type='linear',
        learning_rate=2e-5,
        per_device_train_batch_size=1,
        gradient_accumulation_steps=8,
        gradient_checkpointing=True,
        warmup_steps=10,  
        weight_decay=0.2,  
        fp16=not torch.cuda.is_bf16_supported(),
        bf16=torch.cuda.is_bf16_supported(),
        logging_steps=1,  
        save_steps=100, 
        output_dir="outputs",
        optim="paged_adamw_8bit",
        seed=42,
    ),
)