geekyrakshit commited on
Commit
5e33295
·
1 Parent(s): 177344c

update: LlamaGuardFineTuner

Browse files
.gitignore CHANGED
@@ -168,4 +168,5 @@ temp.txt
168
  binary-classifier/
169
  wandb/
170
  artifacts/
171
- evaluation_results/
 
 
168
  binary-classifier/
169
  wandb/
170
  artifacts/
171
+ evaluation_results/
172
+ checkpoints/
application_pages/llama_guard_fine_tuning.py CHANGED
@@ -1,10 +1,16 @@
 
 
1
  import streamlit as st
2
 
3
  from guardrails_genie.train.llama_guard import DatasetArgs, LlamaGuardFineTuner
4
 
5
 
6
  def initialize_session_state():
7
- st.session_state.llama_guard_fine_tuner = LlamaGuardFineTuner(streamlit_mode=True)
 
 
 
 
8
  if "dataset_address" not in st.session_state:
9
  st.session_state.dataset_address = ""
10
  if "train_dataset_range" not in st.session_state:
@@ -25,6 +31,14 @@ def initialize_session_state():
25
  st.session_state.evaluation_batch_size = None
26
  if "evaluation_temperature" not in st.session_state:
27
  st.session_state.evaluation_temperature = None
 
 
 
 
 
 
 
 
28
 
29
 
30
  initialize_session_state()
@@ -43,18 +57,34 @@ if st.session_state.dataset_address != "":
43
  st.session_state.train_dataset_range = train_dataset_range
44
  st.session_state.test_dataset_range = test_dataset_range
45
 
46
- model_name = st.sidebar.selectbox(
47
- "Model Name",
48
- ["meta-llama/Prompt-Guard-86M"],
49
  )
50
  st.session_state.model_name = model_name
51
 
 
 
 
52
  preview_dataset = st.sidebar.toggle("Preview Dataset")
53
  st.session_state.preview_dataset = preview_dataset
54
 
55
  evaluate_model = st.sidebar.toggle("Evaluate Model")
56
  st.session_state.evaluate_model = evaluate_model
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  load_fine_tuner_button = st.sidebar.button("Load Fine-Tuner")
59
  st.session_state.load_fine_tuner_button = load_fine_tuner_button
60
 
@@ -68,13 +98,19 @@ if st.session_state.dataset_address != "":
68
  )
69
  )
70
  st.session_state.llama_guard_fine_tuner.load_model(
71
- model_name=st.session_state.model_name
 
 
 
 
 
72
  )
73
  if st.session_state.preview_dataset:
74
  st.session_state.llama_guard_fine_tuner.show_dataset_sample()
75
  if st.session_state.evaluate_model:
76
  st.session_state.llama_guard_fine_tuner.evaluate_model(
77
- batch_size=32,
78
- temperature=3.0,
 
79
  )
80
  st.session_state.is_fine_tuner_loaded = True
 
1
+ import os
2
+
3
  import streamlit as st
4
 
5
  from guardrails_genie.train.llama_guard import DatasetArgs, LlamaGuardFineTuner
6
 
7
 
8
  def initialize_session_state():
9
+ st.session_state.llama_guard_fine_tuner = LlamaGuardFineTuner(
10
+ wandb_project=os.getenv("WANDB_PROJECT_NAME"),
11
+ wandb_entity=os.getenv("WANDB_ENTITY_NAME"),
12
+ streamlit_mode=True,
13
+ )
14
  if "dataset_address" not in st.session_state:
15
  st.session_state.dataset_address = ""
16
  if "train_dataset_range" not in st.session_state:
 
31
  st.session_state.evaluation_batch_size = None
32
  if "evaluation_temperature" not in st.session_state:
33
  st.session_state.evaluation_temperature = None
34
+ if "checkpoint" not in st.session_state:
35
+ st.session_state.checkpoint = None
36
+ if "eval_batch_size" not in st.session_state:
37
+ st.session_state.eval_batch_size = 32
38
+ if "eval_positive_label" not in st.session_state:
39
+ st.session_state.eval_positive_label = 2
40
+ if "eval_temperature" not in st.session_state:
41
+ st.session_state.eval_temperature = 1.0
42
 
