{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "8e94ea44", "metadata": {}, "outputs": [], "source": [ "# TODO: load large text dataset like OSCAR\n", "all_sentences_de = [\"Über vier Jahrzehnte gehörte er zu den führenden Bildhauern Niederbayerns\", \"die katze ist niedlich\"] * 1000" ] }, { "cell_type": "code", "execution_count": 2, "id": "e9db6478", "metadata": {}, "outputs": [], "source": [ "from huggingface_hub import snapshot_download\n", "data_folder = snapshot_download(\"fxtentacle/tevr-token-entropy-predictor-de\")" ] }, { "cell_type": "code", "execution_count": 3, "id": "8b37a91c", "metadata": {}, "outputs": [], "source": [ "from transformers import T5ForConditionalGeneration\n", "model = T5ForConditionalGeneration.from_pretrained(data_folder)\n", "model.to('cuda')\n", "model.eval()\n", "None" ] }, { "cell_type": "code", "execution_count": 4, "id": "317a0bb2", "metadata": {}, "outputs": [], "source": [ "import torch\n", "\n", "def text_to_cross_entropy(text):\n", " ttext = torch.tensor([[0]+list(text.encode('UTF-8'))],dtype=torch.int64).to('cuda')\n", " tone = torch.tensor([[1]],dtype=torch.int32).to('cuda')\n", " logits = model.forward(input_ids=tone, attention_mask=tone, decoder_input_ids=ttext, return_dict=False)[0].detach()\n", " cross_entropy = torch.nn.functional.cross_entropy(input=logits[0][:-1], target=ttext[0][1:], reduction='none').detach().cpu().numpy()\n", " return cross_entropy" ] }, { "cell_type": "code", "execution_count": 5, "id": "aec4c1e1", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Über vier Jahrzehnte gehörte er zu den führenden Bildhauern Niederbayerns\n", "Ü 7.254014\n", "b 0.17521738\n", "e 0.00046933602\n", "r 0.01929327\n", " 0.0003675739\n", "v 0.20927554\n", "i 6.13207\n", "e 0.3896482\n", "r 0.009583538\n", " 2.07364\n", "J 0.02978594\n", "a 2.483246\n", "h 0.1591908\n", "r 0.0045124847\n", "z 0.00028653807\n", "e 4.0242333\n", "h 0.031035878\n", "n 0.028907888\n", "t 0.003264101\n", "e 0.0018929198\n", " 0.05816966\n", "g 1.2782481\n", "e 3.5076692\n", "h 0.694337\n", "ö 0.5319732\n", "r 0.48336726\n", "t 0.0050443523\n", "e 0.0017187123\n", " 0.14511283\n", "e 1.0435015\n", "r 0.18165778\n", " 1.0247636\n", "z 0.3594512\n", "u 0.0077577736\n", " 2.072764\n", "d 0.17377533\n", "e 1.0727838\n", "n 1.2805216\n", " 0.24939628\n", "f 0.27717885\n", "ü 0.012466482\n", "h 4.4356546\n", "r 1.7371752\n", "e 0.051492628\n", "n 2.99407\n", "d 0.009648594\n", "e 0.19667451\n", "n 0.007495021\n", " 0.2529005\n", "B 0.004451485\n", "i 0.024661187\n", "l 0.0028436247\n", "d 2.6620464\n", "h 2.825038\n", "a 0.8215449\n", "u 0.011406565\n", "e 2.9599652\n", "r 0.45834702\n", "n 0.11848967\n", " 0.5955992\n", "N 0.010709903\n", "i 1.5338714\n", "e 0.1834471\n", "d 5.668945\n", "e 2.052247\n", "r 0.7692907\n", "b 0.0675718\n", "a 0.028234791\n", "y 0.0045266068\n", "e 4.1125383\n", "r 1.2630856\n", "n 5.436057\n", "s 0.46446246\n" ] } ], "source": [ "text = all_sentences_de[0]\n", "cross_entropy = text_to_cross_entropy(text)\n", "print(text)\n", "for i in range(len(text)):\n", " print(text[i], cross_entropy[i])" ] }, { "cell_type": "code", "execution_count": 6, "id": "57350f0e", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 2000/2000 [00:09<00:00, 219.00it/s]\n" ] } ], "source": [ "from tqdm import tqdm \n", "\n", "sentence_data = all_sentences_de\n", "\n", "text_and_entropies = []\n", "for text in tqdm(sentence_data):\n", " text_and_entropies.append([text,text_to_cross_entropy(text)])" ] }, { "cell_type": "code", "execution_count": 7, "id": "502fdacc", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 1999/1999 [00:00<00:00, 14645.88it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[('lich', 1000), ('hnte', 999), ('rbay', 999), ('örte', 999), ('hört', 999), ('ahrz', 999), ('jahr', 999), ('bild', 999)]\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 1999/1999 [00:00<00:00, 18574.04it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[('ist', 1000), ('den', 999), ('ber', 999), ('aue', 999), ('ern', 999), ('uer', 999)]\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 1999/1999 [00:00<00:00, 20827.32it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[('ni', 1000), ('ge', 999), ('er', 999), ('fü', 999), ('vi', 999)]\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 1999/1999 [00:00<00:00, 19927.45it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[('e', 2999), ('u', 999), ('n', 999), ('h', 999)]\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "from collections import Counter\n", "\n", "# 4s\n", "#target_lengths = [1]\n", "#token_budgets = [36]\n", "\n", "# 4m\n", "target_lengths = [4,3,2,1]\n", "token_budgets = [40,80,96,36]\n", "\n", "# 4l\n", "#target_lengths = [4,3,2,1]\n", "#token_budgets = [384,320,160,36]\n", "\n", "ngrams = [Counter() for l in target_lengths]\n", "tokens = []\n", "\n", "for tgi,tgl in enumerate(target_lengths):\n", " for row in tqdm(text_and_entropies[1:]):\n", " use_text = row[0]\n", " use_scores = row[1]\n", " for t in tokens:\n", " use_text = use_text.replace(t[0],'#')\n", " candidates = []\n", " for i in range(len(use_text)-(tgl-1)):\n", " part = use_text[i:i+tgl].lower()\n", " if '#' in part: continue\n", " if ' ' in part: continue\n", " if '-' in part: continue\n", " score = sum(use_scores[i:i+tgl])\n", " # print(part, score)\n", " candidates.append([score, part])\n", " candidates.sort(reverse=False)\n", " candidates = candidates[:max(1,int(len(candidates)/5))]\n", " #print(candidates)\n", " ngrams[tgi].update([c[1] for c in candidates])\n", " new_tokens = ngrams[tgi].most_common(token_budgets[tgi])\n", " print(new_tokens)\n", " tokens += new_tokens\n", " #break" ] }, { "cell_type": "code", "execution_count": 8, "id": "323833ad", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "27 ['', '', ' ', 'lich', 'hnte', 'rbay', 'örte', 'hört', 'ahrz', 'jahr', 'bild', 'ist', 'den', 'ber', 'aue', 'ern', 'uer', 'ni', 'ge', 'er', 'fü', 'vi', 'e', 'u', 'n', 'h', '?']\n" ] } ], "source": [ "all_tokens = ['','',' ']+[t[0] for t in tokens]+['?']\n", "print(len(all_tokens), all_tokens)" ] }, { "cell_type": "code", "execution_count": 9, "id": "34724bef", "metadata": {}, "outputs": [], "source": [ "import json\n", "with open('./tevr-tokenizer.txt','wt') as f:\n", " json.dump(all_tokens, f)" ] }, { "cell_type": "code", "execution_count": 10, "id": "72a32893", "metadata": {}, "outputs": [], "source": [ "import sys\n", "import os\n", "sys.path.append(data_folder)\n", "from text_tokenizer import HajoTextTokenizer" ] }, { "cell_type": "code", "execution_count": 11, "id": "a7405c3b", "metadata": {}, "outputs": [], "source": [ "text_tokenizer = HajoTextTokenizer('./tevr-tokenizer.txt')" ] }, { "cell_type": "code", "execution_count": 12, "id": "5ceee8e3", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "gehörte\n", "[18, 25, 6]\n", "['ge', 'h', 'örte']\n", "['gehörte']\n" ] } ], "source": [ "sentence = \"gehörte\"\n", "print(sentence)\n", "encoded = text_tokenizer.encode(sentence)\n", "print(encoded)\n", "print([text_tokenizer.all_tokens[i] for i in encoded])\n", "print([text_tokenizer.decode(encoded)])" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.5" } }, "nbformat": 4, "nbformat_minor": 5 }