File size: 3,699 Bytes
ccda2ec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
# train.py
# runs train_olmoe_adapter.py with parameters when called
# #!/usr/bin/env python
"""
Run script for fine-tuning OlmoE with adapters on specific text domains.
Handles argument parsing and configuration.
"""
import argparse
import os
import sys
from dataclasses import dataclass, field
from typing import Optional
from transformers import (
HfArgumentParser,
TrainingArguments,
)
@dataclass
class ScriptArguments:
"""
Arguments for the run script that aren't covered by TrainingArguments.
"""
model_path: str = field(
default="allenai/OLMo-7B-Instruct",
metadata={"help": "Path to the model to fine-tune"}
)
output_dir: str = field(
default="./output_olmoe_adapter",
metadata={"help": "Directory to save the model and logs"}
)
adapter_size: int = field(
default=64,
metadata={"help": "Size of the adapter layers"}
)
dataset_name: str = field(
default="mlfoundations/dclm-baseline-1.0",
metadata={"help": "Name of the dataset to use"}
)
max_steps: int = field(
default=10000,
metadata={"help": "Maximum number of training steps"}
)
learning_rate: float = field(
default=5e-5,
metadata={"help": "Learning rate for fine-tuning"}
)
per_device_batch_size: int = field(
default=8,
metadata={"help": "Batch size per device"}
)
gradient_accumulation_steps: int = field(
default=1,
metadata={"help": "Number of steps to accumulate gradients"}
)
# use_8bit: bool = field(
# default=False,
# metadata={"help": "Whether to use 8-bit precision"}
# )
# use_4bit: bool = field(
# default=False,
# metadata={"help": "Whether to use 4-bit precision"}
# )
def main():
# Parse command-line arguments
parser = HfArgumentParser(ScriptArguments)
args = parser.parse_args_into_dataclasses()[0]
# Create output directory
os.makedirs(args.output_dir, exist_ok=True)
# Prepare command for training
cmd = [
"python",
"train_olmoe_adapter.py",
# Model arguments
f"--model_name_or_path={args.model_path}",
f"--adapter_size={args.adapter_size}",
"--freeze_base_model=True", # Always freeze the base model
f"--checkpoint_dir={args.output_dir}",
# Data arguments
f"--dataset_name={args.dataset_name}",
"--streaming=True", # Always stream for large datasets
"--streaming_buffer_size=8192",
"--max_seq_length=1024",
# Training arguments
f"--output_dir={args.output_dir}",
f"--per_device_train_batch_size={args.per_device_batch_size}",
f"--gradient_accumulation_steps={args.gradient_accumulation_steps}",
f"--learning_rate={args.learning_rate}",
f"--max_steps={args.max_steps}",
"--warmup_steps=500",
"--logging_steps=10",
"--save_steps=1000",
"--save_total_limit=2",
"--dataloader_num_workers=4",
"--seed=42",
]
# Add precision flags if needed
# if args.use_8bit:
# cmd.append("--load_in_8bit")
# if args.use_4bit:
# cmd.append("--load_in_4bit")
# Print the command for logging
cmd_str = " ".join(cmd)
print(f"Running command: {cmd_str}")
# Execute the training script
os.environ["PYTHONPATH"] = os.getcwd()
ret = os.system(cmd_str)
if ret != 0:
print(f"Training failed with exit code {ret}")
sys.exit(ret)
print("Training completed successfully!")
if __name__ == "__main__":
main() |