43
 
44
  initialize_session_state()
 
57
  st.session_state.train_dataset_range = train_dataset_range
58
  st.session_state.test_dataset_range = test_dataset_range
59
 
60
+ model_name = st.sidebar.text_input(
61
+ label="Model Name", value="meta-llama/Prompt-Guard-86M"
 
62
  )
63
  st.session_state.model_name = model_name
64
 
65
+ checkpoint = st.sidebar.text_input(label="Fine-tuned Model Checkpoint", value="")
66
+ st.session_state.checkpoint = checkpoint
67
+
68
  preview_dataset = st.sidebar.toggle("Preview Dataset")
69
  st.session_state.preview_dataset = preview_dataset
70
 
71
  evaluate_model = st.sidebar.toggle("Evaluate Model")
72
  st.session_state.evaluate_model = evaluate_model
73
 
74
+ if st.session_state.evaluate_model:
75
+ eval_batch_size = st.sidebar.slider(
76
+ label="Eval Batch Size", min_value=16, max_value=1024, value=32
77
+ )
78
+ st.session_state.eval_batch_size = eval_batch_size
79
+
80
+ eval_positive_label = st.sidebar.number_input("EVal Positive Label", value=2)
81
+ st.session_state.eval_positive_label = eval_positive_label
82
+
83
+ eval_temperature = st.sidebar.slider(
84
+ label="Eval Temperature", min_value=0.0, max_value=5.0, value=1.0
85
+ )
86
+ st.session_state.eval_temperature = eval_temperature
87
+
88
  load_fine_tuner_button = st.sidebar.button("Load Fine-Tuner")
89
  st.session_state.load_fine_tuner_button = load_fine_tuner_button
90
 
 
98
  )
99
  )
100
  st.session_state.llama_guard_fine_tuner.load_model(
101
+ model_name=st.session_state.model_name,
102
+ checkpoint=(
103
+ None
104
+ if st.session_state.checkpoint == ""
105
+ else st.session_state.checkpoint
106
+ ),
107
  )
108
  if st.session_state.preview_dataset:
109
  st.session_state.llama_guard_fine_tuner.show_dataset_sample()
110
  if st.session_state.evaluate_model:
111
  st.session_state.llama_guard_fine_tuner.evaluate_model(
112
+ batch_size=st.session_state.eval_batch_size,
113
+ positive_label=st.session_state.eval_positive_label,
114
+ temperature=st.session_state.eval_temperature,
115
  )
116
  st.session_state.is_fine_tuner_loaded = True
guardrails_genie/train/__init__.py CHANGED
@@ -1,4 +1,4 @@
 
1
  from .train_classifier import train_binary_classifier
2
- from .llama_guard import LlamaGuardFineTuner, DatasetArgs
3
 
4
- __all__ = ["train_binary_classifier", "LlamaGuardFineTuner", "DatasetArgs"]
 
1
+ from .llama_guard import DatasetArgs, LlamaGuardFineTuner
2
  from .train_classifier import train_binary_classifier
 
3
 
4
+ __all__ = ["train_binary_classifier", "LlamaGuardFineTuner", "DatasetArgs"]
guardrails_genie/train/llama_guard.py CHANGED
@@ -1,5 +1,7 @@
1
  import os
2
  import shutil
 
 
3
 
4
  import plotly.graph_objects as go
5
  import streamlit as st
@@ -7,15 +9,16 @@ import torch
7
  import torch.nn as nn
8
  import torch.nn.functional as F
9
  import torch.optim as optim
10
- import wandb
11
  from datasets import load_dataset
12
  from pydantic import BaseModel
13
  from rich.progress import track
14
- from safetensors.torch import save_model
15
  from sklearn.metrics import roc_auc_score, roc_curve
16
  from torch.utils.data import DataLoader
17
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
18
 
 
 
19
 
20
  class DatasetArgs(BaseModel):
21
  dataset_address: str
@@ -30,7 +33,7 @@ class LlamaGuardFineTuner:
30
  classification tasks, specifically for detecting prompt injection attacks. It
