File size: 3,699 Bytes
ccda2ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# train.py
# runs train_olmoe_adapter.py with parameters when called
# #!/usr/bin/env python
"""
Run script for fine-tuning OlmoE with adapters on specific text domains.
Handles argument parsing and configuration.
"""

import argparse
import os
import sys
from dataclasses import dataclass, field
from typing import Optional

from transformers import (
    HfArgumentParser,
    TrainingArguments,
)


@dataclass
class ScriptArguments:
    """
    Arguments for the run script that aren't covered by TrainingArguments.
    """
    model_path: str = field(
        default="allenai/OLMo-7B-Instruct",
        metadata={"help": "Path to the model to fine-tune"}
    )
    output_dir: str = field(
        default="./output_olmoe_adapter",
        metadata={"help": "Directory to save the model and logs"}
    )
    adapter_size: int = field(
        default=64,
        metadata={"help": "Size of the adapter layers"}
    )
    dataset_name: str = field(
        default="mlfoundations/dclm-baseline-1.0",
        metadata={"help": "Name of the dataset to use"}
    )
    max_steps: int = field(
        default=10000,
        metadata={"help": "Maximum number of training steps"}
    )
    learning_rate: float = field(
        default=5e-5,
        metadata={"help": "Learning rate for fine-tuning"}
    )
    per_device_batch_size: int = field(
        default=8,
        metadata={"help": "Batch size per device"}
    )
    gradient_accumulation_steps: int = field(
        default=1,
        metadata={"help": "Number of steps to accumulate gradients"}
    )
    # use_8bit: bool = field(
    #     default=False,
    #     metadata={"help": "Whether to use 8-bit precision"}
    # )
    # use_4bit: bool = field(
    #     default=False,
    #     metadata={"help": "Whether to use 4-bit precision"}
    # )


def main():
    # Parse command-line arguments
    parser = HfArgumentParser(ScriptArguments)
    args = parser.parse_args_into_dataclasses()[0]
    
    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Prepare command for training
    cmd = [
        "python",
        "train_olmoe_adapter.py",
        
        # Model arguments
        f"--model_name_or_path={args.model_path}",
        f"--adapter_size={args.adapter_size}",
        "--freeze_base_model=True",  # Always freeze the base model
        f"--checkpoint_dir={args.output_dir}",
        
        # Data arguments
        f"--dataset_name={args.dataset_name}",
        "--streaming=True",  # Always stream for large datasets
        "--streaming_buffer_size=8192",
        "--max_seq_length=1024",
        
        # Training arguments
        f"--output_dir={args.output_dir}",
        f"--per_device_train_batch_size={args.per_device_batch_size}",
        f"--gradient_accumulation_steps={args.gradient_accumulation_steps}",
        f"--learning_rate={args.learning_rate}",
        f"--max_steps={args.max_steps}",
        "--warmup_steps=500",
        "--logging_steps=10",
        "--save_steps=1000",
        "--save_total_limit=2",
        "--dataloader_num_workers=4",
        "--seed=42",
    ]
    
    # Add precision flags if needed
    # if args.use_8bit:
    #     cmd.append("--load_in_8bit")
    
    # if args.use_4bit:
    #     cmd.append("--load_in_4bit")
    
    # Print the command for logging
    cmd_str = " ".join(cmd)
    print(f"Running command: {cmd_str}")
    
    # Execute the training script
    os.environ["PYTHONPATH"] = os.getcwd()
    ret = os.system(cmd_str)
    
    if ret != 0:
        print(f"Training failed with exit code {ret}")
        sys.exit(ret)
    
    print("Training completed successfully!")


if __name__ == "__main__":
    main()