jorgefio commited on
Commit
0779329
1 Parent(s): d24343b

Upload distilbert_classification_run.ipynb

Browse files
Files changed (1) hide show
  1. distilbert_classification_run.ipynb +438 -0
distilbert_classification_run.ipynb ADDED
@@ -0,0 +1,438 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": [],
7
+ "gpuType": "T4"
8
+ },
9
+ "kernelspec": {
10
+ "name": "python3",
11
+ "display_name": "Python 3"
12
+ },
13
+ "language_info": {
14
+ "name": "python"
15
+ },
16
+ "accelerator": "GPU"
17
+ },
18
+ "cells": [
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": null,
22
+ "metadata": {
23
+ "id": "m8fE5WS67LOk",
24
+ "colab": {
25
+ "base_uri": "https://localhost:8080/"
26
+ },
27
+ "outputId": "8b4c6ecf-030b-4ad6-f17c-12cbdd20f943"
28
+ },
29
+ "outputs": [
30
+ {
31
+ "output_type": "stream",
32
+ "name": "stdout",
33
+ "text": [
34
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m25.3/25.3 MB\u001b[0m \u001b[31m50.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
35
+ "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
36
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m981.5/981.5 kB\u001b[0m \u001b[31m73.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
37
+ "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
38
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m653.6/653.6 kB\u001b[0m \u001b[31m59.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
39
+ "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
40
+ " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
41
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.4/7.4 MB\u001b[0m \u001b[31m99.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
42
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m66.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
43
+ "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
44
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m468.8/468.8 kB\u001b[0m \u001b[31m51.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
45
+ "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
46
+ " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
47
+ " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
48
+ " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
49
+ " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
50
+ " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
51
+ " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
52
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m268.8/268.8 kB\u001b[0m \u001b[31m16.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
53
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.8/7.8 MB\u001b[0m \u001b[31m76.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
54
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m64.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
55
+ "\u001b[?25h Building wheel for ktrain (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
56
+ " Building wheel for keras_bert (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
57
+ " Building wheel for keras-transformer (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
58
+ " Building wheel for keras-embed-sim (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
59
+ " Building wheel for keras-layer-normalization (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
60
+ " Building wheel for keras-multi-head (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
61
+ " Building wheel for keras-pos-embd (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
62
+ " Building wheel for keras-position-wise-feed-forward (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
63
+ " Building wheel for keras-self-attention (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
64
+ " Building wheel for cchardet (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
65
+ " Building wheel for langdetect (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
66
+ " Building wheel for tika (setup.py) ... \u001b[?25l\u001b[?25hdone\n"
67
+ ]
68
+ }
69
+ ],
70
+ "source": [
71
+ "!pip install -q ktrain"
72
+ ]
73
+ },
74
+ {
75
+ "cell_type": "code",
76
+ "source": [
77
+ "import ktrain\n",
78
+ "from ktrain import text\n",
79
+ "import pandas as pd\n",
80
+ "from sklearn.model_selection import train_test_split\n",
81
+ "import os\n",
82
+ "from sklearn.metrics import accuracy_score, classification_report, confusion_matrix"
83
+ ],
84
+ "metadata": {
85
+ "id": "F8OQn0v18Zuw"
86
+ },
87
+ "execution_count": null,
88
+ "outputs": []
89
+ },
90
+ {
91
+ "cell_type": "code",
92
+ "source": [
93
+ "os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\"\n",
94
+ "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\""
95
+ ],
96
+ "metadata": {
97
+ "id": "QKUWKSZE8j70"
98
+ },
99
+ "execution_count": null,
100
+ "outputs": []
101
+ },
102
+ {
103
+ "cell_type": "code",
104
+ "source": [
105
+ "root_folder = \"/content/drive/MyDrive/Colab Notebooks/\"\n",
106
+ "test_data_file = root_folder + \"data/internet_provider_test.csv\"\n",
107
+ "test_data = pd.read_csv(test_data_file)\n",
108
+ "categories = ['Slow Connection', 'Billing', 'Setup', 'No Connectivity']"
109
+ ],
110
+ "metadata": {
111
+ "id": "_6ofxOvA8arZ"
112
+ },
113
+ "execution_count": null,
114
+ "outputs": []
115
+ },
116
+ {
117
+ "cell_type": "code",
118
+ "source": [
119
+ "predictor = ktrain.load_predictor(root_folder + \"models/distilbert-model\")"
120
+ ],
121
+ "metadata": {
122
+ "id": "jfwfAwGE_qJo"
123
+ },
124
+ "execution_count": null,
125
+ "outputs": []
126
+ },
127
+ {
128
+ "cell_type": "code",
129
+ "source": [
130
+ "test_predictions = predictor.predict(test_data[\"Text\"].tolist())"
131
+ ],
132
+ "metadata": {
133
+ "id": "wNN96flfMNky"
134
+ },
135
+ "execution_count": null,
136
+ "outputs": []
137
+ },
138
+ {
139
+ "cell_type": "code",
140
+ "source": [
141
+ "accuracy = accuracy_score(test_data[\"Category\"].tolist(), test_predictions)\n",
142
+ "print(f'Test Accuracy: {accuracy}')\n",
143
+ "print(classification_report(test_data[\"Category\"].tolist(), test_predictions))\n",
144
+ "\n",
145
+ "conf_matrix = confusion_matrix(test_data[\"Category\"].tolist(), test_predictions)\n",
146
+ "print('Confusion Matrix:')\n",
147
+ "print(conf_matrix)"
148
+ ],
149
+ "metadata": {
150
+ "colab": {
151
+ "base_uri": "https://localhost:8080/"
152
+ },
153
+ "id": "T-m4v3U2S0cY",
154
+ "outputId": "2c0587f8-e834-41d3-8407-01c0a34cc84a"
155
+ },
156
+ "execution_count": null,
157
+ "outputs": [
158
+ {
159
+ "output_type": "stream",
160
+ "name": "stdout",
161
+ "text": [
162
+ "Test Accuracy: 0.9923664122137404\n",
163
+ " precision recall f1-score support\n",
164
+ "\n",
165
+ " Billing 1.00 0.96 0.98 28\n",
166
+ "No Connectivity 1.00 1.00 1.00 27\n",
167
+ " Setup 1.00 1.00 1.00 57\n",
168
+ "Slow Connection 0.95 1.00 0.97 19\n",
169
+ "\n",
170
+ " accuracy 0.99 131\n",
171
+ " macro avg 0.99 0.99 0.99 131\n",
172
+ " weighted avg 0.99 0.99 0.99 131\n",
173
+ "\n",
174
+ "Confusion Matrix:\n",
175
+ "[[27 0 0 1]\n",
176
+ " [ 0 27 0 0]\n",
177
+ " [ 0 0 57 0]\n",
178
+ " [ 0 0 0 19]]\n"
179
+ ]
180
+ }
181
+ ]
182
+ },
183
+ {
184
+ "cell_type": "markdown",
185
+ "source": [],
186
+ "metadata": {
187
+ "id": "m3qtEjRXSsiu"
188
+ }
189
+ },
190
+ {
191
+ "cell_type": "code",
192
+ "source": [
193
+ "def print_prediction(predictor, text):\n",
194
+ " labels = predictor.get_classes()\n",
195
+ " preds = predictor.predict_proba(text)\n",
196
+ " probs = [f\"{label}: {float(pred)}\" for label, pred in zip(labels, preds)]\n",
197
+ " print(probs)"
198
+ ],
199
+ "metadata": {
200
+ "id": "6ZRMQU3duT95"
201
+ },
202
+ "execution_count": null,
203
+ "outputs": []
204
+ },
205
+ {
206
+ "cell_type": "code",
207
+ "source": [
208
+ "x = \"I connection is very slow\"\n",
209
+ "prediction = predictor.predict(x)\n",
210
+ "print(f\"prediction: {prediction}\")"
211
+ ],
212
+ "metadata": {
213
+ "colab": {
214
+ "base_uri": "https://localhost:8080/"
215
+ },
216
+ "id": "Kl1F196gS0UP",
217
+ "outputId": "9aee3f66-6f5f-414a-9045-e16986dfcd11"
218
+ },
219
+ "execution_count": null,
220
+ "outputs": [
221
+ {
222
+ "output_type": "stream",
223
+ "name": "stdout",
224
+ "text": [
225
+ "prediction: Slow Connection\n"
226
+ ]
227
+ }
228
+ ]
229
+ },
230
+ {
231
+ "cell_type": "code",
232
+ "source": [
233
+ "x = \"I can't connect to any website\"\n",
234
+ "prediction = predictor.predict(x)\n",
235
+ "print(f\"prediction: {prediction}\")"
236
+ ],
237
+ "metadata": {
238
+ "colab": {
239
+ "base_uri": "https://localhost:8080/"
240
+ },
241
+ "id": "BQytvMlgS9cW",
242
+ "outputId": "7bf660aa-3014-4a86-c0cd-c94219b33e5c"
243
+ },
244
+ "execution_count": null,
245
+ "outputs": [
246
+ {
247
+ "output_type": "stream",
248
+ "name": "stdout",
249
+ "text": [
250
+ "prediction: No Connectivity\n"
251
+ ]
252
+ }
253
+ ]
254
+ },
255
+ {
256
+ "cell_type": "code",
257
+ "source": [
258
+ "x = \"I am paying too much for the service\"\n",
259
+ "prediction = predictor.predict(x)\n",
260
+ "print(f\"prediction: {prediction}\")"
261
+ ],
262
+ "metadata": {
263
+ "id": "VK0avHJ6TEWD",
264
+ "outputId": "8108186a-228a-4733-c22b-0b42e0a647d1",
265
+ "colab": {
266
+ "base_uri": "https://localhost:8080/"
267
+ }
268
+ },
269
+ "execution_count": null,
270
+ "outputs": [
271
+ {
272
+ "output_type": "stream",
273
+ "name": "stdout",
274
+ "text": [
275
+ "prediction: Billing\n"
276
+ ]
277
+ }
278
+ ]
279
+ },
280
+ {
281
+ "cell_type": "code",
282
+ "source": [
283
+ "x = \"I am waiting for engineer to configure the connection\"\n",
284
+ "prediction = predictor.predict(x)\n",
285
+ "print(f\"prediction: {prediction}\")"
286
+ ],
287
+ "metadata": {
288
+ "colab": {
289
+ "base_uri": "https://localhost:8080/"
290
+ },
291
+ "id": "-qnLDDmPTWXb",
292
+ "outputId": "770b0922-6f50-4f96-d309-5b1caae07831"
293
+ },
294
+ "execution_count": null,
295
+ "outputs": [
296
+ {
297
+ "output_type": "stream",
298
+ "name": "stdout",
299
+ "text": [
300
+ "prediction: Setup\n"
301
+ ]
302
+ }
303
+ ]
304
+ },
305
+ {
306
+ "cell_type": "code",
307
+ "source": [
308
+ "x = \"My internet is not\"\n",
309
+ "prediction = predictor.predict(x)\n",
310
+ "print(f\"prediction: {prediction}\")"
311
+ ],
312
+ "metadata": {
313
+ "colab": {
314
+ "base_uri": "https://localhost:8080/"
315
+ },
316
+ "id": "08G-waRmsIot",
317
+ "outputId": "caf48b98-4146-4557-d68e-df73434bad8e"
318
+ },
319
+ "execution_count": null,
320
+ "outputs": [
321
+ {
322
+ "output_type": "stream",
323
+ "name": "stdout",
324
+ "text": [
325
+ "prediction: No Connectivity\n"
326
+ ]
327
+ }
328
+ ]
329
+ },
330
+ {
331
+ "cell_type": "code",
332
+ "source": [
333
+ "x = \"My internet is not working\"\n",
334
+ "prediction = predictor.predict(x)\n",
335
+ "print(f\"prediction: {prediction}\")"
336
+ ],
337
+ "metadata": {
338
+ "colab": {
339
+ "base_uri": "https://localhost:8080/"
340
+ },
341
+ "id": "CijAiOSRsOYq",
342
+ "outputId": "64e4ef62-0300-4b29-b5c0-8c9be45532dd"
343
+ },
344
+ "execution_count": null,
345
+ "outputs": [
346
+ {
347
+ "output_type": "stream",
348
+ "name": "stdout",
349
+ "text": [
350
+ "prediction: Slow Connection\n"
351
+ ]
352
+ }
353
+ ]
354
+ },
355
+ {
356
+ "cell_type": "code",
357
+ "source": [
358
+ "x = \"My internet is not working.\"\n",
359
+ "prediction = predictor.predict(x)\n",
360
+ "print(f\"prediction: {prediction}\")"
361
+ ],
362
+ "metadata": {
363
+ "colab": {
364
+ "base_uri": "https://localhost:8080/"
365
+ },
366
+ "id": "2AjgbowSsPkf",
367
+ "outputId": "6c78da5e-f100-4a88-e7c2-ca99245c1f7b"
368
+ },
369
+ "execution_count": null,
370
+ "outputs": [
371
+ {
372
+ "output_type": "stream",
373
+ "name": "stdout",
374
+ "text": [
375
+ "prediction: Slow Connection\n"
376
+ ]
377
+ }
378
+ ]
379
+ },
380
+ {
381
+ "cell_type": "code",
382
+ "source": [
383
+ "x = \"My internet is not working at all\"\n",
384
+ "prediction = predictor.predict(x)\n",
385
+ "print(f\"prediction: {prediction}\")"
386
+ ],
387
+ "metadata": {
388
+ "colab": {
389
+ "base_uri": "https://localhost:8080/"
390
+ },
391
+ "id": "_11U1q7ysRb5",
392
+ "outputId": "c1fd45ac-cdb2-41f2-bcd1-89b919f894a4"
393
+ },
394
+ "execution_count": null,
395
+ "outputs": [
396
+ {
397
+ "output_type": "stream",
398
+ "name": "stdout",
399
+ "text": [
400
+ "prediction: Slow Connection\n"
401
+ ]
402
+ }
403
+ ]
404
+ },
405
+ {
406
+ "cell_type": "code",
407
+ "source": [
408
+ "print_prediction(predictor, \"My internet is not working at all\")"
409
+ ],
410
+ "metadata": {
411
+ "colab": {
412
+ "base_uri": "https://localhost:8080/"
413
+ },
414
+ "id": "SXmSCN3cu9cX",
415
+ "outputId": "fe179c02-9483-4b94-e970-e761f91e18e4"
416
+ },
417
+ "execution_count": null,
418
+ "outputs": [
419
+ {
420
+ "output_type": "stream",
421
+ "name": "stdout",
422
+ "text": [
423
+ "['Billing: 0.0002786574768833816', 'No Connectivity: 0.008474737405776978', 'Setup: 0.0002650754468049854', 'Slow Connection: 0.9909815192222595']\n"
424
+ ]
425
+ }
426
+ ]
427
+ },
428
+ {
429
+ "cell_type": "code",
430
+ "source": [],
431
+ "metadata": {
432
+ "id": "IUcD_l2MvFDy"
433
+ },
434
+ "execution_count": null,
435
+ "outputs": []
436
+ }
437
+ ]
438
+ }