|
import streamlit as st |
|
from sklearn.svm import SVC |
|
from sklearn.pipeline import make_pipeline |
|
from sklearn.preprocessing import StandardScaler |
|
import joblib |
|
|
|
import utils.admin_utils as au |
|
|
|
if 'cleaned_data' not in st.session_state: |
|
st.session_state['cleaned_data'] ='' |
|
if 'sentences_train' not in st.session_state: |
|
st.session_state['sentences_train'] ='' |
|
if 'sentences_test' not in st.session_state: |
|
st.session_state['sentences_test'] ='' |
|
if 'labels_train' not in st.session_state: |
|
st.session_state['labels_train'] ='' |
|
if 'labels_test' not in st.session_state: |
|
st.session_state['labels_test'] ='' |
|
if 'svm_classifier' not in st.session_state: |
|
st.session_state['svm_classifier'] ='' |
|
|
|
|
|
st.title("Let's build our Model...") |
|
|
|
|
|
tab_titles = ['Data Preprocessing', 'Model Training', 'Model Evaluation',"Save Model"] |
|
tabs = st.tabs(tab_titles) |
|
|
|
with tabs[0]: |
|
st.header('Data Preprocessing') |
|
st.write('Here we preprocess the data...') |
|
|
|
data = st.file_uploader("Upload CSV file", type='csv') |
|
button = st.button("Load data", key='data') |
|
|
|
if button: |
|
with st.spinner("uploading data..."): |
|
df = au.read_data(data) |
|
embedder = au.initiate_embedder() |
|
st.session_state['cleaned_data'] = au.create_embedding(df, embedder) |
|
st.success("Finished uploading and embedding csv") |
|
|
|
with tabs[1]: |
|
st.header("Model Training") |
|
st.write("Here we train the model...") |
|
button = st.button("Train Model", key='model') |
|
if button: |
|
with st.spinner("Training model..."): |
|
st.session_state['sentences_train'], st.session_state['sentences_test'], st.session_state['labels_train'], st.session_state['labels_test'] = au.split_train_test_data(st.session_state['cleaned_data']) |
|
st.session_state['svm_classifier'] = make_pipeline(StandardScaler(), SVC(class_weight='balanced')) |
|
st.session_state['svm_classifier'].fit(st.session_state['sentences_train'], st.session_state['labels_train']) |
|
st.success('Finished Training model!') |
|
|
|
with tabs[2]: |
|
st.header('Model Evaluation') |
|
st.write('Here we evaluate the model...') |
|
button = st.button("Evaluate model",key="Evaluation") |
|
|
|
if button: |
|
with st.spinner("Evaluating model..."): |
|
acc_score = au.get_score(st.session_state['svm_classifier'], st.session_state['sentences_test'], st.session_state['labels_test']) |
|
st.success(f"Validation accuracy: {100 * acc_score}%") |
|
st.write("A sample run:") |
|
text = "Rude driver with scary driving" |
|
st.write("*** Our Issue ***: " + text) |
|
|
|
embedder = au.initiate_embedder() |
|
query_result = embedder.embed_query(text) |
|
|
|
result = st.session_state['svm_classifier'].predict([query_result]) |
|
st.write("*** Department it belongs to ***: " + result[0]) |
|
|
|
st.success("Finished") |
|
|
|
with tabs[3]: |
|
st.header('Save model') |
|
st.write('Here we save the model...') |
|
|
|
button = st.button("Save model",key="save") |
|
if button: |
|
with st.spinner('Saving model...'): |
|
joblib.dump(st.session_state['svm_classifier'], 'modelsvm.pk1') |
|
st.success('Done!') |
|
|