merve HF staff commited on
Commit
a487597
1 Parent(s): 0b2f5a5

Upload train.py

Browse files
Files changed (1) hide show
  1. train.py +189 -0
train.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %% Importing the dependencies we need
2
+ import numpy as np
3
+ import torch
4
+ from sklearn.datasets import fetch_20newsgroups
5
+ from sklearn.metrics import (accuracy_score, f1_score, confusion_matrix,
6
+ ConfusionMatrixDisplay, classification_report)
7
+ from sklearn.model_selection import train_test_split
8
+ from sklearn.pipeline import Pipeline
9
+ from skops import card, hub_utils
10
+ from skorch import NeuralNetClassifier
11
+ from skorch.callbacks import LRScheduler, ProgressBar
12
+ from skorch.hf import HuggingfacePretrainedTokenizer
13
+ from torch import nn
14
+ from torch.optim.lr_scheduler import LambdaLR
15
+ from transformers import AutoModelForSequenceClassification
16
+ from transformers import AutoTokenizer
17
+ # for model hosting and requirements
18
+ from pathlib import Path
19
+ import transformers
20
+ import skorch
21
+ import sklearn
22
+ import torch
23
+
24
+ # %%
25
+ # Choose a tokenizer and BERT model that work together
26
+ TOKENIZER = "distilbert-base-uncased"
27
+ PRETRAINED_MODEL = "distilbert-base-uncased"
28
+
29
+ # model hyper-parameters
30
+ OPTMIZER = torch.optim.AdamW
31
+ LR = 5e-5
32
+ MAX_EPOCHS = 3
33
+ CRITERION = nn.CrossEntropyLoss
34
+ BATCH_SIZE = 8
35
+
36
+ # device
37
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
38
+
39
+ # %% Load the dataset, define features & labels and split
40
+ dataset = fetch_20newsgroups()
41
+
42
+ print(dataset.DESCR.split('Usage')[0])
43
+
44
+ dataset.target_names
45
+
46
+ X = dataset.data
47
+ y = dataset.target
48
+ X_train, X_test, y_train, y_test, = train_test_split(X, y, stratify=y, random_state=0)
49
+ num_training_steps = MAX_EPOCHS * (len(X_train) // BATCH_SIZE + 1)
50
+
51
+ # %%
52
+ # Defining learning rate scheduler & BERT in nn.Module
53
+
54
+ def lr_schedule(current_step):
55
+ factor = float(num_training_steps - current_step) / float(max(1, num_training_steps))
56
+ assert factor > 0
57
+ return factor
58
+
59
+ class BertModule(nn.Module):
60
+ def __init__(self, name, num_labels):
61
+ super().__init__()
62
+ self.name = name
63
+ self.num_labels = num_labels
64
+
65
+ self.reset_weights()
66
+
67
+ def reset_weights(self):
68
+ self.bert = AutoModelForSequenceClassification.from_pretrained(
69
+ self.name, num_labels=self.num_labels
70
+ )
71
+
72
+ def forward(self, **kwargs):
73
+ pred = self.bert(**kwargs)
74
+ return pred.logits
75
+
76
+ # %% Chaining tokenizer and BERT in one pipeline
77
+ pipeline = Pipeline([
78
+ ('tokenizer', HuggingfacePretrainedTokenizer(TOKENIZER)),
79
+ ('net', NeuralNetClassifier(
80
+ BertModule,
81
+ module__name=PRETRAINED_MODEL,
82
+ module__num_labels=len(set(y_train)),
83
+ optimizer=OPTMIZER,
84
+ lr=LR,
85
+ max_epochs=MAX_EPOCHS,
86
+ criterion=CRITERION,
87
+ batch_size=BATCH_SIZE,
88
+ iterator_train__shuffle=True,
89
+ device=DEVICE,
90
+ callbacks=[
91
+ LRScheduler(LambdaLR, lr_lambda=lr_schedule, step_every='batch'),
92
+ ProgressBar(),
93
+ ],
94
+ )),
95
+ ])
96
+
97
+ torch.manual_seed(0)
98
+ torch.cuda.manual_seed(0)
99
+ torch.cuda.manual_seed_all(0)
100
+ np.random.seed(0)
101
+
102
+ # %% Training
103
+ %time pipeline.fit(X_train, y_train)
104
+
105
+ # %% Evaluate the model
106
+ %%time
107
+ with torch.inference_mode():
108
+ y_pred = pipeline.predict(X_test)
109
+
110
+ accuracy_score(y_test, y_pred)
111
+
112
+ # %% Save the model
113
+ import pickle
114
+ with open("model.pkl", mode="bw") as f:
115
+ pickle.dump(pipeline, file=f)
116
+
117
+ # %% Initialize the repository for Hub
118
+ local_repo = "model_repo"
119
+ hub_utils.init(
120
+ model="model.pkl",
121
+ requirements=[f"scikit-learn={sklearn.__version__}", f"transformers={transformers.__version__}",
122
+ f"torch={torch.__version__}", f"skorch={skorch.__version__}"],
123
+ dst=local_repo,
124
+ task="text-classification",
125
+ data=X_test,
126
+ )
127
+
128
+ # %% Create model card
129
+ model_card = card.Card(pipeline, metadata=card.metadata_from_config(Path("model_repo")))
130
+
131
+ # %% We will add information related to model
132
+ model_description = (
133
+ "This is a neural net classifier and distilbert model chained with sklearn Pipeline trained on 20 news groups dataset."
134
+ )
135
+ limitations = "This model is trained for a tutorial and is not ready to be used in production."
136
+ model_card.add(
137
+ model_description=model_description,
138
+ limitations=limitations
139
+ )
140
+
141
+ # %% We can add plots, evaluation results and more!
142
+ eval_descr = (
143
+ "The model is evaluated on validation data from 20 news group's test split,"
144
+ " using accuracy and F1-score with micro average."
145
+ )
146
+ model_card.add(eval_method=eval_descr)
147
+
148
+ accuracy = accuracy_score(y_test, y_pred)
149
+ f1 = f1_score(y_test, y_pred, average="micro")
150
+ model_card.add_metrics(**{"accuracy": accuracy, "f1 score": f1})
151
+
152
+
153
+ cm = confusion_matrix(y_test, y_pred, labels=pipeline.classes_)
154
+ disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=pipeline.classes_)
155
+ disp.plot()
156
+
157
+ disp.figure_.savefig(Path(local_repo) / "confusion_matrix.png")
158
+ model_card.add_plot(**{"Confusion matrix": "confusion_matrix.png"})
159
+
160
+ clf_report = classification_report(
161
+ y_test, y_pred, output_dict=True, target_names=dataset.target_names
162
+ )
163
+ # %% We can add classification report as a table
164
+ # We first need to convert classification report to DataFrame to add it as a table
165
+ import pandas as pd
166
+ del clf_report["accuracy"]
167
+ clf_report = pd.DataFrame(clf_report).T.reset_index()
168
+ model_card.add_table(
169
+ folded=True,
170
+ **{
171
+ "Classification Report": clf_report,
172
+ },
173
+ )
174
+
175
+ # %% We will save our model card
176
+ model_card.save(Path(local_repo) / "README.md")
177
+
178
+ # %% We will add the training script to our repository
179
+ hub_utils.add_files(__file__, dst=local_repo)
180
+
181
+ # %% Push to Hub! This requires us to authenticate ourselves first.
182
+ from huggingface_hub import notebook_login
183
+ notebook_login()
184
+
185
+ hub_utils.push(
186
+ repo_id="scikit-learn/skorch-text-classification",
187
+ source=local_repo,
188
+ create_remote=True,
189
+ )