diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..dbadf32bbad592c5933aa57b988953c1bbaa5e4b
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,332 @@
+# custom
+
+data/*
+experiments/*
+retrievers
+outputs
+model
+wandb
+
+# Created by https://www.toptal.com/developers/gitignore/api/jetbrains+all,vscode,python,jupyternotebooks,linux,windows,macos
+# Edit at https://www.toptal.com/developers/gitignore?templates=jetbrains+all,vscode,python,jupyternotebooks,linux,windows,macos
+
+### JetBrains+all ###
+# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
+# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
+
+# User-specific stuff
+.idea/**/workspace.xml
+.idea/**/tasks.xml
+.idea/**/usage.statistics.xml
+.idea/**/dictionaries
+.idea/**/shelf
+
+# Generated files
+.idea/**/contentModel.xml
+
+# Sensitive or high-churn files
+.idea/**/dataSources/
+.idea/**/dataSources.ids
+.idea/**/dataSources.local.xml
+.idea/**/sqlDataSources.xml
+.idea/**/dynamic.xml
+.idea/**/uiDesigner.xml
+.idea/**/dbnavigator.xml
+
+# Gradle
+.idea/**/gradle.xml
+.idea/**/libraries
+
+# Gradle and Maven with auto-import
+# When using Gradle or Maven with auto-import, you should exclude module files,
+# since they will be recreated, and may cause churn. Uncomment if using
+# auto-import.
+# .idea/artifacts
+# .idea/compiler.xml
+# .idea/jarRepositories.xml
+# .idea/modules.xml
+# .idea/*.iml
+# .idea/modules
+# *.iml
+# *.ipr
+
+# CMake
+cmake-build-*/
+
+# Mongo Explorer plugin
+.idea/**/mongoSettings.xml
+
+# File-based project format
+*.iws
+
+# IntelliJ
+out/
+
+# mpeltonen/sbt-idea plugin
+.idea_modules/
+
+# JIRA plugin
+atlassian-ide-plugin.xml
+
+# Cursive Clojure plugin
+.idea/replstate.xml
+
+# Crashlytics plugin (for Android Studio and IntelliJ)
+com_crashlytics_export_strings.xml
+crashlytics.properties
+crashlytics-build.properties
+fabric.properties
+
+# Editor-based Rest Client
+.idea/httpRequests
+
+# Android studio 3.1+ serialized cache file
+.idea/caches/build_file_checksums.ser
+
+### JetBrains+all Patch ###
+# Ignores the whole .idea folder and all .iml files
+# See https://github.com/joeblau/gitignore.io/issues/186 and https://github.com/joeblau/gitignore.io/issues/360
+
+.idea/
+
+# Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-249601023
+
+*.iml
+modules.xml
+.idea/misc.xml
+*.ipr
+
+# Sonarlint plugin
+.idea/sonarlint
+
+### JupyterNotebooks ###
+# gitignore template for Jupyter Notebooks
+# website: http://jupyter.org/
+
+.ipynb_checkpoints
+*/.ipynb_checkpoints/*
+
+# IPython
+profile_default/
+ipython_config.py
+
+# Remove previous ipynb_checkpoints
+# git rm -r .ipynb_checkpoints/
+
+### Linux ###
+*~
+
+# temporary files which can be created if a process still has a handle open of a deleted file
+.fuse_hidden*
+
+# KDE directory preferences
+.directory
+
+# Linux trash folder which might appear on any partition or disk
+.Trash-*
+
+# .nfs files are created when an open file is removed but is still being accessed
+.nfs*
+
+### macOS ###
+# General
+.DS_Store
+.AppleDouble
+.LSOverride
+
+# Icon must end with two \r
+Icon
+
+
+# Thumbnails
+._*
+
+# Files that might appear in the root of a volume
+.DocumentRevisions-V100
+.fseventsd
+.Spotlight-V100
+.TemporaryItems
+.Trashes
+.VolumeIcon.icns
+.com.apple.timemachine.donotpresent
+
+# Directories potentially created on remote AFP share
+.AppleDB
+.AppleDesktop
+Network Trash Folder
+Temporary Items
+.apdisk
+
+### Python ###
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+pip-wheel-metadata/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+pytestdebug.log
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+doc/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+
+# IPython
+
+# pyenv
+.python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+pythonenv*
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# profiling data
+.prof
+
+### vscode ###
+.vscode
+.vscode/*
+!.vscode/settings.json
+!.vscode/tasks.json
+!.vscode/launch.json
+!.vscode/extensions.json
+*.code-workspace
+
+### Windows ###
+# Windows thumbnail cache files
+Thumbs.db
+Thumbs.db:encryptable
+ehthumbs.db
+ehthumbs_vista.db
+
+# Dump file
+*.stackdump
+
+# Folder config file
+[Dd]esktop.ini
+
+# Recycle Bin used on file shares
+$RECYCLE.BIN/
+
+# Windows Installer files
+*.cab
+*.msi
+*.msix
+*.msm
+*.msp
+
+# Windows shortcuts
+*.lnk
+
+# End of https://www.toptal.com/developers/gitignore/api/jetbrains+all,vscode,python,jupyternotebooks,linux,windows,macos
\ No newline at end of file
diff --git a/MANIFEST.in b/MANIFEST.in
new file mode 100644
index 0000000000000000000000000000000000000000..f9bd1455b374de796e12d240c1211dee9829d97e
--- /dev/null
+++ b/MANIFEST.in
@@ -0,0 +1 @@
+include requirements.txt
diff --git a/README.md b/README.md
index 41e51eae2b1a8e039beb534af184753a672370eb..bc6e732d1aee0ed57c8c697ecc7364d35dea3929 100644
--- a/README.md
+++ b/README.md
@@ -1,13 +1 @@
----
-title: Relik
-emoji: 🐨
-colorFrom: gray
-colorTo: pink
-sdk: streamlit
-sdk_version: 1.27.2
-app_file: app.py
-pinned: false
-license: mit
----
-
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+# relik
\ No newline at end of file
diff --git a/SETUP.cfg b/SETUP.cfg
new file mode 100644
index 0000000000000000000000000000000000000000..4b4b70f561fe7171daf99bed49ad08fed98aedb6
--- /dev/null
+++ b/SETUP.cfg
@@ -0,0 +1,8 @@
+[metadata]
+description-file = README.md
+
+[build]
+build-base = /tmp/build
+
+[egg_info]
+egg-base = /tmp
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..67f3c58aab567cdb797214d7d886db81b48d5584
--- /dev/null
+++ b/app.py
@@ -0,0 +1,245 @@
+import os
+import re
+import time
+from pathlib import Path
+
+import requests
+import streamlit as st
+from spacy import displacy
+from streamlit_extras.badges import badge
+from streamlit_extras.stylable_container import stylable_container
+
+# RELIK = os.getenv("RELIK", "localhost:8000/api/entities")
+
+import random
+
+from relik.inference.annotator import Relik
+
+
+def get_random_color(ents):
+ colors = {}
+ random_colors = generate_pastel_colors(len(ents))
+ for ent in ents:
+ colors[ent] = random_colors.pop(random.randint(0, len(random_colors) - 1))
+ return colors
+
+
+def floatrange(start, stop, steps):
+ if int(steps) == 1:
+ return [stop]
+ return [
+ start + float(i) * (stop - start) / (float(steps) - 1) for i in range(steps)
+ ]
+
+
+def hsl_to_rgb(h, s, l):
+ def hue_2_rgb(v1, v2, v_h):
+ while v_h < 0.0:
+ v_h += 1.0
+ while v_h > 1.0:
+ v_h -= 1.0
+ if 6 * v_h < 1.0:
+ return v1 + (v2 - v1) * 6.0 * v_h
+ if 2 * v_h < 1.0:
+ return v2
+ if 3 * v_h < 2.0:
+ return v1 + (v2 - v1) * ((2.0 / 3.0) - v_h) * 6.0
+ return v1
+
+ # if not (0 <= s <= 1): raise ValueError, "s (saturation) parameter must be between 0 and 1."
+ # if not (0 <= l <= 1): raise ValueError, "l (lightness) parameter must be between 0 and 1."
+
+ r, b, g = (l * 255,) * 3
+ if s != 0.0:
+ if l < 0.5:
+ var_2 = l * (1.0 + s)
+ else:
+ var_2 = (l + s) - (s * l)
+ var_1 = 2.0 * l - var_2
+ r = 255 * hue_2_rgb(var_1, var_2, h + (1.0 / 3.0))
+ g = 255 * hue_2_rgb(var_1, var_2, h)
+ b = 255 * hue_2_rgb(var_1, var_2, h - (1.0 / 3.0))
+
+ return int(round(r)), int(round(g)), int(round(b))
+
+
+def generate_pastel_colors(n):
+ """Return different pastel colours.
+
+ Input:
+ n (integer) : The number of colors to return
+
+ Output:
+ A list of colors in HTML notation (eg.['#cce0ff', '#ffcccc', '#ccffe0', '#f5ccff', '#f5ffcc'])
+
+ Example:
+ >>> print generate_pastel_colors(5)
+ ['#cce0ff', '#f5ccff', '#ffcccc', '#f5ffcc', '#ccffe0']
+ """
+ if n == 0:
+ return []
+
+ # To generate colors, we use the HSL colorspace (see http://en.wikipedia.org/wiki/HSL_color_space)
+ start_hue = 0.6 # 0=red 1/3=0.333=green 2/3=0.666=blue
+ saturation = 1.0
+ lightness = 0.8
+ # We take points around the chromatic circle (hue):
+ # (Note: we generate n+1 colors, then drop the last one ([:-1]) because
+ # it equals the first one (hue 0 = hue 1))
+ return [
+ "#%02x%02x%02x" % hsl_to_rgb(hue, saturation, lightness)
+ for hue in floatrange(start_hue, start_hue + 1, n + 1)
+ ][:-1]
+
+
+def set_sidebar(css):
+ white_link_wrapper = "{}"
+ with st.sidebar:
+ st.markdown(f"", unsafe_allow_html=True)
+ st.image(
+ "http://nlp.uniroma1.it/static/website/sapienza-nlp-logo-wh.svg",
+ use_column_width=True,
+ )
+ st.markdown("## ReLiK")
+ st.write(
+ f"""
+ - {white_link_wrapper.format("#", " Paper")}
+ - {white_link_wrapper.format("https://github.com/SapienzaNLP/relik", " GitHub")}
+ - {white_link_wrapper.format("https://hub.docker.com/repository/docker/sapienzanlp/relik", " Docker Hub")}
+ """,
+ unsafe_allow_html=True,
+ )
+ st.markdown("## Sapienza NLP")
+ st.write(
+ f"""
+ - {white_link_wrapper.format("https://nlp.uniroma1.it", " Webpage")}
+ - {white_link_wrapper.format("https://github.com/SapienzaNLP", " GitHub")}
+ - {white_link_wrapper.format("https://twitter.com/SapienzaNLP", " Twitter")}
+ - {white_link_wrapper.format("https://www.linkedin.com/company/79434450", " LinkedIn")}
+ """,
+ unsafe_allow_html=True,
+ )
+
+
+def get_el_annotations(response):
+ # swap labels key with ents
+ dict_of_ents = {"text": response.text, "ents": []}
+ dict_of_ents["ents"] = response.labels
+ label_in_text = set(l["label"] for l in dict_of_ents["ents"])
+ options = {"ents": label_in_text, "colors": get_random_color(label_in_text)}
+ return dict_of_ents, options
+
+
+def set_intro(css):
+ # intro
+ st.markdown("# ReLik")
+ st.markdown(
+ "### Retrieve, Read and LinK: Fast and Accurate Entity Linking and Relation Extraction on an Academic Budget"
+ )
+ # st.markdown(
+ # "This is a front-end for the paper [Universal Semantic Annotator: the First Unified API "
+ # "for WSD, SRL and Semantic Parsing](https://www.researchgate.net/publication/360671045_Universal_Semantic_Annotator_the_First_Unified_API_for_WSD_SRL_and_Semantic_Parsing), which will be presented at LREC 2022 by "
+ # "[Riccardo Orlando](https://riccorl.github.io), [Simone Conia](https://c-simone.github.io/), "
+ # "[Stefano Faralli](https://corsidilaurea.uniroma1.it/it/users/stefanofaralliuniroma1it), and [Roberto Navigli](https://www.diag.uniroma1.it/navigli/)."
+ # )
+ badge(type="github", name="sapienzanlp/relik")
+ badge(type="pypi", name="relik")
+
+
+def run_client():
+ with open(Path(__file__).parent / "style.css") as f:
+ css = f.read()
+
+ st.set_page_config(
+ page_title="ReLik",
+ page_icon="🦮",
+ layout="wide",
+ )
+ set_sidebar(css)
+ set_intro(css)
+
+ # text input
+ text = st.text_area(
+ "Enter Text Below:",
+ value="Obama went to Rome for a quick vacation.",
+ height=200,
+ max_chars=500,
+ )
+
+ with stylable_container(
+ key="annotate_button",
+ css_styles="""
+ button {
+ background-color: #802433;
+ color: white;
+ border-radius: 25px;
+ }
+ """,
+ ):
+ submit = st.button("Annotate")
+ # submit = st.button("Run")
+
+ relik = Relik(
+ question_encoder="riccorl/relik-retriever-aida-blink-pretrain-omniencoder",
+ document_index="riccorl/index-relik-retriever-aida-blink-pretrain-omniencoder",
+ reader="riccorl/relik-reader-aida-deberta-small",
+ top_k=100,
+ window_size=32,
+ window_stride=16,
+ candidates_preprocessing_fn="relik.inference.preprocessing.wikipedia_title_and_openings_preprocessing",
+ )
+
+ # ReLik API call
+ if submit:
+ text = text.strip()
+ if text:
+ st.markdown("####")
+ st.markdown("#### Entity Linking")
+ with st.spinner(text="In progress"):
+ response = relik(text)
+ # response = requests.post(RELIK, json=text)
+ # if response.status_code != 200:
+ # st.error("Error: {}".format(response.status_code))
+ # else:
+ # response = response.json()
+
+ # Entity Linking
+ # with stylable_container(
+ # key="container_with_border",
+ # css_styles="""
+ # {
+ # border: 1px solid rgba(49, 51, 63, 0.2);
+ # border-radius: 0.5rem;
+ # padding: 0.5rem;
+ # padding-bottom: 2rem;
+ # }
+ # """,
+ # ):
+ # st.markdown("##")
+ dict_of_ents, options = get_el_annotations(response=response)
+ display = displacy.render(
+ dict_of_ents, manual=True, style="ent", options=options
+ )
+ display = display.replace("\n", " ")
+ # wsd_display = re.sub(
+ # r"(wiki::\d+\w)",
+ # r"\g<1>".format(
+ # language.upper()
+ # ),
+ # wsd_display,
+ # )
+ with st.container():
+ st.write(display, unsafe_allow_html=True)
+
+ st.markdown("####")
+ st.markdown("#### Relation Extraction")
+
+ with st.container():
+ st.write("Coming :)", unsafe_allow_html=True)
+
+ else:
+ st.error("Please enter some text.")
+
+
+if __name__ == "__main__":
+ run_client()
diff --git a/dockerfiles/Dockerfile.cpu b/dockerfiles/Dockerfile.cpu
new file mode 100644
index 0000000000000000000000000000000000000000..b27436cea2a992a2b70e890e2435c528d9c0ba09
--- /dev/null
+++ b/dockerfiles/Dockerfile.cpu
@@ -0,0 +1,17 @@
+FROM tiangolo/uvicorn-gunicorn:python3.10-slim
+
+# Copy and install requirements.txt
+COPY ./requirements.txt ./requirements.txt
+COPY ./src /app
+COPY ./scripts/start.sh /start.sh
+COPY ./scripts/prestart.sh /app
+COPY ./scripts/gunicorn_conf.py /gunicorn_conf.py
+COPY ./scripts/start-reload.sh /start-reload.sh
+COPY ./VERSION /
+RUN mkdir -p /app/resources/model \
+ && pip install --no-cache-dir -r requirements.txt \
+ && chmod +x /start.sh && chmod +x /start-reload.sh
+ARG MODEL_PATH
+COPY ${MODEL_PATH}/* /app/resources/model/
+
+ENV APP_MODULE=main:app
diff --git a/dockerfiles/Dockerfile.cuda b/dockerfiles/Dockerfile.cuda
new file mode 100644
index 0000000000000000000000000000000000000000..0ca669a954f1e76caa350e2311b138029b245d8c
--- /dev/null
+++ b/dockerfiles/Dockerfile.cuda
@@ -0,0 +1,38 @@
+FROM nvidia/cuda:12.2.0-base-ubuntu20.04
+
+ARG DEBIAN_FRONTEND=noninteractive
+
+RUN apt-get update \
+ && apt-get install \
+ curl wget python3.10 \
+ python3.10-distutils \
+ python3-pip \
+ curl wget -y \
+ && rm -rf /var/lib/apt/lists/*
+
+# FastAPI section
+# device env
+ENV DEVICE="cuda"
+# Copy and install requirements.txt
+COPY ./gpu-requirements.txt ./requirements.txt
+COPY ./src /app
+COPY ./scripts/start.sh /start.sh
+COPY ./scripts/gunicorn_conf.py /gunicorn_conf.py
+COPY ./scripts/start-reload.sh /start-reload.sh
+COPY ./scripts/prestart.sh /app
+COPY ./VERSION /
+RUN mkdir -p /app/resources/model \
+ && pip install --upgrade --no-cache-dir -r requirements.txt \
+ && chmod +x /start.sh \
+ && chmod +x /start-reload.sh
+ARG MODEL_NAME_OR_PATH
+
+WORKDIR /app
+
+ENV PYTHONPATH=/app
+
+EXPOSE 80
+
+# Run the start script, it will check for an /app/prestart.sh script (e.g. for migrations)
+# And then will start Gunicorn with Uvicorn
+CMD ["/start.sh"]
diff --git a/examples/train_retriever.py b/examples/train_retriever.py
new file mode 100644
index 0000000000000000000000000000000000000000..514eb0d80706012a50cd203e654bd14c57bb5db0
--- /dev/null
+++ b/examples/train_retriever.py
@@ -0,0 +1,45 @@
+from relik.retriever.trainer import RetrieverTrainer
+from relik import GoldenRetriever
+from relik.retriever.indexers.inmemory import InMemoryDocumentIndex
+from relik.retriever.data.datasets import AidaInBatchNegativesDataset
+
+if __name__ == "__main__":
+ # instantiate retriever
+ document_index = InMemoryDocumentIndex(
+ documents="/root/golden-retriever-v2/data/dpr-like/el/definitions.txt",
+ device="cuda",
+ precision="16",
+ )
+ retriever = GoldenRetriever(
+ question_encoder="intfloat/e5-small-v2", document_index=document_index
+ )
+
+ train_dataset = AidaInBatchNegativesDataset(
+ name="aida_train",
+ path="/root/golden-retriever-v2/data/dpr-like/el/aida_32_tokens_topic/train.jsonl",
+ tokenizer=retriever.question_tokenizer,
+ question_batch_size=64,
+ passage_batch_size=400,
+ max_passage_length=64,
+ use_topics=True,
+ shuffle=True,
+ )
+ val_dataset = AidaInBatchNegativesDataset(
+ name="aida_val",
+ path="/root/golden-retriever-v2/data/dpr-like/el/aida_32_tokens_topic/val.jsonl",
+ tokenizer=retriever.question_tokenizer,
+ question_batch_size=64,
+ passage_batch_size=400,
+ max_passage_length=64,
+ use_topics=True,
+ )
+
+ trainer = RetrieverTrainer(
+ retriever=retriever,
+ train_dataset=train_dataset,
+ val_dataset=val_dataset,
+ max_steps=25_000,
+ wandb_offline_mode=True,
+ )
+
+ trainer.train()
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000000000000000000000000000000000000..0fbfa01d83e3309a377c8c9013070809b045999a
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,15 @@
+[tool.black]
+include = '\.pyi?$'
+exclude = '''
+/(
+ \.git
+ | \.hg
+ | \.mypy_cache
+ | \.tox
+ | \.venv
+ | _build
+ | buck-out
+ | build
+ | dist
+)/
+'''
diff --git a/relik/__init__.py b/relik/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..42a3df6b991b0af65ec5974fc4faa381b8e555b7
--- /dev/null
+++ b/relik/__init__.py
@@ -0,0 +1 @@
+from relik.retriever.pytorch_modules.model import GoldenRetriever
diff --git a/relik/common/__init__.py b/relik/common/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/relik/common/log.py b/relik/common/log.py
new file mode 100644
index 0000000000000000000000000000000000000000..b91e1fa7bfc22b759e0da4d69563315e31ce0e60
--- /dev/null
+++ b/relik/common/log.py
@@ -0,0 +1,97 @@
+import logging
+import sys
+import threading
+from typing import Optional
+
+from rich import get_console
+
+_lock = threading.Lock()
+_default_handler: Optional[logging.Handler] = None
+
+_default_log_level = logging.WARNING
+
+# fancy logger
+_console = get_console()
+
+
+def _get_library_name() -> str:
+ return __name__.split(".")[0]
+
+
+def _get_library_root_logger() -> logging.Logger:
+ return logging.getLogger(_get_library_name())
+
+
+def _configure_library_root_logger() -> None:
+ global _default_handler
+
+ with _lock:
+ if _default_handler:
+ # This library has already configured the library root logger.
+ return
+ _default_handler = logging.StreamHandler() # Set sys.stderr as stream.
+ _default_handler.flush = sys.stderr.flush
+
+ # Apply our default configuration to the library root logger.
+ library_root_logger = _get_library_root_logger()
+ library_root_logger.addHandler(_default_handler)
+ library_root_logger.setLevel(_default_log_level)
+ library_root_logger.propagate = False
+
+
+def _reset_library_root_logger() -> None:
+ global _default_handler
+
+ with _lock:
+ if not _default_handler:
+ return
+
+ library_root_logger = _get_library_root_logger()
+ library_root_logger.removeHandler(_default_handler)
+ library_root_logger.setLevel(logging.NOTSET)
+ _default_handler = None
+
+
+def set_log_level(level: int, logger: logging.Logger = None) -> None:
+ """
+ Set the log level.
+ Args:
+ level (:obj:`int`):
+ Logging level.
+ logger (:obj:`logging.Logger`):
+ Logger to set the log level.
+ """
+ if not logger:
+ _configure_library_root_logger()
+ logger = _get_library_root_logger()
+ logger.setLevel(level)
+
+
+def get_logger(
+ name: Optional[str] = None,
+ level: Optional[int] = None,
+ formatter: Optional[str] = None,
+) -> logging.Logger:
+ """
+ Return a logger with the specified name.
+ """
+
+ if name is None:
+ name = _get_library_name()
+
+ _configure_library_root_logger()
+
+ if level is not None:
+ set_log_level(level)
+
+ if formatter is None:
+ formatter = logging.Formatter(
+ "%(asctime)s - %(levelname)s - %(name)s - %(message)s"
+ )
+ _default_handler.setFormatter(formatter)
+
+ return logging.getLogger(name)
+
+
+def get_console_logger():
+ return _console
diff --git a/relik/common/upload.py b/relik/common/upload.py
new file mode 100644
index 0000000000000000000000000000000000000000..b2cad77bd95f43992af3144baf296560a496556b
--- /dev/null
+++ b/relik/common/upload.py
@@ -0,0 +1,128 @@
+import argparse
+import json
+import logging
+import os
+import tempfile
+import zipfile
+from datetime import datetime
+from pathlib import Path
+from typing import Optional, Union
+
+import huggingface_hub
+
+from relik.common.log import get_logger
+from relik.common.utils import SAPIENZANLP_DATE_FORMAT, get_md5
+
+logger = get_logger(level=logging.DEBUG)
+
+
+def create_info_file(tmpdir: Path):
+ logger.debug("Computing md5 of model.zip")
+ md5 = get_md5(tmpdir / "model.zip")
+ date = datetime.now().strftime(SAPIENZANLP_DATE_FORMAT)
+
+ logger.debug("Dumping info.json file")
+ with (tmpdir / "info.json").open("w") as f:
+ json.dump(dict(md5=md5, upload_date=date), f, indent=2)
+
+
+def zip_run(
+ dir_path: Union[str, os.PathLike],
+ tmpdir: Union[str, os.PathLike],
+ zip_name: str = "model.zip",
+) -> Path:
+ logger.debug(f"zipping {dir_path} to {tmpdir}")
+ # creates a zip version of the provided dir_path
+ run_dir = Path(dir_path)
+ zip_path = tmpdir / zip_name
+
+ with zipfile.ZipFile(zip_path, "w") as zip_file:
+ # fully zip the run directory maintaining its structure
+ for file in run_dir.rglob("*.*"):
+ if file.is_dir():
+ continue
+
+ zip_file.write(file, arcname=file.relative_to(run_dir))
+
+ return zip_path
+
+
+def upload(
+ model_dir: Union[str, os.PathLike],
+ model_name: str,
+ organization: Optional[str] = None,
+ repo_name: Optional[str] = None,
+ commit: Optional[str] = None,
+ archive: bool = False,
+):
+ token = huggingface_hub.HfFolder.get_token()
+ if token is None:
+ print(
+ "No HuggingFace token found. You need to execute `huggingface-cli login` first!"
+ )
+ return
+
+ repo_id = repo_name or model_name
+ if organization is not None:
+ repo_id = f"{organization}/{repo_id}"
+ with tempfile.TemporaryDirectory() as tmpdir:
+ api = huggingface_hub.HfApi()
+ repo_url = api.create_repo(
+ token=token,
+ repo_id=repo_id,
+ exist_ok=True,
+ )
+ repo = huggingface_hub.Repository(
+ str(tmpdir), clone_from=repo_url, use_auth_token=token
+ )
+
+ tmp_path = Path(tmpdir)
+ if archive:
+ # otherwise we zip the model_dir
+ logger.debug(f"Zipping {model_dir} to {tmp_path}")
+ zip_run(model_dir, tmp_path)
+ create_info_file(tmp_path)
+ else:
+ # if the user wants to upload a transformers model, we don't need to zip it
+ # we just need to copy the files to the tmpdir
+ logger.debug(f"Copying {model_dir} to {tmpdir}")
+ os.system(f"cp -r {model_dir}/* {tmpdir}")
+
+ # this method automatically puts large files (>10MB) into git lfs
+ repo.push_to_hub(commit_message=commit or "Automatic push from sapienzanlp")
+
+
+def parse_args() -> argparse.Namespace:
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "model_dir", help="The directory of the model you want to upload"
+ )
+ parser.add_argument("model_name", help="The model you want to upload")
+ parser.add_argument(
+ "--organization",
+ help="the name of the organization where you want to upload the model",
+ )
+ parser.add_argument(
+ "--repo_name",
+ help="Optional name to use when uploading to the HuggingFace repository",
+ )
+ parser.add_argument(
+ "--commit", help="Commit message to use when pushing to the HuggingFace Hub"
+ )
+ parser.add_argument(
+ "--archive",
+ action="store_true",
+ help="""
+ Whether to compress the model directory before uploading it.
+ If True, the model directory will be zipped and the zip file will be uploaded.
+ If False, the model directory will be uploaded as is.""",
+ )
+ return parser.parse_args()
+
+
+def main():
+ upload(**vars(parse_args()))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/relik/common/utils.py b/relik/common/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e49b88feca9a8e6cc517b583fd07c924d57ed8b
--- /dev/null
+++ b/relik/common/utils.py
@@ -0,0 +1,609 @@
+import importlib.util
+import json
+import logging
+import os
+import shutil
+import tarfile
+import tempfile
+from functools import partial
+from hashlib import sha256
+from pathlib import Path
+from typing import Any, BinaryIO, Dict, List, Optional, Union
+from urllib.parse import urlparse
+from zipfile import ZipFile, is_zipfile
+
+import huggingface_hub
+import requests
+import tqdm
+from filelock import FileLock
+from transformers.utils.hub import cached_file as hf_cached_file
+
+from relik.common.log import get_logger
+
+# name constants
+WEIGHTS_NAME = "weights.pt"
+ONNX_WEIGHTS_NAME = "weights.onnx"
+CONFIG_NAME = "config.yaml"
+LABELS_NAME = "labels.json"
+
+# SAPIENZANLP_USER_NAME = "sapienzanlp"
+SAPIENZANLP_USER_NAME = "riccorl"
+SAPIENZANLP_HF_MODEL_REPO_URL = "riccorl/{model_id}"
+SAPIENZANLP_HF_MODEL_REPO_ARCHIVE_URL = (
+ f"{SAPIENZANLP_HF_MODEL_REPO_URL}/resolve/main/model.zip"
+)
+# path constants
+SAPIENZANLP_CACHE_DIR = os.getenv("SAPIENZANLP_CACHE_DIR", Path.home() / ".sapienzanlp")
+SAPIENZANLP_DATE_FORMAT = "%Y-%m-%d %H-%M-%S"
+
+
+logger = get_logger(__name__)
+
+
+def sapienzanlp_model_urls(model_id: str) -> str:
+ """
+ Returns the URL for a possible SapienzaNLP valid model.
+
+ Args:
+ model_id (:obj:`str`):
+ A SapienzaNLP model id.
+
+ Returns:
+ :obj:`str`: The url for the model id.
+ """
+ # check if there is already the namespace of the user
+ if "/" in model_id:
+ return model_id
+ return SAPIENZANLP_HF_MODEL_REPO_URL.format(model_id=model_id)
+
+
+def is_package_available(package_name: str) -> bool:
+ """
+ Check if a package is available.
+
+ Args:
+ package_name (`str`): The name of the package to check.
+ """
+ return importlib.util.find_spec(package_name) is not None
+
+
+def load_json(path: Union[str, Path]) -> Any:
+ """
+ Load a json file provided in input.
+
+ Args:
+ path (`Union[str, Path]`): The path to the json file to load.
+
+ Returns:
+ `Any`: The loaded json file.
+ """
+ with open(path, encoding="utf8") as f:
+ return json.load(f)
+
+
+def dump_json(document: Any, path: Union[str, Path], indent: Optional[int] = None):
+ """
+ Dump input to json file.
+
+ Args:
+ document (`Any`): The document to dump.
+ path (`Union[str, Path]`): The path to dump the document to.
+ indent (`Optional[int]`): The indent to use for the json file.
+
+ """
+ with open(path, "w", encoding="utf8") as outfile:
+ json.dump(document, outfile, indent=indent)
+
+
+def get_md5(path: Path):
+ """
+ Get the MD5 value of a path.
+ """
+ import hashlib
+
+ with path.open("rb") as fin:
+ data = fin.read()
+ return hashlib.md5(data).hexdigest()
+
+
+def file_exists(path: Union[str, os.PathLike]) -> bool:
+ """
+ Check if the file at :obj:`path` exists.
+
+ Args:
+ path (:obj:`str`, :obj:`os.PathLike`):
+ Path to check.
+
+ Returns:
+ :obj:`bool`: :obj:`True` if the file exists.
+ """
+ return Path(path).exists()
+
+
+def dir_exists(path: Union[str, os.PathLike]) -> bool:
+ """
+ Check if the directory at :obj:`path` exists.
+
+ Args:
+ path (:obj:`str`, :obj:`os.PathLike`):
+ Path to check.
+
+ Returns:
+ :obj:`bool`: :obj:`True` if the directory exists.
+ """
+ return Path(path).is_dir()
+
+
+def is_remote_url(url_or_filename: Union[str, Path]):
+ """
+ Returns :obj:`True` if the input path is an url.
+
+ Args:
+ url_or_filename (:obj:`str`, :obj:`Path`):
+ path to check.
+
+ Returns:
+ :obj:`bool`: :obj:`True` if the input path is an url, :obj:`False` otherwise.
+
+ """
+ if isinstance(url_or_filename, Path):
+ url_or_filename = str(url_or_filename)
+ parsed = urlparse(url_or_filename)
+ return parsed.scheme in ("http", "https")
+
+
+def url_to_filename(resource: str, etag: str = None) -> str:
+ """
+ Convert a `resource` into a hashed filename in a repeatable way.
+ If `etag` is specified, append its hash to the resources's, delimited
+ by a period.
+ """
+ resource_bytes = resource.encode("utf-8")
+ resource_hash = sha256(resource_bytes)
+ filename = resource_hash.hexdigest()
+
+ if etag:
+ etag_bytes = etag.encode("utf-8")
+ etag_hash = sha256(etag_bytes)
+ filename += "." + etag_hash.hexdigest()
+
+ return filename
+
+
+def download_resource(
+ url: str,
+ temp_file: BinaryIO,
+ headers=None,
+):
+ """
+ Download remote file.
+ """
+
+ if headers is None:
+ headers = {}
+
+ r = requests.get(url, stream=True, headers=headers)
+ r.raise_for_status()
+ content_length = r.headers.get("Content-Length")
+ total = int(content_length) if content_length is not None else None
+ progress = tqdm(
+ unit="B",
+ unit_scale=True,
+ total=total,
+ desc="Downloading",
+ disable=logger.level in [logging.NOTSET],
+ )
+ for chunk in r.iter_content(chunk_size=1024):
+ if chunk: # filter out keep-alive new chunks
+ progress.update(len(chunk))
+ temp_file.write(chunk)
+ progress.close()
+
+
+def download_and_cache(
+ url: Union[str, Path],
+ cache_dir: Union[str, Path] = None,
+ force_download: bool = False,
+):
+ if cache_dir is None:
+ cache_dir = SAPIENZANLP_CACHE_DIR
+ if isinstance(url, Path):
+ url = str(url)
+
+ # check if cache dir exists
+ Path(cache_dir).mkdir(parents=True, exist_ok=True)
+
+ # check if file is private
+ headers = {}
+ try:
+ r = requests.head(url, allow_redirects=False, timeout=10)
+ r.raise_for_status()
+ except requests.exceptions.HTTPError:
+ if r.status_code == 401:
+ hf_token = huggingface_hub.HfFolder.get_token()
+ if hf_token is None:
+ raise ValueError(
+ "You need to login to HuggingFace to download this model "
+ "(use the `huggingface-cli login` command)"
+ )
+ headers["Authorization"] = f"Bearer {hf_token}"
+
+ etag = None
+ try:
+ r = requests.head(url, allow_redirects=True, timeout=10, headers=headers)
+ r.raise_for_status()
+ etag = r.headers.get("X-Linked-Etag") or r.headers.get("ETag")
+ # We favor a custom header indicating the etag of the linked resource, and
+ # we fallback to the regular etag header.
+ # If we don't have any of those, raise an error.
+ if etag is None:
+ raise OSError(
+ "Distant resource does not have an ETag, we won't be able to reliably ensure reproducibility."
+ )
+ # In case of a redirect,
+ # save an extra redirect on the request.get call,
+ # and ensure we download the exact atomic version even if it changed
+ # between the HEAD and the GET (unlikely, but hey).
+ if 300 <= r.status_code <= 399:
+ url = r.headers["Location"]
+ except (requests.exceptions.SSLError, requests.exceptions.ProxyError):
+ # Actually raise for those subclasses of ConnectionError
+ raise
+ except (requests.exceptions.ConnectionError, requests.exceptions.Timeout):
+ # Otherwise, our Internet connection is down.
+ # etag is None
+ pass
+
+ # get filename from the url
+ filename = url_to_filename(url, etag)
+ # get cache path to put the file
+ cache_path = cache_dir / filename
+
+ # the file is already here, return it
+ if file_exists(cache_path) and not force_download:
+ logger.info(
+ f"{url} found in cache, set `force_download=True` to force the download"
+ )
+ return cache_path
+
+ cache_path = str(cache_path)
+ # Prevent parallel downloads of the same file with a lock.
+ lock_path = cache_path + ".lock"
+ with FileLock(lock_path):
+ # If the download just completed while the lock was activated.
+ if file_exists(cache_path) and not force_download:
+ # Even if returning early like here, the lock will be released.
+ return cache_path
+
+ temp_file_manager = partial(
+ tempfile.NamedTemporaryFile, mode="wb", dir=cache_dir, delete=False
+ )
+
+ # Download to temporary file, then copy to cache dir once finished.
+ # Otherwise, you get corrupt cache entries if the download gets interrupted.
+ with temp_file_manager() as temp_file:
+ logger.info(
+ f"{url} not found in cache or `force_download` set to `True`, downloading to {temp_file.name}"
+ )
+ download_resource(url, temp_file, headers)
+
+ logger.info(f"storing {url} in cache at {cache_path}")
+ os.replace(temp_file.name, cache_path)
+
+ # NamedTemporaryFile creates a file with hardwired 0600 perms (ignoring umask), so fixing it.
+ umask = os.umask(0o666)
+ os.umask(umask)
+ os.chmod(cache_path, 0o666 & ~umask)
+
+ logger.info(f"creating metadata file for {cache_path}")
+ meta = {"url": url} # , "etag": etag}
+ meta_path = cache_path + ".json"
+ with open(meta_path, "w") as meta_file:
+ json.dump(meta, meta_file)
+
+ return cache_path
+
+
+def download_from_hf(
+ path_or_repo_id: Union[str, Path],
+ filenames: Optional[List[str]],
+ cache_dir: Union[str, Path] = None,
+ force_download: bool = False,
+ resume_download: bool = False,
+ proxies: Optional[Dict[str, str]] = None,
+ use_auth_token: Optional[Union[bool, str]] = None,
+ revision: Optional[str] = None,
+ local_files_only: bool = False,
+ subfolder: str = "",
+):
+ if isinstance(path_or_repo_id, Path):
+ path_or_repo_id = str(path_or_repo_id)
+
+ downloaded_paths = []
+ for filename in filenames:
+ downloaded_path = hf_cached_file(
+ path_or_repo_id,
+ filename,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ local_files_only=local_files_only,
+ subfolder=subfolder,
+ )
+ downloaded_paths.append(downloaded_path)
+
+ # we want the folder where the files are downloaded
+ # the best guess is the parent folder of the first file
+ probably_the_folder = Path(downloaded_paths[0]).parent
+ return probably_the_folder
+
+
+def model_name_or_path_resolver(model_name_or_dir: Union[str, os.PathLike]) -> str:
+ """
+ Resolve a model name or directory to a model archive name or directory.
+
+ Args:
+ model_name_or_dir (:obj:`str` or :obj:`os.PathLike`):
+ A model name or directory.
+
+ Returns:
+ :obj:`str`: The model archive name or directory.
+ """
+ if is_remote_url(model_name_or_dir):
+ # if model_name_or_dir is a URL
+ # download it and try to load
+ model_archive = model_name_or_dir
+ elif Path(model_name_or_dir).is_dir() or Path(model_name_or_dir).is_file():
+ # if model_name_or_dir is a local directory or
+ # an archive file try to load it
+ model_archive = model_name_or_dir
+ else:
+ # probably model_name_or_dir is a sapienzanlp model id
+ # guess the url and try to download
+ model_name_or_dir_ = model_name_or_dir
+ # raise ValueError(f"Providing a model id is not supported yet.")
+ model_archive = sapienzanlp_model_urls(model_name_or_dir_)
+
+ return model_archive
+
+
+def from_cache(
+ url_or_filename: Union[str, Path],
+ cache_dir: Union[str, Path] = None,
+ force_download: bool = False,
+ resume_download: bool = False,
+ proxies: Optional[Dict[str, str]] = None,
+ use_auth_token: Optional[Union[bool, str]] = None,
+ revision: Optional[str] = None,
+ local_files_only: bool = False,
+ subfolder: str = "",
+ filenames: Optional[List[str]] = None,
+) -> Path:
+ """
+ Given something that could be either a local path or a URL (or a SapienzaNLP model id),
+ determine which one and return a path to the corresponding file.
+
+ Args:
+ url_or_filename (:obj:`str` or :obj:`Path`):
+ A path to a local file or a URL (or a SapienzaNLP model id).
+ cache_dir (:obj:`str` or :obj:`Path`, `optional`):
+ Path to a directory in which a downloaded file will be cached.
+ force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
+ Whether or not to re-download the file even if it already exists.
+ resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
+ Whether or not to delete incompletely received files. Attempts to resume the download if such a file
+ exists.
+ proxies (:obj:`Dict[str, str]`, `optional`):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
+ use_auth_token (:obj:`Union[bool, str]`, `optional`):
+ Optional string or boolean to use as Bearer token for remote files. If :obj:`True`, will get token from
+ :obj:`~transformers.hf_api.HfApi`. If :obj:`str`, will use that string as token.
+ revision (:obj:`str`, `optional`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
+ identifier allowed by git.
+ local_files_only (:obj:`bool`, `optional`, defaults to :obj:`False`):
+ Whether or not to raise an error if the file to be downloaded is local.
+ subfolder (:obj:`str`, `optional`):
+ In case the relevant file is in a subfolder of the URL, specify it here.
+ filenames (:obj:`List[str]`, `optional`):
+ List of filenames to look for in the directory structure.
+
+ Returns:
+ :obj:`Path`: Path to the cached file.
+ """
+
+ url_or_filename = model_name_or_path_resolver(url_or_filename)
+
+ if cache_dir is None:
+ cache_dir = SAPIENZANLP_CACHE_DIR
+
+ if file_exists(url_or_filename):
+ logger.info(f"{url_or_filename} is a local path or file")
+ output_path = url_or_filename
+ elif is_remote_url(url_or_filename):
+ # URL, so get it from the cache (downloading if necessary)
+ output_path = download_and_cache(
+ url_or_filename,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ )
+ else:
+ if filenames is None:
+ filenames = [WEIGHTS_NAME, CONFIG_NAME, LABELS_NAME]
+ output_path = download_from_hf(
+ url_or_filename,
+ filenames,
+ cache_dir,
+ force_download,
+ resume_download,
+ proxies,
+ use_auth_token,
+ revision,
+ local_files_only,
+ subfolder,
+ )
+
+ # if is_hf_hub_url(url_or_filename):
+ # HuggingFace Hub
+ # output_path = hf_hub_download_url(url_or_filename)
+ # elif is_remote_url(url_or_filename):
+ # # URL, so get it from the cache (downloading if necessary)
+ # output_path = download_and_cache(
+ # url_or_filename,
+ # cache_dir=cache_dir,
+ # force_download=force_download,
+ # )
+ # elif file_exists(url_or_filename):
+ # logger.info(f"{url_or_filename} is a local path or file")
+ # # File, and it exists.
+ # output_path = url_or_filename
+ # elif urlparse(url_or_filename).scheme == "":
+ # # File, but it doesn't exist.
+ # raise EnvironmentError(f"file {url_or_filename} not found")
+ # else:
+ # # Something unknown
+ # raise ValueError(
+ # f"unable to parse {url_or_filename} as a URL or as a local path"
+ # )
+
+ if dir_exists(output_path) or (
+ not is_zipfile(output_path) and not tarfile.is_tarfile(output_path)
+ ):
+ return Path(output_path)
+
+ # Path where we extract compressed archives
+ # for now it will extract it in the same folder
+ # maybe implement extraction in the sapienzanlp folder
+ # when using local archive path?
+ logger.info("Extracting compressed archive")
+ output_dir, output_file = os.path.split(output_path)
+ output_extract_dir_name = output_file.replace(".", "-") + "-extracted"
+ output_path_extracted = os.path.join(output_dir, output_extract_dir_name)
+
+ # already extracted, do not extract
+ if (
+ os.path.isdir(output_path_extracted)
+ and os.listdir(output_path_extracted)
+ and not force_download
+ ):
+ return Path(output_path_extracted)
+
+ # Prevent parallel extractions
+ lock_path = output_path + ".lock"
+ with FileLock(lock_path):
+ shutil.rmtree(output_path_extracted, ignore_errors=True)
+ os.makedirs(output_path_extracted)
+ if is_zipfile(output_path):
+ with ZipFile(output_path, "r") as zip_file:
+ zip_file.extractall(output_path_extracted)
+ zip_file.close()
+ elif tarfile.is_tarfile(output_path):
+ tar_file = tarfile.open(output_path)
+ tar_file.extractall(output_path_extracted)
+ tar_file.close()
+ else:
+ raise EnvironmentError(
+ f"Archive format of {output_path} could not be identified"
+ )
+
+ # remove lock file, is it safe?
+ os.remove(lock_path)
+
+ return Path(output_path_extracted)
+
+
+def is_str_a_path(maybe_path: str) -> bool:
+ """
+ Check if a string is a path.
+
+ Args:
+ maybe_path (`str`): The string to check.
+
+ Returns:
+ `bool`: `True` if the string is a path, `False` otherwise.
+ """
+ # first check if it is a path
+ if Path(maybe_path).exists():
+ return True
+ # check if it is a relative path
+ if Path(os.path.join(os.getcwd(), maybe_path)).exists():
+ return True
+ # otherwise it is not a path
+ return False
+
+
+def relative_to_absolute_path(path: str) -> os.PathLike:
+ """
+ Convert a relative path to an absolute path.
+
+ Args:
+ path (`str`): The relative path to convert.
+
+ Returns:
+ `os.PathLike`: The absolute path.
+ """
+ if not is_str_a_path(path):
+ raise ValueError(f"{path} is not a path")
+ if Path(path).exists():
+ return Path(path).absolute()
+ if Path(os.path.join(os.getcwd(), path)).exists():
+ return Path(os.path.join(os.getcwd(), path)).absolute()
+ raise ValueError(f"{path} is not a path")
+
+
+def to_config(object_to_save: Any) -> Dict[str, Any]:
+ """
+ Convert an object to a dictionary.
+
+ Returns:
+ `Dict[str, Any]`: The dictionary representation of the object.
+ """
+
+ def obj_to_dict(obj):
+ match obj:
+ case dict():
+ data = {}
+ for k, v in obj.items():
+ data[k] = obj_to_dict(v)
+ return data
+
+ case list() | tuple():
+ return [obj_to_dict(x) for x in obj]
+
+ case object(__dict__=_):
+ data = {
+ "_target_": f"{obj.__class__.__module__}.{obj.__class__.__name__}",
+ }
+ for k, v in obj.__dict__.items():
+ if not k.startswith("_"):
+ data[k] = obj_to_dict(v)
+ return data
+
+ case _:
+ return obj
+
+ return obj_to_dict(object_to_save)
+
+
+def get_callable_from_string(callable_fn: str) -> Any:
+ """
+ Get a callable from a string.
+
+ Args:
+ callable_fn (`str`):
+ The string representation of the callable.
+
+ Returns:
+ `Any`: The callable.
+ """
+ # separate the function name from the module name
+ module_name, function_name = callable_fn.rsplit(".", 1)
+ # import the module
+ module = importlib.import_module(module_name)
+ # get the function
+ return getattr(module, function_name)
diff --git a/relik/inference/__init__.py b/relik/inference/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/relik/inference/annotator.py b/relik/inference/annotator.py
new file mode 100644
index 0000000000000000000000000000000000000000..de356f690985a1496b607664dd77d5e49546257d
--- /dev/null
+++ b/relik/inference/annotator.py
@@ -0,0 +1,422 @@
+import os
+from pathlib import Path
+from typing import Any, Callable, Dict, Optional, Union
+
+import hydra
+from omegaconf import OmegaConf
+from relik.retriever.pytorch_modules.hf import GoldenRetrieverModel
+from rich.pretty import pprint
+
+from relik.common.log import get_console_logger, get_logger
+from relik.common.upload import upload
+from relik.common.utils import CONFIG_NAME, from_cache, get_callable_from_string
+from relik.inference.data.objects import EntitySpan, RelikOutput
+from relik.inference.data.tokenizers.spacy_tokenizer import SpacyTokenizer
+from relik.inference.data.window.manager import WindowManager
+from relik.reader.pytorch_modules.span import RelikReaderForSpanExtraction
+from relik.reader.relik_reader import RelikReader
+from relik.retriever.data.utils import batch_generator
+from relik.retriever.indexers.base import BaseDocumentIndex
+from relik.retriever.pytorch_modules.model import GoldenRetriever
+
+logger = get_logger(__name__)
+console_logger = get_console_logger()
+
+
+class Relik:
+ """
+ Relik main class. It is a wrapper around a retriever and a reader.
+
+ Args:
+ retriever (`Optional[GoldenRetriever]`, `optional`):
+ The retriever to use. If `None`, a retriever will be instantiated from the
+ provided `question_encoder`, `passage_encoder` and `document_index`.
+ Defaults to `None`.
+ question_encoder (`Optional[Union[str, GoldenRetrieverModel]]`, `optional`):
+ The question encoder to use. If `retriever` is `None`, a retriever will be
+ instantiated from this parameter. Defaults to `None`.
+ passage_encoder (`Optional[Union[str, GoldenRetrieverModel]]`, `optional`):
+ The passage encoder to use. If `retriever` is `None`, a retriever will be
+ instantiated from this parameter. Defaults to `None`.
+ document_index (`Optional[Union[str, BaseDocumentIndex]]`, `optional`):
+ The document index to use. If `retriever` is `None`, a retriever will be
+ instantiated from this parameter. Defaults to `None`.
+ reader (`Optional[Union[str, RelikReader]]`, `optional`):
+ The reader to use. If `None`, a reader will be instantiated from the
+ provided `reader`. Defaults to `None`.
+ retriever_device (`str`, `optional`, defaults to `cpu`):
+ The device to use for the retriever.
+
+ """
+
+ def __init__(
+ self,
+ retriever: GoldenRetriever | None = None,
+ question_encoder: str | GoldenRetrieverModel | None = None,
+ passage_encoder: str | GoldenRetrieverModel | None = None,
+ document_index: str | BaseDocumentIndex | None = None,
+ reader: str | RelikReader | None = None,
+ device: str = "cpu",
+ retriever_device: str | None = None,
+ document_index_device: str | None = None,
+ reader_device: str | None = None,
+ precision: int = 32,
+ retriever_precision: int | None = None,
+ document_index_precision: int | None = None,
+ reader_precision: int | None = None,
+ reader_kwargs: dict | None = None,
+ retriever_kwargs: dict | None = None,
+ candidates_preprocessing_fn: str | Callable | None = None,
+ top_k: int | None = None,
+ window_size: int | None = None,
+ window_stride: int | None = None,
+ **kwargs,
+ ) -> None:
+ # retriever
+ retriever_device = retriever_device or device
+ document_index_device = document_index_device or device
+ retriever_precision = retriever_precision or precision
+ document_index_precision = document_index_precision or precision
+ if retriever is None and question_encoder is None:
+ raise ValueError(
+ "Either `retriever` or `question_encoder` must be provided"
+ )
+ if retriever is None:
+ self.retriever_kwargs = dict(
+ question_encoder=question_encoder,
+ passage_encoder=passage_encoder,
+ document_index=document_index,
+ device=retriever_device,
+ precision=retriever_precision,
+ index_device=document_index_device,
+ index_precision=document_index_precision,
+ )
+ # overwrite default_retriever_kwargs with retriever_kwargs
+ self.retriever_kwargs.update(retriever_kwargs or {})
+ retriever = GoldenRetriever(**self.retriever_kwargs)
+ retriever.training = False
+ retriever.eval()
+ self.retriever = retriever
+
+ # reader
+ self.reader_device = reader_device or device
+ self.reader_precision = reader_precision or precision
+ self.reader_kwargs = reader_kwargs
+ if isinstance(reader, str):
+ reader_kwargs = reader_kwargs or {}
+ reader = RelikReaderForSpanExtraction(reader, **reader_kwargs)
+ self.reader = reader
+
+ # windowization stuff
+ self.tokenizer = SpacyTokenizer(language="en")
+ self.window_manager: WindowManager | None = None
+
+ # candidates preprocessing
+ # TODO: maybe move this logic somewhere else
+ candidates_preprocessing_fn = candidates_preprocessing_fn or (lambda x: x)
+ if isinstance(candidates_preprocessing_fn, str):
+ candidates_preprocessing_fn = get_callable_from_string(
+ candidates_preprocessing_fn
+ )
+ self.candidates_preprocessing_fn = candidates_preprocessing_fn
+
+ # inference params
+ self.top_k = top_k
+ self.window_size = window_size
+ self.window_stride = window_stride
+
+ def __call__(
+ self,
+ text: Union[str, list],
+ top_k: Optional[int] = None,
+ window_size: Optional[int] = None,
+ window_stride: Optional[int] = None,
+ retriever_batch_size: Optional[int] = 32,
+ reader_batch_size: Optional[int] = 32,
+ return_also_windows: bool = False,
+ **kwargs,
+ ) -> Union[RelikOutput, list[RelikOutput]]:
+ """
+ Annotate a text with entities.
+
+ Args:
+ text (`str` or `list`):
+ The text to annotate. If a list is provided, each element of the list
+ will be annotated separately.
+ top_k (`int`, `optional`, defaults to `None`):
+ The number of candidates to retrieve for each window.
+ window_size (`int`, `optional`, defaults to `None`):
+ The size of the window. If `None`, the whole text will be annotated.
+ window_stride (`int`, `optional`, defaults to `None`):
+ The stride of the window. If `None`, there will be no overlap between windows.
+ retriever_batch_size (`int`, `optional`, defaults to `None`):
+ The batch size to use for the retriever. The whole input is the batch for the retriever.
+ reader_batch_size (`int`, `optional`, defaults to `None`):
+ The batch size to use for the reader. The whole input is the batch for the reader.
+ return_also_windows (`bool`, `optional`, defaults to `False`):
+ Whether to return the windows in the output.
+ **kwargs:
+ Additional keyword arguments to pass to the retriever and the reader.
+
+ Returns:
+ `RelikOutput` or `list[RelikOutput]`:
+ The annotated text. If a list was provided as input, a list of
+ `RelikOutput` objects will be returned.
+ """
+ if top_k is None:
+ top_k = self.top_k or 100
+ if window_size is None:
+ window_size = self.window_size
+ if window_stride is None:
+ window_stride = self.window_stride
+
+ if isinstance(text, str):
+ text = [text]
+
+ if window_size is not None:
+ if self.window_manager is None:
+ self.window_manager = WindowManager(self.tokenizer)
+
+ if window_size == "sentence":
+ # todo: implement sentence windowizer
+ raise NotImplementedError("Sentence windowizer not implemented yet")
+
+ # if window_size < window_stride:
+ # raise ValueError(
+ # f"Window size ({window_size}) must be greater than window stride ({window_stride})"
+ # )
+
+ # window generator
+ windows = [
+ window
+ for doc_id, t in enumerate(text)
+ for window in self.window_manager.create_windows(
+ t,
+ window_size=window_size,
+ stride=window_stride,
+ doc_id=doc_id,
+ )
+ ]
+
+ # retrieve candidates first
+ windows_candidates = []
+ # TODO: Move batching inside retriever
+ for batch in batch_generator(windows, batch_size=retriever_batch_size):
+ retriever_out = self.retriever.retrieve([b.text for b in batch], k=top_k)
+ windows_candidates.extend(
+ [[p.label for p in predictions] for predictions in retriever_out]
+ )
+
+ # add passage to the windows
+ for window, candidates in zip(windows, windows_candidates):
+ window.window_candidates = [
+ self.candidates_preprocessing_fn(c) for c in candidates
+ ]
+
+ windows = self.reader.read(samples=windows, max_batch_size=reader_batch_size)
+ windows = self.window_manager.merge_windows(windows)
+
+ # transform predictions into RelikOutput objects
+ output = []
+ for w in windows:
+ sample_output = RelikOutput(
+ text=text[w.doc_id],
+ labels=sorted(
+ [
+ EntitySpan(
+ start=ss, end=se, label=sl, text=text[w.doc_id][ss:se]
+ )
+ for ss, se, sl in w.predicted_window_labels_chars
+ ],
+ key=lambda x: x.start,
+ ),
+ )
+ output.append(sample_output)
+
+ if return_also_windows:
+ for i, sample_output in enumerate(output):
+ sample_output.windows = [w for w in windows if w.doc_id == i]
+
+ # if only one text was provided, return a single RelikOutput object
+ if len(output) == 1:
+ return output[0]
+
+ return output
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ model_name_or_dir: Union[str, os.PathLike],
+ config_kwargs: Optional[Dict] = None,
+ config_file_name: str = CONFIG_NAME,
+ *args,
+ **kwargs,
+ ) -> "Relik":
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+
+ model_dir = from_cache(
+ model_name_or_dir,
+ filenames=[config_file_name],
+ cache_dir=cache_dir,
+ force_download=force_download,
+ )
+
+ config_path = model_dir / config_file_name
+ if not config_path.exists():
+ raise FileNotFoundError(
+ f"Model configuration file not found at {config_path}."
+ )
+
+ # overwrite config with config_kwargs
+ config = OmegaConf.load(config_path)
+ if config_kwargs is not None:
+ # TODO: check merging behavior
+ config = OmegaConf.merge(config, OmegaConf.create(config_kwargs))
+ # do we want to print the config? I like it
+ pprint(OmegaConf.to_container(config), console=console_logger, expand_all=True)
+
+ # load relik from config
+ relik = hydra.utils.instantiate(config, *args, **kwargs)
+
+ return relik
+
+ def save_pretrained(
+ self,
+ output_dir: Union[str, os.PathLike],
+ config: Optional[Dict[str, Any]] = None,
+ config_file_name: Optional[str] = None,
+ save_weights: bool = False,
+ push_to_hub: bool = False,
+ model_id: Optional[str] = None,
+ organization: Optional[str] = None,
+ repo_name: Optional[str] = None,
+ **kwargs,
+ ):
+ """
+ Save the configuration of Relik to the specified directory as a YAML file.
+
+ Args:
+ output_dir (`str`):
+ The directory to save the configuration file to.
+ config (`Optional[Dict[str, Any]]`, `optional`):
+ The configuration to save. If `None`, the current configuration will be
+ saved. Defaults to `None`.
+ config_file_name (`Optional[str]`, `optional`):
+ The name of the configuration file. Defaults to `config.yaml`.
+ save_weights (`bool`, `optional`):
+ Whether to save the weights of the model. Defaults to `False`.
+ push_to_hub (`bool`, `optional`):
+ Whether to push the saved model to the hub. Defaults to `False`.
+ model_id (`Optional[str]`, `optional`):
+ The id of the model to push to the hub. If `None`, the name of the
+ directory will be used. Defaults to `None`.
+ organization (`Optional[str]`, `optional`):
+ The organization to push the model to. Defaults to `None`.
+ repo_name (`Optional[str]`, `optional`):
+ The name of the repository to push the model to. Defaults to `None`.
+ **kwargs:
+ Additional keyword arguments to pass to `OmegaConf.save`.
+ """
+ if config is None:
+ # create a default config
+ config = {
+ "_target_": f"{self.__class__.__module__}.{self.__class__.__name__}"
+ }
+ if self.retriever is not None:
+ if self.retriever.question_encoder is not None:
+ config[
+ "question_encoder"
+ ] = self.retriever.question_encoder.name_or_path
+ if self.retriever.passage_encoder is not None:
+ config[
+ "passage_encoder"
+ ] = self.retriever.passage_encoder.name_or_path
+ if self.retriever.document_index is not None:
+ config["document_index"] = self.retriever.document_index.name_or_dir
+ if self.reader is not None:
+ config["reader"] = self.reader.model_path
+
+ config["retriever_kwargs"] = self.retriever_kwargs
+ config["reader_kwargs"] = self.reader_kwargs
+ # expand the fn as to be able to save it and load it later
+ config[
+ "candidates_preprocessing_fn"
+ ] = f"{self.candidates_preprocessing_fn.__module__}.{self.candidates_preprocessing_fn.__name__}"
+
+ # these are model-specific and should be saved
+ config["top_k"] = self.top_k
+ config["window_size"] = self.window_size
+ config["window_stride"] = self.window_stride
+
+ config_file_name = config_file_name or CONFIG_NAME
+
+ # create the output directory
+ output_dir = Path(output_dir)
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ logger.info(f"Saving relik config to {output_dir / config_file_name}")
+ # pretty print the config
+ pprint(config, console=console_logger, expand_all=True)
+ OmegaConf.save(config, output_dir / config_file_name)
+
+ if save_weights:
+ model_id = model_id or output_dir.name
+ retriever_model_id = model_id + "-retriever"
+ # save weights
+ logger.info(f"Saving retriever to {output_dir / retriever_model_id}")
+ self.retriever.save_pretrained(
+ output_dir / retriever_model_id,
+ question_encoder_name=retriever_model_id + "-question-encoder",
+ passage_encoder_name=retriever_model_id + "-passage-encoder",
+ document_index_name=retriever_model_id + "-index",
+ push_to_hub=push_to_hub,
+ organization=organization,
+ repo_name=repo_name,
+ **kwargs,
+ )
+ reader_model_id = model_id + "-reader"
+ logger.info(f"Saving reader to {output_dir / reader_model_id}")
+ self.reader.save_pretrained(
+ output_dir / reader_model_id,
+ push_to_hub=push_to_hub,
+ organization=organization,
+ repo_name=repo_name,
+ **kwargs,
+ )
+
+ if push_to_hub:
+ # push to hub
+ logger.info(f"Pushing to hub")
+ model_id = model_id or output_dir.name
+ upload(output_dir, model_id, organization=organization, repo_name=repo_name)
+
+
+def main():
+ from pprint import pprint
+
+ relik = Relik(
+ question_encoder="riccorl/relik-retriever-aida-blink-pretrain-omniencoder",
+ document_index="riccorl/index-relik-retriever-aida-blink-pretrain-omniencoder",
+ reader="riccorl/relik-reader-aida-deberta-small",
+ device="cuda",
+ precision=16,
+ top_k=100,
+ window_size=32,
+ window_stride=16,
+ candidates_preprocessing_fn="relik.inference.preprocessing.wikipedia_title_and_openings_preprocessing",
+ )
+
+ input_text = """
+ Bernie Ecclestone, the former boss of Formula One, has admitted fraud after failing to declare more than £400m held in a trust in Singapore.
+ The 92-year-old billionaire did not disclose the trust to the government in July 2015.
+ Appearing at Southwark Crown Court on Thursday, he told the judge "I plead guilty" after having previously pleaded not guilty.
+ Ecclestone had been due to go on trial next month.
+ """
+
+ preds = relik(input_text)
+ pprint(preds)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/relik/inference/data/__init__.py b/relik/inference/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/relik/inference/data/objects.py b/relik/inference/data/objects.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b11e9641380b9e13d60de427827a73b70cbb9c1
--- /dev/null
+++ b/relik/inference/data/objects.py
@@ -0,0 +1,64 @@
+from __future__ import annotations
+
+from dataclasses import dataclass
+from typing import List, NamedTuple, Optional
+
+from relik.reader.pytorch_modules.hf.modeling_relik import RelikReaderSample
+
+
+@dataclass
+class Word:
+ """
+ A word representation that includes text, index in the sentence, POS tag, lemma,
+ dependency relation, and similar information.
+
+ # Parameters
+ text : `str`, optional
+ The text representation.
+ index : `int`, optional
+ The word offset in the sentence.
+ lemma : `str`, optional
+ The lemma of this word.
+ pos : `str`, optional
+ The coarse-grained part of speech of this word.
+ dep : `str`, optional
+ The dependency relation for this word.
+
+ input_id : `int`, optional
+ Integer representation of the word, used to pass it to a model.
+ token_type_id : `int`, optional
+ Token type id used by some transformers.
+ attention_mask: `int`, optional
+ Attention mask used by transformers, indicates to the model which tokens should
+ be attended to, and which should not.
+ """
+
+ text: str
+ index: int
+ start_char: Optional[int] = None
+ end_char: Optional[int] = None
+ # preprocessing fields
+ lemma: Optional[str] = None
+ pos: Optional[str] = None
+ dep: Optional[str] = None
+ head: Optional[int] = None
+
+ def __str__(self):
+ return self.text
+
+ def __repr__(self):
+ return self.__str__()
+
+
+class EntitySpan(NamedTuple):
+ start: int
+ end: int
+ label: str
+ text: str
+
+
+@dataclass
+class RelikOutput:
+ text: str
+ labels: List[EntitySpan]
+ windows: Optional[List[RelikReaderSample]] = None
diff --git a/relik/inference/data/tokenizers/__init__.py b/relik/inference/data/tokenizers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad70314e8e0ccc18b946ff1317f6415c1892747a
--- /dev/null
+++ b/relik/inference/data/tokenizers/__init__.py
@@ -0,0 +1,89 @@
+SPACY_LANGUAGE_MAPPER = {
+ "ca": "ca_core_news_sm",
+ "da": "da_core_news_sm",
+ "de": "de_core_news_sm",
+ "el": "el_core_news_sm",
+ "en": "en_core_web_sm",
+ "es": "es_core_news_sm",
+ "fr": "fr_core_news_sm",
+ "it": "it_core_news_sm",
+ "ja": "ja_core_news_sm",
+ "lt": "lt_core_news_sm",
+ "mk": "mk_core_news_sm",
+ "nb": "nb_core_news_sm",
+ "nl": "nl_core_news_sm",
+ "pl": "pl_core_news_sm",
+ "pt": "pt_core_news_sm",
+ "ro": "ro_core_news_sm",
+ "ru": "ru_core_news_sm",
+ "xx": "xx_sent_ud_sm",
+ "zh": "zh_core_web_sm",
+ "ca_core_news_sm": "ca_core_news_sm",
+ "ca_core_news_md": "ca_core_news_md",
+ "ca_core_news_lg": "ca_core_news_lg",
+ "ca_core_news_trf": "ca_core_news_trf",
+ "da_core_news_sm": "da_core_news_sm",
+ "da_core_news_md": "da_core_news_md",
+ "da_core_news_lg": "da_core_news_lg",
+ "da_core_news_trf": "da_core_news_trf",
+ "de_core_news_sm": "de_core_news_sm",
+ "de_core_news_md": "de_core_news_md",
+ "de_core_news_lg": "de_core_news_lg",
+ "de_dep_news_trf": "de_dep_news_trf",
+ "el_core_news_sm": "el_core_news_sm",
+ "el_core_news_md": "el_core_news_md",
+ "el_core_news_lg": "el_core_news_lg",
+ "en_core_web_sm": "en_core_web_sm",
+ "en_core_web_md": "en_core_web_md",
+ "en_core_web_lg": "en_core_web_lg",
+ "en_core_web_trf": "en_core_web_trf",
+ "es_core_news_sm": "es_core_news_sm",
+ "es_core_news_md": "es_core_news_md",
+ "es_core_news_lg": "es_core_news_lg",
+ "es_dep_news_trf": "es_dep_news_trf",
+ "fr_core_news_sm": "fr_core_news_sm",
+ "fr_core_news_md": "fr_core_news_md",
+ "fr_core_news_lg": "fr_core_news_lg",
+ "fr_dep_news_trf": "fr_dep_news_trf",
+ "it_core_news_sm": "it_core_news_sm",
+ "it_core_news_md": "it_core_news_md",
+ "it_core_news_lg": "it_core_news_lg",
+ "ja_core_news_sm": "ja_core_news_sm",
+ "ja_core_news_md": "ja_core_news_md",
+ "ja_core_news_lg": "ja_core_news_lg",
+ "ja_dep_news_trf": "ja_dep_news_trf",
+ "lt_core_news_sm": "lt_core_news_sm",
+ "lt_core_news_md": "lt_core_news_md",
+ "lt_core_news_lg": "lt_core_news_lg",
+ "mk_core_news_sm": "mk_core_news_sm",
+ "mk_core_news_md": "mk_core_news_md",
+ "mk_core_news_lg": "mk_core_news_lg",
+ "nb_core_news_sm": "nb_core_news_sm",
+ "nb_core_news_md": "nb_core_news_md",
+ "nb_core_news_lg": "nb_core_news_lg",
+ "nl_core_news_sm": "nl_core_news_sm",
+ "nl_core_news_md": "nl_core_news_md",
+ "nl_core_news_lg": "nl_core_news_lg",
+ "pl_core_news_sm": "pl_core_news_sm",
+ "pl_core_news_md": "pl_core_news_md",
+ "pl_core_news_lg": "pl_core_news_lg",
+ "pt_core_news_sm": "pt_core_news_sm",
+ "pt_core_news_md": "pt_core_news_md",
+ "pt_core_news_lg": "pt_core_news_lg",
+ "ro_core_news_sm": "ro_core_news_sm",
+ "ro_core_news_md": "ro_core_news_md",
+ "ro_core_news_lg": "ro_core_news_lg",
+ "ru_core_news_sm": "ru_core_news_sm",
+ "ru_core_news_md": "ru_core_news_md",
+ "ru_core_news_lg": "ru_core_news_lg",
+ "xx_ent_wiki_sm": "xx_ent_wiki_sm",
+ "xx_sent_ud_sm": "xx_sent_ud_sm",
+ "zh_core_web_sm": "zh_core_web_sm",
+ "zh_core_web_md": "zh_core_web_md",
+ "zh_core_web_lg": "zh_core_web_lg",
+ "zh_core_web_trf": "zh_core_web_trf",
+}
+
+from relik.inference.data.tokenizers.regex_tokenizer import RegexTokenizer
+from relik.inference.data.tokenizers.spacy_tokenizer import SpacyTokenizer
+from relik.inference.data.tokenizers.whitespace_tokenizer import WhitespaceTokenizer
diff --git a/relik/inference/data/tokenizers/base_tokenizer.py b/relik/inference/data/tokenizers/base_tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..1fed161b3eca085656e85d44cb9a64739f3d1e4c
--- /dev/null
+++ b/relik/inference/data/tokenizers/base_tokenizer.py
@@ -0,0 +1,84 @@
+from typing import List, Union
+
+from relik.inference.data.objects import Word
+
+
+class BaseTokenizer:
+ """
+ A :obj:`Tokenizer` splits strings of text into single words, optionally adds
+ pos tags and perform lemmatization.
+ """
+
+ def __call__(
+ self,
+ texts: Union[str, List[str], List[List[str]]],
+ is_split_into_words: bool = False,
+ **kwargs
+ ) -> List[List[Word]]:
+ """
+ Tokenize the input into single words.
+
+ Args:
+ texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`):
+ Text to tag. It can be a single string, a batch of string and pre-tokenized strings.
+ is_split_into_words (:obj:`bool`, optional, defaults to :obj:`False`):
+ If :obj:`True` and the input is a string, the input is split on spaces.
+
+ Returns:
+ :obj:`List[List[Word]]`: The input text tokenized in single words.
+ """
+ raise NotImplementedError
+
+ def tokenize(self, text: str) -> List[Word]:
+ """
+ Implements splitting words into tokens.
+
+ Args:
+ text (:obj:`str`):
+ Text to tokenize.
+
+ Returns:
+ :obj:`List[Word]`: The input text tokenized in single words.
+
+ """
+ raise NotImplementedError
+
+ def tokenize_batch(self, texts: List[str]) -> List[List[Word]]:
+ """
+ Implements batch splitting words into tokens.
+
+ Args:
+ texts (:obj:`List[str]`):
+ Batch of text to tokenize.
+
+ Returns:
+ :obj:`List[List[Word]]`: The input batch tokenized in single words.
+
+ """
+ return [self.tokenize(text) for text in texts]
+
+ @staticmethod
+ def check_is_batched(
+ texts: Union[str, List[str], List[List[str]]], is_split_into_words: bool
+ ):
+ """
+ Check if input is batched or a single sample.
+
+ Args:
+ texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`):
+ Text to check.
+ is_split_into_words (:obj:`bool`):
+ If :obj:`True` and the input is a string, the input is split on spaces.
+
+ Returns:
+ :obj:`bool`: ``True`` if ``texts`` is batched, ``False`` otherwise.
+ """
+ return bool(
+ (not is_split_into_words and isinstance(texts, (list, tuple)))
+ or (
+ is_split_into_words
+ and isinstance(texts, (list, tuple))
+ and texts
+ and isinstance(texts[0], (list, tuple))
+ )
+ )
diff --git a/relik/inference/data/tokenizers/regex_tokenizer.py b/relik/inference/data/tokenizers/regex_tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebe8656afb891a8318a7030375427e190d1dc383
--- /dev/null
+++ b/relik/inference/data/tokenizers/regex_tokenizer.py
@@ -0,0 +1,73 @@
+import re
+from typing import List, Union
+
+from overrides import overrides
+
+from relik.inference.data.objects import Word
+from relik.inference.data.tokenizers.base_tokenizer import BaseTokenizer
+
+
+class RegexTokenizer(BaseTokenizer):
+ """
+ A :obj:`Tokenizer` that splits the text based on a simple regex.
+ """
+
+ def __init__(self):
+ super(RegexTokenizer, self).__init__()
+ # regex for splitting on spaces and punctuation and new lines
+ # self._regex = re.compile(r"\S+|[\[\](),.!?;:\"]|\\n")
+ self._regex = re.compile(
+ r"\w+|\$[\d\.]+|\S+", re.UNICODE | re.MULTILINE | re.DOTALL
+ )
+
+ def __call__(
+ self,
+ texts: Union[str, List[str], List[List[str]]],
+ is_split_into_words: bool = False,
+ **kwargs,
+ ) -> List[List[Word]]:
+ """
+ Tokenize the input into single words by splitting using a simple regex.
+
+ Args:
+ texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`):
+ Text to tag. It can be a single string, a batch of string and pre-tokenized strings.
+ is_split_into_words (:obj:`bool`, optional, defaults to :obj:`False`):
+ If :obj:`True` and the input is a string, the input is split on spaces.
+
+ Returns:
+ :obj:`List[List[Word]]`: The input text tokenized in single words.
+
+ Example::
+
+ >>> from relik.retriever.serve.tokenizers.regex_tokenizer import RegexTokenizer
+
+ >>> regex_tokenizer = RegexTokenizer()
+ >>> regex_tokenizer("Mary sold the car to John.")
+
+ """
+ # check if input is batched or a single sample
+ is_batched = self.check_is_batched(texts, is_split_into_words)
+
+ if is_batched:
+ tokenized = self.tokenize_batch(texts)
+ else:
+ tokenized = self.tokenize(texts)
+
+ return tokenized
+
+ @overrides
+ def tokenize(self, text: Union[str, List[str]]) -> List[Word]:
+ if not isinstance(text, (str, list)):
+ raise ValueError(
+ f"text must be either `str` or `list`, found: `{type(text)}`"
+ )
+
+ if isinstance(text, list):
+ text = " ".join(text)
+ return [
+ Word(t[0], i, start_char=t[1], end_char=t[2])
+ for i, t in enumerate(
+ (m.group(0), m.start(), m.end()) for m in self._regex.finditer(text)
+ )
+ ]
diff --git a/relik/inference/data/tokenizers/spacy_tokenizer.py b/relik/inference/data/tokenizers/spacy_tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..b949216ed5cf152ae4a7722c4a6be3f883481db2
--- /dev/null
+++ b/relik/inference/data/tokenizers/spacy_tokenizer.py
@@ -0,0 +1,228 @@
+import logging
+from typing import Dict, List, Tuple, Union
+
+import spacy
+
+# from ipa.common.utils import load_spacy
+from overrides import overrides
+from spacy.cli.download import download as spacy_download
+from spacy.tokens import Doc
+
+from relik.common.log import get_logger
+from relik.inference.data.objects import Word
+from relik.inference.data.tokenizers import SPACY_LANGUAGE_MAPPER
+from relik.inference.data.tokenizers.base_tokenizer import BaseTokenizer
+
+logger = get_logger(level=logging.DEBUG)
+
+# Spacy and Stanza stuff
+
+LOADED_SPACY_MODELS: Dict[Tuple[str, bool, bool, bool, bool], spacy.Language] = {}
+
+
+def load_spacy(
+ language: str,
+ pos_tags: bool = False,
+ lemma: bool = False,
+ parse: bool = False,
+ split_on_spaces: bool = False,
+) -> spacy.Language:
+ """
+ Download and load spacy model.
+
+ Args:
+ language (:obj:`str`, defaults to :obj:`en`):
+ Language of the text to tokenize.
+ pos_tags (:obj:`bool`, optional, defaults to :obj:`False`):
+ If :obj:`True`, performs POS tagging with spacy model.
+ lemma (:obj:`bool`, optional, defaults to :obj:`False`):
+ If :obj:`True`, performs lemmatization with spacy model.
+ parse (:obj:`bool`, optional, defaults to :obj:`False`):
+ If :obj:`True`, performs dependency parsing with spacy model.
+ split_on_spaces (:obj:`bool`, optional, defaults to :obj:`False`):
+ If :obj:`True`, will split by spaces without performing tokenization.
+
+ Returns:
+ :obj:`spacy.Language`: The spacy model loaded.
+ """
+ exclude = ["vectors", "textcat", "ner"]
+ if not pos_tags:
+ exclude.append("tagger")
+ if not lemma:
+ exclude.append("lemmatizer")
+ if not parse:
+ exclude.append("parser")
+
+ # check if the model is already loaded
+ # if so, there is no need to reload it
+ spacy_params = (language, pos_tags, lemma, parse, split_on_spaces)
+ if spacy_params not in LOADED_SPACY_MODELS:
+ try:
+ spacy_tagger = spacy.load(language, exclude=exclude)
+ except OSError:
+ logger.warning(
+ "Spacy model '%s' not found. Downloading and installing.", language
+ )
+ spacy_download(language)
+ spacy_tagger = spacy.load(language, exclude=exclude)
+
+ # if everything is disabled, return only the tokenizer
+ # for faster tokenization
+ # TODO: is it really faster?
+ # if len(exclude) >= 6:
+ # spacy_tagger = spacy_tagger.tokenizer
+ LOADED_SPACY_MODELS[spacy_params] = spacy_tagger
+
+ return LOADED_SPACY_MODELS[spacy_params]
+
+
+class SpacyTokenizer(BaseTokenizer):
+ """
+ A :obj:`Tokenizer` that uses SpaCy to tokenizer and preprocess the text. It returns :obj:`Word` objects.
+
+ Args:
+ language (:obj:`str`, optional, defaults to :obj:`en`):
+ Language of the text to tokenize.
+ return_pos_tags (:obj:`bool`, optional, defaults to :obj:`False`):
+ If :obj:`True`, performs POS tagging with spacy model.
+ return_lemmas (:obj:`bool`, optional, defaults to :obj:`False`):
+ If :obj:`True`, performs lemmatization with spacy model.
+ return_deps (:obj:`bool`, optional, defaults to :obj:`False`):
+ If :obj:`True`, performs dependency parsing with spacy model.
+ split_on_spaces (:obj:`bool`, optional, defaults to :obj:`False`):
+ If :obj:`True`, will split by spaces without performing tokenization.
+ use_gpu (:obj:`bool`, optional, defaults to :obj:`False`):
+ If :obj:`True`, will load the Stanza model on GPU.
+ """
+
+ def __init__(
+ self,
+ language: str = "en",
+ return_pos_tags: bool = False,
+ return_lemmas: bool = False,
+ return_deps: bool = False,
+ split_on_spaces: bool = False,
+ use_gpu: bool = False,
+ ):
+ super(SpacyTokenizer, self).__init__()
+ if language not in SPACY_LANGUAGE_MAPPER:
+ raise ValueError(
+ f"`{language}` language not supported. The supported "
+ f"languages are: {list(SPACY_LANGUAGE_MAPPER.keys())}."
+ )
+ if use_gpu:
+ # load the model on GPU
+ # if the GPU is not available or not correctly configured,
+ # it will rise an error
+ spacy.require_gpu()
+ self.spacy = load_spacy(
+ SPACY_LANGUAGE_MAPPER[language],
+ return_pos_tags,
+ return_lemmas,
+ return_deps,
+ split_on_spaces,
+ )
+ self.split_on_spaces = split_on_spaces
+
+ def __call__(
+ self,
+ texts: Union[str, List[str], List[List[str]]],
+ is_split_into_words: bool = False,
+ **kwargs,
+ ) -> Union[List[Word], List[List[Word]]]:
+ """
+ Tokenize the input into single words using SpaCy models.
+
+ Args:
+ texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`):
+ Text to tag. It can be a single string, a batch of string and pre-tokenized strings.
+ is_split_into_words (:obj:`bool`, optional, defaults to :obj:`False`):
+ If :obj:`True` and the input is a string, the input is split on spaces.
+
+ Returns:
+ :obj:`List[List[Word]]`: The input text tokenized in single words.
+
+ Example::
+
+ >>> from ipa import SpacyTokenizer
+
+ >>> spacy_tokenizer = SpacyTokenizer(language="en", pos_tags=True, lemma=True)
+ >>> spacy_tokenizer("Mary sold the car to John.")
+
+ """
+ # check if input is batched or a single sample
+ is_batched = self.check_is_batched(texts, is_split_into_words)
+ if is_batched:
+ tokenized = self.tokenize_batch(texts)
+ else:
+ tokenized = self.tokenize(texts)
+ return tokenized
+
+ @overrides
+ def tokenize(self, text: Union[str, List[str]]) -> List[Word]:
+ if self.split_on_spaces:
+ if isinstance(text, str):
+ text = text.split(" ")
+ spaces = [True] * len(text)
+ text = Doc(self.spacy.vocab, words=text, spaces=spaces)
+ return self._clean_tokens(self.spacy(text))
+
+ @overrides
+ def tokenize_batch(
+ self, texts: Union[List[str], List[List[str]]]
+ ) -> List[List[Word]]:
+ if self.split_on_spaces:
+ if isinstance(texts[0], str):
+ texts = [text.split(" ") for text in texts]
+ spaces = [[True] * len(text) for text in texts]
+ texts = [
+ Doc(self.spacy.vocab, words=text, spaces=space)
+ for text, space in zip(texts, spaces)
+ ]
+ return [self._clean_tokens(tokens) for tokens in self.spacy.pipe(texts)]
+
+ @staticmethod
+ def _clean_tokens(tokens: Doc) -> List[Word]:
+ """
+ Converts spaCy tokens to :obj:`Word`.
+
+ Args:
+ tokens (:obj:`spacy.tokens.Doc`):
+ Tokens from SpaCy model.
+
+ Returns:
+ :obj:`List[Word]`: The SpaCy model output converted into :obj:`Word` objects.
+ """
+ words = [
+ Word(
+ token.text,
+ token.i,
+ token.idx,
+ token.idx + len(token),
+ token.lemma_,
+ token.pos_,
+ token.dep_,
+ token.head.i,
+ )
+ for token in tokens
+ ]
+ return words
+
+
+class WhitespaceSpacyTokenizer:
+ """Simple white space tokenizer for SpaCy."""
+
+ def __init__(self, vocab):
+ self.vocab = vocab
+
+ def __call__(self, text):
+ if isinstance(text, str):
+ words = text.split(" ")
+ elif isinstance(text, list):
+ words = text
+ else:
+ raise ValueError(
+ f"text must be either `str` or `list`, found: `{type(text)}`"
+ )
+ spaces = [True] * len(words)
+ return Doc(self.vocab, words=words, spaces=spaces)
diff --git a/relik/inference/data/tokenizers/whitespace_tokenizer.py b/relik/inference/data/tokenizers/whitespace_tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..537ab6fe21eb4f9378d96d7cebfbc8cb12c36104
--- /dev/null
+++ b/relik/inference/data/tokenizers/whitespace_tokenizer.py
@@ -0,0 +1,70 @@
+import re
+from typing import List, Union
+
+from overrides import overrides
+
+from relik.inference.data.objects import Word
+from relik.inference.data.tokenizers.base_tokenizer import BaseTokenizer
+
+
+class WhitespaceTokenizer(BaseTokenizer):
+ """
+ A :obj:`Tokenizer` that splits the text on spaces.
+ """
+
+ def __init__(self):
+ super(WhitespaceTokenizer, self).__init__()
+ self.whitespace_regex = re.compile(r"\S+")
+
+ def __call__(
+ self,
+ texts: Union[str, List[str], List[List[str]]],
+ is_split_into_words: bool = False,
+ **kwargs,
+ ) -> List[List[Word]]:
+ """
+ Tokenize the input into single words by splitting on spaces.
+
+ Args:
+ texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`):
+ Text to tag. It can be a single string, a batch of string and pre-tokenized strings.
+ is_split_into_words (:obj:`bool`, optional, defaults to :obj:`False`):
+ If :obj:`True` and the input is a string, the input is split on spaces.
+
+ Returns:
+ :obj:`List[List[Word]]`: The input text tokenized in single words.
+
+ Example::
+
+ >>> from nlp_preprocessing_wrappers import WhitespaceTokenizer
+
+ >>> whitespace_tokenizer = WhitespaceTokenizer()
+ >>> whitespace_tokenizer("Mary sold the car to John .")
+
+ """
+ # check if input is batched or a single sample
+ is_batched = self.check_is_batched(texts, is_split_into_words)
+
+ if is_batched:
+ tokenized = self.tokenize_batch(texts)
+ else:
+ tokenized = self.tokenize(texts)
+
+ return tokenized
+
+ @overrides
+ def tokenize(self, text: Union[str, List[str]]) -> List[Word]:
+ if not isinstance(text, (str, list)):
+ raise ValueError(
+ f"text must be either `str` or `list`, found: `{type(text)}`"
+ )
+
+ if isinstance(text, list):
+ text = " ".join(text)
+ return [
+ Word(t[0], i, start_char=t[1], end_char=t[2])
+ for i, t in enumerate(
+ (m.group(0), m.start(), m.end())
+ for m in self.whitespace_regex.finditer(text)
+ )
+ ]
diff --git a/relik/inference/data/window/__init__.py b/relik/inference/data/window/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/relik/inference/data/window/manager.py b/relik/inference/data/window/manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..420609b1827f13bb332780554e3e20421908f6e9
--- /dev/null
+++ b/relik/inference/data/window/manager.py
@@ -0,0 +1,262 @@
+import collections
+import itertools
+from dataclasses import dataclass
+from typing import List, Optional, Set, Tuple
+
+from relik.inference.data.tokenizers.base_tokenizer import BaseTokenizer
+from relik.reader.data.relik_reader_sample import RelikReaderSample
+
+
+@dataclass
+class Window:
+ doc_id: int
+ window_id: int
+ text: str
+ tokens: List[str]
+ doc_topic: Optional[str]
+ offset: int
+ token2char_start: dict
+ token2char_end: dict
+ window_candidates: Optional[List[str]] = None
+
+
+class WindowManager:
+ def __init__(self, tokenizer: BaseTokenizer) -> None:
+ self.tokenizer = tokenizer
+
+ def tokenize(self, document: str) -> Tuple[List[str], List[Tuple[int, int]]]:
+ tokenized_document = self.tokenizer(document)
+ tokens = []
+ tokens_char_mapping = []
+ for token in tokenized_document:
+ tokens.append(token.text)
+ tokens_char_mapping.append((token.start_char, token.end_char))
+ return tokens, tokens_char_mapping
+
+ def create_windows(
+ self,
+ document: str,
+ window_size: int,
+ stride: int,
+ doc_id: int = 0,
+ doc_topic: str = None,
+ ) -> List[RelikReaderSample]:
+ document_tokens, tokens_char_mapping = self.tokenize(document)
+ if doc_topic is None:
+ doc_topic = document_tokens[0] if len(document_tokens) > 0 else ""
+ document_windows = []
+ if len(document_tokens) <= window_size:
+ text = document
+ # relik_reader_sample = RelikReaderSample()
+ document_windows.append(
+ # Window(
+ RelikReaderSample(
+ doc_id=doc_id,
+ window_id=0,
+ text=text,
+ tokens=document_tokens,
+ doc_topic=doc_topic,
+ offset=0,
+ token2char_start={
+ str(i): tokens_char_mapping[i][0]
+ for i in range(len(document_tokens))
+ },
+ token2char_end={
+ str(i): tokens_char_mapping[i][1]
+ for i in range(len(document_tokens))
+ },
+ )
+ )
+ else:
+ for window_id, i in enumerate(range(0, len(document_tokens), stride)):
+ # if the last stride is smaller than the window size, then we can
+ # include more tokens form the previous window.
+ if i != 0 and i + window_size > len(document_tokens):
+ overflowing_tokens = i + window_size - len(document_tokens)
+ if overflowing_tokens >= stride:
+ break
+ i -= overflowing_tokens
+
+ involved_token_indices = list(
+ range(i, min(i + window_size, len(document_tokens) - 1))
+ )
+ window_tokens = [document_tokens[j] for j in involved_token_indices]
+ window_text_start = tokens_char_mapping[involved_token_indices[0]][0]
+ window_text_end = tokens_char_mapping[involved_token_indices[-1]][1]
+ text = document[window_text_start:window_text_end]
+ document_windows.append(
+ # Window(
+ RelikReaderSample(
+ # dict(
+ doc_id=doc_id,
+ window_id=window_id,
+ text=text,
+ tokens=window_tokens,
+ doc_topic=doc_topic,
+ offset=window_text_start,
+ token2char_start={
+ str(i): tokens_char_mapping[ti][0]
+ for i, ti in enumerate(involved_token_indices)
+ },
+ token2char_end={
+ str(i): tokens_char_mapping[ti][1]
+ for i, ti in enumerate(involved_token_indices)
+ },
+ # )
+ )
+ )
+ return document_windows
+
+ def merge_windows(
+ self, windows: List[RelikReaderSample]
+ ) -> List[RelikReaderSample]:
+ windows_by_doc_id = collections.defaultdict(list)
+ for window in windows:
+ windows_by_doc_id[window.doc_id].append(window)
+
+ merged_window_by_doc = {
+ doc_id: self.merge_doc_windows(doc_windows)
+ for doc_id, doc_windows in windows_by_doc_id.items()
+ }
+
+ return list(merged_window_by_doc.values())
+
+ def merge_doc_windows(self, windows: List[RelikReaderSample]) -> RelikReaderSample:
+ if len(windows) == 1:
+ return windows[0]
+
+ if len(windows) > 0 and getattr(windows[0], "offset", None) is not None:
+ windows = sorted(windows, key=(lambda x: x.offset))
+
+ window_accumulator = windows[0]
+
+ for next_window in windows[1:]:
+ window_accumulator = self._merge_window_pair(
+ window_accumulator, next_window
+ )
+
+ return window_accumulator
+
+ def _merge_tokens(
+ self, window1: RelikReaderSample, window2: RelikReaderSample
+ ) -> Tuple[list, dict, dict]:
+ w1_tokens = window1.tokens[1:-1]
+ w2_tokens = window2.tokens[1:-1]
+
+ # find intersection
+ tokens_intersection = None
+ for k in reversed(range(1, len(w1_tokens))):
+ if w1_tokens[-k:] == w2_tokens[:k]:
+ tokens_intersection = k
+ break
+ assert tokens_intersection is not None, (
+ f"{window1.doc_id} - {window1.sent_id} - {window1.offset}"
+ + f" {window2.doc_id} - {window2.sent_id} - {window2.offset}\n"
+ + f"w1 tokens: {w1_tokens}\n"
+ + f"w2 tokens: {w2_tokens}\n"
+ )
+
+ final_tokens = (
+ [window1.tokens[0]] # CLS
+ + w1_tokens
+ + w2_tokens[tokens_intersection:]
+ + [window1.tokens[-1]] # SEP
+ )
+
+ w2_starting_offset = len(w1_tokens) - tokens_intersection
+
+ def merge_char_mapping(t2c1: dict, t2c2: dict) -> dict:
+ final_t2c = dict()
+ final_t2c.update(t2c1)
+ for t, c in t2c2.items():
+ t = int(t)
+ if t < tokens_intersection:
+ continue
+ final_t2c[str(t + w2_starting_offset)] = c
+ return final_t2c
+
+ return (
+ final_tokens,
+ merge_char_mapping(window1.token2char_start, window2.token2char_start),
+ merge_char_mapping(window1.token2char_end, window2.token2char_end),
+ )
+
+ def _merge_span_annotation(
+ self, span_annotation1: List[list], span_annotation2: List[list]
+ ) -> List[list]:
+ uniq_store = set()
+ final_span_annotation_store = []
+ for span_annotation in itertools.chain(span_annotation1, span_annotation2):
+ span_annotation_id = tuple(span_annotation)
+ if span_annotation_id not in uniq_store:
+ uniq_store.add(span_annotation_id)
+ final_span_annotation_store.append(span_annotation)
+ return sorted(final_span_annotation_store, key=lambda x: x[0])
+
+ def _merge_predictions(
+ self,
+ window1: RelikReaderSample,
+ window2: RelikReaderSample,
+ ) -> Tuple[Set[Tuple[int, int, str]], dict]:
+ merged_predictions = window1.predicted_window_labels_chars.union(
+ window2.predicted_window_labels_chars
+ )
+
+ span_title_probabilities = dict()
+ # probabilities
+ for span_prediction, predicted_probs in itertools.chain(
+ window1.probs_window_labels_chars.items(),
+ window2.probs_window_labels_chars.items(),
+ ):
+ if span_prediction not in span_title_probabilities:
+ span_title_probabilities[span_prediction] = predicted_probs
+
+ return merged_predictions, span_title_probabilities
+
+ def _merge_window_pair(
+ self,
+ window1: RelikReaderSample,
+ window2: RelikReaderSample,
+ ) -> RelikReaderSample:
+ merging_output = dict()
+
+ if getattr(window1, "doc_id", None) is not None:
+ assert window1.doc_id == window2.doc_id
+
+ if getattr(window1, "offset", None) is not None:
+ assert (
+ window1.offset < window2.offset
+ ), f"window 2 offset ({window2.offset}) is smaller that window 1 offset({window1.offset})"
+
+ merging_output["doc_id"] = window1.doc_id
+ merging_output["offset"] = window2.offset
+
+ m_tokens, m_token2char_start, m_token2char_end = self._merge_tokens(
+ window1, window2
+ )
+
+ window_labels = None
+ if getattr(window1, "window_labels", None) is not None:
+ window_labels = self._merge_span_annotation(
+ window1.window_labels, window2.window_labels
+ )
+ (
+ predicted_window_labels_chars,
+ probs_window_labels_chars,
+ ) = self._merge_predictions(
+ window1,
+ window2,
+ )
+
+ merging_output.update(
+ dict(
+ tokens=m_tokens,
+ token2char_start=m_token2char_start,
+ token2char_end=m_token2char_end,
+ window_labels=window_labels,
+ predicted_window_labels_chars=predicted_window_labels_chars,
+ probs_window_labels_chars=probs_window_labels_chars,
+ )
+ )
+
+ return RelikReaderSample(**merging_output)
diff --git a/relik/inference/gerbil.py b/relik/inference/gerbil.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4c3f17cacea1d5472de99d1a974ad098585fc20
--- /dev/null
+++ b/relik/inference/gerbil.py
@@ -0,0 +1,254 @@
+import argparse
+import json
+import os
+import re
+import sys
+from http.server import BaseHTTPRequestHandler, HTTPServer
+from typing import Iterator, List, Optional, Tuple
+
+from relik.inference.annotator import Relik
+from relik.inference.data.objects import RelikOutput
+
+# sys.path += ['../']
+sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
+
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class GerbilAlbyManager:
+ def __init__(
+ self,
+ annotator: Optional[Relik] = None,
+ response_logger_dir: Optional[str] = None,
+ ) -> None:
+ self.annotator = annotator
+ self.response_logger_dir = response_logger_dir
+ self.predictions_counter = 0
+ self.labels_mapping = None
+
+ def annotate(self, document: str):
+ relik_output: RelikOutput = self.annotator(document)
+ annotations = [(ss, se, l) for ss, se, l, _ in relik_output.labels]
+ if self.labels_mapping is not None:
+ return [
+ (ss, se, self.labels_mapping.get(l, l)) for ss, se, l in annotations
+ ]
+ return annotations
+
+ def set_mapping_file(self, mapping_file_path: str):
+ with open(mapping_file_path) as f:
+ labels_mapping = json.load(f)
+ self.labels_mapping = {v: k for k, v in labels_mapping.items()}
+
+ def write_response_bundle(
+ self,
+ document: str,
+ new_document: str,
+ annotations: list,
+ mapped_annotations: list,
+ ) -> None:
+ if self.response_logger_dir is None:
+ return
+
+ if not os.path.isdir(self.response_logger_dir):
+ os.mkdir(self.response_logger_dir)
+
+ with open(
+ f"{self.response_logger_dir}/{self.predictions_counter}.json", "w"
+ ) as f:
+ out_json_obj = dict(
+ document=document,
+ new_document=new_document,
+ annotations=annotations,
+ mapped_annotations=mapped_annotations,
+ )
+
+ out_json_obj["span_annotations"] = [
+ (ss, se, document[ss:se], label) for (ss, se, label) in annotations
+ ]
+
+ out_json_obj["span_mapped_annotations"] = [
+ (ss, se, new_document[ss:se], label)
+ for (ss, se, label) in mapped_annotations
+ ]
+
+ json.dump(out_json_obj, f, indent=2)
+
+ self.predictions_counter += 1
+
+
+manager = GerbilAlbyManager()
+
+
+def preprocess_document(document: str) -> Tuple[str, List[Tuple[int, int]]]:
+ pattern_subs = {
+ "-LPR- ": " (",
+ "-RPR-": ")",
+ "\n\n": "\n",
+ "-LRB-": "(",
+ "-RRB-": ")",
+ '","': ",",
+ }
+
+ document_acc = document
+ curr_offset = 0
+ char2offset = []
+
+ matchings = re.finditer("({})".format("|".join(pattern_subs)), document)
+ for span_matching in sorted(matchings, key=lambda x: x.span()[0]):
+ span_start, span_end = span_matching.span()
+ span_start -= curr_offset
+ span_end -= curr_offset
+
+ span_text = document_acc[span_start:span_end]
+ span_sub = pattern_subs[span_text]
+ document_acc = document_acc[:span_start] + span_sub + document_acc[span_end:]
+
+ offset = len(span_text) - len(span_sub)
+ curr_offset += offset
+
+ char2offset.append((span_start + len(span_sub), curr_offset))
+
+ return document_acc, char2offset
+
+
+def map_back_annotations(
+ annotations: List[Tuple[int, int, str]], char_mapping: List[Tuple[int, int]]
+) -> Iterator[Tuple[int, int, str]]:
+ def map_char(char_idx: int) -> int:
+ current_offset = 0
+ for offset_idx, offset_value in char_mapping:
+ if char_idx >= offset_idx:
+ current_offset = offset_value
+ else:
+ break
+ return char_idx + current_offset
+
+ for ss, se, label in annotations:
+ yield map_char(ss), map_char(se), label
+
+
+def annotate(document: str) -> List[Tuple[int, int, str]]:
+ new_document, mapping = preprocess_document(document)
+ logger.info("Mapping: " + str(mapping))
+ logger.info("Document: " + str(document))
+ annotations = [
+ (cs, ce, label.replace(" ", "_"))
+ for cs, ce, label in manager.annotate(new_document)
+ ]
+ logger.info("New document: " + str(new_document))
+ mapped_annotations = (
+ list(map_back_annotations(annotations, mapping))
+ if len(mapping) > 0
+ else annotations
+ )
+
+ logger.info(
+ "Annotations: "
+ + str([(ss, se, document[ss:se], ann) for ss, se, ann in mapped_annotations])
+ )
+
+ manager.write_response_bundle(
+ document, new_document, mapped_annotations, annotations
+ )
+
+ if not all(
+ [
+ new_document[ss:se] == document[mss:mse]
+ for (mss, mse, _), (ss, se, _) in zip(mapped_annotations, annotations)
+ ]
+ ):
+ diff_mappings = [
+ (new_document[ss:se], document[mss:mse])
+ for (mss, mse, _), (ss, se, _) in zip(mapped_annotations, annotations)
+ ]
+ return None
+ assert all(
+ [
+ document[mss:mse] == new_document[ss:se]
+ for (mss, mse, _), (ss, se, _) in zip(mapped_annotations, annotations)
+ ]
+ ), (mapped_annotations, annotations)
+
+ return [(cs, ce - cs, label) for cs, ce, label in mapped_annotations]
+
+
+class GetHandler(BaseHTTPRequestHandler):
+ def do_POST(self):
+ content_length = int(self.headers["Content-Length"])
+ post_data = self.rfile.read(content_length)
+ self.send_response(200)
+ self.end_headers()
+ doc_text = read_json(post_data)
+ # try:
+ response = annotate(doc_text)
+
+ self.wfile.write(bytes(json.dumps(response), "utf-8"))
+ return
+
+
+def read_json(post_data):
+ data = json.loads(post_data.decode("utf-8"))
+ # logger.info("received data:", data)
+ text = data["text"]
+ # spans = [(int(j["start"]), int(j["length"])) for j in data["spans"]]
+ return text
+
+
+def parse_args() -> argparse.Namespace:
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--relik-model-name", required=True)
+ parser.add_argument("--responses-log-dir")
+ parser.add_argument("--log-file", default="logs/logging.txt")
+ parser.add_argument("--mapping-file")
+ return parser.parse_args()
+
+
+def main():
+ args = parse_args()
+
+ # init manager
+ manager.response_logger_dir = args.responses_log_dir
+ # manager.annotator = Relik.from_pretrained(args.relik_model_name)
+
+ print("Debugging, not using you relik model but an hardcoded one.")
+ manager.annotator = Relik(
+ question_encoder="riccorl/relik-retriever-aida-blink-pretrain-omniencoder",
+ document_index="riccorl/index-relik-retriever-aida-blink-pretrain-omniencoder",
+ reader="relik/reader/models/relik-reader-deberta-base-new-data",
+ window_size=32,
+ window_stride=16,
+ candidates_preprocessing_fn=(lambda x: x.split("")[0].strip()),
+ )
+
+ if args.mapping_file is not None:
+ manager.set_mapping_file(args.mapping_file)
+
+ port = 6654
+ server = HTTPServer(("localhost", port), GetHandler)
+ logger.info(f"Starting server at http://localhost:{port}")
+
+ # Create a file handler and set its level
+ file_handler = logging.FileHandler(args.log_file)
+ file_handler.setLevel(logging.DEBUG)
+
+ # Create a log formatter and set it on the handler
+ formatter = logging.Formatter(
+ "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
+ )
+ file_handler.setFormatter(formatter)
+
+ # Add the file handler to the logger
+ logger.addHandler(file_handler)
+
+ try:
+ server.serve_forever()
+ except KeyboardInterrupt:
+ exit(0)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/relik/inference/preprocessing.py b/relik/inference/preprocessing.py
new file mode 100644
index 0000000000000000000000000000000000000000..2476fe47ea64d907a8c32c31082253c45b48720c
--- /dev/null
+++ b/relik/inference/preprocessing.py
@@ -0,0 +1,4 @@
+def wikipedia_title_and_openings_preprocessing(
+ wikipedia_title_and_openings: str, sepator: str = " "
+):
+ return wikipedia_title_and_openings.split(sepator, 1)[0]
diff --git a/relik/inference/serve/__init__.py b/relik/inference/serve/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/relik/inference/serve/backend/__init__.py b/relik/inference/serve/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/relik/inference/serve/backend/relik.py b/relik/inference/serve/backend/relik.py
new file mode 100644
index 0000000000000000000000000000000000000000..038e2ef78afbccb0758162996e35cd8dc858d453
--- /dev/null
+++ b/relik/inference/serve/backend/relik.py
@@ -0,0 +1,210 @@
+import logging
+from pathlib import Path
+from typing import List, Optional, Union
+
+from relik.common.utils import is_package_available
+from relik.inference.annotator import Relik
+
+if not is_package_available("fastapi"):
+ raise ImportError(
+ "FastAPI is not installed. Please install FastAPI with `pip install relik[serve]`."
+ )
+from fastapi import FastAPI, HTTPException
+
+if not is_package_available("ray"):
+ raise ImportError(
+ "Ray is not installed. Please install Ray with `pip install relik[serve]`."
+ )
+from ray import serve
+
+from relik.common.log import get_logger
+from relik.inference.serve.backend.utils import (
+ RayParameterManager,
+ ServerParameterManager,
+)
+from relik.retriever.data.utils import batch_generator
+
+logger = get_logger(__name__, level=logging.INFO)
+
+VERSION = {} # type: ignore
+with open(
+ Path(__file__).parent.parent.parent.parent / "version.py", "r"
+) as version_file:
+ exec(version_file.read(), VERSION)
+
+# Env variables for server
+SERVER_MANAGER = ServerParameterManager()
+RAY_MANAGER = RayParameterManager()
+
+app = FastAPI(
+ title="ReLiK",
+ version=VERSION["VERSION"],
+ description="ReLiK REST API",
+)
+
+
+@serve.deployment(
+ ray_actor_options={
+ "num_gpus": RAY_MANAGER.num_gpus
+ if (
+ SERVER_MANAGER.retriver_device == "cuda"
+ or SERVER_MANAGER.reader_device == "cuda"
+ )
+ else 0
+ },
+ autoscaling_config={
+ "min_replicas": RAY_MANAGER.min_replicas,
+ "max_replicas": RAY_MANAGER.max_replicas,
+ },
+)
+@serve.ingress(app)
+class RelikServer:
+ def __init__(
+ self,
+ question_encoder: str,
+ document_index: str,
+ passage_encoder: Optional[str] = None,
+ reader_encoder: Optional[str] = None,
+ top_k: int = 100,
+ retriver_device: str = "cpu",
+ reader_device: str = "cpu",
+ index_device: Optional[str] = None,
+ precision: int = 32,
+ index_precision: Optional[int] = None,
+ use_faiss: bool = False,
+ window_batch_size: int = 32,
+ window_size: int = 32,
+ window_stride: int = 16,
+ split_on_spaces: bool = False,
+ ):
+ # parameters
+ self.question_encoder = question_encoder
+ self.passage_encoder = passage_encoder
+ self.reader_encoder = reader_encoder
+ self.document_index = document_index
+ self.top_k = top_k
+ self.retriver_device = retriver_device
+ self.index_device = index_device or retriver_device
+ self.reader_device = reader_device
+ self.precision = precision
+ self.index_precision = index_precision or precision
+ self.use_faiss = use_faiss
+ self.window_batch_size = window_batch_size
+ self.window_size = window_size
+ self.window_stride = window_stride
+ self.split_on_spaces = split_on_spaces
+
+ # log stuff for debugging
+ logger.info("Initializing RelikServer with parameters:")
+ logger.info(f"QUESTION_ENCODER: {self.question_encoder}")
+ logger.info(f"PASSAGE_ENCODER: {self.passage_encoder}")
+ logger.info(f"READER_ENCODER: {self.reader_encoder}")
+ logger.info(f"DOCUMENT_INDEX: {self.document_index}")
+ logger.info(f"TOP_K: {self.top_k}")
+ logger.info(f"RETRIEVER_DEVICE: {self.retriver_device}")
+ logger.info(f"READER_DEVICE: {self.reader_device}")
+ logger.info(f"INDEX_DEVICE: {self.index_device}")
+ logger.info(f"PRECISION: {self.precision}")
+ logger.info(f"INDEX_PRECISION: {self.index_precision}")
+ logger.info(f"WINDOW_BATCH_SIZE: {self.window_batch_size}")
+ logger.info(f"SPLIT_ON_SPACES: {self.split_on_spaces}")
+
+ self.relik = Relik(
+ question_encoder=self.question_encoder,
+ passage_encoder=self.passage_encoder,
+ document_index=self.document_index,
+ reader=self.reader_encoder,
+ retriever_device=self.retriver_device,
+ document_index_device=self.index_device,
+ reader_device=self.reader_device,
+ retriever_precision=self.precision,
+ document_index_precision=self.index_precision,
+ reader_precision=self.precision,
+ )
+
+ # @serve.batch()
+ async def handle_batch(self, documents: List[str]) -> List:
+ return self.relik(
+ documents,
+ top_k=self.top_k,
+ window_size=self.window_size,
+ window_stride=self.window_stride,
+ batch_size=self.window_batch_size,
+ )
+
+ @app.post("/api/entities")
+ async def entities_endpoint(
+ self,
+ documents: Union[str, List[str]],
+ ):
+ try:
+ # normalize input
+ if isinstance(documents, str):
+ documents = [documents]
+ if document_topics is not None:
+ if isinstance(document_topics, str):
+ document_topics = [document_topics]
+ assert len(documents) == len(document_topics)
+ # get predictions for the retriever
+ return await self.handle_batch(documents, document_topics)
+ except Exception as e:
+ # log the entire stack trace
+ logger.exception(e)
+ raise HTTPException(status_code=500, detail=f"Server Error: {e}")
+
+ @app.post("/api/gerbil")
+ async def gerbil_endpoint(self, documents: Union[str, List[str]]):
+ try:
+ # normalize input
+ if isinstance(documents, str):
+ documents = [documents]
+
+ # output list
+ windows_passages = []
+ # split documents into windows
+ document_windows = [
+ window
+ for doc_id, document in enumerate(documents)
+ for window in self.window_manager(
+ self.tokenizer,
+ document,
+ window_size=self.window_size,
+ stride=self.window_stride,
+ doc_id=doc_id,
+ )
+ ]
+
+ # get text and topic from document windows and create new list
+ model_inputs = [
+ (window.text, window.doc_topic) for window in document_windows
+ ]
+
+ # batch generator
+ for batch in batch_generator(
+ model_inputs, batch_size=self.window_batch_size
+ ):
+ text, text_pair = zip(*batch)
+ batch_predictions = await self.handle_batch_retriever(text, text_pair)
+ windows_passages.extend(
+ [
+ [p.label for p in predictions]
+ for predictions in batch_predictions
+ ]
+ )
+
+ # add passage to document windows
+ for window, passages in zip(document_windows, windows_passages):
+ # clean up passages (remove everything after first tag if present)
+ passages = [c.split(" ", 1)[0] for c in passages]
+ window.window_candidates = passages
+
+ # return document windows
+ return document_windows
+
+ except Exception as e:
+ # log the entire stack trace
+ logger.exception(e)
+ raise HTTPException(status_code=500, detail=f"Server Error: {e}")
+
+
+server = RelikServer.bind(**vars(SERVER_MANAGER))
diff --git a/relik/inference/serve/backend/retriever.py b/relik/inference/serve/backend/retriever.py
new file mode 100644
index 0000000000000000000000000000000000000000..e796893e76b83377a5f8b2c7afdccce21756dcbd
--- /dev/null
+++ b/relik/inference/serve/backend/retriever.py
@@ -0,0 +1,206 @@
+import logging
+from pathlib import Path
+from typing import List, Optional, Union
+
+from relik.common.utils import is_package_available
+
+if not is_package_available("fastapi"):
+ raise ImportError(
+ "FastAPI is not installed. Please install FastAPI with `pip install relik[serve]`."
+ )
+from fastapi import FastAPI, HTTPException
+
+if not is_package_available("ray"):
+ raise ImportError(
+ "Ray is not installed. Please install Ray with `pip install relik[serve]`."
+ )
+from ray import serve
+
+from relik.common.log import get_logger
+from relik.inference.data.tokenizers import SpacyTokenizer, WhitespaceTokenizer
+from relik.inference.data.window.manager import WindowManager
+from relik.inference.serve.backend.utils import (
+ RayParameterManager,
+ ServerParameterManager,
+)
+from relik.retriever.data.utils import batch_generator
+from relik.retriever.pytorch_modules import GoldenRetriever
+
+logger = get_logger(__name__, level=logging.INFO)
+
+VERSION = {} # type: ignore
+with open(Path(__file__).parent.parent.parent / "version.py", "r") as version_file:
+ exec(version_file.read(), VERSION)
+
+# Env variables for server
+SERVER_MANAGER = ServerParameterManager()
+RAY_MANAGER = RayParameterManager()
+
+app = FastAPI(
+ title="Golden Retriever",
+ version=VERSION["VERSION"],
+ description="Golden Retriever REST API",
+)
+
+
+@serve.deployment(
+ ray_actor_options={
+ "num_gpus": RAY_MANAGER.num_gpus if SERVER_MANAGER.device == "cuda" else 0
+ },
+ autoscaling_config={
+ "min_replicas": RAY_MANAGER.min_replicas,
+ "max_replicas": RAY_MANAGER.max_replicas,
+ },
+)
+@serve.ingress(app)
+class GoldenRetrieverServer:
+ def __init__(
+ self,
+ question_encoder: str,
+ document_index: str,
+ passage_encoder: Optional[str] = None,
+ top_k: int = 100,
+ device: str = "cpu",
+ index_device: Optional[str] = None,
+ precision: int = 32,
+ index_precision: Optional[int] = None,
+ use_faiss: bool = False,
+ window_batch_size: int = 32,
+ window_size: int = 32,
+ window_stride: int = 16,
+ split_on_spaces: bool = False,
+ ):
+ # parameters
+ self.question_encoder = question_encoder
+ self.passage_encoder = passage_encoder
+ self.document_index = document_index
+ self.top_k = top_k
+ self.device = device
+ self.index_device = index_device or device
+ self.precision = precision
+ self.index_precision = index_precision or precision
+ self.use_faiss = use_faiss
+ self.window_batch_size = window_batch_size
+ self.window_size = window_size
+ self.window_stride = window_stride
+ self.split_on_spaces = split_on_spaces
+
+ # log stuff for debugging
+ logger.info("Initializing GoldenRetrieverServer with parameters:")
+ logger.info(f"QUESTION_ENCODER: {self.question_encoder}")
+ logger.info(f"PASSAGE_ENCODER: {self.passage_encoder}")
+ logger.info(f"DOCUMENT_INDEX: {self.document_index}")
+ logger.info(f"TOP_K: {self.top_k}")
+ logger.info(f"DEVICE: {self.device}")
+ logger.info(f"INDEX_DEVICE: {self.index_device}")
+ logger.info(f"PRECISION: {self.precision}")
+ logger.info(f"INDEX_PRECISION: {self.index_precision}")
+ logger.info(f"WINDOW_BATCH_SIZE: {self.window_batch_size}")
+ logger.info(f"SPLIT_ON_SPACES: {self.split_on_spaces}")
+
+ self.retriever = GoldenRetriever(
+ question_encoder=self.question_encoder,
+ passage_encoder=self.passage_encoder,
+ document_index=self.document_index,
+ device=self.device,
+ index_device=self.index_device,
+ index_precision=self.index_precision,
+ )
+ self.retriever.eval()
+
+ if self.split_on_spaces:
+ logger.info("Using WhitespaceTokenizer")
+ self.tokenizer = WhitespaceTokenizer()
+ # logger.info("Using RegexTokenizer")
+ # self.tokenizer = RegexTokenizer()
+ else:
+ logger.info("Using SpacyTokenizer")
+ self.tokenizer = SpacyTokenizer(language="en")
+
+ self.window_manager = WindowManager(tokenizer=self.tokenizer)
+
+ # @serve.batch()
+ async def handle_batch(
+ self, documents: List[str], document_topics: List[str]
+ ) -> List:
+ return self.retriever.retrieve(
+ documents, text_pair=document_topics, k=self.top_k, precision=self.precision
+ )
+
+ @app.post("/api/retrieve")
+ async def retrieve_endpoint(
+ self,
+ documents: Union[str, List[str]],
+ document_topics: Optional[Union[str, List[str]]] = None,
+ ):
+ try:
+ # normalize input
+ if isinstance(documents, str):
+ documents = [documents]
+ if document_topics is not None:
+ if isinstance(document_topics, str):
+ document_topics = [document_topics]
+ assert len(documents) == len(document_topics)
+ # get predictions
+ return await self.handle_batch(documents, document_topics)
+ except Exception as e:
+ # log the entire stack trace
+ logger.exception(e)
+ raise HTTPException(status_code=500, detail=f"Server Error: {e}")
+
+ @app.post("/api/gerbil")
+ async def gerbil_endpoint(self, documents: Union[str, List[str]]):
+ try:
+ # normalize input
+ if isinstance(documents, str):
+ documents = [documents]
+
+ # output list
+ windows_passages = []
+ # split documents into windows
+ document_windows = [
+ window
+ for doc_id, document in enumerate(documents)
+ for window in self.window_manager(
+ self.tokenizer,
+ document,
+ window_size=self.window_size,
+ stride=self.window_stride,
+ doc_id=doc_id,
+ )
+ ]
+
+ # get text and topic from document windows and create new list
+ model_inputs = [
+ (window.text, window.doc_topic) for window in document_windows
+ ]
+
+ # batch generator
+ for batch in batch_generator(
+ model_inputs, batch_size=self.window_batch_size
+ ):
+ text, text_pair = zip(*batch)
+ batch_predictions = await self.handle_batch(text, text_pair)
+ windows_passages.extend(
+ [
+ [p.label for p in predictions]
+ for predictions in batch_predictions
+ ]
+ )
+
+ # add passage to document windows
+ for window, passages in zip(document_windows, windows_passages):
+ # clean up passages (remove everything after first tag if present)
+ passages = [c.split(" ", 1)[0] for c in passages]
+ window.window_candidates = passages
+
+ # return document windows
+ return document_windows
+
+ except Exception as e:
+ # log the entire stack trace
+ logger.exception(e)
+ raise HTTPException(status_code=500, detail=f"Server Error: {e}")
+
+
+server = GoldenRetrieverServer.bind(**vars(SERVER_MANAGER))
diff --git a/relik/inference/serve/backend/utils.py b/relik/inference/serve/backend/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..bdf869c1ece0e260355526ee5fcc2f00da7ef887
--- /dev/null
+++ b/relik/inference/serve/backend/utils.py
@@ -0,0 +1,29 @@
+import os
+from dataclasses import dataclass
+from typing import Union
+
+
+@dataclass
+class ServerParameterManager:
+ retriver_device: str = os.environ.get("RETRIEVER_DEVICE", "cpu")
+ reader_device: str = os.environ.get("READER_DEVICE", "cpu")
+ index_device: str = os.environ.get("INDEX_DEVICE", retriver_device)
+ precision: Union[str, int] = os.environ.get("PRECISION", "fp32")
+ index_precision: Union[str, int] = os.environ.get("INDEX_PRECISION", precision)
+ question_encoder: str = os.environ.get("QUESTION_ENCODER", None)
+ passage_encoder: str = os.environ.get("PASSAGE_ENCODER", None)
+ document_index: str = os.environ.get("DOCUMENT_INDEX", None)
+ reader_encoder: str = os.environ.get("READER_ENCODER", None)
+ top_k: int = int(os.environ.get("TOP_K", 100))
+ use_faiss: bool = os.environ.get("USE_FAISS", False)
+ window_batch_size: int = int(os.environ.get("WINDOW_BATCH_SIZE", 32))
+ window_size: int = int(os.environ.get("WINDOW_SIZE", 32))
+ window_stride: int = int(os.environ.get("WINDOW_SIZE", 16))
+ split_on_spaces: bool = os.environ.get("SPLIT_ON_SPACES", False)
+
+
+class RayParameterManager:
+ def __init__(self) -> None:
+ self.num_gpus = int(os.environ.get("NUM_GPUS", 1))
+ self.min_replicas = int(os.environ.get("MIN_REPLICAS", 1))
+ self.max_replicas = int(os.environ.get("MAX_REPLICAS", 1))
diff --git a/relik/inference/serve/frontend/__init__.py b/relik/inference/serve/frontend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/relik/inference/serve/frontend/relik.py b/relik/inference/serve/frontend/relik.py
new file mode 100644
index 0000000000000000000000000000000000000000..5dd8bb4eb2d8c6b056c61bda359959010e688635
--- /dev/null
+++ b/relik/inference/serve/frontend/relik.py
@@ -0,0 +1,231 @@
+import os
+import re
+import time
+from pathlib import Path
+
+import requests
+import streamlit as st
+from spacy import displacy
+from streamlit_extras.badges import badge
+from streamlit_extras.stylable_container import stylable_container
+
+RELIK = os.getenv("RELIK", "localhost:8000/api/entities")
+
+import random
+
+
+def get_random_color(ents):
+ colors = {}
+ random_colors = generate_pastel_colors(len(ents))
+ for ent in ents:
+ colors[ent] = random_colors.pop(random.randint(0, len(random_colors) - 1))
+ return colors
+
+
+def floatrange(start, stop, steps):
+ if int(steps) == 1:
+ return [stop]
+ return [
+ start + float(i) * (stop - start) / (float(steps) - 1) for i in range(steps)
+ ]
+
+
+def hsl_to_rgb(h, s, l):
+ def hue_2_rgb(v1, v2, v_h):
+ while v_h < 0.0:
+ v_h += 1.0
+ while v_h > 1.0:
+ v_h -= 1.0
+ if 6 * v_h < 1.0:
+ return v1 + (v2 - v1) * 6.0 * v_h
+ if 2 * v_h < 1.0:
+ return v2
+ if 3 * v_h < 2.0:
+ return v1 + (v2 - v1) * ((2.0 / 3.0) - v_h) * 6.0
+ return v1
+
+ # if not (0 <= s <= 1): raise ValueError, "s (saturation) parameter must be between 0 and 1."
+ # if not (0 <= l <= 1): raise ValueError, "l (lightness) parameter must be between 0 and 1."
+
+ r, b, g = (l * 255,) * 3
+ if s != 0.0:
+ if l < 0.5:
+ var_2 = l * (1.0 + s)
+ else:
+ var_2 = (l + s) - (s * l)
+ var_1 = 2.0 * l - var_2
+ r = 255 * hue_2_rgb(var_1, var_2, h + (1.0 / 3.0))
+ g = 255 * hue_2_rgb(var_1, var_2, h)
+ b = 255 * hue_2_rgb(var_1, var_2, h - (1.0 / 3.0))
+
+ return int(round(r)), int(round(g)), int(round(b))
+
+
+def generate_pastel_colors(n):
+ """Return different pastel colours.
+
+ Input:
+ n (integer) : The number of colors to return
+
+ Output:
+ A list of colors in HTML notation (eg.['#cce0ff', '#ffcccc', '#ccffe0', '#f5ccff', '#f5ffcc'])
+
+ Example:
+ >>> print generate_pastel_colors(5)
+ ['#cce0ff', '#f5ccff', '#ffcccc', '#f5ffcc', '#ccffe0']
+ """
+ if n == 0:
+ return []
+
+ # To generate colors, we use the HSL colorspace (see http://en.wikipedia.org/wiki/HSL_color_space)
+ start_hue = 0.6 # 0=red 1/3=0.333=green 2/3=0.666=blue
+ saturation = 1.0
+ lightness = 0.8
+ # We take points around the chromatic circle (hue):
+ # (Note: we generate n+1 colors, then drop the last one ([:-1]) because
+ # it equals the first one (hue 0 = hue 1))
+ return [
+ "#%02x%02x%02x" % hsl_to_rgb(hue, saturation, lightness)
+ for hue in floatrange(start_hue, start_hue + 1, n + 1)
+ ][:-1]
+
+
+def set_sidebar(css):
+ white_link_wrapper = "{}"
+ with st.sidebar:
+ st.markdown(f"", unsafe_allow_html=True)
+ st.image(
+ "http://nlp.uniroma1.it/static/website/sapienza-nlp-logo-wh.svg",
+ use_column_width=True,
+ )
+ st.markdown("## ReLiK")
+ st.write(
+ f"""
+ - {white_link_wrapper.format("#", " Paper")}
+ - {white_link_wrapper.format("https://github.com/SapienzaNLP/relik", " GitHub")}
+ - {white_link_wrapper.format("https://hub.docker.com/repository/docker/sapienzanlp/relik", " Docker Hub")}
+ """,
+ unsafe_allow_html=True,
+ )
+ st.markdown("## Sapienza NLP")
+ st.write(
+ f"""
+ - {white_link_wrapper.format("https://nlp.uniroma1.it", " Webpage")}
+ - {white_link_wrapper.format("https://github.com/SapienzaNLP", " GitHub")}
+ - {white_link_wrapper.format("https://twitter.com/SapienzaNLP", " Twitter")}
+ - {white_link_wrapper.format("https://www.linkedin.com/company/79434450", " LinkedIn")}
+ """,
+ unsafe_allow_html=True,
+ )
+
+
+def get_el_annotations(response):
+ # swap labels key with ents
+ response["ents"] = response.pop("labels")
+ label_in_text = set(l["label"] for l in response["ents"])
+ options = {"ents": label_in_text, "colors": get_random_color(label_in_text)}
+ return response, options
+
+
+def set_intro(css):
+ # intro
+ st.markdown("# ReLik")
+ st.markdown(
+ "### Retrieve, Read and LinK: Fast and Accurate Entity Linking and Relation Extraction on an Academic Budget"
+ )
+ # st.markdown(
+ # "This is a front-end for the paper [Universal Semantic Annotator: the First Unified API "
+ # "for WSD, SRL and Semantic Parsing](https://www.researchgate.net/publication/360671045_Universal_Semantic_Annotator_the_First_Unified_API_for_WSD_SRL_and_Semantic_Parsing), which will be presented at LREC 2022 by "
+ # "[Riccardo Orlando](https://riccorl.github.io), [Simone Conia](https://c-simone.github.io/), "
+ # "[Stefano Faralli](https://corsidilaurea.uniroma1.it/it/users/stefanofaralliuniroma1it), and [Roberto Navigli](https://www.diag.uniroma1.it/navigli/)."
+ # )
+ badge(type="github", name="sapienzanlp/relik")
+ badge(type="pypi", name="relik")
+
+
+def run_client():
+ with open(Path(__file__).parent / "style.css") as f:
+ css = f.read()
+
+ st.set_page_config(
+ page_title="ReLik",
+ page_icon="🦮",
+ layout="wide",
+ )
+ set_sidebar(css)
+ set_intro(css)
+
+ # text input
+ text = st.text_area(
+ "Enter Text Below:",
+ value="Obama went to Rome for a quick vacation.",
+ height=200,
+ max_chars=500,
+ )
+
+ with stylable_container(
+ key="annotate_button",
+ css_styles="""
+ button {
+ background-color: #802433;
+ color: white;
+ border-radius: 25px;
+ }
+ """,
+ ):
+ submit = st.button("Annotate")
+ # submit = st.button("Run")
+
+ # ReLik API call
+ if submit:
+ text = text.strip()
+ if text:
+ st.markdown("####")
+ st.markdown("#### Entity Linking")
+ with st.spinner(text="In progress"):
+ response = requests.post(RELIK, json=text)
+ if response.status_code != 200:
+ st.error("Error: {}".format(response.status_code))
+ else:
+ response = response.json()
+
+ # Entity Linking
+ # with stylable_container(
+ # key="container_with_border",
+ # css_styles="""
+ # {
+ # border: 1px solid rgba(49, 51, 63, 0.2);
+ # border-radius: 0.5rem;
+ # padding: 0.5rem;
+ # padding-bottom: 2rem;
+ # }
+ # """,
+ # ):
+ # st.markdown("##")
+ dict_of_ents, options = get_el_annotations(response=response)
+ display = displacy.render(
+ dict_of_ents, manual=True, style="ent", options=options
+ )
+ display = display.replace("\n", " ")
+ # wsd_display = re.sub(
+ # r"(wiki::\d+\w)",
+ # r"\g<1>".format(
+ # language.upper()
+ # ),
+ # wsd_display,
+ # )
+ with st.container():
+ st.write(display, unsafe_allow_html=True)
+
+ st.markdown("####")
+ st.markdown("#### Relation Extraction")
+
+ with st.container():
+ st.write("Coming :)", unsafe_allow_html=True)
+
+ else:
+ st.error("Please enter some text.")
+
+
+if __name__ == "__main__":
+ run_client()
diff --git a/relik/inference/serve/frontend/style.css b/relik/inference/serve/frontend/style.css
new file mode 100644
index 0000000000000000000000000000000000000000..31f0d182cfd9b2636d5db5cbd0e7a1339ed5d1c3
--- /dev/null
+++ b/relik/inference/serve/frontend/style.css
@@ -0,0 +1,33 @@
+/* Sidebar */
+.eczjsme11 {
+ background-color: #802433;
+}
+
+.st-emotion-cache-10oheav h2 {
+ color: white;
+}
+
+.st-emotion-cache-10oheav li {
+ color: white;
+}
+
+/* Main */
+a:link {
+ text-decoration: none;
+ color: white;
+}
+
+a:visited {
+ text-decoration: none;
+ color: white;
+}
+
+a:hover {
+ text-decoration: none;
+ color: rgba(255, 255, 255, 0.871);
+}
+
+a:active {
+ text-decoration: none;
+ color: white;
+}
\ No newline at end of file
diff --git a/relik/reader/__init__.py b/relik/reader/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/relik/reader/conf/config.yaml b/relik/reader/conf/config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..05b4e524d060aed56b930d1f578424b986792975
--- /dev/null
+++ b/relik/reader/conf/config.yaml
@@ -0,0 +1,14 @@
+# Required to make the "experiments" dir the default one for the output of the models
+hydra:
+ run:
+ dir: ./experiments/${model_name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
+
+model_name: relik-reader-deberta-base # used to name the model in wandb and output dir
+project_name: relik-reader # used to name the project in wandb
+
+
+defaults:
+ - _self_
+ - training: base
+ - model: base
+ - data: base
diff --git a/relik/reader/conf/data/base.yaml b/relik/reader/conf/data/base.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1964d8750bed1a162a9ec15d20be1708ccad9914
--- /dev/null
+++ b/relik/reader/conf/data/base.yaml
@@ -0,0 +1,21 @@
+train_dataset_path: "relik/reader/data/train.jsonl"
+val_dataset_path: "relik/reader/data/testa.jsonl"
+
+train_dataset:
+ _target_: "relik.reader.relik_reader_data.RelikDataset"
+ transformer_model: "${model.model.transformer_model}"
+ materialize_samples: False
+ shuffle_candidates: 0.5
+ random_drop_gold_candidates: 0.05
+ noise_param: 0.0
+ for_inference: False
+ tokens_per_batch: 4096
+ special_symbols: null
+
+val_dataset:
+ _target_: "relik.reader.relik_reader_data.RelikDataset"
+ transformer_model: "${model.model.transformer_model}"
+ materialize_samples: False
+ shuffle_candidates: False
+ for_inference: True
+ special_symbols: null
diff --git a/relik/reader/conf/data/re.yaml b/relik/reader/conf/data/re.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..17c18ee886021bc0157edb156020409fdd799fbc
--- /dev/null
+++ b/relik/reader/conf/data/re.yaml
@@ -0,0 +1,54 @@
+train_dataset_path: "relik/reader/data/nyt-alby+/train.jsonl"
+val_dataset_path: "relik/reader/data/nyt-alby+/valid.jsonl"
+test_dataset_path: "relik/reader/data/nyt-alby+/test.jsonl"
+
+relations_definitions:
+ /people/person/nationality: "nationality"
+ /sports/sports_team/location: "sports team location"
+ /location/country/administrative_divisions: "administrative divisions"
+ /business/company/major_shareholders: "shareholders"
+ /people/ethnicity/people: "ethnicity"
+ /people/ethnicity/geographic_distribution: "geographic distributi6on"
+ /business/company_shareholder/major_shareholder_of: "major shareholder"
+ /location/location/contains: "location"
+ /business/company/founders: "founders"
+ /business/person/company: "company"
+ /business/company/advisors: "advisor"
+ /people/deceased_person/place_of_death: "place of death"
+ /business/company/industry: "industry"
+ /people/person/ethnicity: "ethnic background"
+ /people/person/place_of_birth: "place of birth"
+ /location/administrative_division/country: "country of an administration division"
+ /people/person/place_lived: "place lived"
+ /sports/sports_team_location/teams: "sports team"
+ /people/person/children: "child"
+ /people/person/religion: "religion"
+ /location/neighborhood/neighborhood_of: "neighborhood"
+ /location/country/capital: "capital"
+ /business/company/place_founded: "company founded location"
+ /people/person/profession: "occupation"
+
+train_dataset:
+ _target_: "relik.reader.relik_reader_re_data.RelikREDataset"
+ transformer_model: "${model.model.transformer_model}"
+ materialize_samples: False
+ shuffle_candidates: False
+ flip_candidates: 1.0
+ noise_param: 0.0
+ for_inference: False
+ tokens_per_batch: 4096
+ min_length: -1
+ special_symbols: null
+ relations_definitions: ${data.relations_definitions}
+ sorting_fields:
+ - "predictable_candidates"
+val_dataset:
+ _target_: "relik.reader.relik_reader_re_data.RelikREDataset"
+ transformer_model: "${model.model.transformer_model}"
+ materialize_samples: False
+ shuffle_candidates: False
+ flip_candidates: False
+ for_inference: True
+ min_length: -1
+ special_symbols: null
+ relations_definitions: ${data.relations_definitions}
diff --git a/relik/reader/conf/training/base.yaml b/relik/reader/conf/training/base.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8e366a96408bd0f8ff5184849e53d19bb477af38
--- /dev/null
+++ b/relik/reader/conf/training/base.yaml
@@ -0,0 +1,12 @@
+seed: 94
+
+trainer:
+ _target_: lightning.Trainer
+ devices:
+ - 0
+ precision: "16-mixed"
+ max_steps: 50000
+ val_check_interval: 1.0
+ num_sanity_val_steps: 0
+ limit_val_batches: 1
+ gradient_clip_val: 1.0
diff --git a/relik/reader/conf/training/re.yaml b/relik/reader/conf/training/re.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8701ae3fca48830649022644a743783a1016bd5b
--- /dev/null
+++ b/relik/reader/conf/training/re.yaml
@@ -0,0 +1,12 @@
+seed: 15
+
+trainer:
+ _target_: lightning.Trainer
+ devices:
+ - 0
+ precision: "16-mixed"
+ max_steps: 100000
+ val_check_interval: 1.0
+ num_sanity_val_steps: 0
+ limit_val_batches: 1
+ gradient_clip_val: 1.0
\ No newline at end of file
diff --git a/relik/reader/data/__init__.py b/relik/reader/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/relik/reader/data/patches.py b/relik/reader/data/patches.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0d03dbdf08d0e205787ce2b8176c6bd47d2dfca
--- /dev/null
+++ b/relik/reader/data/patches.py
@@ -0,0 +1,51 @@
+from typing import List
+
+from relik.reader.data.relik_reader_sample import RelikReaderSample
+from relik.reader.utils.special_symbols import NME_SYMBOL
+
+
+def merge_patches_predictions(sample) -> None:
+ sample._d["predicted_window_labels"] = dict()
+ predicted_window_labels = sample._d["predicted_window_labels"]
+
+ sample._d["span_title_probabilities"] = dict()
+ span_title_probabilities = sample._d["span_title_probabilities"]
+
+ span2title = dict()
+ for _, patch_info in sorted(sample.patches.items(), key=lambda x: x[0]):
+ # selecting span predictions
+ for predicted_title, predicted_spans in patch_info[
+ "predicted_window_labels"
+ ].items():
+ for pred_span in predicted_spans:
+ pred_span = tuple(pred_span)
+ curr_title = span2title.get(pred_span)
+ if curr_title is None or curr_title == NME_SYMBOL:
+ span2title[pred_span] = predicted_title
+ # else:
+ # print("Merging at patch level")
+
+ # selecting span predictions probability
+ for predicted_span, titles_probabilities in patch_info[
+ "span_title_probabilities"
+ ].items():
+ if predicted_span not in span_title_probabilities:
+ span_title_probabilities[predicted_span] = titles_probabilities
+
+ for span, title in span2title.items():
+ if title not in predicted_window_labels:
+ predicted_window_labels[title] = list()
+ predicted_window_labels[title].append(span)
+
+
+def remove_duplicate_samples(
+ samples: List[RelikReaderSample],
+) -> List[RelikReaderSample]:
+ seen_sample = set()
+ samples_store = []
+ for sample in samples:
+ sample_id = f"{sample.doc_id}#{sample.sent_id}#{sample.offset}"
+ if sample_id not in seen_sample:
+ seen_sample.add(sample_id)
+ samples_store.append(sample)
+ return samples_store
diff --git a/relik/reader/data/relik_reader_data.py b/relik/reader/data/relik_reader_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c65646f99d37cdcf03ab7005c83eb0069da168c
--- /dev/null
+++ b/relik/reader/data/relik_reader_data.py
@@ -0,0 +1,965 @@
+import logging
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Generator,
+ Iterable,
+ Iterator,
+ List,
+ NamedTuple,
+ Optional,
+ Tuple,
+ Union,
+)
+
+import numpy as np
+import torch
+from torch.utils.data import IterableDataset
+from tqdm import tqdm
+from transformers import AutoTokenizer, PreTrainedTokenizer
+
+from relik.reader.data.relik_reader_data_utils import (
+ add_noise_to_value,
+ batchify,
+ chunks,
+ flatten,
+)
+from relik.reader.data.relik_reader_sample import (
+ RelikReaderSample,
+ load_relik_reader_samples,
+)
+from relik.reader.utils.special_symbols import NME_SYMBOL
+
+logger = logging.getLogger(__name__)
+
+
+def preprocess_dataset(
+ input_dataset: Iterable[dict],
+ transformer_model: str,
+ add_topic: bool,
+) -> Iterable[dict]:
+ tokenizer = AutoTokenizer.from_pretrained(transformer_model)
+ for dataset_elem in tqdm(input_dataset, desc="Preprocessing input dataset"):
+ if len(dataset_elem["tokens"]) == 0:
+ print(
+ f"Dataset element with doc id: {dataset_elem['doc_id']}",
+ f"and offset {dataset_elem['offset']} does not contain any token",
+ "Skipping it",
+ )
+ continue
+
+ new_dataset_elem = dict(
+ doc_id=dataset_elem["doc_id"],
+ offset=dataset_elem["offset"],
+ )
+
+ tokenization_out = tokenizer(
+ dataset_elem["tokens"],
+ return_offsets_mapping=True,
+ add_special_tokens=False,
+ )
+
+ window_tokens = tokenization_out.input_ids
+ window_tokens = flatten(window_tokens)
+
+ offsets_mapping = [
+ [
+ (
+ ss + dataset_elem["token2char_start"][str(i)],
+ se + dataset_elem["token2char_start"][str(i)],
+ )
+ for ss, se in tokenization_out.offset_mapping[i]
+ ]
+ for i in range(len(dataset_elem["tokens"]))
+ ]
+
+ offsets_mapping = flatten(offsets_mapping)
+
+ assert len(offsets_mapping) == len(window_tokens)
+
+ window_tokens = (
+ [tokenizer.cls_token_id] + window_tokens + [tokenizer.sep_token_id]
+ )
+
+ topic_offset = 0
+ if add_topic:
+ topic_tokens = tokenizer(
+ dataset_elem["doc_topic"], add_special_tokens=False
+ ).input_ids
+ topic_offset = len(topic_tokens)
+ new_dataset_elem["topic_tokens"] = topic_offset
+ window_tokens = window_tokens[:1] + topic_tokens + window_tokens[1:]
+
+ new_dataset_elem.update(
+ dict(
+ tokens=window_tokens,
+ token2char_start={
+ str(i): s
+ for i, (s, _) in enumerate(offsets_mapping, start=topic_offset)
+ },
+ token2char_end={
+ str(i): e
+ for i, (_, e) in enumerate(offsets_mapping, start=topic_offset)
+ },
+ window_candidates=dataset_elem["window_candidates"],
+ window_candidates_scores=dataset_elem.get("window_candidates_scores"),
+ )
+ )
+
+ if "window_labels" in dataset_elem:
+ window_labels = [
+ (s, e, l.replace("_", " ")) for s, e, l in dataset_elem["window_labels"]
+ ]
+
+ new_dataset_elem["window_labels"] = window_labels
+
+ if not all(
+ [
+ s in new_dataset_elem["token2char_start"].values()
+ for s, _, _ in new_dataset_elem["window_labels"]
+ ]
+ ):
+ print(
+ "Mismatching token start char mapping with labels",
+ new_dataset_elem["token2char_start"],
+ new_dataset_elem["window_labels"],
+ dataset_elem["tokens"],
+ )
+ continue
+
+ if not all(
+ [
+ e in new_dataset_elem["token2char_end"].values()
+ for _, e, _ in new_dataset_elem["window_labels"]
+ ]
+ ):
+ print(
+ "Mismatching token end char mapping with labels",
+ new_dataset_elem["token2char_end"],
+ new_dataset_elem["window_labels"],
+ dataset_elem["tokens"],
+ )
+ continue
+
+ yield new_dataset_elem
+
+
+def preprocess_sample(
+ relik_sample: RelikReaderSample,
+ tokenizer,
+ lowercase_policy: float,
+ add_topic: bool = False,
+) -> None:
+ if len(relik_sample.tokens) == 0:
+ return
+
+ if lowercase_policy > 0:
+ lc_tokens = np.random.uniform(0, 1, len(relik_sample.tokens)) < lowercase_policy
+ relik_sample.tokens = [
+ t.lower() if lc else t for t, lc in zip(relik_sample.tokens, lc_tokens)
+ ]
+
+ tokenization_out = tokenizer(
+ relik_sample.tokens,
+ return_offsets_mapping=True,
+ add_special_tokens=False,
+ )
+
+ window_tokens = tokenization_out.input_ids
+ window_tokens = flatten(window_tokens)
+
+ offsets_mapping = [
+ [
+ (
+ ss + relik_sample.token2char_start[str(i)],
+ se + relik_sample.token2char_start[str(i)],
+ )
+ for ss, se in tokenization_out.offset_mapping[i]
+ ]
+ for i in range(len(relik_sample.tokens))
+ ]
+
+ offsets_mapping = flatten(offsets_mapping)
+
+ assert len(offsets_mapping) == len(window_tokens)
+
+ window_tokens = [tokenizer.cls_token_id] + window_tokens + [tokenizer.sep_token_id]
+
+ topic_offset = 0
+ if add_topic:
+ topic_tokens = tokenizer(
+ relik_sample.doc_topic, add_special_tokens=False
+ ).input_ids
+ topic_offset = len(topic_tokens)
+ relik_sample.topic_tokens = topic_offset
+ window_tokens = window_tokens[:1] + topic_tokens + window_tokens[1:]
+
+ relik_sample._d.update(
+ dict(
+ tokens=window_tokens,
+ token2char_start={
+ str(i): s
+ for i, (s, _) in enumerate(offsets_mapping, start=topic_offset)
+ },
+ token2char_end={
+ str(i): e
+ for i, (_, e) in enumerate(offsets_mapping, start=topic_offset)
+ },
+ )
+ )
+
+ if "window_labels" in relik_sample._d:
+ relik_sample.window_labels = [
+ (s, e, l.replace("_", " ")) for s, e, l in relik_sample.window_labels
+ ]
+
+
+class TokenizationOutput(NamedTuple):
+ input_ids: torch.Tensor
+ attention_mask: torch.Tensor
+ token_type_ids: torch.Tensor
+ prediction_mask: torch.Tensor
+ special_symbols_mask: torch.Tensor
+
+
+class RelikDataset(IterableDataset):
+ def __init__(
+ self,
+ dataset_path: Optional[str],
+ materialize_samples: bool,
+ transformer_model: Union[str, PreTrainedTokenizer],
+ special_symbols: List[str],
+ shuffle_candidates: Optional[Union[bool, float]] = False,
+ for_inference: bool = False,
+ noise_param: float = 0.1,
+ sorting_fields: Optional[str] = None,
+ tokens_per_batch: int = 2048,
+ batch_size: int = None,
+ max_batch_size: int = 128,
+ section_size: int = 50_000,
+ prebatch: bool = True,
+ random_drop_gold_candidates: float = 0.0,
+ use_nme: bool = True,
+ max_subwords_per_candidate: bool = 22,
+ mask_by_instances: bool = False,
+ min_length: int = 5,
+ max_length: int = 2048,
+ model_max_length: int = 1000,
+ split_on_cand_overload: bool = True,
+ skip_empty_training_samples: bool = False,
+ drop_last: bool = False,
+ samples: Optional[Iterator[RelikReaderSample]] = None,
+ lowercase_policy: float = 0.0,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.dataset_path = dataset_path
+ self.materialize_samples = materialize_samples
+ self.samples: Optional[List[RelikReaderSample]] = None
+ if self.materialize_samples:
+ self.samples = list()
+
+ if isinstance(transformer_model, str):
+ self.tokenizer = self._build_tokenizer(transformer_model, special_symbols)
+ else:
+ self.tokenizer = transformer_model
+ self.special_symbols = special_symbols
+ self.shuffle_candidates = shuffle_candidates
+ self.for_inference = for_inference
+ self.noise_param = noise_param
+ self.batching_fields = ["input_ids"]
+ self.sorting_fields = (
+ sorting_fields if sorting_fields is not None else self.batching_fields
+ )
+
+ self.tokens_per_batch = tokens_per_batch
+ self.batch_size = batch_size
+ self.max_batch_size = max_batch_size
+ self.section_size = section_size
+ self.prebatch = prebatch
+
+ self.random_drop_gold_candidates = random_drop_gold_candidates
+ self.use_nme = use_nme
+ self.max_subwords_per_candidate = max_subwords_per_candidate
+ self.mask_by_instances = mask_by_instances
+ self.min_length = min_length
+ self.max_length = max_length
+ self.model_max_length = (
+ model_max_length
+ if model_max_length < self.tokenizer.model_max_length
+ else self.tokenizer.model_max_length
+ )
+
+ # retrocompatibility workaround
+ self.transformer_model = (
+ transformer_model
+ if isinstance(transformer_model, str)
+ else transformer_model.name_or_path
+ )
+ self.split_on_cand_overload = split_on_cand_overload
+ self.skip_empty_training_samples = skip_empty_training_samples
+ self.drop_last = drop_last
+ self.lowercase_policy = lowercase_policy
+ self.samples = samples
+
+ def _build_tokenizer(self, transformer_model: str, special_symbols: List[str]):
+ return AutoTokenizer.from_pretrained(
+ transformer_model,
+ additional_special_tokens=[ss for ss in special_symbols],
+ add_prefix_space=True,
+ )
+
+ @property
+ def fields_batcher(self) -> Dict[str, Union[None, Callable[[list], Any]]]:
+ fields_batchers = {
+ "input_ids": lambda x: batchify(
+ x, padding_value=self.tokenizer.pad_token_id
+ ),
+ "attention_mask": lambda x: batchify(x, padding_value=0),
+ "token_type_ids": lambda x: batchify(x, padding_value=0),
+ "prediction_mask": lambda x: batchify(x, padding_value=1),
+ "global_attention": lambda x: batchify(x, padding_value=0),
+ "token2word": None,
+ "sample": None,
+ "special_symbols_mask": lambda x: batchify(x, padding_value=False),
+ "start_labels": lambda x: batchify(x, padding_value=-100),
+ "end_labels": lambda x: batchify(x, padding_value=-100),
+ "predictable_candidates_symbols": None,
+ "predictable_candidates": None,
+ "patch_offset": None,
+ "optimus_labels": None,
+ }
+
+ if "roberta" in self.transformer_model:
+ del fields_batchers["token_type_ids"]
+
+ return fields_batchers
+
+ def _build_input_ids(
+ self, sentence_input_ids: List[int], candidates_input_ids: List[List[int]]
+ ) -> List[int]:
+ return (
+ [self.tokenizer.cls_token_id]
+ + sentence_input_ids
+ + [self.tokenizer.sep_token_id]
+ + flatten(candidates_input_ids)
+ + [self.tokenizer.sep_token_id]
+ )
+
+ def _get_special_symbols_mask(self, input_ids: torch.Tensor) -> torch.Tensor:
+ special_symbols_mask = input_ids >= (
+ len(self.tokenizer) - len(self.special_symbols)
+ )
+ special_symbols_mask[0] = True
+ return special_symbols_mask
+
+ def _build_tokenizer_essentials(
+ self, input_ids, original_sequence, sample
+ ) -> TokenizationOutput:
+ input_ids = torch.tensor(input_ids, dtype=torch.long)
+ attention_mask = torch.ones_like(input_ids)
+
+ total_sequence_len = len(input_ids)
+ predictable_sentence_len = len(original_sequence)
+
+ # token type ids
+ token_type_ids = torch.cat(
+ [
+ input_ids.new_zeros(
+ predictable_sentence_len + 2
+ ), # original sentence bpes + CLS and SEP
+ input_ids.new_ones(total_sequence_len - predictable_sentence_len - 2),
+ ]
+ )
+
+ # prediction mask -> boolean on tokens that are predictable
+
+ prediction_mask = torch.tensor(
+ [1]
+ + ([0] * predictable_sentence_len)
+ + ([1] * (total_sequence_len - predictable_sentence_len - 1))
+ )
+
+ # add topic tokens to the prediction mask so that they cannot be predicted
+ # or optimized during training
+ topic_tokens = getattr(sample, "topic_tokens", None)
+ if topic_tokens is not None:
+ prediction_mask[1 : 1 + topic_tokens] = 1
+
+ # If mask by instances is active the prediction mask is applied to everything
+ # that is not indicated as an instance in the training set.
+ if self.mask_by_instances:
+ char_start2token = {
+ cs: int(tok) for tok, cs in sample.token2char_start.items()
+ }
+ char_end2token = {ce: int(tok) for tok, ce in sample.token2char_end.items()}
+ instances_mask = torch.ones_like(prediction_mask)
+ for _, span_info in sample.instance_id2span_data.items():
+ span_info = span_info[0]
+ token_start = char_start2token[span_info[0]] + 1 # +1 for the CLS
+ token_end = char_end2token[span_info[1]] + 1 # +1 for the CLS
+ instances_mask[token_start : token_end + 1] = 0
+
+ prediction_mask += instances_mask
+ prediction_mask[prediction_mask > 1] = 1
+
+ assert len(prediction_mask) == len(input_ids)
+
+ # special symbols mask
+ special_symbols_mask = self._get_special_symbols_mask(input_ids)
+
+ return TokenizationOutput(
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ prediction_mask,
+ special_symbols_mask,
+ )
+
+ def _build_labels(
+ self,
+ sample,
+ tokenization_output: TokenizationOutput,
+ predictable_candidates: List[str],
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ start_labels = [0] * len(tokenization_output.input_ids)
+ end_labels = [0] * len(tokenization_output.input_ids)
+
+ char_start2token = {v: int(k) for k, v in sample.token2char_start.items()}
+ char_end2token = {v: int(k) for k, v in sample.token2char_end.items()}
+ for cs, ce, gold_candidate_title in sample.window_labels:
+ if gold_candidate_title not in predictable_candidates:
+ if self.use_nme:
+ gold_candidate_title = NME_SYMBOL
+ else:
+ continue
+ # +1 is to account for the CLS token
+ start_bpe = char_start2token[cs] + 1
+ end_bpe = char_end2token[ce] + 1
+ class_index = predictable_candidates.index(gold_candidate_title)
+ if (
+ start_labels[start_bpe] == 0 and end_labels[end_bpe] == 0
+ ): # prevent from having entities that ends with the same label
+ start_labels[start_bpe] = class_index + 1 # +1 for the NONE class
+ end_labels[end_bpe] = class_index + 1 # +1 for the NONE class
+ else:
+ print(
+ "Found entity with the same last subword, it will not be included."
+ )
+ print(
+ cs,
+ ce,
+ gold_candidate_title,
+ start_labels,
+ end_labels,
+ sample.doc_id,
+ )
+
+ ignored_labels_indices = tokenization_output.prediction_mask == 1
+
+ start_labels = torch.tensor(start_labels, dtype=torch.long)
+ start_labels[ignored_labels_indices] = -100
+
+ end_labels = torch.tensor(end_labels, dtype=torch.long)
+ end_labels[ignored_labels_indices] = -100
+
+ return start_labels, end_labels
+
+ def produce_sample_bag(
+ self, sample, predictable_candidates: List[str], candidates_starting_offset: int
+ ) -> Optional[Tuple[dict, list, int]]:
+ # input sentence tokenization
+ input_subwords = sample.tokens[1:-1] # removing special tokens
+ candidates_symbols = self.special_symbols[candidates_starting_offset:]
+
+ predictable_candidates = list(predictable_candidates)
+ original_predictable_candidates = list(predictable_candidates)
+
+ # add NME as a possible candidate
+ if self.use_nme:
+ predictable_candidates.insert(0, NME_SYMBOL)
+
+ # candidates encoding
+ candidates_symbols = candidates_symbols[: len(predictable_candidates)]
+ candidates_encoding_result = self.tokenizer.batch_encode_plus(
+ [
+ "{} {}".format(cs, ct) if ct != NME_SYMBOL else NME_SYMBOL
+ for cs, ct in zip(candidates_symbols, predictable_candidates)
+ ],
+ add_special_tokens=False,
+ ).input_ids
+
+ if (
+ self.max_subwords_per_candidate is not None
+ and self.max_subwords_per_candidate > 0
+ ):
+ candidates_encoding_result = [
+ cer[: self.max_subwords_per_candidate]
+ for cer in candidates_encoding_result
+ ]
+
+ # drop candidates if the number of input tokens is too long for the model
+ if (
+ sum(map(len, candidates_encoding_result))
+ + len(input_subwords)
+ + 20 # + 20 special tokens
+ > self.model_max_length
+ ):
+ acceptable_tokens_from_candidates = (
+ self.model_max_length - 20 - len(input_subwords)
+ )
+ i = 0
+ cum_len = 0
+ while (
+ cum_len + len(candidates_encoding_result[i])
+ < acceptable_tokens_from_candidates
+ ):
+ cum_len += len(candidates_encoding_result[i])
+ i += 1
+
+ candidates_encoding_result = candidates_encoding_result[:i]
+ candidates_symbols = candidates_symbols[:i]
+ predictable_candidates = predictable_candidates[:i]
+
+ # final input_ids build
+ input_ids = self._build_input_ids(
+ sentence_input_ids=input_subwords,
+ candidates_input_ids=candidates_encoding_result,
+ )
+
+ # complete input building (e.g. attention / prediction mask)
+ tokenization_output = self._build_tokenizer_essentials(
+ input_ids, input_subwords, sample
+ )
+
+ output_dict = {
+ "input_ids": tokenization_output.input_ids,
+ "attention_mask": tokenization_output.attention_mask,
+ "token_type_ids": tokenization_output.token_type_ids,
+ "prediction_mask": tokenization_output.prediction_mask,
+ "special_symbols_mask": tokenization_output.special_symbols_mask,
+ "sample": sample,
+ "predictable_candidates_symbols": candidates_symbols,
+ "predictable_candidates": predictable_candidates,
+ }
+
+ # labels creation
+ if sample.window_labels is not None:
+ start_labels, end_labels = self._build_labels(
+ sample,
+ tokenization_output,
+ predictable_candidates,
+ )
+ output_dict.update(start_labels=start_labels, end_labels=end_labels)
+
+ if (
+ "roberta" in self.transformer_model
+ or "longformer" in self.transformer_model
+ ):
+ del output_dict["token_type_ids"]
+
+ predictable_candidates_set = set(predictable_candidates)
+ remaining_candidates = [
+ candidate
+ for candidate in original_predictable_candidates
+ if candidate not in predictable_candidates_set
+ ]
+ total_used_candidates = (
+ candidates_starting_offset
+ + len(predictable_candidates)
+ - (1 if self.use_nme else 0)
+ )
+
+ if self.use_nme:
+ assert predictable_candidates[0] == NME_SYMBOL
+
+ return output_dict, remaining_candidates, total_used_candidates
+
+ def __iter__(self):
+ dataset_iterator = self.dataset_iterator_func()
+
+ current_dataset_elements = []
+
+ i = None
+ for i, dataset_elem in enumerate(dataset_iterator, start=1):
+ if (
+ self.section_size is not None
+ and len(current_dataset_elements) == self.section_size
+ ):
+ for batch in self.materialize_batches(current_dataset_elements):
+ yield batch
+ current_dataset_elements = []
+
+ current_dataset_elements.append(dataset_elem)
+
+ if i % 50_000 == 0:
+ logger.info(f"Processed: {i} number of elements")
+
+ if len(current_dataset_elements) != 0:
+ for batch in self.materialize_batches(current_dataset_elements):
+ yield batch
+
+ if i is not None:
+ logger.info(f"Dataset finished: {i} number of elements processed")
+ else:
+ logger.warning("Dataset empty")
+
+ def dataset_iterator_func(self):
+ skipped_instances = 0
+ data_samples = (
+ load_relik_reader_samples(self.dataset_path)
+ if self.samples is None
+ else self.samples
+ )
+ for sample in data_samples:
+ preprocess_sample(
+ sample, self.tokenizer, lowercase_policy=self.lowercase_policy
+ )
+ current_patch = 0
+ sample_bag, used_candidates = None, None
+ remaining_candidates = list(sample.window_candidates)
+
+ if not self.for_inference:
+ # randomly drop gold candidates at training time
+ if (
+ self.random_drop_gold_candidates > 0.0
+ and np.random.uniform() < self.random_drop_gold_candidates
+ and len(set(ct for _, _, ct in sample.window_labels)) > 1
+ ):
+ # selecting candidates to drop
+ np.random.shuffle(sample.window_labels)
+ n_dropped_candidates = np.random.randint(
+ 0, len(sample.window_labels) - 1
+ )
+ dropped_candidates = [
+ label_elem[-1]
+ for label_elem in sample.window_labels[:n_dropped_candidates]
+ ]
+ dropped_candidates = set(dropped_candidates)
+
+ # saving NMEs because they should not be dropped
+ if NME_SYMBOL in dropped_candidates:
+ dropped_candidates.remove(NME_SYMBOL)
+
+ # sample update
+ sample.window_labels = [
+ (s, e, _l)
+ if _l not in dropped_candidates
+ else (s, e, NME_SYMBOL)
+ for s, e, _l in sample.window_labels
+ ]
+ remaining_candidates = [
+ wc
+ for wc in remaining_candidates
+ if wc not in dropped_candidates
+ ]
+
+ # shuffle candidates
+ if (
+ isinstance(self.shuffle_candidates, bool)
+ and self.shuffle_candidates
+ ) or (
+ isinstance(self.shuffle_candidates, float)
+ and np.random.uniform() < self.shuffle_candidates
+ ):
+ np.random.shuffle(remaining_candidates)
+
+ while len(remaining_candidates) != 0:
+ sample_bag = self.produce_sample_bag(
+ sample,
+ predictable_candidates=remaining_candidates,
+ candidates_starting_offset=used_candidates
+ if used_candidates is not None
+ else 0,
+ )
+ if sample_bag is not None:
+ sample_bag, remaining_candidates, used_candidates = sample_bag
+ if (
+ self.for_inference
+ or not self.skip_empty_training_samples
+ or (
+ (
+ sample_bag.get("start_labels") is not None
+ and torch.any(sample_bag["start_labels"] > 1).item()
+ )
+ or (
+ sample_bag.get("optimus_labels") is not None
+ and len(sample_bag["optimus_labels"]) > 0
+ )
+ )
+ ):
+ sample_bag["patch_offset"] = current_patch
+ current_patch += 1
+ yield sample_bag
+ else:
+ skipped_instances += 1
+ if skipped_instances % 1000 == 0 and skipped_instances != 0:
+ logger.info(
+ f"Skipped {skipped_instances} instances since they did not have any gold labels..."
+ )
+
+ # Just use the first fitting candidates if split on
+ # cand is not True
+ if not self.split_on_cand_overload:
+ break
+
+ def preshuffle_elements(self, dataset_elements: List):
+ # This shuffling is done so that when using the sorting function,
+ # if it is deterministic given a collection and its order, we will
+ # make the whole operation not deterministic anymore.
+ # Basically, the aim is not to build every time the same batches.
+ if not self.for_inference:
+ dataset_elements = np.random.permutation(dataset_elements)
+
+ sorting_fn = (
+ lambda elem: add_noise_to_value(
+ sum(len(elem[k]) for k in self.sorting_fields),
+ noise_param=self.noise_param,
+ )
+ if not self.for_inference
+ else sum(len(elem[k]) for k in self.sorting_fields)
+ )
+
+ dataset_elements = sorted(dataset_elements, key=sorting_fn)
+
+ if self.for_inference:
+ return dataset_elements
+
+ ds = list(chunks(dataset_elements, 64))
+ np.random.shuffle(ds)
+ return flatten(ds)
+
+ def materialize_batches(
+ self, dataset_elements: List[Dict[str, Any]]
+ ) -> Generator[Dict[str, Any], None, None]:
+ if self.prebatch:
+ dataset_elements = self.preshuffle_elements(dataset_elements)
+
+ current_batch = []
+
+ # function that creates a batch from the 'current_batch' list
+ def output_batch() -> Dict[str, Any]:
+ assert (
+ len(
+ set([len(elem["predictable_candidates"]) for elem in current_batch])
+ )
+ == 1
+ ), " ".join(
+ map(
+ str, [len(elem["predictable_candidates"]) for elem in current_batch]
+ )
+ )
+
+ batch_dict = dict()
+
+ de_values_by_field = {
+ fn: [de[fn] for de in current_batch if fn in de]
+ for fn in self.fields_batcher
+ }
+
+ # in case you provide fields batchers but in the batch
+ # there are no elements for that field
+ de_values_by_field = {
+ fn: fvs for fn, fvs in de_values_by_field.items() if len(fvs) > 0
+ }
+
+ assert len(set([len(v) for v in de_values_by_field.values()]))
+
+ # todo: maybe we should report the user about possible
+ # fields filtering due to "None" instances
+ de_values_by_field = {
+ fn: fvs
+ for fn, fvs in de_values_by_field.items()
+ if all([fv is not None for fv in fvs])
+ }
+
+ for field_name, field_values in de_values_by_field.items():
+ field_batch = (
+ self.fields_batcher[field_name](field_values)
+ if self.fields_batcher[field_name] is not None
+ else field_values
+ )
+
+ batch_dict[field_name] = field_batch
+
+ return batch_dict
+
+ max_len_discards, min_len_discards = 0, 0
+
+ should_token_batch = self.batch_size is None
+
+ curr_pred_elements = -1
+ for de in dataset_elements:
+ if (
+ should_token_batch
+ and self.max_batch_size != -1
+ and len(current_batch) == self.max_batch_size
+ ) or (not should_token_batch and len(current_batch) == self.batch_size):
+ yield output_batch()
+ current_batch = []
+ curr_pred_elements = -1
+
+ too_long_fields = [
+ k
+ for k in de
+ if self.max_length != -1
+ and torch.is_tensor(de[k])
+ and len(de[k]) > self.max_length
+ ]
+ if len(too_long_fields) > 0:
+ max_len_discards += 1
+ continue
+
+ too_short_fields = [
+ k
+ for k in de
+ if self.min_length != -1
+ and torch.is_tensor(de[k])
+ and len(de[k]) < self.min_length
+ ]
+ if len(too_short_fields) > 0:
+ min_len_discards += 1
+ continue
+
+ if should_token_batch:
+ de_len = sum(len(de[k]) for k in self.batching_fields)
+
+ future_max_len = max(
+ de_len,
+ max(
+ [
+ sum(len(bde[k]) for k in self.batching_fields)
+ for bde in current_batch
+ ],
+ default=0,
+ ),
+ )
+
+ future_tokens_per_batch = future_max_len * (len(current_batch) + 1)
+
+ num_predictable_candidates = len(de["predictable_candidates"])
+
+ if len(current_batch) > 0 and (
+ future_tokens_per_batch >= self.tokens_per_batch
+ or (
+ num_predictable_candidates != curr_pred_elements
+ and curr_pred_elements != -1
+ )
+ ):
+ yield output_batch()
+ current_batch = []
+
+ current_batch.append(de)
+ curr_pred_elements = len(de["predictable_candidates"])
+
+ if len(current_batch) != 0 and not self.drop_last:
+ yield output_batch()
+
+ if max_len_discards > 0:
+ if self.for_inference:
+ logger.warning(
+ f"WARNING: Inference mode is True but {max_len_discards} samples longer than max length were "
+ f"found. The {max_len_discards} samples will be DISCARDED. If you are doing some kind of evaluation"
+ f", this can INVALIDATE results. This might happen if the max length was not set to -1 or if the "
+ f"sample length exceeds the maximum length supported by the current model."
+ )
+ else:
+ logger.warning(
+ f"During iteration, {max_len_discards} elements were "
+ f"discarded since longer than max length {self.max_length}"
+ )
+
+ if min_len_discards > 0:
+ if self.for_inference:
+ logger.warning(
+ f"WARNING: Inference mode is True but {min_len_discards} samples shorter than min length were "
+ f"found. The {min_len_discards} samples will be DISCARDED. If you are doing some kind of evaluation"
+ f", this can INVALIDATE results. This might happen if the min length was not set to -1 or if the "
+ f"sample length is shorter than the minimum length supported by the current model."
+ )
+ else:
+ logger.warning(
+ f"During iteration, {min_len_discards} elements were "
+ f"discarded since shorter than min length {self.min_length}"
+ )
+
+ @staticmethod
+ def convert_tokens_to_char_annotations(
+ sample: RelikReaderSample,
+ remove_nmes: bool = True,
+ ) -> RelikReaderSample:
+ """
+ Converts the token annotations to char annotations.
+
+ Args:
+ sample (:obj:`RelikReaderSample`):
+ The sample to convert.
+ remove_nmes (:obj:`bool`, `optional`, defaults to :obj:`True`):
+ Whether to remove the NMEs from the annotations.
+ Returns:
+ :obj:`RelikReaderSample`: The converted sample.
+ """
+ char_annotations = set()
+ for (
+ predicted_entity,
+ predicted_spans,
+ ) in sample.predicted_window_labels.items():
+ if predicted_entity == NME_SYMBOL and remove_nmes:
+ continue
+
+ for span_start, span_end in predicted_spans:
+ span_start = sample.token2char_start[str(span_start)]
+ span_end = sample.token2char_end[str(span_end)]
+
+ char_annotations.add((span_start, span_end, predicted_entity))
+
+ char_probs_annotations = dict()
+ for (
+ span_start,
+ span_end,
+ ), candidates_probs in sample.span_title_probabilities.items():
+ span_start = sample.token2char_start[str(span_start)]
+ span_end = sample.token2char_end[str(span_end)]
+ char_probs_annotations[(span_start, span_end)] = {
+ title for title, _ in candidates_probs
+ }
+
+ sample.predicted_window_labels_chars = char_annotations
+ sample.probs_window_labels_chars = char_probs_annotations
+
+ return sample
+
+ @staticmethod
+ def merge_patches_predictions(sample) -> None:
+ sample._d["predicted_window_labels"] = dict()
+ predicted_window_labels = sample._d["predicted_window_labels"]
+
+ sample._d["span_title_probabilities"] = dict()
+ span_title_probabilities = sample._d["span_title_probabilities"]
+
+ span2title = dict()
+ for _, patch_info in sorted(sample.patches.items(), key=lambda x: x[0]):
+ # selecting span predictions
+ for predicted_title, predicted_spans in patch_info[
+ "predicted_window_labels"
+ ].items():
+ for pred_span in predicted_spans:
+ pred_span = tuple(pred_span)
+ curr_title = span2title.get(pred_span)
+ if curr_title is None or curr_title == NME_SYMBOL:
+ span2title[pred_span] = predicted_title
+ # else:
+ # print("Merging at patch level")
+
+ # selecting span predictions probability
+ for predicted_span, titles_probabilities in patch_info[
+ "span_title_probabilities"
+ ].items():
+ if predicted_span not in span_title_probabilities:
+ span_title_probabilities[predicted_span] = titles_probabilities
+
+ for span, title in span2title.items():
+ if title not in predicted_window_labels:
+ predicted_window_labels[title] = list()
+ predicted_window_labels[title].append(span)
diff --git a/relik/reader/data/relik_reader_data_utils.py b/relik/reader/data/relik_reader_data_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c7446bee296d14653a35895bf9ec8071c87e5af
--- /dev/null
+++ b/relik/reader/data/relik_reader_data_utils.py
@@ -0,0 +1,51 @@
+from typing import List
+
+import numpy as np
+import torch
+
+
+def flatten(lsts: List[list]) -> list:
+ acc_lst = list()
+ for lst in lsts:
+ acc_lst.extend(lst)
+ return acc_lst
+
+
+def batchify(tensors: List[torch.Tensor], padding_value: int = 0) -> torch.Tensor:
+ return torch.nn.utils.rnn.pad_sequence(
+ tensors, batch_first=True, padding_value=padding_value
+ )
+
+
+def batchify_matrices(tensors: List[torch.Tensor], padding_value: int) -> torch.Tensor:
+ x = max([t.shape[0] for t in tensors])
+ y = max([t.shape[1] for t in tensors])
+ out_matrix = torch.zeros((len(tensors), x, y))
+ out_matrix += padding_value
+ for i, tensor in enumerate(tensors):
+ out_matrix[i][0 : tensor.shape[0], 0 : tensor.shape[1]] = tensor
+ return out_matrix
+
+
+def batchify_tensor(tensors: List[torch.Tensor], padding_value: int) -> torch.Tensor:
+ x = max([t.shape[0] for t in tensors])
+ y = max([t.shape[1] for t in tensors])
+ rest = tensors[0].shape[2]
+ out_matrix = torch.zeros((len(tensors), x, y, rest))
+ out_matrix += padding_value
+ for i, tensor in enumerate(tensors):
+ out_matrix[i][0 : tensor.shape[0], 0 : tensor.shape[1], :] = tensor
+ return out_matrix
+
+
+def chunks(lst: list, chunk_size: int) -> List[list]:
+ chunks_acc = list()
+ for i in range(0, len(lst), chunk_size):
+ chunks_acc.append(lst[i : i + chunk_size])
+ return chunks_acc
+
+
+def add_noise_to_value(value: int, noise_param: float):
+ noise_value = value * noise_param
+ noise = np.random.uniform(-noise_value, noise_value)
+ return max(1, value + noise)
diff --git a/relik/reader/data/relik_reader_sample.py b/relik/reader/data/relik_reader_sample.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d7570411fbb939f99d73d1cc3318b21552bc7c2
--- /dev/null
+++ b/relik/reader/data/relik_reader_sample.py
@@ -0,0 +1,49 @@
+import json
+from typing import Iterable
+
+
+class RelikReaderSample:
+ def __init__(self, **kwargs):
+ super().__setattr__("_d", {})
+ self._d = kwargs
+
+ def __getattribute__(self, item):
+ return super(RelikReaderSample, self).__getattribute__(item)
+
+ def __getattr__(self, item):
+ if item.startswith("__") and item.endswith("__"):
+ # this is likely some python library-specific variable (such as __deepcopy__ for copy)
+ # better follow standard behavior here
+ raise AttributeError(item)
+ elif item in self._d:
+ return self._d[item]
+ else:
+ return None
+
+ def __setattr__(self, key, value):
+ if key in self._d:
+ self._d[key] = value
+ else:
+ super().__setattr__(key, value)
+
+ def to_jsons(self) -> str:
+ if "predicted_window_labels" in self._d:
+ new_obj = {
+ k: v
+ for k, v in self._d.items()
+ if k != "predicted_window_labels" and k != "span_title_probabilities"
+ }
+ new_obj["predicted_window_labels"] = [
+ [ss, se, pred_title]
+ for (ss, se), pred_title in self.predicted_window_labels_chars
+ ]
+ else:
+ return json.dumps(self._d)
+
+
+def load_relik_reader_samples(path: str) -> Iterable[RelikReaderSample]:
+ with open(path) as f:
+ for line in f:
+ jsonl_line = json.loads(line.strip())
+ relik_reader_sample = RelikReaderSample(**jsonl_line)
+ yield relik_reader_sample
diff --git a/relik/reader/lightning_modules/__init__.py b/relik/reader/lightning_modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/relik/reader/lightning_modules/relik_reader_pl_module.py b/relik/reader/lightning_modules/relik_reader_pl_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e66e87b6360fe9b4a72477fd5a7fe6295b53ae9
--- /dev/null
+++ b/relik/reader/lightning_modules/relik_reader_pl_module.py
@@ -0,0 +1,50 @@
+from typing import Any, Optional
+
+import lightning
+from lightning.pytorch.utilities.types import STEP_OUTPUT, OptimizerLRScheduler
+
+from relik.reader.relik_reader_core import RelikReaderCoreModel
+
+
+class RelikReaderPLModule(lightning.LightningModule):
+ def __init__(
+ self,
+ cfg: dict,
+ transformer_model: str,
+ additional_special_symbols: int,
+ num_layers: Optional[int] = None,
+ activation: str = "gelu",
+ linears_hidden_size: Optional[int] = 512,
+ use_last_k_layers: int = 1,
+ training: bool = False,
+ *args: Any,
+ **kwargs: Any
+ ):
+ super().__init__(*args, **kwargs)
+ self.save_hyperparameters()
+ self.relik_reader_core_model = RelikReaderCoreModel(
+ transformer_model,
+ additional_special_symbols,
+ num_layers,
+ activation,
+ linears_hidden_size,
+ use_last_k_layers,
+ training=training,
+ )
+ self.optimizer_factory = None
+
+ def training_step(self, batch: dict, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
+ relik_output = self.relik_reader_core_model(**batch)
+ self.log("train-loss", relik_output["loss"])
+ return relik_output["loss"]
+
+ def validation_step(
+ self, batch: dict, *args: Any, **kwargs: Any
+ ) -> Optional[STEP_OUTPUT]:
+ return
+
+ def set_optimizer_factory(self, optimizer_factory) -> None:
+ self.optimizer_factory = optimizer_factory
+
+ def configure_optimizers(self) -> OptimizerLRScheduler:
+ return self.optimizer_factory(self.relik_reader_core_model)
diff --git a/relik/reader/lightning_modules/relik_reader_re_pl_module.py b/relik/reader/lightning_modules/relik_reader_re_pl_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad1d2084f10d700d68b30e68dc29cbace1450f9b
--- /dev/null
+++ b/relik/reader/lightning_modules/relik_reader_re_pl_module.py
@@ -0,0 +1,54 @@
+from typing import Any, Optional
+
+import lightning
+from lightning.pytorch.utilities.types import STEP_OUTPUT, OptimizerLRScheduler
+
+from relik.reader.relik_reader_re import RelikReaderForTripletExtraction
+
+
+class RelikReaderREPLModule(lightning.LightningModule):
+ def __init__(
+ self,
+ cfg: dict,
+ transformer_model: str,
+ additional_special_symbols: int,
+ num_layers: Optional[int] = None,
+ activation: str = "gelu",
+ linears_hidden_size: Optional[int] = 512,
+ use_last_k_layers: int = 1,
+ training: bool = False,
+ *args: Any,
+ **kwargs: Any
+ ):
+ super().__init__(*args, **kwargs)
+ self.save_hyperparameters()
+
+ self.relik_reader_re_model = RelikReaderForTripletExtraction(
+ transformer_model,
+ additional_special_symbols,
+ num_layers,
+ activation,
+ linears_hidden_size,
+ use_last_k_layers,
+ training=training,
+ )
+ self.optimizer_factory = None
+
+ def training_step(self, batch: dict, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
+ relik_output = self.relik_reader_re_model(**batch)
+ self.log("train-loss", relik_output["loss"])
+ self.log("train-start_loss", relik_output["ned_start_loss"])
+ self.log("train-end_loss", relik_output["ned_end_loss"])
+ self.log("train-relation_loss", relik_output["re_loss"])
+ return relik_output["loss"]
+
+ def validation_step(
+ self, batch: dict, *args: Any, **kwargs: Any
+ ) -> Optional[STEP_OUTPUT]:
+ return
+
+ def set_optimizer_factory(self, optimizer_factory) -> None:
+ self.optimizer_factory = optimizer_factory
+
+ def configure_optimizers(self) -> OptimizerLRScheduler:
+ return self.optimizer_factory(self.relik_reader_re_model)
diff --git a/relik/reader/pytorch_modules/__init__.py b/relik/reader/pytorch_modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/relik/reader/pytorch_modules/base.py b/relik/reader/pytorch_modules/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..75db716d53d6dbdb9e9f95b63dfa1d8619769bbf
--- /dev/null
+++ b/relik/reader/pytorch_modules/base.py
@@ -0,0 +1,248 @@
+import logging
+import os
+from pathlib import Path
+from typing import Any, Dict, List
+
+import torch
+import transformers as tr
+from torch.utils.data import IterableDataset
+from transformers import AutoConfig
+
+from relik.common.log import get_console_logger, get_logger
+from relik.common.utils import get_callable_from_string
+from relik.reader.pytorch_modules.hf.modeling_relik import (
+ RelikReaderConfig,
+ RelikReaderSample,
+)
+
+console_logger = get_console_logger()
+logger = get_logger(__name__, level=logging.INFO)
+
+
+class RelikReaderBase(torch.nn.Module):
+ default_reader_class: str | None = None
+ default_data_class: str | None = None
+
+ def __init__(
+ self,
+ transformer_model: str | tr.PreTrainedModel | None = None,
+ additional_special_symbols: int = 0,
+ num_layers: int | None = None,
+ activation: str = "gelu",
+ linears_hidden_size: int | None = 512,
+ use_last_k_layers: int = 1,
+ training: bool = False,
+ device: str | torch.device | None = None,
+ precision: int = 32,
+ tokenizer: str | tr.PreTrainedTokenizer | None = None,
+ dataset: IterableDataset | str | None = None,
+ default_reader_class: tr.PreTrainedModel | str | None = None,
+ **kwargs,
+ ) -> None:
+ super().__init__()
+
+ self.default_reader_class = default_reader_class or self.default_reader_class
+
+ if self.default_reader_class is None:
+ raise ValueError("You must specify a default reader class.")
+
+ # get the callable for the default reader class
+ self.default_reader_class: tr.PreTrainedModel = get_callable_from_string(
+ self.default_reader_class
+ )
+
+ if isinstance(transformer_model, str):
+ config = AutoConfig.from_pretrained(
+ transformer_model, trust_remote_code=True
+ )
+ if "relik-reader" in config.model_type:
+ transformer_model = self.default_reader_class.from_pretrained(
+ transformer_model, **kwargs
+ )
+ else:
+ reader_config = RelikReaderConfig(
+ transformer_model=transformer_model,
+ additional_special_symbols=additional_special_symbols,
+ num_layers=num_layers,
+ activation=activation,
+ linears_hidden_size=linears_hidden_size,
+ use_last_k_layers=use_last_k_layers,
+ training=training,
+ )
+ transformer_model = self.default_reader_class(reader_config)
+
+ self.relik_reader_model = transformer_model
+ self.relik_reader_model_config = self.relik_reader_model.config
+
+ # get the tokenizer
+ self._tokenizer = tokenizer
+
+ # and instantiate the dataset class
+ self.dataset: IterableDataset | None = dataset
+
+ # move the model to the device
+ self.to(device or torch.device("cpu"))
+
+ # set the precision
+ self.precision = precision
+
+ def forward(self, **kwargs) -> Dict[str, Any]:
+ return self.relik_reader_model(**kwargs)
+
+ def _read(self, *args, **kwargs) -> Any:
+ raise NotImplementedError
+
+ @torch.no_grad()
+ @torch.inference_mode()
+ def read(
+ self,
+ text: List[str] | List[List[str]] | None = None,
+ samples: List[RelikReaderSample] | None = None,
+ input_ids: torch.Tensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ token_type_ids: torch.Tensor | None = None,
+ prediction_mask: torch.Tensor | None = None,
+ special_symbols_mask: torch.Tensor | None = None,
+ candidates: List[List[str]] | None = None,
+ max_length: int = 1000,
+ max_batch_size: int = 128,
+ token_batch_size: int = 2048,
+ precision: int | str | None = None,
+ progress_bar: bool = False,
+ *args,
+ **kwargs,
+ ) -> List[RelikReaderSample] | List[List[RelikReaderSample]]:
+ """
+ Reads the given text.
+
+ Args:
+ text (:obj:`List[str]` or :obj:`List[List[str]]`, `optional`):
+ The text to read in tokens. If a list of list of tokens is provided, each
+ inner list is considered a sentence.
+ samples (:obj:`List[RelikReaderSample]`, `optional`):
+ The samples to read. If provided, `text` and `candidates` are ignored.
+ input_ids (:obj:`torch.Tensor`, `optional`):
+ The input ids of the text.
+ attention_mask (:obj:`torch.Tensor`, `optional`):
+ The attention mask of the text.
+ token_type_ids (:obj:`torch.Tensor`, `optional`):
+ The token type ids of the text.
+ prediction_mask (:obj:`torch.Tensor`, `optional`):
+ The prediction mask of the text.
+ special_symbols_mask (:obj:`torch.Tensor`, `optional`):
+ The special symbols mask of the text.
+ candidates (:obj:`List[List[str]]`, `optional`):
+ The candidates of the text.
+ max_length (:obj:`int`, `optional`, defaults to 1024):
+ The maximum length of the text.
+ max_batch_size (:obj:`int`, `optional`, defaults to 128):
+ The maximum batch size.
+ token_batch_size (:obj:`int`, `optional`):
+ The maximum number of tokens per batch.
+ precision (:obj:`int` or :obj:`str`, `optional`):
+ The precision to use. If not provided, the default is 32 bit.
+ progress_bar (:obj:`bool`, `optional`, defaults to :obj:`False`):
+ Whether to show a progress bar.
+
+ Returns:
+ The predicted labels for each sample.
+ """
+ if text is None and input_ids is None and samples is None:
+ raise ValueError(
+ "Either `text` or `input_ids` or `samples` must be provided."
+ )
+ if (input_ids is None and samples is None) and (
+ text is None or candidates is None
+ ):
+ raise ValueError(
+ "`text` and `candidates` must be provided to return the predictions when "
+ "`input_ids` and `samples` is not provided."
+ )
+ if text is not None and samples is None:
+ if len(text) != len(candidates):
+ raise ValueError("`text` and `candidates` must have the same length.")
+ if isinstance(text[0], str): # change to list of text
+ text = [text]
+ candidates = [candidates]
+
+ samples = [
+ RelikReaderSample(tokens=t, candidates=c)
+ for t, c in zip(text, candidates)
+ ]
+
+ return self._read(
+ samples,
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ prediction_mask,
+ special_symbols_mask,
+ max_length,
+ max_batch_size,
+ token_batch_size,
+ precision or self.precision,
+ progress_bar,
+ *args,
+ **kwargs,
+ )
+
+ @property
+ def device(self) -> torch.device:
+ """
+ The device of the model.
+ """
+ return next(self.parameters()).device
+
+ @property
+ def tokenizer(self) -> tr.PreTrainedTokenizer:
+ """
+ The tokenizer.
+ """
+ if self._tokenizer:
+ return self._tokenizer
+
+ self._tokenizer = tr.AutoTokenizer.from_pretrained(
+ self.relik_reader_model.config.name_or_path
+ )
+ return self._tokenizer
+
+ def save_pretrained(
+ self,
+ output_dir: str | os.PathLike,
+ model_name: str | None = None,
+ push_to_hub: bool = False,
+ **kwargs,
+ ) -> None:
+ """
+ Saves the model to the given path.
+
+ Args:
+ output_dir (`str` or :obj:`os.PathLike`):
+ The path to save the model to.
+ model_name (`str`, `optional`):
+ The name of the model. If not provided, the model will be saved as
+ `default_reader_class.__name__`.
+ push_to_hub (`bool`, `optional`, defaults to `False`):
+ Whether to push the model to the HuggingFace Hub.
+ **kwargs:
+ Additional keyword arguments to pass to the `save_pretrained` method
+ """
+ # create the output directory
+ output_dir = Path(output_dir)
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ model_name = model_name or self.default_reader_class.__name__
+
+ logger.info(f"Saving reader to {output_dir / model_name}")
+
+ # save the model
+ self.relik_reader_model.register_for_auto_class()
+ self.relik_reader_model.save_pretrained(
+ output_dir / model_name, push_to_hub=push_to_hub, **kwargs
+ )
+
+ if self.tokenizer:
+ logger.info("Saving also the tokenizer")
+ self.tokenizer.save_pretrained(
+ output_dir / model_name, push_to_hub=push_to_hub, **kwargs
+ )
diff --git a/relik/reader/pytorch_modules/hf/__init__.py b/relik/reader/pytorch_modules/hf/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9c158e6ab6dcd3ab43e60751218600fbb0a5ed5
--- /dev/null
+++ b/relik/reader/pytorch_modules/hf/__init__.py
@@ -0,0 +1,2 @@
+from .configuration_relik import RelikReaderConfig
+from .modeling_relik import RelikReaderREModel
diff --git a/relik/reader/pytorch_modules/hf/configuration_relik.py b/relik/reader/pytorch_modules/hf/configuration_relik.py
new file mode 100644
index 0000000000000000000000000000000000000000..6683823926b4b09a5ad169ef4e0f5b92061d774e
--- /dev/null
+++ b/relik/reader/pytorch_modules/hf/configuration_relik.py
@@ -0,0 +1,33 @@
+from typing import Optional
+
+from transformers import AutoConfig
+from transformers.configuration_utils import PretrainedConfig
+
+
+class RelikReaderConfig(PretrainedConfig):
+ model_type = "relik-reader"
+
+ def __init__(
+ self,
+ transformer_model: str = "microsoft/deberta-v3-base",
+ additional_special_symbols: int = 101,
+ num_layers: Optional[int] = None,
+ activation: str = "gelu",
+ linears_hidden_size: Optional[int] = 512,
+ use_last_k_layers: int = 1,
+ training: bool = False,
+ default_reader_class: Optional[str] = None,
+ **kwargs
+ ) -> None:
+ self.transformer_model = transformer_model
+ self.additional_special_symbols = additional_special_symbols
+ self.num_layers = num_layers
+ self.activation = activation
+ self.linears_hidden_size = linears_hidden_size
+ self.use_last_k_layers = use_last_k_layers
+ self.training = training
+ self.default_reader_class = default_reader_class
+ super().__init__(**kwargs)
+
+
+AutoConfig.register("relik-reader", RelikReaderConfig)
diff --git a/relik/reader/pytorch_modules/hf/modeling_relik.py b/relik/reader/pytorch_modules/hf/modeling_relik.py
new file mode 100644
index 0000000000000000000000000000000000000000..f79fc14e0cabe9f830187467578ff3f65351c9a2
--- /dev/null
+++ b/relik/reader/pytorch_modules/hf/modeling_relik.py
@@ -0,0 +1,981 @@
+from typing import Any, Dict, Optional
+
+import torch
+from transformers import AutoModel, PreTrainedModel
+from transformers.activations import ClippedGELUActivation, GELUActivation
+from transformers.configuration_utils import PretrainedConfig
+from transformers.modeling_utils import PoolerEndLogits
+
+from .configuration_relik import RelikReaderConfig
+
+
+class RelikReaderSample:
+ def __init__(self, **kwargs):
+ super().__setattr__("_d", {})
+ self._d = kwargs
+
+ def __getattribute__(self, item):
+ return super(RelikReaderSample, self).__getattribute__(item)
+
+ def __getattr__(self, item):
+ if item.startswith("__") and item.endswith("__"):
+ # this is likely some python library-specific variable (such as __deepcopy__ for copy)
+ # better follow standard behavior here
+ raise AttributeError(item)
+ elif item in self._d:
+ return self._d[item]
+ else:
+ return None
+
+ def __setattr__(self, key, value):
+ if key in self._d:
+ self._d[key] = value
+ else:
+ super().__setattr__(key, value)
+
+
+activation2functions = {
+ "relu": torch.nn.ReLU(),
+ "gelu": GELUActivation(),
+ "gelu_10": ClippedGELUActivation(-10, 10),
+}
+
+
+class PoolerEndLogitsBi(PoolerEndLogits):
+ def __init__(self, config: PretrainedConfig):
+ super().__init__(config)
+ self.dense_1 = torch.nn.Linear(config.hidden_size, 2)
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ start_states: Optional[torch.FloatTensor] = None,
+ start_positions: Optional[torch.LongTensor] = None,
+ p_mask: Optional[torch.FloatTensor] = None,
+ ) -> torch.FloatTensor:
+ if p_mask is not None:
+ p_mask = p_mask.unsqueeze(-1)
+ logits = super().forward(
+ hidden_states,
+ start_states,
+ start_positions,
+ p_mask,
+ )
+ return logits
+
+
+class RelikReaderSpanModel(PreTrainedModel):
+ config_class = RelikReaderConfig
+
+ def __init__(self, config: RelikReaderConfig, *args, **kwargs):
+ super().__init__(config)
+ # Transformer model declaration
+ self.config = config
+ self.transformer_model = (
+ AutoModel.from_pretrained(self.config.transformer_model)
+ if self.config.num_layers is None
+ else AutoModel.from_pretrained(
+ self.config.transformer_model, num_hidden_layers=self.config.num_layers
+ )
+ )
+ self.transformer_model.resize_token_embeddings(
+ self.transformer_model.config.vocab_size
+ + self.config.additional_special_symbols
+ )
+
+ self.activation = self.config.activation
+ self.linears_hidden_size = self.config.linears_hidden_size
+ self.use_last_k_layers = self.config.use_last_k_layers
+
+ # named entity detection layers
+ self.ned_start_classifier = self._get_projection_layer(
+ self.activation, last_hidden=2, layer_norm=False
+ )
+ self.ned_end_classifier = PoolerEndLogits(self.transformer_model.config)
+
+ # END entity disambiguation layer
+ self.ed_start_projector = self._get_projection_layer(self.activation)
+ self.ed_end_projector = self._get_projection_layer(self.activation)
+
+ self.training = self.config.training
+
+ # criterion
+ self.criterion = torch.nn.CrossEntropyLoss()
+
+ def _get_projection_layer(
+ self,
+ activation: str,
+ last_hidden: Optional[int] = None,
+ input_hidden=None,
+ layer_norm: bool = True,
+ ) -> torch.nn.Sequential:
+ head_components = [
+ torch.nn.Dropout(0.1),
+ torch.nn.Linear(
+ self.transformer_model.config.hidden_size * self.use_last_k_layers
+ if input_hidden is None
+ else input_hidden,
+ self.linears_hidden_size,
+ ),
+ activation2functions[activation],
+ torch.nn.Dropout(0.1),
+ torch.nn.Linear(
+ self.linears_hidden_size,
+ self.linears_hidden_size if last_hidden is None else last_hidden,
+ ),
+ ]
+
+ if layer_norm:
+ head_components.append(
+ torch.nn.LayerNorm(
+ self.linears_hidden_size if last_hidden is None else last_hidden,
+ self.transformer_model.config.layer_norm_eps,
+ )
+ )
+
+ return torch.nn.Sequential(*head_components)
+
+ def _mask_logits(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
+ mask = mask.unsqueeze(-1)
+ if next(self.parameters()).dtype == torch.float16:
+ logits = logits * (1 - mask) - 65500 * mask
+ else:
+ logits = logits * (1 - mask) - 1e30 * mask
+ return logits
+
+ def _get_model_features(
+ self,
+ input_ids: torch.Tensor,
+ attention_mask: torch.Tensor,
+ token_type_ids: Optional[torch.Tensor],
+ ):
+ model_input = {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "output_hidden_states": self.use_last_k_layers > 1,
+ }
+
+ if token_type_ids is not None:
+ model_input["token_type_ids"] = token_type_ids
+
+ model_output = self.transformer_model(**model_input)
+
+ if self.use_last_k_layers > 1:
+ model_features = torch.cat(
+ model_output[1][-self.use_last_k_layers :], dim=-1
+ )
+ else:
+ model_features = model_output[0]
+
+ return model_features
+
+ def compute_ned_end_logits(
+ self,
+ start_predictions,
+ start_labels,
+ model_features,
+ prediction_mask,
+ batch_size,
+ ) -> Optional[torch.Tensor]:
+ # todo: maybe when constraining on the spans,
+ # we should not use a prediction_mask for the end tokens.
+ # at least we should not during training imo
+ start_positions = start_labels if self.training else start_predictions
+ start_positions_indices = (
+ torch.arange(start_positions.size(1), device=start_positions.device)
+ .unsqueeze(0)
+ .expand(batch_size, -1)[start_positions > 0]
+ ).to(start_positions.device)
+
+ if len(start_positions_indices) > 0:
+ expanded_features = torch.cat(
+ [
+ model_features[i].unsqueeze(0).expand(x, -1, -1)
+ for i, x in enumerate(torch.sum(start_positions > 0, dim=-1))
+ if x > 0
+ ],
+ dim=0,
+ ).to(start_positions_indices.device)
+
+ expanded_prediction_mask = torch.cat(
+ [
+ prediction_mask[i].unsqueeze(0).expand(x, -1)
+ for i, x in enumerate(torch.sum(start_positions > 0, dim=-1))
+ if x > 0
+ ],
+ dim=0,
+ ).to(expanded_features.device)
+
+ end_logits = self.ned_end_classifier(
+ hidden_states=expanded_features,
+ start_positions=start_positions_indices,
+ p_mask=expanded_prediction_mask,
+ )
+
+ return end_logits
+
+ return None
+
+ def compute_classification_logits(
+ self,
+ model_features,
+ special_symbols_mask,
+ prediction_mask,
+ batch_size,
+ start_positions=None,
+ end_positions=None,
+ ) -> torch.Tensor:
+ if start_positions is None or end_positions is None:
+ start_positions = torch.zeros_like(prediction_mask)
+ end_positions = torch.zeros_like(prediction_mask)
+
+ model_start_features = self.ed_start_projector(model_features)
+ model_end_features = self.ed_end_projector(model_features)
+ model_end_features[start_positions > 0] = model_end_features[end_positions > 0]
+
+ model_ed_features = torch.cat(
+ [model_start_features, model_end_features], dim=-1
+ )
+
+ # computing ed features
+ classes_representations = torch.sum(special_symbols_mask, dim=1)[0].item()
+ special_symbols_representation = model_ed_features[special_symbols_mask].view(
+ batch_size, classes_representations, -1
+ )
+
+ logits = torch.bmm(
+ model_ed_features,
+ torch.permute(special_symbols_representation, (0, 2, 1)),
+ )
+
+ logits = self._mask_logits(logits, prediction_mask)
+
+ return logits
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ attention_mask: torch.Tensor,
+ token_type_ids: Optional[torch.Tensor] = None,
+ prediction_mask: Optional[torch.Tensor] = None,
+ special_symbols_mask: Optional[torch.Tensor] = None,
+ start_labels: Optional[torch.Tensor] = None,
+ end_labels: Optional[torch.Tensor] = None,
+ use_predefined_spans: bool = False,
+ *args,
+ **kwargs,
+ ) -> Dict[str, Any]:
+ batch_size, seq_len = input_ids.shape
+
+ model_features = self._get_model_features(
+ input_ids, attention_mask, token_type_ids
+ )
+
+ ned_start_labels = None
+
+ # named entity detection if required
+ if use_predefined_spans: # no need to compute spans
+ ned_start_logits, ned_start_probabilities, ned_start_predictions = (
+ None,
+ None,
+ torch.clone(start_labels)
+ if start_labels is not None
+ else torch.zeros_like(input_ids),
+ )
+ ned_end_logits, ned_end_probabilities, ned_end_predictions = (
+ None,
+ None,
+ torch.clone(end_labels)
+ if end_labels is not None
+ else torch.zeros_like(input_ids),
+ )
+
+ ned_start_predictions[ned_start_predictions > 0] = 1
+ ned_end_predictions[ned_end_predictions > 0] = 1
+
+ else: # compute spans
+ # start boundary prediction
+ ned_start_logits = self.ned_start_classifier(model_features)
+ ned_start_logits = self._mask_logits(ned_start_logits, prediction_mask)
+ ned_start_probabilities = torch.softmax(ned_start_logits, dim=-1)
+ ned_start_predictions = ned_start_probabilities.argmax(dim=-1)
+
+ # end boundary prediction
+ ned_start_labels = (
+ torch.zeros_like(start_labels) if start_labels is not None else None
+ )
+
+ if ned_start_labels is not None:
+ ned_start_labels[start_labels == -100] = -100
+ ned_start_labels[start_labels > 0] = 1
+
+ ned_end_logits = self.compute_ned_end_logits(
+ ned_start_predictions,
+ ned_start_labels,
+ model_features,
+ prediction_mask,
+ batch_size,
+ )
+
+ if ned_end_logits is not None:
+ ned_end_probabilities = torch.softmax(ned_end_logits, dim=-1)
+ ned_end_predictions = torch.argmax(ned_end_probabilities, dim=-1)
+ else:
+ ned_end_logits, ned_end_probabilities = None, None
+ ned_end_predictions = ned_start_predictions.new_zeros(batch_size)
+
+ # flattening end predictions
+ # (flattening can happen only if the
+ # end boundaries were not predicted using the gold labels)
+ if not self.training:
+ flattened_end_predictions = torch.clone(ned_start_predictions)
+ flattened_end_predictions[flattened_end_predictions > 0] = 0
+
+ batch_start_predictions = list()
+ for elem_idx in range(batch_size):
+ batch_start_predictions.append(
+ torch.where(ned_start_predictions[elem_idx] > 0)[0].tolist()
+ )
+
+ # check that the total number of start predictions
+ # is equal to the end predictions
+ total_start_predictions = sum(map(len, batch_start_predictions))
+ total_end_predictions = len(ned_end_predictions)
+ assert (
+ total_start_predictions == 0
+ or total_start_predictions == total_end_predictions
+ ), (
+ f"Total number of start predictions = {total_start_predictions}. "
+ f"Total number of end predictions = {total_end_predictions}"
+ )
+
+ curr_end_pred_num = 0
+ for elem_idx, bsp in enumerate(batch_start_predictions):
+ for sp in bsp:
+ ep = ned_end_predictions[curr_end_pred_num].item()
+ if ep < sp:
+ ep = sp
+
+ # if we already set this span throw it (no overlap)
+ if flattened_end_predictions[elem_idx, ep] == 1:
+ ned_start_predictions[elem_idx, sp] = 0
+ else:
+ flattened_end_predictions[elem_idx, ep] = 1
+
+ curr_end_pred_num += 1
+
+ ned_end_predictions = flattened_end_predictions
+
+ start_position, end_position = (
+ (start_labels, end_labels)
+ if self.training
+ else (ned_start_predictions, ned_end_predictions)
+ )
+
+ # Entity disambiguation
+ ed_logits = self.compute_classification_logits(
+ model_features,
+ special_symbols_mask,
+ prediction_mask,
+ batch_size,
+ start_position,
+ end_position,
+ )
+ ed_probabilities = torch.softmax(ed_logits, dim=-1)
+ ed_predictions = torch.argmax(ed_probabilities, dim=-1)
+
+ # output build
+ output_dict = dict(
+ batch_size=batch_size,
+ ned_start_logits=ned_start_logits,
+ ned_start_probabilities=ned_start_probabilities,
+ ned_start_predictions=ned_start_predictions,
+ ned_end_logits=ned_end_logits,
+ ned_end_probabilities=ned_end_probabilities,
+ ned_end_predictions=ned_end_predictions,
+ ed_logits=ed_logits,
+ ed_probabilities=ed_probabilities,
+ ed_predictions=ed_predictions,
+ )
+
+ # compute loss if labels
+ if start_labels is not None and end_labels is not None and self.training:
+ # named entity detection loss
+
+ # start
+ if ned_start_logits is not None:
+ ned_start_loss = self.criterion(
+ ned_start_logits.view(-1, ned_start_logits.shape[-1]),
+ ned_start_labels.view(-1),
+ )
+ else:
+ ned_start_loss = 0
+
+ # end
+ if ned_end_logits is not None:
+ ned_end_labels = torch.zeros_like(end_labels)
+ ned_end_labels[end_labels == -100] = -100
+ ned_end_labels[end_labels > 0] = 1
+
+ ned_end_loss = self.criterion(
+ ned_end_logits,
+ (
+ torch.arange(
+ ned_end_labels.size(1), device=ned_end_labels.device
+ )
+ .unsqueeze(0)
+ .expand(batch_size, -1)[ned_end_labels > 0]
+ ).to(ned_end_labels.device),
+ )
+
+ else:
+ ned_end_loss = 0
+
+ # entity disambiguation loss
+ start_labels[ned_start_labels != 1] = -100
+ ed_labels = torch.clone(start_labels)
+ ed_labels[end_labels > 0] = end_labels[end_labels > 0]
+ ed_loss = self.criterion(
+ ed_logits.view(-1, ed_logits.shape[-1]),
+ ed_labels.view(-1),
+ )
+
+ output_dict["ned_start_loss"] = ned_start_loss
+ output_dict["ned_end_loss"] = ned_end_loss
+ output_dict["ed_loss"] = ed_loss
+
+ output_dict["loss"] = ned_start_loss + ned_end_loss + ed_loss
+
+ return output_dict
+
+
+class RelikReaderREModel(PreTrainedModel):
+ config_class = RelikReaderConfig
+
+ def __init__(self, config, *args, **kwargs):
+ super().__init__(config)
+ # Transformer model declaration
+ # self.transformer_model_name = transformer_model
+ self.config = config
+ self.transformer_model = (
+ AutoModel.from_pretrained(config.transformer_model)
+ if config.num_layers is None
+ else AutoModel.from_pretrained(
+ config.transformer_model, num_hidden_layers=config.num_layers
+ )
+ )
+ self.transformer_model.resize_token_embeddings(
+ self.transformer_model.config.vocab_size + config.additional_special_symbols
+ )
+
+ # named entity detection layers
+ self.ned_start_classifier = self._get_projection_layer(
+ config.activation, last_hidden=2, layer_norm=False
+ )
+
+ self.ned_end_classifier = PoolerEndLogitsBi(self.transformer_model.config)
+
+ self.entity_type_loss = (
+ config.entity_type_loss if hasattr(config, "entity_type_loss") else False
+ )
+ self.relation_disambiguation_loss = (
+ config.relation_disambiguation_loss
+ if hasattr(config, "relation_disambiguation_loss")
+ else False
+ )
+
+ input_hidden_ents = 2 * self.transformer_model.config.hidden_size
+
+ self.re_subject_projector = self._get_projection_layer(
+ config.activation, input_hidden=input_hidden_ents
+ )
+ self.re_object_projector = self._get_projection_layer(
+ config.activation, input_hidden=input_hidden_ents
+ )
+ self.re_relation_projector = self._get_projection_layer(config.activation)
+
+ if self.entity_type_loss or self.relation_disambiguation_loss:
+ self.re_entities_projector = self._get_projection_layer(
+ config.activation,
+ input_hidden=2 * self.transformer_model.config.hidden_size,
+ )
+ self.re_definition_projector = self._get_projection_layer(
+ config.activation,
+ )
+
+ self.re_classifier = self._get_projection_layer(
+ config.activation,
+ input_hidden=config.linears_hidden_size,
+ last_hidden=2,
+ layer_norm=False,
+ )
+
+ if self.entity_type_loss or self.relation_disambiguation_loss:
+ self.re_ed_classifier = self._get_projection_layer(
+ config.activation,
+ input_hidden=config.linears_hidden_size,
+ last_hidden=2,
+ layer_norm=False,
+ )
+
+ self.training = config.training
+
+ # criterion
+ self.criterion = torch.nn.CrossEntropyLoss()
+
+ def _get_projection_layer(
+ self,
+ activation: str,
+ last_hidden: Optional[int] = None,
+ input_hidden=None,
+ layer_norm: bool = True,
+ ) -> torch.nn.Sequential:
+ head_components = [
+ torch.nn.Dropout(0.1),
+ torch.nn.Linear(
+ self.transformer_model.config.hidden_size
+ * self.config.use_last_k_layers
+ if input_hidden is None
+ else input_hidden,
+ self.config.linears_hidden_size,
+ ),
+ activation2functions[activation],
+ torch.nn.Dropout(0.1),
+ torch.nn.Linear(
+ self.config.linears_hidden_size,
+ self.config.linears_hidden_size if last_hidden is None else last_hidden,
+ ),
+ ]
+
+ if layer_norm:
+ head_components.append(
+ torch.nn.LayerNorm(
+ self.config.linears_hidden_size
+ if last_hidden is None
+ else last_hidden,
+ self.transformer_model.config.layer_norm_eps,
+ )
+ )
+
+ return torch.nn.Sequential(*head_components)
+
+ def _mask_logits(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
+ mask = mask.unsqueeze(-1)
+ if next(self.parameters()).dtype == torch.float16:
+ logits = logits * (1 - mask) - 65500 * mask
+ else:
+ logits = logits * (1 - mask) - 1e30 * mask
+ return logits
+
+ def _get_model_features(
+ self,
+ input_ids: torch.Tensor,
+ attention_mask: torch.Tensor,
+ token_type_ids: Optional[torch.Tensor],
+ ):
+ model_input = {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "output_hidden_states": self.config.use_last_k_layers > 1,
+ }
+
+ if token_type_ids is not None:
+ model_input["token_type_ids"] = token_type_ids
+
+ model_output = self.transformer_model(**model_input)
+
+ if self.config.use_last_k_layers > 1:
+ model_features = torch.cat(
+ model_output[1][-self.config.use_last_k_layers :], dim=-1
+ )
+ else:
+ model_features = model_output[0]
+
+ return model_features
+
+ def compute_ned_end_logits(
+ self,
+ start_predictions,
+ start_labels,
+ model_features,
+ prediction_mask,
+ batch_size,
+ ) -> Optional[torch.Tensor]:
+ # todo: maybe when constraining on the spans,
+ # we should not use a prediction_mask for the end tokens.
+ # at least we should not during training imo
+ start_positions = start_labels if self.training else start_predictions
+ start_positions_indices = (
+ torch.arange(start_positions.size(1), device=start_positions.device)
+ .unsqueeze(0)
+ .expand(batch_size, -1)[start_positions > 0]
+ ).to(start_positions.device)
+
+ if len(start_positions_indices) > 0:
+ expanded_features = torch.cat(
+ [
+ model_features[i].unsqueeze(0).expand(x, -1, -1)
+ for i, x in enumerate(torch.sum(start_positions > 0, dim=-1))
+ if x > 0
+ ],
+ dim=0,
+ ).to(start_positions_indices.device)
+
+ expanded_prediction_mask = torch.cat(
+ [
+ prediction_mask[i].unsqueeze(0).expand(x, -1)
+ for i, x in enumerate(torch.sum(start_positions > 0, dim=-1))
+ if x > 0
+ ],
+ dim=0,
+ ).to(expanded_features.device)
+
+ # mask all tokens before start_positions_indices ie, mask all tokens with
+ # indices < start_positions_indices with 1, ie. [range(x) for x in start_positions_indices]
+ expanded_prediction_mask = torch.stack(
+ [
+ torch.cat(
+ [
+ torch.ones(x, device=expanded_features.device),
+ expanded_prediction_mask[i, x:],
+ ]
+ )
+ for i, x in enumerate(start_positions_indices)
+ if x > 0
+ ],
+ dim=0,
+ ).to(expanded_features.device)
+
+ end_logits = self.ned_end_classifier(
+ hidden_states=expanded_features,
+ start_positions=start_positions_indices,
+ p_mask=expanded_prediction_mask,
+ )
+
+ return end_logits
+
+ return None
+
+ def compute_relation_logits(
+ self,
+ model_entity_features,
+ special_symbols_features,
+ ) -> torch.Tensor:
+ model_subject_features = self.re_subject_projector(model_entity_features)
+ model_object_features = self.re_object_projector(model_entity_features)
+ special_symbols_start_representation = self.re_relation_projector(
+ special_symbols_features
+ )
+ re_logits = torch.einsum(
+ "bse,bde,bfe->bsdfe",
+ model_subject_features,
+ model_object_features,
+ special_symbols_start_representation,
+ )
+ re_logits = self.re_classifier(re_logits)
+
+ return re_logits
+
+ def compute_entity_logits(
+ self,
+ model_entity_features,
+ special_symbols_features,
+ ) -> torch.Tensor:
+ model_ed_features = self.re_entities_projector(model_entity_features)
+ special_symbols_ed_representation = self.re_definition_projector(
+ special_symbols_features
+ )
+ logits = torch.einsum(
+ "bce,bde->bcde",
+ model_ed_features,
+ special_symbols_ed_representation,
+ )
+ logits = self.re_ed_classifier(logits)
+ start_logits = self._mask_logits(
+ logits,
+ (model_entity_features == -100)
+ .all(2)
+ .long()
+ .unsqueeze(2)
+ .repeat(1, 1, torch.sum(model_entity_features, dim=1)[0].item()),
+ )
+
+ return logits
+
+ def compute_loss(self, logits, labels, mask=None):
+ logits = logits.view(-1, logits.shape[-1])
+ labels = labels.view(-1).long()
+ if mask is not None:
+ return self.criterion(logits[mask], labels[mask])
+ return self.criterion(logits, labels)
+
+ def compute_ned_end_loss(self, ned_end_logits, end_labels):
+ if ned_end_logits is None:
+ return 0
+ ned_end_labels = torch.zeros_like(end_labels)
+ ned_end_labels[end_labels == -100] = -100
+ ned_end_labels[end_labels > 0] = 1
+ return self.compute_loss(ned_end_logits, ned_end_labels)
+
+ def compute_ned_type_loss(
+ self,
+ disambiguation_labels,
+ re_ned_entities_logits,
+ ned_type_logits,
+ re_entities_logits,
+ entity_types,
+ ):
+ if self.entity_type_loss and self.relation_disambiguation_loss:
+ return self.compute_loss(disambiguation_labels, re_ned_entities_logits)
+ if self.entity_type_loss:
+ return self.compute_loss(
+ disambiguation_labels[:, :, :entity_types], ned_type_logits
+ )
+ if self.relation_disambiguation_loss:
+ return self.compute_loss(disambiguation_labels, re_entities_logits)
+ return 0
+
+ def compute_relation_loss(self, relation_labels, re_logits):
+ return self.compute_loss(
+ re_logits, relation_labels, relation_labels.view(-1) != -100
+ )
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ attention_mask: torch.Tensor,
+ token_type_ids: torch.Tensor,
+ prediction_mask: Optional[torch.Tensor] = None,
+ special_symbols_mask: Optional[torch.Tensor] = None,
+ special_symbols_mask_entities: Optional[torch.Tensor] = None,
+ start_labels: Optional[torch.Tensor] = None,
+ end_labels: Optional[torch.Tensor] = None,
+ disambiguation_labels: Optional[torch.Tensor] = None,
+ relation_labels: Optional[torch.Tensor] = None,
+ is_validation: bool = False,
+ is_prediction: bool = False,
+ *args,
+ **kwargs,
+ ) -> Dict[str, Any]:
+ batch_size = input_ids.shape[0]
+
+ model_features = self._get_model_features(
+ input_ids, attention_mask, token_type_ids
+ )
+
+ # named entity detection
+ if is_prediction and start_labels is not None:
+ ned_start_logits, ned_start_probabilities, ned_start_predictions = (
+ None,
+ None,
+ torch.zeros_like(start_labels),
+ )
+ ned_end_logits, ned_end_probabilities, ned_end_predictions = (
+ None,
+ None,
+ torch.zeros_like(end_labels),
+ )
+
+ ned_start_predictions[start_labels > 0] = 1
+ ned_end_predictions[end_labels > 0] = 1
+ ned_end_predictions = ned_end_predictions[~(end_labels == -100).all(2)]
+ else:
+ # start boundary prediction
+ ned_start_logits = self.ned_start_classifier(model_features)
+ ned_start_logits = self._mask_logits(
+ ned_start_logits, prediction_mask
+ ) # why?
+ ned_start_probabilities = torch.softmax(ned_start_logits, dim=-1)
+ ned_start_predictions = ned_start_probabilities.argmax(dim=-1)
+
+ # end boundary prediction
+ ned_start_labels = (
+ torch.zeros_like(start_labels) if start_labels is not None else None
+ )
+
+ # start_labels contain entity id at their position, we just need 1 for start of entity
+ if ned_start_labels is not None:
+ ned_start_labels[start_labels > 0] = 1
+
+ # compute end logits only if there are any start predictions.
+ # For each start prediction, n end predictions are made
+ ned_end_logits = self.compute_ned_end_logits(
+ ned_start_predictions,
+ ned_start_labels,
+ model_features,
+ prediction_mask,
+ batch_size,
+ )
+ # For each start prediction, n end predictions are made based on
+ # binary classification ie. argmax at each position.
+ ned_end_probabilities = torch.softmax(ned_end_logits, dim=-1)
+ ned_end_predictions = ned_end_probabilities.argmax(dim=-1)
+ if is_prediction or is_validation:
+ end_preds_count = ned_end_predictions.sum(1)
+ # If there are no end predictions for a start prediction, remove the start prediction
+ ned_start_predictions[ned_start_predictions == 1] = (
+ end_preds_count != 0
+ ).long()
+ ned_end_predictions = ned_end_predictions[end_preds_count != 0]
+
+ if end_labels is not None:
+ end_labels = end_labels[~(end_labels == -100).all(2)]
+
+ start_position, end_position = (
+ (start_labels, end_labels)
+ if (not is_prediction and not is_validation)
+ else (ned_start_predictions, ned_end_predictions)
+ )
+
+ start_counts = (start_position > 0).sum(1)
+ ned_end_predictions = ned_end_predictions.split(start_counts.tolist())
+
+ # We can only predict relations if we have start and end predictions
+ if (end_position > 0).sum() > 0:
+ ends_count = (end_position > 0).sum(1)
+ model_subject_features = torch.cat(
+ [
+ torch.repeat_interleave(
+ model_features[start_position > 0], ends_count, dim=0
+ ), # start position features
+ torch.repeat_interleave(model_features, start_counts, dim=0)[
+ end_position > 0
+ ], # end position features
+ ],
+ dim=-1,
+ )
+ ents_count = torch.nn.utils.rnn.pad_sequence(
+ torch.split(ends_count, start_counts.tolist()),
+ batch_first=True,
+ padding_value=0,
+ ).sum(1)
+ model_subject_features = torch.nn.utils.rnn.pad_sequence(
+ torch.split(model_subject_features, ents_count.tolist()),
+ batch_first=True,
+ padding_value=-100,
+ )
+
+ if is_validation or is_prediction:
+ model_subject_features = model_subject_features[:, :30, :]
+
+ # entity disambiguation. Here relation_disambiguation_loss would only be useful to
+ # reduce the number of candidate relations for the next step, but currently unused.
+ if self.entity_type_loss or self.relation_disambiguation_loss:
+ (re_ned_entities_logits) = self.compute_entity_logits(
+ model_subject_features,
+ model_features[
+ special_symbols_mask | special_symbols_mask_entities
+ ].view(batch_size, -1, model_features.shape[-1]),
+ )
+ entity_types = torch.sum(special_symbols_mask_entities, dim=1)[0].item()
+ ned_type_logits = re_ned_entities_logits[:, :, :entity_types]
+ re_entities_logits = re_ned_entities_logits[:, :, entity_types:]
+
+ if self.entity_type_loss:
+ ned_type_probabilities = torch.softmax(ned_type_logits, dim=-1)
+ ned_type_predictions = ned_type_probabilities.argmax(dim=-1)
+ ned_type_predictions = ned_type_predictions.argmax(dim=-1)
+
+ re_entities_probabilities = torch.softmax(re_entities_logits, dim=-1)
+ re_entities_predictions = re_entities_probabilities.argmax(dim=-1)
+ else:
+ (
+ ned_type_logits,
+ ned_type_probabilities,
+ re_entities_logits,
+ re_entities_probabilities,
+ ) = (None, None, None, None)
+ ned_type_predictions, re_entities_predictions = (
+ torch.zeros([batch_size, 1], dtype=torch.long).to(input_ids.device),
+ torch.zeros([batch_size, 1], dtype=torch.long).to(input_ids.device),
+ )
+
+ # Compute relation logits
+ re_logits = self.compute_relation_logits(
+ model_subject_features,
+ model_features[special_symbols_mask].view(
+ batch_size, -1, model_features.shape[-1]
+ ),
+ )
+
+ re_probabilities = torch.softmax(re_logits, dim=-1)
+ # we set a thresshold instead of argmax in cause it needs to be tweaked
+ re_predictions = re_probabilities[:, :, :, :, 1] > 0.5
+ # re_predictions = re_probabilities.argmax(dim=-1)
+ re_probabilities = re_probabilities[:, :, :, :, 1]
+
+ else:
+ (
+ ned_type_logits,
+ ned_type_probabilities,
+ re_entities_logits,
+ re_entities_probabilities,
+ ) = (None, None, None, None)
+ ned_type_predictions, re_entities_predictions = (
+ torch.zeros([batch_size, 1], dtype=torch.long).to(input_ids.device),
+ torch.zeros([batch_size, 1], dtype=torch.long).to(input_ids.device),
+ )
+ re_logits, re_probabilities, re_predictions = (
+ torch.zeros(
+ [batch_size, 1, 1, special_symbols_mask.sum(1)[0]], dtype=torch.long
+ ).to(input_ids.device),
+ torch.zeros(
+ [batch_size, 1, 1, special_symbols_mask.sum(1)[0]], dtype=torch.long
+ ).to(input_ids.device),
+ torch.zeros(
+ [batch_size, 1, 1, special_symbols_mask.sum(1)[0]], dtype=torch.long
+ ).to(input_ids.device),
+ )
+
+ # output build
+ output_dict = dict(
+ batch_size=batch_size,
+ ned_start_logits=ned_start_logits,
+ ned_start_probabilities=ned_start_probabilities,
+ ned_start_predictions=ned_start_predictions,
+ ned_end_logits=ned_end_logits,
+ ned_end_probabilities=ned_end_probabilities,
+ ned_end_predictions=ned_end_predictions,
+ ned_type_logits=ned_type_logits,
+ ned_type_probabilities=ned_type_probabilities,
+ ned_type_predictions=ned_type_predictions,
+ re_entities_logits=re_entities_logits,
+ re_entities_probabilities=re_entities_probabilities,
+ re_entities_predictions=re_entities_predictions,
+ re_logits=re_logits,
+ re_probabilities=re_probabilities,
+ re_predictions=re_predictions,
+ )
+
+ if (
+ start_labels is not None
+ and end_labels is not None
+ and relation_labels is not None
+ ):
+ ned_start_loss = self.compute_loss(ned_start_logits, ned_start_labels)
+ ned_end_loss = self.compute_ned_end_loss(ned_end_logits, end_labels)
+ if self.entity_type_loss or self.relation_disambiguation_loss:
+ ned_type_loss = self.compute_ned_type_loss(
+ disambiguation_labels,
+ re_ned_entities_logits,
+ ned_type_logits,
+ re_entities_logits,
+ entity_types,
+ )
+ relation_loss = self.compute_relation_loss(relation_labels, re_logits)
+ # compute loss. We can skip the relation loss if we are in the first epochs (optional)
+ if self.entity_type_loss or self.relation_disambiguation_loss:
+ output_dict["loss"] = (
+ ned_start_loss + ned_end_loss + relation_loss + ned_type_loss
+ ) / 4
+ output_dict["ned_type_loss"] = ned_type_loss
+ else:
+ output_dict["loss"] = (
+ ned_start_loss + ned_end_loss + relation_loss
+ ) / 3
+
+ output_dict["ned_start_loss"] = ned_start_loss
+ output_dict["ned_end_loss"] = ned_end_loss
+ output_dict["re_loss"] = relation_loss
+
+ return output_dict
diff --git a/relik/reader/pytorch_modules/optim/__init__.py b/relik/reader/pytorch_modules/optim/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..369091133267cfa05240306fbfe5ea3b537d5d9c
--- /dev/null
+++ b/relik/reader/pytorch_modules/optim/__init__.py
@@ -0,0 +1,6 @@
+from relik.reader.pytorch_modules.optim.adamw_with_warmup import (
+ AdamWWithWarmupOptimizer,
+)
+from relik.reader.pytorch_modules.optim.layer_wise_lr_decay import (
+ LayerWiseLRDecayOptimizer,
+)
diff --git a/relik/reader/pytorch_modules/optim/adamw_with_warmup.py b/relik/reader/pytorch_modules/optim/adamw_with_warmup.py
new file mode 100644
index 0000000000000000000000000000000000000000..dfaecc4ca3d1c366f25962db4d0024a5b986fd50
--- /dev/null
+++ b/relik/reader/pytorch_modules/optim/adamw_with_warmup.py
@@ -0,0 +1,66 @@
+from typing import List
+
+import torch
+import transformers
+from torch.optim import AdamW
+
+
+class AdamWWithWarmupOptimizer:
+ def __init__(
+ self,
+ lr: float,
+ warmup_steps: int,
+ total_steps: int,
+ weight_decay: float,
+ no_decay_params: List[str],
+ ):
+ self.lr = lr
+ self.warmup_steps = warmup_steps
+ self.total_steps = total_steps
+ self.weight_decay = weight_decay
+ self.no_decay_params = no_decay_params
+
+ def group_params(self, module: torch.nn.Module) -> list:
+ if self.no_decay_params is not None:
+ optimizer_grouped_parameters = [
+ {
+ "params": [
+ p
+ for n, p in module.named_parameters()
+ if not any(nd in n for nd in self.no_decay_params)
+ ],
+ "weight_decay": self.weight_decay,
+ },
+ {
+ "params": [
+ p
+ for n, p in module.named_parameters()
+ if any(nd in n for nd in self.no_decay_params)
+ ],
+ "weight_decay": 0.0,
+ },
+ ]
+
+ else:
+ optimizer_grouped_parameters = [
+ {"params": module.parameters(), "weight_decay": self.weight_decay}
+ ]
+
+ return optimizer_grouped_parameters
+
+ def __call__(self, module: torch.nn.Module):
+ optimizer_grouped_parameters = self.group_params(module)
+ optimizer = AdamW(
+ optimizer_grouped_parameters, lr=self.lr, weight_decay=self.weight_decay
+ )
+ scheduler = transformers.get_linear_schedule_with_warmup(
+ optimizer, self.warmup_steps, self.total_steps
+ )
+ return {
+ "optimizer": optimizer,
+ "lr_scheduler": {
+ "scheduler": scheduler,
+ "interval": "step",
+ "frequency": 1,
+ },
+ }
diff --git a/relik/reader/pytorch_modules/optim/layer_wise_lr_decay.py b/relik/reader/pytorch_modules/optim/layer_wise_lr_decay.py
new file mode 100644
index 0000000000000000000000000000000000000000..d179096153f356196a921c50083c96b3dcd5f246
--- /dev/null
+++ b/relik/reader/pytorch_modules/optim/layer_wise_lr_decay.py
@@ -0,0 +1,104 @@
+import collections
+from typing import List
+
+import torch
+import transformers
+from torch.optim import AdamW
+
+
+class LayerWiseLRDecayOptimizer:
+ def __init__(
+ self,
+ lr: float,
+ warmup_steps: int,
+ total_steps: int,
+ weight_decay: float,
+ lr_decay: float,
+ no_decay_params: List[str],
+ total_reset: int,
+ ):
+ self.lr = lr
+ self.warmup_steps = warmup_steps
+ self.total_steps = total_steps
+ self.weight_decay = weight_decay
+ self.lr_decay = lr_decay
+ self.no_decay_params = no_decay_params
+ self.total_reset = total_reset
+
+ def group_layers(self, module) -> dict:
+ grouped_layers = collections.defaultdict(list)
+ module_named_parameters = list(module.named_parameters())
+ for ln, lp in module_named_parameters:
+ if "embeddings" in ln:
+ grouped_layers["embeddings"].append((ln, lp))
+ elif "encoder.layer" in ln:
+ layer_num = ln.split("transformer_model.encoder.layer.")[-1]
+ layer_num = layer_num[0 : layer_num.index(".")]
+ grouped_layers[layer_num].append((ln, lp))
+ else:
+ grouped_layers["head"].append((ln, lp))
+
+ depth = len(grouped_layers) - 1
+ final_dict = dict()
+ for key, value in grouped_layers.items():
+ if key == "head":
+ final_dict[0] = value
+ elif key == "embeddings":
+ final_dict[depth] = value
+ else:
+ # -1 because layer number starts from zero
+ final_dict[depth - int(key) - 1] = value
+
+ assert len(module_named_parameters) == sum(
+ len(v) for _, v in final_dict.items()
+ )
+
+ return final_dict
+
+ def group_params(self, module) -> list:
+ optimizer_grouped_params = []
+ for inverse_depth, layer in self.group_layers(module).items():
+ layer_lr = self.lr * (self.lr_decay**inverse_depth)
+ layer_wd_params = {
+ "params": [
+ lp
+ for ln, lp in layer
+ if not any(nd in ln for nd in self.no_decay_params)
+ ],
+ "weight_decay": self.weight_decay,
+ "lr": layer_lr,
+ }
+ layer_no_wd_params = {
+ "params": [
+ lp
+ for ln, lp in layer
+ if any(nd in ln for nd in self.no_decay_params)
+ ],
+ "weight_decay": 0,
+ "lr": layer_lr,
+ }
+
+ if len(layer_wd_params) != 0:
+ optimizer_grouped_params.append(layer_wd_params)
+ if len(layer_no_wd_params) != 0:
+ optimizer_grouped_params.append(layer_no_wd_params)
+
+ return optimizer_grouped_params
+
+ def __call__(self, module: torch.nn.Module):
+ optimizer_grouped_parameters = self.group_params(module)
+ optimizer = AdamW(optimizer_grouped_parameters, lr=self.lr)
+ scheduler = transformers.get_cosine_with_hard_restarts_schedule_with_warmup(
+ optimizer,
+ self.warmup_steps,
+ self.total_steps,
+ num_cycles=self.total_reset,
+ )
+ return {
+ "optimizer": optimizer,
+ "lr_scheduler": {
+ "scheduler": scheduler,
+ "interval": "step",
+ "frequency": 1,
+ },
+ }
diff --git a/relik/reader/pytorch_modules/span.py b/relik/reader/pytorch_modules/span.py
new file mode 100644
index 0000000000000000000000000000000000000000..349e42cafc1dfbc583adc46e7c8cf63d1d3752d8
--- /dev/null
+++ b/relik/reader/pytorch_modules/span.py
@@ -0,0 +1,367 @@
+import collections
+import contextlib
+import logging
+from typing import Any, Dict, Iterator, List
+
+import torch
+import transformers as tr
+from lightning_fabric.utilities import move_data_to_device
+from torch.utils.data import DataLoader, IterableDataset
+from tqdm import tqdm
+
+from relik.common.log import get_console_logger, get_logger
+from relik.common.utils import get_callable_from_string
+from relik.reader.data.relik_reader_sample import RelikReaderSample
+from relik.reader.pytorch_modules.base import RelikReaderBase
+from relik.reader.utils.special_symbols import get_special_symbols
+from relik.retriever.pytorch_modules import PRECISION_MAP
+
+console_logger = get_console_logger()
+logger = get_logger(__name__, level=logging.INFO)
+
+
+class RelikReaderForSpanExtraction(RelikReaderBase):
+ """
+ A class for the RelikReader model for span extraction.
+
+ Args:
+ transformer_model (:obj:`str` or :obj:`transformers.PreTrainedModel` or :obj:`None`, `optional`):
+ The transformer model to use. If `None`, the default model is used.
+ additional_special_symbols (:obj:`int`, `optional`, defaults to 0):
+ The number of additional special symbols to add to the tokenizer.
+ num_layers (:obj:`int`, `optional`):
+ The number of layers to use. If `None`, all layers are used.
+ activation (:obj:`str`, `optional`, defaults to "gelu"):
+ The activation function to use.
+ linears_hidden_size (:obj:`int`, `optional`, defaults to 512):
+ The hidden size of the linears.
+ use_last_k_layers (:obj:`int`, `optional`, defaults to 1):
+ The number of last layers to use.
+ training (:obj:`bool`, `optional`, defaults to False):
+ Whether the model is in training mode.
+ device (:obj:`str` or :obj:`torch.device` or :obj:`None`, `optional`):
+ The device to use. If `None`, the default device is used.
+ tokenizer (:obj:`str` or :obj:`transformers.PreTrainedTokenizer` or :obj:`None`, `optional`):
+ The tokenizer to use. If `None`, the default tokenizer is used.
+ dataset (:obj:`IterableDataset` or :obj:`str` or :obj:`None`, `optional`):
+ The dataset to use. If `None`, the default dataset is used.
+ dataset_kwargs (:obj:`Dict[str, Any]` or :obj:`None`, `optional`):
+ The keyword arguments to pass to the dataset class.
+ default_reader_class (:obj:`str` or :obj:`transformers.PreTrainedModel` or :obj:`None`, `optional`):
+ The default reader class to use. If `None`, the default reader class is used.
+ **kwargs:
+ Keyword arguments.
+ """
+
+ default_reader_class: str = (
+ "relik.reader.pytorch_modules.hf.modeling_relik.RelikReaderSpanModel"
+ )
+ default_data_class: str = "relik.reader.data.relik_reader_data.RelikDataset"
+
+ def __init__(
+ self,
+ transformer_model: str | tr.PreTrainedModel | None = None,
+ additional_special_symbols: int = 0,
+ num_layers: int | None = None,
+ activation: str = "gelu",
+ linears_hidden_size: int | None = 512,
+ use_last_k_layers: int = 1,
+ training: bool = False,
+ device: str | torch.device | None = None,
+ tokenizer: str | tr.PreTrainedTokenizer | None = None,
+ dataset: IterableDataset | str | None = None,
+ dataset_kwargs: Dict[str, Any] | None = None,
+ default_reader_class: tr.PreTrainedModel | str | None = None,
+ **kwargs,
+ ):
+ super().__init__(
+ transformer_model=transformer_model,
+ additional_special_symbols=additional_special_symbols,
+ num_layers=num_layers,
+ activation=activation,
+ linears_hidden_size=linears_hidden_size,
+ use_last_k_layers=use_last_k_layers,
+ training=training,
+ device=device,
+ tokenizer=tokenizer,
+ dataset=dataset,
+ default_reader_class=default_reader_class,
+ **kwargs,
+ )
+ # and instantiate the dataset class
+ self.dataset = dataset
+ if self.dataset is None:
+ default_data_kwargs = dict(
+ dataset_path=None,
+ materialize_samples=False,
+ transformer_model=self.tokenizer,
+ special_symbols=get_special_symbols(
+ self.relik_reader_model.config.additional_special_symbols
+ ),
+ for_inference=True,
+ )
+ # merge the default data kwargs with the ones passed to the model
+ default_data_kwargs.update(dataset_kwargs or {})
+ self.dataset = get_callable_from_string(self.default_data_class)(
+ **default_data_kwargs
+ )
+
+ @torch.no_grad()
+ @torch.inference_mode()
+ def _read(
+ self,
+ samples: List[RelikReaderSample] | None = None,
+ input_ids: torch.Tensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ token_type_ids: torch.Tensor | None = None,
+ prediction_mask: torch.Tensor | None = None,
+ special_symbols_mask: torch.Tensor | None = None,
+ max_length: int = 1000,
+ max_batch_size: int = 128,
+ token_batch_size: int = 2048,
+ precision: str = 32,
+ annotation_type: str = "char",
+ progress_bar: bool = False,
+ *args: object,
+ **kwargs: object,
+ ) -> List[RelikReaderSample] | List[List[RelikReaderSample]]:
+ """
+ A wrapper around the forward method that returns the predicted labels for each sample.
+
+ Args:
+ samples (:obj:`List[RelikReaderSample]`, `optional`):
+ The samples to read. If provided, `text` and `candidates` are ignored.
+ input_ids (:obj:`torch.Tensor`, `optional`):
+ The input ids of the text. If `samples` is provided, this is ignored.
+ attention_mask (:obj:`torch.Tensor`, `optional`):
+ The attention mask of the text. If `samples` is provided, this is ignored.
+ token_type_ids (:obj:`torch.Tensor`, `optional`):
+ The token type ids of the text. If `samples` is provided, this is ignored.
+ prediction_mask (:obj:`torch.Tensor`, `optional`):
+ The prediction mask of the text. If `samples` is provided, this is ignored.
+ special_symbols_mask (:obj:`torch.Tensor`, `optional`):
+ The special symbols mask of the text. If `samples` is provided, this is ignored.
+ max_length (:obj:`int`, `optional`, defaults to 1000):
+ The maximum length of the text.
+ max_batch_size (:obj:`int`, `optional`, defaults to 128):
+ The maximum batch size.
+ token_batch_size (:obj:`int`, `optional`):
+ The token batch size.
+ progress_bar (:obj:`bool`, `optional`, defaults to False):
+ Whether to show a progress bar.
+ precision (:obj:`str`, `optional`, defaults to 32):
+ The precision to use for the model.
+ annotation_type (:obj:`str`, `optional`, defaults to "char"):
+ The annotation type to use. It can be either "char", "token" or "word".
+ *args:
+ Positional arguments.
+ **kwargs:
+ Keyword arguments.
+
+ Returns:
+ :obj:`List[RelikReaderSample]` or :obj:`List[List[RelikReaderSample]]`:
+ The predicted labels for each sample.
+ """
+
+ precision = precision or self.precision
+ if samples is not None:
+
+ def _read_iterator():
+ def samples_it():
+ for i, sample in enumerate(samples):
+ assert sample._mixin_prediction_position is None
+ sample._mixin_prediction_position = i
+ yield sample
+
+ next_prediction_position = 0
+ position2predicted_sample = {}
+
+ # instantiate dataset
+ if self.dataset is None:
+ raise ValueError(
+ "You need to pass a dataset to the model in order to predict"
+ )
+ self.dataset.samples = samples_it()
+ self.dataset.model_max_length = max_length
+ self.dataset.tokens_per_batch = token_batch_size
+ self.dataset.max_batch_size = max_batch_size
+
+ # instantiate dataloader
+ iterator = DataLoader(
+ self.dataset, batch_size=None, num_workers=0, shuffle=False
+ )
+ if progress_bar:
+ iterator = tqdm(iterator, desc="Predicting with RelikReader")
+
+ # fucking autocast only wants pure strings like 'cpu' or 'cuda'
+ # we need to convert the model device to that
+ device_type_for_autocast = str(self.device).split(":")[0]
+ # autocast doesn't work with CPU and stuff different from bfloat16
+ autocast_mngr = (
+ contextlib.nullcontext()
+ if device_type_for_autocast == "cpu"
+ else (
+ torch.autocast(
+ device_type=device_type_for_autocast,
+ dtype=PRECISION_MAP[precision],
+ )
+ )
+ )
+
+ with autocast_mngr:
+ for batch in iterator:
+ batch = move_data_to_device(batch, self.device)
+ batch_out = self._batch_predict(**batch)
+
+ for sample in batch_out:
+ if (
+ sample._mixin_prediction_position
+ >= next_prediction_position
+ ):
+ position2predicted_sample[
+ sample._mixin_prediction_position
+ ] = sample
+
+ # yield
+ while next_prediction_position in position2predicted_sample:
+ yield position2predicted_sample[next_prediction_position]
+ del position2predicted_sample[next_prediction_position]
+ next_prediction_position += 1
+
+ outputs = list(_read_iterator())
+ for sample in outputs:
+ self.dataset.merge_patches_predictions(sample)
+ self.dataset.convert_tokens_to_char_annotations(sample)
+
+ else:
+ outputs = list(
+ self._batch_predict(
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ prediction_mask,
+ special_symbols_mask,
+ *args,
+ **kwargs,
+ )
+ )
+ return outputs
+
+ def _batch_predict(
+ self,
+ input_ids: torch.Tensor,
+ attention_mask: torch.Tensor,
+ token_type_ids: torch.Tensor | None = None,
+ prediction_mask: torch.Tensor | None = None,
+ special_symbols_mask: torch.Tensor | None = None,
+ sample: List[RelikReaderSample] | None = None,
+ top_k: int = 5, # the amount of top-k most probable entities to predict
+ *args,
+ **kwargs,
+ ) -> Iterator[RelikReaderSample]:
+ """
+ A wrapper around the forward method that returns the predicted labels for each sample.
+ It also adds the predicted labels to the samples.
+
+ Args:
+ input_ids (:obj:`torch.Tensor`):
+ The input ids of the text.
+ attention_mask (:obj:`torch.Tensor`):
+ The attention mask of the text.
+ token_type_ids (:obj:`torch.Tensor`, `optional`):
+ The token type ids of the text.
+ prediction_mask (:obj:`torch.Tensor`, `optional`):
+ The prediction mask of the text.
+ special_symbols_mask (:obj:`torch.Tensor`, `optional`):
+ The special symbols mask of the text.
+ sample (:obj:`List[RelikReaderSample]`, `optional`):
+ The samples to read. If provided, `text` and `candidates` are ignored.
+ top_k (:obj:`int`, `optional`, defaults to 5):
+ The amount of top-k most probable entities to predict.
+ *args:
+ Positional arguments.
+ **kwargs:
+ Keyword arguments.
+
+ Returns:
+ The predicted labels for each sample.
+ """
+ forward_output = self.forward(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ prediction_mask=prediction_mask,
+ special_symbols_mask=special_symbols_mask,
+ )
+
+ ned_start_predictions = forward_output["ned_start_predictions"].cpu().numpy()
+ ned_end_predictions = forward_output["ned_end_predictions"].cpu().numpy()
+ ed_predictions = forward_output["ed_predictions"].cpu().numpy()
+ ed_probabilities = forward_output["ed_probabilities"].cpu().numpy()
+
+ batch_predictable_candidates = kwargs["predictable_candidates"]
+ patch_offset = kwargs["patch_offset"]
+ for ts, ne_sp, ne_ep, edp, edpr, pred_cands, po in zip(
+ sample,
+ ned_start_predictions,
+ ned_end_predictions,
+ ed_predictions,
+ ed_probabilities,
+ batch_predictable_candidates,
+ patch_offset,
+ ):
+ ne_start_indices = [ti for ti, c in enumerate(ne_sp[1:]) if c > 0]
+ ne_end_indices = [ti for ti, c in enumerate(ne_ep[1:]) if c > 0]
+
+ final_class2predicted_spans = collections.defaultdict(list)
+ spans2predicted_probabilities = dict()
+ for start_token_index, end_token_index in zip(
+ ne_start_indices, ne_end_indices
+ ):
+ # predicted candidate
+ token_class = edp[start_token_index + 1] - 1
+ predicted_candidate_title = pred_cands[token_class]
+ final_class2predicted_spans[predicted_candidate_title].append(
+ [start_token_index, end_token_index]
+ )
+
+ # candidates probabilities
+ classes_probabilities = edpr[start_token_index + 1]
+ classes_probabilities_best_indices = classes_probabilities.argsort()[
+ ::-1
+ ]
+ titles_2_probs = []
+ top_k = (
+ min(
+ top_k,
+ len(classes_probabilities_best_indices),
+ )
+ if top_k != -1
+ else len(classes_probabilities_best_indices)
+ )
+ for i in range(top_k):
+ titles_2_probs.append(
+ (
+ pred_cands[classes_probabilities_best_indices[i] - 1],
+ classes_probabilities[
+ classes_probabilities_best_indices[i]
+ ].item(),
+ )
+ )
+ spans2predicted_probabilities[
+ (start_token_index, end_token_index)
+ ] = titles_2_probs
+
+ if "patches" not in ts._d:
+ ts._d["patches"] = dict()
+
+ ts._d["patches"][po] = dict()
+ sample_patch = ts._d["patches"][po]
+
+ sample_patch["predicted_window_labels"] = final_class2predicted_spans
+ sample_patch["span_title_probabilities"] = spans2predicted_probabilities
+
+ # additional info
+ sample_patch["predictable_candidates"] = pred_cands
+
+ yield ts
diff --git a/relik/reader/relik_reader.py b/relik/reader/relik_reader.py
new file mode 100644
index 0000000000000000000000000000000000000000..5acd9e8c4774593c4a61245ecf92a5559ab438f2
--- /dev/null
+++ b/relik/reader/relik_reader.py
@@ -0,0 +1,629 @@
+import collections
+import logging
+from pathlib import Path
+from typing import Any, Callable, Dict, Iterator, List, Union
+
+import torch
+import transformers as tr
+from tqdm import tqdm
+from transformers import AutoConfig
+
+from relik.common.log import get_console_logger, get_logger
+from relik.reader.data.relik_reader_data_utils import batchify, flatten
+from relik.reader.data.relik_reader_sample import RelikReaderSample
+from relik.reader.pytorch_modules.hf.modeling_relik import (
+ RelikReaderConfig,
+ RelikReaderSpanModel,
+)
+from relik.reader.relik_reader_predictor import RelikReaderPredictor
+from relik.reader.utils.save_load_utilities import load_model_and_conf
+from relik.reader.utils.special_symbols import NME_SYMBOL, get_special_symbols
+
+console_logger = get_console_logger()
+logger = get_logger(__name__, level=logging.INFO)
+
+
+class RelikReaderForSpanExtraction(torch.nn.Module):
+ def __init__(
+ self,
+ transformer_model: str | tr.PreTrainedModel | None = None,
+ additional_special_symbols: int = 0,
+ num_layers: int | None = None,
+ activation: str = "gelu",
+ linears_hidden_size: int | None = 512,
+ use_last_k_layers: int = 1,
+ training: bool = False,
+ device: str | torch.device | None = None,
+ tokenizer: str | tr.PreTrainedTokenizer | None = None,
+ **kwargs,
+ ) -> None:
+ super().__init__()
+
+ if isinstance(transformer_model, str):
+ config = AutoConfig.from_pretrained(
+ transformer_model, trust_remote_code=True
+ )
+ if "relik-reader" in config.model_type:
+ transformer_model = RelikReaderSpanModel.from_pretrained(
+ transformer_model, **kwargs
+ )
+ else:
+ reader_config = RelikReaderConfig(
+ transformer_model=transformer_model,
+ additional_special_symbols=additional_special_symbols,
+ num_layers=num_layers,
+ activation=activation,
+ linears_hidden_size=linears_hidden_size,
+ use_last_k_layers=use_last_k_layers,
+ training=training,
+ )
+ transformer_model = RelikReaderSpanModel(reader_config)
+
+ self.relik_reader_model = transformer_model
+
+ self._tokenizer = tokenizer
+
+ # move the model to the device
+ self.to(device or torch.device("cpu"))
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ attention_mask: torch.Tensor,
+ token_type_ids: torch.Tensor,
+ prediction_mask: torch.Tensor | None = None,
+ special_symbols_mask: torch.Tensor | None = None,
+ special_symbols_mask_entities: torch.Tensor | None = None,
+ start_labels: torch.Tensor | None = None,
+ end_labels: torch.Tensor | None = None,
+ disambiguation_labels: torch.Tensor | None = None,
+ relation_labels: torch.Tensor | None = None,
+ is_validation: bool = False,
+ is_prediction: bool = False,
+ *args,
+ **kwargs,
+ ) -> Dict[str, Any]:
+ return self.relik_reader_model(
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ prediction_mask,
+ special_symbols_mask,
+ special_symbols_mask_entities,
+ start_labels,
+ end_labels,
+ disambiguation_labels,
+ relation_labels,
+ is_validation,
+ is_prediction,
+ *args,
+ **kwargs,
+ )
+
+ def batch_predict(
+ self,
+ input_ids: torch.Tensor,
+ attention_mask: torch.Tensor,
+ token_type_ids: torch.Tensor | None = None,
+ prediction_mask: torch.Tensor | None = None,
+ special_symbols_mask: torch.Tensor | None = None,
+ sample: List[RelikReaderSample] | None = None,
+ top_k: int = 5, # the amount of top-k most probable entities to predict
+ *args,
+ **kwargs,
+ ) -> Iterator[RelikReaderSample]:
+ """
+
+
+ Args:
+ input_ids:
+ attention_mask:
+ token_type_ids:
+ prediction_mask:
+ special_symbols_mask:
+ sample:
+ top_k:
+ *args:
+ **kwargs:
+
+ Returns:
+
+ """
+ forward_output = self.forward(
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ prediction_mask,
+ special_symbols_mask,
+ )
+
+ ned_start_predictions = forward_output["ned_start_predictions"].cpu().numpy()
+ ned_end_predictions = forward_output["ned_end_predictions"].cpu().numpy()
+ ed_predictions = forward_output["ed_predictions"].cpu().numpy()
+ ed_probabilities = forward_output["ed_probabilities"].cpu().numpy()
+
+ batch_predictable_candidates = kwargs["predictable_candidates"]
+ patch_offset = kwargs["patch_offset"]
+ for ts, ne_sp, ne_ep, edp, edpr, pred_cands, po in zip(
+ sample,
+ ned_start_predictions,
+ ned_end_predictions,
+ ed_predictions,
+ ed_probabilities,
+ batch_predictable_candidates,
+ patch_offset,
+ ):
+ ne_start_indices = [ti for ti, c in enumerate(ne_sp[1:]) if c > 0]
+ ne_end_indices = [ti for ti, c in enumerate(ne_ep[1:]) if c > 0]
+
+ final_class2predicted_spans = collections.defaultdict(list)
+ spans2predicted_probabilities = dict()
+ for start_token_index, end_token_index in zip(
+ ne_start_indices, ne_end_indices
+ ):
+ # predicted candidate
+ token_class = edp[start_token_index + 1] - 1
+ predicted_candidate_title = pred_cands[token_class]
+ final_class2predicted_spans[predicted_candidate_title].append(
+ [start_token_index, end_token_index]
+ )
+
+ # candidates probabilities
+ classes_probabilities = edpr[start_token_index + 1]
+ classes_probabilities_best_indices = classes_probabilities.argsort()[
+ ::-1
+ ]
+ titles_2_probs = []
+ top_k = (
+ min(
+ top_k,
+ len(classes_probabilities_best_indices),
+ )
+ if top_k != -1
+ else len(classes_probabilities_best_indices)
+ )
+ for i in range(top_k):
+ titles_2_probs.append(
+ (
+ pred_cands[classes_probabilities_best_indices[i] - 1],
+ classes_probabilities[
+ classes_probabilities_best_indices[i]
+ ].item(),
+ )
+ )
+ spans2predicted_probabilities[
+ (start_token_index, end_token_index)
+ ] = titles_2_probs
+
+ if "patches" not in ts._d:
+ ts._d["patches"] = dict()
+
+ ts._d["patches"][po] = dict()
+ sample_patch = ts._d["patches"][po]
+
+ sample_patch["predicted_window_labels"] = final_class2predicted_spans
+ sample_patch["span_title_probabilities"] = spans2predicted_probabilities
+
+ # additional info
+ sample_patch["predictable_candidates"] = pred_cands
+
+ yield ts
+
+ def _build_input(self, text: List[str], candidates: List[List[str]]) -> list[str]:
+ candidates_symbols = get_special_symbols(len(candidates))
+ candidates = [
+ [cs, ct] if ct != NME_SYMBOL else [NME_SYMBOL]
+ for cs, ct in zip(candidates_symbols, candidates)
+ ]
+ return (
+ [self.tokenizer.cls_token]
+ + text
+ + [self.tokenizer.sep_token]
+ + flatten(candidates)
+ + [self.tokenizer.sep_token]
+ )
+
+ @staticmethod
+ def _compute_offsets(offsets_mapping):
+ offsets_mapping = offsets_mapping.numpy()
+ token2word = []
+ word2token = {}
+ count = 0
+ for i, offset in enumerate(offsets_mapping):
+ if offset[0] == 0:
+ token2word.append(i - count)
+ word2token[i - count] = [i]
+ else:
+ token2word.append(token2word[-1])
+ word2token[token2word[-1]].append(i)
+ count += 1
+ return token2word, word2token
+
+ @staticmethod
+ def _convert_tokens_to_word_annotations(sample: RelikReaderSample):
+ triplets = []
+ entities = []
+ for entity in sample.predicted_entities:
+ if sample.entity_candidates:
+ entities.append(
+ (
+ sample.token2word[entity[0] - 1],
+ sample.token2word[entity[1] - 1] + 1,
+ sample.entity_candidates[entity[2]],
+ )
+ )
+ else:
+ entities.append(
+ (
+ sample.token2word[entity[0] - 1],
+ sample.token2word[entity[1] - 1] + 1,
+ -1,
+ )
+ )
+ for predicted_triplet, predicted_triplet_probabilities in zip(
+ sample.predicted_relations, sample.predicted_relations_probabilities
+ ):
+ subject, object_, relation = predicted_triplet
+ subject = entities[subject]
+ object_ = entities[object_]
+ relation = sample.candidates[relation]
+ triplets.append(
+ {
+ "subject": {
+ "start": subject[0],
+ "end": subject[1],
+ "type": subject[2],
+ "name": " ".join(sample.tokens[subject[0] : subject[1]]),
+ },
+ "relation": {
+ "name": relation,
+ "probability": float(predicted_triplet_probabilities.round(2)),
+ },
+ "object": {
+ "start": object_[0],
+ "end": object_[1],
+ "type": object_[2],
+ "name": " ".join(sample.tokens[object_[0] : object_[1]]),
+ },
+ }
+ )
+ sample.predicted_entities = entities
+ sample.predicted_relations = triplets
+ sample.predicted_relations_probabilities = None
+
+ @torch.no_grad()
+ @torch.inference_mode()
+ def read(
+ self,
+ text: List[str] | List[List[str]] | None = None,
+ samples: List[RelikReaderSample] | None = None,
+ input_ids: torch.Tensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ token_type_ids: torch.Tensor | None = None,
+ prediction_mask: torch.Tensor | None = None,
+ special_symbols_mask: torch.Tensor | None = None,
+ special_symbols_mask_entities: torch.Tensor | None = None,
+ candidates: List[List[str]] | None = None,
+ max_length: int | None = 1024,
+ max_batch_size: int | None = 64,
+ token_batch_size: int | None = None,
+ progress_bar: bool = False,
+ *args,
+ **kwargs,
+ ) -> List[List[RelikReaderSample]]:
+ """
+ Reads the given text.
+ Args:
+ text: The text to read in tokens.
+ samples:
+ input_ids: The input ids of the text.
+ attention_mask: The attention mask of the text.
+ token_type_ids: The token type ids of the text.
+ prediction_mask: The prediction mask of the text.
+ special_symbols_mask: The special symbols mask of the text.
+ special_symbols_mask_entities: The special symbols mask entities of the text.
+ candidates: The candidates of the text.
+ max_length: The maximum length of the text.
+ max_batch_size: The maximum batch size.
+ token_batch_size: The maximum number of tokens per batch.
+ progress_bar:
+ Returns:
+ The predicted labels for each sample.
+ """
+ if text is None and input_ids is None and samples is None:
+ raise ValueError(
+ "Either `text` or `input_ids` or `samples` must be provided."
+ )
+ if (input_ids is None and samples is None) and (
+ text is None or candidates is None
+ ):
+ raise ValueError(
+ "`text` and `candidates` must be provided to return the predictions when "
+ "`input_ids` and `samples` is not provided."
+ )
+ if text is not None and samples is None:
+ if len(text) != len(candidates):
+ raise ValueError("`text` and `candidates` must have the same length.")
+ if isinstance(text[0], str): # change to list of text
+ text = [text]
+ candidates = [candidates]
+
+ samples = [
+ RelikReaderSample(tokens=t, candidates=c)
+ for t, c in zip(text, candidates)
+ ]
+
+ if samples is not None:
+ # function that creates a batch from the 'current_batch' list
+ def output_batch() -> Dict[str, Any]:
+ assert (
+ len(
+ set(
+ [
+ len(elem["predictable_candidates"])
+ for elem in current_batch
+ ]
+ )
+ )
+ == 1
+ ), " ".join(
+ map(
+ str,
+ [len(elem["predictable_candidates"]) for elem in current_batch],
+ )
+ )
+
+ batch_dict = dict()
+
+ de_values_by_field = {
+ fn: [de[fn] for de in current_batch if fn in de]
+ for fn in self.fields_batcher
+ }
+
+ # in case you provide fields batchers but in the batch
+ # there are no elements for that field
+ de_values_by_field = {
+ fn: fvs for fn, fvs in de_values_by_field.items() if len(fvs) > 0
+ }
+
+ assert len(set([len(v) for v in de_values_by_field.values()]))
+
+ # todo: maybe we should report the user about possible
+ # fields filtering due to "None" instances
+ de_values_by_field = {
+ fn: fvs
+ for fn, fvs in de_values_by_field.items()
+ if all([fv is not None for fv in fvs])
+ }
+
+ for field_name, field_values in de_values_by_field.items():
+ field_batch = (
+ self.fields_batcher[field_name]([fv[0] for fv in field_values])
+ if self.fields_batcher[field_name] is not None
+ else field_values
+ )
+
+ batch_dict[field_name] = field_batch
+
+ batch_dict = {
+ k: v.to(self.device) if isinstance(v, torch.Tensor) else v
+ for k, v in batch_dict.items()
+ }
+ return batch_dict
+
+ current_batch = []
+ predictions = []
+ current_cand_len = -1
+
+ for sample in tqdm(samples, disable=not progress_bar):
+ sample.candidates = [NME_SYMBOL] + sample.candidates
+ inputs_text = self._build_input(sample.tokens, sample.candidates)
+ model_inputs = self.tokenizer(
+ inputs_text,
+ is_split_into_words=True,
+ add_special_tokens=False,
+ padding=False,
+ truncation=True,
+ max_length=max_length or self.tokenizer.model_max_length,
+ return_offsets_mapping=True,
+ return_tensors="pt",
+ )
+ model_inputs["special_symbols_mask"] = (
+ model_inputs["input_ids"] > self.tokenizer.vocab_size
+ )
+ # prediction mask is 0 until the first special symbol
+ model_inputs["token_type_ids"] = (
+ torch.cumsum(model_inputs["special_symbols_mask"], dim=1) > 0
+ ).long()
+ # shift prediction_mask to the left
+ model_inputs["prediction_mask"] = model_inputs["token_type_ids"].roll(
+ shifts=-1, dims=1
+ )
+ model_inputs["prediction_mask"][:, -1] = 1
+ model_inputs["prediction_mask"][:, 0] = 1
+
+ assert (
+ len(model_inputs["special_symbols_mask"])
+ == len(model_inputs["prediction_mask"])
+ == len(model_inputs["input_ids"])
+ )
+
+ model_inputs["sample"] = sample
+
+ # compute cand_len using special_symbols_mask
+ model_inputs["predictable_candidates"] = sample.candidates[
+ : model_inputs["special_symbols_mask"].sum().item()
+ ]
+ # cand_len = sum([id_ > self.tokenizer.vocab_size for id_ in model_inputs["input_ids"]])
+ offsets = model_inputs.pop("offset_mapping")
+ offsets = offsets[model_inputs["prediction_mask"] == 0]
+ sample.token2word, sample.word2token = self._compute_offsets(offsets)
+ future_max_len = max(
+ len(model_inputs["input_ids"]),
+ max([len(b["input_ids"]) for b in current_batch], default=0),
+ )
+ future_tokens_per_batch = future_max_len * (len(current_batch) + 1)
+
+ if len(current_batch) > 0 and (
+ (
+ len(model_inputs["predictable_candidates"]) != current_cand_len
+ and current_cand_len != -1
+ )
+ or (
+ isinstance(token_batch_size, int)
+ and future_tokens_per_batch >= token_batch_size
+ )
+ or len(current_batch) == max_batch_size
+ ):
+ batch_inputs = output_batch()
+ current_batch = []
+ predictions.extend(list(self.batch_predict(**batch_inputs)))
+ current_cand_len = len(model_inputs["predictable_candidates"])
+ current_batch.append(model_inputs)
+
+ if current_batch:
+ batch_inputs = output_batch()
+ predictions.extend(list(self.batch_predict(**batch_inputs)))
+ else:
+ predictions = list(
+ self.batch_predict(
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ prediction_mask,
+ special_symbols_mask,
+ special_symbols_mask_entities,
+ *args,
+ **kwargs,
+ )
+ )
+ return predictions
+
+ @property
+ def device(self) -> torch.device:
+ """
+ The device of the model.
+ """
+ return next(self.parameters()).device
+
+ @property
+ def tokenizer(self) -> tr.PreTrainedTokenizer:
+ """
+ The tokenizer.
+ """
+ if self._tokenizer:
+ return self._tokenizer
+
+ self._tokenizer = tr.AutoTokenizer.from_pretrained(
+ self.relik_reader_model.config.name_or_path
+ )
+ return self._tokenizer
+
+ @property
+ def fields_batcher(self) -> Dict[str, Union[None, Callable[[list], Any]]]:
+ fields_batchers = {
+ "input_ids": lambda x: batchify(
+ x, padding_value=self.tokenizer.pad_token_id
+ ),
+ "attention_mask": lambda x: batchify(x, padding_value=0),
+ "token_type_ids": lambda x: batchify(x, padding_value=0),
+ "prediction_mask": lambda x: batchify(x, padding_value=1),
+ "global_attention": lambda x: batchify(x, padding_value=0),
+ "token2word": None,
+ "sample": None,
+ "special_symbols_mask": lambda x: batchify(x, padding_value=False),
+ "special_symbols_mask_entities": lambda x: batchify(x, padding_value=False),
+ }
+ if "roberta" in self.relik_reader_model.config.model_type:
+ del fields_batchers["token_type_ids"]
+
+ return fields_batchers
+
+ def save_pretrained(
+ self,
+ output_dir: str,
+ model_name: str | None = None,
+ push_to_hub: bool = False,
+ **kwargs,
+ ) -> None:
+ """
+ Saves the model to the given path.
+ Args:
+ output_dir: The path to save the model to.
+ model_name: The name of the model.
+ push_to_hub: Whether to push the model to the hub.
+ """
+ # create the output directory
+ output_dir = Path(output_dir)
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ model_name = model_name or "relik-reader-for-span-extraction"
+
+ logger.info(f"Saving reader to {output_dir / model_name}")
+
+ # save the model
+ self.relik_reader_model.register_for_auto_class()
+ self.relik_reader_model.save_pretrained(
+ output_dir / model_name, push_to_hub=push_to_hub, **kwargs
+ )
+
+ logger.info("Saving reader to disk done.")
+
+ if self.tokenizer:
+ self.tokenizer.save_pretrained(
+ output_dir / model_name, push_to_hub=push_to_hub, **kwargs
+ )
+ logger.info("Saving tokenizer to disk done.")
+
+
+class RelikReader:
+ def __init__(self, model_path: str, predict_nmes: bool = False):
+ model, model_conf = load_model_and_conf(model_path)
+ model.training = False
+ model.eval()
+
+ val_dataset_conf = model_conf.data.val_dataset
+ val_dataset_conf.special_symbols = get_special_symbols(
+ model_conf.model.entities_per_forward
+ )
+ val_dataset_conf.transformer_model = model_conf.model.model.transformer_model
+
+ self.predictor = RelikReaderPredictor(
+ model,
+ dataset_conf=model_conf.data.val_dataset,
+ predict_nmes=predict_nmes,
+ )
+ self.model_path = model_path
+
+ def link_entities(
+ self,
+ dataset_path_or_samples: str | Iterator[RelikReaderSample],
+ token_batch_size: int = 2048,
+ progress_bar: bool = False,
+ ) -> List[RelikReaderSample]:
+ data_input = (
+ (dataset_path_or_samples, None)
+ if isinstance(dataset_path_or_samples, str)
+ else (None, dataset_path_or_samples)
+ )
+ return self.predictor.predict(
+ *data_input,
+ dataset_conf=None,
+ token_batch_size=token_batch_size,
+ progress_bar=progress_bar,
+ )
+
+ # def save_pretrained(self, path: Union[str, Path]):
+ # self.predictor.save(path)
+
+
+def main():
+ rr = RelikReader("riccorl/relik-reader-aida-deberta-small-old", predict_nmes=True)
+ predictions = rr.link_entities(
+ "/Users/ric/Documents/PhD/Projects/relik/data/reader/aida/testa.jsonl"
+ )
+ print(predictions)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/relik/reader/relik_reader_core.py b/relik/reader/relik_reader_core.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d62c5f13b3c1f7e7ba02209d2c88813d4f960ac
--- /dev/null
+++ b/relik/reader/relik_reader_core.py
@@ -0,0 +1,497 @@
+import collections
+from typing import Any, Dict, Iterator, List, Optional
+
+import torch
+from transformers import AutoModel
+from transformers.activations import ClippedGELUActivation, GELUActivation
+from transformers.modeling_utils import PoolerEndLogits
+
+from relik.reader.data.relik_reader_sample import RelikReaderSample
+
+activation2functions = {
+ "relu": torch.nn.ReLU(),
+ "gelu": GELUActivation(),
+ "gelu_10": ClippedGELUActivation(-10, 10),
+}
+
+
+class RelikReaderCoreModel(torch.nn.Module):
+ def __init__(
+ self,
+ transformer_model: str,
+ additional_special_symbols: int,
+ num_layers: Optional[int] = None,
+ activation: str = "gelu",
+ linears_hidden_size: Optional[int] = 512,
+ use_last_k_layers: int = 1,
+ training: bool = False,
+ ) -> None:
+ super().__init__()
+
+ # Transformer model declaration
+ self.transformer_model_name = transformer_model
+ self.transformer_model = (
+ AutoModel.from_pretrained(transformer_model)
+ if num_layers is None
+ else AutoModel.from_pretrained(
+ transformer_model, num_hidden_layers=num_layers
+ )
+ )
+ self.transformer_model.resize_token_embeddings(
+ self.transformer_model.config.vocab_size + additional_special_symbols
+ )
+
+ self.activation = activation
+ self.linears_hidden_size = linears_hidden_size
+ self.use_last_k_layers = use_last_k_layers
+
+ # named entity detection layers
+ self.ned_start_classifier = self._get_projection_layer(
+ self.activation, last_hidden=2, layer_norm=False
+ )
+ self.ned_end_classifier = PoolerEndLogits(self.transformer_model.config)
+
+ # END entity disambiguation layer
+ self.ed_start_projector = self._get_projection_layer(self.activation)
+ self.ed_end_projector = self._get_projection_layer(self.activation)
+
+ self.training = training
+
+ # criterion
+ self.criterion = torch.nn.CrossEntropyLoss()
+
+ def _get_projection_layer(
+ self,
+ activation: str,
+ last_hidden: Optional[int] = None,
+ input_hidden=None,
+ layer_norm: bool = True,
+ ) -> torch.nn.Sequential:
+ head_components = [
+ torch.nn.Dropout(0.1),
+ torch.nn.Linear(
+ self.transformer_model.config.hidden_size * self.use_last_k_layers
+ if input_hidden is None
+ else input_hidden,
+ self.linears_hidden_size,
+ ),
+ activation2functions[activation],
+ torch.nn.Dropout(0.1),
+ torch.nn.Linear(
+ self.linears_hidden_size,
+ self.linears_hidden_size if last_hidden is None else last_hidden,
+ ),
+ ]
+
+ if layer_norm:
+ head_components.append(
+ torch.nn.LayerNorm(
+ self.linears_hidden_size if last_hidden is None else last_hidden,
+ self.transformer_model.config.layer_norm_eps,
+ )
+ )
+
+ return torch.nn.Sequential(*head_components)
+
+ def _mask_logits(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
+ mask = mask.unsqueeze(-1)
+ if next(self.parameters()).dtype == torch.float16:
+ logits = logits * (1 - mask) - 65500 * mask
+ else:
+ logits = logits * (1 - mask) - 1e30 * mask
+ return logits
+
+ def _get_model_features(
+ self,
+ input_ids: torch.Tensor,
+ attention_mask: torch.Tensor,
+ token_type_ids: Optional[torch.Tensor],
+ ):
+ model_input = {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "output_hidden_states": self.use_last_k_layers > 1,
+ }
+
+ if token_type_ids is not None:
+ model_input["token_type_ids"] = token_type_ids
+
+ model_output = self.transformer_model(**model_input)
+
+ if self.use_last_k_layers > 1:
+ model_features = torch.cat(
+ model_output[1][-self.use_last_k_layers :], dim=-1
+ )
+ else:
+ model_features = model_output[0]
+
+ return model_features
+
+ def compute_ned_end_logits(
+ self,
+ start_predictions,
+ start_labels,
+ model_features,
+ prediction_mask,
+ batch_size,
+ ) -> Optional[torch.Tensor]:
+ # todo: maybe when constraining on the spans,
+ # we should not use a prediction_mask for the end tokens.
+ # at least we should not during training imo
+ start_positions = start_labels if self.training else start_predictions
+ start_positions_indices = (
+ torch.arange(start_positions.size(1), device=start_positions.device)
+ .unsqueeze(0)
+ .expand(batch_size, -1)[start_positions > 0]
+ ).to(start_positions.device)
+
+ if len(start_positions_indices) > 0:
+ expanded_features = torch.cat(
+ [
+ model_features[i].unsqueeze(0).expand(x, -1, -1)
+ for i, x in enumerate(torch.sum(start_positions > 0, dim=-1))
+ if x > 0
+ ],
+ dim=0,
+ ).to(start_positions_indices.device)
+
+ expanded_prediction_mask = torch.cat(
+ [
+ prediction_mask[i].unsqueeze(0).expand(x, -1)
+ for i, x in enumerate(torch.sum(start_positions > 0, dim=-1))
+ if x > 0
+ ],
+ dim=0,
+ ).to(expanded_features.device)
+
+ end_logits = self.ned_end_classifier(
+ hidden_states=expanded_features,
+ start_positions=start_positions_indices,
+ p_mask=expanded_prediction_mask,
+ )
+
+ return end_logits
+
+ return None
+
+ def compute_classification_logits(
+ self,
+ model_features,
+ special_symbols_mask,
+ prediction_mask,
+ batch_size,
+ start_positions=None,
+ end_positions=None,
+ ) -> torch.Tensor:
+ if start_positions is None or end_positions is None:
+ start_positions = torch.zeros_like(prediction_mask)
+ end_positions = torch.zeros_like(prediction_mask)
+
+ model_start_features = self.ed_start_projector(model_features)
+ model_end_features = self.ed_end_projector(model_features)
+ model_end_features[start_positions > 0] = model_end_features[end_positions > 0]
+
+ model_ed_features = torch.cat(
+ [model_start_features, model_end_features], dim=-1
+ )
+
+ # computing ed features
+ classes_representations = torch.sum(special_symbols_mask, dim=1)[0].item()
+ special_symbols_representation = model_ed_features[special_symbols_mask].view(
+ batch_size, classes_representations, -1
+ )
+
+ logits = torch.bmm(
+ model_ed_features,
+ torch.permute(special_symbols_representation, (0, 2, 1)),
+ )
+
+ logits = self._mask_logits(logits, prediction_mask)
+
+ return logits
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ attention_mask: torch.Tensor,
+ token_type_ids: Optional[torch.Tensor] = None,
+ prediction_mask: Optional[torch.Tensor] = None,
+ special_symbols_mask: Optional[torch.Tensor] = None,
+ start_labels: Optional[torch.Tensor] = None,
+ end_labels: Optional[torch.Tensor] = None,
+ use_predefined_spans: bool = False,
+ *args,
+ **kwargs,
+ ) -> Dict[str, Any]:
+ batch_size, seq_len = input_ids.shape
+
+ model_features = self._get_model_features(
+ input_ids, attention_mask, token_type_ids
+ )
+
+ # named entity detection if required
+ if use_predefined_spans: # no need to compute spans
+ ned_start_logits, ned_start_probabilities, ned_start_predictions = (
+ None,
+ None,
+ torch.clone(start_labels)
+ if start_labels is not None
+ else torch.zeros_like(input_ids),
+ )
+ ned_end_logits, ned_end_probabilities, ned_end_predictions = (
+ None,
+ None,
+ torch.clone(end_labels)
+ if end_labels is not None
+ else torch.zeros_like(input_ids),
+ )
+
+ ned_start_predictions[ned_start_predictions > 0] = 1
+ ned_end_predictions[ned_end_predictions > 0] = 1
+
+ else: # compute spans
+ # start boundary prediction
+ ned_start_logits = self.ned_start_classifier(model_features)
+ ned_start_logits = self._mask_logits(ned_start_logits, prediction_mask)
+ ned_start_probabilities = torch.softmax(ned_start_logits, dim=-1)
+ ned_start_predictions = ned_start_probabilities.argmax(dim=-1)
+
+ # end boundary prediction
+ ned_start_labels = (
+ torch.zeros_like(start_labels) if start_labels is not None else None
+ )
+
+ if ned_start_labels is not None:
+ ned_start_labels[start_labels == -100] = -100
+ ned_start_labels[start_labels > 0] = 1
+
+ ned_end_logits = self.compute_ned_end_logits(
+ ned_start_predictions,
+ ned_start_labels,
+ model_features,
+ prediction_mask,
+ batch_size,
+ )
+
+ if ned_end_logits is not None:
+ ned_end_probabilities = torch.softmax(ned_end_logits, dim=-1)
+ ned_end_predictions = torch.argmax(ned_end_probabilities, dim=-1)
+ else:
+ ned_end_logits, ned_end_probabilities = None, None
+ ned_end_predictions = ned_start_predictions.new_zeros(batch_size)
+
+ # flattening end predictions
+ # (flattening can happen only if the
+ # end boundaries were not predicted using the gold labels)
+ if not self.training:
+ flattened_end_predictions = torch.clone(ned_start_predictions)
+ flattened_end_predictions[flattened_end_predictions > 0] = 0
+
+ batch_start_predictions = list()
+ for elem_idx in range(batch_size):
+ batch_start_predictions.append(
+ torch.where(ned_start_predictions[elem_idx] > 0)[0].tolist()
+ )
+
+ # check that the total number of start predictions
+ # is equal to the end predictions
+ total_start_predictions = sum(map(len, batch_start_predictions))
+ total_end_predictions = len(ned_end_predictions)
+ assert (
+ total_start_predictions == 0
+ or total_start_predictions == total_end_predictions
+ ), (
+ f"Total number of start predictions = {total_start_predictions}. "
+ f"Total number of end predictions = {total_end_predictions}"
+ )
+
+ curr_end_pred_num = 0
+ for elem_idx, bsp in enumerate(batch_start_predictions):
+ for sp in bsp:
+ ep = ned_end_predictions[curr_end_pred_num].item()
+ if ep < sp:
+ ep = sp
+
+ # if we already set this span throw it (no overlap)
+ if flattened_end_predictions[elem_idx, ep] == 1:
+ ned_start_predictions[elem_idx, sp] = 0
+ else:
+ flattened_end_predictions[elem_idx, ep] = 1
+
+ curr_end_pred_num += 1
+
+ ned_end_predictions = flattened_end_predictions
+
+ start_position, end_position = (
+ (start_labels, end_labels)
+ if self.training
+ else (ned_start_predictions, ned_end_predictions)
+ )
+
+ # Entity disambiguation
+ ed_logits = self.compute_classification_logits(
+ model_features,
+ special_symbols_mask,
+ prediction_mask,
+ batch_size,
+ start_position,
+ end_position,
+ )
+ ed_probabilities = torch.softmax(ed_logits, dim=-1)
+ ed_predictions = torch.argmax(ed_probabilities, dim=-1)
+
+ # output build
+ output_dict = dict(
+ batch_size=batch_size,
+ ned_start_logits=ned_start_logits,
+ ned_start_probabilities=ned_start_probabilities,
+ ned_start_predictions=ned_start_predictions,
+ ned_end_logits=ned_end_logits,
+ ned_end_probabilities=ned_end_probabilities,
+ ned_end_predictions=ned_end_predictions,
+ ed_logits=ed_logits,
+ ed_probabilities=ed_probabilities,
+ ed_predictions=ed_predictions,
+ )
+
+ # compute loss if labels
+ if start_labels is not None and end_labels is not None and self.training:
+ # named entity detection loss
+
+ # start
+ if ned_start_logits is not None:
+ ned_start_loss = self.criterion(
+ ned_start_logits.view(-1, ned_start_logits.shape[-1]),
+ ned_start_labels.view(-1),
+ )
+ else:
+ ned_start_loss = 0
+
+ # end
+ if ned_end_logits is not None:
+ ned_end_labels = torch.zeros_like(end_labels)
+ ned_end_labels[end_labels == -100] = -100
+ ned_end_labels[end_labels > 0] = 1
+
+ ned_end_loss = self.criterion(
+ ned_end_logits,
+ (
+ torch.arange(
+ ned_end_labels.size(1), device=ned_end_labels.device
+ )
+ .unsqueeze(0)
+ .expand(batch_size, -1)[ned_end_labels > 0]
+ ).to(ned_end_labels.device),
+ )
+
+ else:
+ ned_end_loss = 0
+
+ # entity disambiguation loss
+ start_labels[ned_start_labels != 1] = -100
+ ed_labels = torch.clone(start_labels)
+ ed_labels[end_labels > 0] = end_labels[end_labels > 0]
+ ed_loss = self.criterion(
+ ed_logits.view(-1, ed_logits.shape[-1]),
+ ed_labels.view(-1),
+ )
+
+ output_dict["ned_start_loss"] = ned_start_loss
+ output_dict["ned_end_loss"] = ned_end_loss
+ output_dict["ed_loss"] = ed_loss
+
+ output_dict["loss"] = ned_start_loss + ned_end_loss + ed_loss
+
+ return output_dict
+
+ def batch_predict(
+ self,
+ input_ids: torch.Tensor,
+ attention_mask: torch.Tensor,
+ token_type_ids: Optional[torch.Tensor] = None,
+ prediction_mask: Optional[torch.Tensor] = None,
+ special_symbols_mask: Optional[torch.Tensor] = None,
+ sample: Optional[List[RelikReaderSample]] = None,
+ top_k: int = 5, # the amount of top-k most probable entities to predict
+ *args,
+ **kwargs,
+ ) -> Iterator[RelikReaderSample]:
+ forward_output = self.forward(
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ prediction_mask,
+ special_symbols_mask,
+ )
+
+ ned_start_predictions = forward_output["ned_start_predictions"].cpu().numpy()
+ ned_end_predictions = forward_output["ned_end_predictions"].cpu().numpy()
+ ed_predictions = forward_output["ed_predictions"].cpu().numpy()
+ ed_probabilities = forward_output["ed_probabilities"].cpu().numpy()
+
+ batch_predictable_candidates = kwargs["predictable_candidates"]
+ patch_offset = kwargs["patch_offset"]
+ for ts, ne_sp, ne_ep, edp, edpr, pred_cands, po in zip(
+ sample,
+ ned_start_predictions,
+ ned_end_predictions,
+ ed_predictions,
+ ed_probabilities,
+ batch_predictable_candidates,
+ patch_offset,
+ ):
+ ne_start_indices = [ti for ti, c in enumerate(ne_sp[1:]) if c > 0]
+ ne_end_indices = [ti for ti, c in enumerate(ne_ep[1:]) if c > 0]
+
+ final_class2predicted_spans = collections.defaultdict(list)
+ spans2predicted_probabilities = dict()
+ for start_token_index, end_token_index in zip(
+ ne_start_indices, ne_end_indices
+ ):
+ # predicted candidate
+ token_class = edp[start_token_index + 1] - 1
+ predicted_candidate_title = pred_cands[token_class]
+ final_class2predicted_spans[predicted_candidate_title].append(
+ [start_token_index, end_token_index]
+ )
+
+ # candidates probabilities
+ classes_probabilities = edpr[start_token_index + 1]
+ classes_probabilities_best_indices = classes_probabilities.argsort()[
+ ::-1
+ ]
+ titles_2_probs = []
+ top_k = (
+ min(
+ top_k,
+ len(classes_probabilities_best_indices),
+ )
+ if top_k != -1
+ else len(classes_probabilities_best_indices)
+ )
+ for i in range(top_k):
+ titles_2_probs.append(
+ (
+ pred_cands[classes_probabilities_best_indices[i] - 1],
+ classes_probabilities[
+ classes_probabilities_best_indices[i]
+ ].item(),
+ )
+ )
+ spans2predicted_probabilities[
+ (start_token_index, end_token_index)
+ ] = titles_2_probs
+
+ if "patches" not in ts._d:
+ ts._d["patches"] = dict()
+
+ ts._d["patches"][po] = dict()
+ sample_patch = ts._d["patches"][po]
+
+ sample_patch["predicted_window_labels"] = final_class2predicted_spans
+ sample_patch["span_title_probabilities"] = spans2predicted_probabilities
+
+ # additional info
+ sample_patch["predictable_candidates"] = pred_cands
+
+ yield ts
diff --git a/relik/reader/relik_reader_predictor.py b/relik/reader/relik_reader_predictor.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5635d477d67febce3c937ef1945900f004bb269
--- /dev/null
+++ b/relik/reader/relik_reader_predictor.py
@@ -0,0 +1,168 @@
+import logging
+from typing import Iterable, Iterator, List, Optional
+
+import hydra
+import torch
+from lightning.pytorch.utilities import move_data_to_device
+from torch.utils.data import DataLoader
+from tqdm import tqdm
+
+from relik.reader.data.patches import merge_patches_predictions
+from relik.reader.data.relik_reader_sample import (
+ RelikReaderSample,
+ load_relik_reader_samples,
+)
+from relik.reader.relik_reader_core import RelikReaderCoreModel
+from relik.reader.utils.special_symbols import NME_SYMBOL
+
+logger = logging.getLogger(__name__)
+
+
+def convert_tokens_to_char_annotations(
+ sample: RelikReaderSample, remove_nmes: bool = False
+):
+ char_annotations = set()
+
+ for (
+ predicted_entity,
+ predicted_spans,
+ ) in sample.predicted_window_labels.items():
+ if predicted_entity == NME_SYMBOL and remove_nmes:
+ continue
+
+ for span_start, span_end in predicted_spans:
+ span_start = sample.token2char_start[str(span_start)]
+ span_end = sample.token2char_end[str(span_end)]
+
+ char_annotations.add((span_start, span_end, predicted_entity))
+
+ char_probs_annotations = dict()
+ for (
+ span_start,
+ span_end,
+ ), candidates_probs in sample.span_title_probabilities.items():
+ span_start = sample.token2char_start[str(span_start)]
+ span_end = sample.token2char_end[str(span_end)]
+ char_probs_annotations[(span_start, span_end)] = {
+ title for title, _ in candidates_probs
+ }
+
+ sample.predicted_window_labels_chars = char_annotations
+ sample.probs_window_labels_chars = char_probs_annotations
+
+
+class RelikReaderPredictor:
+ def __init__(
+ self,
+ relik_reader_core: RelikReaderCoreModel,
+ dataset_conf: Optional[dict] = None,
+ predict_nmes: bool = False,
+ ) -> None:
+ self.relik_reader_core = relik_reader_core
+ self.dataset_conf = dataset_conf
+ self.predict_nmes = predict_nmes
+
+ if self.dataset_conf is not None:
+ # instantiate dataset
+ self.dataset = hydra.utils.instantiate(
+ dataset_conf,
+ dataset_path=None,
+ samples=None,
+ )
+
+ def predict(
+ self,
+ path: Optional[str],
+ samples: Optional[Iterable[RelikReaderSample]],
+ dataset_conf: Optional[dict],
+ token_batch_size: int = 1024,
+ progress_bar: bool = False,
+ **kwargs,
+ ) -> List[RelikReaderSample]:
+ annotated_samples = list(
+ self._predict(path, samples, dataset_conf, token_batch_size, progress_bar)
+ )
+ for sample in annotated_samples:
+ merge_patches_predictions(sample)
+ convert_tokens_to_char_annotations(
+ sample, remove_nmes=not self.predict_nmes
+ )
+ return annotated_samples
+
+ def _predict(
+ self,
+ path: Optional[str],
+ samples: Optional[Iterable[RelikReaderSample]],
+ dataset_conf: dict,
+ token_batch_size: int = 1024,
+ progress_bar: bool = False,
+ **kwargs,
+ ) -> Iterator[RelikReaderSample]:
+ assert (
+ path is not None or samples is not None
+ ), "Either predict on a path or on an iterable of samples"
+
+ samples = load_relik_reader_samples(path) if samples is None else samples
+
+ # setup infrastructure to re-yield in order
+ def samples_it():
+ for i, sample in enumerate(samples):
+ assert sample._mixin_prediction_position is None
+ sample._mixin_prediction_position = i
+ yield sample
+
+ next_prediction_position = 0
+ position2predicted_sample = {}
+
+ # instantiate dataset
+ if getattr(self, "dataset", None) is not None:
+ dataset = self.dataset
+ dataset.samples = samples_it()
+ dataset.tokens_per_batch = token_batch_size
+ else:
+ dataset = hydra.utils.instantiate(
+ dataset_conf,
+ dataset_path=None,
+ samples=samples_it(),
+ tokens_per_batch=token_batch_size,
+ )
+
+ # instantiate dataloader
+ iterator = DataLoader(dataset, batch_size=None, num_workers=0, shuffle=False)
+ if progress_bar:
+ iterator = tqdm(iterator, desc="Predicting")
+
+ model_device = next(self.relik_reader_core.parameters()).device
+
+ with torch.inference_mode():
+ for batch in iterator:
+ # do batch predict
+ with torch.autocast(
+ "cpu" if model_device == torch.device("cpu") else "cuda"
+ ):
+ batch = move_data_to_device(batch, model_device)
+ batch_out = self.relik_reader_core.batch_predict(**batch)
+ # update prediction position position
+ for sample in batch_out:
+ if sample._mixin_prediction_position >= next_prediction_position:
+ position2predicted_sample[
+ sample._mixin_prediction_position
+ ] = sample
+
+ # yield
+ while next_prediction_position in position2predicted_sample:
+ yield position2predicted_sample[next_prediction_position]
+ del position2predicted_sample[next_prediction_position]
+ next_prediction_position += 1
+
+ if len(position2predicted_sample) > 0:
+ logger.warning(
+ "It seems samples have been discarded in your dataset. "
+ "This means that you WON'T have a prediction for each input sample. "
+ "Prediction order will also be partially disrupted"
+ )
+ for k, v in sorted(position2predicted_sample.items(), key=lambda x: x[0]):
+ yield v
+
+ if progress_bar:
+ iterator.close()
diff --git a/relik/reader/relik_reader_re.py b/relik/reader/relik_reader_re.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1efaa87110863901c522b277d6e594989b21997
--- /dev/null
+++ b/relik/reader/relik_reader_re.py
@@ -0,0 +1,556 @@
+import logging
+from pathlib import Path
+from typing import Any, Callable, Dict, Iterator, List, Optional, Union
+
+import numpy as np
+import torch
+import transformers as tr
+from reader.data.relik_reader_data_utils import batchify, flatten
+from reader.data.relik_reader_sample import RelikReaderSample
+from reader.pytorch_modules.hf.modeling_relik import (
+ RelikReaderConfig,
+ RelikReaderREModel,
+)
+from tqdm import tqdm
+from transformers import AutoConfig
+
+from relik.common.log import get_console_logger, get_logger
+from relik.reader.utils.special_symbols import NME_SYMBOL, get_special_symbols_re
+
+console_logger = get_console_logger()
+logger = get_logger(__name__, level=logging.INFO)
+
+
+class RelikReaderForTripletExtraction(torch.nn.Module):
+ def __init__(
+ self,
+ transformer_model: Optional[Union[str, tr.PreTrainedModel]] = None,
+ additional_special_symbols: Optional[int] = 0,
+ num_layers: Optional[int] = None,
+ activation: str = "gelu",
+ linears_hidden_size: Optional[int] = 512,
+ use_last_k_layers: int = 1,
+ training: bool = False,
+ device: Optional[Union[str, torch.device]] = None,
+ tokenizer: Optional[Union[str, tr.PreTrainedTokenizer]] = None,
+ **kwargs,
+ ) -> None:
+ super().__init__()
+
+ if isinstance(transformer_model, str):
+ config = AutoConfig.from_pretrained(
+ transformer_model, trust_remote_code=True
+ )
+ if "relik_reader" in config.model_type:
+ transformer_model = RelikReaderREModel.from_pretrained(
+ transformer_model, **kwargs
+ )
+ else:
+ reader_config = RelikReaderConfig(
+ transformer_model=transformer_model,
+ additional_special_symbols=additional_special_symbols,
+ num_layers=num_layers,
+ activation=activation,
+ linears_hidden_size=linears_hidden_size,
+ use_last_k_layers=use_last_k_layers,
+ training=training,
+ )
+ transformer_model = RelikReaderREModel(reader_config)
+
+ self.relik_reader_re_model = transformer_model
+
+ self._tokenizer = tokenizer
+
+ # move the model to the device
+ self.to(device or torch.device("cpu"))
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ attention_mask: torch.Tensor,
+ token_type_ids: torch.Tensor,
+ prediction_mask: Optional[torch.Tensor] = None,
+ special_symbols_mask: Optional[torch.Tensor] = None,
+ special_symbols_mask_entities: Optional[torch.Tensor] = None,
+ start_labels: Optional[torch.Tensor] = None,
+ end_labels: Optional[torch.Tensor] = None,
+ disambiguation_labels: Optional[torch.Tensor] = None,
+ relation_labels: Optional[torch.Tensor] = None,
+ is_validation: bool = False,
+ is_prediction: bool = False,
+ *args,
+ **kwargs,
+ ) -> Dict[str, Any]:
+ return self.relik_reader_re_model(
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ prediction_mask,
+ special_symbols_mask,
+ special_symbols_mask_entities,
+ start_labels,
+ end_labels,
+ disambiguation_labels,
+ relation_labels,
+ is_validation,
+ is_prediction,
+ *args,
+ **kwargs,
+ )
+
+ def batch_predict(
+ self,
+ input_ids: torch.Tensor,
+ attention_mask: torch.Tensor,
+ token_type_ids: Optional[torch.Tensor] = None,
+ prediction_mask: Optional[torch.Tensor] = None,
+ special_symbols_mask: Optional[torch.Tensor] = None,
+ special_symbols_mask_entities: Optional[torch.Tensor] = None,
+ sample: Optional[List[RelikReaderSample]] = None,
+ *args,
+ **kwargs,
+ ) -> Iterator[RelikReaderSample]:
+ """
+ Predicts the labels for a batch of samples.
+ Args:
+ input_ids: The input ids of the batch.
+ attention_mask: The attention mask of the batch.
+ token_type_ids: The token type ids of the batch.
+ prediction_mask: The prediction mask of the batch.
+ special_symbols_mask: The special symbols mask of the batch.
+ special_symbols_mask_entities: The special symbols mask entities of the batch.
+ sample: The samples of the batch.
+ Returns:
+ The predicted labels for each sample.
+ """
+ forward_output = self.forward(
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ prediction_mask,
+ special_symbols_mask,
+ special_symbols_mask_entities,
+ is_prediction=True,
+ )
+ ned_start_predictions = forward_output["ned_start_predictions"].cpu().numpy()
+ ned_end_predictions = forward_output["ned_end_predictions"] # .cpu().numpy()
+ ed_predictions = forward_output["re_entities_predictions"].cpu().numpy()
+ ned_type_predictions = forward_output["ned_type_predictions"].cpu().numpy()
+ re_predictions = forward_output["re_predictions"].cpu().numpy()
+ re_probabilities = forward_output["re_probabilities"].detach().cpu().numpy()
+ if sample is None:
+ sample = [RelikReaderSample() for _ in range(len(input_ids))]
+ for ts, ne_st, ne_end, re_pred, re_prob, edp, ne_et in zip(
+ sample,
+ ned_start_predictions,
+ ned_end_predictions,
+ re_predictions,
+ re_probabilities,
+ ed_predictions,
+ ned_type_predictions,
+ ):
+ ne_end = ne_end.cpu().numpy()
+ entities = []
+ if self.relik_reader_re_model.entity_type_loss:
+ starts = np.argwhere(ne_st)
+ i = 0
+ for start, end in zip(starts, ne_end):
+ ends = np.argwhere(end)
+ for e in ends:
+ entities.append([start[0], e[0], ne_et[i]])
+ i += 1
+ else:
+ starts = np.argwhere(ne_st)
+ for start, end in zip(starts, ne_end):
+ ends = np.argwhere(end)
+ for e in ends:
+ entities.append([start[0], e[0]])
+
+ edp = edp[: len(entities)]
+ re_pred = re_pred[: len(entities), : len(entities)]
+ re_prob = re_prob[: len(entities), : len(entities)]
+ possible_re = np.argwhere(re_pred)
+ predicted_triplets = []
+ predicted_triplets_prob = []
+ for i, j, r in possible_re:
+ if self.relik_reader_re_model.relation_disambiguation_loss:
+ if not (
+ i != j
+ and edp[i, r] == 1
+ and edp[j, r] == 1
+ and edp[i, 0] == 0
+ and edp[j, 0] == 0
+ ):
+ continue
+ predicted_triplets.append([i, j, r])
+ predicted_triplets_prob.append(re_prob[i, j, r])
+
+ ts._d["predicted_relations"] = predicted_triplets
+ ts._d["predicted_entities"] = entities
+ ts._d["predicted_relations_probabilities"] = predicted_triplets_prob
+ if ts.token2word:
+ self._convert_tokens_to_word_annotations(ts)
+ yield ts
+
+ def _build_input(self, text: List[str], candidates: List[List[str]]) -> List[int]:
+ candidates_symbols = get_special_symbols_re(len(candidates))
+ candidates = [
+ [cs, ct] if ct != NME_SYMBOL else [NME_SYMBOL]
+ for cs, ct in zip(candidates_symbols, candidates)
+ ]
+ return (
+ [self.tokenizer.cls_token]
+ + text
+ + [self.tokenizer.sep_token]
+ + flatten(candidates)
+ + [self.tokenizer.sep_token]
+ )
+
+ @staticmethod
+ def _compute_offsets(offsets_mapping):
+ offsets_mapping = offsets_mapping.numpy()
+ token2word = []
+ word2token = {}
+ count = 0
+ for i, offset in enumerate(offsets_mapping):
+ if offset[0] == 0:
+ token2word.append(i - count)
+ word2token[i - count] = [i]
+ else:
+ token2word.append(token2word[-1])
+ word2token[token2word[-1]].append(i)
+ count += 1
+ return token2word, word2token
+
+ @staticmethod
+ def _convert_tokens_to_word_annotations(sample: RelikReaderSample):
+ triplets = []
+ entities = []
+ for entity in sample.predicted_entities:
+ if sample.entity_candidates:
+ entities.append(
+ (
+ sample.token2word[entity[0] - 1],
+ sample.token2word[entity[1] - 1] + 1,
+ sample.entity_candidates[entity[2]],
+ )
+ )
+ else:
+ entities.append(
+ (
+ sample.token2word[entity[0] - 1],
+ sample.token2word[entity[1] - 1] + 1,
+ -1,
+ )
+ )
+ for predicted_triplet, predicted_triplet_probabilities in zip(
+ sample.predicted_relations, sample.predicted_relations_probabilities
+ ):
+ subject, object_, relation = predicted_triplet
+ subject = entities[subject]
+ object_ = entities[object_]
+ relation = sample.candidates[relation]
+ triplets.append(
+ {
+ "subject": {
+ "start": subject[0],
+ "end": subject[1],
+ "type": subject[2],
+ "name": " ".join(sample.tokens[subject[0] : subject[1]]),
+ },
+ "relation": {
+ "name": relation,
+ "probability": float(predicted_triplet_probabilities.round(2)),
+ },
+ "object": {
+ "start": object_[0],
+ "end": object_[1],
+ "type": object_[2],
+ "name": " ".join(sample.tokens[object_[0] : object_[1]]),
+ },
+ }
+ )
+ sample.predicted_entities = entities
+ sample.predicted_relations = triplets
+ sample.predicted_relations_probabilities = None
+
+ @torch.no_grad()
+ @torch.inference_mode()
+ def read(
+ self,
+ text: Optional[Union[List[str], List[List[str]]]] = None,
+ samples: Optional[List[RelikReaderSample]] = None,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ prediction_mask: Optional[torch.Tensor] = None,
+ special_symbols_mask: Optional[torch.Tensor] = None,
+ special_symbols_mask_entities: Optional[torch.Tensor] = None,
+ candidates: Optional[List[List[str]]] = None,
+ max_length: Optional[int] = 1024,
+ max_batch_size: Optional[int] = 64,
+ token_batch_size: Optional[int] = None,
+ progress_bar: bool = False,
+ *args,
+ **kwargs,
+ ) -> List[List[RelikReaderSample]]:
+ """
+ Reads the given text.
+ Args:
+ text: The text to read in tokens.
+ input_ids: The input ids of the text.
+ attention_mask: The attention mask of the text.
+ token_type_ids: The token type ids of the text.
+ prediction_mask: The prediction mask of the text.
+ special_symbols_mask: The special symbols mask of the text.
+ special_symbols_mask_entities: The special symbols mask entities of the text.
+ candidates: The candidates of the text.
+ max_length: The maximum length of the text.
+ max_batch_size: The maximum batch size.
+ token_batch_size: The maximum number of tokens per batch.
+ Returns:
+ The predicted labels for each sample.
+ """
+ if text is None and input_ids is None and samples is None:
+ raise ValueError(
+ "Either `text` or `input_ids` or `samples` must be provided."
+ )
+ if (input_ids is None and samples is None) and (
+ text is None or candidates is None
+ ):
+ raise ValueError(
+ "`text` and `candidates` must be provided to return the predictions when `input_ids` and `samples` is not provided."
+ )
+ if text is not None and samples is None:
+ if len(text) != len(candidates):
+ raise ValueError("`text` and `candidates` must have the same length.")
+ if isinstance(text[0], str): # change to list of text
+ text = [text]
+ candidates = [candidates]
+
+ samples = [
+ RelikReaderSample(tokens=t, candidates=c)
+ for t, c in zip(text, candidates)
+ ]
+
+ if samples is not None:
+ # function that creates a batch from the 'current_batch' list
+ def output_batch() -> Dict[str, Any]:
+ assert (
+ len(
+ set(
+ [
+ len(elem["predictable_candidates"])
+ for elem in current_batch
+ ]
+ )
+ )
+ == 1
+ ), " ".join(
+ map(
+ str,
+ [len(elem["predictable_candidates"]) for elem in current_batch],
+ )
+ )
+
+ batch_dict = dict()
+
+ de_values_by_field = {
+ fn: [de[fn] for de in current_batch if fn in de]
+ for fn in self.fields_batcher
+ }
+
+ # in case you provide fields batchers but in the batch
+ # there are no elements for that field
+ de_values_by_field = {
+ fn: fvs for fn, fvs in de_values_by_field.items() if len(fvs) > 0
+ }
+
+ assert len(set([len(v) for v in de_values_by_field.values()]))
+
+ # todo: maybe we should report the user about possible
+ # fields filtering due to "None" instances
+ de_values_by_field = {
+ fn: fvs
+ for fn, fvs in de_values_by_field.items()
+ if all([fv is not None for fv in fvs])
+ }
+
+ for field_name, field_values in de_values_by_field.items():
+ field_batch = (
+ self.fields_batcher[field_name]([fv[0] for fv in field_values])
+ if self.fields_batcher[field_name] is not None
+ else field_values
+ )
+
+ batch_dict[field_name] = field_batch
+
+ batch_dict = {
+ k: v.to(self.device) if isinstance(v, torch.Tensor) else v
+ for k, v in batch_dict.items()
+ }
+ return batch_dict
+
+ current_batch = []
+ predictions = []
+ current_cand_len = -1
+
+ for sample in tqdm(samples, disable=not progress_bar):
+ sample.candidates = [NME_SYMBOL] + sample.candidates
+ inputs_text = self._build_input(sample.tokens, sample.candidates)
+ model_inputs = self.tokenizer(
+ inputs_text,
+ is_split_into_words=True,
+ add_special_tokens=False,
+ padding=False,
+ truncation=True,
+ max_length=max_length or self.tokenizer.model_max_length,
+ return_offsets_mapping=True,
+ return_tensors="pt",
+ )
+ model_inputs["special_symbols_mask"] = (
+ model_inputs["input_ids"] > self.tokenizer.vocab_size
+ )
+ # prediction mask is 0 until the first special symbol
+ model_inputs["token_type_ids"] = (
+ torch.cumsum(model_inputs["special_symbols_mask"], dim=1) > 0
+ ).long()
+ # shift prediction_mask to the left
+ model_inputs["prediction_mask"] = model_inputs["token_type_ids"].roll(
+ shifts=-1, dims=1
+ )
+ model_inputs["prediction_mask"][:, -1] = 1
+ model_inputs["prediction_mask"][:, 0] = 1
+
+ assert (
+ len(model_inputs["special_symbols_mask"])
+ == len(model_inputs["prediction_mask"])
+ == len(model_inputs["input_ids"])
+ )
+
+ model_inputs["sample"] = sample
+
+ # compute cand_len using special_symbols_mask
+ model_inputs["predictable_candidates"] = sample.candidates[
+ : model_inputs["special_symbols_mask"].sum().item()
+ ]
+ # cand_len = sum([id_ > self.tokenizer.vocab_size for id_ in model_inputs["input_ids"]])
+ offsets = model_inputs.pop("offset_mapping")
+ offsets = offsets[model_inputs["prediction_mask"] == 0]
+ sample.token2word, sample.word2token = self._compute_offsets(offsets)
+ future_max_len = max(
+ len(model_inputs["input_ids"]),
+ max([len(b["input_ids"]) for b in current_batch], default=0),
+ )
+ future_tokens_per_batch = future_max_len * (len(current_batch) + 1)
+
+ if len(current_batch) > 0 and (
+ (
+ len(model_inputs["predictable_candidates"]) != current_cand_len
+ and current_cand_len != -1
+ )
+ or (
+ isinstance(token_batch_size, int)
+ and future_tokens_per_batch >= token_batch_size
+ )
+ or len(current_batch) == max_batch_size
+ ):
+ batch_inputs = output_batch()
+ current_batch = []
+ predictions.extend(list(self.batch_predict(**batch_inputs)))
+ current_cand_len = len(model_inputs["predictable_candidates"])
+ current_batch.append(model_inputs)
+
+ if current_batch:
+ batch_inputs = output_batch()
+ predictions.extend(list(self.batch_predict(**batch_inputs)))
+ else:
+ predictions = list(
+ self.batch_predict(
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ prediction_mask,
+ special_symbols_mask,
+ special_symbols_mask_entities,
+ *args,
+ **kwargs,
+ )
+ )
+ return predictions
+
+ @property
+ def device(self) -> torch.device:
+ """
+ The device of the model.
+ """
+ return next(self.parameters()).device
+
+ @property
+ def tokenizer(self) -> tr.PreTrainedTokenizer:
+ """
+ The tokenizer.
+ """
+ if self._tokenizer:
+ return self._tokenizer
+
+ self._tokenizer = tr.AutoTokenizer.from_pretrained(
+ self.relik_reader_re_model.config.name_or_path
+ )
+ return self._tokenizer
+
+ @property
+ def fields_batcher(self) -> Dict[str, Union[None, Callable[[list], Any]]]:
+ fields_batchers = {
+ "input_ids": lambda x: batchify(
+ x, padding_value=self.tokenizer.pad_token_id
+ ),
+ "attention_mask": lambda x: batchify(x, padding_value=0),
+ "token_type_ids": lambda x: batchify(x, padding_value=0),
+ "prediction_mask": lambda x: batchify(x, padding_value=1),
+ "global_attention": lambda x: batchify(x, padding_value=0),
+ "token2word": None,
+ "sample": None,
+ "special_symbols_mask": lambda x: batchify(x, padding_value=False),
+ "special_symbols_mask_entities": lambda x: batchify(x, padding_value=False),
+ }
+ if "roberta" in self.relik_reader_re_model.config.model_type:
+ del fields_batchers["token_type_ids"]
+
+ return fields_batchers
+
+ def save_pretrained(
+ self,
+ output_dir: str,
+ model_name: Optional[str] = None,
+ push_to_hub: bool = False,
+ **kwargs,
+ ) -> None:
+ """
+ Saves the model to the given path.
+ Args:
+ output_dir: The path to save the model to.
+ model_name: The name of the model.
+ push_to_hub: Whether to push the model to the hub.
+ """
+ # create the output directory
+ output_dir = Path(output_dir)
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ model_name = model_name or "relik_reader_for_triplet_extraction"
+
+ logger.info(f"Saving reader to {output_dir / model_name}")
+
+ # save the model
+ self.relik_reader_re_model.register_for_auto_class()
+ self.relik_reader_re_model.save_pretrained(
+ output_dir / model_name, push_to_hub=push_to_hub, **kwargs
+ )
+
+ logger.info("Saving reader to disk done.")
+
+ if self.tokenizer:
+ self.tokenizer.save_pretrained(
+ output_dir / model_name, push_to_hub=push_to_hub, **kwargs
+ )
+ logger.info("Saving tokenizer to disk done.")
diff --git a/relik/reader/relik_reader_re_data.py b/relik/reader/relik_reader_re_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..2bc43b78eeba4e63447d8bf333dacf2a993716fd
--- /dev/null
+++ b/relik/reader/relik_reader_re_data.py
@@ -0,0 +1,849 @@
+import logging
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Generator,
+ Iterator,
+ List,
+ NamedTuple,
+ Optional,
+ Tuple,
+ Union,
+)
+
+import numpy as np
+import torch
+from reader.data.relik_reader_data_utils import (
+ add_noise_to_value,
+ batchify,
+ batchify_matrices,
+ batchify_tensor,
+ chunks,
+ flatten,
+)
+from reader.data.relik_reader_sample import RelikReaderSample, load_relik_reader_samples
+from torch.utils.data import IterableDataset
+from transformers import AutoTokenizer
+
+from relik.reader.utils.special_symbols import NME_SYMBOL
+
+logger = logging.getLogger(__name__)
+
+
+class TokenizationOutput(NamedTuple):
+ input_ids: torch.Tensor
+ attention_mask: torch.Tensor
+ token_type_ids: torch.Tensor
+ prediction_mask: torch.Tensor
+ special_symbols_mask: torch.Tensor
+ special_symbols_mask_entities: torch.Tensor
+
+
+class RelikREDataset(IterableDataset):
+ def __init__(
+ self,
+ dataset_path: str,
+ materialize_samples: bool,
+ transformer_model: str,
+ special_symbols: List[str],
+ shuffle_candidates: Optional[Union[bool, float]],
+ flip_candidates: Optional[Union[bool, float]],
+ relations_definitions: Union[str, Dict[str, str]],
+ for_inference: bool,
+ entities_definitions: Optional[Union[str, Dict[str, str]]] = None,
+ special_symbols_entities: Optional[List[str]] = None,
+ noise_param: float = 0.1,
+ sorting_fields: Optional[str] = None,
+ tokens_per_batch: int = 2048,
+ batch_size: int = None,
+ max_batch_size: int = 128,
+ section_size: int = 50_000,
+ prebatch: bool = True,
+ max_candidates: int = 0,
+ add_gold_candidates: bool = True,
+ use_nme: bool = True,
+ min_length: int = 5,
+ max_length: int = 2048,
+ model_max_length: int = 1000,
+ skip_empty_training_samples: bool = True,
+ drop_last: bool = False,
+ samples: Optional[Iterator[RelikReaderSample]] = None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.dataset_path = dataset_path
+ self.materialize_samples = materialize_samples
+ self.samples: Optional[List[RelikReaderSample]] = None
+ if self.materialize_samples:
+ self.samples = list()
+
+ self.tokenizer = self._build_tokenizer(transformer_model, special_symbols)
+ self.special_symbols = special_symbols
+ self.special_symbols_entities = special_symbols_entities
+ self.shuffle_candidates = shuffle_candidates
+ self.flip_candidates = flip_candidates
+ self.for_inference = for_inference
+ self.noise_param = noise_param
+ self.batching_fields = ["input_ids"]
+ self.sorting_fields = (
+ sorting_fields if sorting_fields is not None else self.batching_fields
+ )
+
+ # open relations definitions file if needed
+ if type(relations_definitions) == str:
+ relations_definitions = {
+ line.split("\t")[0]: line.split("\t")[1]
+ for line in open(relations_definitions)
+ }
+ self.max_candidates = max_candidates
+ self.relations_definitions = relations_definitions
+ self.entities_definitions = entities_definitions
+
+ self.add_gold_candidates = add_gold_candidates
+ self.use_nme = use_nme
+ self.min_length = min_length
+ self.max_length = max_length
+ self.model_max_length = (
+ model_max_length
+ if model_max_length < self.tokenizer.model_max_length
+ else self.tokenizer.model_max_length
+ )
+ self.transformer_model = transformer_model
+ self.skip_empty_training_samples = skip_empty_training_samples
+ self.drop_last = drop_last
+ self.samples = samples
+
+ self.tokens_per_batch = tokens_per_batch
+ self.batch_size = batch_size
+ self.max_batch_size = max_batch_size
+ self.section_size = section_size
+ self.prebatch = prebatch
+
+ def _build_tokenizer(self, transformer_model: str, special_symbols: List[str]):
+ return AutoTokenizer.from_pretrained(
+ transformer_model,
+ additional_special_tokens=[ss for ss in special_symbols],
+ add_prefix_space=True,
+ )
+
+ @property
+ def fields_batcher(self) -> Dict[str, Union[None, Callable[[list], Any]]]:
+ fields_batchers = {
+ "input_ids": lambda x: batchify(
+ x, padding_value=self.tokenizer.pad_token_id
+ ),
+ "attention_mask": lambda x: batchify(x, padding_value=0),
+ "token_type_ids": lambda x: batchify(x, padding_value=0),
+ "prediction_mask": lambda x: batchify(x, padding_value=1),
+ "global_attention": lambda x: batchify(x, padding_value=0),
+ "token2word": None,
+ "sample": None,
+ "special_symbols_mask": lambda x: batchify(x, padding_value=False),
+ "special_symbols_mask_entities": lambda x: batchify(x, padding_value=False),
+ "start_labels": lambda x: batchify(x, padding_value=-100),
+ "end_labels": lambda x: batchify_matrices(x, padding_value=-100),
+ "disambiguation_labels": lambda x: batchify(x, padding_value=-100),
+ "relation_labels": lambda x: batchify_tensor(x, padding_value=-100),
+ "predictable_candidates": None,
+ }
+ if "roberta" in self.transformer_model:
+ del fields_batchers["token_type_ids"]
+
+ return fields_batchers
+
+ def _build_input_ids(
+ self, sentence_input_ids: List[int], candidates_input_ids: List[List[int]]
+ ) -> List[int]:
+ return (
+ [self.tokenizer.cls_token_id]
+ + sentence_input_ids
+ + [self.tokenizer.sep_token_id]
+ + flatten(candidates_input_ids)
+ + [self.tokenizer.sep_token_id]
+ )
+
+ def _get_special_symbols_mask(self, input_ids: torch.Tensor) -> torch.Tensor:
+ special_symbols_mask = input_ids >= (
+ len(self.tokenizer)
+ - len(self.special_symbols + self.special_symbols_entities)
+ )
+ special_symbols_mask[0] = True
+ return special_symbols_mask
+
+ def _build_tokenizer_essentials(
+ self, input_ids, original_sequence
+ ) -> TokenizationOutput:
+ input_ids = torch.tensor(input_ids, dtype=torch.long)
+ attention_mask = torch.ones_like(input_ids)
+
+ total_sequence_len = len(input_ids)
+ predictable_sentence_len = len(original_sequence)
+
+ # token type ids
+ token_type_ids = torch.cat(
+ [
+ input_ids.new_zeros(
+ predictable_sentence_len + 2
+ ), # original sentence bpes + CLS and SEP
+ input_ids.new_ones(total_sequence_len - predictable_sentence_len - 2),
+ ]
+ )
+
+ # prediction mask -> boolean on tokens that are predictable
+
+ prediction_mask = torch.tensor(
+ [1]
+ + ([0] * predictable_sentence_len)
+ + ([1] * (total_sequence_len - predictable_sentence_len - 1))
+ )
+
+ assert len(prediction_mask) == len(input_ids)
+
+ # special symbols mask
+ special_symbols_mask = input_ids >= (
+ len(self.tokenizer)
+ - len(self.special_symbols) # + self.special_symbols_entities)
+ )
+ if self.entities_definitions is not None:
+ # select only the first N true values where N is len(entities_definitions)
+ special_symbols_mask_entities = special_symbols_mask.clone()
+ special_symbols_mask_entities[
+ special_symbols_mask_entities.cumsum(0) > len(self.entities_definitions)
+ ] = False
+ special_symbols_mask = special_symbols_mask ^ special_symbols_mask_entities
+ else:
+ special_symbols_mask_entities = special_symbols_mask.clone()
+
+ return TokenizationOutput(
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ prediction_mask,
+ special_symbols_mask,
+ special_symbols_mask_entities,
+ )
+
+ def _build_labels(
+ self,
+ sample,
+ tokenization_output: TokenizationOutput,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ start_labels = [0] * len(tokenization_output.input_ids)
+ end_labels = []
+
+ sample.entities.sort(key=lambda x: (x[0], x[1]))
+
+ prev_start_bpe = -1
+ num_repeat_start = 0
+ if self.entities_definitions:
+ sample.entities = [(ce[0], ce[1], ce[2]) for ce in sample.entities]
+ sample.entity_candidates = list(self.entities_definitions.keys())
+ disambiguation_labels = torch.zeros(
+ len(sample.entities),
+ len(sample.entity_candidates) + len(sample.candidates),
+ )
+ else:
+ sample.entities = [(ce[0], ce[1], "") for ce in sample.entities]
+ disambiguation_labels = torch.zeros(
+ len(sample.entities), len(sample.candidates)
+ )
+ ignored_labels_indices = tokenization_output.prediction_mask == 1
+ for idx, c_ent in enumerate(sample.entities):
+ start_bpe = sample.word2token[c_ent[0]][0] + 1
+ end_bpe = sample.word2token[c_ent[1] - 1][-1] + 1
+ class_index = idx
+ start_labels[start_bpe] = class_index + 1 # +1 for the NONE class
+ if start_bpe != prev_start_bpe:
+ end_labels.append([0] * len(tokenization_output.input_ids))
+ # end_labels[-1][:start_bpe] = [-100] * start_bpe
+ end_labels[-1][end_bpe] = class_index + 1
+ else:
+ end_labels[-1][end_bpe] = class_index + 1
+ num_repeat_start += 1
+ if self.entities_definitions:
+ entity_type_idx = sample.entity_candidates.index(c_ent[2])
+ disambiguation_labels[idx, entity_type_idx] = 1
+ prev_start_bpe = start_bpe
+
+ start_labels = torch.tensor(start_labels, dtype=torch.long)
+ start_labels[ignored_labels_indices] = -100
+
+ end_labels = torch.tensor(end_labels, dtype=torch.long)
+ end_labels[ignored_labels_indices.repeat(len(end_labels), 1)] = -100
+
+ relation_labels = torch.zeros(
+ len(sample.entities), len(sample.entities), len(sample.candidates)
+ )
+
+ # sample.relations = []
+ for re in sample.triplets:
+ if re["relation"]["name"] not in sample.candidates:
+ re_class_index = len(sample.candidates) - 1
+ else:
+ re_class_index = sample.candidates.index(
+ re["relation"]["name"]
+ ) # should remove this +1
+ if self.entities_definitions:
+ subject_class_index = sample.entities.index(
+ (
+ re["subject"]["start"],
+ re["subject"]["end"],
+ re["subject"]["type"],
+ )
+ )
+ object_class_index = sample.entities.index(
+ (re["object"]["start"], re["object"]["end"], re["object"]["type"])
+ )
+ else:
+ subject_class_index = sample.entities.index(
+ (re["subject"]["start"], re["subject"]["end"], "")
+ )
+ object_class_index = sample.entities.index(
+ (re["object"]["start"], re["object"]["end"], "")
+ )
+
+ relation_labels[subject_class_index, object_class_index, re_class_index] = 1
+
+ if self.entities_definitions:
+ disambiguation_labels[
+ subject_class_index, re_class_index + len(sample.entity_candidates)
+ ] = 1
+ disambiguation_labels[
+ object_class_index, re_class_index + len(sample.entity_candidates)
+ ] = 1
+ # sample.relations.append([re['subject']['start'], re['subject']['end'], re['subject']['type'], re['relation']['name'], re['object']['start'], re['object']['end'], re['object']['type']])
+ else:
+ disambiguation_labels[subject_class_index, re_class_index] = 1
+ disambiguation_labels[object_class_index, re_class_index] = 1
+ # sample.relations.append([re['subject']['start'], re['subject']['end'], "", re['relation']['name'], re['object']['start'], re['object']['end'], ""])
+ return start_labels, end_labels, disambiguation_labels, relation_labels
+
+ def __iter__(self):
+ dataset_iterator = self.dataset_iterator_func()
+ current_dataset_elements = []
+ i = None
+ for i, dataset_elem in enumerate(dataset_iterator, start=1):
+ if (
+ self.section_size is not None
+ and len(current_dataset_elements) == self.section_size
+ ):
+ for batch in self.materialize_batches(current_dataset_elements):
+ yield batch
+ current_dataset_elements = []
+ current_dataset_elements.append(dataset_elem)
+ if i % 50_000 == 0:
+ logger.info(f"Processed: {i} number of elements")
+ if len(current_dataset_elements) != 0:
+ for batch in self.materialize_batches(current_dataset_elements):
+ yield batch
+ if i is not None:
+ logger.info(f"Dataset finished: {i} number of elements processed")
+ else:
+ logger.warning("Dataset empty")
+
+ def dataset_iterator_func(self):
+ data_samples = (
+ load_relik_reader_samples(self.dataset_path)
+ if self.samples is None
+ else self.samples
+ )
+ for sample in data_samples:
+ # input sentence tokenization
+ input_tokenized = self.tokenizer(
+ sample.tokens,
+ return_offsets_mapping=True,
+ add_special_tokens=False,
+ is_split_into_words=True,
+ )
+ input_subwords = input_tokenized["input_ids"]
+ offsets = input_tokenized["offset_mapping"]
+ token2word = []
+ word2token = {}
+ count = 0
+ for i, offset in enumerate(offsets):
+ if offset[0] == 0:
+ token2word.append(i - count)
+ word2token[i - count] = [i]
+ else:
+ token2word.append(token2word[-1])
+ word2token[token2word[-1]].append(i)
+ count += 1
+ sample.token2word = token2word
+ sample.word2token = word2token
+ # input_subwords = sample.tokens[1:-1] # removing special tokens
+ candidates_symbols = self.special_symbols
+
+ if self.max_candidates > 0:
+ # truncate candidates
+ sample.candidates = sample.candidates[: self.max_candidates]
+
+ # add NME as a possible candidate
+ if self.use_nme:
+ sample.candidates.insert(0, NME_SYMBOL)
+
+ # training time sample mods
+ if not self.for_inference:
+ # check whether the sample has labels if not skip
+ if (
+ sample.triplets is None or len(sample.triplets) == 0
+ ) and self.skip_empty_training_samples:
+ logger.warning(
+ "Sample {} has no labels, skipping".format(sample.sample_id)
+ )
+ continue
+
+ # add gold candidates if missing
+ if self.add_gold_candidates:
+ candidates_set = set(sample.candidates)
+ candidates_to_add = []
+ for candidate_title in sample.triplets:
+ if candidate_title["relation"]["name"] not in candidates_set:
+ candidates_to_add.append(
+ candidate_title["relation"]["name"]
+ )
+ if len(candidates_to_add) > 0:
+ # replacing last candidates with the gold ones
+ # this is done in order to preserve the ordering
+ added_gold_candidates = 0
+ gold_candidates_titles_set = set(
+ set(ct["relation"]["name"] for ct in sample.triplets)
+ )
+ for i in reversed(range(len(sample.candidates))):
+ if (
+ sample.candidates[i] not in gold_candidates_titles_set
+ and sample.candidates[i] != NME_SYMBOL
+ ):
+ sample.candidates[i] = candidates_to_add[
+ added_gold_candidates
+ ]
+ added_gold_candidates += 1
+ if len(candidates_to_add) == added_gold_candidates:
+ break
+
+ candidates_still_to_add = (
+ len(candidates_to_add) - added_gold_candidates
+ )
+ while (
+ len(sample.candidates) <= len(candidates_symbols)
+ and candidates_still_to_add != 0
+ ):
+ sample.candidates.append(
+ candidates_to_add[added_gold_candidates]
+ )
+ added_gold_candidates += 1
+ candidates_still_to_add -= 1
+
+ # shuffle candidates
+ if (
+ isinstance(self.shuffle_candidates, bool)
+ and self.shuffle_candidates
+ ) or (
+ isinstance(self.shuffle_candidates, float)
+ and np.random.uniform() < self.shuffle_candidates
+ ):
+ np.random.shuffle(sample.candidates)
+ if NME_SYMBOL in sample.candidates:
+ sample.candidates.remove(NME_SYMBOL)
+ sample.candidates.insert(0, NME_SYMBOL)
+
+ # flip candidates
+ if (
+ isinstance(self.flip_candidates, bool) and self.flip_candidates
+ ) or (
+ isinstance(self.flip_candidates, float)
+ and np.random.uniform() < self.flip_candidates
+ ):
+ for i in range(len(sample.candidates) - 1):
+ if np.random.uniform() < 0.5:
+ sample.candidates[i], sample.candidates[i + 1] = (
+ sample.candidates[i + 1],
+ sample.candidates[i],
+ )
+ if NME_SYMBOL in sample.candidates:
+ sample.candidates.remove(NME_SYMBOL)
+ sample.candidates.insert(0, NME_SYMBOL)
+
+ # candidates encoding
+ candidates_symbols = candidates_symbols[: len(sample.candidates)]
+ relations_defs = [
+ "{} {}".format(cs, self.relations_definitions[ct])
+ if ct != NME_SYMBOL
+ else NME_SYMBOL
+ for cs, ct in zip(candidates_symbols, sample.candidates)
+ ]
+ if self.entities_definitions is not None:
+ candidates_entities_symbols = list(self.special_symbols_entities)
+ candidates_entities_symbols = candidates_entities_symbols[
+ : len(self.entities_definitions)
+ ]
+ entity_defs = [
+ "{} {}".format(cs, self.entities_definitions[ct])
+ for cs, ct in zip(
+ candidates_entities_symbols, self.entities_definitions.keys()
+ )
+ ]
+ relations_defs = (
+ entity_defs + [self.tokenizer.sep_token] + relations_defs
+ )
+
+ candidates_encoding_result = self.tokenizer.batch_encode_plus(
+ relations_defs,
+ add_special_tokens=False,
+ ).input_ids
+
+ # drop candidates if the number of input tokens is too long for the model
+ if (
+ sum(map(len, candidates_encoding_result))
+ + len(input_subwords)
+ + 20 # + 20 special tokens
+ > self.model_max_length
+ ):
+ if self.for_inference:
+ acceptable_tokens_from_candidates = (
+ self.model_max_length - 20 - len(input_subwords)
+ )
+ while (
+ cum_len + len(candidates_encoding_result[i])
+ < acceptable_tokens_from_candidates
+ ):
+ cum_len += len(candidates_encoding_result[i])
+ i += 1
+
+ candidates_encoding_result = candidates_encoding_result[:i]
+ if self.entities_definitions is not None:
+ candidates_symbols = candidates_symbols[
+ : i - len(self.entities_definitions)
+ ]
+ sample.candidates = sample.candidates[
+ : i - len(self.entities_definitions)
+ ]
+ else:
+ candidates_symbols = candidates_symbols[:i]
+ sample.candidates = sample.candidates[:i]
+
+ else:
+ gold_candidates_set = set(
+ [wl["relation"]["name"] for wl in sample.triplets]
+ )
+ gold_candidates_indices = [
+ i
+ for i, wc in enumerate(sample.candidates)
+ if wc in gold_candidates_set
+ ]
+ if self.entities_definitions is not None:
+ gold_candidates_indices = [
+ i + len(self.entities_definitions)
+ for i in gold_candidates_indices
+ ]
+ # add entities indices
+ gold_candidates_indices = gold_candidates_indices + list(
+ range(len(self.entities_definitions))
+ )
+ necessary_taken_tokens = sum(
+ map(
+ len,
+ [
+ candidates_encoding_result[i]
+ for i in gold_candidates_indices
+ ],
+ )
+ )
+
+ acceptable_tokens_from_candidates = (
+ self.model_max_length
+ - 20
+ - len(input_subwords)
+ - necessary_taken_tokens
+ )
+
+ assert acceptable_tokens_from_candidates > 0
+
+ i = 0
+ cum_len = 0
+ while (
+ cum_len + len(candidates_encoding_result[i])
+ < acceptable_tokens_from_candidates
+ ):
+ if i not in gold_candidates_indices:
+ cum_len += len(candidates_encoding_result[i])
+ i += 1
+
+ new_indices = sorted(
+ list(set(list(range(i)) + gold_candidates_indices))
+ )
+ np.random.shuffle(new_indices)
+
+ candidates_encoding_result = [
+ candidates_encoding_result[i] for i in new_indices
+ ]
+ if self.entities_definitions is not None:
+ sample.candidates = [
+ sample.candidates[i - len(self.entities_definitions)]
+ for i in new_indices
+ ]
+ candidates_symbols = candidates_symbols[
+ : i - len(self.entities_definitions)
+ ]
+ else:
+ candidates_symbols = [
+ candidates_symbols[i] for i in new_indices
+ ]
+ sample.window_candidates = [
+ sample.window_candidates[i] for i in new_indices
+ ]
+ if len(sample.candidates) == 0:
+ logger.warning(
+ "Sample {} has no candidates after truncation due to max length".format(
+ sample.sample_id
+ )
+ )
+ continue
+
+ # final input_ids build
+ input_ids = self._build_input_ids(
+ sentence_input_ids=input_subwords,
+ candidates_input_ids=candidates_encoding_result,
+ )
+
+ # complete input building (e.g. attention / prediction mask)
+ tokenization_output = self._build_tokenizer_essentials(
+ input_ids, input_subwords
+ )
+
+ # labels creation
+ start_labels, end_labels, disambiguation_labels, relation_labels = (
+ None,
+ None,
+ None,
+ None,
+ )
+ if sample.entities is not None and len(sample.entities) > 0:
+ (
+ start_labels,
+ end_labels,
+ disambiguation_labels,
+ relation_labels,
+ ) = self._build_labels(
+ sample,
+ tokenization_output,
+ )
+
+ yield {
+ "input_ids": tokenization_output.input_ids,
+ "attention_mask": tokenization_output.attention_mask,
+ "token_type_ids": tokenization_output.token_type_ids,
+ "prediction_mask": tokenization_output.prediction_mask,
+ "special_symbols_mask": tokenization_output.special_symbols_mask,
+ "special_symbols_mask_entities": tokenization_output.special_symbols_mask_entities,
+ "sample": sample,
+ "start_labels": start_labels,
+ "end_labels": end_labels,
+ "disambiguation_labels": disambiguation_labels,
+ "relation_labels": relation_labels,
+ "predictable_candidates": candidates_symbols,
+ }
+
+ def preshuffle_elements(self, dataset_elements: List):
+ # This shuffling is done so that when using the sorting function,
+ # if it is deterministic given a collection and its order, we will
+ # make the whole operation not deterministic anymore.
+ # Basically, the aim is not to build every time the same batches.
+ if not self.for_inference:
+ dataset_elements = np.random.permutation(dataset_elements)
+
+ sorting_fn = (
+ lambda elem: add_noise_to_value(
+ sum(len(elem[k]) for k in self.sorting_fields),
+ noise_param=self.noise_param,
+ )
+ if not self.for_inference
+ else sum(len(elem[k]) for k in self.sorting_fields)
+ )
+
+ dataset_elements = sorted(dataset_elements, key=sorting_fn)
+
+ if self.for_inference:
+ return dataset_elements
+
+ ds = list(chunks(dataset_elements, 64)) # todo: modified
+ np.random.shuffle(ds)
+ return flatten(ds)
+
+ def materialize_batches(
+ self, dataset_elements: List[Dict[str, Any]]
+ ) -> Generator[Dict[str, Any], None, None]:
+ if self.prebatch:
+ dataset_elements = self.preshuffle_elements(dataset_elements)
+
+ current_batch = []
+
+ # function that creates a batch from the 'current_batch' list
+ def output_batch() -> Dict[str, Any]:
+ assert (
+ len(
+ set([len(elem["predictable_candidates"]) for elem in current_batch])
+ )
+ == 1
+ ), " ".join(
+ map(
+ str, [len(elem["predictable_candidates"]) for elem in current_batch]
+ )
+ )
+
+ batch_dict = dict()
+
+ de_values_by_field = {
+ fn: [de[fn] for de in current_batch if fn in de]
+ for fn in self.fields_batcher
+ }
+
+ # in case you provide fields batchers but in the batch
+ # there are no elements for that field
+ de_values_by_field = {
+ fn: fvs for fn, fvs in de_values_by_field.items() if len(fvs) > 0
+ }
+
+ assert len(set([len(v) for v in de_values_by_field.values()]))
+
+ # todo: maybe we should report the user about possible
+ # fields filtering due to "None" instances
+ de_values_by_field = {
+ fn: fvs
+ for fn, fvs in de_values_by_field.items()
+ if all([fv is not None for fv in fvs])
+ }
+
+ for field_name, field_values in de_values_by_field.items():
+ field_batch = (
+ self.fields_batcher[field_name](field_values)
+ if self.fields_batcher[field_name] is not None
+ else field_values
+ )
+
+ batch_dict[field_name] = field_batch
+
+ return batch_dict
+
+ max_len_discards, min_len_discards = 0, 0
+
+ should_token_batch = self.batch_size is None
+
+ curr_pred_elements = -1
+ for de in dataset_elements:
+ if (
+ should_token_batch
+ and self.max_batch_size != -1
+ and len(current_batch) == self.max_batch_size
+ ) or (not should_token_batch and len(current_batch) == self.batch_size):
+ yield output_batch()
+ current_batch = []
+ curr_pred_elements = -1
+
+ # todo support max length (and min length) as dicts
+
+ too_long_fields = [
+ k
+ for k in de
+ if self.max_length != -1
+ and torch.is_tensor(de[k])
+ and len(de[k]) > self.max_length
+ ]
+ if len(too_long_fields) > 0:
+ max_len_discards += 1
+ continue
+
+ too_short_fields = [
+ k
+ for k in de
+ if self.min_length != -1
+ and torch.is_tensor(de[k])
+ and len(de[k]) < self.min_length
+ ]
+ if len(too_short_fields) > 0:
+ min_len_discards += 1
+ continue
+
+ if should_token_batch:
+ de_len = sum(len(de[k]) for k in self.batching_fields)
+
+ future_max_len = max(
+ de_len,
+ max(
+ [
+ sum(len(bde[k]) for k in self.batching_fields)
+ for bde in current_batch
+ ],
+ default=0,
+ ),
+ )
+
+ future_tokens_per_batch = future_max_len * (len(current_batch) + 1)
+
+ num_predictable_candidates = len(de["predictable_candidates"])
+
+ if len(current_batch) > 0 and (
+ future_tokens_per_batch >= self.tokens_per_batch
+ or (
+ num_predictable_candidates != curr_pred_elements
+ and curr_pred_elements != -1
+ )
+ ):
+ yield output_batch()
+ current_batch = []
+
+ current_batch.append(de)
+ curr_pred_elements = len(de["predictable_candidates"])
+
+ if len(current_batch) != 0 and not self.drop_last:
+ yield output_batch()
+
+ if max_len_discards > 0:
+ if self.for_inference:
+ logger.warning(
+ f"WARNING: Inference mode is True but {max_len_discards} samples longer than max length were "
+ f"found. The {max_len_discards} samples will be DISCARDED. If you are doing some kind of evaluation"
+ f", this can INVALIDATE results. This might happen if the max length was not set to -1 or if the "
+ f"sample length exceeds the maximum length supported by the current model."
+ )
+ else:
+ logger.warning(
+ f"During iteration, {max_len_discards} elements were "
+ f"discarded since longer than max length {self.max_length}"
+ )
+
+ if min_len_discards > 0:
+ if self.for_inference:
+ logger.warning(
+ f"WARNING: Inference mode is True but {min_len_discards} samples shorter than min length were "
+ f"found. The {min_len_discards} samples will be DISCARDED. If you are doing some kind of evaluation"
+ f", this can INVALIDATE results. This might happen if the min length was not set to -1 or if the "
+ f"sample length is shorter than the minimum length supported by the current model."
+ )
+ else:
+ logger.warning(
+ f"During iteration, {min_len_discards} elements were "
+ f"discarded since shorter than min length {self.min_length}"
+ )
+
+
+def main():
+ special_symbols = [NME_SYMBOL] + [f"R-{i}" for i in range(50)]
+
+ relik_dataset = RelikREDataset(
+ "/home/huguetcabot/alby-re/alby/data/nyt-alby+/valid.jsonl",
+ materialize_samples=False,
+ transformer_model="microsoft/deberta-v3-base",
+ special_symbols=special_symbols,
+ shuffle_candidates=False,
+ flip_candidates=False,
+ for_inference=True,
+ )
+
+ for batch in relik_dataset:
+ print(batch)
+ exit(0)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/relik/reader/trainer/__init__.py b/relik/reader/trainer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/relik/reader/trainer/predict.py b/relik/reader/trainer/predict.py
new file mode 100644
index 0000000000000000000000000000000000000000..3801bef958f9a092f8d094d2e99fe476fd4caed9
--- /dev/null
+++ b/relik/reader/trainer/predict.py
@@ -0,0 +1,57 @@
+import argparse
+from pprint import pprint
+from typing import Optional
+
+from relik.reader.relik_reader import RelikReader
+from relik.reader.utils.strong_matching_eval import StrongMatching
+
+
+def predict(
+ model_path: str,
+ dataset_path: str,
+ token_batch_size: int,
+ is_eval: bool,
+ output_path: Optional[str],
+) -> None:
+ relik_reader = RelikReader(model_path)
+ predicted_samples = relik_reader.link_entities(
+ dataset_path, token_batch_size=token_batch_size
+ )
+ if is_eval:
+ eval_dict = StrongMatching()(predicted_samples)
+ pprint(eval_dict)
+ if output_path is not None:
+ with open(output_path, "w") as f:
+ for sample in predicted_samples:
+ f.write(sample.to_jsons() + "\n")
+
+
+def parse_arg() -> argparse.Namespace:
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--model-path",
+ required=True,
+ )
+ parser.add_argument("--dataset-path", "-i", required=True)
+ parser.add_argument("--is-eval", action="store_true")
+ parser.add_argument(
+ "--output-path",
+ "-o",
+ )
+ parser.add_argument("--token-batch-size", default=4096)
+ return parser.parse_args()
+
+
+def main():
+ args = parse_arg()
+ predict(
+ args.model_path,
+ args.dataset_path,
+ token_batch_size=args.token_batch_size,
+ is_eval=args.is_eval,
+ output_path=args.output_path,
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/relik/reader/trainer/predict_re.py b/relik/reader/trainer/predict_re.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f21b4b15f857ec561e14f2e8c9fb50ea4f00637
--- /dev/null
+++ b/relik/reader/trainer/predict_re.py
@@ -0,0 +1,125 @@
+import argparse
+
+import torch
+from reader.data.relik_reader_sample import load_relik_reader_samples
+
+from relik.reader.pytorch_modules.hf.modeling_relik import (
+ RelikReaderConfig,
+ RelikReaderREModel,
+)
+from relik.reader.relik_reader_re import RelikReaderForTripletExtraction
+from relik.reader.utils.relation_matching_eval import StrongMatching
+
+dict_nyt = {
+ "/people/person/nationality": "nationality",
+ "/sports/sports_team/location": "sports team location",
+ "/location/country/administrative_divisions": "administrative divisions",
+ "/business/company/major_shareholders": "shareholders",
+ "/people/ethnicity/people": "ethnicity",
+ "/people/ethnicity/geographic_distribution": "geographic distributi6on",
+ "/business/company_shareholder/major_shareholder_of": "major shareholder",
+ "/location/location/contains": "location",
+ "/business/company/founders": "founders",
+ "/business/person/company": "company",
+ "/business/company/advisors": "advisor",
+ "/people/deceased_person/place_of_death": "place of death",
+ "/business/company/industry": "industry",
+ "/people/person/ethnicity": "ethnic background",
+ "/people/person/place_of_birth": "place of birth",
+ "/location/administrative_division/country": "country of an administration division",
+ "/people/person/place_lived": "place lived",
+ "/sports/sports_team_location/teams": "sports team",
+ "/people/person/children": "child",
+ "/people/person/religion": "religion",
+ "/location/neighborhood/neighborhood_of": "neighborhood",
+ "/location/country/capital": "capital",
+ "/business/company/place_founded": "company founded location",
+ "/people/person/profession": "occupation",
+}
+
+
+def eval(model_path, data_path, is_eval, output_path=None):
+ if model_path.endswith(".ckpt"):
+ # if it is a lightning checkpoint we load the model state dict and the tokenizer from the config
+ model_dict = torch.load(model_path)
+
+ additional_special_symbols = model_dict["hyper_parameters"][
+ "additional_special_symbols"
+ ]
+ from transformers import AutoTokenizer
+
+ from relik.reader.utils.special_symbols import get_special_symbols_re
+
+ special_symbols = get_special_symbols_re(additional_special_symbols - 1)
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_dict["hyper_parameters"]["transformer_model"],
+ additional_special_tokens=special_symbols,
+ add_prefix_space=True,
+ )
+ config_model = RelikReaderConfig(
+ model_dict["hyper_parameters"]["transformer_model"],
+ len(special_symbols),
+ training=False,
+ )
+ model = RelikReaderREModel(config_model)
+ model_dict["state_dict"] = {
+ k.replace("relik_reader_re_model.", ""): v
+ for k, v in model_dict["state_dict"].items()
+ }
+ model.load_state_dict(model_dict["state_dict"], strict=False)
+ reader = RelikReaderForTripletExtraction(
+ model, training=False, device="cuda", tokenizer=tokenizer
+ )
+ else:
+ # if it is a huggingface model we load the model directly. Note that it could even be a string from the hub
+ model = RelikReaderREModel.from_pretrained(model_path)
+ reader = RelikReaderForTripletExtraction(model, training=False, device="cuda")
+
+ samples = list(load_relik_reader_samples(data_path))
+
+ for sample in samples:
+ sample.candidates = [dict_nyt[cand] for cand in sample.candidates]
+ sample.triplets = [
+ {
+ "subject": triplet["subject"],
+ "relation": {
+ "name": dict_nyt[triplet["relation"]["name"]],
+ "type": triplet["relation"]["type"],
+ },
+ "object": triplet["object"],
+ }
+ for triplet in sample.triplets
+ ]
+
+ predicted_samples = reader.read(samples=samples, progress_bar=True)
+ if is_eval:
+ strong_matching_metric = StrongMatching()
+ predicted_samples = list(predicted_samples)
+ for k, v in strong_matching_metric(predicted_samples).items():
+ print(f"test_{k}", v)
+ if output_path is not None:
+ with open(output_path, "w") as f:
+ for sample in predicted_samples:
+ f.write(sample.to_jsons() + "\n")
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--model_path",
+ type=str,
+ default="/home/huguetcabot/alby-re/relik/relik/reader/models/relik_re_reader_base",
+ )
+ parser.add_argument(
+ "--data_path",
+ type=str,
+ default="/home/huguetcabot/alby-re/relik/relik/reader/data/testa.jsonl",
+ )
+ parser.add_argument("--is-eval", action="store_true")
+ parser.add_argument("--output_path", type=str, default=None)
+ args = parser.parse_args()
+ eval(args.model_path, args.data_path, args.is_eval, args.output_path)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/relik/reader/trainer/train.py b/relik/reader/trainer/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1983b38c02199f45c112247fc74bd09d3f1e4f0
--- /dev/null
+++ b/relik/reader/trainer/train.py
@@ -0,0 +1,98 @@
+import hydra
+import lightning
+from hydra.utils import to_absolute_path
+from lightning import Trainer
+from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
+from lightning.pytorch.loggers.wandb import WandbLogger
+from omegaconf import DictConfig, OmegaConf, open_dict
+from reader.data.relik_reader_data import RelikDataset
+from reader.lightning_modules.relik_reader_pl_module import RelikReaderPLModule
+from reader.pytorch_modules.optim import LayerWiseLRDecayOptimizer
+from torch.utils.data import DataLoader
+
+from relik.reader.utils.special_symbols import get_special_symbols
+from relik.reader.utils.strong_matching_eval import ELStrongMatchingCallback
+
+
+@hydra.main(config_path="../conf", config_name="config")
+def train(cfg: DictConfig) -> None:
+ lightning.seed_everything(cfg.training.seed)
+
+ special_symbols = get_special_symbols(cfg.model.entities_per_forward)
+
+ # model declaration
+ model = RelikReaderPLModule(
+ cfg=OmegaConf.to_container(cfg),
+ transformer_model=cfg.model.model.transformer_model,
+ additional_special_symbols=len(special_symbols),
+ training=True,
+ )
+
+ # optimizer declaration
+ opt_conf = cfg.model.optimizer
+ electra_optimizer_factory = LayerWiseLRDecayOptimizer(
+ lr=opt_conf.lr,
+ warmup_steps=opt_conf.warmup_steps,
+ total_steps=opt_conf.total_steps,
+ total_reset=opt_conf.total_reset,
+ no_decay_params=opt_conf.no_decay_params,
+ weight_decay=opt_conf.weight_decay,
+ lr_decay=opt_conf.lr_decay,
+ )
+
+ model.set_optimizer_factory(electra_optimizer_factory)
+
+ # datasets declaration
+ train_dataset: RelikDataset = hydra.utils.instantiate(
+ cfg.data.train_dataset,
+ dataset_path=to_absolute_path(cfg.data.train_dataset_path),
+ special_symbols=special_symbols,
+ )
+
+ # update of validation dataset config with special_symbols since they
+ # are required even from the EvaluationCallback dataset_config
+ with open_dict(cfg):
+ cfg.data.val_dataset.special_symbols = special_symbols
+
+ val_dataset: RelikDataset = hydra.utils.instantiate(
+ cfg.data.val_dataset,
+ dataset_path=to_absolute_path(cfg.data.val_dataset_path),
+ )
+
+ # callbacks declaration
+ callbacks = [
+ ELStrongMatchingCallback(
+ to_absolute_path(cfg.data.val_dataset_path), cfg.data.val_dataset
+ ),
+ ModelCheckpoint(
+ "model",
+ filename="{epoch}-{val_core_f1:.2f}",
+ monitor="val_core_f1",
+ mode="max",
+ ),
+ LearningRateMonitor(),
+ ]
+
+ wandb_logger = WandbLogger(cfg.model_name, project=cfg.project_name)
+
+ # trainer declaration
+ trainer: Trainer = hydra.utils.instantiate(
+ cfg.training.trainer,
+ callbacks=callbacks,
+ logger=wandb_logger,
+ )
+
+ # Trainer fit
+ trainer.fit(
+ model=model,
+ train_dataloaders=DataLoader(train_dataset, batch_size=None, num_workers=0),
+ val_dataloaders=DataLoader(val_dataset, batch_size=None, num_workers=0),
+ )
+
+
+def main():
+ train()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/relik/reader/trainer/train_re.py b/relik/reader/trainer/train_re.py
new file mode 100644
index 0000000000000000000000000000000000000000..550b3fc95d653bfc9503af635931aa72176ca89d
--- /dev/null
+++ b/relik/reader/trainer/train_re.py
@@ -0,0 +1,109 @@
+import hydra
+import lightning
+from hydra.utils import to_absolute_path
+from lightning import Trainer
+from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
+from lightning.pytorch.loggers.wandb import WandbLogger
+from omegaconf import DictConfig, OmegaConf, open_dict
+from reader.pytorch_modules.optim import LayerWiseLRDecayOptimizer
+from torch.utils.data import DataLoader
+
+from relik.reader.lightning_modules.relik_reader_re_pl_module import (
+ RelikReaderREPLModule,
+)
+from relik.reader.relik_reader_re_data import RelikREDataset
+from relik.reader.utils.relation_matching_eval import REStrongMatchingCallback
+from relik.reader.utils.special_symbols import get_special_symbols_re
+
+
+@hydra.main(config_path="conf", config_name="config")
+def train(cfg: DictConfig) -> None:
+ lightning.seed_everything(cfg.training.seed)
+
+ special_symbols = get_special_symbols_re(cfg.model.entities_per_forward)
+
+ # datasets declaration
+ train_dataset: RelikREDataset = hydra.utils.instantiate(
+ cfg.data.train_dataset,
+ dataset_path=to_absolute_path(cfg.data.train_dataset_path),
+ special_symbols=special_symbols,
+ )
+
+ # update of validation dataset config with special_symbols since they
+ # are required even from the EvaluationCallback dataset_config
+ with open_dict(cfg):
+ cfg.data.val_dataset.special_symbols = special_symbols
+
+ val_dataset: RelikREDataset = hydra.utils.instantiate(
+ cfg.data.val_dataset,
+ dataset_path=to_absolute_path(cfg.data.val_dataset_path),
+ )
+
+ # model declaration
+ model = RelikReaderREPLModule(
+ cfg=OmegaConf.to_container(cfg),
+ transformer_model=cfg.model.model.transformer_model,
+ additional_special_symbols=len(special_symbols),
+ training=True,
+ )
+ model.relik_reader_re_model._tokenizer = train_dataset.tokenizer
+ # optimizer declaration
+ opt_conf = cfg.model.optimizer
+
+ # adamw_optimizer_factory = AdamWWithWarmupOptimizer(
+ # lr=opt_conf.lr,
+ # warmup_steps=opt_conf.warmup_steps,
+ # total_steps=opt_conf.total_steps,
+ # no_decay_params=opt_conf.no_decay_params,
+ # weight_decay=opt_conf.weight_decay,
+ # )
+
+ electra_optimizer_factory = LayerWiseLRDecayOptimizer(
+ lr=opt_conf.lr,
+ warmup_steps=opt_conf.warmup_steps,
+ total_steps=opt_conf.total_steps,
+ total_reset=opt_conf.total_reset,
+ no_decay_params=opt_conf.no_decay_params,
+ weight_decay=opt_conf.weight_decay,
+ lr_decay=opt_conf.lr_decay,
+ )
+
+ model.set_optimizer_factory(electra_optimizer_factory)
+
+ # callbacks declaration
+ callbacks = [
+ REStrongMatchingCallback(
+ to_absolute_path(cfg.data.val_dataset_path), cfg.data.val_dataset
+ ),
+ ModelCheckpoint(
+ "model",
+ filename="{epoch}-{val_f1:.2f}",
+ monitor="val_f1",
+ mode="max",
+ ),
+ LearningRateMonitor(),
+ ]
+
+ wandb_logger = WandbLogger(cfg.model_name, project=cfg.project_name)
+
+ # trainer declaration
+ trainer: Trainer = hydra.utils.instantiate(
+ cfg.training.trainer,
+ callbacks=callbacks,
+ logger=wandb_logger,
+ )
+
+ # Trainer fit
+ trainer.fit(
+ model=model,
+ train_dataloaders=DataLoader(train_dataset, batch_size=None, num_workers=0),
+ val_dataloaders=DataLoader(val_dataset, batch_size=None, num_workers=0),
+ )
+
+
+def main():
+ train()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/relik/reader/utils/__init__.py b/relik/reader/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/relik/reader/utils/metrics.py b/relik/reader/utils/metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa17bf5d23cc888d6da0c6f40cf7bd3c20d77a66
--- /dev/null
+++ b/relik/reader/utils/metrics.py
@@ -0,0 +1,18 @@
+def safe_divide(num: float, den: float) -> float:
+ if den == 0:
+ return 0
+ else:
+ return num / den
+
+
+def f1_measure(precision: float, recall: float) -> float:
+ if precision == 0 or recall == 0:
+ return 0.0
+ return safe_divide(2 * precision * recall, (precision + recall))
+
+
+def compute_metrics(total_correct, total_preds, total_gold):
+ precision = safe_divide(total_correct, total_preds)
+ recall = safe_divide(total_correct, total_gold)
+ f1 = f1_measure(precision, recall)
+ return precision, recall, f1
diff --git a/relik/reader/utils/relation_matching_eval.py b/relik/reader/utils/relation_matching_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..94a6b1e7a8dc155ed1ab9f6c52cb3c1eebd44505
--- /dev/null
+++ b/relik/reader/utils/relation_matching_eval.py
@@ -0,0 +1,172 @@
+from typing import Dict, List
+
+from lightning.pytorch.callbacks import Callback
+from reader.data.relik_reader_sample import RelikReaderSample
+
+from relik.reader.relik_reader_predictor import RelikReaderPredictor
+from relik.reader.utils.metrics import compute_metrics
+
+
+class StrongMatching:
+ def __call__(self, predicted_samples: List[RelikReaderSample]) -> Dict:
+ # accumulators
+ correct_predictions, total_predictions, total_gold = (
+ 0,
+ 0,
+ 0,
+ )
+ correct_predictions_strict, total_predictions_strict = (
+ 0,
+ 0,
+ )
+ correct_predictions_bound, total_predictions_bound = (
+ 0,
+ 0,
+ )
+ correct_span_predictions, total_span_predictions, total_gold_spans = 0, 0, 0
+
+ # collect data from samples
+ for sample in predicted_samples:
+ if sample.triplets is None:
+ sample.triplets = []
+
+ if sample.entity_candidates:
+ predicted_annotations_strict = set(
+ [
+ (
+ triplet["subject"]["start"],
+ triplet["subject"]["end"],
+ triplet["subject"]["type"],
+ triplet["relation"]["name"],
+ triplet["object"]["start"],
+ triplet["object"]["end"],
+ triplet["object"]["type"],
+ )
+ for triplet in sample.predicted_relations
+ ]
+ )
+ gold_annotations_strict = set(
+ [
+ (
+ triplet["subject"]["start"],
+ triplet["subject"]["end"],
+ triplet["subject"]["type"],
+ triplet["relation"]["name"],
+ triplet["object"]["start"],
+ triplet["object"]["end"],
+ triplet["object"]["type"],
+ )
+ for triplet in sample.triplets
+ ]
+ )
+ predicted_spans_strict = set(sample.predicted_entities)
+ gold_spans_strict = set(sample.entities)
+ # strict
+ correct_span_predictions += len(
+ predicted_spans_strict.intersection(gold_spans_strict)
+ )
+ total_span_predictions += len(predicted_spans_strict)
+ total_gold_spans += len(gold_spans_strict)
+ correct_predictions_strict += len(
+ predicted_annotations_strict.intersection(gold_annotations_strict)
+ )
+ total_predictions_strict += len(predicted_annotations_strict)
+
+ predicted_annotations = set(
+ [
+ (
+ triplet["subject"]["start"],
+ triplet["subject"]["end"],
+ -1,
+ triplet["relation"]["name"],
+ triplet["object"]["start"],
+ triplet["object"]["end"],
+ -1,
+ )
+ for triplet in sample.predicted_relations
+ ]
+ )
+ gold_annotations = set(
+ [
+ (
+ triplet["subject"]["start"],
+ triplet["subject"]["end"],
+ -1,
+ triplet["relation"]["name"],
+ triplet["object"]["start"],
+ triplet["object"]["end"],
+ -1,
+ )
+ for triplet in sample.triplets
+ ]
+ )
+ predicted_spans = set(
+ [(ss, se) for (ss, se, _) in sample.predicted_entities]
+ )
+ gold_spans = set([(ss, se) for (ss, se, _) in sample.entities])
+ total_gold_spans += len(gold_spans)
+
+ correct_predictions_bound += len(predicted_spans.intersection(gold_spans))
+ total_predictions_bound += len(predicted_spans)
+
+ total_predictions += len(predicted_annotations)
+ total_gold += len(gold_annotations)
+ # correct relation extraction
+ correct_predictions += len(
+ predicted_annotations.intersection(gold_annotations)
+ )
+
+ span_precision, span_recall, span_f1 = compute_metrics(
+ correct_span_predictions, total_span_predictions, total_gold_spans
+ )
+ bound_precision, bound_recall, bound_f1 = compute_metrics(
+ correct_predictions_bound, total_predictions_bound, total_gold_spans
+ )
+
+ precision, recall, f1 = compute_metrics(
+ correct_predictions, total_predictions, total_gold
+ )
+
+ if sample.entity_candidates:
+ precision_strict, recall_strict, f1_strict = compute_metrics(
+ correct_predictions_strict, total_predictions_strict, total_gold
+ )
+ return {
+ "span-precision": span_precision,
+ "span-recall": span_recall,
+ "span-f1": span_f1,
+ "precision": precision,
+ "recall": recall,
+ "f1": f1,
+ "precision-strict": precision_strict,
+ "recall-strict": recall_strict,
+ "f1-strict": f1_strict,
+ }
+ else:
+ return {
+ "span-precision": bound_precision,
+ "span-recall": bound_recall,
+ "span-f1": bound_f1,
+ "precision": precision,
+ "recall": recall,
+ "f1": f1,
+ }
+
+
+class REStrongMatchingCallback(Callback):
+ def __init__(self, dataset_path: str, dataset_conf) -> None:
+ super().__init__()
+ self.dataset_path = dataset_path
+ self.dataset_conf = dataset_conf
+ self.strong_matching_metric = StrongMatching()
+
+ def on_validation_epoch_start(self, trainer, pl_module) -> None:
+ relik_reader_predictor = RelikReaderPredictor(pl_module.relik_reader_re_model)
+ predicted_samples = relik_reader_predictor._predict(
+ self.dataset_path,
+ None,
+ self.dataset_conf,
+ )
+ predicted_samples = list(predicted_samples)
+ for k, v in self.strong_matching_metric(predicted_samples).items():
+ pl_module.log(f"val_{k}", v)
diff --git a/relik/reader/utils/save_load_utilities.py b/relik/reader/utils/save_load_utilities.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e635650c1f69c0e223d268f97ec9d6e0677742c
--- /dev/null
+++ b/relik/reader/utils/save_load_utilities.py
@@ -0,0 +1,76 @@
+import argparse
+import os
+from typing import Tuple
+
+import omegaconf
+import torch
+
+from relik.common.utils import from_cache
+from relik.reader.lightning_modules.relik_reader_pl_module import RelikReaderPLModule
+from relik.reader.relik_reader_core import RelikReaderCoreModel
+
+CKPT_FILE_NAME = "model.ckpt"
+CONFIG_FILE_NAME = "cfg.yaml"
+
+
+def convert_pl_module(pl_module_ckpt_path: str, output_dir: str) -> None:
+ if not os.path.exists(output_dir):
+ os.makedirs(output_dir)
+ else:
+ print(f"{output_dir} already exists, aborting operation")
+ exit(1)
+
+ relik_pl_module: RelikReaderPLModule = RelikReaderPLModule.load_from_checkpoint(
+ pl_module_ckpt_path
+ )
+ torch.save(
+ relik_pl_module.relik_reader_core_model, f"{output_dir}/{CKPT_FILE_NAME}"
+ )
+ with open(f"{output_dir}/{CONFIG_FILE_NAME}", "w") as f:
+ omegaconf.OmegaConf.save(
+ omegaconf.OmegaConf.create(relik_pl_module.hparams["cfg"]), f
+ )
+
+
+def load_model_and_conf(
+ model_dir_path: str,
+) -> Tuple[RelikReaderCoreModel, omegaconf.DictConfig]:
+ # TODO: quick workaround to load the model from HF hub
+ model_dir = from_cache(
+ model_dir_path,
+ filenames=[CKPT_FILE_NAME, CONFIG_FILE_NAME],
+ cache_dir=None,
+ force_download=False,
+ )
+
+ ckpt_path = f"{model_dir}/{CKPT_FILE_NAME}"
+ model = torch.load(ckpt_path, map_location=torch.device("cpu"))
+
+ model_cfg_path = f"{model_dir}/{CONFIG_FILE_NAME}"
+ model_conf = omegaconf.OmegaConf.load(model_cfg_path)
+ return model, model_conf
+
+
+def parse_arg() -> argparse.Namespace:
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--ckpt",
+ help="Path to the pytorch lightning ckpt you want to convert.",
+ required=True,
+ )
+ parser.add_argument(
+ "--output-dir",
+ "-o",
+ help="The output dir to store the bare models and the config.",
+ required=True,
+ )
+ return parser.parse_args()
+
+
+def main():
+ args = parse_arg()
+ convert_pl_module(args.ckpt, args.output_dir)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/relik/reader/utils/special_symbols.py b/relik/reader/utils/special_symbols.py
new file mode 100644
index 0000000000000000000000000000000000000000..170909ad6cb2b69e1d6a8384af34cba441e60ce4
--- /dev/null
+++ b/relik/reader/utils/special_symbols.py
@@ -0,0 +1,11 @@
+from typing import List
+
+NME_SYMBOL = "--NME--"
+
+
+def get_special_symbols(num_entities: int) -> List[str]:
+ return [NME_SYMBOL] + [f"[E-{i}]" for i in range(num_entities)]
+
+
+def get_special_symbols_re(num_entities: int) -> List[str]:
+ return [NME_SYMBOL] + [f"[R-{i}]" for i in range(num_entities)]
diff --git a/relik/reader/utils/strong_matching_eval.py b/relik/reader/utils/strong_matching_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..88ad651134907c8067c174ee5f6fcbbb5fc2cb73
--- /dev/null
+++ b/relik/reader/utils/strong_matching_eval.py
@@ -0,0 +1,146 @@
+from typing import Dict, List
+
+from lightning.pytorch.callbacks import Callback
+from reader.data.relik_reader_sample import RelikReaderSample
+
+from relik.reader.relik_reader_predictor import RelikReaderPredictor
+from relik.reader.utils.metrics import f1_measure, safe_divide
+from relik.reader.utils.special_symbols import NME_SYMBOL
+
+
+class StrongMatching:
+ def __call__(self, predicted_samples: List[RelikReaderSample]) -> Dict:
+ # accumulators
+ correct_predictions = 0
+ correct_predictions_at_k = 0
+ total_predictions = 0
+ total_gold = 0
+ correct_span_predictions = 0
+ miss_due_to_candidates = 0
+
+ # prediction index stats
+ avg_correct_predicted_index = []
+ avg_wrong_predicted_index = []
+ less_index_predictions = []
+
+ # collect data from samples
+ for sample in predicted_samples:
+ predicted_annotations = sample.predicted_window_labels_chars
+ predicted_annotations_probabilities = sample.probs_window_labels_chars
+ gold_annotations = {
+ (ss, se, entity)
+ for ss, se, entity in sample.window_labels
+ if entity != NME_SYMBOL
+ }
+ total_predictions += len(predicted_annotations)
+ total_gold += len(gold_annotations)
+
+ # correct named entity detection
+ predicted_spans = {(s, e) for s, e, _ in predicted_annotations}
+ gold_spans = {(s, e) for s, e, _ in gold_annotations}
+ correct_span_predictions += len(predicted_spans.intersection(gold_spans))
+
+ # correct entity linking
+ correct_predictions += len(
+ predicted_annotations.intersection(gold_annotations)
+ )
+
+ for ss, se, ge in gold_annotations.difference(predicted_annotations):
+ if ge not in sample.window_candidates:
+ miss_due_to_candidates += 1
+ if ge in predicted_annotations_probabilities.get((ss, se), set()):
+ correct_predictions_at_k += 1
+
+ # indices metrics
+ predicted_spans_index = {
+ (ss, se): ent for ss, se, ent in predicted_annotations
+ }
+ gold_spans_index = {(ss, se): ent for ss, se, ent in gold_annotations}
+
+ for pred_span, pred_ent in predicted_spans_index.items():
+ gold_ent = gold_spans_index.get(pred_span)
+
+ if pred_span not in gold_spans_index:
+ continue
+
+ # missing candidate
+ if gold_ent not in sample.window_candidates:
+ continue
+
+ gold_idx = sample.window_candidates.index(gold_ent)
+ if gold_idx is None:
+ continue
+ pred_idx = sample.window_candidates.index(pred_ent)
+
+ if gold_ent != pred_ent:
+ avg_wrong_predicted_index.append(pred_idx)
+
+ if gold_idx is not None:
+ if pred_idx > gold_idx:
+ less_index_predictions.append(0)
+ else:
+ less_index_predictions.append(1)
+
+ else:
+ avg_correct_predicted_index.append(pred_idx)
+
+ # compute NED metrics
+ span_precision = safe_divide(correct_span_predictions, total_predictions)
+ span_recall = safe_divide(correct_span_predictions, total_gold)
+ span_f1 = f1_measure(span_precision, span_recall)
+
+ # compute EL metrics
+ precision = safe_divide(correct_predictions, total_predictions)
+ recall = safe_divide(correct_predictions, total_gold)
+ recall_at_k = safe_divide(
+ (correct_predictions + correct_predictions_at_k), total_gold
+ )
+
+ f1 = f1_measure(precision, recall)
+
+ wrong_for_candidates = safe_divide(miss_due_to_candidates, total_gold)
+
+ out_dict = {
+ "span_precision": span_precision,
+ "span_recall": span_recall,
+ "span_f1": span_f1,
+ "core_precision": precision,
+ "core_recall": recall,
+ "core_recall-at-k": recall_at_k,
+ "core_f1": round(f1, 4),
+ "wrong-for-candidates": wrong_for_candidates,
+ "index_errors_avg-index": safe_divide(
+ sum(avg_wrong_predicted_index), len(avg_wrong_predicted_index)
+ ),
+ "index_correct_avg-index": safe_divide(
+ sum(avg_correct_predicted_index), len(avg_correct_predicted_index)
+ ),
+ "index_avg-index": safe_divide(
+ sum(avg_correct_predicted_index + avg_wrong_predicted_index),
+ len(avg_correct_predicted_index + avg_wrong_predicted_index),
+ ),
+ "index_percentage-favoured-smaller-idx": safe_divide(
+ sum(less_index_predictions), len(less_index_predictions)
+ ),
+ }
+
+ return {k: round(v, 5) for k, v in out_dict.items()}
+
+
+class ELStrongMatchingCallback(Callback):
+ def __init__(self, dataset_path: str, dataset_conf) -> None:
+ super().__init__()
+ self.dataset_path = dataset_path
+ self.dataset_conf = dataset_conf
+ self.strong_matching_metric = StrongMatching()
+
+ def on_validation_epoch_start(self, trainer, pl_module) -> None:
+ relik_reader_predictor = RelikReaderPredictor(pl_module.relik_reader_core_model)
+ predicted_samples = relik_reader_predictor.predict(
+ self.dataset_path,
+ samples=None,
+ dataset_conf=self.dataset_conf,
+ )
+ predicted_samples = list(predicted_samples)
+ for k, v in self.strong_matching_metric(predicted_samples).items():
+ pl_module.log(f"val_{k}", v)
diff --git a/relik/retriever/__init__.py b/relik/retriever/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/relik/retriever/callbacks/__init__.py b/relik/retriever/callbacks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/relik/retriever/callbacks/base.py b/relik/retriever/callbacks/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..43042c94bfc93ac32fb60b344ca644cd1c79c1f3
--- /dev/null
+++ b/relik/retriever/callbacks/base.py
@@ -0,0 +1,168 @@
+from functools import partial
+from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
+
+import hydra
+import lightning as pl
+import torch
+from lightning.pytorch.trainer.states import RunningStage
+from omegaconf import DictConfig
+from torch.utils.data import DataLoader, Dataset
+
+from relik.common.log import get_logger
+from relik.retriever.data.base.datasets import BaseDataset
+
+logger = get_logger()
+
+
+STAGES_COMPATIBILITY_MAP = {
+ "train": RunningStage.TRAINING,
+ "val": RunningStage.VALIDATING,
+ "test": RunningStage.TESTING,
+}
+
+DEFAULT_STAGES = {
+ RunningStage.VALIDATING,
+ RunningStage.TESTING,
+ RunningStage.SANITY_CHECKING,
+ RunningStage.PREDICTING,
+}
+
+
+class PredictionCallback(pl.Callback):
+ def __init__(
+ self,
+ batch_size: int = 32,
+ stages: Optional[Set[Union[str, RunningStage]]] = None,
+ other_callbacks: Optional[
+ Union[List[DictConfig], List["NLPTemplateCallback"]]
+ ] = None,
+ datasets: Optional[Union[DictConfig, BaseDataset]] = None,
+ dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
+ *args,
+ **kwargs,
+ ):
+ super().__init__()
+ # parameters
+ self.batch_size = batch_size
+ self.datasets = datasets
+ self.dataloaders = dataloaders
+
+ # callback initialization
+ if stages is None:
+ stages = DEFAULT_STAGES
+
+ # compatibily stuff
+ stages = {STAGES_COMPATIBILITY_MAP.get(stage, stage) for stage in stages}
+ self.stages = [RunningStage(stage) for stage in stages]
+ self.other_callbacks = other_callbacks or []
+ for i, callback in enumerate(self.other_callbacks):
+ if isinstance(callback, DictConfig):
+ self.other_callbacks[i] = hydra.utils.instantiate(
+ callback, _recursive_=False
+ )
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ trainer: pl.Trainer,
+ pl_module: pl.LightningModule,
+ *args,
+ **kwargs,
+ ) -> Any:
+ # it should return the predictions
+ raise NotImplementedError
+
+ def on_validation_epoch_end(
+ self, trainer: pl.Trainer, pl_module: pl.LightningModule
+ ):
+ predictions = self(trainer, pl_module)
+ for callback in self.other_callbacks:
+ callback(
+ trainer=trainer,
+ pl_module=pl_module,
+ callback=self,
+ predictions=predictions,
+ )
+
+ def on_test_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
+ predictions = self(trainer, pl_module)
+ for callback in self.other_callbacks:
+ callback(
+ trainer=trainer,
+ pl_module=pl_module,
+ callback=self,
+ predictions=predictions,
+ )
+
+ @staticmethod
+ def _get_datasets_and_dataloaders(
+ dataset: Optional[Union[Dataset, DictConfig]],
+ dataloader: Optional[DataLoader],
+ trainer: pl.Trainer,
+ dataloader_kwargs: Optional[Dict[str, Any]] = None,
+ collate_fn: Optional[Callable] = None,
+ collate_fn_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[List[Dataset], List[DataLoader]]:
+ """
+ Get the datasets and dataloaders from the datamodule or from the dataset provided.
+
+ Args:
+ dataset (`Optional[Union[Dataset, DictConfig]]`):
+ The dataset to use. If `None`, the datamodule is used.
+ dataloader (`Optional[DataLoader]`):
+ The dataloader to use. If `None`, the datamodule is used.
+ trainer (`pl.Trainer`):
+ The trainer that contains the datamodule.
+ dataloader_kwargs (`Optional[Dict[str, Any]]`):
+ The kwargs to pass to the dataloader.
+ collate_fn (`Optional[Callable]`):
+ The collate function to use.
+ collate_fn_kwargs (`Optional[Dict[str, Any]]`):
+ The kwargs to pass to the collate function.
+
+ Returns:
+ `Tuple[List[Dataset], List[DataLoader]]`: The datasets and dataloaders.
+ """
+ # if a dataset is provided, use it
+ if dataset is not None:
+ dataloader_kwargs = dataloader_kwargs or {}
+ # get dataset
+ if isinstance(dataset, DictConfig):
+ dataset = hydra.utils.instantiate(dataset, _recursive_=False)
+ datasets = [dataset] if not isinstance(dataset, list) else dataset
+ if dataloader is not None:
+ dataloaders = (
+ [dataloader] if isinstance(dataloader, DataLoader) else dataloader
+ )
+ else:
+ collate_fn = collate_fn or partial(
+ datasets[0].collate_fn, **collate_fn_kwargs
+ )
+ dataloader_kwargs["collate_fn"] = collate_fn
+ dataloaders = [DataLoader(datasets[0], **dataloader_kwargs)]
+ else:
+ # get the dataloaders and datasets from the datamodule
+ datasets = (
+ trainer.datamodule.test_datasets
+ if trainer.state.stage == RunningStage.TESTING
+ else trainer.datamodule.val_datasets
+ )
+ dataloaders = (
+ trainer.test_dataloaders
+ if trainer.state.stage == RunningStage.TESTING
+ else trainer.val_dataloaders
+ )
+ return datasets, dataloaders
+
+
+class NLPTemplateCallback:
+ def __call__(
+ self,
+ trainer: pl.Trainer,
+ pl_module: pl.LightningModule,
+ callback: PredictionCallback,
+ predictions: Dict[str, Any],
+ *args,
+ **kwargs,
+ ) -> Any:
+ raise NotImplementedError
diff --git a/relik/retriever/callbacks/evaluation_callbacks.py b/relik/retriever/callbacks/evaluation_callbacks.py
new file mode 100644
index 0000000000000000000000000000000000000000..72d8fcb306f9e0e0fb58d1dbdb8ea6cfcea8c7e3
--- /dev/null
+++ b/relik/retriever/callbacks/evaluation_callbacks.py
@@ -0,0 +1,276 @@
+import logging
+from typing import Dict, List, Optional
+
+import lightning as pl
+import torch
+from lightning.pytorch.trainer.states import RunningStage
+from sklearn.metrics import label_ranking_average_precision_score
+
+from relik.common.log import get_console_logger, get_logger
+from relik.retriever.callbacks.base import DEFAULT_STAGES, NLPTemplateCallback
+
+console_logger = get_console_logger()
+logger = get_logger(__name__, level=logging.INFO)
+
+
+class RecallAtKEvaluationCallback(NLPTemplateCallback):
+ """
+ Computes the recall at k for the predictions. Recall at k is computed as the number of
+ correct predictions in the top k predictions divided by the total number of correct
+ predictions.
+
+ Args:
+ k (`int`):
+ The number of predictions to consider.
+ prefix (`str`, `optional`):
+ The prefix to add to the metrics.
+ verbose (`bool`, `optional`, defaults to `False`):
+ Whether to log the metrics.
+ prog_bar (`bool`, `optional`, defaults to `True`):
+ Whether to log the metrics to the progress bar.
+ """
+
+ def __init__(
+ self,
+ k: int = 100,
+ prefix: Optional[str] = None,
+ verbose: bool = False,
+ prog_bar: bool = True,
+ *args,
+ **kwargs,
+ ):
+ super().__init__()
+ self.k = k
+ self.prefix = prefix
+ self.verbose = verbose
+ self.prog_bar = prog_bar
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ trainer: pl.Trainer,
+ pl_module: pl.LightningModule,
+ predictions: Dict,
+ *args,
+ **kwargs,
+ ) -> dict:
+ """
+ Computes the recall at k for the predictions.
+
+ Args:
+ trainer (:obj:`lightning.trainer.trainer.Trainer`):
+ The trainer object.
+ pl_module (:obj:`lightning.core.lightning.LightningModule`):
+ The lightning module.
+ predictions (:obj:`Dict`):
+ The predictions.
+
+ Returns:
+ :obj:`Dict`: The computed metrics.
+ """
+ if self.verbose:
+ logger.info(f"Computing recall@{self.k}")
+
+ # metrics to return
+ metrics = {}
+
+ stage = trainer.state.stage
+ if stage not in DEFAULT_STAGES:
+ raise ValueError(
+ f"Stage {stage} not supported, only `validate` and `test` are supported."
+ )
+
+ for dataloader_idx, samples in predictions.items():
+ hits, total = 0, 0
+ for sample in samples:
+ # compute the recall at k
+ # cut the predictions to the first k elements
+ predictions = sample["predictions"][: self.k]
+ hits += len(set(predictions) & set(sample["gold"]))
+ total += len(set(sample["gold"]))
+
+ # compute the mean recall at k
+ recall_at_k = hits / total
+ metrics[f"recall@{self.k}_{dataloader_idx}"] = recall_at_k
+ metrics[f"recall@{self.k}"] = sum(metrics.values()) / len(metrics)
+
+ if self.prefix is not None:
+ metrics = {f"{self.prefix}_{k}": v for k, v in metrics.items()}
+ else:
+ metrics = {f"{stage.value}_{k}": v for k, v in metrics.items()}
+ pl_module.log_dict(
+ metrics, on_step=False, on_epoch=True, prog_bar=self.prog_bar
+ )
+
+ if self.verbose:
+ logger.info(
+ f"Recall@{self.k} on {stage.value}: {metrics[f'{stage.value}_recall@{self.k}']}"
+ )
+
+ return metrics
+
+
+class AvgRankingEvaluationCallback(NLPTemplateCallback):
+ """
+ Computes the average ranking of the gold label in the predictions. Average ranking is
+ computed as the average of the rank of the gold label in the predictions.
+
+ Args:
+ k (`int`):
+ The number of predictions to consider.
+ prefix (`str`, `optional`):
+ The prefix to add to the metrics.
+ stages (`List[str]`, `optional`):
+ The stages to compute the metrics on. Defaults to `["validate", "test"]`.
+ verbose (`bool`, `optional`, defaults to `False`):
+ Whether to log the metrics.
+ """
+
+ def __init__(
+ self,
+ k: int,
+ prefix: Optional[str] = None,
+ stages: Optional[List[str]] = None,
+ verbose: bool = True,
+ *args,
+ **kwargs,
+ ):
+ super().__init__()
+ self.k = k
+ self.prefix = prefix
+ self.verbose = verbose
+ self.stages = (
+ [RunningStage(stage) for stage in stages] if stages else DEFAULT_STAGES
+ )
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ trainer: pl.Trainer,
+ pl_module: pl.LightningModule,
+ predictions: Dict,
+ *args,
+ **kwargs,
+ ) -> dict:
+ """
+ Computes the average ranking of the gold label in the predictions.
+
+ Args:
+ trainer (:obj:`lightning.trainer.trainer.Trainer`):
+ The trainer object.
+ pl_module (:obj:`lightning.core.lightning.LightningModule`):
+ The lightning module.
+ predictions (:obj:`Dict`):
+ The predictions.
+
+ Returns:
+ :obj:`Dict`: The computed metrics.
+ """
+ if not predictions:
+ logger.warning("No predictions to compute the AVG Ranking metrics.")
+ return {}
+
+ if self.verbose:
+ logger.info(f"Computing AVG Ranking@{self.k}")
+
+ # metrics to return
+ metrics = {}
+
+ stage = trainer.state.stage
+ if stage not in self.stages:
+ raise ValueError(
+ f"Stage `{stage}` not supported, only `validate` and `test` are supported."
+ )
+
+ for dataloader_idx, samples in predictions.items():
+ rankings = []
+ for sample in samples:
+ window_candidates = sample["predictions"][: self.k]
+ window_labels = sample["gold"]
+ for wl in window_labels:
+ if wl in window_candidates:
+ rankings.append(window_candidates.index(wl) + 1)
+
+ avg_ranking = sum(rankings) / len(rankings) if len(rankings) > 0 else 0
+ metrics[f"avg_ranking@{self.k}_{dataloader_idx}"] = avg_ranking
+ if len(metrics) == 0:
+ metrics[f"avg_ranking@{self.k}"] = 0
+ else:
+ metrics[f"avg_ranking@{self.k}"] = sum(metrics.values()) / len(metrics)
+
+ prefix = self.prefix or stage.value
+ metrics = {
+ f"{prefix}_{k}": torch.as_tensor(v, dtype=torch.float32)
+ for k, v in metrics.items()
+ }
+ pl_module.log_dict(metrics, on_step=False, on_epoch=True, prog_bar=False)
+
+ if self.verbose:
+ logger.info(
+ f"AVG Ranking@{self.k} on {prefix}: {metrics[f'{prefix}_avg_ranking@{self.k}']}"
+ )
+
+ return metrics
+
+
+class LRAPEvaluationCallback(NLPTemplateCallback):
+ def __init__(
+ self,
+ k: int = 100,
+ prefix: Optional[str] = None,
+ verbose: bool = False,
+ prog_bar: bool = True,
+ *args,
+ **kwargs,
+ ):
+ super().__init__()
+ self.k = k
+ self.prefix = prefix
+ self.verbose = verbose
+ self.prog_bar = prog_bar
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ trainer: pl.Trainer,
+ pl_module: pl.LightningModule,
+ predictions: Dict,
+ *args,
+ **kwargs,
+ ) -> dict:
+ if self.verbose:
+ logger.info(f"Computing recall@{self.k}")
+
+ # metrics to return
+ metrics = {}
+
+ stage = trainer.state.stage
+ if stage not in DEFAULT_STAGES:
+ raise ValueError(
+ f"Stage {stage} not supported, only `validate` and `test` are supported."
+ )
+
+ for dataloader_idx, samples in predictions.items():
+ scores = [sample["scores"][: self.k] for sample in samples]
+ golds = [sample["gold"] for sample in samples]
+
+ # compute the mean recall at k
+ lrap = label_ranking_average_precision_score(golds, scores)
+ metrics[f"lrap@{self.k}_{dataloader_idx}"] = lrap
+ metrics[f"lrap@{self.k}"] = sum(metrics.values()) / len(metrics)
+
+ prefix = self.prefix or stage.value
+ metrics = {
+ f"{prefix}_{k}": torch.as_tensor(v, dtype=torch.float32)
+ for k, v in metrics.items()
+ }
+ pl_module.log_dict(
+ metrics, on_step=False, on_epoch=True, prog_bar=self.prog_bar
+ )
+
+ if self.verbose:
+ logger.info(
+ f"Recall@{self.k} on {stage.value}: {metrics[f'{stage.value}_recall@{self.k}']}"
+ )
+
+ return metrics
diff --git a/relik/retriever/callbacks/prediction_callbacks.py b/relik/retriever/callbacks/prediction_callbacks.py
new file mode 100644
index 0000000000000000000000000000000000000000..f8a051ad396d07872dfac05998d1ec550724677a
--- /dev/null
+++ b/relik/retriever/callbacks/prediction_callbacks.py
@@ -0,0 +1,432 @@
+import logging
+import random
+import time
+from copy import deepcopy
+from pathlib import Path
+from typing import List, Optional, Set, Union
+
+import lightning as pl
+import torch
+from lightning.pytorch.trainer.states import RunningStage
+from omegaconf import DictConfig
+from torch.utils.data import DataLoader
+from tqdm import tqdm
+
+from relik.common.log import get_console_logger, get_logger
+from relik.retriever.callbacks.base import PredictionCallback
+from relik.retriever.common.model_inputs import ModelInputs
+from relik.retriever.data.base.datasets import BaseDataset
+from relik.retriever.data.datasets import GoldenRetrieverDataset
+from relik.retriever.data.utils import HardNegativesManager
+from relik.retriever.indexers.base import BaseDocumentIndex
+from relik.retriever.pytorch_modules.model import GoldenRetriever
+
+console_logger = get_console_logger()
+logger = get_logger(__name__, level=logging.INFO)
+
+
+class GoldenRetrieverPredictionCallback(PredictionCallback):
+ def __init__(
+ self,
+ k: Optional[int] = None,
+ batch_size: int = 32,
+ num_workers: int = 8,
+ document_index: Optional[BaseDocumentIndex] = None,
+ precision: Union[str, int] = 32,
+ force_reindex: bool = True,
+ retriever_dir: Optional[Path] = None,
+ stages: Optional[Set[Union[str, RunningStage]]] = None,
+ other_callbacks: Optional[List[DictConfig]] = None,
+ dataset: Optional[Union[DictConfig, BaseDataset]] = None,
+ dataloader: Optional[DataLoader] = None,
+ *args,
+ **kwargs,
+ ):
+ super().__init__(batch_size, stages, other_callbacks, dataset, dataloader)
+ self.k = k
+ self.num_workers = num_workers
+ self.document_index = document_index
+ self.precision = precision
+ self.force_reindex = force_reindex
+ self.retriever_dir = retriever_dir
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ trainer: pl.Trainer,
+ pl_module: pl.LightningModule,
+ datasets: Optional[
+ Union[DictConfig, BaseDataset, List[DictConfig], List[BaseDataset]]
+ ] = None,
+ dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
+ *args,
+ **kwargs,
+ ) -> dict:
+ stage = trainer.state.stage
+ logger.info(f"Computing predictions for stage {stage.value}")
+ if stage not in self.stages:
+ raise ValueError(
+ f"Stage `{stage}` not supported, only {self.stages} are supported"
+ )
+
+ self.datasets, self.dataloaders = self._get_datasets_and_dataloaders(
+ datasets,
+ dataloaders,
+ trainer,
+ dataloader_kwargs=dict(
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ pin_memory=True,
+ shuffle=False,
+ ),
+ )
+
+ # set the model to eval mode
+ pl_module.eval()
+ # get the retriever
+ retriever: GoldenRetriever = pl_module.model
+
+ # here we will store the samples with predictions for each dataloader
+ dataloader_predictions = {}
+ # compute the passage embeddings index for each dataloader
+ for dataloader_idx, dataloader in enumerate(self.dataloaders):
+ current_dataset: GoldenRetrieverDataset = self.datasets[dataloader_idx]
+ logger.info(
+ f"Computing passage embeddings for dataset {current_dataset.name}"
+ )
+ # passages = self._get_passages_dataloader(current_dataset, trainer)
+
+ tokenizer = current_dataset.tokenizer
+
+ def collate_fn(x):
+ return ModelInputs(
+ tokenizer(
+ x,
+ truncation=True,
+ padding=True,
+ max_length=current_dataset.max_passage_length,
+ return_tensors="pt",
+ )
+ )
+
+ # check if we need to reindex the passages and
+ # also if we need to load the retriever from disk
+ if (self.retriever_dir is not None and trainer.current_epoch == 0) or (
+ self.retriever_dir is not None and stage == RunningStage.TESTING
+ ):
+ force_reindex = False
+ else:
+ force_reindex = self.force_reindex
+
+ if (
+ not force_reindex
+ and self.retriever_dir is not None
+ and stage == RunningStage.TESTING
+ ):
+ retriever = retriever.from_pretrained(self.retriever_dir)
+ # set the retriever to eval mode if we are loading it from disk
+
+ # you never know :)
+ retriever.eval()
+
+ retriever.index(
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ max_length=current_dataset.max_passage_length,
+ collate_fn=collate_fn,
+ precision=self.precision,
+ compute_on_cpu=False,
+ force_reindex=force_reindex,
+ )
+
+ # pl_module_original_device = pl_module.device
+ # if (
+ # and pl_module.device.type == "cuda"
+ # ):
+ # pl_module.to("cpu")
+
+ # now compute the question embeddings and compute the top-k accuracy
+ predictions = []
+ start = time.time()
+ for batch in tqdm(
+ dataloader,
+ desc=f"Computing predictions for dataset {current_dataset.name}",
+ ):
+ batch = batch.to(pl_module.device)
+ # get the top-k indices
+ retriever_output = retriever.retrieve(
+ **batch.questions, k=self.k, precision=self.precision
+ )
+ # compute recall at k
+ for batch_idx, retrieved_samples in enumerate(retriever_output):
+ # get the positive passages
+ gold_passages = batch["positives"][batch_idx]
+ # get the index of the gold passages in the retrieved passages
+ gold_passage_indices = [
+ retriever.get_index_from_passage(passage)
+ for passage in gold_passages
+ ]
+ retrieved_indices = [r.index for r in retrieved_samples]
+ retrieved_passages = [r.label for r in retrieved_samples]
+ retrieved_scores = [r.score for r in retrieved_samples]
+ # correct predictions are the passages that are in the top-k and are gold
+ correct_indices = set(gold_passage_indices) & set(retrieved_indices)
+ # wrong predictions are the passages that are in the top-k and are not gold
+ wrong_indices = set(retrieved_indices) - set(gold_passage_indices)
+ # add the predictions to the list
+ prediction_output = dict(
+ sample_idx=batch.sample_idx[batch_idx],
+ gold=gold_passages,
+ predictions=retrieved_passages,
+ scores=retrieved_scores,
+ correct=[
+ retriever.get_passage_from_index(i) for i in correct_indices
+ ],
+ wrong=[
+ retriever.get_passage_from_index(i) for i in wrong_indices
+ ],
+ )
+ predictions.append(prediction_output)
+ end = time.time()
+ logger.info(f"Time to retrieve: {str(end - start)}")
+
+ dataloader_predictions[dataloader_idx] = predictions
+
+ # if pl_module_original_device != pl_module.device:
+ # pl_module.to(pl_module_original_device)
+
+ # return the predictions
+ return dataloader_predictions
+
+ # @staticmethod
+ # def _get_passages_dataloader(
+ # indexer: Optional[BaseIndexer] = None,
+ # dataset: Optional[GoldenRetrieverDataset] = None,
+ # trainer: Optional[pl.Trainer] = None,
+ # ):
+ # if indexer is None:
+ # logger.info(
+ # f"Indexer is None, creating indexer from passages not found in dataset {dataset.name}, computing them from the dataloaders"
+ # )
+ # # get the passages from the all the dataloader passage ids
+ # passages = set() # set to avoid duplicates
+ # for batch in trainer.train_dataloader:
+ # passages.update(
+ # [
+ # " ".join(map(str, [c for c in passage_ids.tolist() if c != 0]))
+ # for passage_ids in batch["passages"]["input_ids"]
+ # ]
+ # )
+ # for d in trainer.val_dataloaders:
+ # for batch in d:
+ # passages.update(
+ # [
+ # " ".join(
+ # map(str, [c for c in passage_ids.tolist() if c != 0])
+ # )
+ # for passage_ids in batch["passages"]["input_ids"]
+ # ]
+ # )
+ # for d in trainer.test_dataloaders:
+ # for batch in d:
+ # passages.update(
+ # [
+ # " ".join(
+ # map(str, [c for c in passage_ids.tolist() if c != 0])
+ # )
+ # for passage_ids in batch["passages"]["input_ids"]
+ # ]
+ # )
+ # passages = list(passages)
+ # else:
+ # passages = dataset.passages
+ # return passages
+
+
+class NegativeAugmentationCallback(GoldenRetrieverPredictionCallback):
+ """
+ Callback that computes the predictions of a retriever model on a dataset and computes the
+ negative examples for the training set.
+
+ Args:
+ k (:obj:`int`, `optional`, defaults to 100):
+ The number of top-k retrieved passages to
+ consider for the evaluation.
+ batch_size (:obj:`int`, `optional`, defaults to 32):
+ The batch size to use for the evaluation.
+ num_workers (:obj:`int`, `optional`, defaults to 0):
+ The number of workers to use for the evaluation.
+ force_reindex (:obj:`bool`, `optional`, defaults to :obj:`False`):
+ Whether to force the reindexing of the dataset.
+ retriever_dir (:obj:`Path`, `optional`):
+ The path to the retriever directory. If not specified, the retriever will be
+ initialized from scratch.
+ stages (:obj:`Set[str]`, `optional`):
+ The stages to run the callback on. If not specified, the callback will be run on
+ train, validation and test.
+ other_callbacks (:obj:`List[DictConfig]`, `optional`):
+ A list of other callbacks to run on the same stages.
+ dataset (:obj:`Union[DictConfig, BaseDataset]`, `optional`):
+ The dataset to use for the evaluation. If not specified, the dataset will be
+ initialized from scratch.
+ metrics_to_monitor (:obj:`List[str]`, `optional`):
+ The metrics to monitor for the evaluation.
+ threshold (:obj:`float`, `optional`, defaults to 0.8):
+ The threshold to consider. If the recall score of the retriever is above the
+ threshold, the negative examples will be added to the training set.
+ max_negatives (:obj:`int`, `optional`, defaults to 5):
+ The maximum number of negative examples to add to the training set.
+ add_with_probability (:obj:`float`, `optional`, defaults to 1.0):
+ The probability with which to add the negative examples to the training set.
+ refresh_every_n_epochs (:obj:`int`, `optional`, defaults to 1):
+ The number of epochs after which to refresh the index.
+ """
+
+ def __init__(
+ self,
+ k: int = 100,
+ batch_size: int = 32,
+ num_workers: int = 0,
+ force_reindex: bool = False,
+ retriever_dir: Optional[Path] = None,
+ stages: Set[Union[str, RunningStage]] = None,
+ other_callbacks: Optional[List[DictConfig]] = None,
+ dataset: Optional[Union[DictConfig, BaseDataset]] = None,
+ metrics_to_monitor: List[str] = None,
+ threshold: float = 0.8,
+ max_negatives: int = 5,
+ add_with_probability: float = 1.0,
+ refresh_every_n_epochs: int = 1,
+ *args,
+ **kwargs,
+ ):
+ super().__init__(
+ k=k,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ force_reindex=force_reindex,
+ retriever_dir=retriever_dir,
+ stages=stages,
+ other_callbacks=other_callbacks,
+ dataset=dataset,
+ *args,
+ **kwargs,
+ )
+ if metrics_to_monitor is None:
+ metrics_to_monitor = ["val_loss"]
+ self.metrics_to_monitor = metrics_to_monitor
+ self.threshold = threshold
+ self.max_negatives = max_negatives
+ self.add_with_probability = add_with_probability
+ self.refresh_every_n_epochs = refresh_every_n_epochs
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ trainer: pl.Trainer,
+ pl_module: pl.LightningModule,
+ *args,
+ **kwargs,
+ ) -> dict:
+ """
+ Computes the predictions of a retriever model on a dataset and computes the negative
+ examples for the training set.
+
+ Args:
+ trainer (:obj:`pl.Trainer`):
+ The trainer object.
+ pl_module (:obj:`pl.LightningModule`):
+ The lightning module.
+
+ Returns:
+ A dictionary containing the negative examples.
+ """
+ stage = trainer.state.stage
+ if stage not in self.stages:
+ return {}
+
+ if self.metrics_to_monitor not in trainer.logged_metrics:
+ raise ValueError(
+ f"Metric `{self.metrics_to_monitor}` not found in trainer.logged_metrics"
+ f"Available metrics: {trainer.logged_metrics.keys()}"
+ )
+ if trainer.logged_metrics[self.metrics_to_monitor] < self.threshold:
+ return {}
+
+ if trainer.current_epoch % self.refresh_every_n_epochs != 0:
+ return {}
+
+ # if all(
+ # [
+ # trainer.logged_metrics.get(metric) is None
+ # for metric in self.metrics_to_monitor
+ # ]
+ # ):
+ # raise ValueError(
+ # f"No metric from {self.metrics_to_monitor} not found in trainer.logged_metrics"
+ # f"Available metrics: {trainer.logged_metrics.keys()}"
+ # )
+
+ # if all(
+ # [
+ # trainer.logged_metrics.get(metric) < self.threshold
+ # for metric in self.metrics_to_monitor
+ # if trainer.logged_metrics.get(metric) is not None
+ # ]
+ # ):
+ # return {}
+
+ if trainer.current_epoch % self.refresh_every_n_epochs != 0:
+ return {}
+
+ logger.info(
+ f"At least one metric from {self.metrics_to_monitor} is above threshold "
+ f"{self.threshold}. Computing hard negatives."
+ )
+
+ # make a copy of the dataset to avoid modifying the original one
+ trainer.datamodule.train_dataset.hn_manager = None
+ dataset_copy = deepcopy(trainer.datamodule.train_dataset)
+ predictions = super().__call__(
+ trainer,
+ pl_module,
+ datasets=dataset_copy,
+ dataloaders=DataLoader(
+ dataset_copy.to_torch_dataset(),
+ shuffle=False,
+ batch_size=None,
+ num_workers=self.num_workers,
+ pin_memory=True,
+ collate_fn=lambda x: x,
+ ),
+ *args,
+ **kwargs,
+ )
+ logger.info(f"Computing hard negatives for epoch {trainer.current_epoch}")
+ # predictions is a dict with the dataloader index as key and the predictions as value
+ # since we only have one dataloader, we can get the predictions directly
+ predictions = list(predictions.values())[0]
+ # store the predictions in a dictionary for faster access based on the sample index
+ hard_negatives_list = {}
+ for prediction in tqdm(predictions, desc="Collecting hard negatives"):
+ if random.random() < 1 - self.add_with_probability:
+ continue
+ top_k_passages = prediction["predictions"]
+ gold_passages = prediction["gold"]
+ # get the ids of the max_negatives wrong passages with the highest similarity
+ wrong_passages = [
+ passage_id
+ for passage_id in top_k_passages
+ if passage_id not in gold_passages
+ ][: self.max_negatives]
+ hard_negatives_list[prediction["sample_idx"]] = wrong_passages
+
+ trainer.datamodule.train_dataset.hn_manager = HardNegativesManager(
+ tokenizer=trainer.datamodule.train_dataset.tokenizer,
+ max_length=trainer.datamodule.train_dataset.max_passage_length,
+ data=hard_negatives_list,
+ )
+
+ # normalize predictions as in the original GoldenRetrieverPredictionCallback
+ predictions = {0: predictions}
+ return predictions
diff --git a/relik/retriever/callbacks/utils_callbacks.py b/relik/retriever/callbacks/utils_callbacks.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba73e0d9ee02d9e1424611551befc002bdaaecf3
--- /dev/null
+++ b/relik/retriever/callbacks/utils_callbacks.py
@@ -0,0 +1,287 @@
+import json
+import logging
+import os
+from pathlib import Path
+from typing import Any, Dict, Optional, Union
+
+import lightning as pl
+import torch
+from lightning.pytorch.trainer.states import RunningStage
+
+from relik.common.log import get_console_logger, get_logger
+from relik.retriever.callbacks.base import NLPTemplateCallback, PredictionCallback
+from relik.retriever.pytorch_modules.hf import GoldenRetrieverModel
+
+console_logger = get_console_logger()
+logger = get_logger(__name__, level=logging.INFO)
+
+
+class SavePredictionsCallback(NLPTemplateCallback):
+ def __init__(
+ self,
+ saving_dir: Optional[Union[str, os.PathLike]] = None,
+ verbose: bool = False,
+ *args,
+ **kwargs,
+ ):
+ super().__init__()
+ self.saving_dir = saving_dir
+ self.verbose = verbose
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ trainer: pl.Trainer,
+ pl_module: pl.LightningModule,
+ predictions: Dict,
+ callback: PredictionCallback,
+ *args,
+ **kwargs,
+ ) -> dict:
+ # write the predictions to a file inside the experiment folder
+ if self.saving_dir is None and trainer.logger is None:
+ logger.info(
+ "You need to specify an output directory (`saving_dir`) or a logger to save the predictions.\n"
+ "Skipping saving predictions."
+ )
+ return
+ datasets = callback.datasets
+ for dataloader_idx, dataloader_predictions in predictions.items():
+ # save to file
+ if self.saving_dir is not None:
+ prediction_folder = Path(self.saving_dir)
+ else:
+ try:
+ prediction_folder = (
+ Path(trainer.logger.experiment.dir) / "predictions"
+ )
+ except Exception:
+ logger.info(
+ "You need to specify an output directory (`saving_dir`) or a logger to save the predictions.\n"
+ "Skipping saving predictions."
+ )
+ return
+ prediction_folder.mkdir(exist_ok=True)
+ predictions_path = (
+ prediction_folder
+ / f"{datasets[dataloader_idx].name}_{dataloader_idx}.json"
+ )
+ if self.verbose:
+ logger.info(f"Saving predictions to {predictions_path}")
+ with open(predictions_path, "w") as f:
+ for prediction in dataloader_predictions:
+ for k, v in prediction.items():
+ if isinstance(v, set):
+ prediction[k] = list(v)
+ f.write(json.dumps(prediction) + "\n")
+
+
+class ResetModelCallback(pl.Callback):
+ def __init__(
+ self,
+ question_encoder: str,
+ passage_encoder: Optional[str] = None,
+ verbose: bool = True,
+ ) -> None:
+ super().__init__()
+ self.question_encoder = question_encoder
+ self.passage_encoder = passage_encoder
+ self.verbose = verbose
+
+ def on_train_epoch_start(
+ self, trainer: pl.Trainer, pl_module: pl.LightningModule, *args, **kwargs
+ ) -> None:
+ if trainer.current_epoch == 0:
+ if self.verbose:
+ logger.info("Current epoch is 0, skipping resetting model")
+ return
+
+ if self.verbose:
+ logger.info("Resetting model, optimizer and lr scheduler")
+ # reload model from scratch
+ previous_device = pl_module.device
+ trainer.model.model.question_encoder = GoldenRetrieverModel.from_pretrained(
+ self.question_encoder
+ )
+ trainer.model.model.question_encoder.to(previous_device)
+ if self.passage_encoder is not None:
+ trainer.model.model.passage_encoder = GoldenRetrieverModel.from_pretrained(
+ self.passage_encoder
+ )
+ trainer.model.model.passage_encoder.to(previous_device)
+
+ trainer.strategy.setup_optimizers(trainer)
+
+
+class FreeUpIndexerVRAMCallback(pl.Callback):
+ def __call__(
+ self,
+ pl_module: pl.LightningModule,
+ *args,
+ **kwargs,
+ ) -> Any:
+ logger.info("Freeing up GPU memory")
+
+ # remove the index from the GPU memory
+ # remove the embeddings from the GPU memory first
+ if pl_module.model.document_index is not None:
+ if pl_module.model.document_index.embeddings is not None:
+ pl_module.model.document_index.embeddings.cpu()
+ pl_module.model.document_index.embeddings = None
+
+ import gc
+
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def on_train_epoch_start(
+ self, trainer: pl.Trainer, pl_module: pl.LightningModule, *args, **kwargs
+ ) -> None:
+ return self(pl_module)
+
+ def on_test_epoch_start(
+ self, trainer: pl.Trainer, pl_module: pl.LightningModule, *args, **kwargs
+ ) -> None:
+ return self(pl_module)
+
+
+class ShuffleTrainDatasetCallback(pl.Callback):
+ def __init__(self, seed: int = 42, verbose: bool = True) -> None:
+ super().__init__()
+ self.seed = seed
+ self.verbose = verbose
+ self.previous_epoch = -1
+
+ def on_validation_epoch_end(self, trainer: pl.Trainer, *args, **kwargs):
+ if self.verbose:
+ if trainer.current_epoch != self.previous_epoch:
+ logger.info(f"Shuffling train dataset at epoch {trainer.current_epoch}")
+
+ # logger.info(f"Shuffling train dataset at epoch {trainer.current_epoch}")
+ if trainer.current_epoch != self.previous_epoch:
+ trainer.datamodule.train_dataset.shuffle_data(
+ seed=self.seed + trainer.current_epoch + 1
+ )
+ self.previous_epoch = trainer.current_epoch
+
+
+class PrefetchTrainDatasetCallback(pl.Callback):
+ def __init__(self, verbose: bool = True) -> None:
+ super().__init__()
+ self.verbose = verbose
+ # self.previous_epoch = -1
+
+ def on_validation_epoch_end(self, trainer: pl.Trainer, *args, **kwargs):
+ if trainer.datamodule.train_dataset.prefetch_batches:
+ if self.verbose:
+ # if trainer.current_epoch != self.previous_epoch:
+ logger.info(
+ f"Prefetching train dataset at epoch {trainer.current_epoch}"
+ )
+ # if trainer.current_epoch != self.previous_epoch:
+ trainer.datamodule.train_dataset.prefetch()
+ self.previous_epoch = trainer.current_epoch
+
+
+class SubsampleTrainDatasetCallback(pl.Callback):
+ def __init__(self, seed: int = 43, verbose: bool = True) -> None:
+ super().__init__()
+ self.seed = seed
+ self.verbose = verbose
+
+ def on_validation_epoch_end(self, trainer: pl.Trainer, *args, **kwargs):
+ if self.verbose:
+ logger.info(f"Subsampling train dataset at epoch {trainer.current_epoch}")
+ trainer.datamodule.train_dataset.random_subsample(
+ seed=self.seed + trainer.current_epoch + 1
+ )
+
+
+class SaveRetrieverCallback(pl.Callback):
+ def __init__(
+ self,
+ saving_dir: Optional[Union[str, os.PathLike]] = None,
+ verbose: bool = True,
+ *args,
+ **kwargs,
+ ):
+ super().__init__()
+ self.saving_dir = saving_dir
+ self.verbose = verbose
+ self.free_up_indexer_callback = FreeUpIndexerVRAMCallback()
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ trainer: pl.Trainer,
+ pl_module: pl.LightningModule,
+ *args,
+ **kwargs,
+ ):
+ if self.saving_dir is None and trainer.logger is None:
+ logger.info(
+ "You need to specify an output directory (`saving_dir`) or a logger to save the retriever.\n"
+ "Skipping saving retriever."
+ )
+ return
+ if self.saving_dir is not None:
+ retriever_folder = Path(self.saving_dir)
+ else:
+ try:
+ retriever_folder = Path(trainer.logger.experiment.dir) / "retriever"
+ except Exception:
+ logger.info(
+ "You need to specify an output directory (`saving_dir`) or a logger to save the retriever.\n"
+ "Skipping saving retriever."
+ )
+ return
+ retriever_folder.mkdir(exist_ok=True, parents=True)
+ if self.verbose:
+ logger.info(f"Saving retriever to {retriever_folder}")
+ pl_module.model.save_pretrained(retriever_folder)
+
+ def on_save_checkpoint(
+ self,
+ trainer: pl.Trainer,
+ pl_module: pl.LightningModule,
+ checkpoint: Dict[str, Any],
+ ):
+ self(trainer, pl_module)
+ # self.free_up_indexer_callback(pl_module)
+
+
+class SampleNegativesDatasetCallback(pl.Callback):
+ def __init__(self, seed: int = 42, verbose: bool = True) -> None:
+ super().__init__()
+ self.seed = seed
+ self.verbose = verbose
+
+ def on_validation_epoch_end(self, trainer: pl.Trainer, *args, **kwargs):
+ if self.verbose:
+ f"Sampling negatives for train dataset at epoch {trainer.current_epoch}"
+ trainer.datamodule.train_dataset.sample_dataset_negatives(
+ seed=self.seed + trainer.current_epoch
+ )
+
+
+class SubsampleDataCallback(pl.Callback):
+ def __init__(self, seed: int = 42, verbose: bool = True) -> None:
+ super().__init__()
+ self.seed = seed
+ self.verbose = verbose
+
+ def on_validation_epoch_start(self, trainer: pl.Trainer, *args, **kwargs):
+ if self.verbose:
+ f"Subsampling data for train dataset at epoch {trainer.current_epoch}"
+ if trainer.state.stage == RunningStage.SANITY_CHECKING:
+ return
+ trainer.datamodule.train_dataset.subsample_data(
+ seed=self.seed + trainer.current_epoch
+ )
+
+ def on_fit_start(self, trainer: pl.Trainer, *args, **kwargs):
+ if self.verbose:
+ f"Subsampling data for train dataset at epoch {trainer.current_epoch}"
+ trainer.datamodule.train_dataset.subsample_data(
+ seed=self.seed + trainer.current_epoch
+ )
diff --git a/relik/retriever/common/__init__.py b/relik/retriever/common/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/relik/retriever/common/model_inputs.py b/relik/retriever/common/model_inputs.py
new file mode 100644
index 0000000000000000000000000000000000000000..06b28f74e9594df32f98f8a32a0b46177db54062
--- /dev/null
+++ b/relik/retriever/common/model_inputs.py
@@ -0,0 +1,55 @@
+from __future__ import annotations
+
+from collections import UserDict
+from typing import Any, Union
+
+import torch
+from lightning.fabric.utilities import move_data_to_device
+
+from relik.common.log import get_console_logger
+
+logger = get_console_logger()
+
+
+class ModelInputs(UserDict):
+ """Model input dictionary wrapper."""
+
+ def __getattr__(self, item: str):
+ try:
+ return self.data[item]
+ except KeyError:
+ raise AttributeError(f"`ModelInputs` has no attribute `{item}`")
+
+ def __getitem__(self, item: str) -> Any:
+ return self.data[item]
+
+ def __getstate__(self):
+ return {"data": self.data}
+
+ def __setstate__(self, state):
+ if "data" in state:
+ self.data = state["data"]
+
+ def keys(self):
+ """A set-like object providing a view on D's keys."""
+ return self.data.keys()
+
+ def values(self):
+ """An object providing a view on D's values."""
+ return self.data.values()
+
+ def items(self):
+ """A set-like object providing a view on D's items."""
+ return self.data.items()
+
+ def to(self, device: Union[str, torch.device]) -> ModelInputs:
+ """
+ Send all tensors values to device.
+ Args:
+ device (`str` or `torch.device`): The device to put the tensors on.
+ Returns:
+ :class:`tokenizers.ModelInputs`: The same instance of :class:`~tokenizers.ModelInputs`
+ after modification.
+ """
+ self.data = move_data_to_device(self.data, device)
+ return self
diff --git a/relik/retriever/common/sampler.py b/relik/retriever/common/sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..024c57b23da6db71dd76929226b005f75b9e98f5
--- /dev/null
+++ b/relik/retriever/common/sampler.py
@@ -0,0 +1,108 @@
+import math
+
+from torch.utils.data.sampler import BatchSampler, Sampler, SubsetRandomSampler
+
+
+def identity(x):
+ return x
+
+
+class SortedSampler(Sampler):
+ """
+ Samples elements sequentially, always in the same order.
+
+ Args:
+ data (`obj`: `Iterable`):
+ Iterable data.
+ sort_key (`obj`: `Callable`):
+ Specifies a function of one argument that is used to
+ extract a numerical comparison key from each list element.
+
+ Example:
+ >>> list(SortedSampler(range(10), sort_key=lambda i: -i))
+ [9, 8, 7, 6, 5, 4, 3, 2, 1, 0]
+
+ """
+
+ def __init__(self, data, sort_key=identity):
+ super().__init__(data)
+ self.data = data
+ self.sort_key = sort_key
+ zip_ = [(i, self.sort_key(row)) for i, row in enumerate(self.data)]
+ zip_ = sorted(zip_, key=lambda r: r[1])
+ self.sorted_indexes = [item[0] for item in zip_]
+
+ def __iter__(self):
+ return iter(self.sorted_indexes)
+
+ def __len__(self):
+ return len(self.data)
+
+
+class BucketBatchSampler(BatchSampler):
+ """
+ `BucketBatchSampler` toggles between `sampler` batches and sorted batches.
+ Typically, the `sampler` will be a `RandomSampler` allowing the user to toggle between
+ random batches and sorted batches. A larger `bucket_size_multiplier` is more sorted and vice
+ versa.
+ Background:
+ ``BucketBatchSampler`` is similar to a ``BucketIterator`` found in popular libraries like
+ ``AllenNLP`` and ``torchtext``. A ``BucketIterator`` pools together examples with a similar
+ size length to reduce the padding required for each batch while maintaining some noise
+ through bucketing.
+ **AllenNLP Implementation:**
+ https://github.com/allenai/allennlp/blob/master/allennlp/data/iterators/bucket_iterator.py
+ **torchtext Implementation:**
+ https://github.com/pytorch/text/blob/master/torchtext/data/iterator.py#L225
+
+ Args:
+ sampler (`obj`: `torch.data.utils.sampler.Sampler):
+ batch_size (`int`):
+ Size of mini-batch.
+ drop_last (`bool`, optional, defaults to `False`):
+ If `True` the sampler will drop the last batch if its size would be less than `batch_size`.
+ sort_key (`obj`: `Callable`, optional, defaults to `identity`):
+ Callable to specify a comparison key for sorting.
+ bucket_size_multiplier (`int`, optional, defaults to `100`):
+ Buckets are of size `batch_size * bucket_size_multiplier`.
+ Example:
+ >>> from torchnlp.random import set_seed
+ >>> set_seed(123)
+ >>>
+ >>> from torch.utils.data.sampler import SequentialSampler
+ >>> sampler = SequentialSampler(list(range(10)))
+ >>> list(BucketBatchSampler(sampler, batch_size=3, drop_last=False))
+ [[6, 7, 8], [0, 1, 2], [3, 4, 5], [9]]
+ >>> list(BucketBatchSampler(sampler, batch_size=3, drop_last=True))
+ [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
+
+ """
+
+ def __init__(
+ self,
+ sampler,
+ batch_size,
+ drop_last: bool = False,
+ sort_key=identity,
+ bucket_size_multiplier=100,
+ ):
+ super().__init__(sampler, batch_size, drop_last)
+ self.sort_key = sort_key
+ _bucket_size = batch_size * bucket_size_multiplier
+ if hasattr(sampler, "__len__"):
+ _bucket_size = min(_bucket_size, len(sampler))
+ self.bucket_sampler = BatchSampler(sampler, _bucket_size, False)
+
+ def __iter__(self):
+ for bucket in self.bucket_sampler:
+ sorted_sampler = SortedSampler(bucket, self.sort_key)
+ for batch in SubsetRandomSampler(
+ list(BatchSampler(sorted_sampler, self.batch_size, self.drop_last))
+ ):
+ yield [bucket[i] for i in batch]
+
+ def __len__(self):
+ if self.drop_last:
+ return len(self.sampler) // self.batch_size
+ else:
+ return math.ceil(len(self.sampler) / self.batch_size)
diff --git a/relik/retriever/conf/data/aida_dataset.yaml b/relik/retriever/conf/data/aida_dataset.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..22fcc6458a2dc757b569baef37846751dd3c1c7a
--- /dev/null
+++ b/relik/retriever/conf/data/aida_dataset.yaml
@@ -0,0 +1,47 @@
+shared_params:
+ passages_path: null
+ max_passage_length: 64
+ passage_batch_size: 64
+ question_batch_size: 64
+ use_topics: False
+
+datamodule:
+ _target_: relik.retriever.lightning_modules.pl_data_modules.GoldenRetrieverPLDataModule
+ datasets:
+ train:
+ _target_: relik.retriever.data.datasets.AidaInBatchNegativesDataset
+ name: "train"
+ path: null
+ tokenizer: ${model.language_model}
+ max_passage_length: ${data.shared_params.max_passage_length}
+ question_batch_size: ${data.shared_params.question_batch_size}
+ passage_batch_size: ${data.shared_params.passage_batch_size}
+ subsample_strategy: null
+ subsample_portion: 0.1
+ shuffle: True
+ use_topics: ${data.shared_params.use_topics}
+
+ val:
+ - _target_: relik.retriever.data.datasets.AidaInBatchNegativesDataset
+ name: "val"
+ path: null
+ tokenizer: ${model.language_model}
+ max_passage_length: ${data.shared_params.max_passage_length}
+ question_batch_size: ${data.shared_params.question_batch_size}
+ passage_batch_size: ${data.shared_params.passage_batch_size}
+ use_topics: ${data.shared_params.use_topics}
+
+ test:
+ - _target_: relik.retriever.data.datasets.AidaInBatchNegativesDataset
+ name: "test"
+ path: null
+ tokenizer: ${model.language_model}
+ max_passage_length: ${data.shared_params.max_passage_length}
+ question_batch_size: ${data.shared_params.question_batch_size}
+ passage_batch_size: ${data.shared_params.passage_batch_size}
+ use_topics: ${data.shared_params.use_topics}
+
+ num_workers:
+ train: 4
+ val: 4
+ test: 4
diff --git a/relik/retriever/conf/data/dataset_v2.yaml b/relik/retriever/conf/data/dataset_v2.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..6040616d1f92182c2d002c065a7b89dc240f96b3
--- /dev/null
+++ b/relik/retriever/conf/data/dataset_v2.yaml
@@ -0,0 +1,43 @@
+shared_params:
+ passages_path: null
+ max_passage_length: 64
+ passage_batch_size: 64
+ question_batch_size: 64
+
+datamodule:
+ _target_: relik.retriever.lightning_modules.pl_data_modules.GoldenRetrieverPLDataModule
+ datasets:
+ train:
+ _target_: relik.retriever.data.datasets.InBatchNegativesDataset
+ name: "train"
+ path: null
+ tokenizer: ${model.language_model}
+ max_passage_length: ${data.shared_params.max_passage_length}
+ question_batch_size: ${data.shared_params.question_batch_size}
+ passage_batch_size: ${data.shared_params.passage_batch_size}
+ subsample_strategy: null
+ subsample_portion: 0.1
+ shuffle: True
+
+ val:
+ - _target_: relik.retriever.data.datasets.InBatchNegativesDataset
+ name: "val"
+ path: null
+ tokenizer: ${model.language_model}
+ max_passage_length: ${data.shared_params.max_passage_length}
+ question_batch_size: ${data.shared_params.question_batch_size}
+ passage_batch_size: ${data.shared_params.passage_batch_size}
+
+ test:
+ - _target_: relik.retriever.data.datasets.InBatchNegativesDataset
+ name: "test"
+ path: null
+ tokenizer: ${model.language_model}
+ max_passage_length: ${data.shared_params.max_passage_length}
+ question_batch_size: ${data.shared_params.question_batch_size}
+ passage_batch_size: ${data.shared_params.passage_batch_size}
+
+ num_workers:
+ train: 0
+ val: 0
+ test: 0
diff --git a/relik/retriever/conf/data/dpr_like.yaml b/relik/retriever/conf/data/dpr_like.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d316b9868b0e64ea3e58725fc56549a2b31436be
--- /dev/null
+++ b/relik/retriever/conf/data/dpr_like.yaml
@@ -0,0 +1,31 @@
+datamodule:
+ _target_: relik.retriever.goldenretriever.lightning_modules.pl_data_modules.PLDataModule
+ tokenizer: ${model.language_model}
+ datasets:
+ train:
+ _target_: relik.retriever.data.dpr.datasets.DPRDataset
+ name: "train"
+ passages_path: ${data_overrides.passages_path}
+ path: ${data_overrides.train_path}
+
+ val:
+ - _target_: relik.retriever.data.dpr.datasets.DPRDataset
+ name: "val"
+ passages_path: ${data_overrides.passages_path}
+ path: ${data_overrides.val_path}
+
+ test:
+ - _target_: relik.retriever.data.dpr.datasets.DPRDataset
+ name: "test"
+ passages_path: ${data_overrides.passages_path}
+ path: ${data_overrides.test_path}
+
+ batch_sizes:
+ train: 32
+ val: 64
+ test: 64
+
+ num_workers:
+ train: 4
+ val: 4
+ test: 4
diff --git a/relik/retriever/conf/data/in_batch_negatives.yaml b/relik/retriever/conf/data/in_batch_negatives.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1a90cb4619604e9d93df47f490f5948ebdfbf312
--- /dev/null
+++ b/relik/retriever/conf/data/in_batch_negatives.yaml
@@ -0,0 +1,48 @@
+shared_params:
+ passages_path: null
+ max_passage_length: 64
+ prefetch_batches: True
+ use_topics: False
+
+datamodule:
+ _target_: goldenretriever.lightning_modules.pl_data_modules.PLDataModule
+ tokenizer: ${model.language_model}
+ datasets:
+ train:
+ _target_: goldenretriever.data.dpr.datasets.InBatchNegativesDPRDataset
+ name: "train"
+ path: null
+ passages_path: ${data.shared_params.passages_path}
+ max_passage_length: ${data.shared_params.max_passage_length}
+ prefetch_batches: ${data.shared_params.prefetch_batches}
+ subsample: null
+ shuffle: True
+ use_topics: ${data.shared_params.use_topics}
+
+ val:
+ - _target_: goldenretriever.data.dpr.datasets.InBatchNegativesDPRDataset
+ name: "val"
+ path: null
+ passages_path: ${data.shared_params.passages_path}
+ max_passage_length: ${data.shared_params.max_passage_length}
+ prefetch_batches: ${data.shared_params.prefetch_batches}
+ use_topics: ${data.shared_params.use_topics}
+
+ test:
+ - _target_: goldenretriever.data.dpr.datasets.InBatchNegativesDPRDataset
+ name: "test"
+ path: null
+ passages_path: ${data.shared_params.passages_path}
+ max_passage_length: ${data.shared_params.max_passage_length}
+ prefetch_batches: ${data.shared_params.prefetch_batches}
+ use_topics: ${data.shared_params.use_topics}
+
+ batch_sizes:
+ train: 64
+ val: 64
+ test: 64
+
+ num_workers:
+ train: 4
+ val: 4
+ test: 4
diff --git a/relik/retriever/conf/data/iterable_in_batch_negatives.yaml b/relik/retriever/conf/data/iterable_in_batch_negatives.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..397fda31c4b42ef5cd22c10805200f4d19cd590d
--- /dev/null
+++ b/relik/retriever/conf/data/iterable_in_batch_negatives.yaml
@@ -0,0 +1,52 @@
+shared_params:
+ passages_path: null
+ max_passage_length: 64
+ max_passages_per_batch: 64
+ max_questions_per_batch: 64
+ prefetch_batches: True
+ use_topics: False
+
+datamodule:
+ _target_: relik.retriever.lightning_modules.pl_data_modules.PLDataModule
+ tokenizer: ${model.language_model}
+ datasets:
+ train:
+ _target_: relik.retriever.data.dpr.datasets.InBatchNegativesDPRIterableDataset
+ name: "train"
+ path: null
+ passages_path: ${data.shared_params.passages_path}
+ max_passage_length: ${data.shared_params.max_passage_length}
+ max_questions_per_batch: ${data.shared_params.max_questions_per_batch}
+ max_passages_per_batch: ${data.shared_params.max_passages_per_batch}
+ prefetch_batches: ${data.shared_params.prefetch_batches}
+ subsample: null
+ random_subsample: False
+ shuffle: True
+ use_topics: ${data.shared_params.use_topics}
+
+ val:
+ - _target_: relik.retriever.data.dpr.datasets.InBatchNegativesDPRIterableDataset
+ name: "val"
+ path: null
+ passages_path: ${data.shared_params.passages_path}
+ max_passage_length: ${data.shared_params.max_passage_length}
+ max_questions_per_batch: ${data.shared_params.max_questions_per_batch}
+ max_passages_per_batch: ${data.shared_params.max_passages_per_batch}
+ prefetch_batches: ${data.shared_params.prefetch_batches}
+ use_topics: ${data.shared_params.use_topics}
+
+ test:
+ - _target_: relik.retriever.data.dpr.datasets.InBatchNegativesDPRIterableDataset
+ name: "test"
+ path: null
+ passages_path: ${data.shared_params.passages_path}
+ max_passage_length: ${data.shared_params.max_passage_length}
+ max_questions_per_batch: ${data.shared_params.max_questions_per_batch}
+ max_passages_per_batch: ${data.shared_params.max_passages_per_batch}
+ prefetch_batches: ${data.shared_params.prefetch_batches}
+ use_topics: ${data.shared_params.use_topics}
+
+ num_workers:
+ train: 0
+ val: 0
+ test: 0
diff --git a/relik/retriever/conf/data/sampled_negatives.yaml b/relik/retriever/conf/data/sampled_negatives.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4f581a4a6488a7ad516aca08a5d0a0fba5906fe1
--- /dev/null
+++ b/relik/retriever/conf/data/sampled_negatives.yaml
@@ -0,0 +1,39 @@
+max_passages: 64
+
+datamodule:
+ _target_: relik.retriever.lightning_modules.pl_data_modules.PLDataModule
+ tokenizer: ${model.language_model}
+ datasets:
+ train:
+ _target_: relik.retriever.data.dpr.datasets.SampledNegativesDPRDataset
+ name: "train"
+ passages_path: ${data_overrides.passages_path}
+ max_passage_length: 64
+ max_passages: ${data.max_passages}
+ path: ${data_overrides.train_path}
+
+ val:
+ - _target_: relik.retriever.data.dpr.datasets.SampledNegativesDPRDataset
+ name: "val"
+ passages_path: ${data_overrides.passages_path}
+ max_passage_length: 64
+ max_passages: ${data.max_passages}
+ path: ${data_overrides.val_path}
+
+ test:
+ - _target_: relik.retriever.data.dpr.datasets.SampledNegativesDPRDataset
+ name: "test"
+ passages_path: ${data_overrides.passages_path}
+ max_passage_length: 64
+ max_passages: ${data.max_passages}
+ path: ${data_overrides.test_path}
+
+ batch_sizes:
+ train: 4
+ val: 64
+ test: 64
+
+ num_workers:
+ train: 4
+ val: 4
+ test: 4
diff --git a/relik/retriever/conf/finetune_iterable_in_batch.yaml b/relik/retriever/conf/finetune_iterable_in_batch.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7bcaa034ca3dca3e82d6efa6d6b674839f2ec880
--- /dev/null
+++ b/relik/retriever/conf/finetune_iterable_in_batch.yaml
@@ -0,0 +1,117 @@
+# Required to make the "experiments" dir the default one for the output of the models
+hydra:
+ run:
+ dir: ./experiments/${model_name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
+
+model_name: ${model.language_model} # used to name the model in wandb
+project_name: relik-retriever # used to name the project in wandb
+
+defaults:
+ - _self_
+ - model: golden_retriever
+ - index: inmemory
+ - loss: nce_loss
+ - optimizer: radamw
+ - scheduler: linear_scheduler
+ - data: dataset_v2 # iterable_in_batch_negatives #dataset_v2
+ - logging: wandb_logging
+ - override hydra/job_logging: colorlog
+ - override hydra/hydra_logging: colorlog
+
+train:
+ # reproducibility
+ seed: 42
+ set_determinism_the_old_way: False
+ # torch parameters
+ float32_matmul_precision: "medium"
+ # if true, only test the model
+ only_test: False
+ # if provided, initialize the model with the weights from the checkpoint
+ pretrain_ckpt_path: null
+ # if provided, start training from the checkpoint
+ checkpoint_path: null
+
+ # task specific parameter
+ top_k: 100
+
+ # pl_trainer
+ pl_trainer:
+ _target_: lightning.Trainer
+ accelerator: gpu
+ devices: 1
+ num_nodes: 1
+ strategy: auto
+ accumulate_grad_batches: 1
+ gradient_clip_val: 1.0
+ val_check_interval: 1.0 # you can specify an int "n" here => validation every "n" steps
+ check_val_every_n_epoch: 1
+ max_epochs: 0
+ max_steps: 25_000
+ deterministic: True
+ fast_dev_run: False
+ precision: 16
+ reload_dataloaders_every_n_epochs: 1
+
+ early_stopping_callback:
+ # null
+ _target_: lightning.callbacks.EarlyStopping
+ monitor: validate_recall@${train.top_k}
+ mode: max
+ patience: 3
+
+ model_checkpoint_callback:
+ _target_: lightning.callbacks.ModelCheckpoint
+ monitor: validate_recall@${train.top_k}
+ mode: max
+ verbose: True
+ save_top_k: 1
+ save_last: False
+ filename: "checkpoint-validate_recall@${train.top_k}_{validate_recall@${train.top_k}:.4f}-epoch_{epoch:02d}"
+ auto_insert_metric_name: False
+
+ callbacks:
+ prediction_callback:
+ _target_: relik.retriever.callbacks.prediction_callbacks.GoldenRetrieverPredictionCallback
+ k: ${train.top_k}
+ batch_size: 64
+ precision: 16
+ index_precision: 16
+ other_callbacks:
+ - _target_: relik.retriever.callbacks.evaluation_callbacks.RecallAtKEvaluationCallback
+ k: ${train.top_k}
+ verbose: True
+ - _target_: relik.retriever.callbacks.evaluation_callbacks.RecallAtKEvaluationCallback
+ k: 50
+ verbose: True
+ prog_bar: False
+ - _target_: relik.retriever.callbacks.evaluation_callbacks.AvgRankingEvaluationCallback
+ k: ${train.top_k}
+ verbose: True
+ - _target_: relik.retriever.callbacks.utils_callbacks.SavePredictionsCallback
+
+ hard_negatives_callback:
+ _target_: relik.retriever.callbacks.prediction_callbacks.NegativeAugmentationCallback
+ k: ${train.top_k}
+ batch_size: 64
+ precision: 16
+ index_precision: 16
+ stages: [validate] #[validate, sanity_check]
+ metrics_to_monitor:
+ validate_recall@${train.top_k}
+ # - sanity_check_recall@${train.top_k}
+ threshold: 0.0
+ max_negatives: 20
+ add_with_probability: 1.0
+ refresh_every_n_epochs: 1
+ other_callbacks:
+ - _target_: relik.retriever.callbacks.evaluation_callbacks.AvgRankingEvaluationCallback
+ k: ${train.top_k}
+ verbose: True
+ prefix: "train"
+
+ utils_callbacks:
+ - _target_: relik.retriever.callbacks.utils_callbacks.SaveRetrieverCallback
+ - _target_: relik.retriever.callbacks.utils_callbacks.FreeUpIndexerVRAMCallback
+ # - _target_: relik.retriever.callbacks.utils_callbacks.ResetModelCallback
+ # question_encoder: ${model.pl_module.model.question_encoder}
+ # passage_encoder: ${model.pl_module.model.passage_encoder}
diff --git a/relik/retriever/conf/index/inmemory.yaml b/relik/retriever/conf/index/inmemory.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d77f5b79946a82384cb4da0c0f60b7b2700e9280
--- /dev/null
+++ b/relik/retriever/conf/index/inmemory.yaml
@@ -0,0 +1,4 @@
+_target_: relik.retriever.indexers.inmemory.InMemoryDocumentIndex
+documents: ${data.shared_params.passages_path}
+device: cuda
+precision: 16
diff --git a/relik/retriever/conf/logging/wandb_logging.yaml b/relik/retriever/conf/logging/wandb_logging.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1908d7e09789ca0b8e4973ec7f3ca5d47d460af3
--- /dev/null
+++ b/relik/retriever/conf/logging/wandb_logging.yaml
@@ -0,0 +1,16 @@
+# don't forget loggers.login() for the first usage.
+
+log: True # set to False to avoid the logging
+
+wandb_arg:
+ _target_: lightning.loggers.WandbLogger
+ name: ${model_name}
+ project: ${project_name}
+ save_dir: ./
+ log_model: True
+ mode: "online"
+ entity: null
+
+watch:
+ log: "all"
+ log_freq: 100
diff --git a/relik/retriever/conf/loss/nce_loss.yaml b/relik/retriever/conf/loss/nce_loss.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..fe9246b88027fd0d499b9bc3b4beaa7937b23586
--- /dev/null
+++ b/relik/retriever/conf/loss/nce_loss.yaml
@@ -0,0 +1 @@
+_target_: relik.retriever.pythorch_modules.losses.MultiLabelNCELoss
diff --git a/relik/retriever/conf/loss/nll_loss.yaml b/relik/retriever/conf/loss/nll_loss.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1e0a5010025a4a6e9da382e5408af4a58cda6185
--- /dev/null
+++ b/relik/retriever/conf/loss/nll_loss.yaml
@@ -0,0 +1 @@
+_target_: torch.nn.NLLLoss
diff --git a/relik/retriever/conf/optimizer/adamw.yaml b/relik/retriever/conf/optimizer/adamw.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ff0f84e15ebd6c60e3e6e411c30e88fb910fdb29
--- /dev/null
+++ b/relik/retriever/conf/optimizer/adamw.yaml
@@ -0,0 +1,4 @@
+_target_: torch.optim.AdamW
+lr: 1e-5
+weight_decay: 0.01
+fused: False
diff --git a/relik/retriever/conf/optimizer/radam.yaml b/relik/retriever/conf/optimizer/radam.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b5d2a4ecf468327bda98fd205bf483b6828cf653
--- /dev/null
+++ b/relik/retriever/conf/optimizer/radam.yaml
@@ -0,0 +1,3 @@
+_target_: torch.optim.RAdam
+lr: 1e-5
+weight_decay: 0
diff --git a/relik/retriever/conf/optimizer/radamw.yaml b/relik/retriever/conf/optimizer/radamw.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..6f1fc8c4696bf793180366a64baf107829cf7752
--- /dev/null
+++ b/relik/retriever/conf/optimizer/radamw.yaml
@@ -0,0 +1,3 @@
+_target_: relik.retriever.pytorch_modules.optim.RAdamW
+lr: 1e-5
+weight_decay: 0.01
diff --git a/relik/retriever/conf/pretrain_iterable_in_batch.yaml b/relik/retriever/conf/pretrain_iterable_in_batch.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c003d612cf718f4b4b84e866009a2944a5c22d9a
--- /dev/null
+++ b/relik/retriever/conf/pretrain_iterable_in_batch.yaml
@@ -0,0 +1,114 @@
+# Required to make the "experiments" dir the default one for the output of the models
+hydra:
+ run:
+ dir: ./experiments/${model_name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
+
+model_name: ${model.language_model} # used to name the model in wandb
+project_name: relik-retriever # used to name the project in wandb
+
+defaults:
+ - _self_
+ - model: golden_retriever
+ - index: inmemory
+ - loss: nce_loss
+ - optimizer: radamw
+ - scheduler: linear_scheduler
+ - data: dataset_v2
+ - logging: wandb_logging
+ - override hydra/job_logging: colorlog
+ - override hydra/hydra_logging: colorlog
+
+train:
+ # reproducibility
+ seed: 42
+ set_determinism_the_old_way: False
+ # torch parameters
+ float32_matmul_precision: "medium"
+ # if true, only test the model
+ only_test: False
+ # if provided, initialize the model with the weights from the checkpoint
+ pretrain_ckpt_path: null
+ # if provided, start training from the checkpoint
+ checkpoint_path: null
+
+ # task specific parameter
+ top_k: 100
+
+ # pl_trainer
+ pl_trainer:
+ _target_: lightning.Trainer
+ accelerator: gpu
+ devices: 1
+ num_nodes: 1
+ strategy: auto
+ accumulate_grad_batches: 1
+ gradient_clip_val: 1.0
+ val_check_interval: 1.0 # you can specify an int "n" here => validation every "n" steps
+ check_val_every_n_epoch: 1
+ max_epochs: 0
+ max_steps: 220_000
+ deterministic: True
+ fast_dev_run: False
+ precision: 16
+ reload_dataloaders_every_n_epochs: 1
+
+ early_stopping_callback:
+ null
+ # _target_: lightning.callbacks.EarlyStopping
+ # monitor: validate_recall@${train.top_k}
+ # mode: max
+ # patience: 15
+
+ model_checkpoint_callback:
+ _target_: lightning.callbacks.ModelCheckpoint
+ monitor: validate_recall@${train.top_k}
+ mode: max
+ verbose: True
+ save_top_k: 1
+ save_last: True
+ filename: "checkpoint-validate_recall@${train.top_k}_{validate_recall@${train.top_k}:.4f}-epoch_{epoch:02d}"
+ auto_insert_metric_name: False
+
+ callbacks:
+ prediction_callback:
+ _target_: relik.retriever.callbacks.prediction_callbacks.GoldenRetrieverPredictionCallback
+ k: ${train.top_k}
+ batch_size: 128
+ precision: 16
+ index_precision: 16
+ other_callbacks:
+ - _target_: relik.retriever.callbacks.evaluation_callbacks.RecallAtKEvaluationCallback
+ k: ${train.top_k}
+ verbose: True
+ - _target_: relik.retriever.callbacks.evaluation_callbacks.RecallAtKEvaluationCallback
+ k: 50
+ verbose: True
+ prog_bar: False
+ - _target_: relik.retriever.callbacks.evaluation_callbacks.AvgRankingEvaluationCallback
+ k: ${train.top_k}
+ verbose: True
+ - _target_: relik.retriever.callbacks.utils_callbacks.SavePredictionsCallback
+
+ hard_negatives_callback:
+ _target_: relik.retriever.callbacks.prediction_callbacks.NegativeAugmentationCallback
+ k: ${train.top_k}
+ batch_size: 128
+ precision: 16
+ index_precision: 16
+ stages: [validate] #[validate, sanity_check]
+ metrics_to_monitor:
+ validate_recall@${train.top_k}
+ # - sanity_check_recall@${train.top_k}
+ threshold: 0.0
+ max_negatives: 15
+ add_with_probability: 0.2
+ refresh_every_n_epochs: 1
+ other_callbacks:
+ - _target_: relik.retriever.callbacks.evaluation_callbacks.AvgRankingEvaluationCallback
+ k: ${train.top_k}
+ verbose: True
+ prefix: "train"
+
+ utils_callbacks:
+ - _target_: relik.retriever.callbacks.utils_callbacks.SaveRetrieverCallback
+ - _target_: relik.retriever.callbacks.utils_callbacks.FreeUpIndexerVRAMCallback
diff --git a/relik/retriever/conf/scheduler/linear_scheduler.yaml b/relik/retriever/conf/scheduler/linear_scheduler.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d1896bff5d01ee2543639e4e379674d60682a0f6
--- /dev/null
+++ b/relik/retriever/conf/scheduler/linear_scheduler.yaml
@@ -0,0 +1,3 @@
+_target_: transformers.get_linear_schedule_with_warmup
+num_warmup_steps: 0
+num_training_steps: ${train.pl_trainer.max_steps}
diff --git a/relik/retriever/conf/scheduler/linear_scheduler_with_warmup.yaml b/relik/retriever/conf/scheduler/linear_scheduler_with_warmup.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..417857489486469032b7e8b19d509a1e45da043c
--- /dev/null
+++ b/relik/retriever/conf/scheduler/linear_scheduler_with_warmup.yaml
@@ -0,0 +1,3 @@
+_target_: transformers.get_linear_schedule_with_warmup
+num_warmup_steps: 5_000
+num_training_steps: ${train.pl_trainer.max_steps}
diff --git a/relik/retriever/conf/scheduler/none.yaml b/relik/retriever/conf/scheduler/none.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ec747fa47ddb81e9bf2d282011ed32aa4c59f932
--- /dev/null
+++ b/relik/retriever/conf/scheduler/none.yaml
@@ -0,0 +1 @@
+null
\ No newline at end of file
diff --git a/relik/retriever/data/__init__.py b/relik/retriever/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/relik/retriever/data/base/__init__.py b/relik/retriever/data/base/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/relik/retriever/data/base/datasets.py b/relik/retriever/data/base/datasets.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a011402a7f6ead5914eb0808a62e9e967c8c12d
--- /dev/null
+++ b/relik/retriever/data/base/datasets.py
@@ -0,0 +1,89 @@
+import os
+from pathlib import Path
+from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
+
+import torch
+from torch.utils.data import Dataset, IterableDataset
+
+from relik.common.log import get_logger
+
+logger = get_logger()
+
+
+class BaseDataset(Dataset):
+ def __init__(
+ self,
+ name: str,
+ path: Optional[Union[str, os.PathLike, List[str], List[os.PathLike]]] = None,
+ data: Any = None,
+ **kwargs,
+ ):
+ super().__init__()
+ self.name = name
+ if path is None and data is None:
+ raise ValueError("Either `path` or `data` must be provided")
+ self.path = path
+ self.project_folder = Path(__file__).parent.parent.parent
+ self.data = data
+
+ def __len__(self) -> int:
+ return len(self.data)
+
+ def __getitem__(
+ self, index
+ ) -> Union[Dict[str, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
+ return self.data[index]
+
+ def __repr__(self) -> str:
+ return f"Dataset({self.name=}, {self.path=})"
+
+ def load(
+ self,
+ paths: Union[str, os.PathLike, List[str], List[os.PathLike]],
+ *args,
+ **kwargs,
+ ) -> Any:
+ # load data from single or multiple paths in one single dataset
+ raise NotImplementedError
+
+ @staticmethod
+ def collate_fn(batch: Any, *args, **kwargs) -> Any:
+ raise NotImplementedError
+
+
+class IterableBaseDataset(IterableDataset):
+ def __init__(
+ self,
+ name: str,
+ path: Optional[Union[str, Path, List[str], List[Path]]] = None,
+ data: Any = None,
+ *args,
+ **kwargs,
+ ):
+ super().__init__()
+ self.name = name
+ if path is None and data is None:
+ raise ValueError("Either `path` or `data` must be provided")
+ self.path = path
+ self.project_folder = Path(__file__).parent.parent.parent
+ self.data = data
+
+ def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
+ for sample in self.data:
+ yield sample
+
+ def __repr__(self) -> str:
+ return f"Dataset({self.name=}, {self.path=})"
+
+ def load(
+ self,
+ paths: Union[str, os.PathLike, List[str], List[os.PathLike]],
+ *args,
+ **kwargs,
+ ) -> Any:
+ # load data from single or multiple paths in one single dataset
+ raise NotImplementedError
+
+ @staticmethod
+ def collate_fn(batch: Any, *args, **kwargs) -> Any:
+ raise NotImplementedError
diff --git a/relik/retriever/data/datasets.py b/relik/retriever/data/datasets.py
new file mode 100644
index 0000000000000000000000000000000000000000..6cf1897a9955d9eb902d74dd08b7dbcd09fd68e4
--- /dev/null
+++ b/relik/retriever/data/datasets.py
@@ -0,0 +1,726 @@
+import os
+from copy import deepcopy
+from enum import Enum
+from functools import partial
+from pathlib import Path
+from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
+
+import datasets
+import psutil
+import torch
+import transformers as tr
+from datasets import load_dataset
+from torch.utils.data import Dataset
+from tqdm import tqdm
+
+from relik.common.log import get_console_logger, get_logger
+from relik.retriever.common.model_inputs import ModelInputs
+from relik.retriever.data.base.datasets import BaseDataset, IterableBaseDataset
+from relik.retriever.data.utils import HardNegativesManager
+
+console_logger = get_console_logger()
+
+logger = get_logger(__name__)
+
+
+class SubsampleStrategyEnum(Enum):
+ NONE = "none"
+ RANDOM = "random"
+ IN_ORDER = "in_order"
+
+
+class GoldenRetrieverDataset:
+ def __init__(
+ self,
+ name: str,
+ path: Union[str, os.PathLike, List[str], List[os.PathLike]] = None,
+ data: Any = None,
+ tokenizer: Optional[Union[str, tr.PreTrainedTokenizer]] = None,
+ # passages: Union[str, os.PathLike, List[str]] = None,
+ passage_batch_size: int = 32,
+ question_batch_size: int = 32,
+ max_positives: int = -1,
+ max_negatives: int = 0,
+ max_hard_negatives: int = 0,
+ max_question_length: int = 256,
+ max_passage_length: int = 64,
+ shuffle: bool = False,
+ subsample_strategy: Optional[str] = SubsampleStrategyEnum.NONE,
+ subsample_portion: float = 0.1,
+ num_proc: Optional[int] = None,
+ load_from_cache_file: bool = True,
+ keep_in_memory: bool = False,
+ prefetch: bool = True,
+ load_fn_kwargs: Optional[Dict[str, Any]] = None,
+ batch_fn_kwargs: Optional[Dict[str, Any]] = None,
+ collate_fn_kwargs: Optional[Dict[str, Any]] = None,
+ ):
+ if path is None and data is None:
+ raise ValueError("Either `path` or `data` must be provided")
+
+ if tokenizer is None:
+ raise ValueError("A tokenizer must be provided")
+
+ # dataset parameters
+ self.name = name
+ self.path = Path(path) or path
+ if path is not None and not isinstance(self.path, Sequence):
+ self.path = [self.path]
+ # self.project_folder = Path(__file__).parent.parent.parent
+ self.data = data
+
+ # hyper-parameters
+ self.passage_batch_size = passage_batch_size
+ self.question_batch_size = question_batch_size
+ self.max_positives = max_positives
+ self.max_negatives = max_negatives
+ self.max_hard_negatives = max_hard_negatives
+ self.max_question_length = max_question_length
+ self.max_passage_length = max_passage_length
+ self.shuffle = shuffle
+ self.num_proc = num_proc
+ self.load_from_cache_file = load_from_cache_file
+ self.keep_in_memory = keep_in_memory
+ self.prefetch = prefetch
+
+ self.tokenizer = tokenizer
+ if isinstance(self.tokenizer, str):
+ self.tokenizer = tr.AutoTokenizer.from_pretrained(self.tokenizer)
+
+ self.padding_ops = {
+ "input_ids": partial(
+ self.pad_sequence,
+ value=self.tokenizer.pad_token_id,
+ ),
+ "attention_mask": partial(self.pad_sequence, value=0),
+ "token_type_ids": partial(
+ self.pad_sequence,
+ value=self.tokenizer.pad_token_type_id,
+ ),
+ }
+
+ # check if subsample strategy is valid
+ if subsample_strategy is not None:
+ # subsample_strategy can be a string or a SubsampleStrategy
+ if isinstance(subsample_strategy, str):
+ try:
+ subsample_strategy = SubsampleStrategyEnum(subsample_strategy)
+ except ValueError:
+ raise ValueError(
+ f"Subsample strategy {subsample_strategy} is not valid. "
+ f"Valid strategies are: {SubsampleStrategyEnum.__members__}"
+ )
+ if not isinstance(subsample_strategy, SubsampleStrategyEnum):
+ raise ValueError(
+ f"Subsample strategy {subsample_strategy} is not valid. "
+ f"Valid strategies are: {SubsampleStrategyEnum.__members__}"
+ )
+ self.subsample_strategy = subsample_strategy
+ self.subsample_portion = subsample_portion
+
+ # load the dataset
+ if data is None:
+ self.data: Dataset = self.load(
+ self.path,
+ tokenizer=self.tokenizer,
+ load_from_cache_file=load_from_cache_file,
+ load_fn_kwargs=load_fn_kwargs,
+ num_proc=num_proc,
+ shuffle=shuffle,
+ keep_in_memory=keep_in_memory,
+ max_positives=max_positives,
+ max_negatives=max_negatives,
+ max_hard_negatives=max_hard_negatives,
+ max_question_length=max_question_length,
+ max_passage_length=max_passage_length,
+ )
+ else:
+ self.data: Dataset = data
+
+ self.hn_manager: Optional[HardNegativesManager] = None
+
+ # keep track of how many times the dataset has been iterated over
+ self.number_of_complete_iterations = 0
+
+ def __repr__(self) -> str:
+ return f"GoldenRetrieverDataset({self.name=}, {self.path=})"
+
+ def __len__(self) -> int:
+ raise NotImplementedError
+
+ def __getitem__(
+ self, index
+ ) -> Union[Dict[str, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
+ raise NotImplementedError
+
+ def to_torch_dataset(self, *args, **kwargs) -> torch.utils.data.Dataset:
+ raise NotImplementedError
+
+ def load(
+ self,
+ paths: Union[str, os.PathLike, List[str], List[os.PathLike]],
+ tokenizer: tr.PreTrainedTokenizer = None,
+ load_fn_kwargs: Dict = None,
+ load_from_cache_file: bool = True,
+ num_proc: Optional[int] = None,
+ shuffle: bool = False,
+ keep_in_memory: bool = True,
+ max_positives: int = -1,
+ max_negatives: int = -1,
+ max_hard_negatives: int = -1,
+ max_passages: int = -1,
+ max_question_length: int = 256,
+ max_passage_length: int = 64,
+ *args,
+ **kwargs,
+ ) -> Any:
+ # if isinstance(paths, Sequence):
+ # paths = [self.project_folder / path for path in paths]
+ # else:
+ # paths = [self.project_folder / paths]
+
+ # read the data and put it in a placeholder list
+ for path in paths:
+ if not path.exists():
+ raise ValueError(f"{path} does not exist")
+
+ fn_kwargs = dict(
+ tokenizer=tokenizer,
+ max_positives=max_positives,
+ max_negatives=max_negatives,
+ max_hard_negatives=max_hard_negatives,
+ max_passages=max_passages,
+ max_question_length=max_question_length,
+ max_passage_length=max_passage_length,
+ )
+ if load_fn_kwargs is not None:
+ fn_kwargs.update(load_fn_kwargs)
+
+ if num_proc is None:
+ num_proc = psutil.cpu_count(logical=False)
+
+ # The data is a list of dictionaries, each dictionary is a sample
+ # Each sample has the following keys:
+ # - "question": the question
+ # - "answers": a list of answers
+ # - "positive_ctxs": a list of positive passages
+ # - "negative_ctxs": a list of negative passages
+ # - "hard_negative_ctxs": a list of hard negative passages
+ # use the huggingface dataset library to load the data, by default it will load the
+ # data in a dict with the key being "train".
+ logger.info(f"Loading data for dataset {self.name}")
+ data = load_dataset(
+ "json",
+ data_files=[str(p) for p in paths], # datasets needs str paths and not Path
+ split="train",
+ streaming=False, # TODO maybe we can make streaming work
+ keep_in_memory=keep_in_memory,
+ )
+ # add id if not present
+ if isinstance(data, datasets.Dataset):
+ data = data.add_column("sample_idx", range(len(data)))
+ else:
+ data = data.map(
+ lambda x, idx: x.update({"sample_idx": idx}), with_indices=True
+ )
+
+ map_kwargs = dict(
+ function=self.load_fn,
+ fn_kwargs=fn_kwargs,
+ )
+ if isinstance(data, datasets.Dataset):
+ map_kwargs.update(
+ dict(
+ load_from_cache_file=load_from_cache_file,
+ keep_in_memory=keep_in_memory,
+ num_proc=num_proc,
+ desc="Loading data",
+ )
+ )
+ # preprocess the data
+ data = data.map(**map_kwargs)
+
+ # shuffle the data
+ if shuffle:
+ data.shuffle(seed=42)
+
+ return data
+
+ @staticmethod
+ def create_batches(
+ data: Dataset,
+ batch_fn: Callable,
+ batch_fn_kwargs: Optional[Dict[str, Any]] = None,
+ prefetch: bool = True,
+ *args,
+ **kwargs,
+ ) -> Union[Iterable, List]:
+ if not prefetch:
+ # if we are streaming, we don't need to create batches right now
+ # we will create them on the fly when we need them
+ batched_data = (
+ batch
+ for batch in batch_fn(
+ data, **(batch_fn_kwargs if batch_fn_kwargs is not None else {})
+ )
+ )
+ else:
+ batched_data = [
+ batch
+ for batch in tqdm(
+ batch_fn(
+ data, **(batch_fn_kwargs if batch_fn_kwargs is not None else {})
+ ),
+ desc="Creating batches",
+ )
+ ]
+ return batched_data
+
+ @staticmethod
+ def collate_batches(
+ batched_data: Union[Iterable, List],
+ collate_fn: Callable,
+ collate_fn_kwargs: Optional[Dict[str, Any]] = None,
+ prefetch: bool = True,
+ *args,
+ **kwargs,
+ ) -> Union[Iterable, List]:
+ if not prefetch:
+ collated_data = (
+ collate_fn(batch, **(collate_fn_kwargs if collate_fn_kwargs else {}))
+ for batch in batched_data
+ )
+ else:
+ collated_data = [
+ collate_fn(batch, **(collate_fn_kwargs if collate_fn_kwargs else {}))
+ for batch in tqdm(batched_data, desc="Collating batches")
+ ]
+ return collated_data
+
+ @staticmethod
+ def load_fn(sample: Dict, *args, **kwargs) -> Dict:
+ raise NotImplementedError
+
+ @staticmethod
+ def batch_fn(data: Dataset, *args, **kwargs) -> Any:
+ raise NotImplementedError
+
+ @staticmethod
+ def collate_fn(batch: Any, *args, **kwargs) -> Any:
+ raise NotImplementedError
+
+ @staticmethod
+ def pad_sequence(
+ sequence: Union[List, torch.Tensor],
+ length: int,
+ value: Any = None,
+ pad_to_left: bool = False,
+ ) -> Union[List, torch.Tensor]:
+ """
+ Pad the input to the specified length with the given value.
+
+ Args:
+ sequence (:obj:`List`, :obj:`torch.Tensor`):
+ Element to pad, it can be either a :obj:`List` or a :obj:`torch.Tensor`.
+ length (:obj:`int`, :obj:`str`, optional, defaults to :obj:`subtoken`):
+ Length after pad.
+ value (:obj:`Any`, optional):
+ Value to use as padding.
+ pad_to_left (:obj:`bool`, optional, defaults to :obj:`False`):
+ If :obj:`True`, pads to the left, right otherwise.
+
+ Returns:
+ :obj:`List`, :obj:`torch.Tensor`: The padded sequence.
+
+ """
+ padding = [value] * abs(length - len(sequence))
+ if isinstance(sequence, torch.Tensor):
+ if len(sequence.shape) > 1:
+ raise ValueError(
+ f"Sequence tensor must be 1D. Current shape is `{len(sequence.shape)}`"
+ )
+ padding = torch.as_tensor(padding)
+ if pad_to_left:
+ if isinstance(sequence, torch.Tensor):
+ return torch.cat((padding, sequence), -1)
+ return padding + sequence
+ if isinstance(sequence, torch.Tensor):
+ return torch.cat((sequence, padding), -1)
+ return sequence + padding
+
+ def convert_to_batch(
+ self, samples: Any, *args, **kwargs
+ ) -> Dict[str, torch.Tensor]:
+ """
+ Convert the list of samples to a batch.
+
+ Args:
+ samples (:obj:`List`):
+ List of samples to convert to a batch.
+
+ Returns:
+ :obj:`Dict[str, torch.Tensor]`: The batch.
+ """
+ # invert questions from list of dict to dict of list
+ samples = {k: [d[k] for d in samples] for k in samples[0]}
+ # get max length of questions
+ max_len = max(len(x) for x in samples["input_ids"])
+ # pad the questions
+ for key in samples:
+ if key in self.padding_ops:
+ samples[key] = torch.as_tensor(
+ [self.padding_ops[key](b, max_len) for b in samples[key]]
+ )
+ return samples
+
+ def shuffle_data(self, seed: int = 42):
+ self.data = self.data.shuffle(seed=seed)
+
+
+class InBatchNegativesDataset(GoldenRetrieverDataset):
+ def __len__(self) -> int:
+ if isinstance(self.data, datasets.Dataset):
+ return len(self.data)
+
+ def __getitem__(
+ self, index
+ ) -> Union[Dict[str, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
+ return self.data[index]
+
+ def to_torch_dataset(self) -> torch.utils.data.Dataset:
+ shuffle_this_time = self.shuffle
+
+ if (
+ self.subsample_strategy
+ and self.subsample_strategy != SubsampleStrategyEnum.NONE
+ ):
+ number_of_samples = int(len(self.data) * self.subsample_portion)
+ if self.subsample_strategy == SubsampleStrategyEnum.RANDOM:
+ logger.info(
+ f"Random subsampling {number_of_samples} samples from {len(self.data)}"
+ )
+ data = (
+ deepcopy(self.data)
+ .shuffle(seed=42 + self.number_of_complete_iterations)
+ .select(range(0, number_of_samples))
+ )
+ elif self.subsample_strategy == SubsampleStrategyEnum.IN_ORDER:
+ # number_of_samples = int(len(self.data) * self.subsample_portion)
+ already_selected = (
+ number_of_samples * self.number_of_complete_iterations
+ )
+ logger.info(
+ f"Subsampling {number_of_samples} samples out of {len(self.data)}"
+ )
+ to_select = min(already_selected + number_of_samples, len(self.data))
+ logger.info(
+ f"Portion of data selected: {already_selected} " f"to {to_select}"
+ )
+ data = deepcopy(self.data).select(range(already_selected, to_select))
+
+ # don't shuffle the data if we are subsampling, and we have still not completed
+ # one full iteration over the dataset
+ if self.number_of_complete_iterations > 0:
+ shuffle_this_time = False
+
+ # reset the number of complete iterations
+ if to_select >= len(self.data):
+ # reset the number of complete iterations,
+ # we have completed one full iteration over the dataset
+ # the value is -1 because we want to start from 0 at the next iteration
+ self.number_of_complete_iterations = -1
+ else:
+ raise ValueError(
+ f"Subsample strategy `{self.subsample_strategy}` is not valid. "
+ f"Valid strategies are: {SubsampleStrategyEnum.__members__}"
+ )
+
+ else:
+ data = data = self.data
+
+ # do we need to shuffle the data?
+ if self.shuffle and shuffle_this_time:
+ logger.info("Shuffling the data")
+ data = data.shuffle(seed=42 + self.number_of_complete_iterations)
+
+ batch_fn_kwargs = {
+ "passage_batch_size": self.passage_batch_size,
+ "question_batch_size": self.question_batch_size,
+ "hard_negatives_manager": self.hn_manager,
+ }
+ batched_data = self.create_batches(
+ data,
+ batch_fn=self.batch_fn,
+ batch_fn_kwargs=batch_fn_kwargs,
+ prefetch=self.prefetch,
+ )
+
+ batched_data = self.collate_batches(
+ batched_data, self.collate_fn, prefetch=self.prefetch
+ )
+
+ # increment the number of complete iterations
+ self.number_of_complete_iterations += 1
+
+ if self.prefetch:
+ return BaseDataset(name=self.name, data=batched_data)
+ else:
+ return IterableBaseDataset(name=self.name, data=batched_data)
+
+ @staticmethod
+ def load_fn(
+ sample: Dict,
+ tokenizer: tr.PreTrainedTokenizer,
+ max_positives: int,
+ max_negatives: int,
+ max_hard_negatives: int,
+ max_passages: int = -1,
+ max_question_length: int = 256,
+ max_passage_length: int = 128,
+ *args,
+ **kwargs,
+ ) -> Dict:
+ # remove duplicates and limit the number of passages
+ positives = list(set([p["text"].strip() for p in sample["positive_ctxs"]]))
+ if max_positives != -1:
+ positives = positives[:max_positives]
+ negatives = list(set([n["text"].strip() for n in sample["negative_ctxs"]]))
+ if max_negatives != -1:
+ negatives = negatives[:max_negatives]
+ hard_negatives = list(
+ set([h["text"].strip() for h in sample["hard_negative_ctxs"]])
+ )
+ if max_hard_negatives != -1:
+ hard_negatives = hard_negatives[:max_hard_negatives]
+
+ question = tokenizer(
+ sample["question"], max_length=max_question_length, truncation=True
+ )
+
+ passage = positives + negatives + hard_negatives
+ if max_passages != -1:
+ passage = passage[:max_passages]
+
+ passage = tokenizer(passage, max_length=max_passage_length, truncation=True)
+
+ # invert the passage data structure from a dict of lists to a list of dicts
+ passage = [dict(zip(passage, t)) for t in zip(*passage.values())]
+
+ output = dict(
+ question=question,
+ passage=passage,
+ positives=positives,
+ positive_pssgs=passage[: len(positives)],
+ )
+ return output
+
+ @staticmethod
+ def batch_fn(
+ data: Dataset,
+ passage_batch_size: int,
+ question_batch_size: int,
+ hard_negatives_manager: Optional[HardNegativesManager] = None,
+ *args,
+ **kwargs,
+ ) -> Dict[str, List[Dict[str, Any]]]:
+ def split_batch(
+ batch: Union[Dict[str, Any], ModelInputs], question_batch_size: int
+ ) -> List[ModelInputs]:
+ """
+ Split a batch into multiple batches of size `question_batch_size` while keeping
+ the same number of passages.
+ """
+
+ def split_fn(x):
+ return [
+ x[i : i + question_batch_size]
+ for i in range(0, len(x), question_batch_size)
+ ]
+
+ # split the sample_idx
+ sample_idx = split_fn(batch["sample_idx"])
+ # split the questions
+ questions = split_fn(batch["questions"])
+ # split the positives
+ positives = split_fn(batch["positives"])
+ # split the positives_pssgs
+ positives_pssgs = split_fn(batch["positives_pssgs"])
+
+ # collect the new batches
+ batches = []
+ for i in range(len(questions)):
+ batches.append(
+ ModelInputs(
+ dict(
+ sample_idx=sample_idx[i],
+ questions=questions[i],
+ passages=batch["passages"],
+ positives=positives[i],
+ positives_pssgs=positives_pssgs[i],
+ )
+ )
+ )
+ return batches
+
+ batch = []
+ passages_in_batch = {}
+
+ for sample in data:
+ if len(passages_in_batch) >= passage_batch_size:
+ # create the batch dict
+ batch_dict = ModelInputs(
+ dict(
+ sample_idx=[s["sample_idx"] for s in batch],
+ questions=[s["question"] for s in batch],
+ passages=list(passages_in_batch.values()),
+ positives_pssgs=[s["positive_pssgs"] for s in batch],
+ positives=[s["positives"] for s in batch],
+ )
+ )
+ # split the batch if needed
+ if len(batch) > question_batch_size:
+ for splited_batch in split_batch(batch_dict, question_batch_size):
+ yield splited_batch
+ else:
+ yield batch_dict
+
+ # reset batch
+ batch = []
+ passages_in_batch = {}
+
+ batch.append(sample)
+ # yes it's a bit ugly but it works :)
+ # count the number of passages in the batch and stop if we reach the limit
+ # we use a set to avoid counting the same passage twice
+ # we use a tuple because set doesn't support lists
+ # we use input_ids as discriminator
+ passages_in_batch.update(
+ {tuple(passage["input_ids"]): passage for passage in sample["passage"]}
+ )
+ # check for hard negatives and add with a probability of 0.1
+ if hard_negatives_manager is not None:
+ if sample["sample_idx"] in hard_negatives_manager:
+ passages_in_batch.update(
+ {
+ tuple(passage["input_ids"]): passage
+ for passage in hard_negatives_manager.get(
+ sample["sample_idx"]
+ )
+ }
+ )
+
+ # left over
+ if len(batch) > 0:
+ # create the batch dict
+ batch_dict = ModelInputs(
+ dict(
+ sample_idx=[s["sample_idx"] for s in batch],
+ questions=[s["question"] for s in batch],
+ passages=list(passages_in_batch.values()),
+ positives_pssgs=[s["positive_pssgs"] for s in batch],
+ positives=[s["positives"] for s in batch],
+ )
+ )
+ # split the batch if needed
+ if len(batch) > question_batch_size:
+ for splited_batch in split_batch(batch_dict, question_batch_size):
+ yield splited_batch
+ else:
+ yield batch_dict
+
+ def collate_fn(self, batch: Any, *args, **kwargs) -> Any:
+ # convert questions and passages to a batch
+ questions = self.convert_to_batch(batch.questions)
+ passages = self.convert_to_batch(batch.passages)
+
+ # build an index to map the position of the passage in the batch
+ passage_index = {tuple(c["input_ids"]): i for i, c in enumerate(batch.passages)}
+
+ # now we can create the labels
+ labels = torch.zeros(
+ questions["input_ids"].shape[0], passages["input_ids"].shape[0]
+ )
+ # iterate over the questions and set the labels to 1 if the passage is positive
+ for sample_idx in range(len(questions["input_ids"])):
+ for pssg in batch["positives_pssgs"][sample_idx]:
+ # get the index of the positive passage
+ index = passage_index[tuple(pssg["input_ids"])]
+ # set the label to 1
+ labels[sample_idx, index] = 1
+
+ model_inputs = ModelInputs(
+ {
+ "questions": questions,
+ "passages": passages,
+ "labels": labels,
+ "positives": batch["positives"],
+ "sample_idx": batch["sample_idx"],
+ }
+ )
+ return model_inputs
+
+
+class AidaInBatchNegativesDataset(InBatchNegativesDataset):
+ def __init__(self, use_topics: bool = False, *args, **kwargs):
+ if "load_fn_kwargs" not in kwargs:
+ kwargs["load_fn_kwargs"] = {}
+ kwargs["load_fn_kwargs"]["use_topics"] = use_topics
+ super().__init__(*args, **kwargs)
+
+ @staticmethod
+ def load_fn(
+ sample: Dict,
+ tokenizer: tr.PreTrainedTokenizer,
+ max_positives: int,
+ max_negatives: int,
+ max_hard_negatives: int,
+ max_passages: int = -1,
+ max_question_length: int = 256,
+ max_passage_length: int = 128,
+ use_topics: bool = False,
+ *args,
+ **kwargs,
+ ) -> Dict:
+ # remove duplicates and limit the number of passages
+ positives = list(set([p["text"].strip() for p in sample["positive_ctxs"]]))
+ if max_positives != -1:
+ positives = positives[:max_positives]
+ negatives = list(set([n["text"].strip() for n in sample["negative_ctxs"]]))
+ if max_negatives != -1:
+ negatives = negatives[:max_negatives]
+ hard_negatives = list(
+ set([h["text"].strip() for h in sample["hard_negative_ctxs"]])
+ )
+ if max_hard_negatives != -1:
+ hard_negatives = hard_negatives[:max_hard_negatives]
+
+ question = sample["question"]
+
+ if "doc_topic" in sample and use_topics:
+ question = tokenizer(
+ question,
+ sample["doc_topic"],
+ max_length=max_question_length,
+ truncation=True,
+ )
+ else:
+ question = tokenizer(
+ question, max_length=max_question_length, truncation=True
+ )
+
+ passage = positives + negatives + hard_negatives
+ if max_passages != -1:
+ passage = passage[:max_passages]
+
+ passage = tokenizer(passage, max_length=max_passage_length, truncation=True)
+
+ # invert the passage data structure from a dict of lists to a list of dicts
+ passage = [dict(zip(passage, t)) for t in zip(*passage.values())]
+
+ output = dict(
+ question=question,
+ passage=passage,
+ positives=positives,
+ positive_pssgs=passage[: len(positives)],
+ )
+ return output
diff --git a/relik/retriever/data/labels.py b/relik/retriever/data/labels.py
new file mode 100644
index 0000000000000000000000000000000000000000..de8f87a2186a413b47d686f92b4d3039536a5988
--- /dev/null
+++ b/relik/retriever/data/labels.py
@@ -0,0 +1,338 @@
+import json
+from pathlib import Path
+from typing import Dict, List, Optional, Set, Union
+
+import transformers as tr
+
+
+class Labels:
+ """
+ Class that contains the labels for a model.
+
+ Args:
+ _labels_to_index (:obj:`Dict[str, Dict[str, int]]`):
+ A dictionary from :obj:`str` to :obj:`int`.
+ _index_to_labels (:obj:`Dict[str, Dict[int, str]]`):
+ A dictionary from :obj:`int` to :obj:`str`.
+ """
+
+ def __init__(
+ self,
+ _labels_to_index: Dict[str, Dict[str, int]] = None,
+ _index_to_labels: Dict[str, Dict[int, str]] = None,
+ **kwargs,
+ ):
+ self._labels_to_index = _labels_to_index or {"labels": {}}
+ self._index_to_labels = _index_to_labels or {"labels": {}}
+ # if _labels_to_index is not empty and _index_to_labels is not provided
+ # to the constructor, build the inverted label dictionary
+ if not _index_to_labels and _labels_to_index:
+ for namespace in self._labels_to_index:
+ self._index_to_labels[namespace] = {
+ v: k for k, v in self._labels_to_index[namespace].items()
+ }
+
+ def get_index_from_label(self, label: str, namespace: str = "labels") -> int:
+ """
+ Returns the index of a literal label.
+
+ Args:
+ label (:obj:`str`):
+ The string representation of the label.
+ namespace (:obj:`str`, optional, defaults to ``labels``):
+ The namespace where the label belongs, e.g. ``roles`` for a SRL task.
+
+ Returns:
+ :obj:`int`: The index of the label.
+ """
+ if namespace not in self._labels_to_index:
+ raise ValueError(
+ f"Provided namespace `{namespace}` is not in the label dictionary."
+ )
+
+ if label not in self._labels_to_index[namespace]:
+ raise ValueError(f"Provided label {label} is not in the label dictionary.")
+
+ return self._labels_to_index[namespace][label]
+
+ def get_label_from_index(self, index: int, namespace: str = "labels") -> str:
+ """
+ Returns the string representation of the label index.
+
+ Args:
+ index (:obj:`int`):
+ The index of the label.
+ namespace (:obj:`str`, optional, defaults to ``labels``):
+ The namespace where the label belongs, e.g. ``roles`` for a SRL task.
+
+ Returns:
+ :obj:`str`: The string representation of the label.
+ """
+ if namespace not in self._index_to_labels:
+ raise ValueError(
+ f"Provided namespace `{namespace}` is not in the label dictionary."
+ )
+
+ if index not in self._index_to_labels[namespace]:
+ raise ValueError(
+ f"Provided label `{index}` is not in the label dictionary."
+ )
+
+ return self._index_to_labels[namespace][index]
+
+ def add_labels(
+ self,
+ labels: Union[str, List[str], Set[str], Dict[str, int]],
+ namespace: str = "labels",
+ ) -> List[int]:
+ """
+ Adds the labels in input in the label dictionary.
+
+ Args:
+ labels (:obj:`str`, :obj:`List[str]`, :obj:`Set[str]`):
+ The labels (single label, list of labels or set of labels) to add to the dictionary.
+ namespace (:obj:`str`, optional, defaults to ``labels``):
+ Namespace where the labels belongs.
+
+ Returns:
+ :obj:`List[int]`: The index of the labels just inserted.
+ """
+ if isinstance(labels, dict):
+ self._labels_to_index[namespace] = labels
+ self._index_to_labels[namespace] = {
+ v: k for k, v in self._labels_to_index[namespace].items()
+ }
+ # normalize input
+ if isinstance(labels, (str, list)):
+ labels = set(labels)
+ # if new namespace, add to the dictionaries
+ if namespace not in self._labels_to_index:
+ self._labels_to_index[namespace] = {}
+ self._index_to_labels[namespace] = {}
+ # returns the new indices
+ return [self._add_label(label, namespace) for label in labels]
+
+ def _add_label(self, label: str, namespace: str = "labels") -> int:
+ """
+ Adds the label in input in the label dictionary.
+
+ Args:
+ label (:obj:`str`):
+ The label to add to the dictionary.
+ namespace (:obj:`str`, optional, defaults to ``labels``):
+ Namespace where the label belongs.
+
+ Returns:
+ :obj:`List[int]`: The index of the label just inserted.
+ """
+ if label not in self._labels_to_index[namespace]:
+ index = len(self._labels_to_index[namespace])
+ self._labels_to_index[namespace][label] = index
+ self._index_to_labels[namespace][index] = label
+ return index
+ else:
+ return self._labels_to_index[namespace][label]
+
+ def get_labels(self, namespace: str = "labels") -> Dict[str, int]:
+ """
+ Returns all the labels that belongs to the input namespace.
+
+ Args:
+ namespace (:obj:`str`, optional, defaults to ``labels``):
+ Labels namespace to retrieve.
+
+ Returns:
+ :obj:`Dict[str, int]`: The label dictionary, from ``str`` to ``int``.
+ """
+ if namespace not in self._labels_to_index:
+ raise ValueError(
+ f"Provided namespace `{namespace}` is not in the label dictionary."
+ )
+ return self._labels_to_index[namespace]
+
+ def get_label_size(self, namespace: str = "labels") -> int:
+ """
+ Returns the number of the labels in the namespace dictionary.
+
+ Args:
+ namespace (:obj:`str`, optional, defaults to ``labels``):
+ Labels namespace to retrieve.
+
+ Returns:
+ :obj:`int`: Number of labels.
+ """
+ if namespace not in self._labels_to_index:
+ raise ValueError(
+ f"Provided namespace `{namespace}` is not in the label dictionary."
+ )
+ return len(self._labels_to_index[namespace])
+
+ def get_namespaces(self) -> List[str]:
+ """
+ Returns all the namespaces in the label dictionary.
+
+ Returns:
+ :obj:`List[str]`: The namespaces in the label dictionary.
+ """
+ return list(self._labels_to_index.keys())
+
+ @classmethod
+ def from_file(cls, file_path: Union[str, Path, dict], **kwargs):
+ with open(file_path, "r") as f:
+ labels_to_index = json.load(f)
+ return cls(labels_to_index, **kwargs)
+
+ def save(self, file_path: Union[str, Path, dict], **kwargs):
+ with open(file_path, "w") as f:
+ json.dump(self._labels_to_index, f, indent=2)
+
+
+class PassageManager:
+ def __init__(
+ self,
+ tokenizer: Optional[tr.PreTrainedTokenizer] = None,
+ passages: Optional[Union[Dict[str, Dict[str, int]], Labels, List[str]]] = None,
+ lazy: bool = True,
+ **kwargs,
+ ):
+ if passages is None:
+ self.passages = Labels()
+ elif isinstance(passages, Labels):
+ self.passages = passages
+ elif isinstance(passages, dict):
+ self.passages = Labels(passages)
+ elif isinstance(passages, list):
+ self.passages = Labels()
+ self.passages.add_labels(passages)
+ else:
+ raise ValueError(
+ "`passages` should be either a Labels object or a dictionary."
+ )
+
+ self.tokenizer = tokenizer
+ self.lazy = lazy
+
+ self._tokenized_passages = {}
+
+ if not self.lazy:
+ self._tokenize_passages(self.passages)
+
+ def __len__(self) -> int:
+ return self.passages.get_label_size()
+
+ def get_index_from_passage(self, passage: str) -> int:
+ """
+ Returns the index of the passage in input.
+
+ Args:
+ passage (:obj:`str`):
+ The passage to get the index from.
+
+ Returns:
+ :obj:`int`: The index of the passage.
+ """
+ return self.passages.get_index_from_label(passage)
+
+ def get_passage_from_index(self, index: int) -> str:
+ """ "
+ Returns the passage from the index in input.
+
+ Args:
+ index (:obj:`int`):
+ The index to get the passage from.
+
+ Returns:
+ :obj:`str`: The passage.
+ """
+ return self.passages.get_label_from_index(index)
+
+ def add_passages(
+ self,
+ passages: Union[str, List[str], Set[str], Dict[str, int]],
+ lazy: Optional[bool] = None,
+ ) -> List[int]:
+ """
+ Adds the passages in input in the passage dictionary.
+
+ Args:
+ passages (:obj:`str`, :obj:`List[str]`, :obj:`Set[str]`, :obj:`Dict[str, int]`):
+ The passages (single passage, list of passages, set of passages or dictionary of passages) to add to the dictionary.
+ lazy (:obj:`bool`, optional, defaults to ``None``):
+ Whether to tokenize the passages right away or not.
+
+ Returns:
+ :obj:`List[int]`: The index of the passages just inserted.
+ """
+
+ return self.passages.add_labels(passages)
+
+ def get_passages(self) -> Dict[str, int]:
+ """
+ Returns all the passages in the passage dictionary.
+
+ Returns:
+ :obj:`Dict[str, int]`: The passage dictionary, from ``str`` to ``int``.
+ """
+ return self.passages.get_labels()
+
+ def get_tokenized_passage(
+ self, passage: Union[str, int], force_tokenize: bool = False, **kwargs
+ ) -> Dict:
+ """
+ Returns the tokenized passage in input.
+
+ Args:
+ passage (:obj:`Union[str, int]`):
+ The passage to tokenize.
+ force_tokenize (:obj:`bool`, optional, defaults to ``False``):
+ Whether to force the tokenization of the passage or not.
+ kwargs:
+ Additional keyword arguments to pass to the tokenizer.
+
+ Returns:
+ :obj:`Dict`: The tokenized passage.
+ """
+ passage_index: Optional[int] = None
+ passage_str: Optional[str] = None
+
+ if isinstance(passage, str):
+ passage_index = self.passages.get_index_from_label(passage)
+ passage_str = passage
+ elif isinstance(passage, int):
+ passage_index = passage
+ passage_str = self.passages.get_label_from_index(passage)
+ else:
+ raise ValueError(
+ f"`passage` should be either a `str` or an `int`. Provided type: {type(passage)}."
+ )
+
+ if passage_index not in self._tokenized_passages or force_tokenize:
+ self._tokenized_passages[passage_index] = self.tokenizer(
+ passage_str, **kwargs
+ )
+
+ return self._tokenized_passages[passage_index]
+
+ def _tokenize_passages(self, **kwargs):
+ for passage in self.passages.get_labels():
+ self.get_tokenized_passage(passage, **kwargs)
+
+ def tokenize(self, text: Union[str, List[str]], **kwargs):
+ """
+ Tokenizes the text in input using the tokenizer.
+
+ Args:
+ text (:obj:`str`, :obj:`List[str]`):
+ The text to tokenize.
+ **kwargs:
+ Additional keyword arguments to pass to the tokenizer.
+
+ Returns:
+ :obj:`List[str]`: The tokenized text.
+
+ """
+ if self.tokenizer is None:
+ raise ValueError(
+ "No tokenizer was provided. Please provide a tokenizer to the passageManager."
+ )
+ return self.tokenizer(text, **kwargs)
diff --git a/relik/retriever/data/utils.py b/relik/retriever/data/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..928dfb833919ce2e30c9e90fed539c46f4bd3dec
--- /dev/null
+++ b/relik/retriever/data/utils.py
@@ -0,0 +1,176 @@
+import json
+import os
+from collections import defaultdict
+from typing import Any, Dict, Iterable, List, Optional, Union
+
+import numpy as np
+import transformers as tr
+from tqdm import tqdm
+
+
+class HardNegativesManager:
+ def __init__(
+ self,
+ tokenizer: tr.PreTrainedTokenizer,
+ data: Union[List[Dict], os.PathLike, Dict[int, List]] = None,
+ max_length: int = 64,
+ batch_size: int = 1000,
+ lazy: bool = False,
+ ) -> None:
+ self._db: dict = None
+ self.tokenizer = tokenizer
+
+ if data is None:
+ self._db = {}
+ else:
+ if isinstance(data, Dict):
+ self._db = data
+ elif isinstance(data, os.PathLike):
+ with open(data) as f:
+ self._db = json.load(f)
+ else:
+ raise ValueError(
+ f"Data type {type(data)} not supported, only Dict and os.PathLike are supported."
+ )
+ # add the tokenizer to the class for future use
+ self.tokenizer = tokenizer
+
+ # invert the db to have a passage -> sample_idx mapping
+ self._passage_db = defaultdict(set)
+ for sample_idx, passages in self._db.items():
+ for passage in passages:
+ self._passage_db[passage].add(sample_idx)
+
+ self._passage_hard_negatives = {}
+ if not lazy:
+ # create a dictionary of passage -> hard_negative mapping
+ batch_size = min(batch_size, len(self._passage_db))
+ unique_passages = list(self._passage_db.keys())
+ for i in tqdm(
+ range(0, len(unique_passages), batch_size),
+ desc="Tokenizing Hard Negatives",
+ ):
+ batch = unique_passages[i : i + batch_size]
+ tokenized_passages = self.tokenizer(
+ batch,
+ max_length=max_length,
+ truncation=True,
+ )
+ for i, passage in enumerate(batch):
+ self._passage_hard_negatives[passage] = {
+ k: tokenized_passages[k][i] for k in tokenized_passages.keys()
+ }
+
+ def __len__(self) -> int:
+ return len(self._db)
+
+ def __getitem__(self, idx: int) -> Dict:
+ return self._db[idx]
+
+ def __iter__(self):
+ for sample in self._db:
+ yield sample
+
+ def __contains__(self, idx: int) -> bool:
+ return idx in self._db
+
+ def get(self, idx: int) -> List[str]:
+ """Get the hard negatives for a given sample index."""
+ if idx not in self._db:
+ raise ValueError(f"Sample index {idx} not in the database.")
+
+ passages = self._db[idx]
+
+ output = []
+ for passage in passages:
+ if passage not in self._passage_hard_negatives:
+ self._passage_hard_negatives[passage] = self._tokenize(passage)
+ output.append(self._passage_hard_negatives[passage])
+
+ return output
+
+ def _tokenize(self, passage: str) -> Dict:
+ return self.tokenizer(passage, max_length=self.max_length, truncation=True)
+
+
+class NegativeSampler:
+ def __init__(
+ self, num_elements: int, probabilities: Optional[Union[List, np.ndarray]] = None
+ ):
+ if not isinstance(probabilities, np.ndarray):
+ probabilities = np.array(probabilities)
+
+ if probabilities is None:
+ # probabilities should sum to 1
+ probabilities = np.random.random(num_elements)
+ probabilities /= np.sum(probabilities)
+ self.probabilities = probabilities
+
+ def __call__(
+ self,
+ sample_size: int,
+ num_samples: int = 1,
+ probabilities: np.array = None,
+ exclude: List[int] = None,
+ ) -> np.array:
+ """
+ Fast sampling of `sample_size` elements from `num_elements` elements.
+ The sampling is done by randomly shifting the probabilities and then
+ finding the smallest of the negative numbers. This is much faster than
+ sampling from a multinomial distribution.
+
+ Args:
+ sample_size (`int`):
+ number of elements to sample
+ num_samples (`int`, optional):
+ number of samples to draw. Defaults to 1.
+ probabilities (`np.array`, optional):
+ probabilities of each element. Defaults to None.
+ exclude (`List[int]`, optional):
+ indices of elements to exclude. Defaults to None.
+
+ Returns:
+ `np.array`: array of sampled indices
+ """
+ if probabilities is None:
+ probabilities = self.probabilities
+
+ if exclude is not None:
+ probabilities[exclude] = 0
+ # re-normalize?
+ # probabilities /= np.sum(probabilities)
+
+ # replicate probabilities as many times as `num_samples`
+ replicated_probabilities = np.tile(probabilities, (num_samples, 1))
+ # get random shifting numbers & scale them correctly
+ random_shifts = np.random.random(replicated_probabilities.shape)
+ random_shifts /= random_shifts.sum(axis=1)[:, np.newaxis]
+ # shift by numbers & find largest (by finding the smallest of the negative)
+ shifted_probabilities = random_shifts - replicated_probabilities
+ sampled_indices = np.argpartition(shifted_probabilities, sample_size, axis=1)[
+ :, :sample_size
+ ]
+ return sampled_indices
+
+
+def batch_generator(samples: Iterable[Any], batch_size: int) -> Iterable[Any]:
+ """
+ Generate batches from samples.
+
+ Args:
+ samples (`Iterable[Any]`): Iterable of samples.
+ batch_size (`int`): Batch size.
+
+ Returns:
+ `Iterable[Any]`: Iterable of batches.
+ """
+ batch = []
+ for sample in samples:
+ batch.append(sample)
+ if len(batch) == batch_size:
+ yield batch
+ batch = []
+
+ # leftover batch
+ if len(batch) > 0:
+ yield batch
diff --git a/relik/retriever/indexers/__init__.py b/relik/retriever/indexers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/relik/retriever/indexers/base.py b/relik/retriever/indexers/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..27ba1eeae4d8f9a757daedb1558ed03788a6b8bd
--- /dev/null
+++ b/relik/retriever/indexers/base.py
@@ -0,0 +1,319 @@
+import os
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Union
+
+import hydra
+import numpy
+import torch
+from omegaconf import OmegaConf
+from rich.pretty import pprint
+
+from relik.common import upload
+from relik.common.log import get_console_logger, get_logger
+from relik.common.utils import (
+ from_cache,
+ is_remote_url,
+ is_str_a_path,
+ relative_to_absolute_path,
+ sapienzanlp_model_urls,
+)
+from relik.retriever.data.labels import Labels
+
+# from relik.retriever.models.model import GoldenRetriever, RetrievedSample
+
+
+logger = get_logger(__name__)
+console_logger = get_console_logger()
+
+
+@dataclass
+class IndexerOutput:
+ indices: Union[torch.Tensor, numpy.ndarray]
+ distances: Union[torch.Tensor, numpy.ndarray]
+
+
+class BaseDocumentIndex:
+ CONFIG_NAME = "config.yaml"
+ DOCUMENTS_FILE_NAME = "documents.json"
+ EMBEDDINGS_FILE_NAME = "embeddings.pt"
+
+ def __init__(
+ self,
+ documents: Union[str, List[str], Labels, os.PathLike, List[os.PathLike]] = None,
+ embeddings: Optional[torch.Tensor] = None,
+ name_or_dir: Optional[Union[str, os.PathLike]] = None,
+ ) -> None:
+ if documents is not None:
+ if isinstance(documents, Labels):
+ self.documents = documents
+ else:
+ documents_are_paths = False
+
+ # normalize the documents to list if not already
+ if not isinstance(documents, list):
+ documents = [documents]
+
+ # now check if the documents are a list of paths (either str or os.PathLike)
+ if isinstance(documents[0], str) or isinstance(
+ documents[0], os.PathLike
+ ):
+ # check if the str is a path
+ documents_are_paths = is_str_a_path(documents[0])
+
+ # if the documents are a list of paths, then we load them
+ if documents_are_paths:
+ logger.info("Loading documents from paths")
+ _documents = []
+ for doc in documents:
+ with open(relative_to_absolute_path(doc)) as f:
+ _documents += [line.strip() for line in f.readlines()]
+ # remove duplicates
+ documents = list(set(_documents))
+
+ self.documents = Labels()
+ self.documents.add_labels(documents)
+ else:
+ self.documents = Labels()
+
+ self.embeddings = embeddings
+ self.name_or_dir = name_or_dir
+
+ @property
+ def config(self) -> Dict[str, Any]:
+ """
+ The configuration of the document index.
+
+ Returns:
+ `Dict[str, Any]`: The configuration of the retriever.
+ """
+
+ def obj_to_dict(obj):
+ match obj:
+ case dict():
+ data = {}
+ for k, v in obj.items():
+ data[k] = obj_to_dict(v)
+ return data
+
+ case list() | tuple():
+ return [obj_to_dict(x) for x in obj]
+
+ case object(__dict__=_):
+ data = {
+ "_target_": f"{obj.__class__.__module__}.{obj.__class__.__name__}",
+ }
+ for k, v in obj.__dict__.items():
+ if not k.startswith("_"):
+ data[k] = obj_to_dict(v)
+ return data
+
+ case _:
+ return obj
+
+ return obj_to_dict(self)
+
+ def index(
+ self,
+ retriever,
+ *args,
+ **kwargs,
+ ) -> "BaseDocumentIndex":
+ raise NotImplementedError
+
+ def search(self, query: Any, k: int = 1, *args, **kwargs) -> List:
+ raise NotImplementedError
+
+ def get_index_from_passage(self, document: str) -> int:
+ """
+ Get the index of the passage.
+
+ Args:
+ document (`str`):
+ The document to get the index for.
+
+ Returns:
+ `int`: The index of the document.
+ """
+ return self.documents.get_index_from_label(document)
+
+ def get_passage_from_index(self, index: int) -> str:
+ """
+ Get the document from the index.
+
+ Args:
+ index (`int`):
+ The index of the document.
+
+ Returns:
+ `str`: The document.
+ """
+ return self.documents.get_label_from_index(index)
+
+ def get_embeddings_from_index(self, index: int) -> torch.Tensor:
+ """
+ Get the document vector from the index.
+
+ Args:
+ index (`int`):
+ The index of the document.
+
+ Returns:
+ `torch.Tensor`: The document vector.
+ """
+ if self.embeddings is None:
+ raise ValueError(
+ "The documents must be indexed before they can be retrieved."
+ )
+ if index >= self.embeddings.shape[0]:
+ raise ValueError(
+ f"The index {index} is out of bounds. The maximum index is {len(self.embeddings) - 1}."
+ )
+ return self.embeddings[index]
+
+ def get_embeddings_from_passage(self, document: str) -> torch.Tensor:
+ """
+ Get the document vector from the document label.
+
+ Args:
+ document (`str`):
+ The document to get the vector for.
+
+ Returns:
+ `torch.Tensor`: The document vector.
+ """
+ if self.embeddings is None:
+ raise ValueError(
+ "The documents must be indexed before they can be retrieved."
+ )
+ return self.get_embeddings_from_index(self.get_index_from_passage(document))
+
+ def save_pretrained(
+ self,
+ output_dir: Union[str, os.PathLike],
+ config: Optional[Dict[str, Any]] = None,
+ config_file_name: Optional[str] = None,
+ document_file_name: Optional[str] = None,
+ embedding_file_name: Optional[str] = None,
+ push_to_hub: bool = False,
+ **kwargs,
+ ):
+ """
+ Save the retriever to a directory.
+
+ Args:
+ output_dir (`str`):
+ The directory to save the retriever to.
+ config (`Optional[Dict[str, Any]]`, `optional`):
+ The configuration to save. If `None`, the current configuration of the retriever will be
+ saved. Defaults to `None`.
+ config_file_name (`Optional[str]`, `optional`):
+ The name of the configuration file. Defaults to `config.yaml`.
+ document_file_name (`Optional[str]`, `optional`):
+ The name of the document file. Defaults to `documents.json`.
+ embedding_file_name (`Optional[str]`, `optional`):
+ The name of the embedding file. Defaults to `embeddings.pt`.
+ push_to_hub (`bool`, `optional`):
+ Whether to push the saved retriever to the hub. Defaults to `False`.
+ """
+ if config is None:
+ # create a default config
+ config = self.config
+
+ config_file_name = config_file_name or self.CONFIG_NAME
+ document_file_name = document_file_name or self.DOCUMENTS_FILE_NAME
+ embedding_file_name = embedding_file_name or self.EMBEDDINGS_FILE_NAME
+
+ # create the output directory
+ output_dir = Path(output_dir)
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ logger.info(f"Saving retriever to {output_dir}")
+ logger.info(f"Saving config to {output_dir / config_file_name}")
+ # pretty print the config
+ pprint(config, console=console_logger, expand_all=True)
+ OmegaConf.save(config, output_dir / config_file_name)
+
+ # save the current state of the retriever
+ embedding_path = output_dir / embedding_file_name
+ logger.info(f"Saving retriever state to {output_dir / embedding_path}")
+ torch.save(self.embeddings, embedding_path)
+
+ # save the passage index
+ documents_path = output_dir / document_file_name
+ logger.info(f"Saving passage index to {documents_path}")
+ self.documents.save(documents_path)
+
+ logger.info("Saving document index to disk done.")
+
+ if push_to_hub:
+ # push to hub
+ logger.info(f"Pushing to hub")
+ model_id = model_id or output_dir.name
+ upload(output_dir, model_id, **kwargs)
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ name_or_dir: Union[str, os.PathLike],
+ device: str = "cpu",
+ precision: Optional[str] = None,
+ config_file_name: Optional[str] = None,
+ document_file_name: Optional[str] = None,
+ embedding_file_name: Optional[str] = None,
+ *args,
+ **kwargs,
+ ) -> "BaseDocumentIndex":
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+
+ config_file_name = config_file_name or cls.CONFIG_NAME
+ document_file_name = document_file_name or cls.DOCUMENTS_FILE_NAME
+ embedding_file_name = embedding_file_name or cls.EMBEDDINGS_FILE_NAME
+
+ model_dir = from_cache(
+ name_or_dir,
+ filenames=[config_file_name, document_file_name, embedding_file_name],
+ cache_dir=cache_dir,
+ force_download=force_download,
+ )
+
+ config_path = model_dir / config_file_name
+ if not config_path.exists():
+ raise FileNotFoundError(
+ f"Model configuration file not found at {config_path}."
+ )
+
+ config = OmegaConf.load(config_path)
+ pprint(OmegaConf.to_container(config), console=console_logger, expand_all=True)
+
+ # load the documents
+ documents_path = model_dir / document_file_name
+
+ if not documents_path.exists():
+ raise ValueError(f"Document file `{documents_path}` does not exist.")
+ logger.info(f"Loading documents from {documents_path}")
+ documents = Labels.from_file(documents_path)
+
+ # load the passage embeddings
+ embedding_path = model_dir / embedding_file_name
+ # run some checks
+ embeddings = None
+ if embedding_path.exists():
+ logger.info(f"Loading embeddings from {embedding_path}")
+ embeddings = torch.load(embedding_path, map_location="cpu")
+ else:
+ logger.warning(f"Embedding file `{embedding_path}` does not exist.")
+
+ document_index = hydra.utils.instantiate(
+ config,
+ documents=documents,
+ embeddings=embeddings,
+ device=device,
+ precision=precision,
+ name_or_dir=name_or_dir,
+ *args,
+ **kwargs,
+ )
+
+ return document_index
diff --git a/relik/retriever/indexers/faiss.py b/relik/retriever/indexers/faiss.py
new file mode 100644
index 0000000000000000000000000000000000000000..49b752db9f04833706f600c149ee18d16511ff55
--- /dev/null
+++ b/relik/retriever/indexers/faiss.py
@@ -0,0 +1,399 @@
+import contextlib
+import logging
+import math
+import os
+from dataclasses import dataclass
+from typing import Callable, List, Optional, Union
+
+import numpy
+import torch
+from pytorch_modules import RetrievedSample
+from torch.utils.data import DataLoader
+from tqdm import tqdm
+
+from relik.common.log import get_logger
+from relik.common.utils import is_package_available
+from relik.retriever.common.model_inputs import ModelInputs
+from relik.retriever.data.base.datasets import BaseDataset
+from relik.retriever.data.labels import Labels
+from relik.retriever.indexers.base import BaseDocumentIndex
+from relik.retriever.pytorch_modules import PRECISION_MAP
+from relik.retriever.pytorch_modules.model import GoldenRetriever
+
+if is_package_available("faiss"):
+ import faiss
+ import faiss.contrib.torch_utils
+
+logger = get_logger(__name__, level=logging.INFO)
+
+
+@dataclass
+class FaissOutput:
+ indices: Union[torch.Tensor, numpy.ndarray]
+ distances: Union[torch.Tensor, numpy.ndarray]
+
+
+class FaissDocumentIndex(BaseDocumentIndex):
+ DOCUMENTS_FILE_NAME = "documents.json"
+ EMBEDDINGS_FILE_NAME = "embeddings.pt"
+ INDEX_FILE_NAME = "index.faiss"
+
+ def __init__(
+ self,
+ documents: Union[List[str], Labels],
+ embeddings: Optional[Union[torch.Tensor, numpy.ndarray]] = None,
+ index=None,
+ index_type: str = "Flat",
+ metric: int = faiss.METRIC_INNER_PRODUCT,
+ normalize: bool = False,
+ device: str = "cpu",
+ name_or_dir: Optional[Union[str, os.PathLike]] = None,
+ *args,
+ **kwargs,
+ ) -> None:
+ super().__init__(documents, embeddings, name_or_dir)
+
+ if embeddings is not None and documents is not None:
+ logger.info("Both documents and embeddings are provided.")
+ if documents.get_label_size() != embeddings.shape[0]:
+ raise ValueError(
+ "The number of documents and embeddings must be the same."
+ )
+
+ # device to store the embeddings
+ self.device = device
+
+ # params
+ self.index_type = index_type
+ self.metric = metric
+ self.normalize = normalize
+
+ if index is not None:
+ self.embeddings = index
+ if self.device == "cuda":
+ # use a single GPU
+ faiss_resource = faiss.StandardGpuResources()
+ self.embeddings = faiss.index_cpu_to_gpu(
+ faiss_resource, 0, self.embeddings
+ )
+ else:
+ if embeddings is not None:
+ # build the faiss index
+ logger.info("Building the index from the embeddings.")
+ self.embeddings = self._build_faiss_index(
+ embeddings=embeddings,
+ index_type=index_type,
+ normalize=normalize,
+ metric=metric,
+ )
+
+ def _build_faiss_index(
+ self,
+ embeddings: Optional[Union[torch.Tensor, numpy.ndarray]],
+ index_type: str,
+ normalize: bool,
+ metric: int,
+ ):
+ # build the faiss index
+ self.normalize = (
+ normalize
+ and metric == faiss.METRIC_INNER_PRODUCT
+ and not isinstance(embeddings, torch.Tensor)
+ )
+ if self.normalize:
+ index_type = f"L2norm,{index_type}"
+ faiss_vector_size = embeddings.shape[1]
+ if self.device == "cpu":
+ index_type = index_type.replace("x,", "x_HNSW32,")
+ index_type = index_type.replace(
+ "x", str(math.ceil(math.sqrt(faiss_vector_size)) * 4)
+ )
+ self.embeddings = faiss.index_factory(faiss_vector_size, index_type, metric)
+
+ # convert to GPU
+ if self.device == "cuda":
+ # use a single GPU
+ faiss_resource = faiss.StandardGpuResources()
+ self.embeddings = faiss.index_cpu_to_gpu(faiss_resource, 0, self.embeddings)
+ else:
+ # move to CPU if embeddings is a torch.Tensor
+ embeddings = (
+ embeddings.cpu() if isinstance(embeddings, torch.Tensor) else embeddings
+ )
+
+ # convert to float32 if embeddings is a torch.Tensor and is float16
+ if isinstance(embeddings, torch.Tensor) and embeddings.dtype == torch.float16:
+ embeddings = embeddings.float()
+
+ self.embeddings.add(embeddings)
+
+ # save parameters for saving/loading
+ self.index_type = index_type
+ self.metric = metric
+
+ # clear the embeddings to free up memory
+ embeddings = None
+
+ return self.embeddings
+
+ @torch.no_grad()
+ @torch.inference_mode()
+ def index(
+ self,
+ retriever: GoldenRetriever,
+ documents: Optional[List[str]] = None,
+ batch_size: int = 32,
+ num_workers: int = 4,
+ max_length: Optional[int] = None,
+ collate_fn: Optional[Callable] = None,
+ encoder_precision: Optional[Union[str, int]] = None,
+ compute_on_cpu: bool = False,
+ force_reindex: bool = False,
+ *args,
+ **kwargs,
+ ) -> "FaissDocumentIndex":
+ """
+ Index the documents using the encoder.
+
+ Args:
+ retriever (:obj:`torch.nn.Module`):
+ The encoder to be used for indexing.
+ documents (:obj:`List[str]`, `optional`, defaults to None):
+ The documents to be indexed.
+ batch_size (:obj:`int`, `optional`, defaults to 32):
+ The batch size to be used for indexing.
+ num_workers (:obj:`int`, `optional`, defaults to 4):
+ The number of workers to be used for indexing.
+ max_length (:obj:`int`, `optional`, defaults to None):
+ The maximum length of the input to the encoder.
+ collate_fn (:obj:`Callable`, `optional`, defaults to None):
+ The collate function to be used for batching.
+ encoder_precision (:obj:`Union[str, int]`, `optional`, defaults to None):
+ The precision to be used for the encoder.
+ compute_on_cpu (:obj:`bool`, `optional`, defaults to False):
+ Whether to compute the embeddings on CPU.
+ force_reindex (:obj:`bool`, `optional`, defaults to False):
+ Whether to force reindexing.
+
+ Returns:
+ :obj:`InMemoryIndexer`: The indexer object.
+ """
+
+ if self.embeddings is not None and not force_reindex:
+ logger.log(
+ "Embeddings are already present and `force_reindex` is `False`. Skipping indexing."
+ )
+ if documents is None:
+ return self
+
+ # release the memory
+ if collate_fn is None:
+ tokenizer = retriever.passage_tokenizer
+
+ def collate_fn(x):
+ return ModelInputs(
+ tokenizer(
+ x,
+ padding=True,
+ return_tensors="pt",
+ truncation=True,
+ max_length=max_length or tokenizer.model_max_length,
+ )
+ )
+
+ if force_reindex:
+ if documents is not None:
+ self.documents.add_labels(documents)
+ data = [k for k in self.documents.get_labels()]
+
+ else:
+ if documents is not None:
+ data = [k for k in Labels(documents).get_labels()]
+ else:
+ return self
+
+ dataloader = DataLoader(
+ BaseDataset(name="passage", data=data),
+ batch_size=batch_size,
+ shuffle=False,
+ num_workers=num_workers,
+ pin_memory=False,
+ collate_fn=collate_fn,
+ )
+
+ encoder = retriever.passage_encoder
+
+ # Create empty lists to store the passage embeddings and passage index
+ passage_embeddings: List[torch.Tensor] = []
+
+ encoder_device = "cpu" if compute_on_cpu else self.device
+
+ # fucking autocast only wants pure strings like 'cpu' or 'cuda'
+ # we need to convert the model device to that
+ device_type_for_autocast = str(encoder_device).split(":")[0]
+ # autocast doesn't work with CPU and stuff different from bfloat16
+ autocast_pssg_mngr = (
+ contextlib.nullcontext()
+ if device_type_for_autocast == "cpu"
+ else (
+ torch.autocast(
+ device_type=device_type_for_autocast,
+ dtype=PRECISION_MAP[encoder_precision],
+ )
+ )
+ )
+ with autocast_pssg_mngr:
+ # Iterate through each batch in the dataloader
+ for batch in tqdm(dataloader, desc="Indexing"):
+ # Move the batch to the device
+ batch: ModelInputs = batch.to(encoder_device)
+ # Compute the passage embeddings
+ passage_outs = encoder(**batch)
+ # Append the passage embeddings to the list
+ if self.device == "cpu":
+ passage_embeddings.extend([c.detach().cpu() for c in passage_outs])
+ else:
+ passage_embeddings.extend([c for c in passage_outs])
+
+ # move the passage embeddings to the CPU if not already done
+ passage_embeddings = [c.detach().cpu() for c in passage_embeddings]
+ # stack it
+ passage_embeddings: torch.Tensor = torch.stack(passage_embeddings, dim=0)
+ # convert to float32 for faiss
+ passage_embeddings.to(PRECISION_MAP["float32"])
+
+ # index the embeddings
+ self.embeddings = self._build_faiss_index(
+ embeddings=passage_embeddings,
+ index_type=self.index_type,
+ normalize=self.normalize,
+ metric=self.metric,
+ )
+ # free up memory from the unused variable
+ del passage_embeddings
+
+ return self
+
+ @torch.no_grad()
+ @torch.inference_mode()
+ def search(self, query: torch.Tensor, k: int = 1) -> list[list[RetrievedSample]]:
+ k = min(k, self.embeddings.ntotal)
+
+ if self.normalize:
+ faiss.normalize_L2(query)
+ if isinstance(query, torch.Tensor) and self.device == "cpu":
+ query = query.detach().cpu()
+ # Retrieve the indices of the top k passage embeddings
+ retriever_out = self.embeddings.search(query, k)
+
+ # get int values (second element of the tuple)
+ batch_top_k: List[List[int]] = retriever_out[1].detach().cpu().tolist()
+ # get float values (first element of the tuple)
+ batch_scores: List[List[float]] = retriever_out[0].detach().cpu().tolist()
+ # Retrieve the passages corresponding to the indices
+ batch_passages = [
+ [self.documents.get_label_from_index(i) for i in indices]
+ for indices in batch_top_k
+ ]
+ # build the output object
+ batch_retrieved_samples = [
+ [
+ RetrievedSample(label=passage, index=index, score=score)
+ for passage, index, score in zip(passages, indices, scores)
+ ]
+ for passages, indices, scores in zip(
+ batch_passages, batch_top_k, batch_scores
+ )
+ ]
+ return batch_retrieved_samples
+
+ # def save(self, saving_dir: Union[str, os.PathLike]):
+ # """
+ # Save the indexer to the disk.
+
+ # Args:
+ # saving_dir (:obj:`Union[str, os.PathLike]`):
+ # The directory where the indexer will be saved.
+ # """
+ # saving_dir = Path(saving_dir)
+ # # save the passage embeddings
+ # index_path = saving_dir / self.INDEX_FILE_NAME
+ # logger.info(f"Saving passage embeddings to {index_path}")
+ # faiss.write_index(self.embeddings, str(index_path))
+ # # save the passage index
+ # documents_path = saving_dir / self.DOCUMENTS_FILE_NAME
+ # logger.info(f"Saving passage index to {documents_path}")
+ # self.documents.save(documents_path)
+
+ # @classmethod
+ # def load(
+ # cls,
+ # loading_dir: Union[str, os.PathLike],
+ # device: str = "cpu",
+ # document_file_name: Optional[str] = None,
+ # embedding_file_name: Optional[str] = None,
+ # index_file_name: Optional[str] = None,
+ # **kwargs,
+ # ) -> "FaissDocumentIndex":
+ # loading_dir = Path(loading_dir)
+
+ # document_file_name = document_file_name or cls.DOCUMENTS_FILE_NAME
+ # embedding_file_name = embedding_file_name or cls.EMBEDDINGS_FILE_NAME
+ # index_file_name = index_file_name or cls.INDEX_FILE_NAME
+
+ # # load the documents
+ # documents_path = loading_dir / document_file_name
+
+ # if not documents_path.exists():
+ # raise ValueError(f"Document file `{documents_path}` does not exist.")
+ # logger.info(f"Loading documents from {documents_path}")
+ # documents = Labels.from_file(documents_path)
+
+ # index = None
+ # embeddings = None
+ # # try to load the index directly
+ # index_path = loading_dir / index_file_name
+ # if not index_path.exists():
+ # # try to load the embeddings
+ # embedding_path = loading_dir / embedding_file_name
+ # # run some checks
+ # if embedding_path.exists():
+ # logger.info(f"Loading embeddings from {embedding_path}")
+ # embeddings = torch.load(embedding_path, map_location="cpu")
+ # logger.warning(
+ # f"Index file `{index_path}` and embedding file `{embedding_path}` do not exist."
+ # )
+ # else:
+ # logger.info(f"Loading index from {index_path}")
+ # index = faiss.read_index(str(embedding_path))
+
+ # return cls(
+ # documents=documents,
+ # embeddings=embeddings,
+ # index=index,
+ # device=device,
+ # **kwargs,
+ # )
+
+ def get_embeddings_from_index(
+ self, index: int
+ ) -> Union[torch.Tensor, numpy.ndarray]:
+ """
+ Get the document vector from the index.
+
+ Args:
+ index (`int`):
+ The index of the document.
+
+ Returns:
+ `torch.Tensor`: The document vector.
+ """
+ if self.embeddings is None:
+ raise ValueError(
+ "The documents must be indexed before they can be retrieved."
+ )
+ if index >= self.embeddings.ntotal:
+ raise ValueError(
+ f"The index {index} is out of bounds. The maximum index is {self.embeddings.ntotal}."
+ )
+ return self.embeddings.reconstruct(index)
diff --git a/relik/retriever/indexers/inmemory.py b/relik/retriever/indexers/inmemory.py
new file mode 100644
index 0000000000000000000000000000000000000000..42926991df7fecd3451dbf18c6b52b67983070b9
--- /dev/null
+++ b/relik/retriever/indexers/inmemory.py
@@ -0,0 +1,275 @@
+import contextlib
+import logging
+import os
+from typing import Callable, List, Optional, Tuple, Union
+
+import torch
+from torch.utils.data import DataLoader
+from tqdm import tqdm
+
+from relik.common.log import get_logger
+from relik.retriever.common.model_inputs import ModelInputs
+from relik.retriever.data.base.datasets import BaseDataset
+from relik.retriever.data.labels import Labels
+from relik.retriever.indexers.base import BaseDocumentIndex
+from relik.retriever.pytorch_modules import PRECISION_MAP, RetrievedSample
+
+logger = get_logger(__name__, level=logging.INFO)
+
+
+class InMemoryDocumentIndex(BaseDocumentIndex):
+ DOCUMENTS_FILE_NAME = "documents.json"
+ EMBEDDINGS_FILE_NAME = "embeddings.pt"
+
+ def __init__(
+ self,
+ documents: Union[str, List[str], Labels, os.PathLike, List[os.PathLike]] = None,
+ embeddings: Optional[torch.Tensor] = None,
+ device: str = "cpu",
+ precision: Optional[str] = None,
+ name_or_dir: Optional[Union[str, os.PathLike]] = None,
+ *args,
+ **kwargs,
+ ) -> None:
+ """
+ An in-memory indexer.
+
+ Args:
+ documents (:obj:`Union[List[str], PassageManager]`):
+ The documents to be indexed.
+ embeddings (:obj:`Optional[torch.Tensor]`, `optional`, defaults to :obj:`None`):
+ The embeddings of the documents.
+ device (:obj:`str`, `optional`, defaults to "cpu"):
+ The device to be used for storing the embeddings.
+ """
+
+ super().__init__(documents, embeddings, name_or_dir)
+
+ if embeddings is not None and documents is not None:
+ logger.info("Both documents and embeddings are provided.")
+ if documents.get_label_size() != embeddings.shape[0]:
+ raise ValueError(
+ "The number of documents and embeddings must be the same."
+ )
+
+ # embeddings of the documents
+ self.embeddings = embeddings
+ # does this do anything?
+ del embeddings
+ # convert the embeddings to the desired precision
+ if precision is not None:
+ if (
+ self.embeddings is not None
+ and self.embeddings.dtype != PRECISION_MAP[precision]
+ ):
+ logger.info(
+ f"Index vectors are of type {self.embeddings.dtype}. "
+ f"Converting to {PRECISION_MAP[precision]}."
+ )
+ self.embeddings = self.embeddings.to(PRECISION_MAP[precision])
+ # move the embeddings to the desired device
+ if self.embeddings is not None and not self.embeddings.device == device:
+ self.embeddings = self.embeddings.to(device)
+
+ # device to store the embeddings
+ self.device = device
+ # precision to be used for the embeddings
+ self.precision = precision
+
+ @torch.no_grad()
+ @torch.inference_mode()
+ def index(
+ self,
+ retriever,
+ documents: Optional[List[str]] = None,
+ batch_size: int = 32,
+ num_workers: int = 4,
+ max_length: Optional[int] = None,
+ collate_fn: Optional[Callable] = None,
+ encoder_precision: Optional[Union[str, int]] = None,
+ compute_on_cpu: bool = False,
+ force_reindex: bool = False,
+ add_to_existing_index: bool = False,
+ ) -> "InMemoryDocumentIndex":
+ """
+ Index the documents using the encoder.
+
+ Args:
+ retriever (:obj:`torch.nn.Module`):
+ The encoder to be used for indexing.
+ documents (:obj:`List[str]`, `optional`, defaults to :obj:`None`):
+ The documents to be indexed.
+ batch_size (:obj:`int`, `optional`, defaults to 32):
+ The batch size to be used for indexing.
+ num_workers (:obj:`int`, `optional`, defaults to 4):
+ The number of workers to be used for indexing.
+ max_length (:obj:`int`, `optional`, defaults to None):
+ The maximum length of the input to the encoder.
+ collate_fn (:obj:`Callable`, `optional`, defaults to None):
+ The collate function to be used for batching.
+ encoder_precision (:obj:`Union[str, int]`, `optional`, defaults to None):
+ The precision to be used for the encoder.
+ compute_on_cpu (:obj:`bool`, `optional`, defaults to False):
+ Whether to compute the embeddings on CPU.
+ force_reindex (:obj:`bool`, `optional`, defaults to False):
+ Whether to force reindexing.
+ add_to_existing_index (:obj:`bool`, `optional`, defaults to False):
+ Whether to add the new documents to the existing index.
+
+ Returns:
+ :obj:`InMemoryIndexer`: The indexer object.
+ """
+
+ if documents is None and self.documents is None:
+ raise ValueError("Documents must be provided.")
+
+ if self.embeddings is not None and not force_reindex:
+ logger.info(
+ "Embeddings are already present and `force_reindex` is `False`. Skipping indexing."
+ )
+ if documents is None:
+ return self
+
+ if collate_fn is None:
+ tokenizer = retriever.passage_tokenizer
+
+ def collate_fn(x):
+ return ModelInputs(
+ tokenizer(
+ x,
+ padding=True,
+ return_tensors="pt",
+ truncation=True,
+ max_length=max_length or tokenizer.model_max_length,
+ )
+ )
+
+ if force_reindex:
+ if documents is not None:
+ self.documents.add_labels(documents)
+ data = [k for k in self.documents.get_labels()]
+
+ else:
+ if documents is not None:
+ data = [k for k in Labels(documents).get_labels()]
+ else:
+ return self
+
+ # if force_reindex:
+ # data = [k for k in self.documents.get_labels()]
+
+ dataloader = DataLoader(
+ BaseDataset(name="passage", data=data),
+ batch_size=batch_size,
+ shuffle=False,
+ num_workers=num_workers,
+ pin_memory=False,
+ collate_fn=collate_fn,
+ )
+
+ encoder = retriever.passage_encoder
+
+ # Create empty lists to store the passage embeddings and passage index
+ passage_embeddings: List[torch.Tensor] = []
+
+ encoder_device = "cpu" if compute_on_cpu else self.device
+
+ # fucking autocast only wants pure strings like 'cpu' or 'cuda'
+ # we need to convert the model device to that
+ device_type_for_autocast = str(encoder_device).split(":")[0]
+ # autocast doesn't work with CPU and stuff different from bfloat16
+ autocast_pssg_mngr = (
+ contextlib.nullcontext()
+ if device_type_for_autocast == "cpu"
+ else (
+ torch.autocast(
+ device_type=device_type_for_autocast,
+ dtype=PRECISION_MAP[encoder_precision],
+ )
+ )
+ )
+ with autocast_pssg_mngr:
+ # Iterate through each batch in the dataloader
+ for batch in tqdm(dataloader, desc="Indexing"):
+ # Move the batch to the device
+ batch: ModelInputs = batch.to(encoder_device)
+ # Compute the passage embeddings
+ passage_outs = encoder(**batch).pooler_output
+ # Append the passage embeddings to the list
+ if self.device == "cpu":
+ passage_embeddings.extend([c.detach().cpu() for c in passage_outs])
+ else:
+ passage_embeddings.extend([c for c in passage_outs])
+
+ # move the passage embeddings to the CPU if not already done
+ # the move to cpu and then to gpu is needed to avoid OOM when using mixed precision
+ if not self.device == "cpu": # this if is to avoid unnecessary moves
+ passage_embeddings = [c.detach().cpu() for c in passage_embeddings]
+ # stack it
+ passage_embeddings: torch.Tensor = torch.stack(passage_embeddings, dim=0)
+ # move the passage embeddings to the gpu if needed
+ if not self.device == "cpu":
+ passage_embeddings = passage_embeddings.to(PRECISION_MAP[self.precision])
+ passage_embeddings = passage_embeddings.to(self.device)
+ self.embeddings = passage_embeddings
+
+ # free up memory from the unused variable
+ del passage_embeddings
+
+ return self
+
+ @torch.no_grad()
+ @torch.inference_mode()
+ def search(self, query: torch.Tensor, k: int = 1) -> list[list[RetrievedSample]]:
+ """
+ Search the documents using the query.
+
+ Args:
+ query (:obj:`torch.Tensor`):
+ The query to be used for searching.
+ k (:obj:`int`, `optional`, defaults to 1):
+ The number of documents to be retrieved.
+
+ Returns:
+ :obj:`List[RetrievedSample]`: The retrieved documents.
+ """
+ # fucking autocast only wants pure strings like 'cpu' or 'cuda'
+ # we need to convert the model device to that
+ device_type_for_autocast = str(self.device).split(":")[0]
+ # autocast doesn't work with CPU and stuff different from bfloat16
+ autocast_pssg_mngr = (
+ contextlib.nullcontext()
+ if device_type_for_autocast == "cpu"
+ else (
+ torch.autocast(
+ device_type=device_type_for_autocast,
+ dtype=self.embeddings.dtype,
+ )
+ )
+ )
+ with autocast_pssg_mngr:
+ similarity = torch.matmul(query, self.embeddings.T)
+ # Retrieve the indices of the top k passage embeddings
+ retriever_out: Tuple = torch.topk(
+ similarity, k=min(k, similarity.shape[-1]), dim=1
+ )
+ # get int values
+ batch_top_k: List[List[int]] = retriever_out.indices.detach().cpu().tolist()
+ # get float values
+ batch_scores: List[List[float]] = retriever_out.values.detach().cpu().tolist()
+ # Retrieve the passages corresponding to the indices
+ batch_passages = [
+ [self.documents.get_label_from_index(i) for i in indices]
+ for indices in batch_top_k
+ ]
+ # build the output object
+ batch_retrieved_samples = [
+ [
+ RetrievedSample(label=passage, index=index, score=score)
+ for passage, index, score in zip(passages, indices, scores)
+ ]
+ for passages, indices, scores in zip(
+ batch_passages, batch_top_k, batch_scores
+ )
+ ]
+ return batch_retrieved_samples
diff --git a/relik/retriever/lightning_modules/__init__.py b/relik/retriever/lightning_modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/relik/retriever/lightning_modules/pl_data_modules.py b/relik/retriever/lightning_modules/pl_data_modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c69f5f291789cb7b473dd8466f7373b1a010e8a
--- /dev/null
+++ b/relik/retriever/lightning_modules/pl_data_modules.py
@@ -0,0 +1,121 @@
+from typing import Any, List, Optional, Sequence, Union
+
+import hydra
+import lightning as pl
+import torch
+from lightning.pytorch.utilities.types import EVAL_DATALOADERS
+from omegaconf import DictConfig
+from torch.utils.data import DataLoader
+
+from relik.common.log import get_logger
+from relik.retriever.data.datasets import GoldenRetrieverDataset
+
+logger = get_logger()
+
+
+class GoldenRetrieverPLDataModule(pl.LightningDataModule):
+ def __init__(
+ self,
+ train_dataset: Optional[GoldenRetrieverDataset] = None,
+ val_datasets: Optional[Sequence[GoldenRetrieverDataset]] = None,
+ test_datasets: Optional[Sequence[GoldenRetrieverDataset]] = None,
+ num_workers: Optional[Union[DictConfig, int]] = None,
+ datasets: Optional[DictConfig] = None,
+ *args,
+ **kwargs,
+ ):
+ super().__init__()
+ self.datasets = datasets
+ if num_workers is None:
+ num_workers = 0
+ if isinstance(num_workers, int):
+ num_workers = DictConfig(
+ {"train": num_workers, "val": num_workers, "test": num_workers}
+ )
+ self.num_workers = num_workers
+ # data
+ self.train_dataset: Optional[GoldenRetrieverDataset] = train_dataset
+ self.val_datasets: Optional[Sequence[GoldenRetrieverDataset]] = val_datasets
+ self.test_datasets: Optional[Sequence[GoldenRetrieverDataset]] = test_datasets
+
+ def prepare_data(self, *args, **kwargs):
+ """
+ Method for preparing the data before the training. This method is called only once.
+ It is used to download the data, tokenize the data, etc.
+ """
+ pass
+
+ def setup(self, stage: Optional[str] = None):
+ if stage == "fit" or stage is None:
+ # usually there is only one dataset for train
+ # if you need more train loader, you can follow
+ # the same logic as val and test datasets
+ if self.train_dataset is None:
+ self.train_dataset = hydra.utils.instantiate(self.datasets.train)
+ self.val_datasets = [
+ hydra.utils.instantiate(dataset_cfg)
+ for dataset_cfg in self.datasets.val
+ ]
+ if stage == "test":
+ if self.test_datasets is None:
+ self.test_datasets = [
+ hydra.utils.instantiate(dataset_cfg)
+ for dataset_cfg in self.datasets.test
+ ]
+
+ def train_dataloader(self, *args, **kwargs) -> DataLoader:
+ torch_dataset = self.train_dataset.to_torch_dataset()
+ return DataLoader(
+ # self.train_dataset.to_torch_dataset(),
+ torch_dataset,
+ shuffle=False,
+ batch_size=None,
+ num_workers=self.num_workers.train,
+ pin_memory=True,
+ collate_fn=lambda x: x,
+ )
+
+ def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
+ dataloaders = []
+ for dataset in self.val_datasets:
+ torch_dataset = dataset.to_torch_dataset()
+ dataloaders.append(
+ DataLoader(
+ torch_dataset,
+ shuffle=False,
+ batch_size=None,
+ num_workers=self.num_workers.val,
+ pin_memory=True,
+ collate_fn=lambda x: x,
+ )
+ )
+ return dataloaders
+
+ def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
+ dataloaders = []
+ for dataset in self.test_datasets:
+ torch_dataset = dataset.to_torch_dataset()
+ dataloaders.append(
+ DataLoader(
+ torch_dataset,
+ shuffle=False,
+ batch_size=None,
+ num_workers=self.num_workers.test,
+ pin_memory=True,
+ collate_fn=lambda x: x,
+ )
+ )
+ return dataloaders
+
+ def predict_dataloader(self) -> EVAL_DATALOADERS:
+ raise NotImplementedError
+
+ def transfer_batch_to_device(
+ self, batch: Any, device: torch.device, dataloader_idx: int
+ ) -> Any:
+ return super().transfer_batch_to_device(batch, device, dataloader_idx)
+
+ def __repr__(self) -> str:
+ return (
+ f"{self.__class__.__name__}(" f"{self.datasets=}, " f"{self.num_workers=}, "
+ )
diff --git a/relik/retriever/lightning_modules/pl_modules.py b/relik/retriever/lightning_modules/pl_modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ef6ace6eb8ac95d92e16b1e28dc50d2492b33f4
--- /dev/null
+++ b/relik/retriever/lightning_modules/pl_modules.py
@@ -0,0 +1,123 @@
+from typing import Any, Union
+
+import hydra
+import lightning as pl
+import torch
+from omegaconf import DictConfig
+
+from relik.retriever.common.model_inputs import ModelInputs
+
+
+class GoldenRetrieverPLModule(pl.LightningModule):
+ def __init__(
+ self,
+ model: Union[torch.nn.Module, DictConfig],
+ optimizer: Union[torch.optim.Optimizer, DictConfig],
+ lr_scheduler: Union[torch.optim.lr_scheduler.LRScheduler, DictConfig] = None,
+ *args,
+ **kwargs,
+ ) -> None:
+ super().__init__()
+ self.save_hyperparameters(ignore=["model"])
+ if isinstance(model, DictConfig):
+ self.model = hydra.utils.instantiate(model)
+ else:
+ self.model = model
+
+ self.optimizer_config = optimizer
+ self.lr_scheduler_config = lr_scheduler
+
+ def forward(self, **kwargs) -> dict:
+ """
+ Method for the forward pass.
+ 'training_step', 'validation_step' and 'test_step' should call
+ this method in order to compute the output predictions and the loss.
+
+ Returns:
+ output_dict: forward output containing the predictions (output logits ecc...) and the loss if any.
+
+ """
+ return self.model(**kwargs)
+
+ def training_step(self, batch: ModelInputs, batch_idx: int) -> torch.Tensor:
+ forward_output = self.forward(**batch, return_loss=True)
+ self.log(
+ "loss",
+ forward_output["loss"],
+ batch_size=batch["questions"]["input_ids"].size(0),
+ prog_bar=True,
+ )
+ return forward_output["loss"]
+
+ def validation_step(self, batch: ModelInputs, batch_idx: int) -> None:
+ forward_output = self.forward(**batch, return_loss=True)
+ self.log(
+ "val_loss",
+ forward_output["loss"],
+ batch_size=batch["questions"]["input_ids"].size(0),
+ )
+
+ def test_step(self, batch: ModelInputs, batch_idx: int) -> Any:
+ forward_output = self.forward(**batch, return_loss=True)
+ self.log(
+ "test_loss",
+ forward_output["loss"],
+ batch_size=batch["questions"]["input_ids"].size(0),
+ )
+
+ def configure_optimizers(self):
+ if isinstance(self.optimizer_config, DictConfig):
+ param_optimizer = list(self.named_parameters())
+ no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
+ optimizer_grouped_parameters = [
+ {
+ "params": [
+ p for n, p in param_optimizer if "layer_norm_layer" in n
+ ],
+ "weight_decay": self.hparams.optimizer.weight_decay,
+ "lr": 1e-4,
+ },
+ {
+ "params": [
+ p
+ for n, p in param_optimizer
+ if all(nd not in n for nd in no_decay)
+ and "layer_norm_layer" not in n
+ ],
+ "weight_decay": self.hparams.optimizer.weight_decay,
+ },
+ {
+ "params": [
+ p
+ for n, p in param_optimizer
+ if "layer_norm_layer" not in n
+ and any(nd in n for nd in no_decay)
+ ],
+ "weight_decay": 0.0,
+ },
+ ]
+ optimizer = hydra.utils.instantiate(
+ self.optimizer_config,
+ # params=self.parameters(),
+ params=optimizer_grouped_parameters,
+ _convert_="partial",
+ )
+ else:
+ optimizer = self.optimizer_config
+
+ if self.lr_scheduler_config is None:
+ return optimizer
+
+ if isinstance(self.lr_scheduler_config, DictConfig):
+ lr_scheduler = hydra.utils.instantiate(
+ self.lr_scheduler_config, optimizer=optimizer
+ )
+ else:
+ lr_scheduler = self.lr_scheduler_config
+
+ lr_scheduler_config = {
+ "scheduler": lr_scheduler,
+ "interval": "step",
+ "frequency": 1,
+ }
+ return [optimizer], [lr_scheduler_config]
diff --git a/relik/retriever/pytorch_modules/__init__.py b/relik/retriever/pytorch_modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..01752b8aa79367a7bcdc2d18438ae87bebdd87f2
--- /dev/null
+++ b/relik/retriever/pytorch_modules/__init__.py
@@ -0,0 +1,28 @@
+from dataclasses import dataclass
+
+import torch
+
+PRECISION_MAP = {
+ None: torch.float32,
+ 16: torch.float16,
+ 32: torch.float32,
+ "float16": torch.float16,
+ "float32": torch.float32,
+ "half": torch.float16,
+ "float": torch.float32,
+ "16": torch.float16,
+ "32": torch.float32,
+ "fp16": torch.float16,
+ "fp32": torch.float32,
+}
+
+
+@dataclass
+class RetrievedSample:
+ """
+ Dataclass for the output of the GoldenRetriever model.
+ """
+
+ score: float
+ index: int
+ label: str
diff --git a/relik/retriever/pytorch_modules/hf.py b/relik/retriever/pytorch_modules/hf.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5868d67ce5ed97a1e66c6d1d3a606350e75267a
--- /dev/null
+++ b/relik/retriever/pytorch_modules/hf.py
@@ -0,0 +1,88 @@
+from typing import Tuple, Union
+
+import torch
+from transformers import PretrainedConfig
+from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
+from transformers.models.bert.modeling_bert import BertModel
+
+
+class GoldenRetrieverConfig(PretrainedConfig):
+ model_type = "bert"
+
+ def __init__(
+ self,
+ vocab_size=30522,
+ hidden_size=768,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ intermediate_size=3072,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ max_position_embeddings=512,
+ type_vocab_size=2,
+ initializer_range=0.02,
+ layer_norm_eps=1e-12,
+ pad_token_id=0,
+ position_embedding_type="absolute",
+ use_cache=True,
+ classifier_dropout=None,
+ **kwargs,
+ ):
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
+
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.hidden_act = hidden_act
+ self.intermediate_size = intermediate_size
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.max_position_embeddings = max_position_embeddings
+ self.type_vocab_size = type_vocab_size
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.position_embedding_type = position_embedding_type
+ self.use_cache = use_cache
+ self.classifier_dropout = classifier_dropout
+
+
+class GoldenRetrieverModel(BertModel):
+ config_class = GoldenRetrieverConfig
+
+ def __init__(self, config, *args, **kwargs):
+ super().__init__(config)
+ self.layer_norm_layer = torch.nn.LayerNorm(
+ config.hidden_size, eps=config.layer_norm_eps
+ )
+
+ def forward(
+ self, **kwargs
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
+ attention_mask = kwargs.get("attention_mask", None)
+ model_outputs = super().forward(**kwargs)
+ if attention_mask is None:
+ pooler_output = model_outputs.pooler_output
+ else:
+ token_embeddings = model_outputs.last_hidden_state
+ input_mask_expanded = (
+ attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
+ )
+ pooler_output = torch.sum(
+ token_embeddings * input_mask_expanded, 1
+ ) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
+
+ pooler_output = self.layer_norm_layer(pooler_output)
+
+ if not kwargs.get("return_dict", True):
+ return (model_outputs[0], pooler_output) + model_outputs[2:]
+
+ return BaseModelOutputWithPoolingAndCrossAttentions(
+ last_hidden_state=model_outputs.last_hidden_state,
+ pooler_output=pooler_output,
+ past_key_values=model_outputs.past_key_values,
+ hidden_states=model_outputs.hidden_states,
+ attentions=model_outputs.attentions,
+ cross_attentions=model_outputs.cross_attentions,
+ )
diff --git a/relik/retriever/pytorch_modules/loss.py b/relik/retriever/pytorch_modules/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..643d3a486ca73ca38486094553b357f5e7c28adb
--- /dev/null
+++ b/relik/retriever/pytorch_modules/loss.py
@@ -0,0 +1,34 @@
+from typing import Optional
+
+import torch
+from torch.nn.modules.loss import _WeightedLoss
+
+
+class MultiLabelNCELoss(_WeightedLoss):
+ __constants__ = ["reduction"]
+
+ def __init__(
+ self,
+ weight: Optional[torch.Tensor] = None,
+ size_average=None,
+ reduction: Optional[str] = "mean",
+ ) -> None:
+ super(MultiLabelNCELoss, self).__init__(weight, size_average, None, reduction)
+
+ def forward(
+ self, input: torch.Tensor, target: torch.Tensor, ignore_index: int = -100
+ ) -> torch.Tensor:
+ gold_scores = input.masked_fill(~(target.bool()), 0)
+ gold_scores_sum = gold_scores.sum(-1) # B x C
+ neg_logits = input.masked_fill(target.bool(), float("-inf")) # B x C x L
+ neg_log_sum_exp = torch.logsumexp(neg_logits, -1, keepdim=True) # B x C x 1
+ norm_term = (
+ torch.logaddexp(input, neg_log_sum_exp)
+ .masked_fill(~(target.bool()), 0)
+ .sum(-1)
+ )
+ gold_log_probs = gold_scores_sum - norm_term
+ loss = -gold_log_probs.sum()
+ if self.reduction == "mean":
+ loss /= input.size(0)
+ return loss
diff --git a/relik/retriever/pytorch_modules/model.py b/relik/retriever/pytorch_modules/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..f02aedd5b43cbd789b87ae4afd417c919b72b129
--- /dev/null
+++ b/relik/retriever/pytorch_modules/model.py
@@ -0,0 +1,533 @@
+import contextlib
+import logging
+import os
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Callable, Dict, List, Optional, Union
+
+import torch
+import torch.nn.functional as F
+import transformers as tr
+
+from relik.common.log import get_console_logger, get_logger
+from relik.retriever.common.model_inputs import ModelInputs
+from relik.retriever.data.labels import Labels
+from relik.retriever.indexers.base import BaseDocumentIndex
+from relik.retriever.indexers.inmemory import InMemoryDocumentIndex
+from relik.retriever.pytorch_modules import PRECISION_MAP, RetrievedSample
+from relik.retriever.pytorch_modules.hf import GoldenRetrieverModel
+
+console_logger = get_console_logger()
+logger = get_logger(__name__, level=logging.INFO)
+
+
+@dataclass
+class GoldenRetrieverOutput(tr.file_utils.ModelOutput):
+ """Class for model's outputs."""
+
+ logits: Optional[torch.FloatTensor] = None
+ loss: Optional[torch.FloatTensor] = None
+ question_encodings: Optional[torch.FloatTensor] = None
+ passages_encodings: Optional[torch.FloatTensor] = None
+
+
+class GoldenRetriever(torch.nn.Module):
+ def __init__(
+ self,
+ question_encoder: Union[str, tr.PreTrainedModel],
+ loss_type: Optional[torch.nn.Module] = None,
+ passage_encoder: Optional[Union[str, tr.PreTrainedModel]] = None,
+ document_index: Optional[Union[str, BaseDocumentIndex]] = None,
+ question_tokenizer: Optional[Union[str, tr.PreTrainedTokenizer]] = None,
+ passage_tokenizer: Optional[Union[str, tr.PreTrainedTokenizer]] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ precision: Optional[Union[str, int]] = None,
+ index_precision: Optional[Union[str, int]] = 32,
+ index_device: Optional[Union[str, torch.device]] = "cpu",
+ *args,
+ **kwargs,
+ ):
+ super().__init__()
+
+ self.passage_encoder_is_question_encoder = False
+ # question encoder model
+ if isinstance(question_encoder, str):
+ question_encoder = GoldenRetrieverModel.from_pretrained(
+ question_encoder, **kwargs
+ )
+ self.question_encoder = question_encoder
+ if passage_encoder is None:
+ # if no passage encoder is provided,
+ # share the weights of the question encoder
+ passage_encoder = question_encoder
+ # keep track of the fact that the passage encoder is the same as the question encoder
+ self.passage_encoder_is_question_encoder = True
+ if isinstance(passage_encoder, str):
+ passage_encoder = GoldenRetrieverModel.from_pretrained(
+ passage_encoder, **kwargs
+ )
+ # passage encoder model
+ self.passage_encoder = passage_encoder
+
+ # loss function
+ self.loss_type = loss_type
+
+ # indexer stuff
+ if document_index is None:
+ # if no indexer is provided, create a new one
+ document_index = InMemoryDocumentIndex(
+ device=index_device, precision=index_precision, **kwargs
+ )
+ if isinstance(document_index, str):
+ document_index = BaseDocumentIndex.from_pretrained(
+ document_index, device=index_device, precision=index_precision, **kwargs
+ )
+ self.document_index = document_index
+
+ # lazy load the tokenizer for inference
+ self._question_tokenizer = question_tokenizer
+ self._passage_tokenizer = passage_tokenizer
+
+ # move the model to the device
+ self.to(device or torch.device("cpu"))
+
+ # set the precision
+ self.precision = precision
+
+ def forward(
+ self,
+ questions: Optional[Dict[str, torch.Tensor]] = None,
+ passages: Optional[Dict[str, torch.Tensor]] = None,
+ labels: Optional[torch.Tensor] = None,
+ question_encodings: Optional[torch.Tensor] = None,
+ passages_encodings: Optional[torch.Tensor] = None,
+ passages_per_question: Optional[List[int]] = None,
+ return_loss: bool = False,
+ return_encodings: bool = False,
+ *args,
+ **kwargs,
+ ) -> GoldenRetrieverOutput:
+ """
+ Forward pass of the model.
+
+ Args:
+ questions (`Dict[str, torch.Tensor]`):
+ The questions to encode.
+ passages (`Dict[str, torch.Tensor]`):
+ The passages to encode.
+ labels (`torch.Tensor`):
+ The labels of the sentences.
+ return_loss (`bool`):
+ Whether to compute the predictions.
+ question_encodings (`torch.Tensor`):
+ The encodings of the questions.
+ passages_encodings (`torch.Tensor`):
+ The encodings of the passages.
+ passages_per_question (`List[int]`):
+ The number of passages per question.
+ return_loss (`bool`):
+ Whether to compute the loss.
+ return_encodings (`bool`):
+ Whether to return the encodings.
+
+ Returns:
+ obj:`torch.Tensor`: The outputs of the model.
+ """
+ if questions is None and question_encodings is None:
+ raise ValueError(
+ "Either `questions` or `question_encodings` must be provided"
+ )
+ if passages is None and passages_encodings is None:
+ raise ValueError(
+ "Either `passages` or `passages_encodings` must be provided"
+ )
+
+ if question_encodings is None:
+ question_encodings = self.question_encoder(**questions).pooler_output
+ if passages_encodings is None:
+ passages_encodings = self.passage_encoder(**passages).pooler_output
+
+ if passages_per_question is not None:
+ # multiply each question encoding with a passages_per_question encodings
+ concatenated_passages = torch.stack(
+ torch.split(passages_encodings, passages_per_question)
+ ).transpose(1, 2)
+ if isinstance(self.loss_type, torch.nn.BCEWithLogitsLoss):
+ # normalize the encodings for cosine similarity
+ concatenated_passages = F.normalize(concatenated_passages, p=2, dim=2)
+ question_encodings = F.normalize(question_encodings, p=2, dim=1)
+ logits = torch.bmm(
+ question_encodings.unsqueeze(1), concatenated_passages
+ ).view(question_encodings.shape[0], -1)
+ else:
+ if isinstance(self.loss_type, torch.nn.BCEWithLogitsLoss):
+ # normalize the encodings for cosine similarity
+ question_encodings = F.normalize(question_encodings, p=2, dim=1)
+ passages_encodings = F.normalize(passages_encodings, p=2, dim=1)
+
+ logits = torch.matmul(question_encodings, passages_encodings.T)
+
+ output = dict(logits=logits)
+
+ if return_loss and labels is not None:
+ if self.loss_type is None:
+ raise ValueError(
+ "If `return_loss` is set to `True`, `loss_type` must be provided"
+ )
+ if isinstance(self.loss_type, torch.nn.NLLLoss):
+ labels = labels.argmax(dim=1)
+ logits = F.log_softmax(logits, dim=1)
+ if len(question_encodings.size()) > 1:
+ logits = logits.view(question_encodings.size(0), -1)
+
+ output["loss"] = self.loss_type(logits, labels)
+
+ if return_encodings:
+ output["question_encodings"] = question_encodings
+ output["passages_encodings"] = passages_encodings
+
+ return GoldenRetrieverOutput(**output)
+
+ @torch.no_grad()
+ @torch.inference_mode()
+ def index(
+ self,
+ batch_size: int = 32,
+ num_workers: int = 4,
+ max_length: Optional[int] = None,
+ collate_fn: Optional[Callable] = None,
+ force_reindex: bool = False,
+ compute_on_cpu: bool = False,
+ precision: Optional[Union[str, int]] = None,
+ ):
+ """
+ Index the passages for later retrieval.
+
+ Args:
+ batch_size (`int`):
+ The batch size to use for the indexing.
+ num_workers (`int`):
+ The number of workers to use for the indexing.
+ max_length (`Optional[int]`):
+ The maximum length of the passages.
+ collate_fn (`Callable`):
+ The collate function to use for the indexing.
+ force_reindex (`bool`):
+ Whether to force reindexing even if the passages are already indexed.
+ compute_on_cpu (`bool`):
+ Whether to move the index to the CPU after the indexing.
+ precision (`Optional[Union[str, int]]`):
+ The precision to use for the model.
+ """
+ if self.document_index is None:
+ raise ValueError(
+ "The retriever must be initialized with an indexer to index "
+ "the passages within the retriever."
+ )
+ return self.document_index.index(
+ retriever=self,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ max_length=max_length,
+ collate_fn=collate_fn,
+ encoder_precision=precision or self.precision,
+ compute_on_cpu=compute_on_cpu,
+ force_reindex=force_reindex,
+ )
+
+ @torch.no_grad()
+ @torch.inference_mode()
+ def retrieve(
+ self,
+ text: Optional[Union[str, List[str]]] = None,
+ text_pair: Optional[Union[str, List[str]]] = None,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ k: Optional[int] = None,
+ max_length: Optional[int] = None,
+ precision: Optional[Union[str, int]] = None,
+ ) -> List[List[RetrievedSample]]:
+ """
+ Retrieve the passages for the questions.
+
+ Args:
+ text (`Optional[Union[str, List[str]]]`):
+ The questions to retrieve the passages for.
+ text_pair (`Optional[Union[str, List[str]]]`):
+ The questions to retrieve the passages for.
+ input_ids (`torch.Tensor`):
+ The input ids of the questions.
+ attention_mask (`torch.Tensor`):
+ The attention mask of the questions.
+ token_type_ids (`torch.Tensor`):
+ The token type ids of the questions.
+ k (`int`):
+ The number of top passages to retrieve.
+ max_length (`Optional[int]`):
+ The maximum length of the questions.
+ precision (`Optional[Union[str, int]]`):
+ The precision to use for the model.
+
+ Returns:
+ `List[List[RetrievedSample]]`: The retrieved passages and their indices.
+ """
+ if self.document_index is None:
+ raise ValueError(
+ "The indexer must be indexed before it can be used within the retriever."
+ )
+ if text is None and input_ids is None:
+ raise ValueError(
+ "Either `text` or `input_ids` must be provided to retrieve the passages."
+ )
+
+ if text is not None:
+ if isinstance(text, str):
+ text = [text]
+ if text_pair is not None and isinstance(text_pair, str):
+ text_pair = [text_pair]
+ tokenizer = self.question_tokenizer
+ model_inputs = ModelInputs(
+ tokenizer(
+ text,
+ text_pair=text_pair,
+ padding=True,
+ return_tensors="pt",
+ truncation=True,
+ max_length=max_length or tokenizer.model_max_length,
+ )
+ )
+ else:
+ model_inputs = ModelInputs(dict(input_ids=input_ids))
+ if attention_mask is not None:
+ model_inputs["attention_mask"] = attention_mask
+ if token_type_ids is not None:
+ model_inputs["token_type_ids"] = token_type_ids
+
+ model_inputs.to(self.device)
+
+ # fucking autocast only wants pure strings like 'cpu' or 'cuda'
+ # we need to convert the model device to that
+ device_type_for_autocast = str(self.device).split(":")[0]
+ # autocast doesn't work with CPU and stuff different from bfloat16
+ autocast_pssg_mngr = (
+ contextlib.nullcontext()
+ if device_type_for_autocast == "cpu"
+ else (
+ torch.autocast(
+ device_type=device_type_for_autocast,
+ dtype=PRECISION_MAP[precision],
+ )
+ )
+ )
+ with autocast_pssg_mngr:
+ question_encodings = self.question_encoder(**model_inputs).pooler_output
+
+ # TODO: fix if encoder and index are on different device
+ return self.document_index.search(question_encodings, k)
+
+ def get_index_from_passage(self, passage: str) -> int:
+ """
+ Get the index of the passage.
+
+ Args:
+ passage (`str`):
+ The passage to get the index for.
+
+ Returns:
+ `int`: The index of the passage.
+ """
+ if self.document_index is None:
+ raise ValueError(
+ "The passages must be indexed before they can be retrieved."
+ )
+ return self.document_index.get_index_from_passage(passage)
+
+ def get_passage_from_index(self, index: int) -> str:
+ """
+ Get the passage from the index.
+
+ Args:
+ index (`int`):
+ The index of the passage.
+
+ Returns:
+ `str`: The passage.
+ """
+ if self.document_index is None:
+ raise ValueError(
+ "The passages must be indexed before they can be retrieved."
+ )
+ return self.document_index.get_passage_from_index(index)
+
+ def get_vector_from_index(self, index: int) -> torch.Tensor:
+ """
+ Get the passage vector from the index.
+
+ Args:
+ index (`int`):
+ The index of the passage.
+
+ Returns:
+ `torch.Tensor`: The passage vector.
+ """
+ if self.document_index is None:
+ raise ValueError(
+ "The passages must be indexed before they can be retrieved."
+ )
+ return self.document_index.get_embeddings_from_index(index)
+
+ def get_vector_from_passage(self, passage: str) -> torch.Tensor:
+ """
+ Get the passage vector from the passage.
+
+ Args:
+ passage (`str`):
+ The passage.
+
+ Returns:
+ `torch.Tensor`: The passage vector.
+ """
+ if self.document_index is None:
+ raise ValueError(
+ "The passages must be indexed before they can be retrieved."
+ )
+ return self.document_index.get_embeddings_from_passage(passage)
+
+ @property
+ def passage_embeddings(self) -> torch.Tensor:
+ """
+ The passage embeddings.
+ """
+ return self._passage_embeddings
+
+ @property
+ def passage_index(self) -> Labels:
+ """
+ The passage index.
+ """
+ return self._passage_index
+
+ @property
+ def device(self) -> torch.device:
+ """
+ The device of the model.
+ """
+ return next(self.parameters()).device
+
+ @property
+ def question_tokenizer(self) -> tr.PreTrainedTokenizer:
+ """
+ The question tokenizer.
+ """
+ if self._question_tokenizer:
+ return self._question_tokenizer
+
+ if (
+ self.question_encoder.config.name_or_path
+ == self.question_encoder.config.name_or_path
+ ):
+ if not self._question_tokenizer:
+ self._question_tokenizer = tr.AutoTokenizer.from_pretrained(
+ self.question_encoder.config.name_or_path
+ )
+ self._passage_tokenizer = self._question_tokenizer
+ return self._question_tokenizer
+
+ if not self._question_tokenizer:
+ self._question_tokenizer = tr.AutoTokenizer.from_pretrained(
+ self.question_encoder.config.name_or_path
+ )
+ return self._question_tokenizer
+
+ @property
+ def passage_tokenizer(self) -> tr.PreTrainedTokenizer:
+ """
+ The passage tokenizer.
+ """
+ if self._passage_tokenizer:
+ return self._passage_tokenizer
+
+ if (
+ self.question_encoder.config.name_or_path
+ == self.passage_encoder.config.name_or_path
+ ):
+ if not self._question_tokenizer:
+ self._question_tokenizer = tr.AutoTokenizer.from_pretrained(
+ self.question_encoder.config.name_or_path
+ )
+ self._passage_tokenizer = self._question_tokenizer
+ return self._passage_tokenizer
+
+ if not self._passage_tokenizer:
+ self._passage_tokenizer = tr.AutoTokenizer.from_pretrained(
+ self.passage_encoder.config.name_or_path
+ )
+ return self._passage_tokenizer
+
+ def save_pretrained(
+ self,
+ output_dir: Union[str, os.PathLike],
+ question_encoder_name: Optional[str] = None,
+ passage_encoder_name: Optional[str] = None,
+ document_index_name: Optional[str] = None,
+ push_to_hub: bool = False,
+ **kwargs,
+ ):
+ """
+ Save the retriever to a directory.
+
+ Args:
+ output_dir (`str`):
+ The directory to save the retriever to.
+ question_encoder_name (`Optional[str]`):
+ The name of the question encoder.
+ passage_encoder_name (`Optional[str]`):
+ The name of the passage encoder.
+ document_index_name (`Optional[str]`):
+ The name of the document index.
+ push_to_hub (`bool`):
+ Whether to push the model to the hub.
+ """
+
+ # create the output directory
+ output_dir = Path(output_dir)
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ logger.info(f"Saving retriever to {output_dir}")
+
+ question_encoder_name = question_encoder_name or "question_encoder"
+ passage_encoder_name = passage_encoder_name or "passage_encoder"
+ document_index_name = document_index_name or "document_index"
+
+ logger.info(
+ f"Saving question encoder state to {output_dir / question_encoder_name}"
+ )
+ # self.question_encoder.config._name_or_path = question_encoder_name
+ self.question_encoder.register_for_auto_class()
+ self.question_encoder.save_pretrained(
+ output_dir / question_encoder_name, push_to_hub=push_to_hub, **kwargs
+ )
+ self.question_tokenizer.save_pretrained(
+ output_dir / question_encoder_name, push_to_hub=push_to_hub, **kwargs
+ )
+ if not self.passage_encoder_is_question_encoder:
+ logger.info(
+ f"Saving passage encoder state to {output_dir / passage_encoder_name}"
+ )
+ # self.passage_encoder.config._name_or_path = passage_encoder_name
+ self.passage_encoder.register_for_auto_class()
+ self.passage_encoder.save_pretrained(
+ output_dir / passage_encoder_name, push_to_hub=push_to_hub, **kwargs
+ )
+ self.passage_tokenizer.save_pretrained(
+ output_dir / passage_encoder_name, push_to_hub=push_to_hub, **kwargs
+ )
+
+ if self.document_index is not None:
+ # save the indexer
+ self.document_index.save_pretrained(
+ output_dir / document_index_name, push_to_hub=push_to_hub, **kwargs
+ )
+
+ logger.info("Saving retriever to disk done.")
diff --git a/relik/retriever/pytorch_modules/optim.py b/relik/retriever/pytorch_modules/optim.py
new file mode 100644
index 0000000000000000000000000000000000000000..e815acf44948a4835d7b093c74e6ca8cf8539dd1
--- /dev/null
+++ b/relik/retriever/pytorch_modules/optim.py
@@ -0,0 +1,213 @@
+import math
+
+import torch
+from torch.optim import Optimizer
+
+
+class RAdamW(Optimizer):
+ r"""Implements RAdamW algorithm.
+
+ RAdam from `On the Variance of the Adaptive Learning Rate and Beyond
+ `_
+
+ * `Adam: A Method for Stochastic Optimization
+ `_
+ * `Decoupled Weight Decay Regularization
+ `_
+ * `On the Convergence of Adam and Beyond
+ `_
+ * `On the Variance of the Adaptive Learning Rate and Beyond
+ `_
+
+ Arguments:
+ params (iterable): iterable of parameters to optimize or dicts defining
+ parameter groups
+ lr (float, optional): learning rate (default: 1e-3)
+ betas (Tuple[float, float], optional): coefficients used for computing
+ running averages of gradient and its square (default: (0.9, 0.999))
+ eps (float, optional): term added to the denominator to improve
+ numerical stability (default: 1e-8)
+ weight_decay (float, optional): weight decay coefficient (default: 1e-2)
+ """
+
+ def __init__(
+ self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2
+ ):
+ if not 0.0 <= lr:
+ raise ValueError("Invalid learning rate: {}".format(lr))
+ if not 0.0 <= eps:
+ raise ValueError("Invalid epsilon value: {}".format(eps))
+ if not 0.0 <= betas[0] < 1.0:
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
+ if not 0.0 <= betas[1] < 1.0:
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
+ defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
+ super(RAdamW, self).__init__(params, defaults)
+
+ def step(self, closure=None):
+ """Performs a single optimization step.
+
+ Arguments:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ loss = None
+ if closure is not None:
+ loss = closure()
+
+ for group in self.param_groups:
+ for p in group["params"]:
+ if p.grad is None:
+ continue
+
+ # Perform optimization step
+ grad = p.grad.data
+ if grad.is_sparse:
+ raise RuntimeError(
+ "Adam does not support sparse gradients, please consider SparseAdam instead"
+ )
+
+ state = self.state[p]
+
+ # State initialization
+ if len(state) == 0:
+ state["step"] = 0
+ # Exponential moving average of gradient values
+ state["exp_avg"] = torch.zeros_like(p.data)
+ # Exponential moving average of squared gradient values
+ state["exp_avg_sq"] = torch.zeros_like(p.data)
+
+ exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
+ beta1, beta2 = group["betas"]
+ eps = group["eps"]
+ lr = group["lr"]
+ if "rho_inf" not in group:
+ group["rho_inf"] = 2 / (1 - beta2) - 1
+ rho_inf = group["rho_inf"]
+
+ state["step"] += 1
+ t = state["step"]
+
+ # Decay the first and second moment running average coefficient
+ exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
+ rho_t = rho_inf - ((2 * t * (beta2**t)) / (1 - beta2**t))
+
+ # Perform stepweight decay
+ p.data.mul_(1 - lr * group["weight_decay"])
+
+ if rho_t >= 5:
+ var = exp_avg_sq.sqrt().add_(eps)
+ r = math.sqrt(
+ (1 - beta2**t)
+ * ((rho_t - 4) * (rho_t - 2) * rho_inf)
+ / ((rho_inf - 4) * (rho_inf - 2) * rho_t)
+ )
+
+ p.data.addcdiv_(exp_avg, var, value=-lr * r / (1 - beta1**t))
+ else:
+ p.data.add_(exp_avg, alpha=-lr / (1 - beta1**t))
+
+ return loss
+
+
+# from typing import List
+# import collections
+
+# import torch
+# import transformers
+# from classy.optim.factories import Factory
+# from transformers import AdamW
+
+
+# class ElectraOptimizer(Factory):
+# def __init__(
+# self,
+# lr: float,
+# warmup_steps: int,
+# total_steps: int,
+# weight_decay: float,
+# lr_decay: float,
+# no_decay_params: List[str],
+# ):
+# self.lr = lr
+# self.warmup_steps = warmup_steps
+# self.total_steps = total_steps
+# self.weight_decay = weight_decay
+# self.lr_decay = lr_decay
+# self.no_decay_params = no_decay_params
+
+# def group_layers(self, module) -> dict:
+# grouped_layers = collections.defaultdict(list)
+# module_named_parameters = list(module.named_parameters())
+# for ln, lp in module_named_parameters:
+# if "embeddings" in ln:
+# grouped_layers["embeddings"].append((ln, lp))
+# elif "encoder.layer" in ln:
+# layer_num = ln.replace("transformer_model.encoder.layer.", "")
+# layer_num = layer_num[0 : layer_num.index(".")]
+# grouped_layers[layer_num].append((ln, lp))
+# else:
+# grouped_layers["head"].append((ln, lp))
+
+# depth = len(grouped_layers) - 1
+# final_dict = dict()
+# for key, value in grouped_layers.items():
+# if key == "head":
+# final_dict[0] = value
+# elif key == "embeddings":
+# final_dict[depth] = value
+# else:
+# # -1 because layer number starts from zero
+# final_dict[depth - int(key) - 1] = value
+
+# assert len(module_named_parameters) == sum(
+# len(v) for _, v in final_dict.items()
+# )
+
+# return final_dict
+
+# def group_params(self, module) -> list:
+# optimizer_grouped_params = []
+# for inverse_depth, layer in self.group_layers(module).items():
+# layer_lr = self.lr * (self.lr_decay**inverse_depth)
+# layer_wd_params = {
+# "params": [
+# lp
+# for ln, lp in layer
+# if not any(nd in ln for nd in self.no_decay_params)
+# ],
+# "weight_decay": self.weight_decay,
+# "lr": layer_lr,
+# }
+# layer_no_wd_params = {
+# "params": [
+# lp
+# for ln, lp in layer
+# if any(nd in ln for nd in self.no_decay_params)
+# ],
+# "weight_decay": 0,
+# "lr": layer_lr,
+# }
+
+# if len(layer_wd_params) != 0:
+# optimizer_grouped_params.append(layer_wd_params)
+# if len(layer_no_wd_params) != 0:
+# optimizer_grouped_params.append(layer_no_wd_params)
+
+# return optimizer_grouped_params
+
+# def __call__(self, module: torch.nn.Module):
+# optimizer_grouped_parameters = self.group_params(module)
+# optimizer = AdamW(optimizer_grouped_parameters, lr=self.lr)
+# scheduler = transformers.get_linear_schedule_with_warmup(
+# optimizer, self.warmup_steps, self.total_steps
+# )
+# return {
+# "optimizer": optimizer,
+# "lr_scheduler": {
+# "scheduler": scheduler,
+# "interval": "step",
+# "frequency": 1,
+# },
+# }
diff --git a/relik/retriever/pytorch_modules/scheduler.py b/relik/retriever/pytorch_modules/scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..5edd2433612b27c1a91fe189e57c4e2d41c462b2
--- /dev/null
+++ b/relik/retriever/pytorch_modules/scheduler.py
@@ -0,0 +1,54 @@
+import torch
+from torch.optim.lr_scheduler import LRScheduler
+
+
+class LinearSchedulerWithWarmup(LRScheduler):
+ def __init__(
+ self,
+ optimizer: torch.optim.Optimizer,
+ num_warmup_steps: int,
+ num_training_steps: int,
+ last_epoch: int = -1,
+ verbose: bool = False,
+ **kwargs,
+ ):
+ self.num_warmup_steps = num_warmup_steps
+ self.num_training_steps = num_training_steps
+ super().__init__(optimizer, last_epoch, verbose)
+
+ def get_lr(self):
+ def scheduler_fn(current_step):
+ if current_step < self.num_warmup_steps:
+ return current_step / max(1, self.num_warmup_steps)
+ return max(
+ 0.0,
+ float(self.num_training_steps - current_step)
+ / float(max(1, self.num_training_steps - self.num_warmup_steps)),
+ )
+
+ return [base_lr * scheduler_fn(self.last_epoch) for base_lr in self.base_lrs]
+
+
+class LinearScheduler(LRScheduler):
+ def __init__(
+ self,
+ optimizer: torch.optim.Optimizer,
+ num_training_steps: int,
+ last_epoch: int = -1,
+ verbose: bool = False,
+ **kwargs,
+ ):
+ self.num_training_steps = num_training_steps
+ super().__init__(optimizer, last_epoch, verbose)
+
+ def get_lr(self):
+ def scheduler_fn(current_step):
+ # if current_step < self.num_warmup_steps:
+ # return current_step / max(1, self.num_warmup_steps)
+ return max(
+ 0.0,
+ float(self.num_training_steps - current_step)
+ / float(max(1, self.num_training_steps)),
+ )
+
+ return [base_lr * scheduler_fn(self.last_epoch) for base_lr in self.base_lrs]
diff --git a/relik/retriever/trainer/__init__.py b/relik/retriever/trainer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1b18bf79091418217ae2bb782c3796dfa8b5b56
--- /dev/null
+++ b/relik/retriever/trainer/__init__.py
@@ -0,0 +1 @@
+from relik.retriever.trainer.train import RetrieverTrainer
diff --git a/relik/retriever/trainer/train.py b/relik/retriever/trainer/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc55298c178e7f519cf2f724e9a0881a168c9158
--- /dev/null
+++ b/relik/retriever/trainer/train.py
@@ -0,0 +1,667 @@
+import os
+from pathlib import Path
+from typing import List, Optional, Union
+
+import hydra
+import lightning as pl
+import omegaconf
+import torch
+from lightning import Trainer
+from lightning.pytorch.callbacks import (
+ EarlyStopping,
+ LearningRateMonitor,
+ ModelCheckpoint,
+ ModelSummary,
+)
+from lightning.pytorch.loggers import WandbLogger
+from omegaconf import OmegaConf
+from rich.pretty import pprint
+
+from relik.common.log import get_console_logger
+from relik.retriever.callbacks.evaluation_callbacks import (
+ AvgRankingEvaluationCallback,
+ RecallAtKEvaluationCallback,
+)
+from relik.retriever.callbacks.prediction_callbacks import (
+ GoldenRetrieverPredictionCallback,
+ NegativeAugmentationCallback,
+)
+from relik.retriever.callbacks.utils_callbacks import (
+ FreeUpIndexerVRAMCallback,
+ SavePredictionsCallback,
+ SaveRetrieverCallback,
+)
+from relik.retriever.data.datasets import GoldenRetrieverDataset
+from relik.retriever.indexers.base import BaseDocumentIndex
+from relik.retriever.lightning_modules.pl_data_modules import (
+ GoldenRetrieverPLDataModule,
+)
+from relik.retriever.lightning_modules.pl_modules import GoldenRetrieverPLModule
+from relik.retriever.pytorch_modules.loss import MultiLabelNCELoss
+from relik.retriever.pytorch_modules.model import GoldenRetriever
+from relik.retriever.pytorch_modules.optim import RAdamW
+from relik.retriever.pytorch_modules.scheduler import (
+ LinearScheduler,
+ LinearSchedulerWithWarmup,
+)
+
+logger = get_console_logger()
+
+
+class RetrieverTrainer:
+ def __init__(
+ self,
+ retriever: GoldenRetriever,
+ train_dataset: GoldenRetrieverDataset,
+ val_dataset: Union[GoldenRetrieverDataset, list[GoldenRetrieverDataset]],
+ test_dataset: Optional[
+ Union[GoldenRetrieverDataset, list[GoldenRetrieverDataset]]
+ ] = None,
+ num_workers: int = 4,
+ optimizer: torch.optim.Optimizer = RAdamW,
+ lr: float = 1e-5,
+ weight_decay: float = 0.01,
+ lr_scheduler: torch.optim.lr_scheduler.LRScheduler = LinearScheduler,
+ num_warmup_steps: int = 0,
+ loss: torch.nn.Module = MultiLabelNCELoss,
+ callbacks: Optional[list] = None,
+ accelerator: str = "auto",
+ devices: int = 1,
+ num_nodes: int = 1,
+ strategy: str = "auto",
+ accumulate_grad_batches: int = 1,
+ gradient_clip_val: float = 1.0,
+ val_check_interval: float = 1.0,
+ check_val_every_n_epoch: int = 1,
+ max_steps: Optional[int] = None,
+ max_epochs: Optional[int] = None,
+ # checkpoint_path: Optional[Union[str, os.PathLike]] = None,
+ deterministic: bool = True,
+ fast_dev_run: bool = False,
+ precision: int = 16,
+ reload_dataloaders_every_n_epochs: int = 1,
+ top_ks: Union[int, List[int]] = 100,
+ # early stopping parameters
+ early_stopping: bool = True,
+ early_stopping_patience: int = 10,
+ # wandb logger parameters
+ log_to_wandb: bool = True,
+ wandb_entity: Optional[str] = None,
+ wandb_experiment_name: Optional[str] = None,
+ wandb_project_name: Optional[str] = None,
+ wandb_save_dir: Optional[Union[str, os.PathLike]] = None,
+ wandb_log_model: bool = True,
+ wandb_offline_mode: bool = False,
+ wandb_watch: str = "all",
+ # checkpoint parameters
+ model_checkpointing: bool = True,
+ chekpoint_dir: Optional[Union[str, os.PathLike]] = None,
+ checkpoint_filename: Optional[Union[str, os.PathLike]] = None,
+ save_top_k: int = 1,
+ save_last: bool = False,
+ # prediction callback parameters
+ prediction_batch_size: int = 128,
+ # hard negatives callback parameters
+ max_hard_negatives_to_mine: int = 15,
+ hard_negatives_threshold: float = 0.0,
+ metrics_to_monitor_for_hard_negatives: Optional[str] = None,
+ mine_hard_negatives_with_probability: float = 1.0,
+ # other parameters
+ seed: int = 42,
+ float32_matmul_precision: str = "medium",
+ **kwargs,
+ ):
+ # put all the parameters in the class
+ self.retriever = retriever
+ # datasets
+ self.train_dataset = train_dataset
+ self.val_dataset = val_dataset
+ self.test_dataset = test_dataset
+ self.num_workers = num_workers
+ # trainer parameters
+ self.optimizer = optimizer
+ self.lr = lr
+ self.weight_decay = weight_decay
+ self.lr_scheduler = lr_scheduler
+ self.num_warmup_steps = num_warmup_steps
+ self.loss = loss
+ self.callbacks = callbacks
+ self.accelerator = accelerator
+ self.devices = devices
+ self.num_nodes = num_nodes
+ self.strategy = strategy
+ self.accumulate_grad_batches = accumulate_grad_batches
+ self.gradient_clip_val = gradient_clip_val
+ self.val_check_interval = val_check_interval
+ self.check_val_every_n_epoch = check_val_every_n_epoch
+ self.max_steps = max_steps
+ self.max_epochs = max_epochs
+ # self.checkpoint_path = checkpoint_path
+ self.deterministic = deterministic
+ self.fast_dev_run = fast_dev_run
+ self.precision = precision
+ self.reload_dataloaders_every_n_epochs = reload_dataloaders_every_n_epochs
+ self.top_ks = top_ks
+ # early stopping parameters
+ self.early_stopping = early_stopping
+ self.early_stopping_patience = early_stopping_patience
+ # wandb logger parameters
+ self.log_to_wandb = log_to_wandb
+ self.wandb_entity = wandb_entity
+ self.wandb_experiment_name = wandb_experiment_name
+ self.wandb_project_name = wandb_project_name
+ self.wandb_save_dir = wandb_save_dir
+ self.wandb_log_model = wandb_log_model
+ self.wandb_offline_mode = wandb_offline_mode
+ self.wandb_watch = wandb_watch
+ # checkpoint parameters
+ self.model_checkpointing = model_checkpointing
+ self.chekpoint_dir = chekpoint_dir
+ self.checkpoint_filename = checkpoint_filename
+ self.save_top_k = save_top_k
+ self.save_last = save_last
+ # prediction callback parameters
+ self.prediction_batch_size = prediction_batch_size
+ # hard negatives callback parameters
+ self.max_hard_negatives_to_mine = max_hard_negatives_to_mine
+ self.hard_negatives_threshold = hard_negatives_threshold
+ self.metrics_to_monitor_for_hard_negatives = (
+ metrics_to_monitor_for_hard_negatives
+ )
+ self.mine_hard_negatives_with_probability = mine_hard_negatives_with_probability
+ # other parameters
+ self.seed = seed
+ self.float32_matmul_precision = float32_matmul_precision
+
+ if self.max_epochs is None and self.max_steps is None:
+ raise ValueError(
+ "Either `max_epochs` or `max_steps` should be specified in the trainer configuration"
+ )
+
+ if self.max_epochs is not None and self.max_steps is not None:
+ logger.log(
+ "Both `max_epochs` and `max_steps` are specified in the trainer configuration. "
+ "Will use `max_epochs` for the number of training steps"
+ )
+ self.max_steps = None
+
+ # reproducibility
+ pl.seed_everything(self.seed)
+ # set the precision of matmul operations
+ torch.set_float32_matmul_precision(self.float32_matmul_precision)
+
+ # lightning data module declaration
+ self.lightining_datamodule = self.configure_lightning_datamodule()
+
+ if self.max_epochs is not None:
+ logger.log(f"Number of training epochs: {self.max_epochs}")
+ self.max_steps = (
+ len(self.lightining_datamodule.train_dataloader()) * self.max_epochs
+ )
+
+ # optimizer declaration
+ self.optimizer, self.lr_scheduler = self.configure_optimizers()
+
+ # lightning module declaration
+ self.lightining_module = self.configure_lightning_module()
+
+ # callbacks declaration
+ self.callbacks_store: List[pl.Callback] = self.configure_callbacks()
+
+ logger.log("Instantiating the Trainer")
+ self.trainer = pl.Trainer(
+ accelerator=self.accelerator,
+ devices=self.devices,
+ num_nodes=self.num_nodes,
+ strategy=self.strategy,
+ accumulate_grad_batches=self.accumulate_grad_batches,
+ max_epochs=self.max_epochs,
+ max_steps=self.max_steps,
+ gradient_clip_val=self.gradient_clip_val,
+ val_check_interval=self.val_check_interval,
+ check_val_every_n_epoch=self.check_val_every_n_epoch,
+ deterministic=self.deterministic,
+ fast_dev_run=self.fast_dev_run,
+ precision=self.precision,
+ reload_dataloaders_every_n_epochs=self.reload_dataloaders_every_n_epochs,
+ callbacks=self.callbacks_store,
+ logger=self.wandb_logger,
+ )
+
+ def configure_lightning_datamodule(self, *args, **kwargs):
+ # lightning data module declaration
+ if isinstance(self.val_dataset, GoldenRetrieverDataset):
+ self.val_dataset = [self.val_dataset]
+ if self.test_dataset is not None and isinstance(
+ self.test_dataset, GoldenRetrieverDataset
+ ):
+ self.test_dataset = [self.test_dataset]
+
+ self.lightining_datamodule = GoldenRetrieverPLDataModule(
+ train_dataset=self.train_dataset,
+ val_datasets=self.val_dataset,
+ test_datasets=self.test_dataset,
+ num_workers=self.num_workers,
+ *args,
+ **kwargs,
+ )
+ return self.lightining_datamodule
+
+ def configure_lightning_module(self, *args, **kwargs):
+ # add loss object to the retriever
+ if self.retriever.loss_type is None:
+ self.retriever.loss_type = self.loss()
+
+ # lightning module declaration
+ self.lightining_module = GoldenRetrieverPLModule(
+ model=self.retriever,
+ optimizer=self.optimizer,
+ lr_scheduler=self.lr_scheduler,
+ *args,
+ **kwargs,
+ )
+
+ return self.lightining_module
+
+ def configure_optimizers(self, *args, **kwargs):
+ # check if it is the class or the instance
+ if isinstance(self.optimizer, type):
+ self.optimizer = self.optimizer(
+ params=self.retriever.parameters(),
+ lr=self.lr,
+ weight_decay=self.weight_decay,
+ )
+ else:
+ self.optimizer = self.optimizer
+
+ # LR Scheduler declaration
+ # check if it is the class, the instance or a function
+ if self.lr_scheduler is not None:
+ if isinstance(self.lr_scheduler, type):
+ self.lr_scheduler = self.lr_scheduler(
+ optimizer=self.optimizer,
+ num_warmup_steps=self.num_warmup_steps,
+ num_training_steps=self.max_steps,
+ )
+
+ return self.optimizer, self.lr_scheduler
+
+ def configure_callbacks(self, *args, **kwargs):
+ # callbacks declaration
+ self.callbacks_store = self.callbacks or []
+ self.callbacks_store.append(ModelSummary(max_depth=2))
+
+ # metric to monitor
+ if isinstance(self.top_ks, int):
+ self.top_ks = [self.top_ks]
+ # order the top_ks in descending order
+ self.top_ks = sorted(self.top_ks, reverse=True)
+ # get the max top_k to monitor
+ self.top_k = self.top_ks[0]
+ self.metric_to_monitor = f"validate_recall@{self.top_k}"
+ self.monitor_mode = "max"
+
+ # early stopping callback if specified
+ self.early_stopping_callback: Optional[EarlyStopping] = None
+ if self.early_stopping:
+ logger.log(
+ f"Eanbling Early Stopping, patience: {self.early_stopping_patience}"
+ )
+ self.early_stopping_callback = EarlyStopping(
+ monitor=self.metric_to_monitor,
+ mode=self.monitor_mode,
+ patience=self.early_stopping_patience,
+ )
+ self.callbacks_store.append(self.early_stopping_callback)
+
+ # wandb logger if specified
+ self.wandb_logger: Optional[WandbLogger] = None
+ self.experiment_path: Optional[Path] = None
+ if self.log_to_wandb:
+ # define some default values for the wandb logger
+ if self.wandb_project_name is None:
+ self.wandb_project_name = "relik-retriever"
+ if self.wandb_save_dir is None:
+ self.wandb_save_dir = "./"
+ logger.log("Instantiating Wandb Logger")
+ self.wandb_logger = WandbLogger(
+ entity=self.wandb_entity,
+ project=self.wandb_project_name,
+ name=self.wandb_experiment_name,
+ save_dir=self.wandb_save_dir,
+ log_model=self.wandb_log_model,
+ mode="offline" if self.wandb_offline_mode else "online",
+ )
+ self.wandb_logger.watch(self.lightining_module, log=self.wandb_watch)
+ self.experiment_path = Path(self.wandb_logger.experiment.dir)
+ # Store the YaML config separately into the wandb dir
+ # yaml_conf: str = OmegaConf.to_yaml(cfg=conf)
+ # (experiment_path / "hparams.yaml").write_text(yaml_conf)
+ # Add a Learning Rate Monitor callback to log the learning rate
+ self.callbacks_store.append(LearningRateMonitor(logging_interval="step"))
+
+ # model checkpoint callback if specified
+ self.model_checkpoint_callback: Optional[ModelCheckpoint] = None
+ if self.model_checkpointing:
+ logger.log("Enabling Model Checkpointing")
+ if self.chekpoint_dir is None:
+ self.chekpoint_dir = (
+ self.experiment_path / "checkpoints"
+ if self.experiment_path
+ else None
+ )
+ if self.checkpoint_filename is None:
+ self.checkpoint_filename = (
+ "checkpoint-validate_recall@"
+ + str(self.top_k)
+ + "_{validate_recall@"
+ + str(self.top_k)
+ + ":.4f}-epoch_{epoch:02d}"
+ )
+ self.model_checkpoint_callback = ModelCheckpoint(
+ monitor=self.metric_to_monitor,
+ mode=self.monitor_mode,
+ verbose=True,
+ save_top_k=self.save_top_k,
+ save_last=self.save_last,
+ filename=self.checkpoint_filename,
+ dirpath=self.chekpoint_dir,
+ auto_insert_metric_name=False,
+ )
+ self.callbacks_store.append(self.model_checkpoint_callback)
+
+ # prediction callback
+ self.other_callbacks_for_prediction = [
+ RecallAtKEvaluationCallback(k) for k in self.top_ks
+ ]
+ self.other_callbacks_for_prediction += [
+ AvgRankingEvaluationCallback(k=self.top_k, verbose=True, prefix="train"),
+ SavePredictionsCallback(),
+ ]
+ self.prediction_callback = GoldenRetrieverPredictionCallback(
+ k=self.top_k,
+ batch_size=self.prediction_batch_size,
+ precision=self.precision,
+ other_callbacks=self.other_callbacks_for_prediction,
+ )
+ self.callbacks_store.append(self.prediction_callback)
+
+ # hard negative mining callback
+ self.hard_negatives_callback: Optional[NegativeAugmentationCallback] = None
+ if self.max_hard_negatives_to_mine > 0:
+ self.metrics_to_monitor = (
+ self.metrics_to_monitor_for_hard_negatives
+ or f"validate_recall@{self.top_k}"
+ )
+ self.hard_negatives_callback = NegativeAugmentationCallback(
+ k=self.top_k,
+ batch_size=self.prediction_batch_size,
+ precision=self.precision,
+ stages=["validate"],
+ metrics_to_monitor=self.metrics_to_monitor,
+ threshold=self.hard_negatives_threshold,
+ max_negatives=self.max_hard_negatives_to_mine,
+ add_with_probability=self.mine_hard_negatives_with_probability,
+ refresh_every_n_epochs=1,
+ other_callbacks=[
+ AvgRankingEvaluationCallback(
+ k=self.top_k, verbose=True, prefix="train"
+ )
+ ],
+ )
+ self.callbacks_store.append(self.hard_negatives_callback)
+
+ # utils callback
+ self.callbacks_store.extend(
+ [SaveRetrieverCallback(), FreeUpIndexerVRAMCallback()]
+ )
+ return self.callbacks_store
+
+ def train(self):
+ self.trainer.fit(self.lightining_module, datamodule=self.lightining_datamodule)
+
+ def test(
+ self,
+ lightining_module: Optional[GoldenRetrieverPLModule] = None,
+ checkpoint_path: Optional[Union[str, os.PathLike]] = None,
+ lightining_datamodule: Optional[GoldenRetrieverPLDataModule] = None,
+ ):
+ if lightining_module is not None:
+ self.lightining_module = lightining_module
+ else:
+ if self.fast_dev_run:
+ best_lightining_module = self.lightining_module
+ else:
+ # load best model for testing
+ if checkpoint_path is not None:
+ best_model_path = checkpoint_path
+ elif self.checkpoint_path:
+ best_model_path = self.checkpoint_path
+ elif self.model_checkpoint_callback:
+ best_model_path = self.model_checkpoint_callback.best_model_path
+ else:
+ raise ValueError(
+ "Either `checkpoint_path` or `model_checkpoint_callback` should "
+ "be provided to the trainer"
+ )
+ logger.log(f"Loading best model from {best_model_path}")
+
+ try:
+ best_lightining_module = (
+ GoldenRetrieverPLModule.load_from_checkpoint(best_model_path)
+ )
+ except Exception as e:
+ logger.log(f"Failed to load the model from checkpoint: {e}")
+ logger.log("Using last model instead")
+ best_lightining_module = self.lightining_module
+
+ lightining_datamodule = lightining_datamodule or self.lightining_datamodule
+ # module test
+ self.trainer.test(best_lightining_module, datamodule=lightining_datamodule)
+
+
+def train(conf: omegaconf.DictConfig) -> None:
+ # reproducibility
+ pl.seed_everything(conf.train.seed)
+ torch.set_float32_matmul_precision(conf.train.float32_matmul_precision)
+
+ logger.log(f"Starting training for [bold cyan]{conf.model_name}[/bold cyan] model")
+ if conf.train.pl_trainer.fast_dev_run:
+ logger.log(
+ f"Debug mode {conf.train.pl_trainer.fast_dev_run}. Forcing debugger configuration"
+ )
+ # Debuggers don't like GPUs nor multiprocessing
+ # conf.train.pl_trainer.accelerator = "cpu"
+ conf.train.pl_trainer.devices = 1
+ conf.train.pl_trainer.strategy = "auto"
+ conf.train.pl_trainer.precision = 32
+ if "num_workers" in conf.data.datamodule:
+ conf.data.datamodule.num_workers = {
+ k: 0 for k in conf.data.datamodule.num_workers
+ }
+ # Switch wandb to offline mode to prevent online logging
+ conf.logging.log = None
+ # remove model checkpoint callback
+ conf.train.model_checkpoint_callback = None
+
+ if "print_config" in conf and conf.print_config:
+ pprint(OmegaConf.to_container(conf), console=logger, expand_all=True)
+
+ # data module declaration
+ logger.log("Instantiating the Data Module")
+ pl_data_module: GoldenRetrieverPLDataModule = hydra.utils.instantiate(
+ conf.data.datamodule, _recursive_=False
+ )
+ # force setup to get labels initialized for the model
+ pl_data_module.prepare_data()
+ # main module declaration
+ pl_module: Optional[GoldenRetrieverPLModule] = None
+
+ if not conf.train.only_test:
+ pl_data_module.setup("fit")
+
+ # count the number of training steps
+ if (
+ "max_epochs" in conf.train.pl_trainer
+ and conf.train.pl_trainer.max_epochs > 0
+ ):
+ num_training_steps = (
+ len(pl_data_module.train_dataloader())
+ * conf.train.pl_trainer.max_epochs
+ )
+ if "max_steps" in conf.train.pl_trainer:
+ logger.log(
+ "Both `max_epochs` and `max_steps` are specified in the trainer configuration. "
+ "Will use `max_epochs` for the number of training steps"
+ )
+ conf.train.pl_trainer.max_steps = None
+ elif (
+ "max_steps" in conf.train.pl_trainer and conf.train.pl_trainer.max_steps > 0
+ ):
+ num_training_steps = conf.train.pl_trainer.max_steps
+ conf.train.pl_trainer.max_epochs = None
+ else:
+ raise ValueError(
+ "Either `max_epochs` or `max_steps` should be specified in the trainer configuration"
+ )
+ logger.log(f"Expected number of training steps: {num_training_steps}")
+
+ if "lr_scheduler" in conf.model.pl_module and conf.model.pl_module.lr_scheduler:
+ # set the number of warmup steps as x% of the total number of training steps
+ if conf.model.pl_module.lr_scheduler.num_warmup_steps is None:
+ if (
+ "warmup_steps_ratio" in conf.model.pl_module
+ and conf.model.pl_module.warmup_steps_ratio is not None
+ ):
+ conf.model.pl_module.lr_scheduler.num_warmup_steps = int(
+ conf.model.pl_module.lr_scheduler.num_training_steps
+ * conf.model.pl_module.warmup_steps_ratio
+ )
+ else:
+ conf.model.pl_module.lr_scheduler.num_warmup_steps = 0
+ logger.log(
+ f"Number of warmup steps: {conf.model.pl_module.lr_scheduler.num_warmup_steps}"
+ )
+
+ logger.log("Instantiating the Model")
+ pl_module: GoldenRetrieverPLModule = hydra.utils.instantiate(
+ conf.model.pl_module, _recursive_=False
+ )
+ if (
+ "pretrain_ckpt_path" in conf.train
+ and conf.train.pretrain_ckpt_path is not None
+ ):
+ logger.log(
+ f"Loading pretrained checkpoint from {conf.train.pretrain_ckpt_path}"
+ )
+ pl_module.load_state_dict(
+ torch.load(conf.train.pretrain_ckpt_path)["state_dict"], strict=False
+ )
+
+ if "compile" in conf.model.pl_module and conf.model.pl_module.compile:
+ try:
+ pl_module = torch.compile(pl_module, backend="inductor")
+ except Exception:
+ logger.log(
+ "Failed to compile the model, you may need to install PyTorch 2.0"
+ )
+
+ # callbacks declaration
+ callbacks_store = [ModelSummary(max_depth=2)]
+
+ experiment_logger: Optional[WandbLogger] = None
+ experiment_path: Optional[Path] = None
+ if conf.logging.log:
+ logger.log("Instantiating Wandb Logger")
+ experiment_logger = hydra.utils.instantiate(conf.logging.wandb_arg)
+ if pl_module is not None:
+ # it may happen that the model is not instantiated if we are only testing
+ # in that case, we don't need to watch the model
+ experiment_logger.watch(pl_module, **conf.logging.watch)
+ experiment_path = Path(experiment_logger.experiment.dir)
+ # Store the YaML config separately into the wandb dir
+ yaml_conf: str = OmegaConf.to_yaml(cfg=conf)
+ (experiment_path / "hparams.yaml").write_text(yaml_conf)
+ # Add a Learning Rate Monitor callback to log the learning rate
+ callbacks_store.append(LearningRateMonitor(logging_interval="step"))
+
+ early_stopping_callback: Optional[EarlyStopping] = None
+ if conf.train.early_stopping_callback is not None:
+ early_stopping_callback = hydra.utils.instantiate(
+ conf.train.early_stopping_callback
+ )
+ callbacks_store.append(early_stopping_callback)
+
+ model_checkpoint_callback: Optional[ModelCheckpoint] = None
+ if conf.train.model_checkpoint_callback is not None:
+ model_checkpoint_callback = hydra.utils.instantiate(
+ conf.train.model_checkpoint_callback,
+ dirpath=experiment_path / "checkpoints" if experiment_path else None,
+ )
+ callbacks_store.append(model_checkpoint_callback)
+
+ if "callbacks" in conf.train and conf.train.callbacks is not None:
+ for _, callback in conf.train.callbacks.items():
+ # callback can be a list of callbacks or a single callback
+ if isinstance(callback, omegaconf.listconfig.ListConfig):
+ for cb in callback:
+ if cb is not None:
+ callbacks_store.append(
+ hydra.utils.instantiate(cb, _recursive_=False)
+ )
+ else:
+ if callback is not None:
+ callbacks_store.append(hydra.utils.instantiate(callback))
+
+ # trainer
+ logger.log("Instantiating the Trainer")
+ trainer: Trainer = hydra.utils.instantiate(
+ conf.train.pl_trainer, callbacks=callbacks_store, logger=experiment_logger
+ )
+
+ if not conf.train.only_test:
+ # module fit
+ trainer.fit(pl_module, datamodule=pl_data_module)
+
+ if conf.train.pl_trainer.fast_dev_run:
+ best_pl_module = pl_module
+ else:
+ # load best model for testing
+ if conf.train.checkpoint_path:
+ best_model_path = conf.evaluation.checkpoint_path
+ elif model_checkpoint_callback:
+ best_model_path = model_checkpoint_callback.best_model_path
+ else:
+ raise ValueError(
+ "Either `checkpoint_path` or `model_checkpoint_callback` should "
+ "be specified in the evaluation configuration"
+ )
+ logger.log(f"Loading best model from {best_model_path}")
+
+ try:
+ best_pl_module = GoldenRetrieverPLModule.load_from_checkpoint(
+ best_model_path
+ )
+ except Exception as e:
+ logger.log(f"Failed to load the model from checkpoint: {e}")
+ logger.log("Using last model instead")
+ best_pl_module = pl_module
+ if "compile" in conf.model.pl_module and conf.model.pl_module.compile:
+ try:
+ best_pl_module = torch.compile(best_pl_module, backend="inductor")
+ except Exception:
+ logger.log(
+ "Failed to compile the model, you may need to install PyTorch 2.0"
+ )
+
+ # module test
+ trainer.test(best_pl_module, datamodule=pl_data_module)
+
+
+@hydra.main(config_path="../../conf", config_name="default", version_base="1.3")
+def main(conf: omegaconf.DictConfig):
+ train(conf)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/relik/version.py b/relik/version.py
new file mode 100644
index 0000000000000000000000000000000000000000..bed137800c980e0e82d7c8ccdf474053baed630f
--- /dev/null
+++ b/relik/version.py
@@ -0,0 +1,13 @@
+import os
+
+_MAJOR = "0"
+_MINOR = "1"
+# On main and in a nightly release the patch should be one ahead of the last
+# released build.
+_PATCH = "0"
+# This is mainly for nightly builds which have the suffix ".dev$DATE". See
+# https://semver.org/#is-v123-a-semantic-version for the semantics.
+_SUFFIX = os.environ.get("RELIK_VERSION_SUFFIX", "")
+
+VERSION_SHORT = "{0}.{1}".format(_MAJOR, _MINOR)
+VERSION = "{0}.{1}.{2}{3}".format(_MAJOR, _MINOR, _PATCH, _SUFFIX)
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..5933df9197cc3d656a9463e4da1a1ea5680bbf93
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,36 @@
+#------- Core dependencies -------
+torch>=2.0
+transformers[sentencepiece]>=4.34,<4.35
+rich>=13.0.0,<14.0.0
+scikit-learn
+overrides
+
+#------- Optional dependencies -------
+
+# train
+lightning>=2.0,<2.1
+hydra-core>=1.3,<1.4
+hydra_colorlog
+wandb>=0.15,<0.16
+datasets>=2.13,<2.15
+
+# faiss
+faiss-cpu==1.7.4 # needed by: faiss
+
+# serve
+fastapi>=0.103,<0.104 # needed by: serve
+uvicorn[standard]==0.23.2 # needed by: serve
+gunicorn==21.2.0 # needed by: serve
+ray[serve]>=2.7,<=2.8 # needed by: serve
+ipa-core # needed by: serve
+streamlit>=1.27,<1.28 # needed by: serve
+streamlit_extras>=0.3,<0.4 # needed by: serve
+
+# retriever
+
+# reader
+
+# dev
+pre-commit
+black[d]
+isort
diff --git a/scripts/setup.sh b/scripts/setup.sh
new file mode 100755
index 0000000000000000000000000000000000000000..36f2afd2cba191cb89c2b36ffc64d51cd2274cc5
--- /dev/null
+++ b/scripts/setup.sh
@@ -0,0 +1,46 @@
+#!/bin/bash
+
+# setup conda
+CONDA_BASE=$(conda info --base)
+# check if conda is installed
+if [ -z "$CONDA_BASE" ]; then
+ echo "Conda is not installed. Please install conda first."
+ exit 1
+fi
+source "$CONDA_BASE"/etc/profile.d/conda.sh
+
+# create conda env
+read -rp "Enter environment name or prefix: " ENV_NAME
+read -rp "Enter python version (default 3.10): " PYTHON_VERSION
+if [ -z "$PYTHON_VERSION" ]; then
+ PYTHON_VERSION="3.10"
+fi
+
+# check if ENV_NAME is a full path
+if [[ "$ENV_NAME" == /* ]]; then
+ CONDA_NEW_ARG="--prefix"
+else
+ CONDA_NEW_ARG="--name"
+fi
+
+conda create -y "$CONDA_NEW_ARG" "$ENV_NAME" python="$PYTHON_VERSION"
+conda activate "$ENV_NAME"
+
+# replace placeholder env with $ENV_NAME in scripts/train.sh
+# NEW_CONDA_LINE="source \$CONDA_BASE/bin/activate $ENV_NAME"
+# sed -i.bak -e "s,.*bin/activate.*,$NEW_CONDA_LINE,g" scripts/train.sh
+
+# install torch
+read -rp "Enter cuda version (e.g. '11.8', default no cuda support): " CUDA_VERSION
+read -rp "Enter PyTorch version (e.g. '2.1', default latest): " PYTORCH_VERSION
+if [ -n "$PYTORCH_VERSION" ]; then
+ PYTORCH_VERSION="=$PYTORCH_VERSION"
+fi
+if [ -z "$CUDA_VERSION" ]; then
+ conda install -y pytorch"$PYTORCH_VERSION" cpuonly -c pytorch
+else
+ conda install -y pytorch"$PYTORCH_VERSION" pytorch-cuda="$CUDA_VERSION" -c pytorch -c nvidia
+fi
+
+# install python requirements
+pip install -e .[all]
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..37575622936c68b53ced7e8f63fb32a3fa701047
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,128 @@
+"""
+Simple check list from AllenNLP repo: https://github.com/allenai/allennlp/blob/master/setup.py
+To create the package for pypi.
+1. Run `make pre-release` (or `make pre-patch` for a patch release) then run `make fix-copies` to fix the index of the
+ documentation.
+2. Run Tests for Amazon Sagemaker. The documentation is located in `./tests/sagemaker/README.md`, otherwise @philschmid.
+3. Unpin specific versions from setup.py that use a git install.
+4. Commit these changes with the message: "Release: VERSION"
+5. Add a tag in git to mark the release: "git tag VERSION -m 'Adds tag VERSION for pypi' "
+ Push the tag to git: git push --tags origin master
+6. Build both the sources and the wheel. Do not change anything in setup.py between
+ creating the wheel and the source distribution (obviously).
+ For the wheel, run: "python setup.py bdist_wheel" in the top level directory.
+ (this will build a wheel for the python version you use to build it).
+ For the sources, run: "python setup.py sdist"
+ You should now have a /dist directory with both .whl and .tar.gz source versions.
+7. Check that everything looks correct by uploading the package to the pypi test server:
+ twine upload dist/* -r pypitest
+ (pypi suggest using twine as other methods upload files via plaintext.)
+ You may have to specify the repository url, use the following command then:
+ twine upload dist/* -r pypitest --repository-url=https://test.pypi.org/legacy/
+ Check that you can install it in a virtualenv by running:
+ pip install -i https://testpypi.python.org/pypi transformers
+8. Upload the final version to actual pypi:
+ twine upload dist/* -r pypi
+9. Copy the release notes from RELEASE.md to the tag in github once everything is looking hunky-dory.
+10. Run `make post-release` (or `make post-patch` for a patch release).
+"""
+from collections import defaultdict
+
+import setuptools
+
+
+def parse_requirements_file(
+ path, allowed_extras: set = None, include_all_extra: bool = True
+):
+ requirements = []
+ extras = defaultdict(list)
+ find_links = []
+ with open(path) as requirements_file:
+ import re
+
+ def fix_url_dependencies(req: str) -> str:
+ """Pip and setuptools disagree about how URL dependencies should be handled."""
+ m = re.match(
+ r"^(git\+)?(https|ssh)://(git@)?github\.com/([\w-]+)/(?P[\w-]+)\.git",
+ req,
+ )
+ if m is None:
+ return req
+ else:
+ return f"{m.group('name')} @ {req}"
+
+ for line in requirements_file:
+ line = line.strip()
+ if line.startswith("#") or len(line) <= 0:
+ continue
+ if (
+ line.startswith("-f")
+ or line.startswith("--find-links")
+ or line.startswith("--index-url")
+ ):
+ find_links.append(line.split(" ", maxsplit=1)[-1].strip())
+ continue
+
+ req, *needed_by = line.split("# needed by:")
+ req = fix_url_dependencies(req.strip())
+ if needed_by:
+ for extra in needed_by[0].strip().split(","):
+ extra = extra.strip()
+ if allowed_extras is not None and extra not in allowed_extras:
+ raise ValueError(f"invalid extra '{extra}' in {path}")
+ extras[extra].append(req)
+ if include_all_extra and req not in extras["all"]:
+ extras["all"].append(req)
+ else:
+ requirements.append(req)
+ return requirements, extras, find_links
+
+
+allowed_extras = {
+ "onnx",
+ "onnx-gpu",
+ "serve",
+ "retriever",
+ "reader",
+ "all",
+ "faiss",
+ "dev",
+}
+
+# Load requirements.
+install_requirements, extras, find_links = parse_requirements_file(
+ "requirements.txt", allowed_extras=allowed_extras
+)
+
+# version.py defines the VERSION and VERSION_SHORT variables.
+# We use exec here, so we don't import allennlp whilst setting up.
+VERSION = {} # type: ignore
+with open("relik/version.py", "r") as version_file:
+ exec(version_file.read(), VERSION)
+
+with open("README.md", "r") as fh:
+ long_description = fh.read()
+
+setuptools.setup(
+ name="relik",
+ version=VERSION["VERSION"],
+ author="Edoardo Barba, Riccardo Orlando, Pere-Lluís Huguet Cabot",
+ author_email="orlandorcc@gmail.com",
+ description="Fast and Accurate Entity Linking and Relation Extraction on an Academic Budget",
+ long_description=long_description,
+ long_description_content_type="text/markdown",
+ url="https://github.com/SapienzaNLP/relik",
+ keywords="NLP Sapienza sapienzanlp deep learning transformer pytorch retriever entity linking relation extraction reader budget",
+ packages=setuptools.find_packages(),
+ include_package_data=True,
+ license="Apache",
+ classifiers=[
+ "Intended Audience :: Science/Research",
+ "License :: OSI Approved :: Apache Software License",
+ "Programming Language :: Python :: 3",
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
+ ],
+ install_requires=install_requirements,
+ extras_require=extras,
+ python_requires=">=3.10",
+)
diff --git a/style.css b/style.css
new file mode 100644
index 0000000000000000000000000000000000000000..31f0d182cfd9b2636d5db5cbd0e7a1339ed5d1c3
--- /dev/null
+++ b/style.css
@@ -0,0 +1,33 @@
+/* Sidebar */
+.eczjsme11 {
+ background-color: #802433;
+}
+
+.st-emotion-cache-10oheav h2 {
+ color: white;
+}
+
+.st-emotion-cache-10oheav li {
+ color: white;
+}
+
+/* Main */
+a:link {
+ text-decoration: none;
+ color: white;
+}
+
+a:visited {
+ text-decoration: none;
+ color: white;
+}
+
+a:hover {
+ text-decoration: none;
+ color: rgba(255, 255, 255, 0.871);
+}
+
+a:active {
+ text-decoration: none;
+ color: white;
+}
\ No newline at end of file