profitboost / prediction.py
7sugiwa's picture
Upload 4 files
5a56710 verified
raw
history blame
No virus
2.2 kB
import joblib
import pandas as pd
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.preprocessing import OneHotEncoder
from sklearn.cluster import KMeans
# Custom Transformer: UnitPriceTransformer
class UnitPriceTransformer(BaseEstimator, TransformerMixin):
def fit(self, X, y=None):
return self
def transform(self, X):
X = X.copy() # Work on a copy to avoid SettingWithCopyWarning
X['unit_price'] = X['sales'] / X['quantity']
return X[['unit_price']]
# Custom Transformer: KMeansAndLabelTransformer
class KMeansAndLabelTransformer(BaseEstimator, TransformerMixin):
def __init__(self, n_clusters=3):
self.n_clusters = n_clusters
self.kmeans = KMeans(n_clusters=n_clusters, random_state=42)
def fit(self, X, y=None):
self.kmeans.fit(X[['unit_price']])
return self
def transform(self, X):
X = X.copy() # Work on a copy to avoid SettingWithCopyWarning
cluster_labels = self.kmeans.predict(X[['unit_price']])
X['distinct_cluster_label'] = cluster_labels.astype(str) + "_" + X['sub_category']
return X[['distinct_cluster_label']]
# Custom Transformer: DynamicOneHotEncoder
class DynamicOneHotEncoder(BaseEstimator, TransformerMixin):
def __init__(self):
self.encoder = OneHotEncoder(handle_unknown='ignore')
def fit(self, X, y=None):
self.encoder.fit(X[['distinct_cluster_label']])
return self
def transform(self, X):
X = X.copy() # Work on a copy to avoid SettingWithCopyWarning
encoded_features = self.encoder.transform(X[['distinct_cluster_label']]).toarray()
# Create a DataFrame with the encoded features
encoded_df = pd.DataFrame(encoded_features, columns=self.encoder.get_feature_names_out(['distinct_cluster_label']))
return encoded_df
# Load the pipeline and model
pipeline = joblib.load('full_pipeline_with_unit_price.pkl')
model = joblib.load('best_model.pkl')
def make_prediction(input_features):
processed_features = pipeline.transform(pd.DataFrame([input_features]))
prediction = model.predict(processed_features)
return prediction[0]