news_classification_model_monitor / classification_model_monitor.py
ksvmuralidhar's picture
Update classification_model_monitor.py
a1e9dce verified
import streamlit as st
import pandas as pd
import numpy as np
import seaborn as sns
import plotly.express as px
import matplotlib.pyplot as plt
from read_predictions_from_db import PredictionDBRead
from read_daily_metrics_from_db import MetricsDBRead
from sklearn.metrics import balanced_accuracy_score, accuracy_score
import logging
from config import (CLASSIFIER_ADJUSTMENT_THRESHOLD,
PERFORMANCE_THRESHOLD,
CLASSIFIER_THRESHOLD)
logging.basicConfig(format='%(asctime)s %(levelname)s: %(message)s', level=logging.INFO)
def filter_prediction_data(data: pd.DataFrame):
try:
logging.info("Entering filter_prediction_data()")
if data is None:
raise Exception("Input Prediction Data frame in None")
filtered_prediction_data = data.loc[(data['y_true_proba'] == 1) & (data['used_for_training'].astype("str").str.contains("_train")==False) &
(data['used_for_training'].astype("str").str.contains("_excluded")==False) &
(data['used_for_training'].astype("str").str.contains("_validation")==False)
].copy()
logging.info("Exiting filter_prediction_data()")
return filtered_prediction_data
except Exception as e:
logging.critical(f"Error in filter_prediction_data(): {e}")
return None
def get_adjusted_predictions(df):
try:
logging.info("Entering get_adjusted_predictions()")
if df is None:
raise Exception('Input Filtered Prediction Data Frame is None')
df = df.copy()
df.reset_index(drop=True, inplace=True)
df.loc[df['y_pred_proba']<CLASSIFIER_ADJUSTMENT_THRESHOLD, 'y_pred'] = 'NATION'
# df.loc[(df['text'].str.contains('Pakistan')) & (df['y_pred'] == 'NATION'), 'y_pred'] = 'WORLD'
# df.loc[(df['text'].str.contains('Zodiac Sign', case=False)) | (df['text'].str.contains('Horoscope', case=False)), 'y_pred'] = 'SCIENCE'
logging.info("Exiting get_adjusted_predictions()")
return df
except Exception as e:
logging.info(f"Error in get_adjusted_predictions(): {e}")
return None
def display_kpis(data: pd.DataFrame, adj_data: pd.DataFrame):
try:
logging.info("Entering display_kpis()")
if data is None:
raise Exception("Input Prediction Data frame in None")
if adj_data is None:
raise Exception('Input Adjusted Data frame is None')
n_samples = len(data)
balanced_accuracy = np.round(balanced_accuracy_score(data['y_true'], data['y_pred']), 4)
accuracy = np.round(accuracy_score(data['y_true'], data['y_pred']), 4)
adj_balanced_accuracy = np.round(balanced_accuracy_score(adj_data['y_true'], adj_data['y_pred']), 4)
adj_accuracy = np.round(accuracy_score(adj_data['y_true'], adj_data['y_pred']), 4)
st.write('''<style>
[data-testid="column"] {
width: calc(33.3333% - 1rem) !important;
flex: 1 1 calc(33.3333% - 1rem) !important;
min-width: calc(33% - 1rem) !important;
}
</style>''',
unsafe_allow_html=True)
col1, col2= st.columns(2)
with col1:
metric1 = st.metric(label="Balanced Accuracy", value=balanced_accuracy)
with col2:
metric2 = st.metric(label="Adj Balanced Accuracy", value=adj_balanced_accuracy)
col3, col4= st.columns(2)
with col3:
metric3 = st.metric(label="Accuracy", value=accuracy)
with col4:
metric4 = st.metric(label="Adj Accuracy", value=adj_accuracy)
col5, col6= st.columns(2)
with col5:
metric5 = st.metric(label="Bal Accuracy Threshold", value=PERFORMANCE_THRESHOLD)
with col6:
metric6 = st.metric(label="N Samples", value=n_samples)
logging.info("Exiting display_kpis()")
except Exception as e:
logging.critical(f'Error in display_kpis(): {e}')
st.error("Couldn't display KPIs")
def plot_daily_metrics(metrics_df: pd.DataFrame):
try:
logging.info("Entering plot_daily_metrics()")
st.write(" ")
if metrics_df is None:
raise Exception('Input Metrics Data Frame is None')
metrics_df['evaluation_date'] = pd.to_datetime(metrics_df['evaluation_date'])
metrics_df['mean_score_minus_std'] = np.round(metrics_df['mean_balanced_accuracy_score'] - metrics_df['std_balanced_accuracy_score'], 4)
metrics_df['mean_score_plus_std'] = np.round(metrics_df['mean_balanced_accuracy_score'] + metrics_df['std_balanced_accuracy_score'], 4)
hover_data={'mean_balanced_accuracy_score': True,
'std_balanced_accuracy_score': False,
'mean_score_minus_std': True,
'mean_score_plus_std': True,
'evaluation_window_days': True,
'n_splits': True,
'sample_start_date': True,
'sample_end_date': True,
'sample_size_of_each_split': True}
hover_labels = {'mean_balanced_accuracy_score': "Mean Score",
'mean_score_minus_std': "Mean Score - Stdev",
'mean_score_plus_std': "Mean Score + Stdev",
'evaluation_window_days': "Observation Window (Days)",
'sample_start_date': "Observation Window Start Date",
'sample_end_date': "Observation Window End Date",
'n_splits': "N Splits For Evaluation",
'sample_size_of_each_split': "Sample Size of Each Split"}
fig = px.line(data_frame=metrics_df, x='evaluation_date',
y='mean_balanced_accuracy_score',
error_y='std_balanced_accuracy_score',
title="Daily Balanced Accuracy",
color_discrete_sequence=['black'],
hover_data=hover_data, labels=hover_labels, markers=True)
fig.add_hline(y=PERFORMANCE_THRESHOLD, line_dash="dash", line_color="green",
annotation_text=f"<b>THRESHOLD</b>",
annotation_position="left top")
fig.update_layout(dragmode='pan')
fig.update_layout(margin=dict(l=0, r=0, t=110, b=10))
st.plotly_chart(fig, use_container_width=True)
logging.info("Exiting plot_daily_metrics()")
except Exception as e:
logging.critical(f'Error in plot_daily_metrics(): {e}')
st.error("Couldn't Plot Daily Model Metrics")
def get_misclassified_classes(data):
try:
logging.info("Entering get_misclassified_classes()")
if data is None:
raise Exception("Input Prediction Data Frame is None")
data = data.copy()
data['match'] = (data['y_true'] == data['y_pred']).astype('int')
y_pred_counts = data['y_pred'].value_counts()
misclassified_examples = data.loc[data['match'] == 0, ['text', 'y_true', 'y_pred', 'y_pred_proba', 'url']].copy()
misclassified_examples.sort_values(by=['y_pred', 'y_pred_proba'], ascending=[True, False], inplace=True)
misclassifications = data.loc[data['match'] == 0, 'y_pred'].value_counts()
missing_classes = [i for i in y_pred_counts.index if i not in misclassifications.index]
for i in missing_classes:
misclassifications[i] = 0
misclassifications = misclassifications[y_pred_counts.index].copy()
misclassifications /= y_pred_counts
misclassifications.sort_values(ascending=False, inplace=True)
logging.info("Exiting get_misclassified_classes()")
return np.round(misclassifications, 2), misclassified_examples
except Exception as e:
logging.critical(f'Error in get_misclassified_classes(): {e}')
return None, None
def display_misclassified_examples(misclassified_classes, misclassified_examples):
try:
logging.info("Entering display_misclassified_examples()")
st.write(" ")
if misclassified_classes is None:
raise Exception('Misclassified Classes Distribution Data Frame is None')
if misclassified_examples is None:
raise Exception('Misclassified Examples Data Frame is None')
fig, ax = plt.subplots(figsize=(10, 4.5))
misclassified_classes.plot(kind='bar', ax=ax, color='black', title="Misclassification percentage")
plt.yticks([])
plt.xlabel("")
ax.bar_label(ax.containers[0]);
st.pyplot(fig)
st.markdown("<b>Misclassified examples</b>", unsafe_allow_html=True)
st.dataframe(misclassified_examples, hide_index=True)
st.markdown(
"""
<style>
[data-testid="stElementToolbar"] {
display: none;
}
</style>
""",
unsafe_allow_html=True
)
logging.info("Exiting display_misclassified_examples()")
except Exception as e:
logging.critical(f'Error in display_misclassified_examples(): {e}')
st.error("Couldn't display Misclassification Data")
def classification_model_monitor():
try:
st.write('<h4>Classification Model Monitor<span style="color: red;"> (out of service)</span></h4>', unsafe_allow_html=True)
prediction_db = PredictionDBRead()
metrics_db = MetricsDBRead()
# Read Prediction Data From DB
prediction_data = prediction_db.read_predictions_from_db()
# Filter Prediction Data
filtered_prediction_data = filter_prediction_data(prediction_data)
# Get Adjusted Prediction Data
adjusted_filtered_prediction_data = get_adjusted_predictions(filtered_prediction_data)
# Display KPIs
display_kpis(filtered_prediction_data, adjusted_filtered_prediction_data)
# Read Daily Metrics From DB
metrics_df = metrics_db.read_metrics_from_db()
# Display daily Metrics Line Plot
plot_daily_metrics(metrics_df)
# Get misclassified class distribution and misclassified examples from Prediction Data
misclassified_classes, misclassified_examples = get_misclassified_classes(filtered_prediction_data)
# Display Misclassification Data
display_misclassified_examples(misclassified_classes, misclassified_examples)
st.markdown(
"""<style>
[data-testid="stMetricValue"] {
font-size: 25px;
}
</style>
""", unsafe_allow_html=True
)
except Exception as e:
logging.critical(f"Error in classification_model_monitor(): {e}")
st.error("Unexpected Error. Couldn't display Classification Model Monitor")