Spaces:
Running
Running
geekyrakshit
commited on
Commit
·
053730f
1
Parent(s):
2be5f55
update: app
Browse files
application_pages/train_classifier.py
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
from dotenv import load_dotenv
|
3 |
|
@@ -26,7 +28,12 @@ dataset_name = st.sidebar.text_input("Dataset Name", value="")
|
|
26 |
st.session_state.dataset_name = dataset_name
|
27 |
|
28 |
base_model_name = st.sidebar.selectbox(
|
29 |
-
"Base Model",
|
|
|
|
|
|
|
|
|
|
|
30 |
)
|
31 |
st.session_state.base_model_name = base_model_name
|
32 |
|
@@ -46,8 +53,9 @@ if st.session_state.should_start_training:
|
|
46 |
f"Explore your training logs on [Weights & Biases]({wandb.run.url})"
|
47 |
)
|
48 |
training_output = train_binary_classifier(
|
49 |
-
project_name="
|
50 |
-
entity_name="
|
|
|
51 |
dataset_repo=st.session_state.dataset_name,
|
52 |
model_name=st.session_state.base_model_name,
|
53 |
batch_size=st.session_state.batch_size,
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
import streamlit as st
|
4 |
from dotenv import load_dotenv
|
5 |
|
|
|
28 |
st.session_state.dataset_name = dataset_name
|
29 |
|
30 |
base_model_name = st.sidebar.selectbox(
|
31 |
+
"Base Model",
|
32 |
+
options=[
|
33 |
+
"distilbert/distilbert-base-uncased",
|
34 |
+
"FacebookAI/roberta-base",
|
35 |
+
"microsoft/deberta-v3-base",
|
36 |
+
],
|
37 |
)
|
38 |
st.session_state.base_model_name = base_model_name
|
39 |
|
|
|
53 |
f"Explore your training logs on [Weights & Biases]({wandb.run.url})"
|
54 |
)
|
55 |
training_output = train_binary_classifier(
|
56 |
+
project_name=os.getenv("WANDB_PROJECT_NAME"),
|
57 |
+
entity_name=os.getenv("WANDB_ENTITY_NAME"),
|
58 |
+
run_name=f"{st.session_state.base_model_name}-finetuned",
|
59 |
dataset_repo=st.session_state.dataset_name,
|
60 |
model_name=st.session_state.base_model_name,
|
61 |
batch_size=st.session_state.batch_size,
|
guardrails_genie/train_classifier.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
|
2 |
import evaluate
|
3 |
import numpy as np
|
4 |
import streamlit as st
|
@@ -39,6 +38,7 @@ class StreamlitProgressbarCallback(TrainerCallback):
|
|
39 |
def train_binary_classifier(
|
40 |
project_name: str,
|
41 |
entity_name: str,
|
|
|
42 |
dataset_repo: str = "geekyrakshit/prompt-injection-dataset",
|
43 |
model_name: str = "distilbert/distilbert-base-uncased",
|
44 |
learning_rate: float = 2e-5,
|
@@ -47,7 +47,7 @@ def train_binary_classifier(
|
|
47 |
weight_decay: float = 0.01,
|
48 |
streamlit_mode: bool = False,
|
49 |
):
|
50 |
-
wandb.init(project=project_name, entity=entity_name)
|
51 |
dataset = load_dataset(dataset_repo)
|
52 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
53 |
|
|
|
|
|
1 |
import evaluate
|
2 |
import numpy as np
|
3 |
import streamlit as st
|
|
|
38 |
def train_binary_classifier(
|
39 |
project_name: str,
|
40 |
entity_name: str,
|
41 |
+
run_name: str,
|
42 |
dataset_repo: str = "geekyrakshit/prompt-injection-dataset",
|
43 |
model_name: str = "distilbert/distilbert-base-uncased",
|
44 |
learning_rate: float = 2e-5,
|
|
|
47 |
weight_decay: float = 0.01,
|
48 |
streamlit_mode: bool = False,
|
49 |
):
|
50 |
+
wandb.init(project=project_name, entity=entity_name, name=run_name)
|
51 |
dataset = load_dataset(dataset_repo)
|
52 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
53 |
|