riccorl commited on
Commit
626eca0
1 Parent(s): f06d71d

first commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +332 -0
  2. MANIFEST.in +1 -0
  3. README.md +1 -13
  4. SETUP.cfg +8 -0
  5. app.py +245 -0
  6. dockerfiles/Dockerfile.cpu +17 -0
  7. dockerfiles/Dockerfile.cuda +38 -0
  8. examples/train_retriever.py +45 -0
  9. pyproject.toml +15 -0
  10. relik/__init__.py +1 -0
  11. relik/common/__init__.py +0 -0
  12. relik/common/log.py +97 -0
  13. relik/common/upload.py +128 -0
  14. relik/common/utils.py +609 -0
  15. relik/inference/__init__.py +0 -0
  16. relik/inference/annotator.py +422 -0
  17. relik/inference/data/__init__.py +0 -0
  18. relik/inference/data/objects.py +64 -0
  19. relik/inference/data/tokenizers/__init__.py +89 -0
  20. relik/inference/data/tokenizers/base_tokenizer.py +84 -0
  21. relik/inference/data/tokenizers/regex_tokenizer.py +73 -0
  22. relik/inference/data/tokenizers/spacy_tokenizer.py +228 -0
  23. relik/inference/data/tokenizers/whitespace_tokenizer.py +70 -0
  24. relik/inference/data/window/__init__.py +0 -0
  25. relik/inference/data/window/manager.py +262 -0
  26. relik/inference/gerbil.py +254 -0
  27. relik/inference/preprocessing.py +4 -0
  28. relik/inference/serve/__init__.py +0 -0
  29. relik/inference/serve/backend/__init__.py +0 -0
  30. relik/inference/serve/backend/relik.py +210 -0
  31. relik/inference/serve/backend/retriever.py +206 -0
  32. relik/inference/serve/backend/utils.py +29 -0
  33. relik/inference/serve/frontend/__init__.py +0 -0
  34. relik/inference/serve/frontend/relik.py +231 -0
  35. relik/inference/serve/frontend/style.css +33 -0
  36. relik/reader/__init__.py +0 -0
  37. relik/reader/conf/config.yaml +14 -0
  38. relik/reader/conf/data/base.yaml +21 -0
  39. relik/reader/conf/data/re.yaml +54 -0
  40. relik/reader/conf/training/base.yaml +12 -0
  41. relik/reader/conf/training/re.yaml +12 -0
  42. relik/reader/data/__init__.py +0 -0
  43. relik/reader/data/patches.py +51 -0
  44. relik/reader/data/relik_reader_data.py +965 -0
  45. relik/reader/data/relik_reader_data_utils.py +51 -0
  46. relik/reader/data/relik_reader_sample.py +49 -0
  47. relik/reader/lightning_modules/__init__.py +0 -0
  48. relik/reader/lightning_modules/relik_reader_pl_module.py +50 -0
  49. relik/reader/lightning_modules/relik_reader_re_pl_module.py +54 -0
  50. relik/reader/pytorch_modules/__init__.py +0 -0
