Fix: keep Colab defaults, install only trl+peft with --no-deps, handle all TRL versions
Browse files- KernelX_Training.ipynb +4 -43
KernelX_Training.ipynb
CHANGED
|
@@ -45,7 +45,7 @@
|
|
| 45 |
"execution_count": null,
|
| 46 |
"metadata": {},
|
| 47 |
"outputs": [],
|
| 48 |
-
"source": "#
|
| 49 |
},
|
| 50 |
{
|
| 51 |
"cell_type": "markdown",
|
|
@@ -182,46 +182,7 @@
|
|
| 182 |
"execution_count": null,
|
| 183 |
"metadata": {},
|
| 184 |
"outputs": [],
|
| 185 |
-
"source": [
|
| 186 |
-
"from datasets import Dataset\n",
|
| 187 |
-
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
|
| 188 |
-
"from peft import LoraConfig\n",
|
| 189 |
-
"from trl import SFTTrainer, SFTConfig\n",
|
| 190 |
-
"\n",
|
| 191 |
-
"MODEL_NAME = config['model']['name']\n",
|
| 192 |
-
"FEATURE_NAMES = config['feature_names']\n",
|
| 193 |
-
"\n",
|
| 194 |
-
"def format_state(features):\n",
|
| 195 |
-
" parts = []\n",
|
| 196 |
-
" for name, val in zip(FEATURE_NAMES, features):\n",
|
| 197 |
-
" if val == int(val):\n",
|
| 198 |
-
" parts.append(f'{name}:{int(val)}')\n",
|
| 199 |
-
" else:\n",
|
| 200 |
-
" parts.append(f'{name}:{val:.2f}')\n",
|
| 201 |
-
" return ' | '.join(parts)\n",
|
| 202 |
-
"\n",
|
| 203 |
-
"def make_world_model_example(record):\n",
|
| 204 |
-
" state_str = format_state(record['state'])\n",
|
| 205 |
-
" next_state_str = format_state(record['next_state'])\n",
|
| 206 |
-
" text = (\n",
|
| 207 |
-
" '<|system|>You are a Linux kernel simulator. '\n",
|
| 208 |
-
" 'Predict the next system state.<|end|>\\n'\n",
|
| 209 |
-
" f'<|user|>[STATE] {state_str}\\n'\n",
|
| 210 |
-
" f'[ACTION] {record[\"action\"]:.4f}\\n'\n",
|
| 211 |
-
" f'[PID] {record[\"pid\"]}\\n'\n",
|
| 212 |
-
" 'Predict [NEXT_STATE]<|end|>\\n'\n",
|
| 213 |
-
" f'<|assistant|>[NEXT_STATE] {next_state_str}<|end|>'\n",
|
| 214 |
-
" )\n",
|
| 215 |
-
" return {'text': text}\n",
|
| 216 |
-
"\n",
|
| 217 |
-
"# Use 10K samples for speed\n",
|
| 218 |
-
"MAX_SAMPLES = 10000\n",
|
| 219 |
-
"train_ds = Dataset.from_list([make_world_model_example(r) for r in train[:MAX_SAMPLES]])\n",
|
| 220 |
-
"val_ds = Dataset.from_list([make_world_model_example(r) for r in val[:MAX_SAMPLES // 8]])\n",
|
| 221 |
-
"\n",
|
| 222 |
-
"print(f'World Model dataset: train={len(train_ds)}, val={len(val_ds)}')\n",
|
| 223 |
-
"print(f'\\nSample:\\n{train_ds[0][\"text\"][:300]}...')"
|
| 224 |
-
]
|
| 225 |
},
|
| 226 |
{
|
| 227 |
"cell_type": "code",
|
|
@@ -245,7 +206,7 @@
|
|
| 245 |
"execution_count": null,
|
| 246 |
"metadata": {},
|
| 247 |
"outputs": [],
|
| 248 |
-
"source": "# Train World Model\nimport inspect\n\nlora_config = LoraConfig(\n r=16, lora_alpha=32,\n target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj',\n 'gate_proj', 'up_proj', 'down_proj'],\n lora_dropout=0.05, bias='none', task_type='CAUSAL_LM',\n)\n\n#
|
| 249 |
},
|
| 250 |
{
|
| 251 |
"cell_type": "markdown",
|
|
@@ -321,7 +282,7 @@
|
|
| 321 |
"execution_count": null,
|
| 322 |
"metadata": {},
|
| 323 |
"outputs": [],
|
| 324 |
-
"source": "# Reload fresh base model for Strategist\ntokenizer =
|
| 325 |
},
|
| 326 |
{
|
| 327 |
"cell_type": "markdown",
|
|
|
|
| 45 |
"execution_count": null,
|
| 46 |
"metadata": {},
|
| 47 |
"outputs": [],
|
| 48 |
+
"source": "# Check what Colab already has, only install what's missing\nimport subprocess, sys\n\n# Keep Colab's torch and transformers - don't touch them\n# Only install the small training libraries that are missing\n!pip install -q --no-deps trl peft\n!pip install -q datasets huggingface_hub\n\n# Verify\nimport torch, transformers\nprint(f'torch={torch.__version__} (Colab default - keeping it)')\nprint(f'transformers={transformers.__version__} (Colab default - keeping it)')\nprint(f'CUDA: {torch.cuda.is_available()}')\n\nfrom transformers import AutoModelForCausalLM, AutoTokenizer\nprint('transformers imports: OK')\n\n# Check if trl has what we need\ntry:\n from trl import SFTTrainer, SFTConfig\n print(f'trl SFTTrainer: OK')\nexcept ImportError:\n print('trl SFTTrainer not available, installing compatible version...')\n !pip install -q \"trl==0.12.2\" --no-deps\n from trl import SFTTrainer, SFTConfig\n print(f'trl SFTTrainer: OK (0.12.2)')\n\nfrom peft import LoraConfig\nprint('peft: OK')\nprint('\\nAll imports working!')"
|
| 49 |
},
|
| 50 |
{
|
| 51 |
"cell_type": "markdown",
|
|
|
|
| 182 |
"execution_count": null,
|
| 183 |
"metadata": {},
|
| 184 |
"outputs": [],
|
| 185 |
+
"source": "from datasets import Dataset\nfrom transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments\nfrom peft import LoraConfig\n\n# Import SFT - handle different TRL versions\ntry:\n from trl import SFTTrainer, SFTConfig\n USE_SFT_CONFIG = True\nexcept ImportError:\n from trl import SFTTrainer\n USE_SFT_CONFIG = False\n\nMODEL_NAME = config['model']['name']\nFEATURE_NAMES = config['feature_names']\n\ndef format_state(features):\n parts = []\n for name, val in zip(FEATURE_NAMES, features):\n if val == int(val):\n parts.append(f'{name}:{int(val)}')\n else:\n parts.append(f'{name}:{val:.2f}')\n return ' | '.join(parts)\n\ndef make_world_model_example(record):\n state_str = format_state(record['state'])\n next_state_str = format_state(record['next_state'])\n text = (\n '<|system|>You are a Linux kernel simulator. '\n 'Predict the next system state.<|end|>\\n'\n f'<|user|>[STATE] {state_str}\\n'\n f'[ACTION] {record[\"action\"]:.4f}\\n'\n f'[PID] {record[\"pid\"]}\\n'\n 'Predict [NEXT_STATE]<|end|>\\n'\n f'<|assistant|>[NEXT_STATE] {next_state_str}<|end|>'\n )\n return {'text': text}\n\n# Use 10K samples for speed\nMAX_SAMPLES = 10000\ntrain_ds = Dataset.from_list([make_world_model_example(r) for r in train[:MAX_SAMPLES]])\nval_ds = Dataset.from_list([make_world_model_example(r) for r in val[:MAX_SAMPLES // 8]])\n\nprint(f'World Model dataset: train={len(train_ds)}, val={len(val_ds)}')\nprint(f'Using SFTConfig: {USE_SFT_CONFIG}')\nprint(f'\\nSample:\\n{train_ds[0][\"text\"][:300]}...')"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
},
|
| 187 |
{
|
| 188 |
"cell_type": "code",
|
|
|
|
| 206 |
"execution_count": null,
|
| 207 |
"metadata": {},
|
| 208 |
"outputs": [],
|
| 209 |
+
"source": "# Train World Model\nimport inspect\n\nlora_config = LoraConfig(\n r=16, lora_alpha=32,\n target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj',\n 'gate_proj', 'up_proj', 'down_proj'],\n lora_dropout=0.05, bias='none', task_type='CAUSAL_LM',\n)\n\n# Build training args compatible with installed TRL version\nif USE_SFT_CONFIG:\n sft_sig = inspect.signature(SFTConfig.__init__)\n seq_key = 'max_seq_length' if 'max_seq_length' in sft_sig.parameters else 'max_length'\n training_args = SFTConfig(\n output_dir='./world_model_checkpoints',\n num_train_epochs=2,\n per_device_train_batch_size=16,\n gradient_accumulation_steps=2,\n learning_rate=2e-4,\n lr_scheduler_type='cosine',\n warmup_ratio=0.1,\n logging_steps=10,\n eval_strategy='steps',\n eval_steps=100,\n save_total_limit=1,\n fp16=True,\n report_to='none',\n **{seq_key: 512},\n )\nelse:\n training_args = TrainingArguments(\n output_dir='./world_model_checkpoints',\n num_train_epochs=2,\n per_device_train_batch_size=16,\n gradient_accumulation_steps=2,\n learning_rate=2e-4,\n lr_scheduler_type='cosine',\n warmup_ratio=0.1,\n logging_steps=10,\n eval_strategy='steps',\n eval_steps=100,\n save_total_limit=1,\n fp16=True,\n report_to='none',\n )\n\ntrainer = SFTTrainer(\n model=model, args=training_args,\n train_dataset=train_ds, eval_dataset=val_ds,\n peft_config=lora_config,\n max_seq_length=512,\n)\n\nprint('Training World Model...')\ntrainer.train()\n\ntrainer.save_model('./world_model_final')\ntokenizer.save_pretrained('./world_model_final')\nprint('World Model saved.')"
|
| 210 |
},
|
| 211 |
{
|
| 212 |
"cell_type": "markdown",
|
|
|
|
| 282 |
"execution_count": null,
|
| 283 |
"metadata": {},
|
| 284 |
"outputs": [],
|
| 285 |
+
"source": "# Reload fresh base model for Strategist\ntokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\nmodel = AutoModelForCausalLM.from_pretrained(MODEL_NAME, device_map='auto')\nif tokenizer.pad_token is None:\n tokenizer.pad_token = tokenizer.eos_token\n\nlora_config = LoraConfig(\n r=16, lora_alpha=32,\n target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj',\n 'gate_proj', 'up_proj', 'down_proj'],\n lora_dropout=0.05, bias='none', task_type='CAUSAL_LM',\n)\n\nif USE_SFT_CONFIG:\n ws_args = SFTConfig(\n output_dir='./strategist_warmstart',\n num_train_epochs=2,\n per_device_train_batch_size=16,\n gradient_accumulation_steps=2,\n learning_rate=2e-4,\n fp16=True,\n logging_steps=5,\n save_total_limit=1,\n report_to='none',\n **{seq_key: 512},\n )\nelse:\n ws_args = TrainingArguments(\n output_dir='./strategist_warmstart',\n num_train_epochs=2,\n per_device_train_batch_size=16,\n gradient_accumulation_steps=2,\n learning_rate=2e-4,\n fp16=True,\n logging_steps=5,\n save_total_limit=1,\n report_to='none',\n )\n\ntrainer = SFTTrainer(\n model=model, args=ws_args,\n train_dataset=ws_dataset, peft_config=lora_config,\n max_seq_length=512,\n)\n\nprint('Training Strategist (warm-start)...')\ntrainer.train()\n\ntrainer.save_model('./strategist_final')\ntokenizer.save_pretrained('./strategist_final')\nprint('Strategist saved.')"
|
| 286 |
},
|
| 287 |
{
|
| 288 |
"cell_type": "markdown",
|