geekyrakshit commited on
Commit
8382f82
1 Parent(s): 3a7ead3

update: training code

Browse files
guardrails_genie/train_classifier.py CHANGED
@@ -42,10 +42,11 @@ def train_binary_classifier(
42
  dataset_repo: str = "geekyrakshit/prompt-injection-dataset",
43
  model_name: str = "distilbert/distilbert-base-uncased",
44
  prompt_column_name: str = "prompt",
45
- learning_rate: float = 2e-5,
46
  batch_size: int = 16,
47
  num_epochs: int = 2,
48
  weight_decay: float = 0.01,
 
49
  streamlit_mode: bool = False,
50
  ):
51
  wandb.init(project=project_name, entity=entity_name, name=run_name)
@@ -88,7 +89,8 @@ def train_binary_classifier(
88
  num_train_epochs=num_epochs,
89
  weight_decay=weight_decay,
90
  eval_strategy="epoch",
91
- save_strategy="epoch",
 
92
  load_best_model_at_end=True,
93
  push_to_hub=False,
94
  report_to="wandb",
 
42
  dataset_repo: str = "geekyrakshit/prompt-injection-dataset",
43
  model_name: str = "distilbert/distilbert-base-uncased",
44
  prompt_column_name: str = "prompt",
45
+ learning_rate: float = 1e-5,
46
  batch_size: int = 16,
47
  num_epochs: int = 2,
48
  weight_decay: float = 0.01,
49
+ save_steps: int = 1000,
50
  streamlit_mode: bool = False,
51
  ):
52
  wandb.init(project=project_name, entity=entity_name, name=run_name)
 
89
  num_train_epochs=num_epochs,
90
  weight_decay=weight_decay,
91
  eval_strategy="epoch",
92
+ save_strategy="steps",
93
+ save_steps=save_steps,
94
  load_best_model_at_end=True,
95
  push_to_hub=False,
96
  report_to="wandb",