pgurazada1's picture
fdd11f5 verified
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import gradio as gr
from datasets import load_dataset
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
LOGS_DATASET_URI = 'pgurazada1/machine-failure-mlops-demo-logs'
# Load and cache training data
dataset = fetch_openml(data_id=42890, as_frame=True, parser="auto")
data_df =
target = 'Machine failure'
numeric_features = [
'Air temperature [K]',
'Process temperature [K]',
'Rotational speed [rpm]',
'Torque [Nm]',
'Tool wear [min]'
categorical_features = ['Type']
X = data_df[numeric_features + categorical_features]
y = data_df[target]
Xtrain, Xtest, ytrain, ytest = train_test_split(
X, y,
def get_data():
Connect to the HuggingFace dataset where the logs are stored.
Pull the data into a dataframe
data = load_dataset(LOGS_DATASET_URI)
sample_df = data['train'].to_pandas().sample(100)
return sample_df
def check_model_drift():
Check proportion of machine failure as compared to
its proportion in training data. If the deviation is more than
2 standard deviations, flag a model drift.
sample_df = get_data()
p_pos_label_training_data = 0.03475
training_data_size = 8000
n_0 = sample_df.prediction.value_counts()[0]
n_1 = sample_df.prediction.value_counts()[1]
except Exception as e:
n_1 = 0
p_pos_label_sample_logs = n_1/(n_0+n_1)
variance = (p_pos_label_training_data * (1-p_pos_label_training_data))/training_data_size
p_diff = abs(p_pos_label_training_data - p_pos_label_sample_logs)
if p_diff > 2 * math.sqrt(variance):
return "Model Drift Detected! Check Logs!"
return "No Model Drift!"
def plot_target_distributions():
sample_df = get_data()
figure, axes = plt.subplots(2, 1, figsize=(9, 7))
sns.countplot(x=ytrain, stat='proportion', ax=axes[0])
axes[0].set_title("Distribution of targets in training data")
sns.countplot(x=sample_df.prediction, stat='proportion', ax=axes[1])
axes[1].set_title("Distribution of predicted targets from the deployed model")
return figure
def psi(actual_proportions, expected_proportions):
psi_values = (actual_proportions - expected_proportions) * \
np.log(actual_proportions / expected_proportions)
return sum(psi_values)
def check_data_drift():
Compare training data features and live features. If the deviation is
more than 2 standard deviations, flag data drift.
Numeric features and catagorical features are dealt with separately.
sample_df = get_data()
data_drift_status = {}
numeric_features = [
'Air temperature [K]',
'Process temperature [K]',
'Rotational speed [rpm]',
'Torque [Nm]',
'Tool wear [min]'
categorical_features = ['Type']
# Numeric features
for feature in numeric_features:
mean_feature_training_data = Xtrain[feature].mean()
std_feature_training_data = Xtrain[feature].std()
mean_feature_sample_logs = sample_df[feature].mean()
mean_diff = abs(mean_feature_training_data - mean_feature_sample_logs)
if mean_diff > 2 * std_feature_training_data:
data_drift_status[feature] = ["Data Drift Detected! Check Logs!"]
data_drift_status[feature] = ["No Data Drift!"]
# Categorical feature Type
live_proportions = sample_df['Type'].value_counts(normalize=True).values
training_proportions = Xtrain['Type'].value_counts(normalize=True).values
psi_value = psi(live_proportions, training_proportions)
if psi_value > 0.1:
data_drift_status['Type'] = ["Data Drift Detected! Check Logs!"]
data_drift_status['Type'] = ["No Data Drift!"]
return pd.DataFrame.from_dict(data_drift_status)
with gr.Blocks(theme=gr.themes.Base()) as demo:
gr.Markdown("# Real-time Monitoring Dashboard")
gr.Markdown("## Model drift detection (every 5 seconds)")
with gr.Row():
with gr.Column():
gr.Textbox(check_model_drift, every=5, label="Model Drift Status")
gr.Markdown("## Distribution of Training Targets")
with gr.Row():
with gr.Column():
gr.Plot(plot_target_distributions, every=86400, label="Target Data Distributions")
gr.Markdown("## Data drift detection (every 5 seconds)")
with gr.Row():
with gr.Column():
gr.DataFrame(check_data_drift, every=5, min_width=240, label="Data Drift Status")