endre sukosd commited on
Commit
3992084
1 Parent(s): 4b647de

Semantic Search HU implementation

Browse files
.gitattributes CHANGED
@@ -25,3 +25,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
  *.zstandard filter=lfs diff=lfs merge=lfs -text
27
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
  *.zstandard filter=lfs diff=lfs merge=lfs -text
27
  *tfevents* filter=lfs diff=lfs merge=lfs -text
28
+ data/processed/shortened_abstracts_hu_2021_09_01.txt filter=lfs diff=lfs merge=lfs -text
29
+ data/processed/shortened_abstracts_hu_2021_09_01_embedded.pt filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Custom
2
+ hf_venv/
3
+ data/
4
+ *.DS_Store
5
+
6
+ # Created by https://www.toptal.com/developers/gitignore/api/visualstudiocode,python,jupyternotebooks,venv
7
+ # Edit at https://www.toptal.com/developers/gitignore?templates=visualstudiocode,python,jupyternotebooks,venv
8
+
9
+ ### JupyterNotebooks ###
10
+ # gitignore template for Jupyter Notebooks
11
+ # website: http://jupyter.org/
12
+
13
+ .ipynb_checkpoints
14
+ */.ipynb_checkpoints/*
15
+
16
+ # IPython
17
+ profile_default/
18
+ ipython_config.py
19
+
20
+ # Remove previous ipynb_checkpoints
21
+ # git rm -r .ipynb_checkpoints/
22
+
23
+ ### Python ###
24
+ # Byte-compiled / optimized / DLL files
25
+ __pycache__/
26
+ *.py[cod]
27
+ *$py.class
28
+
29
+ # C extensions
30
+ *.so
31
+
32
+ # Distribution / packaging
33
+ .Python
34
+ build/
35
+ develop-eggs/
36
+ dist/
37
+ downloads/
38
+ eggs/
39
+ .eggs/
40
+ lib/
41
+ lib64/
42
+ parts/
43
+ sdist/
44
+ var/
45
+ wheels/
46
+ share/python-wheels/
47
+ *.egg-info/
48
+ .installed.cfg
49
+ *.egg
50
+ MANIFEST
51
+
52
+ # PyInstaller
53
+ # Usually these files are written by a python script from a template
54
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
55
+ *.manifest
56
+ *.spec
57
+
58
+ # Installer logs
59
+ pip-log.txt
60
+ pip-delete-this-directory.txt
61
+
62
+ # Unit test / coverage reports
63
+ htmlcov/
64
+ .tox/
65
+ .nox/
66
+ .coverage
67
+ .coverage.*
68
+ .cache
69
+ nosetests.xml
70
+ coverage.xml
71
+ *.cover
72
+ *.py,cover
73
+ .hypothesis/
74
+ .pytest_cache/
75
+ cover/
76
+
77
+ # Translations
78
+ *.mo
79
+ *.pot
80
+
81
+ # Django stuff:
82
+ *.log
83
+ local_settings.py
84
+ db.sqlite3
85
+ db.sqlite3-journal
86
+
87
+ # Flask stuff:
88
+ instance/
89
+ .webassets-cache
90
+
91
+ # Scrapy stuff:
92
+ .scrapy
93
+
94
+ # Sphinx documentation
95
+ docs/_build/
96
+
97
+ # PyBuilder
98
+ .pybuilder/
99
+ target/
100
+
101
+ # Jupyter Notebook
102
+
103
+ # IPython
104
+
105
+ # pyenv
106
+ # For a library or package, you might want to ignore these files since the code is
107
+ # intended to run in multiple environments; otherwise, check them in:
108
+ # .python-version
109
+
110
+ # pipenv
111
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
112
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
113
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
114
+ # install all needed dependencies.
115
+ #Pipfile.lock
116
+
117
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
118
+ __pypackages__/
119
+
120
+ # Celery stuff
121
+ celerybeat-schedule
122
+ celerybeat.pid
123
+
124
+ # SageMath parsed files
125
+ *.sage.py
126
+
127
+ # Environments
128
+ .env
129
+ .venv
130
+ env/
131
+ venv/
132
+ ENV/
133
+ env.bak/
134
+ venv.bak/
135
+
136
+ # Spyder project settings
137
+ .spyderproject
138
+ .spyproject
139
+
140
+ # Rope project settings
141
+ .ropeproject
142
+
143
+ # mkdocs documentation
144
+ /site
145
+
146
+ # mypy
147
+ .mypy_cache/
148
+ .dmypy.json
149
+ dmypy.json
150
+
151
+ # Pyre type checker
152
+ .pyre/
153
+
154
+ # pytype static type analyzer
155
+ .pytype/
156
+
157
+ # Cython debug symbols
158
+ cython_debug/
159
+
160
+ ### venv ###
161
+ # Virtualenv
162
+ # http://iamzed.com/2009/05/07/a-primer-on-virtualenv/
163
+ [Bb]in
164
+ [Ii]nclude
165
+ [Ll]ib
166
+ [Ll]ib64
167
+ [Ll]ocal
168
+ [Ss]cripts
169
+ pyvenv.cfg
170
+ pip-selfcheck.json
171
+
172
+ ### VisualStudioCode ###
173
+ .vscode/*
174
+ !.vscode/settings.json
175
+ !.vscode/tasks.json
176
+ !.vscode/launch.json
177
+ !.vscode/extensions.json
178
+ *.code-workspace
179
+
180
+ # Local History for Visual Studio Code
181
+ .history/
182
+
183
+ ### VisualStudioCode Patch ###
184
+ # Ignore all local history of files
185
+ .history
186
+ .ionide
187
+
188
+ # Support for Project snippet scope
189
+ !.vscode/*.code-snippets
190
+
191
+ # End of https://www.toptal.com/developers/gitignore/api/visualstudiocode,python,jupyternotebooks,venv
README.md CHANGED
@@ -1,37 +1,66 @@
1
  ---
2
  title: SemanticSearch HU
 
3
  emoji: 💻
4
- colorFrom: red
5
- colorTo: indigo
 
 
 
6
  sdk: streamlit
7
- app_file: app.py
 
 
8
  pinned: false
 
9
  ---
10
 
11
- # Configuration
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- `title`: _string_
14
- Display title for the Space
15
 
16
- `emoji`: _string_
17
- Space emoji (emoji-only character allowed)
18
 
19
- `colorFrom`: _string_
20
- Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
21
 
22
- `colorTo`: _string_
23
- Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
 
 
 
 
 
24
 
25
- `sdk`: _string_
26
- Can be either `gradio` or `streamlit`
27
 
28
- `sdk_version` : _string_
29
- Only applicable for `streamlit` SDK.
30
- See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions.
31
 
32
- `app_file`: _string_
33
- Path to your main application file (which contains either `gradio` or `streamlit` Python code).
34
- Path is relative to the root of the repository.
35
 
36
- `pinned`: _boolean_
37
- Whether the Space stays on top of your list.
1
  ---
2
  title: SemanticSearch HU
3
+
4
  emoji: 💻
5
+
6
+ colorFrom: green
7
+
8
+ colorTo: white
9
+
10
  sdk: streamlit
11
+
12
+ app_file: src/app.py
13
+
14
  pinned: false
15
+
16
  ---
17
 
18
+ # Huggingface Course Project - 2021 November
19
+ ## Semantic Search system in Hungarian
20
+
21
+ This repo contains my course project created during the week of the Huggingface Course Launch Community event. The selected project is a denser retrieval based semantic search system in my own language, Hungarian. It is based on [this question-answering project idea description](https://discuss.huggingface.co/t/build-a-question-answering-system-in-your-own-language/11570/2).
22
+
23
+ ## Approach
24
+ - finding a **dataset** of question/answer pairs or descriptive paragraphs in my target language (Hungarian)
25
+ - using a **pretrained model** (Hungarian or multilingual) to generate embeddings for all answers, preferably using sentence-transformers
26
+ - **search for top-K matches** - when user query is entered, generate the query embedding and search through all the answer embeddings fo find the top-K most likely documents
27
+
28
+ ## Dataset - raw text
29
+
30
+ Two datasets were evaluated:
31
+ 1. [not used] [MQA - multilingual Question-Answering](https://huggingface.co/datasets/clips/mqa), with a Hungarian subset
32
+
33
+ This datasets contains two types of data:
34
+ * FAQ, about 800.000 questions and answers scraped from different websites (Common Crawl). The problem with this dataset is that it only contains text from roughly 2.000 different domains (so many of the questions and answers are repetitive), and also the quality of the answers varies greatly, for some domains it is not really relevant (for example full of url references).
35
+ * CQA, about 27.000 community question answering examples, which were scraped from different forums. Here for every questions there are several answers, but again the quality of the answers varies greatly, with many answers not being relevant.
36
+
37
+ 2. **[used] [DBpedia - short abstracts in Hungarian](https://databus.dbpedia.org/dbpedia/text/short-abstracts)**
38
+
39
+ This data contains 450.000 shortened abstract from Wikipedia in Hungarian. This represents the text before the table of contents of Wikipedia articles, shortened to approximately 2-3 sentences. These texts seemed like high quality paragraphs, and so I decided to use them as a bank of "answers".
40
+
41
+ The format of the data is of RDF Turtle (Resource Description Framework), which is a rich format to relate metadata and model information. In our case, we just want to use a fraction of this data, only the pure text of each abstract. The raw text was extracted using `rdflib` library seen in the script in `src/data/dbpedia_dump_wiki_text.py`.
42
+
43
+ ## Model - precalculate embeddings
44
 
45
+ To generate the embeddings for each paragraph/shortened abstract, a sentence embedding approach was used. [SBERT.net](https://www.sbert.net/index.html) offers a framework and lots of pretrained models in more than 100 languages to create embeddings and compare them, to find the ones with similar meaning.
 
46
 
47
+ This task is also called STS (Semantic Text Similarity) or Semantic Search, which seeks to find similarity not just based on lexical matches, but by comparing vector representations of the content and thus improving accuracy.
 
48
 
49
+ There were various [pretrained models](https://www.sbert.net/docs/pretrained_models.html#sentence-embedding-models) to choose from. For this project the **`paraphrase-multilingual-MiniLM-L12-v2`** checkpoint is used, as this is one of the smallest multilingual models at 418 MB, but it has the second fastest encoding speed, which seems like a good compromise.
 
50
 
51
+ ```
52
+ Model facts:
53
+ - Checkpoint name: paraphrase-multilingual-MiniLM-L12-v2
54
+ - Dimensions: 384
55
+ - Suitable Score Functions: cosine-similarity
56
+ - Pooling: Mean Pooling
57
+ ```
58
 
59
+ - Embeddings were calculated based on code examples from [huggingface hub](https://huggingface.co/sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2)
60
+ - Similarity scores were calculated based on code example from [sentence-transformers site](https://www.sbert.net/examples/applications/semantic-search/README.html)
61
 
62
+ To reproduce the precalculated embedding use the notebook in `notebooks/QA_retrieval_precalculate_embeddings.ipynb`, with GPU in Google Colab.
 
 
63
 
64
+ ## Search top-k matches
 
 
65
 
66
+ Finally, having all precalculated embeddings, we can to implement semantic search (dense retrieval).We encode the search query into vector space and retrieves the document embeddings that are closest in vector space (using cosine similarity). By default the top 5 similar wikipedia abstracts are returned. Can be seen in the main script `src/main_qa.py`.
 
approach.txt ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Types of Question Answering
3
+ - extractive question answering (encoder only models BERT)
4
+ - posing questions about a document and identifying the answers as spans of text in the document itself
5
+ - generative question answering (encoder-decoder T5/BART)
6
+ - open ended questions, which need to synthesize information
7
+ - retrieval based/community question answering
8
+
9
+
10
+
11
+ First approach - translate dataset, fine-tune model
12
+ !Not really feasible, because it needs lots of human evaluation for correctly determine answer start token
13
+
14
+ 1. Translate English QA dataset into Hungarian
15
+ - SQuAD - reading comprehension based on Wikipedia articles
16
+ - ~ 100.000 question/answers
17
+ 2. Fine-tune a model and evaluate on this dataset
18
+
19
+
20
+ Second approach - fine-tune multilingual model
21
+ !MQA format different than SQuAD, cannot use ModelForQuestionAnswering
22
+
23
+ 1. Use a Hungarian dataset
24
+ - MQA - multilingual parsed from Common Crawl
25
+ - FAQ - 878.385 (2.415 domain)
26
+ - CQA - 27.639 (171 domain)
27
+ 2. Fine-tune and evaluate a model on this dataset
28
+
29
+
30
+ Possible steps:
31
+ - Use an existing pre-trained model in Hungarian/Romanian/or multilingual to generate embeddings
32
+ - Select Model:
33
+ - multilingual which includes hu:
34
+ - distiluse-base-multilingual-cased-v2 (400MB)
35
+ - paraphrase-multilingual-MiniLM-L12-v2 (400MB) - fastest
36
+ - paraphrase-multilingual-mpnet-base-v2 (900MB) - best performing
37
+ - hubert
38
+ - Select a dataset
39
+ - use MQA hungarian subset
40
+ - use hungarian wikipedia pages data, split it up
41
+ - DBpedia, shortened abstracts = 500.000
42
+ - Pre-compute embeddings for all answers/paragraphs
43
+ - Compute embedding for incoming query
44
+ - Compare similarity between query embedding and precomputed
45
+ - return top-3 answers/questions
46
+
47
+ Alternative steps:
48
+ - train a sentence transformer on the Hungarian / Romanian subsets
49
+ - Use the trained sentence transformer to generate embeddings
notebooks/QA_retrieval_precalculate_embeddings.ipynb ADDED
@@ -0,0 +1 @@
 
1
+ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"QA_retrieval_huggingface_couser_2021_Nov.ipynb","provenance":[],"collapsed_sections":[],"mount_file_id":"1e_NcpgIuSh8rfI_Xf16ltcybK8TbgJWB","authorship_tag":"ABX9TyN3TvKBRyS+wRVSLWNFgC+f"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"},"accelerator":"GPU"},"cells":[{"cell_type":"code","metadata":{"id":"GI4Sz98ItJW7"},"source":["# TPU\n","# !pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.10-cp37-cp37m-linux_x86_64.whl"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"97-OsdFhlD20","executionInfo":{"status":"ok","timestamp":1637680969592,"user_tz":-60,"elapsed":3348,"user":{"displayName":"Sukosd Endre","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GioD4JjUyxYNK5t_w13NsB1TIlQ1P_x6Xj99-re5w=s64","userId":"02963673169135048018"}},"outputId":"c47a98a7-f016-4a4f-827b-edc9229c5eca"},"source":["!pip install transformers sentence_transformers"],"execution_count":9,"outputs":[{"output_type":"stream","name":"stdout","text":["Requirement already satisfied: transformers in /usr/local/lib/python3.7/dist-packages (4.12.5)\n","Requirement already satisfied: sentence_transformers in /usr/local/lib/python3.7/dist-packages (2.1.0)\n","Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.7/dist-packages (from transformers) (6.0)\n","Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from transformers) (2.23.0)\n","Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (2019.12.20)\n","Requirement already satisfied: tokenizers<0.11,>=0.10.1 in /usr/local/lib/python3.7/dist-packages (from transformers) (0.10.3)\n","Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.7/dist-packages (from transformers) (21.3)\n","Requirement already satisfied: huggingface-hub<1.0,>=0.1.0 in /usr/local/lib/python3.7/dist-packages (from transformers) (0.1.2)\n","Requirement already satisfied: sacremoses in /usr/local/lib/python3.7/dist-packages (from transformers) (0.0.46)\n","Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (1.19.5)\n","Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from transformers) (3.4.0)\n","Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from transformers) (4.8.2)\n","Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.7/dist-packages (from transformers) (4.62.3)\n","Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.7/dist-packages (from huggingface-hub<1.0,>=0.1.0->transformers) (3.10.0.2)\n","Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=20.0->transformers) (3.0.6)\n","Requirement already satisfied: nltk in /usr/local/lib/python3.7/dist-packages (from sentence_transformers) (3.2.5)\n","Requirement already satisfied: scikit-learn in /usr/local/lib/python3.7/dist-packages (from sentence_transformers) (1.0.1)\n","Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from sentence_transformers) (1.4.1)\n","Requirement already satisfied: sentencepiece in /usr/local/lib/python3.7/dist-packages (from sentence_transformers) (0.1.96)\n","Requirement already satisfied: torchvision in /usr/local/lib/python3.7/dist-packages (from sentence_transformers) (0.11.1+cu111)\n","Requirement already satisfied: torch>=1.6.0 in /usr/local/lib/python3.7/dist-packages (from sentence_transformers) (1.10.0+cu111)\n","Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->transformers) (3.6.0)\n","Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from nltk->sentence_transformers) (1.15.0)\n","Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2.10)\n","Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2021.10.8)\n","Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (1.24.3)\n","Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (3.0.4)\n","Requirement already satisfied: click in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers) (7.1.2)\n","Requirement already satisfied: joblib in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers) (1.1.0)\n","Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->sentence_transformers) (3.0.0)\n","Requirement already satisfied: pillow!=8.3.0,>=5.3.0 in /usr/local/lib/python3.7/dist-packages (from torchvision->sentence_transformers) (7.1.2)\n"]}]},{"cell_type":"code","metadata":{"id":"3-jkyQkdkdPQ","executionInfo":{"status":"ok","timestamp":1637680970023,"user_tz":-60,"elapsed":3,"user":{"displayName":"Sukosd Endre","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GioD4JjUyxYNK5t_w13NsB1TIlQ1P_x6Xj99-re5w=s64","userId":"02963673169135048018"}}},"source":["from transformers import AutoTokenizer, AutoModel\n","import torch\n","import pickle\n","from sentence_transformers import util\n","from datetime import datetime"],"execution_count":10,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"kA2h5mH8m-n8","executionInfo":{"status":"ok","timestamp":1637654036646,"user_tz":-60,"elapsed":26589,"user":{"displayName":"Sukosd Endre","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GioD4JjUyxYNK5t_w13NsB1TIlQ1P_x6Xj99-re5w=s64","userId":"02963673169135048018"}},"outputId":"88fcd97f-276c-4f70-de60-d1c5c9810443"},"source":["from google.colab import drive\n","drive.mount('/content/drive')\n","#drive.mount('/content/drive', force_remount=True)"],"execution_count":3,"outputs":[{"output_type":"stream","name":"stdout","text":["Mounted at /content/drive\n"]}]},{"cell_type":"markdown","metadata":{"id":"b8SkQGWuB1z7"},"source":["# Load pretrained \n","\n","- multilingual sentence transformers from checkpoint\n","- tokenizer from checkpoint"]},{"cell_type":"code","metadata":{"id":"1R83LLVAk98K","executionInfo":{"status":"ok","timestamp":1637655426545,"user_tz":-60,"elapsed":6237,"user":{"displayName":"Sukosd Endre","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GioD4JjUyxYNK5t_w13NsB1TIlQ1P_x6Xj99-re5w=s64","userId":"02963673169135048018"}}},"source":["multilingual_checkpoint = 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'\n","tokenizer = AutoTokenizer.from_pretrained(multilingual_checkpoint)\n","model = AutoModel.from_pretrained(multilingual_checkpoint)"],"execution_count":3,"outputs":[]},{"cell_type":"code","metadata":{"id":"wcdik3tQpkyi"},"source":["# GPU\n","device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n","model.to(device)\n","print(device)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"-YzAkemLsrC9"},"source":["# TPU\n","# unfortunately incompatible wheel package for pytorch-xla 1.10 version\n","#import torch_xla.core.xla_model as xm\n","#device = xm.xla_device()\n","#print(device)\n","#pip list | grep torch"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"dfeEQJOglxdw","executionInfo":{"status":"ok","timestamp":1637682096594,"user_tz":-60,"elapsed":362,"user":{"displayName":"Sukosd Endre","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GioD4JjUyxYNK5t_w13NsB1TIlQ1P_x6Xj99-re5w=s64","userId":"02963673169135048018"}}},"source":["#Mean Pooling - Take attention mask into account for correct averaging\n","def mean_pooling(model_output, attention_mask):\n"," token_embeddings = model_output[0] #First element of model_output contains all token embeddings\n"," input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()\n"," sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)\n"," sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)\n"," return sum_embeddings / sum_mask\n","\n","def calculateEmbeddings(sentences,tokenizer,model,device=\"cpu\"):\n"," tokenized_sentences = tokenizer(sentences, padding=True, truncation=True, max_length=128, return_tensors='pt')\n"," tokenized_sentences.to(device)\n"," with torch.no_grad():\n"," model_output = model(**tokenized_sentences)\n"," sentence_embeddings = mean_pooling(model_output, tokenized_sentences['attention_mask'])\n"," del tokenized_sentences\n"," torch.cuda.empty_cache()\n"," return sentence_embeddings\n","\n","def findTopKMostSimilar(query_embedding, embeddings, k):\n"," cosine_scores = util.pytorch_cos_sim(query_embedding, embeddings)\n"," cosine_scores_list = cosine_scores.squeeze().tolist()\n"," pairs = []\n"," for idx,score in enumerate(cosine_scores_list):\n"," pairs.append({'index': idx, 'score': score})\n"," pairs = sorted(pairs, key=lambda x: x['score'], reverse=True)\n"," return pairs[0:k]\n","\n","def saveToDisc(embeddings, output_filename):\n"," with open(output_filename, \"ab\") as f:\n"," pickle.dump(embeddings, f, protocol=pickle.HIGHEST_PROTOCOL)"],"execution_count":23,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"MddjkKfMCH81"},"source":["# Create sentence embeddings\n","\n","\n","* Load sentences from raw text file\n","* Precalculate in batches of 1000, to avoid running out of memory\n","* Save to disc/files incrementally, to be able to reuse later (in total 5 files of 100.000 embedding each)\n","\n"]},{"cell_type":"code","metadata":{"id":"yfOsCAVImIAl"},"source":["batch_size = 1000\n","\n","raw_text_file = '/content/drive/MyDrive/huggingface/shortened_abstracts_hu_2021_09_01.txt'\n","datetime_formatted = datetime.now().strftime('%Y-%m-%d_%H:%M:%S')\n","output_embeddings_file_batched = f'/content/drive/MyDrive/huggingface/embeddings_{batch_size}_batches_at_{datetime_formatted}.pkl'\n","output_embeddings_file = f'/content/drive/MyDrive/huggingface/embeddings_at_{datetime_formatted}.pkl'\n","\n","print(datetime.now())\n","concated_sentence_embeddings = None\n","all_sentences = []\n","line = 'init'\n","total_read = 0\n","total_read_limit = 500000\n","skip_index = 400000\n","with open(raw_text_file) as f:\n"," while line and total_read < total_read_limit:\n"," count = 0\n"," sentence_batch = []\n"," while line and count < batch_size:\n"," line = f.readline()\n"," sentence_batch.append(line)\n"," count += 1\n"," \n"," all_sentences.extend(sentence_batch)\n"," \n"," if total_read >= skip_index:\n"," sentence_embeddings = calculateEmbeddings(sentence_batch,tokenizer,model,device)\n"," if concated_sentence_embeddings == None:\n"," concated_sentence_embeddings = sentence_embeddings\n"," else:\n"," concated_sentence_embeddings = torch.cat([concated_sentence_embeddings, sentence_embeddings], dim=0)\n"," print(concated_sentence_embeddings.size())\n"," saveToDisc(sentence_embeddings,output_embeddings_file_batched)\n"," total_read += count\n","print(datetime.now())"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"1rGQc9GRCuNy"},"source":["# Test: Query embeddings"]},{"cell_type":"code","metadata":{"id":"FT7CwpM0Bwhi"},"source":["query_embedding = calculateEmbeddings(['Melyik a legnépesebb város a világon?'],tokenizer,model,device)\n","top_pairs = findTopKMostSimilar(query_embedding, concated_sentence_embeddings, 5)\n","\n","for pair in top_pairs:\n"," i = pair['index']\n"," score = pair['score']\n"," print(\"{} \\t\\t Score: {:.4f}\".format(all_sentences[skip_index+i], score))"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"6Hdu_5FiDYJr"},"source":["# Test: Load pre-calculated embeddings\n","\n","* Load embedding from files and stitch them together\n","* Save into one file\n"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"gkWt0Uj_Ddsp","executionInfo":{"status":"ok","timestamp":1637682006152,"user_tz":-60,"elapsed":1722,"user":{"displayName":"Sukosd Endre","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GioD4JjUyxYNK5t_w13NsB1TIlQ1P_x6Xj99-re5w=s64","userId":"02963673169135048018"}},"outputId":"1921456e-1fd6-4218-9ebb-cbe503f402b1"},"source":["def concatTensors(new_tensor, acc_tensor='None'):\n"," if acc_tensor == None:\n"," acc_tensor = new_tensor\n"," else:\n"," acc_tensor = torch.cat([acc_tensor, new_tensor], dim=0)\n"," return acc_tensor\n","\n","def loadFromDisc(batch_size, number_of_batches, filename):\n"," concated_sentence_embeddings = None\n"," count = 0\n"," batches = 0\n"," with open(filename, \"rb\") as f:\n"," loaded_embeddings = torch.empty([batch_size])\n"," while count < number_of_batches and loaded_embeddings.size()[0]==batch_size:\n"," loaded_embeddings = pickle.load(f)\n"," count += 1\n"," concated_sentence_embeddings = concatTensors(loaded_embeddings,concated_sentence_embeddings)\n"," print(f'Read file using {count} number of read+unpickle operations')\n"," print(concated_sentence_embeddings.size())\n"," return concated_sentence_embeddings\n","\n","\n","output_embeddings_file = 'data/processed/DBpedia_shortened_abstracts_hu_embeddings.pkl'\n","\n","embeddings_files = [\n"," '/content/drive/MyDrive/huggingface/embeddings_1000_batches_at_2021-11-23_08:17:17.pkl',\n"," '/content/drive/MyDrive/huggingface/embeddings_1000_batches_at_2021-11-23_08:28:46.pkl',\n"," '/content/drive/MyDrive/huggingface/embeddings_1000_batches_at_2021-11-23_08:40:54.pkl',\n"," '/content/drive/MyDrive/huggingface/embeddings_1000_batches_at_2021-11-23_08:56:26.pkl',\n"," '/content/drive/MyDrive/huggingface/embeddings_1000_batches_at_2021-11-23_09:31:47.pkl'\n","]\n","\n","all_embeddings = None\n","for idx,emb_file in enumerate(embeddings_files):\n"," print(f'Processing file {idx}')\n"," file_embeddings = loadFromDisc(1000, 100, emb_file)\n"," all_embeddings = concatTensors(file_embeddings,all_embeddings)\n","\n","print(all_embeddings.size())"],"execution_count":20,"outputs":[{"output_type":"stream","name":"stdout","text":["Processing file 0\n","Read file using 100 number of read+unpickle operations\n","torch.Size([100000, 384])\n","Processing file 1\n","Read file using 100 number of read+unpickle operations\n","torch.Size([100000, 384])\n","Processing file 2\n","Read file using 100 number of read+unpickle operations\n","torch.Size([100000, 384])\n","Processing file 3\n","Read file using 100 number of read+unpickle operations\n","torch.Size([100000, 384])\n","Processing file 4\n","Read file using 67 number of read+unpickle operations\n","torch.Size([66529, 384])\n","torch.Size([466529, 384])\n"]}]},{"cell_type":"code","metadata":{"id":"M_8RHpNnIU7o","executionInfo":{"status":"ok","timestamp":1637683739951,"user_tz":-60,"elapsed":384,"user":{"displayName":"Sukosd Endre","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GioD4JjUyxYNK5t_w13NsB1TIlQ1P_x6Xj99-re5w=s64","userId":"02963673169135048018"}}},"source":["all_embeddings_output_file = '/content/drive/MyDrive/huggingface/shortened_abstracts_hu_2021_09_01_embedded.pt'\n","#saveToDisc(all_embeddings, all_embeddings_output_file)\n","torch.save(all_embeddings,all_embeddings_output_file)"],"execution_count":28,"outputs":[]},{"cell_type":"code","metadata":{"id":"LYCwyDpMjsXg"},"source":[""],"execution_count":null,"outputs":[]}]}
notebooks/dbpedia_qa_test.ipynb ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 18,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stderr",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "Using custom data configuration hu-faq-question-language=hu,scope=faq\n",
13
+ "Reusing dataset mqa (/Users/eend/.cache/huggingface/datasets/clips___mqa/hu-faq-question-language=hu,scope=faq/0.0.0/7eda4cdcbd6f009259fc516f204d776915a5f54ea2ad414c3dcddfaacd4dfe0b)\n",
14
+ "100%|██████████| 1/1 [00:00<00:00, 70.47it/s]\n",
15
+ "Using custom data configuration hu-cqa-question-language=hu,scope=cqa\n",
16
+ "Reusing dataset mqa (/Users/eend/.cache/huggingface/datasets/clips___mqa/hu-cqa-question-language=hu,scope=cqa/0.0.0/7eda4cdcbd6f009259fc516f204d776915a5f54ea2ad414c3dcddfaacd4dfe0b)\n",
17
+ "100%|██████████| 1/1 [00:00<00:00, 389.26it/s]\n",
18
+ "Downloading: 5.27kB [00:00, 2.07MB/s] \n",
19
+ "Downloading: 2.36kB [00:00, 1.39MB/s] \n"
20
+ ]
21
+ },
22
+ {
23
+ "name": "stdout",
24
+ "output_type": "stream",
25
+ "text": [
26
+ "Downloading and preparing dataset squad/plain_text (download: 33.51 MiB, generated: 85.63 MiB, post-processed: Unknown size, total: 119.14 MiB) to /Users/eend/.cache/huggingface/datasets/squad/plain_text/1.0.0/d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453...\n"
27
+ ]
28
+ },
29
+ {
30
+ "name": "stderr",
31
+ "output_type": "stream",
32
+ "text": [
33
+ "Downloading: 30.3MB [00:00, 78.5MB/s]\n",
34
+ "Downloading: 4.85MB [00:00, 63.4MB/s] \n",
35
+ "100%|██████████| 2/2 [00:01<00:00, 1.16it/s]\n",
36
+ "100%|██████████| 2/2 [00:00<00:00, 709.70it/s]\n"
37
+ ]
38
+ },
39
+ {
40
+ "name": "stdout",
41
+ "output_type": "stream",
42
+ "text": [
43
+ "Dataset squad downloaded and prepared to /Users/eend/.cache/huggingface/datasets/squad/plain_text/1.0.0/d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453. Subsequent calls will reuse this data.\n"
44
+ ]
45
+ },
46
+ {
47
+ "name": "stderr",
48
+ "output_type": "stream",
49
+ "text": [
50
+ "100%|██████████| 2/2 [00:00<00:00, 259.74it/s]\n"
51
+ ]
52
+ }
53
+ ],
54
+ "source": [
55
+ "from datasets import load_dataset\n",
56
+ "\n",
57
+ "faq_hu = load_dataset(\"clips/mqa\", scope=\"faq\", language=\"hu\")\n",
58
+ "cqa_hu = load_dataset(\"clips/mqa\", scope=\"cqa\", language=\"hu\")\n",
59
+ "squad = load_dataset(\"squad\")"
60
+ ]
61
+ },
62
+ {
63
+ "cell_type": "code",
64
+ "execution_count": 19,
65
+ "metadata": {},
66
+ "outputs": [
67
+ {
68
+ "data": {
69
+ "text/plain": [
70
+ "{'id': ['5733be284776f41900661182', '5733be284776f4190066117f'],\n",
71
+ " 'title': ['University_of_Notre_Dame', 'University_of_Notre_Dame'],\n",
72
+ " 'context': ['Architecturally, the school has a Catholic character. Atop the Main Building\\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend \"Venite Ad Me Omnes\". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.',\n",
73
+ " 'Architecturally, the school has a Catholic character. Atop the Main Building\\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend \"Venite Ad Me Omnes\". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.'],\n",
74
+ " 'question': ['To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?',\n",
75
+ " 'What is in front of the Notre Dame Main Building?'],\n",
76
+ " 'answers': [{'text': ['Saint Bernadette Soubirous'], 'answer_start': [515]},\n",
77
+ " {'text': ['a copper statue of Christ'], 'answer_start': [188]}]}"
78
+ ]
79
+ },
80
+ "execution_count": 19,
81
+ "metadata": {},
82
+ "output_type": "execute_result"
83
+ }
84
+ ],
85
+ "source": [
86
+ "squad['train'][:2]"
87
+ ]
88
+ },
89
+ {
90
+ "cell_type": "code",
91
+ "execution_count": 21,
92
+ "metadata": {},
93
+ "outputs": [
94
+ {
95
+ "data": {
96
+ "text/plain": [
97
+ "{'id': ['4ddf184a540032092a43461d4904ffc0',\n",
98
+ " '2d3fd2e40d3369e9e03acb43f8290d23'],\n",
99
+ " 'text': ['\\n**[JavaFX 1.0](http://www.javafx.com/)** adták csütörtök december 4. \\n\\n\\nMint a fejlesztő, mit gondol a JavaFX? A munkahelyen, van olyan tervei, hogy lépjenek előre JavaFX-alapú alkalmazások vagy weboldalak? Van rövid távú tervek tanulni JavaFX?\\n\\n ',\n",
100
+ " '\\nÉn portolása egy játék, amelyet eredetileg írt a Win32 API, Linux (jó, portolása az OS X port a Win32 port Linux).\\n\\n\\nAzt már végre `QueryPerformanceCounter`azzal, hogy a uSeconds mivel az eljárás elindításához: \\n\\n\\n\\n```\\nBOOL QueryPerformanceCounter(LARGE_INTEGER* performanceCount)\\n{\\n gettimeofday(&currentTimeVal, NULL);\\n performanceCount->QuadPart = (currentTimeVal.tv_sec - startTimeVal.tv_sec);\\n performanceCount->QuadPart *= (1000 * 1000);\\n performanceCount->QuadPart += (currentTimeVal.tv_usec - startTimeVal.tv_usec);\\n\\n return true;\\n}\\n\\n```\\n\\nEz, párosulva `QueryPerformanceFrequency()`így állandó 1000000 a frekvencia, jól működik **a gépemen** , hogy nekem egy 64 bites változót, amely `uSeconds`, mivel a program induló.\\n\\n\\nÍgy *van ez a hordozható?* Nem akarom, hogy felfedezzék azt másként működik, ha a kernel-ben összeállított egy bizonyos módon, vagy ilyesmi. Jól vagyok vele, hogy nem hordozható, hogy valami más, mint a Linux, de.\\n\\n '],\n",
101
+ " 'name': ['JavaFX 1.0 megjelent. Ön mit gondol?\\n====================================\\n\\n',\n",
102
+ " 'Van gettimeofday () garantáltan a us felbontás?\\n===============================================\\n\\n'],\n",
103
+ " 'domain': ['coredump.biz', 'coredump.biz'],\n",
104
+ " 'bucket': ['2020.05', '2020.10'],\n",
105
+ " 'answers': [[{'text': '\\nAmennyire én tudom JavaFX esik lapos rajta arcát. Átnéztem a demókat és példa forráskódot és nem vagyok lenyűgözve. JavaFX egy fárasztó csata, hogy a versenyt az Adobe és a Silverlight amelyeket már a vad egy ideig. Figyeljük meg, hogy én vagyok sokáig Java fejlesztő.\\n\\n ',\n",
106
+ " 'name': '',\n",
107
+ " 'is_accepted': False},\n",
108
+ " {'text': '\\nÉn biztosan gondolom, hogy érdemes egy pillantást, mint amilyennek látszik, mint a RIA itt maradni, és a több platformon / döntéseket, annál jobb. Sun biztos módja mögött, bár tekintve Micrsoft késett a játék Siliverlight és még előttünk álló út V Ha mást nem is, azt szeretném látni, hogy mit tett a Sun másként azok végrehajtását az Adobe és a Microsoft.\\n\\n ',\n",
109
+ " 'name': '',\n",
110
+ " 'is_accepted': False},\n",
111
+ " {'text': '\\nJavaFX az egyetlen nyílt RIA platform, így azt hiszem, hogy felzárkózzon a verseny előbb vagy utóbb.\\n\\n\\nÚgy néz ki, nagyon jó a 1.0 verzió. Demos jól dokumentáltak, és kínál mindent, amit kell.\\n\\n\\nVannak problémák természetesen. Java applet tűnik javult egy kicsit, de ez még mindig messze elmarad. Berakás hosszú ideig magas CPU terhelés. Ez nem mutat előrelépést, mint a szokásos flash alkalmazás tenni, így a felhasználó nem lehet biztos abban, hogy az ő internet lassú, applet vagy nagy java lassú. Azt is el kellett fogadnia bizonyítvány, még több, mint egy néhány demót.\\n\\n ',\n",
112
+ " 'name': '',\n",
113
+ " 'is_accepted': False},\n",
114
+ " {'text': '\\nSzeretem a koncepció JavaFX, de nem volt az esélye, hogy bármit vele. Én nem rendesen Internet alkalmazások, így azt kell menni az utamból, hogy próbálja ki a legújabb platformokon.\\n\\n ',\n",
115
+ " 'name': '',\n",
116
+ " 'is_accepted': False},\n",
117
+ " {'text': '\\nMég egy dolog, hogy megtanulják, hogy szeretnék tanulni, hogy nem volt ideje megtanulni.\\n\\n\\nÍgéretesnek tűnik, de egyetértek másokkal. Ez egy fárasztó csata, és ott van a kétség, hogy ez lesz jellemző a hosszú távon. Egy pozitív Java FX, hogy ez lesz meghosszabbítja a karrierem-beruházások a Java nyelvet.\\n\\n\\nAzt is minél több RIA platformok kialakulni a JVM - így míg a Java csökkenhetnek a JVM továbbra is.\\n\\n ',\n",
118
+ " 'name': '',\n",
119
+ " 'is_accepted': False}],\n",
120
+ " [{'text': '\\nTalán. De van nagyobb problémákat. `gettimeofday()`eredményezhet helytelen időzítés, ha vannak olyan folyamatok a rendszer, hogy a változás az időzítő (azaz ntpd). Egy „normális” linux, bár úgy vélem, a felbontás `gettimeofday()`is 10us. Meg lehet ugrani előre és hátra, és időt, következésképpen alapuló folyamatok fut a rendszer. Ez hatékonyan teszi a választ a kérdésre nincs.\\n\\n\\nMeg kell nézni `clock_gettime(CLOCK_MONOTONIC)`az időzítés időközönként. Ez azonban számos kisebb problémák miatt a dolgok, mint a többmagos rendszerek és külső órajel beállításokat.\\n\\n\\nIs, nézd át a `clock_getres()`funkciót.\\n\\n ',\n",
121
+ " 'name': '',\n",
122
+ " 'is_accepted': True},\n",
123
+ " {'text': '\\nAz én tapasztalataim és amit olvastam az interneten keresztül, a válasz „Nem”, akkor nem garantált. Attól függ, hogy a processzor sebességét, az operációs rendszer, ízét Linux, stb\\n\\n ',\n",
124
+ " 'name': '',\n",
125
+ " 'is_accepted': False},\n",
126
+ " {'text': '\\n\\n> \\n> A tényleges felbontása gettimeofday () függ a hardver architektúra. Intel processzorok, valamint a SPARC gépeket kínálnak nagyfelbontású időzítő, amelyek mérik ezredmásodperc. Egyéb hardverarchitektúrák esik vissza a rendszer időzítő, amely tipikusan beállítása 100 Hz. Ezekben az esetekben az idő felbontás kevésbé lesznek pontosak.\\n> \\n> \\n> \\n\\n\\nKaptam ezt a választ [High Resolution Időmérés és időzítők, I. rész](http://web.archive.org/web/20160711223333/http://www.informit.com/guides/content.aspx?g=cplusplus&seqNum=272)\\n\\n ',\n",
127
+ " 'name': '',\n",
128
+ " 'is_accepted': False},\n",
129
+ " {'text': '\\n**Nagy felbontású, alacsony rezsi időzítése Intel processzorok**\\n\\n\\nHa az Intel hardver, itt van, hogy olvassa el a CPU valós idejű használati számláló. Azt fogja mondani, a CPU-ciklusok számát óta végrehajtott processzort elindult. Ez talán a legfinomabb szemcséjű számláló kaphat a teljesítmény méréséhez.\\n\\n\\nMegjegyzendő, hogy ez a szám a CPU ciklusokat. A linux kaphat a processzor sebességét a / proc / cpuinfo és osztódnak, hogy a másodpercek száma. Alakítja át ezt a kettős elég praktikus.\\n\\n\\nAmikor futtatom ezt én doboz, kapok\\n\\n\\n\\n```\\n11867927879484732\\n11867927879692217\\nit took this long to call printf: 207485\\n\\n```\\n\\nItt a [Intel fejlesztői útmutatót](http://cs.smu.ca/~jamuir/rdtscpm1.pdf) ad tonna részletességgel.\\n\\n\\n\\n```\\n#include <stdio.h>\\n#include <stdint.h>\\n\\ninline uint64_t rdtsc() {\\n uint32_t lo, hi;\\n __asm__ __volatile__ (\\n \"xorl %%eax, %%eax\\\\n\"\\n \"cpuid\\\\n\"\\n \"rdtsc\\\\n\"\\n : \"=a\" (lo), \"=d\" (hi)\\n :\\n : \"%ebx\", \"%ecx\");\\n return (uint64_t)hi << 32 | lo;\\n}\\n\\nmain()\\n{\\n unsigned long long x;\\n unsigned long long y;\\n x = rdtsc();\\n printf(\"%lld\\\\n\",x);\\n y = rdtsc();\\n printf(\"%lld\\\\n\",y);\\n printf(\"it took this long to call printf: %lld\\\\n\",y-x);\\n}\\n\\n```\\n ',\n",
130
+ " 'name': '',\n",
131
+ " 'is_accepted': False},\n",
132
+ " {'text': '\\n\\n> \\n> Tehát azt mondja ezredmásodperc kifejezetten, de azt mondja, a felbontás a rendszer órája nincs megadva. Gondolom felbontás ebben az összefüggésben azt jelenti, hogy az a legkisebb összeg, hogy valaha is növekedhet?\\n> \\n> \\n> \\n\\n\\nAz adatstruktúra úgy definiáljuk, mint amelynek mikroszekundum, mint egy mértékegység, de ez nem jelenti azt, hogy az óra vagy az operációs rendszer valójában képes mérni, hogy finoman.\\n\\n\\nMint a többi ember azt, `gettimeofday()`rossz, mert az idő beállításával okozhat órajelelcsúszás és dobja ki a számításból. `clock_gettime(CLOCK_MONOTONIC)`az, amit akarsz, és `clock_getres()`megmondja, hogy a pontosság az óra.\\n\\n ',\n",
133
+ " 'name': '',\n",
134
+ " 'is_accepted': False},\n",
135
+ " {'text': '\\n@Bernard:\\n\\n\\n\\n> \\n> Be kell vallanom, a legtöbb példa egyenesen a fejem fölött. Ez nem fordul le, és úgy tűnik, működik, mégis. Biztonságos ez az SMP rendszerek vagy SpeedStep?\\n> \\n> \\n> \\n\\n\\nEz egy jó kérdés ... Azt hiszem, hogy a kód rendben van. Gyakorlati szempontból, tudjuk használni a cégem minden nap, és mi fut elég széles skáláját dobozok, minden 2-8 magot. Természetesen YMMV stb, de úgy tűnik, hogy egy megbízható és alacsony rezsi (mert nem teszi vál- tani rendszer-space) módszer az időzítés.\\n\\n\\nÁltalában hogyan működik:\\n\\n\\n* állapítsa meg a blokk kódot kell szerelő (és illékony, ezért az optimalizáló hagyják egyedül).\\n* végrehajtja a CPUID utasítást. Amellett, hogy egyre néhány CPU információk (amelyek nem teszünk semmit) szinkronizálja a CPU végrehajtási puffert úgy hogy az időzítést nem befolyásolja out-of-order végrehajtás.\\n* végrehajtja a rdtsc (értsd timestamp) végrehajtását. Ez letölti száma gépi ciklus óta végrehajtott processzor alaphelyzetbe állt. Ez egy 64 bites érték, így a jelenlegi CPU sebességet akkor körülveszi minden 194 év múlva. Érdekes, hogy az eredeti Pentium referencia megjegyzik, hogy körbe minden 5800 évben.\\n* Az elmúlt pár sor tárolja az értékeket a regiszterek a változók hi és lo, és tegye, hogy a 64 bites visszatérési értéke.\\n\\n\\nKülönös megjegyzések:\\n\\n\\n* out-of-order végrehajtás okozhat hibás eredményeket, így végre a „CPUID” utasítás, amely azon túlmenően, hogy egy kis információt a processzor is szinkronizálja bármely out-of-order utasítás végrehajtását.\\n* A legtöbb operációs rendszer szinkronizálja a számlálók a CPU mikor indul el, így a válasz jó, hogy egy pár nano-másodperc.\\n* A téli álmot alvó megjegyzés valószínűleg igaz, de a gyakorlatban valószínűleg nem törődnek időzítések között hibernáció határokat.\\n* kapcsolatos SpeedStep: Újabb Intel CPU kompenzálja a sebesség változik, és visszatér egy beállított száma. Tettem egy gyors át néhány doboz a hálózatunkon, és már csak egy doboz, amely nem volt meg: a Pentium 3 fut néhány régi adatbázis szerver. (Ezek linux dobozokat, így egyeztettem: grep constant\\\\_tsc / proc / cpuinfo)\\n* Nem vagyok biztos abban, hogy a AMD CPU vagyunk elsősorban Intel bolt, bár tudom, hogy néhány alacsony szintű rendszerek guruk volt egy AMD értékelést.\\n\\n\\nRemélem ez kielégíti a kíváncsiságát, ez egy érdekes és (IMHO) keretében tanulmányozott programozás területén. Tudod, amikor Jeff és Joel volt szó, hogy egy programozó kell tudni C? Azt kiabált nekik, hogy „hé elfelejteni, hogy a magas szintű C dolgok ... szerelő, amit meg kell tanulni, ha azt szeretné tudni, hogy mi a számítógép csinál!”\\n\\n ',\n",
136
+ " 'name': '',\n",
137
+ " 'is_accepted': False},\n",
138
+ " {'text': '\\nA bor valóban használja gettimeofday (), hogy végre QueryPerformanceCounter (), és köztudott, hogy sok Windows játékok dolgozni Linux és Mac.\\n\\n\\nElindítja <http://source.winehq.org/source/dlls/kernel32/cpu.c#L312>\\n\\n\\nvezet <http://source.winehq.org/source/dlls/ntdll/time.c#L448>\\n\\n ',\n",
139
+ " 'name': '',\n",
140
+ " 'is_accepted': False},\n",
141
+ " {'text': '\\nReading a RDTSC nem megbízható az SMP rendszerek, mivel minden egyes CPU fenntartja saját számlálót és minden ellen nem garantált, hogy a szinkronizált a másikhoz képest CPU.\\n\\n\\nLehet, hogy azt sugallják, próbál **`clock_gettime(CLOCK_REALTIME)`**. A POSIX utasítás azt jelzi, hogy ez végre kell hajtani minden kompatibilis rendszereket. Ez olyan ns száma, de valószínűleg ellenőrizni fogja majd **`clock_getres(CLOCK_REALTIME)`**a rendszer, hogy mi a tényleges felbontás.\\n\\n ',\n",
142
+ " 'name': '',\n",
143
+ " 'is_accepted': False},\n",
144
+ " {'text': '\\nLehet, hogy érdekli a [Linux GYIK-`clock_gettime(CLOCK_REALTIME)`](http://juliusdavies.ca/posix_clocks/clock_realtime_linux_faq.html)\\n\\n ',\n",
145
+ " 'name': '',\n",
146
+ " 'is_accepted': False},\n",
147
+ " {'text': '\\n[Ez a válasz](https://stackoverflow.com/a/98/) említi problémák az óra beállítása közben. Mindkét problémákra garantálja kullancs egységek és a problémák az idő beállítása is megoldódnak C ++ 11 a `<chrono>`könyvtárban.\\n\\n\\nAz óra `std::chrono::steady_clock`garantáltan nem kell korrigálni, továbbá előre lép állandó sebességgel képest valós időben, így technológiák, mint a SpeedStep nem befolyásolja azt.\\n\\n\\nTudod kap typesafe egységek átalakításával az egyik `std::chrono::duration`szakterületek, például `std::chrono::microseconds`. Az ilyen típusú, nincs kétség, az egységek által használt kullancs értéket. Ugyanakkor szem előtt tartani, hogy az óra nem feltétlenül ezt az állásfoglalást. Ön tudja alakítani egy időtartamot attoseconds nélkül, hogy ténylegesen egy órát, hogy pontos.\\n\\n ',\n",
148
+ " 'name': '',\n",
149
+ " 'is_accepted': False}]]}"
150
+ ]
151
+ },
152
+ "execution_count": 21,
153
+ "metadata": {},
154
+ "output_type": "execute_result"
155
+ }
156
+ ],
157
+ "source": [
158
+ "cqa_hu['train'][100:102]"
159
+ ]
160
+ },
161
+ {
162
+ "cell_type": "code",
163
+ "execution_count": 22,
164
+ "metadata": {},
165
+ "outputs": [
166
+ {
167
+ "data": {
168
+ "text/plain": [
169
+ "{'id': ['bbfce58894a1bb9140659cfbfe334fb6',\n",
170
+ " '3f98c644c947963c7990c047661fcfc5'],\n",
171
+ " 'text': ['', ''],\n",
172
+ " 'name': ['a rendszerpartíciót újratelepítés nélkül nagyobbá tehetem windows és programok?',\n",
173
+ " 'van-e ingyenes eszköz a c meghajtó nagyobbá tételéhez?'],\n",
174
+ " 'domain': ['hdd-tool.com', 'hdd-tool.com'],\n",
175
+ " 'bucket': ['2020.40', '2020.40'],\n",
176
+ " 'answers': [[{'text': 'igen, ez a cikk háromféle módszert mutat be 3féle eszköz segítségével e feladat elvégzéséhez.',\n",
177
+ " 'name': '',\n",
178
+ " 'is_accepted': True}],\n",
179
+ " [{'text': 'igen, niubi partition editor ingyenes kiadása van a windows 10/8/7/vista/xp otthoni felhasználók.',\n",
180
+ " 'name': '',\n",
181
+ " 'is_accepted': True}]]}"
182
+ ]
183
+ },
184
+ "execution_count": 22,
185
+ "metadata": {},
186
+ "output_type": "execute_result"
187
+ }
188
+ ],
189
+ "source": [
190
+ "faq_hu['train'][100:102]"
191
+ ]
192
+ },
193
+ {
194
+ "cell_type": "code",
195
+ "execution_count": 5,
196
+ "metadata": {},
197
+ "outputs": [
198
+ {
199
+ "data": {
200
+ "text/plain": [
201
+ "171"
202
+ ]
203
+ },
204
+ "execution_count": 5,
205
+ "metadata": {},
206
+ "output_type": "execute_result"
207
+ }
208
+ ],
209
+ "source": [
210
+ "a = set(cqa_hu['train']['domain'])\n",
211
+ "len(a)"
212
+ ]
213
+ },
214
+ {
215
+ "cell_type": "code",
216
+ "execution_count": 16,
217
+ "metadata": {},
218
+ "outputs": [
219
+ {
220
+ "data": {
221
+ "text/plain": [
222
+ "[('hotels.com', 321784),\n",
223
+ " ('travelminit.hu', 105216),\n",
224
+ " ('tripadvisor.co.hu', 84606),\n",
225
+ " ('travelminit.ro', 50327),\n",
226
+ " ('booking.com', 20315),\n",
227
+ " ('aszinonimaszotar.hu', 18896),\n",
228
+ " ('skyscanner.hu', 16717),\n",
229
+ " ('szallasvadasz.hu', 13759),\n",
230
+ " ('esky.hu', 12513),\n",
231
+ " ('travelminit.com', 12455),\n",
232
+ " ('pitchup.com', 9906),\n",
233
+ " ('kiwi.com', 9452),\n",
234
+ " ('languagecourse.net', 8284),\n",
235
+ " ('ekuponok.com', 7385),\n",
236
+ " ('rentalcargroup.com', 6980),\n",
237
+ " ('solvusoft.com', 6807),\n",
238
+ " ('flatio.hu', 5650),\n",
239
+ " ('haziallat.hu', 4255),\n",
240
+ " ('miapanasz.hu', 3814),\n",
241
+ " ('liveagent.hu', 3632)]"
242
+ ]
243
+ },
244
+ "execution_count": 16,
245
+ "metadata": {},
246
+ "output_type": "execute_result"
247
+ }
248
+ ],
249
+ "source": [
250
+ "from collections import Counter\n",
251
+ "\n",
252
+ "faq_domains = Counter(faq_hu['train']['domain'])\n",
253
+ "faq_domains.most_common(20)"
254
+ ]
255
+ },
256
+ {
257
+ "cell_type": "code",
258
+ "execution_count": null,
259
+ "metadata": {},
260
+ "outputs": [],
261
+ "source": []
262
+ }
263
+ ],
264
+ "metadata": {
265
+ "interpreter": {
266
+ "hash": "02e357c7440d8ed11be29edfeecade50b9c6cce68ea0a63234d5a765afff05f4"
267
+ },
268
+ "kernelspec": {
269
+ "display_name": "Python 3.9.6 64-bit ('hf_venv': venv)",
270
+ "name": "python3"
271
+ },
272
+ "language_info": {
273
+ "codemirror_mode": {
274
+ "name": "ipython",
275
+ "version": 3
276
+ },
277
+ "file_extension": ".py",
278
+ "mimetype": "text/x-python",
279
+ "name": "python",
280
+ "nbconvert_exporter": "python",
281
+ "pygments_lexer": "ipython3",
282
+ "version": "3.9.6"
283
+ },
284
+ "orig_nbformat": 4
285
+ },
286
+ "nbformat": 4,
287
+ "nbformat_minor": 2
288
+ }
notebooks/mqa_test.ipynb ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stderr",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "Using custom data configuration hu-faq-question-language=hu,scope=faq\n",
13
+ "Reusing dataset mqa (/Users/eend/.cache/huggingface/datasets/clips___mqa/hu-faq-question-language=hu,scope=faq/0.0.0/7eda4cdcbd6f009259fc516f204d776915a5f54ea2ad414c3dcddfaacd4dfe0b)\n",
14
+ "100%|██████████| 1/1 [00:00<00:00, 19.53it/s]\n"
15
+ ]
16
+ }
17
+ ],
18
+ "source": [
19
+ "from datasets import load_dataset\n",
20
+ "\n",
21
+ "faq_hu = load_dataset(\"clips/mqa\", scope=\"faq\", language=\"hu\")"
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "code",
26
+ "execution_count": 8,
27
+ "metadata": {},
28
+ "outputs": [
29
+ {
30
+ "data": {
31
+ "text/plain": [
32
+ "{'id': 'a44ad85683f3d8afd1ffa42ce55fefcd',\n",
33
+ " 'text': '',\n",
34
+ " 'name': 'szingapúr területén mely kisállatbarát hotelek ideálisak a családok számára?',\n",
35
+ " 'domain': 'tripadvisor.co.hu',\n",
36
+ " 'bucket': '2020.29',\n",
37
+ " 'answers': [{'text': 'a(z) szingapúr területén nyaraló családok tapasztalatai szerint ezek igazán jó kisállatbarát hotelek: \\n**[intercontinental singapore](https://www.tripadvisor.co.hu/hotel_review-g294265-d299199-reviews-intercontinental_singapore-singapore.html?faqtqr=5&faqts=hotels&faqtt=214&faqtup=geo%3a294265%3bzfa%3a9&m=63287)** utazói osztályozás: 4.5/5 \\n**[fraser suites singapore](https://www.tripadvisor.co.hu/hotel_review-g294265-d306172-reviews-fraser_suites_singapore-singapore.html?faqtqr=5&faqts=hotels&faqtt=214&faqtup=geo%3a294265%3bzfa%3a9&m=63287)** utazói osztályozás: 4.5/5 \\n**[holiday inn express singapore katong](https://www.tripadvisor.co.hu/hotel_review-g294265-d8777586-reviews-holiday_inn_express_singapore_katong-singapore.html?faqtqr=5&faqts=hotels&faqtt=214&faqtup=geo%3a294265%3bzfa%3a9&m=63287)** utazói osztályozás: 4.0/5',\n",
38
+ " 'name': '',\n",
39
+ " 'is_accepted': True}]}"
40
+ ]
41
+ },
42
+ "execution_count": 8,
43
+ "metadata": {},
44
+ "output_type": "execute_result"
45
+ }
46
+ ],
47
+ "source": [
48
+ "faq_hu['train'][810000]"
49
+ ]
50
+ },
51
+ {
52
+ "cell_type": "code",
53
+ "execution_count": 16,
54
+ "metadata": {},
55
+ "outputs": [
56
+ {
57
+ "data": {
58
+ "text/plain": [
59
+ "tensor([[ 1, 2, 2, 3, 4],\n",
60
+ " [ 2, 3, 4, 5, 7],\n",
61
+ " [ 2, 4, 4, 6, 8],\n",
62
+ " [ 4, 6, 8, 10, 14]])"
63
+ ]
64
+ },
65
+ "execution_count": 16,
66
+ "metadata": {},
67
+ "output_type": "execute_result"
68
+ }
69
+ ],
70
+ "source": [
71
+ "import torch\n",
72
+ "\n",
73
+ "a = torch.tensor([[1,2,2,3,4],[2,3,4,5,7]])\n",
74
+ "b = a * 2\n",
75
+ "\n",
76
+ "tensor_list = []\n",
77
+ "tensor_list.append(a)\n",
78
+ "tensor_list.append(b)\n",
79
+ "torch.cat((a,b),dim=0)"
80
+ ]
81
+ },
82
+ {
83
+ "cell_type": "code",
84
+ "execution_count": 17,
85
+ "metadata": {},
86
+ "outputs": [
87
+ {
88
+ "data": {
89
+ "text/plain": [
90
+ "5"
91
+ ]
92
+ },
93
+ "execution_count": 17,
94
+ "metadata": {},
95
+ "output_type": "execute_result"
96
+ }
97
+ ],
98
+ "source": [
99
+ "a.size()[1]"
100
+ ]
101
+ },
102
+ {
103
+ "cell_type": "code",
104
+ "execution_count": 18,
105
+ "metadata": {},
106
+ "outputs": [
107
+ {
108
+ "data": {
109
+ "text/plain": [
110
+ "tensor([[1, 2, 2, 3, 4],\n",
111
+ " [2, 3, 4, 5, 7]])"
112
+ ]
113
+ },
114
+ "execution_count": 18,
115
+ "metadata": {},
116
+ "output_type": "execute_result"
117
+ }
118
+ ],
119
+ "source": [
120
+ "a[:2]"
121
+ ]
122
+ },
123
+ {
124
+ "cell_type": "code",
125
+ "execution_count": 24,
126
+ "metadata": {},
127
+ "outputs": [
128
+ {
129
+ "data": {
130
+ "text/plain": [
131
+ "[[1, 2, 2, 3, 4], [2, 3, 4, 5, 7]]"
132
+ ]
133
+ },
134
+ "execution_count": 24,
135
+ "metadata": {},
136
+ "output_type": "execute_result"
137
+ }
138
+ ],
139
+ "source": [
140
+ "a.tolist()"
141
+ ]
142
+ },
143
+ {
144
+ "cell_type": "code",
145
+ "execution_count": 25,
146
+ "metadata": {},
147
+ "outputs": [],
148
+ "source": [
149
+ "c = torch.empty([1,5])"
150
+ ]
151
+ },
152
+ {
153
+ "cell_type": "code",
154
+ "execution_count": 26,
155
+ "metadata": {},
156
+ "outputs": [
157
+ {
158
+ "data": {
159
+ "text/plain": [
160
+ "tensor([[1.4569e-19, 1.0658e-32, 1.1258e+24, 1.5789e-19, 1.1819e+22]])"
161
+ ]
162
+ },
163
+ "execution_count": 26,
164
+ "metadata": {},
165
+ "output_type": "execute_result"
166
+ }
167
+ ],
168
+ "source": [
169
+ "c"
170
+ ]
171
+ },
172
+ {
173
+ "cell_type": "code",
174
+ "execution_count": 28,
175
+ "metadata": {},
176
+ "outputs": [
177
+ {
178
+ "data": {
179
+ "text/plain": [
180
+ "[1.4568973155122501e-19,\n",
181
+ " 1.0658291767562146e-32,\n",
182
+ " 1.1257918204515671e+24,\n",
183
+ " 1.5789373458898217e-19,\n",
184
+ " 1.1818655764620037e+22]"
185
+ ]
186
+ },
187
+ "execution_count": 28,
188
+ "metadata": {},
189
+ "output_type": "execute_result"
190
+ }
191
+ ],
192
+ "source": [
193
+ "c.squeeze().tolist()"
194
+ ]
195
+ },
196
+ {
197
+ "cell_type": "code",
198
+ "execution_count": 32,
199
+ "metadata": {},
200
+ "outputs": [
201
+ {
202
+ "name": "stdout",
203
+ "output_type": "stream",
204
+ "text": [
205
+ "None\n"
206
+ ]
207
+ }
208
+ ],
209
+ "source": [
210
+ "a = [1,2,3]\n",
211
+ "b= [2,4,5]\n",
212
+ "print(a.extend(b))"
213
+ ]
214
+ },
215
+ {
216
+ "cell_type": "code",
217
+ "execution_count": 33,
218
+ "metadata": {},
219
+ "outputs": [
220
+ {
221
+ "data": {
222
+ "text/plain": [
223
+ "[1, 2, 3, 2, 4, 5]"
224
+ ]
225
+ },
226
+ "execution_count": 33,
227
+ "metadata": {},
228
+ "output_type": "execute_result"
229
+ }
230
+ ],
231
+ "source": [
232
+ "\n"
233
+ ]
234
+ },
235
+ {
236
+ "cell_type": "code",
237
+ "execution_count": 41,
238
+ "metadata": {},
239
+ "outputs": [
240
+ {
241
+ "name": "stdout",
242
+ "output_type": "stream",
243
+ "text": [
244
+ "\n",
245
+ "\n",
246
+ "Types of Question Answering\n",
247
+ "\n",
248
+ " - extractive question answering (encoder only models BERT)\n",
249
+ "\n",
250
+ " - posing questions about a document and identifying the answers as spans of text in the document itself\n",
251
+ "\n",
252
+ " - generative question answering (encoder-decoder T5/BART)\n",
253
+ "\n",
254
+ " - open ended questions, which need to synthesize information\n",
255
+ "\n",
256
+ " - retrieval based/community question answering \n",
257
+ "\n",
258
+ "\n",
259
+ "\n",
260
+ "\n",
261
+ "\n",
262
+ "\n",
263
+ "\n",
264
+ "First approach - translate dataset, fine-tune model\n",
265
+ "\n",
266
+ "!Not really feasible, because it needs lots of human evaluation for correctly determine answer start token\n",
267
+ "\n",
268
+ "\n",
269
+ "\n",
270
+ " 1. Translate English QA dataset into Hungarian\n",
271
+ "\n",
272
+ " - SQuAD - reading comprehension based on Wikipedia articles\n",
273
+ "\n",
274
+ " - ~ 100.000 question/answers\n",
275
+ "\n",
276
+ " 2. Fine-tune a model and evaluate on this dataset\n",
277
+ "\n",
278
+ "\n",
279
+ "\n",
280
+ "\n",
281
+ "\n",
282
+ "Second approach - fine-tune multilingual model\n",
283
+ "\n",
284
+ "!MQA format different than SQuAD, cannot use ModelForQuestionAnswering\n",
285
+ "\n",
286
+ "\n",
287
+ "\n",
288
+ " 1. Use a Hungarian dataset\n",
289
+ "\n",
290
+ " - MQA - multilingual parsed from Common Crawl\n",
291
+ "\n",
292
+ " - FAQ - 878.385 (2.415 domain)\n",
293
+ "\n",
294
+ " - CQA - 27.639 (171 domain)\n",
295
+ "\n",
296
+ " 2. Fine-tune and evaluate a model on this dataset\n",
297
+ "\n",
298
+ " \n",
299
+ "\n",
300
+ " \n",
301
+ "\n",
302
+ " Possible steps:\n",
303
+ "\n",
304
+ " - Use an existing pre-trained model in Hungarian/Romanian/or multilingual to generate embeddings\n",
305
+ "\n",
306
+ " - Select Model:\n",
307
+ "\n",
308
+ " - multilingual which includes hu:\n",
309
+ "\n",
310
+ " - distiluse-base-multilingual-cased-v2 (400MB)\n",
311
+ "\n",
312
+ " - paraphrase-multilingual-MiniLM-L12-v2 (400MB) - fastest\n",
313
+ "\n",
314
+ " - paraphrase-multilingual-mpnet-base-v2 (900MB) - best performing\n",
315
+ "\n",
316
+ " - hubert\n",
317
+ "\n",
318
+ " - Select a dataset\n",
319
+ "\n",
320
+ " - use MQA hungarian subset\n",
321
+ "\n",
322
+ " - use hungarian wikipedia pages data, split it up\n",
323
+ "\n",
324
+ " - DBpedia, shortened abstracts = 500.000\n",
325
+ "\n",
326
+ " - Pre-compute embeddings for all answers/paragraphs\n",
327
+ "\n",
328
+ " - Compute embedding for incoming query\n",
329
+ "\n",
330
+ " - Compare similarity between query embedding and precomputed \n",
331
+ "\n",
332
+ " - return top-3 answers/questions\n",
333
+ "\n",
334
+ " \n",
335
+ "\n",
336
+ " Alternative steps:\n",
337
+ "\n",
338
+ " - train a sentence transformer on the Hungarian / Romanian subsets\n",
339
+ "\n",
340
+ " - Use the trained sentence transformer to generate embeddings\n",
341
+ "\n"
342
+ ]
343
+ }
344
+ ],
345
+ "source": [
346
+ "with open('../approach.txt','r') as f:\n",
347
+ " line = 'init'\n",
348
+ " while line != '':\n",
349
+ " line=f.readline();\n",
350
+ " print(line)"
351
+ ]
352
+ },
353
+ {
354
+ "cell_type": "code",
355
+ "execution_count": 42,
356
+ "metadata": {},
357
+ "outputs": [
358
+ {
359
+ "data": {
360
+ "text/plain": [
361
+ "tensor([1.4013e-45, 0.0000e+00, 2.8026e-45, 0.0000e+00, 2.8026e-45, 0.0000e+00,\n",
362
+ " 4.2039e-45, 0.0000e+00, 5.6052e-45, 0.0000e+00, 2.8026e-45, 0.0000e+00,\n",
363
+ " 4.2039e-45, 0.0000e+00, 5.6052e-45, 0.0000e+00, 7.0065e-45, 0.0000e+00,\n",
364
+ " 9.8091e-45, 0.0000e+00])"
365
+ ]
366
+ },
367
+ "execution_count": 42,
368
+ "metadata": {},
369
+ "output_type": "execute_result"
370
+ }
371
+ ],
372
+ "source": [
373
+ "d = torch.empty([20])\n",
374
+ "d"
375
+ ]
376
+ },
377
+ {
378
+ "cell_type": "code",
379
+ "execution_count": null,
380
+ "metadata": {},
381
+ "outputs": [],
382
+ "source": []
383
+ }
384
+ ],
385
+ "metadata": {
386
+ "interpreter": {
387
+ "hash": "02e357c7440d8ed11be29edfeecade50b9c6cce68ea0a63234d5a765afff05f4"
388
+ },
389
+ "kernelspec": {
390
+ "display_name": "Python 3.9.6 64-bit ('hf_venv': venv)",
391
+ "name": "python3"
392
+ },
393
+ "language_info": {
394
+ "codemirror_mode": {
395
+ "name": "ipython",
396
+ "version": 3
397
+ },
398
+ "file_extension": ".py",
399
+ "mimetype": "text/x-python",
400
+ "name": "python",
401
+ "nbconvert_exporter": "python",
402
+ "pygments_lexer": "ipython3",
403
+ "version": "3.9.6"
404
+ },
405
+ "orig_nbformat": 4
406
+ },
407
+ "nbformat": 4,
408
+ "nbformat_minor": 2
409
+ }
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ transformers
2
+ torch
3
+ sentence_transformers
requirements_full.txt ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiohttp==3.8.1
2
+ aiosignal==1.2.0
3
+ altair==4.1.0
4
+ appnope==0.1.2
5
+ argon2-cffi==21.1.0
6
+ arrow==1.2.1
7
+ astor==0.8.1
8
+ async-timeout==4.0.1
9
+ attrs==21.2.0
10
+ backcall==0.2.0
11
+ base58==2.1.1
12
+ binaryornot==0.4.4
13
+ bleach==4.1.0
14
+ blinker==1.4
15
+ cachetools==4.2.4
16
+ certifi==2021.10.8
17
+ cffi==1.15.0
18
+ chardet==4.0.0
19
+ charset-normalizer==2.0.7
20
+ click==7.1.2
21
+ cycler==0.11.0
22
+ datasets==1.15.1
23
+ debugpy==1.5.1
24
+ decorator==5.1.0
25
+ defusedxml==0.7.1
26
+ dill==0.3.4
27
+ entrypoints==0.3
28
+ filelock==3.3.2
29
+ fonttools==4.28.1
30
+ frozenlist==1.2.0
31
+ fsspec==2021.11.0
32
+ gitdb==4.0.9
33
+ GitPython==3.1.24
34
+ huggingface-hub==0.1.2
35
+ idna==3.3
36
+ ipykernel==6.5.0
37
+ ipython==7.29.0
38
+ ipython-genutils==0.2.0
39
+ ipywidgets==7.6.5
40
+ isodate==0.6.0
41
+ jedi==0.18.1
42
+ Jinja2==3.0.3
43
+ jinja2-time==0.2.0
44
+ joblib==1.1.0
45
+ jsonschema==4.2.1
46
+ jupyter==1.0.0
47
+ jupyter-client==7.0.6
48
+ jupyter-console==6.4.0
49
+ jupyter-core==4.9.1
50
+ jupyterlab-pygments==0.1.2
51
+ jupyterlab-widgets==1.0.2
52
+ kiwisolver==1.3.2
53
+ MarkupSafe==2.0.1
54
+ matplotlib==3.5.0
55
+ matplotlib-inline==0.1.3
56
+ mistune==0.8.4
57
+ multidict==5.2.0
58
+ multiprocess==0.70.12.2
59
+ nbclient==0.5.8
60
+ nbconvert==6.3.0
61
+ nbformat==5.1.3
62
+ nest-asyncio==1.5.1
63
+ nltk==3.6.5
64
+ notebook==6.4.6
65
+ numpy==1.21.4
66
+ packaging==21.2
67
+ pandas==1.3.4
68
+ pandocfilters==1.5.0
69
+ parso==0.8.2
70
+ pexpect==4.8.0
71
+ pickleshare==0.7.5
72
+ Pillow==8.4.0
73
+ plotly==5.4.0
74
+ poyo==0.5.0
75
+ prometheus-client==0.12.0
76
+ prompt-toolkit==3.0.22
77
+ protobuf==3.19.1
78
+ ptyprocess==0.7.0
79
+ pyarrow==6.0.0
80
+ pycparser==2.21
81
+ pydeck==0.7.1
82
+ Pygments==2.10.0
83
+ Pympler==0.9
84
+ pyparsing==2.4.7
85
+ pyrsistent==0.18.0
86
+ python-dateutil==2.8.2
87
+ python-slugify==5.0.2
88
+ pytz==2021.3
89
+ pytz-deprecation-shim==0.1.0.post0
90
+ PyYAML==6.0
91
+ pyzmq==22.3.0
92
+ qtconsole==5.2.0
93
+ QtPy==1.11.2
94
+ rdflib==6.0.2
95
+ regex==2021.11.10
96
+ requests==2.26.0
97
+ sacremoses==0.0.46
98
+ scikit-learn==1.0.1
99
+ scipy==1.7.2
100
+ seaborn==0.11.2
101
+ Send2Trash==1.8.0
102
+ sentence-transformers==2.1.0
103
+ sentencepiece==0.1.96
104
+ setuptools-scm==6.3.2
105
+ six==1.16.0
106
+ smmap==5.0.0
107
+ streamlit==1.2.0
108
+ tenacity==8.0.1
109
+ terminado==0.12.1
110
+ testpath==0.5.0
111
+ text-unidecode==1.3
112
+ threadpoolctl==3.0.0
113
+ tokenizers==0.10.3
114
+ toml==0.10.2
115
+ tomli==1.2.2
116
+ toolz==0.11.2
117
+ torch==1.10.0
118
+ torchaudio==0.10.0
119
+ torchvision==0.11.1
120
+ tornado==6.1
121
+ tqdm==4.62.3
122
+ traitlets==5.1.1
123
+ transformers==4.12.3
124
+ typing-extensions==3.10.0.2
125
+ tzdata==2021.5
126
+ tzlocal==4.1
127
+ urllib3==1.26.7
128
+ validators==0.18.2
129
+ wcwidth==0.2.5
130
+ webencodings==0.5.1
131
+ widgetsnbextension==3.5.2
132
+ xxhash==2.0.2
133
+ yarl==1.7.2
src/app.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModel
3
+ import torch
4
+ from sentence_transformers import util
5
+
6
+ @st.cache
7
+ def load_raw_sentences(filename):
8
+ with open(filename) as f:
9
+ return f.readlines()
10
+
11
+ @st.cache
12
+ def load_embeddings(filename):
13
+ with open(filename) as f:
14
+ return torch.load(filename,map_location=torch.device('cpu') )
15
+
16
+ #Mean Pooling - Take attention mask into account for correct averaging
17
+ def mean_pooling(model_output, attention_mask):
18
+ token_embeddings = model_output[0] #First element of model_output contains all token embeddings
19
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
20
+ sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
21
+ sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
22
+ return sum_embeddings / sum_mask
23
+
24
+ def findTopKMostSimilar(query_embedding, embeddings, all_sentences, k):
25
+ cosine_scores = util.pytorch_cos_sim(query_embedding, embeddings)
26
+ cosine_scores_list = cosine_scores.squeeze().tolist()
27
+ pairs = []
28
+ for idx,score in enumerate(cosine_scores_list):
29
+ pairs.append({'index': idx, 'score': score, 'text': all_sentences[idx]})
30
+ pairs = sorted(pairs, key=lambda x: x['score'], reverse=True)
31
+ return pairs[0:k]
32
+
33
+ def calculateEmbeddings(sentences,tokenizer,model):
34
+ tokenized_sentences = tokenizer(sentences, padding=True, truncation=True, max_length=128, return_tensors='pt')
35
+ with torch.no_grad():
36
+ model_output = model(**tokenized_sentences)
37
+ sentence_embeddings = mean_pooling(model_output, tokenized_sentences['attention_mask'])
38
+ return sentence_embeddings
39
+
40
+
41
+ multilingual_checkpoint = 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'
42
+ tokenizer = AutoTokenizer.from_pretrained(multilingual_checkpoint)
43
+ model = AutoModel.from_pretrained(multilingual_checkpoint)
44
+
45
+ raw_text_file = 'data/processed/shortened_abstracts_hu_2021_09_01.txt'
46
+ all_sentences = load_raw_sentences(raw_text_file)
47
+
48
+ embeddings_file = 'data/processed/shortened_abstracts_hu_2021_09_01_embedded.pt'
49
+ all_embeddings = load_embeddings(embeddings_file)
50
+
51
+
52
+ st.text('Search Wikipedia abstracts in Hungarian - Input some search term and see the top-5 most similar wikipedia abstracts')
53
+ st.text('Wikipedia absztrakt kereső - adjon meg egy tetszőleges kifejezést és a rendszer visszaadja az 5 hozzá legjobban hasonlító Wikipedia absztraktot')
54
+
55
+ input_query = st.text_area("Hol élnek a bengali tigrisek?")
56
+
57
+ if input_query:
58
+ query_embedding = calculateEmbeddings([input_query],tokenizer,model)
59
+ top_pairs = findTopKMostSimilar(query_embedding, all_embeddings, all_sentences, 5)
60
+ st.json(top_pairs)
src/data/dbpedia_dump_embeddings.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModel
2
+ from datetime import datetime
3
+ import torch
4
+ import pickle
5
+
6
+ #Mean Pooling - Take attention mask into account for correct averaging
7
+ def mean_pooling(model_output, attention_mask):
8
+ token_embeddings = model_output[0] #First element of model_output contains all token embeddings
9
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
10
+ sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
11
+ sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
12
+ return sum_embeddings / sum_mask
13
+
14
+ def calculateEmbeddings(sentences,tokenizer,model):
15
+ tokenized_sentences = tokenizer(sentences, padding=True, truncation=True, max_length=128, return_tensors='pt')
16
+ with torch.no_grad():
17
+ model_output = model(**tokenized_sentences)
18
+ sentence_embeddings = mean_pooling(model_output, tokenized_sentences['attention_mask'])
19
+ return sentence_embeddings
20
+
21
+
22
+ def saveToDisc(embeddings, filename):
23
+ with open(filename, "ab") as f:
24
+ pickle.dump(embeddings, f, protocol=pickle.HIGHEST_PROTOCOL)
25
+
26
+ def saveToDisc(sentences, embeddings, filename):
27
+ with open(filename, "ab") as f:
28
+ pickle.dump({'sentences': sentences, 'embeddings': embeddings}, f, protocol=pickle.HIGHEST_PROTOCOL)
29
+
30
+ dt = datetime.now()
31
+ datetime_formatted = dt.strftime('%Y-%m-%d_%H:%M:%S')
32
+ batch_size = 1000
33
+
34
+ input_text_file = 'data/processed/shortened_abstracts_hu_2021_09_01.txt'
35
+ output_embeddings_file = f'data/processed/embeddings_{batch_size}_batches_at_{datetime_formatted}.pkl'
36
+
37
+ multilingual_checkpoint = 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'
38
+ tokenizer = AutoTokenizer.from_pretrained(multilingual_checkpoint)
39
+ model = AutoModel.from_pretrained(multilingual_checkpoint)
40
+
41
+
42
+ total_read = 0
43
+ total_read_limit = 3 * batch_size
44
+ with open(input_text_file) as f:
45
+ while total_read < total_read_limit:
46
+ count = 0
47
+ sentences = []
48
+ line = 'init'
49
+ while line and count < batch_size:
50
+ line = f.readline()
51
+ sentences.append(line)
52
+ count += 1
53
+
54
+ sentence_embeddings = calculateEmbeddings(sentences,tokenizer,model)
55
+ saveToDisc(sentences, sentence_embeddings,output_embeddings_file)
56
+ total_read += count
src/data/dbpedia_dump_wiki_text.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from rdflib import Graph
2
+
3
+ # Downloaded from https://databus.dbpedia.org/dbpedia/text/short-abstracts
4
+ raw_data_path = 'data/raw/short-abstracts_lang=hu.ttl'
5
+ processed_data_path = 'data/processed/shortened_abstracts_hu_2021_09_01.txt'
6
+
7
+ g = Graph()
8
+ g.parse(raw_data_path, format='turtle')
9
+
10
+ i = 0
11
+ objects = []
12
+ with open(processed_data_path, 'w') as f:
13
+ print(len(g))
14
+ for subject, predicate, object in g:
15
+ objects.append(object.replace(' +/-','').replace('\n',' '))
16
+ objects.append('\n')
17
+ i += 1
18
+ f.writelines(objects)
src/exploration/automodel_test.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
2
+ import torch
3
+
4
+ checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
5
+
6
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint)
7
+ model = AutoModelForSequenceClassification.from_pretrained(checkpoint)
8
+
9
+ raw_inputs = ["This is the second best day of my life.", "Are you freaking kidding me right now?"]
10
+
11
+ tokens = tokenizer(raw_inputs, padding=True, return_tensors="pt")
12
+ print(tokens)
13
+
14
+ raw_outputs = model(**tokens)
15
+ print(raw_outputs.logits)
16
+
17
+ predictions = torch.nn.functional.softmax(raw_outputs.logits, dim=-1)
18
+ print(predictions)
19
+
20
+ # max value, index of max value, and corresponding label
21
+ labels = model.config.id2label
22
+ max_value_index = [(torch.max(p), torch.argmax(p)) for p in predictions]
23
+ [print("{:.5f}".format(e[0].item()),labels[e[1].item()]) for e in max_value_index]
src/exploration/datetime_test.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+
3
+ dt = datetime.now()
4
+ print(dt)
5
+ print(dt.strftime('%a %d-%m-%Y'))
6
+ print(dt.strftime('%a %d/%m/%Y'))
7
+ print(dt.strftime('%a %d/%m/%y'))
8
+ print(dt.strftime('%A %d-%m-%Y, %H:%M:%S'))
9
+ print(dt.strftime('%X %x'))
10
+ print(dt.strftime('%Y-%m-%d_%H:%M:%S'))
src/exploration/mqa_test.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+
3
+ faq_hu = load_dataset("clips/mqa", scope="faq", language="hu")
4
+ cqa_hu = load_dataset("clips/mqa", scope="cqa", language="hu")
5
+
6
+ print(faq_hu)
7
+ print(cqa_hu)
8
+ print(faq_hu['train'][:5])
9
+ print(cqa_hu['train'][:5])
src/exploration/pipeline_test.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline
2
+
3
+ generator = pipeline("text-generation", model="distilgpt2")
4
+ res = lambda _:generator("My girlfriend told me that I have a huge", max_length=40)
5
+ print(res(0))
6
+
7
+ top_k=10
8
+ maskfiller = pipeline("fill-mask", model="distilbert-base-uncased")
9
+ hu_res = lambda _: maskfiller("Hungarians are a very [MASK] nation.", top_k=top_k)
10
+ ju_res = lambda _:maskfiller("Jews are a very [MASK] nation.", top_k=top_k)
11
+ it_res = lambda _:maskfiller("Italians are a very [MASK] nation.", top_k=top_k)
12
+
13
+ token_str = lambda x:[e["token_str"] for e in x]
14
+ print(token_str(hu_res(0)))
15
+ print(token_str(ju_res(0)))
16
+ print(token_str(it_res(0)))
src/exploration/serialize_test.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pickle
3
+ '''
4
+ a = [1,2,3]
5
+ b = [4,5,6]
6
+ at = torch.tensor([a,a])
7
+ bt = torch.tensor([b,b])
8
+
9
+ with open('serialize_test.pkl', "ab") as f:
10
+ pickle.dump(at,f)
11
+ pickle.dump(bt,f)
12
+
13
+ with open('serialize_test.pkl', "rb") as f:
14
+ print(pickle.load(f))
15
+ print(pickle.load(f))
16
+ '''
17
+
18
+ def loadFromDiskRaw(batch_number, filename='embeddings.pkl'):
19
+ count = 0
20
+ with open(filename, "rb") as f:
21
+ while count < batch_number:
22
+ stored_data = pickle.load(f)
23
+ print(stored_data.size())
24
+ print(stored_data[0][:15])
25
+ count += 1
26
+ return stored_data
27
+
28
+ output_embeddings_file = 'data/processed/DBpedia_shortened_abstracts_hu_embeddings.pkl'
29
+ loadFromDiskRaw(3, output_embeddings_file)
30
+
src/features/semantic_retreiver.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModel
2
+ import torch
3
+ import pickle
4
+ from sentence_transformers import util
5
+ from datetime import datetime
6
+
7
+ #Mean Pooling - Take attention mask into account for correct averaging
8
+ def mean_pooling(model_output, attention_mask):
9
+ token_embeddings = model_output[0] #First element of model_output contains all token embeddings
10
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
11
+ sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
12
+ sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
13
+ return sum_embeddings / sum_mask
14
+
15
+
16
+ dt = datetime.now()
17
+ datetime_formatted = dt.strftime('%Y-%m-%d_%H:%M:%S')
18
+ batch_size = 1000
19
+ output_embeddings_file = f'data/processed/embeddings_{batch_size}_batches_at_{datetime_formatted}.pkl'
20
+ def saveToDisc(embeddings):
21
+ with open(output_embeddings_file, "ab") as f:
22
+ pickle.dump(embeddings, f, protocol=pickle.HIGHEST_PROTOCOL)
23
+
24
+
25
+ def saveToDisc(sentences, embeddings, filename='embeddings.pkl'):
26
+ with open(filename, "ab") as f:
27
+ pickle.dump({'sentences': sentences, 'embeddings': embeddings}, f, protocol=pickle.HIGHEST_PROTOCOL)
28
+
29
+ def saveToDiscRaw(embeddings, filename='embeddings.pkl'):
30
+ with open(filename, "ab") as f:
31
+ pickle.dump(embeddings, f, protocol=pickle.HIGHEST_PROTOCOL)
32
+ #for emb in embeddings:
33
+ # torch.save(emb,f)
34
+
35
+ def loadFromDiskRaw(filename='embeddings.pkl'):
36
+ with open(filename, "rb") as f:
37
+ stored_data = pickle.load(f)
38
+ return stored_data
39
+
40
+ def loadFromDisk(filename='embeddings.pkl'):
41
+ with open(filename, "rb") as f:
42
+ stored_data = pickle.load(f)
43
+ stored_sentences = stored_data['sentences']
44
+ stored_embeddings = stored_data['embeddings']
45
+ return stored_sentences, stored_embeddings
46
+
47
+ def findTopKMostSimilarPairs(embeddings, k):
48
+ cosine_scores = util.pytorch_cos_sim(embeddings, embeddings)
49
+ pairs = []
50
+ for i in range(len(cosine_scores)-1):
51
+ for j in range(i+1, len(cosine_scores)):
52
+ pairs.append({'index': [i, j], 'score': cosine_scores[i][j]})
53
+
54
+ pairs = sorted(pairs, key=lambda x: x['score'], reverse=True)
55
+ return pairs[0:k]
56
+
57
+ def findTopKMostSimilar(query_embedding, embeddings, k):
58
+ cosine_scores = util.pytorch_cos_sim(query_embedding, embeddings)
59
+ cosine_scores_list = cosine_scores.squeeze().tolist()
60
+ pairs = []
61
+ for idx,score in enumerate(cosine_scores_list):
62
+ pairs.append({'index': idx, 'score': score})
63
+ pairs = sorted(pairs, key=lambda x: x['score'], reverse=True)
64
+ return pairs[0:k]
65
+
66
+
67
+ def calculateEmbeddings(sentences,tokenizer,model):
68
+ tokenized_sentences = tokenizer(sentences, padding=True, truncation=True, max_length=128, return_tensors='pt')
69
+ with torch.no_grad():
70
+ model_output = model(**tokenized_sentences)
71
+ sentence_embeddings = mean_pooling(model_output, tokenized_sentences['attention_mask'])
72
+ return sentence_embeddings
73
+
74
+ multilingual_checkpoint = 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'
75
+ tokenizer = AutoTokenizer.from_pretrained(multilingual_checkpoint)
76
+ model = AutoModel.from_pretrained(multilingual_checkpoint)
77
+
78
+ raw_text_file = 'data/processed/shortened_abstracts_hu_2021_09_01.txt'
79
+
80
+
81
+ concated_sentence_embeddings = None
82
+ all_sentences = []
83
+
84
+ print(datetime.now())
85
+ batch_size = 5
86
+ line = 'init'
87
+ total_read = 0
88
+ total_read_limit = 120
89
+ skip_index = 100
90
+ with open(raw_text_file) as f:
91
+ while line and total_read < total_read_limit:
92
+ count = 0
93
+ sentence_batch = []
94
+ while line and count < batch_size:
95
+ line = f.readline()
96
+ sentence_batch.append(line)
97
+ count += 1
98
+
99
+ all_sentences.extend(sentence_batch)
100
+
101
+ if total_read >= skip_index:
102
+ sentence_embeddings = calculateEmbeddings(sentence_batch,tokenizer,model)
103
+ if concated_sentence_embeddings == None:
104
+ concated_sentence_embeddings = sentence_embeddings
105
+ else:
106
+ concated_sentence_embeddings = torch.cat([concated_sentence_embeddings, sentence_embeddings], dim=0)
107
+ print(concated_sentence_embeddings.size())
108
+ #saveToDiscRaw(sentence_embeddings)
109
+
110
+ total_read += count
111
+ if total_read%5==0:
112
+ print(f'total_read:{total_read}')
113
+ print(datetime.now())
114
+
115
+
116
+ query_embedding = calculateEmbeddings(['Melyik a legnépesebb város a világon?'],tokenizer,model)
117
+ top_pairs = findTopKMostSimilar(query_embedding, concated_sentence_embeddings, 5)
118
+
119
+ for pair in top_pairs:
120
+ i = pair['index']
121
+ score = pair['score']
122
+ print("{} \t\t Score: {:.4f}".format(all_sentences[skip_index+i], score))
123
+ '''
124
+ query = ''
125
+ while query != 'exit':
126
+ query = input("Enter your query: ")
127
+ query_embedding = calculateEmbeddings([query],tokenizer,model)
128
+
129
+
130
+ '''
src/main_qa.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModel
2
+ import torch
3
+ from sentence_transformers import util
4
+
5
+ def load_raw_sentences(filename):
6
+ with open(filename) as f:
7
+ return f.readlines()
8
+
9
+ #Mean Pooling - Take attention mask into account for correct averaging
10
+ def mean_pooling(model_output, attention_mask):
11
+ token_embeddings = model_output[0] #First element of model_output contains all token embeddings
12
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
13
+ sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
14
+ sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
15
+ return sum_embeddings / sum_mask
16
+
17
+ def findTopKMostSimilar(query_embedding, embeddings, k):
18
+ cosine_scores = util.pytorch_cos_sim(query_embedding, embeddings)
19
+ cosine_scores_list = cosine_scores.squeeze().tolist()
20
+ pairs = []
21
+ for idx,score in enumerate(cosine_scores_list):
22
+ pairs.append({'index': idx, 'score': score})
23
+ pairs = sorted(pairs, key=lambda x: x['score'], reverse=True)
24
+ return pairs[0:k]
25
+
26
+ def calculateEmbeddings(sentences,tokenizer,model):
27
+ tokenized_sentences = tokenizer(sentences, padding=True, truncation=True, max_length=128, return_tensors='pt')
28
+ with torch.no_grad():
29
+ model_output = model(**tokenized_sentences)
30
+ sentence_embeddings = mean_pooling(model_output, tokenized_sentences['attention_mask'])
31
+ return sentence_embeddings
32
+
33
+ multilingual_checkpoint = 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'
34
+ tokenizer = AutoTokenizer.from_pretrained(multilingual_checkpoint)
35
+ model = AutoModel.from_pretrained(multilingual_checkpoint)
36
+
37
+ raw_text_file = 'data/processed/shortened_abstracts_hu_2021_09_01.txt'
38
+ embeddings_file = 'data/processed/shortened_abstracts_hu_2021_09_01_embedded.pt'
39
+
40
+ all_sentences = load_raw_sentences(raw_text_file)
41
+ all_embeddings = torch.load(embeddings_file,map_location=torch.device('cpu') )
42
+
43
+ query = ''
44
+ while query != 'exit':
45
+ query = input("Enter your query: ")
46
+ query_embedding = calculateEmbeddings([query],tokenizer,model)
47
+ top_pairs = findTopKMostSimilar(query_embedding, all_embeddings, 5)
48
+ for pair in top_pairs:
49
+ i = pair['index']
50
+ score = pair['score']
51
+ print("{} \t\t Score: {:.4f}".format(all_sentences[i], score))