sonebu commited on
Commit
2f6628d
1 Parent(s): b4c89d1

moving from github

Browse files
.gitattributes CHANGED
@@ -33,3 +33,11 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.pth.tar filter=lfs diff=lfs merge=lfs -text
37
+ *.png filter=lfs diff=lfs merge=lfs -text
38
+ *.jpg filter=lfs diff=lfs merge=lfs -text
39
+ *.es filter=lfs diff=lfs merge=lfs -text
40
+ *.en filter=lfs diff=lfs merge=lfs -text
41
+ *.elf filter=lfs diff=lfs merge=lfs -text
42
+ *.gif filter=lfs diff=lfs merge=lfs -text
43
+ *.json filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ **/.ipynb_checkpoints/
2
+ **/__pycache__/
LICENSE ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ NLP demo software by HyperbeeAI
2
+ Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. main@shallow.ai
README.md CHANGED
@@ -1,3 +1,48 @@
1
- ---
2
- license: other
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # NLP demo software by HyperbeeAI
2
+ Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. main@shallow.ai
3
+
4
+ This repository contains evaluation tools for the ai85 Spanish-to-English translation project.
5
+
6
+ To run the demo, see explanations in "demo.ipynb", which acts as the serial terminal to communicate with the ai85 from the host PC. Further explanations are provided below as well as in the notebooks.
7
+
8
+ ![Demo](./assets/ai8x-nlp-demo.gif)
9
+
10
+ ### Contents:
11
+
12
+ - **.py files:** python modules used by the Jupyter notebooks. These files define a simulation environment for the MAX78000 CNN accelerator hardware + some peripheral tools that help evaluation. Note that the simulator only includes the chip features that are relevant to this project (e.g., pooling not implemented because this project does not need it).
13
+
14
+ - **evaluation.ipynb:** this Jupyter notebook provides an interface to try out different sentences from the test set on the model in the simulation environment, and compute the BLEU score of the model over the test set.
15
+
16
+ - **demo.ipynb:** this Jupyter notebook acts as the serial interface with the chip. A sentence in the source language is sent over to the chip for translation via the serial port, the implementation on the chip translates this and sends it back via the same serial port in the target language, and the result is displayed on the notebook cell. This needs to be run together with the "assets/demo.elf" program on the chip, which does the actual translation job on the ai85. There is a specific cell on the notebook that needs to be run before the ai85 demo.elf is started. Check the notebook for further info.
17
+
18
+ - **assets/demo.elf:** C program running the actual translation application. Run this together with the demo.ipynb notebook for the translation demo. See further explanations inside demo.ipynb.
19
+
20
+ ### Extras/Notes:
21
+
22
+ - the demo C program does not require any extra modules/libraries, it can be directly run the same way as the Maxim SDK examples (i.e., using the arm gdb, defining the target as "remote localhost:3333", doing "load" etc.). However, note that the Jupyter notebook demo.ipynb needs to be run together with the C program for meaningful output. There is a specific cell on the notebook that needs to be run before the ai85 demo.elf is started. Check the notebook for further info.
23
+
24
+ - The demo.ipynb notebook needs to run on the same host PC that programs the ai85 since it uses the on-board (USB) serial port (that programs the ai85) to communicate with the chip while the translation application is running.
25
+
26
+ - Although the program should run on both the EVKit and the FeatherBoard without errors (since it uses common functionality), it was only explicitly tested with the FeatherBoard for now.
27
+
28
+ ### Setup:
29
+
30
+ This demo has been tested with the following configuration:
31
+
32
+ Python 3.8.10.
33
+ datasets 1.8.0
34
+ huggingface-hub 0.0.10
35
+ ipykernel 5.5.3
36
+ ipython 7.22.0
37
+ notebook 6.3.0
38
+ numpy 1.20.2
39
+ pyserial 3.5
40
+ sacrebleu 1.5.1
41
+ tokenizers 0.10.3
42
+ torch 1.8.1
43
+ torchtext 0.9.1
44
+ tqdm 4.49.0
45
+
46
+ Note1: torchtext might default to older versions (e.g., v0.8) on some containers (typically in those provided by AWS, which use older versions of python that don't align well with the newer torchtext versions), in that case, the .legacy submodule path needs to be removed from the import directives in the .py files and Jupyter notebooks.
47
+
48
+ Note2: there are multiple python packages on pip that provide serial port implementation, with conflicting function/object names too. Although the package used here gets imported with "import serial", it needs to be installed via "pip install pyserial", not "pip install serial". Make sure you get the correct version.
assets/ai8x-nlp-demo.gif ADDED

Git LFS Details

  • SHA256: c7100951ce0b1aa5809782f5a27f1586c6a20f991844914755dca8f20cf6e32a
  • Pointer size: 132 Bytes
  • Size of remote file: 2.53 MB
assets/demo.elf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:504440cab7269b333570f11888979dd63e610bcfe9e84466a0f3dca79b49ebda
3
+ size 2483932
assets/en.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f774c53ea142a16a7e507a67e46d882755e0b052604ea9f8afb4e51ccd48f894
3
+ size 394357
assets/es.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0ee2fab6b130bffdc8748cd8ce8330fba8406eb61a83cdb0128972067bdc0a82
3
+ size 407380
assets/es2en_hw_cp6.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f30f0e64f114594c83761887ecc9dd6edac9433d6efa9b25929f767423302fc8
3
+ size 9953564
dataloader.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ###########################################################################
2
+ # NLP demo software by HyperbeeAI. #
3
+ # Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. main@shallow.ai #
4
+ ###########################################################################
5
+ license_statement = "NLP demo software by HyperbeeAI. Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. main@shallow.ai"
6
+ print("imported dataloader.py")
7
+ print(license_statement)
8
+ print("")
9
+
10
+ from torchtext.legacy.datasets import TranslationDataset
11
+ from torchtext.legacy.data import Field, BucketIterator
12
+ import os
13
+
14
+ class NewsDataset(TranslationDataset):
15
+
16
+ name = 'news-comm-v15'
17
+
18
+ @staticmethod
19
+ def sort_key(ex):
20
+ return len(ex.src)
21
+
22
+ @classmethod
23
+ def splits(cls, exts, fields, root='./',
24
+ train='news-comm-v15-all', validation='news-comm-v15-all-valid', test='news-comm-v15-all-test', **kwargs):
25
+
26
+ if 'path' not in kwargs:
27
+ expected_folder = os.path.join(root, cls.name)
28
+ path = expected_folder if os.path.exists(expected_folder) else None
29
+ else:
30
+ path = kwargs['path']
31
+ del kwargs['path']
32
+
33
+ return super(NewsDataset, cls).splits(exts, fields, path, root, train, validation, test, **kwargs)
demo.ipynb ADDED
@@ -0,0 +1,620 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "d3092ed4",
6
+ "metadata": {},
7
+ "source": [
8
+ "# NLP demo software by HyperbeeAI\n",
9
+ "\n",
10
+ "Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. main@shallow.ai \n",
11
+ "\n",
12
+ "### Deployment\n",
13
+ "\n",
14
+ "This notebook acts as the serial terminal that we use in the ai85 translation demo.\n",
15
+ "\n",
16
+ "- load parameter set\n",
17
+ "- run a test on the PC to determine what to expect from the chip\n",
18
+ "- run test on the chip via serial terminal on PC"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "markdown",
23
+ "id": "e6208384",
24
+ "metadata": {},
25
+ "source": [
26
+ "### Initialization"
27
+ ]
28
+ },
29
+ {
30
+ "cell_type": "code",
31
+ "execution_count": 1,
32
+ "id": "6c10cb53",
33
+ "metadata": {},
34
+ "outputs": [
35
+ {
36
+ "name": "stdout",
37
+ "output_type": "stream",
38
+ "text": [
39
+ "imported utils.py\n",
40
+ "NLP demo software by HyperbeeAI. Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. main@shallow.ai\n",
41
+ "\n",
42
+ "imported layers.py\n",
43
+ "NLP demo software by HyperbeeAI. Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. main@shallow.ai\n",
44
+ "\n",
45
+ "imported functions.py\n",
46
+ "NLP demo software by HyperbeeAI. Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. main@shallow.ai\n",
47
+ "\n",
48
+ "imported models.py\n",
49
+ "NLP demo software by HyperbeeAI. Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. main@shallow.ai\n",
50
+ "\n",
51
+ "imported dataloader.py\n",
52
+ "NLP demo software by HyperbeeAI. Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. main@shallow.ai\n",
53
+ "\n"
54
+ ]
55
+ }
56
+ ],
57
+ "source": [
58
+ "import torch, random\n",
59
+ "import numpy as np\n",
60
+ "import torch.nn as nn\n",
61
+ "from torchtext.legacy.datasets import TranslationDataset\n",
62
+ "from torchtext.legacy.data import Field, BucketIterator\n",
63
+ "from utils import tokenize_es, tokenize_en, tokenizer_es, tokenizer_en, TRG_PAD_IDX, \\\n",
64
+ " translate_sentence, calculate_bleu, license_statement\n",
65
+ "from models import encoder, decoder, seq2seq\n",
66
+ "from dataloader import NewsDataset\n",
67
+ "\n",
68
+ "import serial"
69
+ ]
70
+ },
71
+ {
72
+ "cell_type": "code",
73
+ "execution_count": 2,
74
+ "id": "9966ccad",
75
+ "metadata": {},
76
+ "outputs": [],
77
+ "source": [
78
+ "SEED = 1234\n",
79
+ "random.seed(SEED)\n",
80
+ "torch.manual_seed(SEED)\n",
81
+ "torch.cuda.manual_seed(SEED)\n",
82
+ "torch.backends.cudnn.deterministic = True\n",
83
+ "BATCH_SIZE = 48"
84
+ ]
85
+ },
86
+ {
87
+ "cell_type": "code",
88
+ "execution_count": 3,
89
+ "id": "6d864c26",
90
+ "metadata": {},
91
+ "outputs": [
92
+ {
93
+ "name": "stdout",
94
+ "output_type": "stream",
95
+ "text": [
96
+ "Working with device: cuda\n"
97
+ ]
98
+ }
99
+ ],
100
+ "source": [
101
+ "SRC = Field(tokenize = tokenize_es, \n",
102
+ " init_token = tokenizer_es.token_to_id(\"<BOS>\"), \n",
103
+ " eos_token = tokenizer_es.token_to_id(\"<EOS>\"), \n",
104
+ " pad_token = tokenizer_es.token_to_id(\"<PAD>\"),\n",
105
+ " unk_token = tokenizer_es.token_to_id(\"<UNK>\"),\n",
106
+ " use_vocab = False,\n",
107
+ " batch_first = True)\n",
108
+ "\n",
109
+ "TRG = Field(tokenize = tokenize_en, \n",
110
+ " init_token = tokenizer_en.token_to_id(\"<BOS>\"), \n",
111
+ " eos_token = tokenizer_en.token_to_id(\"<EOS>\"), \n",
112
+ " pad_token = tokenizer_en.token_to_id(\"<PAD>\"),\n",
113
+ " unk_token = tokenizer_en.token_to_id(\"<UNK>\"),\n",
114
+ " use_vocab = False,\n",
115
+ " batch_first = True)\n",
116
+ "\n",
117
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
118
+ "#device = 'cpu'\n",
119
+ "print(\"Working with device:\", device)"
120
+ ]
121
+ },
122
+ {
123
+ "cell_type": "code",
124
+ "execution_count": 4,
125
+ "id": "7f1f2efb",
126
+ "metadata": {},
127
+ "outputs": [],
128
+ "source": [
129
+ "train_data, valid_data, test_data = NewsDataset.splits(exts=('.es', '.en'), fields=(SRC, TRG))\n",
130
+ "train_iterator, valid_iterator, test_iterator = BucketIterator.splits(\n",
131
+ " (train_data, valid_data, test_data),\n",
132
+ " batch_size = BATCH_SIZE,\n",
133
+ " device = device)"
134
+ ]
135
+ },
136
+ {
137
+ "cell_type": "code",
138
+ "execution_count": 5,
139
+ "id": "ccd6c1fc",
140
+ "metadata": {},
141
+ "outputs": [],
142
+ "source": [
143
+ "enc = encoder(device)\n",
144
+ "dec = decoder(device, TRG_PAD_IDX)\n",
145
+ "model = seq2seq(enc, dec)"
146
+ ]
147
+ },
148
+ {
149
+ "cell_type": "code",
150
+ "execution_count": 6,
151
+ "id": "6ae348e3",
152
+ "metadata": {},
153
+ "outputs": [],
154
+ "source": [
155
+ "trained_checkpoint = \"assets/es2en_hw_cp6.pt\"\n",
156
+ "model.load_state_dict(torch.load(trained_checkpoint, map_location=device), strict=False);\n",
157
+ "model.to(device);"
158
+ ]
159
+ },
160
+ {
161
+ "cell_type": "markdown",
162
+ "id": "ddb1a23b",
163
+ "metadata": {},
164
+ "source": [
165
+ "### serial conversion functions"
166
+ ]
167
+ },
168
+ {
169
+ "cell_type": "code",
170
+ "execution_count": 7,
171
+ "id": "534e72f2",
172
+ "metadata": {},
173
+ "outputs": [],
174
+ "source": [
175
+ "def singlepass64_tensor2serial(seq_length, tensor):\n",
176
+ " data = tensor.cpu().detach().numpy();\n",
177
+ " char_array = '';\n",
178
+ "\n",
179
+ " i=0;\n",
180
+ " while i < 64:\n",
181
+ " for j in range(0,seq_length):\n",
182
+ " ch3 = data[0,i+3,j].astype('int8')\n",
183
+ " ch2 = data[0,i+2,j].astype('int8')\n",
184
+ " ch1 = data[0,i+1,j].astype('int8')\n",
185
+ " ch0 = data[0,i+0,j].astype('int8')\n",
186
+ "\n",
187
+ " # 2s complements\n",
188
+ " val3 = \"{0:#0{1}x}\".format(int(np.binary_repr(ch3, width=8), 2),4)\n",
189
+ " val2 = \"{0:#0{1}x}\".format(int(np.binary_repr(ch2, width=8), 2),4)\n",
190
+ " val1 = \"{0:#0{1}x}\".format(int(np.binary_repr(ch1, width=8), 2),4)\n",
191
+ " val0 = \"{0:#0{1}x}\".format(int(np.binary_repr(ch0, width=8), 2),4)\n",
192
+ "\n",
193
+ " char_array += val3[2:] + val2[2:] + val1[2:] + val0[2:]\n",
194
+ "\n",
195
+ " i=i+4\n",
196
+ " \n",
197
+ " return char_array\n",
198
+ "\n",
199
+ "def twos_comp(val, bits):\n",
200
+ " if (val & (1 << (bits - 1))) != 0:\n",
201
+ " val = val - (1 << bits)\n",
202
+ " return val\n",
203
+ "\n",
204
+ "def tensor_fromserial_singlepass64(char_array, seq_length, typetensor):\n",
205
+ " out_tensor = torch.zeros_like(typetensor)\n",
206
+ " i=0;\n",
207
+ " while i < 64:\n",
208
+ " for j in range(0, seq_length):\n",
209
+ " cursor = (i*seq_length*2 + j*8); # seq_length*2 because we use 2 characters per element due to pyserial \\CR \\LF issue\n",
210
+ " word = char_array[cursor : cursor+8];\n",
211
+ " \n",
212
+ " # 2s complements\n",
213
+ " val3 = twos_comp(int(word[0:2],16), 8)\n",
214
+ " val2 = twos_comp(int(word[2:4],16), 8)\n",
215
+ " val1 = twos_comp(int(word[4:6],16), 8)\n",
216
+ " val0 = twos_comp(int(word[6:8],16), 8)\n",
217
+ " \n",
218
+ " out_tensor[0,i+3,j] = val3;\n",
219
+ " out_tensor[0,i+2,j] = val2;\n",
220
+ " out_tensor[0,i+1,j] = val1;\n",
221
+ " out_tensor[0,i+0,j] = val0;\n",
222
+ " \n",
223
+ " i=i+4\n",
224
+ "\n",
225
+ " return out_tensor\n",
226
+ "\n",
227
+ "def widemode_twos_comp(val, bits):\n",
228
+ " if (val & (1 << (bits - 1))) != 0:\n",
229
+ " val = ((val - (1 << bits)) >> 5) + 1\n",
230
+ " return (val >> 5)\n",
231
+ "\n",
232
+ "def tensor_fromserial_widemode64(char_array, seq_length, typetensor):\n",
233
+ " out_tensor = torch.zeros_like(typetensor)\n",
234
+ " i=0;\n",
235
+ " while i < 64:\n",
236
+ " for j in range(0, seq_length):\n",
237
+ " cursor = (i*seq_length*8 + j*32); # seq_length*8 now because we use 8 characters per element, same pyserial issue\n",
238
+ " word = char_array[cursor : cursor+32];\n",
239
+ " \n",
240
+ " # 2s complements\n",
241
+ " val0 = twos_comp(int(word[0:8],16), 32)\n",
242
+ " val1 = twos_comp(int(word[8:16],16), 32)\n",
243
+ " val2 = twos_comp(int(word[16:24],16), 32)\n",
244
+ " val3 = twos_comp(int(word[24:32],16), 32)\n",
245
+ " \n",
246
+ " out_tensor[0,i+0,j] = val0;\n",
247
+ " out_tensor[0,i+1,j] = val1;\n",
248
+ " out_tensor[0,i+2,j] = val2;\n",
249
+ " out_tensor[0,i+3,j] = val3;\n",
250
+ " \n",
251
+ " i=i+4\n",
252
+ "\n",
253
+ " return out_tensor"
254
+ ]
255
+ },
256
+ {
257
+ "cell_type": "markdown",
258
+ "id": "f248bc1d",
259
+ "metadata": {},
260
+ "source": [
261
+ "## Test"
262
+ ]
263
+ },
264
+ {
265
+ "cell_type": "markdown",
266
+ "id": "76d11d80",
267
+ "metadata": {},
268
+ "source": [
269
+ "### choose id of example"
270
+ ]
271
+ },
272
+ {
273
+ "cell_type": "code",
274
+ "execution_count": 8,
275
+ "id": "cdbfd418",
276
+ "metadata": {},
277
+ "outputs": [],
278
+ "source": [
279
+ "example_idx = 120"
280
+ ]
281
+ },
282
+ {
283
+ "cell_type": "markdown",
284
+ "id": "26e82b50",
285
+ "metadata": {},
286
+ "source": [
287
+ "### on PC"
288
+ ]
289
+ },
290
+ {
291
+ "cell_type": "code",
292
+ "execution_count": 9,
293
+ "id": "250dcc52",
294
+ "metadata": {},
295
+ "outputs": [
296
+ {
297
+ "name": "stdout",
298
+ "output_type": "stream",
299
+ "text": [
300
+ "trg = but this won ’ t be the last answer , although for the time being it will drive corporate restructuring and the managerial mind .\n",
301
+ "\n",
302
+ "predicted trg = but this will not be the latest response , though it will now be the central force of corporate restructuring and managerial thinking .\n",
303
+ "\n",
304
+ "src = pero esto no será la última respuesta , aunque por ahora será la fuerza central de la reestructuración corporativa y el pensamiento gerencial .\n",
305
+ "\n"
306
+ ]
307
+ }
308
+ ],
309
+ "source": [
310
+ "model.to(device)\n",
311
+ "src = vars(test_data.examples[example_idx])['src']\n",
312
+ "trg = tokenizer_en.decode(vars(test_data.examples[example_idx])['trg'], skip_special_tokens=False)\n",
313
+ "print(f'trg = {trg}')\n",
314
+ "print(\"\")\n",
315
+ "translation = translate_sentence(src, SRC, TRG, model, device)\n",
316
+ "print(f'predicted trg = {translation}')\n",
317
+ "print(\"\")\n",
318
+ "src = tokenizer_es.decode(src, skip_special_tokens=False)\n",
319
+ "print(f'src = {src}')\n",
320
+ "print(\"\")"
321
+ ]
322
+ },
323
+ {
324
+ "cell_type": "markdown",
325
+ "id": "10e43fe8",
326
+ "metadata": {},
327
+ "source": [
328
+ "### on chip"
329
+ ]
330
+ },
331
+ {
332
+ "cell_type": "code",
333
+ "execution_count": 10,
334
+ "id": "b7aa9adc",
335
+ "metadata": {},
336
+ "outputs": [],
337
+ "source": [
338
+ "enc_pre = model.encoder.pre.to(device)\n",
339
+ "dec_pre = model.decoder.pre.to(device)\n",
340
+ "dec_i2w = model.decoder.fff.to(device)\n",
341
+ "\n",
342
+ "src = vars(test_data.examples[example_idx])['src']\n",
343
+ "trg = tokenizer_en.decode(vars(test_data.examples[example_idx])['trg'], skip_special_tokens=False)"
344
+ ]
345
+ },
346
+ {
347
+ "cell_type": "markdown",
348
+ "id": "738e668a",
349
+ "metadata": {},
350
+ "source": [
351
+ "**MARK**\n",
352
+ "\n",
353
+ "The below cell starts running a serial terminal on this notebook. First run this cell, and when it says \"waiting for ai85\", load the \"assets/demo.elf\" program onto the ai85 chip, and start running it (type c in gdb). This should trigger the terminal here, and operation should resume normally.\n",
354
+ "\n",
355
+ "The cell is designed to translate a single sentence."
356
+ ]
357
+ },
358
+ {
359
+ "cell_type": "code",
360
+ "execution_count": 11,
361
+ "id": "0f5a5628",
362
+ "metadata": {},
363
+ "outputs": [],
364
+ "source": [
365
+ "def ai85_demo_function():\n",
366
+ " \n",
367
+ " print(\"Please enter a Spanish sentence\")\n",
368
+ " textinput = input()\n",
369
+ " print(\"\")\n",
370
+ " print(\"\")\n",
371
+ "\n",
372
+ " src = (tokenizer_es.encode(textinput)).ids\n",
373
+ " trg = tokenizer_en.decode(vars(test_data.examples[example_idx])['trg'], skip_special_tokens=False)\n",
374
+ " with serial.Serial('/dev/ttyACM0', 115200) as ser: # , timeout=5 (not necessary, just for info)\n",
375
+ " tokens = src\n",
376
+ " tokens = [SRC.init_token] + tokens + [SRC.eos_token] + [SRC.pad_token] * (48 - 2 - len(tokens)) \n",
377
+ " src_tensor = torch.LongTensor(tokens).unsqueeze(0).to(device)\n",
378
+ "\n",
379
+ " batch_size = src_tensor.shape[0];\n",
380
+ " src_len = src_tensor.shape[1];\n",
381
+ " enc_pre_d = enc_pre(src_tensor, 0, src_len, batch_size);\n",
382
+ " encarray = singlepass64_tensor2serial(48, enc_pre_d);\n",
383
+ "\n",
384
+ " #### to chip\n",
385
+ " print(\"** shallow.AI ai85 demo **\")\n",
386
+ " print(\"** loading demo to ai85 **\")\n",
387
+ " line = ser.readline()\n",
388
+ " while(line != b''):\n",
389
+ " line = ser.readline()\n",
390
+ " if(line == b'GJcav7Wf2kmhaXJdsO0QVzX3slsv96Ck\\r\\n'):\n",
391
+ " ser.write(encarray.encode(encoding=\"ascii\"))\n",
392
+ " line = ser.readline()\n",
393
+ " break\n",
394
+ "\n",
395
+ " trg_indexes = [TRG.init_token, ] + [TRG.pad_token] * (48 - 1) \n",
396
+ "\n",
397
+ " done_decoding_flag = False\n",
398
+ " for i in range(47):\n",
399
+ " start_idx = max(0, i - 7)\n",
400
+ " trg_tensor = torch.LongTensor(trg_indexes[start_idx:start_idx + 8]).unsqueeze(0).to(device)\n",
401
+ " batch_size = trg_tensor.shape[0]\n",
402
+ " trg_len = trg_tensor.shape[1]\n",
403
+ " pos_start = max(0, i - 7)\n",
404
+ " dec_pre_d = dec_pre(trg_tensor, pos_start, trg_len + pos_start, batch_size)\n",
405
+ " decarray = singlepass64_tensor2serial(8, dec_pre_d);\n",
406
+ " while(line != b''):\n",
407
+ " line = ser.readline()\n",
408
+ " if(line == b'gZMFxLf6muLVf9P6Iyea56VbA4qktpUR\\r\\n'):\n",
409
+ " if(done_decoding_flag):\n",
410
+ " print(\"****** ai85 is done ******\")\n",
411
+ " decarray = \"done\" + decarray[4:]\n",
412
+ " ser.write(decarray.encode(encoding=\"ascii\"))\n",
413
+ " line = ser.readline()\n",
414
+ " break\n",
415
+ "\n",
416
+ " if(done_decoding_flag):\n",
417
+ " break\n",
418
+ "\n",
419
+ " line = ser.readline()\n",
420
+ " h2e_out = tensor_fromserial_widemode64(line, 1, dec_pre_d[:,:,0:1]) / (128.0 * 2**(5+1))\n",
421
+ " output = dec_i2w(h2e_out.permute(0, 2, 1))\n",
422
+ " pred_token = output.argmax(2)\n",
423
+ " trg_indexes[i + 1] = pred_token\n",
424
+ " if pred_token == TRG.eos_token:\n",
425
+ " done_decoding_flag = True\n",
426
+ " \n",
427
+ " try:\n",
428
+ " trg_indexes = trg_indexes[1:trg_indexes.index(TRG.eos_token)]\n",
429
+ " except ValueError: \n",
430
+ " trg_indexes = trg_indexes[1:]\n",
431
+ "\n",
432
+ " trg_tokens = tokenizer_en.decode(trg_indexes, skip_special_tokens=False)\n",
433
+ " \n",
434
+ " print(\"\")\n",
435
+ " print(\"\")\n",
436
+ " print(\"English translation on ai85:\")\n",
437
+ " print(f'{trg_tokens}')"
438
+ ]
439
+ },
440
+ {
441
+ "cell_type": "markdown",
442
+ "id": "af1aa370",
443
+ "metadata": {},
444
+ "source": [
445
+ "# NLP demo software by HyperbeeAI\n",
446
+ "\n",
447
+ "Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. main@shallow.ai "
448
+ ]
449
+ },
450
+ {
451
+ "cell_type": "code",
452
+ "execution_count": 12,
453
+ "id": "7df357a0",
454
+ "metadata": {},
455
+ "outputs": [
456
+ {
457
+ "name": "stdout",
458
+ "output_type": "stream",
459
+ "text": [
460
+ "Please enter a Spanish sentence\n",
461
+ "La vinculación entre el crecimiento económico y el bienestar humano parece evidente.\n",
462
+ "\n",
463
+ "\n",
464
+ "** shallow.AI ai85 demo **\n",
465
+ "** loading demo to ai85 **\n",
466
+ "****** ai85 is done ******\n",
467
+ "\n",
468
+ "\n",
469
+ "English translation on ai85:\n",
470
+ "the link between economic growth and human welfare seems clear .\n"
471
+ ]
472
+ }
473
+ ],
474
+ "source": [
475
+ "ai85_demo_function()"
476
+ ]
477
+ },
478
+ {
479
+ "cell_type": "code",
480
+ "execution_count": null,
481
+ "id": "3e7577a0",
482
+ "metadata": {},
483
+ "outputs": [],
484
+ "source": []
485
+ },
486
+ {
487
+ "cell_type": "code",
488
+ "execution_count": null,
489
+ "id": "52a397de",
490
+ "metadata": {},
491
+ "outputs": [],
492
+ "source": []
493
+ },
494
+ {
495
+ "cell_type": "code",
496
+ "execution_count": null,
497
+ "id": "96f7b68e",
498
+ "metadata": {},
499
+ "outputs": [],
500
+ "source": []
501
+ },
502
+ {
503
+ "cell_type": "code",
504
+ "execution_count": null,
505
+ "id": "3fae6816",
506
+ "metadata": {},
507
+ "outputs": [],
508
+ "source": []
509
+ },
510
+ {
511
+ "cell_type": "code",
512
+ "execution_count": null,
513
+ "id": "0a92e88d",
514
+ "metadata": {},
515
+ "outputs": [],
516
+ "source": []
517
+ },
518
+ {
519
+ "cell_type": "code",
520
+ "execution_count": null,
521
+ "id": "e60ac632",
522
+ "metadata": {},
523
+ "outputs": [],
524
+ "source": []
525
+ },
526
+ {
527
+ "cell_type": "code",
528
+ "execution_count": null,
529
+ "id": "9f982aec",
530
+ "metadata": {},
531
+ "outputs": [],
532
+ "source": []
533
+ },
534
+ {
535
+ "cell_type": "code",
536
+ "execution_count": null,
537
+ "id": "bfbc6cfc",
538
+ "metadata": {},
539
+ "outputs": [],
540
+ "source": []
541
+ },
542
+ {
543
+ "cell_type": "code",
544
+ "execution_count": null,
545
+ "id": "b59b5243",
546
+ "metadata": {},
547
+ "outputs": [],
548
+ "source": []
549
+ },
550
+ {
551
+ "cell_type": "code",
552
+ "execution_count": null,
553
+ "id": "61b8c8d3",
554
+ "metadata": {},
555
+ "outputs": [],
556
+ "source": []
557
+ },
558
+ {
559
+ "cell_type": "code",
560
+ "execution_count": null,
561
+ "id": "459a0550",
562
+ "metadata": {},
563
+ "outputs": [],
564
+ "source": []
565
+ },
566
+ {
567
+ "cell_type": "code",
568
+ "execution_count": null,
569
+ "id": "82cc8933",
570
+ "metadata": {},
571
+ "outputs": [],
572
+ "source": []
573
+ },
574
+ {
575
+ "cell_type": "code",
576
+ "execution_count": null,
577
+ "id": "d9e43f05",
578
+ "metadata": {},
579
+ "outputs": [],
580
+ "source": []
581
+ },
582
+ {
583
+ "cell_type": "code",
584
+ "execution_count": null,
585
+ "id": "04c6aee2",
586
+ "metadata": {},
587
+ "outputs": [],
588
+ "source": []
589
+ },
590
+ {
591
+ "cell_type": "code",
592
+ "execution_count": null,
593
+ "id": "de644855",
594
+ "metadata": {},
595
+ "outputs": [],
596
+ "source": []
597
+ }
598
+ ],
599
+ "metadata": {
600
+ "kernelspec": {
601
+ "display_name": "Python 3",
602
+ "language": "python",
603
+ "name": "python3"
604
+ },
605
+ "language_info": {
606
+ "codemirror_mode": {
607
+ "name": "ipython",
608
+ "version": 3
609
+ },
610
+ "file_extension": ".py",
611
+ "mimetype": "text/x-python",
612
+ "name": "python",
613
+ "nbconvert_exporter": "python",
614
+ "pygments_lexer": "ipython3",
615
+ "version": "3.8.10"
616
+ }
617
+ },
618
+ "nbformat": 4,
619
+ "nbformat_minor": 5
620
+ }
evaluation.ipynb ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "acb67391",
6
+ "metadata": {},
7
+ "source": [
8
+ "# NLP demo software by HyperbeeAI\n",
9
+ "\n",
10
+ "Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. main@shallow.ai \n",
11
+ "\n",
12
+ "### Evaluation\n",
13
+ "\n",
14
+ "This notebook evaluates the model on the test set with chosen examples, and calculates the BLEU score. A simulation of the ai85 chip implemented in pytorch is used for this purpose. See imported .py modules for further info."
15
+ ]
16
+ },
17
+ {
18
+ "cell_type": "code",
19
+ "execution_count": 1,
20
+ "id": "3899e26e",
21
+ "metadata": {},
22
+ "outputs": [
23
+ {
24
+ "name": "stdout",
25
+ "output_type": "stream",
26
+ "text": [
27
+ "imported utils.py\n",
28
+ "NLP demo software by HyperbeeAI. Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. main@shallow.ai\n",
29
+ "\n",
30
+ "imported layers.py\n",
31
+ "NLP demo software by HyperbeeAI. Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. main@shallow.ai\n",
32
+ "\n",
33
+ "imported functions.py\n",
34
+ "NLP demo software by HyperbeeAI. Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. main@shallow.ai\n",
35
+ "\n",
36
+ "imported models.py\n",
37
+ "NLP demo software by HyperbeeAI. Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. main@shallow.ai\n",
38
+ "\n",
39
+ "imported dataloader.py\n",
40
+ "NLP demo software by HyperbeeAI. Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. main@shallow.ai\n",
41
+ "\n"
42
+ ]
43
+ }
44
+ ],
45
+ "source": [
46
+ "import torch, random\n",
47
+ "import torch.nn as nn\n",
48
+ "from torchtext.legacy.datasets import TranslationDataset\n",
49
+ "from torchtext.legacy.data import Field, BucketIterator\n",
50
+ "from utils import tokenize_es, tokenize_en, tokenizer_es, tokenizer_en, TRG_PAD_IDX, \\\n",
51
+ " translate_sentence, calculate_bleu\n",
52
+ "from models import encoder, decoder, seq2seq\n",
53
+ "from dataloader import NewsDataset"
54
+ ]
55
+ },
56
+ {
57
+ "cell_type": "code",
58
+ "execution_count": 2,
59
+ "id": "812af6e8",
60
+ "metadata": {},
61
+ "outputs": [],
62
+ "source": [
63
+ "SEED = 1234\n",
64
+ "random.seed(SEED)\n",
65
+ "torch.manual_seed(SEED)\n",
66
+ "torch.cuda.manual_seed(SEED)\n",
67
+ "torch.backends.cudnn.deterministic = True\n",
68
+ "BATCH_SIZE = 48"
69
+ ]
70
+ },
71
+ {
72
+ "cell_type": "code",
73
+ "execution_count": 3,
74
+ "id": "b5717979",
75
+ "metadata": {},
76
+ "outputs": [
77
+ {
78
+ "name": "stdout",
79
+ "output_type": "stream",
80
+ "text": [
81
+ "Working with device: cuda\n"
82
+ ]
83
+ }
84
+ ],
85
+ "source": [
86
+ "SRC = Field(tokenize = tokenize_es, \n",
87
+ " init_token = tokenizer_es.token_to_id(\"<BOS>\"), \n",
88
+ " eos_token = tokenizer_es.token_to_id(\"<EOS>\"), \n",
89
+ " pad_token = tokenizer_es.token_to_id(\"<PAD>\"),\n",
90
+ " unk_token = tokenizer_es.token_to_id(\"<UNK>\"),\n",
91
+ " use_vocab = False,\n",
92
+ " batch_first = True)\n",
93
+ "\n",
94
+ "TRG = Field(tokenize = tokenize_en, \n",
95
+ " init_token = tokenizer_en.token_to_id(\"<BOS>\"), \n",
96
+ " eos_token = tokenizer_en.token_to_id(\"<EOS>\"), \n",
97
+ " pad_token = tokenizer_en.token_to_id(\"<PAD>\"),\n",
98
+ " unk_token = tokenizer_en.token_to_id(\"<UNK>\"),\n",
99
+ " use_vocab = False,\n",
100
+ " batch_first = True)\n",
101
+ "\n",
102
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
103
+ "#device = 'cpu'\n",
104
+ "print(\"Working with device:\", device)"
105
+ ]
106
+ },
107
+ {
108
+ "cell_type": "code",
109
+ "execution_count": 4,
110
+ "id": "5819e256",
111
+ "metadata": {},
112
+ "outputs": [],
113
+ "source": [
114
+ "train_data, valid_data, test_data = NewsDataset.splits(exts=('.es', '.en'), fields=(SRC, TRG))\n",
115
+ "_, _, test_iterator = BucketIterator.splits(\n",
116
+ " (train_data, valid_data, test_data),\n",
117
+ " batch_size = BATCH_SIZE,\n",
118
+ " device = device)"
119
+ ]
120
+ },
121
+ {
122
+ "cell_type": "code",
123
+ "execution_count": 5,
124
+ "id": "a2cbdf99",
125
+ "metadata": {},
126
+ "outputs": [],
127
+ "source": [
128
+ "enc = encoder(device)\n",
129
+ "dec = decoder(device, TRG_PAD_IDX)\n",
130
+ "model = seq2seq(enc, dec)"
131
+ ]
132
+ },
133
+ {
134
+ "cell_type": "code",
135
+ "execution_count": 6,
136
+ "id": "516e80e4",
137
+ "metadata": {},
138
+ "outputs": [],
139
+ "source": [
140
+ "trained_checkpoint = \"assets/es2en_hw_cp6.pt\"\n",
141
+ "res = model.load_state_dict(torch.load(trained_checkpoint, map_location=device), strict=False);\n",
142
+ "model.to(device);"
143
+ ]
144
+ },
145
+ {
146
+ "cell_type": "code",
147
+ "execution_count": 7,
148
+ "id": "14a2a9ef",
149
+ "metadata": {
150
+ "scrolled": true
151
+ },
152
+ "outputs": [
153
+ {
154
+ "name": "stdout",
155
+ "output_type": "stream",
156
+ "text": [
157
+ "Example from test data:\n",
158
+ "trg = for a relatively poor country like china , real unions could help balance employers ’ power , bringing quality - of - life benefits that outweigh the growth costs .\n",
159
+ "\n",
160
+ "predicted trg = for a relatively poor country as china , the existence of real unions could help balance employers ’ power , generating higher life benefits than the costs for growth .\n",
161
+ "\n",
162
+ "src = para un país relativamente pobre como es china , la existencia de sindicatos reales podría ayudar a equilibrar el poder de los empleadores , generando beneficios de calidad de vida mayores que los costes para el crecimiento .\n",
163
+ "\n"
164
+ ]
165
+ }
166
+ ],
167
+ "source": [
168
+ "print(\"Example from test data:\")\n",
169
+ "example_idx = 800\n",
170
+ "src = vars(test_data.examples[example_idx])['src']\n",
171
+ "trg = tokenizer_en.decode(vars(test_data.examples[example_idx])['trg'], skip_special_tokens=False)\n",
172
+ "print(f'trg = {trg}')\n",
173
+ "print(\"\")\n",
174
+ "translation = translate_sentence(src, SRC, TRG, model, device)\n",
175
+ "print(f'predicted trg = {translation}')\n",
176
+ "print(\"\")\n",
177
+ "src = tokenizer_es.decode(src, skip_special_tokens=False)\n",
178
+ "print(f'src = {src}')\n",
179
+ "print(\"\")"
180
+ ]
181
+ },
182
+ {
183
+ "cell_type": "code",
184
+ "execution_count": 8,
185
+ "id": "7e64577f",
186
+ "metadata": {},
187
+ "outputs": [
188
+ {
189
+ "name": "stderr",
190
+ "output_type": "stream",
191
+ "text": [
192
+ "1it [00:00, 5.08it/s]"
193
+ ]
194
+ },
195
+ {
196
+ "name": "stdout",
197
+ "output_type": "stream",
198
+ "text": [
199
+ "Evaluate on bleu:\n"
200
+ ]
201
+ },
202
+ {
203
+ "name": "stderr",
204
+ "output_type": "stream",
205
+ "text": [
206
+ "3998it [14:55, 4.47it/s]\n",
207
+ "That's 100 lines that end in a tokenized period ('.')\n",
208
+ "It looks like you forgot to detokenize your test data, which may hurt your score.\n",
209
+ "If you insist your data is detokenized, or don't care, you can suppress this message with '--force'.\n"
210
+ ]
211
+ },
212
+ {
213
+ "name": "stdout",
214
+ "output_type": "stream",
215
+ "text": [
216
+ "BLEU score:\n",
217
+ "{'score': 28.35048236992193, 'counts': [57540, 32851, 20648, 13309], 'totals': [100210, 96590, 92970, 89354], 'precisions': [57.41941921963876, 34.01076716016151, 22.209314832741743, 14.894688542202923], 'bp': 1.0, 'sys_len': 100210, 'ref_len': 91115}\n"
218
+ ]
219
+ }
220
+ ],
221
+ "source": [
222
+ "b_score = calculate_bleu(test_data, SRC, TRG, model, device)\n",
223
+ "print('BLEU score:')\n",
224
+ "print(b_score)"
225
+ ]
226
+ },
227
+ {
228
+ "cell_type": "code",
229
+ "execution_count": null,
230
+ "id": "dd6ae971",
231
+ "metadata": {},
232
+ "outputs": [],
233
+ "source": []
234
+ }
235
+ ],
236
+ "metadata": {
237
+ "kernelspec": {
238
+ "display_name": "Python 3",
239
+ "language": "python",
240
+ "name": "python3"
241
+ },
242
+ "language_info": {
243
+ "codemirror_mode": {
244
+ "name": "ipython",
245
+ "version": 3
246
+ },
247
+ "file_extension": ".py",
248
+ "mimetype": "text/x-python",
249
+ "name": "python",
250
+ "nbconvert_exporter": "python",
251
+ "pygments_lexer": "ipython3",
252
+ "version": "3.8.10"
253
+ }
254
+ },
255
+ "nbformat": 4,
256
+ "nbformat_minor": 5
257
+ }
functions.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ###########################################################################
2
+ # NLP demo software by HyperbeeAI. #
3
+ # Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. main@shallow.ai #
4
+ ###########################################################################
5
+ license_statement = "NLP demo software by HyperbeeAI. Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. main@shallow.ai"
6
+ print("imported functions.py")
7
+ print(license_statement)
8
+ print("")
9
+
10
+ import torch, sys
11
+ import torch.nn as nn
12
+ from torch.autograd import Function
13
+
14
+ class Q_ud(Function):
15
+ @staticmethod
16
+ def forward(_, x, xb):
17
+ factor = 2**(xb-1)
18
+ return x.mul(factor).add(.5).floor().div(factor)
19
+
20
+ class Q_u(Function):
21
+ @staticmethod
22
+ def forward(_, x, xb):
23
+ factor = 2**(8-xb)
24
+ return x.mul(factor).add(.5).floor()
25
+
26
+ class Q_d(Function):
27
+ @staticmethod
28
+ def forward(_, x, xb):
29
+ factor = 2**(xb-1)
30
+ return x.div(factor).add(.5).floor()
31
+
32
+ class quantization(nn.Module):
33
+ def __init__(self, xb = 8, mode='updown', wide=False):
34
+ super().__init__()
35
+ self.xb = xb
36
+ self.mode = mode
37
+ self.wide = wide
38
+
39
+ def forward(self, x):
40
+ if(self.mode=='updown'):
41
+ return Q_ud.apply(x, self.xb)
42
+ elif(self.mode=='down'):
43
+ if(self.wide):
44
+ return Q_d.apply(x, self.xb - 5)
45
+ else:
46
+ return Q_d.apply(x, self.xb)
47
+ elif(self.mode=='up'):
48
+ return Q_u.apply(x, self.xb)
49
+ else:
50
+ print('wrong quantization mode. exiting')
51
+ sys.exit()
52
+
53
+ class clamping_hw(nn.Module):
54
+ def __init__(self, xb = 8, wide=False):
55
+ super().__init__()
56
+ if(wide):
57
+ self.min_val = -2**(30-1)
58
+ self.max_val = 2**(30-1)-1
59
+ else:
60
+ self.min_val = -2**(xb-1)
61
+ self.max_val = 2**(xb-1)-1
62
+
63
+ def forward(self, x):
64
+ return x.clamp(min=self.min_val, max=self.max_val)
65
+
66
+ ###################################################
67
+ ### Linear layer functional
68
+ def linear_functional(x, weight, bias, _stride, _padding):
69
+ # dummy linear function that has same arguments as conv
70
+ return nn.functional.linear(x, weight, bias)
layers.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ###########################################################################
2
+ # NLP demo software by HyperbeeAI. #
3
+ # Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. main@shallow.ai #
4
+ ###########################################################################
5
+ license_statement = "NLP demo software by HyperbeeAI. Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. main@shallow.ai"
6
+ print("imported layers.py")
7
+ print(license_statement)
8
+ print("")
9
+
10
+ import torch, sys
11
+ import torch.nn as nn
12
+ import numpy as np
13
+ from torch.autograd import Function
14
+ from functions import quantization, clamping_hw, linear_functional
15
+
16
+ class ai85_base(nn.Module):
17
+ def __init__(
18
+ self,
19
+ operation_module = None,
20
+ operation_fcnl = None,
21
+ activation_module = None,
22
+ output_width_30b = False
23
+ ):
24
+ super().__init__()
25
+ self.op = operation_module
26
+ self.op_fcn = operation_fcnl
27
+ self.act = activation_module
28
+ self.wide = output_width_30b
29
+ self.quantize_Q_d_8b = None
30
+ self.quantize_Q_u_wb = None
31
+ self.quantize_Q_d_wide = None
32
+ self.clamp_C_hw_8b = None
33
+ self.clamp_C_hw_wide = None
34
+ self.output_shift = nn.Parameter(torch.Tensor([ 0 ]), requires_grad=False)
35
+ self.weight_bits = nn.Parameter(torch.Tensor([ 8 ]), requires_grad=False)
36
+ self.bias_bits = nn.Parameter(torch.Tensor([ 8 ]), requires_grad=False)
37
+ self.quantize_activation = nn.Parameter(torch.Tensor([ 1 ]), requires_grad=False)
38
+ self.adjust_output_shift = nn.Parameter(torch.Tensor([ 0 ]), requires_grad=False)
39
+ self.shift_quantile = nn.Parameter(torch.Tensor([ 1 ]), requires_grad=False)
40
+ weight_bits = self.weight_bits
41
+ bias_bits = self.bias_bits
42
+ shift_quantile = self.shift_quantile
43
+ self.configure_layer_base( weight_bits, bias_bits, shift_quantile )
44
+
45
+ def configure_layer_base(self, weight_bits, bias_bits, shift_quantile):
46
+ self.quantize_Q_d_8b = quantization(xb = 8, mode ='down' , wide=False) # 8 here is activation bits
47
+ self.quantize_Q_u_wb = quantization(xb = weight_bits, mode ='up' , wide=False)
48
+ self.quantize_Q_d_wide = quantization(xb = 8, mode ='down' , wide=True) # 8 here is activation bits, but its wide, so check inside
49
+ self.clamp_C_hw_8b = clamping_hw(xb = 8, wide=False) # 8 here is activation bits
50
+ self.clamp_C_hw_wide = clamping_hw(xb = None, wide=True) # None to avoid misleading info on the # of bits, check inside
51
+ self.weight_bits = nn.Parameter(torch.Tensor([ weight_bits ]), requires_grad=False)
52
+ self.bias_bits = nn.Parameter(torch.Tensor([ bias_bits ]), requires_grad=False)
53
+ self.shift_quantile = nn.Parameter(torch.Tensor([ shift_quantile ]), requires_grad=False)
54
+
55
+ def forward(self, x):
56
+ w = self.op.weight
57
+ b = self.op.bias
58
+ los = self.output_shift
59
+ s_o = 2**(los)
60
+ w_q = self.quantize_Q_u_wb(w);
61
+ b_q = self.quantize_Q_u_wb(b);
62
+
63
+ x = self.op_fcn(x, w_q, b_q, self.op.stride, self.op.padding) # convolution / linear
64
+ x = x*s_o
65
+ if(self.act is not None):
66
+ x = self.act(x)
67
+ if((self.wide) and (self.act is None)):
68
+ x = self.quantize_Q_d_wide(x)
69
+ x = self.clamp_C_hw_wide(x)
70
+ ### The +5 here is the 5 fractional bits the chip adds to the number in wide mode
71
+ ### we divide the number back here to get it back into range. ai8x-training does not do this for some reason
72
+ ### until the synthesis/deployment phase, and they do a +1 bit, why?
73
+ x = x / (2**(5)); # this is simulation of chip behavior
74
+ x = x / 128.0 # this is ours, for convenience + this part is done outside the chip since it's the step before table lookup
75
+ x = x / 2.0; # this is ours, for convenience + this part is done outside the chip since it's the step before table lookup
76
+ else:
77
+ x = self.quantize_Q_d_8b(x)
78
+ x = self.clamp_C_hw_8b(x)
79
+
80
+ return x
81
+
82
+ class ai85_conv1d(ai85_base):
83
+ def __init__(
84
+ self,
85
+ C_in_channels = None,
86
+ D_out_channels = None,
87
+ K_kernel_dimension = None,
88
+ padding = 0,
89
+ activation = None,
90
+ output_width_30b = False,
91
+ ):
92
+
93
+ if(activation is None):
94
+ activation_fcn = None;
95
+ elif(activation == 'relu'):
96
+ activation_fcn = nn.ReLU(inplace=True);
97
+ else:
98
+ print('wrong activation type in model. only {relu} is acceptable. exiting')
99
+ sys.exit()
100
+
101
+ operation_mdl = nn.Conv1d(C_in_channels, D_out_channels, kernel_size=K_kernel_dimension, stride=1, padding=padding, bias=True);
102
+ operation_fcn = nn.functional.conv1d
103
+
104
+ super().__init__(
105
+ activation_module = activation_fcn,
106
+ operation_module = operation_mdl,
107
+ operation_fcnl = operation_fcn,
108
+ output_width_30b = output_width_30b,
109
+ )
110
+
111
+ class ai85_add(nn.Module):
112
+ def __init__(self ):
113
+ super().__init__()
114
+ self.clamp_C_hw_8b = clamping_hw( xb = 8, wide=False) # 8 here is activation bits
115
+
116
+ def forward(self, x, res):
117
+ x = self.clamp_C_hw_8b(x+res)
118
+ return x
119
+
120
+ class ai85_fullyconnected(ai85_base):
121
+ def __init__(
122
+ self,
123
+ in_features = None,
124
+ out_features = None,
125
+ activation = None,
126
+ output_width_30b = False):
127
+
128
+ if(activation is None):
129
+ activation_fcn = None;
130
+ elif(activation == 'relu'):
131
+ activation_fcn = nn.ReLU(inplace=True);
132
+ else:
133
+ print('wrong activation type in model. only {relu} is acceptable. exiting')
134
+ sys.exit()
135
+
136
+ operation_mdl = nn.Linear(in_features, out_features, bias=True);
137
+ operation_fcn = linear_functional
138
+
139
+ super().__init__(
140
+ activation_module = activation_fcn,
141
+ operation_module = operation_mdl,
142
+ operation_fcnl = operation_fcn,
143
+ output_width_30b = output_width_30b
144
+ )
145
+ # Define dummy arguments to make Linear and conv compatible in ai85_base.
146
+ # the name "op" here refers to op in super, i.e., in base_layer
147
+ self.op.stride = None
148
+ self.op.padding = None
149
+
150
+ class lpre(nn.Module):
151
+ def __init__(self):
152
+ super().__init__()
153
+ self.ee1 = nn.Embedding(16384, 64)
154
+ self.ee2 = nn.Embedding(48, 64)
155
+ self.quantize = quantization(xb = 8, mode ='updown', wide=False)
156
+
157
+ def forward(self, x, sp1, sp2, sb):
158
+ pp= torch.arange(sp1, sp2).unsqueeze(0).repeat(sb, 1).to(x.device)
159
+ ee2_d = self.ee2(pp)
160
+ ee1_d = self.ee1(x)
161
+ ed = ee1_d + ee2_d
162
+ min_w = self.ee2.weight.data.min() + self.ee1.weight.data.min()
163
+ max_w = self.ee2.weight.data.max() + self.ee1.weight.data.max()
164
+ t = (ed - min_w) / (max_w - min_w)
165
+ t = t.add(-0.5).mul(2.0)
166
+ t = self.quantize(t)
167
+ t = t.clamp(min= -1.0, max=1.0-(1.0/128.0))
168
+ t = t.mul(2**(8-1)).add(0.5).floor().clamp(min=-128, max=127)
169
+ return t.permute(0, 2, 1)
models.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ###########################################################################
2
+ # NLP demo software by HyperbeeAI. #
3
+ # Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. main@shallow.ai #
4
+ ###########################################################################
5
+ license_statement = "NLP demo software by HyperbeeAI. Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. main@shallow.ai"
6
+ print("imported models.py")
7
+ print(license_statement)
8
+ print("")
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import layers
13
+
14
+ class encoder_ai85cnn(nn.Module):
15
+ def __init__(
16
+ self,
17
+ device,
18
+ **kwargs
19
+ ):
20
+ super().__init__()
21
+ self.cc0 = layers.ai85_conv1d( 64, 112, 1, 0, activation=None)
22
+ self.cc1 = layers.ai85_conv1d( 112, 112, 3, 1, activation='relu')
23
+ self.res1 = layers.ai85_add()
24
+ self.cc2 = layers.ai85_conv1d( 112, 112, 3, 1, activation='relu')
25
+ self.res2 = layers.ai85_add()
26
+ self.cc3 = layers.ai85_conv1d( 112, 112, 3, 1, activation='relu')
27
+ self.res3 = layers.ai85_add()
28
+ self.cc4 = layers.ai85_conv1d( 112, 112, 3, 1, activation='relu')
29
+ self.res4 = layers.ai85_add()
30
+ self.cc5 = layers.ai85_conv1d( 112, 64 , 1, 0, activation=None)
31
+ self.resg = layers.ai85_add()
32
+ self.device = device
33
+
34
+ def forward(self, x):
35
+ r = self.cc0(x)
36
+ t = self.cc1( r )
37
+ r = self.res1(t, r)
38
+ t = self.cc2( r )
39
+ r = self.res2(t, r)
40
+ t = self.cc3( r )
41
+ r = self.res3(t, r)
42
+ t = self.cc4( r )
43
+ r = self.res4(t, r)
44
+ t = self.cc5(r)
45
+ y = self.resg(t, x)
46
+ return y
47
+
48
+ class encoder(nn.Module):
49
+ def __init__(
50
+ self,
51
+ device,
52
+ **kwargs
53
+ ):
54
+ super().__init__()
55
+ self.pre = layers.lpre()
56
+ self.cnn = encoder_ai85cnn(device = device);
57
+ self.device = device
58
+
59
+ def forward(self, x):
60
+ ssb = x.shape[0]
61
+ sl = x.shape[1]
62
+ pre_d = self.pre(x, 0, sl, ssb)
63
+ out = self.cnn(pre_d)
64
+ return out, pre_d
65
+
66
+ class decoder_ai85cnn_ccf(nn.Module):
67
+ def __init__(self, **kwargs):
68
+ super().__init__()
69
+ self.op = layers.ai85_conv1d( 112, 64 , 1, 0, activation=None, output_width_30b=True)
70
+
71
+ def forward(self, x):
72
+ y = self.op(x)
73
+ return y
74
+
75
+ class decoder_ai85cnn_cpr(nn.Module):
76
+ def __init__(self, **kwargs):
77
+ super().__init__()
78
+ self.layer1 = layers.ai85_conv1d( 64*2, 64, 1, 0, activation='relu')
79
+ self.layer2 = layers.ai85_conv1d( 64, 64, 1, 0, activation='relu')
80
+
81
+ def forward(self, x):
82
+ x = self.layer1(x)
83
+ y = self.layer2(x)
84
+ return y
85
+
86
+ class decoder_ai85cnn_cl1(nn.Module):
87
+ def __init__(self, **kwargs):
88
+ super().__init__()
89
+ self.op = layers.ai85_conv1d( 112, 112, 3, 0, activation='relu')
90
+
91
+ def forward(self, x):
92
+ y = self.op(x)
93
+ return y
94
+
95
+ class decoder_ai85cnn_cma(nn.Module):
96
+ def __init__(self, **kwargs):
97
+ super().__init__()
98
+ self.op = layers.ai85_conv1d( 64, 112, 1, 0, activation=None)
99
+ self.res= layers.ai85_add()
100
+
101
+ def forward(self, x, res):
102
+ t = self.op(x)
103
+ y = self.res(t, res)
104
+ return y
105
+
106
+ class decoder_ai85cnn_claa(nn.Module):
107
+ def __init__(self, **kwargs):
108
+ super().__init__()
109
+ self.op = layers.ai85_conv1d( 112, 112, 3, 0, activation='relu')
110
+
111
+ def forward(self, x):
112
+ y = self.op(x)
113
+ return y
114
+
115
+ class decoder_ai85cnn_cl0(nn.Module):
116
+ def __init__(self, **kwargs):
117
+ super().__init__()
118
+ self.op = layers.ai85_conv1d( 64, 112, 1, 0, activation=None)
119
+
120
+ def forward(self, x):
121
+ y = self.op(x)
122
+ return y
123
+
124
+ class decoder_ai85cnn_clfa(nn.Module):
125
+ def __init__(self, **kwargs):
126
+ super().__init__()
127
+ self.op = layers.ai85_conv1d( 112, 112, 3, 0, activation='relu')
128
+
129
+ def forward(self, x):
130
+ y = self.op(x)
131
+ return y
132
+
133
+ class decoder_ai85cnn_ccac(nn.Module):
134
+ def __init__(self, **kwargs):
135
+ super().__init__()
136
+ self.op = layers.ai85_conv1d( 112, 112, 3, 0, activation='relu')
137
+
138
+ def forward(self, x):
139
+ y = self.op(x)
140
+ return y
141
+
142
+ class decoder_ai85cnn_cib(nn.Module):
143
+ def __init__(self, **kwargs):
144
+ super().__init__()
145
+ self.op = layers.ai85_conv1d( 112, 64 , 1, 0, activation=None)
146
+
147
+ def forward(self, x):
148
+ y = self.op(x)
149
+ return y
150
+
151
+ class decoder(nn.Module):
152
+ def __init__(
153
+ self,
154
+ device,
155
+ tpi,
156
+ **kwargs
157
+ ):
158
+ super().__init__()
159
+
160
+ self.device = device
161
+ self.tpi = tpi
162
+ self.pre = layers.lpre()
163
+ self.fff = nn.Linear(64, 16384)
164
+ self.fff.weight = self.pre.ee1.weight # i.e., fff is not a layer, this is just an easy way of doing reverse embedding on pytorch
165
+ self.cl0 = decoder_ai85cnn_cl0();
166
+ self.ccf = decoder_ai85cnn_ccf();
167
+ self.cib = decoder_ai85cnn_cib();
168
+ self.cma = decoder_ai85cnn_cma();
169
+ self.cpr = decoder_ai85cnn_cpr();
170
+ self.cl1 = decoder_ai85cnn_cl1();
171
+ self.claa = decoder_ai85cnn_claa();
172
+ self.clfa = decoder_ai85cnn_clfa();
173
+ self.ccac = decoder_ai85cnn_ccac();
174
+
175
+ def forward(self, x, ees , pss=0):
176
+ ssb = x.shape[0]
177
+ sst = x.shape[1]
178
+ sl = ees.shape[2]
179
+
180
+ pre_d = self.pre(x, pss, sst + pss, ssb)
181
+ t = self.cl0(pre_d)
182
+ cl0_out = t
183
+ ssb, ts1, _ = t.shape
184
+ tp = torch.zeros(ssb, ts1, 2).fill_(self.tpi).to(t.device)
185
+ t = torch.cat((tp, t), dim = 2)
186
+ xconv = self.cl1(t)
187
+ t = self.cib(xconv)
188
+ ssb, ss_p, sst = t.shape
189
+ x2 = ees.unsqueeze(3).repeat(1, 1, 1, sst).view(ssb, ss_p, -1)
190
+ t = t.unsqueeze(2).repeat(1, 1, sl, 1).view(ssb, ss_p, -1)
191
+ t = torch.cat([t, x2], dim=1)
192
+ t = self.cpr(t)
193
+ t = t.view(ssb, ss_p, sl, sst)
194
+ t = torch.max(t, dim=2).values
195
+ t = self.cma(t, xconv)
196
+ t = torch.cat((tp, t), dim = 2)
197
+ xconv = self.claa(t)
198
+ t = self.cib(xconv)
199
+ t = t.unsqueeze(2).repeat(1, 1, sl, 1).view(ssb, ss_p, -1)
200
+ t = torch.cat([t, x2], dim=1)
201
+ t = self.cpr(t)
202
+ t = t.view(ssb, ss_p, sl, sst)
203
+ t = torch.max(t, dim=2).values
204
+ t = self.cma(t, xconv)
205
+ t = torch.cat((tp, t), dim = 2)
206
+ xconv = self.clfa(t)
207
+ t = self.cib(xconv)
208
+ t = t.unsqueeze(2).repeat(1, 1, sl, 1).view(ssb, ss_p, -1)
209
+ t = torch.cat([t, x2], dim=1)
210
+ t = self.cpr(t)
211
+ t = t.view(ssb, ss_p, sl, sst)
212
+ t = torch.max(t, dim=2).values
213
+ t = self.cma(t, xconv)
214
+ t = torch.cat((tp, t), dim = 2)
215
+ xconv = self.ccac(t)
216
+ t = self.cib(xconv)
217
+ t = t.unsqueeze(2).repeat(1, 1, sl, 1).view(ssb, ss_p, -1)
218
+ t = torch.cat([t, x2], dim=1)
219
+ t = self.cpr(t)
220
+ t = t.view(ssb, ss_p, sl, sst)
221
+ t = torch.max(t, dim=2).values
222
+ t = self.cma(t, xconv)
223
+ pss = t + sst
224
+ ccf_out = self.ccf(t)
225
+ output = self.fff(ccf_out.permute(0, 2, 1))
226
+
227
+ return output, pre_d, ccf_out
228
+
229
+ class seq2seq(nn.Module):
230
+ def __init__(self, encoder, decoder):
231
+ super().__init__()
232
+
233
+ self.encoder = encoder
234
+ self.decoder = decoder
235
+
236
+ def forward(self, src, trg):
237
+ enc_out, _ = self.encoder(src)
238
+ output, _, _ = self.decoder(trg, enc_out)
239
+ return output
240
+
news-comm-v15/news-comm-v15-all-test.en ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:757cea85bddca13bdbb0d4dbc187f748d3b97a4e04a5360b6ce7235c38b85261
3
+ size 562915
news-comm-v15/news-comm-v15-all-test.es ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5f459d6c1333abd7c545e0fd140e248dcfd05135562f25e14ffc6a98d3bccaa5
3
+ size 654959
news-comm-v15/news-comm-v15-all-valid.en ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3973e022e93220f9212c18d0d0c543ae7c309e46640da93a4a0314de999f5112
3
+ size 1
news-comm-v15/news-comm-v15-all-valid.es ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3973e022e93220f9212c18d0d0c543ae7c309e46640da93a4a0314de999f5112
3
+ size 1
news-comm-v15/news-comm-v15-all.en ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2e0bfde74c1665f5b44edfe370780d9cffc413768a6ad2e1530e1e42d0b77ae2
3
+ size 201
news-comm-v15/news-comm-v15-all.es ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f0cbc37784a40546152cd146c8f4468e44bb4a23921c51d19b1309fbd0e63200
3
+ size 259
news-comm-v15/readme ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ Test data sampled from:
2
+ https://data.statmt.org/news-commentary/v15/training/news-commentary-v15.en-es.tsv.gz
utils.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ###########################################################################
2
+ # NLP demo software by HyperbeeAI. #
3
+ # Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. main@shallow.ai #
4
+ ###########################################################################
5
+ license_statement = "NLP demo software by HyperbeeAI. Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. main@shallow.ai"
6
+ print("imported utils.py")
7
+ print(license_statement)
8
+ print("")
9
+
10
+ import torch
11
+ import layers
12
+ from tokenizers import Tokenizer
13
+ import time, torch, datasets
14
+ from tqdm import tqdm
15
+
16
+ tokenizer_en = None
17
+ tokenizer_es = None
18
+
19
+ def tokenize_es(text):
20
+ return tokenizer_es.encode(text).ids[:48 - 2]
21
+
22
+ def tokenize_en(text):
23
+ return tokenizer_en.encode(text).ids[:48 - 1]
24
+
25
+ def translate_sentence(sentence, src_field, trg_field, model, device):
26
+
27
+ model.eval()
28
+ if isinstance(sentence, str):
29
+ tokens = tokenize_es(sentence)
30
+ else:
31
+ tokens = sentence
32
+
33
+ tokens = [src_field.init_token] + tokens + [src_field.eos_token] + [src_field.pad_token] * (48 - 2 - len(tokens))
34
+ src_tensor = torch.LongTensor(tokens).unsqueeze(0).to(device)
35
+
36
+ with torch.no_grad():
37
+ enc_out, _ = model.encoder(src_tensor)
38
+
39
+ trg_indexes = [trg_field.init_token, ] + [trg_field.pad_token] * (48 - 1)
40
+
41
+ for i in range(48 - 1):
42
+ start_idx = max(0, i - 7)
43
+
44
+ trg_tensor = torch.LongTensor(trg_indexes[start_idx:start_idx + 8]).unsqueeze(0).to(device)
45
+
46
+ with torch.no_grad():
47
+ output, _, _ = model.decoder(trg_tensor, enc_out, max(0, i - 7))
48
+
49
+ pred_token = output.argmax(2)[:, min(i, 7)].item()
50
+ trg_indexes[i + 1] = pred_token
51
+ if pred_token == trg_field.eos_token:
52
+ break
53
+
54
+ try:
55
+ trg_indexes = trg_indexes[1:trg_indexes.index(trg_field.eos_token)]
56
+ except ValueError:
57
+ trg_indexes = trg_indexes[1:]
58
+
59
+ trg_tokens = tokenizer_en.decode(trg_indexes, skip_special_tokens=False)
60
+
61
+ return trg_tokens
62
+
63
+
64
+ def calculate_bleu(data, src_field, trg_field, model, device, spiece=False, output_file = f"test.{time.time()}.out"):
65
+
66
+ if spiece:
67
+ from tokenizers import pre_tokenizers
68
+ pre_tokenizer = pre_tokenizers.Digits(individual_digits=True)
69
+ else:
70
+ pre_tokenizer = tokenizer_en.pre_tokenizer
71
+
72
+ trgs = []
73
+ pred_trgs = []
74
+ print('Evaluate on bleu:')
75
+ for src, trg in tqdm(zip(open("news-comm-v15/news-comm-v15-all-test.es"), open("news-comm-v15/news-comm-v15-all-test.en"))):
76
+
77
+ if len(src) < 3 or len(trg) < 3:
78
+ continue
79
+
80
+ normalized = pre_tokenizer.pre_tokenize_str(tokenizer_en.normalizer.normalize_str(trg))
81
+
82
+ if len(normalized) > 48:
83
+ continue
84
+
85
+ trgs.append([ " ".join(map(lambda x: x[0], normalized)) ])
86
+
87
+ pred_trg = translate_sentence(src, src_field, trg_field, model, device)
88
+ pred_trgs.append(pred_trg)
89
+
90
+
91
+ with open(output_file, "w") as fo:
92
+ fo.write("\n".join(pred_trgs))
93
+
94
+ sacrebleu = datasets.load_metric('sacrebleu')
95
+ return sacrebleu.compute(predictions=pred_trgs, references=trgs)
96
+
97
+ tokenizer_es = Tokenizer.from_file(f"assets/es.json")
98
+ tokenizer_en = Tokenizer.from_file(f"assets/en.json")
99
+ TRG_PAD_IDX = tokenizer_en.token_to_id("<PAD>")