oxdev commited on
Commit
55ef8ec
Β·
verified Β·
1 Parent(s): 3c818d7

Add Google Colab training notebook for V2 GRPO training (free T4 path)

Browse files
Files changed (1) hide show
  1. train_grpo_v2_colab.ipynb +482 -0
train_grpo_v2_colab.ipynb ADDED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": [],
7
+ "gpuType": "T4"
8
+ },
9
+ "kernelspec": {
10
+ "name": "python3",
11
+ "display_name": "Python 3"
12
+ },
13
+ "language_info": {
14
+ "name": "python"
15
+ },
16
+ "accelerator": "GPU"
17
+ },
18
+ "cells": [
19
+ {
20
+ "cell_type": "markdown",
21
+ "metadata": {},
22
+ "source": [
23
+ "# πŸ” Smart Contract Security Auditor β€” GRPO V2 Training\n",
24
+ "\n",
25
+ "Train a specialized smart contract security auditor using **Group Relative Policy Optimization (GRPO)**\n",
26
+ "on **50,902 real audit findings** from top security firms.\n",
27
+ "\n",
28
+ "**Model:** Qwen2.5-Coder-0.5B-Instruct β†’ oxdev/security-auditor-grpo\n",
29
+ "\n",
30
+ "**Dataset:** [oxdev/smart-contract-security-audit-v2](https://huggingface.co/datasets/oxdev/smart-contract-security-audit-v2)\n",
31
+ "\n",
32
+ "**Hardware:** Free Colab T4 (16GB VRAM)\n",
33
+ "\n",
34
+ "---\n",
35
+ "\n",
36
+ "## Setup\n",
37
+ "1. Go to **Runtime β†’ Change runtime type β†’ T4 GPU**\n",
38
+ "2. Run all cells in order\n",
39
+ "3. When prompted, enter your HuggingFace token (needs write access)\n",
40
+ "4. Training takes ~4-6 hours on a T4 GPU with 2K samples"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "code",
45
+ "execution_count": null,
46
+ "metadata": {},
47
+ "outputs": [],
48
+ "source": [
49
+ "# Cell 1: Install dependencies\n",
50
+ "!pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121\n",
51
+ "!pip install -q transformers>=4.51.0 trl>=1.2.0 datasets accelerate huggingface_hub\n",
52
+ "print('\\nβœ… Dependencies installed!')\n",
53
+ "\n",
54
+ "import torch\n",
55
+ "print(f'PyTorch: {torch.__version__}')\n",
56
+ "print(f'CUDA available: {torch.cuda.is_available()}')\n",
57
+ "if torch.cuda.is_available():\n",
58
+ " print(f'GPU: {torch.cuda.get_device_name(0)}')\n",
59
+ " print(f'VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB')"
60
+ ]
61
+ },
62
+ {
63
+ "cell_type": "code",
64
+ "execution_count": null,
65
+ "metadata": {},
66
+ "outputs": [],
67
+ "source": [
68
+ "# Cell 2: Login to HuggingFace (needed to push model)\n",
69
+ "from huggingface_hub import login\n",
70
+ "login() # Will prompt for your token"
71
+ ]
72
+ },
73
+ {
74
+ "cell_type": "code",
75
+ "execution_count": null,
76
+ "metadata": {},
77
+ "outputs": [],
78
+ "source": [
79
+ "# Cell 3: Configuration\n",
80
+ "# ╔══════════════════════════════════════════════════════════════╗\n",
81
+ "# β•‘ MODIFY THESE SETTINGS AS NEEDED β•‘\n",
82
+ "# β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•\n",
83
+ "\n",
84
+ "MODEL_NAME = \"Qwen/Qwen2.5-Coder-0.5B-Instruct\" # Base model\n",
85
+ "DATASET_ID = \"oxdev/smart-contract-security-audit-v2\" # 50K real findings\n",
86
+ "HUB_MODEL_ID = \"oxdev/security-auditor-grpo\" # Where to push\n",
87
+ "OUTPUT_DIR = \"/content/grpo_v2_output\" # Local output\n",
88
+ "\n",
89
+ "# Training hyperparameters (tuned for T4 16GB)\n",
90
+ "SUBSET_SIZE = 2000 # Samples to train on (2K fits in ~4hrs on T4)\n",
91
+ "BATCH_SIZE = 2 # Per-device batch size\n",
92
+ "GRAD_ACCUM = 4 # Gradient accumulation β†’ effective batch = 8\n",
93
+ "NUM_GENERATIONS = 2 # GRPO generations per prompt\n",
94
+ "MAX_COMPLETION_LENGTH = 512 # Max tokens per completion\n",
95
+ "LEARNING_RATE = 1e-6\n",
96
+ "BETA = 0.04 # KL penalty\n",
97
+ "NUM_EPOCHS = 1\n",
98
+ "SAVE_STEPS = 100\n",
99
+ "\n",
100
+ "print(f'Config ready: {SUBSET_SIZE} samples, batch={BATCH_SIZE}Γ—{GRAD_ACCUM}, lr={LEARNING_RATE}')"
101
+ ]
102
+ },
103
+ {
104
+ "cell_type": "code",
105
+ "execution_count": null,
106
+ "metadata": {},
107
+ "outputs": [],
108
+ "source": [
109
+ "# Cell 4: Load and inspect dataset\n",
110
+ "from datasets import load_dataset\n",
111
+ "from collections import Counter\n",
112
+ "\n",
113
+ "print('Loading dataset...')\n",
114
+ "dataset = load_dataset(DATASET_ID, split='train')\n",
115
+ "print(f'Total: {len(dataset)} samples')\n",
116
+ "print(f'Columns: {dataset.column_names}')\n",
117
+ "print()\n",
118
+ "\n",
119
+ "# Show distributions\n",
120
+ "sev_dist = Counter(dataset['severity'])\n",
121
+ "cat_dist = Counter(dataset['category'])\n",
122
+ "src_dist = Counter(dataset['source'])\n",
123
+ "\n",
124
+ "print('Severity distribution:')\n",
125
+ "for sev, count in sorted(sev_dist.items(), key=lambda x: -x[1]):\n",
126
+ " print(f' {sev:15s}: {count:6d} ({count/len(dataset)*100:.1f}%)')\n",
127
+ "\n",
128
+ "print(f'\\nCategory distribution (top 10):')\n",
129
+ "for cat, count in sorted(cat_dist.items(), key=lambda x: -x[1])[:10]:\n",
130
+ " print(f' {cat:20s}: {count:6d}')\n",
131
+ "\n",
132
+ "print(f'\\nSource distribution:')\n",
133
+ "for src, count in sorted(src_dist.items(), key=lambda x: -x[1]):\n",
134
+ " print(f' {src:20s}: {count:6d}')\n",
135
+ "\n",
136
+ "# Show a sample\n",
137
+ "print(f'\\n--- Sample prompt (first 300 chars) ---')\n",
138
+ "p = dataset[0]['prompt']\n",
139
+ "user_msg = [m for m in p if m['role'] == 'user'][0]['content']\n",
140
+ "print(user_msg[:300])"
141
+ ]
142
+ },
143
+ {
144
+ "cell_type": "code",
145
+ "execution_count": null,
146
+ "metadata": {},
147
+ "outputs": [],
148
+ "source": [
149
+ "# Cell 5: Curate high-quality training subset\n",
150
+ "print(f'Selecting top {SUBSET_SIZE} highest-value samples...')\n",
151
+ "\n",
152
+ "indices = []\n",
153
+ "idx_set = set()\n",
154
+ "\n",
155
+ "# Priority 1: HIGH+CRITICAL severity with code (most valuable)\n",
156
+ "for i, row in enumerate(dataset):\n",
157
+ " if row['severity'] in ('high', 'critical') and row['has_code']:\n",
158
+ " indices.append(i)\n",
159
+ " idx_set.add(i)\n",
160
+ "print(f' HIGH+CRITICAL with code: {len(indices)}')\n",
161
+ "\n",
162
+ "# Priority 2: Any with PoC reference\n",
163
+ "for i, row in enumerate(dataset):\n",
164
+ " if row['has_poc'] and i not in idx_set:\n",
165
+ " indices.append(i)\n",
166
+ " idx_set.add(i)\n",
167
+ "print(f' + Has PoC: {len(indices)}')\n",
168
+ "\n",
169
+ "# Priority 3: MEDIUM with code (fill to cap)\n",
170
+ "for i, row in enumerate(dataset):\n",
171
+ " if row['severity'] == 'medium' and row['has_code'] and i not in idx_set:\n",
172
+ " indices.append(i)\n",
173
+ " idx_set.add(i)\n",
174
+ " if len(indices) >= SUBSET_SIZE:\n",
175
+ " break\n",
176
+ "\n",
177
+ "# If still short, add remaining HIGH+CRITICAL without code\n",
178
+ "if len(indices) < SUBSET_SIZE:\n",
179
+ " for i, row in enumerate(dataset):\n",
180
+ " if row['severity'] in ('high', 'critical') and i not in idx_set:\n",
181
+ " indices.append(i)\n",
182
+ " idx_set.add(i)\n",
183
+ " if len(indices) >= SUBSET_SIZE:\n",
184
+ " break\n",
185
+ "\n",
186
+ "train_dataset = dataset.select(indices[:SUBSET_SIZE])\n",
187
+ "print(f'\\nβœ… Final subset: {len(train_dataset)} samples')\n",
188
+ "\n",
189
+ "# Show final distribution\n",
190
+ "final_sev = Counter(train_dataset['severity'])\n",
191
+ "for sev, count in sorted(final_sev.items(), key=lambda x: -x[1]):\n",
192
+ " print(f' {sev:15s}: {count:6d}')"
193
+ ]
194
+ },
195
+ {
196
+ "cell_type": "code",
197
+ "execution_count": null,
198
+ "metadata": {},
199
+ "outputs": [],
200
+ "source": [
201
+ "# Cell 6: Define reward functions\n",
202
+ "import re\n",
203
+ "\n",
204
+ "def format_reward(prompts, completions, completion_ids=None, **kwargs):\n",
205
+ " \"\"\"Reward for producing structured FINDING blocks and proper formatting.\"\"\"\n",
206
+ " rewards = []\n",
207
+ " for completion in completions:\n",
208
+ " text = completion[0]['content'] if isinstance(completion, list) else str(completion)\n",
209
+ " reward = 0.0\n",
210
+ " if re.search(r'FINDING\\s*\\|', text):\n",
211
+ " reward += 0.3\n",
212
+ " fields = ['contract:', 'function:', 'bug_class:', 'confidence:']\n",
213
+ " reward += 0.05 * sum(1 for f in fields if f in text)\n",
214
+ " if re.search(r'```solidity', text):\n",
215
+ " reward += 0.15\n",
216
+ " section_keywords = ['description', 'impact', 'proof', 'fix', 'recommendation', 'mitigation']\n",
217
+ " sect_count = sum(1 for kw in section_keywords if re.search(rf'(?i)(###?\\s*{kw}|{kw}:)', text))\n",
218
+ " reward += 0.05 * min(sect_count, 3)\n",
219
+ " if len(text) < 50: reward -= 0.3\n",
220
+ " elif len(text) > 4000: reward -= 0.1\n",
221
+ " rewards.append(max(-1.0, min(1.0, reward)))\n",
222
+ " return rewards\n",
223
+ "\n",
224
+ "\n",
225
+ "def _sev_rank(sev):\n",
226
+ " return {'critical': 5, 'high': 4, 'medium': 3, 'low': 2, 'informational': 1, 'gas': 0}.get(sev, -1)\n",
227
+ "\n",
228
+ "def severity_reward(prompts, completions, completion_ids=None, severity=None, **kwargs):\n",
229
+ " \"\"\"Reward for correctly identifying the severity level.\"\"\"\n",
230
+ " rewards = []\n",
231
+ " if severity is None:\n",
232
+ " return [0.0] * len(completions)\n",
233
+ " sev_list = severity if isinstance(severity, list) else [severity] * len(completions)\n",
234
+ " for i, completion in enumerate(completions):\n",
235
+ " text = completion[0]['content'] if isinstance(completion, list) else str(completion)\n",
236
+ " gt_sev = sev_list[i] if i < len(sev_list) else 'unknown'\n",
237
+ " if gt_sev == 'unknown':\n",
238
+ " rewards.append(0.0); continue\n",
239
+ " sev_match = re.search(r'(?i)(critical|high|medium|low|informational|gas)', text.lower())\n",
240
+ " if not sev_match:\n",
241
+ " rewards.append(-0.3)\n",
242
+ " else:\n",
243
+ " pred = sev_match.group(1).lower()\n",
244
+ " diff = abs(_sev_rank(pred) - _sev_rank(gt_sev))\n",
245
+ " rewards.append(1.0 if diff == 0 else 0.3 if diff == 1 else -0.5)\n",
246
+ " return rewards\n",
247
+ "\n",
248
+ "\n",
249
+ "CATEGORY_KEYWORDS = {\n",
250
+ " 'reentrancy': ['reentrancy', 'reentrant', 're-enter', 'callback'],\n",
251
+ " 'access-control': ['access control', 'unauthorized', 'permission', 'onlyowner', 'role', 'privilege'],\n",
252
+ " 'oracle': ['oracle', 'price feed', 'chainlink', 'twap', 'price manipulation'],\n",
253
+ " 'flash-loan': ['flash loan', 'flashloan'],\n",
254
+ " 'overflow': ['overflow', 'underflow', 'arithmetic'],\n",
255
+ " 'front-running': ['front-run', 'frontrun', 'sandwich', 'mev'],\n",
256
+ " 'dos': ['denial of service', 'dos', 'gas limit', 'unbounded', 'out of gas'],\n",
257
+ " 'token': ['erc20', 'erc721', 'token', 'fee-on-transfer', 'rebasing'],\n",
258
+ " 'storage': ['storage collision', 'delegatecall', 'proxy', 'slot'],\n",
259
+ " 'cross-chain': ['bridge', 'cross-chain', 'relay', 'message passing'],\n",
260
+ " 'liquidation': ['liquidation', 'collateral', 'health factor'],\n",
261
+ " 'signature': ['signature', 'ecrecover', 'replay', 'nonce', 'eip712'],\n",
262
+ " 'initialization': ['initialize', 'constructor', 'uninitialized'],\n",
263
+ " 'rounding': ['rounding', 'precision', 'truncation', 'decimal'],\n",
264
+ " 'logic': ['logic error', 'incorrect calculation', 'business logic'],\n",
265
+ "}\n",
266
+ "\n",
267
+ "def category_reward(prompts, completions, completion_ids=None, category=None, **kwargs):\n",
268
+ " \"\"\"Reward for identifying the correct vulnerability category.\"\"\"\n",
269
+ " rewards = []\n",
270
+ " if category is None:\n",
271
+ " return [0.0] * len(completions)\n",
272
+ " cat_list = category if isinstance(category, list) else [category] * len(completions)\n",
273
+ " for i, completion in enumerate(completions):\n",
274
+ " text = completion[0]['content'] if isinstance(completion, list) else str(completion)\n",
275
+ " gt_cat = cat_list[i] if i < len(cat_list) else 'other'\n",
276
+ " if gt_cat in ('other', 'unknown'):\n",
277
+ " rewards.append(0.0); continue\n",
278
+ " gt_keywords = CATEGORY_KEYWORDS.get(gt_cat, [])\n",
279
+ " if not gt_keywords:\n",
280
+ " rewards.append(0.0); continue\n",
281
+ " hits = sum(1 for kw in gt_keywords if kw in text.lower())\n",
282
+ " if hits >= 2: rewards.append(1.0)\n",
283
+ " elif hits == 1: rewards.append(0.5)\n",
284
+ " else:\n",
285
+ " any_hit = any(kw in text.lower() for kws in CATEGORY_KEYWORDS.values() for kw in kws)\n",
286
+ " rewards.append(-0.2 if any_hit else -0.5)\n",
287
+ " return rewards\n",
288
+ "\n",
289
+ "\n",
290
+ "def quality_reward(prompts, completions, completion_ids=None, **kwargs):\n",
291
+ " \"\"\"Reward for overall response quality: technical depth, actionability.\"\"\"\n",
292
+ " rewards = []\n",
293
+ " for completion in completions:\n",
294
+ " text = completion[0]['content'] if isinstance(completion, list) else str(completion)\n",
295
+ " reward = 0.0\n",
296
+ " technical_terms = [\n",
297
+ " 'msg.sender', 'tx.origin', 'delegatecall', 'selfdestruct',\n",
298
+ " 'transfer', 'call.value', 'abi.encode', 'keccak256',\n",
299
+ " 'require(', 'assert(', 'revert', 'mapping', 'storage',\n",
300
+ " 'memory', 'calldata', 'modifier', 'interface', 'pragma',\n",
301
+ " 'assembly', 'unchecked', 'payable', 'receive()', 'fallback()',\n",
302
+ " ]\n",
303
+ " reward += min(0.3, 0.03 * sum(1 for t in technical_terms if t in text))\n",
304
+ " reasoning = ['because', 'therefore', 'this means', 'as a result',\n",
305
+ " 'the attacker can', 'this allows', 'leading to',\n",
306
+ " 'step 1', 'step 2', 'first,', 'then,', 'finally,']\n",
307
+ " reward += min(0.3, 0.06 * sum(1 for r in reasoning if r.lower() in text.lower()))\n",
308
+ " fix_ind = ['fix:', 'recommendation:', 'mitigation:', 'should', 'consider', 'instead']\n",
309
+ " reward += min(0.2, 0.05 * sum(1 for f in fix_ind if f.lower() in text.lower()))\n",
310
+ " if re.search(r'line\\s+\\d+|L\\d+|#L\\d+', text): reward += 0.1\n",
311
+ " if re.search(r'function\\s+\\w+\\s*\\(', text): reward += 0.1\n",
312
+ " generic = ['i cannot', \"i don't\", 'no vulnerabilities found', 'the code looks safe']\n",
313
+ " if any(p in text.lower() for p in generic): reward -= 0.5\n",
314
+ " rewards.append(max(-1.0, min(1.0, reward)))\n",
315
+ " return rewards\n",
316
+ "\n",
317
+ "print('βœ… 4 reward functions defined: format, severity, category, quality')"
318
+ ]
319
+ },
320
+ {
321
+ "cell_type": "code",
322
+ "execution_count": null,
323
+ "metadata": {},
324
+ "outputs": [],
325
+ "source": [
326
+ "# Cell 7: Initialize GRPO Trainer\n",
327
+ "from trl import GRPOTrainer, GRPOConfig\n",
328
+ "\n",
329
+ "config = GRPOConfig(\n",
330
+ " output_dir=OUTPUT_DIR,\n",
331
+ " num_train_epochs=NUM_EPOCHS,\n",
332
+ " per_device_train_batch_size=BATCH_SIZE,\n",
333
+ " gradient_accumulation_steps=GRAD_ACCUM,\n",
334
+ " num_generations=NUM_GENERATIONS,\n",
335
+ " max_completion_length=MAX_COMPLETION_LENGTH,\n",
336
+ " learning_rate=LEARNING_RATE,\n",
337
+ " beta=BETA,\n",
338
+ " scale_rewards=True,\n",
339
+ " reward_weights=[0.25, 0.25, 0.25, 0.25],\n",
340
+ " gradient_checkpointing=True,\n",
341
+ " bf16=True,\n",
342
+ " logging_steps=10,\n",
343
+ " logging_first_step=True,\n",
344
+ " logging_strategy='steps',\n",
345
+ " disable_tqdm=False, # Show progress bar in Colab\n",
346
+ " save_strategy='steps',\n",
347
+ " save_steps=SAVE_STEPS,\n",
348
+ " save_total_limit=2,\n",
349
+ " push_to_hub=False, # We push manually at the end\n",
350
+ " log_completions=False,\n",
351
+ " report_to='none',\n",
352
+ " seed=42,\n",
353
+ ")\n",
354
+ "\n",
355
+ "print('Initializing GRPOTrainer...')\n",
356
+ "trainer = GRPOTrainer(\n",
357
+ " model=MODEL_NAME,\n",
358
+ " args=config,\n",
359
+ " reward_funcs=[format_reward, severity_reward, category_reward, quality_reward],\n",
360
+ " train_dataset=train_dataset,\n",
361
+ ")\n",
362
+ "print(f'βœ… GRPOTrainer ready! {len(train_dataset)} samples, ~{len(train_dataset) // (BATCH_SIZE * GRAD_ACCUM)} steps')"
363
+ ]
364
+ },
365
+ {
366
+ "cell_type": "code",
367
+ "execution_count": null,
368
+ "metadata": {},
369
+ "outputs": [],
370
+ "source": [
371
+ "# Cell 8: TRAIN! πŸš€\n",
372
+ "# This takes 4-6 hours on T4. Colab will keep running if you stay connected.\n",
373
+ "# Tip: Keep the tab open and active to prevent disconnection.\n",
374
+ "\n",
375
+ "import time\n",
376
+ "start = time.time()\n",
377
+ "print('πŸš€ Starting GRPO V2 training...')\n",
378
+ "print(f'Estimated time: ~{len(train_dataset) / (BATCH_SIZE * GRAD_ACCUM) * 45 / 3600:.1f} hours')\n",
379
+ "print()\n",
380
+ "\n",
381
+ "trainer.train()\n",
382
+ "\n",
383
+ "elapsed = time.time() - start\n",
384
+ "print(f'\\nβœ… Training complete in {elapsed/3600:.1f} hours!')"
385
+ ]
386
+ },
387
+ {
388
+ "cell_type": "code",
389
+ "execution_count": null,
390
+ "metadata": {},
391
+ "outputs": [],
392
+ "source": [
393
+ "# Cell 9: Save and push to Hub\n",
394
+ "import os\n",
395
+ "from huggingface_hub import HfApi\n",
396
+ "\n",
397
+ "print(f'Saving model to {OUTPUT_DIR}...')\n",
398
+ "trainer.save_model(OUTPUT_DIR)\n",
399
+ "\n",
400
+ "print(f'Pushing to Hub: {HUB_MODEL_ID}...')\n",
401
+ "api = HfApi()\n",
402
+ "api.create_repo(repo_id=HUB_MODEL_ID, exist_ok=True)\n",
403
+ "\n",
404
+ "# Upload model files (skip checkpoints and optimizer states to save time)\n",
405
+ "api.upload_folder(\n",
406
+ " folder_path=OUTPUT_DIR,\n",
407
+ " repo_id=HUB_MODEL_ID,\n",
408
+ " commit_message='GRPO V2 β€” trained on real audit findings, 4 reward functions',\n",
409
+ " ignore_patterns=['checkpoint-*', '*.pt'], # Skip checkpoints\n",
410
+ ")\n",
411
+ "\n",
412
+ "print(f'\\nπŸŽ‰ Model pushed to https://huggingface.co/{HUB_MODEL_ID}')"
413
+ ]
414
+ },
415
+ {
416
+ "cell_type": "code",
417
+ "execution_count": null,
418
+ "metadata": {},
419
+ "outputs": [],
420
+ "source": [
421
+ "# Cell 10: Quick inference test\n",
422
+ "from transformers import pipeline as hf_pipeline\n",
423
+ "\n",
424
+ "print('Loading trained model for inference...')\n",
425
+ "pipe = hf_pipeline('text-generation', model=OUTPUT_DIR, device=0, torch_dtype=torch.bfloat16)\n",
426
+ "\n",
427
+ "test_contract = \"\"\"\n",
428
+ "pragma solidity ^0.8.0;\n",
429
+ "\n",
430
+ "contract SimpleBank {\n",
431
+ " mapping(address => uint256) public balances;\n",
432
+ "\n",
433
+ " function deposit() public payable {\n",
434
+ " balances[msg.sender] += msg.value;\n",
435
+ " }\n",
436
+ "\n",
437
+ " function withdraw(uint256 amount) public {\n",
438
+ " require(balances[msg.sender] >= amount);\n",
439
+ " (bool success, ) = msg.sender.call{value: amount}(\\\"\\\");\n",
440
+ " require(success);\n",
441
+ " balances[msg.sender] -= amount;\n",
442
+ " }\n",
443
+ "}\n",
444
+ "\"\"\"\n",
445
+ "\n",
446
+ "messages = [\n",
447
+ " {'role': 'system', 'content': 'You are an expert smart contract security auditor. Analyze the provided Solidity code for vulnerabilities.'},\n",
448
+ " {'role': 'user', 'content': f'Audit this contract:\\n```solidity\\n{test_contract}\\n```'},\n",
449
+ "]\n",
450
+ "\n",
451
+ "result = pipe(messages, max_new_tokens=512, do_sample=False, return_full_text=False)\n",
452
+ "output = result[0]['generated_text']\n",
453
+ "if isinstance(output, list):\n",
454
+ " output = output[-1]['content']\n",
455
+ "\n",
456
+ "print('\\n=== Audit Result ===')\n",
457
+ "print(output)"
458
+ ]
459
+ },
460
+ {
461
+ "cell_type": "markdown",
462
+ "metadata": {},
463
+ "source": [
464
+ "---\n",
465
+ "\n",
466
+ "## πŸŽ‰ Done!\n",
467
+ "\n",
468
+ "Your V2 model is now pushed to the Hub. Test it interactively at:\n",
469
+ "\n",
470
+ "**Demo Space:** [oxdev/security-auditor-demo](https://huggingface.co/spaces/oxdev/security-auditor-demo)\n",
471
+ "\n",
472
+ "**Model:** [oxdev/security-auditor-grpo](https://huggingface.co/oxdev/security-auditor-grpo)\n",
473
+ "\n",
474
+ "### Next Steps\n",
475
+ "- Train on more data: increase `SUBSET_SIZE` to 5000 or 10000\n",
476
+ "- Use a bigger model: try `Qwen/Qwen2.5-Coder-1.5B-Instruct` (needs A100)\n",
477
+ "- Fine-tune rewards: adjust weights in `reward_weights`\n",
478
+ "- Try different hyperparameters: learning rate, beta, num_generations"
479
+ ]
480
+ }
481
+ ]
482
+ }