Spaces:
Running
Running
geekyrakshit
commited on
Commit
•
a202ba5
1
Parent(s):
573a89c
fix: LlamaGuardFineTuner
Browse files- application_pages/llama_guard_fine_tuning.py +49 -12
- guardrails_genie/train/llama_guard.py +16 -3
- test.ipynb +0 -0
application_pages/llama_guard_fine_tuning.py
CHANGED
@@ -11,8 +11,20 @@ def initialize_session_state():
|
|
11 |
st.session_state.train_dataset_range = 0
|
12 |
if "test_dataset_range" not in st.session_state:
|
13 |
st.session_state.test_dataset_range = 0
|
14 |
-
if "
|
15 |
-
st.session_state.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
|
18 |
initialize_session_state()
|
@@ -30,14 +42,39 @@ if st.session_state.dataset_address != "":
|
|
30 |
)
|
31 |
st.session_state.train_dataset_range = train_dataset_range
|
32 |
st.session_state.test_dataset_range = test_dataset_range
|
33 |
-
|
34 |
-
st.
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
)
|
42 |
-
st.session_state.
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
st.session_state.train_dataset_range = 0
|
12 |
if "test_dataset_range" not in st.session_state:
|
13 |
st.session_state.test_dataset_range = 0
|
14 |
+
if "load_fine_tuner_button" not in st.session_state:
|
15 |
+
st.session_state.load_fine_tuner_button = False
|
16 |
+
if "is_fine_tuner_loaded" not in st.session_state:
|
17 |
+
st.session_state.is_fine_tuner_loaded = False
|
18 |
+
if "model_name" not in st.session_state:
|
19 |
+
st.session_state.model_name = ""
|
20 |
+
if "preview_dataset" not in st.session_state:
|
21 |
+
st.session_state.preview_dataset = False
|
22 |
+
if "evaluate_model" not in st.session_state:
|
23 |
+
st.session_state.evaluate_model = False
|
24 |
+
if "evaluation_batch_size" not in st.session_state:
|
25 |
+
st.session_state.evaluation_batch_size = None
|
26 |
+
if "evaluation_temperature" not in st.session_state:
|
27 |
+
st.session_state.evaluation_temperature = None
|
28 |
|
29 |
|
30 |
initialize_session_state()
|
|
|
42 |
)
|
43 |
st.session_state.train_dataset_range = train_dataset_range
|
44 |
st.session_state.test_dataset_range = test_dataset_range
|
45 |
+
|
46 |
+
model_name = st.sidebar.selectbox(
|
47 |
+
"Model Name",
|
48 |
+
["meta-llama/Prompt-Guard-86M"],
|
49 |
+
)
|
50 |
+
st.session_state.model_name = model_name
|
51 |
+
|
52 |
+
preview_dataset = st.sidebar.toggle("Preview Dataset")
|
53 |
+
st.session_state.preview_dataset = preview_dataset
|
54 |
+
|
55 |
+
evaluate_model = st.sidebar.toggle("Evaluate Model")
|
56 |
+
st.session_state.evaluate_model = evaluate_model
|
57 |
+
|
58 |
+
load_fine_tuner_button = st.sidebar.button("Load Fine-Tuner")
|
59 |
+
st.session_state.load_fine_tuner_button = load_fine_tuner_button
|
60 |
+
|
61 |
+
if st.session_state.load_fine_tuner_button:
|
62 |
+
with st.status("Loading Fine-Tuner"):
|
63 |
+
st.session_state.llama_guard_fine_tuner.load_dataset(
|
64 |
+
DatasetArgs(
|
65 |
+
dataset_address=st.session_state.dataset_address,
|
66 |
+
train_dataset_range=st.session_state.train_dataset_range,
|
67 |
+
test_dataset_range=st.session_state.test_dataset_range,
|
68 |
+
)
|
69 |
+
)
|
70 |
+
st.session_state.llama_guard_fine_tuner.load_model(
|
71 |
+
model_name=st.session_state.model_name
|
72 |
)
|
73 |
+
if st.session_state.preview_dataset:
|
74 |
+
st.session_state.llama_guard_fine_tuner.show_dataset_sample()
|
75 |
+
if st.session_state.evaluate_model:
|
76 |
+
st.session_state.llama_guard_fine_tuner.evaluate_model(
|
77 |
+
batch_size=32,
|
78 |
+
temperature=3.0,
|
79 |
+
)
|
80 |
+
st.session_state.is_fine_tuner_loaded = True
|
guardrails_genie/train/llama_guard.py
CHANGED
@@ -23,12 +23,14 @@ class LlamaGuardFineTuner:
|
|
23 |
dataset = load_dataset(dataset_args.dataset_address)
|
24 |
self.train_dataset = (
|
25 |
dataset["train"]
|
26 |
-
if dataset_args.train_dataset_range
|
|
|
27 |
else dataset["train"].select(range(dataset_args.train_dataset_range))
|
28 |
)
|
29 |
self.test_dataset = (
|
30 |
dataset["test"]
|
31 |
-
if dataset_args.test_dataset_range
|
|
|
32 |
else dataset["test"].select(range(dataset_args.test_dataset_range))
|
33 |
)
|
34 |
|
@@ -69,7 +71,12 @@ class LlamaGuardFineTuner:
|
|
69 |
data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)
|
70 |
|
71 |
scores = []
|
72 |
-
|
|
|
|
|
|
|
|
|
|
|
73 |
input_ids, attention_mask = [b.to(self.device) for b in batch]
|
74 |
with torch.no_grad():
|
75 |
logits = self.model(
|
@@ -81,6 +88,12 @@ class LlamaGuardFineTuner:
|
|
81 |
probabilities[:, positive_label].cpu().numpy()
|
82 |
)
|
83 |
scores.extend(positive_class_probabilities)
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
return scores
|
86 |
|
|
|
23 |
dataset = load_dataset(dataset_args.dataset_address)
|
24 |
self.train_dataset = (
|
25 |
dataset["train"]
|
26 |
+
if dataset_args.train_dataset_range <= 0
|
27 |
+
or dataset_args.train_dataset_range > len(dataset["train"])
|
28 |
else dataset["train"].select(range(dataset_args.train_dataset_range))
|
29 |
)
|
30 |
self.test_dataset = (
|
31 |
dataset["test"]
|
32 |
+
if dataset_args.test_dataset_range <= 0
|
33 |
+
or dataset_args.test_dataset_range > len(dataset["test"])
|
34 |
else dataset["test"].select(range(dataset_args.test_dataset_range))
|
35 |
)
|
36 |
|
|
|
71 |
data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)
|
72 |
|
73 |
scores = []
|
74 |
+
progress_bar = (
|
75 |
+
st.progress(0, text="Evaluating") if self.streamlit_mode else None
|
76 |
+
)
|
77 |
+
for i, batch in track(
|
78 |
+
enumerate(data_loader), description="Evaluating", total=len(data_loader)
|
79 |
+
):
|
80 |
input_ids, attention_mask = [b.to(self.device) for b in batch]
|
81 |
with torch.no_grad():
|
82 |
logits = self.model(
|
|
|
88 |
probabilities[:, positive_label].cpu().numpy()
|
89 |
)
|
90 |
scores.extend(positive_class_probabilities)
|
91 |
+
if progress_bar:
|
92 |
+
progress_percentage = (i + 1) * 100 // len(data_loader)
|
93 |
+
progress_bar.progress(
|
94 |
+
progress_percentage,
|
95 |
+
text=f"Evaluating batch {i + 1}/{len(data_loader)}",
|
96 |
+
)
|
97 |
|
98 |
return scores
|
99 |
|
test.ipynb
ADDED
File without changes
|