Spaces:
Running
Running
geekyrakshit
commited on
Commit
•
8382f82
1
Parent(s):
3a7ead3
update: training code
Browse files
guardrails_genie/train_classifier.py
CHANGED
@@ -42,10 +42,11 @@ def train_binary_classifier(
|
|
42 |
dataset_repo: str = "geekyrakshit/prompt-injection-dataset",
|
43 |
model_name: str = "distilbert/distilbert-base-uncased",
|
44 |
prompt_column_name: str = "prompt",
|
45 |
-
learning_rate: float =
|
46 |
batch_size: int = 16,
|
47 |
num_epochs: int = 2,
|
48 |
weight_decay: float = 0.01,
|
|
|
49 |
streamlit_mode: bool = False,
|
50 |
):
|
51 |
wandb.init(project=project_name, entity=entity_name, name=run_name)
|
@@ -88,7 +89,8 @@ def train_binary_classifier(
|
|
88 |
num_train_epochs=num_epochs,
|
89 |
weight_decay=weight_decay,
|
90 |
eval_strategy="epoch",
|
91 |
-
save_strategy="
|
|
|
92 |
load_best_model_at_end=True,
|
93 |
push_to_hub=False,
|
94 |
report_to="wandb",
|
|
|
42 |
dataset_repo: str = "geekyrakshit/prompt-injection-dataset",
|
43 |
model_name: str = "distilbert/distilbert-base-uncased",
|
44 |
prompt_column_name: str = "prompt",
|
45 |
+
learning_rate: float = 1e-5,
|
46 |
batch_size: int = 16,
|
47 |
num_epochs: int = 2,
|
48 |
weight_decay: float = 0.01,
|
49 |
+
save_steps: int = 1000,
|
50 |
streamlit_mode: bool = False,
|
51 |
):
|
52 |
wandb.init(project=project_name, entity=entity_name, name=run_name)
|
|
|
89 |
num_train_epochs=num_epochs,
|
90 |
weight_decay=weight_decay,
|
91 |
eval_strategy="epoch",
|
92 |
+
save_strategy="steps",
|
93 |
+
save_steps=save_steps,
|
94 |
load_best_model_at_end=True,
|
95 |
push_to_hub=False,
|
96 |
report_to="wandb",
|