SkipMoE / train.py
chengyanwu
stuff
ccda2ec
# train.py
# runs train_olmoe_adapter.py with parameters when called
# #!/usr/bin/env python
"""
Run script for fine-tuning OlmoE with adapters on specific text domains.
Handles argument parsing and configuration.
"""
import argparse
import os
import sys
from dataclasses import dataclass, field
from typing import Optional
from transformers import (
HfArgumentParser,
TrainingArguments,
)
@dataclass
class ScriptArguments:
"""
Arguments for the run script that aren't covered by TrainingArguments.
"""
model_path: str = field(
default="allenai/OLMo-7B-Instruct",
metadata={"help": "Path to the model to fine-tune"}
)
output_dir: str = field(
default="./output_olmoe_adapter",
metadata={"help": "Directory to save the model and logs"}
)
adapter_size: int = field(
default=64,
metadata={"help": "Size of the adapter layers"}
)
dataset_name: str = field(
default="mlfoundations/dclm-baseline-1.0",
metadata={"help": "Name of the dataset to use"}
)
max_steps: int = field(
default=10000,
metadata={"help": "Maximum number of training steps"}
)
learning_rate: float = field(
default=5e-5,
metadata={"help": "Learning rate for fine-tuning"}
)
per_device_batch_size: int = field(
default=8,
metadata={"help": "Batch size per device"}
)
gradient_accumulation_steps: int = field(
default=1,
metadata={"help": "Number of steps to accumulate gradients"}
)
# use_8bit: bool = field(
# default=False,
# metadata={"help": "Whether to use 8-bit precision"}
# )
# use_4bit: bool = field(
# default=False,
# metadata={"help": "Whether to use 4-bit precision"}
# )
def main():
# Parse command-line arguments
parser = HfArgumentParser(ScriptArguments)
args = parser.parse_args_into_dataclasses()[0]
# Create output directory
os.makedirs(args.output_dir, exist_ok=True)
# Prepare command for training
cmd = [
"python",
"train_olmoe_adapter.py",
# Model arguments
f"--model_name_or_path={args.model_path}",
f"--adapter_size={args.adapter_size}",
"--freeze_base_model=True", # Always freeze the base model
f"--checkpoint_dir={args.output_dir}",
# Data arguments
f"--dataset_name={args.dataset_name}",
"--streaming=True", # Always stream for large datasets
"--streaming_buffer_size=8192",
"--max_seq_length=1024",
# Training arguments
f"--output_dir={args.output_dir}",
f"--per_device_train_batch_size={args.per_device_batch_size}",
f"--gradient_accumulation_steps={args.gradient_accumulation_steps}",
f"--learning_rate={args.learning_rate}",
f"--max_steps={args.max_steps}",
"--warmup_steps=500",
"--logging_steps=10",
"--save_steps=1000",
"--save_total_limit=2",
"--dataloader_num_workers=4",
"--seed=42",
]
# Add precision flags if needed
# if args.use_8bit:
# cmd.append("--load_in_8bit")
# if args.use_4bit:
# cmd.append("--load_in_4bit")
# Print the command for logging
cmd_str = " ".join(cmd)
print(f"Running command: {cmd_str}")
# Execute the training script
os.environ["PYTHONPATH"] = os.getcwd()
ret = os.system(cmd_str)
if ret != 0:
print(f"Training failed with exit code {ret}")
sys.exit(ret)
print("Training completed successfully!")
if __name__ == "__main__":
main()