geekyrakshit commited on
Commit
053730f
·
1 Parent(s): 2be5f55

update: app

Browse files
application_pages/train_classifier.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import streamlit as st
2
  from dotenv import load_dotenv
3
 
@@ -26,7 +28,12 @@ dataset_name = st.sidebar.text_input("Dataset Name", value="")
26
  st.session_state.dataset_name = dataset_name
27
 
28
  base_model_name = st.sidebar.selectbox(
29
- "Base Model", options=["distilbert/distilbert-base-uncased", "roberta-base"]
 
 
 
 
 
30
  )
31
  st.session_state.base_model_name = base_model_name
32
 
@@ -46,8 +53,9 @@ if st.session_state.should_start_training:
46
  f"Explore your training logs on [Weights & Biases]({wandb.run.url})"
47
  )
48
  training_output = train_binary_classifier(
49
- project_name="guardrails-genie",
50
- entity_name="geekyrakshit",
 
51
  dataset_repo=st.session_state.dataset_name,
52
  model_name=st.session_state.base_model_name,
53
  batch_size=st.session_state.batch_size,
 
1
+ import os
2
+
3
  import streamlit as st
4
  from dotenv import load_dotenv
5
 
 
28
  st.session_state.dataset_name = dataset_name
29
 
30
  base_model_name = st.sidebar.selectbox(
31
+ "Base Model",
32
+ options=[
33
+ "distilbert/distilbert-base-uncased",
34
+ "FacebookAI/roberta-base",
35
+ "microsoft/deberta-v3-base",
36
+ ],
37
  )
38
  st.session_state.base_model_name = base_model_name
39
 
 
53
  f"Explore your training logs on [Weights & Biases]({wandb.run.url})"
54
  )
55
  training_output = train_binary_classifier(
56
+ project_name=os.getenv("WANDB_PROJECT_NAME"),
57
+ entity_name=os.getenv("WANDB_ENTITY_NAME"),
58
+ run_name=f"{st.session_state.base_model_name}-finetuned",
59
  dataset_repo=st.session_state.dataset_name,
60
  model_name=st.session_state.base_model_name,
61
  batch_size=st.session_state.batch_size,
guardrails_genie/train_classifier.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  import evaluate
3
  import numpy as np
4
  import streamlit as st
@@ -39,6 +38,7 @@ class StreamlitProgressbarCallback(TrainerCallback):
39
  def train_binary_classifier(
40
  project_name: str,
41
  entity_name: str,
 
42
  dataset_repo: str = "geekyrakshit/prompt-injection-dataset",
43
  model_name: str = "distilbert/distilbert-base-uncased",
44
  learning_rate: float = 2e-5,
@@ -47,7 +47,7 @@ def train_binary_classifier(
47
  weight_decay: float = 0.01,
48
  streamlit_mode: bool = False,
49
  ):
50
- wandb.init(project=project_name, entity=entity_name)
51
  dataset = load_dataset(dataset_repo)
52
  tokenizer = AutoTokenizer.from_pretrained(model_name)
53
 
 
 
1
  import evaluate
2
  import numpy as np
3
  import streamlit as st
 
38
  def train_binary_classifier(
39
  project_name: str,
40
  entity_name: str,
41
+ run_name: str,
42
  dataset_repo: str = "geekyrakshit/prompt-injection-dataset",
43
  model_name: str = "distilbert/distilbert-base-uncased",
44
  learning_rate: float = 2e-5,
 
47
  weight_decay: float = 0.01,
48
  streamlit_mode: bool = False,
49
  ):
50
+ wandb.init(project=project_name, entity=entity_name, name=run_name)
51
  dataset = load_dataset(dataset_repo)
52
  tokenizer = AutoTokenizer.from_pretrained(model_name)
53