stmasson commited on
Commit
2cd4c7b
·
verified ·
1 Parent(s): f5c4f4a

Upload scripts/train_qwen3_dpo_reasoning.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/train_qwen3_dpo_reasoning.py +239 -0
scripts/train_qwen3_dpo_reasoning.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = [
4
+ # "transformers>=4.45.0",
5
+ # "trl>=0.12.0",
6
+ # "peft>=0.13.0",
7
+ # "datasets>=3.0.0",
8
+ # "accelerate>=1.0.0",
9
+ # "huggingface_hub>=0.26.0",
10
+ # "torch>=2.4.0",
11
+ # "bitsandbytes>=0.44.0",
12
+ # ]
13
+ # [tool.uv]
14
+ # index-strategy = "unsafe-best-match"
15
+ # extra-index-url = ["https://download.pytorch.org/whl/cu124"]
16
+ # ///
17
+ """
18
+ DPO Training Script for Qwen3-0.6B on n8n Workflow Reasoning
19
+
20
+ This script fine-tunes Qwen3-0.6B using Direct Preference Optimization (DPO)
21
+ to improve reasoning quality when generating n8n workflows.
22
+
23
+ The dataset contains:
24
+ - prompt: task description for generating n8n workflow
25
+ - chosen: high-quality response with detailed <thinking> reasoning
26
+ - rejected: low-quality response with superficial reasoning or errors
27
+
28
+ Usage:
29
+ hf jobs uv run \
30
+ --script train_qwen3_dpo_reasoning.py \
31
+ --flavor l40sx1 \
32
+ --name qwen3-dpo-reasoning \
33
+ --timeout 12h
34
+ """
35
+
36
+ import os
37
+ import torch
38
+ from datasets import load_dataset
39
+ from transformers import AutoModelForCausalLM, AutoTokenizer
40
+ from peft import LoraConfig
41
+ from trl import DPOConfig, DPOTrainer
42
+ from huggingface_hub import login
43
+
44
+ # ============================================================================
45
+ # CONFIGURATION
46
+ # ============================================================================
47
+
48
+ # Base model
49
+ MODEL_NAME = os.environ.get("BASE_MODEL", "Qwen/Qwen3-0.6B")
50
+
51
+ # Dataset
52
+ DATASET_REPO = "stmasson/n8n-workflows-thinking"
53
+ DATA_DIR = "data/dpo"
54
+
55
+ # Output
56
+ OUTPUT_DIR = "./qwen3-dpo-reasoning"
57
+ HF_REPO = os.environ.get("HF_REPO", "stmasson/qwen3-0.6b-n8n-reasoning")
58
+
59
+ # Hyperparameters
60
+ NUM_EPOCHS = int(os.environ.get("NUM_EPOCHS", "1"))
61
+ BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "1"))
62
+ GRAD_ACCUM = int(os.environ.get("GRAD_ACCUM", "8"))
63
+ LEARNING_RATE = float(os.environ.get("LEARNING_RATE", "5e-6"))
64
+ MAX_LENGTH = int(os.environ.get("MAX_LENGTH", "4096"))
65
+ MAX_PROMPT_LENGTH = int(os.environ.get("MAX_PROMPT_LENGTH", "512"))
66
+ BETA = float(os.environ.get("BETA", "0.1")) # DPO beta parameter
67
+
68
+ # LoRA configuration
69
+ LORA_R = int(os.environ.get("LORA_R", "32"))
70
+ LORA_ALPHA = int(os.environ.get("LORA_ALPHA", "64"))
71
+ LORA_DROPOUT = float(os.environ.get("LORA_DROPOUT", "0.05"))
72
+
73
+ # ============================================================================
74
+ # AUTHENTICATION
75
+ # ============================================================================
76
+
77
+ print("=" * 60)
78
+ print("DPO TRAINING - QWEN3-0.6B N8N REASONING")
79
+ print("=" * 60)
80
+
81
+ hf_token = os.environ.get("HF_TOKEN")
82
+ if hf_token:
83
+ login(token=hf_token)
84
+ print("Authenticated with HuggingFace")
85
+ else:
86
+ print("Warning: HF_TOKEN not set, push disabled")
87
+
88
+ # ============================================================================
89
+ # LOAD MODEL AND TOKENIZER
90
+ # ============================================================================
91
+
92
+ print(f"\nLoading model: {MODEL_NAME}")
93
+
94
+ model = AutoModelForCausalLM.from_pretrained(
95
+ MODEL_NAME,
96
+ torch_dtype=torch.bfloat16,
97
+ attn_implementation="sdpa",
98
+ device_map="auto",
99
+ trust_remote_code=True,
100
+ )
101
+
102
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
103
+ if tokenizer.pad_token is None:
104
+ tokenizer.pad_token = tokenizer.eos_token
105
+ tokenizer.padding_side = "left" # Important for DPO
106
+
107
+ print(f"Model loaded: {model.config.num_hidden_layers} layers, {model.config.hidden_size} hidden size")
108
+
109
+ # ============================================================================
110
+ # LORA CONFIGURATION
111
+ # ============================================================================
112
+
113
+ print(f"\nLoRA config: r={LORA_R}, alpha={LORA_ALPHA}")
114
+
115
+ peft_config = LoraConfig(
116
+ r=LORA_R,
117
+ lora_alpha=LORA_ALPHA,
118
+ target_modules=[
119
+ "q_proj", "k_proj", "v_proj", "o_proj",
120
+ "gate_proj", "up_proj", "down_proj"
121
+ ],
122
+ lora_dropout=LORA_DROPOUT,
123
+ bias="none",
124
+ task_type="CAUSAL_LM"
125
+ )
126
+
127
+ # ============================================================================
128
+ # LOAD DATASET
129
+ # ============================================================================
130
+
131
+ print(f"\nLoading dataset: {DATASET_REPO}")
132
+
133
+ train_dataset = load_dataset(DATASET_REPO, data_dir=DATA_DIR, split="train")
134
+ eval_dataset = load_dataset(DATASET_REPO, data_dir=DATA_DIR, split="validation")
135
+
136
+ print(f"Train: {len(train_dataset)} examples")
137
+ print(f"Validation: {len(eval_dataset)} examples")
138
+
139
+ # Filter out extremely long examples to avoid OOM
140
+ def filter_by_length(example):
141
+ prompt_len = len(example["prompt"])
142
+ chosen_len = len(example["chosen"])
143
+ rejected_len = len(example["rejected"])
144
+ # Filter examples where total chars > 50000 (roughly 12500 tokens)
145
+ return (prompt_len + max(chosen_len, rejected_len)) < 50000
146
+
147
+ train_dataset = train_dataset.filter(filter_by_length)
148
+ eval_dataset = eval_dataset.filter(filter_by_length)
149
+
150
+ print(f"After filtering - Train: {len(train_dataset)}, Val: {len(eval_dataset)}")
151
+
152
+ # Show example
153
+ print("\nExample prompt:", train_dataset[0]["prompt"][:100], "...")
154
+
155
+ # ============================================================================
156
+ # DPO TRAINING CONFIGURATION
157
+ # ============================================================================
158
+
159
+ print(f"\nTraining configuration:")
160
+ print(f" - Epochs: {NUM_EPOCHS}")
161
+ print(f" - Batch size: {BATCH_SIZE}")
162
+ print(f" - Gradient accumulation: {GRAD_ACCUM}")
163
+ print(f" - Effective batch size: {BATCH_SIZE * GRAD_ACCUM}")
164
+ print(f" - Learning rate: {LEARNING_RATE}")
165
+ print(f" - Max length: {MAX_LENGTH}")
166
+ print(f" - DPO beta: {BETA}")
167
+
168
+ training_args = DPOConfig(
169
+ output_dir=OUTPUT_DIR,
170
+ num_train_epochs=NUM_EPOCHS,
171
+ per_device_train_batch_size=BATCH_SIZE,
172
+ per_device_eval_batch_size=BATCH_SIZE,
173
+ gradient_accumulation_steps=GRAD_ACCUM,
174
+ learning_rate=LEARNING_RATE,
175
+ lr_scheduler_type="cosine",
176
+ warmup_ratio=0.1,
177
+ weight_decay=0.01,
178
+ bf16=True,
179
+ tf32=True,
180
+ logging_steps=10,
181
+ save_strategy="steps",
182
+ save_steps=500,
183
+ save_total_limit=3,
184
+ eval_strategy="steps",
185
+ eval_steps=500,
186
+ max_length=MAX_LENGTH,
187
+ max_prompt_length=MAX_PROMPT_LENGTH,
188
+ beta=BETA,
189
+ loss_type="sigmoid", # Standard DPO loss
190
+ gradient_checkpointing=True,
191
+ gradient_checkpointing_kwargs={"use_reentrant": False},
192
+ report_to="none",
193
+ run_name="qwen3-dpo-reasoning",
194
+ hub_model_id=HF_REPO if hf_token else None,
195
+ push_to_hub=bool(hf_token),
196
+ hub_strategy="checkpoint",
197
+ )
198
+
199
+ # ============================================================================
200
+ # TRAINING
201
+ # ============================================================================
202
+
203
+ print("\nInitializing DPO trainer...")
204
+
205
+ trainer = DPOTrainer(
206
+ model=model,
207
+ args=training_args,
208
+ train_dataset=train_dataset,
209
+ eval_dataset=eval_dataset,
210
+ peft_config=peft_config,
211
+ processing_class=tokenizer,
212
+ )
213
+
214
+ # Show trainable parameters
215
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
216
+ total_params = sum(p.numel() for p in model.parameters())
217
+ print(f"\nTrainable parameters: {trainable_params:,} / {total_params:,} ({100 * trainable_params / total_params:.2f}%)")
218
+
219
+ print("\n" + "=" * 60)
220
+ print("STARTING DPO TRAINING")
221
+ print("=" * 60)
222
+
223
+ trainer.train()
224
+
225
+ # ============================================================================
226
+ # SAVE MODEL
227
+ # ============================================================================
228
+
229
+ print("\nSaving model...")
230
+ trainer.save_model(f"{OUTPUT_DIR}/final")
231
+
232
+ if hf_token:
233
+ print(f"Pushing to {HF_REPO}...")
234
+ trainer.push_to_hub()
235
+ print(f"Model available at: https://huggingface.co/{HF_REPO}")
236
+
237
+ print("\n" + "=" * 60)
238
+ print("DPO TRAINING COMPLETE")
239
+ print("=" * 60)