plain-sklearn / train.py
Benjamin Bossan
Update model
4982c2f
"""Script to create the model artifact
Trains a simple logistic regression with grid search on a synthetic dataset and
stores the model in a pickle file.
"""
import joblib
from sklearn.datasets import make_classification
from sklearn.linear_model import SGDClassifier
from sklearn.model_selection import GridSearchCV
SEED = 0
FILENAME = 'sklearn_model.joblib'
def get_data():
X, y = make_classification(n_samples=2000, random_state=SEED)
return X, y
def get_model(**kwargs):
model = SGDClassifier(random_state=SEED)
model.set_params(**kwargs)
return model
def get_hparams():
hparams = {
'penalty': ['l1', 'l2'],
'alpha': [0.00001, 0.0001, 0.001],
}
return hparams
def grid_search(model, X, y, hparams):
search = GridSearchCV(model, hparams, cv=5, scoring='accuracy')
search.fit(X, y)
return search
def train(model, X, y, hparams):
search = grid_search(model, X, y, hparams=hparams)
print(f"Best accuracy: {100 * search.best_score_:.1f}%")
print(f"Best parameters: {search.best_params_}")
return search.best_estimator_
def save_model(model, filename):
joblib.dump(model, filename)
print(f"Stored model in '{filename}'")
def main():
X, y = get_data()
model = get_model()
hparams = get_hparams()
model_trained = train(model, X, y, hparams=hparams)
save_model(model_trained, FILENAME)
if __name__ == '__main__':
main()