Spaces:
Running
Running
Commit
·
159baa9
1
Parent(s):
053730f
update: app
Browse files
README.md
CHANGED
@@ -20,6 +20,8 @@ source .venv/bin/activate
|
|
20 |
```bash
|
21 |
export OPENAI_API_KEY="YOUR_OPENAI_API_KEY"
|
22 |
export WEAVE_PROJECT="YOUR_WEAVE_PROJECT"
|
|
|
|
|
23 |
export WANDB_LOG_MODEL="checkpoint"
|
24 |
streamlit run app.py
|
25 |
```
|
|
|
20 |
```bash
|
21 |
export OPENAI_API_KEY="YOUR_OPENAI_API_KEY"
|
22 |
export WEAVE_PROJECT="YOUR_WEAVE_PROJECT"
|
23 |
+
export WANDB_PROJECT_NAME="YOUR_WANDB_PROJECT_NAME"
|
24 |
+
export WANDB_ENTITY_NAME="YOUR_WANDB_ENTITY_NAME"
|
25 |
export WANDB_LOG_MODEL="checkpoint"
|
26 |
streamlit run app.py
|
27 |
```
|
application_pages/train_classifier.py
CHANGED
@@ -3,7 +3,6 @@ import os
|
|
3 |
import streamlit as st
|
4 |
from dotenv import load_dotenv
|
5 |
|
6 |
-
import wandb
|
7 |
from guardrails_genie.train_classifier import train_binary_classifier
|
8 |
|
9 |
|
@@ -49,9 +48,6 @@ st.session_state.should_start_training = (
|
|
49 |
|
50 |
if st.session_state.should_start_training:
|
51 |
with st.expander("Training", expanded=True):
|
52 |
-
st.markdown(
|
53 |
-
f"Explore your training logs on [Weights & Biases]({wandb.run.url})"
|
54 |
-
)
|
55 |
training_output = train_binary_classifier(
|
56 |
project_name=os.getenv("WANDB_PROJECT_NAME"),
|
57 |
entity_name=os.getenv("WANDB_ENTITY_NAME"),
|
|
|
3 |
import streamlit as st
|
4 |
from dotenv import load_dotenv
|
5 |
|
|
|
6 |
from guardrails_genie.train_classifier import train_binary_classifier
|
7 |
|
8 |
|
|
|
48 |
|
49 |
if st.session_state.should_start_training:
|
50 |
with st.expander("Training", expanded=True):
|
|
|
|
|
|
|
51 |
training_output = train_binary_classifier(
|
52 |
project_name=os.getenv("WANDB_PROJECT_NAME"),
|
53 |
entity_name=os.getenv("WANDB_ENTITY_NAME"),
|
guardrails_genie/train_classifier.py
CHANGED
@@ -48,6 +48,10 @@ def train_binary_classifier(
|
|
48 |
streamlit_mode: bool = False,
|
49 |
):
|
50 |
wandb.init(project=project_name, entity=entity_name, name=run_name)
|
|
|
|
|
|
|
|
|
51 |
dataset = load_dataset(dataset_repo)
|
52 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
53 |
|
@@ -97,6 +101,10 @@ def train_binary_classifier(
|
|
97 |
compute_metrics=compute_metrics,
|
98 |
callbacks=[StreamlitProgressbarCallback()] if streamlit_mode else [],
|
99 |
)
|
100 |
-
|
|
|
|
|
|
|
|
|
101 |
wandb.finish()
|
102 |
return training_output
|
|
|
48 |
streamlit_mode: bool = False,
|
49 |
):
|
50 |
wandb.init(project=project_name, entity=entity_name, name=run_name)
|
51 |
+
if streamlit_mode:
|
52 |
+
st.markdown(
|
53 |
+
f"Explore your training logs on [Weights & Biases]({wandb.run.url})"
|
54 |
+
)
|
55 |
dataset = load_dataset(dataset_repo)
|
56 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
57 |
|
|
|
101 |
compute_metrics=compute_metrics,
|
102 |
callbacks=[StreamlitProgressbarCallback()] if streamlit_mode else [],
|
103 |
)
|
104 |
+
try:
|
105 |
+
training_output = trainer.train()
|
106 |
+
except Exception as e:
|
107 |
+
wandb.finish()
|
108 |
+
raise e
|
109 |
wandb.finish()
|
110 |
return training_output
|