waleko commited on
Commit
ff83063
1 Parent(s): 2b8c339

first version

Browse files
app.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import numpy as np
4
+ from sklearn.pipeline import make_pipeline
5
+ from catboost import CatBoostClassifier
6
+ from sklearn.preprocessing import StandardScaler
7
+ import shap
8
+ import matplotlib.pyplot as plt
9
+ from sklearn.decomposition import PCA
10
+ from sklearn.feature_selection import SelectKBest
11
+ from sklearn.ensemble import RandomForestClassifier
12
+ from sklearn.preprocessing import StandardScaler
13
+ from sklearn.pipeline import make_pipeline
14
+ from sklearn.linear_model import LogisticRegression
15
+ from catboost import CatBoostClassifier
16
+ from sklearn.base import BaseEstimator, TransformerMixin
17
+ from sklearn.cluster import DBSCAN
18
+ from sklearn.neighbors import NearestNeighbors
19
+ import numpy as np
20
+ import pandas as pd
21
+ from tqdm.auto import tqdm
22
+ from sklearn.preprocessing import OneHotEncoder
23
+ import pickle
24
+
25
+
26
+ class CustomFeatureTransformer(BaseEstimator, TransformerMixin):
27
+ def __init__(self, verbose=False):
28
+ self.verbose = verbose
29
+ self.column_means_ = None
30
+
31
+ def fit(self, X, y=None):
32
+ X_copy = X.copy()
33
+
34
+ self.numerical_columns = list(X_copy.select_dtypes(include=np.number).columns)
35
+ self.categorical_columns = list(X_copy.select_dtypes(exclude=np.number).columns)
36
+ # filter out with > 100 unique values
37
+ for col in self.categorical_columns:
38
+ if len(X_copy[col].unique()) > 100:
39
+ self.categorical_columns.remove(col)
40
+ if self.verbose:
41
+ print(f'removed {col} with {len(X_copy[col].unique())} unique values')
42
+
43
+ # Store means for each column
44
+ self.column_means_ = X_copy[self.numerical_columns].mean().fillna(0)
45
+ self.onehot_encoder = OneHotEncoder(sparse_output=False, handle_unknown='ignore')
46
+ self.onehot_encoder.fit(X_copy[self.categorical_columns])
47
+
48
+ return self
49
+
50
+ def transform(self, X):
51
+ X_copy = X.copy()
52
+ X_copy.reset_index(drop=True, inplace=True)
53
+ result_dfs = []
54
+
55
+ # Process each column
56
+ for col in self.numerical_columns:
57
+ # Add is_null indicator
58
+ is_null = X_copy[col].isna()
59
+ result_dfs.append(pd.DataFrame({
60
+ f"{col}_is_null": is_null.astype(int)
61
+ }))
62
+
63
+ filled_values = X_copy[col].fillna(self.column_means_[col])
64
+ result_dfs.append(pd.DataFrame({
65
+ f"{col}_value": filled_values
66
+ }))
67
+
68
+ # Add non-numerical columns using one-hot encoding
69
+ result_dfs.append(pd.DataFrame(self.onehot_encoder.transform(X_copy[self.categorical_columns]), columns=self.onehot_encoder.get_feature_names_out()))
70
+
71
+ # Concatenate all transformed features
72
+ df = pd.concat(result_dfs, axis=1)
73
+ assert not df.isna().any().any()
74
+ return df
75
+
76
+
77
+ class DayNumberTransformer:
78
+ def __init__(self):
79
+ pass
80
+
81
+ def fit(self, X, y=None):
82
+ return self
83
+
84
+ def transform(self, X, y=None):
85
+ X = X.copy()
86
+ X['message_timestamp'] = pd.to_datetime(X['message_timestamp'])
87
+ X['week_number'] = X['message_timestamp'].dt.strftime('%U %w')
88
+ return X
89
+
90
+ class WeatherTransformer:
91
+ def __init__(self, weather):
92
+ self.weather = weather
93
+ self.weather['date'] = pd.to_datetime(self.weather['date']).dt.tz_convert('Europe/Berlin')
94
+
95
+ def fit(self, X, y=None):
96
+ return self
97
+
98
+ def transform(self, X, y=None):
99
+ X = X.copy()
100
+
101
+ # round ot hour
102
+ X['message_timestamp'] = pd.to_datetime(X['message_timestamp']).dt.tz_localize('Europe/Berlin')
103
+ X['message_timestamp'] = X['message_timestamp'].dt.round('h')
104
+
105
+ # join weather data by column message_timestamp and date
106
+ X = X.merge(self.weather, left_on='message_timestamp', right_on='date', how='left')
107
+ # print number of rows in X that have no weather data
108
+ if X['temperature_2m'].isna().sum() > 0:
109
+ print("Number of rows without weather data: ", X['temperature_2m'].isna().sum())
110
+
111
+ columns_X = X.columns
112
+ # delete all that contain 'sensor' in the name
113
+ columns_X = [col for col in columns_X if 'sensor' not in col]
114
+
115
+ # print("Columns in X: ", columns_X)
116
+ # 1 / 0
117
+
118
+ return X
119
+
120
+ class TopFeaturesSelector:
121
+ def __init__(self, top_features):
122
+ self.top_features = top_features
123
+
124
+ def fit(self, X, y=None):
125
+ return self
126
+
127
+ def transform(self, X, y=None):
128
+ return X[self.top_features]
129
+
130
+
131
+
132
+
133
+ import warnings
134
+ warnings.filterwarnings("ignore")
135
+
136
+
137
+
138
+ weather_file = 'hourly_data.csv'
139
+ shap_importance_file = 'shap_importance.csv'
140
+ weather = pd.read_csv(weather_file)
141
+ shap_importance_df = pd.read_csv(shap_importance_file)
142
+ top_features = shap_importance_df['Feature'].head(25).values
143
+ catboost = CatBoostClassifier().load_model('catboost_model.cbm')
144
+ scaler = pickle.load(open('scaler.pkl', 'rb'))
145
+ custom_feature_transformer = pickle.load(open('customfeatureselector.pkl', 'rb'))
146
+
147
+ # Define the sklearn pipeline
148
+ pipe = make_pipeline(
149
+ WeatherTransformer(weather),
150
+ DayNumberTransformer(),
151
+ custom_feature_transformer,
152
+ TopFeaturesSelector(top_features),
153
+ scaler,
154
+ catboost
155
+ )
156
+
157
+
158
+ def egor_plots(X_test, k=1000):
159
+ # Preprocess X_test
160
+ X_prescaled = pipe[:-2].transform(X_test)[:k]
161
+ X_test_preprocessed = pipe[-2].transform(X_prescaled)
162
+
163
+ # SHAP Analysis
164
+ st.write("SHAP Analysis... This may take a couple of minutes depending on the number of samples.")
165
+ explainer = shap.TreeExplainer(pipe[-1])
166
+ shap_values = explainer(X_test_preprocessed)
167
+ shap_values.feature_names = X_prescaled.columns
168
+
169
+ # SHAP Summary Plot
170
+ st.write("### SHAP Summary Plot")
171
+ fig_summary = shap.summary_plot(shap_values, X_test_preprocessed, show=False)
172
+ st.pyplot(fig_summary)
173
+
174
+ # SHAP Scatter Plots
175
+ st.write("### SHAP Scatter Plots")
176
+ for i in range(25):
177
+ feature_name = top_features[i]
178
+ st.write(f"#### Scatter Plot for Feature: {feature_name}")
179
+ fig, ax = plt.subplots()
180
+ shap.plots.scatter(shap_values[:, i], X_test_preprocessed[:, i], show=False, ax=ax)
181
+ ax.axhline(y=0, color='r', linestyle='--')
182
+ ax.axvline(x=0, color='g', linestyle='--')
183
+ st.pyplot(fig)
184
+
185
+ # Streamlit App
186
+ st.title("BMW Hackathon Defect Detection")
187
+ st.write("### Upload your tabular data")
188
+
189
+ # File uploader
190
+ uploaded_file = st.file_uploader("Upload a CSV file", type=["csv"])
191
+ # Add radio button for prediction type
192
+ prediction_type = st.radio(
193
+ "Select prediction type",
194
+ ["predict", "predict_proba"],
195
+ index=0
196
+ )
197
+ k = st.slider("Number of samples for SHAP plots", min_value=10, max_value=1000, value=100)
198
+
199
+ if uploaded_file:
200
+ # Load the uploaded file
201
+ data = pd.read_csv(uploaded_file)
202
+ st.write("Uploaded Data:")
203
+ st.write(data.head())
204
+
205
+ st.write("Predicting...")
206
+ if prediction_type == 'predict':
207
+ y_pred = pipe.predict(data)
208
+ # status 1 -> OK, 0 -> NOK
209
+ status = pd.Series(['OK' if pred == 1 else 'NOK' for pred in y_pred])
210
+ elif prediction_type == 'predict_proba':
211
+ status = pipe.predict_proba(data)[:, 1]
212
+ else:
213
+ raise ValueError(f"Invalid prediction type: {prediction_type}")
214
+ res = pd.DataFrame(
215
+ {"physical_part_id": data["physical_part_id"],
216
+ "status": status}
217
+ )
218
+ st.write("### Results")
219
+ st.write(res.head())
220
+ # Download the predictions as CSV
221
+ csv = res.to_csv(index=False)
222
+ st.download_button(
223
+ label="Download predictions as CSV",
224
+ data=csv,
225
+ file_name="predictions.csv",
226
+ mime="text/csv"
227
+ )
228
+ st.write("### SHAP plots")
229
+ egor_plots(data)
catboost_model.cbm ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ad04e1e3d1f47b2472afe968e4f9f5a6944766136c8a51e83d22e38d9d6fbbb5
3
+ size 32984696
customfeatureselector.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3cdb80823b17987678aeca72c82f1dcc6d96734ee3d90c6c18a29461f0c18094
3
+ size 18816
hourly_data.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:16dea43ddccacab82981233c1dddec016696c42d887ca55e6d8fddf52e10d524
3
+ size 160429
predictions.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:59e60207af9ea2f7c7ca343325b2d380f4a3f2cbbb844c55a5caf2ff412c231e
3
+ size 555823
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ scikit-learn
3
+ pandas
4
+ numpy
5
+ catboost
6
+ shap
7
+ matplotlib
8
+ tqdm
scaler.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2807a66307ef299ad1e1818faf34916950f5d5c22c5924a4aee8575903c31d67
3
+ size 1916
shap_importance.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a7cf3a7acd2c10ae46196f67d25e680950c92ebacde9ed6a3802c9d3d2502ddc
3
+ size 43112