Rayugacodes commited on
Commit
9709cdc
·
verified ·
1 Parent(s): 32a197f

GPU training script for HF

Browse files
Files changed (1) hide show
  1. train_on_hf.py +378 -0
train_on_hf.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ KernelX — Full GPU Training Script for Hugging Face
4
+
5
+ Run this on a HF Space or notebook with GPU (T4/A10/A100).
6
+ It handles everything: download data, train World Model, train Strategist (GRPO),
7
+ merge LoRA, export GGUF, and push results back to HF Hub.
8
+
9
+ Usage (on HF with GPU):
10
+ pip install torch transformers trl peft datasets accelerate huggingface_hub
11
+ python train_on_hf.py --hf-token YOUR_TOKEN
12
+ """
13
+
14
+ import argparse
15
+ import json
16
+ import os
17
+ import sys
18
+ from pathlib import Path
19
+
20
+
21
+ def setup(hf_token: str):
22
+ """Login and download data from HF."""
23
+ from huggingface_hub import login, hf_hub_download, snapshot_download
24
+ login(token=hf_token)
25
+
26
+ # Download training data
27
+ data_dir = Path("data")
28
+ data_dir.mkdir(exist_ok=True)
29
+
30
+ for fname in ["state_transitions.jsonl", "train.jsonl", "val.jsonl", "test.jsonl", "preprocessing_config.json"]:
31
+ path = hf_hub_download(
32
+ repo_id="Rayugacodes/kernelx-training-data",
33
+ filename=fname,
34
+ repo_type="dataset",
35
+ local_dir=str(data_dir),
36
+ )
37
+ print(f"Downloaded {fname}")
38
+
39
+ # Download training scripts
40
+ snapshot_download(
41
+ repo_id="Rayugacodes/kernelx-strategist",
42
+ local_dir="model_repo",
43
+ allow_patterns=["training/**"],
44
+ )
45
+ print("Downloaded training scripts")
46
+
47
+ return data_dir
48
+
49
+
50
+ def train_world_model(data_dir: Path, max_samples: int = 50000):
51
+ """Stage 2: Train World Model via SFT."""
52
+ from datasets import Dataset
53
+ from transformers import AutoModelForCausalLM, AutoTokenizer
54
+ from peft import LoraConfig
55
+ from trl import SFTTrainer, SFTConfig
56
+
57
+ config = json.load(open(data_dir / "preprocessing_config.json"))
58
+ MODEL_NAME = config["model"]["name"]
59
+ FEATURE_NAMES = config["feature_names"]
60
+
61
+ def format_state(features):
62
+ parts = []
63
+ for name, val in zip(FEATURE_NAMES, features):
64
+ if val == int(val):
65
+ parts.append(f"{name}:{int(val)}")
66
+ else:
67
+ parts.append(f"{name}:{val:.2f}")
68
+ return " | ".join(parts)
69
+
70
+ def make_sft_example(record):
71
+ state_str = format_state(record["state"])
72
+ action_str = f"{record['action']:.4f}"
73
+ next_state_str = format_state(record["next_state"])
74
+ text = (
75
+ "<|system|>You are a Linux kernel simulator. "
76
+ "Predict the next system state.<|end|>\n"
77
+ f"<|user|>[STATE] {state_str}\n"
78
+ f"[ACTION] {action_str}\n"
79
+ f"[PID] {record['pid']}\n"
80
+ "Predict [NEXT_STATE]<|end|>\n"
81
+ f"<|assistant|>[NEXT_STATE] {next_state_str}<|end|>"
82
+ )
83
+ return {"text": text}
84
+
85
+ print("\n=== Stage 2: World Model SFT ===")
86
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
87
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, device_map="auto")
88
+ if tokenizer.pad_token is None:
89
+ tokenizer.pad_token = tokenizer.eos_token
90
+
91
+ train_records = [json.loads(l) for l in open(data_dir / "train.jsonl") if l.strip()][:max_samples]
92
+ val_records = [json.loads(l) for l in open(data_dir / "val.jsonl") if l.strip()][:max_samples // 8]
93
+
94
+ train_dataset = Dataset.from_list([make_sft_example(r) for r in train_records])
95
+ val_dataset = Dataset.from_list([make_sft_example(r) for r in val_records])
96
+ print(f" Train: {len(train_dataset)} Val: {len(val_dataset)}")
97
+
98
+ lora_config = LoraConfig(
99
+ r=16, lora_alpha=32,
100
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
101
+ "gate_proj", "up_proj", "down_proj"],
102
+ lora_dropout=0.05, bias="none", task_type="CAUSAL_LM",
103
+ )
104
+
105
+ training_args = SFTConfig(
106
+ output_dir="./world_model_checkpoints",
107
+ num_train_epochs=3,
108
+ per_device_train_batch_size=8,
109
+ gradient_accumulation_steps=2,
110
+ learning_rate=2e-4,
111
+ lr_scheduler_type="cosine",
112
+ warmup_ratio=0.1,
113
+ logging_steps=10,
114
+ eval_strategy="steps",
115
+ eval_steps=200,
116
+ save_steps=500,
117
+ save_total_limit=2,
118
+ fp16=True,
119
+ max_length=512,
120
+ report_to="none",
121
+ )
122
+
123
+ trainer = SFTTrainer(
124
+ model=model, args=training_args,
125
+ train_dataset=train_dataset, eval_dataset=val_dataset,
126
+ peft_config=lora_config,
127
+ )
128
+
129
+ trainer.train()
130
+ trainer.save_model("./world_model_final")
131
+ tokenizer.save_pretrained("./world_model_final")
132
+ print("World Model saved.")
133
+ return model, tokenizer
134
+
135
+
136
+ def train_strategist(data_dir: Path, max_samples: int = 10000):
137
+ """Stage 3: Warm-start SFT + GRPO for the Strategist."""
138
+ import re
139
+ import random
140
+ import numpy as np
141
+ from datasets import Dataset
142
+ from transformers import AutoModelForCausalLM, AutoTokenizer
143
+ from peft import LoraConfig
144
+ from trl import SFTTrainer, SFTConfig, GRPOConfig, GRPOTrainer
145
+
146
+ config = json.load(open(data_dir / "preprocessing_config.json"))
147
+ MODEL_NAME = config["model"]["name"]
148
+ FEATURE_NAMES = config["feature_names"]
149
+ IDX_WAIT_US = 9
150
+ IDX_CTX_SWITCHES = 8
151
+ IDX_EXEC_NS = 4
152
+
153
+ def format_state(features):
154
+ parts = []
155
+ for name, val in zip(FEATURE_NAMES, features):
156
+ if val == int(val):
157
+ parts.append(f"{name}:{int(val)}")
158
+ else:
159
+ parts.append(f"{name}:{val:.2f}")
160
+ return " | ".join(parts)
161
+
162
+ def build_prompt(state, pid, cpu):
163
+ state_str = format_state(state)
164
+ return (
165
+ "<|system|>You are a Linux kernel scheduling strategist. "
166
+ "Given the current system state, output a scheduling action.<|end|>\n"
167
+ f"<|user|>[STATE] {state_str}\n"
168
+ f"[PID] {pid} [CPU] {cpu}\n"
169
+ "[ACTION]<|end|>\n"
170
+ "<|assistant|>"
171
+ )
172
+
173
+ def parse_action(text):
174
+ m = re.search(r"\[ACTION\]\s*([-+]?\d*\.?\d+)", text)
175
+ if not m:
176
+ m = re.search(r"([-+]?\d*\.?\d+)", text)
177
+ if not m:
178
+ raise ValueError("No action found")
179
+ return float(m.group(1))
180
+
181
+ # Load data
182
+ all_records = [json.loads(l) for l in open(data_dir / "train.jsonl") if l.strip()]
183
+ records = random.sample(all_records, min(max_samples, len(all_records)))
184
+ print(f"\n=== Stage 3: Strategist Training ({len(records)} samples) ===")
185
+
186
+ # --- Phase 1: Warm-start SFT ---
187
+ print("\n--- Phase 1: Warm-start SFT ---")
188
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
189
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, device_map="auto")
190
+ if tokenizer.pad_token is None:
191
+ tokenizer.pad_token = tokenizer.eos_token
192
+
193
+ warmstart_examples = []
194
+ for rec in records[:500]:
195
+ state = rec["state"]
196
+ wait_us = state[IDX_WAIT_US]
197
+ csw = state[IDX_CTX_SWITCHES]
198
+ if wait_us > 15:
199
+ action = -0.6
200
+ elif csw > 10:
201
+ action = -0.3
202
+ elif wait_us < 3:
203
+ action = 0.1
204
+ else:
205
+ action = 0.05
206
+ prompt = build_prompt(state, rec["pid"], rec["cpu"])
207
+ warmstart_examples.append({"text": f"{prompt}{action:.4f}<|end|>"})
208
+
209
+ ws_dataset = Dataset.from_list(warmstart_examples)
210
+
211
+ lora_config = LoraConfig(
212
+ r=16, lora_alpha=32,
213
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
214
+ "gate_proj", "up_proj", "down_proj"],
215
+ lora_dropout=0.05, bias="none", task_type="CAUSAL_LM",
216
+ )
217
+
218
+ ws_args = SFTConfig(
219
+ output_dir="./strategist_warmstart",
220
+ num_train_epochs=2,
221
+ per_device_train_batch_size=8,
222
+ gradient_accumulation_steps=2,
223
+ learning_rate=2e-4,
224
+ fp16=True,
225
+ max_length=512,
226
+ logging_steps=5,
227
+ save_steps=100,
228
+ report_to="none",
229
+ )
230
+
231
+ trainer = SFTTrainer(
232
+ model=model, args=ws_args,
233
+ train_dataset=ws_dataset, peft_config=lora_config,
234
+ )
235
+ trainer.train()
236
+ trainer.save_model("./strategist_warmstart")
237
+ tokenizer.save_pretrained("./strategist_warmstart")
238
+ print("Warm-start complete.")
239
+
240
+ # --- Phase 2: GRPO ---
241
+ print("\n--- Phase 2: GRPO RL Training ---")
242
+
243
+ # Build nearest-neighbor simulator from data
244
+ all_states = np.array([r["state"] for r in records])
245
+ all_next_states = [r["next_state"] for r in records]
246
+
247
+ def simulate(state_features, action_val):
248
+ state_arr = np.array(state_features)
249
+ dists = np.linalg.norm(all_states[:500] - state_arr, axis=1)
250
+ return all_next_states[int(np.argmin(dists))]
251
+
252
+ def reward_fn(completions, prompts):
253
+ rewards = []
254
+ for prompt, completion in zip(prompts, completions):
255
+ try:
256
+ # Parse state from prompt
257
+ state_match = re.search(r"\[STATE\]\s*(.+?)(?:\n|$)", prompt)
258
+ values = []
259
+ for part in state_match.group(1).split("|"):
260
+ part = part.strip()
261
+ if ":" in part:
262
+ values.append(float(part.split(":")[1]))
263
+
264
+ action_val = parse_action(completion)
265
+ next_state = simulate(values, action_val)
266
+
267
+ # Reward: throughput + latency + stability + format
268
+ exec_delta = next_state[IDX_EXEC_NS] - values[IDX_EXEC_NS]
269
+ r_throughput = float(np.log(max(0.0, exec_delta) + 1))
270
+ wait_delta = next_state[IDX_WAIT_US] - values[IDX_WAIT_US]
271
+ r_latency = -2.0 * max(0.0, wait_delta)
272
+ r_stability = -0.5 * abs(action_val)
273
+ r_format = 1.0 if -1.0 <= action_val <= 1.0 else 0.0
274
+
275
+ rewards.append(r_throughput + r_latency + r_stability + r_format)
276
+ except (ValueError, IndexError, AttributeError):
277
+ rewards.append(-5.0)
278
+ return rewards
279
+
280
+ prompt_dataset = Dataset.from_list([
281
+ {"prompt": build_prompt(r["state"], r["pid"], r["cpu"])}
282
+ for r in records
283
+ ])
284
+
285
+ grpo_lora = LoraConfig(
286
+ r=16, lora_alpha=32,
287
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
288
+ "gate_proj", "up_proj", "down_proj"],
289
+ lora_dropout=0.05, bias="none", task_type="CAUSAL_LM",
290
+ )
291
+
292
+ grpo_config = GRPOConfig(
293
+ output_dir="./strategist_grpo",
294
+ num_train_epochs=1,
295
+ per_device_train_batch_size=2,
296
+ gradient_accumulation_steps=8,
297
+ learning_rate=5e-6,
298
+ num_generations=4,
299
+ max_completion_length=16,
300
+ max_prompt_length=384,
301
+ logging_steps=5,
302
+ save_steps=200,
303
+ save_total_limit=2,
304
+ temperature=0.7,
305
+ fp16=True,
306
+ report_to="none",
307
+ )
308
+
309
+ grpo_trainer = GRPOTrainer(
310
+ model=model,
311
+ args=grpo_config,
312
+ train_dataset=prompt_dataset,
313
+ reward_funcs=reward_fn,
314
+ peft_config=grpo_lora,
315
+ )
316
+
317
+ grpo_trainer.train()
318
+ grpo_trainer.save_model("./strategist_final")
319
+ tokenizer.save_pretrained("./strategist_final")
320
+ print("GRPO training complete.")
321
+
322
+ return model, tokenizer
323
+
324
+
325
+ def merge_and_push(hf_token: str):
326
+ """Merge LoRA, push merged model to HF Hub."""
327
+ from transformers import AutoModelForCausalLM, AutoTokenizer
328
+ from peft import PeftModel
329
+ from huggingface_hub import login
330
+ login(token=hf_token)
331
+
332
+ config = json.load(open("data/preprocessing_config.json"))
333
+ MODEL_NAME = config["model"]["name"]
334
+
335
+ print("\n=== Merging LoRA and pushing to HF ===")
336
+ base = AutoModelForCausalLM.from_pretrained(MODEL_NAME, device_map="cpu")
337
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
338
+ model = PeftModel.from_pretrained(base, "./strategist_final")
339
+ merged = model.merge_and_unload()
340
+
341
+ merged.save_pretrained("./strategist_merged")
342
+ tokenizer.save_pretrained("./strategist_merged")
343
+
344
+ merged.push_to_hub("Rayugacodes/kernelx-strategist", commit_message="Merged strategist (warm-start + GRPO)")
345
+ tokenizer.push_to_hub("Rayugacodes/kernelx-strategist", commit_message="Tokenizer")
346
+ print("Pushed to https://huggingface.co/Rayugacodes/kernelx-strategist")
347
+
348
+
349
+ def main():
350
+ parser = argparse.ArgumentParser(description="KernelX GPU Training on HF")
351
+ parser.add_argument("--hf-token", required=True, help="HuggingFace token")
352
+ parser.add_argument("--world-model-samples", type=int, default=50000)
353
+ parser.add_argument("--strategist-samples", type=int, default=10000)
354
+ parser.add_argument("--skip-world-model", action="store_true")
355
+ parser.add_argument("--skip-strategist", action="store_true")
356
+ parser.add_argument("--skip-merge", action="store_true")
357
+ args = parser.parse_args()
358
+
359
+ # Setup
360
+ data_dir = setup(args.hf_token)
361
+
362
+ # Train
363
+ if not args.skip_world_model:
364
+ train_world_model(data_dir, max_samples=args.world_model_samples)
365
+
366
+ if not args.skip_strategist:
367
+ train_strategist(data_dir, max_samples=args.strategist_samples)
368
+
369
+ if not args.skip_merge:
370
+ merge_and_push(args.hf_token)
371
+
372
+ print("\n=== All done! ===")
373
+ print("Model: https://huggingface.co/Rayugacodes/kernelx-strategist")
374
+ print("Next: convert to GGUF for sub-50ms CPU inference")
375
+
376
+
377
+ if __name__ == "__main__":
378
+ main()