sukritvemula commited on
Commit
d2983fa
·
verified ·
1 Parent(s): 063cab2

Add validated training script

Browse files
Files changed (1) hide show
  1. train_coding_agent.py +267 -0
train_coding_agent.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Qwen3-8B Coding & Agentic Reasoning Expert — Multi-Dataset SFT Training
3
+ ========================================================================
4
+ Base: Qwen/Qwen3-8B (Apache 2.0, 8.2B params, 32K context)
5
+ Method: QLoRA SFT with assistant-only loss masking
6
+ Datasets:
7
+ - TIGER-Lab/VisCode-200K (visualization/chart generation) — ChatML ready
8
+ - m-a-p/CodeFeedback-Filtered-Instruction (code instruction tuning)
9
+ - nvidia/OpenCodeReasoning (reasoning with <think> blocks)
10
+ - glaiveai/glaive-function-calling-v2 (tool calling)
11
+ - ise-uiuc/Magicoder-OSS-Instruct-75K (code generation)
12
+
13
+ Recipe: Based on Qwen3-Coder-Next + LoRA Without Regret papers
14
+ Target: Coding + agentic reasoning + visualization + tool-use expert
15
+
16
+ Usage:
17
+ pip install transformers>=4.51.0 trl>=1.3.0 peft>=0.15.0 datasets accelerate bitsandbytes torch trackio
18
+ HUB_MODEL_ID=your-username/model-name python train_coding_agent.py
19
+ """
20
+
21
+ import os
22
+ import re
23
+ import json
24
+ import torch
25
+ import trackio
26
+ from datasets import load_dataset, concatenate_datasets, Dataset
27
+ from transformers import AutoTokenizer, BitsAndBytesConfig, TrainerCallback
28
+ from trl import SFTTrainer, SFTConfig
29
+ from peft import LoraConfig, TaskType
30
+
31
+ # ============================================================
32
+ # Configuration
33
+ # ============================================================
34
+ MODEL_ID = "Qwen/Qwen3-8B"
35
+ OUTPUT_DIR = "./qwen3-8b-coding-agent"
36
+ HUB_MODEL_ID = os.environ.get("HUB_MODEL_ID", "sukritvemula/Qwen3-8B-CodeAgent")
37
+
38
+ # Training hyperparameters (from Qwen3 + LoRA Without Regret papers)
39
+ LEARNING_RATE = 2e-4
40
+ NUM_EPOCHS = 2
41
+ BATCH_SIZE = 2
42
+ GRAD_ACCUM = 8
43
+ MAX_LENGTH = 4096
44
+ LORA_R = 64
45
+ LORA_ALPHA = 16
46
+ WARMUP_RATIO = 0.05
47
+
48
+ # Dataset proportions (~50K samples)
49
+ MAX_VISCODE = 12000
50
+ MAX_CODEFEEDBACK = 10000
51
+ MAX_OPENCODE = 10000
52
+ MAX_GLAIVE = 8000
53
+ MAX_MAGICODER = 10000
54
+
55
+ SYSTEM_PROMPT = """You are an expert AI assistant specialized in coding, agentic reasoning, data visualization, and tool use. You can:
56
+ 1. Write, debug, and explain code in any programming language
57
+ 2. Reason step-by-step through complex problems using <think>...</think> blocks
58
+ 3. Generate charts, graphs, and data visualizations using matplotlib, plotly, seaborn
59
+ 4. Call functions and tools when needed, returning structured JSON for tool invocations
60
+ 5. Search the web and read research papers to provide accurate, up-to-date information
61
+ 6. Replicate images and diagrams programmatically
62
+
63
+ Always think carefully before responding. Be precise, avoid hallucination, and cite sources when possible."""
64
+
65
+
66
+ class AlertCallback(TrainerCallback):
67
+ def __init__(self):
68
+ self.best_loss = float('inf')
69
+ self.initial_loss = None
70
+ self.steps_since_improvement = 0
71
+
72
+ def on_log(self, args, state, control, logs=None, **kwargs):
73
+ if logs is None:
74
+ return
75
+ loss = logs.get("loss")
76
+ if loss is None:
77
+ return
78
+ step = state.global_step
79
+ if self.initial_loss is None:
80
+ self.initial_loss = loss
81
+ trackio.alert(title="Training Started", text=f"Initial loss={loss:.4f} at step {step}. Model: {MODEL_ID}, lr={LEARNING_RATE}, batch={BATCH_SIZE}x{GRAD_ACCUM}={BATCH_SIZE*GRAD_ACCUM}", level="INFO")
82
+ if loss != loss or loss > 20.0:
83
+ trackio.alert(title="DIVERGENCE DETECTED", text=f"loss={loss} at step {step} — training has diverged. lr likely too high, try lr={LEARNING_RATE*0.1:.1e}", level="ERROR")
84
+ return
85
+ if loss < self.best_loss:
86
+ self.best_loss = loss
87
+ self.steps_since_improvement = 0
88
+ else:
89
+ self.steps_since_improvement += 1
90
+ if step > 100 and loss > self.initial_loss * 0.9:
91
+ trackio.alert(title="Slow Convergence", text=f"loss={loss:.4f} at step {step}, only {((self.initial_loss - loss) / self.initial_loss * 100):.1f}% reduction from initial {self.initial_loss:.4f}. Consider lr={LEARNING_RATE*2:.1e}", level="WARN")
92
+ if self.steps_since_improvement > 200:
93
+ trackio.alert(title="Loss Stagnation", text=f"No improvement for {self.steps_since_improvement} steps. Best loss={self.best_loss:.4f}, current={loss:.4f}.", level="WARN")
94
+ if step > 0 and step % 500 == 0:
95
+ trackio.alert(title="Training Milestone", text=f"Step {step}: loss={loss:.4f}, best_loss={self.best_loss:.4f}, lr={logs.get('learning_rate', 'N/A')}", level="INFO")
96
+
97
+
98
+ def process_viscode(max_samples):
99
+ print(f"Loading VisCode-200K (max {max_samples})...")
100
+ ds = load_dataset("TIGER-Lab/VisCode-200K", split=f"train[:{max_samples}]")
101
+ def add_system(example):
102
+ messages = example["messages"]
103
+ if messages and messages[0]["role"] != "system":
104
+ messages = [{"role": "system", "content": SYSTEM_PROMPT}] + messages
105
+ return {"messages": messages}
106
+ ds = ds.map(add_system, num_proc=4)
107
+ print(f" VisCode: {len(ds)} samples loaded")
108
+ return ds
109
+
110
+
111
+ def process_codefeedback(max_samples):
112
+ print(f"Loading CodeFeedback (max {max_samples})...")
113
+ ds = load_dataset("m-a-p/CodeFeedback-Filtered-Instruction", split=f"train[:{max_samples}]")
114
+ def to_messages(example):
115
+ return {"messages": [
116
+ {"role": "system", "content": SYSTEM_PROMPT},
117
+ {"role": "user", "content": example["query"]},
118
+ {"role": "assistant", "content": example["answer"]}
119
+ ]}
120
+ ds = ds.map(to_messages, remove_columns=ds.column_names, num_proc=4)
121
+ print(f" CodeFeedback: {len(ds)} samples loaded")
122
+ return ds
123
+
124
+
125
+ def process_opencode_reasoning(max_samples):
126
+ print(f"Loading OpenCodeReasoning (max {max_samples})...")
127
+ ds = load_dataset("nvidia/OpenCodeReasoning", "split_0", split=f"split_0[:{max_samples}]")
128
+ def to_messages(example):
129
+ return {"messages": [
130
+ {"role": "system", "content": SYSTEM_PROMPT},
131
+ {"role": "user", "content": example["input"]},
132
+ {"role": "assistant", "content": example["output"]}
133
+ ]}
134
+ ds = ds.map(to_messages, remove_columns=ds.column_names, num_proc=4)
135
+ print(f" OpenCodeReasoning: {len(ds)} samples loaded")
136
+ return ds
137
+
138
+
139
+ def process_glaive_function_calling(max_samples):
140
+ print(f"Loading Glaive Function Calling (max {max_samples})...")
141
+ ds = load_dataset("glaiveai/glaive-function-calling-v2", split=f"train[:{max_samples}]")
142
+ def to_messages(example):
143
+ system_content = re.sub(r'^SYSTEM:\s*', '', example["system"])
144
+ chat = example["chat"]
145
+ messages = [{"role": "system", "content": system_content}]
146
+ parts = re.split(r'\n*(USER:|ASSISTANT:|FUNCTION RESPONSE:)', chat)
147
+ current_role, current_content = None, ""
148
+ for part in parts:
149
+ part = part.strip()
150
+ if not part:
151
+ continue
152
+ if part == "USER:":
153
+ if current_role and current_content.strip():
154
+ messages.append({"role": current_role, "content": current_content.strip()})
155
+ current_role, current_content = "user", ""
156
+ elif part == "ASSISTANT:":
157
+ if current_role and current_content.strip():
158
+ messages.append({"role": current_role, "content": current_content.strip()})
159
+ current_role, current_content = "assistant", ""
160
+ elif part == "FUNCTION RESPONSE:":
161
+ if current_role and current_content.strip():
162
+ messages.append({"role": current_role, "content": current_content.strip()})
163
+ current_role, current_content = "user", "[Function Response] "
164
+ else:
165
+ current_content += part
166
+ if current_role and current_content.strip():
167
+ messages.append({"role": current_role, "content": current_content.strip()})
168
+ merged = [messages[0]]
169
+ for msg in messages[1:]:
170
+ if merged and msg["role"] == merged[-1]["role"]:
171
+ merged[-1]["content"] += "\n" + msg["content"]
172
+ else:
173
+ merged.append(msg)
174
+ messages = merged
175
+ if len(messages) < 3 or messages[-1]["role"] != "assistant":
176
+ return {"messages": []}
177
+ return {"messages": messages}
178
+ ds = ds.map(to_messages, remove_columns=ds.column_names, num_proc=4)
179
+ ds = ds.filter(lambda x: len(x["messages"]) >= 3 and any(m["role"] == "assistant" for m in x["messages"]))
180
+ print(f" Glaive Function Calling: {len(ds)} samples loaded")
181
+ return ds
182
+
183
+
184
+ def process_magicoder(max_samples):
185
+ print(f"Loading Magicoder (max {max_samples})...")
186
+ ds = load_dataset("ise-uiuc/Magicoder-OSS-Instruct-75K", split=f"train[:{max_samples}]")
187
+ def to_messages(example):
188
+ return {"messages": [
189
+ {"role": "system", "content": SYSTEM_PROMPT},
190
+ {"role": "user", "content": example["problem"]},
191
+ {"role": "assistant", "content": example["solution"]}
192
+ ]}
193
+ ds = ds.map(to_messages, remove_columns=ds.column_names, num_proc=4)
194
+ print(f" Magicoder: {len(ds)} samples loaded")
195
+ return ds
196
+
197
+
198
+ def main():
199
+ print("=" * 60)
200
+ print("Qwen3-8B Coding & Agentic Reasoning Expert Training")
201
+ print("=" * 60)
202
+
203
+ datasets_list = []
204
+ for loader in [
205
+ lambda: process_viscode(MAX_VISCODE),
206
+ lambda: process_codefeedback(MAX_CODEFEEDBACK),
207
+ lambda: process_opencode_reasoning(MAX_OPENCODE),
208
+ lambda: process_glaive_function_calling(MAX_GLAIVE),
209
+ lambda: process_magicoder(MAX_MAGICODER),
210
+ ]:
211
+ try:
212
+ datasets_list.append(loader())
213
+ except Exception as e:
214
+ print(f" ⚠️ Failed: {e}")
215
+
216
+ if not datasets_list:
217
+ raise ValueError("No datasets loaded!")
218
+
219
+ combined = concatenate_datasets(datasets_list).shuffle(seed=42)
220
+ print(f"✅ Total training samples: {len(combined)}")
221
+
222
+ bnb_config = BitsAndBytesConfig(
223
+ load_in_4bit=True, bnb_4bit_quant_type="nf4",
224
+ bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True,
225
+ )
226
+ peft_config = LoraConfig(
227
+ r=LORA_R, lora_alpha=LORA_ALPHA, lora_dropout=0.05, bias="none",
228
+ task_type=TaskType.CAUSAL_LM, target_modules="all-linear", use_rslora=True,
229
+ )
230
+ training_args = SFTConfig(
231
+ output_dir=OUTPUT_DIR, push_to_hub=True, hub_model_id=HUB_MODEL_ID,
232
+ hub_strategy="every_save", max_length=MAX_LENGTH, packing=False,
233
+ assistant_only_loss=True, num_train_epochs=NUM_EPOCHS,
234
+ per_device_train_batch_size=BATCH_SIZE, gradient_accumulation_steps=GRAD_ACCUM,
235
+ learning_rate=LEARNING_RATE, lr_scheduler_type="cosine", warmup_ratio=WARMUP_RATIO,
236
+ weight_decay=0.01, max_grad_norm=1.0, bf16=True, tf32=True,
237
+ gradient_checkpointing=True, gradient_checkpointing_kwargs={"use_reentrant": False},
238
+ logging_steps=10, logging_first_step=True, disable_tqdm=True,
239
+ save_strategy="steps", save_steps=500, save_total_limit=3, eval_strategy="no",
240
+ report_to="trackio", run_name="sft-qwen3-8b-coding-agent-v1",
241
+ model_init_kwargs={
242
+ "quantization_config": bnb_config, "device_map": "auto",
243
+ "use_cache": False, "torch_dtype": torch.bfloat16,
244
+ },
245
+ seed=42, dataloader_num_workers=4, dataloader_pin_memory=True,
246
+ )
247
+
248
+ trainer = SFTTrainer(
249
+ model=MODEL_ID, args=training_args, train_dataset=combined,
250
+ peft_config=peft_config, callbacks=[AlertCallback()],
251
+ )
252
+
253
+ total_params = sum(p.numel() for p in trainer.model.parameters())
254
+ trainable_params = sum(p.numel() for p in trainer.model.parameters() if p.requires_grad)
255
+ print(f"Total: {total_params:,} | Trainable: {trainable_params:,} ({100 * trainable_params / total_params:.2f}%)")
256
+
257
+ train_result = trainer.train()
258
+ trainer.save_model(OUTPUT_DIR)
259
+ trainer.push_to_hub(commit_message="Training complete: Qwen3-8B Coding Agent v1")
260
+
261
+ metrics = train_result.metrics
262
+ trackio.alert(title="Training Complete", text=f"Final loss={metrics.get('train_loss', 'N/A')}, hub_model={HUB_MODEL_ID}", level="INFO")
263
+ print(f"✅ DONE! Model: https://huggingface.co/{HUB_MODEL_ID}")
264
+
265
+
266
+ if __name__ == "__main__":
267
+ main()