sonebu
commited on
Commit
•
2f6628d
1
Parent(s):
b4c89d1
moving from github
Browse files- .gitattributes +8 -0
- .gitignore +2 -0
- LICENSE +2 -0
- README.md +48 -3
- assets/ai8x-nlp-demo.gif +3 -0
- assets/demo.elf +3 -0
- assets/en.json +3 -0
- assets/es.json +3 -0
- assets/es2en_hw_cp6.pt +3 -0
- dataloader.py +33 -0
- demo.ipynb +620 -0
- evaluation.ipynb +257 -0
- functions.py +70 -0
- layers.py +169 -0
- models.py +240 -0
- news-comm-v15/news-comm-v15-all-test.en +3 -0
- news-comm-v15/news-comm-v15-all-test.es +3 -0
- news-comm-v15/news-comm-v15-all-valid.en +3 -0
- news-comm-v15/news-comm-v15-all-valid.es +3 -0
- news-comm-v15/news-comm-v15-all.en +3 -0
- news-comm-v15/news-comm-v15-all.es +3 -0
- news-comm-v15/readme +2 -0
- utils.py +99 -0
.gitattributes
CHANGED
@@ -33,3 +33,11 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.pth.tar filter=lfs diff=lfs merge=lfs -text
|
37 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
39 |
+
*.es filter=lfs diff=lfs merge=lfs -text
|
40 |
+
*.en filter=lfs diff=lfs merge=lfs -text
|
41 |
+
*.elf filter=lfs diff=lfs merge=lfs -text
|
42 |
+
*.gif filter=lfs diff=lfs merge=lfs -text
|
43 |
+
*.json filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
**/.ipynb_checkpoints/
|
2 |
+
**/__pycache__/
|
LICENSE
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
NLP demo software by HyperbeeAI
|
2 |
+
Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. main@shallow.ai
|
README.md
CHANGED
@@ -1,3 +1,48 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# NLP demo software by HyperbeeAI
|
2 |
+
Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. main@shallow.ai
|
3 |
+
|
4 |
+
This repository contains evaluation tools for the ai85 Spanish-to-English translation project.
|
5 |
+
|
6 |
+
To run the demo, see explanations in "demo.ipynb", which acts as the serial terminal to communicate with the ai85 from the host PC. Further explanations are provided below as well as in the notebooks.
|
7 |
+
|
8 |
+
![Demo](./assets/ai8x-nlp-demo.gif)
|
9 |
+
|
10 |
+
### Contents:
|
11 |
+
|
12 |
+
- **.py files:** python modules used by the Jupyter notebooks. These files define a simulation environment for the MAX78000 CNN accelerator hardware + some peripheral tools that help evaluation. Note that the simulator only includes the chip features that are relevant to this project (e.g., pooling not implemented because this project does not need it).
|
13 |
+
|
14 |
+
- **evaluation.ipynb:** this Jupyter notebook provides an interface to try out different sentences from the test set on the model in the simulation environment, and compute the BLEU score of the model over the test set.
|
15 |
+
|
16 |
+
- **demo.ipynb:** this Jupyter notebook acts as the serial interface with the chip. A sentence in the source language is sent over to the chip for translation via the serial port, the implementation on the chip translates this and sends it back via the same serial port in the target language, and the result is displayed on the notebook cell. This needs to be run together with the "assets/demo.elf" program on the chip, which does the actual translation job on the ai85. There is a specific cell on the notebook that needs to be run before the ai85 demo.elf is started. Check the notebook for further info.
|
17 |
+
|
18 |
+
- **assets/demo.elf:** C program running the actual translation application. Run this together with the demo.ipynb notebook for the translation demo. See further explanations inside demo.ipynb.
|
19 |
+
|
20 |
+
### Extras/Notes:
|
21 |
+
|
22 |
+
- the demo C program does not require any extra modules/libraries, it can be directly run the same way as the Maxim SDK examples (i.e., using the arm gdb, defining the target as "remote localhost:3333", doing "load" etc.). However, note that the Jupyter notebook demo.ipynb needs to be run together with the C program for meaningful output. There is a specific cell on the notebook that needs to be run before the ai85 demo.elf is started. Check the notebook for further info.
|
23 |
+
|
24 |
+
- The demo.ipynb notebook needs to run on the same host PC that programs the ai85 since it uses the on-board (USB) serial port (that programs the ai85) to communicate with the chip while the translation application is running.
|
25 |
+
|
26 |
+
- Although the program should run on both the EVKit and the FeatherBoard without errors (since it uses common functionality), it was only explicitly tested with the FeatherBoard for now.
|
27 |
+
|
28 |
+
### Setup:
|
29 |
+
|
30 |
+
This demo has been tested with the following configuration:
|
31 |
+
|
32 |
+
Python 3.8.10.
|
33 |
+
datasets 1.8.0
|
34 |
+
huggingface-hub 0.0.10
|
35 |
+
ipykernel 5.5.3
|
36 |
+
ipython 7.22.0
|
37 |
+
notebook 6.3.0
|
38 |
+
numpy 1.20.2
|
39 |
+
pyserial 3.5
|
40 |
+
sacrebleu 1.5.1
|
41 |
+
tokenizers 0.10.3
|
42 |
+
torch 1.8.1
|
43 |
+
torchtext 0.9.1
|
44 |
+
tqdm 4.49.0
|
45 |
+
|
46 |
+
Note1: torchtext might default to older versions (e.g., v0.8) on some containers (typically in those provided by AWS, which use older versions of python that don't align well with the newer torchtext versions), in that case, the .legacy submodule path needs to be removed from the import directives in the .py files and Jupyter notebooks.
|
47 |
+
|
48 |
+
Note2: there are multiple python packages on pip that provide serial port implementation, with conflicting function/object names too. Although the package used here gets imported with "import serial", it needs to be installed via "pip install pyserial", not "pip install serial". Make sure you get the correct version.
|
assets/ai8x-nlp-demo.gif
ADDED
Git LFS Details
|
assets/demo.elf
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:504440cab7269b333570f11888979dd63e610bcfe9e84466a0f3dca79b49ebda
|
3 |
+
size 2483932
|
assets/en.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f774c53ea142a16a7e507a67e46d882755e0b052604ea9f8afb4e51ccd48f894
|
3 |
+
size 394357
|
assets/es.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0ee2fab6b130bffdc8748cd8ce8330fba8406eb61a83cdb0128972067bdc0a82
|
3 |
+
size 407380
|
assets/es2en_hw_cp6.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f30f0e64f114594c83761887ecc9dd6edac9433d6efa9b25929f767423302fc8
|
3 |
+
size 9953564
|
dataloader.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
###########################################################################
|
2 |
+
# NLP demo software by HyperbeeAI. #
|
3 |
+
# Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. main@shallow.ai #
|
4 |
+
###########################################################################
|
5 |
+
license_statement = "NLP demo software by HyperbeeAI. Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. main@shallow.ai"
|
6 |
+
print("imported dataloader.py")
|
7 |
+
print(license_statement)
|
8 |
+
print("")
|
9 |
+
|
10 |
+
from torchtext.legacy.datasets import TranslationDataset
|
11 |
+
from torchtext.legacy.data import Field, BucketIterator
|
12 |
+
import os
|
13 |
+
|
14 |
+
class NewsDataset(TranslationDataset):
|
15 |
+
|
16 |
+
name = 'news-comm-v15'
|
17 |
+
|
18 |
+
@staticmethod
|
19 |
+
def sort_key(ex):
|
20 |
+
return len(ex.src)
|
21 |
+
|
22 |
+
@classmethod
|
23 |
+
def splits(cls, exts, fields, root='./',
|
24 |
+
train='news-comm-v15-all', validation='news-comm-v15-all-valid', test='news-comm-v15-all-test', **kwargs):
|
25 |
+
|
26 |
+
if 'path' not in kwargs:
|
27 |
+
expected_folder = os.path.join(root, cls.name)
|
28 |
+
path = expected_folder if os.path.exists(expected_folder) else None
|
29 |
+
else:
|
30 |
+
path = kwargs['path']
|
31 |
+
del kwargs['path']
|
32 |
+
|
33 |
+
return super(NewsDataset, cls).splits(exts, fields, path, root, train, validation, test, **kwargs)
|
demo.ipynb
ADDED
@@ -0,0 +1,620 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"id": "d3092ed4",
|
6 |
+
"metadata": {},
|
7 |
+
"source": [
|
8 |
+
"# NLP demo software by HyperbeeAI\n",
|
9 |
+
"\n",
|
10 |
+
"Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. main@shallow.ai \n",
|
11 |
+
"\n",
|
12 |
+
"### Deployment\n",
|
13 |
+
"\n",
|
14 |
+
"This notebook acts as the serial terminal that we use in the ai85 translation demo.\n",
|
15 |
+
"\n",
|
16 |
+
"- load parameter set\n",
|
17 |
+
"- run a test on the PC to determine what to expect from the chip\n",
|
18 |
+
"- run test on the chip via serial terminal on PC"
|
19 |
+
]
|
20 |
+
},
|
21 |
+
{
|
22 |
+
"cell_type": "markdown",
|
23 |
+
"id": "e6208384",
|
24 |
+
"metadata": {},
|
25 |
+
"source": [
|
26 |
+
"### Initialization"
|
27 |
+
]
|
28 |
+
},
|
29 |
+
{
|
30 |
+
"cell_type": "code",
|
31 |
+
"execution_count": 1,
|
32 |
+
"id": "6c10cb53",
|
33 |
+
"metadata": {},
|
34 |
+
"outputs": [
|
35 |
+
{
|
36 |
+
"name": "stdout",
|
37 |
+
"output_type": "stream",
|
38 |
+
"text": [
|
39 |
+
"imported utils.py\n",
|
40 |
+
"NLP demo software by HyperbeeAI. Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. main@shallow.ai\n",
|
41 |
+
"\n",
|
42 |
+
"imported layers.py\n",
|
43 |
+
"NLP demo software by HyperbeeAI. Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. main@shallow.ai\n",
|
44 |
+
"\n",
|
45 |
+
"imported functions.py\n",
|
46 |
+
"NLP demo software by HyperbeeAI. Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. main@shallow.ai\n",
|
47 |
+
"\n",
|
48 |
+
"imported models.py\n",
|
49 |
+
"NLP demo software by HyperbeeAI. Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. main@shallow.ai\n",
|
50 |
+
"\n",
|
51 |
+
"imported dataloader.py\n",
|
52 |
+
"NLP demo software by HyperbeeAI. Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. main@shallow.ai\n",
|
53 |
+
"\n"
|
54 |
+
]
|
55 |
+
}
|
56 |
+
],
|
57 |
+
"source": [
|
58 |
+
"import torch, random\n",
|
59 |
+
"import numpy as np\n",
|
60 |
+
"import torch.nn as nn\n",
|
61 |
+
"from torchtext.legacy.datasets import TranslationDataset\n",
|
62 |
+
"from torchtext.legacy.data import Field, BucketIterator\n",
|
63 |
+
"from utils import tokenize_es, tokenize_en, tokenizer_es, tokenizer_en, TRG_PAD_IDX, \\\n",
|
64 |
+
" translate_sentence, calculate_bleu, license_statement\n",
|
65 |
+
"from models import encoder, decoder, seq2seq\n",
|
66 |
+
"from dataloader import NewsDataset\n",
|
67 |
+
"\n",
|
68 |
+
"import serial"
|
69 |
+
]
|
70 |
+
},
|
71 |
+
{
|
72 |
+
"cell_type": "code",
|
73 |
+
"execution_count": 2,
|
74 |
+
"id": "9966ccad",
|
75 |
+
"metadata": {},
|
76 |
+
"outputs": [],
|
77 |
+
"source": [
|
78 |
+
"SEED = 1234\n",
|
79 |
+
"random.seed(SEED)\n",
|
80 |
+
"torch.manual_seed(SEED)\n",
|
81 |
+
"torch.cuda.manual_seed(SEED)\n",
|
82 |
+
"torch.backends.cudnn.deterministic = True\n",
|
83 |
+
"BATCH_SIZE = 48"
|
84 |
+
]
|
85 |
+
},
|
86 |
+
{
|
87 |
+
"cell_type": "code",
|
88 |
+
"execution_count": 3,
|
89 |
+
"id": "6d864c26",
|
90 |
+
"metadata": {},
|
91 |
+
"outputs": [
|
92 |
+
{
|
93 |
+
"name": "stdout",
|
94 |
+
"output_type": "stream",
|
95 |
+
"text": [
|
96 |
+
"Working with device: cuda\n"
|
97 |
+
]
|
98 |
+
}
|
99 |
+
],
|
100 |
+
"source": [
|
101 |
+
"SRC = Field(tokenize = tokenize_es, \n",
|
102 |
+
" init_token = tokenizer_es.token_to_id(\"<BOS>\"), \n",
|
103 |
+
" eos_token = tokenizer_es.token_to_id(\"<EOS>\"), \n",
|
104 |
+
" pad_token = tokenizer_es.token_to_id(\"<PAD>\"),\n",
|
105 |
+
" unk_token = tokenizer_es.token_to_id(\"<UNK>\"),\n",
|
106 |
+
" use_vocab = False,\n",
|
107 |
+
" batch_first = True)\n",
|
108 |
+
"\n",
|
109 |
+
"TRG = Field(tokenize = tokenize_en, \n",
|
110 |
+
" init_token = tokenizer_en.token_to_id(\"<BOS>\"), \n",
|
111 |
+
" eos_token = tokenizer_en.token_to_id(\"<EOS>\"), \n",
|
112 |
+
" pad_token = tokenizer_en.token_to_id(\"<PAD>\"),\n",
|
113 |
+
" unk_token = tokenizer_en.token_to_id(\"<UNK>\"),\n",
|
114 |
+
" use_vocab = False,\n",
|
115 |
+
" batch_first = True)\n",
|
116 |
+
"\n",
|
117 |
+
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
118 |
+
"#device = 'cpu'\n",
|
119 |
+
"print(\"Working with device:\", device)"
|
120 |
+
]
|
121 |
+
},
|
122 |
+
{
|
123 |
+
"cell_type": "code",
|
124 |
+
"execution_count": 4,
|
125 |
+
"id": "7f1f2efb",
|
126 |
+
"metadata": {},
|
127 |
+
"outputs": [],
|
128 |
+
"source": [
|
129 |
+
"train_data, valid_data, test_data = NewsDataset.splits(exts=('.es', '.en'), fields=(SRC, TRG))\n",
|
130 |
+
"train_iterator, valid_iterator, test_iterator = BucketIterator.splits(\n",
|
131 |
+
" (train_data, valid_data, test_data),\n",
|
132 |
+
" batch_size = BATCH_SIZE,\n",
|
133 |
+
" device = device)"
|
134 |
+
]
|
135 |
+
},
|
136 |
+
{
|
137 |
+
"cell_type": "code",
|
138 |
+
"execution_count": 5,
|
139 |
+
"id": "ccd6c1fc",
|
140 |
+
"metadata": {},
|
141 |
+
"outputs": [],
|
142 |
+
"source": [
|
143 |
+
"enc = encoder(device)\n",
|
144 |
+
"dec = decoder(device, TRG_PAD_IDX)\n",
|
145 |
+
"model = seq2seq(enc, dec)"
|
146 |
+
]
|
147 |
+
},
|
148 |
+
{
|
149 |
+
"cell_type": "code",
|
150 |
+
"execution_count": 6,
|
151 |
+
"id": "6ae348e3",
|
152 |
+
"metadata": {},
|
153 |
+
"outputs": [],
|
154 |
+
"source": [
|
155 |
+
"trained_checkpoint = \"assets/es2en_hw_cp6.pt\"\n",
|
156 |
+
"model.load_state_dict(torch.load(trained_checkpoint, map_location=device), strict=False);\n",
|
157 |
+
"model.to(device);"
|
158 |
+
]
|
159 |
+
},
|
160 |
+
{
|
161 |
+
"cell_type": "markdown",
|
162 |
+
"id": "ddb1a23b",
|
163 |
+
"metadata": {},
|
164 |
+
"source": [
|
165 |
+
"### serial conversion functions"
|
166 |
+
]
|
167 |
+
},
|
168 |
+
{
|
169 |
+
"cell_type": "code",
|
170 |
+
"execution_count": 7,
|
171 |
+
"id": "534e72f2",
|
172 |
+
"metadata": {},
|
173 |
+
"outputs": [],
|
174 |
+
"source": [
|
175 |
+
"def singlepass64_tensor2serial(seq_length, tensor):\n",
|
176 |
+
" data = tensor.cpu().detach().numpy();\n",
|
177 |
+
" char_array = '';\n",
|
178 |
+
"\n",
|
179 |
+
" i=0;\n",
|
180 |
+
" while i < 64:\n",
|
181 |
+
" for j in range(0,seq_length):\n",
|
182 |
+
" ch3 = data[0,i+3,j].astype('int8')\n",
|
183 |
+
" ch2 = data[0,i+2,j].astype('int8')\n",
|
184 |
+
" ch1 = data[0,i+1,j].astype('int8')\n",
|
185 |
+
" ch0 = data[0,i+0,j].astype('int8')\n",
|
186 |
+
"\n",
|
187 |
+
" # 2s complements\n",
|
188 |
+
" val3 = \"{0:#0{1}x}\".format(int(np.binary_repr(ch3, width=8), 2),4)\n",
|
189 |
+
" val2 = \"{0:#0{1}x}\".format(int(np.binary_repr(ch2, width=8), 2),4)\n",
|
190 |
+
" val1 = \"{0:#0{1}x}\".format(int(np.binary_repr(ch1, width=8), 2),4)\n",
|
191 |
+
" val0 = \"{0:#0{1}x}\".format(int(np.binary_repr(ch0, width=8), 2),4)\n",
|
192 |
+
"\n",
|
193 |
+
" char_array += val3[2:] + val2[2:] + val1[2:] + val0[2:]\n",
|
194 |
+
"\n",
|
195 |
+
" i=i+4\n",
|
196 |
+
" \n",
|
197 |
+
" return char_array\n",
|
198 |
+
"\n",
|
199 |
+
"def twos_comp(val, bits):\n",
|
200 |
+
" if (val & (1 << (bits - 1))) != 0:\n",
|
201 |
+
" val = val - (1 << bits)\n",
|
202 |
+
" return val\n",
|
203 |
+
"\n",
|
204 |
+
"def tensor_fromserial_singlepass64(char_array, seq_length, typetensor):\n",
|
205 |
+
" out_tensor = torch.zeros_like(typetensor)\n",
|
206 |
+
" i=0;\n",
|
207 |
+
" while i < 64:\n",
|
208 |
+
" for j in range(0, seq_length):\n",
|
209 |
+
" cursor = (i*seq_length*2 + j*8); # seq_length*2 because we use 2 characters per element due to pyserial \\CR \\LF issue\n",
|
210 |
+
" word = char_array[cursor : cursor+8];\n",
|
211 |
+
" \n",
|
212 |
+
" # 2s complements\n",
|
213 |
+
" val3 = twos_comp(int(word[0:2],16), 8)\n",
|
214 |
+
" val2 = twos_comp(int(word[2:4],16), 8)\n",
|
215 |
+
" val1 = twos_comp(int(word[4:6],16), 8)\n",
|
216 |
+
" val0 = twos_comp(int(word[6:8],16), 8)\n",
|
217 |
+
" \n",
|
218 |
+
" out_tensor[0,i+3,j] = val3;\n",
|
219 |
+
" out_tensor[0,i+2,j] = val2;\n",
|
220 |
+
" out_tensor[0,i+1,j] = val1;\n",
|
221 |
+
" out_tensor[0,i+0,j] = val0;\n",
|
222 |
+
" \n",
|
223 |
+
" i=i+4\n",
|
224 |
+
"\n",
|
225 |
+
" return out_tensor\n",
|
226 |
+
"\n",
|
227 |
+
"def widemode_twos_comp(val, bits):\n",
|
228 |
+
" if (val & (1 << (bits - 1))) != 0:\n",
|
229 |
+
" val = ((val - (1 << bits)) >> 5) + 1\n",
|
230 |
+
" return (val >> 5)\n",
|
231 |
+
"\n",
|
232 |
+
"def tensor_fromserial_widemode64(char_array, seq_length, typetensor):\n",
|
233 |
+
" out_tensor = torch.zeros_like(typetensor)\n",
|
234 |
+
" i=0;\n",
|
235 |
+
" while i < 64:\n",
|
236 |
+
" for j in range(0, seq_length):\n",
|
237 |
+
" cursor = (i*seq_length*8 + j*32); # seq_length*8 now because we use 8 characters per element, same pyserial issue\n",
|
238 |
+
" word = char_array[cursor : cursor+32];\n",
|
239 |
+
" \n",
|
240 |
+
" # 2s complements\n",
|
241 |
+
" val0 = twos_comp(int(word[0:8],16), 32)\n",
|
242 |
+
" val1 = twos_comp(int(word[8:16],16), 32)\n",
|
243 |
+
" val2 = twos_comp(int(word[16:24],16), 32)\n",
|
244 |
+
" val3 = twos_comp(int(word[24:32],16), 32)\n",
|
245 |
+
" \n",
|
246 |
+
" out_tensor[0,i+0,j] = val0;\n",
|
247 |
+
" out_tensor[0,i+1,j] = val1;\n",
|
248 |
+
" out_tensor[0,i+2,j] = val2;\n",
|
249 |
+
" out_tensor[0,i+3,j] = val3;\n",
|
250 |
+
" \n",
|
251 |
+
" i=i+4\n",
|
252 |
+
"\n",
|
253 |
+
" return out_tensor"
|
254 |
+
]
|
255 |
+
},
|
256 |
+
{
|
257 |
+
"cell_type": "markdown",
|
258 |
+
"id": "f248bc1d",
|
259 |
+
"metadata": {},
|
260 |
+
"source": [
|
261 |
+
"## Test"
|
262 |
+
]
|
263 |
+
},
|
264 |
+
{
|
265 |
+
"cell_type": "markdown",
|
266 |
+
"id": "76d11d80",
|
267 |
+
"metadata": {},
|
268 |
+
"source": [
|
269 |
+
"### choose id of example"
|
270 |
+
]
|
271 |
+
},
|
272 |
+
{
|
273 |
+
"cell_type": "code",
|
274 |
+
"execution_count": 8,
|
275 |
+
"id": "cdbfd418",
|
276 |
+
"metadata": {},
|
277 |
+
"outputs": [],
|
278 |
+
"source": [
|
279 |
+
"example_idx = 120"
|
280 |
+
]
|
281 |
+
},
|
282 |
+
{
|
283 |
+
"cell_type": "markdown",
|
284 |
+
"id": "26e82b50",
|
285 |
+
"metadata": {},
|
286 |
+
"source": [
|
287 |
+
"### on PC"
|
288 |
+
]
|
289 |
+
},
|
290 |
+
{
|
291 |
+
"cell_type": "code",
|
292 |
+
"execution_count": 9,
|
293 |
+
"id": "250dcc52",
|
294 |
+
"metadata": {},
|
295 |
+
"outputs": [
|
296 |
+
{
|
297 |
+
"name": "stdout",
|
298 |
+
"output_type": "stream",
|
299 |
+
"text": [
|
300 |
+
"trg = but this won ’ t be the last answer , although for the time being it will drive corporate restructuring and the managerial mind .\n",
|
301 |
+
"\n",
|
302 |
+
"predicted trg = but this will not be the latest response , though it will now be the central force of corporate restructuring and managerial thinking .\n",
|
303 |
+
"\n",
|
304 |
+
"src = pero esto no será la última respuesta , aunque por ahora será la fuerza central de la reestructuración corporativa y el pensamiento gerencial .\n",
|
305 |
+
"\n"
|
306 |
+
]
|
307 |
+
}
|
308 |
+
],
|
309 |
+
"source": [
|
310 |
+
"model.to(device)\n",
|
311 |
+
"src = vars(test_data.examples[example_idx])['src']\n",
|
312 |
+
"trg = tokenizer_en.decode(vars(test_data.examples[example_idx])['trg'], skip_special_tokens=False)\n",
|
313 |
+
"print(f'trg = {trg}')\n",
|
314 |
+
"print(\"\")\n",
|
315 |
+
"translation = translate_sentence(src, SRC, TRG, model, device)\n",
|
316 |
+
"print(f'predicted trg = {translation}')\n",
|
317 |
+
"print(\"\")\n",
|
318 |
+
"src = tokenizer_es.decode(src, skip_special_tokens=False)\n",
|
319 |
+
"print(f'src = {src}')\n",
|
320 |
+
"print(\"\")"
|
321 |
+
]
|
322 |
+
},
|
323 |
+
{
|
324 |
+
"cell_type": "markdown",
|
325 |
+
"id": "10e43fe8",
|
326 |
+
"metadata": {},
|
327 |
+
"source": [
|
328 |
+
"### on chip"
|
329 |
+
]
|
330 |
+
},
|
331 |
+
{
|
332 |
+
"cell_type": "code",
|
333 |
+
"execution_count": 10,
|
334 |
+
"id": "b7aa9adc",
|
335 |
+
"metadata": {},
|
336 |
+
"outputs": [],
|
337 |
+
"source": [
|
338 |
+
"enc_pre = model.encoder.pre.to(device)\n",
|
339 |
+
"dec_pre = model.decoder.pre.to(device)\n",
|
340 |
+
"dec_i2w = model.decoder.fff.to(device)\n",
|
341 |
+
"\n",
|
342 |
+
"src = vars(test_data.examples[example_idx])['src']\n",
|
343 |
+
"trg = tokenizer_en.decode(vars(test_data.examples[example_idx])['trg'], skip_special_tokens=False)"
|
344 |
+
]
|
345 |
+
},
|
346 |
+
{
|
347 |
+
"cell_type": "markdown",
|
348 |
+
"id": "738e668a",
|
349 |
+
"metadata": {},
|
350 |
+
"source": [
|
351 |
+
"**MARK**\n",
|
352 |
+
"\n",
|
353 |
+
"The below cell starts running a serial terminal on this notebook. First run this cell, and when it says \"waiting for ai85\", load the \"assets/demo.elf\" program onto the ai85 chip, and start running it (type c in gdb). This should trigger the terminal here, and operation should resume normally.\n",
|
354 |
+
"\n",
|
355 |
+
"The cell is designed to translate a single sentence."
|
356 |
+
]
|
357 |
+
},
|
358 |
+
{
|
359 |
+
"cell_type": "code",
|
360 |
+
"execution_count": 11,
|
361 |
+
"id": "0f5a5628",
|
362 |
+
"metadata": {},
|
363 |
+
"outputs": [],
|
364 |
+
"source": [
|
365 |
+
"def ai85_demo_function():\n",
|
366 |
+
" \n",
|
367 |
+
" print(\"Please enter a Spanish sentence\")\n",
|
368 |
+
" textinput = input()\n",
|
369 |
+
" print(\"\")\n",
|
370 |
+
" print(\"\")\n",
|
371 |
+
"\n",
|
372 |
+
" src = (tokenizer_es.encode(textinput)).ids\n",
|
373 |
+
" trg = tokenizer_en.decode(vars(test_data.examples[example_idx])['trg'], skip_special_tokens=False)\n",
|
374 |
+
" with serial.Serial('/dev/ttyACM0', 115200) as ser: # , timeout=5 (not necessary, just for info)\n",
|
375 |
+
" tokens = src\n",
|
376 |
+
" tokens = [SRC.init_token] + tokens + [SRC.eos_token] + [SRC.pad_token] * (48 - 2 - len(tokens)) \n",
|
377 |
+
" src_tensor = torch.LongTensor(tokens).unsqueeze(0).to(device)\n",
|
378 |
+
"\n",
|
379 |
+
" batch_size = src_tensor.shape[0];\n",
|
380 |
+
" src_len = src_tensor.shape[1];\n",
|
381 |
+
" enc_pre_d = enc_pre(src_tensor, 0, src_len, batch_size);\n",
|
382 |
+
" encarray = singlepass64_tensor2serial(48, enc_pre_d);\n",
|
383 |
+
"\n",
|
384 |
+
" #### to chip\n",
|
385 |
+
" print(\"** shallow.AI ai85 demo **\")\n",
|
386 |
+
" print(\"** loading demo to ai85 **\")\n",
|
387 |
+
" line = ser.readline()\n",
|
388 |
+
" while(line != b''):\n",
|
389 |
+
" line = ser.readline()\n",
|
390 |
+
" if(line == b'GJcav7Wf2kmhaXJdsO0QVzX3slsv96Ck\\r\\n'):\n",
|
391 |
+
" ser.write(encarray.encode(encoding=\"ascii\"))\n",
|
392 |
+
" line = ser.readline()\n",
|
393 |
+
" break\n",
|
394 |
+
"\n",
|
395 |
+
" trg_indexes = [TRG.init_token, ] + [TRG.pad_token] * (48 - 1) \n",
|
396 |
+
"\n",
|
397 |
+
" done_decoding_flag = False\n",
|
398 |
+
" for i in range(47):\n",
|
399 |
+
" start_idx = max(0, i - 7)\n",
|
400 |
+
" trg_tensor = torch.LongTensor(trg_indexes[start_idx:start_idx + 8]).unsqueeze(0).to(device)\n",
|
401 |
+
" batch_size = trg_tensor.shape[0]\n",
|
402 |
+
" trg_len = trg_tensor.shape[1]\n",
|
403 |
+
" pos_start = max(0, i - 7)\n",
|
404 |
+
" dec_pre_d = dec_pre(trg_tensor, pos_start, trg_len + pos_start, batch_size)\n",
|
405 |
+
" decarray = singlepass64_tensor2serial(8, dec_pre_d);\n",
|
406 |
+
" while(line != b''):\n",
|
407 |
+
" line = ser.readline()\n",
|
408 |
+
" if(line == b'gZMFxLf6muLVf9P6Iyea56VbA4qktpUR\\r\\n'):\n",
|
409 |
+
" if(done_decoding_flag):\n",
|
410 |
+
" print(\"****** ai85 is done ******\")\n",
|
411 |
+
" decarray = \"done\" + decarray[4:]\n",
|
412 |
+
" ser.write(decarray.encode(encoding=\"ascii\"))\n",
|
413 |
+
" line = ser.readline()\n",
|
414 |
+
" break\n",
|
415 |
+
"\n",
|
416 |
+
" if(done_decoding_flag):\n",
|
417 |
+
" break\n",
|
418 |
+
"\n",
|
419 |
+
" line = ser.readline()\n",
|
420 |
+
" h2e_out = tensor_fromserial_widemode64(line, 1, dec_pre_d[:,:,0:1]) / (128.0 * 2**(5+1))\n",
|
421 |
+
" output = dec_i2w(h2e_out.permute(0, 2, 1))\n",
|
422 |
+
" pred_token = output.argmax(2)\n",
|
423 |
+
" trg_indexes[i + 1] = pred_token\n",
|
424 |
+
" if pred_token == TRG.eos_token:\n",
|
425 |
+
" done_decoding_flag = True\n",
|
426 |
+
" \n",
|
427 |
+
" try:\n",
|
428 |
+
" trg_indexes = trg_indexes[1:trg_indexes.index(TRG.eos_token)]\n",
|
429 |
+
" except ValueError: \n",
|
430 |
+
" trg_indexes = trg_indexes[1:]\n",
|
431 |
+
"\n",
|
432 |
+
" trg_tokens = tokenizer_en.decode(trg_indexes, skip_special_tokens=False)\n",
|
433 |
+
" \n",
|
434 |
+
" print(\"\")\n",
|
435 |
+
" print(\"\")\n",
|
436 |
+
" print(\"English translation on ai85:\")\n",
|
437 |
+
" print(f'{trg_tokens}')"
|
438 |
+
]
|
439 |
+
},
|
440 |
+
{
|
441 |
+
"cell_type": "markdown",
|
442 |
+
"id": "af1aa370",
|
443 |
+
"metadata": {},
|
444 |
+
"source": [
|
445 |
+
"# NLP demo software by HyperbeeAI\n",
|
446 |
+
"\n",
|
447 |
+
"Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. main@shallow.ai "
|
448 |
+
]
|
449 |
+
},
|
450 |
+
{
|
451 |
+
"cell_type": "code",
|
452 |
+
"execution_count": 12,
|
453 |
+
"id": "7df357a0",
|
454 |
+
"metadata": {},
|
455 |
+
"outputs": [
|
456 |
+
{
|
457 |
+
"name": "stdout",
|
458 |
+
"output_type": "stream",
|
459 |
+
"text": [
|
460 |
+
"Please enter a Spanish sentence\n",
|
461 |
+
"La vinculación entre el crecimiento económico y el bienestar humano parece evidente.\n",
|
462 |
+
"\n",
|
463 |
+
"\n",
|
464 |
+
"** shallow.AI ai85 demo **\n",
|
465 |
+
"** loading demo to ai85 **\n",
|
466 |
+
"****** ai85 is done ******\n",
|
467 |
+
"\n",
|
468 |
+
"\n",
|
469 |
+
"English translation on ai85:\n",
|
470 |
+
"the link between economic growth and human welfare seems clear .\n"
|
471 |
+
]
|
472 |
+
}
|
473 |
+
],
|
474 |
+
"source": [
|
475 |
+
"ai85_demo_function()"
|
476 |
+
]
|
477 |
+
},
|
478 |
+
{
|
479 |
+
"cell_type": "code",
|
480 |
+
"execution_count": null,
|
481 |
+
"id": "3e7577a0",
|
482 |
+
"metadata": {},
|
483 |
+
"outputs": [],
|
484 |
+
"source": []
|
485 |
+
},
|
486 |
+
{
|
487 |
+
"cell_type": "code",
|
488 |
+
"execution_count": null,
|
489 |
+
"id": "52a397de",
|
490 |
+
"metadata": {},
|
491 |
+
"outputs": [],
|
492 |
+
"source": []
|
493 |
+
},
|
494 |
+
{
|
495 |
+
"cell_type": "code",
|
496 |
+
"execution_count": null,
|
497 |
+
"id": "96f7b68e",
|
498 |
+
"metadata": {},
|
499 |
+
"outputs": [],
|
500 |
+
"source": []
|
501 |
+
},
|
502 |
+
{
|
503 |
+
"cell_type": "code",
|
504 |
+
"execution_count": null,
|
505 |
+
"id": "3fae6816",
|
506 |
+
"metadata": {},
|
507 |
+
"outputs": [],
|
508 |
+
"source": []
|
509 |
+
},
|
510 |
+
{
|
511 |
+
"cell_type": "code",
|
512 |
+
"execution_count": null,
|
513 |
+
"id": "0a92e88d",
|
514 |
+
"metadata": {},
|
515 |
+
"outputs": [],
|
516 |
+
"source": []
|
517 |
+
},
|
518 |
+
{
|
519 |
+
"cell_type": "code",
|
520 |
+
"execution_count": null,
|
521 |
+
"id": "e60ac632",
|
522 |
+
"metadata": {},
|
523 |
+
"outputs": [],
|
524 |
+
"source": []
|
525 |
+
},
|
526 |
+
{
|
527 |
+
"cell_type": "code",
|
528 |
+
"execution_count": null,
|
529 |
+
"id": "9f982aec",
|
530 |
+
"metadata": {},
|
531 |
+
"outputs": [],
|
532 |
+
"source": []
|
533 |
+
},
|
534 |
+
{
|
535 |
+
"cell_type": "code",
|
536 |
+
"execution_count": null,
|
537 |
+
"id": "bfbc6cfc",
|
538 |
+
"metadata": {},
|
539 |
+
"outputs": [],
|
540 |
+
"source": []
|
541 |
+
},
|
542 |
+
{
|
543 |
+
"cell_type": "code",
|
544 |
+
"execution_count": null,
|
545 |
+
"id": "b59b5243",
|
546 |
+
"metadata": {},
|
547 |
+
"outputs": [],
|
548 |
+
"source": []
|
549 |
+
},
|
550 |
+
{
|
551 |
+
"cell_type": "code",
|
552 |
+
"execution_count": null,
|
553 |
+
"id": "61b8c8d3",
|
554 |
+
"metadata": {},
|
555 |
+
"outputs": [],
|
556 |
+
"source": []
|
557 |
+
},
|
558 |
+
{
|
559 |
+
"cell_type": "code",
|
560 |
+
"execution_count": null,
|
561 |
+
"id": "459a0550",
|
562 |
+
"metadata": {},
|
563 |
+
"outputs": [],
|
564 |
+
"source": []
|
565 |
+
},
|
566 |
+
{
|
567 |
+
"cell_type": "code",
|
568 |
+
"execution_count": null,
|
569 |
+
"id": "82cc8933",
|
570 |
+
"metadata": {},
|
571 |
+
"outputs": [],
|
572 |
+
"source": []
|
573 |
+
},
|
574 |
+
{
|
575 |
+
"cell_type": "code",
|
576 |
+
"execution_count": null,
|
577 |
+
"id": "d9e43f05",
|
578 |
+
"metadata": {},
|
579 |
+
"outputs": [],
|
580 |
+
"source": []
|
581 |
+
},
|
582 |
+
{
|
583 |
+
"cell_type": "code",
|
584 |
+
"execution_count": null,
|
585 |
+
"id": "04c6aee2",
|
586 |
+
"metadata": {},
|
587 |
+
"outputs": [],
|
588 |
+
"source": []
|
589 |
+
},
|
590 |
+
{
|
591 |
+
"cell_type": "code",
|
592 |
+
"execution_count": null,
|
593 |
+
"id": "de644855",
|
594 |
+
"metadata": {},
|
595 |
+
"outputs": [],
|
596 |
+
"source": []
|
597 |
+
}
|
598 |
+
],
|
599 |
+
"metadata": {
|
600 |
+
"kernelspec": {
|
601 |
+
"display_name": "Python 3",
|
602 |
+
"language": "python",
|
603 |
+
"name": "python3"
|
604 |
+
},
|
605 |
+
"language_info": {
|
606 |
+
"codemirror_mode": {
|
607 |
+
"name": "ipython",
|
608 |
+
"version": 3
|
609 |
+
},
|
610 |
+
"file_extension": ".py",
|
611 |
+
"mimetype": "text/x-python",
|
612 |
+
"name": "python",
|
613 |
+
"nbconvert_exporter": "python",
|
614 |
+
"pygments_lexer": "ipython3",
|
615 |
+
"version": "3.8.10"
|
616 |
+
}
|
617 |
+
},
|
618 |
+
"nbformat": 4,
|
619 |
+
"nbformat_minor": 5
|
620 |
+
}
|
evaluation.ipynb
ADDED
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"id": "acb67391",
|
6 |
+
"metadata": {},
|
7 |
+
"source": [
|
8 |
+
"# NLP demo software by HyperbeeAI\n",
|
9 |
+
"\n",
|
10 |
+
"Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. main@shallow.ai \n",
|
11 |
+
"\n",
|
12 |
+
"### Evaluation\n",
|
13 |
+
"\n",
|
14 |
+
"This notebook evaluates the model on the test set with chosen examples, and calculates the BLEU score. A simulation of the ai85 chip implemented in pytorch is used for this purpose. See imported .py modules for further info."
|
15 |
+
]
|
16 |
+
},
|
17 |
+
{
|
18 |
+
"cell_type": "code",
|
19 |
+
"execution_count": 1,
|
20 |
+
"id": "3899e26e",
|
21 |
+
"metadata": {},
|
22 |
+
"outputs": [
|
23 |
+
{
|
24 |
+
"name": "stdout",
|
25 |
+
"output_type": "stream",
|
26 |
+
"text": [
|
27 |
+
"imported utils.py\n",
|
28 |
+
"NLP demo software by HyperbeeAI. Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. main@shallow.ai\n",
|
29 |
+
"\n",
|
30 |
+
"imported layers.py\n",
|
31 |
+
"NLP demo software by HyperbeeAI. Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. main@shallow.ai\n",
|
32 |
+
"\n",
|
33 |
+
"imported functions.py\n",
|
34 |
+
"NLP demo software by HyperbeeAI. Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. main@shallow.ai\n",
|
35 |
+
"\n",
|
36 |
+
"imported models.py\n",
|
37 |
+
"NLP demo software by HyperbeeAI. Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. main@shallow.ai\n",
|
38 |
+
"\n",
|
39 |
+
"imported dataloader.py\n",
|
40 |
+
"NLP demo software by HyperbeeAI. Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. main@shallow.ai\n",
|
41 |
+
"\n"
|
42 |
+
]
|
43 |
+
}
|
44 |
+
],
|
45 |
+
"source": [
|
46 |
+
"import torch, random\n",
|
47 |
+
"import torch.nn as nn\n",
|
48 |
+
"from torchtext.legacy.datasets import TranslationDataset\n",
|
49 |
+
"from torchtext.legacy.data import Field, BucketIterator\n",
|
50 |
+
"from utils import tokenize_es, tokenize_en, tokenizer_es, tokenizer_en, TRG_PAD_IDX, \\\n",
|
51 |
+
" translate_sentence, calculate_bleu\n",
|
52 |
+
"from models import encoder, decoder, seq2seq\n",
|
53 |
+
"from dataloader import NewsDataset"
|
54 |
+
]
|
55 |
+
},
|
56 |
+
{
|
57 |
+
"cell_type": "code",
|
58 |
+
"execution_count": 2,
|
59 |
+
"id": "812af6e8",
|
60 |
+
"metadata": {},
|
61 |
+
"outputs": [],
|
62 |
+
"source": [
|
63 |
+
"SEED = 1234\n",
|
64 |
+
"random.seed(SEED)\n",
|
65 |
+
"torch.manual_seed(SEED)\n",
|
66 |
+
"torch.cuda.manual_seed(SEED)\n",
|
67 |
+
"torch.backends.cudnn.deterministic = True\n",
|
68 |
+
"BATCH_SIZE = 48"
|
69 |
+
]
|
70 |
+
},
|
71 |
+
{
|
72 |
+
"cell_type": "code",
|
73 |
+
"execution_count": 3,
|
74 |
+
"id": "b5717979",
|
75 |
+
"metadata": {},
|
76 |
+
"outputs": [
|
77 |
+
{
|
78 |
+
"name": "stdout",
|
79 |
+
"output_type": "stream",
|
80 |
+
"text": [
|
81 |
+
"Working with device: cuda\n"
|
82 |
+
]
|
83 |
+
}
|
84 |
+
],
|
85 |
+
"source": [
|
86 |
+
"SRC = Field(tokenize = tokenize_es, \n",
|
87 |
+
" init_token = tokenizer_es.token_to_id(\"<BOS>\"), \n",
|
88 |
+
" eos_token = tokenizer_es.token_to_id(\"<EOS>\"), \n",
|
89 |
+
" pad_token = tokenizer_es.token_to_id(\"<PAD>\"),\n",
|
90 |
+
" unk_token = tokenizer_es.token_to_id(\"<UNK>\"),\n",
|
91 |
+
" use_vocab = False,\n",
|
92 |
+
" batch_first = True)\n",
|
93 |
+
"\n",
|
94 |
+
"TRG = Field(tokenize = tokenize_en, \n",
|
95 |
+
" init_token = tokenizer_en.token_to_id(\"<BOS>\"), \n",
|
96 |
+
" eos_token = tokenizer_en.token_to_id(\"<EOS>\"), \n",
|
97 |
+
" pad_token = tokenizer_en.token_to_id(\"<PAD>\"),\n",
|
98 |
+
" unk_token = tokenizer_en.token_to_id(\"<UNK>\"),\n",
|
99 |
+
" use_vocab = False,\n",
|
100 |
+
" batch_first = True)\n",
|
101 |
+
"\n",
|
102 |
+
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
103 |
+
"#device = 'cpu'\n",
|
104 |
+
"print(\"Working with device:\", device)"
|
105 |
+
]
|
106 |
+
},
|
107 |
+
{
|
108 |
+
"cell_type": "code",
|
109 |
+
"execution_count": 4,
|
110 |
+
"id": "5819e256",
|
111 |
+
"metadata": {},
|
112 |
+
"outputs": [],
|
113 |
+
"source": [
|
114 |
+
"train_data, valid_data, test_data = NewsDataset.splits(exts=('.es', '.en'), fields=(SRC, TRG))\n",
|
115 |
+
"_, _, test_iterator = BucketIterator.splits(\n",
|
116 |
+
" (train_data, valid_data, test_data),\n",
|
117 |
+
" batch_size = BATCH_SIZE,\n",
|
118 |
+
" device = device)"
|
119 |
+
]
|
120 |
+
},
|
121 |
+
{
|
122 |
+
"cell_type": "code",
|
123 |
+
"execution_count": 5,
|
124 |
+
"id": "a2cbdf99",
|
125 |
+
"metadata": {},
|
126 |
+
"outputs": [],
|
127 |
+
"source": [
|
128 |
+
"enc = encoder(device)\n",
|
129 |
+
"dec = decoder(device, TRG_PAD_IDX)\n",
|
130 |
+
"model = seq2seq(enc, dec)"
|
131 |
+
]
|
132 |
+
},
|
133 |
+
{
|
134 |
+
"cell_type": "code",
|
135 |
+
"execution_count": 6,
|
136 |
+
"id": "516e80e4",
|
137 |
+
"metadata": {},
|
138 |
+
"outputs": [],
|
139 |
+
"source": [
|
140 |
+
"trained_checkpoint = \"assets/es2en_hw_cp6.pt\"\n",
|
141 |
+
"res = model.load_state_dict(torch.load(trained_checkpoint, map_location=device), strict=False);\n",
|
142 |
+
"model.to(device);"
|
143 |
+
]
|
144 |
+
},
|
145 |
+
{
|
146 |
+
"cell_type": "code",
|
147 |
+
"execution_count": 7,
|
148 |
+
"id": "14a2a9ef",
|
149 |
+
"metadata": {
|
150 |
+
"scrolled": true
|
151 |
+
},
|
152 |
+
"outputs": [
|
153 |
+
{
|
154 |
+
"name": "stdout",
|
155 |
+
"output_type": "stream",
|
156 |
+
"text": [
|
157 |
+
"Example from test data:\n",
|
158 |
+
"trg = for a relatively poor country like china , real unions could help balance employers ’ power , bringing quality - of - life benefits that outweigh the growth costs .\n",
|
159 |
+
"\n",
|
160 |
+
"predicted trg = for a relatively poor country as china , the existence of real unions could help balance employers ’ power , generating higher life benefits than the costs for growth .\n",
|
161 |
+
"\n",
|
162 |
+
"src = para un país relativamente pobre como es china , la existencia de sindicatos reales podría ayudar a equilibrar el poder de los empleadores , generando beneficios de calidad de vida mayores que los costes para el crecimiento .\n",
|
163 |
+
"\n"
|
164 |
+
]
|
165 |
+
}
|
166 |
+
],
|
167 |
+
"source": [
|
168 |
+
"print(\"Example from test data:\")\n",
|
169 |
+
"example_idx = 800\n",
|
170 |
+
"src = vars(test_data.examples[example_idx])['src']\n",
|
171 |
+
"trg = tokenizer_en.decode(vars(test_data.examples[example_idx])['trg'], skip_special_tokens=False)\n",
|
172 |
+
"print(f'trg = {trg}')\n",
|
173 |
+
"print(\"\")\n",
|
174 |
+
"translation = translate_sentence(src, SRC, TRG, model, device)\n",
|
175 |
+
"print(f'predicted trg = {translation}')\n",
|
176 |
+
"print(\"\")\n",
|
177 |
+
"src = tokenizer_es.decode(src, skip_special_tokens=False)\n",
|
178 |
+
"print(f'src = {src}')\n",
|
179 |
+
"print(\"\")"
|
180 |
+
]
|
181 |
+
},
|
182 |
+
{
|
183 |
+
"cell_type": "code",
|
184 |
+
"execution_count": 8,
|
185 |
+
"id": "7e64577f",
|
186 |
+
"metadata": {},
|
187 |
+
"outputs": [
|
188 |
+
{
|
189 |
+
"name": "stderr",
|
190 |
+
"output_type": "stream",
|
191 |
+
"text": [
|
192 |
+
"1it [00:00, 5.08it/s]"
|
193 |
+
]
|
194 |
+
},
|
195 |
+
{
|
196 |
+
"name": "stdout",
|
197 |
+
"output_type": "stream",
|
198 |
+
"text": [
|
199 |
+
"Evaluate on bleu:\n"
|
200 |
+
]
|
201 |
+
},
|
202 |
+
{
|
203 |
+
"name": "stderr",
|
204 |
+
"output_type": "stream",
|
205 |
+
"text": [
|
206 |
+
"3998it [14:55, 4.47it/s]\n",
|
207 |
+
"That's 100 lines that end in a tokenized period ('.')\n",
|
208 |
+
"It looks like you forgot to detokenize your test data, which may hurt your score.\n",
|
209 |
+
"If you insist your data is detokenized, or don't care, you can suppress this message with '--force'.\n"
|
210 |
+
]
|
211 |
+
},
|
212 |
+
{
|
213 |
+
"name": "stdout",
|
214 |
+
"output_type": "stream",
|
215 |
+
"text": [
|
216 |
+
"BLEU score:\n",
|
217 |
+
"{'score': 28.35048236992193, 'counts': [57540, 32851, 20648, 13309], 'totals': [100210, 96590, 92970, 89354], 'precisions': [57.41941921963876, 34.01076716016151, 22.209314832741743, 14.894688542202923], 'bp': 1.0, 'sys_len': 100210, 'ref_len': 91115}\n"
|
218 |
+
]
|
219 |
+
}
|
220 |
+
],
|
221 |
+
"source": [
|
222 |
+
"b_score = calculate_bleu(test_data, SRC, TRG, model, device)\n",
|
223 |
+
"print('BLEU score:')\n",
|
224 |
+
"print(b_score)"
|
225 |
+
]
|
226 |
+
},
|
227 |
+
{
|
228 |
+
"cell_type": "code",
|
229 |
+
"execution_count": null,
|
230 |
+
"id": "dd6ae971",
|
231 |
+
"metadata": {},
|
232 |
+
"outputs": [],
|
233 |
+
"source": []
|
234 |
+
}
|
235 |
+
],
|
236 |
+
"metadata": {
|
237 |
+
"kernelspec": {
|
238 |
+
"display_name": "Python 3",
|
239 |
+
"language": "python",
|
240 |
+
"name": "python3"
|
241 |
+
},
|
242 |
+
"language_info": {
|
243 |
+
"codemirror_mode": {
|
244 |
+
"name": "ipython",
|
245 |
+
"version": 3
|
246 |
+
},
|
247 |
+
"file_extension": ".py",
|
248 |
+
"mimetype": "text/x-python",
|
249 |
+
"name": "python",
|
250 |
+
"nbconvert_exporter": "python",
|
251 |
+
"pygments_lexer": "ipython3",
|
252 |
+
"version": "3.8.10"
|
253 |
+
}
|
254 |
+
},
|
255 |
+
"nbformat": 4,
|
256 |
+
"nbformat_minor": 5
|
257 |
+
}
|
functions.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
###########################################################################
|
2 |
+
# NLP demo software by HyperbeeAI. #
|
3 |
+
# Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. main@shallow.ai #
|
4 |
+
###########################################################################
|
5 |
+
license_statement = "NLP demo software by HyperbeeAI. Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. main@shallow.ai"
|
6 |
+
print("imported functions.py")
|
7 |
+
print(license_statement)
|
8 |
+
print("")
|
9 |
+
|
10 |
+
import torch, sys
|
11 |
+
import torch.nn as nn
|
12 |
+
from torch.autograd import Function
|
13 |
+
|
14 |
+
class Q_ud(Function):
|
15 |
+
@staticmethod
|
16 |
+
def forward(_, x, xb):
|
17 |
+
factor = 2**(xb-1)
|
18 |
+
return x.mul(factor).add(.5).floor().div(factor)
|
19 |
+
|
20 |
+
class Q_u(Function):
|
21 |
+
@staticmethod
|
22 |
+
def forward(_, x, xb):
|
23 |
+
factor = 2**(8-xb)
|
24 |
+
return x.mul(factor).add(.5).floor()
|
25 |
+
|
26 |
+
class Q_d(Function):
|
27 |
+
@staticmethod
|
28 |
+
def forward(_, x, xb):
|
29 |
+
factor = 2**(xb-1)
|
30 |
+
return x.div(factor).add(.5).floor()
|
31 |
+
|
32 |
+
class quantization(nn.Module):
|
33 |
+
def __init__(self, xb = 8, mode='updown', wide=False):
|
34 |
+
super().__init__()
|
35 |
+
self.xb = xb
|
36 |
+
self.mode = mode
|
37 |
+
self.wide = wide
|
38 |
+
|
39 |
+
def forward(self, x):
|
40 |
+
if(self.mode=='updown'):
|
41 |
+
return Q_ud.apply(x, self.xb)
|
42 |
+
elif(self.mode=='down'):
|
43 |
+
if(self.wide):
|
44 |
+
return Q_d.apply(x, self.xb - 5)
|
45 |
+
else:
|
46 |
+
return Q_d.apply(x, self.xb)
|
47 |
+
elif(self.mode=='up'):
|
48 |
+
return Q_u.apply(x, self.xb)
|
49 |
+
else:
|
50 |
+
print('wrong quantization mode. exiting')
|
51 |
+
sys.exit()
|
52 |
+
|
53 |
+
class clamping_hw(nn.Module):
|
54 |
+
def __init__(self, xb = 8, wide=False):
|
55 |
+
super().__init__()
|
56 |
+
if(wide):
|
57 |
+
self.min_val = -2**(30-1)
|
58 |
+
self.max_val = 2**(30-1)-1
|
59 |
+
else:
|
60 |
+
self.min_val = -2**(xb-1)
|
61 |
+
self.max_val = 2**(xb-1)-1
|
62 |
+
|
63 |
+
def forward(self, x):
|
64 |
+
return x.clamp(min=self.min_val, max=self.max_val)
|
65 |
+
|
66 |
+
###################################################
|
67 |
+
### Linear layer functional
|
68 |
+
def linear_functional(x, weight, bias, _stride, _padding):
|
69 |
+
# dummy linear function that has same arguments as conv
|
70 |
+
return nn.functional.linear(x, weight, bias)
|
layers.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
###########################################################################
|
2 |
+
# NLP demo software by HyperbeeAI. #
|
3 |
+
# Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. main@shallow.ai #
|
4 |
+
###########################################################################
|
5 |
+
license_statement = "NLP demo software by HyperbeeAI. Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. main@shallow.ai"
|
6 |
+
print("imported layers.py")
|
7 |
+
print(license_statement)
|
8 |
+
print("")
|
9 |
+
|
10 |
+
import torch, sys
|
11 |
+
import torch.nn as nn
|
12 |
+
import numpy as np
|
13 |
+
from torch.autograd import Function
|
14 |
+
from functions import quantization, clamping_hw, linear_functional
|
15 |
+
|
16 |
+
class ai85_base(nn.Module):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
operation_module = None,
|
20 |
+
operation_fcnl = None,
|
21 |
+
activation_module = None,
|
22 |
+
output_width_30b = False
|
23 |
+
):
|
24 |
+
super().__init__()
|
25 |
+
self.op = operation_module
|
26 |
+
self.op_fcn = operation_fcnl
|
27 |
+
self.act = activation_module
|
28 |
+
self.wide = output_width_30b
|
29 |
+
self.quantize_Q_d_8b = None
|
30 |
+
self.quantize_Q_u_wb = None
|
31 |
+
self.quantize_Q_d_wide = None
|
32 |
+
self.clamp_C_hw_8b = None
|
33 |
+
self.clamp_C_hw_wide = None
|
34 |
+
self.output_shift = nn.Parameter(torch.Tensor([ 0 ]), requires_grad=False)
|
35 |
+
self.weight_bits = nn.Parameter(torch.Tensor([ 8 ]), requires_grad=False)
|
36 |
+
self.bias_bits = nn.Parameter(torch.Tensor([ 8 ]), requires_grad=False)
|
37 |
+
self.quantize_activation = nn.Parameter(torch.Tensor([ 1 ]), requires_grad=False)
|
38 |
+
self.adjust_output_shift = nn.Parameter(torch.Tensor([ 0 ]), requires_grad=False)
|
39 |
+
self.shift_quantile = nn.Parameter(torch.Tensor([ 1 ]), requires_grad=False)
|
40 |
+
weight_bits = self.weight_bits
|
41 |
+
bias_bits = self.bias_bits
|
42 |
+
shift_quantile = self.shift_quantile
|
43 |
+
self.configure_layer_base( weight_bits, bias_bits, shift_quantile )
|
44 |
+
|
45 |
+
def configure_layer_base(self, weight_bits, bias_bits, shift_quantile):
|
46 |
+
self.quantize_Q_d_8b = quantization(xb = 8, mode ='down' , wide=False) # 8 here is activation bits
|
47 |
+
self.quantize_Q_u_wb = quantization(xb = weight_bits, mode ='up' , wide=False)
|
48 |
+
self.quantize_Q_d_wide = quantization(xb = 8, mode ='down' , wide=True) # 8 here is activation bits, but its wide, so check inside
|
49 |
+
self.clamp_C_hw_8b = clamping_hw(xb = 8, wide=False) # 8 here is activation bits
|
50 |
+
self.clamp_C_hw_wide = clamping_hw(xb = None, wide=True) # None to avoid misleading info on the # of bits, check inside
|
51 |
+
self.weight_bits = nn.Parameter(torch.Tensor([ weight_bits ]), requires_grad=False)
|
52 |
+
self.bias_bits = nn.Parameter(torch.Tensor([ bias_bits ]), requires_grad=False)
|
53 |
+
self.shift_quantile = nn.Parameter(torch.Tensor([ shift_quantile ]), requires_grad=False)
|
54 |
+
|
55 |
+
def forward(self, x):
|
56 |
+
w = self.op.weight
|
57 |
+
b = self.op.bias
|
58 |
+
los = self.output_shift
|
59 |
+
s_o = 2**(los)
|
60 |
+
w_q = self.quantize_Q_u_wb(w);
|
61 |
+
b_q = self.quantize_Q_u_wb(b);
|
62 |
+
|
63 |
+
x = self.op_fcn(x, w_q, b_q, self.op.stride, self.op.padding) # convolution / linear
|
64 |
+
x = x*s_o
|
65 |
+
if(self.act is not None):
|
66 |
+
x = self.act(x)
|
67 |
+
if((self.wide) and (self.act is None)):
|
68 |
+
x = self.quantize_Q_d_wide(x)
|
69 |
+
x = self.clamp_C_hw_wide(x)
|
70 |
+
### The +5 here is the 5 fractional bits the chip adds to the number in wide mode
|
71 |
+
### we divide the number back here to get it back into range. ai8x-training does not do this for some reason
|
72 |
+
### until the synthesis/deployment phase, and they do a +1 bit, why?
|
73 |
+
x = x / (2**(5)); # this is simulation of chip behavior
|
74 |
+
x = x / 128.0 # this is ours, for convenience + this part is done outside the chip since it's the step before table lookup
|
75 |
+
x = x / 2.0; # this is ours, for convenience + this part is done outside the chip since it's the step before table lookup
|
76 |
+
else:
|
77 |
+
x = self.quantize_Q_d_8b(x)
|
78 |
+
x = self.clamp_C_hw_8b(x)
|
79 |
+
|
80 |
+
return x
|
81 |
+
|
82 |
+
class ai85_conv1d(ai85_base):
|
83 |
+
def __init__(
|
84 |
+
self,
|
85 |
+
C_in_channels = None,
|
86 |
+
D_out_channels = None,
|
87 |
+
K_kernel_dimension = None,
|
88 |
+
padding = 0,
|
89 |
+
activation = None,
|
90 |
+
output_width_30b = False,
|
91 |
+
):
|
92 |
+
|
93 |
+
if(activation is None):
|
94 |
+
activation_fcn = None;
|
95 |
+
elif(activation == 'relu'):
|
96 |
+
activation_fcn = nn.ReLU(inplace=True);
|
97 |
+
else:
|
98 |
+
print('wrong activation type in model. only {relu} is acceptable. exiting')
|
99 |
+
sys.exit()
|
100 |
+
|
101 |
+
operation_mdl = nn.Conv1d(C_in_channels, D_out_channels, kernel_size=K_kernel_dimension, stride=1, padding=padding, bias=True);
|
102 |
+
operation_fcn = nn.functional.conv1d
|
103 |
+
|
104 |
+
super().__init__(
|
105 |
+
activation_module = activation_fcn,
|
106 |
+
operation_module = operation_mdl,
|
107 |
+
operation_fcnl = operation_fcn,
|
108 |
+
output_width_30b = output_width_30b,
|
109 |
+
)
|
110 |
+
|
111 |
+
class ai85_add(nn.Module):
|
112 |
+
def __init__(self ):
|
113 |
+
super().__init__()
|
114 |
+
self.clamp_C_hw_8b = clamping_hw( xb = 8, wide=False) # 8 here is activation bits
|
115 |
+
|
116 |
+
def forward(self, x, res):
|
117 |
+
x = self.clamp_C_hw_8b(x+res)
|
118 |
+
return x
|
119 |
+
|
120 |
+
class ai85_fullyconnected(ai85_base):
|
121 |
+
def __init__(
|
122 |
+
self,
|
123 |
+
in_features = None,
|
124 |
+
out_features = None,
|
125 |
+
activation = None,
|
126 |
+
output_width_30b = False):
|
127 |
+
|
128 |
+
if(activation is None):
|
129 |
+
activation_fcn = None;
|
130 |
+
elif(activation == 'relu'):
|
131 |
+
activation_fcn = nn.ReLU(inplace=True);
|
132 |
+
else:
|
133 |
+
print('wrong activation type in model. only {relu} is acceptable. exiting')
|
134 |
+
sys.exit()
|
135 |
+
|
136 |
+
operation_mdl = nn.Linear(in_features, out_features, bias=True);
|
137 |
+
operation_fcn = linear_functional
|
138 |
+
|
139 |
+
super().__init__(
|
140 |
+
activation_module = activation_fcn,
|
141 |
+
operation_module = operation_mdl,
|
142 |
+
operation_fcnl = operation_fcn,
|
143 |
+
output_width_30b = output_width_30b
|
144 |
+
)
|
145 |
+
# Define dummy arguments to make Linear and conv compatible in ai85_base.
|
146 |
+
# the name "op" here refers to op in super, i.e., in base_layer
|
147 |
+
self.op.stride = None
|
148 |
+
self.op.padding = None
|
149 |
+
|
150 |
+
class lpre(nn.Module):
|
151 |
+
def __init__(self):
|
152 |
+
super().__init__()
|
153 |
+
self.ee1 = nn.Embedding(16384, 64)
|
154 |
+
self.ee2 = nn.Embedding(48, 64)
|
155 |
+
self.quantize = quantization(xb = 8, mode ='updown', wide=False)
|
156 |
+
|
157 |
+
def forward(self, x, sp1, sp2, sb):
|
158 |
+
pp= torch.arange(sp1, sp2).unsqueeze(0).repeat(sb, 1).to(x.device)
|
159 |
+
ee2_d = self.ee2(pp)
|
160 |
+
ee1_d = self.ee1(x)
|
161 |
+
ed = ee1_d + ee2_d
|
162 |
+
min_w = self.ee2.weight.data.min() + self.ee1.weight.data.min()
|
163 |
+
max_w = self.ee2.weight.data.max() + self.ee1.weight.data.max()
|
164 |
+
t = (ed - min_w) / (max_w - min_w)
|
165 |
+
t = t.add(-0.5).mul(2.0)
|
166 |
+
t = self.quantize(t)
|
167 |
+
t = t.clamp(min= -1.0, max=1.0-(1.0/128.0))
|
168 |
+
t = t.mul(2**(8-1)).add(0.5).floor().clamp(min=-128, max=127)
|
169 |
+
return t.permute(0, 2, 1)
|
models.py
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
###########################################################################
|
2 |
+
# NLP demo software by HyperbeeAI. #
|
3 |
+
# Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. main@shallow.ai #
|
4 |
+
###########################################################################
|
5 |
+
license_statement = "NLP demo software by HyperbeeAI. Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. main@shallow.ai"
|
6 |
+
print("imported models.py")
|
7 |
+
print(license_statement)
|
8 |
+
print("")
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import layers
|
13 |
+
|
14 |
+
class encoder_ai85cnn(nn.Module):
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
device,
|
18 |
+
**kwargs
|
19 |
+
):
|
20 |
+
super().__init__()
|
21 |
+
self.cc0 = layers.ai85_conv1d( 64, 112, 1, 0, activation=None)
|
22 |
+
self.cc1 = layers.ai85_conv1d( 112, 112, 3, 1, activation='relu')
|
23 |
+
self.res1 = layers.ai85_add()
|
24 |
+
self.cc2 = layers.ai85_conv1d( 112, 112, 3, 1, activation='relu')
|
25 |
+
self.res2 = layers.ai85_add()
|
26 |
+
self.cc3 = layers.ai85_conv1d( 112, 112, 3, 1, activation='relu')
|
27 |
+
self.res3 = layers.ai85_add()
|
28 |
+
self.cc4 = layers.ai85_conv1d( 112, 112, 3, 1, activation='relu')
|
29 |
+
self.res4 = layers.ai85_add()
|
30 |
+
self.cc5 = layers.ai85_conv1d( 112, 64 , 1, 0, activation=None)
|
31 |
+
self.resg = layers.ai85_add()
|
32 |
+
self.device = device
|
33 |
+
|
34 |
+
def forward(self, x):
|
35 |
+
r = self.cc0(x)
|
36 |
+
t = self.cc1( r )
|
37 |
+
r = self.res1(t, r)
|
38 |
+
t = self.cc2( r )
|
39 |
+
r = self.res2(t, r)
|
40 |
+
t = self.cc3( r )
|
41 |
+
r = self.res3(t, r)
|
42 |
+
t = self.cc4( r )
|
43 |
+
r = self.res4(t, r)
|
44 |
+
t = self.cc5(r)
|
45 |
+
y = self.resg(t, x)
|
46 |
+
return y
|
47 |
+
|
48 |
+
class encoder(nn.Module):
|
49 |
+
def __init__(
|
50 |
+
self,
|
51 |
+
device,
|
52 |
+
**kwargs
|
53 |
+
):
|
54 |
+
super().__init__()
|
55 |
+
self.pre = layers.lpre()
|
56 |
+
self.cnn = encoder_ai85cnn(device = device);
|
57 |
+
self.device = device
|
58 |
+
|
59 |
+
def forward(self, x):
|
60 |
+
ssb = x.shape[0]
|
61 |
+
sl = x.shape[1]
|
62 |
+
pre_d = self.pre(x, 0, sl, ssb)
|
63 |
+
out = self.cnn(pre_d)
|
64 |
+
return out, pre_d
|
65 |
+
|
66 |
+
class decoder_ai85cnn_ccf(nn.Module):
|
67 |
+
def __init__(self, **kwargs):
|
68 |
+
super().__init__()
|
69 |
+
self.op = layers.ai85_conv1d( 112, 64 , 1, 0, activation=None, output_width_30b=True)
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
y = self.op(x)
|
73 |
+
return y
|
74 |
+
|
75 |
+
class decoder_ai85cnn_cpr(nn.Module):
|
76 |
+
def __init__(self, **kwargs):
|
77 |
+
super().__init__()
|
78 |
+
self.layer1 = layers.ai85_conv1d( 64*2, 64, 1, 0, activation='relu')
|
79 |
+
self.layer2 = layers.ai85_conv1d( 64, 64, 1, 0, activation='relu')
|
80 |
+
|
81 |
+
def forward(self, x):
|
82 |
+
x = self.layer1(x)
|
83 |
+
y = self.layer2(x)
|
84 |
+
return y
|
85 |
+
|
86 |
+
class decoder_ai85cnn_cl1(nn.Module):
|
87 |
+
def __init__(self, **kwargs):
|
88 |
+
super().__init__()
|
89 |
+
self.op = layers.ai85_conv1d( 112, 112, 3, 0, activation='relu')
|
90 |
+
|
91 |
+
def forward(self, x):
|
92 |
+
y = self.op(x)
|
93 |
+
return y
|
94 |
+
|
95 |
+
class decoder_ai85cnn_cma(nn.Module):
|
96 |
+
def __init__(self, **kwargs):
|
97 |
+
super().__init__()
|
98 |
+
self.op = layers.ai85_conv1d( 64, 112, 1, 0, activation=None)
|
99 |
+
self.res= layers.ai85_add()
|
100 |
+
|
101 |
+
def forward(self, x, res):
|
102 |
+
t = self.op(x)
|
103 |
+
y = self.res(t, res)
|
104 |
+
return y
|
105 |
+
|
106 |
+
class decoder_ai85cnn_claa(nn.Module):
|
107 |
+
def __init__(self, **kwargs):
|
108 |
+
super().__init__()
|
109 |
+
self.op = layers.ai85_conv1d( 112, 112, 3, 0, activation='relu')
|
110 |
+
|
111 |
+
def forward(self, x):
|
112 |
+
y = self.op(x)
|
113 |
+
return y
|
114 |
+
|
115 |
+
class decoder_ai85cnn_cl0(nn.Module):
|
116 |
+
def __init__(self, **kwargs):
|
117 |
+
super().__init__()
|
118 |
+
self.op = layers.ai85_conv1d( 64, 112, 1, 0, activation=None)
|
119 |
+
|
120 |
+
def forward(self, x):
|
121 |
+
y = self.op(x)
|
122 |
+
return y
|
123 |
+
|
124 |
+
class decoder_ai85cnn_clfa(nn.Module):
|
125 |
+
def __init__(self, **kwargs):
|
126 |
+
super().__init__()
|
127 |
+
self.op = layers.ai85_conv1d( 112, 112, 3, 0, activation='relu')
|
128 |
+
|
129 |
+
def forward(self, x):
|
130 |
+
y = self.op(x)
|
131 |
+
return y
|
132 |
+
|
133 |
+
class decoder_ai85cnn_ccac(nn.Module):
|
134 |
+
def __init__(self, **kwargs):
|
135 |
+
super().__init__()
|
136 |
+
self.op = layers.ai85_conv1d( 112, 112, 3, 0, activation='relu')
|
137 |
+
|
138 |
+
def forward(self, x):
|
139 |
+
y = self.op(x)
|
140 |
+
return y
|
141 |
+
|
142 |
+
class decoder_ai85cnn_cib(nn.Module):
|
143 |
+
def __init__(self, **kwargs):
|
144 |
+
super().__init__()
|
145 |
+
self.op = layers.ai85_conv1d( 112, 64 , 1, 0, activation=None)
|
146 |
+
|
147 |
+
def forward(self, x):
|
148 |
+
y = self.op(x)
|
149 |
+
return y
|
150 |
+
|
151 |
+
class decoder(nn.Module):
|
152 |
+
def __init__(
|
153 |
+
self,
|
154 |
+
device,
|
155 |
+
tpi,
|
156 |
+
**kwargs
|
157 |
+
):
|
158 |
+
super().__init__()
|
159 |
+
|
160 |
+
self.device = device
|
161 |
+
self.tpi = tpi
|
162 |
+
self.pre = layers.lpre()
|
163 |
+
self.fff = nn.Linear(64, 16384)
|
164 |
+
self.fff.weight = self.pre.ee1.weight # i.e., fff is not a layer, this is just an easy way of doing reverse embedding on pytorch
|
165 |
+
self.cl0 = decoder_ai85cnn_cl0();
|
166 |
+
self.ccf = decoder_ai85cnn_ccf();
|
167 |
+
self.cib = decoder_ai85cnn_cib();
|
168 |
+
self.cma = decoder_ai85cnn_cma();
|
169 |
+
self.cpr = decoder_ai85cnn_cpr();
|
170 |
+
self.cl1 = decoder_ai85cnn_cl1();
|
171 |
+
self.claa = decoder_ai85cnn_claa();
|
172 |
+
self.clfa = decoder_ai85cnn_clfa();
|
173 |
+
self.ccac = decoder_ai85cnn_ccac();
|
174 |
+
|
175 |
+
def forward(self, x, ees , pss=0):
|
176 |
+
ssb = x.shape[0]
|
177 |
+
sst = x.shape[1]
|
178 |
+
sl = ees.shape[2]
|
179 |
+
|
180 |
+
pre_d = self.pre(x, pss, sst + pss, ssb)
|
181 |
+
t = self.cl0(pre_d)
|
182 |
+
cl0_out = t
|
183 |
+
ssb, ts1, _ = t.shape
|
184 |
+
tp = torch.zeros(ssb, ts1, 2).fill_(self.tpi).to(t.device)
|
185 |
+
t = torch.cat((tp, t), dim = 2)
|
186 |
+
xconv = self.cl1(t)
|
187 |
+
t = self.cib(xconv)
|
188 |
+
ssb, ss_p, sst = t.shape
|
189 |
+
x2 = ees.unsqueeze(3).repeat(1, 1, 1, sst).view(ssb, ss_p, -1)
|
190 |
+
t = t.unsqueeze(2).repeat(1, 1, sl, 1).view(ssb, ss_p, -1)
|
191 |
+
t = torch.cat([t, x2], dim=1)
|
192 |
+
t = self.cpr(t)
|
193 |
+
t = t.view(ssb, ss_p, sl, sst)
|
194 |
+
t = torch.max(t, dim=2).values
|
195 |
+
t = self.cma(t, xconv)
|
196 |
+
t = torch.cat((tp, t), dim = 2)
|
197 |
+
xconv = self.claa(t)
|
198 |
+
t = self.cib(xconv)
|
199 |
+
t = t.unsqueeze(2).repeat(1, 1, sl, 1).view(ssb, ss_p, -1)
|
200 |
+
t = torch.cat([t, x2], dim=1)
|
201 |
+
t = self.cpr(t)
|
202 |
+
t = t.view(ssb, ss_p, sl, sst)
|
203 |
+
t = torch.max(t, dim=2).values
|
204 |
+
t = self.cma(t, xconv)
|
205 |
+
t = torch.cat((tp, t), dim = 2)
|
206 |
+
xconv = self.clfa(t)
|
207 |
+
t = self.cib(xconv)
|
208 |
+
t = t.unsqueeze(2).repeat(1, 1, sl, 1).view(ssb, ss_p, -1)
|
209 |
+
t = torch.cat([t, x2], dim=1)
|
210 |
+
t = self.cpr(t)
|
211 |
+
t = t.view(ssb, ss_p, sl, sst)
|
212 |
+
t = torch.max(t, dim=2).values
|
213 |
+
t = self.cma(t, xconv)
|
214 |
+
t = torch.cat((tp, t), dim = 2)
|
215 |
+
xconv = self.ccac(t)
|
216 |
+
t = self.cib(xconv)
|
217 |
+
t = t.unsqueeze(2).repeat(1, 1, sl, 1).view(ssb, ss_p, -1)
|
218 |
+
t = torch.cat([t, x2], dim=1)
|
219 |
+
t = self.cpr(t)
|
220 |
+
t = t.view(ssb, ss_p, sl, sst)
|
221 |
+
t = torch.max(t, dim=2).values
|
222 |
+
t = self.cma(t, xconv)
|
223 |
+
pss = t + sst
|
224 |
+
ccf_out = self.ccf(t)
|
225 |
+
output = self.fff(ccf_out.permute(0, 2, 1))
|
226 |
+
|
227 |
+
return output, pre_d, ccf_out
|
228 |
+
|
229 |
+
class seq2seq(nn.Module):
|
230 |
+
def __init__(self, encoder, decoder):
|
231 |
+
super().__init__()
|
232 |
+
|
233 |
+
self.encoder = encoder
|
234 |
+
self.decoder = decoder
|
235 |
+
|
236 |
+
def forward(self, src, trg):
|
237 |
+
enc_out, _ = self.encoder(src)
|
238 |
+
output, _, _ = self.decoder(trg, enc_out)
|
239 |
+
return output
|
240 |
+
|
news-comm-v15/news-comm-v15-all-test.en
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:757cea85bddca13bdbb0d4dbc187f748d3b97a4e04a5360b6ce7235c38b85261
|
3 |
+
size 562915
|
news-comm-v15/news-comm-v15-all-test.es
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5f459d6c1333abd7c545e0fd140e248dcfd05135562f25e14ffc6a98d3bccaa5
|
3 |
+
size 654959
|
news-comm-v15/news-comm-v15-all-valid.en
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3973e022e93220f9212c18d0d0c543ae7c309e46640da93a4a0314de999f5112
|
3 |
+
size 1
|
news-comm-v15/news-comm-v15-all-valid.es
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3973e022e93220f9212c18d0d0c543ae7c309e46640da93a4a0314de999f5112
|
3 |
+
size 1
|
news-comm-v15/news-comm-v15-all.en
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2e0bfde74c1665f5b44edfe370780d9cffc413768a6ad2e1530e1e42d0b77ae2
|
3 |
+
size 201
|
news-comm-v15/news-comm-v15-all.es
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f0cbc37784a40546152cd146c8f4468e44bb4a23921c51d19b1309fbd0e63200
|
3 |
+
size 259
|
news-comm-v15/readme
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
Test data sampled from:
|
2 |
+
https://data.statmt.org/news-commentary/v15/training/news-commentary-v15.en-es.tsv.gz
|
utils.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
###########################################################################
|
2 |
+
# NLP demo software by HyperbeeAI. #
|
3 |
+
# Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. main@shallow.ai #
|
4 |
+
###########################################################################
|
5 |
+
license_statement = "NLP demo software by HyperbeeAI. Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. main@shallow.ai"
|
6 |
+
print("imported utils.py")
|
7 |
+
print(license_statement)
|
8 |
+
print("")
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import layers
|
12 |
+
from tokenizers import Tokenizer
|
13 |
+
import time, torch, datasets
|
14 |
+
from tqdm import tqdm
|
15 |
+
|
16 |
+
tokenizer_en = None
|
17 |
+
tokenizer_es = None
|
18 |
+
|
19 |
+
def tokenize_es(text):
|
20 |
+
return tokenizer_es.encode(text).ids[:48 - 2]
|
21 |
+
|
22 |
+
def tokenize_en(text):
|
23 |
+
return tokenizer_en.encode(text).ids[:48 - 1]
|
24 |
+
|
25 |
+
def translate_sentence(sentence, src_field, trg_field, model, device):
|
26 |
+
|
27 |
+
model.eval()
|
28 |
+
if isinstance(sentence, str):
|
29 |
+
tokens = tokenize_es(sentence)
|
30 |
+
else:
|
31 |
+
tokens = sentence
|
32 |
+
|
33 |
+
tokens = [src_field.init_token] + tokens + [src_field.eos_token] + [src_field.pad_token] * (48 - 2 - len(tokens))
|
34 |
+
src_tensor = torch.LongTensor(tokens).unsqueeze(0).to(device)
|
35 |
+
|
36 |
+
with torch.no_grad():
|
37 |
+
enc_out, _ = model.encoder(src_tensor)
|
38 |
+
|
39 |
+
trg_indexes = [trg_field.init_token, ] + [trg_field.pad_token] * (48 - 1)
|
40 |
+
|
41 |
+
for i in range(48 - 1):
|
42 |
+
start_idx = max(0, i - 7)
|
43 |
+
|
44 |
+
trg_tensor = torch.LongTensor(trg_indexes[start_idx:start_idx + 8]).unsqueeze(0).to(device)
|
45 |
+
|
46 |
+
with torch.no_grad():
|
47 |
+
output, _, _ = model.decoder(trg_tensor, enc_out, max(0, i - 7))
|
48 |
+
|
49 |
+
pred_token = output.argmax(2)[:, min(i, 7)].item()
|
50 |
+
trg_indexes[i + 1] = pred_token
|
51 |
+
if pred_token == trg_field.eos_token:
|
52 |
+
break
|
53 |
+
|
54 |
+
try:
|
55 |
+
trg_indexes = trg_indexes[1:trg_indexes.index(trg_field.eos_token)]
|
56 |
+
except ValueError:
|
57 |
+
trg_indexes = trg_indexes[1:]
|
58 |
+
|
59 |
+
trg_tokens = tokenizer_en.decode(trg_indexes, skip_special_tokens=False)
|
60 |
+
|
61 |
+
return trg_tokens
|
62 |
+
|
63 |
+
|
64 |
+
def calculate_bleu(data, src_field, trg_field, model, device, spiece=False, output_file = f"test.{time.time()}.out"):
|
65 |
+
|
66 |
+
if spiece:
|
67 |
+
from tokenizers import pre_tokenizers
|
68 |
+
pre_tokenizer = pre_tokenizers.Digits(individual_digits=True)
|
69 |
+
else:
|
70 |
+
pre_tokenizer = tokenizer_en.pre_tokenizer
|
71 |
+
|
72 |
+
trgs = []
|
73 |
+
pred_trgs = []
|
74 |
+
print('Evaluate on bleu:')
|
75 |
+
for src, trg in tqdm(zip(open("news-comm-v15/news-comm-v15-all-test.es"), open("news-comm-v15/news-comm-v15-all-test.en"))):
|
76 |
+
|
77 |
+
if len(src) < 3 or len(trg) < 3:
|
78 |
+
continue
|
79 |
+
|
80 |
+
normalized = pre_tokenizer.pre_tokenize_str(tokenizer_en.normalizer.normalize_str(trg))
|
81 |
+
|
82 |
+
if len(normalized) > 48:
|
83 |
+
continue
|
84 |
+
|
85 |
+
trgs.append([ " ".join(map(lambda x: x[0], normalized)) ])
|
86 |
+
|
87 |
+
pred_trg = translate_sentence(src, src_field, trg_field, model, device)
|
88 |
+
pred_trgs.append(pred_trg)
|
89 |
+
|
90 |
+
|
91 |
+
with open(output_file, "w") as fo:
|
92 |
+
fo.write("\n".join(pred_trgs))
|
93 |
+
|
94 |
+
sacrebleu = datasets.load_metric('sacrebleu')
|
95 |
+
return sacrebleu.compute(predictions=pred_trgs, references=trgs)
|
96 |
+
|
97 |
+
tokenizer_es = Tokenizer.from_file(f"assets/es.json")
|
98 |
+
tokenizer_en = Tokenizer.from_file(f"assets/en.json")
|
99 |
+
TRG_PAD_IDX = tokenizer_en.token_to_id("<PAD>")
|