geekyrakshit commited on
Commit
883a576
1 Parent(s): 32d5d0c

add: docs for LlamaGuardFineTuner

Browse files
Files changed (1) hide show
  1. guardrails_genie/train/llama_guard.py +111 -0
guardrails_genie/train/llama_guard.py CHANGED
@@ -24,6 +24,19 @@ class DatasetArgs(BaseModel):
24
 
25
 
26
  class LlamaGuardFineTuner:
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  def __init__(
28
  self, wandb_project: str, wandb_entity: str, streamlit_mode: bool = False
29
  ):
@@ -32,6 +45,24 @@ class LlamaGuardFineTuner:
32
  self.streamlit_mode = streamlit_mode
33
 
34
  def load_dataset(self, dataset_args: DatasetArgs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  dataset = load_dataset(dataset_args.dataset_address)
36
  self.train_dataset = (
37
  dataset["train"]
@@ -47,6 +78,22 @@ class LlamaGuardFineTuner:
47
  )
48
 
49
  def load_model(self, model_name: str = "meta-llama/Prompt-Guard-86M"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
51
  self.model_name = model_name
52
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
@@ -55,6 +102,19 @@ class LlamaGuardFineTuner:
55
  )
56
 
57
  def show_dataset_sample(self):
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  if self.streamlit_mode:
59
  st.markdown("### Train Dataset Sample")
60
  st.dataframe(self.train_dataset.to_pandas().head())
@@ -189,6 +249,31 @@ class LlamaGuardFineTuner:
189
  truncation: bool = True,
190
  max_length: int = 512,
191
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  test_scores = self.evaluate_batch(
193
  self.test_dataset["text"],
194
  batch_size=batch_size,
@@ -217,6 +302,32 @@ class LlamaGuardFineTuner:
217
  log_interval: int = 20,
218
  save_interval: int = 1000,
219
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  os.makedirs("checkpoints", exist_ok=True)
221
  wandb.init(
222
  project=self.wandb_project,
 
24
 
25
 
26
  class LlamaGuardFineTuner:
27
+ """
28
+ `LlamaGuardFineTuner` is a class designed to fine-tune and evaluate the
29
+ [Prompt Guard model by Meta LLama](meta-llama/Prompt-Guard-86M) for prompt
30
+ classification tasks, specifically for detecting prompt injection attacks. It
31
+ integrates with Weights & Biases for experiment tracking and optionally
32
+ displays progress in a Streamlit app.
33
+
34
+ Args:
35
+ wandb_project (str): The name of the Weights & Biases project.
36
+ wandb_entity (str): The Weights & Biases entity (user or team).
37
+ streamlit_mode (bool): If True, integrates with Streamlit to display progress.
38
+ """
39
+
40
  def __init__(
41
  self, wandb_project: str, wandb_entity: str, streamlit_mode: bool = False
42
  ):
 
45
  self.streamlit_mode = streamlit_mode
46
 
47
  def load_dataset(self, dataset_args: DatasetArgs):
48
+ """
49
+ Loads the training and testing datasets based on the provided dataset arguments.
50
+
51
+ This function uses the `load_dataset` function from the `datasets` library to load
52
+ the dataset specified by the `dataset_address` attribute of the `dataset_args` parameter.
53
+ It then selects a subset of the training and testing datasets based on the specified
54
+ ranges in `train_dataset_range` and `test_dataset_range` attributes of `dataset_args`.
55
+ If the specified range is less than or equal to 0 or exceeds the length of the dataset,
56
+ the entire dataset is used.
57
+
58
+ Args:
59
+ dataset_args (DatasetArgs): An instance of the `DatasetArgs` class containing
60
+ the dataset address and the ranges for training and testing datasets.
61
+
62
+ Attributes:
63
+ train_dataset: The selected training dataset.
64
+ test_dataset: The selected testing dataset.
65
+ """
66
  dataset = load_dataset(dataset_args.dataset_address)
67
  self.train_dataset = (
68
  dataset["train"]
 
78
  )
79
 
80
  def load_model(self, model_name: str = "meta-llama/Prompt-Guard-86M"):
81
+ """
82
+ Loads the specified pre-trained model and tokenizer for sequence classification tasks.
83
+
84
+ This function sets the device to GPU if available, otherwise defaults to CPU. It then
85
+ loads the tokenizer and model from the Hugging Face model hub using the provided model name.
86
+ The model is moved to the specified device (GPU or CPU).
87
+
88
+ Args:
89
+ model_name (str): The name of the pre-trained model to load.
90
+
91
+ Attributes:
92
+ device (str): The device to run the model on, either "cuda" for GPU or "cpu".
93
+ model_name (str): The name of the loaded pre-trained model.
94
+ tokenizer (AutoTokenizer): The tokenizer associated with the pre-trained model.
95
+ model (AutoModelForSequenceClassification): The loaded pre-trained model for sequence classification.
96
+ """
97
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
98
  self.model_name = model_name
99
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
 
102
  )
103
 
104
  def show_dataset_sample(self):
105
+ """
106
+ Displays a sample of the training and testing datasets using Streamlit.
107
+
108
+ This function checks if the `streamlit_mode` attribute is enabled. If it is,
109
+ it converts the training and testing datasets to pandas DataFrames and displays
110
+ the first few rows of each dataset using Streamlit's `dataframe` function. The
111
+ training dataset sample is displayed under the heading "Train Dataset Sample",
112
+ and the testing dataset sample is displayed under the heading "Test Dataset Sample".
113
+
114
+ Note:
115
+ This function requires the `streamlit` library to be installed and the
116
+ `streamlit_mode` attribute to be set to True.
117
+ """
118
  if self.streamlit_mode:
119
  st.markdown("### Train Dataset Sample")
120
  st.dataframe(self.train_dataset.to_pandas().head())
 
249
  truncation: bool = True,
250
  max_length: int = 512,
251
  ):
252
+ """
253
+ Evaluates the fine-tuned model on the test dataset and visualizes the results.
254
+
255
+ This function evaluates the model by processing the test dataset in batches.
256
+ It computes the test scores using the `evaluate_batch` method, which takes
257
+ several parameters to control the evaluation process, such as batch size,
258
+ positive label, temperature, truncation, and maximum sequence length.
259
+
260
+ After obtaining the test scores, it visualizes the performance of the model
261
+ using two methods:
262
+ 1. `visualize_roc_curve`: Plots the Receiver Operating Characteristic (ROC) curve
263
+ to show the trade-off between the true positive rate and false positive rate.
264
+ 2. `visualize_score_distribution`: Plots the distribution of scores for positive
265
+ and negative examples to provide insights into the model's performance.
266
+
267
+ Args:
268
+ batch_size (int, optional): The number of samples to process in each batch.
269
+ positive_label (int, optional): The label considered as positive for evaluation.
270
+ temperature (float, optional): The temperature parameter for scaling logits.
271
+ truncation (bool, optional): Whether to truncate sequences to the maximum length.
272
+ max_length (int, optional): The maximum length of sequences after truncation.
273
+
274
+ Returns:
275
+ list[float]: The test scores obtained from the evaluation.
276
+ """
277
  test_scores = self.evaluate_batch(
278
  self.test_dataset["text"],
279
  batch_size=batch_size,
 
302
  log_interval: int = 20,
303
  save_interval: int = 1000,
304
  ):
305
+ """
306
+ Fine-tunes the pre-trained LlamaGuard model on the training dataset for a single epoch.
307
+
308
+ This function sets up and executes the training loop for the LlamaGuard model.
309
+ It initializes the Weights & Biases (wandb) logging, configures the model's
310
+ classifier layer to match the specified number of classes, and sets the model
311
+ to training mode. The function uses an AdamW optimizer to update the model
312
+ parameters based on the computed loss.
313
+
314
+ The training process involves iterating over the training dataset in batches,
315
+ computing the loss for each batch, and updating the model parameters. The
316
+ function logs the loss to wandb at specified intervals and optionally displays
317
+ a progress bar using Streamlit if `streamlit_mode` is enabled. Model checkpoints
318
+ are saved at specified intervals during training.
319
+
320
+ Args:
321
+ batch_size (int, optional): The number of samples per batch during training.
322
+ lr (float, optional): The learning rate for the optimizer.
323
+ num_classes (int, optional): The number of output classes for the classifier.
324
+ log_interval (int, optional): The interval (in batches) at which to log the loss.
325
+ save_interval (int, optional): The interval (in batches) at which to save model checkpoints.
326
+
327
+ Note:
328
+ This function requires the `wandb` and `streamlit` libraries to be installed
329
+ and configured appropriately.
330
+ """
331
  os.makedirs("checkpoints", exist_ok=True)
332
  wandb.init(
333
  project=self.wandb_project,