pgurazada1's picture
Update app.py
fdd11f5 verified
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import gradio as gr
from datasets import load_dataset
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
LOGS_DATASET_URI = 'pgurazada1/machine-failure-mlops-demo-logs'
# Load and cache training data
dataset = fetch_openml(data_id=42890, as_frame=True, parser="auto")
data_df = dataset.data
target = 'Machine failure'
numeric_features = [
'Air temperature [K]',
'Process temperature [K]',
'Rotational speed [rpm]',
'Torque [Nm]',
'Tool wear [min]'
]
categorical_features = ['Type']
X = data_df[numeric_features + categorical_features]
y = data_df[target]
Xtrain, Xtest, ytrain, ytest = train_test_split(
X, y,
test_size=0.2,
random_state=42
)
def get_data():
"""
Connect to the HuggingFace dataset where the logs are stored.
Pull the data into a dataframe
"""
data = load_dataset(LOGS_DATASET_URI)
sample_df = data['train'].to_pandas().sample(100)
return sample_df
def check_model_drift():
"""
Check proportion of machine failure as compared to
its proportion in training data. If the deviation is more than
2 standard deviations, flag a model drift.
"""
sample_df = get_data()
p_pos_label_training_data = 0.03475
training_data_size = 8000
n_0 = sample_df.prediction.value_counts()[0]
try:
n_1 = sample_df.prediction.value_counts()[1]
except Exception as e:
n_1 = 0
p_pos_label_sample_logs = n_1/(n_0+n_1)
variance = (p_pos_label_training_data * (1-p_pos_label_training_data))/training_data_size
p_diff = abs(p_pos_label_training_data - p_pos_label_sample_logs)
if p_diff > 2 * math.sqrt(variance):
return "Model Drift Detected! Check Logs!"
else:
return "No Model Drift!"
def plot_target_distributions():
sample_df = get_data()
figure, axes = plt.subplots(2, 1, figsize=(9, 7))
sns.countplot(x=ytrain, stat='proportion', ax=axes[0])
axes[0].set_title("Distribution of targets in training data")
axes[0].set_xlabel('')
sns.countplot(x=sample_df.prediction, stat='proportion', ax=axes[1])
axes[1].set_title("Distribution of predicted targets from the deployed model")
axes[1].set_xlabel('')
plt.close()
return figure
def psi(actual_proportions, expected_proportions):
psi_values = (actual_proportions - expected_proportions) * \
np.log(actual_proportions / expected_proportions)
return sum(psi_values)
def check_data_drift():
"""
Compare training data features and live features. If the deviation is
more than 2 standard deviations, flag data drift.
Numeric features and catagorical features are dealt with separately.
"""
sample_df = get_data()
data_drift_status = {}
numeric_features = [
'Air temperature [K]',
'Process temperature [K]',
'Rotational speed [rpm]',
'Torque [Nm]',
'Tool wear [min]'
]
categorical_features = ['Type']
# Numeric features
for feature in numeric_features:
mean_feature_training_data = Xtrain[feature].mean()
std_feature_training_data = Xtrain[feature].std()
mean_feature_sample_logs = sample_df[feature].mean()
mean_diff = abs(mean_feature_training_data - mean_feature_sample_logs)
if mean_diff > 2 * std_feature_training_data:
data_drift_status[feature] = ["Data Drift Detected! Check Logs!"]
else:
data_drift_status[feature] = ["No Data Drift!"]
# Categorical feature Type
live_proportions = sample_df['Type'].value_counts(normalize=True).values
training_proportions = Xtrain['Type'].value_counts(normalize=True).values
psi_value = psi(live_proportions, training_proportions)
if psi_value > 0.1:
data_drift_status['Type'] = ["Data Drift Detected! Check Logs!"]
else:
data_drift_status['Type'] = ["No Data Drift!"]
return pd.DataFrame.from_dict(data_drift_status)
with gr.Blocks(theme=gr.themes.Base()) as demo:
gr.Markdown("# Real-time Monitoring Dashboard")
gr.Markdown("## Model drift detection (every 5 seconds)")
with gr.Row():
with gr.Column():
gr.Textbox(check_model_drift, every=5, label="Model Drift Status")
gr.Markdown("## Distribution of Training Targets")
with gr.Row():
with gr.Column():
gr.Plot(plot_target_distributions, every=86400, label="Target Data Distributions")
gr.Markdown("## Data drift detection (every 5 seconds)")
with gr.Row():
with gr.Column():
gr.DataFrame(check_data_drift, every=5, min_width=240, label="Data Drift Status")
demo.queue().launch()