fawazo commited on
Commit
a0a68fc
·
verified ·
1 Parent(s): b3589b2

Upload train_pentest_v3.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_pentest_v3.py +225 -0
train_pentest_v3.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # dependencies = [
3
+ # "trl>=0.12.0",
4
+ # "peft>=0.7.0",
5
+ # "transformers>=4.36.0",
6
+ # "accelerate>=0.24.0",
7
+ # "trackio",
8
+ # "datasets>=2.14.0",
9
+ # ]
10
+ # ///
11
+
12
+ import traceback
13
+ from datasets import load_dataset, concatenate_datasets
14
+
15
+ PENTEST_SYSTEM_PROMPT = """You are an expert penetration testing AI. Analyze web traffic and respond with JSON only.
16
+
17
+ Formats:
18
+ 1. {"action": "report", "vulnerability": {"type": "SQLi|XSS|SSRF|IDOR|RCE", "severity": "critical|high|medium|low", "description": "...", "evidence": "..."}}
19
+ 2. {"action": "request", "method": "GET|POST", "path": "/...", "reasoning": "..."}
20
+ 3. {"action": "command", "cmd": "...", "reasoning": "..."}
21
+ 4. {"action": "complete", "summary": "..."}
22
+
23
+ Respond with ONLY valid JSON."""
24
+
25
+ def validate_messages(messages):
26
+ """Check if messages are valid for chat template"""
27
+ if not messages or not isinstance(messages, list):
28
+ return False
29
+ if len(messages) < 2:
30
+ return False
31
+ for msg in messages:
32
+ if not isinstance(msg, dict):
33
+ return False
34
+ if "role" not in msg or "content" not in msg:
35
+ return False
36
+ if not msg["content"] or not isinstance(msg["content"], str):
37
+ return False
38
+ if msg["role"] not in ["system", "user", "assistant"]:
39
+ return False
40
+ if len(msg["content"].strip()) < 5:
41
+ return False
42
+ return True
43
+
44
+ def load_trendyol():
45
+ print("Loading Trendyol...")
46
+ ds = load_dataset("Trendyol/Trendyol-Cybersecurity-Instruction-Tuning-Dataset", split="train")
47
+ print(f" Raw: {len(ds)}")
48
+
49
+ def convert(ex):
50
+ msgs = [
51
+ {"role": "system", "content": PENTEST_SYSTEM_PROMPT},
52
+ {"role": "user", "content": str(ex["user"]).strip()},
53
+ {"role": "assistant", "content": str(ex["assistant"]).strip()}
54
+ ]
55
+ return {"messages": msgs, "valid": validate_messages(msgs)}
56
+
57
+ ds = ds.map(convert, remove_columns=ds.column_names)
58
+ ds = ds.filter(lambda x: x["valid"])
59
+ ds = ds.remove_columns(["valid"])
60
+ print(f" Valid: {len(ds)}")
61
+ return ds
62
+
63
+ def load_fenrir():
64
+ print("Loading Fenrir...")
65
+ ds = load_dataset("AlicanKiraz0/Cybersecurity-Dataset-Fenrir-v2.0", split="train")
66
+ print(f" Raw: {len(ds)}")
67
+
68
+ def convert(ex):
69
+ msgs = [
70
+ {"role": "system", "content": PENTEST_SYSTEM_PROMPT},
71
+ {"role": "user", "content": str(ex["user"]).strip()},
72
+ {"role": "assistant", "content": str(ex["assistant"]).strip()}
73
+ ]
74
+ return {"messages": msgs, "valid": validate_messages(msgs)}
75
+
76
+ ds = ds.map(convert, remove_columns=ds.column_names)
77
+ ds = ds.filter(lambda x: x["valid"])
78
+ ds = ds.remove_columns(["valid"])
79
+ print(f" Valid: {len(ds)}")
80
+ return ds
81
+
82
+ def load_pentest():
83
+ print("Loading pentest-agent...")
84
+ try:
85
+ ds = load_dataset("jason-oneal/pentest-agent-dataset", data_files="chatml_train.jsonl", split="train")
86
+ print(f" Raw: {len(ds)}")
87
+
88
+ def fix_messages(ex):
89
+ msgs = ex.get("messages", [])
90
+ if not msgs:
91
+ return {"messages": [], "valid": False}
92
+
93
+ # Ensure system prompt
94
+ new_msgs = []
95
+ has_system = False
96
+ for m in msgs:
97
+ if isinstance(m, dict) and "role" in m and "content" in m:
98
+ role = str(m["role"]).strip().lower()
99
+ content = str(m["content"]).strip() if m["content"] else ""
100
+ if role == "system":
101
+ has_system = True
102
+ new_msgs.append({"role": "system", "content": PENTEST_SYSTEM_PROMPT})
103
+ elif role in ["user", "assistant"] and content:
104
+ new_msgs.append({"role": role, "content": content})
105
+
106
+ if not has_system:
107
+ new_msgs.insert(0, {"role": "system", "content": PENTEST_SYSTEM_PROMPT})
108
+
109
+ return {"messages": new_msgs, "valid": validate_messages(new_msgs)}
110
+
111
+ ds = ds.map(fix_messages, remove_columns=ds.column_names)
112
+ ds = ds.filter(lambda x: x["valid"])
113
+ ds = ds.remove_columns(["valid"])
114
+ print(f" Valid: {len(ds)}")
115
+ return ds
116
+ except Exception as e:
117
+ print(f" Error: {e}")
118
+ return None
119
+
120
+ # Load datasets
121
+ print("=" * 50)
122
+ print("LOADING AND VALIDATING DATASETS")
123
+ print("=" * 50)
124
+
125
+ all_ds = []
126
+
127
+ try:
128
+ all_ds.append(load_trendyol())
129
+ except Exception as e:
130
+ print(f"Trendyol error: {e}")
131
+
132
+ try:
133
+ all_ds.append(load_fenrir())
134
+ except Exception as e:
135
+ print(f"Fenrir error: {e}")
136
+
137
+ try:
138
+ pds = load_pentest()
139
+ if pds and len(pds) > 0:
140
+ all_ds.append(pds)
141
+ except Exception as e:
142
+ print(f"Pentest error: {e}")
143
+
144
+ print(f"\nCombining {len(all_ds)} datasets...")
145
+ combined = concatenate_datasets(all_ds)
146
+ combined = combined.shuffle(seed=42)
147
+ print(f"Total valid examples: {len(combined)}")
148
+
149
+ # Split
150
+ split = combined.train_test_split(test_size=0.02, seed=42)
151
+ train_ds = split["train"]
152
+ eval_ds = split["test"]
153
+ print(f"Train: {len(train_ds)}, Eval: {len(eval_ds)}")
154
+
155
+ # Verify a sample
156
+ print("\nSample message structure:")
157
+ sample = train_ds[0]["messages"]
158
+ for m in sample:
159
+ print(f" {m['role']}: {m['content'][:50]}...")
160
+
161
+ # Training
162
+ print("\n" + "=" * 50)
163
+ print("TRAINING")
164
+ print("=" * 50)
165
+
166
+ from peft import LoraConfig
167
+ from trl import SFTTrainer, SFTConfig
168
+
169
+ config = SFTConfig(
170
+ output_dir="qwen2.5-coder-3b-pentest",
171
+ push_to_hub=True,
172
+ hub_model_id="fawazo/qwen2.5-coder-3b-pentest",
173
+ hub_strategy="every_save",
174
+
175
+ num_train_epochs=2,
176
+ per_device_train_batch_size=2,
177
+ gradient_accumulation_steps=8,
178
+ learning_rate=1e-4,
179
+ max_length=2048,
180
+
181
+ gradient_checkpointing=True,
182
+ bf16=True,
183
+
184
+ logging_steps=25,
185
+ save_strategy="steps",
186
+ save_steps=500,
187
+ save_total_limit=2,
188
+
189
+ eval_strategy="steps",
190
+ eval_steps=500,
191
+
192
+ warmup_ratio=0.03,
193
+ lr_scheduler_type="cosine",
194
+
195
+ report_to="trackio",
196
+ project="pentest-agent",
197
+ run_name="qwen-3b-cybersec-v3",
198
+ )
199
+
200
+ peft_config = LoraConfig(
201
+ r=32,
202
+ lora_alpha=64,
203
+ lora_dropout=0.05,
204
+ bias="none",
205
+ task_type="CAUSAL_LM",
206
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
207
+ )
208
+
209
+ print("Initializing trainer...")
210
+ trainer = SFTTrainer(
211
+ model="Qwen/Qwen2.5-Coder-3B",
212
+ train_dataset=train_ds,
213
+ eval_dataset=eval_ds,
214
+ args=config,
215
+ peft_config=peft_config,
216
+ )
217
+
218
+ print("Starting training...")
219
+ trainer.train()
220
+ trainer.push_to_hub()
221
+
222
+ print("\n" + "=" * 50)
223
+ print("COMPLETE!")
224
+ print("https://huggingface.co/fawazo/qwen2.5-coder-3b-pentest")
225
+ print("=" * 50)