fawazo commited on
Commit
a71f7c1
·
verified ·
1 Parent(s): bea46e3

Upload train_pentest_sft.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_pentest_sft.py +79 -0
train_pentest_sft.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # dependencies = [
3
+ # "trl>=0.12.0",
4
+ # "peft>=0.7.0",
5
+ # "transformers>=4.36.0",
6
+ # "accelerate>=0.24.0",
7
+ # "trackio",
8
+ # ]
9
+ # ///
10
+
11
+ from datasets import load_dataset
12
+ from peft import LoraConfig
13
+ from trl import SFTTrainer, SFTConfig
14
+
15
+ # Load dataset (ChatML format)
16
+ print("Loading pentest dataset...")
17
+ dataset = load_dataset(
18
+ "jason-oneal/pentest-agent-dataset",
19
+ data_files="chatml_train.jsonl",
20
+ split="train"
21
+ )
22
+ print(f"Dataset loaded: {len(dataset)} examples")
23
+
24
+ # Train/eval split
25
+ dataset_split = dataset.train_test_split(test_size=0.1, seed=42)
26
+ train_dataset = dataset_split["train"]
27
+ eval_dataset = dataset_split["test"]
28
+
29
+ # Training configuration
30
+ config = SFTConfig(
31
+ output_dir="qwen2.5-coder-1.5b-pentest",
32
+ push_to_hub=True,
33
+ hub_model_id="fawazo/qwen2.5-coder-1.5b-pentest",
34
+ hub_strategy="every_save",
35
+
36
+ num_train_epochs=3,
37
+ per_device_train_batch_size=4,
38
+ gradient_accumulation_steps=4,
39
+ learning_rate=2e-5,
40
+
41
+ logging_steps=10,
42
+ save_strategy="steps",
43
+ save_steps=200,
44
+ save_total_limit=2,
45
+
46
+ eval_strategy="steps",
47
+ eval_steps=200,
48
+
49
+ warmup_ratio=0.1,
50
+ lr_scheduler_type="cosine",
51
+
52
+ report_to="trackio",
53
+ project="pentest-coder",
54
+ run_name="qwen2.5-coder-1.5b-sft",
55
+ )
56
+
57
+ # LoRA config for efficient training
58
+ peft_config = LoraConfig(
59
+ r=16,
60
+ lora_alpha=32,
61
+ lora_dropout=0.05,
62
+ bias="none",
63
+ task_type="CAUSAL_LM",
64
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
65
+ )
66
+
67
+ # Train
68
+ print("Starting training...")
69
+ trainer = SFTTrainer(
70
+ model="Qwen/Qwen2.5-Coder-1.5B",
71
+ train_dataset=train_dataset,
72
+ eval_dataset=eval_dataset,
73
+ args=config,
74
+ peft_config=peft_config,
75
+ )
76
+
77
+ trainer.train()
78
+ trainer.push_to_hub()
79
+ print("Model saved to: https://huggingface.co/fawazo/qwen2.5-coder-1.5b-pentest")