geekyrakshit commited on
Commit
2900eb1
1 Parent(s): dfbca8a

add: LlamaGuardFineTuner.train

Browse files
Files changed (1) hide show
  1. guardrails_genie/train/llama_guard.py +63 -13
guardrails_genie/train/llama_guard.py CHANGED
@@ -1,11 +1,18 @@
 
 
1
  import plotly.graph_objects as go
2
  import streamlit as st
3
  import torch
 
4
  import torch.nn.functional as F
 
 
5
  from datasets import load_dataset
6
  from pydantic import BaseModel
7
  from rich.progress import track
 
8
  from sklearn.metrics import roc_auc_score, roc_curve
 
9
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
10
 
11
 
@@ -16,7 +23,11 @@ class DatasetArgs(BaseModel):
16
 
17
 
18
  class LlamaGuardFineTuner:
19
- def __init__(self, streamlit_mode: bool = False):
 
 
 
 
20
  self.streamlit_mode = streamlit_mode
21
 
22
  def load_dataset(self, dataset_args: DatasetArgs):
@@ -36,6 +47,7 @@ class LlamaGuardFineTuner:
36
 
37
  def load_model(self, model_name: str = "meta-llama/Prompt-Guard-86M"):
38
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
 
39
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
40
  self.model = AutoModelForSequenceClassification.from_pretrained(model_name).to(
41
  self.device
@@ -101,7 +113,6 @@ class LlamaGuardFineTuner:
101
  test_labels = [int(elt) for elt in self.test_dataset["label"]]
102
  fpr, tpr, _ = roc_curve(test_labels, test_scores)
103
  roc_auc = roc_auc_score(test_labels, test_scores)
104
-
105
  fig = go.Figure()
106
  fig.add_trace(
107
  go.Scatter(
@@ -121,7 +132,6 @@ class LlamaGuardFineTuner:
121
  line=dict(color="navy", width=2, dash="dash"),
122
  )
123
  )
124
-
125
  fig.update_layout(
126
  title="Receiver Operating Characteristic",
127
  xaxis_title="False Positive Rate",
@@ -130,7 +140,6 @@ class LlamaGuardFineTuner:
130
  yaxis=dict(range=[0.0, 1.05]),
131
  legend=dict(x=0.8, y=0.2),
132
  )
133
-
134
  if self.streamlit_mode:
135
  st.plotly_chart(fig)
136
  else:
@@ -140,10 +149,7 @@ class LlamaGuardFineTuner:
140
  test_labels = [int(elt) for elt in self.test_dataset["label"]]
141
  positive_scores = [scores[i] for i in range(500) if test_labels[i] == 1]
142
  negative_scores = [scores[i] for i in range(500) if test_labels[i] == 0]
143
-
144
  fig = go.Figure()
145
-
146
- # Plotting positive scores
147
  fig.add_trace(
148
  go.Histogram(
149
  x=positive_scores,
@@ -153,8 +159,6 @@ class LlamaGuardFineTuner:
153
  opacity=0.75,
154
  )
155
  )
156
-
157
- # Plotting negative scores
158
  fig.add_trace(
159
  go.Histogram(
160
  x=negative_scores,
@@ -164,8 +168,6 @@ class LlamaGuardFineTuner:
164
  opacity=0.75,
165
  )
166
  )
167
-
168
- # Updating layout
169
  fig.update_layout(
170
  title="Score Distribution for Positive and Negative Examples",
171
  xaxis_title="Score",
@@ -173,8 +175,6 @@ class LlamaGuardFineTuner:
173
  barmode="overlay",
174
  legend_title="Scores",
175
  )
176
-
177
- # Display the plot
178
  if self.streamlit_mode:
179
  st.plotly_chart(fig)
180
  else:
@@ -199,3 +199,53 @@ class LlamaGuardFineTuner:
199
  self.visualize_roc_curve(test_scores)
200
  self.visualize_score_distribution(test_scores)
201
  return test_scores
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
  import plotly.graph_objects as go
4
  import streamlit as st
5
  import torch
6
+ import torch.nn as nn
7
  import torch.nn.functional as F
8
+ import torch.optim as optim
9
+ import wandb
10
  from datasets import load_dataset
11
  from pydantic import BaseModel
12
  from rich.progress import track
13
+ from safetensors.torch import save_model
14
  from sklearn.metrics import roc_auc_score, roc_curve
15
+ from torch.utils.data import DataLoader
16
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
17
 
18
 
 
23
 
24
 
25
  class LlamaGuardFineTuner:
26
+ def __init__(
27
+ self, wandb_project: str, wandb_entity: str, streamlit_mode: bool = False
28
+ ):
29
+ self.wandb_project = wandb_project
30
+ self.wandb_entity = wandb_entity
31
  self.streamlit_mode = streamlit_mode
32
 
33
  def load_dataset(self, dataset_args: DatasetArgs):
 
47
 
48
  def load_model(self, model_name: str = "meta-llama/Prompt-Guard-86M"):
49
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
50
+ self.model_name = model_name
51
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
52
  self.model = AutoModelForSequenceClassification.from_pretrained(model_name).to(
53
  self.device
 
113
  test_labels = [int(elt) for elt in self.test_dataset["label"]]
114
  fpr, tpr, _ = roc_curve(test_labels, test_scores)
115
  roc_auc = roc_auc_score(test_labels, test_scores)
 
116
  fig = go.Figure()
117
  fig.add_trace(
118
  go.Scatter(
 
132
  line=dict(color="navy", width=2, dash="dash"),
133
  )
134
  )
 
135
  fig.update_layout(
136
  title="Receiver Operating Characteristic",
137
  xaxis_title="False Positive Rate",
 
140
  yaxis=dict(range=[0.0, 1.05]),
141
  legend=dict(x=0.8, y=0.2),
142
  )
 
143
  if self.streamlit_mode:
144
  st.plotly_chart(fig)
145
  else:
 
149
  test_labels = [int(elt) for elt in self.test_dataset["label"]]
150
  positive_scores = [scores[i] for i in range(500) if test_labels[i] == 1]
151
  negative_scores = [scores[i] for i in range(500) if test_labels[i] == 0]
 
152
  fig = go.Figure()
 
 
153
  fig.add_trace(
154
  go.Histogram(
155
  x=positive_scores,
 
159
  opacity=0.75,
160
  )
161
  )
 
 
162
  fig.add_trace(
163
  go.Histogram(
164
  x=negative_scores,
 
168
  opacity=0.75,
169
  )
170
  )
 
 
171
  fig.update_layout(
172
  title="Score Distribution for Positive and Negative Examples",
173
  xaxis_title="Score",
 
175
  barmode="overlay",
176
  legend_title="Scores",
177
  )
 
 
178
  if self.streamlit_mode:
179
  st.plotly_chart(fig)
180
  else:
 
199
  self.visualize_roc_curve(test_scores)
200
  self.visualize_score_distribution(test_scores)
201
  return test_scores
202
+
203
+ def collate_fn(self, batch):
204
+ texts = [item["text"] for item in batch]
205
+ labels = torch.tensor([int(item["label"]) for item in batch])
206
+ encodings = self.tokenizer(
207
+ texts, padding=True, truncation=True, max_length=512, return_tensors="pt"
208
+ )
209
+ return encodings.input_ids, encodings.attention_mask, labels
210
+
211
+ def train(self, batch_size: int = 32, lr: float = 5e-6, num_classes: int = 2):
212
+ wandb.init(
213
+ project=self.wandb_project,
214
+ entity=self.wandb_entity,
215
+ name=f"{self.model_name}-{self.dataset_name}",
216
+ )
217
+ self.model.classifier = nn.Linear(
218
+ self.model.classifier.in_features, num_classes
219
+ )
220
+ self.model.num_labels = num_classes
221
+ self.model.train()
222
+ optimizer = optim.AdamW(self.model.parameters(), lr=lr)
223
+ data_loader = DataLoader(
224
+ self.train_dataset,
225
+ batch_size=batch_size,
226
+ shuffle=True,
227
+ collate_fn=self.collate_fn,
228
+ )
229
+ progress_bar = st.progress(0, text="Training") if self.streamlit_mode else None
230
+ for i, batch in track(
231
+ enumerate(data_loader), description="Training", total=len(data_loader)
232
+ ):
233
+ input_ids, attention_mask, labels = [x.to(self.device) for x in batch]
234
+ outputs = self.model(
235
+ input_ids=input_ids, attention_mask=attention_mask, labels=labels
236
+ )
237
+ loss = outputs.loss
238
+ optimizer.zero_grad()
239
+ loss.backward()
240
+ optimizer.step()
241
+ wandb.log({"loss": loss.item()})
242
+ if progress_bar:
243
+ progress_percentage = (i + 1) * 100 // len(data_loader)
244
+ progress_bar.progress(
245
+ progress_percentage,
246
+ text=f"Training batch {i + 1}/{len(data_loader)}, Loss: {loss.item()}",
247
+ )
248
+ save_model(self.model, f"{self.model_name}-{self.dataset_name}.safetensors")
249
+ wandb.log_model(f"{self.model_name}-{self.dataset_name}.safetensors")
250
+ wandb.finish()
251
+ os.remove(f"{self.model_name}-{self.dataset_name}.safetensors")