geekyrakshit commited on
Commit
98ced8b
1 Parent(s): 1af55c6

add: train_binary_classifier

Browse files
guardrails_genie/train_classifier.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import evaluate
2
+ import numpy as np
3
+ import wandb
4
+ from datasets import load_dataset
5
+ from transformers import (
6
+ AutoModelForSequenceClassification,
7
+ AutoTokenizer,
8
+ DataCollatorWithPadding,
9
+ Trainer,
10
+ TrainingArguments,
11
+ )
12
+
13
+
14
+ def train_binary_classifier(
15
+ project_name: str,
16
+ entity_name: str,
17
+ dataset_repo: str = "geekyrakshit/prompt-injection-dataset",
18
+ model_name: str = "distilbert/distilbert-base-uncased",
19
+ learning_rate: float = 2e-5,
20
+ batch_size: int = 16,
21
+ num_epochs: int = 2,
22
+ weight_decay: float = 0.01,
23
+ ):
24
+ wandb.init(project=project_name, entity=entity_name)
25
+ dataset = load_dataset(dataset_repo)
26
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
27
+
28
+ def preprocess_function(examples):
29
+ return tokenizer(examples["prompt"], truncation=True)
30
+
31
+ tokenized_datasets = dataset.map(preprocess_function, batched=True)
32
+ data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
33
+ accuracy = evaluate.load("accuracy")
34
+
35
+ def compute_metrics(eval_pred):
36
+ predictions, labels = eval_pred
37
+ predictions = np.argmax(predictions, axis=1)
38
+ return accuracy.compute(predictions=predictions, references=labels)
39
+
40
+ id2label = {0: "SAFE", 1: "INJECTION"}
41
+ label2id = {"SAFE": 0, "INJECTION": 1}
42
+
43
+ model = AutoModelForSequenceClassification.from_pretrained(
44
+ model_name,
45
+ num_labels=2,
46
+ id2label=id2label,
47
+ label2id=label2id,
48
+ )
49
+
50
+ trainer = Trainer(
51
+ model=model,
52
+ args=TrainingArguments(
53
+ output_dir="binary-classifier",
54
+ learning_rate=learning_rate,
55
+ per_device_train_batch_size=batch_size,
56
+ per_device_eval_batch_size=batch_size,
57
+ num_train_epochs=num_epochs,
58
+ weight_decay=weight_decay,
59
+ eval_strategy="epoch",
60
+ save_strategy="epoch",
61
+ load_best_model_at_end=True,
62
+ push_to_hub=True,
63
+ report_to="wandb",
64
+ ),
65
+ train_dataset=tokenized_datasets["train"],
66
+ eval_dataset=tokenized_datasets["test"],
67
+ processing_class=tokenizer,
68
+ data_collator=data_collator,
69
+ compute_metrics=compute_metrics,
70
+ )
71
+ trainer.train()