fxtentacle commited on
Commit
dbeb0d0
1 Parent(s): b9e6d62

Upload Generate TEVR Tokenizer.ipynb

Browse files
Files changed (1) hide show
  1. Generate TEVR Tokenizer.ipynb +390 -0
Generate TEVR Tokenizer.ipynb ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "8e94ea44",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "# TODO: load large text dataset like OSCAR\n",
11
+ "all_sentences_de = [\"Über vier Jahrzehnte gehörte er zu den führenden Bildhauern Niederbayerns\", \"die katze ist niedlich\"] * 1000"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "code",
16
+ "execution_count": 2,
17
+ "id": "e9db6478",
18
+ "metadata": {},
19
+ "outputs": [],
20
+ "source": [
21
+ "from huggingface_hub import snapshot_download\n",
22
+ "data_folder = snapshot_download(\"fxtentacle/tevr-token-entropy-predictor-de\")"
23
+ ]
24
+ },
25
+ {
26
+ "cell_type": "code",
27
+ "execution_count": 3,
28
+ "id": "8b37a91c",
29
+ "metadata": {},
30
+ "outputs": [],
31
+ "source": [
32
+ "from transformers import T5ForConditionalGeneration\n",
33
+ "model = T5ForConditionalGeneration.from_pretrained(data_folder)\n",
34
+ "model.to('cuda')\n",
35
+ "model.eval()\n",
36
+ "None"
37
+ ]
38
+ },
39
+ {
40
+ "cell_type": "code",
41
+ "execution_count": 4,
42
+ "id": "317a0bb2",
43
+ "metadata": {},
44
+ "outputs": [],
45
+ "source": [
46
+ "import torch\n",
47
+ "\n",
48
+ "def text_to_cross_entropy(text):\n",
49
+ " ttext = torch.tensor([[0]+list(text.encode('UTF-8'))],dtype=torch.int64).to('cuda')\n",
50
+ " tone = torch.tensor([[1]],dtype=torch.int32).to('cuda')\n",
51
+ " logits = model.forward(input_ids=tone, attention_mask=tone, decoder_input_ids=ttext, return_dict=False)[0].detach()\n",
52
+ " cross_entropy = torch.nn.functional.cross_entropy(input=logits[0][:-1], target=ttext[0][1:], reduction='none').detach().cpu().numpy()\n",
53
+ " return cross_entropy"
54
+ ]
55
+ },
56
+ {
57
+ "cell_type": "code",
58
+ "execution_count": 5,
59
+ "id": "aec4c1e1",
60
+ "metadata": {},
61
+ "outputs": [
62
+ {
63
+ "name": "stdout",
64
+ "output_type": "stream",
65
+ "text": [
66
+ "Über vier Jahrzehnte gehörte er zu den führenden Bildhauern Niederbayerns\n",
67
+ "Ü 7.254014\n",
68
+ "b 0.17521738\n",
69
+ "e 0.00046933602\n",
70
+ "r 0.01929327\n",
71
+ " 0.0003675739\n",
72
+ "v 0.20927554\n",
73
+ "i 6.13207\n",
74
+ "e 0.3896482\n",
75
+ "r 0.009583538\n",
76
+ " 2.07364\n",
77
+ "J 0.02978594\n",
78
+ "a 2.483246\n",
79
+ "h 0.1591908\n",
80
+ "r 0.0045124847\n",
81
+ "z 0.00028653807\n",
82
+ "e 4.0242333\n",
83
+ "h 0.031035878\n",
84
+ "n 0.028907888\n",
85
+ "t 0.003264101\n",
86
+ "e 0.0018929198\n",
87
+ " 0.05816966\n",
88
+ "g 1.2782481\n",
89
+ "e 3.5076692\n",
90
+ "h 0.694337\n",
91
+ "ö 0.5319732\n",
92
+ "r 0.48336726\n",
93
+ "t 0.0050443523\n",
94
+ "e 0.0017187123\n",
95
+ " 0.14511283\n",
96
+ "e 1.0435015\n",
97
+ "r 0.18165778\n",
98
+ " 1.0247636\n",
99
+ "z 0.3594512\n",
100
+ "u 0.0077577736\n",
101
+ " 2.072764\n",
102
+ "d 0.17377533\n",
103
+ "e 1.0727838\n",
104
+ "n 1.2805216\n",
105
+ " 0.24939628\n",
106
+ "f 0.27717885\n",
107
+ "ü 0.012466482\n",
108
+ "h 4.4356546\n",
109
+ "r 1.7371752\n",
110
+ "e 0.051492628\n",
111
+ "n 2.99407\n",
112
+ "d 0.009648594\n",
113
+ "e 0.19667451\n",
114
+ "n 0.007495021\n",
115
+ " 0.2529005\n",
116
+ "B 0.004451485\n",
117
+ "i 0.024661187\n",
118
+ "l 0.0028436247\n",
119
+ "d 2.6620464\n",
120
+ "h 2.825038\n",
121
+ "a 0.8215449\n",
122
+ "u 0.011406565\n",
123
+ "e 2.9599652\n",
124
+ "r 0.45834702\n",
125
+ "n 0.11848967\n",
126
+ " 0.5955992\n",
127
+ "N 0.010709903\n",
128
+ "i 1.5338714\n",
129
+ "e 0.1834471\n",
130
+ "d 5.668945\n",
131
+ "e 2.052247\n",
132
+ "r 0.7692907\n",
133
+ "b 0.0675718\n",
134
+ "a 0.028234791\n",
135
+ "y 0.0045266068\n",
136
+ "e 4.1125383\n",
137
+ "r 1.2630856\n",
138
+ "n 5.436057\n",
139
+ "s 0.46446246\n"
140
+ ]
141
+ }
142
+ ],
143
+ "source": [
144
+ "text = all_sentences_de[0]\n",
145
+ "cross_entropy = text_to_cross_entropy(text)\n",
146
+ "print(text)\n",
147
+ "for i in range(len(text)):\n",
148
+ " print(text[i], cross_entropy[i])"
149
+ ]
150
+ },
151
+ {
152
+ "cell_type": "code",
153
+ "execution_count": 6,
154
+ "id": "57350f0e",
155
+ "metadata": {},
156
+ "outputs": [
157
+ {
158
+ "name": "stderr",
159
+ "output_type": "stream",
160
+ "text": [
161
+ "100%|██████████| 2000/2000 [00:09<00:00, 219.00it/s]\n"
162
+ ]
163
+ }
164
+ ],
165
+ "source": [
166
+ "from tqdm import tqdm \n",
167
+ "\n",
168
+ "sentence_data = all_sentences_de\n",
169
+ "\n",
170
+ "text_and_entropies = []\n",
171
+ "for text in tqdm(sentence_data):\n",
172
+ " text_and_entropies.append([text,text_to_cross_entropy(text)])"
173
+ ]
174
+ },
175
+ {
176
+ "cell_type": "code",
177
+ "execution_count": 7,
178
+ "id": "502fdacc",
179
+ "metadata": {},
180
+ "outputs": [
181
+ {
182
+ "name": "stderr",
183
+ "output_type": "stream",
184
+ "text": [
185
+ "100%|██████████| 1999/1999 [00:00<00:00, 14645.88it/s]\n"
186
+ ]
187
+ },
188
+ {
189
+ "name": "stdout",
190
+ "output_type": "stream",
191
+ "text": [
192
+ "[('lich', 1000), ('hnte', 999), ('rbay', 999), ('örte', 999), ('hört', 999), ('ahrz', 999), ('jahr', 999), ('bild', 999)]\n"
193
+ ]
194
+ },
195
+ {
196
+ "name": "stderr",
197
+ "output_type": "stream",
198
+ "text": [
199
+ "100%|██████████| 1999/1999 [00:00<00:00, 18574.04it/s]\n"
200
+ ]
201
+ },
202
+ {
203
+ "name": "stdout",
204
+ "output_type": "stream",
205
+ "text": [
206
+ "[('ist', 1000), ('den', 999), ('ber', 999), ('aue', 999), ('ern', 999), ('uer', 999)]\n"
207
+ ]
208
+ },
209
+ {
210
+ "name": "stderr",
211
+ "output_type": "stream",
212
+ "text": [
213
+ "100%|██████████| 1999/1999 [00:00<00:00, 20827.32it/s]\n"
214
+ ]
215
+ },
216
+ {
217
+ "name": "stdout",
218
+ "output_type": "stream",
219
+ "text": [
220
+ "[('ni', 1000), ('ge', 999), ('er', 999), ('fü', 999), ('vi', 999)]\n"
221
+ ]
222
+ },
223
+ {
224
+ "name": "stderr",
225
+ "output_type": "stream",
226
+ "text": [
227
+ "100%|██████████| 1999/1999 [00:00<00:00, 19927.45it/s]"
228
+ ]
229
+ },
230
+ {
231
+ "name": "stdout",
232
+ "output_type": "stream",
233
+ "text": [
234
+ "[('e', 2999), ('u', 999), ('n', 999), ('h', 999)]\n"
235
+ ]
236
+ },
237
+ {
238
+ "name": "stderr",
239
+ "output_type": "stream",
240
+ "text": [
241
+ "\n"
242
+ ]
243
+ }
244
+ ],
245
+ "source": [
246
+ "from collections import Counter\n",
247
+ "\n",
248
+ "# 4s\n",
249
+ "#target_lengths = [1]\n",
250
+ "#token_budgets = [36]\n",
251
+ "\n",
252
+ "# 4m\n",
253
+ "target_lengths = [4,3,2,1]\n",
254
+ "token_budgets = [40,80,96,36]\n",
255
+ "\n",
256
+ "# 4l\n",
257
+ "#target_lengths = [4,3,2,1]\n",
258
+ "#token_budgets = [384,320,160,36]\n",
259
+ "\n",
260
+ "ngrams = [Counter() for l in target_lengths]\n",
261
+ "tokens = []\n",
262
+ "\n",
263
+ "for tgi,tgl in enumerate(target_lengths):\n",
264
+ " for row in tqdm(text_and_entropies[1:]):\n",
265
+ " use_text = row[0]\n",
266
+ " use_scores = row[1]\n",
267
+ " for t in tokens:\n",
268
+ " use_text = use_text.replace(t[0],'#')\n",
269
+ " candidates = []\n",
270
+ " for i in range(len(use_text)-(tgl-1)):\n",
271
+ " part = use_text[i:i+tgl].lower()\n",
272
+ " if '#' in part: continue\n",
273
+ " if ' ' in part: continue\n",
274
+ " if '-' in part: continue\n",
275
+ " score = sum(use_scores[i:i+tgl])\n",
276
+ " # print(part, score)\n",
277
+ " candidates.append([score, part])\n",
278
+ " candidates.sort(reverse=False)\n",
279
+ " candidates = candidates[:max(1,int(len(candidates)/5))]\n",
280
+ " #print(candidates)\n",
281
+ " ngrams[tgi].update([c[1] for c in candidates])\n",
282
+ " new_tokens = ngrams[tgi].most_common(token_budgets[tgi])\n",
283
+ " print(new_tokens)\n",
284
+ " tokens += new_tokens\n",
285
+ " #break"
286
+ ]
287
+ },
288
+ {
289
+ "cell_type": "code",
290
+ "execution_count": 8,
291
+ "id": "323833ad",
292
+ "metadata": {},
293
+ "outputs": [
294
+ {
295
+ "name": "stdout",
296
+ "output_type": "stream",
297
+ "text": [
298
+ "27 ['<pad>', '<eos>', ' ', 'lich', 'hnte', 'rbay', 'örte', 'hört', 'ahrz', 'jahr', 'bild', 'ist', 'den', 'ber', 'aue', 'ern', 'uer', 'ni', 'ge', 'er', 'fü', 'vi', 'e', 'u', 'n', 'h', '?']\n"
299
+ ]
300
+ }
301
+ ],
302
+ "source": [
303
+ "all_tokens = ['<pad>','<eos>',' ']+[t[0] for t in tokens]+['?']\n",
304
+ "print(len(all_tokens), all_tokens)"
305
+ ]
306
+ },
307
+ {
308
+ "cell_type": "code",
309
+ "execution_count": 9,
310
+ "id": "34724bef",
311
+ "metadata": {},
312
+ "outputs": [],
313
+ "source": [
314
+ "import json\n",
315
+ "with open('./tevr-tokenizer.txt','wt') as f:\n",
316
+ " json.dump(all_tokens, f)"
317
+ ]
318
+ },
319
+ {
320
+ "cell_type": "code",
321
+ "execution_count": 10,
322
+ "id": "72a32893",
323
+ "metadata": {},
324
+ "outputs": [],
325
+ "source": [
326
+ "import sys\n",
327
+ "import os\n",
328
+ "sys.path.append(data_folder)\n",
329
+ "from text_tokenizer import HajoTextTokenizer"
330
+ ]
331
+ },
332
+ {
333
+ "cell_type": "code",
334
+ "execution_count": 11,
335
+ "id": "a7405c3b",
336
+ "metadata": {},
337
+ "outputs": [],
338
+ "source": [
339
+ "text_tokenizer = HajoTextTokenizer('./tevr-tokenizer.txt')"
340
+ ]
341
+ },
342
+ {
343
+ "cell_type": "code",
344
+ "execution_count": 12,
345
+ "id": "5ceee8e3",
346
+ "metadata": {},
347
+ "outputs": [
348
+ {
349
+ "name": "stdout",
350
+ "output_type": "stream",
351
+ "text": [
352
+ "gehörte\n",
353
+ "[18, 25, 6]\n",
354
+ "['ge', 'h', 'örte']\n",
355
+ "['gehörte']\n"
356
+ ]
357
+ }
358
+ ],
359
+ "source": [
360
+ "sentence = \"gehörte\"\n",
361
+ "print(sentence)\n",
362
+ "encoded = text_tokenizer.encode(sentence)\n",
363
+ "print(encoded)\n",
364
+ "print([text_tokenizer.all_tokens[i] for i in encoded])\n",
365
+ "print([text_tokenizer.decode(encoded)])"
366
+ ]
367
+ }
368
+ ],
369
+ "metadata": {
370
+ "kernelspec": {
371
+ "display_name": "Python 3 (ipykernel)",
372
+ "language": "python",
373
+ "name": "python3"
374
+ },
375
+ "language_info": {
376
+ "codemirror_mode": {
377
+ "name": "ipython",
378
+ "version": 3
379
+ },
380
+ "file_extension": ".py",
381
+ "mimetype": "text/x-python",
382
+ "name": "python",
383
+ "nbconvert_exporter": "python",
384
+ "pygments_lexer": "ipython3",
385
+ "version": "3.7.5"
386
+ }
387
+ },
388
+ "nbformat": 4,
389
+ "nbformat_minor": 5
390
+ }