31
  integrates with Weights & Biases for experiment tracking and optionally
32
  displays progress in a Streamlit app.
33
-
34
  !!! example "Sample Usage"
35
  ```python
36
  from guardrails_genie.train.llama_guard import LlamaGuardFineTuner, DatasetArgs
@@ -98,7 +101,11 @@ class LlamaGuardFineTuner:
98
  else dataset["test"].select(range(dataset_args.test_dataset_range))
99
  )
100
 
101
- def load_model(self, model_name: str = "meta-llama/Prompt-Guard-86M"):
 
 
 
 
102
  """
103
  Loads the specified pre-trained model and tokenizer for sequence classification tasks.
104
 
@@ -118,9 +125,20 @@ class LlamaGuardFineTuner:
118
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
119
  self.model_name = model_name
120
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
121
- self.model = AutoModelForSequenceClassification.from_pretrained(model_name).to(
122
- self.device
123
- )
 
 
 
 
 
 
 
 
 
 
 
124
 
125
  def show_dataset_sample(self):
126
  """
 
1
  import os
2
  import shutil
3
+ from glob import glob
4
+ from typing import Optional
5
 
6
  import plotly.graph_objects as go
7
  import streamlit as st
 
9
  import torch.nn as nn
10
  import torch.nn.functional as F
11
  import torch.optim as optim
 
12
  from datasets import load_dataset
13
  from pydantic import BaseModel
14
  from rich.progress import track
15
+ from safetensors.torch import load_model, save_model
16
  from sklearn.metrics import roc_auc_score, roc_curve
17
  from torch.utils.data import DataLoader
18
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
19
 
20
+ import wandb
21
+
22
 
23
  class DatasetArgs(BaseModel):
24
  dataset_address: str
 
33
  classification tasks, specifically for detecting prompt injection attacks. It
34
  integrates with Weights & Biases for experiment tracking and optionally
35
  displays progress in a Streamlit app.
36
+
37
  !!! example "Sample Usage"
38
  ```python
39
  from guardrails_genie.train.llama_guard import LlamaGuardFineTuner, DatasetArgs
 
101
  else dataset["test"].select(range(dataset_args.test_dataset_range))
102
  )
103
 
104
+ def load_model(
105
+ self,
106
+ model_name: str = "meta-llama/Prompt-Guard-86M",
107
+ checkpoint: Optional[str] = None,
108
+ ):
109
  """
110
  Loads the specified pre-trained model and tokenizer for sequence classification tasks.
111
 
 
125
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
126
  self.model_name = model_name
127
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
128
+ if checkpoint is None:
129
+ self.model = AutoModelForSequenceClassification.from_pretrained(
130
+ model_name
131
+ ).to(self.device)
132
+ else:
133
+ api = wandb.Api()
134
+ artifact = api.artifact(checkpoint.removeprefix("wandb://"))
135
+ artifact_dir = artifact.download()
136
+ model_file_path = glob(os.path.join(artifact_dir, "model-*.safetensors"))[0]
137
+ self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
138
+ self.model.classifier = nn.Linear(self.model.classifier.in_features, 2)
139
+ self.model.num_labels = 2
140
+ load_model(self.model, model_file_path)
141
+ self.model = self.model.to(self.device)
142
 
143
  def show_dataset_sample(self):
144
  """
guardrails_genie/train/train_classifier.py CHANGED
@@ -1,7 +1,6 @@
1
  import evaluate
2
  import numpy as np
3
  import streamlit as st
4
- import wandb
5
  from datasets import load_dataset
6
  from transformers import (
7
  AutoModelForSequenceClassification,
@@ -11,6 +10,7 @@ from transformers import (
11
  TrainingArguments,
12
  )
13
 
 
14
  from guardrails_genie.utils import StreamlitProgressbarCallback
15
 
16
 
 
1
  import evaluate
2
  import numpy as np
3
  import streamlit as st
 
4
  from datasets import load_dataset
5
  from transformers import (
6
  AutoModelForSequenceClassification,
 
10
  TrainingArguments,
11
  )
12
 
13
+ import wandb
14
  from guardrails_genie.utils import StreamlitProgressbarCallback
15
 
16