.gitignore ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # custom
2
+
3
+ data/*
4
+ experiments/*
5
+ retrievers
6
+ outputs
7
+ model
8
+ wandb
9
+
10
+ # Created by https://www.toptal.com/developers/gitignore/api/jetbrains+all,vscode,python,jupyternotebooks,linux,windows,macos
11
+ # Edit at https://www.toptal.com/developers/gitignore?templates=jetbrains+all,vscode,python,jupyternotebooks,linux,windows,macos
12
+
13
+ ### JetBrains+all ###
14
+ # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
15
+ # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
16
+
17
+ # User-specific stuff
18
+ .idea/**/workspace.xml
19
+ .idea/**/tasks.xml
20
+ .idea/**/usage.statistics.xml
21
+ .idea/**/dictionaries
22
+ .idea/**/shelf
23
+
24
+ # Generated files
25
+ .idea/**/contentModel.xml
26
+
27
+ # Sensitive or high-churn files
28
+ .idea/**/dataSources/
29
+ .idea/**/dataSources.ids
30
+ .idea/**/dataSources.local.xml
31
+ .idea/**/sqlDataSources.xml
32
+ .idea/**/dynamic.xml
33
+ .idea/**/uiDesigner.xml
34
+ .idea/**/dbnavigator.xml
35
+
36
+ # Gradle
37
+ .idea/**/gradle.xml
38
+ .idea/**/libraries
39
+
40
+ # Gradle and Maven with auto-import
41
+ # When using Gradle or Maven with auto-import, you should exclude module files,
42
+ # since they will be recreated, and may cause churn. Uncomment if using
43
+ # auto-import.
44
+ # .idea/artifacts
45
+ # .idea/compiler.xml
46
+ # .idea/jarRepositories.xml
47
+ # .idea/modules.xml
48
+ # .idea/*.iml
49
+ # .idea/modules
50
+ # *.iml
51
+ # *.ipr
52
+
53
+ # CMake
54
+ cmake-build-*/
55
+
56
+ # Mongo Explorer plugin
57
+ .idea/**/mongoSettings.xml
58
+
59
+ # File-based project format
60
+ *.iws
61
+
62
+ # IntelliJ
63
+ out/
64
+
65
+ # mpeltonen/sbt-idea plugin
66
+ .idea_modules/
67
+
68
+ # JIRA plugin
69
+ atlassian-ide-plugin.xml
70
+
71
+ # Cursive Clojure plugin
72
+ .idea/replstate.xml
73
+
74
+ # Crashlytics plugin (for Android Studio and IntelliJ)
75
+ com_crashlytics_export_strings.xml
76
+ crashlytics.properties
77
+ crashlytics-build.properties
78
+ fabric.properties
79
+
80
+ # Editor-based Rest Client
81
+ .idea/httpRequests
82
+
83
+ # Android studio 3.1+ serialized cache file
84
+ .idea/caches/build_file_checksums.ser
85
+
86
+ ### JetBrains+all Patch ###
87
+ # Ignores the whole .idea folder and all .iml files
88
+ # See https://github.com/joeblau/gitignore.io/issues/186 and https://github.com/joeblau/gitignore.io/issues/360
89
+
90
+ .idea/
91
+
92
+ # Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-249601023
93
+
94
+ *.iml
95
+ modules.xml
96
+ .idea/misc.xml
97
+ *.ipr
98
+
99
+ # Sonarlint plugin
100
+ .idea/sonarlint
101
+
102
+ ### JupyterNotebooks ###
103
+ # gitignore template for Jupyter Notebooks
104
+ # website: http://jupyter.org/
105
+
106
+ .ipynb_checkpoints
107
+ */.ipynb_checkpoints/*
108
+
109
+ # IPython
110
+ profile_default/
111
+ ipython_config.py
112
+
113
+ # Remove previous ipynb_checkpoints
114
+ # git rm -r .ipynb_checkpoints/
115
+
116
+ ### Linux ###
117
+ *~
118
+
119
+ # temporary files which can be created if a process still has a handle open of a deleted file
120
+ .fuse_hidden*
121
+
122
+ # KDE directory preferences
123
+ .directory
124
+
125
+ # Linux trash folder which might appear on any partition or disk
126
+ .Trash-*
127
+
128
+ # .nfs files are created when an open file is removed but is still being accessed
129
+ .nfs*
130
+
131
+ ### macOS ###
132
+ # General
133
+ .DS_Store
134
+ .AppleDouble
135
+ .LSOverride
136
+
137
+ # Icon must end with two \r
138
+ Icon
139
+
140
+
141
+ # Thumbnails
142
+ ._*
143
+
144
+ # Files that might appear in the root of a volume
145
+ .DocumentRevisions-V100
146
+ .fseventsd
147
+ .Spotlight-V100
148
+ .TemporaryItems
149
+ .Trashes
150
+ .VolumeIcon.icns
151
+ .com.apple.timemachine.donotpresent
152
+
153
+ # Directories potentially created on remote AFP share
154
+ .AppleDB
155
+ .AppleDesktop
156
+ Network Trash Folder
157
+ Temporary Items
158
+ .apdisk
159
+
160
+ ### Python ###
161
+ # Byte-compiled / optimized / DLL files
162
+ __pycache__/
163
+ *.py[cod]
164
+ *$py.class
165
+
166
+ # C extensions
167
+ *.so
168
+
169
+ # Distribution / packaging
170
+ .Python
171
+ build/
172
+ develop-eggs/
173
+ dist/
174
+ downloads/
175
+ eggs/
176
+ .eggs/
177
+ lib/
178
+ lib64/
179
+ parts/
180
+ sdist/
181
+ var/
182
+ wheels/
183
+ pip-wheel-metadata/
184
+ share/python-wheels/
185
+ *.egg-info/
186
+ .installed.cfg
187
+ *.egg
188
+ MANIFEST
189
+
190
+ # PyInstaller
191
+ # Usually these files are written by a python script from a template
192
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
193
+ *.manifest
194
+ *.spec
195
+
196
+ # Installer logs
197
+ pip-log.txt
198
+ pip-delete-this-directory.txt
199
+
200
+ # Unit test / coverage reports
201
+ htmlcov/
202
+ .tox/
203
+ .nox/
204
+ .coverage
205
+ .coverage.*
206
+ .cache
207
+ nosetests.xml
208
+ coverage.xml
209
+ *.cover
210
+ *.py,cover
211
+ .hypothesis/
212
+ .pytest_cache/
213
+ pytestdebug.log
214
+
215
+ # Translations
216
+ *.mo
217
+ *.pot
218
+
219
+ # Django stuff:
220
+ *.log
221
+ local_settings.py
222
+ db.sqlite3
223
+ db.sqlite3-journal
224
+
225
+ # Flask stuff:
226
+ instance/
227
+ .webassets-cache
228
+
229
+ # Scrapy stuff:
230
+ .scrapy
231
+
232
+ # Sphinx documentation
233
+ docs/_build/
234
+ doc/_build/
235
+
236
+ # PyBuilder
237
+ target/
238
+
239
+ # Jupyter Notebook
240
+
241
+ # IPython
242
+
243
+ # pyenv
244
+ .python-version
245
+
246
+ # pipenv
247
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
248
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
249
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
250
+ # install all needed dependencies.
251
+ #Pipfile.lock
252
+
253
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
254
+ __pypackages__/
255
+
256
+ # Celery stuff
257
+ celerybeat-schedule
258
+ celerybeat.pid
259
+
260
+ # SageMath parsed files
261
+ *.sage.py
262
+
263
+ # Environments
264
+ .env
265
+ .venv
266
+ env/
267
+ venv/
268
+ ENV/
269
+ env.bak/
270
+ venv.bak/
271
+ pythonenv*
272
+
273
+ # Spyder project settings
274
+ .spyderproject
275
+ .spyproject
276
+
277
+ # Rope project settings
278
+ .ropeproject
279
+
280
+ # mkdocs documentation
281
+ /site
282
+
283
+ # mypy
284
+ .mypy_cache/
285
+ .dmypy.json
286
+ dmypy.json
287
+
288
+ # Pyre type checker
289
+ .pyre/
290
+
291
+ # pytype static type analyzer
292
+ .pytype/
293
+
294
+ # profiling data
295
+ .prof
296
+
297
+ ### vscode ###
298
+ .vscode
299
+ .vscode/*
300
+ !.vscode/settings.json
301
+ !.vscode/tasks.json
302
+ !.vscode/launch.json
303
+ !.vscode/extensions.json
304
+ *.code-workspace
305
+
306
+ ### Windows ###
307
+ # Windows thumbnail cache files
308
+ Thumbs.db
309
+ Thumbs.db:encryptable
310
+ ehthumbs.db
311
+ ehthumbs_vista.db
312
+
313
+ # Dump file
314
+ *.stackdump
315
+
316
+ # Folder config file
317
+ [Dd]esktop.ini
318
+
319
+ # Recycle Bin used on file shares
320
+ $RECYCLE.BIN/
321
+
322
+ # Windows Installer files
323
+ *.cab
324
+ *.msi
325
+ *.msix
326
+ *.msm
327
+ *.msp
328
+
329
+ # Windows shortcuts
330
+ *.lnk
331
+
332
+ # End of https://www.toptal.com/developers/gitignore/api/jetbrains+all,vscode,python,jupyternotebooks,linux,windows,macos
MANIFEST.in ADDED
@@ -0,0 +1 @@
 
 
1
+ include requirements.txt
README.md CHANGED
@@ -1,13 +1 @@
1
- ---
2
- title: Relik
3
- emoji: 🐨
4
- colorFrom: gray
5
- colorTo: pink
6
- sdk: streamlit
7
- sdk_version: 1.27.2
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ # relik
 
 
 
 
 
 
 
 
 
 
 
 
SETUP.cfg ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ [metadata]
2
+ description-file = README.md
3
+
4
+ [build]
5
+ build-base = /tmp/build
6
+
7
+ [egg_info]
8
+ egg-base = /tmp
app.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import time
4
+ from pathlib import Path
5
+
6
+ import requests
7
+ import streamlit as st
8
+ from spacy import displacy
9
+ from streamlit_extras.badges import badge
10
+ from streamlit_extras.stylable_container import stylable_container
11
+
12
+ # RELIK = os.getenv("RELIK", "localhost:8000/api/entities")
13
+
14
+ import random
15
+
16
+ from relik.inference.annotator import Relik
17
+
18
+
19
+ def get_random_color(ents):
20
+ colors = {}
21
+ random_colors = generate_pastel_colors(len(ents))
22
+ for ent in ents:
23
+ colors[ent] = random_colors.pop(random.randint(0, len(random_colors) - 1))
24
+ return colors
25
+
26
+
27
+ def floatrange(start, stop, steps):
28
+ if int(steps) == 1:
29
+ return [stop]
30
+ return [
31
+ start + float(i) * (stop - start) / (float(steps) - 1) for i in range(steps)
32
+ ]
33
+
34
+
35
+ def hsl_to_rgb(h, s, l):
36
+ def hue_2_rgb(v1, v2, v_h):
37
+ while v_h < 0.0:
38
+ v_h += 1.0
39
+ while v_h > 1.0:
40
+ v_h -= 1.0
41
+ if 6 * v_h < 1.0:
42
+ return v1 + (v2 - v1) * 6.0 * v_h
43
+ if 2 * v_h < 1.0:
44
+ return v2
45
+ if 3 * v_h < 2.0:
46
+ return v1 + (v2 - v1) * ((2.0 / 3.0) - v_h) * 6.0
47
+ return v1
48
+
49
+ # if not (0 <= s <= 1): raise ValueError, "s (saturation) parameter must be between 0 and 1."
50
+ # if not (0 <= l <= 1): raise ValueError, "l (lightness) parameter must be between 0 and 1."
51
+
52
+ r, b, g = (l * 255,) * 3
53
+ if s != 0.0:
54
+ if l < 0.5:
55
+ var_2 = l * (1.0 + s)
56
+ else:
57
+ var_2 = (l + s) - (s * l)
58
+ var_1 = 2.0 * l - var_2
59
+ r = 255 * hue_2_rgb(var_1, var_2, h + (1.0 / 3.0))
60
+ g = 255 * hue_2_rgb(var_1, var_2, h)
61
+ b = 255 * hue_2_rgb(var_1, var_2, h - (1.0 / 3.0))
62
+
63
+ return int(round(r)), int(round(g)), int(round(b))
64
+
65
+
66
+ def generate_pastel_colors(n):
67
+ """Return different pastel colours.
68
+
69
+ Input:
70
+ n (integer) : The number of colors to return
71
+
72
+ Output:
73
+ A list of colors in HTML notation (eg.['#cce0ff', '#ffcccc', '#ccffe0', '#f5ccff', '#f5ffcc'])
74
+
75
+ Example:
76
+ >>> print generate_pastel_colors(5)
77
+ ['#cce0ff', '#f5ccff', '#ffcccc', '#f5ffcc', '#ccffe0']
78
+ """
79
+ if n == 0:
80
+ return []
81
+
82
+ # To generate colors, we use the HSL colorspace (see http://en.wikipedia.org/wiki/HSL_color_space)
83
+ start_hue = 0.6 # 0=red 1/3=0.333=green 2/3=0.666=blue
84
+ saturation = 1.0
85
+ lightness = 0.8
86
+ # We take points around the chromatic circle (hue):
87
+ # (Note: we generate n+1 colors, then drop the last one ([:-1]) because
88
+ # it equals the first one (hue 0 = hue 1))
89
+ return [
90
+ "#%02x%02x%02x" % hsl_to_rgb(hue, saturation, lightness)
91
+ for hue in floatrange(start_hue, start_hue + 1, n + 1)
92
+ ][:-1]
93
+
94
+
95
+ def set_sidebar(css):
96
+ white_link_wrapper = "<link rel='stylesheet' href='https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.2/css/all.min.css'><a href='{}'>{}</a>"
97
+ with st.sidebar:
98
+ st.markdown(f"<style>{css}</style>", unsafe_allow_html=True)
99
+ st.image(
100
+ "http://nlp.uniroma1.it/static/website/sapienza-nlp-logo-wh.svg",
101
+ use_column_width=True,
102
+ )
103
+ st.markdown("## ReLiK")
104
+ st.write(
105
+ f"""
106
+ - {white_link_wrapper.format("#", "<i class='fa-solid fa-file'></i>&nbsp; Paper")}
107
+ - {white_link_wrapper.format("https://github.com/SapienzaNLP/relik", "<i class='fa-brands fa-github'></i>&nbsp; GitHub")}
108
+ - {white_link_wrapper.format("https://hub.docker.com/repository/docker/sapienzanlp/relik", "<i class='fa-brands fa-docker'></i>&nbsp; Docker Hub")}
109
+ """,
110
+ unsafe_allow_html=True,
111
+ )
112
+ st.markdown("## Sapienza NLP")
113
+ st.write(
114
+ f"""
115
+ - {white_link_wrapper.format("https://nlp.uniroma1.it", "<i class='fa-solid fa-globe'></i>&nbsp; Webpage")}
116
+ - {white_link_wrapper.format("https://github.com/SapienzaNLP", "<i class='fa-brands fa-github'></i>&nbsp; GitHub")}
117
+ - {white_link_wrapper.format("https://twitter.com/SapienzaNLP", "<i class='fa-brands fa-twitter'></i>&nbsp; Twitter")}
118
+ - {white_link_wrapper.format("https://www.linkedin.com/company/79434450", "<i class='fa-brands fa-linkedin'></i>&nbsp; LinkedIn")}
119
+ """,
120
+ unsafe_allow_html=True,
121
+ )
122
+
123
+
124
+ def get_el_annotations(response):
125
+ # swap labels key with ents
126
+ dict_of_ents = {"text": response.text, "ents": []}
127
+ dict_of_ents["ents"] = response.labels
128
+ label_in_text = set(l["label"] for l in dict_of_ents["ents"])
129
+ options = {"ents": label_in_text, "colors": get_random_color(label_in_text)}
130
+ return dict_of_ents, options
131
+
132
+
133
+ def set_intro(css):
134
+ # intro
135
+ st.markdown("# ReLik")
136
+ st.markdown(
137
+ "### Retrieve, Read and LinK: Fast and Accurate Entity Linking and Relation Extraction on an Academic Budget"
138
+ )
139
+ # st.markdown(
140
+ # "This is a front-end for the paper [Universal Semantic Annotator: the First Unified API "
141
+ # "for WSD, SRL and Semantic Parsing](https://www.researchgate.net/publication/360671045_Universal_Semantic_Annotator_the_First_Unified_API_for_WSD_SRL_and_Semantic_Parsing), which will be presented at LREC 2022 by "
142
+ # "[Riccardo Orlando](https://riccorl.github.io), [Simone Conia](https://c-simone.github.io/), "
143
+ # "[Stefano Faralli](https://corsidilaurea.uniroma1.it/it/users/stefanofaralliuniroma1it), and [Roberto Navigli](https://www.diag.uniroma1.it/navigli/)."
144
+ # )
145
+ badge(type="github", name="sapienzanlp/relik")
146
+ badge(type="pypi", name="relik")
147
+
148
+
149
+ def run_client():
150
+ with open(Path(__file__).parent / "style.css") as f:
151
+ css = f.read()
152
+
153
+ st.set_page_config(
154
+ page_title="ReLik",
155
+ page_icon="🦮",
156
+ layout="wide",
157
+ )
158
+ set_sidebar(css)
159
+ set_intro(css)
160
+
161
+ # text input
162
+ text = st.text_area(
163
+ "Enter Text Below:",
164
+ value="Obama went to Rome for a quick vacation.",
165
+ height=200,
166
+ max_chars=500,
167
+ )
168
+
169
+ with stylable_container(
170
+ key="annotate_button",
171
+ css_styles="""
172
+ button {
173
+ background-color: #802433;
174
+ color: white;
175
+ border-radius: 25px;
176
+ }
177
+ """,
178
+ ):
179
+ submit = st.button("Annotate")
180
+ # submit = st.button("Run")
181
+
182
+ relik = Relik(
183
+ question_encoder="riccorl/relik-retriever-aida-blink-pretrain-omniencoder",
184
+ document_index="riccorl/index-relik-retriever-aida-blink-pretrain-omniencoder",
185
+ reader="riccorl/relik-reader-aida-deberta-small",
186
+ top_k=100,
187
+ window_size=32,
188
+ window_stride=16,
189
+ candidates_preprocessing_fn="relik.inference.preprocessing.wikipedia_title_and_openings_preprocessing",
190
+ )
191
+
192
+ # ReLik API call
193
+ if submit:
194
+ text = text.strip()
195
+ if text:
196
+ st.markdown("####")
197
+ st.markdown("#### Entity Linking")
198
+ with st.spinner(text="In progress"):
199
+ response = relik(text)
200
+ # response = requests.post(RELIK, json=text)
201
+ # if response.status_code != 200:
202
+ # st.error("Error: {}".format(response.status_code))
203
+ # else:
204
+ # response = response.json()
205
+
206
+ # Entity Linking
207
+ # with stylable_container(
208
+ # key="container_with_border",
209
+ # css_styles="""
210
+ # {
211
+ # border: 1px solid rgba(49, 51, 63, 0.2);
212
+ # border-radius: 0.5rem;
213
+ # padding: 0.5rem;
214
+ # padding-bottom: 2rem;
215
+ # }
216
+ # """,
217
+ # ):
218
+ # st.markdown("##")
219
+ dict_of_ents, options = get_el_annotations(response=response)
220
+ display = displacy.render(
221
+ dict_of_ents, manual=True, style="ent", options=options
222
+ )
223
+ display = display.replace("\n", " ")
224
+ # wsd_display = re.sub(
225
+ # r"(wiki::\d+\w)",
226
+ # r"<a href='https://babelnet.org/synset?id=\g<1>&orig=\g<1>&lang={}'>\g<1></a>".format(
227
+ # language.upper()
228
+ # ),
229
+ # wsd_display,
230
+ # )
231
+ with st.container():
232
+ st.write(display, unsafe_allow_html=True)
233
+
234
+ st.markdown("####")
235
+ st.markdown("#### Relation Extraction")
236
+
237
+ with st.container():
238
+ st.write("Coming :)", unsafe_allow_html=True)
239
+
240
+ else:
241
+ st.error("Please enter some text.")
242
+
243
+
244
+ if __name__ == "__main__":
245
+ run_client()
dockerfiles/Dockerfile.cpu ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM tiangolo/uvicorn-gunicorn:python3.10-slim
2
+
3
+ # Copy and install requirements.txt
4
+ COPY ./requirements.txt ./requirements.txt
5
+ COPY ./src /app
6
+ COPY ./scripts/start.sh /start.sh
7
+ COPY ./scripts/prestart.sh /app
8
+ COPY ./scripts/gunicorn_conf.py /gunicorn_conf.py
9
+ COPY ./scripts/start-reload.sh /start-reload.sh
10
+ COPY ./VERSION /
11
+ RUN mkdir -p /app/resources/model \
12
+ && pip install --no-cache-dir -r requirements.txt \
13
+ && chmod +x /start.sh && chmod +x /start-reload.sh
14
+ ARG MODEL_PATH
15
+ COPY ${MODEL_PATH}/* /app/resources/model/
16
+
17
+ ENV APP_MODULE=main:app
dockerfiles/Dockerfile.cuda ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:12.2.0-base-ubuntu20.04
2
+
3
+ ARG DEBIAN_FRONTEND=noninteractive
4
+
5
+ RUN apt-get update \
6
+ && apt-get install \
7
+ curl wget python3.10 \
8
+ python3.10-distutils \
9
+ python3-pip \
10
+ curl wget -y \
11
+ && rm -rf /var/lib/apt/lists/*
12
+
13
+ # FastAPI section
14
+ # device env
15
+ ENV DEVICE="cuda"
16
+ # Copy and install requirements.txt
17
+ COPY ./gpu-requirements.txt ./requirements.txt
18
+ COPY ./src /app
19
+ COPY ./scripts/start.sh /start.sh
20
+ COPY ./scripts/gunicorn_conf.py /gunicorn_conf.py
21
+ COPY ./scripts/start-reload.sh /start-reload.sh
22
+ COPY ./scripts/prestart.sh /app
23
+ COPY ./VERSION /
24
+ RUN mkdir -p /app/resources/model \
25
+ && pip install --upgrade --no-cache-dir -r requirements.txt \
26
+ && chmod +x /start.sh \
27
+ && chmod +x /start-reload.sh
28
+ ARG MODEL_NAME_OR_PATH
29
+
30
+ WORKDIR /app
31
+
32
+ ENV PYTHONPATH=/app
33
+
34
+ EXPOSE 80
35
+
36
+ # Run the start script, it will check for an /app/prestart.sh script (e.g. for migrations)
37
+ # And then will start Gunicorn with Uvicorn
38
+ CMD ["/start.sh"]
examples/train_retriever.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from relik.retriever.trainer import RetrieverTrainer
2
+ from relik import GoldenRetriever
3
+ from relik.retriever.indexers.inmemory import InMemoryDocumentIndex
4
+ from relik.retriever.data.datasets import AidaInBatchNegativesDataset
5
+
6
+ if __name__ == "__main__":
7
+ # instantiate retriever
8
+ document_index = InMemoryDocumentIndex(
9
+ documents="/root/golden-retriever-v2/data/dpr-like/el/definitions.txt",
10
+ device="cuda",
11
+ precision="16",
12
+ )
13
+ retriever = GoldenRetriever(
14
+ question_encoder="intfloat/e5-small-v2", document_index=document_index
15
+ )
16
+
17
+ train_dataset = AidaInBatchNegativesDataset(
18
+ name="aida_train",
19
+ path="/root/golden-retriever-v2/data/dpr-like/el/aida_32_tokens_topic/train.jsonl",
20
+ tokenizer=retriever.question_tokenizer,
21
+ question_batch_size=64,
22
+ passage_batch_size=400,
23
+ max_passage_length=64,
24
+ use_topics=True,
25
+ shuffle=True,
26
+ )
27
+ val_dataset = AidaInBatchNegativesDataset(
28
+ name="aida_val",
29
+ path="/root/golden-retriever-v2/data/dpr-like/el/aida_32_tokens_topic/val.jsonl",
30
+ tokenizer=retriever.question_tokenizer,
31
+ question_batch_size=64,
32
+ passage_batch_size=400,
33
+ max_passage_length=64,
34
+ use_topics=True,
35
+ )
36
+
37
+ trainer = RetrieverTrainer(
38
+ retriever=retriever,
39
+ train_dataset=train_dataset,
40
+ val_dataset=val_dataset,
41
+ max_steps=25_000,
42
+ wandb_offline_mode=True,
43
+ )
44
+
45
+ trainer.train()
pyproject.toml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.black]
2
+ include = '\.pyi?$'
3
+ exclude = '''
4
+ /(
5
+ \.git
6
+ | \.hg
7
+ | \.mypy_cache
8
+ | \.tox
9
+ | \.venv
10
+ | _build
11
+ | buck-out
12
+ | build
13
+ | dist
14
+ )/
15
+ '''
relik/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from relik.retriever.pytorch_modules.model import GoldenRetriever
relik/common/__init__.py ADDED
File without changes
relik/common/log.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import sys
3
+ import threading
4
+ from typing import Optional
5
+
6
+ from rich import get_console
7
+
8
+ _lock = threading.Lock()
9
+ _default_handler: Optional[logging.Handler] = None
10
+
11
+ _default_log_level = logging.WARNING
12
+
13
+ # fancy logger
14
+ _console = get_console()
15
+
16
+
17
+ def _get_library_name() -> str:
18
+ return __name__.split(".")[0]
19
+
20
+
21
+ def _get_library_root_logger() -> logging.Logger:
22
+ return logging.getLogger(_get_library_name())
23
+
24
+
25
+ def _configure_library_root_logger() -> None:
26
+ global _default_handler
27
+
28
+ with _lock:
29
+ if _default_handler:
30
+ # This library has already configured the library root logger.
31
+ return
32
+ _default_handler = logging.StreamHandler() # Set sys.stderr as stream.
33
+ _default_handler.flush = sys.stderr.flush
34
+
35
+ # Apply our default configuration to the library root logger.
36
+ library_root_logger = _get_library_root_logger()
37
+ library_root_logger.addHandler(_default_handler)
38
+ library_root_logger.setLevel(_default_log_level)
39
+ library_root_logger.propagate = False
40
+
41
+
42
+ def _reset_library_root_logger() -> None:
43
+ global _default_handler
44
+
45
+ with _lock:
46
+ if not _default_handler:
47
+ return
48
+
49
+ library_root_logger = _get_library_root_logger()
50
+ library_root_logger.removeHandler(_default_handler)
51
+ library_root_logger.setLevel(logging.NOTSET)
52
+ _default_handler = None
53
+
54
+
55
+ def set_log_level(level: int, logger: logging.Logger = None) -> None:
56
+ """
57
+ Set the log level.
58
+ Args:
59
+ level (:obj:`int`):
60
+ Logging level.
61
+ logger (:obj:`logging.Logger`):
62
+ Logger to set the log level.
63
+ """
64
+ if not logger:
65
+ _configure_library_root_logger()
66
+ logger = _get_library_root_logger()
67
+ logger.setLevel(level)
68
+
69
+
70
+ def get_logger(
71
+ name: Optional[str] = None,
72
+ level: Optional[int] = None,
73
+ formatter: Optional[str] = None,
74
+ ) -> logging.Logger:
75
+ """
76
+ Return a logger with the specified name.
77
+ """
78
+
79
+ if name is None:
80
+ name = _get_library_name()
81
+
82
+ _configure_library_root_logger()
83
+
84
+ if level is not None:
85
+ set_log_level(level)
86
+
87
+ if formatter is None:
88
+ formatter = logging.Formatter(
89
+ "%(asctime)s - %(levelname)s - %(name)s - %(message)s"
90
+ )
91
+ _default_handler.setFormatter(formatter)
92
+
93
+ return logging.getLogger(name)
94
+
95
+
96
+ def get_console_logger():
97
+ return _console
relik/common/upload.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import logging
4
+ import os
5
+ import tempfile
6
+ import zipfile
7
+ from datetime import datetime
8
+ from pathlib import Path
9
+ from typing import Optional, Union
10
+
11
+ import huggingface_hub
12
+
13
+ from relik.common.log import get_logger
14
+ from relik.common.utils import SAPIENZANLP_DATE_FORMAT, get_md5
15
+
16
+ logger = get_logger(level=logging.DEBUG)
17
+
18
+
19
+ def create_info_file(tmpdir: Path):
20
+ logger.debug("Computing md5 of model.zip")
21
+ md5 = get_md5(tmpdir / "model.zip")
22
+ date = datetime.now().strftime(SAPIENZANLP_DATE_FORMAT)
23
+
24
+ logger.debug("Dumping info.json file")
25
+ with (tmpdir / "info.json").open("w") as f:
26
+ json.dump(dict(md5=md5, upload_date=date), f, indent=2)
27
+
28
+
29
+ def zip_run(
30
+ dir_path: Union[str, os.PathLike],
31
+ tmpdir: Union[str, os.PathLike],
32
+ zip_name: str = "model.zip",
33
+ ) -> Path:
34
+ logger.debug(f"zipping {dir_path} to {tmpdir}")
35
+ # creates a zip version of the provided dir_path
36
+ run_dir = Path(dir_path)
37
+ zip_path = tmpdir / zip_name
38
+
39
+ with zipfile.ZipFile(zip_path, "w") as zip_file:
40
+ # fully zip the run directory maintaining its structure
41
+ for file in run_dir.rglob("*.*"):
42
+ if file.is_dir():
43
+ continue
44
+
45
+ zip_file.write(file, arcname=file.relative_to(run_dir))
46
+
47
+ return zip_path
48
+
49
+
50
+ def upload(
51
+ model_dir: Union[str, os.PathLike],
52
+ model_name: str,
53
+ organization: Optional[str] = None,
54
+ repo_name: Optional[str] = None,
55
+ commit: Optional[str] = None,
56
+ archive: bool = False,
57
+ ):
58
+ token = huggingface_hub.HfFolder.get_token()
59
+ if token is None:
60
+ print(
61
+ "No HuggingFace token found. You need to execute `huggingface-cli login` first!"
62
+ )
63
+ return
64
+
65
+ repo_id = repo_name or model_name
66
+ if organization is not None:
67
+ repo_id = f"{organization}/{repo_id}"
68
+ with tempfile.TemporaryDirectory() as tmpdir:
69
+ api = huggingface_hub.HfApi()
70
+ repo_url = api.create_repo(
71
+ token=token,
72
+ repo_id=repo_id,
73
+ exist_ok=True,
74
+ )
75
+ repo = huggingface_hub.Repository(
76
+ str(tmpdir), clone_from=repo_url, use_auth_token=token
77
+ )
78
+
79
+ tmp_path = Path(tmpdir)
80
+ if archive:
81
+ # otherwise we zip the model_dir
82
+ logger.debug(f"Zipping {model_dir} to {tmp_path}")
83
+ zip_run(model_dir, tmp_path)
84
+ create_info_file(tmp_path)
85
+ else:
86
+ # if the user wants to upload a transformers model, we don't need to zip it
87
+ # we just need to copy the files to the tmpdir
88
+ logger.debug(f"Copying {model_dir} to {tmpdir}")
89
+ os.system(f"cp -r {model_dir}/* {tmpdir}")
90
+
91
+ # this method automatically puts large files (>10MB) into git lfs
92
+ repo.push_to_hub(commit_message=commit or "Automatic push from sapienzanlp")
93
+
94
+
95
+ def parse_args() -> argparse.Namespace:
96
+ parser = argparse.ArgumentParser()
97
+ parser.add_argument(
98
+ "model_dir", help="The directory of the model you want to upload"
99
+ )
100
+ parser.add_argument("model_name", help="The model you want to upload")
101
+ parser.add_argument(
102
+ "--organization",
103
+ help="the name of the organization where you want to upload the model",
104
+ )
105
+ parser.add_argument(
106
+ "--repo_name",
107
+ help="Optional name to use when uploading to the HuggingFace repository",
108
+ )
109
+ parser.add_argument(
110
+ "--commit", help="Commit message to use when pushing to the HuggingFace Hub"
111
+ )
112
+ parser.add_argument(
113
+ "--archive",
114
+ action="store_true",
115
+ help="""
116
+ Whether to compress the model directory before uploading it.
117
+ If True, the model directory will be zipped and the zip file will be uploaded.
118
+ If False, the model directory will be uploaded as is.""",
119
+ )
120
+ return parser.parse_args()
121
+
122
+
123
+ def main():
124
+ upload(**vars(parse_args()))
125
+
126
+
127
+ if __name__ == "__main__":
128
+ main()
relik/common/utils.py ADDED
@@ -0,0 +1,609 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib.util
2
+ import json
3
+ import logging
4
+ import os
5
+ import shutil
6
+ import tarfile
7
+ import tempfile
8
+ from functools import partial
9
+ from hashlib import sha256
10
+ from pathlib import Path
11
+ from typing import Any, BinaryIO, Dict, List, Optional, Union
12
+ from urllib.parse import urlparse
13
+ from zipfile import ZipFile, is_zipfile
14
+
15
+ import huggingface_hub
16
+ import requests
17
+ import tqdm
18
+ from filelock import FileLock
19
+ from transformers.utils.hub import cached_file as hf_cached_file
20
+
21
+ from relik.common.log import get_logger
22
+
23
+ # name constants
24
+ WEIGHTS_NAME = "weights.pt"
25
+ ONNX_WEIGHTS_NAME = "weights.onnx"
26
+ CONFIG_NAME = "config.yaml"
27
+ LABELS_NAME = "labels.json"
28
+
29
+ # SAPIENZANLP_USER_NAME = "sapienzanlp"
30
+ SAPIENZANLP_USER_NAME = "riccorl"
31
+ SAPIENZANLP_HF_MODEL_REPO_URL = "riccorl/{model_id}"
32
+ SAPIENZANLP_HF_MODEL_REPO_ARCHIVE_URL = (
33
+ f"{SAPIENZANLP_HF_MODEL_REPO_URL}/resolve/main/model.zip"
34
+ )
35
+ # path constants
36
+ SAPIENZANLP_CACHE_DIR = os.getenv("SAPIENZANLP_CACHE_DIR", Path.home() / ".sapienzanlp")
37
+ SAPIENZANLP_DATE_FORMAT = "%Y-%m-%d %H-%M-%S"
38
+
39
+
40
+ logger = get_logger(__name__)
41
+
42
+
43
+ def sapienzanlp_model_urls(model_id: str) -> str:
44
+ """
45
+ Returns the URL for a possible SapienzaNLP valid model.
46
+
47
+ Args:
48
+ model_id (:obj:`str`):
49
+ A SapienzaNLP model id.
50
+
51
+ Returns:
52
+ :obj:`str`: The url for the model id.
53
+ """
54
+ # check if there is already the namespace of the user
55
+ if "/" in model_id:
56
+ return model_id
57
+ return SAPIENZANLP_HF_MODEL_REPO_URL.format(model_id=model_id)
58
+
59
+
60
+ def is_package_available(package_name: str) -> bool:
61
+ """
62
+ Check if a package is available.
63
+
64
+ Args:
65
+ package_name (`str`): The name of the package to check.
66
+ """
67
+ return importlib.util.find_spec(package_name) is not None
68
+
69
+
70
+ def load_json(path: Union[str, Path]) -> Any:
71
+ """
72
+ Load a json file provided in input.
73
+
74
+ Args:
75
+ path (`Union[str, Path]`): The path to the json file to load.
76
+
77
+ Returns:
78
+ `Any`: The loaded json file.
79
+ """
80
+ with open(path, encoding="utf8") as f:
81
+ return json.load(f)
82
+
83
+
84
+ def dump_json(document: Any, path: Union[str, Path], indent: Optional[int] = None):
85
+ """
86
+ Dump input to json file.
87
+
88
+ Args:
89
+ document (`Any`): The document to dump.
90
+ path (`Union[str, Path]`): The path to dump the document to.
91
+ indent (`Optional[int]`): The indent to use for the json file.
92
+
93
+ """
94
+ with open(path, "w", encoding="utf8") as outfile:
95
+ json.dump(document, outfile, indent=indent)
96
+
97
+
98
+ def get_md5(path: Path):
99
+ """
100
+ Get the MD5 value of a path.
101
+ """
102
+ import hashlib
103
+
104
+ with path.open("rb") as fin:
105
+ data = fin.read()
106
+ return hashlib.md5(data).hexdigest()
107
+
108
+
109
+ def file_exists(path: Union[str, os.PathLike]) -> bool:
110
+ """
111
+ Check if the file at :obj:`path` exists.
112
+
113
+ Args:
114
+ path (:obj:`str`, :obj:`os.PathLike`):
115
+ Path to check.
116
+
117
+ Returns:
118
+ :obj:`bool`: :obj:`True` if the file exists.
119
+ """
120
+ return Path(path).exists()
121
+
122
+
123
+ def dir_exists(path: Union[str, os.PathLike]) -> bool:
124
+ """
125
+ Check if the directory at :obj:`path` exists.
126
+
127
+ Args:
128
+ path (:obj:`str`, :obj:`os.PathLike`):
129
+ Path to check.
130
+
131
+ Returns:
132
+ :obj:`bool`: :obj:`True` if the directory exists.
133
+ """
134
+ return Path(path).is_dir()
135
+
136
+
137
+ def is_remote_url(url_or_filename: Union[str, Path]):
138
+ """
139
+ Returns :obj:`True` if the input path is an url.
140
+
141
+ Args:
142
+ url_or_filename (:obj:`str`, :obj:`Path`):
143
+ path to check.
144
+
145
+ Returns:
146
+ :obj:`bool`: :obj:`True` if the input path is an url, :obj:`False` otherwise.
147
+
148
+ """
149
+ if isinstance(url_or_filename, Path):
150
+ url_or_filename = str(url_or_filename)
151
+ parsed = urlparse(url_or_filename)
152
+ return parsed.scheme in ("http", "https")
153
+
154
+
155
+ def url_to_filename(resource: str, etag: str = None) -> str:
156
+ """
157
+ Convert a `resource` into a hashed filename in a repeatable way.
158
+ If `etag` is specified, append its hash to the resources's, delimited
159
+ by a period.
160
+ """
161
+ resource_bytes = resource.encode("utf-8")
162
+ resource_hash = sha256(resource_bytes)
163
+ filename = resource_hash.hexdigest()
164
+
165
+ if etag:
166
+ etag_bytes = etag.encode("utf-8")
167
+ etag_hash = sha256(etag_bytes)
168
+ filename += "." + etag_hash.hexdigest()
169
+
170
+ return filename
171
+
172
+
173
+ def download_resource(
174
+ url: str,
175
+ temp_file: BinaryIO,
176
+ headers=None,
177
+ ):
178
+ """
179
+ Download remote file.
180
+ """
181
+
182
+ if headers is None:
183
+ headers = {}
184
+
185
+ r = requests.get(url, stream=True, headers=headers)
186
+ r.raise_for_status()
187
+ content_length = r.headers.get("Content-Length")
188
+ total = int(content_length) if content_length is not None else None
189
+ progress = tqdm(
190
+ unit="B",
191
+ unit_scale=True,
192
+ total=total,
193
+ desc="Downloading",
194
+ disable=logger.level in [logging.NOTSET],
195
+ )
196
+ for chunk in r.iter_content(chunk_size=1024):
197
+ if chunk: # filter out keep-alive new chunks
198
+ progress.update(len(chunk))
199
+ temp_file.write(chunk)
200
+ progress.close()
201
+
202
+
203
+ def download_and_cache(
204
+ url: Union[str, Path],
205
+ cache_dir: Union[str, Path] = None,
206
+ force_download: bool = False,
207
+ ):
208
+ if cache_dir is None:
209
+ cache_dir = SAPIENZANLP_CACHE_DIR
210
+ if isinstance(url, Path):
211
+ url = str(url)
212
+
213
+ # check if cache dir exists
214
+ Path(cache_dir).mkdir(parents=True, exist_ok=True)
215
+
216
+ # check if file is private
217
+ headers = {}
218
+ try:
219
+ r = requests.head(url, allow_redirects=False, timeout=10)
220
+ r.raise_for_status()
221
+ except requests.exceptions.HTTPError:
222
+ if r.status_code == 401:
223
+ hf_token = huggingface_hub.HfFolder.get_token()
224
+ if hf_token is None:
225
+ raise ValueError(
226
+ "You need to login to HuggingFace to download this model "
227
+ "(use the `huggingface-cli login` command)"
228
+ )
229
+ headers["Authorization"] = f"Bearer {hf_token}"
230
+
231
+ etag = None
232
+ try:
233
+ r = requests.head(url, allow_redirects=True, timeout=10, headers=headers)
234
+ r.raise_for_status()
235
+ etag = r.headers.get("X-Linked-Etag") or r.headers.get("ETag")
236
+ # We favor a custom header indicating the etag of the linked resource, and
237
+ # we fallback to the regular etag header.
238
+ # If we don't have any of those, raise an error.
239
+ if etag is None:
240
+ raise OSError(
241
+ "Distant resource does not have an ETag, we won't be able to reliably ensure reproducibility."
242
+ )
243
+ # In case of a redirect,
244
+ # save an extra redirect on the request.get call,
245
+ # and ensure we download the exact atomic version even if it changed
246
+ # between the HEAD and the GET (unlikely, but hey).
247
+ if 300 <= r.status_code <= 399:
248
+ url = r.headers["Location"]
249
+ except (requests.exceptions.SSLError, requests.exceptions.ProxyError):
250
+ # Actually raise for those subclasses of ConnectionError
251
+ raise
252
+ except (requests.exceptions.ConnectionError, requests.exceptions.Timeout):
253
+ # Otherwise, our Internet connection is down.
254
+ # etag is None
255
+ pass
256
+
257
+ # get filename from the url
258
+ filename = url_to_filename(url, etag)
259
+ # get cache path to put the file
260
+ cache_path = cache_dir / filename
261
+
262
+ # the file is already here, return it
263
+ if file_exists(cache_path) and not force_download:
264
+ logger.info(
265
+ f"{url} found in cache, set `force_download=True` to force the download"
266
+ )
267
+ return cache_path
268
+
269
+ cache_path = str(cache_path)
270
+ # Prevent parallel downloads of the same file with a lock.
271
+ lock_path = cache_path + ".lock"
272
+ with FileLock(lock_path):
273
+ # If the download just completed while the lock was activated.
274
+ if file_exists(cache_path) and not force_download:
275
+ # Even if returning early like here, the lock will be released.
276
+ return cache_path
277
+
278
+ temp_file_manager = partial(
279
+ tempfile.NamedTemporaryFile, mode="wb", dir=cache_dir, delete=False
280
+ )
281
+
282
+ # Download to temporary file, then copy to cache dir once finished.
283
+ # Otherwise, you get corrupt cache entries if the download gets interrupted.
284
+ with temp_file_manager() as temp_file:
285
+ logger.info(
286
+ f"{url} not found in cache or `force_download` set to `True`, downloading to {temp_file.name}"
287
+ )
288
+ download_resource(url, temp_file, headers)
289
+
290
+ logger.info(f"storing {url} in cache at {cache_path}")
291
+ os.replace(temp_file.name, cache_path)
292
+
293
+ # NamedTemporaryFile creates a file with hardwired 0600 perms (ignoring umask), so fixing it.
294
+ umask = os.umask(0o666)
295
+ os.umask(umask)
296
+ os.chmod(cache_path, 0o666 & ~umask)
297
+
298
+ logger.info(f"creating metadata file for {cache_path}")
299
+ meta = {"url": url} # , "etag": etag}
300
+ meta_path = cache_path + ".json"
301
+ with open(meta_path, "w") as meta_file:
302
+ json.dump(meta, meta_file)
303
+
304
+ return cache_path
305
+
306
+
307
+ def download_from_hf(
308
+ path_or_repo_id: Union[str, Path],
309
+ filenames: Optional[List[str]],
310
+ cache_dir: Union[str, Path] = None,
311
+ force_download: bool = False,
312
+ resume_download: bool = False,
313
+ proxies: Optional[Dict[str, str]] = None,
314
+ use_auth_token: Optional[Union[bool, str]] = None,
315
+ revision: Optional[str] = None,
316
+ local_files_only: bool = False,
317
+ subfolder: str = "",
318
+ ):
319
+ if isinstance(path_or_repo_id, Path):
320
+ path_or_repo_id = str(path_or_repo_id)
321
+
322
+ downloaded_paths = []
323
+ for filename in filenames:
324
+ downloaded_path = hf_cached_file(
325
+ path_or_repo_id,
326
+ filename,
327
+ cache_dir=cache_dir,
328
+ force_download=force_download,
329
+ proxies=proxies,
330
+ resume_download=resume_download,
331
+ use_auth_token=use_auth_token,
332
+ revision=revision,
333
+ local_files_only=local_files_only,
334
+ subfolder=subfolder,
335
+ )
336
+ downloaded_paths.append(downloaded_path)
337
+
338
+ # we want the folder where the files are downloaded
339
+ # the best guess is the parent folder of the first file
340
+ probably_the_folder = Path(downloaded_paths[0]).parent
341
+ return probably_the_folder
342
+
343
+
344
+ def model_name_or_path_resolver(model_name_or_dir: Union[str, os.PathLike]) -> str:
345
+ """
346
+ Resolve a model name or directory to a model archive name or directory.
347
+
348
+ Args:
349
+ model_name_or_dir (:obj:`str` or :obj:`os.PathLike`):
350
+ A model name or directory.
351
+
352
+ Returns:
353
+ :obj:`str`: The model archive name or directory.
354
+ """
355
+ if is_remote_url(model_name_or_dir):
356
+ # if model_name_or_dir is a URL
357
+ # download it and try to load
358
+ model_archive = model_name_or_dir
359
+ elif Path(model_name_or_dir).is_dir() or Path(model_name_or_dir).is_file():
360
+ # if model_name_or_dir is a local directory or
361
+ # an archive file try to load it
362
+ model_archive = model_name_or_dir
363
+ else:
364
+ # probably model_name_or_dir is a sapienzanlp model id
365
+ # guess the url and try to download
366
+ model_name_or_dir_ = model_name_or_dir
367
+ # raise ValueError(f"Providing a model id is not supported yet.")
368
+ model_archive = sapienzanlp_model_urls(model_name_or_dir_)
369
+
370
+ return model_archive
371
+
372
+
373
+ def from_cache(
374
+ url_or_filename: Union[str, Path],
375
+ cache_dir: Union[str, Path] = None,
376
+ force_download: bool = False,
377
+ resume_download: bool = False,
378
+ proxies: Optional[Dict[str, str]] = None,
379
+ use_auth_token: Optional[Union[bool, str]] = None,
380
+ revision: Optional[str] = None,
381
+ local_files_only: bool = False,
382
+ subfolder: str = "",
383
+ filenames: Optional[List[str]] = None,
384
+ ) -> Path:
385
+ """
386
+ Given something that could be either a local path or a URL (or a SapienzaNLP model id),
387
+ determine which one and return a path to the corresponding file.
388
+
389
+ Args:
390
+ url_or_filename (:obj:`str` or :obj:`Path`):
391
+ A path to a local file or a URL (or a SapienzaNLP model id).
392
+ cache_dir (:obj:`str` or :obj:`Path`, `optional`):
393
+ Path to a directory in which a downloaded file will be cached.
394
+ force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
395
+ Whether or not to re-download the file even if it already exists.
396
+ resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
397
+ Whether or not to delete incompletely received files. Attempts to resume the download if such a file
398
+ exists.
399
+ proxies (:obj:`Dict[str, str]`, `optional`):
400
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
401
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
402
+ use_auth_token (:obj:`Union[bool, str]`, `optional`):
403
+ Optional string or boolean to use as Bearer token for remote files. If :obj:`True`, will get token from
404
+ :obj:`~transformers.hf_api.HfApi`. If :obj:`str`, will use that string as token.
405
+ revision (:obj:`str`, `optional`):
406
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
407
+ git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
408
+ identifier allowed by git.
409
+ local_files_only (:obj:`bool`, `optional`, defaults to :obj:`False`):
410
+ Whether or not to raise an error if the file to be downloaded is local.
411
+ subfolder (:obj:`str`, `optional`):
412
+ In case the relevant file is in a subfolder of the URL, specify it here.
413
+ filenames (:obj:`List[str]`, `optional`):
414
+ List of filenames to look for in the directory structure.
415
+
416
+ Returns:
417
+ :obj:`Path`: Path to the cached file.
418
+ """
419
+
420
+ url_or_filename = model_name_or_path_resolver(url_or_filename)
421
+
422
+ if cache_dir is None:
423
+ cache_dir = SAPIENZANLP_CACHE_DIR
424
+
425
+ if file_exists(url_or_filename):
426
+ logger.info(f"{url_or_filename} is a local path or file")
427
+ output_path = url_or_filename
428
+ elif is_remote_url(url_or_filename):
429
+ # URL, so get it from the cache (downloading if necessary)
430
+ output_path = download_and_cache(
431
+ url_or_filename,
432
+ cache_dir=cache_dir,
433
+ force_download=force_download,
434
+ )
435
+ else:
436
+ if filenames is None:
437
+ filenames = [WEIGHTS_NAME, CONFIG_NAME, LABELS_NAME]
438
+ output_path = download_from_hf(
439
+ url_or_filename,
440
+ filenames,
441
+ cache_dir,
442
+ force_download,
443
+ resume_download,
444
+ proxies,
445
+ use_auth_token,
446
+ revision,
447
+ local_files_only,
448
+ subfolder,
449
+ )
450
+
451
+ # if is_hf_hub_url(url_or_filename):
452
+ # HuggingFace Hub
453
+ # output_path = hf_hub_download_url(url_or_filename)
454
+ # elif is_remote_url(url_or_filename):
455
+ # # URL, so get it from the cache (downloading if necessary)
456
+ # output_path = download_and_cache(
457
+ # url_or_filename,
458
+ # cache_dir=cache_dir,
459
+ # force_download=force_download,
460
+ # )
461
+ # elif file_exists(url_or_filename):
462
+ # logger.info(f"{url_or_filename} is a local path or file")
463
+ # # File, and it exists.
464
+ # output_path = url_or_filename
465
+ # elif urlparse(url_or_filename).scheme == "":
466
+ # # File, but it doesn't exist.
467
+ # raise EnvironmentError(f"file {url_or_filename} not found")
468
+ # else:
469
+ # # Something unknown
470
+ # raise ValueError(
471
+ # f"unable to parse {url_or_filename} as a URL or as a local path"
472
+ # )
473
+
474
+ if dir_exists(output_path) or (
475
+ not is_zipfile(output_path) and not tarfile.is_tarfile(output_path)
476
+ ):
477
+ return Path(output_path)
478
+
479
+ # Path where we extract compressed archives
480
+ # for now it will extract it in the same folder
481
+ # maybe implement extraction in the sapienzanlp folder
482
+ # when using local archive path?
483
+ logger.info("Extracting compressed archive")
484
+ output_dir, output_file = os.path.split(output_path)
485
+ output_extract_dir_name = output_file.replace(".", "-") + "-extracted"
486
+ output_path_extracted = os.path.join(output_dir, output_extract_dir_name)
487
+
488
+ # already extracted, do not extract
489
+ if (
490
+ os.path.isdir(output_path_extracted)
491
+ and os.listdir(output_path_extracted)
492
+ and not force_download
493
+ ):
494
+ return Path(output_path_extracted)
495
+
496
+ # Prevent parallel extractions
497
+ lock_path = output_path + ".lock"
498
+ with FileLock(lock_path):
499
+ shutil.rmtree(output_path_extracted, ignore_errors=True)
500
+ os.makedirs(output_path_extracted)
501
+ if is_zipfile(output_path):
502
+ with ZipFile(output_path, "r") as zip_file:
503
+ zip_file.extractall(output_path_extracted)
504
+ zip_file.close()
505
+ elif tarfile.is_tarfile(output_path):
506
+ tar_file = tarfile.open(output_path)
507
+ tar_file.extractall(output_path_extracted)
508
+ tar_file.close()
509
+ else:
510
+ raise EnvironmentError(
511
+ f"Archive format of {output_path} could not be identified"
512
+ )
513
+
514
+ # remove lock file, is it safe?
515
+ os.remove(lock_path)
516
+
517
+ return Path(output_path_extracted)
518
+
519
+
520
+ def is_str_a_path(maybe_path: str) -> bool:
521
+ """
522
+ Check if a string is a path.
523
+
524
+ Args:
525
+ maybe_path (`str`): The string to check.
526
+
527
+ Returns:
528
+ `bool`: `True` if the string is a path, `False` otherwise.
529
+ """
530
+ # first check if it is a path
531
+ if Path(maybe_path).exists():
532
+ return True
533
+ # check if it is a relative path
534
+ if Path(os.path.join(os.getcwd(), maybe_path)).exists():
535
+ return True
536
+ # otherwise it is not a path
537
+ return False
538
+
539
+
540
+ def relative_to_absolute_path(path: str) -> os.PathLike:
541
+ """
542
+ Convert a relative path to an absolute path.
543
+
544
+ Args:
545
+ path (`str`): The relative path to convert.
546
+
547
+ Returns:
548
+ `os.PathLike`: The absolute path.
549
+ """
550
+ if not is_str_a_path(path):
551
+ raise ValueError(f"{path} is not a path")
552
+ if Path(path).exists():
553
+ return Path(path).absolute()
554
+ if Path(os.path.join(os.getcwd(), path)).exists():
555
+ return Path(os.path.join(os.getcwd(), path)).absolute()
556
+ raise ValueError(f"{path} is not a path")
557
+
558
+
559
+ def to_config(object_to_save: Any) -> Dict[str, Any]:
560
+ """
561
+ Convert an object to a dictionary.
562
+
563
+ Returns:
564
+ `Dict[str, Any]`: The dictionary representation of the object.
565
+ """
566
+
567
+ def obj_to_dict(obj):
568
+ match obj:
569
+ case dict():
570
+ data = {}
571
+ for k, v in obj.items():
572
+ data[k] = obj_to_dict(v)
573
+ return data
574
+
575
+ case list() | tuple():
576
+ return [obj_to_dict(x) for x in obj]
577
+
578
+ case object(__dict__=_):
579
+ data = {
580
+ "_target_": f"{obj.__class__.__module__}.{obj.__class__.__name__}",
581
+ }
582
+ for k, v in obj.__dict__.items():
583
+ if not k.startswith("_"):
584
+ data[k] = obj_to_dict(v)
585
+ return data
586
+
587
+ case _:
588
+ return obj
589
+
590
+ return obj_to_dict(object_to_save)
591
+
592
+
593
+ def get_callable_from_string(callable_fn: str) -> Any:
594
+ """
595
+ Get a callable from a string.
596
+
597
+ Args:
598
+ callable_fn (`str`):
599
+ The string representation of the callable.
600
+
601
+ Returns:
602
+ `Any`: The callable.
603
+ """
604
+ # separate the function name from the module name
605
+ module_name, function_name = callable_fn.rsplit(".", 1)
606
+ # import the module
607
+ module = importlib.import_module(module_name)
608
+ # get the function
609
+ return getattr(module, function_name)
relik/inference/__init__.py ADDED
File without changes
relik/inference/annotator.py ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from typing import Any, Callable, Dict, Optional, Union
4
+
5
+ import hydra
6
+ from omegaconf import OmegaConf
7
+ from relik.retriever.pytorch_modules.hf import GoldenRetrieverModel
8
+ from rich.pretty import pprint
9
+
10
+ from relik.common.log import get_console_logger, get_logger
11
+ from relik.common.upload import upload
12
+ from relik.common.utils import CONFIG_NAME, from_cache, get_callable_from_string
13
+ from relik.inference.data.objects import EntitySpan, RelikOutput
14
+ from relik.inference.data.tokenizers.spacy_tokenizer import SpacyTokenizer
15
+ from relik.inference.data.window.manager import WindowManager
16
+ from relik.reader.pytorch_modules.span import RelikReaderForSpanExtraction
17
+ from relik.reader.relik_reader import RelikReader
18
+ from relik.retriever.data.utils import batch_generator
19
+ from relik.retriever.indexers.base import BaseDocumentIndex
20
+ from relik.retriever.pytorch_modules.model import GoldenRetriever
21
+
22
+ logger = get_logger(__name__)
23
+ console_logger = get_console_logger()
24
+
25
+
26
+ class Relik:
27
+ """
28
+ Relik main class. It is a wrapper around a retriever and a reader.
29
+
30
+ Args:
31
+ retriever (`Optional[GoldenRetriever]`, `optional`):
32
+ The retriever to use. If `None`, a retriever will be instantiated from the
33
+ provided `question_encoder`, `passage_encoder` and `document_index`.
34
+ Defaults to `None`.
35
+ question_encoder (`Optional[Union[str, GoldenRetrieverModel]]`, `optional`):
36
+ The question encoder to use. If `retriever` is `None`, a retriever will be
37
+ instantiated from this parameter. Defaults to `None`.
38
+ passage_encoder (`Optional[Union[str, GoldenRetrieverModel]]`, `optional`):
39
+ The passage encoder to use. If `retriever` is `None`, a retriever will be
40
+ instantiated from this parameter. Defaults to `None`.
41
+ document_index (`Optional[Union[str, BaseDocumentIndex]]`, `optional`):
42
+ The document index to use. If `retriever` is `None`, a retriever will be
43
+ instantiated from this parameter. Defaults to `None`.
44
+ reader (`Optional[Union[str, RelikReader]]`, `optional`):
45
+ The reader to use. If `None`, a reader will be instantiated from the
46
+ provided `reader`. Defaults to `None`.
47
+ retriever_device (`str`, `optional`, defaults to `cpu`):
48
+ The device to use for the retriever.
49
+
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ retriever: GoldenRetriever | None = None,
55
+ question_encoder: str | GoldenRetrieverModel | None = None,
56
+ passage_encoder: str | GoldenRetrieverModel | None = None,
57
+ document_index: str | BaseDocumentIndex | None = None,
58
+ reader: str | RelikReader | None = None,
59
+ device: str = "cpu",
60
+ retriever_device: str | None = None,
61
+ document_index_device: str | None = None,
62
+ reader_device: str | None = None,
63
+ precision: int = 32,
64
+ retriever_precision: int | None = None,
65
+ document_index_precision: int | None = None,
66
+ reader_precision: int | None = None,
67
+ reader_kwargs: dict | None = None,
68
+ retriever_kwargs: dict | None = None,
69
+ candidates_preprocessing_fn: str | Callable | None = None,
70
+ top_k: int | None = None,
71
+ window_size: int | None = None,
72
+ window_stride: int | None = None,
73
+ **kwargs,
74
+ ) -> None:
75
+ # retriever
76
+ retriever_device = retriever_device or device
77
+ document_index_device = document_index_device or device
78
+ retriever_precision = retriever_precision or precision
79
+ document_index_precision = document_index_precision or precision
80
+ if retriever is None and question_encoder is None:
81
+ raise ValueError(
82
+ "Either `retriever` or `question_encoder` must be provided"
83
+ )
84
+ if retriever is None:
85
+ self.retriever_kwargs = dict(
86
+ question_encoder=question_encoder,
87
+ passage_encoder=passage_encoder,
88
+ document_index=document_index,
89
+ device=retriever_device,
90
+ precision=retriever_precision,
91
+ index_device=document_index_device,
92
+ index_precision=document_index_precision,
93
+ )
94
+ # overwrite default_retriever_kwargs with retriever_kwargs
95
+ self.retriever_kwargs.update(retriever_kwargs or {})
96
+ retriever = GoldenRetriever(**self.retriever_kwargs)
97
+ retriever.training = False
98
+ retriever.eval()
99
+ self.retriever = retriever
100
+
101
+ # reader
102
+ self.reader_device = reader_device or device
103
+ self.reader_precision = reader_precision or precision
104
+ self.reader_kwargs = reader_kwargs
105
+ if isinstance(reader, str):
106
+ reader_kwargs = reader_kwargs or {}
107
+ reader = RelikReaderForSpanExtraction(reader, **reader_kwargs)
108
+ self.reader = reader
109
+
110
+ # windowization stuff
111
+ self.tokenizer = SpacyTokenizer(language="en")
112
+ self.window_manager: WindowManager | None = None
113
+
114
+ # candidates preprocessing
115
+ # TODO: maybe move this logic somewhere else
116
+ candidates_preprocessing_fn = candidates_preprocessing_fn or (lambda x: x)
117
+ if isinstance(candidates_preprocessing_fn, str):
118
+ candidates_preprocessing_fn = get_callable_from_string(
119
+ candidates_preprocessing_fn
120
+ )
121
+ self.candidates_preprocessing_fn = candidates_preprocessing_fn
122
+
123
+ # inference params
124
+ self.top_k = top_k
125
+ self.window_size = window_size
126
+ self.window_stride = window_stride
127
+
128
+ def __call__(
129
+ self,
130
+ text: Union[str, list],
131
+ top_k: Optional[int] = None,
132
+ window_size: Optional[int] = None,
133
+ window_stride: Optional[int] = None,
134
+ retriever_batch_size: Optional[int] = 32,
135
+ reader_batch_size: Optional[int] = 32,
136
+ return_also_windows: bool = False,
137
+ **kwargs,
138
+ ) -> Union[RelikOutput, list[RelikOutput]]:
139
+ """
140
+ Annotate a text with entities.
141
+
142
+ Args:
143
+ text (`str` or `list`):
144
+ The text to annotate. If a list is provided, each element of the list
145
+ will be annotated separately.
146
+ top_k (`int`, `optional`, defaults to `None`):
147
+ The number of candidates to retrieve for each window.
148
+ window_size (`int`, `optional`, defaults to `None`):
149
+ The size of the window. If `None`, the whole text will be annotated.
150
+ window_stride (`int`, `optional`, defaults to `None`):
151
+ The stride of the window. If `None`, there will be no overlap between windows.
152
+ retriever_batch_size (`int`, `optional`, defaults to `None`):
153
+ The batch size to use for the retriever. The whole input is the batch for the retriever.
154
+ reader_batch_size (`int`, `optional`, defaults to `None`):
155
+ The batch size to use for the reader. The whole input is the batch for the reader.
156
+ return_also_windows (`bool`, `optional`, defaults to `False`):
157
+ Whether to return the windows in the output.
158
+ **kwargs:
159
+ Additional keyword arguments to pass to the retriever and the reader.
160
+
161
+ Returns:
162
+ `RelikOutput` or `list[RelikOutput]`:
163
+ The annotated text. If a list was provided as input, a list of
164
+ `RelikOutput` objects will be returned.
165
+ """
166
+ if top_k is None:
167
+ top_k = self.top_k or 100
168
+ if window_size is None:
169
+ window_size = self.window_size
170
+ if window_stride is None:
171
+ window_stride = self.window_stride
172
+
173
+ if isinstance(text, str):
174
+ text = [text]
175
+
176
+ if window_size is not None:
177
+ if self.window_manager is None:
178
+ self.window_manager = WindowManager(self.tokenizer)
179
+
180
+ if window_size == "sentence":
181
+ # todo: implement sentence windowizer
182
+ raise NotImplementedError("Sentence windowizer not implemented yet")
183
+
184
+ # if window_size < window_stride:
185
+ # raise ValueError(
186
+ # f"Window size ({window_size}) must be greater than window stride ({window_stride})"
187
+ # )
188
+
189
+ # window generator
190
+ windows = [
191
+ window
192
+ for doc_id, t in enumerate(text)
193
+ for window in self.window_manager.create_windows(
194
+ t,
195
+ window_size=window_size,
196
+ stride=window_stride,
197
+ doc_id=doc_id,
198
+ )
199
+ ]
200
+
201
+ # retrieve candidates first
202
+ windows_candidates = []
203
+ # TODO: Move batching inside retriever
204
+ for batch in batch_generator(windows, batch_size=retriever_batch_size):
205
+ retriever_out = self.retriever.retrieve([b.text for b in batch], k=top_k)
206
+ windows_candidates.extend(
207
+ [[p.label for p in predictions] for predictions in retriever_out]
208
+ )
209
+
210
+ # add passage to the windows
211
+ for window, candidates in zip(windows, windows_candidates):
212
+ window.window_candidates = [
213
+ self.candidates_preprocessing_fn(c) for c in candidates
214
+ ]
215
+
216
+ windows = self.reader.read(samples=windows, max_batch_size=reader_batch_size)
217
+ windows = self.window_manager.merge_windows(windows)
218
+
219
+ # transform predictions into RelikOutput objects
220
+ output = []
221
+ for w in windows:
222
+ sample_output = RelikOutput(
223
+ text=text[w.doc_id],
224
+ labels=sorted(
225
+ [
226
+ EntitySpan(
227
+ start=ss, end=se, label=sl, text=text[w.doc_id][ss:se]
228
+ )
229
+ for ss, se, sl in w.predicted_window_labels_chars
230
+ ],
231
+ key=lambda x: x.start,
232
+ ),
233
+ )
234
+ output.append(sample_output)
235
+
236
+ if return_also_windows:
237
+ for i, sample_output in enumerate(output):
238
+ sample_output.windows = [w for w in windows if w.doc_id == i]
239
+
240
+ # if only one text was provided, return a single RelikOutput object
241
+ if len(output) == 1:
242
+ return output[0]
243
+
244
+ return output
245
+
246
+ @classmethod
247
+ def from_pretrained(
248
+ cls,
249
+ model_name_or_dir: Union[str, os.PathLike],
250
+ config_kwargs: Optional[Dict] = None,
251
+ config_file_name: str = CONFIG_NAME,
252
+ *args,
253
+ **kwargs,
254
+ ) -> "Relik":
255
+ cache_dir = kwargs.pop("cache_dir", None)
256
+ force_download = kwargs.pop("force_download", False)
257
+
258
+ model_dir = from_cache(
259
+ model_name_or_dir,
260
+ filenames=[config_file_name],
261
+ cache_dir=cache_dir,
262
+ force_download=force_download,
263
+ )
264
+
265
+ config_path = model_dir / config_file_name
266
+ if not config_path.exists():
267
+ raise FileNotFoundError(
268
+ f"Model configuration file not found at {config_path}."
269
+ )
270
+
271
+ # overwrite config with config_kwargs
272
+ config = OmegaConf.load(config_path)
273
+ if config_kwargs is not None:
274
+ # TODO: check merging behavior
275
+ config = OmegaConf.merge(config, OmegaConf.create(config_kwargs))
276
+ # do we want to print the config? I like it
277
+ pprint(OmegaConf.to_container(config), console=console_logger, expand_all=True)
278
+
279
+ # load relik from config
280
+ relik = hydra.utils.instantiate(config, *args, **kwargs)
281
+
282
+ return relik
283
+
284
+ def save_pretrained(
285
+ self,
286
+ output_dir: Union[str, os.PathLike],
287
+ config: Optional[Dict[str, Any]] = None,
288
+ config_file_name: Optional[str] = None,
289
+ save_weights: bool = False,
290
+ push_to_hub: bool = False,
291
+ model_id: Optional[str] = None,
292
+ organization: Optional[str] = None,
293
+ repo_name: Optional[str] = None,
294
+ **kwargs,
295
+ ):
296
+ """
297
+ Save the configuration of Relik to the specified directory as a YAML file.
298
+
299
+ Args:
300
+ output_dir (`str`):
301
+ The directory to save the configuration file to.
302
+ config (`Optional[Dict[str, Any]]`, `optional`):
303
+ The configuration to save. If `None`, the current configuration will be
304
+ saved. Defaults to `None`.
305
+ config_file_name (`Optional[str]`, `optional`):
306
+ The name of the configuration file. Defaults to `config.yaml`.
307
+ save_weights (`bool`, `optional`):
308
+ Whether to save the weights of the model. Defaults to `False`.
309
+ push_to_hub (`bool`, `optional`):
310
+ Whether to push the saved model to the hub. Defaults to `False`.
311
+ model_id (`Optional[str]`, `optional`):
312
+ The id of the model to push to the hub. If `None`, the name of the
313
+ directory will be used. Defaults to `None`.
314
+ organization (`Optional[str]`, `optional`):
315
+ The organization to push the model to. Defaults to `None`.
316
+ repo_name (`Optional[str]`, `optional`):
317
+ The name of the repository to push the model to. Defaults to `None`.
318
+ **kwargs:
319
+ Additional keyword arguments to pass to `OmegaConf.save`.
320
+ """
321
+ if config is None:
322
+ # create a default config
323
+ config = {
324
+ "_target_": f"{self.__class__.__module__}.{self.__class__.__name__}"
325
+ }
326
+ if self.retriever is not None:
327
+ if self.retriever.question_encoder is not None:
328
+ config[
329
+ "question_encoder"
330
+ ] = self.retriever.question_encoder.name_or_path
331
+ if self.retriever.passage_encoder is not None:
332
+ config[
333
+ "passage_encoder"
334
+ ] = self.retriever.passage_encoder.name_or_path
335
+ if self.retriever.document_index is not None:
336
+ config["document_index"] = self.retriever.document_index.name_or_dir
337
+ if self.reader is not None:
338
+ config["reader"] = self.reader.model_path
339
+
340
+ config["retriever_kwargs"] = self.retriever_kwargs
341
+ config["reader_kwargs"] = self.reader_kwargs
342
+ # expand the fn as to be able to save it and load it later
343
+ config[
344
+ "candidates_preprocessing_fn"
345
+ ] = f"{self.candidates_preprocessing_fn.__module__}.{self.candidates_preprocessing_fn.__name__}"
346
+
347
+ # these are model-specific and should be saved
348
+ config["top_k"] = self.top_k
349
+ config["window_size"] = self.window_size
350
+ config["window_stride"] = self.window_stride
351
+
352
+ config_file_name = config_file_name or CONFIG_NAME
353
+
354
+ # create the output directory
355
+ output_dir = Path(output_dir)
356
+ output_dir.mkdir(parents=True, exist_ok=True)
357
+
358
+ logger.info(f"Saving relik config to {output_dir / config_file_name}")
359
+ # pretty print the config
360
+ pprint(config, console=console_logger, expand_all=True)
361
+ OmegaConf.save(config, output_dir / config_file_name)
362
+
363
+ if save_weights:
364
+ model_id = model_id or output_dir.name
365
+ retriever_model_id = model_id + "-retriever"
366
+ # save weights
367
+ logger.info(f"Saving retriever to {output_dir / retriever_model_id}")
368
+ self.retriever.save_pretrained(
369
+ output_dir / retriever_model_id,
370
+ question_encoder_name=retriever_model_id + "-question-encoder",
371
+ passage_encoder_name=retriever_model_id + "-passage-encoder",
372
+ document_index_name=retriever_model_id + "-index",
373
+ push_to_hub=push_to_hub,
374
+ organization=organization,
375
+ repo_name=repo_name,
376
+ **kwargs,
377
+ )
378
+ reader_model_id = model_id + "-reader"
379
+ logger.info(f"Saving reader to {output_dir / reader_model_id}")
380
+ self.reader.save_pretrained(
381
+ output_dir / reader_model_id,
382
+ push_to_hub=push_to_hub,
383
+ organization=organization,
384
+ repo_name=repo_name,
385
+ **kwargs,
386
+ )
387
+
388
+ if push_to_hub:
389
+ # push to hub
390
+ logger.info(f"Pushing to hub")
391
+ model_id = model_id or output_dir.name
392
+ upload(output_dir, model_id, organization=organization, repo_name=repo_name)
393
+
394
+
395
+ def main():
396
+ from pprint import pprint
397
+
398
+ relik = Relik(
399
+ question_encoder="riccorl/relik-retriever-aida-blink-pretrain-omniencoder",
400
+ document_index="riccorl/index-relik-retriever-aida-blink-pretrain-omniencoder",
401
+ reader="riccorl/relik-reader-aida-deberta-small",
402
+ device="cuda",
403
+ precision=16,
404
+ top_k=100,
405
+ window_size=32,
406
+ window_stride=16,
407
+ candidates_preprocessing_fn="relik.inference.preprocessing.wikipedia_title_and_openings_preprocessing",
408
+ )
409
+
410
+ input_text = """
411
+ Bernie Ecclestone, the former boss of Formula One, has admitted fraud after failing to declare more than £400m held in a trust in Singapore.
412
+ The 92-year-old billionaire did not disclose the trust to the government in July 2015.
413
+ Appearing at Southwark Crown Court on Thursday, he told the judge "I plead guilty" after having previously pleaded not guilty.
414
+ Ecclestone had been due to go on trial next month.
415
+ """
416
+
417
+ preds = relik(input_text)
418
+ pprint(preds)
419
+
420
+
421
+ if __name__ == "__main__":
422
+ main()
relik/inference/data/__init__.py ADDED
File without changes
relik/inference/data/objects.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import List, NamedTuple, Optional
5
+
6
+ from relik.reader.pytorch_modules.hf.modeling_relik import RelikReaderSample
7
+
8
+
9
+ @dataclass
10
+ class Word:
11
+ """
12
+ A word representation that includes text, index in the sentence, POS tag, lemma,
13
+ dependency relation, and similar information.
14
+
15
+ # Parameters
16
+ text : `str`, optional
17
+ The text representation.
18
+ index : `int`, optional
19
+ The word offset in the sentence.
20
+ lemma : `str`, optional
21
+ The lemma of this word.
22
+ pos : `str`, optional
23
+ The coarse-grained part of speech of this word.
24
+ dep : `str`, optional
25
+ The dependency relation for this word.
26
+
27
+ input_id : `int`, optional
28
+ Integer representation of the word, used to pass it to a model.
29
+ token_type_id : `int`, optional
30
+ Token type id used by some transformers.
31
+ attention_mask: `int`, optional
32
+ Attention mask used by transformers, indicates to the model which tokens should
33
+ be attended to, and which should not.
34
+ """
35
+
36
+ text: str
37
+ index: int
38
+ start_char: Optional[int] = None
39
+ end_char: Optional[int] = None
40
+ # preprocessing fields
41
+ lemma: Optional[str] = None
42
+ pos: Optional[str] = None
43
+ dep: Optional[str] = None
44
+ head: Optional[int] = None
45
+
46
+ def __str__(self):
47
+ return self.text
48
+
49
+ def __repr__(self):
50
+ return self.__str__()
51
+
52
+
53
+ class EntitySpan(NamedTuple):
54
+ start: int
55
+ end: int
56
+ label: str
57
+ text: str
58
+
59
+
60
+ @dataclass
61
+ class RelikOutput:
62
+ text: str
63
+ labels: List[EntitySpan]
64
+ windows: Optional[List[RelikReaderSample]] = None
relik/inference/data/tokenizers/__init__.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ SPACY_LANGUAGE_MAPPER = {
2
+ "ca": "ca_core_news_sm",
3
+ "da": "da_core_news_sm",
4
+ "de": "de_core_news_sm",
5
+ "el": "el_core_news_sm",
6
+ "en": "en_core_web_sm",
7
+ "es": "es_core_news_sm",
8
+ "fr": "fr_core_news_sm",
9
+ "it": "it_core_news_sm",
10
+ "ja": "ja_core_news_sm",
11
+ "lt": "lt_core_news_sm",
12
+ "mk": "mk_core_news_sm",
13
+ "nb": "nb_core_news_sm",
14
+ "nl": "nl_core_news_sm",
15
+ "pl": "pl_core_news_sm",
16
+ "pt": "pt_core_news_sm",
17
+ "ro": "ro_core_news_sm",
18
+ "ru": "ru_core_news_sm",
19
+ "xx": "xx_sent_ud_sm",
20
+ "zh": "zh_core_web_sm",
21
+ "ca_core_news_sm": "ca_core_news_sm",
22
+ "ca_core_news_md": "ca_core_news_md",
23
+ "ca_core_news_lg": "ca_core_news_lg",
24
+ "ca_core_news_trf": "ca_core_news_trf",
25
+ "da_core_news_sm": "da_core_news_sm",
26
+ "da_core_news_md": "da_core_news_md",
27
+ "da_core_news_lg": "da_core_news_lg",
28
+ "da_core_news_trf": "da_core_news_trf",
29
+ "de_core_news_sm": "de_core_news_sm",
30
+ "de_core_news_md": "de_core_news_md",
31
+ "de_core_news_lg": "de_core_news_lg",
32
+ "de_dep_news_trf": "de_dep_news_trf",
33
+ "el_core_news_sm": "el_core_news_sm",
34
+ "el_core_news_md": "el_core_news_md",
35
+ "el_core_news_lg": "el_core_news_lg",
36
+ "en_core_web_sm": "en_core_web_sm",
37
+ "en_core_web_md": "en_core_web_md",
38
+ "en_core_web_lg": "en_core_web_lg",
39
+ "en_core_web_trf": "en_core_web_trf",
40
+ "es_core_news_sm": "es_core_news_sm",
41
+ "es_core_news_md": "es_core_news_md",
42
+ "es_core_news_lg": "es_core_news_lg",
43
+ "es_dep_news_trf": "es_dep_news_trf",
44
+ "fr_core_news_sm": "fr_core_news_sm",
45
+ "fr_core_news_md": "fr_core_news_md",
46
+ "fr_core_news_lg": "fr_core_news_lg",
47
+ "fr_dep_news_trf": "fr_dep_news_trf",
48
+ "it_core_news_sm": "it_core_news_sm",
49
+ "it_core_news_md": "it_core_news_md",
50
+ "it_core_news_lg": "it_core_news_lg",
51
+ "ja_core_news_sm": "ja_core_news_sm",
52
+ "ja_core_news_md": "ja_core_news_md",
53
+ "ja_core_news_lg": "ja_core_news_lg",
54
+ "ja_dep_news_trf": "ja_dep_news_trf",
55
+ "lt_core_news_sm": "lt_core_news_sm",
56
+ "lt_core_news_md": "lt_core_news_md",
57
+ "lt_core_news_lg": "lt_core_news_lg",
58
+ "mk_core_news_sm": "mk_core_news_sm",
59
+ "mk_core_news_md": "mk_core_news_md",
60
+ "mk_core_news_lg": "mk_core_news_lg",
61
+ "nb_core_news_sm": "nb_core_news_sm",
62
+ "nb_core_news_md": "nb_core_news_md",
63
+ "nb_core_news_lg": "nb_core_news_lg",
64
+ "nl_core_news_sm": "nl_core_news_sm",
65
+ "nl_core_news_md": "nl_core_news_md",
66
+ "nl_core_news_lg": "nl_core_news_lg",
67
+ "pl_core_news_sm": "pl_core_news_sm",
68
+ "pl_core_news_md": "pl_core_news_md",
69
+ "pl_core_news_lg": "pl_core_news_lg",
70
+ "pt_core_news_sm": "pt_core_news_sm",
71
+ "pt_core_news_md": "pt_core_news_md",
72
+ "pt_core_news_lg": "pt_core_news_lg",
73
+ "ro_core_news_sm": "ro_core_news_sm",
74
+ "ro_core_news_md": "ro_core_news_md",
75
+ "ro_core_news_lg": "ro_core_news_lg",
76
+ "ru_core_news_sm": "ru_core_news_sm",
77
+ "ru_core_news_md": "ru_core_news_md",
78
+ "ru_core_news_lg": "ru_core_news_lg",
79
+ "xx_ent_wiki_sm": "xx_ent_wiki_sm",
80
+ "xx_sent_ud_sm": "xx_sent_ud_sm",
81
+ "zh_core_web_sm": "zh_core_web_sm",
82
+ "zh_core_web_md": "zh_core_web_md",
83
+ "zh_core_web_lg": "zh_core_web_lg",
84
+ "zh_core_web_trf": "zh_core_web_trf",
85
+ }
86
+
87
+ from relik.inference.data.tokenizers.regex_tokenizer import RegexTokenizer
88
+ from relik.inference.data.tokenizers.spacy_tokenizer import SpacyTokenizer
89
+ from relik.inference.data.tokenizers.whitespace_tokenizer import WhitespaceTokenizer
relik/inference/data/tokenizers/base_tokenizer.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Union
2
+
3
+ from relik.inference.data.objects import Word
4
+
5
+
6
+ class BaseTokenizer:
7
+ """
8
+ A :obj:`Tokenizer` splits strings of text into single words, optionally adds
9
+ pos tags and perform lemmatization.
10
+ """
11
+
12
+ def __call__(
13
+ self,
14
+ texts: Union[str, List[str], List[List[str]]],
15
+ is_split_into_words: bool = False,
16
+ **kwargs
17
+ ) -> List[List[Word]]:
18
+ """
19
+ Tokenize the input into single words.
20
+
21
+ Args:
22
+ texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`):
23
+ Text to tag. It can be a single string, a batch of string and pre-tokenized strings.
24
+ is_split_into_words (:obj:`bool`, optional, defaults to :obj:`False`):
25
+ If :obj:`True` and the input is a string, the input is split on spaces.
26
+
27
+ Returns:
28
+ :obj:`List[List[Word]]`: The input text tokenized in single words.
29
+ """
30
+ raise NotImplementedError
31
+
32
+ def tokenize(self, text: str) -> List[Word]:
33
+ """
34
+ Implements splitting words into tokens.
35
+
36
+ Args:
37
+ text (:obj:`str`):
38
+ Text to tokenize.
39
+
40
+ Returns:
41
+ :obj:`List[Word]`: The input text tokenized in single words.
42
+
43
+ """
44
+ raise NotImplementedError
45
+
46
+ def tokenize_batch(self, texts: List[str]) -> List[List[Word]]:
47
+ """
48
+ Implements batch splitting words into tokens.
49
+
50
+ Args:
51
+ texts (:obj:`List[str]`):
52
+ Batch of text to tokenize.
53
+
54
+ Returns:
55
+ :obj:`List[List[Word]]`: The input batch tokenized in single words.
56
+
57
+ """
58
+ return [self.tokenize(text) for text in texts]
59
+
60
+ @staticmethod
61
+ def check_is_batched(
62
+ texts: Union[str, List[str], List[List[str]]], is_split_into_words: bool
63
+ ):
64
+ """
65
+ Check if input is batched or a single sample.
66
+
67
+ Args:
68
+ texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`):
69
+ Text to check.
70
+ is_split_into_words (:obj:`bool`):
71
+ If :obj:`True` and the input is a string, the input is split on spaces.
72
+
73
+ Returns:
74
+ :obj:`bool`: ``True`` if ``texts`` is batched, ``False`` otherwise.
75
+ """
76
+ return bool(
77
+ (not is_split_into_words and isinstance(texts, (list, tuple)))
78
+ or (
79
+ is_split_into_words
80
+ and isinstance(texts, (list, tuple))
81
+ and texts
82
+ and isinstance(texts[0], (list, tuple))
83
+ )
84
+ )
relik/inference/data/tokenizers/regex_tokenizer.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import List, Union
3
+
4
+ from overrides import overrides
5
+
6
+ from relik.inference.data.objects import Word
7
+ from relik.inference.data.tokenizers.base_tokenizer import BaseTokenizer
8
+
9
+
10
+ class RegexTokenizer(BaseTokenizer):
11
+ """
12
+ A :obj:`Tokenizer` that splits the text based on a simple regex.
13
+ """
14
+
15
+ def __init__(self):
16
+ super(RegexTokenizer, self).__init__()
17
+ # regex for splitting on spaces and punctuation and new lines
18
+ # self._regex = re.compile(r"\S+|[\[\](),.!?;:\"]|\\n")
19
+ self._regex = re.compile(
20
+ r"\w+|\$[\d\.]+|\S+", re.UNICODE | re.MULTILINE | re.DOTALL
21
+ )
22
+
23
+ def __call__(
24
+ self,
25
+ texts: Union[str, List[str], List[List[str]]],
26
+ is_split_into_words: bool = False,
27
+ **kwargs,
28
+ ) -> List[List[Word]]:
29
+ """
30
+ Tokenize the input into single words by splitting using a simple regex.
31
+
32
+ Args:
33
+ texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`):
34
+ Text to tag. It can be a single string, a batch of string and pre-tokenized strings.
35
+ is_split_into_words (:obj:`bool`, optional, defaults to :obj:`False`):
36
+ If :obj:`True` and the input is a string, the input is split on spaces.
37
+
38
+ Returns:
39
+ :obj:`List[List[Word]]`: The input text tokenized in single words.
40
+
41
+ Example::
42
+
43
+ >>> from relik.retriever.serve.tokenizers.regex_tokenizer import RegexTokenizer
44
+
45
+ >>> regex_tokenizer = RegexTokenizer()
46
+ >>> regex_tokenizer("Mary sold the car to John.")
47
+
48
+ """
49
+ # check if input is batched or a single sample
50
+ is_batched = self.check_is_batched(texts, is_split_into_words)
51
+
52
+ if is_batched:
53
+ tokenized = self.tokenize_batch(texts)
54
+ else:
55
+ tokenized = self.tokenize(texts)
56
+
57
+ return tokenized
58
+
59
+ @overrides
60
+ def tokenize(self, text: Union[str, List[str]]) -> List[Word]:
61
+ if not isinstance(text, (str, list)):
62
+ raise ValueError(
63
+ f"text must be either `str` or `list`, found: `{type(text)}`"
64
+ )
65
+
66
+ if isinstance(text, list):
67
+ text = " ".join(text)
68
+ return [
69
+ Word(t[0], i, start_char=t[1], end_char=t[2])
70
+ for i, t in enumerate(
71
+ (m.group(0), m.start(), m.end()) for m in self._regex.finditer(text)
72
+ )
73
+ ]
relik/inference/data/tokenizers/spacy_tokenizer.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Dict, List, Tuple, Union
3
+
4
+ import spacy
5
+
6
+ # from ipa.common.utils import load_spacy
7
+ from overrides import overrides
8
+ from spacy.cli.download import download as spacy_download
9
+ from spacy.tokens import Doc
10
+
11
+ from relik.common.log import get_logger
12
+ from relik.inference.data.objects import Word
13
+ from relik.inference.data.tokenizers import SPACY_LANGUAGE_MAPPER
14
+ from relik.inference.data.tokenizers.base_tokenizer import BaseTokenizer
15
+
16
+ logger = get_logger(level=logging.DEBUG)
17
+
18
+ # Spacy and Stanza stuff
19
+
20
+ LOADED_SPACY_MODELS: Dict[Tuple[str, bool, bool, bool, bool], spacy.Language] = {}
21
+
22
+
23
+ def load_spacy(
24
+ language: str,
25
+ pos_tags: bool = False,
26
+ lemma: bool = False,
27
+ parse: bool = False,
28
+ split_on_spaces: bool = False,
29
+ ) -> spacy.Language:
30
+ """
31
+ Download and load spacy model.
32
+
33
+ Args:
34
+ language (:obj:`str`, defaults to :obj:`en`):
35
+ Language of the text to tokenize.
36
+ pos_tags (:obj:`bool`, optional, defaults to :obj:`False`):
37
+ If :obj:`True`, performs POS tagging with spacy model.
38
+ lemma (:obj:`bool`, optional, defaults to :obj:`False`):
39
+ If :obj:`True`, performs lemmatization with spacy model.
40
+ parse (:obj:`bool`, optional, defaults to :obj:`False`):
41
+ If :obj:`True`, performs dependency parsing with spacy model.
42
+ split_on_spaces (:obj:`bool`, optional, defaults to :obj:`False`):
43
+ If :obj:`True`, will split by spaces without performing tokenization.
44
+
45
+ Returns:
46
+ :obj:`spacy.Language`: The spacy model loaded.
47
+ """
48
+ exclude = ["vectors", "textcat", "ner"]
49
+ if not pos_tags:
50
+ exclude.append("tagger")
51
+ if not lemma:
52
+ exclude.append("lemmatizer")
53
+ if not parse:
54
+ exclude.append("parser")
55
+
56
+ # check if the model is already loaded
57
+ # if so, there is no need to reload it
58
+ spacy_params = (language, pos_tags, lemma, parse, split_on_spaces)
59
+ if spacy_params not in LOADED_SPACY_MODELS:
60
+ try:
61
+ spacy_tagger = spacy.load(language, exclude=exclude)
62
+ except OSError:
63
+ logger.warning(
64
+ "Spacy model '%s' not found. Downloading and installing.", language
65
+ )
66
+ spacy_download(language)
67
+ spacy_tagger = spacy.load(language, exclude=exclude)
68
+
69
+ # if everything is disabled, return only the tokenizer
70
+ # for faster tokenization
71
+ # TODO: is it really faster?
72
+ # if len(exclude) >= 6:
73
+ # spacy_tagger = spacy_tagger.tokenizer
74
+ LOADED_SPACY_MODELS[spacy_params] = spacy_tagger
75
+
76
+ return LOADED_SPACY_MODELS[spacy_params]
77
+
78
+
79
+ class SpacyTokenizer(BaseTokenizer):
80
+ """
81
+ A :obj:`Tokenizer` that uses SpaCy to tokenizer and preprocess the text. It returns :obj:`Word` objects.
82
+
83
+ Args:
84
+ language (:obj:`str`, optional, defaults to :obj:`en`):
85
+ Language of the text to tokenize.
86
+ return_pos_tags (:obj:`bool`, optional, defaults to :obj:`False`):
87
+ If :obj:`True`, performs POS tagging with spacy model.
88
+ return_lemmas (:obj:`bool`, optional, defaults to :obj:`False`):
89
+ If :obj:`True`, performs lemmatization with spacy model.
90
+ return_deps (:obj:`bool`, optional, defaults to :obj:`False`):
91
+ If :obj:`True`, performs dependency parsing with spacy model.
92
+ split_on_spaces (:obj:`bool`, optional, defaults to :obj:`False`):
93
+ If :obj:`True`, will split by spaces without performing tokenization.
94
+ use_gpu (:obj:`bool`, optional, defaults to :obj:`False`):
95
+ If :obj:`True`, will load the Stanza model on GPU.
96
+ """
97
+
98
+ def __init__(
99
+ self,
100
+ language: str = "en",
101
+ return_pos_tags: bool = False,
102
+ return_lemmas: bool = False,
103
+ return_deps: bool = False,
104
+ split_on_spaces: bool = False,
105
+ use_gpu: bool = False,
106
+ ):
107
+ super(SpacyTokenizer, self).__init__()
108
+ if language not in SPACY_LANGUAGE_MAPPER:
109
+ raise ValueError(
110
+ f"`{language}` language not supported. The supported "
111
+ f"languages are: {list(SPACY_LANGUAGE_MAPPER.keys())}."
112
+ )
113
+ if use_gpu:
114
+ # load the model on GPU
115
+ # if the GPU is not available or not correctly configured,
116
+ # it will rise an error
117
+ spacy.require_gpu()
118
+ self.spacy = load_spacy(
119
+ SPACY_LANGUAGE_MAPPER[language],
120
+ return_pos_tags,
121
+ return_lemmas,
122
+ return_deps,
123
+ split_on_spaces,
124
+ )
125
+ self.split_on_spaces = split_on_spaces
126
+
127
+ def __call__(
128
+ self,
129
+ texts: Union[str, List[str], List[List[str]]],
130
+ is_split_into_words: bool = False,
131
+ **kwargs,
132
+ ) -> Union[List[Word], List[List[Word]]]:
133
+ """
134
+ Tokenize the input into single words using SpaCy models.
135
+
136
+ Args:
137
+ texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`):
138
+ Text to tag. It can be a single string, a batch of string and pre-tokenized strings.
139
+ is_split_into_words (:obj:`bool`, optional, defaults to :obj:`False`):
140
+ If :obj:`True` and the input is a string, the input is split on spaces.
141
+
142
+ Returns:
143
+ :obj:`List[List[Word]]`: The input text tokenized in single words.
144
+
145
+ Example::
146
+
147
+ >>> from ipa import SpacyTokenizer
148
+
149
+ >>> spacy_tokenizer = SpacyTokenizer(language="en", pos_tags=True, lemma=True)
150
+ >>> spacy_tokenizer("Mary sold the car to John.")
151
+
152
+ """
153
+ # check if input is batched or a single sample
154
+ is_batched = self.check_is_batched(texts, is_split_into_words)
155
+ if is_batched:
156
+ tokenized = self.tokenize_batch(texts)
157
+ else:
158
+ tokenized = self.tokenize(texts)
159
+ return tokenized
160
+
161
+ @overrides
162
+ def tokenize(self, text: Union[str, List[str]]) -> List[Word]:
163
+ if self.split_on_spaces:
164
+ if isinstance(text, str):
165
+ text = text.split(" ")
166
+ spaces = [True] * len(text)
167
+ text = Doc(self.spacy.vocab, words=text, spaces=spaces)
168
+ return self._clean_tokens(self.spacy(text))
169
+
170
+ @overrides
171
+ def tokenize_batch(
172
+ self, texts: Union[List[str], List[List[str]]]
173
+ ) -> List[List[Word]]:
174
+ if self.split_on_spaces:
175
+ if isinstance(texts[0], str):
176
+ texts = [text.split(" ") for text in texts]
177
+ spaces = [[True] * len(text) for text in texts]
178
+ texts = [
179
+ Doc(self.spacy.vocab, words=text, spaces=space)
180
+ for text, space in zip(texts, spaces)
181
+ ]
182
+ return [self._clean_tokens(tokens) for tokens in self.spacy.pipe(texts)]
183
+
184
+ @staticmethod
185
+ def _clean_tokens(tokens: Doc) -> List[Word]:
186
+ """
187
+ Converts spaCy tokens to :obj:`Word`.
188
+
189
+ Args:
190
+ tokens (:obj:`spacy.tokens.Doc`):
191
+ Tokens from SpaCy model.
192
+
193
+ Returns:
194
+ :obj:`List[Word]`: The SpaCy model output converted into :obj:`Word` objects.
195
+ """
196
+ words = [
197
+ Word(
198
+ token.text,
199
+ token.i,
200
+ token.idx,
201
+ token.idx + len(token),
202
+ token.lemma_,
203
+ token.pos_,
204
+ token.dep_,
205
+ token.head.i,
206
+ )
207
+ for token in tokens
208
+ ]
209
+ return words
210
+
211
+
212
+ class WhitespaceSpacyTokenizer:
213
+ """Simple white space tokenizer for SpaCy."""
214
+
215
+ def __init__(self, vocab):
216
+ self.vocab = vocab
217
+
218
+ def __call__(self, text):
219
+ if isinstance(text, str):
220
+ words = text.split(" ")
221
+ elif isinstance(text, list):
222
+ words = text
223
+ else:
224
+ raise ValueError(
225
+ f"text must be either `str` or `list`, found: `{type(text)}`"
226
+ )
227
+ spaces = [True] * len(words)
228
+ return Doc(self.vocab, words=words, spaces=spaces)
relik/inference/data/tokenizers/whitespace_tokenizer.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import List, Union
3
+
4
+ from overrides import overrides
5
+
6
+ from relik.inference.data.objects import Word
7
+ from relik.inference.data.tokenizers.base_tokenizer import BaseTokenizer
8
+
9
+
10
+ class WhitespaceTokenizer(BaseTokenizer):
11
+ """
12
+ A :obj:`Tokenizer` that splits the text on spaces.
13
+ """
14
+
15
+ def __init__(self):
16
+ super(WhitespaceTokenizer, self).__init__()
17
+ self.whitespace_regex = re.compile(r"\S+")
18
+
19
+ def __call__(
20
+ self,
21
+ texts: Union[str, List[str], List[List[str]]],
22
+ is_split_into_words: bool = False,
23
+ **kwargs,
24
+ ) -> List[List[Word]]:
25
+ """
26
+ Tokenize the input into single words by splitting on spaces.
27
+
28
+ Args:
29
+ texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`):
30
+ Text to tag. It can be a single string, a batch of string and pre-tokenized strings.
31
+ is_split_into_words (:obj:`bool`, optional, defaults to :obj:`False`):
32
+ If :obj:`True` and the input is a string, the input is split on spaces.
33
+
34
+ Returns:
35
+ :obj:`List[List[Word]]`: The input text tokenized in single words.
36
+
37
+ Example::
38
+
39
+ >>> from nlp_preprocessing_wrappers import WhitespaceTokenizer
40
+
41
+ >>> whitespace_tokenizer = WhitespaceTokenizer()
42
+ >>> whitespace_tokenizer("Mary sold the car to John .")
43
+
44
+ """
45
+ # check if input is batched or a single sample
46
+ is_batched = self.check_is_batched(texts, is_split_into_words)
47
+
48
+ if is_batched:
49
+ tokenized = self.tokenize_batch(texts)
50
+ else:
51
+ tokenized = self.tokenize(texts)
52
+
53
+ return tokenized
54
+
55
+ @overrides
56
+ def tokenize(self, text: Union[str, List[str]]) -> List[Word]:
57
+ if not isinstance(text, (str, list)):
58
+ raise ValueError(
59
+ f"text must be either `str` or `list`, found: `{type(text)}`"
60
+ )
61
+
62
+ if isinstance(text, list):
63
+ text = " ".join(text)
64
+ return [
65
+ Word(t[0], i, start_char=t[1], end_char=t[2])
66
+ for i, t in enumerate(
67
+ (m.group(0), m.start(), m.end())
68
+ for m in self.whitespace_regex.finditer(text)
69
+ )
70
+ ]
relik/inference/data/window/__init__.py ADDED
File without changes
relik/inference/data/window/manager.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import itertools
3
+ from dataclasses import dataclass
4
+ from typing import List, Optional, Set, Tuple
5
+
6
+ from relik.inference.data.tokenizers.base_tokenizer import BaseTokenizer
7
+ from relik.reader.data.relik_reader_sample import RelikReaderSample
8
+
9
+
10
+ @dataclass
11
+ class Window:
12
+ doc_id: int
13
+ window_id: int
14
+ text: str
15
+ tokens: List[str]
16
+ doc_topic: Optional[str]
17
+ offset: int
18
+ token2char_start: dict
19
+ token2char_end: dict
20
+ window_candidates: Optional[List[str]] = None
21
+
22
+
23
+ class WindowManager:
24
+ def __init__(self, tokenizer: BaseTokenizer) -> None:
25
+ self.tokenizer = tokenizer
26
+
27
+ def tokenize(self, document: str) -> Tuple[List[str], List[Tuple[int, int]]]:
28
+ tokenized_document = self.tokenizer(document)
29
+ tokens = []
30
+ tokens_char_mapping = []
31
+ for token in tokenized_document:
32
+ tokens.append(token.text)
33
+ tokens_char_mapping.append((token.start_char, token.end_char))
34
+ return tokens, tokens_char_mapping
35
+
36
+ def create_windows(
37
+ self,
38
+ document: str,
39
+ window_size: int,
40
+ stride: int,
41
+ doc_id: int = 0,
42
+ doc_topic: str = None,
43
+ ) -> List[RelikReaderSample]:
44
+ document_tokens, tokens_char_mapping = self.tokenize(document)
45
+ if doc_topic is None:
46
+ doc_topic = document_tokens[0] if len(document_tokens) > 0 else ""
47
+ document_windows = []
48
+ if len(document_tokens) <= window_size:
49
+ text = document
50
+ # relik_reader_sample = RelikReaderSample()
51
+ document_windows.append(
52
+ # Window(
53
+ RelikReaderSample(
54
+ doc_id=doc_id,
55
+ window_id=0,
56
+ text=text,
57
+ tokens=document_tokens,
58
+ doc_topic=doc_topic,
59
+ offset=0,
60
+ token2char_start={
61
+ str(i): tokens_char_mapping[i][0]
62
+ for i in range(len(document_tokens))
63
+ },
64
+ token2char_end={
65
+ str(i): tokens_char_mapping[i][1]
66
+ for i in range(len(document_tokens))
67
+ },
68
+ )
69
+ )
70
+ else:
71
+ for window_id, i in enumerate(range(0, len(document_tokens), stride)):
72
+ # if the last stride is smaller than the window size, then we can
73
+ # include more tokens form the previous window.
74
+ if i != 0 and i + window_size > len(document_tokens):
75
+ overflowing_tokens = i + window_size - len(document_tokens)
76
+ if overflowing_tokens >= stride:
77
+ break
78
+ i -= overflowing_tokens
79
+
80
+ involved_token_indices = list(
81
+ range(i, min(i + window_size, len(document_tokens) - 1))
82
+ )
83
+ window_tokens = [document_tokens[j] for j in involved_token_indices]
84
+ window_text_start = tokens_char_mapping[involved_token_indices[0]][0]
85
+ window_text_end = tokens_char_mapping[involved_token_indices[-1]][1]
86
+ text = document[window_text_start:window_text_end]
87
+ document_windows.append(
88
+ # Window(
89
+ RelikReaderSample(
90
+ # dict(
91
+ doc_id=doc_id,
92
+ window_id=window_id,
93
+ text=text,
94
+ tokens=window_tokens,
95
+ doc_topic=doc_topic,
96
+ offset=window_text_start,
97
+ token2char_start={
98
+ str(i): tokens_char_mapping[ti][0]
99
+ for i, ti in enumerate(involved_token_indices)
100
+ },
101
+ token2char_end={
102
+ str(i): tokens_char_mapping[ti][1]
103
+ for i, ti in enumerate(involved_token_indices)
104
+ },
105
+ # )
106
+ )
107
+ )
108
+ return document_windows
109
+
110
+ def merge_windows(
111
+ self, windows: List[RelikReaderSample]
112
+ ) -> List[RelikReaderSample]:
113
+ windows_by_doc_id = collections.defaultdict(list)
114
+ for window in windows:
115
+ windows_by_doc_id[window.doc_id].append(window)
116
+
117
+ merged_window_by_doc = {
118
+ doc_id: self.merge_doc_windows(doc_windows)
119
+ for doc_id, doc_windows in windows_by_doc_id.items()
120
+ }
121
+
122
+ return list(merged_window_by_doc.values())
123
+
124
+ def merge_doc_windows(self, windows: List[RelikReaderSample]) -> RelikReaderSample:
125
+ if len(windows) == 1:
126
+ return windows[0]
127
+
128
+ if len(windows) > 0 and getattr(windows[0], "offset", None) is not None:
129
+ windows = sorted(windows, key=(lambda x: x.offset))
130
+
131
+ window_accumulator = windows[0]
132
+
133
+ for next_window in windows[1:]:
134
+ window_accumulator = self._merge_window_pair(
135
+ window_accumulator, next_window
136
+ )
137
+
138
+ return window_accumulator
139
+
140
+ def _merge_tokens(
141
+ self, window1: RelikReaderSample, window2: RelikReaderSample
142
+ ) -> Tuple[list, dict, dict]:
143
+ w1_tokens = window1.tokens[1:-1]
144
+ w2_tokens = window2.tokens[1:-1]
145
+
146
+ # find intersection
147
+ tokens_intersection = None
148
+ for k in reversed(range(1, len(w1_tokens))):
149
+ if w1_tokens[-k:] == w2_tokens[:k]:
150
+ tokens_intersection = k
151
+ break
152
+ assert tokens_intersection is not None, (
153
+ f"{window1.doc_id} - {window1.sent_id} - {window1.offset}"
154
+ + f" {window2.doc_id} - {window2.sent_id} - {window2.offset}\n"
155
+ + f"w1 tokens: {w1_tokens}\n"
156
+ + f"w2 tokens: {w2_tokens}\n"
157
+ )
158
+
159
+ final_tokens = (
160
+ [window1.tokens[0]] # CLS
161
+ + w1_tokens
162
+ + w2_tokens[tokens_intersection:]
163
+ + [window1.tokens[-1]] # SEP
164
+ )
165
+
166
+ w2_starting_offset = len(w1_tokens) - tokens_intersection
167
+
168
+ def merge_char_mapping(t2c1: dict, t2c2: dict) -> dict:
169
+ final_t2c = dict()
170
+ final_t2c.update(t2c1)
171
+ for t, c in t2c2.items():
172
+ t = int(t)
173
+ if t < tokens_intersection:
174
+ continue
175
+ final_t2c[str(t + w2_starting_offset)] = c
176
+ return final_t2c
177
+
178
+ return (
179
+ final_tokens,
180
+ merge_char_mapping(window1.token2char_start, window2.token2char_start),
181
+ merge_char_mapping(window1.token2char_end, window2.token2char_end),
182
+ )
183
+
184
+ def _merge_span_annotation(
185
+ self, span_annotation1: List[list], span_annotation2: List[list]
186
+ ) -> List[list]:
187
+ uniq_store = set()
188
+ final_span_annotation_store = []
189
+ for span_annotation in itertools.chain(span_annotation1, span_annotation2):
190
+ span_annotation_id = tuple(span_annotation)
191
+ if span_annotation_id not in uniq_store:
192
+ uniq_store.add(span_annotation_id)
193
+ final_span_annotation_store.append(span_annotation)
194
+ return sorted(final_span_annotation_store, key=lambda x: x[0])
195
+
196
+ def _merge_predictions(
197
+ self,
198
+ window1: RelikReaderSample,
199
+ window2: RelikReaderSample,
200
+ ) -> Tuple[Set[Tuple[int, int, str]], dict]:
201
+ merged_predictions = window1.predicted_window_labels_chars.union(
202
+ window2.predicted_window_labels_chars
203
+ )
204
+
205
+ span_title_probabilities = dict()
206
+ # probabilities
207
+ for span_prediction, predicted_probs in itertools.chain(
208
+ window1.probs_window_labels_chars.items(),
209
+ window2.probs_window_labels_chars.items(),
210
+ ):
211
+ if span_prediction not in span_title_probabilities:
212
+ span_title_probabilities[span_prediction] = predicted_probs
213
+
214
+ return merged_predictions, span_title_probabilities
215
+
216
+ def _merge_window_pair(
217
+ self,
218
+ window1: RelikReaderSample,
219
+ window2: RelikReaderSample,
220
+ ) -> RelikReaderSample:
221
+ merging_output = dict()
222
+
223
+ if getattr(window1, "doc_id", None) is not None:
224
+ assert window1.doc_id == window2.doc_id
225
+
226
+ if getattr(window1, "offset", None) is not None:
227
+ assert (
228
+ window1.offset < window2.offset
229
+ ), f"window 2 offset ({window2.offset}) is smaller that window 1 offset({window1.offset})"
230
+
231
+ merging_output["doc_id"] = window1.doc_id
232
+ merging_output["offset"] = window2.offset
233
+
234
+ m_tokens, m_token2char_start, m_token2char_end = self._merge_tokens(
235
+ window1, window2
236
+ )
237
+
238
+ window_labels = None
239
+ if getattr(window1, "window_labels", None) is not None:
240
+ window_labels = self._merge_span_annotation(
241
+ window1.window_labels, window2.window_labels
242
+ )
243
+ (
244
+ predicted_window_labels_chars,
245
+ probs_window_labels_chars,
246
+ ) = self._merge_predictions(
247
+ window1,
248
+ window2,
249
+ )
250
+
251
+ merging_output.update(
252
+ dict(
253
+ tokens=m_tokens,
254
+ token2char_start=m_token2char_start,
255
+ token2char_end=m_token2char_end,
256
+ window_labels=window_labels,
257
+ predicted_window_labels_chars=predicted_window_labels_chars,
258
+ probs_window_labels_chars=probs_window_labels_chars,
259
+ )
260
+ )
261
+
262
+ return RelikReaderSample(**merging_output)
relik/inference/gerbil.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import re
5
+ import sys
6
+ from http.server import BaseHTTPRequestHandler, HTTPServer
7
+ from typing import Iterator, List, Optional, Tuple
8
+
9
+ from relik.inference.annotator import Relik
10
+ from relik.inference.data.objects import RelikOutput
11
+
12
+ # sys.path += ['../']
13
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
14
+
15
+
16
+ import logging
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class GerbilAlbyManager:
22
+ def __init__(
23
+ self,
24
+ annotator: Optional[Relik] = None,
25
+ response_logger_dir: Optional[str] = None,
26
+ ) -> None:
27
+ self.annotator = annotator
28
+ self.response_logger_dir = response_logger_dir
29
+ self.predictions_counter = 0
30
+ self.labels_mapping = None
31
+
32
+ def annotate(self, document: str):
33
+ relik_output: RelikOutput = self.annotator(document)
34
+ annotations = [(ss, se, l) for ss, se, l, _ in relik_output.labels]
35
+ if self.labels_mapping is not None:
36
+ return [
37
+ (ss, se, self.labels_mapping.get(l, l)) for ss, se, l in annotations
38
+ ]
39
+ return annotations
40
+
41
+ def set_mapping_file(self, mapping_file_path: str):
42
+ with open(mapping_file_path) as f:
43
+ labels_mapping = json.load(f)
44
+ self.labels_mapping = {v: k for k, v in labels_mapping.items()}
45
+
46
+ def write_response_bundle(
47
+ self,
48
+ document: str,
49
+ new_document: str,
50
+ annotations: list,
51
+ mapped_annotations: list,
52
+ ) -> None:
53
+ if self.response_logger_dir is None:
54
+ return
55
+
56
+ if not os.path.isdir(self.response_logger_dir):
57
+ os.mkdir(self.response_logger_dir)
58
+
59
+ with open(
60
+ f"{self.response_logger_dir}/{self.predictions_counter}.json", "w"
61
+ ) as f:
62
+ out_json_obj = dict(
63
+ document=document,
64
+ new_document=new_document,
65
+ annotations=annotations,
66
+ mapped_annotations=mapped_annotations,
67
+ )
68
+
69
+ out_json_obj["span_annotations"] = [
70
+ (ss, se, document[ss:se], label) for (ss, se, label) in annotations
71
+ ]
72
+
73
+ out_json_obj["span_mapped_annotations"] = [
74
+ (ss, se, new_document[ss:se], label)
75
+ for (ss, se, label) in mapped_annotations
76
+ ]
77
+
78
+ json.dump(out_json_obj, f, indent=2)
79
+
80
+ self.predictions_counter += 1
81
+
82
+
83
+ manager = GerbilAlbyManager()
84
+
85
+
86
+ def preprocess_document(document: str) -> Tuple[str, List[Tuple[int, int]]]:
87
+ pattern_subs = {
88
+ "-LPR- ": " (",
89
+ "-RPR-": ")",
90
+ "\n\n": "\n",
91
+ "-LRB-": "(",
92
+ "-RRB-": ")",
93
+ '","': ",",
94
+ }
95
+
96
+ document_acc = document
97
+ curr_offset = 0
98
+ char2offset = []
99
+
100
+ matchings = re.finditer("({})".format("|".join(pattern_subs)), document)
101
+ for span_matching in sorted(matchings, key=lambda x: x.span()[0]):
102
+ span_start, span_end = span_matching.span()
103
+ span_start -= curr_offset
104
+ span_end -= curr_offset
105
+
106
+ span_text = document_acc[span_start:span_end]
107
+ span_sub = pattern_subs[span_text]
108
+ document_acc = document_acc[:span_start] + span_sub + document_acc[span_end:]
109
+
110
+ offset = len(span_text) - len(span_sub)
111
+ curr_offset += offset
112
+
113
+ char2offset.append((span_start + len(span_sub), curr_offset))
114
+
115
+ return document_acc, char2offset
116
+
117
+
118
+ def map_back_annotations(
119
+ annotations: List[Tuple[int, int, str]], char_mapping: List[Tuple[int, int]]
120
+ ) -> Iterator[Tuple[int, int, str]]:
121
+ def map_char(char_idx: int) -> int:
122
+ current_offset = 0
123
+ for offset_idx, offset_value in char_mapping:
124
+ if char_idx >= offset_idx:
125
+ current_offset = offset_value
126
+ else:
127
+ break
128
+ return char_idx + current_offset
129
+
130
+ for ss, se, label in annotations:
131
+ yield map_char(ss), map_char(se), label
132
+
133
+
134
+ def annotate(document: str) -> List[Tuple[int, int, str]]:
135
+ new_document, mapping = preprocess_document(document)
136
+ logger.info("Mapping: " + str(mapping))
137
+ logger.info("Document: " + str(document))
138
+ annotations = [
139
+ (cs, ce, label.replace(" ", "_"))
140
+ for cs, ce, label in manager.annotate(new_document)
141
+ ]
142
+ logger.info("New document: " + str(new_document))
143
+ mapped_annotations = (
144
+ list(map_back_annotations(annotations, mapping))
145
+ if len(mapping) > 0
146
+ else annotations
147
+ )
148
+
149
+ logger.info(
150
+ "Annotations: "
151
+ + str([(ss, se, document[ss:se], ann) for ss, se, ann in mapped_annotations])
152
+ )
153
+
154
+ manager.write_response_bundle(
155
+ document, new_document, mapped_annotations, annotations
156
+ )
157
+
158
+ if not all(
159
+ [
160
+ new_document[ss:se] == document[mss:mse]
161
+ for (mss, mse, _), (ss, se, _) in zip(mapped_annotations, annotations)
162
+ ]
163
+ ):
164
+ diff_mappings = [
165
+ (new_document[ss:se], document[mss:mse])
166
+ for (mss, mse, _), (ss, se, _) in zip(mapped_annotations, annotations)
167
+ ]
168
+ return None
169
+ assert all(
170
+ [
171
+ document[mss:mse] == new_document[ss:se]
172
+ for (mss, mse, _), (ss, se, _) in zip(mapped_annotations, annotations)
173
+ ]
174
+ ), (mapped_annotations, annotations)
175
+
176
+ return [(cs, ce - cs, label) for cs, ce, label in mapped_annotations]
177
+
178
+
179
+ class GetHandler(BaseHTTPRequestHandler):
180
+ def do_POST(self):
181
+ content_length = int(self.headers["Content-Length"])
182
+ post_data = self.rfile.read(content_length)
183
+ self.send_response(200)
184
+ self.end_headers()
185
+ doc_text = read_json(post_data)
186
+ # try:
187
+ response = annotate(doc_text)
188
+
189
+ self.wfile.write(bytes(json.dumps(response), "utf-8"))
190
+ return
191
+
192
+
193
+ def read_json(post_data):
194
+ data = json.loads(post_data.decode("utf-8"))
195
+ # logger.info("received data:", data)
196
+ text = data["text"]
197
+ # spans = [(int(j["start"]), int(j["length"])) for j in data["spans"]]
198
+ return text
199
+
200
+
201
+ def parse_args() -> argparse.Namespace:
202
+ parser = argparse.ArgumentParser()
203
+ parser.add_argument("--relik-model-name", required=True)
204
+ parser.add_argument("--responses-log-dir")
205
+ parser.add_argument("--log-file", default="logs/logging.txt")
206
+ parser.add_argument("--mapping-file")
207
+ return parser.parse_args()
208
+
209
+
210
+ def main():
211
+ args = parse_args()
212
+
213
+ # init manager
214
+ manager.response_logger_dir = args.responses_log_dir
215
+ # manager.annotator = Relik.from_pretrained(args.relik_model_name)
216
+
217
+ print("Debugging, not using you relik model but an hardcoded one.")
218
+ manager.annotator = Relik(
219
+ question_encoder="riccorl/relik-retriever-aida-blink-pretrain-omniencoder",
220
+ document_index="riccorl/index-relik-retriever-aida-blink-pretrain-omniencoder",
221
+ reader="relik/reader/models/relik-reader-deberta-base-new-data",
222
+ window_size=32,
223
+ window_stride=16,
224
+ candidates_preprocessing_fn=(lambda x: x.split("<def>")[0].strip()),
225
+ )
226
+
227
+ if args.mapping_file is not None:
228
+ manager.set_mapping_file(args.mapping_file)
229
+
230
+ port = 6654
231
+ server = HTTPServer(("localhost", port), GetHandler)
232
+ logger.info(f"Starting server at http://localhost:{port}")
233
+
234
+ # Create a file handler and set its level
235
+ file_handler = logging.FileHandler(args.log_file)
236
+ file_handler.setLevel(logging.DEBUG)
237
+
238
+ # Create a log formatter and set it on the handler
239
+ formatter = logging.Formatter(
240
+ "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
241
+ )
242
+ file_handler.setFormatter(formatter)
243
+
244
+ # Add the file handler to the logger
245
+ logger.addHandler(file_handler)
246
+
247
+ try:
248
+ server.serve_forever()
249
+ except KeyboardInterrupt:
250
+ exit(0)
251
+
252
+
253
+ if __name__ == "__main__":
254
+ main()
relik/inference/preprocessing.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ def wikipedia_title_and_openings_preprocessing(
2
+ wikipedia_title_and_openings: str, sepator: str = " <def>"
3
+ ):
4
+ return wikipedia_title_and_openings.split(sepator, 1)[0]
relik/inference/serve/__init__.py ADDED
File without changes
relik/inference/serve/backend/__init__.py ADDED
File without changes
relik/inference/serve/backend/relik.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from pathlib import Path
3
+ from typing import List, Optional, Union
4
+
5
+ from relik.common.utils import is_package_available
6
+ from relik.inference.annotator import Relik
7
+
8
+ if not is_package_available("fastapi"):
9
+ raise ImportError(
10
+ "FastAPI is not installed. Please install FastAPI with `pip install relik[serve]`."
11
+ )
12
+ from fastapi import FastAPI, HTTPException
13
+
14
+ if not is_package_available("ray"):
15
+ raise ImportError(
16
+ "Ray is not installed. Please install Ray with `pip install relik[serve]`."
17
+ )
18
+ from ray import serve
19
+
20
+ from relik.common.log import get_logger
21
+ from relik.inference.serve.backend.utils import (
22
+ RayParameterManager,
23
+ ServerParameterManager,
24
+ )
25
+ from relik.retriever.data.utils import batch_generator
26
+
27
+ logger = get_logger(__name__, level=logging.INFO)
28
+
29
+ VERSION = {} # type: ignore
30
+ with open(
31
+ Path(__file__).parent.parent.parent.parent / "version.py", "r"
32
+ ) as version_file:
33
+ exec(version_file.read(), VERSION)
34
+
35
+ # Env variables for server
36
+ SERVER_MANAGER = ServerParameterManager()
37
+ RAY_MANAGER = RayParameterManager()
38
+
39
+ app = FastAPI(
40
+ title="ReLiK",
41
+ version=VERSION["VERSION"],
42
+ description="ReLiK REST API",
43
+ )
44
+
45
+
46
+ @serve.deployment(
47
+ ray_actor_options={
48
+ "num_gpus": RAY_MANAGER.num_gpus
49
+ if (
50
+ SERVER_MANAGER.retriver_device == "cuda"
51
+ or SERVER_MANAGER.reader_device == "cuda"
52
+ )
53
+ else 0
54
+ },
55
+ autoscaling_config={
56
+ "min_replicas": RAY_MANAGER.min_replicas,
57
+ "max_replicas": RAY_MANAGER.max_replicas,
58
+ },
59
+ )
60
+ @serve.ingress(app)
61
+ class RelikServer:
62
+ def __init__(
63
+ self,
64
+ question_encoder: str,
65
+ document_index: str,
66
+ passage_encoder: Optional[str] = None,
67
+ reader_encoder: Optional[str] = None,
68
+ top_k: int = 100,
69
+ retriver_device: str = "cpu",
70
+ reader_device: str = "cpu",
71
+ index_device: Optional[str] = None,
72
+ precision: int = 32,
73
+ index_precision: Optional[int] = None,
74
+ use_faiss: bool = False,
75
+ window_batch_size: int = 32,
76
+ window_size: int = 32,
77
+ window_stride: int = 16,
78
+ split_on_spaces: bool = False,
79
+ ):
80
+ # parameters
81
+ self.question_encoder = question_encoder
82
+ self.passage_encoder = passage_encoder
83
+ self.reader_encoder = reader_encoder
84
+ self.document_index = document_index
85
+ self.top_k = top_k
86
+ self.retriver_device = retriver_device
87
+ self.index_device = index_device or retriver_device
88
+ self.reader_device = reader_device
89
+ self.precision = precision
90
+ self.index_precision = index_precision or precision
91
+ self.use_faiss = use_faiss
92
+ self.window_batch_size = window_batch_size
93
+ self.window_size = window_size
94
+ self.window_stride = window_stride
95
+ self.split_on_spaces = split_on_spaces
96
+
97
+ # log stuff for debugging
98
+ logger.info("Initializing RelikServer with parameters:")
99
+ logger.info(f"QUESTION_ENCODER: {self.question_encoder}")
100
+ logger.info(f"PASSAGE_ENCODER: {self.passage_encoder}")
101
+ logger.info(f"READER_ENCODER: {self.reader_encoder}")
102
+ logger.info(f"DOCUMENT_INDEX: {self.document_index}")
103
+ logger.info(f"TOP_K: {self.top_k}")
104
+ logger.info(f"RETRIEVER_DEVICE: {self.retriver_device}")
105
+ logger.info(f"READER_DEVICE: {self.reader_device}")
106
+ logger.info(f"INDEX_DEVICE: {self.index_device}")
107
+ logger.info(f"PRECISION: {self.precision}")
108
+ logger.info(f"INDEX_PRECISION: {self.index_precision}")
109
+ logger.info(f"WINDOW_BATCH_SIZE: {self.window_batch_size}")
110
+ logger.info(f"SPLIT_ON_SPACES: {self.split_on_spaces}")
111
+
112
+ self.relik = Relik(
113
+ question_encoder=self.question_encoder,
114
+ passage_encoder=self.passage_encoder,
115
+ document_index=self.document_index,
116
+ reader=self.reader_encoder,
117
+ retriever_device=self.retriver_device,
118
+ document_index_device=self.index_device,
119
+ reader_device=self.reader_device,
120
+ retriever_precision=self.precision,
121
+ document_index_precision=self.index_precision,
122
+ reader_precision=self.precision,
123
+ )
124
+
125
+ # @serve.batch()
126
+ async def handle_batch(self, documents: List[str]) -> List:
127
+ return self.relik(
128
+ documents,
129
+ top_k=self.top_k,
130
+ window_size=self.window_size,
131
+ window_stride=self.window_stride,
132
+ batch_size=self.window_batch_size,
133
+ )
134
+
135
+ @app.post("/api/entities")
136
+ async def entities_endpoint(
137
+ self,
138
+ documents: Union[str, List[str]],
139
+ ):
140
+ try:
141
+ # normalize input
142
+ if isinstance(documents, str):
143
+ documents = [documents]
144
+ if document_topics is not None:
145
+ if isinstance(document_topics, str):
146
+ document_topics = [document_topics]
147
+ assert len(documents) == len(document_topics)
148
+ # get predictions for the retriever
149
+ return await self.handle_batch(documents, document_topics)
150
+ except Exception as e:
151
+ # log the entire stack trace
152
+ logger.exception(e)
153
+ raise HTTPException(status_code=500, detail=f"Server Error: {e}")
154
+
155
+ @app.post("/api/gerbil")
156
+ async def gerbil_endpoint(self, documents: Union[str, List[str]]):
157
+ try:
158
+ # normalize input
159
+ if isinstance(documents, str):
160
+ documents = [documents]
161
+
162
+ # output list
163
+ windows_passages = []
164
+ # split documents into windows
165
+ document_windows = [
166
+ window
167
+ for doc_id, document in enumerate(documents)
168
+ for window in self.window_manager(
169
+ self.tokenizer,
170
+ document,
171
+ window_size=self.window_size,
172
+ stride=self.window_stride,
173
+ doc_id=doc_id,
174
+ )
175
+ ]
176
+
177
+ # get text and topic from document windows and create new list
178
+ model_inputs = [
179
+ (window.text, window.doc_topic) for window in document_windows
180
+ ]
181
+
182
+ # batch generator
183
+ for batch in batch_generator(
184
+ model_inputs, batch_size=self.window_batch_size
185
+ ):
186
+ text, text_pair = zip(*batch)
187
+ batch_predictions = await self.handle_batch_retriever(text, text_pair)
188
+ windows_passages.extend(
189
+ [
190
+ [p.label for p in predictions]
191
+ for predictions in batch_predictions
192
+ ]
193
+ )
194
+
195
+ # add passage to document windows
196
+ for window, passages in zip(document_windows, windows_passages):
197
+ # clean up passages (remove everything after first <def> tag if present)
198
+ passages = [c.split(" <def>", 1)[0] for c in passages]
199
+ window.window_candidates = passages
200
+
201
+ # return document windows
202
+ return document_windows
203
+
204
+ except Exception as e:
205
+ # log the entire stack trace
206
+ logger.exception(e)
207
+ raise HTTPException(status_code=500, detail=f"Server Error: {e}")
208
+
209
+
210
+ server = RelikServer.bind(**vars(SERVER_MANAGER))
relik/inference/serve/backend/retriever.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from pathlib import Path
3
+ from typing import List, Optional, Union
4
+
5
+ from relik.common.utils import is_package_available
6
+
7
+ if not is_package_available("fastapi"):
8
+ raise ImportError(
9
+ "FastAPI is not installed. Please install FastAPI with `pip install relik[serve]`."
10
+ )
11
+ from fastapi import FastAPI, HTTPException
12
+
13
+ if not is_package_available("ray"):
14
+ raise ImportError(
15
+ "Ray is not installed. Please install Ray with `pip install relik[serve]`."
16
+ )
17
+ from ray import serve
18
+
19
+ from relik.common.log import get_logger
20
+ from relik.inference.data.tokenizers import SpacyTokenizer, WhitespaceTokenizer
21
+ from relik.inference.data.window.manager import WindowManager
22
+ from relik.inference.serve.backend.utils import (
23
+ RayParameterManager,
24
+ ServerParameterManager,
25
+ )
26
+ from relik.retriever.data.utils import batch_generator
27
+ from relik.retriever.pytorch_modules import GoldenRetriever
28
+
29
+ logger = get_logger(__name__, level=logging.INFO)
30
+
31
+ VERSION = {} # type: ignore
32
+ with open(Path(__file__).parent.parent.parent / "version.py", "r") as version_file:
33
+ exec(version_file.read(), VERSION)
34
+
35
+ # Env variables for server
36
+ SERVER_MANAGER = ServerParameterManager()
37
+ RAY_MANAGER = RayParameterManager()
38
+
39
+ app = FastAPI(
40
+ title="Golden Retriever",
41
+ version=VERSION["VERSION"],
42
+ description="Golden Retriever REST API",
43
+ )
44
+
45
+
46
+ @serve.deployment(
47
+ ray_actor_options={
48
+ "num_gpus": RAY_MANAGER.num_gpus if SERVER_MANAGER.device == "cuda" else 0
49
+ },
50
+ autoscaling_config={
51
+ "min_replicas": RAY_MANAGER.min_replicas,
52
+ "max_replicas": RAY_MANAGER.max_replicas,
53
+ },
54
+ )
55
+ @serve.ingress(app)
56
+ class GoldenRetrieverServer:
57
+ def __init__(
58
+ self,
59
+ question_encoder: str,
60
+ document_index: str,
61
+ passage_encoder: Optional[str] = None,
62
+ top_k: int = 100,
63
+ device: str = "cpu",
64
+ index_device: Optional[str] = None,
65
+ precision: int = 32,
66
+ index_precision: Optional[int] = None,
67
+ use_faiss: bool = False,
68
+ window_batch_size: int = 32,
69
+ window_size: int = 32,
70
+ window_stride: int = 16,
71
+ split_on_spaces: bool = False,
72
+ ):
73
+ # parameters
74
+ self.question_encoder = question_encoder
75
+ self.passage_encoder = passage_encoder
76
+ self.document_index = document_index
77
+ self.top_k = top_k
78
+ self.device = device
79
+ self.index_device = index_device or device
80
+ self.precision = precision
81
+ self.index_precision = index_precision or precision
82
+ self.use_faiss = use_faiss
83
+ self.window_batch_size = window_batch_size
84
+ self.window_size = window_size
85
+ self.window_stride = window_stride
86
+ self.split_on_spaces = split_on_spaces
87
+
88
+ # log stuff for debugging
89
+ logger.info("Initializing GoldenRetrieverServer with parameters:")
90
+ logger.info(f"QUESTION_ENCODER: {self.question_encoder}")
91
+ logger.info(f"PASSAGE_ENCODER: {self.passage_encoder}")
92
+ logger.info(f"DOCUMENT_INDEX: {self.document_index}")
93
+ logger.info(f"TOP_K: {self.top_k}")
94
+ logger.info(f"DEVICE: {self.device}")
95
+ logger.info(f"INDEX_DEVICE: {self.index_device}")
96
+ logger.info(f"PRECISION: {self.precision}")
97
+ logger.info(f"INDEX_PRECISION: {self.index_precision}")
98
+ logger.info(f"WINDOW_BATCH_SIZE: {self.window_batch_size}")
99
+ logger.info(f"SPLIT_ON_SPACES: {self.split_on_spaces}")
100
+
101
+ self.retriever = GoldenRetriever(
102
+ question_encoder=self.question_encoder,
103
+ passage_encoder=self.passage_encoder,
104
+ document_index=self.document_index,
105
+ device=self.device,
106
+ index_device=self.index_device,
107
+ index_precision=self.index_precision,
108
+ )
109
+ self.retriever.eval()
110
+
111
+ if self.split_on_spaces:
112
+ logger.info("Using WhitespaceTokenizer")
113
+ self.tokenizer = WhitespaceTokenizer()
114
+ # logger.info("Using RegexTokenizer")
115
+ # self.tokenizer = RegexTokenizer()
116
+ else:
117
+ logger.info("Using SpacyTokenizer")
118
+ self.tokenizer = SpacyTokenizer(language="en")
119
+
120
+ self.window_manager = WindowManager(tokenizer=self.tokenizer)
121
+
122
+ # @serve.batch()
123
+ async def handle_batch(
124
+ self, documents: List[str], document_topics: List[str]
125
+ ) -> List:
126
+ return self.retriever.retrieve(
127
+ documents, text_pair=document_topics, k=self.top_k, precision=self.precision
128
+ )
129
+
130
+ @app.post("/api/retrieve")
131
+ async def retrieve_endpoint(
132
+ self,
133
+ documents: Union[str, List[str]],
134
+ document_topics: Optional[Union[str, List[str]]] = None,
135
+ ):
136
+ try:
137
+ # normalize input
138
+ if isinstance(documents, str):
139
+ documents = [documents]
140
+ if document_topics is not None:
141
+ if isinstance(document_topics, str):
142
+ document_topics = [document_topics]
143
+ assert len(documents) == len(document_topics)
144
+ # get predictions
145
+ return await self.handle_batch(documents, document_topics)
146
+ except Exception as e:
147
+ # log the entire stack trace
148
+ logger.exception(e)
149
+ raise HTTPException(status_code=500, detail=f"Server Error: {e}")
150
+
151
+ @app.post("/api/gerbil")
152
+ async def gerbil_endpoint(self, documents: Union[str, List[str]]):
153
+ try:
154
+ # normalize input
155
+ if isinstance(documents, str):
156
+ documents = [documents]
157
+
158
+ # output list
159
+ windows_passages = []
160
+ # split documents into windows
161
+ document_windows = [
162
+ window
163
+ for doc_id, document in enumerate(documents)
164
+ for window in self.window_manager(
165
+ self.tokenizer,
166
+ document,
167
+ window_size=self.window_size,
168
+ stride=self.window_stride,
169
+ doc_id=doc_id,
170
+ )
171
+ ]
172
+
173
+ # get text and topic from document windows and create new list
174
+ model_inputs = [
175
+ (window.text, window.doc_topic) for window in document_windows
176
+ ]
177
+
178
+ # batch generator
179
+ for batch in batch_generator(
180
+ model_inputs, batch_size=self.window_batch_size
181
+ ):
182
+ text, text_pair = zip(*batch)
183
+ batch_predictions = await self.handle_batch(text, text_pair)
184
+ windows_passages.extend(
185
+ [
186
+ [p.label for p in predictions]
187
+ for predictions in batch_predictions
188
+ ]
189
+ )
190
+
191
+ # add passage to document windows
192
+ for window, passages in zip(document_windows, windows_passages):
193
+ # clean up passages (remove everything after first <def> tag if present)
194
+ passages = [c.split(" <def>", 1)[0] for c in passages]
195
+ window.window_candidates = passages
196
+
197
+ # return document windows
198
+ return document_windows
199
+
200
+ except Exception as e:
201
+ # log the entire stack trace
202
+ logger.exception(e)
203
+ raise HTTPException(status_code=500, detail=f"Server Error: {e}")
204
+
205
+
206
+ server = GoldenRetrieverServer.bind(**vars(SERVER_MANAGER))
relik/inference/serve/backend/utils.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dataclasses import dataclass
3
+ from typing import Union
4
+
5
+
6
+ @dataclass
7
+ class ServerParameterManager:
8
+ retriver_device: str = os.environ.get("RETRIEVER_DEVICE", "cpu")
9
+ reader_device: str = os.environ.get("READER_DEVICE", "cpu")
10
+ index_device: str = os.environ.get("INDEX_DEVICE", retriver_device)
11
+ precision: Union[str, int] = os.environ.get("PRECISION", "fp32")
12
+ index_precision: Union[str, int] = os.environ.get("INDEX_PRECISION", precision)
13
+ question_encoder: str = os.environ.get("QUESTION_ENCODER", None)
14
+ passage_encoder: str = os.environ.get("PASSAGE_ENCODER", None)
15
+ document_index: str = os.environ.get("DOCUMENT_INDEX", None)
16
+ reader_encoder: str = os.environ.get("READER_ENCODER", None)
17
+ top_k: int = int(os.environ.get("TOP_K", 100))
18
+ use_faiss: bool = os.environ.get("USE_FAISS", False)
19
+ window_batch_size: int = int(os.environ.get("WINDOW_BATCH_SIZE", 32))
20
+ window_size: int = int(os.environ.get("WINDOW_SIZE", 32))
21
+ window_stride: int = int(os.environ.get("WINDOW_SIZE", 16))
22
+ split_on_spaces: bool = os.environ.get("SPLIT_ON_SPACES", False)
23
+
24
+
25
+ class RayParameterManager:
26
+ def __init__(self) -> None:
27
+ self.num_gpus = int(os.environ.get("NUM_GPUS", 1))
28
+ self.min_replicas = int(os.environ.get("MIN_REPLICAS", 1))
29
+ self.max_replicas = int(os.environ.get("MAX_REPLICAS", 1))
relik/inference/serve/frontend/__init__.py ADDED
File without changes
relik/inference/serve/frontend/relik.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import time
4
+ from pathlib import Path
5
+
6
+ import requests
7
+ import streamlit as st
8
+ from spacy import displacy
9
+ from streamlit_extras.badges import badge
10
+ from streamlit_extras.stylable_container import stylable_container
11
+
12
+ RELIK = os.getenv("RELIK", "localhost:8000/api/entities")
13
+
14
+ import random
15
+
16
+
17
+ def get_random_color(ents):
18
+ colors = {}
19
+ random_colors = generate_pastel_colors(len(ents))
20
+ for ent in ents:
21
+ colors[ent] = random_colors.pop(random.randint(0, len(random_colors) - 1))
22
+ return colors
23
+
24
+
25
+ def floatrange(start, stop, steps):
26
+ if int(steps) == 1:
27
+ return [stop]
28
+ return [
29
+ start + float(i) * (stop - start) / (float(steps) - 1) for i in range(steps)
30
+ ]
31
+
32
+
33
+ def hsl_to_rgb(h, s, l):
34
+ def hue_2_rgb(v1, v2, v_h):
35
+ while v_h < 0.0:
36
+ v_h += 1.0
37
+ while v_h > 1.0:
38
+ v_h -= 1.0
39
+ if 6 * v_h < 1.0:
40
+ return v1 + (v2 - v1) * 6.0 * v_h
41
+ if 2 * v_h < 1.0:
42
+ return v2
43
+ if 3 * v_h < 2.0:
44
+ return v1 + (v2 - v1) * ((2.0 / 3.0) - v_h) * 6.0
45
+ return v1
46
+
47
+ # if not (0 <= s <= 1): raise ValueError, "s (saturation) parameter must be between 0 and 1."
48
+ # if not (0 <= l <= 1): raise ValueError, "l (lightness) parameter must be between 0 and 1."
49
+
50
+ r, b, g = (l * 255,) * 3
51
+ if s != 0.0:
52
+ if l < 0.5:
53
+ var_2 = l * (1.0 + s)
54
+ else:
55
+ var_2 = (l + s) - (s * l)
56
+ var_1 = 2.0 * l - var_2
57
+ r = 255 * hue_2_rgb(var_1, var_2, h + (1.0 / 3.0))
58
+ g = 255 * hue_2_rgb(var_1, var_2, h)
59
+ b = 255 * hue_2_rgb(var_1, var_2, h - (1.0 / 3.0))
60
+
61
+ return int(round(r)), int(round(g)), int(round(b))
62
+
63
+
64
+ def generate_pastel_colors(n):
65
+ """Return different pastel colours.
66
+
67
+ Input:
68
+ n (integer) : The number of colors to return
69
+
70
+ Output:
71
+ A list of colors in HTML notation (eg.['#cce0ff', '#ffcccc', '#ccffe0', '#f5ccff', '#f5ffcc'])
72
+
73
+ Example:
74
+ >>> print generate_pastel_colors(5)
75
+ ['#cce0ff', '#f5ccff', '#ffcccc', '#f5ffcc', '#ccffe0']
76
+ """
77
+ if n == 0:
78
+ return []
79
+
80
+ # To generate colors, we use the HSL colorspace (see http://en.wikipedia.org/wiki/HSL_color_space)
81
+ start_hue = 0.6 # 0=red 1/3=0.333=green 2/3=0.666=blue
82
+ saturation = 1.0
83
+ lightness = 0.8
84
+ # We take points around the chromatic circle (hue):
85
+ # (Note: we generate n+1 colors, then drop the last one ([:-1]) because
86
+ # it equals the first one (hue 0 = hue 1))
87
+ return [
88
+ "#%02x%02x%02x" % hsl_to_rgb(hue, saturation, lightness)
89
+ for hue in floatrange(start_hue, start_hue + 1, n + 1)
90
+ ][:-1]
91
+
92
+
93
+ def set_sidebar(css):
94
+ white_link_wrapper = "<link rel='stylesheet' href='https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.2/css/all.min.css'><a href='{}'>{}</a>"
95
+ with st.sidebar:
96
+ st.markdown(f"<style>{css}</style>", unsafe_allow_html=True)
97
+ st.image(
98
+ "http://nlp.uniroma1.it/static/website/sapienza-nlp-logo-wh.svg",
99
+ use_column_width=True,
100
+ )
101
+ st.markdown("## ReLiK")
102
+ st.write(
103
+ f"""
104
+ - {white_link_wrapper.format("#", "<i class='fa-solid fa-file'></i>&nbsp; Paper")}
105
+ - {white_link_wrapper.format("https://github.com/SapienzaNLP/relik", "<i class='fa-brands fa-github'></i>&nbsp; GitHub")}
106
+ - {white_link_wrapper.format("https://hub.docker.com/repository/docker/sapienzanlp/relik", "<i class='fa-brands fa-docker'></i>&nbsp; Docker Hub")}
107
+ """,
108
+ unsafe_allow_html=True,
109
+ )
110
+ st.markdown("## Sapienza NLP")
111
+ st.write(
112
+ f"""
113
+ - {white_link_wrapper.format("https://nlp.uniroma1.it", "<i class='fa-solid fa-globe'></i>&nbsp; Webpage")}
114
+ - {white_link_wrapper.format("https://github.com/SapienzaNLP", "<i class='fa-brands fa-github'></i>&nbsp; GitHub")}
115
+ - {white_link_wrapper.format("https://twitter.com/SapienzaNLP", "<i class='fa-brands fa-twitter'></i>&nbsp; Twitter")}
116
+ - {white_link_wrapper.format("https://www.linkedin.com/company/79434450", "<i class='fa-brands fa-linkedin'></i>&nbsp; LinkedIn")}
117
+ """,
118
+ unsafe_allow_html=True,
119
+ )
120
+
121
+
122
+ def get_el_annotations(response):
123
+ # swap labels key with ents
124
+ response["ents"] = response.pop("labels")
125
+ label_in_text = set(l["label"] for l in response["ents"])
126
+ options = {"ents": label_in_text, "colors": get_random_color(label_in_text)}
127
+ return response, options
128
+
129
+
130
+ def set_intro(css):
131
+ # intro
132
+ st.markdown("# ReLik")
133
+ st.markdown(
134
+ "### Retrieve, Read and LinK: Fast and Accurate Entity Linking and Relation Extraction on an Academic Budget"
135
+ )
136
+ # st.markdown(
137
+ # "This is a front-end for the paper [Universal Semantic Annotator: the First Unified API "
138
+ # "for WSD, SRL and Semantic Parsing](https://www.researchgate.net/publication/360671045_Universal_Semantic_Annotator_the_First_Unified_API_for_WSD_SRL_and_Semantic_Parsing), which will be presented at LREC 2022 by "
139
+ # "[Riccardo Orlando](https://riccorl.github.io), [Simone Conia](https://c-simone.github.io/), "
140
+ # "[Stefano Faralli](https://corsidilaurea.uniroma1.it/it/users/stefanofaralliuniroma1it), and [Roberto Navigli](https://www.diag.uniroma1.it/navigli/)."
141
+ # )
142
+ badge(type="github", name="sapienzanlp/relik")
143
+ badge(type="pypi", name="relik")
144
+
145
+
146
+ def run_client():
147
+ with open(Path(__file__).parent / "style.css") as f:
148
+ css = f.read()
149
+
150
+ st.set_page_config(
151
+ page_title="ReLik",
152
+ page_icon="🦮",
153
+ layout="wide",
154
+ )
155
+ set_sidebar(css)
156
+ set_intro(css)
157
+
158
+ # text input
159
+ text = st.text_area(
160
+ "Enter Text Below:",
161
+ value="Obama went to Rome for a quick vacation.",
162
+ height=200,
163
+ max_chars=500,
164
+ )
165
+
166
+ with stylable_container(
167
+ key="annotate_button",
168
+ css_styles="""
169
+ button {
170
+ background-color: #802433;
171
+ color: white;
172
+ border-radius: 25px;
173
+ }
174
+ """,
175
+ ):
176
+ submit = st.button("Annotate")
177
+ # submit = st.button("Run")
178
+
179
+ # ReLik API call
180
+ if submit:
181
+ text = text.strip()
182
+ if text:
183
+ st.markdown("####")
184
+ st.markdown("#### Entity Linking")
185
+ with st.spinner(text="In progress"):
186
+ response = requests.post(RELIK, json=text)
187
+ if response.status_code != 200:
188
+ st.error("Error: {}".format(response.status_code))
189
+ else:
190
+ response = response.json()
191
+
192
+ # Entity Linking
193
+ # with stylable_container(
194
+ # key="container_with_border",
195
+ # css_styles="""
196
+ # {
197
+ # border: 1px solid rgba(49, 51, 63, 0.2);
198
+ # border-radius: 0.5rem;
199
+ # padding: 0.5rem;
200
+ # padding-bottom: 2rem;
201
+ # }
202
+ # """,
203
+ # ):
204
+ # st.markdown("##")
205
+ dict_of_ents, options = get_el_annotations(response=response)
206
+ display = displacy.render(
207
+ dict_of_ents, manual=True, style="ent", options=options
208
+ )
209
+ display = display.replace("\n", " ")
210
+ # wsd_display = re.sub(
211
+ # r"(wiki::\d+\w)",
212
+ # r"<a href='https://babelnet.org/synset?id=\g<1>&orig=\g<1>&lang={}'>\g<1></a>".format(
213
+ # language.upper()
214
+ # ),
215
+ # wsd_display,
216
+ # )
217
+ with st.container():
218
+ st.write(display, unsafe_allow_html=True)
219
+
220
+ st.markdown("####")
221
+ st.markdown("#### Relation Extraction")
222
+
223
+ with st.container():
224
+ st.write("Coming :)", unsafe_allow_html=True)
225
+
226
+ else:
227
+ st.error("Please enter some text.")
228
+
229
+
230
+ if __name__ == "__main__":
231
+ run_client()
relik/inference/serve/frontend/style.css ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Sidebar */
2
+ .eczjsme11 {
3
+ background-color: #802433;
4
+ }
5
+
6
+ .st-emotion-cache-10oheav h2 {
7
+ color: white;
8
+ }
9
+
10
+ .st-emotion-cache-10oheav li {
11
+ color: white;
12
+ }
13
+
14
+ /* Main */
15
+ a:link {
16
+ text-decoration: none;
17
+ color: white;
18
+ }
19
+
20
+ a:visited {
21
+ text-decoration: none;
22
+ color: white;
23
+ }
24
+
25
+ a:hover {
26
+ text-decoration: none;
27
+ color: rgba(255, 255, 255, 0.871);
28
+ }
29
+
30
+ a:active {
31
+ text-decoration: none;
32
+ color: white;
33
+ }
relik/reader/__init__.py ADDED
File without changes
relik/reader/conf/config.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Required to make the "experiments" dir the default one for the output of the models
2
+ hydra:
3
+ run:
4
+ dir: ./experiments/${model_name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
5
+
6
+ model_name: relik-reader-deberta-base # used to name the model in wandb and output dir
7
+ project_name: relik-reader # used to name the project in wandb
8
+
9
+
10
+ defaults:
11
+ - _self_
12
+ - training: base
13
+ - model: base
14
+ - data: base
relik/reader/conf/data/base.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ train_dataset_path: "relik/reader/data/train.jsonl"
2
+ val_dataset_path: "relik/reader/data/testa.jsonl"
3
+
4
+ train_dataset:
5
+ _target_: "relik.reader.relik_reader_data.RelikDataset"
6
+ transformer_model: "${model.model.transformer_model}"
7
+ materialize_samples: False
8
+ shuffle_candidates: 0.5
9
+ random_drop_gold_candidates: 0.05
10
+ noise_param: 0.0
11
+ for_inference: False
12
+ tokens_per_batch: 4096
13
+ special_symbols: null
14
+
15
+ val_dataset:
16
+ _target_: "relik.reader.relik_reader_data.RelikDataset"
17
+ transformer_model: "${model.model.transformer_model}"
18
+ materialize_samples: False
19
+ shuffle_candidates: False
20
+ for_inference: True
21
+ special_symbols: null
relik/reader/conf/data/re.yaml ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ train_dataset_path: "relik/reader/data/nyt-alby+/train.jsonl"
2
+ val_dataset_path: "relik/reader/data/nyt-alby+/valid.jsonl"
3
+ test_dataset_path: "relik/reader/data/nyt-alby+/test.jsonl"
4
+
5
+ relations_definitions:
6
+ /people/person/nationality: "nationality"
7
+ /sports/sports_team/location: "sports team location"
8
+ /location/country/administrative_divisions: "administrative divisions"
9
+ /business/company/major_shareholders: "shareholders"
10
+ /people/ethnicity/people: "ethnicity"
11
+ /people/ethnicity/geographic_distribution: "geographic distributi6on"
12
+ /business/company_shareholder/major_shareholder_of: "major shareholder"
13
+ /location/location/contains: "location"
14
+ /business/company/founders: "founders"
15
+ /business/person/company: "company"
16
+ /business/company/advisors: "advisor"
17
+ /people/deceased_person/place_of_death: "place of death"
18
+ /business/company/industry: "industry"
19
+ /people/person/ethnicity: "ethnic background"
20
+ /people/person/place_of_birth: "place of birth"
21
+ /location/administrative_division/country: "country of an administration division"
22
+ /people/person/place_lived: "place lived"
23
+ /sports/sports_team_location/teams: "sports team"
24
+ /people/person/children: "child"
25
+ /people/person/religion: "religion"
26
+ /location/neighborhood/neighborhood_of: "neighborhood"
27
+ /location/country/capital: "capital"
28
+ /business/company/place_founded: "company founded location"
29
+ /people/person/profession: "occupation"
30
+
31
+ train_dataset:
32
+ _target_: "relik.reader.relik_reader_re_data.RelikREDataset"
33
+ transformer_model: "${model.model.transformer_model}"
34
+ materialize_samples: False
35
+ shuffle_candidates: False
36
+ flip_candidates: 1.0
37
+ noise_param: 0.0
38
+ for_inference: False
39
+ tokens_per_batch: 4096
40
+ min_length: -1
41
+ special_symbols: null
42
+ relations_definitions: ${data.relations_definitions}
43
+ sorting_fields:
44
+ - "predictable_candidates"
45
+ val_dataset:
46
+ _target_: "relik.reader.relik_reader_re_data.RelikREDataset"
47
+ transformer_model: "${model.model.transformer_model}"
48
+ materialize_samples: False
49
+ shuffle_candidates: False
50
+ flip_candidates: False
51
+ for_inference: True
52
+ min_length: -1
53
+ special_symbols: null
54
+ relations_definitions: ${data.relations_definitions}
relik/reader/conf/training/base.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ seed: 94
2
+
3
+ trainer:
4
+ _target_: lightning.Trainer
5
+ devices:
6
+ - 0
7
+ precision: "16-mixed"
8
+ max_steps: 50000
9
+ val_check_interval: 1.0
10
+ num_sanity_val_steps: 0
11
+ limit_val_batches: 1
12
+ gradient_clip_val: 1.0
relik/reader/conf/training/re.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ seed: 15
2
+
3
+ trainer:
4
+ _target_: lightning.Trainer
5
+ devices:
6
+ - 0
7
+ precision: "16-mixed"
8
+ max_steps: 100000
9
+ val_check_interval: 1.0
10
+ num_sanity_val_steps: 0
11
+ limit_val_batches: 1
12
+ gradient_clip_val: 1.0
relik/reader/data/__init__.py ADDED
File without changes
relik/reader/data/patches.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ from relik.reader.data.relik_reader_sample import RelikReaderSample
4
+ from relik.reader.utils.special_symbols import NME_SYMBOL
5
+
6
+
7
+ def merge_patches_predictions(sample) -> None:
8
+ sample._d["predicted_window_labels"] = dict()
9
+ predicted_window_labels = sample._d["predicted_window_labels"]
10
+
11
+ sample._d["span_title_probabilities"] = dict()
12
+ span_title_probabilities = sample._d["span_title_probabilities"]
13
+
14
+ span2title = dict()
15
+ for _, patch_info in sorted(sample.patches.items(), key=lambda x: x[0]):
16
+ # selecting span predictions
17
+ for predicted_title, predicted_spans in patch_info[
18
+ "predicted_window_labels"
19
+ ].items():
20
+ for pred_span in predicted_spans:
21
+ pred_span = tuple(pred_span)
22
+ curr_title = span2title.get(pred_span)
23
+ if curr_title is None or curr_title == NME_SYMBOL:
24
+ span2title[pred_span] = predicted_title
25
+ # else:
26
+ # print("Merging at patch level")
27
+
28
+ # selecting span predictions probability
29
+ for predicted_span, titles_probabilities in patch_info[
30
+ "span_title_probabilities"
31
+ ].items():
32
+ if predicted_span not in span_title_probabilities:
33
+ span_title_probabilities[predicted_span] = titles_probabilities
34
+
35
+ for span, title in span2title.items():
36
+ if title not in predicted_window_labels:
37
+ predicted_window_labels[title] = list()
38
+ predicted_window_labels[title].append(span)
39
+
40
+
41
+ def remove_duplicate_samples(
42
+ samples: List[RelikReaderSample],
43
+ ) -> List[RelikReaderSample]:
44
+ seen_sample = set()
45
+ samples_store = []
46
+ for sample in samples:
47
+ sample_id = f"{sample.doc_id}#{sample.sent_id}#{sample.offset}"
48
+ if sample_id not in seen_sample:
49
+ seen_sample.add(sample_id)
50
+ samples_store.append(sample)
51
+ return samples_store
relik/reader/data/relik_reader_data.py ADDED
@@ -0,0 +1,965 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import (
3
+ Any,
4
+ Callable,
5
+ Dict,
6
+ Generator,
7
+ Iterable,
8
+ Iterator,
9
+ List,
10
+ NamedTuple,
11
+ Optional,
12
+ Tuple,
13
+ Union,
14
+ )
15
+
16
+ import numpy as np
17
+ import torch
18
+ from torch.utils.data import IterableDataset
19
+ from tqdm import tqdm
20
+ from transformers import AutoTokenizer, PreTrainedTokenizer
21
+
22
+ from relik.reader.data.relik_reader_data_utils import (
23
+ add_noise_to_value,
24
+ batchify,
25
+ chunks,
26
+ flatten,
27
+ )
28
+ from relik.reader.data.relik_reader_sample import (
29
+ RelikReaderSample,
30
+ load_relik_reader_samples,
31
+ )
32
+ from relik.reader.utils.special_symbols import NME_SYMBOL
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+
37
+ def preprocess_dataset(
38
+ input_dataset: Iterable[dict],
39
+ transformer_model: str,
40
+ add_topic: bool,
41
+ ) -> Iterable[dict]:
42
+ tokenizer = AutoTokenizer.from_pretrained(transformer_model)
43
+ for dataset_elem in tqdm(input_dataset, desc="Preprocessing input dataset"):
44
+ if len(dataset_elem["tokens"]) == 0:
45
+ print(
46
+ f"Dataset element with doc id: {dataset_elem['doc_id']}",
47
+ f"and offset {dataset_elem['offset']} does not contain any token",
48
+ "Skipping it",
49
+ )
50
+ continue
51
+
52
+ new_dataset_elem = dict(
53
+ doc_id=dataset_elem["doc_id"],
54
+ offset=dataset_elem["offset"],
55
+ )
56
+
57
+ tokenization_out = tokenizer(
58
+ dataset_elem["tokens"],
59
+ return_offsets_mapping=True,
60
+ add_special_tokens=False,
61
+ )
62
+
63
+ window_tokens = tokenization_out.input_ids
64
+ window_tokens = flatten(window_tokens)
65
+
66
+ offsets_mapping = [
67
+ [
68
+ (
69
+ ss + dataset_elem["token2char_start"][str(i)],
70
+ se + dataset_elem["token2char_start"][str(i)],
71
+ )
72
+ for ss, se in tokenization_out.offset_mapping[i]
73
+ ]
74
+ for i in range(len(dataset_elem["tokens"]))
75
+ ]
76
+
77
+ offsets_mapping = flatten(offsets_mapping)
78
+
79
+ assert len(offsets_mapping) == len(window_tokens)
80
+
81
+ window_tokens = (
82
+ [tokenizer.cls_token_id] + window_tokens + [tokenizer.sep_token_id]
83
+ )
84
+
85
+ topic_offset = 0
86
+ if add_topic:
87
+ topic_tokens = tokenizer(
88
+ dataset_elem["doc_topic"], add_special_tokens=False
89
+ ).input_ids
90
+ topic_offset = len(topic_tokens)
91
+ new_dataset_elem["topic_tokens"] = topic_offset
92
+ window_tokens = window_tokens[:1] + topic_tokens + window_tokens[1:]
93
+
94
+ new_dataset_elem.update(
95
+ dict(
96
+ tokens=window_tokens,
97
+ token2char_start={
98
+ str(i): s
99
+ for i, (s, _) in enumerate(offsets_mapping, start=topic_offset)
100
+ },
101
+ token2char_end={
102
+ str(i): e
103
+ for i, (_, e) in enumerate(offsets_mapping, start=topic_offset)
104
+ },
105
+ window_candidates=dataset_elem["window_candidates"],
106
+ window_candidates_scores=dataset_elem.get("window_candidates_scores"),
107
+ )
108
+ )
109
+
110
+ if "window_labels" in dataset_elem:
111
+ window_labels = [
112
+ (s, e, l.replace("_", " ")) for s, e, l in dataset_elem["window_labels"]
113
+ ]
114
+
115
+ new_dataset_elem["window_labels"] = window_labels
116
+
117
+ if not all(
118
+ [
119
+ s in new_dataset_elem["token2char_start"].values()
120
+ for s, _, _ in new_dataset_elem["window_labels"]
121
+ ]
122
+ ):
123
+ print(
124
+ "Mismatching token start char mapping with labels",
125
+ new_dataset_elem["token2char_start"],
126
+ new_dataset_elem["window_labels"],
127
+ dataset_elem["tokens"],
128
+ )
129
+ continue
130
+
131
+ if not all(
132
+ [
133
+ e in new_dataset_elem["token2char_end"].values()
134
+ for _, e, _ in new_dataset_elem["window_labels"]
135
+ ]
136
+ ):
137
+ print(
138
+ "Mismatching token end char mapping with labels",
139
+ new_dataset_elem["token2char_end"],
140
+ new_dataset_elem["window_labels"],
141
+ dataset_elem["tokens"],
142
+ )
143
+ continue
144
+
145
+ yield new_dataset_elem
146
+
147
+
148
+ def preprocess_sample(
149
+ relik_sample: RelikReaderSample,
150
+ tokenizer,
151
+ lowercase_policy: float,
152
+ add_topic: bool = False,
153
+ ) -> None:
154
+ if len(relik_sample.tokens) == 0:
155
+ return
156
+
157
+ if lowercase_policy > 0:
158
+ lc_tokens = np.random.uniform(0, 1, len(relik_sample.tokens)) < lowercase_policy
159
+ relik_sample.tokens = [
160
+ t.lower() if lc else t for t, lc in zip(relik_sample.tokens, lc_tokens)
161
+ ]
162
+
163
+ tokenization_out = tokenizer(
164
+ relik_sample.tokens,
165
+ return_offsets_mapping=True,
166
+ add_special_tokens=False,
167
+ )
168
+
169
+ window_tokens = tokenization_out.input_ids
170
+ window_tokens = flatten(window_tokens)
171
+
172
+ offsets_mapping = [
173
+ [
174
+ (
175
+ ss + relik_sample.token2char_start[str(i)],
176
+ se + relik_sample.token2char_start[str(i)],
177
+ )
178
+ for ss, se in tokenization_out.offset_mapping[i]
179
+ ]
180
+ for i in range(len(relik_sample.tokens))
181
+ ]
182
+
183
+ offsets_mapping = flatten(offsets_mapping)
184
+
185
+ assert len(offsets_mapping) == len(window_tokens)
186
+
187
+ window_tokens = [tokenizer.cls_token_id] + window_tokens + [tokenizer.sep_token_id]
188
+
189
+ topic_offset = 0
190
+ if add_topic:
191
+ topic_tokens = tokenizer(
192
+ relik_sample.doc_topic, add_special_tokens=False
193
+ ).input_ids
194
+ topic_offset = len(topic_tokens)
195
+ relik_sample.topic_tokens = topic_offset
196
+ window_tokens = window_tokens[:1] + topic_tokens + window_tokens[1:]
197
+
198
+ relik_sample._d.update(
199
+ dict(
200
+ tokens=window_tokens,
201
+ token2char_start={
202
+ str(i): s
203
+ for i, (s, _) in enumerate(offsets_mapping, start=topic_offset)
204
+ },
205
+ token2char_end={
206
+ str(i): e
207
+ for i, (_, e) in enumerate(offsets_mapping, start=topic_offset)
208
+ },
209
+ )
210
+ )
211
+
212
+ if "window_labels" in relik_sample._d:
213
+ relik_sample.window_labels = [
214
+ (s, e, l.replace("_", " ")) for s, e, l in relik_sample.window_labels
215
+ ]
216
+
217
+
218
+ class TokenizationOutput(NamedTuple):
219
+ input_ids: torch.Tensor
220
+ attention_mask: torch.Tensor
221
+ token_type_ids: torch.Tensor
222
+ prediction_mask: torch.Tensor
223
+ special_symbols_mask: torch.Tensor
224
+
225
+
226
+ class RelikDataset(IterableDataset):
227
+ def __init__(
228
+ self,
229
+ dataset_path: Optional[str],
230
+ materialize_samples: bool,
231
+ transformer_model: Union[str, PreTrainedTokenizer],
232
+ special_symbols: List[str],
233
+ shuffle_candidates: Optional[Union[bool, float]] = False,
234
+ for_inference: bool = False,
235
+ noise_param: float = 0.1,
236
+ sorting_fields: Optional[str] = None,
237
+ tokens_per_batch: int = 2048,
238
+ batch_size: int = None,
239
+ max_batch_size: int = 128,
240
+ section_size: int = 50_000,
241
+ prebatch: bool = True,
242
+ random_drop_gold_candidates: float = 0.0,
243
+ use_nme: bool = True,
244
+ max_subwords_per_candidate: bool = 22,
245
+ mask_by_instances: bool = False,
246
+ min_length: int = 5,
247
+ max_length: int = 2048,
248
+ model_max_length: int = 1000,
249
+ split_on_cand_overload: bool = True,
250
+ skip_empty_training_samples: bool = False,
251
+ drop_last: bool = False,
252
+ samples: Optional[Iterator[RelikReaderSample]] = None,
253
+ lowercase_policy: float = 0.0,
254
+ **kwargs,
255
+ ):
256
+ super().__init__(**kwargs)
257
+ self.dataset_path = dataset_path
258
+ self.materialize_samples = materialize_samples
259
+ self.samples: Optional[List[RelikReaderSample]] = None
260
+ if self.materialize_samples:
261
+ self.samples = list()
262
+
263
+ if isinstance(transformer_model, str):
264
+ self.tokenizer = self._build_tokenizer(transformer_model, special_symbols)
265
+ else:
266
+ self.tokenizer = transformer_model
267
+ self.special_symbols = special_symbols
268
+ self.shuffle_candidates = shuffle_candidates
269
+ self.for_inference = for_inference
270
+ self.noise_param = noise_param
271
+ self.batching_fields = ["input_ids"]
272
+ self.sorting_fields = (
273
+ sorting_fields if sorting_fields is not None else self.batching_fields
274
+ )
275
+
276
+ self.tokens_per_batch = tokens_per_batch
277
+ self.batch_size = batch_size
278
+ self.max_batch_size = max_batch_size
279
+ self.section_size = section_size
280
+ self.prebatch = prebatch
281
+
282
+ self.random_drop_gold_candidates = random_drop_gold_candidates
283
+ self.use_nme = use_nme
284
+ self.max_subwords_per_candidate = max_subwords_per_candidate
285
+ self.mask_by_instances = mask_by_instances
286
+ self.min_length = min_length
287
+ self.max_length = max_length
288
+ self.model_max_length = (
289
+ model_max_length
290
+ if model_max_length < self.tokenizer.model_max_length
291
+ else self.tokenizer.model_max_length
292
+ )
293
+
294
+ # retrocompatibility workaround
295
+ self.transformer_model = (
296
+ transformer_model
297
+ if isinstance(transformer_model, str)
298
+ else transformer_model.name_or_path
299
+ )
300
+ self.split_on_cand_overload = split_on_cand_overload
301
+ self.skip_empty_training_samples = skip_empty_training_samples
302
+ self.drop_last = drop_last
303
+ self.lowercase_policy = lowercase_policy
304
+ self.samples = samples
305
+
306
+ def _build_tokenizer(self, transformer_model: str, special_symbols: List[str]):
307
+ return AutoTokenizer.from_pretrained(
308
+ transformer_model,
309
+ additional_special_tokens=[ss for ss in special_symbols],
310
+ add_prefix_space=True,
311
+ )
312
+
313
+ @property
314
+ def fields_batcher(self) -> Dict[str, Union[None, Callable[[list], Any]]]:
315
+ fields_batchers = {
316
+ "input_ids": lambda x: batchify(
317
+ x, padding_value=self.tokenizer.pad_token_id
318
+ ),
319
+ "attention_mask": lambda x: batchify(x, padding_value=0),
320
+ "token_type_ids": lambda x: batchify(x, padding_value=0),
321
+ "prediction_mask": lambda x: batchify(x, padding_value=1),
322
+ "global_attention": lambda x: batchify(x, padding_value=0),
323
+ "token2word": None,
324
+ "sample": None,
325
+ "special_symbols_mask": lambda x: batchify(x, padding_value=False),
326
+ "start_labels": lambda x: batchify(x, padding_value=-100),
327
+ "end_labels": lambda x: batchify(x, padding_value=-100),
328
+ "predictable_candidates_symbols": None,
329
+ "predictable_candidates": None,
330
+ "patch_offset": None,
331
+ "optimus_labels": None,
332
+ }
333
+
334
+ if "roberta" in self.transformer_model:
335
+ del fields_batchers["token_type_ids"]
336
+
337
+ return fields_batchers
338
+
339
+ def _build_input_ids(
340
+ self, sentence_input_ids: List[int], candidates_input_ids: List[List[int]]
341
+ ) -> List[int]:
342
+ return (
343
+ [self.tokenizer.cls_token_id]
344
+ + sentence_input_ids
345
+ + [self.tokenizer.sep_token_id]
346
+ + flatten(candidates_input_ids)
347
+ + [self.tokenizer.sep_token_id]
348
+ )
349
+
350
+ def _get_special_symbols_mask(self, input_ids: torch.Tensor) -> torch.Tensor:
351
+ special_symbols_mask = input_ids >= (
352
+ len(self.tokenizer) - len(self.special_symbols)
353
+ )
354
+ special_symbols_mask[0] = True
355
+ return special_symbols_mask
356
+
357
+ def _build_tokenizer_essentials(
358
+ self, input_ids, original_sequence, sample
359
+ ) -> TokenizationOutput:
360
+ input_ids = torch.tensor(input_ids, dtype=torch.long)
361
+ attention_mask = torch.ones_like(input_ids)
362
+
363
+ total_sequence_len = len(input_ids)
364
+ predictable_sentence_len = len(original_sequence)
365
+
366
+ # token type ids
367
+ token_type_ids = torch.cat(
368
+ [
369
+ input_ids.new_zeros(
370
+ predictable_sentence_len + 2
371
+ ), # original sentence bpes + CLS and SEP
372
+ input_ids.new_ones(total_sequence_len - predictable_sentence_len - 2),
373
+ ]
374
+ )
375
+
376
+ # prediction mask -> boolean on tokens that are predictable
377
+
378
+ prediction_mask = torch.tensor(
379
+ [1]
380
+ + ([0] * predictable_sentence_len)
381
+ + ([1] * (total_sequence_len - predictable_sentence_len - 1))
382
+ )
383
+
384
+ # add topic tokens to the prediction mask so that they cannot be predicted
385
+ # or optimized during training
386
+ topic_tokens = getattr(sample, "topic_tokens", None)
387
+ if topic_tokens is not None:
388
+ prediction_mask[1 : 1 + topic_tokens] = 1
389
+
390
+ # If mask by instances is active the prediction mask is applied to everything
391
+ # that is not indicated as an instance in the training set.
392
+ if self.mask_by_instances:
393
+ char_start2token = {
394
+ cs: int(tok) for tok, cs in sample.token2char_start.items()
395
+ }
396
+ char_end2token = {ce: int(tok) for tok, ce in sample.token2char_end.items()}
397
+ instances_mask = torch.ones_like(prediction_mask)
398
+ for _, span_info in sample.instance_id2span_data.items():
399
+ span_info = span_info[0]
400
+ token_start = char_start2token[span_info[0]] + 1 # +1 for the CLS
401
+ token_end = char_end2token[span_info[1]] + 1 # +1 for the CLS
402
+ instances_mask[token_start : token_end + 1] = 0
403
+
404
+ prediction_mask += instances_mask
405
+ prediction_mask[prediction_mask > 1] = 1
406
+
407
+ assert len(prediction_mask) == len(input_ids)
408
+
409
+ # special symbols mask
410
+ special_symbols_mask = self._get_special_symbols_mask(input_ids)
411
+
412
+ return TokenizationOutput(
413
+ input_ids,
414
+ attention_mask,
415
+ token_type_ids,
416
+ prediction_mask,
417
+ special_symbols_mask,
418
+ )
419
+
420
+ def _build_labels(
421
+ self,
422
+ sample,
423
+ tokenization_output: TokenizationOutput,
424
+ predictable_candidates: List[str],
425
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
426
+ start_labels = [0] * len(tokenization_output.input_ids)
427
+ end_labels = [0] * len(tokenization_output.input_ids)
428
+
429
+ char_start2token = {v: int(k) for k, v in sample.token2char_start.items()}
430
+ char_end2token = {v: int(k) for k, v in sample.token2char_end.items()}
431
+ for cs, ce, gold_candidate_title in sample.window_labels:
432
+ if gold_candidate_title not in predictable_candidates:
433
+ if self.use_nme:
434
+ gold_candidate_title = NME_SYMBOL
435
+ else:
436
+ continue
437
+ # +1 is to account for the CLS token
438
+ start_bpe = char_start2token[cs] + 1
439
+ end_bpe = char_end2token[ce] + 1
440
+ class_index = predictable_candidates.index(gold_candidate_title)
441
+ if (
442
+ start_labels[start_bpe] == 0 and end_labels[end_bpe] == 0
443
+ ): # prevent from having entities that ends with the same label
444
+ start_labels[start_bpe] = class_index + 1 # +1 for the NONE class
445
+ end_labels[end_bpe] = class_index + 1 # +1 for the NONE class
446
+ else:
447
+ print(
448
+ "Found entity with the same last subword, it will not be included."
449
+ )
450
+ print(
451
+ cs,
452
+ ce,
453
+ gold_candidate_title,
454
+ start_labels,
455
+ end_labels,
456
+ sample.doc_id,
457
+ )
458
+
459
+ ignored_labels_indices = tokenization_output.prediction_mask == 1
460
+
461
+ start_labels = torch.tensor(start_labels, dtype=torch.long)
462
+ start_labels[ignored_labels_indices] = -100
463
+
464
+ end_labels = torch.tensor(end_labels, dtype=torch.long)
465
+ end_labels[ignored_labels_indices] = -100
466
+
467
+ return start_labels, end_labels
468
+
469
+ def produce_sample_bag(
470
+ self, sample, predictable_candidates: List[str], candidates_starting_offset: int
471
+ ) -> Optional[Tuple[dict, list, int]]:
472
+ # input sentence tokenization
473
+ input_subwords = sample.tokens[1:-1] # removing special tokens
474
+ candidates_symbols = self.special_symbols[candidates_starting_offset:]
475
+
476
+ predictable_candidates = list(predictable_candidates)
477
+ original_predictable_candidates = list(predictable_candidates)
478
+
479
+ # add NME as a possible candidate
480
+ if self.use_nme:
481
+ predictable_candidates.insert(0, NME_SYMBOL)
482
+
483
+ # candidates encoding
484
+ candidates_symbols = candidates_symbols[: len(predictable_candidates)]
485
+ candidates_encoding_result = self.tokenizer.batch_encode_plus(
486
+ [
487
+ "{} {}".format(cs, ct) if ct != NME_SYMBOL else NME_SYMBOL
488
+ for cs, ct in zip(candidates_symbols, predictable_candidates)
489
+ ],
490
+ add_special_tokens=False,
491
+ ).input_ids
492
+
493
+ if (
494
+ self.max_subwords_per_candidate is not None
495
+ and self.max_subwords_per_candidate > 0
496
+ ):
497
+ candidates_encoding_result = [
498
+ cer[: self.max_subwords_per_candidate]
499
+ for cer in candidates_encoding_result
500
+ ]
501
+
502
+ # drop candidates if the number of input tokens is too long for the model
503
+ if (
504
+ sum(map(len, candidates_encoding_result))
505
+ + len(input_subwords)
506
+ + 20 # + 20 special tokens
507
+ > self.model_max_length
508
+ ):
509
+ acceptable_tokens_from_candidates = (
510
+ self.model_max_length - 20 - len(input_subwords)
511
+ )
512
+ i = 0
513
+ cum_len = 0
514
+ while (
515
+ cum_len + len(candidates_encoding_result[i])
516
+ < acceptable_tokens_from_candidates
517
+ ):
518
+ cum_len += len(candidates_encoding_result[i])
519
+ i += 1
520
+
521
+ candidates_encoding_result = candidates_encoding_result[:i]
522
+ candidates_symbols = candidates_symbols[:i]
523
+ predictable_candidates = predictable_candidates[:i]
524
+
525
+ # final input_ids build
526
+ input_ids = self._build_input_ids(
527
+ sentence_input_ids=input_subwords,
528
+ candidates_input_ids=candidates_encoding_result,
529
+ )
530
+
531
+ # complete input building (e.g. attention / prediction mask)
532
+ tokenization_output = self._build_tokenizer_essentials(
533
+ input_ids, input_subwords, sample
534
+ )
535
+
536
+ output_dict = {
537
+ "input_ids": tokenization_output.input_ids,
538
+ "attention_mask": tokenization_output.attention_mask,
539
+ "token_type_ids": tokenization_output.token_type_ids,
540
+ "prediction_mask": tokenization_output.prediction_mask,
541
+ "special_symbols_mask": tokenization_output.special_symbols_mask,
542
+ "sample": sample,
543
+ "predictable_candidates_symbols": candidates_symbols,
544
+ "predictable_candidates": predictable_candidates,
545
+ }
546
+
547
+ # labels creation
548
+ if sample.window_labels is not None:
549
+ start_labels, end_labels = self._build_labels(
550
+ sample,
551
+ tokenization_output,
552
+ predictable_candidates,
553
+ )
554
+ output_dict.update(start_labels=start_labels, end_labels=end_labels)
555
+
556
+ if (
557
+ "roberta" in self.transformer_model
558
+ or "longformer" in self.transformer_model
559
+ ):
560
+ del output_dict["token_type_ids"]
561
+
562
+ predictable_candidates_set = set(predictable_candidates)
563
+ remaining_candidates = [
564
+ candidate
565
+ for candidate in original_predictable_candidates
566
+ if candidate not in predictable_candidates_set
567
+ ]
568
+ total_used_candidates = (
569
+ candidates_starting_offset
570
+ + len(predictable_candidates)
571
+ - (1 if self.use_nme else 0)
572
+ )
573
+
574
+ if self.use_nme:
575
+ assert predictable_candidates[0] == NME_SYMBOL
576
+
577
+ return output_dict, remaining_candidates, total_used_candidates
578
+
579
+ def __iter__(self):
580
+ dataset_iterator = self.dataset_iterator_func()
581
+
582
+ current_dataset_elements = []
583
+
584
+ i = None
585
+ for i, dataset_elem in enumerate(dataset_iterator, start=1):
586
+ if (
587
+ self.section_size is not None
588
+ and len(current_dataset_elements) == self.section_size
589
+ ):
590
+ for batch in self.materialize_batches(current_dataset_elements):
591
+ yield batch
592
+ current_dataset_elements = []
593
+
594
+ current_dataset_elements.append(dataset_elem)
595
+
596
+ if i % 50_000 == 0:
597
+ logger.info(f"Processed: {i} number of elements")
598
+
599
+ if len(current_dataset_elements) != 0:
600
+ for batch in self.materialize_batches(current_dataset_elements):
601
+ yield batch
602
+
603
+ if i is not None:
604
+ logger.info(f"Dataset finished: {i} number of elements processed")
605
+ else:
606
+ logger.warning("Dataset empty")
607
+
608
+ def dataset_iterator_func(self):
609
+ skipped_instances = 0
610
+ data_samples = (
611
+ load_relik_reader_samples(self.dataset_path)
612
+ if self.samples is None
613
+ else self.samples
614
+ )
615
+ for sample in data_samples:
616
+ preprocess_sample(
617
+ sample, self.tokenizer, lowercase_policy=self.lowercase_policy
618
+ )
619
+ current_patch = 0
620
+ sample_bag, used_candidates = None, None
621
+ remaining_candidates = list(sample.window_candidates)
622
+
623
+ if not self.for_inference:
624
+ # randomly drop gold candidates at training time
625
+ if (
626
+ self.random_drop_gold_candidates > 0.0
627
+ and np.random.uniform() < self.random_drop_gold_candidates
628
+ and len(set(ct for _, _, ct in sample.window_labels)) > 1
629
+ ):
630
+ # selecting candidates to drop
631
+ np.random.shuffle(sample.window_labels)
632
+ n_dropped_candidates = np.random.randint(
633
+ 0, len(sample.window_labels) - 1
634
+ )
635
+ dropped_candidates = [
636
+ label_elem[-1]
637
+ for label_elem in sample.window_labels[:n_dropped_candidates]
638
+ ]
639
+ dropped_candidates = set(dropped_candidates)
640
+
641
+ # saving NMEs because they should not be dropped
642
+ if NME_SYMBOL in dropped_candidates:
643
+ dropped_candidates.remove(NME_SYMBOL)
644
+
645
+ # sample update
646
+ sample.window_labels = [
647
+ (s, e, _l)
648
+ if _l not in dropped_candidates
649
+ else (s, e, NME_SYMBOL)
650
+ for s, e, _l in sample.window_labels
651
+ ]
652
+ remaining_candidates = [
653
+ wc
654
+ for wc in remaining_candidates
655
+ if wc not in dropped_candidates
656
+ ]
657
+
658
+ # shuffle candidates
659
+ if (
660
+ isinstance(self.shuffle_candidates, bool)
661
+ and self.shuffle_candidates
662
+ ) or (
663
+ isinstance(self.shuffle_candidates, float)
664
+ and np.random.uniform() < self.shuffle_candidates
665
+ ):
666
+ np.random.shuffle(remaining_candidates)
667
+
668
+ while len(remaining_candidates) != 0:
669
+ sample_bag = self.produce_sample_bag(
670
+ sample,
671
+ predictable_candidates=remaining_candidates,
672
+ candidates_starting_offset=used_candidates
673
+ if used_candidates is not None
674
+ else 0,
675
+ )
676
+ if sample_bag is not None:
677
+ sample_bag, remaining_candidates, used_candidates = sample_bag
678
+ if (
679
+ self.for_inference
680
+ or not self.skip_empty_training_samples
681
+ or (
682
+ (
683
+ sample_bag.get("start_labels") is not None
684
+ and torch.any(sample_bag["start_labels"] > 1).item()
685
+ )
686
+ or (
687
+ sample_bag.get("optimus_labels") is not None
688
+ and len(sample_bag["optimus_labels"]) > 0
689
+ )
690
+ )
691
+ ):
692
+ sample_bag["patch_offset"] = current_patch
693
+ current_patch += 1
694
+ yield sample_bag
695
+ else:
696
+ skipped_instances += 1
697
+ if skipped_instances % 1000 == 0 and skipped_instances != 0:
698
+ logger.info(
699
+ f"Skipped {skipped_instances} instances since they did not have any gold labels..."
700
+ )
701
+
702
+ # Just use the first fitting candidates if split on
703
+ # cand is not True
704
+ if not self.split_on_cand_overload:
705
+ break
706
+
707
+ def preshuffle_elements(self, dataset_elements: List):
708
+ # This shuffling is done so that when using the sorting function,
709
+ # if it is deterministic given a collection and its order, we will
710
+ # make the whole operation not deterministic anymore.
711
+ # Basically, the aim is not to build every time the same batches.
712
+ if not self.for_inference:
713
+ dataset_elements = np.random.permutation(dataset_elements)
714
+
715
+ sorting_fn = (
716
+ lambda elem: add_noise_to_value(
717
+ sum(len(elem[k]) for k in self.sorting_fields),
718
+ noise_param=self.noise_param,
719
+ )
720
+ if not self.for_inference
721
+ else sum(len(elem[k]) for k in self.sorting_fields)
722
+ )
723
+
724
+ dataset_elements = sorted(dataset_elements, key=sorting_fn)
725
+
726
+ if self.for_inference:
727
+ return dataset_elements
728
+
729
+ ds = list(chunks(dataset_elements, 64))
730
+ np.random.shuffle(ds)
731
+ return flatten(ds)
732
+
733
+ def materialize_batches(
734
+ self, dataset_elements: List[Dict[str, Any]]
735
+ ) -> Generator[Dict[str, Any], None, None]:
736
+ if self.prebatch:
737
+ dataset_elements = self.preshuffle_elements(dataset_elements)
738
+
739
+ current_batch = []
740
+
741
+ # function that creates a batch from the 'current_batch' list
742
+ def output_batch() -> Dict[str, Any]:
743
+ assert (
744
+ len(
745
+ set([len(elem["predictable_candidates"]) for elem in current_batch])
746
+ )
747
+ == 1
748
+ ), " ".join(
749
+ map(
750
+ str, [len(elem["predictable_candidates"]) for elem in current_batch]
751
+ )
752
+ )
753
+
754
+ batch_dict = dict()
755
+
756
+ de_values_by_field = {
757
+ fn: [de[fn] for de in current_batch if fn in de]
758
+ for fn in self.fields_batcher
759
+ }
760
+
761
+ # in case you provide fields batchers but in the batch
762
+ # there are no elements for that field
763
+ de_values_by_field = {
764
+ fn: fvs for fn, fvs in de_values_by_field.items() if len(fvs) > 0
765
+ }
766
+
767
+ assert len(set([len(v) for v in de_values_by_field.values()]))
768
+
769
+ # todo: maybe we should report the user about possible
770
+ # fields filtering due to "None" instances
771
+ de_values_by_field = {
772
+ fn: fvs
773
+ for fn, fvs in de_values_by_field.items()
774
+ if all([fv is not None for fv in fvs])
775
+ }
776
+
777
+ for field_name, field_values in de_values_by_field.items():
778
+ field_batch = (
779
+ self.fields_batcher[field_name](field_values)
780
+ if self.fields_batcher[field_name] is not None
781
+ else field_values
782
+ )
783
+
784
+ batch_dict[field_name] = field_batch
785
+
786
+ return batch_dict
787
+
788
+ max_len_discards, min_len_discards = 0, 0
789
+
790
+ should_token_batch = self.batch_size is None
791
+
792
+ curr_pred_elements = -1
793
+ for de in dataset_elements:
794
+ if (
795
+ should_token_batch
796
+ and self.max_batch_size != -1
797
+ and len(current_batch) == self.max_batch_size
798
+ ) or (not should_token_batch and len(current_batch) == self.batch_size):
799
+ yield output_batch()
800
+ current_batch = []
801
+ curr_pred_elements = -1
802
+
803
+ too_long_fields = [
804
+ k
805
+ for k in de
806
+ if self.max_length != -1
807
+ and torch.is_tensor(de[k])
808
+ and len(de[k]) > self.max_length
809
+ ]
810
+ if len(too_long_fields) > 0:
811
+ max_len_discards += 1
812
+ continue
813
+
814
+ too_short_fields = [
815
+ k
816
+ for k in de
817
+ if self.min_length != -1
818
+ and torch.is_tensor(de[k])
819
+ and len(de[k]) < self.min_length
820
+ ]
821
+ if len(too_short_fields) > 0:
822
+ min_len_discards += 1
823
+ continue
824
+
825
+ if should_token_batch:
826
+ de_len = sum(len(de[k]) for k in self.batching_fields)
827
+
828
+ future_max_len = max(
829
+ de_len,
830
+ max(
831
+ [
832
+ sum(len(bde[k]) for k in self.batching_fields)
833
+ for bde in current_batch
834
+ ],
835
+ default=0,
836
+ ),
837
+ )
838
+
839
+ future_tokens_per_batch = future_max_len * (len(current_batch) + 1)
840
+
841
+ num_predictable_candidates = len(de["predictable_candidates"])
842
+
843
+ if len(current_batch) > 0 and (
844
+ future_tokens_per_batch >= self.tokens_per_batch
845
+ or (
846
+ num_predictable_candidates != curr_pred_elements
847
+ and curr_pred_elements != -1
848
+ )
849
+ ):
850
+ yield output_batch()
851
+ current_batch = []
852
+
853
+ current_batch.append(de)
854
+ curr_pred_elements = len(de["predictable_candidates"])
855
+
856
+ if len(current_batch) != 0 and not self.drop_last:
857
+ yield output_batch()
858
+
859
+ if max_len_discards > 0:
860
+ if self.for_inference:
861
+ logger.warning(
862
+ f"WARNING: Inference mode is True but {max_len_discards} samples longer than max length were "
863
+ f"found. The {max_len_discards} samples will be DISCARDED. If you are doing some kind of evaluation"
864
+ f", this can INVALIDATE results. This might happen if the max length was not set to -1 or if the "
865
+ f"sample length exceeds the maximum length supported by the current model."
866
+ )
867
+ else:
868
+ logger.warning(
869
+ f"During iteration, {max_len_discards} elements were "
870
+ f"discarded since longer than max length {self.max_length}"
871
+ )
872
+
873
+ if min_len_discards > 0:
874
+ if self.for_inference:
875
+ logger.warning(
876
+ f"WARNING: Inference mode is True but {min_len_discards} samples shorter than min length were "
877
+ f"found. The {min_len_discards} samples will be DISCARDED. If you are doing some kind of evaluation"
878
+ f", this can INVALIDATE results. This might happen if the min length was not set to -1 or if the "
879
+ f"sample length is shorter than the minimum length supported by the current model."
880
+ )
881
+ else:
882
+ logger.warning(
883
+ f"During iteration, {min_len_discards} elements were "
884
+ f"discarded since shorter than min length {self.min_length}"
885
+ )
886
+
887
+ @staticmethod
888
+ def convert_tokens_to_char_annotations(
889
+ sample: RelikReaderSample,
890
+ remove_nmes: bool = True,
891
+ ) -> RelikReaderSample:
892
+ """
893
+ Converts the token annotations to char annotations.
894
+
895
+ Args:
896
+ sample (:obj:`RelikReaderSample`):
897
+ The sample to convert.
898
+ remove_nmes (:obj:`bool`, `optional`, defaults to :obj:`True`):
899
+ Whether to remove the NMEs from the annotations.
900
+ Returns:
901
+ :obj:`RelikReaderSample`: The converted sample.
902
+ """
903
+ char_annotations = set()
904
+ for (
905
+ predicted_entity,
906
+ predicted_spans,
907
+ ) in sample.predicted_window_labels.items():
908
+ if predicted_entity == NME_SYMBOL and remove_nmes:
909
+ continue
910
+
911
+ for span_start, span_end in predicted_spans:
912
+ span_start = sample.token2char_start[str(span_start)]
913
+ span_end = sample.token2char_end[str(span_end)]
914
+
915
+ char_annotations.add((span_start, span_end, predicted_entity))
916
+
917
+ char_probs_annotations = dict()
918
+ for (
919
+ span_start,
920
+ span_end,
921
+ ), candidates_probs in sample.span_title_probabilities.items():
922
+ span_start = sample.token2char_start[str(span_start)]
923
+ span_end = sample.token2char_end[str(span_end)]
924
+ char_probs_annotations[(span_start, span_end)] = {
925
+ title for title, _ in candidates_probs
926
+ }
927
+
928
+ sample.predicted_window_labels_chars = char_annotations
929
+ sample.probs_window_labels_chars = char_probs_annotations
930
+
931
+ return sample
932
+
933
+ @staticmethod
934
+ def merge_patches_predictions(sample) -> None:
935
+ sample._d["predicted_window_labels"] = dict()
936
+ predicted_window_labels = sample._d["predicted_window_labels"]
937
+
938
+ sample._d["span_title_probabilities"] = dict()
939
+ span_title_probabilities = sample._d["span_title_probabilities"]
940
+
941
+ span2title = dict()
942
+ for _, patch_info in sorted(sample.patches.items(), key=lambda x: x[0]):
943
+ # selecting span predictions
944
+ for predicted_title, predicted_spans in patch_info[
945
+ "predicted_window_labels"
946
+ ].items():
947
+ for pred_span in predicted_spans:
948
+ pred_span = tuple(pred_span)
949
+ curr_title = span2title.get(pred_span)
950
+ if curr_title is None or curr_title == NME_SYMBOL:
951
+ span2title[pred_span] = predicted_title
952
+ # else:
953
+ # print("Merging at patch level")
954
+
955
+ # selecting span predictions probability
956
+ for predicted_span, titles_probabilities in patch_info[
957
+ "span_title_probabilities"
958
+ ].items():
959
+ if predicted_span not in span_title_probabilities:
960
+ span_title_probabilities[predicted_span] = titles_probabilities
961
+
962
+ for span, title in span2title.items():
963
+ if title not in predicted_window_labels:
964
+ predicted_window_labels[title] = list()
965
+ predicted_window_labels[title].append(span)
relik/reader/data/relik_reader_data_utils.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+
7
+ def flatten(lsts: List[list]) -> list:
8
+ acc_lst = list()
9
+ for lst in lsts:
10
+ acc_lst.extend(lst)
11
+ return acc_lst
12
+
13
+
14
+ def batchify(tensors: List[torch.Tensor], padding_value: int = 0) -> torch.Tensor:
15
+ return torch.nn.utils.rnn.pad_sequence(
16
+ tensors, batch_first=True, padding_value=padding_value
17
+ )
18
+
19
+
20
+ def batchify_matrices(tensors: List[torch.Tensor], padding_value: int) -> torch.Tensor:
21
+ x = max([t.shape[0] for t in tensors])
22
+ y = max([t.shape[1] for t in tensors])
23
+ out_matrix = torch.zeros((len(tensors), x, y))
24
+ out_matrix += padding_value
25
+ for i, tensor in enumerate(tensors):
26
+ out_matrix[i][0 : tensor.shape[0], 0 : tensor.shape[1]] = tensor
27
+ return out_matrix
28
+
29
+
30
+ def batchify_tensor(tensors: List[torch.Tensor], padding_value: int) -> torch.Tensor:
31
+ x = max([t.shape[0] for t in tensors])
32
+ y = max([t.shape[1] for t in tensors])
33
+ rest = tensors[0].shape[2]
34
+ out_matrix = torch.zeros((len(tensors), x, y, rest))
35
+ out_matrix += padding_value
36
+ for i, tensor in enumerate(tensors):
37
+ out_matrix[i][0 : tensor.shape[0], 0 : tensor.shape[1], :] = tensor
38
+ return out_matrix
39
+
40
+
41
+ def chunks(lst: list, chunk_size: int) -> List[list]:
42
+ chunks_acc = list()
43
+ for i in range(0, len(lst), chunk_size):
44
+ chunks_acc.append(lst[i : i + chunk_size])
45
+ return chunks_acc
46
+
47
+
48
+ def add_noise_to_value(value: int, noise_param: float):
49
+ noise_value = value * noise_param
50
+ noise = np.random.uniform(-noise_value, noise_value)
51
+ return max(1, value + noise)
relik/reader/data/relik_reader_sample.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from typing import Iterable
3
+
4
+
5
+ class RelikReaderSample:
6
+ def __init__(self, **kwargs):
7
+ super().__setattr__("_d", {})
8
+ self._d = kwargs
9
+
10
+ def __getattribute__(self, item):
11
+ return super(RelikReaderSample, self).__getattribute__(item)
12
+
13
+ def __getattr__(self, item):
14
+ if item.startswith("__") and item.endswith("__"):
15
+ # this is likely some python library-specific variable (such as __deepcopy__ for copy)
16
+ # better follow standard behavior here
17
+ raise AttributeError(item)
18
+ elif item in self._d:
19
+ return self._d[item]
20
+ else:
21
+ return None
22
+
23
+ def __setattr__(self, key, value):
24
+ if key in self._d:
25
+ self._d[key] = value
26
+ else:
27
+ super().__setattr__(key, value)
28
+
29
+ def to_jsons(self) -> str:
30
+ if "predicted_window_labels" in self._d:
31
+ new_obj = {
32
+ k: v
33
+ for k, v in self._d.items()
34
+ if k != "predicted_window_labels" and k != "span_title_probabilities"
35
+ }
36
+ new_obj["predicted_window_labels"] = [
37
+ [ss, se, pred_title]
38
+ for (ss, se), pred_title in self.predicted_window_labels_chars
39
+ ]
40
+ else:
41
+ return json.dumps(self._d)
42
+
43
+
44
+ def load_relik_reader_samples(path: str) -> Iterable[RelikReaderSample]:
45
+ with open(path) as f:
46
+ for line in f:
47
+ jsonl_line = json.loads(line.strip())
48
+ relik_reader_sample = RelikReaderSample(**jsonl_line)
49
+ yield relik_reader_sample
relik/reader/lightning_modules/__init__.py ADDED
File without changes
relik/reader/lightning_modules/relik_reader_pl_module.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Optional
2
+
3
+ import lightning
4
+ from lightning.pytorch.utilities.types import STEP_OUTPUT, OptimizerLRScheduler
5
+
6
+ from relik.reader.relik_reader_core import RelikReaderCoreModel
7
+
8
+
9
+ class RelikReaderPLModule(lightning.LightningModule):
10
+ def __init__(
11
+ self,
12
+ cfg: dict,
13
+ transformer_model: str,
14
+ additional_special_symbols: int,
15
+ num_layers: Optional[int] = None,
16
+ activation: str = "gelu",
17
+ linears_hidden_size: Optional[int] = 512,
18
+ use_last_k_layers: int = 1,
19
+ training: bool = False,
20
+ *args: Any,
21
+ **kwargs: Any
22
+ ):
23
+ super().__init__(*args, **kwargs)
24
+ self.save_hyperparameters()
25
+ self.relik_reader_core_model = RelikReaderCoreModel(
26
+ transformer_model,
27
+ additional_special_symbols,
28
+ num_layers,
29
+ activation,
30
+ linears_hidden_size,
31
+ use_last_k_layers,
32
+ training=training,
33
+ )
34
+ self.optimizer_factory = None
35
+
36
+ def training_step(self, batch: dict, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
37
+ relik_output = self.relik_reader_core_model(**batch)
38
+ self.log("train-loss", relik_output["loss"])
39
+ return relik_output["loss"]
40
+
41
+ def validation_step(
42
+ self, batch: dict, *args: Any, **kwargs: Any
43
+ ) -> Optional[STEP_OUTPUT]:
44
+ return
45
+
46
+ def set_optimizer_factory(self, optimizer_factory) -> None:
47
+ self.optimizer_factory = optimizer_factory
48
+
49
+ def configure_optimizers(self) -> OptimizerLRScheduler:
50
+ return self.optimizer_factory(self.relik_reader_core_model)
relik/reader/lightning_modules/relik_reader_re_pl_module.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Optional
2
+
3
+ import lightning
4
+ from lightning.pytorch.utilities.types import STEP_OUTPUT, OptimizerLRScheduler
5
+
6
+ from relik.reader.relik_reader_re import RelikReaderForTripletExtraction
7
+
8
+
9
+ class RelikReaderREPLModule(lightning.LightningModule):
10
+ def __init__(
11
+ self,
12
+ cfg: dict,
13
+ transformer_model: str,
14
+ additional_special_symbols: int,
15
+ num_layers: Optional[int] = None,
16
+ activation: str = "gelu",
17
+ linears_hidden_size: Optional[int] = 512,
18
+ use_last_k_layers: int = 1,
19
+ training: bool = False,
20
+ *args: Any,
21
+ **kwargs: Any
22
+ ):
23
+ super().__init__(*args, **kwargs)
24
+ self.save_hyperparameters()
25
+
26
+ self.relik_reader_re_model = RelikReaderForTripletExtraction(
27
+ transformer_model,
28
+ additional_special_symbols,
29
+ num_layers,
30
+ activation,
31
+ linears_hidden_size,
32
+ use_last_k_layers,
33
+ training=training,
34
+ )
35
+ self.optimizer_factory = None
36
+
37
+ def training_step(self, batch: dict, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
38
+ relik_output = self.relik_reader_re_model(**batch)
39
+ self.log("train-loss", relik_output["loss"])
40
+ self.log("train-start_loss", relik_output["ned_start_loss"])
41
+ self.log("train-end_loss", relik_output["ned_end_loss"])
42
+ self.log("train-relation_loss", relik_output["re_loss"])
43
+ return relik_output["loss"]
44
+
45
+ def validation_step(
46
+ self, batch: dict, *args: Any, **kwargs: Any
47
+ ) -> Optional[STEP_OUTPUT]:
48
+ return
49
+
50
+ def set_optimizer_factory(self, optimizer_factory) -> None:
51
+ self.optimizer_factory = optimizer_factory
52
+
53
+ def configure_optimizers(self) -> OptimizerLRScheduler:
54
+ return self.optimizer_factory(self.relik_reader_re_model)
relik/reader/pytorch_modules/__init__.py ADDED
File without changes