RelativelyUnique commited on
Commit
5abb7cd
1 Parent(s): c1fc48e

migrate to huggingface spaces

Browse files
.pre-commit-config.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v2.3.0
4
+ hooks:
5
+ - id: check-yaml
6
+ - id: end-of-file-fixer
7
+ - id: trailing-whitespace
8
+ - repo: https://github.com/nbQA-dev/nbQA
9
+ rev: 1.3.1
10
+ hooks:
11
+ - id: nbqa-black
12
+ args: [--line-length=99]
13
+ - id: nbqa-pyupgrade
14
+ args: [--py36-plus]
15
+ - id: nbqa-isort
16
+ args: [--profile=black]
app.py ADDED
File without changes
pkmn-classifier/evaluate.ipynb ADDED
@@ -0,0 +1,541 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "name": "some_evaluation.ipynb",
7
+ "provenance": [],
8
+ "authorship_tag": "ABX9TyOnJhILbrhl8aZ1wDYochYn",
9
+ "include_colab_link": true
10
+ },
11
+ "kernelspec": {
12
+ "name": "python3",
13
+ "display_name": "Python 3"
14
+ },
15
+ "language_info": {
16
+ "name": "python"
17
+ }
18
+ },
19
+ "cells": [
20
+ {
21
+ "cell_type": "markdown",
22
+ "metadata": {
23
+ "id": "view-in-github",
24
+ "colab_type": "text"
25
+ },
26
+ "source": [
27
+ "<a href=\"https://colab.research.google.com/github/mrcoombes/projects/blob/main/some_evaluation.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
28
+ ]
29
+ },
30
+ {
31
+ "cell_type": "code",
32
+ "execution_count": null,
33
+ "metadata": {
34
+ "colab": {
35
+ "base_uri": "https://localhost:8080/"
36
+ },
37
+ "id": "zSH83k5bXBVV",
38
+ "outputId": "6bccc72d-1764-4ecd-ea55-3b05fa5f255a"
39
+ },
40
+ "outputs": [
41
+ {
42
+ "output_type": "stream",
43
+ "name": "stdout",
44
+ "text": [
45
+ "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
46
+ "Collecting transformers\n",
47
+ " Downloading transformers-4.19.2-py3-none-any.whl (4.2 MB)\n",
48
+ "\u001b[K |████████████████████████████████| 4.2 MB 5.1 MB/s \n",
49
+ "\u001b[?25hRequirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.7/dist-packages (from transformers) (21.3)\n",
50
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from transformers) (3.7.0)\n",
51
+ "Collecting tokenizers!=0.11.3,<0.13,>=0.11.1\n",
52
+ " Downloading tokenizers-0.12.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.6 MB)\n",
53
+ "\u001b[K |████████████████████████████████| 6.6 MB 45.6 MB/s \n",
54
+ "\u001b[?25hCollecting pyyaml>=5.1\n",
55
+ " Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)\n",
56
+ "\u001b[K |████████████████████████████████| 596 kB 45.0 MB/s \n",
57
+ "\u001b[?25hRequirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (1.21.6)\n",
58
+ "Collecting huggingface-hub<1.0,>=0.1.0\n",
59
+ " Downloading huggingface_hub-0.7.0-py3-none-any.whl (86 kB)\n",
60
+ "\u001b[K |████████████████████████████████| 86 kB 5.0 MB/s \n",
61
+ "\u001b[?25hRequirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (2019.12.20)\n",
62
+ "Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from transformers) (4.11.4)\n",
63
+ "Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from transformers) (2.23.0)\n",
64
+ "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.7/dist-packages (from transformers) (4.64.0)\n",
65
+ "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.7/dist-packages (from huggingface-hub<1.0,>=0.1.0->transformers) (4.2.0)\n",
66
+ "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=20.0->transformers) (3.0.9)\n",
67
+ "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->transformers) (3.8.0)\n",
68
+ "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2.10)\n",
69
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2022.5.18.1)\n",
70
+ "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (3.0.4)\n",
71
+ "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (1.24.3)\n",
72
+ "Installing collected packages: pyyaml, tokenizers, huggingface-hub, transformers\n",
73
+ " Attempting uninstall: pyyaml\n",
74
+ " Found existing installation: PyYAML 3.13\n",
75
+ " Uninstalling PyYAML-3.13:\n",
76
+ " Successfully uninstalled PyYAML-3.13\n",
77
+ "Successfully installed huggingface-hub-0.7.0 pyyaml-6.0 tokenizers-0.12.1 transformers-4.19.2\n",
78
+ "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
79
+ "Collecting datasets\n",
80
+ " Downloading datasets-2.2.2-py3-none-any.whl (346 kB)\n",
81
+ "\u001b[K |████████████████████████████████| 346 kB 5.1 MB/s \n",
82
+ "\u001b[?25hRequirement already satisfied: multiprocess in /usr/local/lib/python3.7/dist-packages (from datasets) (0.70.13)\n",
83
+ "Collecting xxhash\n",
84
+ " Downloading xxhash-3.0.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (212 kB)\n",
85
+ "\u001b[K |████████████████████████████████| 212 kB 58.9 MB/s \n",
86
+ "\u001b[?25hRequirement already satisfied: huggingface-hub<1.0.0,>=0.1.0 in /usr/local/lib/python3.7/dist-packages (from datasets) (0.7.0)\n",
87
+ "Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.7/dist-packages (from datasets) (4.64.0)\n",
88
+ "Collecting dill<0.3.5\n",
89
+ " Downloading dill-0.3.4-py2.py3-none-any.whl (86 kB)\n",
90
+ "\u001b[K |████████████████████████████████| 86 kB 5.4 MB/s \n",
91
+ "\u001b[?25hCollecting aiohttp\n",
92
+ " Downloading aiohttp-3.8.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.1 MB)\n",
93
+ "\u001b[K |████████████████████████████████| 1.1 MB 52.8 MB/s \n",
94
+ "\u001b[?25hRequirement already satisfied: packaging in /usr/local/lib/python3.7/dist-packages (from datasets) (21.3)\n",
95
+ "Collecting responses<0.19\n",
96
+ " Downloading responses-0.18.0-py3-none-any.whl (38 kB)\n",
97
+ "Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from datasets) (4.11.4)\n",
98
+ "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from datasets) (1.21.6)\n",
99
+ "Collecting fsspec[http]>=2021.05.0\n",
100
+ " Downloading fsspec-2022.5.0-py3-none-any.whl (140 kB)\n",
101
+ "\u001b[K |████████████████████████████████| 140 kB 56.3 MB/s \n",
102
+ "\u001b[?25hRequirement already satisfied: pyarrow>=6.0.0 in /usr/local/lib/python3.7/dist-packages (from datasets) (6.0.1)\n",
103
+ "Requirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from datasets) (1.3.5)\n",
104
+ "Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.7/dist-packages (from datasets) (2.23.0)\n",
105
+ "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.7/dist-packages (from huggingface-hub<1.0.0,>=0.1.0->datasets) (6.0)\n",
106
+ "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.7/dist-packages (from huggingface-hub<1.0.0,>=0.1.0->datasets) (4.2.0)\n",
107
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from huggingface-hub<1.0.0,>=0.1.0->datasets) (3.7.0)\n",
108
+ "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging->datasets) (3.0.9)\n",
109
+ "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests>=2.19.0->datasets) (1.24.3)\n",
110
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests>=2.19.0->datasets) (2022.5.18.1)\n",
111
+ "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests>=2.19.0->datasets) (2.10)\n",
112
+ "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests>=2.19.0->datasets) (3.0.4)\n",
113
+ "Collecting urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1\n",
114
+ " Downloading urllib3-1.25.11-py2.py3-none-any.whl (127 kB)\n",
115
+ "\u001b[K |████████████████████████████████| 127 kB 60.8 MB/s \n",
116
+ "\u001b[?25hCollecting aiosignal>=1.1.2\n",
117
+ " Downloading aiosignal-1.2.0-py3-none-any.whl (8.2 kB)\n",
118
+ "Collecting yarl<2.0,>=1.0\n",
119
+ " Downloading yarl-1.7.2-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (271 kB)\n",
120
+ "\u001b[K |████████████████████████████████| 271 kB 50.7 MB/s \n",
121
+ "\u001b[?25hCollecting multidict<7.0,>=4.5\n",
122
+ " Downloading multidict-6.0.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (94 kB)\n",
123
+ "\u001b[K |████████████████████████████████| 94 kB 1.0 MB/s \n",
124
+ "\u001b[?25hRequirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (21.4.0)\n",
125
+ "Collecting asynctest==0.13.0\n",
126
+ " Downloading asynctest-0.13.0-py3-none-any.whl (26 kB)\n",
127
+ "Collecting async-timeout<5.0,>=4.0.0a3\n",
128
+ " Downloading async_timeout-4.0.2-py3-none-any.whl (5.8 kB)\n",
129
+ "Requirement already satisfied: charset-normalizer<3.0,>=2.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (2.0.12)\n",
130
+ "Collecting frozenlist>=1.1.1\n",
131
+ " Downloading frozenlist-1.3.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (144 kB)\n",
132
+ "\u001b[K |████████████████████████████████| 144 kB 57.3 MB/s \n",
133
+ "\u001b[?25hRequirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->datasets) (3.8.0)\n",
134
+ "Collecting multiprocess\n",
135
+ " Downloading multiprocess-0.70.12.2-py37-none-any.whl (112 kB)\n",
136
+ "\u001b[K |████████████████████████████████| 112 kB 64.6 MB/s \n",
137
+ "\u001b[?25hRequirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/dist-packages (from pandas->datasets) (2.8.2)\n",
138
+ "Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.7/dist-packages (from pandas->datasets) (2022.1)\n",
139
+ "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.7.3->pandas->datasets) (1.15.0)\n",
140
+ "Installing collected packages: multidict, frozenlist, yarl, urllib3, asynctest, async-timeout, aiosignal, fsspec, dill, aiohttp, xxhash, responses, multiprocess, datasets\n",
141
+ " Attempting uninstall: urllib3\n",
142
+ " Found existing installation: urllib3 1.24.3\n",
143
+ " Uninstalling urllib3-1.24.3:\n",
144
+ " Successfully uninstalled urllib3-1.24.3\n",
145
+ " Attempting uninstall: dill\n",
146
+ " Found existing installation: dill 0.3.5.1\n",
147
+ " Uninstalling dill-0.3.5.1:\n",
148
+ " Successfully uninstalled dill-0.3.5.1\n",
149
+ " Attempting uninstall: multiprocess\n",
150
+ " Found existing installation: multiprocess 0.70.13\n",
151
+ " Uninstalling multiprocess-0.70.13:\n",
152
+ " Successfully uninstalled multiprocess-0.70.13\n",
153
+ "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
154
+ "datascience 0.10.6 requires folium==0.2.1, but you have folium 0.8.3 which is incompatible.\u001b[0m\n",
155
+ "Successfully installed aiohttp-3.8.1 aiosignal-1.2.0 async-timeout-4.0.2 asynctest-0.13.0 datasets-2.2.2 dill-0.3.4 frozenlist-1.3.0 fsspec-2022.5.0 multidict-6.0.2 multiprocess-0.70.12.2 responses-0.18.0 urllib3-1.25.11 xxhash-3.0.0 yarl-1.7.2\n"
156
+ ]
157
+ }
158
+ ],
159
+ "source": [
160
+ "!pip install transformers\n",
161
+ "!pip install datasets"
162
+ ]
163
+ },
164
+ {
165
+ "cell_type": "code",
166
+ "source": [
167
+ "from transformers import AutoModelForSequenceClassification, AutoTokenizer\n",
168
+ "from transformers import pipeline\n",
169
+ "\n",
170
+ "model = AutoModelForSequenceClassification.from_pretrained('mrcoombes/distilbert-wikipedia-pokemon')\n",
171
+ "tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')\n",
172
+ "\n",
173
+ "classifier = pipeline('text-classification', model = model, tokenizer=tokenizer, return_all_scores=True)\n",
174
+ "\n",
175
+ "clf = lambda x: sorted(classifier(x)[0], key=lambda y: y['score'], reverse=True)"
176
+ ],
177
+ "metadata": {
178
+ "id": "3PGrgNFoXje6"
179
+ },
180
+ "execution_count": 90,
181
+ "outputs": []
182
+ },
183
+ {
184
+ "cell_type": "code",
185
+ "source": [
186
+ "pooh = 'Pooh is a small yellow bear. He is nearly 22 inches (560 mm) tall. He wears an old red t-shirt. His favorite food is honey. The first thing he says when he gets up in the morning is \"What is for breakfast?\". He invented the game \"Poohsticks\".'\n",
187
+ "\n",
188
+ "clf(pooh)"
189
+ ],
190
+ "metadata": {
191
+ "colab": {
192
+ "base_uri": "https://localhost:8080/"
193
+ },
194
+ "id": "SG72dknXqD7K",
195
+ "outputId": "8a923810-4757-48f0-de13-0fe127428bf9"
196
+ },
197
+ "execution_count": 91,
198
+ "outputs": [
199
+ {
200
+ "output_type": "execute_result",
201
+ "data": {
202
+ "text/plain": [
203
+ "[{'label': 'Normal', 'score': 0.2600723206996918},\n",
204
+ " {'label': 'Psychic', 'score': 0.1543654203414917},\n",
205
+ " {'label': 'Fire', 'score': 0.06567206978797913},\n",
206
+ " {'label': 'Electric', 'score': 0.05511392652988434},\n",
207
+ " {'label': 'Bug', 'score': 0.054607123136520386},\n",
208
+ " {'label': 'Water', 'score': 0.05065950006246567},\n",
209
+ " {'label': 'Ghost', 'score': 0.04469247907400131},\n",
210
+ " {'label': 'Fairy', 'score': 0.03949982300400734},\n",
211
+ " {'label': 'Dark', 'score': 0.039207249879837036},\n",
212
+ " {'label': 'Grass', 'score': 0.03893379494547844},\n",
213
+ " {'label': 'Dragon', 'score': 0.034857217222452164},\n",
214
+ " {'label': 'Poison', 'score': 0.02820383384823799},\n",
215
+ " {'label': 'Ice', 'score': 0.028034506365656853},\n",
216
+ " {'label': 'Fighting', 'score': 0.02752748504281044},\n",
217
+ " {'label': 'Ground', 'score': 0.021409546956419945},\n",
218
+ " {'label': 'Steel', 'score': 0.021021585911512375},\n",
219
+ " {'label': 'Rock', 'score': 0.020723488181829453},\n",
220
+ " {'label': 'Flying', 'score': 0.015398567542433739}]"
221
+ ]
222
+ },
223
+ "metadata": {},
224
+ "execution_count": 91
225
+ }
226
+ ]
227
+ },
228
+ {
229
+ "cell_type": "code",
230
+ "source": [
231
+ "clf('Tigger is a fictional tiger. He has orange fur with black stripes. He is easily recognized by his beady eyes, long chin, springy tail, and bouncy personality. As he says himself, \"Bouncing is what Tiggers do best.\" Like other Pooh characters, Tigger is based on one of the stuffed animals of Christopher Robin Milne.')"
232
+ ],
233
+ "metadata": {
234
+ "colab": {
235
+ "base_uri": "https://localhost:8080/"
236
+ },
237
+ "id": "o3-gMTkxe44B",
238
+ "outputId": "632deac3-f378-4fda-c9a3-112edaf0e966"
239
+ },
240
+ "execution_count": 93,
241
+ "outputs": [
242
+ {
243
+ "output_type": "execute_result",
244
+ "data": {
245
+ "text/plain": [
246
+ "[{'label': 'Normal', 'score': 0.335999995470047},\n",
247
+ " {'label': 'Water', 'score': 0.09296873211860657},\n",
248
+ " {'label': 'Psychic', 'score': 0.07045886665582657},\n",
249
+ " {'label': 'Fire', 'score': 0.06954856216907501},\n",
250
+ " {'label': 'Bug', 'score': 0.044683054089546204},\n",
251
+ " {'label': 'Dragon', 'score': 0.0426957830786705},\n",
252
+ " {'label': 'Dark', 'score': 0.04223953187465668},\n",
253
+ " {'label': 'Electric', 'score': 0.040662020444869995},\n",
254
+ " {'label': 'Fairy', 'score': 0.03830638527870178},\n",
255
+ " {'label': 'Grass', 'score': 0.03206343948841095},\n",
256
+ " {'label': 'Poison', 'score': 0.027593791484832764},\n",
257
+ " {'label': 'Ground', 'score': 0.026110077276825905},\n",
258
+ " {'label': 'Ghost', 'score': 0.025732463225722313},\n",
259
+ " {'label': 'Ice', 'score': 0.02510858327150345},\n",
260
+ " {'label': 'Fighting', 'score': 0.023818649351596832},\n",
261
+ " {'label': 'Steel', 'score': 0.022579438984394073},\n",
262
+ " {'label': 'Rock', 'score': 0.021684883162379265},\n",
263
+ " {'label': 'Flying', 'score': 0.017745744436979294}]"
264
+ ]
265
+ },
266
+ "metadata": {},
267
+ "execution_count": 93
268
+ }
269
+ ]
270
+ },
271
+ {
272
+ "cell_type": "code",
273
+ "source": [
274
+ "# sourced from https://avengers.marvelhq.com/characters\n",
275
+ "avengers = {\n",
276
+ "\"black_panther\":\"Driven by an extremely lifelike artificial intelligence and possessing a nearly indestructible robotic body, the synthetic android called the Vision has taken his place among Earth’s Mightiest Heroes, the Avengers!\",\n",
277
+ "\"iron_man\":\"When billionaire industrialist Tony Stark dons his sophisticated steel-mesh armor, he becomes a living high-tech weapon - the world's greatest fighting machine. Tony has primed his ultra modern creation for waging state of the art campaigns, attaining sonic flight, and defending the greater good! He is the Armored Avenger - driven by a heart that is part machine, but all hero! He is the INVINCIBLE IRON MAN!\",\n",
278
+ "\"thor\":\"Nordic legend tells the tale of the son of Odin, the heir to the throne of Asgard - he is THOR, renowned as the mightiest hero of mythology! Thor's strength, endurance, and quest for battle are far greater than his Asgardian brethren. The mighty Thor wields an enchanted Uru hammer, Mjolnir, and is master of thunder and lightning.\",\n",
279
+ "\"black_widow\":\"Natasha Romanoff is the super-spy known as the Black Widow! Trained extensively in the art of espionage and outfitted with state-of-the-art equipment, Black Widow's combat skills are virtually unmatched. One of S.H.I.E.L.D's most valuable agents, she has carried out numerous black-ops missions and has recently been assigned by Director Nick Fury to keep an eye on the Avengers.\",\n",
280
+ "\"hulk\":\"A massive dose of gamma radiation transformed the brilliant but meek scientist Bruce Banner's DNA, awakening the hidden, adrenaline-fed hero in his genes... HULK! A hero of few words and incredible strength, the Hulk has long been pursued by those who want to use his immense power for their own purposes, or by those who thought the Jade Giant's anger was too dangerous to be controlled. Now, as a member of the Avengers, Hulk helps smash the unimaginable threats that no Hero could face alone, hoping to at least prove to the world that he is the strongest HERO there is!\",\n",
281
+ "\"cap\":\"During WWII, the patriotic Steve Rogers was offered a place in the military's top operation: Rebirth. Injected with an experimental super-serum, Rogers emerged from the treatment with heightened endurance, strength, and reaction time. With extensive training and an indestructible Vibranium shield, Rogers soon became the country's ultimate weapon: CAPTAIN AMERICA! Though frozen in ice during a climactic battle toward the end of the war, Rogers was discovered and revived decades later. Now the living legend continues the war against evil in modern times as a member of The Avengers!\",\n",
282
+ "}"
283
+ ],
284
+ "metadata": {
285
+ "id": "s950R1RBgztY"
286
+ },
287
+ "execution_count": 94,
288
+ "outputs": []
289
+ },
290
+ {
291
+ "cell_type": "code",
292
+ "source": [
293
+ "tokenizer.model_max_length"
294
+ ],
295
+ "metadata": {
296
+ "colab": {
297
+ "base_uri": "https://localhost:8080/"
298
+ },
299
+ "id": "pVv2MH8-iNKE",
300
+ "outputId": "adae6f5e-2788-47b7-c382-f46c1c6b17c1"
301
+ },
302
+ "execution_count": 23,
303
+ "outputs": [
304
+ {
305
+ "output_type": "execute_result",
306
+ "data": {
307
+ "text/plain": [
308
+ "512"
309
+ ]
310
+ },
311
+ "metadata": {},
312
+ "execution_count": 23
313
+ }
314
+ ]
315
+ },
316
+ {
317
+ "cell_type": "code",
318
+ "source": [
319
+ "for avenger, description in avengers.items():\n",
320
+ " print(len(tokenizer(description)['input_ids']))\n",
321
+ " assert len(tokenizer(description)['input_ids']) <= tokenizer.model_max_length"
322
+ ],
323
+ "metadata": {
324
+ "colab": {
325
+ "base_uri": "https://localhost:8080/"
326
+ },
327
+ "id": "rUp5CaCIi-E4",
328
+ "outputId": "724e56f4-5d15-4a65-fe5f-4794e760b944"
329
+ },
330
+ "execution_count": 95,
331
+ "outputs": [
332
+ {
333
+ "output_type": "stream",
334
+ "name": "stdout",
335
+ "text": [
336
+ "43\n",
337
+ "87\n",
338
+ "81\n",
339
+ "93\n",
340
+ "125\n",
341
+ "118\n"
342
+ ]
343
+ }
344
+ ]
345
+ },
346
+ {
347
+ "cell_type": "code",
348
+ "source": [
349
+ "types = {}\n",
350
+ "\n",
351
+ "for avenger, description in avengers.items():\n",
352
+ " types[avenger] = clf(description)"
353
+ ],
354
+ "metadata": {
355
+ "id": "piSuS63cjZDR"
356
+ },
357
+ "execution_count": 96,
358
+ "outputs": []
359
+ },
360
+ {
361
+ "cell_type": "code",
362
+ "source": [
363
+ "types"
364
+ ],
365
+ "metadata": {
366
+ "colab": {
367
+ "base_uri": "https://localhost:8080/"
368
+ },
369
+ "id": "lohLMB1WkD3j",
370
+ "outputId": "92fdf9b4-2671-498a-99c6-001ad02c22f1"
371
+ },
372
+ "execution_count": 97,
373
+ "outputs": [
374
+ {
375
+ "output_type": "execute_result",
376
+ "data": {
377
+ "text/plain": [
378
+ "{'black_panther': [{'label': 'Psychic', 'score': 0.14039084315299988},\n",
379
+ " {'label': 'Normal', 'score': 0.11493027955293655},\n",
380
+ " {'label': 'Electric', 'score': 0.08732766658067703},\n",
381
+ " {'label': 'Bug', 'score': 0.07257580012083054},\n",
382
+ " {'label': 'Ghost', 'score': 0.05607102811336517},\n",
383
+ " {'label': 'Fire', 'score': 0.05543342977762222},\n",
384
+ " {'label': 'Dark', 'score': 0.05199592188000679},\n",
385
+ " {'label': 'Steel', 'score': 0.05135666951537132},\n",
386
+ " {'label': 'Dragon', 'score': 0.04997771978378296},\n",
387
+ " {'label': 'Water', 'score': 0.04398226737976074},\n",
388
+ " {'label': 'Fighting', 'score': 0.042541105300188065},\n",
389
+ " {'label': 'Fairy', 'score': 0.04088388383388519},\n",
390
+ " {'label': 'Rock', 'score': 0.0403033122420311},\n",
391
+ " {'label': 'Grass', 'score': 0.0353127047419548},\n",
392
+ " {'label': 'Ice', 'score': 0.03269997984170914},\n",
393
+ " {'label': 'Ground', 'score': 0.030976610258221626},\n",
394
+ " {'label': 'Flying', 'score': 0.026838229969143867},\n",
395
+ " {'label': 'Poison', 'score': 0.02640254609286785}],\n",
396
+ " 'black_widow': [{'label': 'Psychic', 'score': 0.11771736294031143},\n",
397
+ " {'label': 'Normal', 'score': 0.09108898788690567},\n",
398
+ " {'label': 'Electric', 'score': 0.0777704268693924},\n",
399
+ " {'label': 'Ghost', 'score': 0.0661557987332344},\n",
400
+ " {'label': 'Bug', 'score': 0.06480831652879715},\n",
401
+ " {'label': 'Fire', 'score': 0.06366308778524399},\n",
402
+ " {'label': 'Dark', 'score': 0.06326092034578323},\n",
403
+ " {'label': 'Fighting', 'score': 0.05556448549032211},\n",
404
+ " {'label': 'Steel', 'score': 0.05402420461177826},\n",
405
+ " {'label': 'Dragon', 'score': 0.050193771719932556},\n",
406
+ " {'label': 'Grass', 'score': 0.04514629766345024},\n",
407
+ " {'label': 'Water', 'score': 0.0421665795147419},\n",
408
+ " {'label': 'Fairy', 'score': 0.04069530591368675},\n",
409
+ " {'label': 'Ice', 'score': 0.03823541849851608},\n",
410
+ " {'label': 'Rock', 'score': 0.03812859579920769},\n",
411
+ " {'label': 'Ground', 'score': 0.03384226933121681},\n",
412
+ " {'label': 'Flying', 'score': 0.02902897447347641},\n",
413
+ " {'label': 'Poison', 'score': 0.028509100899100304}],\n",
414
+ " 'cap': [{'label': 'Psychic', 'score': 0.1309943050146103},\n",
415
+ " {'label': 'Electric', 'score': 0.1047213152050972},\n",
416
+ " {'label': 'Fire', 'score': 0.08712606877088547},\n",
417
+ " {'label': 'Normal', 'score': 0.07033843547105789},\n",
418
+ " {'label': 'Bug', 'score': 0.061330799013376236},\n",
419
+ " {'label': 'Dragon', 'score': 0.058846309781074524},\n",
420
+ " {'label': 'Ghost', 'score': 0.056743498891592026},\n",
421
+ " {'label': 'Steel', 'score': 0.052078839391469955},\n",
422
+ " {'label': 'Rock', 'score': 0.05096956342458725},\n",
423
+ " {'label': 'Ice', 'score': 0.050031181424856186},\n",
424
+ " {'label': 'Fighting', 'score': 0.042880505323410034},\n",
425
+ " {'label': 'Dark', 'score': 0.042242322117090225},\n",
426
+ " {'label': 'Fairy', 'score': 0.037094846367836},\n",
427
+ " {'label': 'Water', 'score': 0.03594202175736427},\n",
428
+ " {'label': 'Grass', 'score': 0.03098415769636631},\n",
429
+ " {'label': 'Ground', 'score': 0.03098156489431858},\n",
430
+ " {'label': 'Poison', 'score': 0.030156349763274193},\n",
431
+ " {'label': 'Flying', 'score': 0.026537923142313957}],\n",
432
+ " 'hulk': [{'label': 'Psychic', 'score': 0.19098739326000214},\n",
433
+ " {'label': 'Electric', 'score': 0.09343517571687698},\n",
434
+ " {'label': 'Normal', 'score': 0.0896320790052414},\n",
435
+ " {'label': 'Fire', 'score': 0.06670381128787994},\n",
436
+ " {'label': 'Ghost', 'score': 0.062663234770298},\n",
437
+ " {'label': 'Bug', 'score': 0.0534147247672081},\n",
438
+ " {'label': 'Dragon', 'score': 0.05224861577153206},\n",
439
+ " {'label': 'Dark', 'score': 0.048941321671009064},\n",
440
+ " {'label': 'Steel', 'score': 0.04724843055009842},\n",
441
+ " {'label': 'Fighting', 'score': 0.042670682072639465},\n",
442
+ " {'label': 'Water', 'score': 0.038225673139095306},\n",
443
+ " {'label': 'Fairy', 'score': 0.038099996745586395},\n",
444
+ " {'label': 'Ice', 'score': 0.03596445918083191},\n",
445
+ " {'label': 'Rock', 'score': 0.033795569092035294},\n",
446
+ " {'label': 'Grass', 'score': 0.028919294476509094},\n",
447
+ " {'label': 'Flying', 'score': 0.027448808774352074},\n",
448
+ " {'label': 'Ground', 'score': 0.02604469284415245},\n",
449
+ " {'label': 'Poison', 'score': 0.023556090891361237}],\n",
450
+ " 'iron_man': [{'label': 'Electric', 'score': 0.08547740429639816},\n",
451
+ " {'label': 'Steel', 'score': 0.08327846229076385},\n",
452
+ " {'label': 'Fire', 'score': 0.08254428952932358},\n",
453
+ " {'label': 'Psychic', 'score': 0.08083238452672958},\n",
454
+ " {'label': 'Normal', 'score': 0.0718364492058754},\n",
455
+ " {'label': 'Bug', 'score': 0.07048758119344711},\n",
456
+ " {'label': 'Fighting', 'score': 0.06302879750728607},\n",
457
+ " {'label': 'Dragon', 'score': 0.06081298366189003},\n",
458
+ " {'label': 'Rock', 'score': 0.05577201768755913},\n",
459
+ " {'label': 'Dark', 'score': 0.05418119952082634},\n",
460
+ " {'label': 'Ice', 'score': 0.04405869543552399},\n",
461
+ " {'label': 'Grass', 'score': 0.043908316642045975},\n",
462
+ " {'label': 'Ground', 'score': 0.043028734624385834},\n",
463
+ " {'label': 'Ghost', 'score': 0.04071647301316261},\n",
464
+ " {'label': 'Fairy', 'score': 0.03440254554152489},\n",
465
+ " {'label': 'Water', 'score': 0.03119415044784546},\n",
466
+ " {'label': 'Flying', 'score': 0.027475930750370026},\n",
467
+ " {'label': 'Poison', 'score': 0.0269634909927845}],\n",
468
+ " 'thor': [{'label': 'Psychic', 'score': 0.1594998836517334},\n",
469
+ " {'label': 'Normal', 'score': 0.08447221666574478},\n",
470
+ " {'label': 'Fire', 'score': 0.08234388381242752},\n",
471
+ " {'label': 'Electric', 'score': 0.08180674910545349},\n",
472
+ " {'label': 'Ghost', 'score': 0.06200374290347099},\n",
473
+ " {'label': 'Dark', 'score': 0.06072374805808067},\n",
474
+ " {'label': 'Dragon', 'score': 0.05554042384028435},\n",
475
+ " {'label': 'Bug', 'score': 0.0517977699637413},\n",
476
+ " {'label': 'Steel', 'score': 0.04996638000011444},\n",
477
+ " {'label': 'Fighting', 'score': 0.04767195135354996},\n",
478
+ " {'label': 'Water', 'score': 0.04199497774243355},\n",
479
+ " {'label': 'Fairy', 'score': 0.03844171389937401},\n",
480
+ " {'label': 'Ice', 'score': 0.03652443364262581},\n",
481
+ " {'label': 'Rock', 'score': 0.03268720209598541},\n",
482
+ " {'label': 'Grass', 'score': 0.031752120703458786},\n",
483
+ " {'label': 'Ground', 'score': 0.03069768100976944},\n",
484
+ " {'label': 'Flying', 'score': 0.030208230018615723},\n",
485
+ " {'label': 'Poison', 'score': 0.021866854280233383}]}"
486
+ ]
487
+ },
488
+ "metadata": {},
489
+ "execution_count": 97
490
+ }
491
+ ]
492
+ },
493
+ {
494
+ "cell_type": "code",
495
+ "source": [
496
+ "# Well, that's dissapointing... anyway...\n",
497
+ "print(clf('Thor thunder and lightning'))\n",
498
+ "print('\\n\\n')\n",
499
+ "print(clf('Thor is the god'))\n",
500
+ "print('\\n\\n')\n",
501
+ "print(clf('Thor is the god of thunder and lightning'))"
502
+ ],
503
+ "metadata": {
504
+ "colab": {
505
+ "base_uri": "https://localhost:8080/"
506
+ },
507
+ "id": "BsmZe_kSkFNs",
508
+ "outputId": "149879af-01f2-4df0-d32c-0f6f56db0868"
509
+ },
510
+ "execution_count": 102,
511
+ "outputs": [
512
+ {
513
+ "output_type": "stream",
514
+ "name": "stdout",
515
+ "text": [
516
+ "[{'label': 'Electric', 'score': 0.11147855967283249}, {'label': 'Psychic', 'score': 0.10549111664295197}, {'label': 'Fire', 'score': 0.08532625436782837}, {'label': 'Normal', 'score': 0.07318642735481262}, {'label': 'Bug', 'score': 0.06542624533176422}, {'label': 'Ghost', 'score': 0.06079600378870964}, {'label': 'Steel', 'score': 0.05627329647541046}, {'label': 'Dragon', 'score': 0.051239024847745895}, {'label': 'Water', 'score': 0.04977051168680191}, {'label': 'Rock', 'score': 0.04720067232847214}, {'label': 'Fighting', 'score': 0.045852769166231155}, {'label': 'Ice', 'score': 0.041491683572530746}, {'label': 'Dark', 'score': 0.041257862001657486}, {'label': 'Grass', 'score': 0.041162408888339996}, {'label': 'Ground', 'score': 0.03666074201464653}, {'label': 'Fairy', 'score': 0.030792705714702606}, {'label': 'Poison', 'score': 0.029198039323091507}, {'label': 'Flying', 'score': 0.027395661920309067}]\n",
517
+ "\n",
518
+ "\n",
519
+ "\n",
520
+ "[{'label': 'Psychic', 'score': 0.10292913764715195}, {'label': 'Normal', 'score': 0.07830759137868881}, {'label': 'Electric', 'score': 0.07660140842199326}, {'label': 'Water', 'score': 0.06917692720890045}, {'label': 'Fire', 'score': 0.06355269998311996}, {'label': 'Bug', 'score': 0.06285296380519867}, {'label': 'Ghost', 'score': 0.06255136430263519}, {'label': 'Dark', 'score': 0.06148482486605644}, {'label': 'Dragon', 'score': 0.04900471866130829}, {'label': 'Grass', 'score': 0.047270145267248154}, {'label': 'Fighting', 'score': 0.04712704196572304}, {'label': 'Steel', 'score': 0.04605235531926155}, {'label': 'Rock', 'score': 0.04519566521048546}, {'label': 'Ground', 'score': 0.04106030985713005}, {'label': 'Ice', 'score': 0.039930786937475204}, {'label': 'Fairy', 'score': 0.03901135176420212}, {'label': 'Poison', 'score': 0.03707478940486908}, {'label': 'Flying', 'score': 0.030815837904810905}]\n",
521
+ "\n",
522
+ "\n",
523
+ "\n",
524
+ "[{'label': 'Psychic', 'score': 0.12898413836956024}, {'label': 'Electric', 'score': 0.10279341787099838}, {'label': 'Fire', 'score': 0.08279084414243698}, {'label': 'Normal', 'score': 0.07643641531467438}, {'label': 'Ghost', 'score': 0.06414799392223358}, {'label': 'Bug', 'score': 0.060264818370342255}, {'label': 'Water', 'score': 0.052734874188899994}, {'label': 'Dragon', 'score': 0.05206158012151718}, {'label': 'Steel', 'score': 0.047912657260894775}, {'label': 'Dark', 'score': 0.04722539335489273}, {'label': 'Fighting', 'score': 0.042949989438056946}, {'label': 'Rock', 'score': 0.042036380618810654}, {'label': 'Ice', 'score': 0.03920420631766319}, {'label': 'Grass', 'score': 0.036403022706508636}, {'label': 'Fairy', 'score': 0.03535248339176178}, {'label': 'Ground', 'score': 0.03424106165766716}, {'label': 'Flying', 'score': 0.027838196605443954}, {'label': 'Poison', 'score': 0.02662259340286255}]\n"
525
+ ]
526
+ }
527
+ ]
528
+ },
529
+ {
530
+ "cell_type": "code",
531
+ "source": [
532
+ ""
533
+ ],
534
+ "metadata": {
535
+ "id": "i-oI9LS1vq7e"
536
+ },
537
+ "execution_count": null,
538
+ "outputs": []
539
+ }
540
+ ]
541
+ }
pkmn-classifier/nlp.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
pkmn-classifier/pokemon.csv ADDED
The diff for this file is too large to render. See raw diff
 
