pgurazada1's picture
Update app.py
166dae4 verified
raw
history blame contribute delete
No virus
3.35 kB
import os
import uuid
import joblib
import json
import gradio as gr
import pandas as pd
from huggingface_hub import CommitScheduler
from pathlib import Path
# Run the training script placed in the same directory as app.py
# The training script will train and persist a logistic regression
# model with the filename 'model.joblib'
os.system("python train.py")
# Load the freshly trained model from disk
machine_failure_predictor = joblib.load('model.joblib')
# Prepare the logging functionality
log_file = Path("logs/") / f"data_{uuid.uuid4()}.json"
log_folder = log_file.parent
scheduler = CommitScheduler(
repo_id="machine-failure-mlops-demo-logs",
repo_type="dataset",
folder_path=log_folder,
path_in_repo="data",
every=2
)
# Define the predict function that runs when 'Submit' is clicked or when a API request is made
def predict_machine_failure(air_temperature, process_temperature, rotational_speed, torque, tool_wear, type):
sample = {
'Air temperature [K]': air_temperature,
'Process temperature [K]': process_temperature,
'Rotational speed [rpm]': rotational_speed,
'Torque [Nm]': torque,
'Tool wear [min]': tool_wear,
'Type': type
}
data_point = pd.DataFrame([sample])
prediction = machine_failure_predictor.predict(data_point).tolist()
# While the prediction is made, log both the inputs and outputs to a local log file
# While writing to the log file, ensure that the commit scheduler is locked to avoid parallel
# access
with scheduler.lock:
with log_file.open("a") as f:
f.write(json.dumps(
{
'Air temperature [K]': air_temperature,
'Process temperature [K]': process_temperature,
'Rotational speed [rpm]': rotational_speed,
'Torque [Nm]': torque,
'Tool wear [min]': tool_wear,
'Type': type,
'prediction': prediction[0]
}
))
f.write("\n")
return prediction[0]
# Set up UI components for input and output
air_temperature_input = gr.Number(label='Air temperature [K]')
process_temperature_input = gr.Number(label='Process temperature [K]')
rotational_speed_input = gr.Number(label='Rotational speed [rpm]')
torque_input = gr.Number(label='Torque [Nm]')
tool_wear_input = gr.Number(label='Tool wear [min]')
type_input = gr.Dropdown(
['L', 'M', 'H'],
label='Type'
)
model_output = gr.Label(label="Machine failure")
# Create the interface
demo = gr.Interface(
fn=predict_machine_failure,
inputs=[air_temperature_input, process_temperature_input, rotational_speed_input,
torque_input, tool_wear_input, type_input],
outputs=model_output,
theme=gr.themes.Base(),
title="Machine Failure Predictor",
description="This API allows you to predict the machine failure status of an equipment",
examples=[[300.8, 310.3, 1538, 36.1, 198, 'L'],
[296.3, 307.3, 1368, 49.5, 10, 'M'],
[298.6, 309.1, 1339, 51.1, 34, 'M'],
[302.4, 311.1, 1634, 34.2, 184, 'L'],
[297.9, 307.7, 1546, 37.6, 72, 'L']],
concurrency_limit=32
)
# Launch with a load balancer
demo.queue()
demo.launch(share=False)