geekyrakshit commited on
Commit
a202ba5
1 Parent(s): 573a89c

fix: LlamaGuardFineTuner

Browse files
application_pages/llama_guard_fine_tuning.py CHANGED
@@ -11,8 +11,20 @@ def initialize_session_state():
11
  st.session_state.train_dataset_range = 0
12
  if "test_dataset_range" not in st.session_state:
13
  st.session_state.test_dataset_range = 0
14
- if "load_dataset_button" not in st.session_state:
15
- st.session_state.load_dataset_button = False
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
 
18
  initialize_session_state()
@@ -30,14 +42,39 @@ if st.session_state.dataset_address != "":
30
  )
31
  st.session_state.train_dataset_range = train_dataset_range
32
  st.session_state.test_dataset_range = test_dataset_range
33
- load_dataset_button = st.sidebar.button("Load Dataset")
34
- st.session_state.load_dataset_button = load_dataset_button
35
- if load_dataset_button:
36
- with st.status("Dataset Arguments"):
37
- dataset_args = DatasetArgs(
38
- dataset_address=st.session_state.dataset_address,
39
- train_dataset_range=st.session_state.train_dataset_range,
40
- test_dataset_range=st.session_state.test_dataset_range,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  )
42
- st.session_state.llama_guard_fine_tuner.load_dataset(dataset_args)
43
- st.session_state.llama_guard_fine_tuner.show_dataset_sample()
 
 
 
 
 
 
 
11
  st.session_state.train_dataset_range = 0
12
  if "test_dataset_range" not in st.session_state:
13
  st.session_state.test_dataset_range = 0
14
+ if "load_fine_tuner_button" not in st.session_state:
15
+ st.session_state.load_fine_tuner_button = False
16
+ if "is_fine_tuner_loaded" not in st.session_state:
17
+ st.session_state.is_fine_tuner_loaded = False
18
+ if "model_name" not in st.session_state:
19
+ st.session_state.model_name = ""
20
+ if "preview_dataset" not in st.session_state:
21
+ st.session_state.preview_dataset = False
22
+ if "evaluate_model" not in st.session_state:
23
+ st.session_state.evaluate_model = False
24
+ if "evaluation_batch_size" not in st.session_state:
25
+ st.session_state.evaluation_batch_size = None
26
+ if "evaluation_temperature" not in st.session_state:
27
+ st.session_state.evaluation_temperature = None
28
 
29
 
30
  initialize_session_state()
 
42
  )
43
  st.session_state.train_dataset_range = train_dataset_range
44
  st.session_state.test_dataset_range = test_dataset_range
45
+
46
+ model_name = st.sidebar.selectbox(
47
+ "Model Name",
48
+ ["meta-llama/Prompt-Guard-86M"],
49
+ )
50
+ st.session_state.model_name = model_name
51
+
52
+ preview_dataset = st.sidebar.toggle("Preview Dataset")
53
+ st.session_state.preview_dataset = preview_dataset
54
+
55
+ evaluate_model = st.sidebar.toggle("Evaluate Model")
56
+ st.session_state.evaluate_model = evaluate_model
57
+
58
+ load_fine_tuner_button = st.sidebar.button("Load Fine-Tuner")
59
+ st.session_state.load_fine_tuner_button = load_fine_tuner_button
60
+
61
+ if st.session_state.load_fine_tuner_button:
62
+ with st.status("Loading Fine-Tuner"):
63
+ st.session_state.llama_guard_fine_tuner.load_dataset(
64
+ DatasetArgs(
65
+ dataset_address=st.session_state.dataset_address,
66
+ train_dataset_range=st.session_state.train_dataset_range,
67
+ test_dataset_range=st.session_state.test_dataset_range,
68
+ )
69
+ )
70
+ st.session_state.llama_guard_fine_tuner.load_model(
71
+ model_name=st.session_state.model_name
72
  )
73
+ if st.session_state.preview_dataset:
74
+ st.session_state.llama_guard_fine_tuner.show_dataset_sample()
75
+ if st.session_state.evaluate_model:
76
+ st.session_state.llama_guard_fine_tuner.evaluate_model(
77
+ batch_size=32,
78
+ temperature=3.0,
79
+ )
80
+ st.session_state.is_fine_tuner_loaded = True
guardrails_genie/train/llama_guard.py CHANGED
@@ -23,12 +23,14 @@ class LlamaGuardFineTuner:
23
  dataset = load_dataset(dataset_args.dataset_address)
24
  self.train_dataset = (
25
  dataset["train"]
26
- if dataset_args.train_dataset_range > 0
 
27
  else dataset["train"].select(range(dataset_args.train_dataset_range))
28
  )
29
  self.test_dataset = (
30
  dataset["test"]
31
- if dataset_args.test_dataset_range > 0
 
32
  else dataset["test"].select(range(dataset_args.test_dataset_range))
33
  )
34
 
@@ -69,7 +71,12 @@ class LlamaGuardFineTuner:
69
  data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)
70
 
71
  scores = []
72
- for batch in track(data_loader, description="Evaluating"):
 
 
 
 
 
73
  input_ids, attention_mask = [b.to(self.device) for b in batch]
74
  with torch.no_grad():
75
  logits = self.model(
@@ -81,6 +88,12 @@ class LlamaGuardFineTuner:
81
  probabilities[:, positive_label].cpu().numpy()
82
  )
83
  scores.extend(positive_class_probabilities)
 
 
 
 
 
 
84
 
85
  return scores
86
 
 
23
  dataset = load_dataset(dataset_args.dataset_address)
24
  self.train_dataset = (
25
  dataset["train"]
26
+ if dataset_args.train_dataset_range <= 0
27
+ or dataset_args.train_dataset_range > len(dataset["train"])
28
  else dataset["train"].select(range(dataset_args.train_dataset_range))
29
  )
30
  self.test_dataset = (
31
  dataset["test"]
32
+ if dataset_args.test_dataset_range <= 0
33
+ or dataset_args.test_dataset_range > len(dataset["test"])
34
  else dataset["test"].select(range(dataset_args.test_dataset_range))
35
  )
36
 
 
71
  data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)
72
 
73
  scores = []
74
+ progress_bar = (
75
+ st.progress(0, text="Evaluating") if self.streamlit_mode else None
76
+ )
77
+ for i, batch in track(
78
+ enumerate(data_loader), description="Evaluating", total=len(data_loader)
79
+ ):
80
  input_ids, attention_mask = [b.to(self.device) for b in batch]
81
  with torch.no_grad():
82
  logits = self.model(
 
88
  probabilities[:, positive_label].cpu().numpy()
89
  )
90
  scores.extend(positive_class_probabilities)
91
+ if progress_bar:
92
+ progress_percentage = (i + 1) * 100 // len(data_loader)
93
+ progress_bar.progress(
94
+ progress_percentage,
95
+ text=f"Evaluating batch {i + 1}/{len(data_loader)}",
96
+ )
97
 
98
  return scores
99
 
test.ipynb ADDED
File without changes