pkmn-classifier/transformer.ipynb ADDED
@@ -0,0 +1,1503 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "source": [
6
+ "!pip install transformers\n",
7
+ "!pip install datasets"
8
+ ],
9
+ "metadata": {
10
+ "colab": {
11
+ "base_uri": "https://localhost:8080/"
12
+ },
13
+ "id": "v6Hl9dT9xI3h",
14
+ "outputId": "99c0c9dc-432d-43b4-ce25-828e9e467326"
15
+ },
16
+ "execution_count": 1,
17
+ "outputs": [
18
+ {
19
+ "output_type": "stream",
20
+ "name": "stdout",
21
+ "text": [
22
+ "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
23
+ "Requirement already satisfied: transformers in /usr/local/lib/python3.7/dist-packages (4.19.2)\n",
24
+ "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (1.21.6)\n",
25
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from transformers) (3.7.0)\n",
26
+ "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (2019.12.20)\n",
27
+ "Requirement already satisfied: tokenizers!=0.11.3,<0.13,>=0.11.1 in /usr/local/lib/python3.7/dist-packages (from transformers) (0.12.1)\n",
28
+ "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.7/dist-packages (from transformers) (21.3)\n",
29
+ "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.7/dist-packages (from transformers) (6.0)\n",
30
+ "Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from transformers) (4.11.4)\n",
31
+ "Requirement already satisfied: huggingface-hub<1.0,>=0.1.0 in /usr/local/lib/python3.7/dist-packages (from transformers) (0.7.0)\n",
32
+ "Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from transformers) (2.23.0)\n",
33
+ "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.7/dist-packages (from transformers) (4.64.0)\n",
34
+ "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.7/dist-packages (from huggingface-hub<1.0,>=0.1.0->transformers) (4.2.0)\n",
35
+ "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=20.0->transformers) (3.0.9)\n",
36
+ "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->transformers) (3.8.0)\n",
37
+ "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (3.0.4)\n",
38
+ "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2.10)\n",
39
+ "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (1.25.11)\n",
40
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2022.5.18.1)\n",
41
+ "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
42
+ "Requirement already satisfied: datasets in /usr/local/lib/python3.7/dist-packages (2.2.2)\n",
43
+ "Requirement already satisfied: pyarrow>=6.0.0 in /usr/local/lib/python3.7/dist-packages (from datasets) (6.0.1)\n",
44
+ "Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.7/dist-packages (from datasets) (4.64.0)\n",
45
+ "Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.7/dist-packages (from datasets) (2.23.0)\n",
46
+ "Requirement already satisfied: huggingface-hub<1.0.0,>=0.1.0 in /usr/local/lib/python3.7/dist-packages (from datasets) (0.7.0)\n",
47
+ "Requirement already satisfied: dill<0.3.5 in /usr/local/lib/python3.7/dist-packages (from datasets) (0.3.4)\n",
48
+ "Requirement already satisfied: responses<0.19 in /usr/local/lib/python3.7/dist-packages (from datasets) (0.18.0)\n",
49
+ "Requirement already satisfied: multiprocess in /usr/local/lib/python3.7/dist-packages (from datasets) (0.70.12.2)\n",
50
+ "Requirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from datasets) (1.3.5)\n",
51
+ "Requirement already satisfied: packaging in /usr/local/lib/python3.7/dist-packages (from datasets) (21.3)\n",
52
+ "Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from datasets) (4.11.4)\n",
53
+ "Requirement already satisfied: fsspec[http]>=2021.05.0 in /usr/local/lib/python3.7/dist-packages (from datasets) (2022.5.0)\n",
54
+ "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from datasets) (1.21.6)\n",
55
+ "Requirement already satisfied: aiohttp in /usr/local/lib/python3.7/dist-packages (from datasets) (3.8.1)\n",
56
+ "Requirement already satisfied: xxhash in /usr/local/lib/python3.7/dist-packages (from datasets) (3.0.0)\n",
57
+ "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.7/dist-packages (from huggingface-hub<1.0.0,>=0.1.0->datasets) (6.0)\n",
58
+ "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.7/dist-packages (from huggingface-hub<1.0.0,>=0.1.0->datasets) (4.2.0)\n",
59
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from huggingface-hub<1.0.0,>=0.1.0->datasets) (3.7.0)\n",
60
+ "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging->datasets) (3.0.9)\n",
61
+ "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests>=2.19.0->datasets) (2.10)\n",
62
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests>=2.19.0->datasets) (2022.5.18.1)\n",
63
+ "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests>=2.19.0->datasets) (3.0.4)\n",
64
+ "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests>=2.19.0->datasets) (1.25.11)\n",
65
+ "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (4.0.2)\n",
66
+ "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (1.3.0)\n",
67
+ "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (21.4.0)\n",
68
+ "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (6.0.2)\n",
69
+ "Requirement already satisfied: asynctest==0.13.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (0.13.0)\n",
70
+ "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (1.2.0)\n",
71
+ "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (1.7.2)\n",
72
+ "Requirement already satisfied: charset-normalizer<3.0,>=2.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (2.0.12)\n",
73
+ "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->datasets) (3.8.0)\n",
74
+ "Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.7/dist-packages (from pandas->datasets) (2022.1)\n",
75
+ "Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/dist-packages (from pandas->datasets) (2.8.2)\n",
76
+ "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.7.3->pandas->datasets) (1.15.0)\n"
77
+ ]
78
+ }
79
+ ]
80
+ },
81
+ {
82
+ "cell_type": "code",
83
+ "execution_count": 2,
84
+ "metadata": {
85
+ "id": "i-B5sPHELBBj"
86
+ },
87
+ "outputs": [],
88
+ "source": [
89
+ "import matplotlib.pyplot as plt\n",
90
+ "import pandas as pd\n",
91
+ "import seaborn as sns"
92
+ ]
93
+ },
94
+ {
95
+ "cell_type": "code",
96
+ "execution_count": 3,
97
+ "metadata": {
98
+ "id": "uxhFjfeHLBBr"
99
+ },
100
+ "outputs": [],
101
+ "source": [
102
+ "# Read the pokedex we scraped in web_scrape.ipynb into a DataFrame\n",
103
+ "pkmn = pd.read_csv(\"pokemon.csv\")\n",
104
+ "pkmn.rename(columns={\"Unnamed: 0\": \"wiki_index\"}, inplace=True)\n",
105
+ "pkmn = pkmn[pkmn.primary_type != \"Bird\"] # MissingNo is special, but not special enough to break the rules."
106
+ ]
107
+ },
108
+ {
109
+ "cell_type": "code",
110
+ "source": [
111
+ "# Fixing Inference.\n",
112
+ "\n",
113
+ "lil = pkmn[['primary_type', 'Notes']].copy()"
114
+ ],
115
+ "metadata": {
116
+ "id": "y_n4NNSOxzXL"
117
+ },
118
+ "execution_count": 4,
119
+ "outputs": []
120
+ },
121
+ {
122
+ "cell_type": "code",
123
+ "source": [
124
+ "from datasets.dataset_dict import DatasetDict\n",
125
+ "from datasets import Dataset\n",
126
+ "import datasets"
127
+ ],
128
+ "metadata": {
129
+ "id": "F_HSR6Kk2M0H"
130
+ },
131
+ "execution_count": 5,
132
+ "outputs": []
133
+ },
134
+ {
135
+ "cell_type": "code",
136
+ "source": [
137
+ "lil['primary_type'] = lil['primary_type'].astype('category') \n",
138
+ "lil['label'] = lil['primary_type'].cat.codes\n",
139
+ "df = lil[['label', 'Notes']].copy()\n",
140
+ "df = df.rename(columns={'Notes': 'text'})"
141
+ ],
142
+ "metadata": {
143
+ "id": "Dt-gNH3D4P9Y"
144
+ },
145
+ "execution_count": 6,
146
+ "outputs": []
147
+ },
148
+ {
149
+ "cell_type": "code",
150
+ "source": [
151
+ "id2label = {k: v for k, v in enumerate(lil['primary_type'].cat.categories)}\n",
152
+ "label2id = {v: k for k, v in enumerate(lil['primary_type'].cat.categories)}"
153
+ ],
154
+ "metadata": {
155
+ "id": "3eJghd1m4TSC"
156
+ },
157
+ "execution_count": 7,
158
+ "outputs": []
159
+ },
160
+ {
161
+ "cell_type": "code",
162
+ "execution_count": 8,
163
+ "metadata": {
164
+ "colab": {
165
+ "base_uri": "https://localhost:8080/"
166
+ },
167
+ "id": "VyS7Nc_Iv7ys",
168
+ "outputId": "4558a505-5e59-4c41-f6e5-33e637891a49"
169
+ },
170
+ "outputs": [
171
+ {
172
+ "output_type": "stream",
173
+ "name": "stderr",
174
+ "text": [
175
+ "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_projector.bias', 'vocab_layer_norm.bias', 'vocab_layer_norm.weight', 'vocab_transform.weight', 'vocab_projector.weight', 'vocab_transform.bias']\n",
176
+ "- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
177
+ "- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
178
+ "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'classifier.weight', 'pre_classifier.bias', 'classifier.bias']\n",
179
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
180
+ ]
181
+ }
182
+ ],
183
+ "source": [
184
+ "from transformers import AutoModelForSequenceClassification, AutoTokenizer\n",
185
+ "\n",
186
+ "tokenizer = AutoTokenizer.from_pretrained(\"distilbert-base-uncased\")\n",
187
+ "model = AutoModelForSequenceClassification.from_pretrained(\n",
188
+ " \"distilbert-base-uncased\", num_labels=18, id2label=id2label, label2id=label2id\n",
189
+ ")"
190
+ ]
191
+ },
192
+ {
193
+ "cell_type": "code",
194
+ "source": [
195
+ "train_df = df.sample(frac=0.7)\n",
196
+ "test_df = df.drop(train_df.index, inplace=False)\n",
197
+ "\n",
198
+ "train_dataset = Dataset.from_dict(train_df)\n",
199
+ "test_dataset = Dataset.from_dict(test_df)\n",
200
+ "my_dataset_dict = DatasetDict({\"train\":train_dataset,\"test\":test_dataset})\n",
201
+ "\n",
202
+ "my_dataset_dict"
203
+ ],
204
+ "metadata": {
205
+ "colab": {
206
+ "base_uri": "https://localhost:8080/"
207
+ },
208
+ "id": "GhlYrzASzDUS",
209
+ "outputId": "5c41102d-0c54-498f-f24e-0dae59a0ed56"
210
+ },
211
+ "execution_count": 10,
212
+ "outputs": [
213
+ {
214
+ "output_type": "execute_result",
215
+ "data": {
216
+ "text/plain": [
217
+ "DatasetDict({\n",
218
+ " train: Dataset({\n",
219
+ " features: ['label', 'text'],\n",
220
+ " num_rows: 654\n",
221
+ " })\n",
222
+ " test: Dataset({\n",
223
+ " features: ['label', 'text'],\n",
224
+ " num_rows: 281\n",
225
+ " })\n",
226
+ "})"
227
+ ]
228
+ },
229
+ "metadata": {},
230
+ "execution_count": 10
231
+ }
232
+ ]
233
+ },
234
+ {
235
+ "cell_type": "code",
236
+ "source": [
237
+ "def tokenize_function(examples):\n",
238
+ " return tokenizer(examples[\"text\"], padding=\"max_length\", truncation=True)\n",
239
+ "\n",
240
+ "dataset = my_dataset_dict\n",
241
+ "tokenized_datasets = dataset.map(tokenize_function, batched=True)"
242
+ ],
243
+ "metadata": {
244
+ "colab": {
245
+ "base_uri": "https://localhost:8080/",
246
+ "height": 81,
247
+ "referenced_widgets": [
248
+ "e397276a3494413a8537fa5ba48f7015",
249
+ "3681f628eb844524bd4a121dd313d2ef",
250
+ "4699b0136d4d449ca4accf9bc90fc45a",
251
+ "0435b641146f45488248e042ff0d4f31",
252
+ "8233bb294c7549b6a8ae5e140e8ca5b6",
253
+ "608d0d01fe4d4a14877920e51b10233a",
254
+ "246696956b774368ad7572db606a2414",
255
+ "cd1d0bd35b464f74aff7c7421de03fc9",
256
+ "eac0a76ba4cf4b61ae82e9aa2f770dce",
257
+ "68e198c9120f4451bd9b9f033a9d4bae",
258
+ "bd8b8f2e781e4de1a477c8af4c450d1f",
259
+ "b2eba59325544c0790fea0b5a08916b1",
260
+ "deeb1a85968d4619a7928712c09168d2",
261
+ "39d408f797bf46aa8a9617a68c8ba913",
262
+ "e643c45c934d414d9b46abdc64eccbbc",
263
+ "f7972177ace147fbb45918eebe106915",
264
+ "aa8e9ad5e1ce4b84b49c194c61f90820",
265
+ "93c35117cea74e5b9de2b871168b7095",
266
+ "09542f17885d4ce1b6fb5e8682beb6de",
267
+ "e1fce70e1a67446982a09c9d4948b48d",
268
+ "7712b1987ee143fe9ceb6ef13bded85d",
269
+ "51f19ba89e5f4d208564f03ff6f2b0da"
270
+ ]
271
+ },
272
+ "id": "HB1hGkNLy9rG",
273
+ "outputId": "365046b5-635d-402d-9c81-3fec9d05ac64"
274
+ },
275
+ "execution_count": 11,
276
+ "outputs": [
277
+ {
278
+ "output_type": "display_data",
279
+ "data": {
280
+ "text/plain": [
281
+ " 0%| | 0/1 [00:00<?, ?ba/s]"
282
+ ],
283
+ "application/vnd.jupyter.widget-view+json": {
284
+ "version_major": 2,
285
+ "version_minor": 0,
286
+ "model_id": "e397276a3494413a8537fa5ba48f7015"
287
+ }
288
+ },
289
+ "metadata": {}
290
+ },
291
+ {
292
+ "output_type": "display_data",
293
+ "data": {
294
+ "text/plain": [
295
+ " 0%| | 0/1 [00:00<?, ?ba/s]"
296
+ ],
297
+ "application/vnd.jupyter.widget-view+json": {
298
+ "version_major": 2,
299
+ "version_minor": 0,
300
+ "model_id": "b2eba59325544c0790fea0b5a08916b1"
301
+ }
302
+ },
303
+ "metadata": {}
304
+ }
305
+ ]
306
+ },
307
+ {
308
+ "cell_type": "code",
309
+ "source": [
310
+ "tokenized_datasets"
311
+ ],
312
+ "metadata": {
313
+ "colab": {
314
+ "base_uri": "https://localhost:8080/"
315
+ },
316
+ "id": "ZXN4Vn-l3dYd",
317
+ "outputId": "5fcee721-01eb-48c4-b111-af15c88b5078"
318
+ },
319
+ "execution_count": 12,
320
+ "outputs": [
321
+ {
322
+ "output_type": "execute_result",
323
+ "data": {
324
+ "text/plain": [
325
+ "DatasetDict({\n",
326
+ " train: Dataset({\n",
327
+ " features: ['label', 'text', 'input_ids', 'attention_mask'],\n",
328
+ " num_rows: 654\n",
329
+ " })\n",
330
+ " test: Dataset({\n",
331
+ " features: ['label', 'text', 'input_ids', 'attention_mask'],\n",
332
+ " num_rows: 281\n",
333
+ " })\n",
334
+ "})"
335
+ ]
336
+ },
337
+ "metadata": {},
338
+ "execution_count": 12
339
+ }
340
+ ]
341
+ },
342
+ {
343
+ "cell_type": "code",
344
+ "source": [
345
+ "small_train_dataset = tokenized_datasets[\"train\"]\n",
346
+ "small_eval_dataset = tokenized_datasets[\"test\"]"
347
+ ],
348
+ "metadata": {
349
+ "id": "PuYI3C5vyIdB"
350
+ },
351
+ "execution_count": 13,
352
+ "outputs": []
353
+ },
354
+ {
355
+ "cell_type": "code",
356
+ "source": [
357
+ "from transformers import TrainingArguments\n",
358
+ "\n",
359
+ "training_args = TrainingArguments(output_dir=\"test_trainer\")"
360
+ ],
361
+ "metadata": {
362
+ "id": "d9jCOTqy3mJl"
363
+ },
364
+ "execution_count": 14,
365
+ "outputs": []
366
+ },
367
+ {
368
+ "cell_type": "code",
369
+ "source": [
370
+ "import numpy as np\n",
371
+ "from datasets import load_metric\n",
372
+ "\n",
373
+ "metric = load_metric(\"accuracy\")\n"
374
+ ],
375
+ "metadata": {
376
+ "id": "SSmBwKCy3wO9"
377
+ },
378
+ "execution_count": 15,
379
+ "outputs": []
380
+ },
381
+ {
382
+ "cell_type": "code",
383
+ "source": [
384
+ "def compute_metrics(eval_pred):\n",
385
+ " logits, labels = eval_pred\n",
386
+ " predictions = np.argmax(logits, axis=-1)\n",
387
+ " return metric.compute(predictions=predictions, references=labels)"
388
+ ],
389
+ "metadata": {
390
+ "id": "8dOc3PsL30Q4"
391
+ },
392
+ "execution_count": 16,
393
+ "outputs": []
394
+ },
395
+ {
396
+ "cell_type": "code",
397
+ "source": [
398
+ "from transformers import TrainingArguments, Trainer\n",
399
+ "\n",
400
+ "training_args = TrainingArguments(output_dir=\"test_trainer\", evaluation_strategy=\"epoch\")"
401
+ ],
402
+ "metadata": {
403
+ "id": "-P9jBD1U37PZ"
404
+ },
405
+ "execution_count": 17,
406
+ "outputs": []
407
+ },
408
+ {
409
+ "cell_type": "code",
410
+ "source": [
411
+ "trainer = Trainer(\n",
412
+ " model=model,\n",
413
+ " args=training_args,\n",
414
+ " train_dataset=small_train_dataset,\n",
415
+ " eval_dataset=small_eval_dataset,\n",
416
+ " compute_metrics=compute_metrics,\n",
417
+ ")"
418
+ ],
419
+ "metadata": {
420
+ "id": "wuKk1EvI4DSo"
421
+ },
422
+ "execution_count": 18,
423
+ "outputs": []
424
+ },
425
+ {
426
+ "cell_type": "code",
427
+ "source": [
428
+ "trainer.train()"
429
+ ],
430
+ "metadata": {
431
+ "colab": {
432
+ "base_uri": "https://localhost:8080/",
433
+ "height": 675
434
+ },
435
+ "id": "CCTgG2cl4GZb",
436
+ "outputId": "ebe231ae-ad4e-4a44-eb0a-070ff395d6e9"
437
+ },
438
+ "execution_count": 19,
439
+ "outputs": [
440
+ {
441
+ "output_type": "stream",
442
+ "name": "stderr",
443
+ "text": [
444
+ "The following columns in the training set don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `DistilBertForSequenceClassification.forward`, you can safely ignore this message.\n",
445
+ "/usr/local/lib/python3.7/dist-packages/transformers/optimization.py:309: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
446
+ " FutureWarning,\n",
447
+ "***** Running training *****\n",
448
+ " Num examples = 654\n",
449
+ " Num Epochs = 3\n",
450
+ " Instantaneous batch size per device = 8\n",
451
+ " Total train batch size (w. parallel, distributed & accumulation) = 8\n",
452
+ " Gradient Accumulation steps = 1\n",
453
+ " Total optimization steps = 246\n"
454
+ ]
455
+ },
456
+ {
457
+ "output_type": "display_data",
458
+ "data": {
459
+ "text/plain": [
460
+ "<IPython.core.display.HTML object>"
461
+ ],
462
+ "text/html": [
463
+ "\n",
464
+ " <div>\n",
465
+ " \n",
466
+ " <progress value='246' max='246' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
467
+ " [246/246 01:46, Epoch 3/3]\n",
468
+ " </div>\n",
469
+ " <table border=\"1\" class=\"dataframe\">\n",
470
+ " <thead>\n",
471
+ " <tr style=\"text-align: left;\">\n",
472
+ " <th>Epoch</th>\n",
473
+ " <th>Training Loss</th>\n",
474
+ " <th>Validation Loss</th>\n",
475
+ " <th>Accuracy</th>\n",
476
+ " </tr>\n",
477
+ " </thead>\n",
478
+ " <tbody>\n",
479
+ " <tr>\n",
480
+ " <td>1</td>\n",
481
+ " <td>No log</td>\n",
482
+ " <td>2.471577</td>\n",
483
+ " <td>0.274021</td>\n",
484
+ " </tr>\n",
485
+ " <tr>\n",
486
+ " <td>2</td>\n",
487
+ " <td>No log</td>\n",
488
+ " <td>2.191889</td>\n",
489
+ " <td>0.437722</td>\n",
490
+ " </tr>\n",
491
+ " <tr>\n",
492
+ " <td>3</td>\n",
493
+ " <td>No log</td>\n",
494
+ " <td>2.077948</td>\n",
495
+ " <td>0.473310</td>\n",
496
+ " </tr>\n",
497
+ " </tbody>\n",
498
+ "</table><p>"
499
+ ]
500
+ },
501
+ "metadata": {}
502
+ },
503
+ {
504
+ "output_type": "stream",
505
+ "name": "stderr",
506
+ "text": [
507
+ "The following columns in the evaluation set don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `DistilBertForSequenceClassification.forward`, you can safely ignore this message.\n",
508
+ "***** Running Evaluation *****\n",
509
+ " Num examples = 281\n",
510
+ " Batch size = 8\n",
511
+ "The following columns in the evaluation set don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `DistilBertForSequenceClassification.forward`, you can safely ignore this message.\n",
512
+ "***** Running Evaluation *****\n",
513
+ " Num examples = 281\n",
514
+ " Batch size = 8\n",
515
+ "The following columns in the evaluation set don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `DistilBertForSequenceClassification.forward`, you can safely ignore this message.\n",
516
+ "***** Running Evaluation *****\n",
517
+ " Num examples = 281\n",
518
+ " Batch size = 8\n",
519
+ "\n",
520
+ "\n",
521
+ "Training completed. Do not forget to share your model on huggingface.co/models =)\n",
522
+ "\n",
523
+ "\n"
524
+ ]
525
+ },
526
+ {
527
+ "output_type": "execute_result",
528
+ "data": {
529
+ "text/plain": [
530
+ "TrainOutput(global_step=246, training_loss=2.312268264894563, metrics={'train_runtime': 106.6333, 'train_samples_per_second': 18.4, 'train_steps_per_second': 2.307, 'total_flos': 259975195619328.0, 'train_loss': 2.312268264894563, 'epoch': 3.0})"
531
+ ]
532
+ },
533
+ "metadata": {},
534
+ "execution_count": 19
535
+ }
536
+ ]
537
+ },
538
+ {
539
+ "cell_type": "code",
540
+ "source": [
541
+ "model.save_pretrained(\"./model\")"
542
+ ],
543
+ "metadata": {
544
+ "colab": {
545
+ "base_uri": "https://localhost:8080/"
546
+ },
547
+ "id": "svA3ZvuKW6LA",
548
+ "outputId": "ca4674d4-92c2-4c9a-a151-1d093ddbf954"
549
+ },
550
+ "execution_count": 20,
551
+ "outputs": [
552
+ {
553
+ "output_type": "stream",
554
+ "name": "stderr",
555
+ "text": [
556
+ "Configuration saved in ./config.json\n",
557
+ "Model weights saved in ./pytorch_model.bin\n"
558
+ ]
559
+ }
560
+ ]
561
+ },
562
+ {
563
+ "cell_type": "code",
564
+ "source": [
565
+ "model2 = AutoModelForSequenceClassification.from_pretrained('./model')"
566
+ ],
567
+ "metadata": {
568
+ "colab": {
569
+ "base_uri": "https://localhost:8080/"
570
+ },
571
+ "id": "rHGVDBRc3eqe",
572
+ "outputId": "29d1d861-c58a-41a0-8179-19abc44efac1"
573
+ },
574
+ "execution_count": 21,
575
+ "outputs": [
576
+ {
577
+ "output_type": "stream",
578
+ "name": "stderr",
579
+ "text": [
580
+ "loading configuration file ./config.json\n",
581
+ "Model config DistilBertConfig {\n",
582
+ " \"_name_or_path\": \".\",\n",
583
+ " \"activation\": \"gelu\",\n",
584
+ " \"architectures\": [\n",
585
+ " \"DistilBertForSequenceClassification\"\n",
586
+ " ],\n",
587
+ " \"attention_dropout\": 0.1,\n",
588
+ " \"dim\": 768,\n",
589
+ " \"dropout\": 0.1,\n",
590
+ " \"hidden_dim\": 3072,\n",
591
+ " \"id2label\": {\n",
592
+ " \"0\": \"Bug\",\n",
593
+ " \"1\": \"Dark\",\n",
594
+ " \"2\": \"Dragon\",\n",
595
+ " \"3\": \"Electric\",\n",
596
+ " \"4\": \"Fairy\",\n",
597
+ " \"5\": \"Fighting\",\n",
598
+ " \"6\": \"Fire\",\n",
599
+ " \"7\": \"Flying\",\n",
600
+ " \"8\": \"Ghost\",\n",
601
+ " \"9\": \"Grass\",\n",
602
+ " \"10\": \"Ground\",\n",
603
+ " \"11\": \"Ice\",\n",
604
+ " \"12\": \"Normal\",\n",
605
+ " \"13\": \"Poison\",\n",
606
+ " \"14\": \"Psychic\",\n",
607
+ " \"15\": \"Rock\",\n",
608
+ " \"16\": \"Steel\",\n",
609
+ " \"17\": \"Water\"\n",
610
+ " },\n",
611
+ " \"initializer_range\": 0.02,\n",
612
+ " \"label2id\": {\n",
613
+ " \"Bug\": 0,\n",
614
+ " \"Dark\": 1,\n",
615
+ " \"Dragon\": 2,\n",
616
+ " \"Electric\": 3,\n",
617
+ " \"Fairy\": 4,\n",
618
+ " \"Fighting\": 5,\n",
619
+ " \"Fire\": 6,\n",
620
+ " \"Flying\": 7,\n",
621
+ " \"Ghost\": 8,\n",
622
+ " \"Grass\": 9,\n",
623
+ " \"Ground\": 10,\n",
624
+ " \"Ice\": 11,\n",
625
+ " \"Normal\": 12,\n",
626
+ " \"Poison\": 13,\n",
627
+ " \"Psychic\": 14,\n",
628
+ " \"Rock\": 15,\n",
629
+ " \"Steel\": 16,\n",
630
+ " \"Water\": 17\n",
631
+ " },\n",
632
+ " \"max_position_embeddings\": 512,\n",
633
+ " \"model_type\": \"distilbert\",\n",
634
+ " \"n_heads\": 12,\n",
635
+ " \"n_layers\": 6,\n",
636
+ " \"pad_token_id\": 0,\n",
637
+ " \"problem_type\": \"single_label_classification\",\n",
638
+ " \"qa_dropout\": 0.1,\n",
639
+ " \"seq_classif_dropout\": 0.2,\n",
640
+ " \"sinusoidal_pos_embds\": false,\n",
641
+ " \"tie_weights_\": true,\n",
642
+ " \"torch_dtype\": \"float32\",\n",
643
+ " \"transformers_version\": \"4.19.2\",\n",
644
+ " \"vocab_size\": 30522\n",
645
+ "}\n",
646
+ "\n",
647
+ "loading weights file ./pytorch_model.bin\n",
648
+ "All model checkpoint weights were used when initializing DistilBertForSequenceClassification.\n",
649
+ "\n",
650
+ "All the weights of DistilBertForSequenceClassification were initialized from the model checkpoint at ..\n",
651
+ "If your task is similar to the task the model of the checkpoint was trained on, you can already use DistilBertForSequenceClassification for predictions without further training.\n"
652
+ ]
653
+ }
654
+ ]
655
+ },
656
+ {
657
+ "cell_type": "code",
658
+ "source": [
659
+ "from transformers import pipeline\n",
660
+ "\n",
661
+ "classifier = pipeline(task=\"text-classification\", tokenizer=tokenizer, model=model2.to('cpu'))"
662
+ ],
663
+ "metadata": {
664
+ "id": "yP--Matd4WCe"
665
+ },
666
+ "execution_count": 22,
667
+ "outputs": []
668
+ },
669
+ {
670
+ "cell_type": "code",
671
+ "source": [
672
+ "classifier('This pokemon climbs buildings at night.')"
673
+ ],
674
+ "metadata": {
675
+ "colab": {
676
+ "base_uri": "https://localhost:8080/"
677
+ },
678
+ "id": "qVAlioYj6jBO",
679
+ "outputId": "c820c497-da3e-4e1e-c1d0-ef2494bffb03"
680
+ },
681
+ "execution_count": 41,
682
+ "outputs": [
683
+ {
684
+ "output_type": "execute_result",
685
+ "data": {
686
+ "text/plain": [
687
+ "[{'label': 'Bug', 'score': 0.17771221697330475}]"
688
+ ]
689
+ },
690
+ "metadata": {},
691
+ "execution_count": 41
692
+ }
693
+ ]
694
+ },
695
+ {
696
+ "cell_type": "code",
697
+ "source": [
698
+ "classifier('This pokemon climbs buildings at night. They frequent midnight pool parties')"
699
+ ],
700
+ "metadata": {
701
+ "colab": {
702
+ "base_uri": "https://localhost:8080/"
703
+ },
704
+ "id": "3AJcXRR54Yk4",
705
+ "outputId": "5bd7cdd1-fd34-483f-d30c-560d54a28493"
706
+ },
707
+ "execution_count": 36,
708
+ "outputs": [
709
+ {
710
+ "output_type": "execute_result",
711
+ "data": {
712
+ "text/plain": [
713
+ "[{'label': 'Water', 'score': 0.4050225019454956}]"
714
+ ]
715
+ },
716
+ "metadata": {},
717
+ "execution_count": 36
718
+ }
719
+ ]
720
+ },
721
+ {
722
+ "cell_type": "code",
723
+ "source": [
724
+ "classifier('This pokemon climbs buildings at night. They frequent midnight garden parties')"
725
+ ],
726
+ "metadata": {
727
+ "colab": {
728
+ "base_uri": "https://localhost:8080/"
729
+ },
730
+ "id": "kJ54PeHz4pbO",
731
+ "outputId": "3e1571bd-6a0c-4800-ab4d-60835cb4f3a5"
732
+ },
733
+ "execution_count": 37,
734
+ "outputs": [
735
+ {
736
+ "output_type": "execute_result",
737
+ "data": {
738
+ "text/plain": [
739
+ "[{'label': 'Grass', 'score': 0.38808730244636536}]"
740
+ ]
741
+ },
742
+ "metadata": {},
743
+ "execution_count": 37
744
+ }
745
+ ]
746
+ },
747
+ {
748
+ "cell_type": "code",
749
+ "source": [
750
+ "classifier('This pokemon climbs buildings at night. They frequent midnight flame-throwing parties')"
751
+ ],
752
+ "metadata": {
753
+ "colab": {
754
+ "base_uri": "https://localhost:8080/"
755
+ },
756
+ "id": "GXVKzk-N8_So",
757
+ "outputId": "61e50b07-f3a5-4ff9-c152-1ea2d846dcfe"
758
+ },
759
+ "execution_count": 38,
760
+ "outputs": [
761
+ {
762
+ "output_type": "execute_result",
763
+ "data": {
764
+ "text/plain": [
765
+ "[{'label': 'Fire', 'score': 0.22531799972057343}]"
766
+ ]
767
+ },
768
+ "metadata": {},
769
+ "execution_count": 38
770
+ }
771
+ ]
772
+ },
773
+ {
774
+ "cell_type": "code",
775
+ "source": [
776
+ ""
777
+ ],
778
+ "metadata": {
779
+ "id": "idRtKyjM9KIE"
780
+ },
781
+ "execution_count": null,
782
+ "outputs": []
783
+ }
784
+ ],
785
+ "metadata": {
786
+ "colab": {
787
+ "name": "nlp.ipynb",
788
+ "provenance": []
789
+ },
790
+ "interpreter": {
791
+ "hash": "45e1260056979d5382785f386f12ee00f44622d9a136ee7663e9a61a67ca2a68"
792
+ },
793
+ "kernelspec": {
794
+ "display_name": "Python 3.10.0 ('projects-vBrzsZbN-py3.10')",
795
+ "language": "python",
796
+ "name": "python3"
797
+ },
798
+ "language_info": {
799
+ "codemirror_mode": {
800
+ "name": "ipython",
801
+ "version": 3
802
+ },
803
+ "file_extension": ".py",
804
+ "mimetype": "text/x-python",
805
+ "name": "python",
806
+ "nbconvert_exporter": "python",
807
+ "pygments_lexer": "ipython3",
808
+ "version": "3.10.0"
809
+ },
810
+ "orig_nbformat": 4,
811
+ "accelerator": "GPU",
812
+ "widgets": {
813
+ "application/vnd.jupyter.widget-state+json": {
814
+ "e397276a3494413a8537fa5ba48f7015": {
815
+ "model_module": "@jupyter-widgets/controls",
816
+ "model_name": "HBoxModel",
817
+ "model_module_version": "1.5.0",
818
+ "state": {
819
+ "_dom_classes": [],
820
+ "_model_module": "@jupyter-widgets/controls",
821
+ "_model_module_version": "1.5.0",
822
+ "_model_name": "HBoxModel",
823
+ "_view_count": null,
824
+ "_view_module": "@jupyter-widgets/controls",
825
+ "_view_module_version": "1.5.0",
826
+ "_view_name": "HBoxView",
827
+ "box_style": "",
828
+ "children": [
829
+ "IPY_MODEL_3681f628eb844524bd4a121dd313d2ef",
830
+ "IPY_MODEL_4699b0136d4d449ca4accf9bc90fc45a",
831
+ "IPY_MODEL_0435b641146f45488248e042ff0d4f31"
832
+ ],
833
+ "layout": "IPY_MODEL_8233bb294c7549b6a8ae5e140e8ca5b6"
834
+ }
835
+ },
836
+ "3681f628eb844524bd4a121dd313d2ef": {
837
+ "model_module": "@jupyter-widgets/controls",
838
+ "model_name": "HTMLModel",
839
+ "model_module_version": "1.5.0",
840
+ "state": {
841
+ "_dom_classes": [],
842
+ "_model_module": "@jupyter-widgets/controls",
843
+ "_model_module_version": "1.5.0",
844
+ "_model_name": "HTMLModel",
845
+ "_view_count": null,
846
+ "_view_module": "@jupyter-widgets/controls",
847
+ "_view_module_version": "1.5.0",
848
+ "_view_name": "HTMLView",
849
+ "description": "",
850
+ "description_tooltip": null,
851
+ "layout": "IPY_MODEL_608d0d01fe4d4a14877920e51b10233a",
852
+ "placeholder": "​",
853
+ "style": "IPY_MODEL_246696956b774368ad7572db606a2414",
854
+ "value": "100%"
855
+ }
856
+ },
857
+ "4699b0136d4d449ca4accf9bc90fc45a": {
858
+ "model_module": "@jupyter-widgets/controls",
859
+ "model_name": "FloatProgressModel",
860
+ "model_module_version": "1.5.0",
861
+ "state": {
862
+ "_dom_classes": [],
863
+ "_model_module": "@jupyter-widgets/controls",
864
+ "_model_module_version": "1.5.0",
865
+ "_model_name": "FloatProgressModel",
866
+ "_view_count": null,
867
+ "_view_module": "@jupyter-widgets/controls",
868
+ "_view_module_version": "1.5.0",
869
+ "_view_name": "ProgressView",
870
+ "bar_style": "success",
871
+ "description": "",
872
+ "description_tooltip": null,
873
+ "layout": "IPY_MODEL_cd1d0bd35b464f74aff7c7421de03fc9",
874
+ "max": 1,
875
+ "min": 0,
876
+ "orientation": "horizontal",
877
+ "style": "IPY_MODEL_eac0a76ba4cf4b61ae82e9aa2f770dce",
878
+ "value": 1
879
+ }
880
+ },
881
+ "0435b641146f45488248e042ff0d4f31": {
882
+ "model_module": "@jupyter-widgets/controls",
883
+ "model_name": "HTMLModel",
884
+ "model_module_version": "1.5.0",
885
+ "state": {
886
+ "_dom_classes": [],
887
+ "_model_module": "@jupyter-widgets/controls",
888
+ "_model_module_version": "1.5.0",
889
+ "_model_name": "HTMLModel",
890
+ "_view_count": null,
891
+ "_view_module": "@jupyter-widgets/controls",
892
+ "_view_module_version": "1.5.0",
893
+ "_view_name": "HTMLView",
894
+ "description": "",
895
+ "description_tooltip": null,
896
+ "layout": "IPY_MODEL_68e198c9120f4451bd9b9f033a9d4bae",
897
+ "placeholder": "​",
898
+ "style": "IPY_MODEL_bd8b8f2e781e4de1a477c8af4c450d1f",
899
+ "value": " 1/1 [00:00&lt;00:00, 2.22ba/s]"
900
+ }
901
+ },
902
+ "8233bb294c7549b6a8ae5e140e8ca5b6": {
903
+ "model_module": "@jupyter-widgets/base",
904
+ "model_name": "LayoutModel",
905
+ "model_module_version": "1.2.0",
906
+ "state": {
907
+ "_model_module": "@jupyter-widgets/base",
908
+ "_model_module_version": "1.2.0",
909
+ "_model_name": "LayoutModel",
910
+ "_view_count": null,
911
+ "_view_module": "@jupyter-widgets/base",
912
+ "_view_module_version": "1.2.0",
913
+ "_view_name": "LayoutView",
914
+ "align_content": null,
915
+ "align_items": null,
916
+ "align_self": null,
917
+ "border": null,
918
+ "bottom": null,
919
+ "display": null,
920
+ "flex": null,
921
+ "flex_flow": null,
922
+ "grid_area": null,
923
+ "grid_auto_columns": null,
924
+ "grid_auto_flow": null,
925
+ "grid_auto_rows": null,
926
+ "grid_column": null,
927
+ "grid_gap": null,
928
+ "grid_row": null,
929
+ "grid_template_areas": null,
930
+ "grid_template_columns": null,
931
+ "grid_template_rows": null,
932
+ "height": null,
933
+ "justify_content": null,
934
+ "justify_items": null,
935
+ "left": null,
936
+ "margin": null,
937
+ "max_height": null,
938
+ "max_width": null,
939
+ "min_height": null,
940
+ "min_width": null,
941
+ "object_fit": null,
942
+ "object_position": null,
943
+ "order": null,
944
+ "overflow": null,
945
+ "overflow_x": null,
946
+ "overflow_y": null,
947
+ "padding": null,
948
+ "right": null,
949
+ "top": null,
950
+ "visibility": null,
951
+ "width": null
952
+ }
953
+ },
954
+ "608d0d01fe4d4a14877920e51b10233a": {
955
+ "model_module": "@jupyter-widgets/base",
956
+ "model_name": "LayoutModel",
957
+ "model_module_version": "1.2.0",
958
+ "state": {
959
+ "_model_module": "@jupyter-widgets/base",
960
+ "_model_module_version": "1.2.0",
961
+ "_model_name": "LayoutModel",
962
+ "_view_count": null,
963
+ "_view_module": "@jupyter-widgets/base",
964
+ "_view_module_version": "1.2.0",
965
+ "_view_name": "LayoutView",
966
+ "align_content": null,
967
+ "align_items": null,
968
+ "align_self": null,
969
+ "border": null,
970
+ "bottom": null,
971
+ "display": null,
972
+ "flex": null,
973
+ "flex_flow": null,
974
+ "grid_area": null,
975
+ "grid_auto_columns": null,
976
+ "grid_auto_flow": null,
977
+ "grid_auto_rows": null,
978
+ "grid_column": null,
979
+ "grid_gap": null,
980
+ "grid_row": null,
981
+ "grid_template_areas": null,
982
+ "grid_template_columns": null,
983
+ "grid_template_rows": null,
984
+ "height": null,
985
+ "justify_content": null,
986
+ "justify_items": null,
987
+ "left": null,
988
+ "margin": null,
989
+ "max_height": null,
990
+ "max_width": null,
991
+ "min_height": null,
992
+ "min_width": null,
993
+ "object_fit": null,
994
+ "object_position": null,
995
+ "order": null,
996
+ "overflow": null,
997
+ "overflow_x": null,
998
+ "overflow_y": null,
999
+ "padding": null,
1000
+ "right": null,
1001
+ "top": null,
1002
+ "visibility": null,
1003
+ "width": null
1004
+ }
1005
+ },
1006
+ "246696956b774368ad7572db606a2414": {
1007
+ "model_module": "@jupyter-widgets/controls",
1008
+ "model_name": "DescriptionStyleModel",
1009
+ "model_module_version": "1.5.0",
1010
+ "state": {
1011
+ "_model_module": "@jupyter-widgets/controls",
1012
+ "_model_module_version": "1.5.0",
1013
+ "_model_name": "DescriptionStyleModel",
1014
+ "_view_count": null,
1015
+ "_view_module": "@jupyter-widgets/base",
1016
+ "_view_module_version": "1.2.0",
1017
+ "_view_name": "StyleView",
1018
+ "description_width": ""
1019
+ }
1020
+ },
1021
+ "cd1d0bd35b464f74aff7c7421de03fc9": {
1022
+ "model_module": "@jupyter-widgets/base",
1023
+ "model_name": "LayoutModel",
1024
+ "model_module_version": "1.2.0",
1025
+ "state": {
1026
+ "_model_module": "@jupyter-widgets/base",
1027
+ "_model_module_version": "1.2.0",
1028
+ "_model_name": "LayoutModel",
1029
+ "_view_count": null,
1030
+ "_view_module": "@jupyter-widgets/base",
1031
+ "_view_module_version": "1.2.0",
1032
+ "_view_name": "LayoutView",
1033
+ "align_content": null,
1034
+ "align_items": null,
1035
+ "align_self": null,
1036
+ "border": null,
1037
+ "bottom": null,
1038
+ "display": null,
1039
+ "flex": null,
1040
+ "flex_flow": null,
1041
+ "grid_area": null,
1042
+ "grid_auto_columns": null,
1043
+ "grid_auto_flow": null,
1044
+ "grid_auto_rows": null,
1045
+ "grid_column": null,
1046
+ "grid_gap": null,
1047
+ "grid_row": null,
1048
+ "grid_template_areas": null,
1049
+ "grid_template_columns": null,
1050
+ "grid_template_rows": null,
1051
+ "height": null,
1052
+ "justify_content": null,
1053
+ "justify_items": null,
1054
+ "left": null,
1055
+ "margin": null,
1056
+ "max_height": null,
1057
+ "max_width": null,
1058
+ "min_height": null,
1059
+ "min_width": null,
1060
+ "object_fit": null,
1061
+ "object_position": null,
1062
+ "order": null,
1063
+ "overflow": null,
1064
+ "overflow_x": null,
1065
+ "overflow_y": null,
1066
+ "padding": null,
1067
+ "right": null,
1068
+ "top": null,
1069
+ "visibility": null,
1070
+ "width": null
1071
+ }
1072
+ },
1073
+ "eac0a76ba4cf4b61ae82e9aa2f770dce": {
1074
+ "model_module": "@jupyter-widgets/controls",
1075
+ "model_name": "ProgressStyleModel",
1076
+ "model_module_version": "1.5.0",
1077
+ "state": {
1078
+ "_model_module": "@jupyter-widgets/controls",
1079
+ "_model_module_version": "1.5.0",
1080
+ "_model_name": "ProgressStyleModel",
1081
+ "_view_count": null,
1082
+ "_view_module": "@jupyter-widgets/base",
1083
+ "_view_module_version": "1.2.0",
1084
+ "_view_name": "StyleView",
1085
+ "bar_color": null,
1086
+ "description_width": ""
1087
+ }
1088
+ },
1089
+ "68e198c9120f4451bd9b9f033a9d4bae": {
1090
+ "model_module": "@jupyter-widgets/base",
1091
+ "model_name": "LayoutModel",
1092
+ "model_module_version": "1.2.0",
1093
+ "state": {
1094
+ "_model_module": "@jupyter-widgets/base",
1095
+ "_model_module_version": "1.2.0",
1096
+ "_model_name": "LayoutModel",
1097
+ "_view_count": null,
1098
+ "_view_module": "@jupyter-widgets/base",
1099
+ "_view_module_version": "1.2.0",
1100
+ "_view_name": "LayoutView",
1101
+ "align_content": null,
1102
+ "align_items": null,
1103
+ "align_self": null,
1104
+ "border": null,
1105
+ "bottom": null,
1106
+ "display": null,
1107
+ "flex": null,
1108
+ "flex_flow": null,
1109
+ "grid_area": null,
1110
+ "grid_auto_columns": null,
1111
+ "grid_auto_flow": null,
1112
+ "grid_auto_rows": null,
1113
+ "grid_column": null,
1114
+ "grid_gap": null,
1115
+ "grid_row": null,
1116
+ "grid_template_areas": null,
1117
+ "grid_template_columns": null,
1118
+ "grid_template_rows": null,
1119
+ "height": null,
1120
+ "justify_content": null,
1121
+ "justify_items": null,
1122
+ "left": null,
1123
+ "margin": null,
1124
+ "max_height": null,
1125
+ "max_width": null,
1126
+ "min_height": null,
1127
+ "min_width": null,
1128
+ "object_fit": null,
1129
+ "object_position": null,
1130
+ "order": null,
1131
+ "overflow": null,
1132
+ "overflow_x": null,
1133
+ "overflow_y": null,
1134
+ "padding": null,
1135
+ "right": null,
1136
+ "top": null,
1137
+ "visibility": null,
1138
+ "width": null
1139
+ }
1140
+ },
1141
+ "bd8b8f2e781e4de1a477c8af4c450d1f": {
1142
+ "model_module": "@jupyter-widgets/controls",
1143
+ "model_name": "DescriptionStyleModel",
1144
+ "model_module_version": "1.5.0",
1145
+ "state": {
1146
+ "_model_module": "@jupyter-widgets/controls",
1147
+ "_model_module_version": "1.5.0",
1148
+ "_model_name": "DescriptionStyleModel",
1149
+ "_view_count": null,
1150
+ "_view_module": "@jupyter-widgets/base",
1151
+ "_view_module_version": "1.2.0",
1152
+ "_view_name": "StyleView",
1153
+ "description_width": ""
1154
+ }
1155
+ },
1156
+ "b2eba59325544c0790fea0b5a08916b1": {
1157
+ "model_module": "@jupyter-widgets/controls",
1158
+ "model_name": "HBoxModel",
1159
+ "model_module_version": "1.5.0",
1160
+ "state": {
1161
+ "_dom_classes": [],
1162
+ "_model_module": "@jupyter-widgets/controls",
1163
+ "_model_module_version": "1.5.0",
1164
+ "_model_name": "HBoxModel",
1165
+ "_view_count": null,
1166
+ "_view_module": "@jupyter-widgets/controls",
1167
+ "_view_module_version": "1.5.0",
1168
+ "_view_name": "HBoxView",
1169
+ "box_style": "",
1170
+ "children": [
1171
+ "IPY_MODEL_deeb1a85968d4619a7928712c09168d2",
1172
+ "IPY_MODEL_39d408f797bf46aa8a9617a68c8ba913",
1173
+ "IPY_MODEL_e643c45c934d414d9b46abdc64eccbbc"
1174
+ ],
1175
+ "layout": "IPY_MODEL_f7972177ace147fbb45918eebe106915"
1176
+ }
1177
+ },
1178
+ "deeb1a85968d4619a7928712c09168d2": {
1179
+ "model_module": "@jupyter-widgets/controls",
1180
+ "model_name": "HTMLModel",
1181
+ "model_module_version": "1.5.0",
1182
+ "state": {
1183
+ "_dom_classes": [],
1184
+ "_model_module": "@jupyter-widgets/controls",
1185
+ "_model_module_version": "1.5.0",
1186
+ "_model_name": "HTMLModel",
1187
+ "_view_count": null,
1188
+ "_view_module": "@jupyter-widgets/controls",
1189
+ "_view_module_version": "1.5.0",
1190
+ "_view_name": "HTMLView",
1191
+ "description": "",
1192
+ "description_tooltip": null,
1193
+ "layout": "IPY_MODEL_aa8e9ad5e1ce4b84b49c194c61f90820",
1194
+ "placeholder": "​",
1195
+ "style": "IPY_MODEL_93c35117cea74e5b9de2b871168b7095",
1196
+ "value": "100%"
1197
+ }
1198
+ },
1199
+ "39d408f797bf46aa8a9617a68c8ba913": {
1200
+ "model_module": "@jupyter-widgets/controls",
1201
+ "model_name": "FloatProgressModel",
1202
+ "model_module_version": "1.5.0",
1203
+ "state": {
1204
+ "_dom_classes": [],
1205
+ "_model_module": "@jupyter-widgets/controls",
1206
+ "_model_module_version": "1.5.0",
1207
+ "_model_name": "FloatProgressModel",
1208
+ "_view_count": null,
1209
+ "_view_module": "@jupyter-widgets/controls",
1210
+ "_view_module_version": "1.5.0",
1211
+ "_view_name": "ProgressView",
1212
+ "bar_style": "success",
1213
+ "description": "",
1214
+ "description_tooltip": null,
1215
+ "layout": "IPY_MODEL_09542f17885d4ce1b6fb5e8682beb6de",
1216
+ "max": 1,
1217
+ "min": 0,
1218
+ "orientation": "horizontal",
1219
+ "style": "IPY_MODEL_e1fce70e1a67446982a09c9d4948b48d",
1220
+ "value": 1
1221
+ }
1222
+ },
1223
+ "e643c45c934d414d9b46abdc64eccbbc": {
1224
+ "model_module": "@jupyter-widgets/controls",
1225
+ "model_name": "HTMLModel",
1226
+ "model_module_version": "1.5.0",
1227
+ "state": {
1228
+ "_dom_classes": [],
1229
+ "_model_module": "@jupyter-widgets/controls",
1230
+ "_model_module_version": "1.5.0",
1231
+ "_model_name": "HTMLModel",
1232
+ "_view_count": null,
1233
+ "_view_module": "@jupyter-widgets/controls",
1234
+ "_view_module_version": "1.5.0",
1235
+ "_view_name": "HTMLView",
1236
+ "description": "",
1237
+ "description_tooltip": null,
1238
+ "layout": "IPY_MODEL_7712b1987ee143fe9ceb6ef13bded85d",
1239
+ "placeholder": "​",
1240
+ "style": "IPY_MODEL_51f19ba89e5f4d208564f03ff6f2b0da",
1241
+ "value": " 1/1 [00:00&lt;00:00, 3.74ba/s]"
1242
+ }
1243
+ },
1244
+ "f7972177ace147fbb45918eebe106915": {
1245
+ "model_module": "@jupyter-widgets/base",
1246
+ "model_name": "LayoutModel",
1247
+ "model_module_version": "1.2.0",
1248
+ "state": {
1249
+ "_model_module": "@jupyter-widgets/base",
1250
+ "_model_module_version": "1.2.0",
1251
+ "_model_name": "LayoutModel",
1252
+ "_view_count": null,
1253
+ "_view_module": "@jupyter-widgets/base",
1254
+ "_view_module_version": "1.2.0",
1255
+ "_view_name": "LayoutView",
1256
+ "align_content": null,
1257
+ "align_items": null,
1258
+ "align_self": null,
1259
+ "border": null,
1260
+ "bottom": null,
1261
+ "display": null,
1262
+ "flex": null,
1263
+ "flex_flow": null,
1264
+ "grid_area": null,
1265
+ "grid_auto_columns": null,
1266
+ "grid_auto_flow": null,
1267
+ "grid_auto_rows": null,
1268
+ "grid_column": null,
1269
+ "grid_gap": null,
1270
+ "grid_row": null,
1271
+ "grid_template_areas": null,
1272
+ "grid_template_columns": null,
1273
+ "grid_template_rows": null,
1274
+ "height": null,
1275
+ "justify_content": null,
1276
+ "justify_items": null,
1277
+ "left": null,
1278
+ "margin": null,
1279
+ "max_height": null,
1280
+ "max_width": null,
1281
+ "min_height": null,
1282
+ "min_width": null,
1283
+ "object_fit": null,
1284
+ "object_position": null,
1285
+ "order": null,
1286
+ "overflow": null,
1287
+ "overflow_x": null,
1288
+ "overflow_y": null,
1289
+ "padding": null,
1290
+ "right": null,
1291
+ "top": null,
1292
+ "visibility": null,
1293
+ "width": null
1294
+ }
1295
+ },
1296
+ "aa8e9ad5e1ce4b84b49c194c61f90820": {
1297
+ "model_module": "@jupyter-widgets/base",
1298
+ "model_name": "LayoutModel",
1299
+ "model_module_version": "1.2.0",
1300
+ "state": {
1301
+ "_model_module": "@jupyter-widgets/base",
1302
+ "_model_module_version": "1.2.0",
1303
+ "_model_name": "LayoutModel",
1304
+ "_view_count": null,
1305
+ "_view_module": "@jupyter-widgets/base",
1306
+ "_view_module_version": "1.2.0",
1307
+ "_view_name": "LayoutView",
1308
+ "align_content": null,
1309
+ "align_items": null,
1310
+ "align_self": null,
1311
+ "border": null,
1312
+ "bottom": null,
1313
+ "display": null,
1314
+ "flex": null,
1315
+ "flex_flow": null,
1316
+ "grid_area": null,
1317
+ "grid_auto_columns": null,
1318
+ "grid_auto_flow": null,
1319
+ "grid_auto_rows": null,
1320
+ "grid_column": null,
1321
+ "grid_gap": null,
1322
+ "grid_row": null,
1323
+ "grid_template_areas": null,
1324
+ "grid_template_columns": null,
1325
+ "grid_template_rows": null,
1326
+ "height": null,
1327
+ "justify_content": null,
1328
+ "justify_items": null,
1329
+ "left": null,
1330
+ "margin": null,
1331
+ "max_height": null,
1332
+ "max_width": null,
1333
+ "min_height": null,
1334
+ "min_width": null,
1335
+ "object_fit": null,
1336
+ "object_position": null,
1337
+ "order": null,
1338
+ "overflow": null,
1339
+ "overflow_x": null,
1340
+ "overflow_y": null,
1341
+ "padding": null,
1342
+ "right": null,
1343
+ "top": null,
1344
+ "visibility": null,
1345
+ "width": null
1346
+ }
1347
+ },
1348
+ "93c35117cea74e5b9de2b871168b7095": {
1349
+ "model_module": "@jupyter-widgets/controls",
1350
+ "model_name": "DescriptionStyleModel",
1351
+ "model_module_version": "1.5.0",
1352
+ "state": {
1353
+ "_model_module": "@jupyter-widgets/controls",
1354
+ "_model_module_version": "1.5.0",
1355
+ "_model_name": "DescriptionStyleModel",
1356
+ "_view_count": null,
1357
+ "_view_module": "@jupyter-widgets/base",
1358
+ "_view_module_version": "1.2.0",
1359
+ "_view_name": "StyleView",
1360
+ "description_width": ""
1361
+ }
1362
+ },
1363
+ "09542f17885d4ce1b6fb5e8682beb6de": {
1364
+ "model_module": "@jupyter-widgets/base",
1365
+ "model_name": "LayoutModel",
1366
+ "model_module_version": "1.2.0",
1367
+ "state": {
1368
+ "_model_module": "@jupyter-widgets/base",
1369
+ "_model_module_version": "1.2.0",
1370
+ "_model_name": "LayoutModel",
1371
+ "_view_count": null,
1372
+ "_view_module": "@jupyter-widgets/base",
1373
+ "_view_module_version": "1.2.0",
1374
+ "_view_name": "LayoutView",
1375
+ "align_content": null,
1376
+ "align_items": null,
1377
+ "align_self": null,
1378
+ "border": null,
1379
+ "bottom": null,
1380
+ "display": null,
1381
+ "flex": null,
1382
+ "flex_flow": null,
1383
+ "grid_area": null,
1384
+ "grid_auto_columns": null,
1385
+ "grid_auto_flow": null,
1386
+ "grid_auto_rows": null,
1387
+ "grid_column": null,
1388
+ "grid_gap": null,
1389
+ "grid_row": null,
1390
+ "grid_template_areas": null,
1391
+ "grid_template_columns": null,
1392
+ "grid_template_rows": null,
1393
+ "height": null,
1394
+ "justify_content": null,
1395
+ "justify_items": null,
1396
+ "left": null,
1397
+ "margin": null,
1398
+ "max_height": null,
1399
+ "max_width": null,
1400
+ "min_height": null,
1401
+ "min_width": null,
1402
+ "object_fit": null,
1403
+ "object_position": null,
1404
+ "order": null,
1405
+ "overflow": null,
1406
+ "overflow_x": null,
1407
+ "overflow_y": null,
1408
+ "padding": null,
1409
+ "right": null,
1410
+ "top": null,
1411
+ "visibility": null,
1412
+ "width": null
1413
+ }
1414
+ },
1415
+ "e1fce70e1a67446982a09c9d4948b48d": {
1416
+ "model_module": "@jupyter-widgets/controls",
1417
+ "model_name": "ProgressStyleModel",
1418
+ "model_module_version": "1.5.0",
1419
+ "state": {
1420
+ "_model_module": "@jupyter-widgets/controls",
1421
+ "_model_module_version": "1.5.0",
1422
+ "_model_name": "ProgressStyleModel",
1423
+ "_view_count": null,
1424
+ "_view_module": "@jupyter-widgets/base",
1425
+ "_view_module_version": "1.2.0",
1426
+ "_view_name": "StyleView",
1427
+ "bar_color": null,
1428
+ "description_width": ""
1429
+ }
1430
+ },
1431
+ "7712b1987ee143fe9ceb6ef13bded85d": {
1432
+ "model_module": "@jupyter-widgets/base",
1433
+ "model_name": "LayoutModel",
1434
+ "model_module_version": "1.2.0",
1435
+ "state": {
1436
+ "_model_module": "@jupyter-widgets/base",
1437
+ "_model_module_version": "1.2.0",
1438
+ "_model_name": "LayoutModel",
1439
+ "_view_count": null,
1440
+ "_view_module": "@jupyter-widgets/base",
1441
+ "_view_module_version": "1.2.0",
1442
+ "_view_name": "LayoutView",
1443
+ "align_content": null,
1444
+ "align_items": null,
1445
+ "align_self": null,
1446
+ "border": null,
1447
+ "bottom": null,
1448
+ "display": null,
1449
+ "flex": null,
1450
+ "flex_flow": null,
1451
+ "grid_area": null,
1452
+ "grid_auto_columns": null,
1453
+ "grid_auto_flow": null,
1454
+ "grid_auto_rows": null,
1455
+ "grid_column": null,
1456
+ "grid_gap": null,
1457
+ "grid_row": null,
1458
+ "grid_template_areas": null,
1459
+ "grid_template_columns": null,
1460
+ "grid_template_rows": null,
1461
+ "height": null,
1462
+ "justify_content": null,
1463
+ "justify_items": null,
1464
+ "left": null,
1465
+ "margin": null,
1466
+ "max_height": null,
1467
+ "max_width": null,
1468
+ "min_height": null,
1469
+ "min_width": null,
1470
+ "object_fit": null,
1471
+ "object_position": null,
1472
+ "order": null,
1473
+ "overflow": null,
1474
+ "overflow_x": null,
1475
+ "overflow_y": null,
1476
+ "padding": null,
1477
+ "right": null,
1478
+ "top": null,
1479
+ "visibility": null,
1480
+ "width": null
1481
+ }
1482
+ },
1483
+ "51f19ba89e5f4d208564f03ff6f2b0da": {
1484
+ "model_module": "@jupyter-widgets/controls",
1485
+ "model_name": "DescriptionStyleModel",
1486
+ "model_module_version": "1.5.0",
1487
+ "state": {
1488
+ "_model_module": "@jupyter-widgets/controls",
1489
+ "_model_module_version": "1.5.0",
1490
+ "_model_name": "DescriptionStyleModel",
1491
+ "_view_count": null,
1492
+ "_view_module": "@jupyter-widgets/base",
1493
+ "_view_module_version": "1.2.0",
1494
+ "_view_name": "StyleView",
1495
+ "description_width": ""
1496
+ }
1497
+ }
1498
+ }
1499
+ }
1500
+ },
1501
+ "nbformat": 4,
1502
+ "nbformat_minor": 0
1503
+ }
pkmn-classifier/web_scrape.ipynb ADDED
@@ -0,0 +1,476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import bs4 as BeautifulSoup\n",
10
+ "import pandas as pd\n",
11
+ "import requests"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "code",
16
+ "execution_count": 2,
17
+ "metadata": {},
18
+ "outputs": [],
19
+ "source": [
20
+ "list_of_urls = [\n",
21
+ " \"https://en.wikipedia.org/wiki/List_of_generation_I_Pok%C3%A9mon\",\n",
22
+ " \"https://en.wikipedia.org/wiki/List_of_generation_II_Pok%C3%A9mon\",\n",
23
+ " \"https://en.wikipedia.org/wiki/List_of_generation_III_Pok%C3%A9mon\",\n",
24
+ " \"https://en.wikipedia.org/wiki/List_of_generation_IV_Pok%C3%A9mon\",\n",
25
+ " \"https://en.wikipedia.org/wiki/List_of_generation_V_Pok%C3%A9mon\",\n",
26
+ " \"https://en.wikipedia.org/wiki/List_of_generation_VI_Pok%C3%A9mon\",\n",
27
+ " \"https://en.wikipedia.org/wiki/List_of_generation_VII_Pok%C3%A9mon\",\n",
28
+ " \"https://en.wikipedia.org/wiki/List_of_generation_VIII_Pok%C3%A9mon\",\n",
29
+ " \"https://en.wikipedia.org/wiki/List_of_generation_IX_Pok%C3%A9mon\",\n",
30
+ "]\n",
31
+ "\n",
32
+ "generations = [\"I\", \"II\", \"III\", \"IV\", \"V\", \"VI\", \"VII\", \"VIII\", \"IX\"]"
33
+ ]
34
+ },
35
+ {
36
+ "cell_type": "code",
37
+ "execution_count": 3,
38
+ "metadata": {},
39
+ "outputs": [
40
+ {
41
+ "name": "stdout",
42
+ "output_type": "stream",
43
+ "text": [
44
+ "I\n",
45
+ "II\n",
46
+ "III\n",
47
+ "IV\n",
48
+ "V\n",
49
+ "VI\n",
50
+ "VII\n",
51
+ "VIII\n",
52
+ "IX\n"
53
+ ]
54
+ }
55
+ ],
56
+ "source": [
57
+ "import time\n",
58
+ "\n",
59
+ "generation_list = []\n",
60
+ "for generation, url in zip(generations, list_of_urls):\n",
61
+ " print(generation)\n",
62
+ " generation_list.append(pd.read_html(url))\n",
63
+ " time.sleep(2)"
64
+ ]
65
+ },
66
+ {
67
+ "cell_type": "code",
68
+ "execution_count": 4,
69
+ "metadata": {},
70
+ "outputs": [],
71
+ "source": [
72
+ "biggest_table_from_each_url = []\n",
73
+ "\n",
74
+ "for tables in generation_list[:-1]:\n",
75
+ " largest_in_generation = tables[0]\n",
76
+ " for table in tables:\n",
77
+ " if table.size > largest_in_generation.size:\n",
78
+ " largest_in_generation = table\n",
79
+ "\n",
80
+ " biggest_table_from_each_url.append(largest_in_generation)\n",
81
+ "\n",
82
+ "biggest_table_from_each_url.append(\n",
83
+ " generation_list[-1][1]\n",
84
+ ") # hacky, gen IX has fewer new pokemon than the table at the bottom of the wikipedia page."
85
+ ]
86
+ },
87
+ {
88
+ "cell_type": "code",
89
+ "execution_count": 5,
90
+ "metadata": {},
91
+ "outputs": [
92
+ {
93
+ "data": {
94
+ "text/html": [
95
+ "<div>\n",
96
+ "<style scoped>\n",
97
+ " .dataframe tbody tr th:only-of-type {\n",
98
+ " vertical-align: middle;\n",
99
+ " }\n",
100
+ "\n",
101
+ " .dataframe tbody tr th {\n",
102
+ " vertical-align: top;\n",
103
+ " }\n",
104
+ "\n",
105
+ " .dataframe thead tr th {\n",
106
+ " text-align: left;\n",
107
+ " }\n",
108
+ "</style>\n",
109
+ "<table border=\"1\" class=\"dataframe\">\n",
110
+ " <thead>\n",
111
+ " <tr>\n",
112
+ " <th></th>\n",
113
+ " <th colspan=\"2\" halign=\"left\">Name</th>\n",
114
+ " <th>National Pokédexnumber</th>\n",
115
+ " <th colspan=\"2\" halign=\"left\">Type(s)</th>\n",
116
+ " <th>Evolves from</th>\n",
117
+ " <th>Evolves into</th>\n",
118
+ " <th>Notes</th>\n",
119
+ " </tr>\n",
120
+ " <tr>\n",
121
+ " <th></th>\n",
122
+ " <th>English</th>\n",
123
+ " <th>Japanese</th>\n",
124
+ " <th>National Pokédexnumber</th>\n",
125
+ " <th>Primary</th>\n",
126
+ " <th>Secondary</th>\n",
127
+ " <th>Evolves from</th>\n",
128
+ " <th>Evolves into</th>\n",
129
+ " <th>Notes</th>\n",
130
+ " </tr>\n",
131
+ " </thead>\n",
132
+ " <tbody>\n",
133
+ " <tr>\n",
134
+ " <th>0</th>\n",
135
+ " <td>Sprigatito</td>\n",
136
+ " <td>Nyaoha (ニャオハ)</td>\n",
137
+ " <td>TBA</td>\n",
138
+ " <td>Grass</td>\n",
139
+ " <td>Grass</td>\n",
140
+ " <td>Beginning of evolution</td>\n",
141
+ " <td>Unknown</td>\n",
142
+ " <td>Sprigatito is a cat-like Pokémon and the Grass...</td>\n",
143
+ " </tr>\n",
144
+ " <tr>\n",
145
+ " <th>1</th>\n",
146
+ " <td>Fuecoco</td>\n",
147
+ " <td>Hogēta (ホゲータ)</td>\n",
148
+ " <td>TBA</td>\n",
149
+ " <td>Fire</td>\n",
150
+ " <td>Fire</td>\n",
151
+ " <td>Beginning of evolution</td>\n",
152
+ " <td>Unknown</td>\n",
153
+ " <td>Fuecoco is a crocodile-like Pokémon and the Fi...</td>\n",
154
+ " </tr>\n",
155
+ " <tr>\n",
156
+ " <th>2</th>\n",
157
+ " <td>Quaxly</td>\n",
158
+ " <td>Kuwassu (クワッス)</td>\n",
159
+ " <td>TBA</td>\n",
160
+ " <td>Water</td>\n",
161
+ " <td>Water</td>\n",
162
+ " <td>Beginning of evolution</td>\n",
163
+ " <td>Unknown</td>\n",
164
+ " <td>Quaxly is a duck-like Pokémon and the Water-ty...</td>\n",
165
+ " </tr>\n",
166
+ " </tbody>\n",
167
+ "</table>\n",
168
+ "</div>"
169
+ ],
170
+ "text/plain": [
171
+ " Name National Pokédexnumber Type(s) \\\n",
172
+ " English Japanese National Pokédexnumber Primary Secondary \n",
173
+ "0 Sprigatito Nyaoha (ニャオハ) TBA Grass Grass \n",
174
+ "1 Fuecoco Hogēta (ホゲータ) TBA Fire Fire \n",
175
+ "2 Quaxly Kuwassu (クワッス) TBA Water Water \n",
176
+ "\n",
177
+ " Evolves from Evolves into \\\n",
178
+ " Evolves from Evolves into \n",
179
+ "0 Beginning of evolution Unknown \n",
180
+ "1 Beginning of evolution Unknown \n",
181
+ "2 Beginning of evolution Unknown \n",
182
+ "\n",
183
+ " Notes \n",
184
+ " Notes \n",
185
+ "0 Sprigatito is a cat-like Pokémon and the Grass... \n",
186
+ "1 Fuecoco is a crocodile-like Pokémon and the Fi... \n",
187
+ "2 Quaxly is a duck-like Pokémon and the Water-ty... "
188
+ ]
189
+ },
190
+ "execution_count": 5,
191
+ "metadata": {},
192
+ "output_type": "execute_result"
193
+ }
194
+ ],
195
+ "source": [
196
+ "biggest_table_from_each_url[-1]"
197
+ ]
198
+ },
199
+ {
200
+ "cell_type": "code",
201
+ "execution_count": 6,
202
+ "metadata": {},
203
+ "outputs": [
204
+ {
205
+ "data": {
206
+ "text/plain": [
207
+ "9"
208
+ ]
209
+ },
210
+ "execution_count": 6,
211
+ "metadata": {},
212
+ "output_type": "execute_result"
213
+ }
214
+ ],
215
+ "source": [
216
+ "len(biggest_table_from_each_url)"
217
+ ]
218
+ },
219
+ {
220
+ "cell_type": "code",
221
+ "execution_count": 7,
222
+ "metadata": {},
223
+ "outputs": [
224
+ {
225
+ "data": {
226
+ "text/plain": [
227
+ "[(152, 8),\n",
228
+ " (100, 8),\n",
229
+ " (138, 8),\n",
230
+ " (117, 8),\n",
231
+ " (158, 8),\n",
232
+ " (73, 8),\n",
233
+ " (94, 8),\n",
234
+ " (101, 8),\n",
235
+ " (3, 8)]"
236
+ ]
237
+ },
238
+ "execution_count": 7,
239
+ "metadata": {},
240
+ "output_type": "execute_result"
241
+ }
242
+ ],
243
+ "source": [
244
+ "[i.shape for i in biggest_table_from_each_url]"
245
+ ]
246
+ },
247
+ {
248
+ "cell_type": "code",
249
+ "execution_count": 8,
250
+ "metadata": {},
251
+ "outputs": [],
252
+ "source": [
253
+ "all_pkmn = pd.concat(biggest_table_from_each_url)"
254
+ ]
255
+ },
256
+ {
257
+ "cell_type": "code",
258
+ "execution_count": 9,
259
+ "metadata": {},
260
+ "outputs": [
261
+ {
262
+ "data": {
263
+ "text/plain": [
264
+ "MultiIndex([( 'Name', 'English'),\n",
265
+ " ( 'Name', 'Japanese'),\n",
266
+ " ('National Pokédexnumber', 'National Pokédexnumber'),\n",
267
+ " ( 'Type(s)', 'Primary'),\n",
268
+ " ( 'Type(s)', 'Secondary'),\n",
269
+ " ( 'Evolves from', 'Evolves from'),\n",
270
+ " ( 'Evolves into', 'Evolves into'),\n",
271
+ " ( 'Notes', 'Notes')],\n",
272
+ " )"
273
+ ]
274
+ },
275
+ "execution_count": 9,
276
+ "metadata": {},
277
+ "output_type": "execute_result"
278
+ }
279
+ ],
280
+ "source": [
281
+ "all_pkmn.columns"
282
+ ]
283
+ },
284
+ {
285
+ "cell_type": "code",
286
+ "execution_count": 10,
287
+ "metadata": {},
288
+ "outputs": [],
289
+ "source": [
290
+ "all_pkmn = (\n",
291
+ " all_pkmn.T.reset_index()\n",
292
+ " .drop(columns=[\"level_0\"])\n",
293
+ " .set_index(\"level_1\")\n",
294
+ " .T.rename(\n",
295
+ " columns={\n",
296
+ " \"English\": \"english_name\",\n",
297
+ " \"Japanese\": \"japanese_name\",\n",
298
+ " \"Primary\": \"primary_type\",\n",
299
+ " \"Secondary\": \"secondary_type\",\n",
300
+ " }\n",
301
+ " )\n",
302
+ ")"
303
+ ]
304
+ },
305
+ {
306
+ "cell_type": "code",
307
+ "execution_count": 12,
308
+ "metadata": {},
309
+ "outputs": [],
310
+ "source": [
311
+ "df = all_pkmn[[\"english_name\", \"primary_type\", \"secondary_type\", \"Notes\"]]"
312
+ ]
313
+ },
314
+ {
315
+ "cell_type": "code",
316
+ "execution_count": 17,
317
+ "metadata": {},
318
+ "outputs": [],
319
+ "source": [
320
+ "def select_text_before_first_bracket(string):\n",
321
+ " return string.split(\"[\")[0]\n",
322
+ "\n",
323
+ "\n",
324
+ "all_pkmn[\"primary_type\"] = all_pkmn[\"primary_type\"].apply(select_text_before_first_bracket)\n",
325
+ "all_pkmn[\"secondary_type\"] = all_pkmn[\"secondary_type\"].apply(select_text_before_first_bracket)"
326
+ ]
327
+ },
328
+ {
329
+ "cell_type": "code",
330
+ "execution_count": 20,
331
+ "metadata": {},
332
+ "outputs": [
333
+ {
334
+ "data": {
335
+ "text/plain": [
336
+ "Water 126\n",
337
+ "Normal 111\n",
338
+ "Grass 89\n",
339
+ "Bug 80\n",
340
+ "Psychic 64\n",
341
+ "Fire 61\n",
342
+ "Electric 55\n",
343
+ "Rock 50\n",
344
+ "Fighting 38\n",
345
+ "Dark 37\n",
346
+ "Ground 36\n",
347
+ "Poison 36\n",
348
+ "Ghost 32\n",
349
+ "Dragon 31\n",
350
+ "Steel 30\n",
351
+ "Ice 29\n",
352
+ "Fairy 23\n",
353
+ "Flying 7\n",
354
+ "Bird 1\n",
355
+ "Name: primary_type, dtype: int64"
356
+ ]
357
+ },
358
+ "execution_count": 20,
359
+ "metadata": {},
360
+ "output_type": "execute_result"
361
+ }
362
+ ],
363
+ "source": [
364
+ "# Wikipedia adds a number of footnotes to the types, this output is much cleaner.\n",
365
+ "all_pkmn[\"primary_type\"].value_counts()"
366
+ ]
367
+ },
368
+ {
369
+ "cell_type": "code",
370
+ "execution_count": 26,
371
+ "metadata": {},
372
+ "outputs": [
373
+ {
374
+ "data": {
375
+ "text/html": [
376
+ "<div>\n",
377
+ "<style scoped>\n",
378
+ " .dataframe tbody tr th:only-of-type {\n",
379
+ " vertical-align: middle;\n",
380
+ " }\n",
381
+ "\n",
382
+ " .dataframe tbody tr th {\n",
383
+ " vertical-align: top;\n",
384
+ " }\n",
385
+ "\n",
386
+ " .dataframe thead th {\n",
387
+ " text-align: right;\n",
388
+ " }\n",
389
+ "</style>\n",
390
+ "<table border=\"1\" class=\"dataframe\">\n",
391
+ " <thead>\n",
392
+ " <tr style=\"text-align: right;\">\n",
393
+ " <th>level_1</th>\n",
394
+ " <th>english_name</th>\n",
395
+ " <th>japanese_name</th>\n",
396
+ " <th>National Pokédexnumber</th>\n",
397
+ " <th>primary_type</th>\n",
398
+ " <th>secondary_type</th>\n",
399
+ " <th>Evolves from</th>\n",
400
+ " <th>Evolves into</th>\n",
401
+ " <th>Notes</th>\n",
402
+ " </tr>\n",
403
+ " </thead>\n",
404
+ " <tbody>\n",
405
+ " <tr>\n",
406
+ " <th>151</th>\n",
407
+ " <td>MissingNo.</td>\n",
408
+ " <td>Ketsuban (けつばん)</td>\n",
409
+ " <td>None[nb 8]</td>\n",
410
+ " <td>Bird</td>\n",
411
+ " <td>Normal</td>\n",
412
+ " <td>No evolution</td>\n",
413
+ " <td>No evolution</td>\n",
414
+ " <td>An error handler species, \"Missing Number\" was...</td>\n",
415
+ " </tr>\n",
416
+ " </tbody>\n",
417
+ "</table>\n",
418
+ "</div>"
419
+ ],
420
+ "text/plain": [
421
+ "level_1 english_name japanese_name National Pokédexnumber primary_type \\\n",
422
+ "151 MissingNo. Ketsuban (けつばん) None[nb 8] Bird \n",
423
+ "\n",
424
+ "level_1 secondary_type Evolves from Evolves into \\\n",
425
+ "151 Normal No evolution No evolution \n",
426
+ "\n",
427
+ "level_1 Notes \n",
428
+ "151 An error handler species, \"Missing Number\" was... "
429
+ ]
430
+ },
431
+ "execution_count": 26,
432
+ "metadata": {},
433
+ "output_type": "execute_result"
434
+ }
435
+ ],
436
+ "source": [
437
+ "# Bird Type? Is that like flying or something?\n",
438
+ "all_pkmn[all_pkmn.primary_type == \"Bird\"]"
439
+ ]
440
+ },
441
+ {
442
+ "cell_type": "code",
443
+ "execution_count": 27,
444
+ "metadata": {},
445
+ "outputs": [],
446
+ "source": [
447
+ "all_pkmn.to_csv(\"pokemon.csv\")"
448
+ ]
449
+ }
450
+ ],
451
+ "metadata": {
452
+ "interpreter": {
453
+ "hash": "45e1260056979d5382785f386f12ee00f44622d9a136ee7663e9a61a67ca2a68"
454
+ },
455
+ "kernelspec": {
456
+ "display_name": "Python 3.10.0 ('projects-vBrzsZbN-py3.10')",
457
+ "language": "python",
458
+ "name": "python3"
459
+ },
460
+ "language_info": {
461
+ "codemirror_mode": {
462
+ "name": "ipython",
463
+ "version": 3
464
+ },
465
+ "file_extension": ".py",
466
+ "mimetype": "text/x-python",
467
+ "name": "python",
468
+ "nbconvert_exporter": "python",
469
+ "pygments_lexer": "ipython3",
470
+ "version": "3.10.0"
471
+ },
472
+ "orig_nbformat": 4
473
+ },
474
+ "nbformat": 4,
475
+ "nbformat_minor": 2
476
+ }
poetry.lock ADDED
The diff for this file is too large to render. See raw diff
 
pyproject.toml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.poetry]
2
+ name = "projects"
3
+ version = "0.1.0"
4
+ description = ""
5
+ authors = ["Martin Coombes <mrcjdc@protonmail.com>"]
6
+
7
+ [tool.poetry.dependencies]
8
+ python = ">=3.10, <3.11"
9
+ beautifulsoup4 = "^4.11.1"
10
+ jupyterlab = "^3.4.2"
11
+ lxml = "^4.8.0"
12
+ numpy = "^1.22.4"
13
+ scikit-learn = "^1.1.1"
14
+ scipy = "^1.8.1"
15
+ seaborn = "^0.11.2"
16
+ streamlit = "^1.9.2"
17
+ transformers = "^4.19.2"
18
+ torch = "^1.11.0"
19
+
20
+ [tool.poetry.dev-dependencies]
21
+ black = "^22.3.0"
22
+ nbqa = "^1.3.1"
23
+ pre-commit = "^2.19.0"
24
+
25
+ [build-system]
26
+ requires = ["poetry-core>=1.0.0"]
27
+ build-backend = "poetry.core.masonry.api"