Benjamin Bossan commited on
Commit
0598e08
1 Parent(s): d747363

A simple logistic regression model

Browse files

- Update README.md (incl. model card)
- Add training script
- Add model artifact

Files changed (5) hide show
  1. .gitattributes +2 -0
  2. README.md +31 -0
  3. model.pickle +3 -0
  4. requirements.txt +1 -0
  5. train.py +65 -0
.gitattributes CHANGED
@@ -13,6 +13,8 @@
13
  *.ot filter=lfs diff=lfs merge=lfs -text
14
  *.parquet filter=lfs diff=lfs merge=lfs -text
15
  *.pb filter=lfs diff=lfs merge=lfs -text
 
 
16
  *.pt filter=lfs diff=lfs merge=lfs -text
17
  *.pth filter=lfs diff=lfs merge=lfs -text
18
  *.rar filter=lfs diff=lfs merge=lfs -text
 
13
  *.ot filter=lfs diff=lfs merge=lfs -text
14
  *.parquet filter=lfs diff=lfs merge=lfs -text
15
  *.pb filter=lfs diff=lfs merge=lfs -text
16
+ *.pickle filter=lfs diff=lfs merge=lfs -text
17
+ *.pkl filter=lfs diff=lfs merge=lfs -text
18
  *.pt filter=lfs diff=lfs merge=lfs -text
19
  *.pth filter=lfs diff=lfs merge=lfs -text
20
  *.rar filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,3 +1,34 @@
1
  ---
2
  license: bsd-3-clause
 
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: bsd-3-clause
3
+ tags:
4
+ - sklearn
5
+ datasets:
6
+ - synthetic dataset from sklearn
7
+ metrics:
8
+ - type: accuracy
9
+ value: 0.948
10
  ---
11
+
12
+ # Simple example using plain scikit-learn
13
+
14
+ ## Reproducing the model
15
+
16
+ Inside a Python environment, install the dependencies listed in `requirements.txt` and then run:
17
+
18
+ ``` bash
19
+ python train.py
20
+ ```
21
+
22
+ The resulting model artifact should be stored in `model.pickle`.
23
+
24
+ ## The model
25
+
26
+ The used model is a simple logistic regression trained through gradient descent.
27
+
28
+ ## Intended use & limitations
29
+
30
+ This model is just for demonstration purposes and should thus not be used.
31
+
32
+ ## Dataset
33
+
34
+ The dataset is entirely synthetic and has no real world origin.
model.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:49024e6163c30049244412395379a7189646f0080a9368d2c92f7ef6cfb3041e
3
+ size 1112
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ scikit-learn==1.0.1
train.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Script to create the model artifact
2
+
3
+ Trains a simple logistic regression with grid search on a synthetic dataset and
4
+ stores the model in a pickle file.
5
+
6
+ """
7
+
8
+ import pickle
9
+
10
+ from sklearn.datasets import make_classification
11
+ from sklearn.linear_model import SGDClassifier
12
+ from sklearn.model_selection import GridSearchCV
13
+
14
+
15
+ SEED = 0
16
+
17
+
18
+ def get_data():
19
+ X, y = make_classification(n_samples=1000, random_state=SEED)
20
+ return X, y
21
+
22
+
23
+ def get_model(**kwargs):
24
+ model = SGDClassifier(random_state=SEED)
25
+ model.set_params(**kwargs)
26
+ return model
27
+
28
+
29
+ def get_hparams():
30
+ hparams = {
31
+ 'penalty': ['l1', 'l2'],
32
+ 'alpha': [0.00001, 0.0001, 0.001],
33
+ }
34
+ return hparams
35
+
36
+
37
+ def grid_search(model, X, y, hparams):
38
+ search = GridSearchCV(model, hparams, cv=5, scoring='accuracy')
39
+ search.fit(X, y)
40
+ return search
41
+
42
+
43
+ def train(model, X, y, hparams):
44
+ search = grid_search(model, X, y, hparams=hparams)
45
+ print(f"Best accuracy: {100 * search.best_score_:.1f}%")
46
+ print(f"Best parameters: {search.best_params_}")
47
+ return search.best_estimator_
48
+
49
+
50
+ def save_model(model, filename):
51
+ with open(filename, 'wb') as f:
52
+ pickle.dump(model, f)
53
+ print(f"Stored model in '{filename}'")
54
+
55
+
56
+ def main():
57
+ X, y = get_data()
58
+ model = get_model()
59
+ hparams = get_hparams()
60
+ model_trained = train(model, X, y, hparams=hparams)
61
+ save_model(model_trained, 'model.pickle')
62
+
63
+
64
+ if __name__ == '__main__':
65
+ main()