File size: 5,466 Bytes
0ded6bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc29a51
0ded6bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b4f1cb3
c68717e
 
0ded6bb
 
 
 
c68717e
0ded6bb
 
d455d12
 
 
 
0ded6bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81f39f1
 
0ded6bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c68717e
0ded6bb
 
 
 
 
 
 
 
 
 
c68717e
 
0ded6bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
"""
GPT-OSS Medical o1 SFT Training Configuration
Dataset: FreedomIntelligence/medical-o1-reasoning-SFT
Format: Question | Complex_CoT | Response β†’ GPT-OSS Harmony text

This configuration uses GPT-OSS Harmony formatting to combine the medical
dataset's question, chain-of-thought (Complex_CoT), and final response into a
single assistant turn, with optional system and developer messages.
"""

from config.train_gpt_oss_custom import GPTOSSEnhancedCustomConfig

# Medical-o1 SFT configuration for GPT-OSS
config = GPTOSSEnhancedCustomConfig(
    # ============================================================================
    # DATASET CONFIGURATION
    # ============================================================================
    dataset_name="FreedomIntelligence/medical-o1-reasoning-SFT",
    dataset_config="en",               # Use English split by default (can be changed to en_mix/zh/zh_mix)
    dataset_split="train",
    dataset_format="medical_o1_sft",   # Enable medical formatter in training script

    # Field mapping and prefixes
    input_field="Question",            # used for length filtering pre-format
    target_field="Response",           # used for length filtering pre-format
    question_field="Question",
    reasoning_field="Complex_CoT",
    response_field="Response",
    reason_prefix="Reasoning: ",
    answer_prefix="Final Answer: ",

    # GPT-OSS Harmony formatting
    use_harmony_format=True,
    use_chat_template=False,
    system_message=(
        "You are GPT-Tonic, a large language model trained by TonicAI."
    ),
    developer_message=(
        "You are are GPT-Tonic, an intelligent assistant that always answers health-related queries scientifically."
    ),
    chat_template_kwargs={
        "add_generation_prompt": True,
        "tokenize": False,
        "reasoning_effort": "low",
        "model_identity": "You are GPT-Tonic, a large language model trained by TonicAI.",
        "builtin_tools": [],
    },

    # Filtering & sampling
    filter_bad_entries=False,
    max_samples=None,
    min_length=10,
    max_length=2048,

    # ============================================================================
    # TRAINING HYPERPARAMETERS
    # ============================================================================
    num_train_epochs=2.0,
    batch_size=4,
    gradient_accumulation_steps=4,
    learning_rate=2e-4,
    min_lr=2e-5,
    weight_decay=0.01,
    warmup_ratio=0.03,
    warmup_steps=50,
    max_grad_norm=1.0,

    # Scheduler: use broadly compatible cosine by default to avoid TRL signature issues
    scheduler="cosine",
    lr_scheduler_kwargs={},

    # Sequence length
    max_seq_length=2048,

    # ============================================================================
    # MIXED PRECISION / PERFORMANCE
    # ============================================================================
    fp16=False,
    bf16=True,
    tf32=True,

    dataloader_num_workers=4,
    dataloader_pin_memory=True,
    dataloader_prefetch_factor=2,
    group_by_length=True,
    remove_unused_columns=True,

    # ============================================================================
    # LORA & QUANTIZATION
    # ============================================================================
    use_lora=True,
    lora_config={
        "r": 16,
        "lora_alpha": 32,
        "lora_dropout": 0.05,
        "target_modules": "all-linear",
        "target_parameters": [
            "7.mlp.experts.gate_up_proj",
            "7.mlp.experts.down_proj",
            "15.mlp.experts.gate_up_proj",
            "15.mlp.experts.down_proj",
            "23.mlp.experts.gate_up_proj",
            "23.mlp.experts.down_proj",
        ],
        "bias": "none",
        "task_type": "CAUSAL_LM",
    },

    use_quantization=True,
    quantization_config={
        "dequantize": True,
        "load_in_4bit": False,
        # Optional MXFP4 config is auto-applied by training script if available
    },

    # ============================================================================
    # LOGGING & EVAL
    # ============================================================================
    eval_strategy="steps",
    eval_steps=100,
    logging_steps=10,
    save_strategy="steps",
    save_steps=500,
    save_total_limit=3,
    save_only_model=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    load_best_model_at_end=False,
    eval_accumulation_steps=2,
    eval_batch_size=1,
    eval_ratio=0.001,
    test_ratio=0.0005,

    # ============================================================================
    # MONITORING & HUB
    # ============================================================================
    enable_tracking=True,
    log_artifacts=False,
    log_metrics=True,
    log_config=True,
    push_to_hub=False,
    hub_model_id=None,
    hub_private_repo=False,
)

# Quick summary for visibility when the config is imported
print("\n🩺 GPT-OSS Medical o1 SFT Configuration")
print("=" * 60)
print(f"πŸ“Š Dataset: {config.dataset_name} [{config.dataset_config}] (medical_o1_sft)")
print(f"πŸ“ˆ Training: {config.num_train_epochs} epoch | batch {config.batch_size} x acc {config.gradient_accumulation_steps}")
print(f"🧠 LoRA Rank: {config.lora_config['r']}")
print(f"πŸ“ Sequence Length: {config.max_seq_length}")
print(f"🎡 Harmony Format: {'Enabled' if config.use_harmony_format else 'Disabled'}")
print("=" * 60)