epalvarez commited on
Commit
a86b5b2
1 Parent(s): 04e753c

First upload of the app.py without the scheduler

Browse files

The scheduler code has been commented out on the app.py file for this initial version.

Files changed (3) hide show
  1. app.py +87 -0
  2. model_mf.joblib +3 -0
  3. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # +++
2
+ import os
3
+ import uuid
4
+ import joblib
5
+ import json
6
+
7
+ # IMPORTANT: I already installed the package "gradio" in my current Virtual Environment (VEnvDSDIL_gpu_Py3.12) as: pip install -q gradio_client
8
+ # Do NOT install "gradio_client" package again in Anaconda otherwise it will mess up the package.
9
+ import gradio as gr
10
+ import pandas as pd
11
+
12
+ # must install the package "huggingface_hub" first in the current python Virtual Environment, with pip, not with conda, as follows
13
+ # pip install huggingface_hub
14
+ # i.e., in the command line interface within the activated Virtual Environment:
15
+ # (VEnvDSDIL_gpu_Py3.12) epalvarez@DSDILmStation01:~ $ pip install huggingface_hub
16
+ from huggingface_hub import CommitScheduler
17
+ from pathlib import Path
18
+
19
+ log_file = Path("logs/") / f"data_{uuid.uuid4()}.json"
20
+ log_folder = log_file.parent
21
+
22
+ # Scheduler will log every 2 API calls:
23
+ # scheduler = CommitScheduler(
24
+ # repo_id="machine-failure-logs",
25
+ # repo_type="dataset",
26
+ # folder_path=log_folder,
27
+ # path_in_repo="data",
28
+ # every=2
29
+ # )
30
+
31
+ machine_failure_predictor = joblib.load('model_mf.joblib')
32
+
33
+ air_temperature_input = gr.Number(label='Air temperature [K]')
34
+ process_temperature_input = gr.Number(label='Process temperature [K]')
35
+ rotational_speed_input = gr.Number(label='Rotational speed [rpm]')
36
+ torque_input = gr.Number(label='Torque [Nm]')
37
+ tool_wear_input = gr.Number(label='Tool wear [min]')
38
+ type_input = gr.Dropdown(
39
+ ['L', 'M', 'H'],
40
+ label='Type'
41
+ )
42
+
43
+ model_output = gr.Label(label="Machine failure")
44
+
45
+ def predict_machine_failure(air_temperature, process_temperature, rotational_speed, torque, tool_wear, type):
46
+ sample = {
47
+ 'Air temperature [K]': air_temperature,
48
+ 'Process temperature [K]': process_temperature,
49
+ 'Rotational speed [rpm]': rotational_speed,
50
+ 'Torque [Nm]': torque,
51
+ 'Tool wear [min]': tool_wear,
52
+ 'Type': type
53
+ }
54
+ data_point = pd.DataFrame([sample])
55
+ prediction = machine_failure_predictor.predict(data_point).tolist()
56
+
57
+ # Each time we get a prediction we will determine if we should log it to a hugging_face dataset according to the schedule definition outside this function
58
+ # with scheduler.lock:
59
+ # with log_file.open("a") as f:
60
+ # f.write(json.dumps(
61
+ # {
62
+ # 'Air temperature [K]': air_temperature,
63
+ # 'Process temperature [K]': process_temperature,
64
+ # 'Rotational speed [rpm]': rotational_speed,
65
+ # 'Torque [Nm]': torque,
66
+ # 'Tool wear [min]': tool_wear,
67
+ # 'Type': type,
68
+ # 'prediction': prediction[0]
69
+ # }
70
+ # ))
71
+ # f.write("\n")
72
+
73
+ return prediction[0]
74
+
75
+ demo = gr.Interface(
76
+ fn=predict_machine_failure,
77
+ inputs=[air_temperature_input, process_temperature_input, rotational_speed_input,
78
+ torque_input, tool_wear_input, type_input],
79
+ outputs=model_output,
80
+ title="Machine Failure Predictor",
81
+ description="This API allows you to predict the machine failure status of an equipment",
82
+ allow_flagging="auto",
83
+ concurrency_limit=8
84
+ )
85
+
86
+ demo.queue()
87
+ demo.launch(share=False)
model_mf.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e8745b4d4e2c0da514f0edb99c23932b1a12b6af2b97fbcf517af800f4ad5088
3
+ size 4238
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #scikit-learn==1.2.2
2
+ scikit-learn==1.5.0
3
+ joblib==1.4.0
4
+ pandas==2.2.2
5
+ numpy==2.0.0