Jamba-Hercules / README.md
Severian's picture
Update README.md
041c2f1 verified
|
raw
history blame
No virus
2.3 kB
---
license: apache-2.0
tags:
- jamba
datasets:
- teknium/OpenHermes-2.5
base_model: ai21labs/Jamba-v0.1
pipeline_tag: text-generation
---
# Jamba-Open-Hermes
<img src="https://cdn-uploads.huggingface.co/production/uploads/64740cf7485a7c8e1bd51ac9/Ph6ZvxwF7a0m_B5Su_EK7.webp" width="500" height="500">
# This is highly experimental and should be viewed as purely testing right now. Jamba has been very hard to train but I wanted to see how it did on one of the best datasets we have access to. I believe in transparent development so all *best* working iterations, even if they are a bit wonky, will be pushed here.
---
# New training underway! Thanks to the generous insights provided by **lightblue/Jamba-v0.1-chat-multilingual**, the new training is going much better. We should hopefully have a decently trained Jamaba-Open-Hermes model for general use and experimentation.
*There's been limited testing so no example outputs yet*
---
## Training
### Open-Hermes-2.0 (Only first 1500 examples): **[ 1530/125193 4:46:45 < 386:48:08, 0.09 it/s, Epoch 0.01/1]**
**Notes:**
- Tried over 30+ combinations of hyperparameters. Below are the best I could land on.
- Loss hovered around ~5-6 no matter what I tried with the learning rate.
- Couldn't increase batch size due to Colab limitations, so the answer may lie somewhere in a perfect balance of Lr and Batch Size.
### Hyperparameters
```py
lora_config = LoraConfig(
r=8,
lora_alpha=16,
target_modules=["embed_tokens", "x_proj", "in_proj", "out_proj"],
lora_dropout=0.05,
task_type="CAUSAL_LM",
bias="none"
)
trainer = SFTTrainer(
model=model,
train_dataset=train_dataset,
dataset_text_field="text",
max_seq_length=max_seq_length,
tokenizer=tokenizer,
args=TrainingArguments(
num_train_epochs=1,
lr_scheduler_type='cosine',
learning_rate=0.0002,
per_device_train_batch_size=1,
gradient_accumulation_steps=8,
gradient_checkpointing=True,
warmup_steps=10,
weight_decay=0.01,
fp16=not torch.cuda.is_bf16_supported(),
bf16=torch.cuda.is_bf16_supported(),
logging_steps=1,
save_steps=200,
output_dir="outputs",
optim="adamw_8bit",
seed=42,
),
)
```