geekyrakshit commited on
Commit
159baa9
·
1 Parent(s): 053730f

update: app

Browse files
README.md CHANGED
@@ -20,6 +20,8 @@ source .venv/bin/activate
20
  ```bash
21
  export OPENAI_API_KEY="YOUR_OPENAI_API_KEY"
22
  export WEAVE_PROJECT="YOUR_WEAVE_PROJECT"
 
 
23
  export WANDB_LOG_MODEL="checkpoint"
24
  streamlit run app.py
25
  ```
 
20
  ```bash
21
  export OPENAI_API_KEY="YOUR_OPENAI_API_KEY"
22
  export WEAVE_PROJECT="YOUR_WEAVE_PROJECT"
23
+ export WANDB_PROJECT_NAME="YOUR_WANDB_PROJECT_NAME"
24
+ export WANDB_ENTITY_NAME="YOUR_WANDB_ENTITY_NAME"
25
  export WANDB_LOG_MODEL="checkpoint"
26
  streamlit run app.py
27
  ```
application_pages/train_classifier.py CHANGED
@@ -3,7 +3,6 @@ import os
3
  import streamlit as st
4
  from dotenv import load_dotenv
5
 
6
- import wandb
7
  from guardrails_genie.train_classifier import train_binary_classifier
8
 
9
 
@@ -49,9 +48,6 @@ st.session_state.should_start_training = (
49
 
50
  if st.session_state.should_start_training:
51
  with st.expander("Training", expanded=True):
52
- st.markdown(
53
- f"Explore your training logs on [Weights & Biases]({wandb.run.url})"
54
- )
55
  training_output = train_binary_classifier(
56
  project_name=os.getenv("WANDB_PROJECT_NAME"),
57
  entity_name=os.getenv("WANDB_ENTITY_NAME"),
 
3
  import streamlit as st
4
  from dotenv import load_dotenv
5
 
 
6
  from guardrails_genie.train_classifier import train_binary_classifier
7
 
8
 
 
48
 
49
  if st.session_state.should_start_training:
50
  with st.expander("Training", expanded=True):
 
 
 
51
  training_output = train_binary_classifier(
52
  project_name=os.getenv("WANDB_PROJECT_NAME"),
53
  entity_name=os.getenv("WANDB_ENTITY_NAME"),
guardrails_genie/train_classifier.py CHANGED
@@ -48,6 +48,10 @@ def train_binary_classifier(
48
  streamlit_mode: bool = False,
49
  ):
50
  wandb.init(project=project_name, entity=entity_name, name=run_name)
 
 
 
 
51
  dataset = load_dataset(dataset_repo)
52
  tokenizer = AutoTokenizer.from_pretrained(model_name)
53
 
@@ -97,6 +101,10 @@ def train_binary_classifier(
97
  compute_metrics=compute_metrics,
98
  callbacks=[StreamlitProgressbarCallback()] if streamlit_mode else [],
99
  )
100
- training_output = trainer.train()
 
 
 
 
101
  wandb.finish()
102
  return training_output
 
48
  streamlit_mode: bool = False,
49
  ):
50
  wandb.init(project=project_name, entity=entity_name, name=run_name)
51
+ if streamlit_mode:
52
+ st.markdown(
53
+ f"Explore your training logs on [Weights & Biases]({wandb.run.url})"
54
+ )
55
  dataset = load_dataset(dataset_repo)
56
  tokenizer = AutoTokenizer.from_pretrained(model_name)
57
 
 
101
  compute_metrics=compute_metrics,
102
  callbacks=[StreamlitProgressbarCallback()] if streamlit_mode else [],
103
  )
104
+ try:
105
+ training_output = trainer.train()
106
+ except Exception as e:
107
+ wandb.finish()
108
+ raise e
109
  wandb.finish()
110
  return training_output