|
|
|
|
|
|
|
""" |
|
Run script for fine-tuning OlmoE with adapters on specific text domains. |
|
Handles argument parsing and configuration. |
|
""" |
|
|
|
import argparse |
|
import os |
|
import sys |
|
from dataclasses import dataclass, field |
|
from typing import Optional |
|
|
|
from transformers import ( |
|
HfArgumentParser, |
|
TrainingArguments, |
|
) |
|
|
|
|
|
@dataclass |
|
class ScriptArguments: |
|
""" |
|
Arguments for the run script that aren't covered by TrainingArguments. |
|
""" |
|
model_path: str = field( |
|
default="allenai/OLMo-7B-Instruct", |
|
metadata={"help": "Path to the model to fine-tune"} |
|
) |
|
output_dir: str = field( |
|
default="./output_olmoe_adapter", |
|
metadata={"help": "Directory to save the model and logs"} |
|
) |
|
adapter_size: int = field( |
|
default=64, |
|
metadata={"help": "Size of the adapter layers"} |
|
) |
|
dataset_name: str = field( |
|
default="mlfoundations/dclm-baseline-1.0", |
|
metadata={"help": "Name of the dataset to use"} |
|
) |
|
max_steps: int = field( |
|
default=10000, |
|
metadata={"help": "Maximum number of training steps"} |
|
) |
|
learning_rate: float = field( |
|
default=5e-5, |
|
metadata={"help": "Learning rate for fine-tuning"} |
|
) |
|
per_device_batch_size: int = field( |
|
default=8, |
|
metadata={"help": "Batch size per device"} |
|
) |
|
gradient_accumulation_steps: int = field( |
|
default=1, |
|
metadata={"help": "Number of steps to accumulate gradients"} |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
|
parser = HfArgumentParser(ScriptArguments) |
|
args = parser.parse_args_into_dataclasses()[0] |
|
|
|
|
|
os.makedirs(args.output_dir, exist_ok=True) |
|
|
|
|
|
cmd = [ |
|
"python", |
|
"train_olmoe_adapter.py", |
|
|
|
|
|
f"--model_name_or_path={args.model_path}", |
|
f"--adapter_size={args.adapter_size}", |
|
"--freeze_base_model=True", |
|
f"--checkpoint_dir={args.output_dir}", |
|
|
|
|
|
f"--dataset_name={args.dataset_name}", |
|
"--streaming=True", |
|
"--streaming_buffer_size=8192", |
|
"--max_seq_length=1024", |
|
|
|
|
|
f"--output_dir={args.output_dir}", |
|
f"--per_device_train_batch_size={args.per_device_batch_size}", |
|
f"--gradient_accumulation_steps={args.gradient_accumulation_steps}", |
|
f"--learning_rate={args.learning_rate}", |
|
f"--max_steps={args.max_steps}", |
|
"--warmup_steps=500", |
|
"--logging_steps=10", |
|
"--save_steps=1000", |
|
"--save_total_limit=2", |
|
"--dataloader_num_workers=4", |
|
"--seed=42", |
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cmd_str = " ".join(cmd) |
|
print(f"Running command: {cmd_str}") |
|
|
|
|
|
os.environ["PYTHONPATH"] = os.getcwd() |
|
ret = os.system(cmd_str) |
|
|
|
if ret != 0: |
|
print(f"Training failed with exit code {ret}") |
|
sys.exit(ret) |
|
|
|
print("Training completed successfully!") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |