fawazo commited on
Commit
b3589b2
·
verified ·
1 Parent(s): bde55d4

Upload train_pentest_v2.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_pentest_v2.py +191 -0
train_pentest_v2.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # dependencies = [
3
+ # "trl>=0.12.0",
4
+ # "peft>=0.7.0",
5
+ # "transformers>=4.36.0",
6
+ # "accelerate>=0.24.0",
7
+ # "trackio",
8
+ # "datasets>=2.14.0",
9
+ # ]
10
+ # ///
11
+
12
+ import json
13
+ import traceback
14
+ from datasets import load_dataset, concatenate_datasets, Dataset
15
+ from peft import LoraConfig
16
+ from trl import SFTTrainer, SFTConfig
17
+
18
+ # Custom system prompt for pentesting JSON output
19
+ PENTEST_SYSTEM_PROMPT = """You are an expert penetration testing AI assistant. Analyze web traffic and respond with JSON only.
20
+
21
+ Response formats:
22
+ 1. Vulnerability found:
23
+ {"action": "report", "vulnerability": {"type": "SQLi|XSS|SSRF|IDOR|RCE|LFI|XXE", "severity": "critical|high|medium|low", "description": "...", "evidence": "..."}}
24
+
25
+ 2. Send follow-up request:
26
+ {"action": "request", "method": "GET|POST", "path": "/...", "headers": {}, "body": "", "reasoning": "..."}
27
+
28
+ 3. Run command:
29
+ {"action": "command", "cmd": "...", "reasoning": "..."}
30
+
31
+ 4. Analysis complete:
32
+ {"action": "complete", "summary": "...", "tested": ["..."]}
33
+
34
+ Respond with ONLY valid JSON."""
35
+
36
+ def load_trendyol():
37
+ print("Loading Trendyol dataset...")
38
+ ds = load_dataset("Trendyol/Trendyol-Cybersecurity-Instruction-Tuning-Dataset", split="train")
39
+ print(f" Loaded {len(ds)} examples")
40
+
41
+ def convert(example):
42
+ return {
43
+ "messages": [
44
+ {"role": "system", "content": PENTEST_SYSTEM_PROMPT},
45
+ {"role": "user", "content": example["user"]},
46
+ {"role": "assistant", "content": example["assistant"]}
47
+ ]
48
+ }
49
+ return ds.map(convert, remove_columns=ds.column_names)
50
+
51
+ def load_fenrir():
52
+ print("Loading Fenrir v2.0 dataset...")
53
+ ds = load_dataset("AlicanKiraz0/Cybersecurity-Dataset-Fenrir-v2.0", split="train")
54
+ print(f" Loaded {len(ds)} examples")
55
+
56
+ def convert(example):
57
+ return {
58
+ "messages": [
59
+ {"role": "system", "content": PENTEST_SYSTEM_PROMPT},
60
+ {"role": "user", "content": example["user"]},
61
+ {"role": "assistant", "content": example["assistant"]}
62
+ ]
63
+ }
64
+ cols = [c for c in ds.column_names]
65
+ return ds.map(convert, remove_columns=cols)
66
+
67
+ def load_pentest():
68
+ print("Loading pentest-agent dataset...")
69
+ try:
70
+ ds = load_dataset("jason-oneal/pentest-agent-dataset", data_files="chatml_train.jsonl", split="train")
71
+ print(f" Loaded {len(ds)} examples")
72
+
73
+ def update_system(example):
74
+ messages = example["messages"]
75
+ if messages and len(messages) > 0:
76
+ if messages[0]["role"] == "system":
77
+ messages[0]["content"] = PENTEST_SYSTEM_PROMPT
78
+ else:
79
+ messages.insert(0, {"role": "system", "content": PENTEST_SYSTEM_PROMPT})
80
+ return {"messages": messages}
81
+ return ds.map(update_system)
82
+ except Exception as e:
83
+ print(f" Warning: Could not load pentest-agent: {e}")
84
+ return None
85
+
86
+ # Main execution
87
+ print("=" * 50)
88
+ print("LOADING DATASETS")
89
+ print("=" * 50)
90
+
91
+ datasets_list = []
92
+
93
+ try:
94
+ ds1 = load_trendyol()
95
+ datasets_list.append(ds1)
96
+ except Exception as e:
97
+ print(f"ERROR loading Trendyol: {e}")
98
+ traceback.print_exc()
99
+
100
+ try:
101
+ ds2 = load_fenrir()
102
+ datasets_list.append(ds2)
103
+ except Exception as e:
104
+ print(f"ERROR loading Fenrir: {e}")
105
+ traceback.print_exc()
106
+
107
+ try:
108
+ ds3 = load_pentest()
109
+ if ds3:
110
+ datasets_list.append(ds3)
111
+ except Exception as e:
112
+ print(f"ERROR loading pentest-agent: {e}")
113
+ traceback.print_exc()
114
+
115
+ if not datasets_list:
116
+ raise RuntimeError("No datasets loaded!")
117
+
118
+ print(f"\nCombining {len(datasets_list)} datasets...")
119
+ combined = concatenate_datasets(datasets_list)
120
+ print(f"Total: {len(combined)} examples")
121
+
122
+ combined = combined.shuffle(seed=42)
123
+ split_ds = combined.train_test_split(test_size=0.02, seed=42)
124
+ train_ds = split_ds["train"]
125
+ eval_ds = split_ds["test"]
126
+ print(f"Train: {len(train_ds)}, Eval: {len(eval_ds)}")
127
+
128
+ # Training config
129
+ print("\n" + "=" * 50)
130
+ print("STARTING TRAINING")
131
+ print("=" * 50)
132
+
133
+ config = SFTConfig(
134
+ output_dir="qwen2.5-coder-3b-pentest",
135
+ push_to_hub=True,
136
+ hub_model_id="fawazo/qwen2.5-coder-3b-pentest",
137
+ hub_strategy="every_save",
138
+
139
+ num_train_epochs=2,
140
+ per_device_train_batch_size=2,
141
+ gradient_accumulation_steps=8,
142
+ learning_rate=1e-4,
143
+ max_length=2048,
144
+
145
+ gradient_checkpointing=True,
146
+ bf16=True,
147
+
148
+ logging_steps=25,
149
+ save_strategy="steps",
150
+ save_steps=500,
151
+ save_total_limit=2,
152
+
153
+ eval_strategy="steps",
154
+ eval_steps=500,
155
+
156
+ warmup_ratio=0.03,
157
+ lr_scheduler_type="cosine",
158
+
159
+ report_to="trackio",
160
+ project="pentest-agent",
161
+ run_name="qwen-3b-cybersec-150k",
162
+ )
163
+
164
+ peft_config = LoraConfig(
165
+ r=32,
166
+ lora_alpha=64,
167
+ lora_dropout=0.05,
168
+ bias="none",
169
+ task_type="CAUSAL_LM",
170
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
171
+ )
172
+
173
+ print("Loading model Qwen/Qwen2.5-Coder-3B...")
174
+ trainer = SFTTrainer(
175
+ model="Qwen/Qwen2.5-Coder-3B",
176
+ train_dataset=train_ds,
177
+ eval_dataset=eval_ds,
178
+ args=config,
179
+ peft_config=peft_config,
180
+ )
181
+
182
+ print("Training...")
183
+ trainer.train()
184
+
185
+ print("Pushing to Hub...")
186
+ trainer.push_to_hub()
187
+
188
+ print("\n" + "=" * 50)
189
+ print("TRAINING COMPLETE!")
190
+ print("Model: https://huggingface.co/fawazo/qwen2.5-coder-3b-pentest")
191
+ print("=" * 50)