lvwerra HF staff commited on
Commit
ef4e73a
1 Parent(s): b767fcc

Upload distilbert-imdb-training.ipynb

Browse files
Files changed (1) hide show
  1. distilbert-imdb-training.ipynb +292 -0
distilbert-imdb-training.ipynb ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Train IMDb Classifier\n",
8
+ "> Train a IMDb classifier with DistilBERT."
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": null,
14
+ "metadata": {},
15
+ "outputs": [],
16
+ "source": [
17
+ "!huggingface-cli login"
18
+ ]
19
+ },
20
+ {
21
+ "cell_type": "markdown",
22
+ "metadata": {},
23
+ "source": [
24
+ "## Load IMDb dataset"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "execution_count": null,
30
+ "metadata": {},
31
+ "outputs": [],
32
+ "source": [
33
+ "from datasets import load_dataset, load_metric"
34
+ ]
35
+ },
36
+ {
37
+ "cell_type": "code",
38
+ "execution_count": null,
39
+ "metadata": {},
40
+ "outputs": [],
41
+ "source": [
42
+ "ds = load_dataset(\"imdb\")"
43
+ ]
44
+ },
45
+ {
46
+ "cell_type": "code",
47
+ "execution_count": null,
48
+ "metadata": {},
49
+ "outputs": [
50
+ {
51
+ "data": {
52
+ "text/plain": [
53
+ "DatasetDict({\n",
54
+ " train: Dataset({\n",
55
+ " features: ['text', 'label'],\n",
56
+ " num_rows: 25000\n",
57
+ " })\n",
58
+ " test: Dataset({\n",
59
+ " features: ['text', 'label'],\n",
60
+ " num_rows: 25000\n",
61
+ " })\n",
62
+ " unsupervised: Dataset({\n",
63
+ " features: ['text', 'label'],\n",
64
+ " num_rows: 50000\n",
65
+ " })\n",
66
+ "})"
67
+ ]
68
+ },
69
+ "execution_count": null,
70
+ "metadata": {},
71
+ "output_type": "execute_result"
72
+ }
73
+ ],
74
+ "source": [
75
+ "ds"
76
+ ]
77
+ },
78
+ {
79
+ "cell_type": "code",
80
+ "execution_count": null,
81
+ "metadata": {},
82
+ "outputs": [
83
+ {
84
+ "data": {
85
+ "text/plain": [
86
+ "{'label': ClassLabel(num_classes=2, names=['neg', 'pos'], names_file=None, id=None),\n",
87
+ " 'text': Value(dtype='string', id=None)}"
88
+ ]
89
+ },
90
+ "execution_count": null,
91
+ "metadata": {},
92
+ "output_type": "execute_result"
93
+ }
94
+ ],
95
+ "source": [
96
+ "ds['train'].features"
97
+ ]
98
+ },
99
+ {
100
+ "cell_type": "markdown",
101
+ "metadata": {},
102
+ "source": [
103
+ "## Load Pretrained DistilBERT"
104
+ ]
105
+ },
106
+ {
107
+ "cell_type": "code",
108
+ "execution_count": null,
109
+ "metadata": {},
110
+ "outputs": [],
111
+ "source": [
112
+ "from transformers import AutoModelForSequenceClassification, AutoTokenizer\n",
113
+ "\n",
114
+ "model_name = \"distilbert-base-uncased\"\n",
115
+ "model = AutoModelForSequenceClassification.from_pretrained(model_name)\n",
116
+ "tokenizer = AutoTokenizer.from_pretrained(model_name)"
117
+ ]
118
+ },
119
+ {
120
+ "cell_type": "markdown",
121
+ "metadata": {},
122
+ "source": [
123
+ "## Prepocess Data"
124
+ ]
125
+ },
126
+ {
127
+ "cell_type": "code",
128
+ "execution_count": null,
129
+ "metadata": {},
130
+ "outputs": [
131
+ {
132
+ "data": {
133
+ "application/vnd.jupyter.widget-view+json": {
134
+ "model_id": "6ddef2e0d4a04e12ad7513950158236c",
135
+ "version_major": 2,
136
+ "version_minor": 0
137
+ },
138
+ "text/plain": [
139
+ " 0%| | 0/25 [00:00<?, ?ba/s]"
140
+ ]
141
+ },
142
+ "metadata": {},
143
+ "output_type": "display_data"
144
+ },
145
+ {
146
+ "data": {
147
+ "application/vnd.jupyter.widget-view+json": {
148
+ "model_id": "4b1392a042614a1682b6f62642262446",
149
+ "version_major": 2,
150
+ "version_minor": 0
151
+ },
152
+ "text/plain": [
153
+ " 0%| | 0/25 [00:00<?, ?ba/s]"
154
+ ]
155
+ },
156
+ "metadata": {},
157
+ "output_type": "display_data"
158
+ },
159
+ {
160
+ "data": {
161
+ "application/vnd.jupyter.widget-view+json": {
162
+ "model_id": "a7f130baafab4493bfe185fa7f3a9fe9",
163
+ "version_major": 2,
164
+ "version_minor": 0
165
+ },
166
+ "text/plain": [
167
+ " 0%| | 0/50 [00:00<?, ?ba/s]"
168
+ ]
169
+ },
170
+ "metadata": {},
171
+ "output_type": "display_data"
172
+ }
173
+ ],
174
+ "source": [
175
+ "def tokenize(examples):\n",
176
+ " outputs = tokenizer(examples['text'], truncation=True)\n",
177
+ " return outputs\n",
178
+ "\n",
179
+ "tokenized_ds = ds.map(tokenize, batched=True)"
180
+ ]
181
+ },
182
+ {
183
+ "cell_type": "markdown",
184
+ "metadata": {},
185
+ "source": [
186
+ "## Prepare Trainer"
187
+ ]
188
+ },
189
+ {
190
+ "cell_type": "code",
191
+ "execution_count": null,
192
+ "metadata": {},
193
+ "outputs": [],
194
+ "source": [
195
+ "from transformers import TrainingArguments, Trainer, DataCollatorWithPadding"
196
+ ]
197
+ },
198
+ {
199
+ "cell_type": "code",
200
+ "execution_count": null,
201
+ "metadata": {},
202
+ "outputs": [],
203
+ "source": [
204
+ "import numpy as np\n",
205
+ "\n",
206
+ "def compute_metrics(eval_preds):\n",
207
+ " metric = load_metric(\"accuracy\")\n",
208
+ " logits, labels = eval_preds\n",
209
+ " predictions = np.argmax(logits, axis=-1)\n",
210
+ " return metric.compute(predictions=predictions, references=labels)"
211
+ ]
212
+ },
213
+ {
214
+ "cell_type": "code",
215
+ "execution_count": null,
216
+ "metadata": {},
217
+ "outputs": [],
218
+ "source": [
219
+ "training_args = TrainingArguments(num_train_epochs=1,\n",
220
+ " output_dir=\"distilbert-imdb\",\n",
221
+ " push_to_hub=True,\n",
222
+ " per_device_train_batch_size=16,\n",
223
+ " per_device_eval_batch_size=16,\n",
224
+ " evaluation_strategy=\"epoch\")"
225
+ ]
226
+ },
227
+ {
228
+ "cell_type": "code",
229
+ "execution_count": null,
230
+ "metadata": {},
231
+ "outputs": [],
232
+ "source": [
233
+ "data_collator = DataCollatorWithPadding(tokenizer)"
234
+ ]
235
+ },
236
+ {
237
+ "cell_type": "code",
238
+ "execution_count": null,
239
+ "metadata": {},
240
+ "outputs": [],
241
+ "source": [
242
+ "trainer = Trainer(model=model, tokenizer=tokenizer,\n",
243
+ " data_collator=data_collator,\n",
244
+ " args=training_args,\n",
245
+ " train_dataset=tokenized_ds[\"train\"],\n",
246
+ " eval_dataset=tokenized_ds[\"test\"], \n",
247
+ " compute_metrics=compute_metrics)"
248
+ ]
249
+ },
250
+ {
251
+ "cell_type": "markdown",
252
+ "metadata": {},
253
+ "source": [
254
+ "## Train Model and Push to Hub"
255
+ ]
256
+ },
257
+ {
258
+ "cell_type": "code",
259
+ "execution_count": null,
260
+ "metadata": {},
261
+ "outputs": [],
262
+ "source": [
263
+ "trainer.train()"
264
+ ]
265
+ },
266
+ {
267
+ "cell_type": "code",
268
+ "execution_count": null,
269
+ "metadata": {},
270
+ "outputs": [],
271
+ "source": [
272
+ "trainer.push_to_hub()"
273
+ ]
274
+ },
275
+ {
276
+ "cell_type": "code",
277
+ "execution_count": null,
278
+ "metadata": {},
279
+ "outputs": [],
280
+ "source": []
281
+ }
282
+ ],
283
+ "metadata": {
284
+ "kernelspec": {
285
+ "display_name": "Python 3 (ipykernel)",
286
+ "language": "python",
287
+ "name": "python3"
288
+ }
289
+ },
290
+ "nbformat": 4,
291
+ "nbformat_minor": 4
292
+ }