|
import streamlit as st |
|
import pandas as pd |
|
import numpy as np |
|
import seaborn as sns |
|
import plotly.express as px |
|
import matplotlib.pyplot as plt |
|
from read_predictions_from_db import PredictionDBRead |
|
from read_daily_metrics_from_db import MetricsDBRead |
|
from sklearn.metrics import balanced_accuracy_score, accuracy_score |
|
import logging |
|
from config import (CLASSIFIER_ADJUSTMENT_THRESHOLD, |
|
PERFORMANCE_THRESHOLD, |
|
CLASSIFIER_THRESHOLD) |
|
|
|
logging.basicConfig(format='%(asctime)s %(levelname)s: %(message)s', level=logging.INFO) |
|
|
|
|
|
def filter_prediction_data(data: pd.DataFrame): |
|
try: |
|
logging.info("Entering filter_prediction_data()") |
|
if data is None: |
|
raise Exception("Input Prediction Data frame in None") |
|
|
|
filtered_prediction_data = data.loc[(data['y_true_proba'] == 1) & (data['used_for_training'].astype("str").str.contains("_train")==False) & |
|
(data['used_for_training'].astype("str").str.contains("_excluded")==False) & |
|
(data['used_for_training'].astype("str").str.contains("_validation")==False) |
|
].copy() |
|
|
|
logging.info("Exiting filter_prediction_data()") |
|
return filtered_prediction_data |
|
except Exception as e: |
|
logging.critical(f"Error in filter_prediction_data(): {e}") |
|
return None |
|
|
|
|
|
def get_adjusted_predictions(df): |
|
try: |
|
logging.info("Entering get_adjusted_predictions()") |
|
if df is None: |
|
raise Exception('Input Filtered Prediction Data Frame is None') |
|
df = df.copy() |
|
df.reset_index(drop=True, inplace=True) |
|
df.loc[df['y_pred_proba']<CLASSIFIER_ADJUSTMENT_THRESHOLD, 'y_pred'] = 'NATION' |
|
|
|
|
|
logging.info("Exiting get_adjusted_predictions()") |
|
return df |
|
except Exception as e: |
|
logging.info(f"Error in get_adjusted_predictions(): {e}") |
|
return None |
|
|
|
|
|
def display_kpis(data: pd.DataFrame, adj_data: pd.DataFrame): |
|
try: |
|
logging.info("Entering display_kpis()") |
|
if data is None: |
|
raise Exception("Input Prediction Data frame in None") |
|
if adj_data is None: |
|
raise Exception('Input Adjusted Data frame is None') |
|
|
|
n_samples = len(data) |
|
balanced_accuracy = np.round(balanced_accuracy_score(data['y_true'], data['y_pred']), 4) |
|
accuracy = np.round(accuracy_score(data['y_true'], data['y_pred']), 4) |
|
|
|
adj_balanced_accuracy = np.round(balanced_accuracy_score(adj_data['y_true'], adj_data['y_pred']), 4) |
|
adj_accuracy = np.round(accuracy_score(adj_data['y_true'], adj_data['y_pred']), 4) |
|
|
|
st.write('''<style> |
|
[data-testid="column"] { |
|
width: calc(33.3333% - 1rem) !important; |
|
flex: 1 1 calc(33.3333% - 1rem) !important; |
|
min-width: calc(33% - 1rem) !important; |
|
} |
|
</style>''', |
|
unsafe_allow_html=True) |
|
|
|
col1, col2= st.columns(2) |
|
with col1: |
|
metric1 = st.metric(label="Balanced Accuracy", value=balanced_accuracy) |
|
with col2: |
|
metric2 = st.metric(label="Adj Balanced Accuracy", value=adj_balanced_accuracy) |
|
|
|
col3, col4= st.columns(2) |
|
with col3: |
|
metric3 = st.metric(label="Accuracy", value=accuracy) |
|
with col4: |
|
metric4 = st.metric(label="Adj Accuracy", value=adj_accuracy) |
|
|
|
col5, col6= st.columns(2) |
|
with col5: |
|
metric5 = st.metric(label="Bal Accuracy Threshold", value=PERFORMANCE_THRESHOLD) |
|
with col6: |
|
metric6 = st.metric(label="N Samples", value=n_samples) |
|
logging.info("Exiting display_kpis()") |
|
except Exception as e: |
|
logging.critical(f'Error in display_kpis(): {e}') |
|
st.error("Couldn't display KPIs") |
|
|
|
|
|
def plot_daily_metrics(metrics_df: pd.DataFrame): |
|
try: |
|
logging.info("Entering plot_daily_metrics()") |
|
st.write(" ") |
|
if metrics_df is None: |
|
raise Exception('Input Metrics Data Frame is None') |
|
|
|
metrics_df['evaluation_date'] = pd.to_datetime(metrics_df['evaluation_date']) |
|
metrics_df['mean_score_minus_std'] = np.round(metrics_df['mean_balanced_accuracy_score'] - metrics_df['std_balanced_accuracy_score'], 4) |
|
metrics_df['mean_score_plus_std'] = np.round(metrics_df['mean_balanced_accuracy_score'] + metrics_df['std_balanced_accuracy_score'], 4) |
|
|
|
hover_data={'mean_balanced_accuracy_score': True, |
|
'std_balanced_accuracy_score': False, |
|
'mean_score_minus_std': True, |
|
'mean_score_plus_std': True, |
|
'evaluation_window_days': True, |
|
'n_splits': True, |
|
'sample_start_date': True, |
|
'sample_end_date': True, |
|
'sample_size_of_each_split': True} |
|
|
|
hover_labels = {'mean_balanced_accuracy_score': "Mean Score", |
|
'mean_score_minus_std': "Mean Score - Stdev", |
|
'mean_score_plus_std': "Mean Score + Stdev", |
|
'evaluation_window_days': "Observation Window (Days)", |
|
'sample_start_date': "Observation Window Start Date", |
|
'sample_end_date': "Observation Window End Date", |
|
'n_splits': "N Splits For Evaluation", |
|
'sample_size_of_each_split': "Sample Size of Each Split"} |
|
|
|
fig = px.line(data_frame=metrics_df, x='evaluation_date', |
|
y='mean_balanced_accuracy_score', |
|
error_y='std_balanced_accuracy_score', |
|
title="Daily Balanced Accuracy", |
|
color_discrete_sequence=['black'], |
|
hover_data=hover_data, labels=hover_labels, markers=True) |
|
|
|
fig.add_hline(y=PERFORMANCE_THRESHOLD, line_dash="dash", line_color="green", |
|
annotation_text=f"<b>THRESHOLD</b>", |
|
annotation_position="left top") |
|
|
|
fig.update_layout(dragmode='pan') |
|
fig.update_layout(margin=dict(l=0, r=0, t=110, b=10)) |
|
st.plotly_chart(fig, use_container_width=True) |
|
logging.info("Exiting plot_daily_metrics()") |
|
except Exception as e: |
|
logging.critical(f'Error in plot_daily_metrics(): {e}') |
|
st.error("Couldn't Plot Daily Model Metrics") |
|
|
|
|
|
def get_misclassified_classes(data): |
|
try: |
|
logging.info("Entering get_misclassified_classes()") |
|
if data is None: |
|
raise Exception("Input Prediction Data Frame is None") |
|
|
|
data = data.copy() |
|
data['match'] = (data['y_true'] == data['y_pred']).astype('int') |
|
y_pred_counts = data['y_pred'].value_counts() |
|
|
|
misclassified_examples = data.loc[data['match'] == 0, ['text', 'y_true', 'y_pred', 'y_pred_proba', 'url']].copy() |
|
misclassified_examples.sort_values(by=['y_pred', 'y_pred_proba'], ascending=[True, False], inplace=True) |
|
|
|
misclassifications = data.loc[data['match'] == 0, 'y_pred'].value_counts() |
|
|
|
missing_classes = [i for i in y_pred_counts.index if i not in misclassifications.index] |
|
for i in missing_classes: |
|
misclassifications[i] = 0 |
|
|
|
misclassifications = misclassifications[y_pred_counts.index].copy() |
|
misclassifications /= y_pred_counts |
|
misclassifications.sort_values(ascending=False, inplace=True) |
|
logging.info("Exiting get_misclassified_classes()") |
|
return np.round(misclassifications, 2), misclassified_examples |
|
except Exception as e: |
|
logging.critical(f'Error in get_misclassified_classes(): {e}') |
|
return None, None |
|
|
|
|
|
def display_misclassified_examples(misclassified_classes, misclassified_examples): |
|
try: |
|
logging.info("Entering display_misclassified_examples()") |
|
st.write(" ") |
|
if misclassified_classes is None: |
|
raise Exception('Misclassified Classes Distribution Data Frame is None') |
|
if misclassified_examples is None: |
|
raise Exception('Misclassified Examples Data Frame is None') |
|
|
|
fig, ax = plt.subplots(figsize=(10, 4.5)) |
|
misclassified_classes.plot(kind='bar', ax=ax, color='black', title="Misclassification percentage") |
|
plt.yticks([]) |
|
plt.xlabel("") |
|
ax.bar_label(ax.containers[0]); |
|
st.pyplot(fig) |
|
|
|
st.markdown("<b>Misclassified examples</b>", unsafe_allow_html=True) |
|
st.dataframe(misclassified_examples, hide_index=True) |
|
st.markdown( |
|
""" |
|
<style> |
|
[data-testid="stElementToolbar"] { |
|
display: none; |
|
} |
|
</style> |
|
""", |
|
unsafe_allow_html=True |
|
) |
|
logging.info("Exiting display_misclassified_examples()") |
|
except Exception as e: |
|
logging.critical(f'Error in display_misclassified_examples(): {e}') |
|
st.error("Couldn't display Misclassification Data") |
|
|
|
|
|
def classification_model_monitor(): |
|
try: |
|
st.write('<h4>Classification Model Monitor<span style="color: red;"> (out of service)</span></h4>', unsafe_allow_html=True) |
|
|
|
prediction_db = PredictionDBRead() |
|
metrics_db = MetricsDBRead() |
|
|
|
|
|
prediction_data = prediction_db.read_predictions_from_db() |
|
|
|
filtered_prediction_data = filter_prediction_data(prediction_data) |
|
|
|
adjusted_filtered_prediction_data = get_adjusted_predictions(filtered_prediction_data) |
|
|
|
display_kpis(filtered_prediction_data, adjusted_filtered_prediction_data) |
|
|
|
|
|
metrics_df = metrics_db.read_metrics_from_db() |
|
|
|
plot_daily_metrics(metrics_df) |
|
|
|
|
|
misclassified_classes, misclassified_examples = get_misclassified_classes(filtered_prediction_data) |
|
|
|
display_misclassified_examples(misclassified_classes, misclassified_examples) |
|
|
|
st.markdown( |
|
"""<style> |
|
[data-testid="stMetricValue"] { |
|
font-size: 25px; |
|
} |
|
</style> |
|
""", unsafe_allow_html=True |
|
) |
|
|
|
except Exception as e: |
|
logging.critical(f"Error in classification_model_monitor(): {e}") |
|
st.error("Unexpected Error. Couldn't display Classification Model Monitor") |
|
|