diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..dbadf32bbad592c5933aa57b988953c1bbaa5e4b --- /dev/null +++ b/.gitignore @@ -0,0 +1,332 @@ +# custom + +data/* +experiments/* +retrievers +outputs +model +wandb + +# Created by https://www.toptal.com/developers/gitignore/api/jetbrains+all,vscode,python,jupyternotebooks,linux,windows,macos +# Edit at https://www.toptal.com/developers/gitignore?templates=jetbrains+all,vscode,python,jupyternotebooks,linux,windows,macos + +### JetBrains+all ### +# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider +# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 + +# User-specific stuff +.idea/**/workspace.xml +.idea/**/tasks.xml +.idea/**/usage.statistics.xml +.idea/**/dictionaries +.idea/**/shelf + +# Generated files +.idea/**/contentModel.xml + +# Sensitive or high-churn files +.idea/**/dataSources/ +.idea/**/dataSources.ids +.idea/**/dataSources.local.xml +.idea/**/sqlDataSources.xml +.idea/**/dynamic.xml +.idea/**/uiDesigner.xml +.idea/**/dbnavigator.xml + +# Gradle +.idea/**/gradle.xml +.idea/**/libraries + +# Gradle and Maven with auto-import +# When using Gradle or Maven with auto-import, you should exclude module files, +# since they will be recreated, and may cause churn. Uncomment if using +# auto-import. +# .idea/artifacts +# .idea/compiler.xml +# .idea/jarRepositories.xml +# .idea/modules.xml +# .idea/*.iml +# .idea/modules +# *.iml +# *.ipr + +# CMake +cmake-build-*/ + +# Mongo Explorer plugin +.idea/**/mongoSettings.xml + +# File-based project format +*.iws + +# IntelliJ +out/ + +# mpeltonen/sbt-idea plugin +.idea_modules/ + +# JIRA plugin +atlassian-ide-plugin.xml + +# Cursive Clojure plugin +.idea/replstate.xml + +# Crashlytics plugin (for Android Studio and IntelliJ) +com_crashlytics_export_strings.xml +crashlytics.properties +crashlytics-build.properties +fabric.properties + +# Editor-based Rest Client +.idea/httpRequests + +# Android studio 3.1+ serialized cache file +.idea/caches/build_file_checksums.ser + +### JetBrains+all Patch ### +# Ignores the whole .idea folder and all .iml files +# See https://github.com/joeblau/gitignore.io/issues/186 and https://github.com/joeblau/gitignore.io/issues/360 + +.idea/ + +# Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-249601023 + +*.iml +modules.xml +.idea/misc.xml +*.ipr + +# Sonarlint plugin +.idea/sonarlint + +### JupyterNotebooks ### +# gitignore template for Jupyter Notebooks +# website: http://jupyter.org/ + +.ipynb_checkpoints +*/.ipynb_checkpoints/* + +# IPython +profile_default/ +ipython_config.py + +# Remove previous ipynb_checkpoints +# git rm -r .ipynb_checkpoints/ + +### Linux ### +*~ + +# temporary files which can be created if a process still has a handle open of a deleted file +.fuse_hidden* + +# KDE directory preferences +.directory + +# Linux trash folder which might appear on any partition or disk +.Trash-* + +# .nfs files are created when an open file is removed but is still being accessed +.nfs* + +### macOS ### +# General +.DS_Store +.AppleDouble +.LSOverride + +# Icon must end with two \r +Icon + + +# Thumbnails +._* + +# Files that might appear in the root of a volume +.DocumentRevisions-V100 +.fseventsd +.Spotlight-V100 +.TemporaryItems +.Trashes +.VolumeIcon.icns +.com.apple.timemachine.donotpresent + +# Directories potentially created on remote AFP share +.AppleDB +.AppleDesktop +Network Trash Folder +Temporary Items +.apdisk + +### Python ### +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +pytestdebug.log + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ +doc/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook + +# IPython + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ +pythonenv* + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# profiling data +.prof + +### vscode ### +.vscode +.vscode/* +!.vscode/settings.json +!.vscode/tasks.json +!.vscode/launch.json +!.vscode/extensions.json +*.code-workspace + +### Windows ### +# Windows thumbnail cache files +Thumbs.db +Thumbs.db:encryptable +ehthumbs.db +ehthumbs_vista.db + +# Dump file +*.stackdump + +# Folder config file +[Dd]esktop.ini + +# Recycle Bin used on file shares +$RECYCLE.BIN/ + +# Windows Installer files +*.cab +*.msi +*.msix +*.msm +*.msp + +# Windows shortcuts +*.lnk + +# End of https://www.toptal.com/developers/gitignore/api/jetbrains+all,vscode,python,jupyternotebooks,linux,windows,macos \ No newline at end of file diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000000000000000000000000000000000000..f9bd1455b374de796e12d240c1211dee9829d97e --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1 @@ +include requirements.txt diff --git a/README.md b/README.md index 41e51eae2b1a8e039beb534af184753a672370eb..bc6e732d1aee0ed57c8c697ecc7364d35dea3929 100644 --- a/README.md +++ b/README.md @@ -1,13 +1 @@ ---- -title: Relik -emoji: 🐨 -colorFrom: gray -colorTo: pink -sdk: streamlit -sdk_version: 1.27.2 -app_file: app.py -pinned: false -license: mit ---- - -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +# relik \ No newline at end of file diff --git a/SETUP.cfg b/SETUP.cfg new file mode 100644 index 0000000000000000000000000000000000000000..4b4b70f561fe7171daf99bed49ad08fed98aedb6 --- /dev/null +++ b/SETUP.cfg @@ -0,0 +1,8 @@ +[metadata] +description-file = README.md + +[build] +build-base = /tmp/build + +[egg_info] +egg-base = /tmp diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..67f3c58aab567cdb797214d7d886db81b48d5584 --- /dev/null +++ b/app.py @@ -0,0 +1,245 @@ +import os +import re +import time +from pathlib import Path + +import requests +import streamlit as st +from spacy import displacy +from streamlit_extras.badges import badge +from streamlit_extras.stylable_container import stylable_container + +# RELIK = os.getenv("RELIK", "localhost:8000/api/entities") + +import random + +from relik.inference.annotator import Relik + + +def get_random_color(ents): + colors = {} + random_colors = generate_pastel_colors(len(ents)) + for ent in ents: + colors[ent] = random_colors.pop(random.randint(0, len(random_colors) - 1)) + return colors + + +def floatrange(start, stop, steps): + if int(steps) == 1: + return [stop] + return [ + start + float(i) * (stop - start) / (float(steps) - 1) for i in range(steps) + ] + + +def hsl_to_rgb(h, s, l): + def hue_2_rgb(v1, v2, v_h): + while v_h < 0.0: + v_h += 1.0 + while v_h > 1.0: + v_h -= 1.0 + if 6 * v_h < 1.0: + return v1 + (v2 - v1) * 6.0 * v_h + if 2 * v_h < 1.0: + return v2 + if 3 * v_h < 2.0: + return v1 + (v2 - v1) * ((2.0 / 3.0) - v_h) * 6.0 + return v1 + + # if not (0 <= s <= 1): raise ValueError, "s (saturation) parameter must be between 0 and 1." + # if not (0 <= l <= 1): raise ValueError, "l (lightness) parameter must be between 0 and 1." + + r, b, g = (l * 255,) * 3 + if s != 0.0: + if l < 0.5: + var_2 = l * (1.0 + s) + else: + var_2 = (l + s) - (s * l) + var_1 = 2.0 * l - var_2 + r = 255 * hue_2_rgb(var_1, var_2, h + (1.0 / 3.0)) + g = 255 * hue_2_rgb(var_1, var_2, h) + b = 255 * hue_2_rgb(var_1, var_2, h - (1.0 / 3.0)) + + return int(round(r)), int(round(g)), int(round(b)) + + +def generate_pastel_colors(n): + """Return different pastel colours. + + Input: + n (integer) : The number of colors to return + + Output: + A list of colors in HTML notation (eg.['#cce0ff', '#ffcccc', '#ccffe0', '#f5ccff', '#f5ffcc']) + + Example: + >>> print generate_pastel_colors(5) + ['#cce0ff', '#f5ccff', '#ffcccc', '#f5ffcc', '#ccffe0'] + """ + if n == 0: + return [] + + # To generate colors, we use the HSL colorspace (see http://en.wikipedia.org/wiki/HSL_color_space) + start_hue = 0.6 # 0=red 1/3=0.333=green 2/3=0.666=blue + saturation = 1.0 + lightness = 0.8 + # We take points around the chromatic circle (hue): + # (Note: we generate n+1 colors, then drop the last one ([:-1]) because + # it equals the first one (hue 0 = hue 1)) + return [ + "#%02x%02x%02x" % hsl_to_rgb(hue, saturation, lightness) + for hue in floatrange(start_hue, start_hue + 1, n + 1) + ][:-1] + + +def set_sidebar(css): + white_link_wrapper = "{}" + with st.sidebar: + st.markdown(f"", unsafe_allow_html=True) + st.image( + "http://nlp.uniroma1.it/static/website/sapienza-nlp-logo-wh.svg", + use_column_width=True, + ) + st.markdown("## ReLiK") + st.write( + f""" + - {white_link_wrapper.format("#", "  Paper")} + - {white_link_wrapper.format("https://github.com/SapienzaNLP/relik", "  GitHub")} + - {white_link_wrapper.format("https://hub.docker.com/repository/docker/sapienzanlp/relik", "  Docker Hub")} + """, + unsafe_allow_html=True, + ) + st.markdown("## Sapienza NLP") + st.write( + f""" + - {white_link_wrapper.format("https://nlp.uniroma1.it", "  Webpage")} + - {white_link_wrapper.format("https://github.com/SapienzaNLP", "  GitHub")} + - {white_link_wrapper.format("https://twitter.com/SapienzaNLP", "  Twitter")} + - {white_link_wrapper.format("https://www.linkedin.com/company/79434450", "  LinkedIn")} + """, + unsafe_allow_html=True, + ) + + +def get_el_annotations(response): + # swap labels key with ents + dict_of_ents = {"text": response.text, "ents": []} + dict_of_ents["ents"] = response.labels + label_in_text = set(l["label"] for l in dict_of_ents["ents"]) + options = {"ents": label_in_text, "colors": get_random_color(label_in_text)} + return dict_of_ents, options + + +def set_intro(css): + # intro + st.markdown("# ReLik") + st.markdown( + "### Retrieve, Read and LinK: Fast and Accurate Entity Linking and Relation Extraction on an Academic Budget" + ) + # st.markdown( + # "This is a front-end for the paper [Universal Semantic Annotator: the First Unified API " + # "for WSD, SRL and Semantic Parsing](https://www.researchgate.net/publication/360671045_Universal_Semantic_Annotator_the_First_Unified_API_for_WSD_SRL_and_Semantic_Parsing), which will be presented at LREC 2022 by " + # "[Riccardo Orlando](https://riccorl.github.io), [Simone Conia](https://c-simone.github.io/), " + # "[Stefano Faralli](https://corsidilaurea.uniroma1.it/it/users/stefanofaralliuniroma1it), and [Roberto Navigli](https://www.diag.uniroma1.it/navigli/)." + # ) + badge(type="github", name="sapienzanlp/relik") + badge(type="pypi", name="relik") + + +def run_client(): + with open(Path(__file__).parent / "style.css") as f: + css = f.read() + + st.set_page_config( + page_title="ReLik", + page_icon="🦮", + layout="wide", + ) + set_sidebar(css) + set_intro(css) + + # text input + text = st.text_area( + "Enter Text Below:", + value="Obama went to Rome for a quick vacation.", + height=200, + max_chars=500, + ) + + with stylable_container( + key="annotate_button", + css_styles=""" + button { + background-color: #802433; + color: white; + border-radius: 25px; + } + """, + ): + submit = st.button("Annotate") + # submit = st.button("Run") + + relik = Relik( + question_encoder="riccorl/relik-retriever-aida-blink-pretrain-omniencoder", + document_index="riccorl/index-relik-retriever-aida-blink-pretrain-omniencoder", + reader="riccorl/relik-reader-aida-deberta-small", + top_k=100, + window_size=32, + window_stride=16, + candidates_preprocessing_fn="relik.inference.preprocessing.wikipedia_title_and_openings_preprocessing", + ) + + # ReLik API call + if submit: + text = text.strip() + if text: + st.markdown("####") + st.markdown("#### Entity Linking") + with st.spinner(text="In progress"): + response = relik(text) + # response = requests.post(RELIK, json=text) + # if response.status_code != 200: + # st.error("Error: {}".format(response.status_code)) + # else: + # response = response.json() + + # Entity Linking + # with stylable_container( + # key="container_with_border", + # css_styles=""" + # { + # border: 1px solid rgba(49, 51, 63, 0.2); + # border-radius: 0.5rem; + # padding: 0.5rem; + # padding-bottom: 2rem; + # } + # """, + # ): + # st.markdown("##") + dict_of_ents, options = get_el_annotations(response=response) + display = displacy.render( + dict_of_ents, manual=True, style="ent", options=options + ) + display = display.replace("\n", " ") + # wsd_display = re.sub( + # r"(wiki::\d+\w)", + # r"\g<1>".format( + # language.upper() + # ), + # wsd_display, + # ) + with st.container(): + st.write(display, unsafe_allow_html=True) + + st.markdown("####") + st.markdown("#### Relation Extraction") + + with st.container(): + st.write("Coming :)", unsafe_allow_html=True) + + else: + st.error("Please enter some text.") + + +if __name__ == "__main__": + run_client() diff --git a/dockerfiles/Dockerfile.cpu b/dockerfiles/Dockerfile.cpu new file mode 100644 index 0000000000000000000000000000000000000000..b27436cea2a992a2b70e890e2435c528d9c0ba09 --- /dev/null +++ b/dockerfiles/Dockerfile.cpu @@ -0,0 +1,17 @@ +FROM tiangolo/uvicorn-gunicorn:python3.10-slim + +# Copy and install requirements.txt +COPY ./requirements.txt ./requirements.txt +COPY ./src /app +COPY ./scripts/start.sh /start.sh +COPY ./scripts/prestart.sh /app +COPY ./scripts/gunicorn_conf.py /gunicorn_conf.py +COPY ./scripts/start-reload.sh /start-reload.sh +COPY ./VERSION / +RUN mkdir -p /app/resources/model \ + && pip install --no-cache-dir -r requirements.txt \ + && chmod +x /start.sh && chmod +x /start-reload.sh +ARG MODEL_PATH +COPY ${MODEL_PATH}/* /app/resources/model/ + +ENV APP_MODULE=main:app diff --git a/dockerfiles/Dockerfile.cuda b/dockerfiles/Dockerfile.cuda new file mode 100644 index 0000000000000000000000000000000000000000..0ca669a954f1e76caa350e2311b138029b245d8c --- /dev/null +++ b/dockerfiles/Dockerfile.cuda @@ -0,0 +1,38 @@ +FROM nvidia/cuda:12.2.0-base-ubuntu20.04 + +ARG DEBIAN_FRONTEND=noninteractive + +RUN apt-get update \ + && apt-get install \ + curl wget python3.10 \ + python3.10-distutils \ + python3-pip \ + curl wget -y \ + && rm -rf /var/lib/apt/lists/* + +# FastAPI section +# device env +ENV DEVICE="cuda" +# Copy and install requirements.txt +COPY ./gpu-requirements.txt ./requirements.txt +COPY ./src /app +COPY ./scripts/start.sh /start.sh +COPY ./scripts/gunicorn_conf.py /gunicorn_conf.py +COPY ./scripts/start-reload.sh /start-reload.sh +COPY ./scripts/prestart.sh /app +COPY ./VERSION / +RUN mkdir -p /app/resources/model \ + && pip install --upgrade --no-cache-dir -r requirements.txt \ + && chmod +x /start.sh \ + && chmod +x /start-reload.sh +ARG MODEL_NAME_OR_PATH + +WORKDIR /app + +ENV PYTHONPATH=/app + +EXPOSE 80 + +# Run the start script, it will check for an /app/prestart.sh script (e.g. for migrations) +# And then will start Gunicorn with Uvicorn +CMD ["/start.sh"] diff --git a/examples/train_retriever.py b/examples/train_retriever.py new file mode 100644 index 0000000000000000000000000000000000000000..514eb0d80706012a50cd203e654bd14c57bb5db0 --- /dev/null +++ b/examples/train_retriever.py @@ -0,0 +1,45 @@ +from relik.retriever.trainer import RetrieverTrainer +from relik import GoldenRetriever +from relik.retriever.indexers.inmemory import InMemoryDocumentIndex +from relik.retriever.data.datasets import AidaInBatchNegativesDataset + +if __name__ == "__main__": + # instantiate retriever + document_index = InMemoryDocumentIndex( + documents="/root/golden-retriever-v2/data/dpr-like/el/definitions.txt", + device="cuda", + precision="16", + ) + retriever = GoldenRetriever( + question_encoder="intfloat/e5-small-v2", document_index=document_index + ) + + train_dataset = AidaInBatchNegativesDataset( + name="aida_train", + path="/root/golden-retriever-v2/data/dpr-like/el/aida_32_tokens_topic/train.jsonl", + tokenizer=retriever.question_tokenizer, + question_batch_size=64, + passage_batch_size=400, + max_passage_length=64, + use_topics=True, + shuffle=True, + ) + val_dataset = AidaInBatchNegativesDataset( + name="aida_val", + path="/root/golden-retriever-v2/data/dpr-like/el/aida_32_tokens_topic/val.jsonl", + tokenizer=retriever.question_tokenizer, + question_batch_size=64, + passage_batch_size=400, + max_passage_length=64, + use_topics=True, + ) + + trainer = RetrieverTrainer( + retriever=retriever, + train_dataset=train_dataset, + val_dataset=val_dataset, + max_steps=25_000, + wandb_offline_mode=True, + ) + + trainer.train() diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..0fbfa01d83e3309a377c8c9013070809b045999a --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,15 @@ +[tool.black] +include = '\.pyi?$' +exclude = ''' +/( + \.git + | \.hg + | \.mypy_cache + | \.tox + | \.venv + | _build + | buck-out + | build + | dist +)/ +''' diff --git a/relik/__init__.py b/relik/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..42a3df6b991b0af65ec5974fc4faa381b8e555b7 --- /dev/null +++ b/relik/__init__.py @@ -0,0 +1 @@ +from relik.retriever.pytorch_modules.model import GoldenRetriever diff --git a/relik/common/__init__.py b/relik/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/relik/common/log.py b/relik/common/log.py new file mode 100644 index 0000000000000000000000000000000000000000..b91e1fa7bfc22b759e0da4d69563315e31ce0e60 --- /dev/null +++ b/relik/common/log.py @@ -0,0 +1,97 @@ +import logging +import sys +import threading +from typing import Optional + +from rich import get_console + +_lock = threading.Lock() +_default_handler: Optional[logging.Handler] = None + +_default_log_level = logging.WARNING + +# fancy logger +_console = get_console() + + +def _get_library_name() -> str: + return __name__.split(".")[0] + + +def _get_library_root_logger() -> logging.Logger: + return logging.getLogger(_get_library_name()) + + +def _configure_library_root_logger() -> None: + global _default_handler + + with _lock: + if _default_handler: + # This library has already configured the library root logger. + return + _default_handler = logging.StreamHandler() # Set sys.stderr as stream. + _default_handler.flush = sys.stderr.flush + + # Apply our default configuration to the library root logger. + library_root_logger = _get_library_root_logger() + library_root_logger.addHandler(_default_handler) + library_root_logger.setLevel(_default_log_level) + library_root_logger.propagate = False + + +def _reset_library_root_logger() -> None: + global _default_handler + + with _lock: + if not _default_handler: + return + + library_root_logger = _get_library_root_logger() + library_root_logger.removeHandler(_default_handler) + library_root_logger.setLevel(logging.NOTSET) + _default_handler = None + + +def set_log_level(level: int, logger: logging.Logger = None) -> None: + """ + Set the log level. + Args: + level (:obj:`int`): + Logging level. + logger (:obj:`logging.Logger`): + Logger to set the log level. + """ + if not logger: + _configure_library_root_logger() + logger = _get_library_root_logger() + logger.setLevel(level) + + +def get_logger( + name: Optional[str] = None, + level: Optional[int] = None, + formatter: Optional[str] = None, +) -> logging.Logger: + """ + Return a logger with the specified name. + """ + + if name is None: + name = _get_library_name() + + _configure_library_root_logger() + + if level is not None: + set_log_level(level) + + if formatter is None: + formatter = logging.Formatter( + "%(asctime)s - %(levelname)s - %(name)s - %(message)s" + ) + _default_handler.setFormatter(formatter) + + return logging.getLogger(name) + + +def get_console_logger(): + return _console diff --git a/relik/common/upload.py b/relik/common/upload.py new file mode 100644 index 0000000000000000000000000000000000000000..b2cad77bd95f43992af3144baf296560a496556b --- /dev/null +++ b/relik/common/upload.py @@ -0,0 +1,128 @@ +import argparse +import json +import logging +import os +import tempfile +import zipfile +from datetime import datetime +from pathlib import Path +from typing import Optional, Union + +import huggingface_hub + +from relik.common.log import get_logger +from relik.common.utils import SAPIENZANLP_DATE_FORMAT, get_md5 + +logger = get_logger(level=logging.DEBUG) + + +def create_info_file(tmpdir: Path): + logger.debug("Computing md5 of model.zip") + md5 = get_md5(tmpdir / "model.zip") + date = datetime.now().strftime(SAPIENZANLP_DATE_FORMAT) + + logger.debug("Dumping info.json file") + with (tmpdir / "info.json").open("w") as f: + json.dump(dict(md5=md5, upload_date=date), f, indent=2) + + +def zip_run( + dir_path: Union[str, os.PathLike], + tmpdir: Union[str, os.PathLike], + zip_name: str = "model.zip", +) -> Path: + logger.debug(f"zipping {dir_path} to {tmpdir}") + # creates a zip version of the provided dir_path + run_dir = Path(dir_path) + zip_path = tmpdir / zip_name + + with zipfile.ZipFile(zip_path, "w") as zip_file: + # fully zip the run directory maintaining its structure + for file in run_dir.rglob("*.*"): + if file.is_dir(): + continue + + zip_file.write(file, arcname=file.relative_to(run_dir)) + + return zip_path + + +def upload( + model_dir: Union[str, os.PathLike], + model_name: str, + organization: Optional[str] = None, + repo_name: Optional[str] = None, + commit: Optional[str] = None, + archive: bool = False, +): + token = huggingface_hub.HfFolder.get_token() + if token is None: + print( + "No HuggingFace token found. You need to execute `huggingface-cli login` first!" + ) + return + + repo_id = repo_name or model_name + if organization is not None: + repo_id = f"{organization}/{repo_id}" + with tempfile.TemporaryDirectory() as tmpdir: + api = huggingface_hub.HfApi() + repo_url = api.create_repo( + token=token, + repo_id=repo_id, + exist_ok=True, + ) + repo = huggingface_hub.Repository( + str(tmpdir), clone_from=repo_url, use_auth_token=token + ) + + tmp_path = Path(tmpdir) + if archive: + # otherwise we zip the model_dir + logger.debug(f"Zipping {model_dir} to {tmp_path}") + zip_run(model_dir, tmp_path) + create_info_file(tmp_path) + else: + # if the user wants to upload a transformers model, we don't need to zip it + # we just need to copy the files to the tmpdir + logger.debug(f"Copying {model_dir} to {tmpdir}") + os.system(f"cp -r {model_dir}/* {tmpdir}") + + # this method automatically puts large files (>10MB) into git lfs + repo.push_to_hub(commit_message=commit or "Automatic push from sapienzanlp") + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument( + "model_dir", help="The directory of the model you want to upload" + ) + parser.add_argument("model_name", help="The model you want to upload") + parser.add_argument( + "--organization", + help="the name of the organization where you want to upload the model", + ) + parser.add_argument( + "--repo_name", + help="Optional name to use when uploading to the HuggingFace repository", + ) + parser.add_argument( + "--commit", help="Commit message to use when pushing to the HuggingFace Hub" + ) + parser.add_argument( + "--archive", + action="store_true", + help=""" + Whether to compress the model directory before uploading it. + If True, the model directory will be zipped and the zip file will be uploaded. + If False, the model directory will be uploaded as is.""", + ) + return parser.parse_args() + + +def main(): + upload(**vars(parse_args())) + + +if __name__ == "__main__": + main() diff --git a/relik/common/utils.py b/relik/common/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6e49b88feca9a8e6cc517b583fd07c924d57ed8b --- /dev/null +++ b/relik/common/utils.py @@ -0,0 +1,609 @@ +import importlib.util +import json +import logging +import os +import shutil +import tarfile +import tempfile +from functools import partial +from hashlib import sha256 +from pathlib import Path +from typing import Any, BinaryIO, Dict, List, Optional, Union +from urllib.parse import urlparse +from zipfile import ZipFile, is_zipfile + +import huggingface_hub +import requests +import tqdm +from filelock import FileLock +from transformers.utils.hub import cached_file as hf_cached_file + +from relik.common.log import get_logger + +# name constants +WEIGHTS_NAME = "weights.pt" +ONNX_WEIGHTS_NAME = "weights.onnx" +CONFIG_NAME = "config.yaml" +LABELS_NAME = "labels.json" + +# SAPIENZANLP_USER_NAME = "sapienzanlp" +SAPIENZANLP_USER_NAME = "riccorl" +SAPIENZANLP_HF_MODEL_REPO_URL = "riccorl/{model_id}" +SAPIENZANLP_HF_MODEL_REPO_ARCHIVE_URL = ( + f"{SAPIENZANLP_HF_MODEL_REPO_URL}/resolve/main/model.zip" +) +# path constants +SAPIENZANLP_CACHE_DIR = os.getenv("SAPIENZANLP_CACHE_DIR", Path.home() / ".sapienzanlp") +SAPIENZANLP_DATE_FORMAT = "%Y-%m-%d %H-%M-%S" + + +logger = get_logger(__name__) + + +def sapienzanlp_model_urls(model_id: str) -> str: + """ + Returns the URL for a possible SapienzaNLP valid model. + + Args: + model_id (:obj:`str`): + A SapienzaNLP model id. + + Returns: + :obj:`str`: The url for the model id. + """ + # check if there is already the namespace of the user + if "/" in model_id: + return model_id + return SAPIENZANLP_HF_MODEL_REPO_URL.format(model_id=model_id) + + +def is_package_available(package_name: str) -> bool: + """ + Check if a package is available. + + Args: + package_name (`str`): The name of the package to check. + """ + return importlib.util.find_spec(package_name) is not None + + +def load_json(path: Union[str, Path]) -> Any: + """ + Load a json file provided in input. + + Args: + path (`Union[str, Path]`): The path to the json file to load. + + Returns: + `Any`: The loaded json file. + """ + with open(path, encoding="utf8") as f: + return json.load(f) + + +def dump_json(document: Any, path: Union[str, Path], indent: Optional[int] = None): + """ + Dump input to json file. + + Args: + document (`Any`): The document to dump. + path (`Union[str, Path]`): The path to dump the document to. + indent (`Optional[int]`): The indent to use for the json file. + + """ + with open(path, "w", encoding="utf8") as outfile: + json.dump(document, outfile, indent=indent) + + +def get_md5(path: Path): + """ + Get the MD5 value of a path. + """ + import hashlib + + with path.open("rb") as fin: + data = fin.read() + return hashlib.md5(data).hexdigest() + + +def file_exists(path: Union[str, os.PathLike]) -> bool: + """ + Check if the file at :obj:`path` exists. + + Args: + path (:obj:`str`, :obj:`os.PathLike`): + Path to check. + + Returns: + :obj:`bool`: :obj:`True` if the file exists. + """ + return Path(path).exists() + + +def dir_exists(path: Union[str, os.PathLike]) -> bool: + """ + Check if the directory at :obj:`path` exists. + + Args: + path (:obj:`str`, :obj:`os.PathLike`): + Path to check. + + Returns: + :obj:`bool`: :obj:`True` if the directory exists. + """ + return Path(path).is_dir() + + +def is_remote_url(url_or_filename: Union[str, Path]): + """ + Returns :obj:`True` if the input path is an url. + + Args: + url_or_filename (:obj:`str`, :obj:`Path`): + path to check. + + Returns: + :obj:`bool`: :obj:`True` if the input path is an url, :obj:`False` otherwise. + + """ + if isinstance(url_or_filename, Path): + url_or_filename = str(url_or_filename) + parsed = urlparse(url_or_filename) + return parsed.scheme in ("http", "https") + + +def url_to_filename(resource: str, etag: str = None) -> str: + """ + Convert a `resource` into a hashed filename in a repeatable way. + If `etag` is specified, append its hash to the resources's, delimited + by a period. + """ + resource_bytes = resource.encode("utf-8") + resource_hash = sha256(resource_bytes) + filename = resource_hash.hexdigest() + + if etag: + etag_bytes = etag.encode("utf-8") + etag_hash = sha256(etag_bytes) + filename += "." + etag_hash.hexdigest() + + return filename + + +def download_resource( + url: str, + temp_file: BinaryIO, + headers=None, +): + """ + Download remote file. + """ + + if headers is None: + headers = {} + + r = requests.get(url, stream=True, headers=headers) + r.raise_for_status() + content_length = r.headers.get("Content-Length") + total = int(content_length) if content_length is not None else None + progress = tqdm( + unit="B", + unit_scale=True, + total=total, + desc="Downloading", + disable=logger.level in [logging.NOTSET], + ) + for chunk in r.iter_content(chunk_size=1024): + if chunk: # filter out keep-alive new chunks + progress.update(len(chunk)) + temp_file.write(chunk) + progress.close() + + +def download_and_cache( + url: Union[str, Path], + cache_dir: Union[str, Path] = None, + force_download: bool = False, +): + if cache_dir is None: + cache_dir = SAPIENZANLP_CACHE_DIR + if isinstance(url, Path): + url = str(url) + + # check if cache dir exists + Path(cache_dir).mkdir(parents=True, exist_ok=True) + + # check if file is private + headers = {} + try: + r = requests.head(url, allow_redirects=False, timeout=10) + r.raise_for_status() + except requests.exceptions.HTTPError: + if r.status_code == 401: + hf_token = huggingface_hub.HfFolder.get_token() + if hf_token is None: + raise ValueError( + "You need to login to HuggingFace to download this model " + "(use the `huggingface-cli login` command)" + ) + headers["Authorization"] = f"Bearer {hf_token}" + + etag = None + try: + r = requests.head(url, allow_redirects=True, timeout=10, headers=headers) + r.raise_for_status() + etag = r.headers.get("X-Linked-Etag") or r.headers.get("ETag") + # We favor a custom header indicating the etag of the linked resource, and + # we fallback to the regular etag header. + # If we don't have any of those, raise an error. + if etag is None: + raise OSError( + "Distant resource does not have an ETag, we won't be able to reliably ensure reproducibility." + ) + # In case of a redirect, + # save an extra redirect on the request.get call, + # and ensure we download the exact atomic version even if it changed + # between the HEAD and the GET (unlikely, but hey). + if 300 <= r.status_code <= 399: + url = r.headers["Location"] + except (requests.exceptions.SSLError, requests.exceptions.ProxyError): + # Actually raise for those subclasses of ConnectionError + raise + except (requests.exceptions.ConnectionError, requests.exceptions.Timeout): + # Otherwise, our Internet connection is down. + # etag is None + pass + + # get filename from the url + filename = url_to_filename(url, etag) + # get cache path to put the file + cache_path = cache_dir / filename + + # the file is already here, return it + if file_exists(cache_path) and not force_download: + logger.info( + f"{url} found in cache, set `force_download=True` to force the download" + ) + return cache_path + + cache_path = str(cache_path) + # Prevent parallel downloads of the same file with a lock. + lock_path = cache_path + ".lock" + with FileLock(lock_path): + # If the download just completed while the lock was activated. + if file_exists(cache_path) and not force_download: + # Even if returning early like here, the lock will be released. + return cache_path + + temp_file_manager = partial( + tempfile.NamedTemporaryFile, mode="wb", dir=cache_dir, delete=False + ) + + # Download to temporary file, then copy to cache dir once finished. + # Otherwise, you get corrupt cache entries if the download gets interrupted. + with temp_file_manager() as temp_file: + logger.info( + f"{url} not found in cache or `force_download` set to `True`, downloading to {temp_file.name}" + ) + download_resource(url, temp_file, headers) + + logger.info(f"storing {url} in cache at {cache_path}") + os.replace(temp_file.name, cache_path) + + # NamedTemporaryFile creates a file with hardwired 0600 perms (ignoring umask), so fixing it. + umask = os.umask(0o666) + os.umask(umask) + os.chmod(cache_path, 0o666 & ~umask) + + logger.info(f"creating metadata file for {cache_path}") + meta = {"url": url} # , "etag": etag} + meta_path = cache_path + ".json" + with open(meta_path, "w") as meta_file: + json.dump(meta, meta_file) + + return cache_path + + +def download_from_hf( + path_or_repo_id: Union[str, Path], + filenames: Optional[List[str]], + cache_dir: Union[str, Path] = None, + force_download: bool = False, + resume_download: bool = False, + proxies: Optional[Dict[str, str]] = None, + use_auth_token: Optional[Union[bool, str]] = None, + revision: Optional[str] = None, + local_files_only: bool = False, + subfolder: str = "", +): + if isinstance(path_or_repo_id, Path): + path_or_repo_id = str(path_or_repo_id) + + downloaded_paths = [] + for filename in filenames: + downloaded_path = hf_cached_file( + path_or_repo_id, + filename, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + use_auth_token=use_auth_token, + revision=revision, + local_files_only=local_files_only, + subfolder=subfolder, + ) + downloaded_paths.append(downloaded_path) + + # we want the folder where the files are downloaded + # the best guess is the parent folder of the first file + probably_the_folder = Path(downloaded_paths[0]).parent + return probably_the_folder + + +def model_name_or_path_resolver(model_name_or_dir: Union[str, os.PathLike]) -> str: + """ + Resolve a model name or directory to a model archive name or directory. + + Args: + model_name_or_dir (:obj:`str` or :obj:`os.PathLike`): + A model name or directory. + + Returns: + :obj:`str`: The model archive name or directory. + """ + if is_remote_url(model_name_or_dir): + # if model_name_or_dir is a URL + # download it and try to load + model_archive = model_name_or_dir + elif Path(model_name_or_dir).is_dir() or Path(model_name_or_dir).is_file(): + # if model_name_or_dir is a local directory or + # an archive file try to load it + model_archive = model_name_or_dir + else: + # probably model_name_or_dir is a sapienzanlp model id + # guess the url and try to download + model_name_or_dir_ = model_name_or_dir + # raise ValueError(f"Providing a model id is not supported yet.") + model_archive = sapienzanlp_model_urls(model_name_or_dir_) + + return model_archive + + +def from_cache( + url_or_filename: Union[str, Path], + cache_dir: Union[str, Path] = None, + force_download: bool = False, + resume_download: bool = False, + proxies: Optional[Dict[str, str]] = None, + use_auth_token: Optional[Union[bool, str]] = None, + revision: Optional[str] = None, + local_files_only: bool = False, + subfolder: str = "", + filenames: Optional[List[str]] = None, +) -> Path: + """ + Given something that could be either a local path or a URL (or a SapienzaNLP model id), + determine which one and return a path to the corresponding file. + + Args: + url_or_filename (:obj:`str` or :obj:`Path`): + A path to a local file or a URL (or a SapienzaNLP model id). + cache_dir (:obj:`str` or :obj:`Path`, `optional`): + Path to a directory in which a downloaded file will be cached. + force_download (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to re-download the file even if it already exists. + resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to delete incompletely received files. Attempts to resume the download if such a file + exists. + proxies (:obj:`Dict[str, str]`, `optional`): + A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. + use_auth_token (:obj:`Union[bool, str]`, `optional`): + Optional string or boolean to use as Bearer token for remote files. If :obj:`True`, will get token from + :obj:`~transformers.hf_api.HfApi`. If :obj:`str`, will use that string as token. + revision (:obj:`str`, `optional`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any + identifier allowed by git. + local_files_only (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to raise an error if the file to be downloaded is local. + subfolder (:obj:`str`, `optional`): + In case the relevant file is in a subfolder of the URL, specify it here. + filenames (:obj:`List[str]`, `optional`): + List of filenames to look for in the directory structure. + + Returns: + :obj:`Path`: Path to the cached file. + """ + + url_or_filename = model_name_or_path_resolver(url_or_filename) + + if cache_dir is None: + cache_dir = SAPIENZANLP_CACHE_DIR + + if file_exists(url_or_filename): + logger.info(f"{url_or_filename} is a local path or file") + output_path = url_or_filename + elif is_remote_url(url_or_filename): + # URL, so get it from the cache (downloading if necessary) + output_path = download_and_cache( + url_or_filename, + cache_dir=cache_dir, + force_download=force_download, + ) + else: + if filenames is None: + filenames = [WEIGHTS_NAME, CONFIG_NAME, LABELS_NAME] + output_path = download_from_hf( + url_or_filename, + filenames, + cache_dir, + force_download, + resume_download, + proxies, + use_auth_token, + revision, + local_files_only, + subfolder, + ) + + # if is_hf_hub_url(url_or_filename): + # HuggingFace Hub + # output_path = hf_hub_download_url(url_or_filename) + # elif is_remote_url(url_or_filename): + # # URL, so get it from the cache (downloading if necessary) + # output_path = download_and_cache( + # url_or_filename, + # cache_dir=cache_dir, + # force_download=force_download, + # ) + # elif file_exists(url_or_filename): + # logger.info(f"{url_or_filename} is a local path or file") + # # File, and it exists. + # output_path = url_or_filename + # elif urlparse(url_or_filename).scheme == "": + # # File, but it doesn't exist. + # raise EnvironmentError(f"file {url_or_filename} not found") + # else: + # # Something unknown + # raise ValueError( + # f"unable to parse {url_or_filename} as a URL or as a local path" + # ) + + if dir_exists(output_path) or ( + not is_zipfile(output_path) and not tarfile.is_tarfile(output_path) + ): + return Path(output_path) + + # Path where we extract compressed archives + # for now it will extract it in the same folder + # maybe implement extraction in the sapienzanlp folder + # when using local archive path? + logger.info("Extracting compressed archive") + output_dir, output_file = os.path.split(output_path) + output_extract_dir_name = output_file.replace(".", "-") + "-extracted" + output_path_extracted = os.path.join(output_dir, output_extract_dir_name) + + # already extracted, do not extract + if ( + os.path.isdir(output_path_extracted) + and os.listdir(output_path_extracted) + and not force_download + ): + return Path(output_path_extracted) + + # Prevent parallel extractions + lock_path = output_path + ".lock" + with FileLock(lock_path): + shutil.rmtree(output_path_extracted, ignore_errors=True) + os.makedirs(output_path_extracted) + if is_zipfile(output_path): + with ZipFile(output_path, "r") as zip_file: + zip_file.extractall(output_path_extracted) + zip_file.close() + elif tarfile.is_tarfile(output_path): + tar_file = tarfile.open(output_path) + tar_file.extractall(output_path_extracted) + tar_file.close() + else: + raise EnvironmentError( + f"Archive format of {output_path} could not be identified" + ) + + # remove lock file, is it safe? + os.remove(lock_path) + + return Path(output_path_extracted) + + +def is_str_a_path(maybe_path: str) -> bool: + """ + Check if a string is a path. + + Args: + maybe_path (`str`): The string to check. + + Returns: + `bool`: `True` if the string is a path, `False` otherwise. + """ + # first check if it is a path + if Path(maybe_path).exists(): + return True + # check if it is a relative path + if Path(os.path.join(os.getcwd(), maybe_path)).exists(): + return True + # otherwise it is not a path + return False + + +def relative_to_absolute_path(path: str) -> os.PathLike: + """ + Convert a relative path to an absolute path. + + Args: + path (`str`): The relative path to convert. + + Returns: + `os.PathLike`: The absolute path. + """ + if not is_str_a_path(path): + raise ValueError(f"{path} is not a path") + if Path(path).exists(): + return Path(path).absolute() + if Path(os.path.join(os.getcwd(), path)).exists(): + return Path(os.path.join(os.getcwd(), path)).absolute() + raise ValueError(f"{path} is not a path") + + +def to_config(object_to_save: Any) -> Dict[str, Any]: + """ + Convert an object to a dictionary. + + Returns: + `Dict[str, Any]`: The dictionary representation of the object. + """ + + def obj_to_dict(obj): + match obj: + case dict(): + data = {} + for k, v in obj.items(): + data[k] = obj_to_dict(v) + return data + + case list() | tuple(): + return [obj_to_dict(x) for x in obj] + + case object(__dict__=_): + data = { + "_target_": f"{obj.__class__.__module__}.{obj.__class__.__name__}", + } + for k, v in obj.__dict__.items(): + if not k.startswith("_"): + data[k] = obj_to_dict(v) + return data + + case _: + return obj + + return obj_to_dict(object_to_save) + + +def get_callable_from_string(callable_fn: str) -> Any: + """ + Get a callable from a string. + + Args: + callable_fn (`str`): + The string representation of the callable. + + Returns: + `Any`: The callable. + """ + # separate the function name from the module name + module_name, function_name = callable_fn.rsplit(".", 1) + # import the module + module = importlib.import_module(module_name) + # get the function + return getattr(module, function_name) diff --git a/relik/inference/__init__.py b/relik/inference/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/relik/inference/annotator.py b/relik/inference/annotator.py new file mode 100644 index 0000000000000000000000000000000000000000..de356f690985a1496b607664dd77d5e49546257d --- /dev/null +++ b/relik/inference/annotator.py @@ -0,0 +1,422 @@ +import os +from pathlib import Path +from typing import Any, Callable, Dict, Optional, Union + +import hydra +from omegaconf import OmegaConf +from relik.retriever.pytorch_modules.hf import GoldenRetrieverModel +from rich.pretty import pprint + +from relik.common.log import get_console_logger, get_logger +from relik.common.upload import upload +from relik.common.utils import CONFIG_NAME, from_cache, get_callable_from_string +from relik.inference.data.objects import EntitySpan, RelikOutput +from relik.inference.data.tokenizers.spacy_tokenizer import SpacyTokenizer +from relik.inference.data.window.manager import WindowManager +from relik.reader.pytorch_modules.span import RelikReaderForSpanExtraction +from relik.reader.relik_reader import RelikReader +from relik.retriever.data.utils import batch_generator +from relik.retriever.indexers.base import BaseDocumentIndex +from relik.retriever.pytorch_modules.model import GoldenRetriever + +logger = get_logger(__name__) +console_logger = get_console_logger() + + +class Relik: + """ + Relik main class. It is a wrapper around a retriever and a reader. + + Args: + retriever (`Optional[GoldenRetriever]`, `optional`): + The retriever to use. If `None`, a retriever will be instantiated from the + provided `question_encoder`, `passage_encoder` and `document_index`. + Defaults to `None`. + question_encoder (`Optional[Union[str, GoldenRetrieverModel]]`, `optional`): + The question encoder to use. If `retriever` is `None`, a retriever will be + instantiated from this parameter. Defaults to `None`. + passage_encoder (`Optional[Union[str, GoldenRetrieverModel]]`, `optional`): + The passage encoder to use. If `retriever` is `None`, a retriever will be + instantiated from this parameter. Defaults to `None`. + document_index (`Optional[Union[str, BaseDocumentIndex]]`, `optional`): + The document index to use. If `retriever` is `None`, a retriever will be + instantiated from this parameter. Defaults to `None`. + reader (`Optional[Union[str, RelikReader]]`, `optional`): + The reader to use. If `None`, a reader will be instantiated from the + provided `reader`. Defaults to `None`. + retriever_device (`str`, `optional`, defaults to `cpu`): + The device to use for the retriever. + + """ + + def __init__( + self, + retriever: GoldenRetriever | None = None, + question_encoder: str | GoldenRetrieverModel | None = None, + passage_encoder: str | GoldenRetrieverModel | None = None, + document_index: str | BaseDocumentIndex | None = None, + reader: str | RelikReader | None = None, + device: str = "cpu", + retriever_device: str | None = None, + document_index_device: str | None = None, + reader_device: str | None = None, + precision: int = 32, + retriever_precision: int | None = None, + document_index_precision: int | None = None, + reader_precision: int | None = None, + reader_kwargs: dict | None = None, + retriever_kwargs: dict | None = None, + candidates_preprocessing_fn: str | Callable | None = None, + top_k: int | None = None, + window_size: int | None = None, + window_stride: int | None = None, + **kwargs, + ) -> None: + # retriever + retriever_device = retriever_device or device + document_index_device = document_index_device or device + retriever_precision = retriever_precision or precision + document_index_precision = document_index_precision or precision + if retriever is None and question_encoder is None: + raise ValueError( + "Either `retriever` or `question_encoder` must be provided" + ) + if retriever is None: + self.retriever_kwargs = dict( + question_encoder=question_encoder, + passage_encoder=passage_encoder, + document_index=document_index, + device=retriever_device, + precision=retriever_precision, + index_device=document_index_device, + index_precision=document_index_precision, + ) + # overwrite default_retriever_kwargs with retriever_kwargs + self.retriever_kwargs.update(retriever_kwargs or {}) + retriever = GoldenRetriever(**self.retriever_kwargs) + retriever.training = False + retriever.eval() + self.retriever = retriever + + # reader + self.reader_device = reader_device or device + self.reader_precision = reader_precision or precision + self.reader_kwargs = reader_kwargs + if isinstance(reader, str): + reader_kwargs = reader_kwargs or {} + reader = RelikReaderForSpanExtraction(reader, **reader_kwargs) + self.reader = reader + + # windowization stuff + self.tokenizer = SpacyTokenizer(language="en") + self.window_manager: WindowManager | None = None + + # candidates preprocessing + # TODO: maybe move this logic somewhere else + candidates_preprocessing_fn = candidates_preprocessing_fn or (lambda x: x) + if isinstance(candidates_preprocessing_fn, str): + candidates_preprocessing_fn = get_callable_from_string( + candidates_preprocessing_fn + ) + self.candidates_preprocessing_fn = candidates_preprocessing_fn + + # inference params + self.top_k = top_k + self.window_size = window_size + self.window_stride = window_stride + + def __call__( + self, + text: Union[str, list], + top_k: Optional[int] = None, + window_size: Optional[int] = None, + window_stride: Optional[int] = None, + retriever_batch_size: Optional[int] = 32, + reader_batch_size: Optional[int] = 32, + return_also_windows: bool = False, + **kwargs, + ) -> Union[RelikOutput, list[RelikOutput]]: + """ + Annotate a text with entities. + + Args: + text (`str` or `list`): + The text to annotate. If a list is provided, each element of the list + will be annotated separately. + top_k (`int`, `optional`, defaults to `None`): + The number of candidates to retrieve for each window. + window_size (`int`, `optional`, defaults to `None`): + The size of the window. If `None`, the whole text will be annotated. + window_stride (`int`, `optional`, defaults to `None`): + The stride of the window. If `None`, there will be no overlap between windows. + retriever_batch_size (`int`, `optional`, defaults to `None`): + The batch size to use for the retriever. The whole input is the batch for the retriever. + reader_batch_size (`int`, `optional`, defaults to `None`): + The batch size to use for the reader. The whole input is the batch for the reader. + return_also_windows (`bool`, `optional`, defaults to `False`): + Whether to return the windows in the output. + **kwargs: + Additional keyword arguments to pass to the retriever and the reader. + + Returns: + `RelikOutput` or `list[RelikOutput]`: + The annotated text. If a list was provided as input, a list of + `RelikOutput` objects will be returned. + """ + if top_k is None: + top_k = self.top_k or 100 + if window_size is None: + window_size = self.window_size + if window_stride is None: + window_stride = self.window_stride + + if isinstance(text, str): + text = [text] + + if window_size is not None: + if self.window_manager is None: + self.window_manager = WindowManager(self.tokenizer) + + if window_size == "sentence": + # todo: implement sentence windowizer + raise NotImplementedError("Sentence windowizer not implemented yet") + + # if window_size < window_stride: + # raise ValueError( + # f"Window size ({window_size}) must be greater than window stride ({window_stride})" + # ) + + # window generator + windows = [ + window + for doc_id, t in enumerate(text) + for window in self.window_manager.create_windows( + t, + window_size=window_size, + stride=window_stride, + doc_id=doc_id, + ) + ] + + # retrieve candidates first + windows_candidates = [] + # TODO: Move batching inside retriever + for batch in batch_generator(windows, batch_size=retriever_batch_size): + retriever_out = self.retriever.retrieve([b.text for b in batch], k=top_k) + windows_candidates.extend( + [[p.label for p in predictions] for predictions in retriever_out] + ) + + # add passage to the windows + for window, candidates in zip(windows, windows_candidates): + window.window_candidates = [ + self.candidates_preprocessing_fn(c) for c in candidates + ] + + windows = self.reader.read(samples=windows, max_batch_size=reader_batch_size) + windows = self.window_manager.merge_windows(windows) + + # transform predictions into RelikOutput objects + output = [] + for w in windows: + sample_output = RelikOutput( + text=text[w.doc_id], + labels=sorted( + [ + EntitySpan( + start=ss, end=se, label=sl, text=text[w.doc_id][ss:se] + ) + for ss, se, sl in w.predicted_window_labels_chars + ], + key=lambda x: x.start, + ), + ) + output.append(sample_output) + + if return_also_windows: + for i, sample_output in enumerate(output): + sample_output.windows = [w for w in windows if w.doc_id == i] + + # if only one text was provided, return a single RelikOutput object + if len(output) == 1: + return output[0] + + return output + + @classmethod + def from_pretrained( + cls, + model_name_or_dir: Union[str, os.PathLike], + config_kwargs: Optional[Dict] = None, + config_file_name: str = CONFIG_NAME, + *args, + **kwargs, + ) -> "Relik": + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + + model_dir = from_cache( + model_name_or_dir, + filenames=[config_file_name], + cache_dir=cache_dir, + force_download=force_download, + ) + + config_path = model_dir / config_file_name + if not config_path.exists(): + raise FileNotFoundError( + f"Model configuration file not found at {config_path}." + ) + + # overwrite config with config_kwargs + config = OmegaConf.load(config_path) + if config_kwargs is not None: + # TODO: check merging behavior + config = OmegaConf.merge(config, OmegaConf.create(config_kwargs)) + # do we want to print the config? I like it + pprint(OmegaConf.to_container(config), console=console_logger, expand_all=True) + + # load relik from config + relik = hydra.utils.instantiate(config, *args, **kwargs) + + return relik + + def save_pretrained( + self, + output_dir: Union[str, os.PathLike], + config: Optional[Dict[str, Any]] = None, + config_file_name: Optional[str] = None, + save_weights: bool = False, + push_to_hub: bool = False, + model_id: Optional[str] = None, + organization: Optional[str] = None, + repo_name: Optional[str] = None, + **kwargs, + ): + """ + Save the configuration of Relik to the specified directory as a YAML file. + + Args: + output_dir (`str`): + The directory to save the configuration file to. + config (`Optional[Dict[str, Any]]`, `optional`): + The configuration to save. If `None`, the current configuration will be + saved. Defaults to `None`. + config_file_name (`Optional[str]`, `optional`): + The name of the configuration file. Defaults to `config.yaml`. + save_weights (`bool`, `optional`): + Whether to save the weights of the model. Defaults to `False`. + push_to_hub (`bool`, `optional`): + Whether to push the saved model to the hub. Defaults to `False`. + model_id (`Optional[str]`, `optional`): + The id of the model to push to the hub. If `None`, the name of the + directory will be used. Defaults to `None`. + organization (`Optional[str]`, `optional`): + The organization to push the model to. Defaults to `None`. + repo_name (`Optional[str]`, `optional`): + The name of the repository to push the model to. Defaults to `None`. + **kwargs: + Additional keyword arguments to pass to `OmegaConf.save`. + """ + if config is None: + # create a default config + config = { + "_target_": f"{self.__class__.__module__}.{self.__class__.__name__}" + } + if self.retriever is not None: + if self.retriever.question_encoder is not None: + config[ + "question_encoder" + ] = self.retriever.question_encoder.name_or_path + if self.retriever.passage_encoder is not None: + config[ + "passage_encoder" + ] = self.retriever.passage_encoder.name_or_path + if self.retriever.document_index is not None: + config["document_index"] = self.retriever.document_index.name_or_dir + if self.reader is not None: + config["reader"] = self.reader.model_path + + config["retriever_kwargs"] = self.retriever_kwargs + config["reader_kwargs"] = self.reader_kwargs + # expand the fn as to be able to save it and load it later + config[ + "candidates_preprocessing_fn" + ] = f"{self.candidates_preprocessing_fn.__module__}.{self.candidates_preprocessing_fn.__name__}" + + # these are model-specific and should be saved + config["top_k"] = self.top_k + config["window_size"] = self.window_size + config["window_stride"] = self.window_stride + + config_file_name = config_file_name or CONFIG_NAME + + # create the output directory + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + logger.info(f"Saving relik config to {output_dir / config_file_name}") + # pretty print the config + pprint(config, console=console_logger, expand_all=True) + OmegaConf.save(config, output_dir / config_file_name) + + if save_weights: + model_id = model_id or output_dir.name + retriever_model_id = model_id + "-retriever" + # save weights + logger.info(f"Saving retriever to {output_dir / retriever_model_id}") + self.retriever.save_pretrained( + output_dir / retriever_model_id, + question_encoder_name=retriever_model_id + "-question-encoder", + passage_encoder_name=retriever_model_id + "-passage-encoder", + document_index_name=retriever_model_id + "-index", + push_to_hub=push_to_hub, + organization=organization, + repo_name=repo_name, + **kwargs, + ) + reader_model_id = model_id + "-reader" + logger.info(f"Saving reader to {output_dir / reader_model_id}") + self.reader.save_pretrained( + output_dir / reader_model_id, + push_to_hub=push_to_hub, + organization=organization, + repo_name=repo_name, + **kwargs, + ) + + if push_to_hub: + # push to hub + logger.info(f"Pushing to hub") + model_id = model_id or output_dir.name + upload(output_dir, model_id, organization=organization, repo_name=repo_name) + + +def main(): + from pprint import pprint + + relik = Relik( + question_encoder="riccorl/relik-retriever-aida-blink-pretrain-omniencoder", + document_index="riccorl/index-relik-retriever-aida-blink-pretrain-omniencoder", + reader="riccorl/relik-reader-aida-deberta-small", + device="cuda", + precision=16, + top_k=100, + window_size=32, + window_stride=16, + candidates_preprocessing_fn="relik.inference.preprocessing.wikipedia_title_and_openings_preprocessing", + ) + + input_text = """ + Bernie Ecclestone, the former boss of Formula One, has admitted fraud after failing to declare more than £400m held in a trust in Singapore. + The 92-year-old billionaire did not disclose the trust to the government in July 2015. + Appearing at Southwark Crown Court on Thursday, he told the judge "I plead guilty" after having previously pleaded not guilty. + Ecclestone had been due to go on trial next month. + """ + + preds = relik(input_text) + pprint(preds) + + +if __name__ == "__main__": + main() diff --git a/relik/inference/data/__init__.py b/relik/inference/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/relik/inference/data/objects.py b/relik/inference/data/objects.py new file mode 100644 index 0000000000000000000000000000000000000000..4b11e9641380b9e13d60de427827a73b70cbb9c1 --- /dev/null +++ b/relik/inference/data/objects.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import List, NamedTuple, Optional + +from relik.reader.pytorch_modules.hf.modeling_relik import RelikReaderSample + + +@dataclass +class Word: + """ + A word representation that includes text, index in the sentence, POS tag, lemma, + dependency relation, and similar information. + + # Parameters + text : `str`, optional + The text representation. + index : `int`, optional + The word offset in the sentence. + lemma : `str`, optional + The lemma of this word. + pos : `str`, optional + The coarse-grained part of speech of this word. + dep : `str`, optional + The dependency relation for this word. + + input_id : `int`, optional + Integer representation of the word, used to pass it to a model. + token_type_id : `int`, optional + Token type id used by some transformers. + attention_mask: `int`, optional + Attention mask used by transformers, indicates to the model which tokens should + be attended to, and which should not. + """ + + text: str + index: int + start_char: Optional[int] = None + end_char: Optional[int] = None + # preprocessing fields + lemma: Optional[str] = None + pos: Optional[str] = None + dep: Optional[str] = None + head: Optional[int] = None + + def __str__(self): + return self.text + + def __repr__(self): + return self.__str__() + + +class EntitySpan(NamedTuple): + start: int + end: int + label: str + text: str + + +@dataclass +class RelikOutput: + text: str + labels: List[EntitySpan] + windows: Optional[List[RelikReaderSample]] = None diff --git a/relik/inference/data/tokenizers/__init__.py b/relik/inference/data/tokenizers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ad70314e8e0ccc18b946ff1317f6415c1892747a --- /dev/null +++ b/relik/inference/data/tokenizers/__init__.py @@ -0,0 +1,89 @@ +SPACY_LANGUAGE_MAPPER = { + "ca": "ca_core_news_sm", + "da": "da_core_news_sm", + "de": "de_core_news_sm", + "el": "el_core_news_sm", + "en": "en_core_web_sm", + "es": "es_core_news_sm", + "fr": "fr_core_news_sm", + "it": "it_core_news_sm", + "ja": "ja_core_news_sm", + "lt": "lt_core_news_sm", + "mk": "mk_core_news_sm", + "nb": "nb_core_news_sm", + "nl": "nl_core_news_sm", + "pl": "pl_core_news_sm", + "pt": "pt_core_news_sm", + "ro": "ro_core_news_sm", + "ru": "ru_core_news_sm", + "xx": "xx_sent_ud_sm", + "zh": "zh_core_web_sm", + "ca_core_news_sm": "ca_core_news_sm", + "ca_core_news_md": "ca_core_news_md", + "ca_core_news_lg": "ca_core_news_lg", + "ca_core_news_trf": "ca_core_news_trf", + "da_core_news_sm": "da_core_news_sm", + "da_core_news_md": "da_core_news_md", + "da_core_news_lg": "da_core_news_lg", + "da_core_news_trf": "da_core_news_trf", + "de_core_news_sm": "de_core_news_sm", + "de_core_news_md": "de_core_news_md", + "de_core_news_lg": "de_core_news_lg", + "de_dep_news_trf": "de_dep_news_trf", + "el_core_news_sm": "el_core_news_sm", + "el_core_news_md": "el_core_news_md", + "el_core_news_lg": "el_core_news_lg", + "en_core_web_sm": "en_core_web_sm", + "en_core_web_md": "en_core_web_md", + "en_core_web_lg": "en_core_web_lg", + "en_core_web_trf": "en_core_web_trf", + "es_core_news_sm": "es_core_news_sm", + "es_core_news_md": "es_core_news_md", + "es_core_news_lg": "es_core_news_lg", + "es_dep_news_trf": "es_dep_news_trf", + "fr_core_news_sm": "fr_core_news_sm", + "fr_core_news_md": "fr_core_news_md", + "fr_core_news_lg": "fr_core_news_lg", + "fr_dep_news_trf": "fr_dep_news_trf", + "it_core_news_sm": "it_core_news_sm", + "it_core_news_md": "it_core_news_md", + "it_core_news_lg": "it_core_news_lg", + "ja_core_news_sm": "ja_core_news_sm", + "ja_core_news_md": "ja_core_news_md", + "ja_core_news_lg": "ja_core_news_lg", + "ja_dep_news_trf": "ja_dep_news_trf", + "lt_core_news_sm": "lt_core_news_sm", + "lt_core_news_md": "lt_core_news_md", + "lt_core_news_lg": "lt_core_news_lg", + "mk_core_news_sm": "mk_core_news_sm", + "mk_core_news_md": "mk_core_news_md", + "mk_core_news_lg": "mk_core_news_lg", + "nb_core_news_sm": "nb_core_news_sm", + "nb_core_news_md": "nb_core_news_md", + "nb_core_news_lg": "nb_core_news_lg", + "nl_core_news_sm": "nl_core_news_sm", + "nl_core_news_md": "nl_core_news_md", + "nl_core_news_lg": "nl_core_news_lg", + "pl_core_news_sm": "pl_core_news_sm", + "pl_core_news_md": "pl_core_news_md", + "pl_core_news_lg": "pl_core_news_lg", + "pt_core_news_sm": "pt_core_news_sm", + "pt_core_news_md": "pt_core_news_md", + "pt_core_news_lg": "pt_core_news_lg", + "ro_core_news_sm": "ro_core_news_sm", + "ro_core_news_md": "ro_core_news_md", + "ro_core_news_lg": "ro_core_news_lg", + "ru_core_news_sm": "ru_core_news_sm", + "ru_core_news_md": "ru_core_news_md", + "ru_core_news_lg": "ru_core_news_lg", + "xx_ent_wiki_sm": "xx_ent_wiki_sm", + "xx_sent_ud_sm": "xx_sent_ud_sm", + "zh_core_web_sm": "zh_core_web_sm", + "zh_core_web_md": "zh_core_web_md", + "zh_core_web_lg": "zh_core_web_lg", + "zh_core_web_trf": "zh_core_web_trf", +} + +from relik.inference.data.tokenizers.regex_tokenizer import RegexTokenizer +from relik.inference.data.tokenizers.spacy_tokenizer import SpacyTokenizer +from relik.inference.data.tokenizers.whitespace_tokenizer import WhitespaceTokenizer diff --git a/relik/inference/data/tokenizers/base_tokenizer.py b/relik/inference/data/tokenizers/base_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..1fed161b3eca085656e85d44cb9a64739f3d1e4c --- /dev/null +++ b/relik/inference/data/tokenizers/base_tokenizer.py @@ -0,0 +1,84 @@ +from typing import List, Union + +from relik.inference.data.objects import Word + + +class BaseTokenizer: + """ + A :obj:`Tokenizer` splits strings of text into single words, optionally adds + pos tags and perform lemmatization. + """ + + def __call__( + self, + texts: Union[str, List[str], List[List[str]]], + is_split_into_words: bool = False, + **kwargs + ) -> List[List[Word]]: + """ + Tokenize the input into single words. + + Args: + texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`): + Text to tag. It can be a single string, a batch of string and pre-tokenized strings. + is_split_into_words (:obj:`bool`, optional, defaults to :obj:`False`): + If :obj:`True` and the input is a string, the input is split on spaces. + + Returns: + :obj:`List[List[Word]]`: The input text tokenized in single words. + """ + raise NotImplementedError + + def tokenize(self, text: str) -> List[Word]: + """ + Implements splitting words into tokens. + + Args: + text (:obj:`str`): + Text to tokenize. + + Returns: + :obj:`List[Word]`: The input text tokenized in single words. + + """ + raise NotImplementedError + + def tokenize_batch(self, texts: List[str]) -> List[List[Word]]: + """ + Implements batch splitting words into tokens. + + Args: + texts (:obj:`List[str]`): + Batch of text to tokenize. + + Returns: + :obj:`List[List[Word]]`: The input batch tokenized in single words. + + """ + return [self.tokenize(text) for text in texts] + + @staticmethod + def check_is_batched( + texts: Union[str, List[str], List[List[str]]], is_split_into_words: bool + ): + """ + Check if input is batched or a single sample. + + Args: + texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`): + Text to check. + is_split_into_words (:obj:`bool`): + If :obj:`True` and the input is a string, the input is split on spaces. + + Returns: + :obj:`bool`: ``True`` if ``texts`` is batched, ``False`` otherwise. + """ + return bool( + (not is_split_into_words and isinstance(texts, (list, tuple))) + or ( + is_split_into_words + and isinstance(texts, (list, tuple)) + and texts + and isinstance(texts[0], (list, tuple)) + ) + ) diff --git a/relik/inference/data/tokenizers/regex_tokenizer.py b/relik/inference/data/tokenizers/regex_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..ebe8656afb891a8318a7030375427e190d1dc383 --- /dev/null +++ b/relik/inference/data/tokenizers/regex_tokenizer.py @@ -0,0 +1,73 @@ +import re +from typing import List, Union + +from overrides import overrides + +from relik.inference.data.objects import Word +from relik.inference.data.tokenizers.base_tokenizer import BaseTokenizer + + +class RegexTokenizer(BaseTokenizer): + """ + A :obj:`Tokenizer` that splits the text based on a simple regex. + """ + + def __init__(self): + super(RegexTokenizer, self).__init__() + # regex for splitting on spaces and punctuation and new lines + # self._regex = re.compile(r"\S+|[\[\](),.!?;:\"]|\\n") + self._regex = re.compile( + r"\w+|\$[\d\.]+|\S+", re.UNICODE | re.MULTILINE | re.DOTALL + ) + + def __call__( + self, + texts: Union[str, List[str], List[List[str]]], + is_split_into_words: bool = False, + **kwargs, + ) -> List[List[Word]]: + """ + Tokenize the input into single words by splitting using a simple regex. + + Args: + texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`): + Text to tag. It can be a single string, a batch of string and pre-tokenized strings. + is_split_into_words (:obj:`bool`, optional, defaults to :obj:`False`): + If :obj:`True` and the input is a string, the input is split on spaces. + + Returns: + :obj:`List[List[Word]]`: The input text tokenized in single words. + + Example:: + + >>> from relik.retriever.serve.tokenizers.regex_tokenizer import RegexTokenizer + + >>> regex_tokenizer = RegexTokenizer() + >>> regex_tokenizer("Mary sold the car to John.") + + """ + # check if input is batched or a single sample + is_batched = self.check_is_batched(texts, is_split_into_words) + + if is_batched: + tokenized = self.tokenize_batch(texts) + else: + tokenized = self.tokenize(texts) + + return tokenized + + @overrides + def tokenize(self, text: Union[str, List[str]]) -> List[Word]: + if not isinstance(text, (str, list)): + raise ValueError( + f"text must be either `str` or `list`, found: `{type(text)}`" + ) + + if isinstance(text, list): + text = " ".join(text) + return [ + Word(t[0], i, start_char=t[1], end_char=t[2]) + for i, t in enumerate( + (m.group(0), m.start(), m.end()) for m in self._regex.finditer(text) + ) + ] diff --git a/relik/inference/data/tokenizers/spacy_tokenizer.py b/relik/inference/data/tokenizers/spacy_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..b949216ed5cf152ae4a7722c4a6be3f883481db2 --- /dev/null +++ b/relik/inference/data/tokenizers/spacy_tokenizer.py @@ -0,0 +1,228 @@ +import logging +from typing import Dict, List, Tuple, Union + +import spacy + +# from ipa.common.utils import load_spacy +from overrides import overrides +from spacy.cli.download import download as spacy_download +from spacy.tokens import Doc + +from relik.common.log import get_logger +from relik.inference.data.objects import Word +from relik.inference.data.tokenizers import SPACY_LANGUAGE_MAPPER +from relik.inference.data.tokenizers.base_tokenizer import BaseTokenizer + +logger = get_logger(level=logging.DEBUG) + +# Spacy and Stanza stuff + +LOADED_SPACY_MODELS: Dict[Tuple[str, bool, bool, bool, bool], spacy.Language] = {} + + +def load_spacy( + language: str, + pos_tags: bool = False, + lemma: bool = False, + parse: bool = False, + split_on_spaces: bool = False, +) -> spacy.Language: + """ + Download and load spacy model. + + Args: + language (:obj:`str`, defaults to :obj:`en`): + Language of the text to tokenize. + pos_tags (:obj:`bool`, optional, defaults to :obj:`False`): + If :obj:`True`, performs POS tagging with spacy model. + lemma (:obj:`bool`, optional, defaults to :obj:`False`): + If :obj:`True`, performs lemmatization with spacy model. + parse (:obj:`bool`, optional, defaults to :obj:`False`): + If :obj:`True`, performs dependency parsing with spacy model. + split_on_spaces (:obj:`bool`, optional, defaults to :obj:`False`): + If :obj:`True`, will split by spaces without performing tokenization. + + Returns: + :obj:`spacy.Language`: The spacy model loaded. + """ + exclude = ["vectors", "textcat", "ner"] + if not pos_tags: + exclude.append("tagger") + if not lemma: + exclude.append("lemmatizer") + if not parse: + exclude.append("parser") + + # check if the model is already loaded + # if so, there is no need to reload it + spacy_params = (language, pos_tags, lemma, parse, split_on_spaces) + if spacy_params not in LOADED_SPACY_MODELS: + try: + spacy_tagger = spacy.load(language, exclude=exclude) + except OSError: + logger.warning( + "Spacy model '%s' not found. Downloading and installing.", language + ) + spacy_download(language) + spacy_tagger = spacy.load(language, exclude=exclude) + + # if everything is disabled, return only the tokenizer + # for faster tokenization + # TODO: is it really faster? + # if len(exclude) >= 6: + # spacy_tagger = spacy_tagger.tokenizer + LOADED_SPACY_MODELS[spacy_params] = spacy_tagger + + return LOADED_SPACY_MODELS[spacy_params] + + +class SpacyTokenizer(BaseTokenizer): + """ + A :obj:`Tokenizer` that uses SpaCy to tokenizer and preprocess the text. It returns :obj:`Word` objects. + + Args: + language (:obj:`str`, optional, defaults to :obj:`en`): + Language of the text to tokenize. + return_pos_tags (:obj:`bool`, optional, defaults to :obj:`False`): + If :obj:`True`, performs POS tagging with spacy model. + return_lemmas (:obj:`bool`, optional, defaults to :obj:`False`): + If :obj:`True`, performs lemmatization with spacy model. + return_deps (:obj:`bool`, optional, defaults to :obj:`False`): + If :obj:`True`, performs dependency parsing with spacy model. + split_on_spaces (:obj:`bool`, optional, defaults to :obj:`False`): + If :obj:`True`, will split by spaces without performing tokenization. + use_gpu (:obj:`bool`, optional, defaults to :obj:`False`): + If :obj:`True`, will load the Stanza model on GPU. + """ + + def __init__( + self, + language: str = "en", + return_pos_tags: bool = False, + return_lemmas: bool = False, + return_deps: bool = False, + split_on_spaces: bool = False, + use_gpu: bool = False, + ): + super(SpacyTokenizer, self).__init__() + if language not in SPACY_LANGUAGE_MAPPER: + raise ValueError( + f"`{language}` language not supported. The supported " + f"languages are: {list(SPACY_LANGUAGE_MAPPER.keys())}." + ) + if use_gpu: + # load the model on GPU + # if the GPU is not available or not correctly configured, + # it will rise an error + spacy.require_gpu() + self.spacy = load_spacy( + SPACY_LANGUAGE_MAPPER[language], + return_pos_tags, + return_lemmas, + return_deps, + split_on_spaces, + ) + self.split_on_spaces = split_on_spaces + + def __call__( + self, + texts: Union[str, List[str], List[List[str]]], + is_split_into_words: bool = False, + **kwargs, + ) -> Union[List[Word], List[List[Word]]]: + """ + Tokenize the input into single words using SpaCy models. + + Args: + texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`): + Text to tag. It can be a single string, a batch of string and pre-tokenized strings. + is_split_into_words (:obj:`bool`, optional, defaults to :obj:`False`): + If :obj:`True` and the input is a string, the input is split on spaces. + + Returns: + :obj:`List[List[Word]]`: The input text tokenized in single words. + + Example:: + + >>> from ipa import SpacyTokenizer + + >>> spacy_tokenizer = SpacyTokenizer(language="en", pos_tags=True, lemma=True) + >>> spacy_tokenizer("Mary sold the car to John.") + + """ + # check if input is batched or a single sample + is_batched = self.check_is_batched(texts, is_split_into_words) + if is_batched: + tokenized = self.tokenize_batch(texts) + else: + tokenized = self.tokenize(texts) + return tokenized + + @overrides + def tokenize(self, text: Union[str, List[str]]) -> List[Word]: + if self.split_on_spaces: + if isinstance(text, str): + text = text.split(" ") + spaces = [True] * len(text) + text = Doc(self.spacy.vocab, words=text, spaces=spaces) + return self._clean_tokens(self.spacy(text)) + + @overrides + def tokenize_batch( + self, texts: Union[List[str], List[List[str]]] + ) -> List[List[Word]]: + if self.split_on_spaces: + if isinstance(texts[0], str): + texts = [text.split(" ") for text in texts] + spaces = [[True] * len(text) for text in texts] + texts = [ + Doc(self.spacy.vocab, words=text, spaces=space) + for text, space in zip(texts, spaces) + ] + return [self._clean_tokens(tokens) for tokens in self.spacy.pipe(texts)] + + @staticmethod + def _clean_tokens(tokens: Doc) -> List[Word]: + """ + Converts spaCy tokens to :obj:`Word`. + + Args: + tokens (:obj:`spacy.tokens.Doc`): + Tokens from SpaCy model. + + Returns: + :obj:`List[Word]`: The SpaCy model output converted into :obj:`Word` objects. + """ + words = [ + Word( + token.text, + token.i, + token.idx, + token.idx + len(token), + token.lemma_, + token.pos_, + token.dep_, + token.head.i, + ) + for token in tokens + ] + return words + + +class WhitespaceSpacyTokenizer: + """Simple white space tokenizer for SpaCy.""" + + def __init__(self, vocab): + self.vocab = vocab + + def __call__(self, text): + if isinstance(text, str): + words = text.split(" ") + elif isinstance(text, list): + words = text + else: + raise ValueError( + f"text must be either `str` or `list`, found: `{type(text)}`" + ) + spaces = [True] * len(words) + return Doc(self.vocab, words=words, spaces=spaces) diff --git a/relik/inference/data/tokenizers/whitespace_tokenizer.py b/relik/inference/data/tokenizers/whitespace_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..537ab6fe21eb4f9378d96d7cebfbc8cb12c36104 --- /dev/null +++ b/relik/inference/data/tokenizers/whitespace_tokenizer.py @@ -0,0 +1,70 @@ +import re +from typing import List, Union + +from overrides import overrides + +from relik.inference.data.objects import Word +from relik.inference.data.tokenizers.base_tokenizer import BaseTokenizer + + +class WhitespaceTokenizer(BaseTokenizer): + """ + A :obj:`Tokenizer` that splits the text on spaces. + """ + + def __init__(self): + super(WhitespaceTokenizer, self).__init__() + self.whitespace_regex = re.compile(r"\S+") + + def __call__( + self, + texts: Union[str, List[str], List[List[str]]], + is_split_into_words: bool = False, + **kwargs, + ) -> List[List[Word]]: + """ + Tokenize the input into single words by splitting on spaces. + + Args: + texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`): + Text to tag. It can be a single string, a batch of string and pre-tokenized strings. + is_split_into_words (:obj:`bool`, optional, defaults to :obj:`False`): + If :obj:`True` and the input is a string, the input is split on spaces. + + Returns: + :obj:`List[List[Word]]`: The input text tokenized in single words. + + Example:: + + >>> from nlp_preprocessing_wrappers import WhitespaceTokenizer + + >>> whitespace_tokenizer = WhitespaceTokenizer() + >>> whitespace_tokenizer("Mary sold the car to John .") + + """ + # check if input is batched or a single sample + is_batched = self.check_is_batched(texts, is_split_into_words) + + if is_batched: + tokenized = self.tokenize_batch(texts) + else: + tokenized = self.tokenize(texts) + + return tokenized + + @overrides + def tokenize(self, text: Union[str, List[str]]) -> List[Word]: + if not isinstance(text, (str, list)): + raise ValueError( + f"text must be either `str` or `list`, found: `{type(text)}`" + ) + + if isinstance(text, list): + text = " ".join(text) + return [ + Word(t[0], i, start_char=t[1], end_char=t[2]) + for i, t in enumerate( + (m.group(0), m.start(), m.end()) + for m in self.whitespace_regex.finditer(text) + ) + ] diff --git a/relik/inference/data/window/__init__.py b/relik/inference/data/window/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/relik/inference/data/window/manager.py b/relik/inference/data/window/manager.py new file mode 100644 index 0000000000000000000000000000000000000000..420609b1827f13bb332780554e3e20421908f6e9 --- /dev/null +++ b/relik/inference/data/window/manager.py @@ -0,0 +1,262 @@ +import collections +import itertools +from dataclasses import dataclass +from typing import List, Optional, Set, Tuple + +from relik.inference.data.tokenizers.base_tokenizer import BaseTokenizer +from relik.reader.data.relik_reader_sample import RelikReaderSample + + +@dataclass +class Window: + doc_id: int + window_id: int + text: str + tokens: List[str] + doc_topic: Optional[str] + offset: int + token2char_start: dict + token2char_end: dict + window_candidates: Optional[List[str]] = None + + +class WindowManager: + def __init__(self, tokenizer: BaseTokenizer) -> None: + self.tokenizer = tokenizer + + def tokenize(self, document: str) -> Tuple[List[str], List[Tuple[int, int]]]: + tokenized_document = self.tokenizer(document) + tokens = [] + tokens_char_mapping = [] + for token in tokenized_document: + tokens.append(token.text) + tokens_char_mapping.append((token.start_char, token.end_char)) + return tokens, tokens_char_mapping + + def create_windows( + self, + document: str, + window_size: int, + stride: int, + doc_id: int = 0, + doc_topic: str = None, + ) -> List[RelikReaderSample]: + document_tokens, tokens_char_mapping = self.tokenize(document) + if doc_topic is None: + doc_topic = document_tokens[0] if len(document_tokens) > 0 else "" + document_windows = [] + if len(document_tokens) <= window_size: + text = document + # relik_reader_sample = RelikReaderSample() + document_windows.append( + # Window( + RelikReaderSample( + doc_id=doc_id, + window_id=0, + text=text, + tokens=document_tokens, + doc_topic=doc_topic, + offset=0, + token2char_start={ + str(i): tokens_char_mapping[i][0] + for i in range(len(document_tokens)) + }, + token2char_end={ + str(i): tokens_char_mapping[i][1] + for i in range(len(document_tokens)) + }, + ) + ) + else: + for window_id, i in enumerate(range(0, len(document_tokens), stride)): + # if the last stride is smaller than the window size, then we can + # include more tokens form the previous window. + if i != 0 and i + window_size > len(document_tokens): + overflowing_tokens = i + window_size - len(document_tokens) + if overflowing_tokens >= stride: + break + i -= overflowing_tokens + + involved_token_indices = list( + range(i, min(i + window_size, len(document_tokens) - 1)) + ) + window_tokens = [document_tokens[j] for j in involved_token_indices] + window_text_start = tokens_char_mapping[involved_token_indices[0]][0] + window_text_end = tokens_char_mapping[involved_token_indices[-1]][1] + text = document[window_text_start:window_text_end] + document_windows.append( + # Window( + RelikReaderSample( + # dict( + doc_id=doc_id, + window_id=window_id, + text=text, + tokens=window_tokens, + doc_topic=doc_topic, + offset=window_text_start, + token2char_start={ + str(i): tokens_char_mapping[ti][0] + for i, ti in enumerate(involved_token_indices) + }, + token2char_end={ + str(i): tokens_char_mapping[ti][1] + for i, ti in enumerate(involved_token_indices) + }, + # ) + ) + ) + return document_windows + + def merge_windows( + self, windows: List[RelikReaderSample] + ) -> List[RelikReaderSample]: + windows_by_doc_id = collections.defaultdict(list) + for window in windows: + windows_by_doc_id[window.doc_id].append(window) + + merged_window_by_doc = { + doc_id: self.merge_doc_windows(doc_windows) + for doc_id, doc_windows in windows_by_doc_id.items() + } + + return list(merged_window_by_doc.values()) + + def merge_doc_windows(self, windows: List[RelikReaderSample]) -> RelikReaderSample: + if len(windows) == 1: + return windows[0] + + if len(windows) > 0 and getattr(windows[0], "offset", None) is not None: + windows = sorted(windows, key=(lambda x: x.offset)) + + window_accumulator = windows[0] + + for next_window in windows[1:]: + window_accumulator = self._merge_window_pair( + window_accumulator, next_window + ) + + return window_accumulator + + def _merge_tokens( + self, window1: RelikReaderSample, window2: RelikReaderSample + ) -> Tuple[list, dict, dict]: + w1_tokens = window1.tokens[1:-1] + w2_tokens = window2.tokens[1:-1] + + # find intersection + tokens_intersection = None + for k in reversed(range(1, len(w1_tokens))): + if w1_tokens[-k:] == w2_tokens[:k]: + tokens_intersection = k + break + assert tokens_intersection is not None, ( + f"{window1.doc_id} - {window1.sent_id} - {window1.offset}" + + f" {window2.doc_id} - {window2.sent_id} - {window2.offset}\n" + + f"w1 tokens: {w1_tokens}\n" + + f"w2 tokens: {w2_tokens}\n" + ) + + final_tokens = ( + [window1.tokens[0]] # CLS + + w1_tokens + + w2_tokens[tokens_intersection:] + + [window1.tokens[-1]] # SEP + ) + + w2_starting_offset = len(w1_tokens) - tokens_intersection + + def merge_char_mapping(t2c1: dict, t2c2: dict) -> dict: + final_t2c = dict() + final_t2c.update(t2c1) + for t, c in t2c2.items(): + t = int(t) + if t < tokens_intersection: + continue + final_t2c[str(t + w2_starting_offset)] = c + return final_t2c + + return ( + final_tokens, + merge_char_mapping(window1.token2char_start, window2.token2char_start), + merge_char_mapping(window1.token2char_end, window2.token2char_end), + ) + + def _merge_span_annotation( + self, span_annotation1: List[list], span_annotation2: List[list] + ) -> List[list]: + uniq_store = set() + final_span_annotation_store = [] + for span_annotation in itertools.chain(span_annotation1, span_annotation2): + span_annotation_id = tuple(span_annotation) + if span_annotation_id not in uniq_store: + uniq_store.add(span_annotation_id) + final_span_annotation_store.append(span_annotation) + return sorted(final_span_annotation_store, key=lambda x: x[0]) + + def _merge_predictions( + self, + window1: RelikReaderSample, + window2: RelikReaderSample, + ) -> Tuple[Set[Tuple[int, int, str]], dict]: + merged_predictions = window1.predicted_window_labels_chars.union( + window2.predicted_window_labels_chars + ) + + span_title_probabilities = dict() + # probabilities + for span_prediction, predicted_probs in itertools.chain( + window1.probs_window_labels_chars.items(), + window2.probs_window_labels_chars.items(), + ): + if span_prediction not in span_title_probabilities: + span_title_probabilities[span_prediction] = predicted_probs + + return merged_predictions, span_title_probabilities + + def _merge_window_pair( + self, + window1: RelikReaderSample, + window2: RelikReaderSample, + ) -> RelikReaderSample: + merging_output = dict() + + if getattr(window1, "doc_id", None) is not None: + assert window1.doc_id == window2.doc_id + + if getattr(window1, "offset", None) is not None: + assert ( + window1.offset < window2.offset + ), f"window 2 offset ({window2.offset}) is smaller that window 1 offset({window1.offset})" + + merging_output["doc_id"] = window1.doc_id + merging_output["offset"] = window2.offset + + m_tokens, m_token2char_start, m_token2char_end = self._merge_tokens( + window1, window2 + ) + + window_labels = None + if getattr(window1, "window_labels", None) is not None: + window_labels = self._merge_span_annotation( + window1.window_labels, window2.window_labels + ) + ( + predicted_window_labels_chars, + probs_window_labels_chars, + ) = self._merge_predictions( + window1, + window2, + ) + + merging_output.update( + dict( + tokens=m_tokens, + token2char_start=m_token2char_start, + token2char_end=m_token2char_end, + window_labels=window_labels, + predicted_window_labels_chars=predicted_window_labels_chars, + probs_window_labels_chars=probs_window_labels_chars, + ) + ) + + return RelikReaderSample(**merging_output) diff --git a/relik/inference/gerbil.py b/relik/inference/gerbil.py new file mode 100644 index 0000000000000000000000000000000000000000..d4c3f17cacea1d5472de99d1a974ad098585fc20 --- /dev/null +++ b/relik/inference/gerbil.py @@ -0,0 +1,254 @@ +import argparse +import json +import os +import re +import sys +from http.server import BaseHTTPRequestHandler, HTTPServer +from typing import Iterator, List, Optional, Tuple + +from relik.inference.annotator import Relik +from relik.inference.data.objects import RelikOutput + +# sys.path += ['../'] +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../"))) + + +import logging + +logger = logging.getLogger(__name__) + + +class GerbilAlbyManager: + def __init__( + self, + annotator: Optional[Relik] = None, + response_logger_dir: Optional[str] = None, + ) -> None: + self.annotator = annotator + self.response_logger_dir = response_logger_dir + self.predictions_counter = 0 + self.labels_mapping = None + + def annotate(self, document: str): + relik_output: RelikOutput = self.annotator(document) + annotations = [(ss, se, l) for ss, se, l, _ in relik_output.labels] + if self.labels_mapping is not None: + return [ + (ss, se, self.labels_mapping.get(l, l)) for ss, se, l in annotations + ] + return annotations + + def set_mapping_file(self, mapping_file_path: str): + with open(mapping_file_path) as f: + labels_mapping = json.load(f) + self.labels_mapping = {v: k for k, v in labels_mapping.items()} + + def write_response_bundle( + self, + document: str, + new_document: str, + annotations: list, + mapped_annotations: list, + ) -> None: + if self.response_logger_dir is None: + return + + if not os.path.isdir(self.response_logger_dir): + os.mkdir(self.response_logger_dir) + + with open( + f"{self.response_logger_dir}/{self.predictions_counter}.json", "w" + ) as f: + out_json_obj = dict( + document=document, + new_document=new_document, + annotations=annotations, + mapped_annotations=mapped_annotations, + ) + + out_json_obj["span_annotations"] = [ + (ss, se, document[ss:se], label) for (ss, se, label) in annotations + ] + + out_json_obj["span_mapped_annotations"] = [ + (ss, se, new_document[ss:se], label) + for (ss, se, label) in mapped_annotations + ] + + json.dump(out_json_obj, f, indent=2) + + self.predictions_counter += 1 + + +manager = GerbilAlbyManager() + + +def preprocess_document(document: str) -> Tuple[str, List[Tuple[int, int]]]: + pattern_subs = { + "-LPR- ": " (", + "-RPR-": ")", + "\n\n": "\n", + "-LRB-": "(", + "-RRB-": ")", + '","': ",", + } + + document_acc = document + curr_offset = 0 + char2offset = [] + + matchings = re.finditer("({})".format("|".join(pattern_subs)), document) + for span_matching in sorted(matchings, key=lambda x: x.span()[0]): + span_start, span_end = span_matching.span() + span_start -= curr_offset + span_end -= curr_offset + + span_text = document_acc[span_start:span_end] + span_sub = pattern_subs[span_text] + document_acc = document_acc[:span_start] + span_sub + document_acc[span_end:] + + offset = len(span_text) - len(span_sub) + curr_offset += offset + + char2offset.append((span_start + len(span_sub), curr_offset)) + + return document_acc, char2offset + + +def map_back_annotations( + annotations: List[Tuple[int, int, str]], char_mapping: List[Tuple[int, int]] +) -> Iterator[Tuple[int, int, str]]: + def map_char(char_idx: int) -> int: + current_offset = 0 + for offset_idx, offset_value in char_mapping: + if char_idx >= offset_idx: + current_offset = offset_value + else: + break + return char_idx + current_offset + + for ss, se, label in annotations: + yield map_char(ss), map_char(se), label + + +def annotate(document: str) -> List[Tuple[int, int, str]]: + new_document, mapping = preprocess_document(document) + logger.info("Mapping: " + str(mapping)) + logger.info("Document: " + str(document)) + annotations = [ + (cs, ce, label.replace(" ", "_")) + for cs, ce, label in manager.annotate(new_document) + ] + logger.info("New document: " + str(new_document)) + mapped_annotations = ( + list(map_back_annotations(annotations, mapping)) + if len(mapping) > 0 + else annotations + ) + + logger.info( + "Annotations: " + + str([(ss, se, document[ss:se], ann) for ss, se, ann in mapped_annotations]) + ) + + manager.write_response_bundle( + document, new_document, mapped_annotations, annotations + ) + + if not all( + [ + new_document[ss:se] == document[mss:mse] + for (mss, mse, _), (ss, se, _) in zip(mapped_annotations, annotations) + ] + ): + diff_mappings = [ + (new_document[ss:se], document[mss:mse]) + for (mss, mse, _), (ss, se, _) in zip(mapped_annotations, annotations) + ] + return None + assert all( + [ + document[mss:mse] == new_document[ss:se] + for (mss, mse, _), (ss, se, _) in zip(mapped_annotations, annotations) + ] + ), (mapped_annotations, annotations) + + return [(cs, ce - cs, label) for cs, ce, label in mapped_annotations] + + +class GetHandler(BaseHTTPRequestHandler): + def do_POST(self): + content_length = int(self.headers["Content-Length"]) + post_data = self.rfile.read(content_length) + self.send_response(200) + self.end_headers() + doc_text = read_json(post_data) + # try: + response = annotate(doc_text) + + self.wfile.write(bytes(json.dumps(response), "utf-8")) + return + + +def read_json(post_data): + data = json.loads(post_data.decode("utf-8")) + # logger.info("received data:", data) + text = data["text"] + # spans = [(int(j["start"]), int(j["length"])) for j in data["spans"]] + return text + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument("--relik-model-name", required=True) + parser.add_argument("--responses-log-dir") + parser.add_argument("--log-file", default="logs/logging.txt") + parser.add_argument("--mapping-file") + return parser.parse_args() + + +def main(): + args = parse_args() + + # init manager + manager.response_logger_dir = args.responses_log_dir + # manager.annotator = Relik.from_pretrained(args.relik_model_name) + + print("Debugging, not using you relik model but an hardcoded one.") + manager.annotator = Relik( + question_encoder="riccorl/relik-retriever-aida-blink-pretrain-omniencoder", + document_index="riccorl/index-relik-retriever-aida-blink-pretrain-omniencoder", + reader="relik/reader/models/relik-reader-deberta-base-new-data", + window_size=32, + window_stride=16, + candidates_preprocessing_fn=(lambda x: x.split("")[0].strip()), + ) + + if args.mapping_file is not None: + manager.set_mapping_file(args.mapping_file) + + port = 6654 + server = HTTPServer(("localhost", port), GetHandler) + logger.info(f"Starting server at http://localhost:{port}") + + # Create a file handler and set its level + file_handler = logging.FileHandler(args.log_file) + file_handler.setLevel(logging.DEBUG) + + # Create a log formatter and set it on the handler + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + file_handler.setFormatter(formatter) + + # Add the file handler to the logger + logger.addHandler(file_handler) + + try: + server.serve_forever() + except KeyboardInterrupt: + exit(0) + + +if __name__ == "__main__": + main() diff --git a/relik/inference/preprocessing.py b/relik/inference/preprocessing.py new file mode 100644 index 0000000000000000000000000000000000000000..2476fe47ea64d907a8c32c31082253c45b48720c --- /dev/null +++ b/relik/inference/preprocessing.py @@ -0,0 +1,4 @@ +def wikipedia_title_and_openings_preprocessing( + wikipedia_title_and_openings: str, sepator: str = " " +): + return wikipedia_title_and_openings.split(sepator, 1)[0] diff --git a/relik/inference/serve/__init__.py b/relik/inference/serve/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/relik/inference/serve/backend/__init__.py b/relik/inference/serve/backend/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/relik/inference/serve/backend/relik.py b/relik/inference/serve/backend/relik.py new file mode 100644 index 0000000000000000000000000000000000000000..038e2ef78afbccb0758162996e35cd8dc858d453 --- /dev/null +++ b/relik/inference/serve/backend/relik.py @@ -0,0 +1,210 @@ +import logging +from pathlib import Path +from typing import List, Optional, Union + +from relik.common.utils import is_package_available +from relik.inference.annotator import Relik + +if not is_package_available("fastapi"): + raise ImportError( + "FastAPI is not installed. Please install FastAPI with `pip install relik[serve]`." + ) +from fastapi import FastAPI, HTTPException + +if not is_package_available("ray"): + raise ImportError( + "Ray is not installed. Please install Ray with `pip install relik[serve]`." + ) +from ray import serve + +from relik.common.log import get_logger +from relik.inference.serve.backend.utils import ( + RayParameterManager, + ServerParameterManager, +) +from relik.retriever.data.utils import batch_generator + +logger = get_logger(__name__, level=logging.INFO) + +VERSION = {} # type: ignore +with open( + Path(__file__).parent.parent.parent.parent / "version.py", "r" +) as version_file: + exec(version_file.read(), VERSION) + +# Env variables for server +SERVER_MANAGER = ServerParameterManager() +RAY_MANAGER = RayParameterManager() + +app = FastAPI( + title="ReLiK", + version=VERSION["VERSION"], + description="ReLiK REST API", +) + + +@serve.deployment( + ray_actor_options={ + "num_gpus": RAY_MANAGER.num_gpus + if ( + SERVER_MANAGER.retriver_device == "cuda" + or SERVER_MANAGER.reader_device == "cuda" + ) + else 0 + }, + autoscaling_config={ + "min_replicas": RAY_MANAGER.min_replicas, + "max_replicas": RAY_MANAGER.max_replicas, + }, +) +@serve.ingress(app) +class RelikServer: + def __init__( + self, + question_encoder: str, + document_index: str, + passage_encoder: Optional[str] = None, + reader_encoder: Optional[str] = None, + top_k: int = 100, + retriver_device: str = "cpu", + reader_device: str = "cpu", + index_device: Optional[str] = None, + precision: int = 32, + index_precision: Optional[int] = None, + use_faiss: bool = False, + window_batch_size: int = 32, + window_size: int = 32, + window_stride: int = 16, + split_on_spaces: bool = False, + ): + # parameters + self.question_encoder = question_encoder + self.passage_encoder = passage_encoder + self.reader_encoder = reader_encoder + self.document_index = document_index + self.top_k = top_k + self.retriver_device = retriver_device + self.index_device = index_device or retriver_device + self.reader_device = reader_device + self.precision = precision + self.index_precision = index_precision or precision + self.use_faiss = use_faiss + self.window_batch_size = window_batch_size + self.window_size = window_size + self.window_stride = window_stride + self.split_on_spaces = split_on_spaces + + # log stuff for debugging + logger.info("Initializing RelikServer with parameters:") + logger.info(f"QUESTION_ENCODER: {self.question_encoder}") + logger.info(f"PASSAGE_ENCODER: {self.passage_encoder}") + logger.info(f"READER_ENCODER: {self.reader_encoder}") + logger.info(f"DOCUMENT_INDEX: {self.document_index}") + logger.info(f"TOP_K: {self.top_k}") + logger.info(f"RETRIEVER_DEVICE: {self.retriver_device}") + logger.info(f"READER_DEVICE: {self.reader_device}") + logger.info(f"INDEX_DEVICE: {self.index_device}") + logger.info(f"PRECISION: {self.precision}") + logger.info(f"INDEX_PRECISION: {self.index_precision}") + logger.info(f"WINDOW_BATCH_SIZE: {self.window_batch_size}") + logger.info(f"SPLIT_ON_SPACES: {self.split_on_spaces}") + + self.relik = Relik( + question_encoder=self.question_encoder, + passage_encoder=self.passage_encoder, + document_index=self.document_index, + reader=self.reader_encoder, + retriever_device=self.retriver_device, + document_index_device=self.index_device, + reader_device=self.reader_device, + retriever_precision=self.precision, + document_index_precision=self.index_precision, + reader_precision=self.precision, + ) + + # @serve.batch() + async def handle_batch(self, documents: List[str]) -> List: + return self.relik( + documents, + top_k=self.top_k, + window_size=self.window_size, + window_stride=self.window_stride, + batch_size=self.window_batch_size, + ) + + @app.post("/api/entities") + async def entities_endpoint( + self, + documents: Union[str, List[str]], + ): + try: + # normalize input + if isinstance(documents, str): + documents = [documents] + if document_topics is not None: + if isinstance(document_topics, str): + document_topics = [document_topics] + assert len(documents) == len(document_topics) + # get predictions for the retriever + return await self.handle_batch(documents, document_topics) + except Exception as e: + # log the entire stack trace + logger.exception(e) + raise HTTPException(status_code=500, detail=f"Server Error: {e}") + + @app.post("/api/gerbil") + async def gerbil_endpoint(self, documents: Union[str, List[str]]): + try: + # normalize input + if isinstance(documents, str): + documents = [documents] + + # output list + windows_passages = [] + # split documents into windows + document_windows = [ + window + for doc_id, document in enumerate(documents) + for window in self.window_manager( + self.tokenizer, + document, + window_size=self.window_size, + stride=self.window_stride, + doc_id=doc_id, + ) + ] + + # get text and topic from document windows and create new list + model_inputs = [ + (window.text, window.doc_topic) for window in document_windows + ] + + # batch generator + for batch in batch_generator( + model_inputs, batch_size=self.window_batch_size + ): + text, text_pair = zip(*batch) + batch_predictions = await self.handle_batch_retriever(text, text_pair) + windows_passages.extend( + [ + [p.label for p in predictions] + for predictions in batch_predictions + ] + ) + + # add passage to document windows + for window, passages in zip(document_windows, windows_passages): + # clean up passages (remove everything after first tag if present) + passages = [c.split(" ", 1)[0] for c in passages] + window.window_candidates = passages + + # return document windows + return document_windows + + except Exception as e: + # log the entire stack trace + logger.exception(e) + raise HTTPException(status_code=500, detail=f"Server Error: {e}") + + +server = RelikServer.bind(**vars(SERVER_MANAGER)) diff --git a/relik/inference/serve/backend/retriever.py b/relik/inference/serve/backend/retriever.py new file mode 100644 index 0000000000000000000000000000000000000000..e796893e76b83377a5f8b2c7afdccce21756dcbd --- /dev/null +++ b/relik/inference/serve/backend/retriever.py @@ -0,0 +1,206 @@ +import logging +from pathlib import Path +from typing import List, Optional, Union + +from relik.common.utils import is_package_available + +if not is_package_available("fastapi"): + raise ImportError( + "FastAPI is not installed. Please install FastAPI with `pip install relik[serve]`." + ) +from fastapi import FastAPI, HTTPException + +if not is_package_available("ray"): + raise ImportError( + "Ray is not installed. Please install Ray with `pip install relik[serve]`." + ) +from ray import serve + +from relik.common.log import get_logger +from relik.inference.data.tokenizers import SpacyTokenizer, WhitespaceTokenizer +from relik.inference.data.window.manager import WindowManager +from relik.inference.serve.backend.utils import ( + RayParameterManager, + ServerParameterManager, +) +from relik.retriever.data.utils import batch_generator +from relik.retriever.pytorch_modules import GoldenRetriever + +logger = get_logger(__name__, level=logging.INFO) + +VERSION = {} # type: ignore +with open(Path(__file__).parent.parent.parent / "version.py", "r") as version_file: + exec(version_file.read(), VERSION) + +# Env variables for server +SERVER_MANAGER = ServerParameterManager() +RAY_MANAGER = RayParameterManager() + +app = FastAPI( + title="Golden Retriever", + version=VERSION["VERSION"], + description="Golden Retriever REST API", +) + + +@serve.deployment( + ray_actor_options={ + "num_gpus": RAY_MANAGER.num_gpus if SERVER_MANAGER.device == "cuda" else 0 + }, + autoscaling_config={ + "min_replicas": RAY_MANAGER.min_replicas, + "max_replicas": RAY_MANAGER.max_replicas, + }, +) +@serve.ingress(app) +class GoldenRetrieverServer: + def __init__( + self, + question_encoder: str, + document_index: str, + passage_encoder: Optional[str] = None, + top_k: int = 100, + device: str = "cpu", + index_device: Optional[str] = None, + precision: int = 32, + index_precision: Optional[int] = None, + use_faiss: bool = False, + window_batch_size: int = 32, + window_size: int = 32, + window_stride: int = 16, + split_on_spaces: bool = False, + ): + # parameters + self.question_encoder = question_encoder + self.passage_encoder = passage_encoder + self.document_index = document_index + self.top_k = top_k + self.device = device + self.index_device = index_device or device + self.precision = precision + self.index_precision = index_precision or precision + self.use_faiss = use_faiss + self.window_batch_size = window_batch_size + self.window_size = window_size + self.window_stride = window_stride + self.split_on_spaces = split_on_spaces + + # log stuff for debugging + logger.info("Initializing GoldenRetrieverServer with parameters:") + logger.info(f"QUESTION_ENCODER: {self.question_encoder}") + logger.info(f"PASSAGE_ENCODER: {self.passage_encoder}") + logger.info(f"DOCUMENT_INDEX: {self.document_index}") + logger.info(f"TOP_K: {self.top_k}") + logger.info(f"DEVICE: {self.device}") + logger.info(f"INDEX_DEVICE: {self.index_device}") + logger.info(f"PRECISION: {self.precision}") + logger.info(f"INDEX_PRECISION: {self.index_precision}") + logger.info(f"WINDOW_BATCH_SIZE: {self.window_batch_size}") + logger.info(f"SPLIT_ON_SPACES: {self.split_on_spaces}") + + self.retriever = GoldenRetriever( + question_encoder=self.question_encoder, + passage_encoder=self.passage_encoder, + document_index=self.document_index, + device=self.device, + index_device=self.index_device, + index_precision=self.index_precision, + ) + self.retriever.eval() + + if self.split_on_spaces: + logger.info("Using WhitespaceTokenizer") + self.tokenizer = WhitespaceTokenizer() + # logger.info("Using RegexTokenizer") + # self.tokenizer = RegexTokenizer() + else: + logger.info("Using SpacyTokenizer") + self.tokenizer = SpacyTokenizer(language="en") + + self.window_manager = WindowManager(tokenizer=self.tokenizer) + + # @serve.batch() + async def handle_batch( + self, documents: List[str], document_topics: List[str] + ) -> List: + return self.retriever.retrieve( + documents, text_pair=document_topics, k=self.top_k, precision=self.precision + ) + + @app.post("/api/retrieve") + async def retrieve_endpoint( + self, + documents: Union[str, List[str]], + document_topics: Optional[Union[str, List[str]]] = None, + ): + try: + # normalize input + if isinstance(documents, str): + documents = [documents] + if document_topics is not None: + if isinstance(document_topics, str): + document_topics = [document_topics] + assert len(documents) == len(document_topics) + # get predictions + return await self.handle_batch(documents, document_topics) + except Exception as e: + # log the entire stack trace + logger.exception(e) + raise HTTPException(status_code=500, detail=f"Server Error: {e}") + + @app.post("/api/gerbil") + async def gerbil_endpoint(self, documents: Union[str, List[str]]): + try: + # normalize input + if isinstance(documents, str): + documents = [documents] + + # output list + windows_passages = [] + # split documents into windows + document_windows = [ + window + for doc_id, document in enumerate(documents) + for window in self.window_manager( + self.tokenizer, + document, + window_size=self.window_size, + stride=self.window_stride, + doc_id=doc_id, + ) + ] + + # get text and topic from document windows and create new list + model_inputs = [ + (window.text, window.doc_topic) for window in document_windows + ] + + # batch generator + for batch in batch_generator( + model_inputs, batch_size=self.window_batch_size + ): + text, text_pair = zip(*batch) + batch_predictions = await self.handle_batch(text, text_pair) + windows_passages.extend( + [ + [p.label for p in predictions] + for predictions in batch_predictions + ] + ) + + # add passage to document windows + for window, passages in zip(document_windows, windows_passages): + # clean up passages (remove everything after first tag if present) + passages = [c.split(" ", 1)[0] for c in passages] + window.window_candidates = passages + + # return document windows + return document_windows + + except Exception as e: + # log the entire stack trace + logger.exception(e) + raise HTTPException(status_code=500, detail=f"Server Error: {e}") + + +server = GoldenRetrieverServer.bind(**vars(SERVER_MANAGER)) diff --git a/relik/inference/serve/backend/utils.py b/relik/inference/serve/backend/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bdf869c1ece0e260355526ee5fcc2f00da7ef887 --- /dev/null +++ b/relik/inference/serve/backend/utils.py @@ -0,0 +1,29 @@ +import os +from dataclasses import dataclass +from typing import Union + + +@dataclass +class ServerParameterManager: + retriver_device: str = os.environ.get("RETRIEVER_DEVICE", "cpu") + reader_device: str = os.environ.get("READER_DEVICE", "cpu") + index_device: str = os.environ.get("INDEX_DEVICE", retriver_device) + precision: Union[str, int] = os.environ.get("PRECISION", "fp32") + index_precision: Union[str, int] = os.environ.get("INDEX_PRECISION", precision) + question_encoder: str = os.environ.get("QUESTION_ENCODER", None) + passage_encoder: str = os.environ.get("PASSAGE_ENCODER", None) + document_index: str = os.environ.get("DOCUMENT_INDEX", None) + reader_encoder: str = os.environ.get("READER_ENCODER", None) + top_k: int = int(os.environ.get("TOP_K", 100)) + use_faiss: bool = os.environ.get("USE_FAISS", False) + window_batch_size: int = int(os.environ.get("WINDOW_BATCH_SIZE", 32)) + window_size: int = int(os.environ.get("WINDOW_SIZE", 32)) + window_stride: int = int(os.environ.get("WINDOW_SIZE", 16)) + split_on_spaces: bool = os.environ.get("SPLIT_ON_SPACES", False) + + +class RayParameterManager: + def __init__(self) -> None: + self.num_gpus = int(os.environ.get("NUM_GPUS", 1)) + self.min_replicas = int(os.environ.get("MIN_REPLICAS", 1)) + self.max_replicas = int(os.environ.get("MAX_REPLICAS", 1)) diff --git a/relik/inference/serve/frontend/__init__.py b/relik/inference/serve/frontend/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/relik/inference/serve/frontend/relik.py b/relik/inference/serve/frontend/relik.py new file mode 100644 index 0000000000000000000000000000000000000000..5dd8bb4eb2d8c6b056c61bda359959010e688635 --- /dev/null +++ b/relik/inference/serve/frontend/relik.py @@ -0,0 +1,231 @@ +import os +import re +import time +from pathlib import Path + +import requests +import streamlit as st +from spacy import displacy +from streamlit_extras.badges import badge +from streamlit_extras.stylable_container import stylable_container + +RELIK = os.getenv("RELIK", "localhost:8000/api/entities") + +import random + + +def get_random_color(ents): + colors = {} + random_colors = generate_pastel_colors(len(ents)) + for ent in ents: + colors[ent] = random_colors.pop(random.randint(0, len(random_colors) - 1)) + return colors + + +def floatrange(start, stop, steps): + if int(steps) == 1: + return [stop] + return [ + start + float(i) * (stop - start) / (float(steps) - 1) for i in range(steps) + ] + + +def hsl_to_rgb(h, s, l): + def hue_2_rgb(v1, v2, v_h): + while v_h < 0.0: + v_h += 1.0 + while v_h > 1.0: + v_h -= 1.0 + if 6 * v_h < 1.0: + return v1 + (v2 - v1) * 6.0 * v_h + if 2 * v_h < 1.0: + return v2 + if 3 * v_h < 2.0: + return v1 + (v2 - v1) * ((2.0 / 3.0) - v_h) * 6.0 + return v1 + + # if not (0 <= s <= 1): raise ValueError, "s (saturation) parameter must be between 0 and 1." + # if not (0 <= l <= 1): raise ValueError, "l (lightness) parameter must be between 0 and 1." + + r, b, g = (l * 255,) * 3 + if s != 0.0: + if l < 0.5: + var_2 = l * (1.0 + s) + else: + var_2 = (l + s) - (s * l) + var_1 = 2.0 * l - var_2 + r = 255 * hue_2_rgb(var_1, var_2, h + (1.0 / 3.0)) + g = 255 * hue_2_rgb(var_1, var_2, h) + b = 255 * hue_2_rgb(var_1, var_2, h - (1.0 / 3.0)) + + return int(round(r)), int(round(g)), int(round(b)) + + +def generate_pastel_colors(n): + """Return different pastel colours. + + Input: + n (integer) : The number of colors to return + + Output: + A list of colors in HTML notation (eg.['#cce0ff', '#ffcccc', '#ccffe0', '#f5ccff', '#f5ffcc']) + + Example: + >>> print generate_pastel_colors(5) + ['#cce0ff', '#f5ccff', '#ffcccc', '#f5ffcc', '#ccffe0'] + """ + if n == 0: + return [] + + # To generate colors, we use the HSL colorspace (see http://en.wikipedia.org/wiki/HSL_color_space) + start_hue = 0.6 # 0=red 1/3=0.333=green 2/3=0.666=blue + saturation = 1.0 + lightness = 0.8 + # We take points around the chromatic circle (hue): + # (Note: we generate n+1 colors, then drop the last one ([:-1]) because + # it equals the first one (hue 0 = hue 1)) + return [ + "#%02x%02x%02x" % hsl_to_rgb(hue, saturation, lightness) + for hue in floatrange(start_hue, start_hue + 1, n + 1) + ][:-1] + + +def set_sidebar(css): + white_link_wrapper = "{}" + with st.sidebar: + st.markdown(f"", unsafe_allow_html=True) + st.image( + "http://nlp.uniroma1.it/static/website/sapienza-nlp-logo-wh.svg", + use_column_width=True, + ) + st.markdown("## ReLiK") + st.write( + f""" + - {white_link_wrapper.format("#", "  Paper")} + - {white_link_wrapper.format("https://github.com/SapienzaNLP/relik", "  GitHub")} + - {white_link_wrapper.format("https://hub.docker.com/repository/docker/sapienzanlp/relik", "  Docker Hub")} + """, + unsafe_allow_html=True, + ) + st.markdown("## Sapienza NLP") + st.write( + f""" + - {white_link_wrapper.format("https://nlp.uniroma1.it", "  Webpage")} + - {white_link_wrapper.format("https://github.com/SapienzaNLP", "  GitHub")} + - {white_link_wrapper.format("https://twitter.com/SapienzaNLP", "  Twitter")} + - {white_link_wrapper.format("https://www.linkedin.com/company/79434450", "  LinkedIn")} + """, + unsafe_allow_html=True, + ) + + +def get_el_annotations(response): + # swap labels key with ents + response["ents"] = response.pop("labels") + label_in_text = set(l["label"] for l in response["ents"]) + options = {"ents": label_in_text, "colors": get_random_color(label_in_text)} + return response, options + + +def set_intro(css): + # intro + st.markdown("# ReLik") + st.markdown( + "### Retrieve, Read and LinK: Fast and Accurate Entity Linking and Relation Extraction on an Academic Budget" + ) + # st.markdown( + # "This is a front-end for the paper [Universal Semantic Annotator: the First Unified API " + # "for WSD, SRL and Semantic Parsing](https://www.researchgate.net/publication/360671045_Universal_Semantic_Annotator_the_First_Unified_API_for_WSD_SRL_and_Semantic_Parsing), which will be presented at LREC 2022 by " + # "[Riccardo Orlando](https://riccorl.github.io), [Simone Conia](https://c-simone.github.io/), " + # "[Stefano Faralli](https://corsidilaurea.uniroma1.it/it/users/stefanofaralliuniroma1it), and [Roberto Navigli](https://www.diag.uniroma1.it/navigli/)." + # ) + badge(type="github", name="sapienzanlp/relik") + badge(type="pypi", name="relik") + + +def run_client(): + with open(Path(__file__).parent / "style.css") as f: + css = f.read() + + st.set_page_config( + page_title="ReLik", + page_icon="🦮", + layout="wide", + ) + set_sidebar(css) + set_intro(css) + + # text input + text = st.text_area( + "Enter Text Below:", + value="Obama went to Rome for a quick vacation.", + height=200, + max_chars=500, + ) + + with stylable_container( + key="annotate_button", + css_styles=""" + button { + background-color: #802433; + color: white; + border-radius: 25px; + } + """, + ): + submit = st.button("Annotate") + # submit = st.button("Run") + + # ReLik API call + if submit: + text = text.strip() + if text: + st.markdown("####") + st.markdown("#### Entity Linking") + with st.spinner(text="In progress"): + response = requests.post(RELIK, json=text) + if response.status_code != 200: + st.error("Error: {}".format(response.status_code)) + else: + response = response.json() + + # Entity Linking + # with stylable_container( + # key="container_with_border", + # css_styles=""" + # { + # border: 1px solid rgba(49, 51, 63, 0.2); + # border-radius: 0.5rem; + # padding: 0.5rem; + # padding-bottom: 2rem; + # } + # """, + # ): + # st.markdown("##") + dict_of_ents, options = get_el_annotations(response=response) + display = displacy.render( + dict_of_ents, manual=True, style="ent", options=options + ) + display = display.replace("\n", " ") + # wsd_display = re.sub( + # r"(wiki::\d+\w)", + # r"\g<1>".format( + # language.upper() + # ), + # wsd_display, + # ) + with st.container(): + st.write(display, unsafe_allow_html=True) + + st.markdown("####") + st.markdown("#### Relation Extraction") + + with st.container(): + st.write("Coming :)", unsafe_allow_html=True) + + else: + st.error("Please enter some text.") + + +if __name__ == "__main__": + run_client() diff --git a/relik/inference/serve/frontend/style.css b/relik/inference/serve/frontend/style.css new file mode 100644 index 0000000000000000000000000000000000000000..31f0d182cfd9b2636d5db5cbd0e7a1339ed5d1c3 --- /dev/null +++ b/relik/inference/serve/frontend/style.css @@ -0,0 +1,33 @@ +/* Sidebar */ +.eczjsme11 { + background-color: #802433; +} + +.st-emotion-cache-10oheav h2 { + color: white; +} + +.st-emotion-cache-10oheav li { + color: white; +} + +/* Main */ +a:link { + text-decoration: none; + color: white; +} + +a:visited { + text-decoration: none; + color: white; +} + +a:hover { + text-decoration: none; + color: rgba(255, 255, 255, 0.871); +} + +a:active { + text-decoration: none; + color: white; +} \ No newline at end of file diff --git a/relik/reader/__init__.py b/relik/reader/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/relik/reader/conf/config.yaml b/relik/reader/conf/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..05b4e524d060aed56b930d1f578424b986792975 --- /dev/null +++ b/relik/reader/conf/config.yaml @@ -0,0 +1,14 @@ +# Required to make the "experiments" dir the default one for the output of the models +hydra: + run: + dir: ./experiments/${model_name}/${now:%Y-%m-%d}/${now:%H-%M-%S} + +model_name: relik-reader-deberta-base # used to name the model in wandb and output dir +project_name: relik-reader # used to name the project in wandb + + +defaults: + - _self_ + - training: base + - model: base + - data: base diff --git a/relik/reader/conf/data/base.yaml b/relik/reader/conf/data/base.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1964d8750bed1a162a9ec15d20be1708ccad9914 --- /dev/null +++ b/relik/reader/conf/data/base.yaml @@ -0,0 +1,21 @@ +train_dataset_path: "relik/reader/data/train.jsonl" +val_dataset_path: "relik/reader/data/testa.jsonl" + +train_dataset: + _target_: "relik.reader.relik_reader_data.RelikDataset" + transformer_model: "${model.model.transformer_model}" + materialize_samples: False + shuffle_candidates: 0.5 + random_drop_gold_candidates: 0.05 + noise_param: 0.0 + for_inference: False + tokens_per_batch: 4096 + special_symbols: null + +val_dataset: + _target_: "relik.reader.relik_reader_data.RelikDataset" + transformer_model: "${model.model.transformer_model}" + materialize_samples: False + shuffle_candidates: False + for_inference: True + special_symbols: null diff --git a/relik/reader/conf/data/re.yaml b/relik/reader/conf/data/re.yaml new file mode 100644 index 0000000000000000000000000000000000000000..17c18ee886021bc0157edb156020409fdd799fbc --- /dev/null +++ b/relik/reader/conf/data/re.yaml @@ -0,0 +1,54 @@ +train_dataset_path: "relik/reader/data/nyt-alby+/train.jsonl" +val_dataset_path: "relik/reader/data/nyt-alby+/valid.jsonl" +test_dataset_path: "relik/reader/data/nyt-alby+/test.jsonl" + +relations_definitions: + /people/person/nationality: "nationality" + /sports/sports_team/location: "sports team location" + /location/country/administrative_divisions: "administrative divisions" + /business/company/major_shareholders: "shareholders" + /people/ethnicity/people: "ethnicity" + /people/ethnicity/geographic_distribution: "geographic distributi6on" + /business/company_shareholder/major_shareholder_of: "major shareholder" + /location/location/contains: "location" + /business/company/founders: "founders" + /business/person/company: "company" + /business/company/advisors: "advisor" + /people/deceased_person/place_of_death: "place of death" + /business/company/industry: "industry" + /people/person/ethnicity: "ethnic background" + /people/person/place_of_birth: "place of birth" + /location/administrative_division/country: "country of an administration division" + /people/person/place_lived: "place lived" + /sports/sports_team_location/teams: "sports team" + /people/person/children: "child" + /people/person/religion: "religion" + /location/neighborhood/neighborhood_of: "neighborhood" + /location/country/capital: "capital" + /business/company/place_founded: "company founded location" + /people/person/profession: "occupation" + +train_dataset: + _target_: "relik.reader.relik_reader_re_data.RelikREDataset" + transformer_model: "${model.model.transformer_model}" + materialize_samples: False + shuffle_candidates: False + flip_candidates: 1.0 + noise_param: 0.0 + for_inference: False + tokens_per_batch: 4096 + min_length: -1 + special_symbols: null + relations_definitions: ${data.relations_definitions} + sorting_fields: + - "predictable_candidates" +val_dataset: + _target_: "relik.reader.relik_reader_re_data.RelikREDataset" + transformer_model: "${model.model.transformer_model}" + materialize_samples: False + shuffle_candidates: False + flip_candidates: False + for_inference: True + min_length: -1 + special_symbols: null + relations_definitions: ${data.relations_definitions} diff --git a/relik/reader/conf/training/base.yaml b/relik/reader/conf/training/base.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8e366a96408bd0f8ff5184849e53d19bb477af38 --- /dev/null +++ b/relik/reader/conf/training/base.yaml @@ -0,0 +1,12 @@ +seed: 94 + +trainer: + _target_: lightning.Trainer + devices: + - 0 + precision: "16-mixed" + max_steps: 50000 + val_check_interval: 1.0 + num_sanity_val_steps: 0 + limit_val_batches: 1 + gradient_clip_val: 1.0 diff --git a/relik/reader/conf/training/re.yaml b/relik/reader/conf/training/re.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8701ae3fca48830649022644a743783a1016bd5b --- /dev/null +++ b/relik/reader/conf/training/re.yaml @@ -0,0 +1,12 @@ +seed: 15 + +trainer: + _target_: lightning.Trainer + devices: + - 0 + precision: "16-mixed" + max_steps: 100000 + val_check_interval: 1.0 + num_sanity_val_steps: 0 + limit_val_batches: 1 + gradient_clip_val: 1.0 \ No newline at end of file diff --git a/relik/reader/data/__init__.py b/relik/reader/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/relik/reader/data/patches.py b/relik/reader/data/patches.py new file mode 100644 index 0000000000000000000000000000000000000000..b0d03dbdf08d0e205787ce2b8176c6bd47d2dfca --- /dev/null +++ b/relik/reader/data/patches.py @@ -0,0 +1,51 @@ +from typing import List + +from relik.reader.data.relik_reader_sample import RelikReaderSample +from relik.reader.utils.special_symbols import NME_SYMBOL + + +def merge_patches_predictions(sample) -> None: + sample._d["predicted_window_labels"] = dict() + predicted_window_labels = sample._d["predicted_window_labels"] + + sample._d["span_title_probabilities"] = dict() + span_title_probabilities = sample._d["span_title_probabilities"] + + span2title = dict() + for _, patch_info in sorted(sample.patches.items(), key=lambda x: x[0]): + # selecting span predictions + for predicted_title, predicted_spans in patch_info[ + "predicted_window_labels" + ].items(): + for pred_span in predicted_spans: + pred_span = tuple(pred_span) + curr_title = span2title.get(pred_span) + if curr_title is None or curr_title == NME_SYMBOL: + span2title[pred_span] = predicted_title + # else: + # print("Merging at patch level") + + # selecting span predictions probability + for predicted_span, titles_probabilities in patch_info[ + "span_title_probabilities" + ].items(): + if predicted_span not in span_title_probabilities: + span_title_probabilities[predicted_span] = titles_probabilities + + for span, title in span2title.items(): + if title not in predicted_window_labels: + predicted_window_labels[title] = list() + predicted_window_labels[title].append(span) + + +def remove_duplicate_samples( + samples: List[RelikReaderSample], +) -> List[RelikReaderSample]: + seen_sample = set() + samples_store = [] + for sample in samples: + sample_id = f"{sample.doc_id}#{sample.sent_id}#{sample.offset}" + if sample_id not in seen_sample: + seen_sample.add(sample_id) + samples_store.append(sample) + return samples_store diff --git a/relik/reader/data/relik_reader_data.py b/relik/reader/data/relik_reader_data.py new file mode 100644 index 0000000000000000000000000000000000000000..3c65646f99d37cdcf03ab7005c83eb0069da168c --- /dev/null +++ b/relik/reader/data/relik_reader_data.py @@ -0,0 +1,965 @@ +import logging +from typing import ( + Any, + Callable, + Dict, + Generator, + Iterable, + Iterator, + List, + NamedTuple, + Optional, + Tuple, + Union, +) + +import numpy as np +import torch +from torch.utils.data import IterableDataset +from tqdm import tqdm +from transformers import AutoTokenizer, PreTrainedTokenizer + +from relik.reader.data.relik_reader_data_utils import ( + add_noise_to_value, + batchify, + chunks, + flatten, +) +from relik.reader.data.relik_reader_sample import ( + RelikReaderSample, + load_relik_reader_samples, +) +from relik.reader.utils.special_symbols import NME_SYMBOL + +logger = logging.getLogger(__name__) + + +def preprocess_dataset( + input_dataset: Iterable[dict], + transformer_model: str, + add_topic: bool, +) -> Iterable[dict]: + tokenizer = AutoTokenizer.from_pretrained(transformer_model) + for dataset_elem in tqdm(input_dataset, desc="Preprocessing input dataset"): + if len(dataset_elem["tokens"]) == 0: + print( + f"Dataset element with doc id: {dataset_elem['doc_id']}", + f"and offset {dataset_elem['offset']} does not contain any token", + "Skipping it", + ) + continue + + new_dataset_elem = dict( + doc_id=dataset_elem["doc_id"], + offset=dataset_elem["offset"], + ) + + tokenization_out = tokenizer( + dataset_elem["tokens"], + return_offsets_mapping=True, + add_special_tokens=False, + ) + + window_tokens = tokenization_out.input_ids + window_tokens = flatten(window_tokens) + + offsets_mapping = [ + [ + ( + ss + dataset_elem["token2char_start"][str(i)], + se + dataset_elem["token2char_start"][str(i)], + ) + for ss, se in tokenization_out.offset_mapping[i] + ] + for i in range(len(dataset_elem["tokens"])) + ] + + offsets_mapping = flatten(offsets_mapping) + + assert len(offsets_mapping) == len(window_tokens) + + window_tokens = ( + [tokenizer.cls_token_id] + window_tokens + [tokenizer.sep_token_id] + ) + + topic_offset = 0 + if add_topic: + topic_tokens = tokenizer( + dataset_elem["doc_topic"], add_special_tokens=False + ).input_ids + topic_offset = len(topic_tokens) + new_dataset_elem["topic_tokens"] = topic_offset + window_tokens = window_tokens[:1] + topic_tokens + window_tokens[1:] + + new_dataset_elem.update( + dict( + tokens=window_tokens, + token2char_start={ + str(i): s + for i, (s, _) in enumerate(offsets_mapping, start=topic_offset) + }, + token2char_end={ + str(i): e + for i, (_, e) in enumerate(offsets_mapping, start=topic_offset) + }, + window_candidates=dataset_elem["window_candidates"], + window_candidates_scores=dataset_elem.get("window_candidates_scores"), + ) + ) + + if "window_labels" in dataset_elem: + window_labels = [ + (s, e, l.replace("_", " ")) for s, e, l in dataset_elem["window_labels"] + ] + + new_dataset_elem["window_labels"] = window_labels + + if not all( + [ + s in new_dataset_elem["token2char_start"].values() + for s, _, _ in new_dataset_elem["window_labels"] + ] + ): + print( + "Mismatching token start char mapping with labels", + new_dataset_elem["token2char_start"], + new_dataset_elem["window_labels"], + dataset_elem["tokens"], + ) + continue + + if not all( + [ + e in new_dataset_elem["token2char_end"].values() + for _, e, _ in new_dataset_elem["window_labels"] + ] + ): + print( + "Mismatching token end char mapping with labels", + new_dataset_elem["token2char_end"], + new_dataset_elem["window_labels"], + dataset_elem["tokens"], + ) + continue + + yield new_dataset_elem + + +def preprocess_sample( + relik_sample: RelikReaderSample, + tokenizer, + lowercase_policy: float, + add_topic: bool = False, +) -> None: + if len(relik_sample.tokens) == 0: + return + + if lowercase_policy > 0: + lc_tokens = np.random.uniform(0, 1, len(relik_sample.tokens)) < lowercase_policy + relik_sample.tokens = [ + t.lower() if lc else t for t, lc in zip(relik_sample.tokens, lc_tokens) + ] + + tokenization_out = tokenizer( + relik_sample.tokens, + return_offsets_mapping=True, + add_special_tokens=False, + ) + + window_tokens = tokenization_out.input_ids + window_tokens = flatten(window_tokens) + + offsets_mapping = [ + [ + ( + ss + relik_sample.token2char_start[str(i)], + se + relik_sample.token2char_start[str(i)], + ) + for ss, se in tokenization_out.offset_mapping[i] + ] + for i in range(len(relik_sample.tokens)) + ] + + offsets_mapping = flatten(offsets_mapping) + + assert len(offsets_mapping) == len(window_tokens) + + window_tokens = [tokenizer.cls_token_id] + window_tokens + [tokenizer.sep_token_id] + + topic_offset = 0 + if add_topic: + topic_tokens = tokenizer( + relik_sample.doc_topic, add_special_tokens=False + ).input_ids + topic_offset = len(topic_tokens) + relik_sample.topic_tokens = topic_offset + window_tokens = window_tokens[:1] + topic_tokens + window_tokens[1:] + + relik_sample._d.update( + dict( + tokens=window_tokens, + token2char_start={ + str(i): s + for i, (s, _) in enumerate(offsets_mapping, start=topic_offset) + }, + token2char_end={ + str(i): e + for i, (_, e) in enumerate(offsets_mapping, start=topic_offset) + }, + ) + ) + + if "window_labels" in relik_sample._d: + relik_sample.window_labels = [ + (s, e, l.replace("_", " ")) for s, e, l in relik_sample.window_labels + ] + + +class TokenizationOutput(NamedTuple): + input_ids: torch.Tensor + attention_mask: torch.Tensor + token_type_ids: torch.Tensor + prediction_mask: torch.Tensor + special_symbols_mask: torch.Tensor + + +class RelikDataset(IterableDataset): + def __init__( + self, + dataset_path: Optional[str], + materialize_samples: bool, + transformer_model: Union[str, PreTrainedTokenizer], + special_symbols: List[str], + shuffle_candidates: Optional[Union[bool, float]] = False, + for_inference: bool = False, + noise_param: float = 0.1, + sorting_fields: Optional[str] = None, + tokens_per_batch: int = 2048, + batch_size: int = None, + max_batch_size: int = 128, + section_size: int = 50_000, + prebatch: bool = True, + random_drop_gold_candidates: float = 0.0, + use_nme: bool = True, + max_subwords_per_candidate: bool = 22, + mask_by_instances: bool = False, + min_length: int = 5, + max_length: int = 2048, + model_max_length: int = 1000, + split_on_cand_overload: bool = True, + skip_empty_training_samples: bool = False, + drop_last: bool = False, + samples: Optional[Iterator[RelikReaderSample]] = None, + lowercase_policy: float = 0.0, + **kwargs, + ): + super().__init__(**kwargs) + self.dataset_path = dataset_path + self.materialize_samples = materialize_samples + self.samples: Optional[List[RelikReaderSample]] = None + if self.materialize_samples: + self.samples = list() + + if isinstance(transformer_model, str): + self.tokenizer = self._build_tokenizer(transformer_model, special_symbols) + else: + self.tokenizer = transformer_model + self.special_symbols = special_symbols + self.shuffle_candidates = shuffle_candidates + self.for_inference = for_inference + self.noise_param = noise_param + self.batching_fields = ["input_ids"] + self.sorting_fields = ( + sorting_fields if sorting_fields is not None else self.batching_fields + ) + + self.tokens_per_batch = tokens_per_batch + self.batch_size = batch_size + self.max_batch_size = max_batch_size + self.section_size = section_size + self.prebatch = prebatch + + self.random_drop_gold_candidates = random_drop_gold_candidates + self.use_nme = use_nme + self.max_subwords_per_candidate = max_subwords_per_candidate + self.mask_by_instances = mask_by_instances + self.min_length = min_length + self.max_length = max_length + self.model_max_length = ( + model_max_length + if model_max_length < self.tokenizer.model_max_length + else self.tokenizer.model_max_length + ) + + # retrocompatibility workaround + self.transformer_model = ( + transformer_model + if isinstance(transformer_model, str) + else transformer_model.name_or_path + ) + self.split_on_cand_overload = split_on_cand_overload + self.skip_empty_training_samples = skip_empty_training_samples + self.drop_last = drop_last + self.lowercase_policy = lowercase_policy + self.samples = samples + + def _build_tokenizer(self, transformer_model: str, special_symbols: List[str]): + return AutoTokenizer.from_pretrained( + transformer_model, + additional_special_tokens=[ss for ss in special_symbols], + add_prefix_space=True, + ) + + @property + def fields_batcher(self) -> Dict[str, Union[None, Callable[[list], Any]]]: + fields_batchers = { + "input_ids": lambda x: batchify( + x, padding_value=self.tokenizer.pad_token_id + ), + "attention_mask": lambda x: batchify(x, padding_value=0), + "token_type_ids": lambda x: batchify(x, padding_value=0), + "prediction_mask": lambda x: batchify(x, padding_value=1), + "global_attention": lambda x: batchify(x, padding_value=0), + "token2word": None, + "sample": None, + "special_symbols_mask": lambda x: batchify(x, padding_value=False), + "start_labels": lambda x: batchify(x, padding_value=-100), + "end_labels": lambda x: batchify(x, padding_value=-100), + "predictable_candidates_symbols": None, + "predictable_candidates": None, + "patch_offset": None, + "optimus_labels": None, + } + + if "roberta" in self.transformer_model: + del fields_batchers["token_type_ids"] + + return fields_batchers + + def _build_input_ids( + self, sentence_input_ids: List[int], candidates_input_ids: List[List[int]] + ) -> List[int]: + return ( + [self.tokenizer.cls_token_id] + + sentence_input_ids + + [self.tokenizer.sep_token_id] + + flatten(candidates_input_ids) + + [self.tokenizer.sep_token_id] + ) + + def _get_special_symbols_mask(self, input_ids: torch.Tensor) -> torch.Tensor: + special_symbols_mask = input_ids >= ( + len(self.tokenizer) - len(self.special_symbols) + ) + special_symbols_mask[0] = True + return special_symbols_mask + + def _build_tokenizer_essentials( + self, input_ids, original_sequence, sample + ) -> TokenizationOutput: + input_ids = torch.tensor(input_ids, dtype=torch.long) + attention_mask = torch.ones_like(input_ids) + + total_sequence_len = len(input_ids) + predictable_sentence_len = len(original_sequence) + + # token type ids + token_type_ids = torch.cat( + [ + input_ids.new_zeros( + predictable_sentence_len + 2 + ), # original sentence bpes + CLS and SEP + input_ids.new_ones(total_sequence_len - predictable_sentence_len - 2), + ] + ) + + # prediction mask -> boolean on tokens that are predictable + + prediction_mask = torch.tensor( + [1] + + ([0] * predictable_sentence_len) + + ([1] * (total_sequence_len - predictable_sentence_len - 1)) + ) + + # add topic tokens to the prediction mask so that they cannot be predicted + # or optimized during training + topic_tokens = getattr(sample, "topic_tokens", None) + if topic_tokens is not None: + prediction_mask[1 : 1 + topic_tokens] = 1 + + # If mask by instances is active the prediction mask is applied to everything + # that is not indicated as an instance in the training set. + if self.mask_by_instances: + char_start2token = { + cs: int(tok) for tok, cs in sample.token2char_start.items() + } + char_end2token = {ce: int(tok) for tok, ce in sample.token2char_end.items()} + instances_mask = torch.ones_like(prediction_mask) + for _, span_info in sample.instance_id2span_data.items(): + span_info = span_info[0] + token_start = char_start2token[span_info[0]] + 1 # +1 for the CLS + token_end = char_end2token[span_info[1]] + 1 # +1 for the CLS + instances_mask[token_start : token_end + 1] = 0 + + prediction_mask += instances_mask + prediction_mask[prediction_mask > 1] = 1 + + assert len(prediction_mask) == len(input_ids) + + # special symbols mask + special_symbols_mask = self._get_special_symbols_mask(input_ids) + + return TokenizationOutput( + input_ids, + attention_mask, + token_type_ids, + prediction_mask, + special_symbols_mask, + ) + + def _build_labels( + self, + sample, + tokenization_output: TokenizationOutput, + predictable_candidates: List[str], + ) -> Tuple[torch.Tensor, torch.Tensor]: + start_labels = [0] * len(tokenization_output.input_ids) + end_labels = [0] * len(tokenization_output.input_ids) + + char_start2token = {v: int(k) for k, v in sample.token2char_start.items()} + char_end2token = {v: int(k) for k, v in sample.token2char_end.items()} + for cs, ce, gold_candidate_title in sample.window_labels: + if gold_candidate_title not in predictable_candidates: + if self.use_nme: + gold_candidate_title = NME_SYMBOL + else: + continue + # +1 is to account for the CLS token + start_bpe = char_start2token[cs] + 1 + end_bpe = char_end2token[ce] + 1 + class_index = predictable_candidates.index(gold_candidate_title) + if ( + start_labels[start_bpe] == 0 and end_labels[end_bpe] == 0 + ): # prevent from having entities that ends with the same label + start_labels[start_bpe] = class_index + 1 # +1 for the NONE class + end_labels[end_bpe] = class_index + 1 # +1 for the NONE class + else: + print( + "Found entity with the same last subword, it will not be included." + ) + print( + cs, + ce, + gold_candidate_title, + start_labels, + end_labels, + sample.doc_id, + ) + + ignored_labels_indices = tokenization_output.prediction_mask == 1 + + start_labels = torch.tensor(start_labels, dtype=torch.long) + start_labels[ignored_labels_indices] = -100 + + end_labels = torch.tensor(end_labels, dtype=torch.long) + end_labels[ignored_labels_indices] = -100 + + return start_labels, end_labels + + def produce_sample_bag( + self, sample, predictable_candidates: List[str], candidates_starting_offset: int + ) -> Optional[Tuple[dict, list, int]]: + # input sentence tokenization + input_subwords = sample.tokens[1:-1] # removing special tokens + candidates_symbols = self.special_symbols[candidates_starting_offset:] + + predictable_candidates = list(predictable_candidates) + original_predictable_candidates = list(predictable_candidates) + + # add NME as a possible candidate + if self.use_nme: + predictable_candidates.insert(0, NME_SYMBOL) + + # candidates encoding + candidates_symbols = candidates_symbols[: len(predictable_candidates)] + candidates_encoding_result = self.tokenizer.batch_encode_plus( + [ + "{} {}".format(cs, ct) if ct != NME_SYMBOL else NME_SYMBOL + for cs, ct in zip(candidates_symbols, predictable_candidates) + ], + add_special_tokens=False, + ).input_ids + + if ( + self.max_subwords_per_candidate is not None + and self.max_subwords_per_candidate > 0 + ): + candidates_encoding_result = [ + cer[: self.max_subwords_per_candidate] + for cer in candidates_encoding_result + ] + + # drop candidates if the number of input tokens is too long for the model + if ( + sum(map(len, candidates_encoding_result)) + + len(input_subwords) + + 20 # + 20 special tokens + > self.model_max_length + ): + acceptable_tokens_from_candidates = ( + self.model_max_length - 20 - len(input_subwords) + ) + i = 0 + cum_len = 0 + while ( + cum_len + len(candidates_encoding_result[i]) + < acceptable_tokens_from_candidates + ): + cum_len += len(candidates_encoding_result[i]) + i += 1 + + candidates_encoding_result = candidates_encoding_result[:i] + candidates_symbols = candidates_symbols[:i] + predictable_candidates = predictable_candidates[:i] + + # final input_ids build + input_ids = self._build_input_ids( + sentence_input_ids=input_subwords, + candidates_input_ids=candidates_encoding_result, + ) + + # complete input building (e.g. attention / prediction mask) + tokenization_output = self._build_tokenizer_essentials( + input_ids, input_subwords, sample + ) + + output_dict = { + "input_ids": tokenization_output.input_ids, + "attention_mask": tokenization_output.attention_mask, + "token_type_ids": tokenization_output.token_type_ids, + "prediction_mask": tokenization_output.prediction_mask, + "special_symbols_mask": tokenization_output.special_symbols_mask, + "sample": sample, + "predictable_candidates_symbols": candidates_symbols, + "predictable_candidates": predictable_candidates, + } + + # labels creation + if sample.window_labels is not None: + start_labels, end_labels = self._build_labels( + sample, + tokenization_output, + predictable_candidates, + ) + output_dict.update(start_labels=start_labels, end_labels=end_labels) + + if ( + "roberta" in self.transformer_model + or "longformer" in self.transformer_model + ): + del output_dict["token_type_ids"] + + predictable_candidates_set = set(predictable_candidates) + remaining_candidates = [ + candidate + for candidate in original_predictable_candidates + if candidate not in predictable_candidates_set + ] + total_used_candidates = ( + candidates_starting_offset + + len(predictable_candidates) + - (1 if self.use_nme else 0) + ) + + if self.use_nme: + assert predictable_candidates[0] == NME_SYMBOL + + return output_dict, remaining_candidates, total_used_candidates + + def __iter__(self): + dataset_iterator = self.dataset_iterator_func() + + current_dataset_elements = [] + + i = None + for i, dataset_elem in enumerate(dataset_iterator, start=1): + if ( + self.section_size is not None + and len(current_dataset_elements) == self.section_size + ): + for batch in self.materialize_batches(current_dataset_elements): + yield batch + current_dataset_elements = [] + + current_dataset_elements.append(dataset_elem) + + if i % 50_000 == 0: + logger.info(f"Processed: {i} number of elements") + + if len(current_dataset_elements) != 0: + for batch in self.materialize_batches(current_dataset_elements): + yield batch + + if i is not None: + logger.info(f"Dataset finished: {i} number of elements processed") + else: + logger.warning("Dataset empty") + + def dataset_iterator_func(self): + skipped_instances = 0 + data_samples = ( + load_relik_reader_samples(self.dataset_path) + if self.samples is None + else self.samples + ) + for sample in data_samples: + preprocess_sample( + sample, self.tokenizer, lowercase_policy=self.lowercase_policy + ) + current_patch = 0 + sample_bag, used_candidates = None, None + remaining_candidates = list(sample.window_candidates) + + if not self.for_inference: + # randomly drop gold candidates at training time + if ( + self.random_drop_gold_candidates > 0.0 + and np.random.uniform() < self.random_drop_gold_candidates + and len(set(ct for _, _, ct in sample.window_labels)) > 1 + ): + # selecting candidates to drop + np.random.shuffle(sample.window_labels) + n_dropped_candidates = np.random.randint( + 0, len(sample.window_labels) - 1 + ) + dropped_candidates = [ + label_elem[-1] + for label_elem in sample.window_labels[:n_dropped_candidates] + ] + dropped_candidates = set(dropped_candidates) + + # saving NMEs because they should not be dropped + if NME_SYMBOL in dropped_candidates: + dropped_candidates.remove(NME_SYMBOL) + + # sample update + sample.window_labels = [ + (s, e, _l) + if _l not in dropped_candidates + else (s, e, NME_SYMBOL) + for s, e, _l in sample.window_labels + ] + remaining_candidates = [ + wc + for wc in remaining_candidates + if wc not in dropped_candidates + ] + + # shuffle candidates + if ( + isinstance(self.shuffle_candidates, bool) + and self.shuffle_candidates + ) or ( + isinstance(self.shuffle_candidates, float) + and np.random.uniform() < self.shuffle_candidates + ): + np.random.shuffle(remaining_candidates) + + while len(remaining_candidates) != 0: + sample_bag = self.produce_sample_bag( + sample, + predictable_candidates=remaining_candidates, + candidates_starting_offset=used_candidates + if used_candidates is not None + else 0, + ) + if sample_bag is not None: + sample_bag, remaining_candidates, used_candidates = sample_bag + if ( + self.for_inference + or not self.skip_empty_training_samples + or ( + ( + sample_bag.get("start_labels") is not None + and torch.any(sample_bag["start_labels"] > 1).item() + ) + or ( + sample_bag.get("optimus_labels") is not None + and len(sample_bag["optimus_labels"]) > 0 + ) + ) + ): + sample_bag["patch_offset"] = current_patch + current_patch += 1 + yield sample_bag + else: + skipped_instances += 1 + if skipped_instances % 1000 == 0 and skipped_instances != 0: + logger.info( + f"Skipped {skipped_instances} instances since they did not have any gold labels..." + ) + + # Just use the first fitting candidates if split on + # cand is not True + if not self.split_on_cand_overload: + break + + def preshuffle_elements(self, dataset_elements: List): + # This shuffling is done so that when using the sorting function, + # if it is deterministic given a collection and its order, we will + # make the whole operation not deterministic anymore. + # Basically, the aim is not to build every time the same batches. + if not self.for_inference: + dataset_elements = np.random.permutation(dataset_elements) + + sorting_fn = ( + lambda elem: add_noise_to_value( + sum(len(elem[k]) for k in self.sorting_fields), + noise_param=self.noise_param, + ) + if not self.for_inference + else sum(len(elem[k]) for k in self.sorting_fields) + ) + + dataset_elements = sorted(dataset_elements, key=sorting_fn) + + if self.for_inference: + return dataset_elements + + ds = list(chunks(dataset_elements, 64)) + np.random.shuffle(ds) + return flatten(ds) + + def materialize_batches( + self, dataset_elements: List[Dict[str, Any]] + ) -> Generator[Dict[str, Any], None, None]: + if self.prebatch: + dataset_elements = self.preshuffle_elements(dataset_elements) + + current_batch = [] + + # function that creates a batch from the 'current_batch' list + def output_batch() -> Dict[str, Any]: + assert ( + len( + set([len(elem["predictable_candidates"]) for elem in current_batch]) + ) + == 1 + ), " ".join( + map( + str, [len(elem["predictable_candidates"]) for elem in current_batch] + ) + ) + + batch_dict = dict() + + de_values_by_field = { + fn: [de[fn] for de in current_batch if fn in de] + for fn in self.fields_batcher + } + + # in case you provide fields batchers but in the batch + # there are no elements for that field + de_values_by_field = { + fn: fvs for fn, fvs in de_values_by_field.items() if len(fvs) > 0 + } + + assert len(set([len(v) for v in de_values_by_field.values()])) + + # todo: maybe we should report the user about possible + # fields filtering due to "None" instances + de_values_by_field = { + fn: fvs + for fn, fvs in de_values_by_field.items() + if all([fv is not None for fv in fvs]) + } + + for field_name, field_values in de_values_by_field.items(): + field_batch = ( + self.fields_batcher[field_name](field_values) + if self.fields_batcher[field_name] is not None + else field_values + ) + + batch_dict[field_name] = field_batch + + return batch_dict + + max_len_discards, min_len_discards = 0, 0 + + should_token_batch = self.batch_size is None + + curr_pred_elements = -1 + for de in dataset_elements: + if ( + should_token_batch + and self.max_batch_size != -1 + and len(current_batch) == self.max_batch_size + ) or (not should_token_batch and len(current_batch) == self.batch_size): + yield output_batch() + current_batch = [] + curr_pred_elements = -1 + + too_long_fields = [ + k + for k in de + if self.max_length != -1 + and torch.is_tensor(de[k]) + and len(de[k]) > self.max_length + ] + if len(too_long_fields) > 0: + max_len_discards += 1 + continue + + too_short_fields = [ + k + for k in de + if self.min_length != -1 + and torch.is_tensor(de[k]) + and len(de[k]) < self.min_length + ] + if len(too_short_fields) > 0: + min_len_discards += 1 + continue + + if should_token_batch: + de_len = sum(len(de[k]) for k in self.batching_fields) + + future_max_len = max( + de_len, + max( + [ + sum(len(bde[k]) for k in self.batching_fields) + for bde in current_batch + ], + default=0, + ), + ) + + future_tokens_per_batch = future_max_len * (len(current_batch) + 1) + + num_predictable_candidates = len(de["predictable_candidates"]) + + if len(current_batch) > 0 and ( + future_tokens_per_batch >= self.tokens_per_batch + or ( + num_predictable_candidates != curr_pred_elements + and curr_pred_elements != -1 + ) + ): + yield output_batch() + current_batch = [] + + current_batch.append(de) + curr_pred_elements = len(de["predictable_candidates"]) + + if len(current_batch) != 0 and not self.drop_last: + yield output_batch() + + if max_len_discards > 0: + if self.for_inference: + logger.warning( + f"WARNING: Inference mode is True but {max_len_discards} samples longer than max length were " + f"found. The {max_len_discards} samples will be DISCARDED. If you are doing some kind of evaluation" + f", this can INVALIDATE results. This might happen if the max length was not set to -1 or if the " + f"sample length exceeds the maximum length supported by the current model." + ) + else: + logger.warning( + f"During iteration, {max_len_discards} elements were " + f"discarded since longer than max length {self.max_length}" + ) + + if min_len_discards > 0: + if self.for_inference: + logger.warning( + f"WARNING: Inference mode is True but {min_len_discards} samples shorter than min length were " + f"found. The {min_len_discards} samples will be DISCARDED. If you are doing some kind of evaluation" + f", this can INVALIDATE results. This might happen if the min length was not set to -1 or if the " + f"sample length is shorter than the minimum length supported by the current model." + ) + else: + logger.warning( + f"During iteration, {min_len_discards} elements were " + f"discarded since shorter than min length {self.min_length}" + ) + + @staticmethod + def convert_tokens_to_char_annotations( + sample: RelikReaderSample, + remove_nmes: bool = True, + ) -> RelikReaderSample: + """ + Converts the token annotations to char annotations. + + Args: + sample (:obj:`RelikReaderSample`): + The sample to convert. + remove_nmes (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether to remove the NMEs from the annotations. + Returns: + :obj:`RelikReaderSample`: The converted sample. + """ + char_annotations = set() + for ( + predicted_entity, + predicted_spans, + ) in sample.predicted_window_labels.items(): + if predicted_entity == NME_SYMBOL and remove_nmes: + continue + + for span_start, span_end in predicted_spans: + span_start = sample.token2char_start[str(span_start)] + span_end = sample.token2char_end[str(span_end)] + + char_annotations.add((span_start, span_end, predicted_entity)) + + char_probs_annotations = dict() + for ( + span_start, + span_end, + ), candidates_probs in sample.span_title_probabilities.items(): + span_start = sample.token2char_start[str(span_start)] + span_end = sample.token2char_end[str(span_end)] + char_probs_annotations[(span_start, span_end)] = { + title for title, _ in candidates_probs + } + + sample.predicted_window_labels_chars = char_annotations + sample.probs_window_labels_chars = char_probs_annotations + + return sample + + @staticmethod + def merge_patches_predictions(sample) -> None: + sample._d["predicted_window_labels"] = dict() + predicted_window_labels = sample._d["predicted_window_labels"] + + sample._d["span_title_probabilities"] = dict() + span_title_probabilities = sample._d["span_title_probabilities"] + + span2title = dict() + for _, patch_info in sorted(sample.patches.items(), key=lambda x: x[0]): + # selecting span predictions + for predicted_title, predicted_spans in patch_info[ + "predicted_window_labels" + ].items(): + for pred_span in predicted_spans: + pred_span = tuple(pred_span) + curr_title = span2title.get(pred_span) + if curr_title is None or curr_title == NME_SYMBOL: + span2title[pred_span] = predicted_title + # else: + # print("Merging at patch level") + + # selecting span predictions probability + for predicted_span, titles_probabilities in patch_info[ + "span_title_probabilities" + ].items(): + if predicted_span not in span_title_probabilities: + span_title_probabilities[predicted_span] = titles_probabilities + + for span, title in span2title.items(): + if title not in predicted_window_labels: + predicted_window_labels[title] = list() + predicted_window_labels[title].append(span) diff --git a/relik/reader/data/relik_reader_data_utils.py b/relik/reader/data/relik_reader_data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3c7446bee296d14653a35895bf9ec8071c87e5af --- /dev/null +++ b/relik/reader/data/relik_reader_data_utils.py @@ -0,0 +1,51 @@ +from typing import List + +import numpy as np +import torch + + +def flatten(lsts: List[list]) -> list: + acc_lst = list() + for lst in lsts: + acc_lst.extend(lst) + return acc_lst + + +def batchify(tensors: List[torch.Tensor], padding_value: int = 0) -> torch.Tensor: + return torch.nn.utils.rnn.pad_sequence( + tensors, batch_first=True, padding_value=padding_value + ) + + +def batchify_matrices(tensors: List[torch.Tensor], padding_value: int) -> torch.Tensor: + x = max([t.shape[0] for t in tensors]) + y = max([t.shape[1] for t in tensors]) + out_matrix = torch.zeros((len(tensors), x, y)) + out_matrix += padding_value + for i, tensor in enumerate(tensors): + out_matrix[i][0 : tensor.shape[0], 0 : tensor.shape[1]] = tensor + return out_matrix + + +def batchify_tensor(tensors: List[torch.Tensor], padding_value: int) -> torch.Tensor: + x = max([t.shape[0] for t in tensors]) + y = max([t.shape[1] for t in tensors]) + rest = tensors[0].shape[2] + out_matrix = torch.zeros((len(tensors), x, y, rest)) + out_matrix += padding_value + for i, tensor in enumerate(tensors): + out_matrix[i][0 : tensor.shape[0], 0 : tensor.shape[1], :] = tensor + return out_matrix + + +def chunks(lst: list, chunk_size: int) -> List[list]: + chunks_acc = list() + for i in range(0, len(lst), chunk_size): + chunks_acc.append(lst[i : i + chunk_size]) + return chunks_acc + + +def add_noise_to_value(value: int, noise_param: float): + noise_value = value * noise_param + noise = np.random.uniform(-noise_value, noise_value) + return max(1, value + noise) diff --git a/relik/reader/data/relik_reader_sample.py b/relik/reader/data/relik_reader_sample.py new file mode 100644 index 0000000000000000000000000000000000000000..3d7570411fbb939f99d73d1cc3318b21552bc7c2 --- /dev/null +++ b/relik/reader/data/relik_reader_sample.py @@ -0,0 +1,49 @@ +import json +from typing import Iterable + + +class RelikReaderSample: + def __init__(self, **kwargs): + super().__setattr__("_d", {}) + self._d = kwargs + + def __getattribute__(self, item): + return super(RelikReaderSample, self).__getattribute__(item) + + def __getattr__(self, item): + if item.startswith("__") and item.endswith("__"): + # this is likely some python library-specific variable (such as __deepcopy__ for copy) + # better follow standard behavior here + raise AttributeError(item) + elif item in self._d: + return self._d[item] + else: + return None + + def __setattr__(self, key, value): + if key in self._d: + self._d[key] = value + else: + super().__setattr__(key, value) + + def to_jsons(self) -> str: + if "predicted_window_labels" in self._d: + new_obj = { + k: v + for k, v in self._d.items() + if k != "predicted_window_labels" and k != "span_title_probabilities" + } + new_obj["predicted_window_labels"] = [ + [ss, se, pred_title] + for (ss, se), pred_title in self.predicted_window_labels_chars + ] + else: + return json.dumps(self._d) + + +def load_relik_reader_samples(path: str) -> Iterable[RelikReaderSample]: + with open(path) as f: + for line in f: + jsonl_line = json.loads(line.strip()) + relik_reader_sample = RelikReaderSample(**jsonl_line) + yield relik_reader_sample diff --git a/relik/reader/lightning_modules/__init__.py b/relik/reader/lightning_modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/relik/reader/lightning_modules/relik_reader_pl_module.py b/relik/reader/lightning_modules/relik_reader_pl_module.py new file mode 100644 index 0000000000000000000000000000000000000000..4e66e87b6360fe9b4a72477fd5a7fe6295b53ae9 --- /dev/null +++ b/relik/reader/lightning_modules/relik_reader_pl_module.py @@ -0,0 +1,50 @@ +from typing import Any, Optional + +import lightning +from lightning.pytorch.utilities.types import STEP_OUTPUT, OptimizerLRScheduler + +from relik.reader.relik_reader_core import RelikReaderCoreModel + + +class RelikReaderPLModule(lightning.LightningModule): + def __init__( + self, + cfg: dict, + transformer_model: str, + additional_special_symbols: int, + num_layers: Optional[int] = None, + activation: str = "gelu", + linears_hidden_size: Optional[int] = 512, + use_last_k_layers: int = 1, + training: bool = False, + *args: Any, + **kwargs: Any + ): + super().__init__(*args, **kwargs) + self.save_hyperparameters() + self.relik_reader_core_model = RelikReaderCoreModel( + transformer_model, + additional_special_symbols, + num_layers, + activation, + linears_hidden_size, + use_last_k_layers, + training=training, + ) + self.optimizer_factory = None + + def training_step(self, batch: dict, *args: Any, **kwargs: Any) -> STEP_OUTPUT: + relik_output = self.relik_reader_core_model(**batch) + self.log("train-loss", relik_output["loss"]) + return relik_output["loss"] + + def validation_step( + self, batch: dict, *args: Any, **kwargs: Any + ) -> Optional[STEP_OUTPUT]: + return + + def set_optimizer_factory(self, optimizer_factory) -> None: + self.optimizer_factory = optimizer_factory + + def configure_optimizers(self) -> OptimizerLRScheduler: + return self.optimizer_factory(self.relik_reader_core_model) diff --git a/relik/reader/lightning_modules/relik_reader_re_pl_module.py b/relik/reader/lightning_modules/relik_reader_re_pl_module.py new file mode 100644 index 0000000000000000000000000000000000000000..ad1d2084f10d700d68b30e68dc29cbace1450f9b --- /dev/null +++ b/relik/reader/lightning_modules/relik_reader_re_pl_module.py @@ -0,0 +1,54 @@ +from typing import Any, Optional + +import lightning +from lightning.pytorch.utilities.types import STEP_OUTPUT, OptimizerLRScheduler + +from relik.reader.relik_reader_re import RelikReaderForTripletExtraction + + +class RelikReaderREPLModule(lightning.LightningModule): + def __init__( + self, + cfg: dict, + transformer_model: str, + additional_special_symbols: int, + num_layers: Optional[int] = None, + activation: str = "gelu", + linears_hidden_size: Optional[int] = 512, + use_last_k_layers: int = 1, + training: bool = False, + *args: Any, + **kwargs: Any + ): + super().__init__(*args, **kwargs) + self.save_hyperparameters() + + self.relik_reader_re_model = RelikReaderForTripletExtraction( + transformer_model, + additional_special_symbols, + num_layers, + activation, + linears_hidden_size, + use_last_k_layers, + training=training, + ) + self.optimizer_factory = None + + def training_step(self, batch: dict, *args: Any, **kwargs: Any) -> STEP_OUTPUT: + relik_output = self.relik_reader_re_model(**batch) + self.log("train-loss", relik_output["loss"]) + self.log("train-start_loss", relik_output["ned_start_loss"]) + self.log("train-end_loss", relik_output["ned_end_loss"]) + self.log("train-relation_loss", relik_output["re_loss"]) + return relik_output["loss"] + + def validation_step( + self, batch: dict, *args: Any, **kwargs: Any + ) -> Optional[STEP_OUTPUT]: + return + + def set_optimizer_factory(self, optimizer_factory) -> None: + self.optimizer_factory = optimizer_factory + + def configure_optimizers(self) -> OptimizerLRScheduler: + return self.optimizer_factory(self.relik_reader_re_model) diff --git a/relik/reader/pytorch_modules/__init__.py b/relik/reader/pytorch_modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/relik/reader/pytorch_modules/base.py b/relik/reader/pytorch_modules/base.py new file mode 100644 index 0000000000000000000000000000000000000000..75db716d53d6dbdb9e9f95b63dfa1d8619769bbf --- /dev/null +++ b/relik/reader/pytorch_modules/base.py @@ -0,0 +1,248 @@ +import logging +import os +from pathlib import Path +from typing import Any, Dict, List + +import torch +import transformers as tr +from torch.utils.data import IterableDataset +from transformers import AutoConfig + +from relik.common.log import get_console_logger, get_logger +from relik.common.utils import get_callable_from_string +from relik.reader.pytorch_modules.hf.modeling_relik import ( + RelikReaderConfig, + RelikReaderSample, +) + +console_logger = get_console_logger() +logger = get_logger(__name__, level=logging.INFO) + + +class RelikReaderBase(torch.nn.Module): + default_reader_class: str | None = None + default_data_class: str | None = None + + def __init__( + self, + transformer_model: str | tr.PreTrainedModel | None = None, + additional_special_symbols: int = 0, + num_layers: int | None = None, + activation: str = "gelu", + linears_hidden_size: int | None = 512, + use_last_k_layers: int = 1, + training: bool = False, + device: str | torch.device | None = None, + precision: int = 32, + tokenizer: str | tr.PreTrainedTokenizer | None = None, + dataset: IterableDataset | str | None = None, + default_reader_class: tr.PreTrainedModel | str | None = None, + **kwargs, + ) -> None: + super().__init__() + + self.default_reader_class = default_reader_class or self.default_reader_class + + if self.default_reader_class is None: + raise ValueError("You must specify a default reader class.") + + # get the callable for the default reader class + self.default_reader_class: tr.PreTrainedModel = get_callable_from_string( + self.default_reader_class + ) + + if isinstance(transformer_model, str): + config = AutoConfig.from_pretrained( + transformer_model, trust_remote_code=True + ) + if "relik-reader" in config.model_type: + transformer_model = self.default_reader_class.from_pretrained( + transformer_model, **kwargs + ) + else: + reader_config = RelikReaderConfig( + transformer_model=transformer_model, + additional_special_symbols=additional_special_symbols, + num_layers=num_layers, + activation=activation, + linears_hidden_size=linears_hidden_size, + use_last_k_layers=use_last_k_layers, + training=training, + ) + transformer_model = self.default_reader_class(reader_config) + + self.relik_reader_model = transformer_model + self.relik_reader_model_config = self.relik_reader_model.config + + # get the tokenizer + self._tokenizer = tokenizer + + # and instantiate the dataset class + self.dataset: IterableDataset | None = dataset + + # move the model to the device + self.to(device or torch.device("cpu")) + + # set the precision + self.precision = precision + + def forward(self, **kwargs) -> Dict[str, Any]: + return self.relik_reader_model(**kwargs) + + def _read(self, *args, **kwargs) -> Any: + raise NotImplementedError + + @torch.no_grad() + @torch.inference_mode() + def read( + self, + text: List[str] | List[List[str]] | None = None, + samples: List[RelikReaderSample] | None = None, + input_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + token_type_ids: torch.Tensor | None = None, + prediction_mask: torch.Tensor | None = None, + special_symbols_mask: torch.Tensor | None = None, + candidates: List[List[str]] | None = None, + max_length: int = 1000, + max_batch_size: int = 128, + token_batch_size: int = 2048, + precision: int | str | None = None, + progress_bar: bool = False, + *args, + **kwargs, + ) -> List[RelikReaderSample] | List[List[RelikReaderSample]]: + """ + Reads the given text. + + Args: + text (:obj:`List[str]` or :obj:`List[List[str]]`, `optional`): + The text to read in tokens. If a list of list of tokens is provided, each + inner list is considered a sentence. + samples (:obj:`List[RelikReaderSample]`, `optional`): + The samples to read. If provided, `text` and `candidates` are ignored. + input_ids (:obj:`torch.Tensor`, `optional`): + The input ids of the text. + attention_mask (:obj:`torch.Tensor`, `optional`): + The attention mask of the text. + token_type_ids (:obj:`torch.Tensor`, `optional`): + The token type ids of the text. + prediction_mask (:obj:`torch.Tensor`, `optional`): + The prediction mask of the text. + special_symbols_mask (:obj:`torch.Tensor`, `optional`): + The special symbols mask of the text. + candidates (:obj:`List[List[str]]`, `optional`): + The candidates of the text. + max_length (:obj:`int`, `optional`, defaults to 1024): + The maximum length of the text. + max_batch_size (:obj:`int`, `optional`, defaults to 128): + The maximum batch size. + token_batch_size (:obj:`int`, `optional`): + The maximum number of tokens per batch. + precision (:obj:`int` or :obj:`str`, `optional`): + The precision to use. If not provided, the default is 32 bit. + progress_bar (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether to show a progress bar. + + Returns: + The predicted labels for each sample. + """ + if text is None and input_ids is None and samples is None: + raise ValueError( + "Either `text` or `input_ids` or `samples` must be provided." + ) + if (input_ids is None and samples is None) and ( + text is None or candidates is None + ): + raise ValueError( + "`text` and `candidates` must be provided to return the predictions when " + "`input_ids` and `samples` is not provided." + ) + if text is not None and samples is None: + if len(text) != len(candidates): + raise ValueError("`text` and `candidates` must have the same length.") + if isinstance(text[0], str): # change to list of text + text = [text] + candidates = [candidates] + + samples = [ + RelikReaderSample(tokens=t, candidates=c) + for t, c in zip(text, candidates) + ] + + return self._read( + samples, + input_ids, + attention_mask, + token_type_ids, + prediction_mask, + special_symbols_mask, + max_length, + max_batch_size, + token_batch_size, + precision or self.precision, + progress_bar, + *args, + **kwargs, + ) + + @property + def device(self) -> torch.device: + """ + The device of the model. + """ + return next(self.parameters()).device + + @property + def tokenizer(self) -> tr.PreTrainedTokenizer: + """ + The tokenizer. + """ + if self._tokenizer: + return self._tokenizer + + self._tokenizer = tr.AutoTokenizer.from_pretrained( + self.relik_reader_model.config.name_or_path + ) + return self._tokenizer + + def save_pretrained( + self, + output_dir: str | os.PathLike, + model_name: str | None = None, + push_to_hub: bool = False, + **kwargs, + ) -> None: + """ + Saves the model to the given path. + + Args: + output_dir (`str` or :obj:`os.PathLike`): + The path to save the model to. + model_name (`str`, `optional`): + The name of the model. If not provided, the model will be saved as + `default_reader_class.__name__`. + push_to_hub (`bool`, `optional`, defaults to `False`): + Whether to push the model to the HuggingFace Hub. + **kwargs: + Additional keyword arguments to pass to the `save_pretrained` method + """ + # create the output directory + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + model_name = model_name or self.default_reader_class.__name__ + + logger.info(f"Saving reader to {output_dir / model_name}") + + # save the model + self.relik_reader_model.register_for_auto_class() + self.relik_reader_model.save_pretrained( + output_dir / model_name, push_to_hub=push_to_hub, **kwargs + ) + + if self.tokenizer: + logger.info("Saving also the tokenizer") + self.tokenizer.save_pretrained( + output_dir / model_name, push_to_hub=push_to_hub, **kwargs + ) diff --git a/relik/reader/pytorch_modules/hf/__init__.py b/relik/reader/pytorch_modules/hf/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c9c158e6ab6dcd3ab43e60751218600fbb0a5ed5 --- /dev/null +++ b/relik/reader/pytorch_modules/hf/__init__.py @@ -0,0 +1,2 @@ +from .configuration_relik import RelikReaderConfig +from .modeling_relik import RelikReaderREModel diff --git a/relik/reader/pytorch_modules/hf/configuration_relik.py b/relik/reader/pytorch_modules/hf/configuration_relik.py new file mode 100644 index 0000000000000000000000000000000000000000..6683823926b4b09a5ad169ef4e0f5b92061d774e --- /dev/null +++ b/relik/reader/pytorch_modules/hf/configuration_relik.py @@ -0,0 +1,33 @@ +from typing import Optional + +from transformers import AutoConfig +from transformers.configuration_utils import PretrainedConfig + + +class RelikReaderConfig(PretrainedConfig): + model_type = "relik-reader" + + def __init__( + self, + transformer_model: str = "microsoft/deberta-v3-base", + additional_special_symbols: int = 101, + num_layers: Optional[int] = None, + activation: str = "gelu", + linears_hidden_size: Optional[int] = 512, + use_last_k_layers: int = 1, + training: bool = False, + default_reader_class: Optional[str] = None, + **kwargs + ) -> None: + self.transformer_model = transformer_model + self.additional_special_symbols = additional_special_symbols + self.num_layers = num_layers + self.activation = activation + self.linears_hidden_size = linears_hidden_size + self.use_last_k_layers = use_last_k_layers + self.training = training + self.default_reader_class = default_reader_class + super().__init__(**kwargs) + + +AutoConfig.register("relik-reader", RelikReaderConfig) diff --git a/relik/reader/pytorch_modules/hf/modeling_relik.py b/relik/reader/pytorch_modules/hf/modeling_relik.py new file mode 100644 index 0000000000000000000000000000000000000000..f79fc14e0cabe9f830187467578ff3f65351c9a2 --- /dev/null +++ b/relik/reader/pytorch_modules/hf/modeling_relik.py @@ -0,0 +1,981 @@ +from typing import Any, Dict, Optional + +import torch +from transformers import AutoModel, PreTrainedModel +from transformers.activations import ClippedGELUActivation, GELUActivation +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_utils import PoolerEndLogits + +from .configuration_relik import RelikReaderConfig + + +class RelikReaderSample: + def __init__(self, **kwargs): + super().__setattr__("_d", {}) + self._d = kwargs + + def __getattribute__(self, item): + return super(RelikReaderSample, self).__getattribute__(item) + + def __getattr__(self, item): + if item.startswith("__") and item.endswith("__"): + # this is likely some python library-specific variable (such as __deepcopy__ for copy) + # better follow standard behavior here + raise AttributeError(item) + elif item in self._d: + return self._d[item] + else: + return None + + def __setattr__(self, key, value): + if key in self._d: + self._d[key] = value + else: + super().__setattr__(key, value) + + +activation2functions = { + "relu": torch.nn.ReLU(), + "gelu": GELUActivation(), + "gelu_10": ClippedGELUActivation(-10, 10), +} + + +class PoolerEndLogitsBi(PoolerEndLogits): + def __init__(self, config: PretrainedConfig): + super().__init__(config) + self.dense_1 = torch.nn.Linear(config.hidden_size, 2) + + def forward( + self, + hidden_states: torch.FloatTensor, + start_states: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + p_mask: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + if p_mask is not None: + p_mask = p_mask.unsqueeze(-1) + logits = super().forward( + hidden_states, + start_states, + start_positions, + p_mask, + ) + return logits + + +class RelikReaderSpanModel(PreTrainedModel): + config_class = RelikReaderConfig + + def __init__(self, config: RelikReaderConfig, *args, **kwargs): + super().__init__(config) + # Transformer model declaration + self.config = config + self.transformer_model = ( + AutoModel.from_pretrained(self.config.transformer_model) + if self.config.num_layers is None + else AutoModel.from_pretrained( + self.config.transformer_model, num_hidden_layers=self.config.num_layers + ) + ) + self.transformer_model.resize_token_embeddings( + self.transformer_model.config.vocab_size + + self.config.additional_special_symbols + ) + + self.activation = self.config.activation + self.linears_hidden_size = self.config.linears_hidden_size + self.use_last_k_layers = self.config.use_last_k_layers + + # named entity detection layers + self.ned_start_classifier = self._get_projection_layer( + self.activation, last_hidden=2, layer_norm=False + ) + self.ned_end_classifier = PoolerEndLogits(self.transformer_model.config) + + # END entity disambiguation layer + self.ed_start_projector = self._get_projection_layer(self.activation) + self.ed_end_projector = self._get_projection_layer(self.activation) + + self.training = self.config.training + + # criterion + self.criterion = torch.nn.CrossEntropyLoss() + + def _get_projection_layer( + self, + activation: str, + last_hidden: Optional[int] = None, + input_hidden=None, + layer_norm: bool = True, + ) -> torch.nn.Sequential: + head_components = [ + torch.nn.Dropout(0.1), + torch.nn.Linear( + self.transformer_model.config.hidden_size * self.use_last_k_layers + if input_hidden is None + else input_hidden, + self.linears_hidden_size, + ), + activation2functions[activation], + torch.nn.Dropout(0.1), + torch.nn.Linear( + self.linears_hidden_size, + self.linears_hidden_size if last_hidden is None else last_hidden, + ), + ] + + if layer_norm: + head_components.append( + torch.nn.LayerNorm( + self.linears_hidden_size if last_hidden is None else last_hidden, + self.transformer_model.config.layer_norm_eps, + ) + ) + + return torch.nn.Sequential(*head_components) + + def _mask_logits(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + mask = mask.unsqueeze(-1) + if next(self.parameters()).dtype == torch.float16: + logits = logits * (1 - mask) - 65500 * mask + else: + logits = logits * (1 - mask) - 1e30 * mask + return logits + + def _get_model_features( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + token_type_ids: Optional[torch.Tensor], + ): + model_input = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "output_hidden_states": self.use_last_k_layers > 1, + } + + if token_type_ids is not None: + model_input["token_type_ids"] = token_type_ids + + model_output = self.transformer_model(**model_input) + + if self.use_last_k_layers > 1: + model_features = torch.cat( + model_output[1][-self.use_last_k_layers :], dim=-1 + ) + else: + model_features = model_output[0] + + return model_features + + def compute_ned_end_logits( + self, + start_predictions, + start_labels, + model_features, + prediction_mask, + batch_size, + ) -> Optional[torch.Tensor]: + # todo: maybe when constraining on the spans, + # we should not use a prediction_mask for the end tokens. + # at least we should not during training imo + start_positions = start_labels if self.training else start_predictions + start_positions_indices = ( + torch.arange(start_positions.size(1), device=start_positions.device) + .unsqueeze(0) + .expand(batch_size, -1)[start_positions > 0] + ).to(start_positions.device) + + if len(start_positions_indices) > 0: + expanded_features = torch.cat( + [ + model_features[i].unsqueeze(0).expand(x, -1, -1) + for i, x in enumerate(torch.sum(start_positions > 0, dim=-1)) + if x > 0 + ], + dim=0, + ).to(start_positions_indices.device) + + expanded_prediction_mask = torch.cat( + [ + prediction_mask[i].unsqueeze(0).expand(x, -1) + for i, x in enumerate(torch.sum(start_positions > 0, dim=-1)) + if x > 0 + ], + dim=0, + ).to(expanded_features.device) + + end_logits = self.ned_end_classifier( + hidden_states=expanded_features, + start_positions=start_positions_indices, + p_mask=expanded_prediction_mask, + ) + + return end_logits + + return None + + def compute_classification_logits( + self, + model_features, + special_symbols_mask, + prediction_mask, + batch_size, + start_positions=None, + end_positions=None, + ) -> torch.Tensor: + if start_positions is None or end_positions is None: + start_positions = torch.zeros_like(prediction_mask) + end_positions = torch.zeros_like(prediction_mask) + + model_start_features = self.ed_start_projector(model_features) + model_end_features = self.ed_end_projector(model_features) + model_end_features[start_positions > 0] = model_end_features[end_positions > 0] + + model_ed_features = torch.cat( + [model_start_features, model_end_features], dim=-1 + ) + + # computing ed features + classes_representations = torch.sum(special_symbols_mask, dim=1)[0].item() + special_symbols_representation = model_ed_features[special_symbols_mask].view( + batch_size, classes_representations, -1 + ) + + logits = torch.bmm( + model_ed_features, + torch.permute(special_symbols_representation, (0, 2, 1)), + ) + + logits = self._mask_logits(logits, prediction_mask) + + return logits + + def forward( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + token_type_ids: Optional[torch.Tensor] = None, + prediction_mask: Optional[torch.Tensor] = None, + special_symbols_mask: Optional[torch.Tensor] = None, + start_labels: Optional[torch.Tensor] = None, + end_labels: Optional[torch.Tensor] = None, + use_predefined_spans: bool = False, + *args, + **kwargs, + ) -> Dict[str, Any]: + batch_size, seq_len = input_ids.shape + + model_features = self._get_model_features( + input_ids, attention_mask, token_type_ids + ) + + ned_start_labels = None + + # named entity detection if required + if use_predefined_spans: # no need to compute spans + ned_start_logits, ned_start_probabilities, ned_start_predictions = ( + None, + None, + torch.clone(start_labels) + if start_labels is not None + else torch.zeros_like(input_ids), + ) + ned_end_logits, ned_end_probabilities, ned_end_predictions = ( + None, + None, + torch.clone(end_labels) + if end_labels is not None + else torch.zeros_like(input_ids), + ) + + ned_start_predictions[ned_start_predictions > 0] = 1 + ned_end_predictions[ned_end_predictions > 0] = 1 + + else: # compute spans + # start boundary prediction + ned_start_logits = self.ned_start_classifier(model_features) + ned_start_logits = self._mask_logits(ned_start_logits, prediction_mask) + ned_start_probabilities = torch.softmax(ned_start_logits, dim=-1) + ned_start_predictions = ned_start_probabilities.argmax(dim=-1) + + # end boundary prediction + ned_start_labels = ( + torch.zeros_like(start_labels) if start_labels is not None else None + ) + + if ned_start_labels is not None: + ned_start_labels[start_labels == -100] = -100 + ned_start_labels[start_labels > 0] = 1 + + ned_end_logits = self.compute_ned_end_logits( + ned_start_predictions, + ned_start_labels, + model_features, + prediction_mask, + batch_size, + ) + + if ned_end_logits is not None: + ned_end_probabilities = torch.softmax(ned_end_logits, dim=-1) + ned_end_predictions = torch.argmax(ned_end_probabilities, dim=-1) + else: + ned_end_logits, ned_end_probabilities = None, None + ned_end_predictions = ned_start_predictions.new_zeros(batch_size) + + # flattening end predictions + # (flattening can happen only if the + # end boundaries were not predicted using the gold labels) + if not self.training: + flattened_end_predictions = torch.clone(ned_start_predictions) + flattened_end_predictions[flattened_end_predictions > 0] = 0 + + batch_start_predictions = list() + for elem_idx in range(batch_size): + batch_start_predictions.append( + torch.where(ned_start_predictions[elem_idx] > 0)[0].tolist() + ) + + # check that the total number of start predictions + # is equal to the end predictions + total_start_predictions = sum(map(len, batch_start_predictions)) + total_end_predictions = len(ned_end_predictions) + assert ( + total_start_predictions == 0 + or total_start_predictions == total_end_predictions + ), ( + f"Total number of start predictions = {total_start_predictions}. " + f"Total number of end predictions = {total_end_predictions}" + ) + + curr_end_pred_num = 0 + for elem_idx, bsp in enumerate(batch_start_predictions): + for sp in bsp: + ep = ned_end_predictions[curr_end_pred_num].item() + if ep < sp: + ep = sp + + # if we already set this span throw it (no overlap) + if flattened_end_predictions[elem_idx, ep] == 1: + ned_start_predictions[elem_idx, sp] = 0 + else: + flattened_end_predictions[elem_idx, ep] = 1 + + curr_end_pred_num += 1 + + ned_end_predictions = flattened_end_predictions + + start_position, end_position = ( + (start_labels, end_labels) + if self.training + else (ned_start_predictions, ned_end_predictions) + ) + + # Entity disambiguation + ed_logits = self.compute_classification_logits( + model_features, + special_symbols_mask, + prediction_mask, + batch_size, + start_position, + end_position, + ) + ed_probabilities = torch.softmax(ed_logits, dim=-1) + ed_predictions = torch.argmax(ed_probabilities, dim=-1) + + # output build + output_dict = dict( + batch_size=batch_size, + ned_start_logits=ned_start_logits, + ned_start_probabilities=ned_start_probabilities, + ned_start_predictions=ned_start_predictions, + ned_end_logits=ned_end_logits, + ned_end_probabilities=ned_end_probabilities, + ned_end_predictions=ned_end_predictions, + ed_logits=ed_logits, + ed_probabilities=ed_probabilities, + ed_predictions=ed_predictions, + ) + + # compute loss if labels + if start_labels is not None and end_labels is not None and self.training: + # named entity detection loss + + # start + if ned_start_logits is not None: + ned_start_loss = self.criterion( + ned_start_logits.view(-1, ned_start_logits.shape[-1]), + ned_start_labels.view(-1), + ) + else: + ned_start_loss = 0 + + # end + if ned_end_logits is not None: + ned_end_labels = torch.zeros_like(end_labels) + ned_end_labels[end_labels == -100] = -100 + ned_end_labels[end_labels > 0] = 1 + + ned_end_loss = self.criterion( + ned_end_logits, + ( + torch.arange( + ned_end_labels.size(1), device=ned_end_labels.device + ) + .unsqueeze(0) + .expand(batch_size, -1)[ned_end_labels > 0] + ).to(ned_end_labels.device), + ) + + else: + ned_end_loss = 0 + + # entity disambiguation loss + start_labels[ned_start_labels != 1] = -100 + ed_labels = torch.clone(start_labels) + ed_labels[end_labels > 0] = end_labels[end_labels > 0] + ed_loss = self.criterion( + ed_logits.view(-1, ed_logits.shape[-1]), + ed_labels.view(-1), + ) + + output_dict["ned_start_loss"] = ned_start_loss + output_dict["ned_end_loss"] = ned_end_loss + output_dict["ed_loss"] = ed_loss + + output_dict["loss"] = ned_start_loss + ned_end_loss + ed_loss + + return output_dict + + +class RelikReaderREModel(PreTrainedModel): + config_class = RelikReaderConfig + + def __init__(self, config, *args, **kwargs): + super().__init__(config) + # Transformer model declaration + # self.transformer_model_name = transformer_model + self.config = config + self.transformer_model = ( + AutoModel.from_pretrained(config.transformer_model) + if config.num_layers is None + else AutoModel.from_pretrained( + config.transformer_model, num_hidden_layers=config.num_layers + ) + ) + self.transformer_model.resize_token_embeddings( + self.transformer_model.config.vocab_size + config.additional_special_symbols + ) + + # named entity detection layers + self.ned_start_classifier = self._get_projection_layer( + config.activation, last_hidden=2, layer_norm=False + ) + + self.ned_end_classifier = PoolerEndLogitsBi(self.transformer_model.config) + + self.entity_type_loss = ( + config.entity_type_loss if hasattr(config, "entity_type_loss") else False + ) + self.relation_disambiguation_loss = ( + config.relation_disambiguation_loss + if hasattr(config, "relation_disambiguation_loss") + else False + ) + + input_hidden_ents = 2 * self.transformer_model.config.hidden_size + + self.re_subject_projector = self._get_projection_layer( + config.activation, input_hidden=input_hidden_ents + ) + self.re_object_projector = self._get_projection_layer( + config.activation, input_hidden=input_hidden_ents + ) + self.re_relation_projector = self._get_projection_layer(config.activation) + + if self.entity_type_loss or self.relation_disambiguation_loss: + self.re_entities_projector = self._get_projection_layer( + config.activation, + input_hidden=2 * self.transformer_model.config.hidden_size, + ) + self.re_definition_projector = self._get_projection_layer( + config.activation, + ) + + self.re_classifier = self._get_projection_layer( + config.activation, + input_hidden=config.linears_hidden_size, + last_hidden=2, + layer_norm=False, + ) + + if self.entity_type_loss or self.relation_disambiguation_loss: + self.re_ed_classifier = self._get_projection_layer( + config.activation, + input_hidden=config.linears_hidden_size, + last_hidden=2, + layer_norm=False, + ) + + self.training = config.training + + # criterion + self.criterion = torch.nn.CrossEntropyLoss() + + def _get_projection_layer( + self, + activation: str, + last_hidden: Optional[int] = None, + input_hidden=None, + layer_norm: bool = True, + ) -> torch.nn.Sequential: + head_components = [ + torch.nn.Dropout(0.1), + torch.nn.Linear( + self.transformer_model.config.hidden_size + * self.config.use_last_k_layers + if input_hidden is None + else input_hidden, + self.config.linears_hidden_size, + ), + activation2functions[activation], + torch.nn.Dropout(0.1), + torch.nn.Linear( + self.config.linears_hidden_size, + self.config.linears_hidden_size if last_hidden is None else last_hidden, + ), + ] + + if layer_norm: + head_components.append( + torch.nn.LayerNorm( + self.config.linears_hidden_size + if last_hidden is None + else last_hidden, + self.transformer_model.config.layer_norm_eps, + ) + ) + + return torch.nn.Sequential(*head_components) + + def _mask_logits(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + mask = mask.unsqueeze(-1) + if next(self.parameters()).dtype == torch.float16: + logits = logits * (1 - mask) - 65500 * mask + else: + logits = logits * (1 - mask) - 1e30 * mask + return logits + + def _get_model_features( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + token_type_ids: Optional[torch.Tensor], + ): + model_input = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "output_hidden_states": self.config.use_last_k_layers > 1, + } + + if token_type_ids is not None: + model_input["token_type_ids"] = token_type_ids + + model_output = self.transformer_model(**model_input) + + if self.config.use_last_k_layers > 1: + model_features = torch.cat( + model_output[1][-self.config.use_last_k_layers :], dim=-1 + ) + else: + model_features = model_output[0] + + return model_features + + def compute_ned_end_logits( + self, + start_predictions, + start_labels, + model_features, + prediction_mask, + batch_size, + ) -> Optional[torch.Tensor]: + # todo: maybe when constraining on the spans, + # we should not use a prediction_mask for the end tokens. + # at least we should not during training imo + start_positions = start_labels if self.training else start_predictions + start_positions_indices = ( + torch.arange(start_positions.size(1), device=start_positions.device) + .unsqueeze(0) + .expand(batch_size, -1)[start_positions > 0] + ).to(start_positions.device) + + if len(start_positions_indices) > 0: + expanded_features = torch.cat( + [ + model_features[i].unsqueeze(0).expand(x, -1, -1) + for i, x in enumerate(torch.sum(start_positions > 0, dim=-1)) + if x > 0 + ], + dim=0, + ).to(start_positions_indices.device) + + expanded_prediction_mask = torch.cat( + [ + prediction_mask[i].unsqueeze(0).expand(x, -1) + for i, x in enumerate(torch.sum(start_positions > 0, dim=-1)) + if x > 0 + ], + dim=0, + ).to(expanded_features.device) + + # mask all tokens before start_positions_indices ie, mask all tokens with + # indices < start_positions_indices with 1, ie. [range(x) for x in start_positions_indices] + expanded_prediction_mask = torch.stack( + [ + torch.cat( + [ + torch.ones(x, device=expanded_features.device), + expanded_prediction_mask[i, x:], + ] + ) + for i, x in enumerate(start_positions_indices) + if x > 0 + ], + dim=0, + ).to(expanded_features.device) + + end_logits = self.ned_end_classifier( + hidden_states=expanded_features, + start_positions=start_positions_indices, + p_mask=expanded_prediction_mask, + ) + + return end_logits + + return None + + def compute_relation_logits( + self, + model_entity_features, + special_symbols_features, + ) -> torch.Tensor: + model_subject_features = self.re_subject_projector(model_entity_features) + model_object_features = self.re_object_projector(model_entity_features) + special_symbols_start_representation = self.re_relation_projector( + special_symbols_features + ) + re_logits = torch.einsum( + "bse,bde,bfe->bsdfe", + model_subject_features, + model_object_features, + special_symbols_start_representation, + ) + re_logits = self.re_classifier(re_logits) + + return re_logits + + def compute_entity_logits( + self, + model_entity_features, + special_symbols_features, + ) -> torch.Tensor: + model_ed_features = self.re_entities_projector(model_entity_features) + special_symbols_ed_representation = self.re_definition_projector( + special_symbols_features + ) + logits = torch.einsum( + "bce,bde->bcde", + model_ed_features, + special_symbols_ed_representation, + ) + logits = self.re_ed_classifier(logits) + start_logits = self._mask_logits( + logits, + (model_entity_features == -100) + .all(2) + .long() + .unsqueeze(2) + .repeat(1, 1, torch.sum(model_entity_features, dim=1)[0].item()), + ) + + return logits + + def compute_loss(self, logits, labels, mask=None): + logits = logits.view(-1, logits.shape[-1]) + labels = labels.view(-1).long() + if mask is not None: + return self.criterion(logits[mask], labels[mask]) + return self.criterion(logits, labels) + + def compute_ned_end_loss(self, ned_end_logits, end_labels): + if ned_end_logits is None: + return 0 + ned_end_labels = torch.zeros_like(end_labels) + ned_end_labels[end_labels == -100] = -100 + ned_end_labels[end_labels > 0] = 1 + return self.compute_loss(ned_end_logits, ned_end_labels) + + def compute_ned_type_loss( + self, + disambiguation_labels, + re_ned_entities_logits, + ned_type_logits, + re_entities_logits, + entity_types, + ): + if self.entity_type_loss and self.relation_disambiguation_loss: + return self.compute_loss(disambiguation_labels, re_ned_entities_logits) + if self.entity_type_loss: + return self.compute_loss( + disambiguation_labels[:, :, :entity_types], ned_type_logits + ) + if self.relation_disambiguation_loss: + return self.compute_loss(disambiguation_labels, re_entities_logits) + return 0 + + def compute_relation_loss(self, relation_labels, re_logits): + return self.compute_loss( + re_logits, relation_labels, relation_labels.view(-1) != -100 + ) + + def forward( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + token_type_ids: torch.Tensor, + prediction_mask: Optional[torch.Tensor] = None, + special_symbols_mask: Optional[torch.Tensor] = None, + special_symbols_mask_entities: Optional[torch.Tensor] = None, + start_labels: Optional[torch.Tensor] = None, + end_labels: Optional[torch.Tensor] = None, + disambiguation_labels: Optional[torch.Tensor] = None, + relation_labels: Optional[torch.Tensor] = None, + is_validation: bool = False, + is_prediction: bool = False, + *args, + **kwargs, + ) -> Dict[str, Any]: + batch_size = input_ids.shape[0] + + model_features = self._get_model_features( + input_ids, attention_mask, token_type_ids + ) + + # named entity detection + if is_prediction and start_labels is not None: + ned_start_logits, ned_start_probabilities, ned_start_predictions = ( + None, + None, + torch.zeros_like(start_labels), + ) + ned_end_logits, ned_end_probabilities, ned_end_predictions = ( + None, + None, + torch.zeros_like(end_labels), + ) + + ned_start_predictions[start_labels > 0] = 1 + ned_end_predictions[end_labels > 0] = 1 + ned_end_predictions = ned_end_predictions[~(end_labels == -100).all(2)] + else: + # start boundary prediction + ned_start_logits = self.ned_start_classifier(model_features) + ned_start_logits = self._mask_logits( + ned_start_logits, prediction_mask + ) # why? + ned_start_probabilities = torch.softmax(ned_start_logits, dim=-1) + ned_start_predictions = ned_start_probabilities.argmax(dim=-1) + + # end boundary prediction + ned_start_labels = ( + torch.zeros_like(start_labels) if start_labels is not None else None + ) + + # start_labels contain entity id at their position, we just need 1 for start of entity + if ned_start_labels is not None: + ned_start_labels[start_labels > 0] = 1 + + # compute end logits only if there are any start predictions. + # For each start prediction, n end predictions are made + ned_end_logits = self.compute_ned_end_logits( + ned_start_predictions, + ned_start_labels, + model_features, + prediction_mask, + batch_size, + ) + # For each start prediction, n end predictions are made based on + # binary classification ie. argmax at each position. + ned_end_probabilities = torch.softmax(ned_end_logits, dim=-1) + ned_end_predictions = ned_end_probabilities.argmax(dim=-1) + if is_prediction or is_validation: + end_preds_count = ned_end_predictions.sum(1) + # If there are no end predictions for a start prediction, remove the start prediction + ned_start_predictions[ned_start_predictions == 1] = ( + end_preds_count != 0 + ).long() + ned_end_predictions = ned_end_predictions[end_preds_count != 0] + + if end_labels is not None: + end_labels = end_labels[~(end_labels == -100).all(2)] + + start_position, end_position = ( + (start_labels, end_labels) + if (not is_prediction and not is_validation) + else (ned_start_predictions, ned_end_predictions) + ) + + start_counts = (start_position > 0).sum(1) + ned_end_predictions = ned_end_predictions.split(start_counts.tolist()) + + # We can only predict relations if we have start and end predictions + if (end_position > 0).sum() > 0: + ends_count = (end_position > 0).sum(1) + model_subject_features = torch.cat( + [ + torch.repeat_interleave( + model_features[start_position > 0], ends_count, dim=0 + ), # start position features + torch.repeat_interleave(model_features, start_counts, dim=0)[ + end_position > 0 + ], # end position features + ], + dim=-1, + ) + ents_count = torch.nn.utils.rnn.pad_sequence( + torch.split(ends_count, start_counts.tolist()), + batch_first=True, + padding_value=0, + ).sum(1) + model_subject_features = torch.nn.utils.rnn.pad_sequence( + torch.split(model_subject_features, ents_count.tolist()), + batch_first=True, + padding_value=-100, + ) + + if is_validation or is_prediction: + model_subject_features = model_subject_features[:, :30, :] + + # entity disambiguation. Here relation_disambiguation_loss would only be useful to + # reduce the number of candidate relations for the next step, but currently unused. + if self.entity_type_loss or self.relation_disambiguation_loss: + (re_ned_entities_logits) = self.compute_entity_logits( + model_subject_features, + model_features[ + special_symbols_mask | special_symbols_mask_entities + ].view(batch_size, -1, model_features.shape[-1]), + ) + entity_types = torch.sum(special_symbols_mask_entities, dim=1)[0].item() + ned_type_logits = re_ned_entities_logits[:, :, :entity_types] + re_entities_logits = re_ned_entities_logits[:, :, entity_types:] + + if self.entity_type_loss: + ned_type_probabilities = torch.softmax(ned_type_logits, dim=-1) + ned_type_predictions = ned_type_probabilities.argmax(dim=-1) + ned_type_predictions = ned_type_predictions.argmax(dim=-1) + + re_entities_probabilities = torch.softmax(re_entities_logits, dim=-1) + re_entities_predictions = re_entities_probabilities.argmax(dim=-1) + else: + ( + ned_type_logits, + ned_type_probabilities, + re_entities_logits, + re_entities_probabilities, + ) = (None, None, None, None) + ned_type_predictions, re_entities_predictions = ( + torch.zeros([batch_size, 1], dtype=torch.long).to(input_ids.device), + torch.zeros([batch_size, 1], dtype=torch.long).to(input_ids.device), + ) + + # Compute relation logits + re_logits = self.compute_relation_logits( + model_subject_features, + model_features[special_symbols_mask].view( + batch_size, -1, model_features.shape[-1] + ), + ) + + re_probabilities = torch.softmax(re_logits, dim=-1) + # we set a thresshold instead of argmax in cause it needs to be tweaked + re_predictions = re_probabilities[:, :, :, :, 1] > 0.5 + # re_predictions = re_probabilities.argmax(dim=-1) + re_probabilities = re_probabilities[:, :, :, :, 1] + + else: + ( + ned_type_logits, + ned_type_probabilities, + re_entities_logits, + re_entities_probabilities, + ) = (None, None, None, None) + ned_type_predictions, re_entities_predictions = ( + torch.zeros([batch_size, 1], dtype=torch.long).to(input_ids.device), + torch.zeros([batch_size, 1], dtype=torch.long).to(input_ids.device), + ) + re_logits, re_probabilities, re_predictions = ( + torch.zeros( + [batch_size, 1, 1, special_symbols_mask.sum(1)[0]], dtype=torch.long + ).to(input_ids.device), + torch.zeros( + [batch_size, 1, 1, special_symbols_mask.sum(1)[0]], dtype=torch.long + ).to(input_ids.device), + torch.zeros( + [batch_size, 1, 1, special_symbols_mask.sum(1)[0]], dtype=torch.long + ).to(input_ids.device), + ) + + # output build + output_dict = dict( + batch_size=batch_size, + ned_start_logits=ned_start_logits, + ned_start_probabilities=ned_start_probabilities, + ned_start_predictions=ned_start_predictions, + ned_end_logits=ned_end_logits, + ned_end_probabilities=ned_end_probabilities, + ned_end_predictions=ned_end_predictions, + ned_type_logits=ned_type_logits, + ned_type_probabilities=ned_type_probabilities, + ned_type_predictions=ned_type_predictions, + re_entities_logits=re_entities_logits, + re_entities_probabilities=re_entities_probabilities, + re_entities_predictions=re_entities_predictions, + re_logits=re_logits, + re_probabilities=re_probabilities, + re_predictions=re_predictions, + ) + + if ( + start_labels is not None + and end_labels is not None + and relation_labels is not None + ): + ned_start_loss = self.compute_loss(ned_start_logits, ned_start_labels) + ned_end_loss = self.compute_ned_end_loss(ned_end_logits, end_labels) + if self.entity_type_loss or self.relation_disambiguation_loss: + ned_type_loss = self.compute_ned_type_loss( + disambiguation_labels, + re_ned_entities_logits, + ned_type_logits, + re_entities_logits, + entity_types, + ) + relation_loss = self.compute_relation_loss(relation_labels, re_logits) + # compute loss. We can skip the relation loss if we are in the first epochs (optional) + if self.entity_type_loss or self.relation_disambiguation_loss: + output_dict["loss"] = ( + ned_start_loss + ned_end_loss + relation_loss + ned_type_loss + ) / 4 + output_dict["ned_type_loss"] = ned_type_loss + else: + output_dict["loss"] = ( + ned_start_loss + ned_end_loss + relation_loss + ) / 3 + + output_dict["ned_start_loss"] = ned_start_loss + output_dict["ned_end_loss"] = ned_end_loss + output_dict["re_loss"] = relation_loss + + return output_dict diff --git a/relik/reader/pytorch_modules/optim/__init__.py b/relik/reader/pytorch_modules/optim/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..369091133267cfa05240306fbfe5ea3b537d5d9c --- /dev/null +++ b/relik/reader/pytorch_modules/optim/__init__.py @@ -0,0 +1,6 @@ +from relik.reader.pytorch_modules.optim.adamw_with_warmup import ( + AdamWWithWarmupOptimizer, +) +from relik.reader.pytorch_modules.optim.layer_wise_lr_decay import ( + LayerWiseLRDecayOptimizer, +) diff --git a/relik/reader/pytorch_modules/optim/adamw_with_warmup.py b/relik/reader/pytorch_modules/optim/adamw_with_warmup.py new file mode 100644 index 0000000000000000000000000000000000000000..dfaecc4ca3d1c366f25962db4d0024a5b986fd50 --- /dev/null +++ b/relik/reader/pytorch_modules/optim/adamw_with_warmup.py @@ -0,0 +1,66 @@ +from typing import List + +import torch +import transformers +from torch.optim import AdamW + + +class AdamWWithWarmupOptimizer: + def __init__( + self, + lr: float, + warmup_steps: int, + total_steps: int, + weight_decay: float, + no_decay_params: List[str], + ): + self.lr = lr + self.warmup_steps = warmup_steps + self.total_steps = total_steps + self.weight_decay = weight_decay + self.no_decay_params = no_decay_params + + def group_params(self, module: torch.nn.Module) -> list: + if self.no_decay_params is not None: + optimizer_grouped_parameters = [ + { + "params": [ + p + for n, p in module.named_parameters() + if not any(nd in n for nd in self.no_decay_params) + ], + "weight_decay": self.weight_decay, + }, + { + "params": [ + p + for n, p in module.named_parameters() + if any(nd in n for nd in self.no_decay_params) + ], + "weight_decay": 0.0, + }, + ] + + else: + optimizer_grouped_parameters = [ + {"params": module.parameters(), "weight_decay": self.weight_decay} + ] + + return optimizer_grouped_parameters + + def __call__(self, module: torch.nn.Module): + optimizer_grouped_parameters = self.group_params(module) + optimizer = AdamW( + optimizer_grouped_parameters, lr=self.lr, weight_decay=self.weight_decay + ) + scheduler = transformers.get_linear_schedule_with_warmup( + optimizer, self.warmup_steps, self.total_steps + ) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "interval": "step", + "frequency": 1, + }, + } diff --git a/relik/reader/pytorch_modules/optim/layer_wise_lr_decay.py b/relik/reader/pytorch_modules/optim/layer_wise_lr_decay.py new file mode 100644 index 0000000000000000000000000000000000000000..d179096153f356196a921c50083c96b3dcd5f246 --- /dev/null +++ b/relik/reader/pytorch_modules/optim/layer_wise_lr_decay.py @@ -0,0 +1,104 @@ +import collections +from typing import List + +import torch +import transformers +from torch.optim import AdamW + + +class LayerWiseLRDecayOptimizer: + def __init__( + self, + lr: float, + warmup_steps: int, + total_steps: int, + weight_decay: float, + lr_decay: float, + no_decay_params: List[str], + total_reset: int, + ): + self.lr = lr + self.warmup_steps = warmup_steps + self.total_steps = total_steps + self.weight_decay = weight_decay + self.lr_decay = lr_decay + self.no_decay_params = no_decay_params + self.total_reset = total_reset + + def group_layers(self, module) -> dict: + grouped_layers = collections.defaultdict(list) + module_named_parameters = list(module.named_parameters()) + for ln, lp in module_named_parameters: + if "embeddings" in ln: + grouped_layers["embeddings"].append((ln, lp)) + elif "encoder.layer" in ln: + layer_num = ln.split("transformer_model.encoder.layer.")[-1] + layer_num = layer_num[0 : layer_num.index(".")] + grouped_layers[layer_num].append((ln, lp)) + else: + grouped_layers["head"].append((ln, lp)) + + depth = len(grouped_layers) - 1 + final_dict = dict() + for key, value in grouped_layers.items(): + if key == "head": + final_dict[0] = value + elif key == "embeddings": + final_dict[depth] = value + else: + # -1 because layer number starts from zero + final_dict[depth - int(key) - 1] = value + + assert len(module_named_parameters) == sum( + len(v) for _, v in final_dict.items() + ) + + return final_dict + + def group_params(self, module) -> list: + optimizer_grouped_params = [] + for inverse_depth, layer in self.group_layers(module).items(): + layer_lr = self.lr * (self.lr_decay**inverse_depth) + layer_wd_params = { + "params": [ + lp + for ln, lp in layer + if not any(nd in ln for nd in self.no_decay_params) + ], + "weight_decay": self.weight_decay, + "lr": layer_lr, + } + layer_no_wd_params = { + "params": [ + lp + for ln, lp in layer + if any(nd in ln for nd in self.no_decay_params) + ], + "weight_decay": 0, + "lr": layer_lr, + } + + if len(layer_wd_params) != 0: + optimizer_grouped_params.append(layer_wd_params) + if len(layer_no_wd_params) != 0: + optimizer_grouped_params.append(layer_no_wd_params) + + return optimizer_grouped_params + + def __call__(self, module: torch.nn.Module): + optimizer_grouped_parameters = self.group_params(module) + optimizer = AdamW(optimizer_grouped_parameters, lr=self.lr) + scheduler = transformers.get_cosine_with_hard_restarts_schedule_with_warmup( + optimizer, + self.warmup_steps, + self.total_steps, + num_cycles=self.total_reset, + ) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "interval": "step", + "frequency": 1, + }, + } diff --git a/relik/reader/pytorch_modules/span.py b/relik/reader/pytorch_modules/span.py new file mode 100644 index 0000000000000000000000000000000000000000..349e42cafc1dfbc583adc46e7c8cf63d1d3752d8 --- /dev/null +++ b/relik/reader/pytorch_modules/span.py @@ -0,0 +1,367 @@ +import collections +import contextlib +import logging +from typing import Any, Dict, Iterator, List + +import torch +import transformers as tr +from lightning_fabric.utilities import move_data_to_device +from torch.utils.data import DataLoader, IterableDataset +from tqdm import tqdm + +from relik.common.log import get_console_logger, get_logger +from relik.common.utils import get_callable_from_string +from relik.reader.data.relik_reader_sample import RelikReaderSample +from relik.reader.pytorch_modules.base import RelikReaderBase +from relik.reader.utils.special_symbols import get_special_symbols +from relik.retriever.pytorch_modules import PRECISION_MAP + +console_logger = get_console_logger() +logger = get_logger(__name__, level=logging.INFO) + + +class RelikReaderForSpanExtraction(RelikReaderBase): + """ + A class for the RelikReader model for span extraction. + + Args: + transformer_model (:obj:`str` or :obj:`transformers.PreTrainedModel` or :obj:`None`, `optional`): + The transformer model to use. If `None`, the default model is used. + additional_special_symbols (:obj:`int`, `optional`, defaults to 0): + The number of additional special symbols to add to the tokenizer. + num_layers (:obj:`int`, `optional`): + The number of layers to use. If `None`, all layers are used. + activation (:obj:`str`, `optional`, defaults to "gelu"): + The activation function to use. + linears_hidden_size (:obj:`int`, `optional`, defaults to 512): + The hidden size of the linears. + use_last_k_layers (:obj:`int`, `optional`, defaults to 1): + The number of last layers to use. + training (:obj:`bool`, `optional`, defaults to False): + Whether the model is in training mode. + device (:obj:`str` or :obj:`torch.device` or :obj:`None`, `optional`): + The device to use. If `None`, the default device is used. + tokenizer (:obj:`str` or :obj:`transformers.PreTrainedTokenizer` or :obj:`None`, `optional`): + The tokenizer to use. If `None`, the default tokenizer is used. + dataset (:obj:`IterableDataset` or :obj:`str` or :obj:`None`, `optional`): + The dataset to use. If `None`, the default dataset is used. + dataset_kwargs (:obj:`Dict[str, Any]` or :obj:`None`, `optional`): + The keyword arguments to pass to the dataset class. + default_reader_class (:obj:`str` or :obj:`transformers.PreTrainedModel` or :obj:`None`, `optional`): + The default reader class to use. If `None`, the default reader class is used. + **kwargs: + Keyword arguments. + """ + + default_reader_class: str = ( + "relik.reader.pytorch_modules.hf.modeling_relik.RelikReaderSpanModel" + ) + default_data_class: str = "relik.reader.data.relik_reader_data.RelikDataset" + + def __init__( + self, + transformer_model: str | tr.PreTrainedModel | None = None, + additional_special_symbols: int = 0, + num_layers: int | None = None, + activation: str = "gelu", + linears_hidden_size: int | None = 512, + use_last_k_layers: int = 1, + training: bool = False, + device: str | torch.device | None = None, + tokenizer: str | tr.PreTrainedTokenizer | None = None, + dataset: IterableDataset | str | None = None, + dataset_kwargs: Dict[str, Any] | None = None, + default_reader_class: tr.PreTrainedModel | str | None = None, + **kwargs, + ): + super().__init__( + transformer_model=transformer_model, + additional_special_symbols=additional_special_symbols, + num_layers=num_layers, + activation=activation, + linears_hidden_size=linears_hidden_size, + use_last_k_layers=use_last_k_layers, + training=training, + device=device, + tokenizer=tokenizer, + dataset=dataset, + default_reader_class=default_reader_class, + **kwargs, + ) + # and instantiate the dataset class + self.dataset = dataset + if self.dataset is None: + default_data_kwargs = dict( + dataset_path=None, + materialize_samples=False, + transformer_model=self.tokenizer, + special_symbols=get_special_symbols( + self.relik_reader_model.config.additional_special_symbols + ), + for_inference=True, + ) + # merge the default data kwargs with the ones passed to the model + default_data_kwargs.update(dataset_kwargs or {}) + self.dataset = get_callable_from_string(self.default_data_class)( + **default_data_kwargs + ) + + @torch.no_grad() + @torch.inference_mode() + def _read( + self, + samples: List[RelikReaderSample] | None = None, + input_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + token_type_ids: torch.Tensor | None = None, + prediction_mask: torch.Tensor | None = None, + special_symbols_mask: torch.Tensor | None = None, + max_length: int = 1000, + max_batch_size: int = 128, + token_batch_size: int = 2048, + precision: str = 32, + annotation_type: str = "char", + progress_bar: bool = False, + *args: object, + **kwargs: object, + ) -> List[RelikReaderSample] | List[List[RelikReaderSample]]: + """ + A wrapper around the forward method that returns the predicted labels for each sample. + + Args: + samples (:obj:`List[RelikReaderSample]`, `optional`): + The samples to read. If provided, `text` and `candidates` are ignored. + input_ids (:obj:`torch.Tensor`, `optional`): + The input ids of the text. If `samples` is provided, this is ignored. + attention_mask (:obj:`torch.Tensor`, `optional`): + The attention mask of the text. If `samples` is provided, this is ignored. + token_type_ids (:obj:`torch.Tensor`, `optional`): + The token type ids of the text. If `samples` is provided, this is ignored. + prediction_mask (:obj:`torch.Tensor`, `optional`): + The prediction mask of the text. If `samples` is provided, this is ignored. + special_symbols_mask (:obj:`torch.Tensor`, `optional`): + The special symbols mask of the text. If `samples` is provided, this is ignored. + max_length (:obj:`int`, `optional`, defaults to 1000): + The maximum length of the text. + max_batch_size (:obj:`int`, `optional`, defaults to 128): + The maximum batch size. + token_batch_size (:obj:`int`, `optional`): + The token batch size. + progress_bar (:obj:`bool`, `optional`, defaults to False): + Whether to show a progress bar. + precision (:obj:`str`, `optional`, defaults to 32): + The precision to use for the model. + annotation_type (:obj:`str`, `optional`, defaults to "char"): + The annotation type to use. It can be either "char", "token" or "word". + *args: + Positional arguments. + **kwargs: + Keyword arguments. + + Returns: + :obj:`List[RelikReaderSample]` or :obj:`List[List[RelikReaderSample]]`: + The predicted labels for each sample. + """ + + precision = precision or self.precision + if samples is not None: + + def _read_iterator(): + def samples_it(): + for i, sample in enumerate(samples): + assert sample._mixin_prediction_position is None + sample._mixin_prediction_position = i + yield sample + + next_prediction_position = 0 + position2predicted_sample = {} + + # instantiate dataset + if self.dataset is None: + raise ValueError( + "You need to pass a dataset to the model in order to predict" + ) + self.dataset.samples = samples_it() + self.dataset.model_max_length = max_length + self.dataset.tokens_per_batch = token_batch_size + self.dataset.max_batch_size = max_batch_size + + # instantiate dataloader + iterator = DataLoader( + self.dataset, batch_size=None, num_workers=0, shuffle=False + ) + if progress_bar: + iterator = tqdm(iterator, desc="Predicting with RelikReader") + + # fucking autocast only wants pure strings like 'cpu' or 'cuda' + # we need to convert the model device to that + device_type_for_autocast = str(self.device).split(":")[0] + # autocast doesn't work with CPU and stuff different from bfloat16 + autocast_mngr = ( + contextlib.nullcontext() + if device_type_for_autocast == "cpu" + else ( + torch.autocast( + device_type=device_type_for_autocast, + dtype=PRECISION_MAP[precision], + ) + ) + ) + + with autocast_mngr: + for batch in iterator: + batch = move_data_to_device(batch, self.device) + batch_out = self._batch_predict(**batch) + + for sample in batch_out: + if ( + sample._mixin_prediction_position + >= next_prediction_position + ): + position2predicted_sample[ + sample._mixin_prediction_position + ] = sample + + # yield + while next_prediction_position in position2predicted_sample: + yield position2predicted_sample[next_prediction_position] + del position2predicted_sample[next_prediction_position] + next_prediction_position += 1 + + outputs = list(_read_iterator()) + for sample in outputs: + self.dataset.merge_patches_predictions(sample) + self.dataset.convert_tokens_to_char_annotations(sample) + + else: + outputs = list( + self._batch_predict( + input_ids, + attention_mask, + token_type_ids, + prediction_mask, + special_symbols_mask, + *args, + **kwargs, + ) + ) + return outputs + + def _batch_predict( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + token_type_ids: torch.Tensor | None = None, + prediction_mask: torch.Tensor | None = None, + special_symbols_mask: torch.Tensor | None = None, + sample: List[RelikReaderSample] | None = None, + top_k: int = 5, # the amount of top-k most probable entities to predict + *args, + **kwargs, + ) -> Iterator[RelikReaderSample]: + """ + A wrapper around the forward method that returns the predicted labels for each sample. + It also adds the predicted labels to the samples. + + Args: + input_ids (:obj:`torch.Tensor`): + The input ids of the text. + attention_mask (:obj:`torch.Tensor`): + The attention mask of the text. + token_type_ids (:obj:`torch.Tensor`, `optional`): + The token type ids of the text. + prediction_mask (:obj:`torch.Tensor`, `optional`): + The prediction mask of the text. + special_symbols_mask (:obj:`torch.Tensor`, `optional`): + The special symbols mask of the text. + sample (:obj:`List[RelikReaderSample]`, `optional`): + The samples to read. If provided, `text` and `candidates` are ignored. + top_k (:obj:`int`, `optional`, defaults to 5): + The amount of top-k most probable entities to predict. + *args: + Positional arguments. + **kwargs: + Keyword arguments. + + Returns: + The predicted labels for each sample. + """ + forward_output = self.forward( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + prediction_mask=prediction_mask, + special_symbols_mask=special_symbols_mask, + ) + + ned_start_predictions = forward_output["ned_start_predictions"].cpu().numpy() + ned_end_predictions = forward_output["ned_end_predictions"].cpu().numpy() + ed_predictions = forward_output["ed_predictions"].cpu().numpy() + ed_probabilities = forward_output["ed_probabilities"].cpu().numpy() + + batch_predictable_candidates = kwargs["predictable_candidates"] + patch_offset = kwargs["patch_offset"] + for ts, ne_sp, ne_ep, edp, edpr, pred_cands, po in zip( + sample, + ned_start_predictions, + ned_end_predictions, + ed_predictions, + ed_probabilities, + batch_predictable_candidates, + patch_offset, + ): + ne_start_indices = [ti for ti, c in enumerate(ne_sp[1:]) if c > 0] + ne_end_indices = [ti for ti, c in enumerate(ne_ep[1:]) if c > 0] + + final_class2predicted_spans = collections.defaultdict(list) + spans2predicted_probabilities = dict() + for start_token_index, end_token_index in zip( + ne_start_indices, ne_end_indices + ): + # predicted candidate + token_class = edp[start_token_index + 1] - 1 + predicted_candidate_title = pred_cands[token_class] + final_class2predicted_spans[predicted_candidate_title].append( + [start_token_index, end_token_index] + ) + + # candidates probabilities + classes_probabilities = edpr[start_token_index + 1] + classes_probabilities_best_indices = classes_probabilities.argsort()[ + ::-1 + ] + titles_2_probs = [] + top_k = ( + min( + top_k, + len(classes_probabilities_best_indices), + ) + if top_k != -1 + else len(classes_probabilities_best_indices) + ) + for i in range(top_k): + titles_2_probs.append( + ( + pred_cands[classes_probabilities_best_indices[i] - 1], + classes_probabilities[ + classes_probabilities_best_indices[i] + ].item(), + ) + ) + spans2predicted_probabilities[ + (start_token_index, end_token_index) + ] = titles_2_probs + + if "patches" not in ts._d: + ts._d["patches"] = dict() + + ts._d["patches"][po] = dict() + sample_patch = ts._d["patches"][po] + + sample_patch["predicted_window_labels"] = final_class2predicted_spans + sample_patch["span_title_probabilities"] = spans2predicted_probabilities + + # additional info + sample_patch["predictable_candidates"] = pred_cands + + yield ts diff --git a/relik/reader/relik_reader.py b/relik/reader/relik_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..5acd9e8c4774593c4a61245ecf92a5559ab438f2 --- /dev/null +++ b/relik/reader/relik_reader.py @@ -0,0 +1,629 @@ +import collections +import logging +from pathlib import Path +from typing import Any, Callable, Dict, Iterator, List, Union + +import torch +import transformers as tr +from tqdm import tqdm +from transformers import AutoConfig + +from relik.common.log import get_console_logger, get_logger +from relik.reader.data.relik_reader_data_utils import batchify, flatten +from relik.reader.data.relik_reader_sample import RelikReaderSample +from relik.reader.pytorch_modules.hf.modeling_relik import ( + RelikReaderConfig, + RelikReaderSpanModel, +) +from relik.reader.relik_reader_predictor import RelikReaderPredictor +from relik.reader.utils.save_load_utilities import load_model_and_conf +from relik.reader.utils.special_symbols import NME_SYMBOL, get_special_symbols + +console_logger = get_console_logger() +logger = get_logger(__name__, level=logging.INFO) + + +class RelikReaderForSpanExtraction(torch.nn.Module): + def __init__( + self, + transformer_model: str | tr.PreTrainedModel | None = None, + additional_special_symbols: int = 0, + num_layers: int | None = None, + activation: str = "gelu", + linears_hidden_size: int | None = 512, + use_last_k_layers: int = 1, + training: bool = False, + device: str | torch.device | None = None, + tokenizer: str | tr.PreTrainedTokenizer | None = None, + **kwargs, + ) -> None: + super().__init__() + + if isinstance(transformer_model, str): + config = AutoConfig.from_pretrained( + transformer_model, trust_remote_code=True + ) + if "relik-reader" in config.model_type: + transformer_model = RelikReaderSpanModel.from_pretrained( + transformer_model, **kwargs + ) + else: + reader_config = RelikReaderConfig( + transformer_model=transformer_model, + additional_special_symbols=additional_special_symbols, + num_layers=num_layers, + activation=activation, + linears_hidden_size=linears_hidden_size, + use_last_k_layers=use_last_k_layers, + training=training, + ) + transformer_model = RelikReaderSpanModel(reader_config) + + self.relik_reader_model = transformer_model + + self._tokenizer = tokenizer + + # move the model to the device + self.to(device or torch.device("cpu")) + + def forward( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + token_type_ids: torch.Tensor, + prediction_mask: torch.Tensor | None = None, + special_symbols_mask: torch.Tensor | None = None, + special_symbols_mask_entities: torch.Tensor | None = None, + start_labels: torch.Tensor | None = None, + end_labels: torch.Tensor | None = None, + disambiguation_labels: torch.Tensor | None = None, + relation_labels: torch.Tensor | None = None, + is_validation: bool = False, + is_prediction: bool = False, + *args, + **kwargs, + ) -> Dict[str, Any]: + return self.relik_reader_model( + input_ids, + attention_mask, + token_type_ids, + prediction_mask, + special_symbols_mask, + special_symbols_mask_entities, + start_labels, + end_labels, + disambiguation_labels, + relation_labels, + is_validation, + is_prediction, + *args, + **kwargs, + ) + + def batch_predict( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + token_type_ids: torch.Tensor | None = None, + prediction_mask: torch.Tensor | None = None, + special_symbols_mask: torch.Tensor | None = None, + sample: List[RelikReaderSample] | None = None, + top_k: int = 5, # the amount of top-k most probable entities to predict + *args, + **kwargs, + ) -> Iterator[RelikReaderSample]: + """ + + + Args: + input_ids: + attention_mask: + token_type_ids: + prediction_mask: + special_symbols_mask: + sample: + top_k: + *args: + **kwargs: + + Returns: + + """ + forward_output = self.forward( + input_ids, + attention_mask, + token_type_ids, + prediction_mask, + special_symbols_mask, + ) + + ned_start_predictions = forward_output["ned_start_predictions"].cpu().numpy() + ned_end_predictions = forward_output["ned_end_predictions"].cpu().numpy() + ed_predictions = forward_output["ed_predictions"].cpu().numpy() + ed_probabilities = forward_output["ed_probabilities"].cpu().numpy() + + batch_predictable_candidates = kwargs["predictable_candidates"] + patch_offset = kwargs["patch_offset"] + for ts, ne_sp, ne_ep, edp, edpr, pred_cands, po in zip( + sample, + ned_start_predictions, + ned_end_predictions, + ed_predictions, + ed_probabilities, + batch_predictable_candidates, + patch_offset, + ): + ne_start_indices = [ti for ti, c in enumerate(ne_sp[1:]) if c > 0] + ne_end_indices = [ti for ti, c in enumerate(ne_ep[1:]) if c > 0] + + final_class2predicted_spans = collections.defaultdict(list) + spans2predicted_probabilities = dict() + for start_token_index, end_token_index in zip( + ne_start_indices, ne_end_indices + ): + # predicted candidate + token_class = edp[start_token_index + 1] - 1 + predicted_candidate_title = pred_cands[token_class] + final_class2predicted_spans[predicted_candidate_title].append( + [start_token_index, end_token_index] + ) + + # candidates probabilities + classes_probabilities = edpr[start_token_index + 1] + classes_probabilities_best_indices = classes_probabilities.argsort()[ + ::-1 + ] + titles_2_probs = [] + top_k = ( + min( + top_k, + len(classes_probabilities_best_indices), + ) + if top_k != -1 + else len(classes_probabilities_best_indices) + ) + for i in range(top_k): + titles_2_probs.append( + ( + pred_cands[classes_probabilities_best_indices[i] - 1], + classes_probabilities[ + classes_probabilities_best_indices[i] + ].item(), + ) + ) + spans2predicted_probabilities[ + (start_token_index, end_token_index) + ] = titles_2_probs + + if "patches" not in ts._d: + ts._d["patches"] = dict() + + ts._d["patches"][po] = dict() + sample_patch = ts._d["patches"][po] + + sample_patch["predicted_window_labels"] = final_class2predicted_spans + sample_patch["span_title_probabilities"] = spans2predicted_probabilities + + # additional info + sample_patch["predictable_candidates"] = pred_cands + + yield ts + + def _build_input(self, text: List[str], candidates: List[List[str]]) -> list[str]: + candidates_symbols = get_special_symbols(len(candidates)) + candidates = [ + [cs, ct] if ct != NME_SYMBOL else [NME_SYMBOL] + for cs, ct in zip(candidates_symbols, candidates) + ] + return ( + [self.tokenizer.cls_token] + + text + + [self.tokenizer.sep_token] + + flatten(candidates) + + [self.tokenizer.sep_token] + ) + + @staticmethod + def _compute_offsets(offsets_mapping): + offsets_mapping = offsets_mapping.numpy() + token2word = [] + word2token = {} + count = 0 + for i, offset in enumerate(offsets_mapping): + if offset[0] == 0: + token2word.append(i - count) + word2token[i - count] = [i] + else: + token2word.append(token2word[-1]) + word2token[token2word[-1]].append(i) + count += 1 + return token2word, word2token + + @staticmethod + def _convert_tokens_to_word_annotations(sample: RelikReaderSample): + triplets = [] + entities = [] + for entity in sample.predicted_entities: + if sample.entity_candidates: + entities.append( + ( + sample.token2word[entity[0] - 1], + sample.token2word[entity[1] - 1] + 1, + sample.entity_candidates[entity[2]], + ) + ) + else: + entities.append( + ( + sample.token2word[entity[0] - 1], + sample.token2word[entity[1] - 1] + 1, + -1, + ) + ) + for predicted_triplet, predicted_triplet_probabilities in zip( + sample.predicted_relations, sample.predicted_relations_probabilities + ): + subject, object_, relation = predicted_triplet + subject = entities[subject] + object_ = entities[object_] + relation = sample.candidates[relation] + triplets.append( + { + "subject": { + "start": subject[0], + "end": subject[1], + "type": subject[2], + "name": " ".join(sample.tokens[subject[0] : subject[1]]), + }, + "relation": { + "name": relation, + "probability": float(predicted_triplet_probabilities.round(2)), + }, + "object": { + "start": object_[0], + "end": object_[1], + "type": object_[2], + "name": " ".join(sample.tokens[object_[0] : object_[1]]), + }, + } + ) + sample.predicted_entities = entities + sample.predicted_relations = triplets + sample.predicted_relations_probabilities = None + + @torch.no_grad() + @torch.inference_mode() + def read( + self, + text: List[str] | List[List[str]] | None = None, + samples: List[RelikReaderSample] | None = None, + input_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + token_type_ids: torch.Tensor | None = None, + prediction_mask: torch.Tensor | None = None, + special_symbols_mask: torch.Tensor | None = None, + special_symbols_mask_entities: torch.Tensor | None = None, + candidates: List[List[str]] | None = None, + max_length: int | None = 1024, + max_batch_size: int | None = 64, + token_batch_size: int | None = None, + progress_bar: bool = False, + *args, + **kwargs, + ) -> List[List[RelikReaderSample]]: + """ + Reads the given text. + Args: + text: The text to read in tokens. + samples: + input_ids: The input ids of the text. + attention_mask: The attention mask of the text. + token_type_ids: The token type ids of the text. + prediction_mask: The prediction mask of the text. + special_symbols_mask: The special symbols mask of the text. + special_symbols_mask_entities: The special symbols mask entities of the text. + candidates: The candidates of the text. + max_length: The maximum length of the text. + max_batch_size: The maximum batch size. + token_batch_size: The maximum number of tokens per batch. + progress_bar: + Returns: + The predicted labels for each sample. + """ + if text is None and input_ids is None and samples is None: + raise ValueError( + "Either `text` or `input_ids` or `samples` must be provided." + ) + if (input_ids is None and samples is None) and ( + text is None or candidates is None + ): + raise ValueError( + "`text` and `candidates` must be provided to return the predictions when " + "`input_ids` and `samples` is not provided." + ) + if text is not None and samples is None: + if len(text) != len(candidates): + raise ValueError("`text` and `candidates` must have the same length.") + if isinstance(text[0], str): # change to list of text + text = [text] + candidates = [candidates] + + samples = [ + RelikReaderSample(tokens=t, candidates=c) + for t, c in zip(text, candidates) + ] + + if samples is not None: + # function that creates a batch from the 'current_batch' list + def output_batch() -> Dict[str, Any]: + assert ( + len( + set( + [ + len(elem["predictable_candidates"]) + for elem in current_batch + ] + ) + ) + == 1 + ), " ".join( + map( + str, + [len(elem["predictable_candidates"]) for elem in current_batch], + ) + ) + + batch_dict = dict() + + de_values_by_field = { + fn: [de[fn] for de in current_batch if fn in de] + for fn in self.fields_batcher + } + + # in case you provide fields batchers but in the batch + # there are no elements for that field + de_values_by_field = { + fn: fvs for fn, fvs in de_values_by_field.items() if len(fvs) > 0 + } + + assert len(set([len(v) for v in de_values_by_field.values()])) + + # todo: maybe we should report the user about possible + # fields filtering due to "None" instances + de_values_by_field = { + fn: fvs + for fn, fvs in de_values_by_field.items() + if all([fv is not None for fv in fvs]) + } + + for field_name, field_values in de_values_by_field.items(): + field_batch = ( + self.fields_batcher[field_name]([fv[0] for fv in field_values]) + if self.fields_batcher[field_name] is not None + else field_values + ) + + batch_dict[field_name] = field_batch + + batch_dict = { + k: v.to(self.device) if isinstance(v, torch.Tensor) else v + for k, v in batch_dict.items() + } + return batch_dict + + current_batch = [] + predictions = [] + current_cand_len = -1 + + for sample in tqdm(samples, disable=not progress_bar): + sample.candidates = [NME_SYMBOL] + sample.candidates + inputs_text = self._build_input(sample.tokens, sample.candidates) + model_inputs = self.tokenizer( + inputs_text, + is_split_into_words=True, + add_special_tokens=False, + padding=False, + truncation=True, + max_length=max_length or self.tokenizer.model_max_length, + return_offsets_mapping=True, + return_tensors="pt", + ) + model_inputs["special_symbols_mask"] = ( + model_inputs["input_ids"] > self.tokenizer.vocab_size + ) + # prediction mask is 0 until the first special symbol + model_inputs["token_type_ids"] = ( + torch.cumsum(model_inputs["special_symbols_mask"], dim=1) > 0 + ).long() + # shift prediction_mask to the left + model_inputs["prediction_mask"] = model_inputs["token_type_ids"].roll( + shifts=-1, dims=1 + ) + model_inputs["prediction_mask"][:, -1] = 1 + model_inputs["prediction_mask"][:, 0] = 1 + + assert ( + len(model_inputs["special_symbols_mask"]) + == len(model_inputs["prediction_mask"]) + == len(model_inputs["input_ids"]) + ) + + model_inputs["sample"] = sample + + # compute cand_len using special_symbols_mask + model_inputs["predictable_candidates"] = sample.candidates[ + : model_inputs["special_symbols_mask"].sum().item() + ] + # cand_len = sum([id_ > self.tokenizer.vocab_size for id_ in model_inputs["input_ids"]]) + offsets = model_inputs.pop("offset_mapping") + offsets = offsets[model_inputs["prediction_mask"] == 0] + sample.token2word, sample.word2token = self._compute_offsets(offsets) + future_max_len = max( + len(model_inputs["input_ids"]), + max([len(b["input_ids"]) for b in current_batch], default=0), + ) + future_tokens_per_batch = future_max_len * (len(current_batch) + 1) + + if len(current_batch) > 0 and ( + ( + len(model_inputs["predictable_candidates"]) != current_cand_len + and current_cand_len != -1 + ) + or ( + isinstance(token_batch_size, int) + and future_tokens_per_batch >= token_batch_size + ) + or len(current_batch) == max_batch_size + ): + batch_inputs = output_batch() + current_batch = [] + predictions.extend(list(self.batch_predict(**batch_inputs))) + current_cand_len = len(model_inputs["predictable_candidates"]) + current_batch.append(model_inputs) + + if current_batch: + batch_inputs = output_batch() + predictions.extend(list(self.batch_predict(**batch_inputs))) + else: + predictions = list( + self.batch_predict( + input_ids, + attention_mask, + token_type_ids, + prediction_mask, + special_symbols_mask, + special_symbols_mask_entities, + *args, + **kwargs, + ) + ) + return predictions + + @property + def device(self) -> torch.device: + """ + The device of the model. + """ + return next(self.parameters()).device + + @property + def tokenizer(self) -> tr.PreTrainedTokenizer: + """ + The tokenizer. + """ + if self._tokenizer: + return self._tokenizer + + self._tokenizer = tr.AutoTokenizer.from_pretrained( + self.relik_reader_model.config.name_or_path + ) + return self._tokenizer + + @property + def fields_batcher(self) -> Dict[str, Union[None, Callable[[list], Any]]]: + fields_batchers = { + "input_ids": lambda x: batchify( + x, padding_value=self.tokenizer.pad_token_id + ), + "attention_mask": lambda x: batchify(x, padding_value=0), + "token_type_ids": lambda x: batchify(x, padding_value=0), + "prediction_mask": lambda x: batchify(x, padding_value=1), + "global_attention": lambda x: batchify(x, padding_value=0), + "token2word": None, + "sample": None, + "special_symbols_mask": lambda x: batchify(x, padding_value=False), + "special_symbols_mask_entities": lambda x: batchify(x, padding_value=False), + } + if "roberta" in self.relik_reader_model.config.model_type: + del fields_batchers["token_type_ids"] + + return fields_batchers + + def save_pretrained( + self, + output_dir: str, + model_name: str | None = None, + push_to_hub: bool = False, + **kwargs, + ) -> None: + """ + Saves the model to the given path. + Args: + output_dir: The path to save the model to. + model_name: The name of the model. + push_to_hub: Whether to push the model to the hub. + """ + # create the output directory + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + model_name = model_name or "relik-reader-for-span-extraction" + + logger.info(f"Saving reader to {output_dir / model_name}") + + # save the model + self.relik_reader_model.register_for_auto_class() + self.relik_reader_model.save_pretrained( + output_dir / model_name, push_to_hub=push_to_hub, **kwargs + ) + + logger.info("Saving reader to disk done.") + + if self.tokenizer: + self.tokenizer.save_pretrained( + output_dir / model_name, push_to_hub=push_to_hub, **kwargs + ) + logger.info("Saving tokenizer to disk done.") + + +class RelikReader: + def __init__(self, model_path: str, predict_nmes: bool = False): + model, model_conf = load_model_and_conf(model_path) + model.training = False + model.eval() + + val_dataset_conf = model_conf.data.val_dataset + val_dataset_conf.special_symbols = get_special_symbols( + model_conf.model.entities_per_forward + ) + val_dataset_conf.transformer_model = model_conf.model.model.transformer_model + + self.predictor = RelikReaderPredictor( + model, + dataset_conf=model_conf.data.val_dataset, + predict_nmes=predict_nmes, + ) + self.model_path = model_path + + def link_entities( + self, + dataset_path_or_samples: str | Iterator[RelikReaderSample], + token_batch_size: int = 2048, + progress_bar: bool = False, + ) -> List[RelikReaderSample]: + data_input = ( + (dataset_path_or_samples, None) + if isinstance(dataset_path_or_samples, str) + else (None, dataset_path_or_samples) + ) + return self.predictor.predict( + *data_input, + dataset_conf=None, + token_batch_size=token_batch_size, + progress_bar=progress_bar, + ) + + # def save_pretrained(self, path: Union[str, Path]): + # self.predictor.save(path) + + +def main(): + rr = RelikReader("riccorl/relik-reader-aida-deberta-small-old", predict_nmes=True) + predictions = rr.link_entities( + "/Users/ric/Documents/PhD/Projects/relik/data/reader/aida/testa.jsonl" + ) + print(predictions) + + +if __name__ == "__main__": + main() diff --git a/relik/reader/relik_reader_core.py b/relik/reader/relik_reader_core.py new file mode 100644 index 0000000000000000000000000000000000000000..1d62c5f13b3c1f7e7ba02209d2c88813d4f960ac --- /dev/null +++ b/relik/reader/relik_reader_core.py @@ -0,0 +1,497 @@ +import collections +from typing import Any, Dict, Iterator, List, Optional + +import torch +from transformers import AutoModel +from transformers.activations import ClippedGELUActivation, GELUActivation +from transformers.modeling_utils import PoolerEndLogits + +from relik.reader.data.relik_reader_sample import RelikReaderSample + +activation2functions = { + "relu": torch.nn.ReLU(), + "gelu": GELUActivation(), + "gelu_10": ClippedGELUActivation(-10, 10), +} + + +class RelikReaderCoreModel(torch.nn.Module): + def __init__( + self, + transformer_model: str, + additional_special_symbols: int, + num_layers: Optional[int] = None, + activation: str = "gelu", + linears_hidden_size: Optional[int] = 512, + use_last_k_layers: int = 1, + training: bool = False, + ) -> None: + super().__init__() + + # Transformer model declaration + self.transformer_model_name = transformer_model + self.transformer_model = ( + AutoModel.from_pretrained(transformer_model) + if num_layers is None + else AutoModel.from_pretrained( + transformer_model, num_hidden_layers=num_layers + ) + ) + self.transformer_model.resize_token_embeddings( + self.transformer_model.config.vocab_size + additional_special_symbols + ) + + self.activation = activation + self.linears_hidden_size = linears_hidden_size + self.use_last_k_layers = use_last_k_layers + + # named entity detection layers + self.ned_start_classifier = self._get_projection_layer( + self.activation, last_hidden=2, layer_norm=False + ) + self.ned_end_classifier = PoolerEndLogits(self.transformer_model.config) + + # END entity disambiguation layer + self.ed_start_projector = self._get_projection_layer(self.activation) + self.ed_end_projector = self._get_projection_layer(self.activation) + + self.training = training + + # criterion + self.criterion = torch.nn.CrossEntropyLoss() + + def _get_projection_layer( + self, + activation: str, + last_hidden: Optional[int] = None, + input_hidden=None, + layer_norm: bool = True, + ) -> torch.nn.Sequential: + head_components = [ + torch.nn.Dropout(0.1), + torch.nn.Linear( + self.transformer_model.config.hidden_size * self.use_last_k_layers + if input_hidden is None + else input_hidden, + self.linears_hidden_size, + ), + activation2functions[activation], + torch.nn.Dropout(0.1), + torch.nn.Linear( + self.linears_hidden_size, + self.linears_hidden_size if last_hidden is None else last_hidden, + ), + ] + + if layer_norm: + head_components.append( + torch.nn.LayerNorm( + self.linears_hidden_size if last_hidden is None else last_hidden, + self.transformer_model.config.layer_norm_eps, + ) + ) + + return torch.nn.Sequential(*head_components) + + def _mask_logits(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + mask = mask.unsqueeze(-1) + if next(self.parameters()).dtype == torch.float16: + logits = logits * (1 - mask) - 65500 * mask + else: + logits = logits * (1 - mask) - 1e30 * mask + return logits + + def _get_model_features( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + token_type_ids: Optional[torch.Tensor], + ): + model_input = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "output_hidden_states": self.use_last_k_layers > 1, + } + + if token_type_ids is not None: + model_input["token_type_ids"] = token_type_ids + + model_output = self.transformer_model(**model_input) + + if self.use_last_k_layers > 1: + model_features = torch.cat( + model_output[1][-self.use_last_k_layers :], dim=-1 + ) + else: + model_features = model_output[0] + + return model_features + + def compute_ned_end_logits( + self, + start_predictions, + start_labels, + model_features, + prediction_mask, + batch_size, + ) -> Optional[torch.Tensor]: + # todo: maybe when constraining on the spans, + # we should not use a prediction_mask for the end tokens. + # at least we should not during training imo + start_positions = start_labels if self.training else start_predictions + start_positions_indices = ( + torch.arange(start_positions.size(1), device=start_positions.device) + .unsqueeze(0) + .expand(batch_size, -1)[start_positions > 0] + ).to(start_positions.device) + + if len(start_positions_indices) > 0: + expanded_features = torch.cat( + [ + model_features[i].unsqueeze(0).expand(x, -1, -1) + for i, x in enumerate(torch.sum(start_positions > 0, dim=-1)) + if x > 0 + ], + dim=0, + ).to(start_positions_indices.device) + + expanded_prediction_mask = torch.cat( + [ + prediction_mask[i].unsqueeze(0).expand(x, -1) + for i, x in enumerate(torch.sum(start_positions > 0, dim=-1)) + if x > 0 + ], + dim=0, + ).to(expanded_features.device) + + end_logits = self.ned_end_classifier( + hidden_states=expanded_features, + start_positions=start_positions_indices, + p_mask=expanded_prediction_mask, + ) + + return end_logits + + return None + + def compute_classification_logits( + self, + model_features, + special_symbols_mask, + prediction_mask, + batch_size, + start_positions=None, + end_positions=None, + ) -> torch.Tensor: + if start_positions is None or end_positions is None: + start_positions = torch.zeros_like(prediction_mask) + end_positions = torch.zeros_like(prediction_mask) + + model_start_features = self.ed_start_projector(model_features) + model_end_features = self.ed_end_projector(model_features) + model_end_features[start_positions > 0] = model_end_features[end_positions > 0] + + model_ed_features = torch.cat( + [model_start_features, model_end_features], dim=-1 + ) + + # computing ed features + classes_representations = torch.sum(special_symbols_mask, dim=1)[0].item() + special_symbols_representation = model_ed_features[special_symbols_mask].view( + batch_size, classes_representations, -1 + ) + + logits = torch.bmm( + model_ed_features, + torch.permute(special_symbols_representation, (0, 2, 1)), + ) + + logits = self._mask_logits(logits, prediction_mask) + + return logits + + def forward( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + token_type_ids: Optional[torch.Tensor] = None, + prediction_mask: Optional[torch.Tensor] = None, + special_symbols_mask: Optional[torch.Tensor] = None, + start_labels: Optional[torch.Tensor] = None, + end_labels: Optional[torch.Tensor] = None, + use_predefined_spans: bool = False, + *args, + **kwargs, + ) -> Dict[str, Any]: + batch_size, seq_len = input_ids.shape + + model_features = self._get_model_features( + input_ids, attention_mask, token_type_ids + ) + + # named entity detection if required + if use_predefined_spans: # no need to compute spans + ned_start_logits, ned_start_probabilities, ned_start_predictions = ( + None, + None, + torch.clone(start_labels) + if start_labels is not None + else torch.zeros_like(input_ids), + ) + ned_end_logits, ned_end_probabilities, ned_end_predictions = ( + None, + None, + torch.clone(end_labels) + if end_labels is not None + else torch.zeros_like(input_ids), + ) + + ned_start_predictions[ned_start_predictions > 0] = 1 + ned_end_predictions[ned_end_predictions > 0] = 1 + + else: # compute spans + # start boundary prediction + ned_start_logits = self.ned_start_classifier(model_features) + ned_start_logits = self._mask_logits(ned_start_logits, prediction_mask) + ned_start_probabilities = torch.softmax(ned_start_logits, dim=-1) + ned_start_predictions = ned_start_probabilities.argmax(dim=-1) + + # end boundary prediction + ned_start_labels = ( + torch.zeros_like(start_labels) if start_labels is not None else None + ) + + if ned_start_labels is not None: + ned_start_labels[start_labels == -100] = -100 + ned_start_labels[start_labels > 0] = 1 + + ned_end_logits = self.compute_ned_end_logits( + ned_start_predictions, + ned_start_labels, + model_features, + prediction_mask, + batch_size, + ) + + if ned_end_logits is not None: + ned_end_probabilities = torch.softmax(ned_end_logits, dim=-1) + ned_end_predictions = torch.argmax(ned_end_probabilities, dim=-1) + else: + ned_end_logits, ned_end_probabilities = None, None + ned_end_predictions = ned_start_predictions.new_zeros(batch_size) + + # flattening end predictions + # (flattening can happen only if the + # end boundaries were not predicted using the gold labels) + if not self.training: + flattened_end_predictions = torch.clone(ned_start_predictions) + flattened_end_predictions[flattened_end_predictions > 0] = 0 + + batch_start_predictions = list() + for elem_idx in range(batch_size): + batch_start_predictions.append( + torch.where(ned_start_predictions[elem_idx] > 0)[0].tolist() + ) + + # check that the total number of start predictions + # is equal to the end predictions + total_start_predictions = sum(map(len, batch_start_predictions)) + total_end_predictions = len(ned_end_predictions) + assert ( + total_start_predictions == 0 + or total_start_predictions == total_end_predictions + ), ( + f"Total number of start predictions = {total_start_predictions}. " + f"Total number of end predictions = {total_end_predictions}" + ) + + curr_end_pred_num = 0 + for elem_idx, bsp in enumerate(batch_start_predictions): + for sp in bsp: + ep = ned_end_predictions[curr_end_pred_num].item() + if ep < sp: + ep = sp + + # if we already set this span throw it (no overlap) + if flattened_end_predictions[elem_idx, ep] == 1: + ned_start_predictions[elem_idx, sp] = 0 + else: + flattened_end_predictions[elem_idx, ep] = 1 + + curr_end_pred_num += 1 + + ned_end_predictions = flattened_end_predictions + + start_position, end_position = ( + (start_labels, end_labels) + if self.training + else (ned_start_predictions, ned_end_predictions) + ) + + # Entity disambiguation + ed_logits = self.compute_classification_logits( + model_features, + special_symbols_mask, + prediction_mask, + batch_size, + start_position, + end_position, + ) + ed_probabilities = torch.softmax(ed_logits, dim=-1) + ed_predictions = torch.argmax(ed_probabilities, dim=-1) + + # output build + output_dict = dict( + batch_size=batch_size, + ned_start_logits=ned_start_logits, + ned_start_probabilities=ned_start_probabilities, + ned_start_predictions=ned_start_predictions, + ned_end_logits=ned_end_logits, + ned_end_probabilities=ned_end_probabilities, + ned_end_predictions=ned_end_predictions, + ed_logits=ed_logits, + ed_probabilities=ed_probabilities, + ed_predictions=ed_predictions, + ) + + # compute loss if labels + if start_labels is not None and end_labels is not None and self.training: + # named entity detection loss + + # start + if ned_start_logits is not None: + ned_start_loss = self.criterion( + ned_start_logits.view(-1, ned_start_logits.shape[-1]), + ned_start_labels.view(-1), + ) + else: + ned_start_loss = 0 + + # end + if ned_end_logits is not None: + ned_end_labels = torch.zeros_like(end_labels) + ned_end_labels[end_labels == -100] = -100 + ned_end_labels[end_labels > 0] = 1 + + ned_end_loss = self.criterion( + ned_end_logits, + ( + torch.arange( + ned_end_labels.size(1), device=ned_end_labels.device + ) + .unsqueeze(0) + .expand(batch_size, -1)[ned_end_labels > 0] + ).to(ned_end_labels.device), + ) + + else: + ned_end_loss = 0 + + # entity disambiguation loss + start_labels[ned_start_labels != 1] = -100 + ed_labels = torch.clone(start_labels) + ed_labels[end_labels > 0] = end_labels[end_labels > 0] + ed_loss = self.criterion( + ed_logits.view(-1, ed_logits.shape[-1]), + ed_labels.view(-1), + ) + + output_dict["ned_start_loss"] = ned_start_loss + output_dict["ned_end_loss"] = ned_end_loss + output_dict["ed_loss"] = ed_loss + + output_dict["loss"] = ned_start_loss + ned_end_loss + ed_loss + + return output_dict + + def batch_predict( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + token_type_ids: Optional[torch.Tensor] = None, + prediction_mask: Optional[torch.Tensor] = None, + special_symbols_mask: Optional[torch.Tensor] = None, + sample: Optional[List[RelikReaderSample]] = None, + top_k: int = 5, # the amount of top-k most probable entities to predict + *args, + **kwargs, + ) -> Iterator[RelikReaderSample]: + forward_output = self.forward( + input_ids, + attention_mask, + token_type_ids, + prediction_mask, + special_symbols_mask, + ) + + ned_start_predictions = forward_output["ned_start_predictions"].cpu().numpy() + ned_end_predictions = forward_output["ned_end_predictions"].cpu().numpy() + ed_predictions = forward_output["ed_predictions"].cpu().numpy() + ed_probabilities = forward_output["ed_probabilities"].cpu().numpy() + + batch_predictable_candidates = kwargs["predictable_candidates"] + patch_offset = kwargs["patch_offset"] + for ts, ne_sp, ne_ep, edp, edpr, pred_cands, po in zip( + sample, + ned_start_predictions, + ned_end_predictions, + ed_predictions, + ed_probabilities, + batch_predictable_candidates, + patch_offset, + ): + ne_start_indices = [ti for ti, c in enumerate(ne_sp[1:]) if c > 0] + ne_end_indices = [ti for ti, c in enumerate(ne_ep[1:]) if c > 0] + + final_class2predicted_spans = collections.defaultdict(list) + spans2predicted_probabilities = dict() + for start_token_index, end_token_index in zip( + ne_start_indices, ne_end_indices + ): + # predicted candidate + token_class = edp[start_token_index + 1] - 1 + predicted_candidate_title = pred_cands[token_class] + final_class2predicted_spans[predicted_candidate_title].append( + [start_token_index, end_token_index] + ) + + # candidates probabilities + classes_probabilities = edpr[start_token_index + 1] + classes_probabilities_best_indices = classes_probabilities.argsort()[ + ::-1 + ] + titles_2_probs = [] + top_k = ( + min( + top_k, + len(classes_probabilities_best_indices), + ) + if top_k != -1 + else len(classes_probabilities_best_indices) + ) + for i in range(top_k): + titles_2_probs.append( + ( + pred_cands[classes_probabilities_best_indices[i] - 1], + classes_probabilities[ + classes_probabilities_best_indices[i] + ].item(), + ) + ) + spans2predicted_probabilities[ + (start_token_index, end_token_index) + ] = titles_2_probs + + if "patches" not in ts._d: + ts._d["patches"] = dict() + + ts._d["patches"][po] = dict() + sample_patch = ts._d["patches"][po] + + sample_patch["predicted_window_labels"] = final_class2predicted_spans + sample_patch["span_title_probabilities"] = spans2predicted_probabilities + + # additional info + sample_patch["predictable_candidates"] = pred_cands + + yield ts diff --git a/relik/reader/relik_reader_predictor.py b/relik/reader/relik_reader_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..e5635d477d67febce3c937ef1945900f004bb269 --- /dev/null +++ b/relik/reader/relik_reader_predictor.py @@ -0,0 +1,168 @@ +import logging +from typing import Iterable, Iterator, List, Optional + +import hydra +import torch +from lightning.pytorch.utilities import move_data_to_device +from torch.utils.data import DataLoader +from tqdm import tqdm + +from relik.reader.data.patches import merge_patches_predictions +from relik.reader.data.relik_reader_sample import ( + RelikReaderSample, + load_relik_reader_samples, +) +from relik.reader.relik_reader_core import RelikReaderCoreModel +from relik.reader.utils.special_symbols import NME_SYMBOL + +logger = logging.getLogger(__name__) + + +def convert_tokens_to_char_annotations( + sample: RelikReaderSample, remove_nmes: bool = False +): + char_annotations = set() + + for ( + predicted_entity, + predicted_spans, + ) in sample.predicted_window_labels.items(): + if predicted_entity == NME_SYMBOL and remove_nmes: + continue + + for span_start, span_end in predicted_spans: + span_start = sample.token2char_start[str(span_start)] + span_end = sample.token2char_end[str(span_end)] + + char_annotations.add((span_start, span_end, predicted_entity)) + + char_probs_annotations = dict() + for ( + span_start, + span_end, + ), candidates_probs in sample.span_title_probabilities.items(): + span_start = sample.token2char_start[str(span_start)] + span_end = sample.token2char_end[str(span_end)] + char_probs_annotations[(span_start, span_end)] = { + title for title, _ in candidates_probs + } + + sample.predicted_window_labels_chars = char_annotations + sample.probs_window_labels_chars = char_probs_annotations + + +class RelikReaderPredictor: + def __init__( + self, + relik_reader_core: RelikReaderCoreModel, + dataset_conf: Optional[dict] = None, + predict_nmes: bool = False, + ) -> None: + self.relik_reader_core = relik_reader_core + self.dataset_conf = dataset_conf + self.predict_nmes = predict_nmes + + if self.dataset_conf is not None: + # instantiate dataset + self.dataset = hydra.utils.instantiate( + dataset_conf, + dataset_path=None, + samples=None, + ) + + def predict( + self, + path: Optional[str], + samples: Optional[Iterable[RelikReaderSample]], + dataset_conf: Optional[dict], + token_batch_size: int = 1024, + progress_bar: bool = False, + **kwargs, + ) -> List[RelikReaderSample]: + annotated_samples = list( + self._predict(path, samples, dataset_conf, token_batch_size, progress_bar) + ) + for sample in annotated_samples: + merge_patches_predictions(sample) + convert_tokens_to_char_annotations( + sample, remove_nmes=not self.predict_nmes + ) + return annotated_samples + + def _predict( + self, + path: Optional[str], + samples: Optional[Iterable[RelikReaderSample]], + dataset_conf: dict, + token_batch_size: int = 1024, + progress_bar: bool = False, + **kwargs, + ) -> Iterator[RelikReaderSample]: + assert ( + path is not None or samples is not None + ), "Either predict on a path or on an iterable of samples" + + samples = load_relik_reader_samples(path) if samples is None else samples + + # setup infrastructure to re-yield in order + def samples_it(): + for i, sample in enumerate(samples): + assert sample._mixin_prediction_position is None + sample._mixin_prediction_position = i + yield sample + + next_prediction_position = 0 + position2predicted_sample = {} + + # instantiate dataset + if getattr(self, "dataset", None) is not None: + dataset = self.dataset + dataset.samples = samples_it() + dataset.tokens_per_batch = token_batch_size + else: + dataset = hydra.utils.instantiate( + dataset_conf, + dataset_path=None, + samples=samples_it(), + tokens_per_batch=token_batch_size, + ) + + # instantiate dataloader + iterator = DataLoader(dataset, batch_size=None, num_workers=0, shuffle=False) + if progress_bar: + iterator = tqdm(iterator, desc="Predicting") + + model_device = next(self.relik_reader_core.parameters()).device + + with torch.inference_mode(): + for batch in iterator: + # do batch predict + with torch.autocast( + "cpu" if model_device == torch.device("cpu") else "cuda" + ): + batch = move_data_to_device(batch, model_device) + batch_out = self.relik_reader_core.batch_predict(**batch) + # update prediction position position + for sample in batch_out: + if sample._mixin_prediction_position >= next_prediction_position: + position2predicted_sample[ + sample._mixin_prediction_position + ] = sample + + # yield + while next_prediction_position in position2predicted_sample: + yield position2predicted_sample[next_prediction_position] + del position2predicted_sample[next_prediction_position] + next_prediction_position += 1 + + if len(position2predicted_sample) > 0: + logger.warning( + "It seems samples have been discarded in your dataset. " + "This means that you WON'T have a prediction for each input sample. " + "Prediction order will also be partially disrupted" + ) + for k, v in sorted(position2predicted_sample.items(), key=lambda x: x[0]): + yield v + + if progress_bar: + iterator.close() diff --git a/relik/reader/relik_reader_re.py b/relik/reader/relik_reader_re.py new file mode 100644 index 0000000000000000000000000000000000000000..d1efaa87110863901c522b277d6e594989b21997 --- /dev/null +++ b/relik/reader/relik_reader_re.py @@ -0,0 +1,556 @@ +import logging +from pathlib import Path +from typing import Any, Callable, Dict, Iterator, List, Optional, Union + +import numpy as np +import torch +import transformers as tr +from reader.data.relik_reader_data_utils import batchify, flatten +from reader.data.relik_reader_sample import RelikReaderSample +from reader.pytorch_modules.hf.modeling_relik import ( + RelikReaderConfig, + RelikReaderREModel, +) +from tqdm import tqdm +from transformers import AutoConfig + +from relik.common.log import get_console_logger, get_logger +from relik.reader.utils.special_symbols import NME_SYMBOL, get_special_symbols_re + +console_logger = get_console_logger() +logger = get_logger(__name__, level=logging.INFO) + + +class RelikReaderForTripletExtraction(torch.nn.Module): + def __init__( + self, + transformer_model: Optional[Union[str, tr.PreTrainedModel]] = None, + additional_special_symbols: Optional[int] = 0, + num_layers: Optional[int] = None, + activation: str = "gelu", + linears_hidden_size: Optional[int] = 512, + use_last_k_layers: int = 1, + training: bool = False, + device: Optional[Union[str, torch.device]] = None, + tokenizer: Optional[Union[str, tr.PreTrainedTokenizer]] = None, + **kwargs, + ) -> None: + super().__init__() + + if isinstance(transformer_model, str): + config = AutoConfig.from_pretrained( + transformer_model, trust_remote_code=True + ) + if "relik_reader" in config.model_type: + transformer_model = RelikReaderREModel.from_pretrained( + transformer_model, **kwargs + ) + else: + reader_config = RelikReaderConfig( + transformer_model=transformer_model, + additional_special_symbols=additional_special_symbols, + num_layers=num_layers, + activation=activation, + linears_hidden_size=linears_hidden_size, + use_last_k_layers=use_last_k_layers, + training=training, + ) + transformer_model = RelikReaderREModel(reader_config) + + self.relik_reader_re_model = transformer_model + + self._tokenizer = tokenizer + + # move the model to the device + self.to(device or torch.device("cpu")) + + def forward( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + token_type_ids: torch.Tensor, + prediction_mask: Optional[torch.Tensor] = None, + special_symbols_mask: Optional[torch.Tensor] = None, + special_symbols_mask_entities: Optional[torch.Tensor] = None, + start_labels: Optional[torch.Tensor] = None, + end_labels: Optional[torch.Tensor] = None, + disambiguation_labels: Optional[torch.Tensor] = None, + relation_labels: Optional[torch.Tensor] = None, + is_validation: bool = False, + is_prediction: bool = False, + *args, + **kwargs, + ) -> Dict[str, Any]: + return self.relik_reader_re_model( + input_ids, + attention_mask, + token_type_ids, + prediction_mask, + special_symbols_mask, + special_symbols_mask_entities, + start_labels, + end_labels, + disambiguation_labels, + relation_labels, + is_validation, + is_prediction, + *args, + **kwargs, + ) + + def batch_predict( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + token_type_ids: Optional[torch.Tensor] = None, + prediction_mask: Optional[torch.Tensor] = None, + special_symbols_mask: Optional[torch.Tensor] = None, + special_symbols_mask_entities: Optional[torch.Tensor] = None, + sample: Optional[List[RelikReaderSample]] = None, + *args, + **kwargs, + ) -> Iterator[RelikReaderSample]: + """ + Predicts the labels for a batch of samples. + Args: + input_ids: The input ids of the batch. + attention_mask: The attention mask of the batch. + token_type_ids: The token type ids of the batch. + prediction_mask: The prediction mask of the batch. + special_symbols_mask: The special symbols mask of the batch. + special_symbols_mask_entities: The special symbols mask entities of the batch. + sample: The samples of the batch. + Returns: + The predicted labels for each sample. + """ + forward_output = self.forward( + input_ids, + attention_mask, + token_type_ids, + prediction_mask, + special_symbols_mask, + special_symbols_mask_entities, + is_prediction=True, + ) + ned_start_predictions = forward_output["ned_start_predictions"].cpu().numpy() + ned_end_predictions = forward_output["ned_end_predictions"] # .cpu().numpy() + ed_predictions = forward_output["re_entities_predictions"].cpu().numpy() + ned_type_predictions = forward_output["ned_type_predictions"].cpu().numpy() + re_predictions = forward_output["re_predictions"].cpu().numpy() + re_probabilities = forward_output["re_probabilities"].detach().cpu().numpy() + if sample is None: + sample = [RelikReaderSample() for _ in range(len(input_ids))] + for ts, ne_st, ne_end, re_pred, re_prob, edp, ne_et in zip( + sample, + ned_start_predictions, + ned_end_predictions, + re_predictions, + re_probabilities, + ed_predictions, + ned_type_predictions, + ): + ne_end = ne_end.cpu().numpy() + entities = [] + if self.relik_reader_re_model.entity_type_loss: + starts = np.argwhere(ne_st) + i = 0 + for start, end in zip(starts, ne_end): + ends = np.argwhere(end) + for e in ends: + entities.append([start[0], e[0], ne_et[i]]) + i += 1 + else: + starts = np.argwhere(ne_st) + for start, end in zip(starts, ne_end): + ends = np.argwhere(end) + for e in ends: + entities.append([start[0], e[0]]) + + edp = edp[: len(entities)] + re_pred = re_pred[: len(entities), : len(entities)] + re_prob = re_prob[: len(entities), : len(entities)] + possible_re = np.argwhere(re_pred) + predicted_triplets = [] + predicted_triplets_prob = [] + for i, j, r in possible_re: + if self.relik_reader_re_model.relation_disambiguation_loss: + if not ( + i != j + and edp[i, r] == 1 + and edp[j, r] == 1 + and edp[i, 0] == 0 + and edp[j, 0] == 0 + ): + continue + predicted_triplets.append([i, j, r]) + predicted_triplets_prob.append(re_prob[i, j, r]) + + ts._d["predicted_relations"] = predicted_triplets + ts._d["predicted_entities"] = entities + ts._d["predicted_relations_probabilities"] = predicted_triplets_prob + if ts.token2word: + self._convert_tokens_to_word_annotations(ts) + yield ts + + def _build_input(self, text: List[str], candidates: List[List[str]]) -> List[int]: + candidates_symbols = get_special_symbols_re(len(candidates)) + candidates = [ + [cs, ct] if ct != NME_SYMBOL else [NME_SYMBOL] + for cs, ct in zip(candidates_symbols, candidates) + ] + return ( + [self.tokenizer.cls_token] + + text + + [self.tokenizer.sep_token] + + flatten(candidates) + + [self.tokenizer.sep_token] + ) + + @staticmethod + def _compute_offsets(offsets_mapping): + offsets_mapping = offsets_mapping.numpy() + token2word = [] + word2token = {} + count = 0 + for i, offset in enumerate(offsets_mapping): + if offset[0] == 0: + token2word.append(i - count) + word2token[i - count] = [i] + else: + token2word.append(token2word[-1]) + word2token[token2word[-1]].append(i) + count += 1 + return token2word, word2token + + @staticmethod + def _convert_tokens_to_word_annotations(sample: RelikReaderSample): + triplets = [] + entities = [] + for entity in sample.predicted_entities: + if sample.entity_candidates: + entities.append( + ( + sample.token2word[entity[0] - 1], + sample.token2word[entity[1] - 1] + 1, + sample.entity_candidates[entity[2]], + ) + ) + else: + entities.append( + ( + sample.token2word[entity[0] - 1], + sample.token2word[entity[1] - 1] + 1, + -1, + ) + ) + for predicted_triplet, predicted_triplet_probabilities in zip( + sample.predicted_relations, sample.predicted_relations_probabilities + ): + subject, object_, relation = predicted_triplet + subject = entities[subject] + object_ = entities[object_] + relation = sample.candidates[relation] + triplets.append( + { + "subject": { + "start": subject[0], + "end": subject[1], + "type": subject[2], + "name": " ".join(sample.tokens[subject[0] : subject[1]]), + }, + "relation": { + "name": relation, + "probability": float(predicted_triplet_probabilities.round(2)), + }, + "object": { + "start": object_[0], + "end": object_[1], + "type": object_[2], + "name": " ".join(sample.tokens[object_[0] : object_[1]]), + }, + } + ) + sample.predicted_entities = entities + sample.predicted_relations = triplets + sample.predicted_relations_probabilities = None + + @torch.no_grad() + @torch.inference_mode() + def read( + self, + text: Optional[Union[List[str], List[List[str]]]] = None, + samples: Optional[List[RelikReaderSample]] = None, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + prediction_mask: Optional[torch.Tensor] = None, + special_symbols_mask: Optional[torch.Tensor] = None, + special_symbols_mask_entities: Optional[torch.Tensor] = None, + candidates: Optional[List[List[str]]] = None, + max_length: Optional[int] = 1024, + max_batch_size: Optional[int] = 64, + token_batch_size: Optional[int] = None, + progress_bar: bool = False, + *args, + **kwargs, + ) -> List[List[RelikReaderSample]]: + """ + Reads the given text. + Args: + text: The text to read in tokens. + input_ids: The input ids of the text. + attention_mask: The attention mask of the text. + token_type_ids: The token type ids of the text. + prediction_mask: The prediction mask of the text. + special_symbols_mask: The special symbols mask of the text. + special_symbols_mask_entities: The special symbols mask entities of the text. + candidates: The candidates of the text. + max_length: The maximum length of the text. + max_batch_size: The maximum batch size. + token_batch_size: The maximum number of tokens per batch. + Returns: + The predicted labels for each sample. + """ + if text is None and input_ids is None and samples is None: + raise ValueError( + "Either `text` or `input_ids` or `samples` must be provided." + ) + if (input_ids is None and samples is None) and ( + text is None or candidates is None + ): + raise ValueError( + "`text` and `candidates` must be provided to return the predictions when `input_ids` and `samples` is not provided." + ) + if text is not None and samples is None: + if len(text) != len(candidates): + raise ValueError("`text` and `candidates` must have the same length.") + if isinstance(text[0], str): # change to list of text + text = [text] + candidates = [candidates] + + samples = [ + RelikReaderSample(tokens=t, candidates=c) + for t, c in zip(text, candidates) + ] + + if samples is not None: + # function that creates a batch from the 'current_batch' list + def output_batch() -> Dict[str, Any]: + assert ( + len( + set( + [ + len(elem["predictable_candidates"]) + for elem in current_batch + ] + ) + ) + == 1 + ), " ".join( + map( + str, + [len(elem["predictable_candidates"]) for elem in current_batch], + ) + ) + + batch_dict = dict() + + de_values_by_field = { + fn: [de[fn] for de in current_batch if fn in de] + for fn in self.fields_batcher + } + + # in case you provide fields batchers but in the batch + # there are no elements for that field + de_values_by_field = { + fn: fvs for fn, fvs in de_values_by_field.items() if len(fvs) > 0 + } + + assert len(set([len(v) for v in de_values_by_field.values()])) + + # todo: maybe we should report the user about possible + # fields filtering due to "None" instances + de_values_by_field = { + fn: fvs + for fn, fvs in de_values_by_field.items() + if all([fv is not None for fv in fvs]) + } + + for field_name, field_values in de_values_by_field.items(): + field_batch = ( + self.fields_batcher[field_name]([fv[0] for fv in field_values]) + if self.fields_batcher[field_name] is not None + else field_values + ) + + batch_dict[field_name] = field_batch + + batch_dict = { + k: v.to(self.device) if isinstance(v, torch.Tensor) else v + for k, v in batch_dict.items() + } + return batch_dict + + current_batch = [] + predictions = [] + current_cand_len = -1 + + for sample in tqdm(samples, disable=not progress_bar): + sample.candidates = [NME_SYMBOL] + sample.candidates + inputs_text = self._build_input(sample.tokens, sample.candidates) + model_inputs = self.tokenizer( + inputs_text, + is_split_into_words=True, + add_special_tokens=False, + padding=False, + truncation=True, + max_length=max_length or self.tokenizer.model_max_length, + return_offsets_mapping=True, + return_tensors="pt", + ) + model_inputs["special_symbols_mask"] = ( + model_inputs["input_ids"] > self.tokenizer.vocab_size + ) + # prediction mask is 0 until the first special symbol + model_inputs["token_type_ids"] = ( + torch.cumsum(model_inputs["special_symbols_mask"], dim=1) > 0 + ).long() + # shift prediction_mask to the left + model_inputs["prediction_mask"] = model_inputs["token_type_ids"].roll( + shifts=-1, dims=1 + ) + model_inputs["prediction_mask"][:, -1] = 1 + model_inputs["prediction_mask"][:, 0] = 1 + + assert ( + len(model_inputs["special_symbols_mask"]) + == len(model_inputs["prediction_mask"]) + == len(model_inputs["input_ids"]) + ) + + model_inputs["sample"] = sample + + # compute cand_len using special_symbols_mask + model_inputs["predictable_candidates"] = sample.candidates[ + : model_inputs["special_symbols_mask"].sum().item() + ] + # cand_len = sum([id_ > self.tokenizer.vocab_size for id_ in model_inputs["input_ids"]]) + offsets = model_inputs.pop("offset_mapping") + offsets = offsets[model_inputs["prediction_mask"] == 0] + sample.token2word, sample.word2token = self._compute_offsets(offsets) + future_max_len = max( + len(model_inputs["input_ids"]), + max([len(b["input_ids"]) for b in current_batch], default=0), + ) + future_tokens_per_batch = future_max_len * (len(current_batch) + 1) + + if len(current_batch) > 0 and ( + ( + len(model_inputs["predictable_candidates"]) != current_cand_len + and current_cand_len != -1 + ) + or ( + isinstance(token_batch_size, int) + and future_tokens_per_batch >= token_batch_size + ) + or len(current_batch) == max_batch_size + ): + batch_inputs = output_batch() + current_batch = [] + predictions.extend(list(self.batch_predict(**batch_inputs))) + current_cand_len = len(model_inputs["predictable_candidates"]) + current_batch.append(model_inputs) + + if current_batch: + batch_inputs = output_batch() + predictions.extend(list(self.batch_predict(**batch_inputs))) + else: + predictions = list( + self.batch_predict( + input_ids, + attention_mask, + token_type_ids, + prediction_mask, + special_symbols_mask, + special_symbols_mask_entities, + *args, + **kwargs, + ) + ) + return predictions + + @property + def device(self) -> torch.device: + """ + The device of the model. + """ + return next(self.parameters()).device + + @property + def tokenizer(self) -> tr.PreTrainedTokenizer: + """ + The tokenizer. + """ + if self._tokenizer: + return self._tokenizer + + self._tokenizer = tr.AutoTokenizer.from_pretrained( + self.relik_reader_re_model.config.name_or_path + ) + return self._tokenizer + + @property + def fields_batcher(self) -> Dict[str, Union[None, Callable[[list], Any]]]: + fields_batchers = { + "input_ids": lambda x: batchify( + x, padding_value=self.tokenizer.pad_token_id + ), + "attention_mask": lambda x: batchify(x, padding_value=0), + "token_type_ids": lambda x: batchify(x, padding_value=0), + "prediction_mask": lambda x: batchify(x, padding_value=1), + "global_attention": lambda x: batchify(x, padding_value=0), + "token2word": None, + "sample": None, + "special_symbols_mask": lambda x: batchify(x, padding_value=False), + "special_symbols_mask_entities": lambda x: batchify(x, padding_value=False), + } + if "roberta" in self.relik_reader_re_model.config.model_type: + del fields_batchers["token_type_ids"] + + return fields_batchers + + def save_pretrained( + self, + output_dir: str, + model_name: Optional[str] = None, + push_to_hub: bool = False, + **kwargs, + ) -> None: + """ + Saves the model to the given path. + Args: + output_dir: The path to save the model to. + model_name: The name of the model. + push_to_hub: Whether to push the model to the hub. + """ + # create the output directory + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + model_name = model_name or "relik_reader_for_triplet_extraction" + + logger.info(f"Saving reader to {output_dir / model_name}") + + # save the model + self.relik_reader_re_model.register_for_auto_class() + self.relik_reader_re_model.save_pretrained( + output_dir / model_name, push_to_hub=push_to_hub, **kwargs + ) + + logger.info("Saving reader to disk done.") + + if self.tokenizer: + self.tokenizer.save_pretrained( + output_dir / model_name, push_to_hub=push_to_hub, **kwargs + ) + logger.info("Saving tokenizer to disk done.") diff --git a/relik/reader/relik_reader_re_data.py b/relik/reader/relik_reader_re_data.py new file mode 100644 index 0000000000000000000000000000000000000000..2bc43b78eeba4e63447d8bf333dacf2a993716fd --- /dev/null +++ b/relik/reader/relik_reader_re_data.py @@ -0,0 +1,849 @@ +import logging +from typing import ( + Any, + Callable, + Dict, + Generator, + Iterator, + List, + NamedTuple, + Optional, + Tuple, + Union, +) + +import numpy as np +import torch +from reader.data.relik_reader_data_utils import ( + add_noise_to_value, + batchify, + batchify_matrices, + batchify_tensor, + chunks, + flatten, +) +from reader.data.relik_reader_sample import RelikReaderSample, load_relik_reader_samples +from torch.utils.data import IterableDataset +from transformers import AutoTokenizer + +from relik.reader.utils.special_symbols import NME_SYMBOL + +logger = logging.getLogger(__name__) + + +class TokenizationOutput(NamedTuple): + input_ids: torch.Tensor + attention_mask: torch.Tensor + token_type_ids: torch.Tensor + prediction_mask: torch.Tensor + special_symbols_mask: torch.Tensor + special_symbols_mask_entities: torch.Tensor + + +class RelikREDataset(IterableDataset): + def __init__( + self, + dataset_path: str, + materialize_samples: bool, + transformer_model: str, + special_symbols: List[str], + shuffle_candidates: Optional[Union[bool, float]], + flip_candidates: Optional[Union[bool, float]], + relations_definitions: Union[str, Dict[str, str]], + for_inference: bool, + entities_definitions: Optional[Union[str, Dict[str, str]]] = None, + special_symbols_entities: Optional[List[str]] = None, + noise_param: float = 0.1, + sorting_fields: Optional[str] = None, + tokens_per_batch: int = 2048, + batch_size: int = None, + max_batch_size: int = 128, + section_size: int = 50_000, + prebatch: bool = True, + max_candidates: int = 0, + add_gold_candidates: bool = True, + use_nme: bool = True, + min_length: int = 5, + max_length: int = 2048, + model_max_length: int = 1000, + skip_empty_training_samples: bool = True, + drop_last: bool = False, + samples: Optional[Iterator[RelikReaderSample]] = None, + **kwargs, + ): + super().__init__(**kwargs) + self.dataset_path = dataset_path + self.materialize_samples = materialize_samples + self.samples: Optional[List[RelikReaderSample]] = None + if self.materialize_samples: + self.samples = list() + + self.tokenizer = self._build_tokenizer(transformer_model, special_symbols) + self.special_symbols = special_symbols + self.special_symbols_entities = special_symbols_entities + self.shuffle_candidates = shuffle_candidates + self.flip_candidates = flip_candidates + self.for_inference = for_inference + self.noise_param = noise_param + self.batching_fields = ["input_ids"] + self.sorting_fields = ( + sorting_fields if sorting_fields is not None else self.batching_fields + ) + + # open relations definitions file if needed + if type(relations_definitions) == str: + relations_definitions = { + line.split("\t")[0]: line.split("\t")[1] + for line in open(relations_definitions) + } + self.max_candidates = max_candidates + self.relations_definitions = relations_definitions + self.entities_definitions = entities_definitions + + self.add_gold_candidates = add_gold_candidates + self.use_nme = use_nme + self.min_length = min_length + self.max_length = max_length + self.model_max_length = ( + model_max_length + if model_max_length < self.tokenizer.model_max_length + else self.tokenizer.model_max_length + ) + self.transformer_model = transformer_model + self.skip_empty_training_samples = skip_empty_training_samples + self.drop_last = drop_last + self.samples = samples + + self.tokens_per_batch = tokens_per_batch + self.batch_size = batch_size + self.max_batch_size = max_batch_size + self.section_size = section_size + self.prebatch = prebatch + + def _build_tokenizer(self, transformer_model: str, special_symbols: List[str]): + return AutoTokenizer.from_pretrained( + transformer_model, + additional_special_tokens=[ss for ss in special_symbols], + add_prefix_space=True, + ) + + @property + def fields_batcher(self) -> Dict[str, Union[None, Callable[[list], Any]]]: + fields_batchers = { + "input_ids": lambda x: batchify( + x, padding_value=self.tokenizer.pad_token_id + ), + "attention_mask": lambda x: batchify(x, padding_value=0), + "token_type_ids": lambda x: batchify(x, padding_value=0), + "prediction_mask": lambda x: batchify(x, padding_value=1), + "global_attention": lambda x: batchify(x, padding_value=0), + "token2word": None, + "sample": None, + "special_symbols_mask": lambda x: batchify(x, padding_value=False), + "special_symbols_mask_entities": lambda x: batchify(x, padding_value=False), + "start_labels": lambda x: batchify(x, padding_value=-100), + "end_labels": lambda x: batchify_matrices(x, padding_value=-100), + "disambiguation_labels": lambda x: batchify(x, padding_value=-100), + "relation_labels": lambda x: batchify_tensor(x, padding_value=-100), + "predictable_candidates": None, + } + if "roberta" in self.transformer_model: + del fields_batchers["token_type_ids"] + + return fields_batchers + + def _build_input_ids( + self, sentence_input_ids: List[int], candidates_input_ids: List[List[int]] + ) -> List[int]: + return ( + [self.tokenizer.cls_token_id] + + sentence_input_ids + + [self.tokenizer.sep_token_id] + + flatten(candidates_input_ids) + + [self.tokenizer.sep_token_id] + ) + + def _get_special_symbols_mask(self, input_ids: torch.Tensor) -> torch.Tensor: + special_symbols_mask = input_ids >= ( + len(self.tokenizer) + - len(self.special_symbols + self.special_symbols_entities) + ) + special_symbols_mask[0] = True + return special_symbols_mask + + def _build_tokenizer_essentials( + self, input_ids, original_sequence + ) -> TokenizationOutput: + input_ids = torch.tensor(input_ids, dtype=torch.long) + attention_mask = torch.ones_like(input_ids) + + total_sequence_len = len(input_ids) + predictable_sentence_len = len(original_sequence) + + # token type ids + token_type_ids = torch.cat( + [ + input_ids.new_zeros( + predictable_sentence_len + 2 + ), # original sentence bpes + CLS and SEP + input_ids.new_ones(total_sequence_len - predictable_sentence_len - 2), + ] + ) + + # prediction mask -> boolean on tokens that are predictable + + prediction_mask = torch.tensor( + [1] + + ([0] * predictable_sentence_len) + + ([1] * (total_sequence_len - predictable_sentence_len - 1)) + ) + + assert len(prediction_mask) == len(input_ids) + + # special symbols mask + special_symbols_mask = input_ids >= ( + len(self.tokenizer) + - len(self.special_symbols) # + self.special_symbols_entities) + ) + if self.entities_definitions is not None: + # select only the first N true values where N is len(entities_definitions) + special_symbols_mask_entities = special_symbols_mask.clone() + special_symbols_mask_entities[ + special_symbols_mask_entities.cumsum(0) > len(self.entities_definitions) + ] = False + special_symbols_mask = special_symbols_mask ^ special_symbols_mask_entities + else: + special_symbols_mask_entities = special_symbols_mask.clone() + + return TokenizationOutput( + input_ids, + attention_mask, + token_type_ids, + prediction_mask, + special_symbols_mask, + special_symbols_mask_entities, + ) + + def _build_labels( + self, + sample, + tokenization_output: TokenizationOutput, + ) -> Tuple[torch.Tensor, torch.Tensor]: + start_labels = [0] * len(tokenization_output.input_ids) + end_labels = [] + + sample.entities.sort(key=lambda x: (x[0], x[1])) + + prev_start_bpe = -1 + num_repeat_start = 0 + if self.entities_definitions: + sample.entities = [(ce[0], ce[1], ce[2]) for ce in sample.entities] + sample.entity_candidates = list(self.entities_definitions.keys()) + disambiguation_labels = torch.zeros( + len(sample.entities), + len(sample.entity_candidates) + len(sample.candidates), + ) + else: + sample.entities = [(ce[0], ce[1], "") for ce in sample.entities] + disambiguation_labels = torch.zeros( + len(sample.entities), len(sample.candidates) + ) + ignored_labels_indices = tokenization_output.prediction_mask == 1 + for idx, c_ent in enumerate(sample.entities): + start_bpe = sample.word2token[c_ent[0]][0] + 1 + end_bpe = sample.word2token[c_ent[1] - 1][-1] + 1 + class_index = idx + start_labels[start_bpe] = class_index + 1 # +1 for the NONE class + if start_bpe != prev_start_bpe: + end_labels.append([0] * len(tokenization_output.input_ids)) + # end_labels[-1][:start_bpe] = [-100] * start_bpe + end_labels[-1][end_bpe] = class_index + 1 + else: + end_labels[-1][end_bpe] = class_index + 1 + num_repeat_start += 1 + if self.entities_definitions: + entity_type_idx = sample.entity_candidates.index(c_ent[2]) + disambiguation_labels[idx, entity_type_idx] = 1 + prev_start_bpe = start_bpe + + start_labels = torch.tensor(start_labels, dtype=torch.long) + start_labels[ignored_labels_indices] = -100 + + end_labels = torch.tensor(end_labels, dtype=torch.long) + end_labels[ignored_labels_indices.repeat(len(end_labels), 1)] = -100 + + relation_labels = torch.zeros( + len(sample.entities), len(sample.entities), len(sample.candidates) + ) + + # sample.relations = [] + for re in sample.triplets: + if re["relation"]["name"] not in sample.candidates: + re_class_index = len(sample.candidates) - 1 + else: + re_class_index = sample.candidates.index( + re["relation"]["name"] + ) # should remove this +1 + if self.entities_definitions: + subject_class_index = sample.entities.index( + ( + re["subject"]["start"], + re["subject"]["end"], + re["subject"]["type"], + ) + ) + object_class_index = sample.entities.index( + (re["object"]["start"], re["object"]["end"], re["object"]["type"]) + ) + else: + subject_class_index = sample.entities.index( + (re["subject"]["start"], re["subject"]["end"], "") + ) + object_class_index = sample.entities.index( + (re["object"]["start"], re["object"]["end"], "") + ) + + relation_labels[subject_class_index, object_class_index, re_class_index] = 1 + + if self.entities_definitions: + disambiguation_labels[ + subject_class_index, re_class_index + len(sample.entity_candidates) + ] = 1 + disambiguation_labels[ + object_class_index, re_class_index + len(sample.entity_candidates) + ] = 1 + # sample.relations.append([re['subject']['start'], re['subject']['end'], re['subject']['type'], re['relation']['name'], re['object']['start'], re['object']['end'], re['object']['type']]) + else: + disambiguation_labels[subject_class_index, re_class_index] = 1 + disambiguation_labels[object_class_index, re_class_index] = 1 + # sample.relations.append([re['subject']['start'], re['subject']['end'], "", re['relation']['name'], re['object']['start'], re['object']['end'], ""]) + return start_labels, end_labels, disambiguation_labels, relation_labels + + def __iter__(self): + dataset_iterator = self.dataset_iterator_func() + current_dataset_elements = [] + i = None + for i, dataset_elem in enumerate(dataset_iterator, start=1): + if ( + self.section_size is not None + and len(current_dataset_elements) == self.section_size + ): + for batch in self.materialize_batches(current_dataset_elements): + yield batch + current_dataset_elements = [] + current_dataset_elements.append(dataset_elem) + if i % 50_000 == 0: + logger.info(f"Processed: {i} number of elements") + if len(current_dataset_elements) != 0: + for batch in self.materialize_batches(current_dataset_elements): + yield batch + if i is not None: + logger.info(f"Dataset finished: {i} number of elements processed") + else: + logger.warning("Dataset empty") + + def dataset_iterator_func(self): + data_samples = ( + load_relik_reader_samples(self.dataset_path) + if self.samples is None + else self.samples + ) + for sample in data_samples: + # input sentence tokenization + input_tokenized = self.tokenizer( + sample.tokens, + return_offsets_mapping=True, + add_special_tokens=False, + is_split_into_words=True, + ) + input_subwords = input_tokenized["input_ids"] + offsets = input_tokenized["offset_mapping"] + token2word = [] + word2token = {} + count = 0 + for i, offset in enumerate(offsets): + if offset[0] == 0: + token2word.append(i - count) + word2token[i - count] = [i] + else: + token2word.append(token2word[-1]) + word2token[token2word[-1]].append(i) + count += 1 + sample.token2word = token2word + sample.word2token = word2token + # input_subwords = sample.tokens[1:-1] # removing special tokens + candidates_symbols = self.special_symbols + + if self.max_candidates > 0: + # truncate candidates + sample.candidates = sample.candidates[: self.max_candidates] + + # add NME as a possible candidate + if self.use_nme: + sample.candidates.insert(0, NME_SYMBOL) + + # training time sample mods + if not self.for_inference: + # check whether the sample has labels if not skip + if ( + sample.triplets is None or len(sample.triplets) == 0 + ) and self.skip_empty_training_samples: + logger.warning( + "Sample {} has no labels, skipping".format(sample.sample_id) + ) + continue + + # add gold candidates if missing + if self.add_gold_candidates: + candidates_set = set(sample.candidates) + candidates_to_add = [] + for candidate_title in sample.triplets: + if candidate_title["relation"]["name"] not in candidates_set: + candidates_to_add.append( + candidate_title["relation"]["name"] + ) + if len(candidates_to_add) > 0: + # replacing last candidates with the gold ones + # this is done in order to preserve the ordering + added_gold_candidates = 0 + gold_candidates_titles_set = set( + set(ct["relation"]["name"] for ct in sample.triplets) + ) + for i in reversed(range(len(sample.candidates))): + if ( + sample.candidates[i] not in gold_candidates_titles_set + and sample.candidates[i] != NME_SYMBOL + ): + sample.candidates[i] = candidates_to_add[ + added_gold_candidates + ] + added_gold_candidates += 1 + if len(candidates_to_add) == added_gold_candidates: + break + + candidates_still_to_add = ( + len(candidates_to_add) - added_gold_candidates + ) + while ( + len(sample.candidates) <= len(candidates_symbols) + and candidates_still_to_add != 0 + ): + sample.candidates.append( + candidates_to_add[added_gold_candidates] + ) + added_gold_candidates += 1 + candidates_still_to_add -= 1 + + # shuffle candidates + if ( + isinstance(self.shuffle_candidates, bool) + and self.shuffle_candidates + ) or ( + isinstance(self.shuffle_candidates, float) + and np.random.uniform() < self.shuffle_candidates + ): + np.random.shuffle(sample.candidates) + if NME_SYMBOL in sample.candidates: + sample.candidates.remove(NME_SYMBOL) + sample.candidates.insert(0, NME_SYMBOL) + + # flip candidates + if ( + isinstance(self.flip_candidates, bool) and self.flip_candidates + ) or ( + isinstance(self.flip_candidates, float) + and np.random.uniform() < self.flip_candidates + ): + for i in range(len(sample.candidates) - 1): + if np.random.uniform() < 0.5: + sample.candidates[i], sample.candidates[i + 1] = ( + sample.candidates[i + 1], + sample.candidates[i], + ) + if NME_SYMBOL in sample.candidates: + sample.candidates.remove(NME_SYMBOL) + sample.candidates.insert(0, NME_SYMBOL) + + # candidates encoding + candidates_symbols = candidates_symbols[: len(sample.candidates)] + relations_defs = [ + "{} {}".format(cs, self.relations_definitions[ct]) + if ct != NME_SYMBOL + else NME_SYMBOL + for cs, ct in zip(candidates_symbols, sample.candidates) + ] + if self.entities_definitions is not None: + candidates_entities_symbols = list(self.special_symbols_entities) + candidates_entities_symbols = candidates_entities_symbols[ + : len(self.entities_definitions) + ] + entity_defs = [ + "{} {}".format(cs, self.entities_definitions[ct]) + for cs, ct in zip( + candidates_entities_symbols, self.entities_definitions.keys() + ) + ] + relations_defs = ( + entity_defs + [self.tokenizer.sep_token] + relations_defs + ) + + candidates_encoding_result = self.tokenizer.batch_encode_plus( + relations_defs, + add_special_tokens=False, + ).input_ids + + # drop candidates if the number of input tokens is too long for the model + if ( + sum(map(len, candidates_encoding_result)) + + len(input_subwords) + + 20 # + 20 special tokens + > self.model_max_length + ): + if self.for_inference: + acceptable_tokens_from_candidates = ( + self.model_max_length - 20 - len(input_subwords) + ) + while ( + cum_len + len(candidates_encoding_result[i]) + < acceptable_tokens_from_candidates + ): + cum_len += len(candidates_encoding_result[i]) + i += 1 + + candidates_encoding_result = candidates_encoding_result[:i] + if self.entities_definitions is not None: + candidates_symbols = candidates_symbols[ + : i - len(self.entities_definitions) + ] + sample.candidates = sample.candidates[ + : i - len(self.entities_definitions) + ] + else: + candidates_symbols = candidates_symbols[:i] + sample.candidates = sample.candidates[:i] + + else: + gold_candidates_set = set( + [wl["relation"]["name"] for wl in sample.triplets] + ) + gold_candidates_indices = [ + i + for i, wc in enumerate(sample.candidates) + if wc in gold_candidates_set + ] + if self.entities_definitions is not None: + gold_candidates_indices = [ + i + len(self.entities_definitions) + for i in gold_candidates_indices + ] + # add entities indices + gold_candidates_indices = gold_candidates_indices + list( + range(len(self.entities_definitions)) + ) + necessary_taken_tokens = sum( + map( + len, + [ + candidates_encoding_result[i] + for i in gold_candidates_indices + ], + ) + ) + + acceptable_tokens_from_candidates = ( + self.model_max_length + - 20 + - len(input_subwords) + - necessary_taken_tokens + ) + + assert acceptable_tokens_from_candidates > 0 + + i = 0 + cum_len = 0 + while ( + cum_len + len(candidates_encoding_result[i]) + < acceptable_tokens_from_candidates + ): + if i not in gold_candidates_indices: + cum_len += len(candidates_encoding_result[i]) + i += 1 + + new_indices = sorted( + list(set(list(range(i)) + gold_candidates_indices)) + ) + np.random.shuffle(new_indices) + + candidates_encoding_result = [ + candidates_encoding_result[i] for i in new_indices + ] + if self.entities_definitions is not None: + sample.candidates = [ + sample.candidates[i - len(self.entities_definitions)] + for i in new_indices + ] + candidates_symbols = candidates_symbols[ + : i - len(self.entities_definitions) + ] + else: + candidates_symbols = [ + candidates_symbols[i] for i in new_indices + ] + sample.window_candidates = [ + sample.window_candidates[i] for i in new_indices + ] + if len(sample.candidates) == 0: + logger.warning( + "Sample {} has no candidates after truncation due to max length".format( + sample.sample_id + ) + ) + continue + + # final input_ids build + input_ids = self._build_input_ids( + sentence_input_ids=input_subwords, + candidates_input_ids=candidates_encoding_result, + ) + + # complete input building (e.g. attention / prediction mask) + tokenization_output = self._build_tokenizer_essentials( + input_ids, input_subwords + ) + + # labels creation + start_labels, end_labels, disambiguation_labels, relation_labels = ( + None, + None, + None, + None, + ) + if sample.entities is not None and len(sample.entities) > 0: + ( + start_labels, + end_labels, + disambiguation_labels, + relation_labels, + ) = self._build_labels( + sample, + tokenization_output, + ) + + yield { + "input_ids": tokenization_output.input_ids, + "attention_mask": tokenization_output.attention_mask, + "token_type_ids": tokenization_output.token_type_ids, + "prediction_mask": tokenization_output.prediction_mask, + "special_symbols_mask": tokenization_output.special_symbols_mask, + "special_symbols_mask_entities": tokenization_output.special_symbols_mask_entities, + "sample": sample, + "start_labels": start_labels, + "end_labels": end_labels, + "disambiguation_labels": disambiguation_labels, + "relation_labels": relation_labels, + "predictable_candidates": candidates_symbols, + } + + def preshuffle_elements(self, dataset_elements: List): + # This shuffling is done so that when using the sorting function, + # if it is deterministic given a collection and its order, we will + # make the whole operation not deterministic anymore. + # Basically, the aim is not to build every time the same batches. + if not self.for_inference: + dataset_elements = np.random.permutation(dataset_elements) + + sorting_fn = ( + lambda elem: add_noise_to_value( + sum(len(elem[k]) for k in self.sorting_fields), + noise_param=self.noise_param, + ) + if not self.for_inference + else sum(len(elem[k]) for k in self.sorting_fields) + ) + + dataset_elements = sorted(dataset_elements, key=sorting_fn) + + if self.for_inference: + return dataset_elements + + ds = list(chunks(dataset_elements, 64)) # todo: modified + np.random.shuffle(ds) + return flatten(ds) + + def materialize_batches( + self, dataset_elements: List[Dict[str, Any]] + ) -> Generator[Dict[str, Any], None, None]: + if self.prebatch: + dataset_elements = self.preshuffle_elements(dataset_elements) + + current_batch = [] + + # function that creates a batch from the 'current_batch' list + def output_batch() -> Dict[str, Any]: + assert ( + len( + set([len(elem["predictable_candidates"]) for elem in current_batch]) + ) + == 1 + ), " ".join( + map( + str, [len(elem["predictable_candidates"]) for elem in current_batch] + ) + ) + + batch_dict = dict() + + de_values_by_field = { + fn: [de[fn] for de in current_batch if fn in de] + for fn in self.fields_batcher + } + + # in case you provide fields batchers but in the batch + # there are no elements for that field + de_values_by_field = { + fn: fvs for fn, fvs in de_values_by_field.items() if len(fvs) > 0 + } + + assert len(set([len(v) for v in de_values_by_field.values()])) + + # todo: maybe we should report the user about possible + # fields filtering due to "None" instances + de_values_by_field = { + fn: fvs + for fn, fvs in de_values_by_field.items() + if all([fv is not None for fv in fvs]) + } + + for field_name, field_values in de_values_by_field.items(): + field_batch = ( + self.fields_batcher[field_name](field_values) + if self.fields_batcher[field_name] is not None + else field_values + ) + + batch_dict[field_name] = field_batch + + return batch_dict + + max_len_discards, min_len_discards = 0, 0 + + should_token_batch = self.batch_size is None + + curr_pred_elements = -1 + for de in dataset_elements: + if ( + should_token_batch + and self.max_batch_size != -1 + and len(current_batch) == self.max_batch_size + ) or (not should_token_batch and len(current_batch) == self.batch_size): + yield output_batch() + current_batch = [] + curr_pred_elements = -1 + + # todo support max length (and min length) as dicts + + too_long_fields = [ + k + for k in de + if self.max_length != -1 + and torch.is_tensor(de[k]) + and len(de[k]) > self.max_length + ] + if len(too_long_fields) > 0: + max_len_discards += 1 + continue + + too_short_fields = [ + k + for k in de + if self.min_length != -1 + and torch.is_tensor(de[k]) + and len(de[k]) < self.min_length + ] + if len(too_short_fields) > 0: + min_len_discards += 1 + continue + + if should_token_batch: + de_len = sum(len(de[k]) for k in self.batching_fields) + + future_max_len = max( + de_len, + max( + [ + sum(len(bde[k]) for k in self.batching_fields) + for bde in current_batch + ], + default=0, + ), + ) + + future_tokens_per_batch = future_max_len * (len(current_batch) + 1) + + num_predictable_candidates = len(de["predictable_candidates"]) + + if len(current_batch) > 0 and ( + future_tokens_per_batch >= self.tokens_per_batch + or ( + num_predictable_candidates != curr_pred_elements + and curr_pred_elements != -1 + ) + ): + yield output_batch() + current_batch = [] + + current_batch.append(de) + curr_pred_elements = len(de["predictable_candidates"]) + + if len(current_batch) != 0 and not self.drop_last: + yield output_batch() + + if max_len_discards > 0: + if self.for_inference: + logger.warning( + f"WARNING: Inference mode is True but {max_len_discards} samples longer than max length were " + f"found. The {max_len_discards} samples will be DISCARDED. If you are doing some kind of evaluation" + f", this can INVALIDATE results. This might happen if the max length was not set to -1 or if the " + f"sample length exceeds the maximum length supported by the current model." + ) + else: + logger.warning( + f"During iteration, {max_len_discards} elements were " + f"discarded since longer than max length {self.max_length}" + ) + + if min_len_discards > 0: + if self.for_inference: + logger.warning( + f"WARNING: Inference mode is True but {min_len_discards} samples shorter than min length were " + f"found. The {min_len_discards} samples will be DISCARDED. If you are doing some kind of evaluation" + f", this can INVALIDATE results. This might happen if the min length was not set to -1 or if the " + f"sample length is shorter than the minimum length supported by the current model." + ) + else: + logger.warning( + f"During iteration, {min_len_discards} elements were " + f"discarded since shorter than min length {self.min_length}" + ) + + +def main(): + special_symbols = [NME_SYMBOL] + [f"R-{i}" for i in range(50)] + + relik_dataset = RelikREDataset( + "/home/huguetcabot/alby-re/alby/data/nyt-alby+/valid.jsonl", + materialize_samples=False, + transformer_model="microsoft/deberta-v3-base", + special_symbols=special_symbols, + shuffle_candidates=False, + flip_candidates=False, + for_inference=True, + ) + + for batch in relik_dataset: + print(batch) + exit(0) + + +if __name__ == "__main__": + main() diff --git a/relik/reader/trainer/__init__.py b/relik/reader/trainer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/relik/reader/trainer/predict.py b/relik/reader/trainer/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..3801bef958f9a092f8d094d2e99fe476fd4caed9 --- /dev/null +++ b/relik/reader/trainer/predict.py @@ -0,0 +1,57 @@ +import argparse +from pprint import pprint +from typing import Optional + +from relik.reader.relik_reader import RelikReader +from relik.reader.utils.strong_matching_eval import StrongMatching + + +def predict( + model_path: str, + dataset_path: str, + token_batch_size: int, + is_eval: bool, + output_path: Optional[str], +) -> None: + relik_reader = RelikReader(model_path) + predicted_samples = relik_reader.link_entities( + dataset_path, token_batch_size=token_batch_size + ) + if is_eval: + eval_dict = StrongMatching()(predicted_samples) + pprint(eval_dict) + if output_path is not None: + with open(output_path, "w") as f: + for sample in predicted_samples: + f.write(sample.to_jsons() + "\n") + + +def parse_arg() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument( + "--model-path", + required=True, + ) + parser.add_argument("--dataset-path", "-i", required=True) + parser.add_argument("--is-eval", action="store_true") + parser.add_argument( + "--output-path", + "-o", + ) + parser.add_argument("--token-batch-size", default=4096) + return parser.parse_args() + + +def main(): + args = parse_arg() + predict( + args.model_path, + args.dataset_path, + token_batch_size=args.token_batch_size, + is_eval=args.is_eval, + output_path=args.output_path, + ) + + +if __name__ == "__main__": + main() diff --git a/relik/reader/trainer/predict_re.py b/relik/reader/trainer/predict_re.py new file mode 100644 index 0000000000000000000000000000000000000000..2f21b4b15f857ec561e14f2e8c9fb50ea4f00637 --- /dev/null +++ b/relik/reader/trainer/predict_re.py @@ -0,0 +1,125 @@ +import argparse + +import torch +from reader.data.relik_reader_sample import load_relik_reader_samples + +from relik.reader.pytorch_modules.hf.modeling_relik import ( + RelikReaderConfig, + RelikReaderREModel, +) +from relik.reader.relik_reader_re import RelikReaderForTripletExtraction +from relik.reader.utils.relation_matching_eval import StrongMatching + +dict_nyt = { + "/people/person/nationality": "nationality", + "/sports/sports_team/location": "sports team location", + "/location/country/administrative_divisions": "administrative divisions", + "/business/company/major_shareholders": "shareholders", + "/people/ethnicity/people": "ethnicity", + "/people/ethnicity/geographic_distribution": "geographic distributi6on", + "/business/company_shareholder/major_shareholder_of": "major shareholder", + "/location/location/contains": "location", + "/business/company/founders": "founders", + "/business/person/company": "company", + "/business/company/advisors": "advisor", + "/people/deceased_person/place_of_death": "place of death", + "/business/company/industry": "industry", + "/people/person/ethnicity": "ethnic background", + "/people/person/place_of_birth": "place of birth", + "/location/administrative_division/country": "country of an administration division", + "/people/person/place_lived": "place lived", + "/sports/sports_team_location/teams": "sports team", + "/people/person/children": "child", + "/people/person/religion": "religion", + "/location/neighborhood/neighborhood_of": "neighborhood", + "/location/country/capital": "capital", + "/business/company/place_founded": "company founded location", + "/people/person/profession": "occupation", +} + + +def eval(model_path, data_path, is_eval, output_path=None): + if model_path.endswith(".ckpt"): + # if it is a lightning checkpoint we load the model state dict and the tokenizer from the config + model_dict = torch.load(model_path) + + additional_special_symbols = model_dict["hyper_parameters"][ + "additional_special_symbols" + ] + from transformers import AutoTokenizer + + from relik.reader.utils.special_symbols import get_special_symbols_re + + special_symbols = get_special_symbols_re(additional_special_symbols - 1) + tokenizer = AutoTokenizer.from_pretrained( + model_dict["hyper_parameters"]["transformer_model"], + additional_special_tokens=special_symbols, + add_prefix_space=True, + ) + config_model = RelikReaderConfig( + model_dict["hyper_parameters"]["transformer_model"], + len(special_symbols), + training=False, + ) + model = RelikReaderREModel(config_model) + model_dict["state_dict"] = { + k.replace("relik_reader_re_model.", ""): v + for k, v in model_dict["state_dict"].items() + } + model.load_state_dict(model_dict["state_dict"], strict=False) + reader = RelikReaderForTripletExtraction( + model, training=False, device="cuda", tokenizer=tokenizer + ) + else: + # if it is a huggingface model we load the model directly. Note that it could even be a string from the hub + model = RelikReaderREModel.from_pretrained(model_path) + reader = RelikReaderForTripletExtraction(model, training=False, device="cuda") + + samples = list(load_relik_reader_samples(data_path)) + + for sample in samples: + sample.candidates = [dict_nyt[cand] for cand in sample.candidates] + sample.triplets = [ + { + "subject": triplet["subject"], + "relation": { + "name": dict_nyt[triplet["relation"]["name"]], + "type": triplet["relation"]["type"], + }, + "object": triplet["object"], + } + for triplet in sample.triplets + ] + + predicted_samples = reader.read(samples=samples, progress_bar=True) + if is_eval: + strong_matching_metric = StrongMatching() + predicted_samples = list(predicted_samples) + for k, v in strong_matching_metric(predicted_samples).items(): + print(f"test_{k}", v) + if output_path is not None: + with open(output_path, "w") as f: + for sample in predicted_samples: + f.write(sample.to_jsons() + "\n") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_path", + type=str, + default="/home/huguetcabot/alby-re/relik/relik/reader/models/relik_re_reader_base", + ) + parser.add_argument( + "--data_path", + type=str, + default="/home/huguetcabot/alby-re/relik/relik/reader/data/testa.jsonl", + ) + parser.add_argument("--is-eval", action="store_true") + parser.add_argument("--output_path", type=str, default=None) + args = parser.parse_args() + eval(args.model_path, args.data_path, args.is_eval, args.output_path) + + +if __name__ == "__main__": + main() diff --git a/relik/reader/trainer/train.py b/relik/reader/trainer/train.py new file mode 100644 index 0000000000000000000000000000000000000000..f1983b38c02199f45c112247fc74bd09d3f1e4f0 --- /dev/null +++ b/relik/reader/trainer/train.py @@ -0,0 +1,98 @@ +import hydra +import lightning +from hydra.utils import to_absolute_path +from lightning import Trainer +from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint +from lightning.pytorch.loggers.wandb import WandbLogger +from omegaconf import DictConfig, OmegaConf, open_dict +from reader.data.relik_reader_data import RelikDataset +from reader.lightning_modules.relik_reader_pl_module import RelikReaderPLModule +from reader.pytorch_modules.optim import LayerWiseLRDecayOptimizer +from torch.utils.data import DataLoader + +from relik.reader.utils.special_symbols import get_special_symbols +from relik.reader.utils.strong_matching_eval import ELStrongMatchingCallback + + +@hydra.main(config_path="../conf", config_name="config") +def train(cfg: DictConfig) -> None: + lightning.seed_everything(cfg.training.seed) + + special_symbols = get_special_symbols(cfg.model.entities_per_forward) + + # model declaration + model = RelikReaderPLModule( + cfg=OmegaConf.to_container(cfg), + transformer_model=cfg.model.model.transformer_model, + additional_special_symbols=len(special_symbols), + training=True, + ) + + # optimizer declaration + opt_conf = cfg.model.optimizer + electra_optimizer_factory = LayerWiseLRDecayOptimizer( + lr=opt_conf.lr, + warmup_steps=opt_conf.warmup_steps, + total_steps=opt_conf.total_steps, + total_reset=opt_conf.total_reset, + no_decay_params=opt_conf.no_decay_params, + weight_decay=opt_conf.weight_decay, + lr_decay=opt_conf.lr_decay, + ) + + model.set_optimizer_factory(electra_optimizer_factory) + + # datasets declaration + train_dataset: RelikDataset = hydra.utils.instantiate( + cfg.data.train_dataset, + dataset_path=to_absolute_path(cfg.data.train_dataset_path), + special_symbols=special_symbols, + ) + + # update of validation dataset config with special_symbols since they + # are required even from the EvaluationCallback dataset_config + with open_dict(cfg): + cfg.data.val_dataset.special_symbols = special_symbols + + val_dataset: RelikDataset = hydra.utils.instantiate( + cfg.data.val_dataset, + dataset_path=to_absolute_path(cfg.data.val_dataset_path), + ) + + # callbacks declaration + callbacks = [ + ELStrongMatchingCallback( + to_absolute_path(cfg.data.val_dataset_path), cfg.data.val_dataset + ), + ModelCheckpoint( + "model", + filename="{epoch}-{val_core_f1:.2f}", + monitor="val_core_f1", + mode="max", + ), + LearningRateMonitor(), + ] + + wandb_logger = WandbLogger(cfg.model_name, project=cfg.project_name) + + # trainer declaration + trainer: Trainer = hydra.utils.instantiate( + cfg.training.trainer, + callbacks=callbacks, + logger=wandb_logger, + ) + + # Trainer fit + trainer.fit( + model=model, + train_dataloaders=DataLoader(train_dataset, batch_size=None, num_workers=0), + val_dataloaders=DataLoader(val_dataset, batch_size=None, num_workers=0), + ) + + +def main(): + train() + + +if __name__ == "__main__": + main() diff --git a/relik/reader/trainer/train_re.py b/relik/reader/trainer/train_re.py new file mode 100644 index 0000000000000000000000000000000000000000..550b3fc95d653bfc9503af635931aa72176ca89d --- /dev/null +++ b/relik/reader/trainer/train_re.py @@ -0,0 +1,109 @@ +import hydra +import lightning +from hydra.utils import to_absolute_path +from lightning import Trainer +from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint +from lightning.pytorch.loggers.wandb import WandbLogger +from omegaconf import DictConfig, OmegaConf, open_dict +from reader.pytorch_modules.optim import LayerWiseLRDecayOptimizer +from torch.utils.data import DataLoader + +from relik.reader.lightning_modules.relik_reader_re_pl_module import ( + RelikReaderREPLModule, +) +from relik.reader.relik_reader_re_data import RelikREDataset +from relik.reader.utils.relation_matching_eval import REStrongMatchingCallback +from relik.reader.utils.special_symbols import get_special_symbols_re + + +@hydra.main(config_path="conf", config_name="config") +def train(cfg: DictConfig) -> None: + lightning.seed_everything(cfg.training.seed) + + special_symbols = get_special_symbols_re(cfg.model.entities_per_forward) + + # datasets declaration + train_dataset: RelikREDataset = hydra.utils.instantiate( + cfg.data.train_dataset, + dataset_path=to_absolute_path(cfg.data.train_dataset_path), + special_symbols=special_symbols, + ) + + # update of validation dataset config with special_symbols since they + # are required even from the EvaluationCallback dataset_config + with open_dict(cfg): + cfg.data.val_dataset.special_symbols = special_symbols + + val_dataset: RelikREDataset = hydra.utils.instantiate( + cfg.data.val_dataset, + dataset_path=to_absolute_path(cfg.data.val_dataset_path), + ) + + # model declaration + model = RelikReaderREPLModule( + cfg=OmegaConf.to_container(cfg), + transformer_model=cfg.model.model.transformer_model, + additional_special_symbols=len(special_symbols), + training=True, + ) + model.relik_reader_re_model._tokenizer = train_dataset.tokenizer + # optimizer declaration + opt_conf = cfg.model.optimizer + + # adamw_optimizer_factory = AdamWWithWarmupOptimizer( + # lr=opt_conf.lr, + # warmup_steps=opt_conf.warmup_steps, + # total_steps=opt_conf.total_steps, + # no_decay_params=opt_conf.no_decay_params, + # weight_decay=opt_conf.weight_decay, + # ) + + electra_optimizer_factory = LayerWiseLRDecayOptimizer( + lr=opt_conf.lr, + warmup_steps=opt_conf.warmup_steps, + total_steps=opt_conf.total_steps, + total_reset=opt_conf.total_reset, + no_decay_params=opt_conf.no_decay_params, + weight_decay=opt_conf.weight_decay, + lr_decay=opt_conf.lr_decay, + ) + + model.set_optimizer_factory(electra_optimizer_factory) + + # callbacks declaration + callbacks = [ + REStrongMatchingCallback( + to_absolute_path(cfg.data.val_dataset_path), cfg.data.val_dataset + ), + ModelCheckpoint( + "model", + filename="{epoch}-{val_f1:.2f}", + monitor="val_f1", + mode="max", + ), + LearningRateMonitor(), + ] + + wandb_logger = WandbLogger(cfg.model_name, project=cfg.project_name) + + # trainer declaration + trainer: Trainer = hydra.utils.instantiate( + cfg.training.trainer, + callbacks=callbacks, + logger=wandb_logger, + ) + + # Trainer fit + trainer.fit( + model=model, + train_dataloaders=DataLoader(train_dataset, batch_size=None, num_workers=0), + val_dataloaders=DataLoader(val_dataset, batch_size=None, num_workers=0), + ) + + +def main(): + train() + + +if __name__ == "__main__": + main() diff --git a/relik/reader/utils/__init__.py b/relik/reader/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/relik/reader/utils/metrics.py b/relik/reader/utils/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..fa17bf5d23cc888d6da0c6f40cf7bd3c20d77a66 --- /dev/null +++ b/relik/reader/utils/metrics.py @@ -0,0 +1,18 @@ +def safe_divide(num: float, den: float) -> float: + if den == 0: + return 0 + else: + return num / den + + +def f1_measure(precision: float, recall: float) -> float: + if precision == 0 or recall == 0: + return 0.0 + return safe_divide(2 * precision * recall, (precision + recall)) + + +def compute_metrics(total_correct, total_preds, total_gold): + precision = safe_divide(total_correct, total_preds) + recall = safe_divide(total_correct, total_gold) + f1 = f1_measure(precision, recall) + return precision, recall, f1 diff --git a/relik/reader/utils/relation_matching_eval.py b/relik/reader/utils/relation_matching_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..94a6b1e7a8dc155ed1ab9f6c52cb3c1eebd44505 --- /dev/null +++ b/relik/reader/utils/relation_matching_eval.py @@ -0,0 +1,172 @@ +from typing import Dict, List + +from lightning.pytorch.callbacks import Callback +from reader.data.relik_reader_sample import RelikReaderSample + +from relik.reader.relik_reader_predictor import RelikReaderPredictor +from relik.reader.utils.metrics import compute_metrics + + +class StrongMatching: + def __call__(self, predicted_samples: List[RelikReaderSample]) -> Dict: + # accumulators + correct_predictions, total_predictions, total_gold = ( + 0, + 0, + 0, + ) + correct_predictions_strict, total_predictions_strict = ( + 0, + 0, + ) + correct_predictions_bound, total_predictions_bound = ( + 0, + 0, + ) + correct_span_predictions, total_span_predictions, total_gold_spans = 0, 0, 0 + + # collect data from samples + for sample in predicted_samples: + if sample.triplets is None: + sample.triplets = [] + + if sample.entity_candidates: + predicted_annotations_strict = set( + [ + ( + triplet["subject"]["start"], + triplet["subject"]["end"], + triplet["subject"]["type"], + triplet["relation"]["name"], + triplet["object"]["start"], + triplet["object"]["end"], + triplet["object"]["type"], + ) + for triplet in sample.predicted_relations + ] + ) + gold_annotations_strict = set( + [ + ( + triplet["subject"]["start"], + triplet["subject"]["end"], + triplet["subject"]["type"], + triplet["relation"]["name"], + triplet["object"]["start"], + triplet["object"]["end"], + triplet["object"]["type"], + ) + for triplet in sample.triplets + ] + ) + predicted_spans_strict = set(sample.predicted_entities) + gold_spans_strict = set(sample.entities) + # strict + correct_span_predictions += len( + predicted_spans_strict.intersection(gold_spans_strict) + ) + total_span_predictions += len(predicted_spans_strict) + total_gold_spans += len(gold_spans_strict) + correct_predictions_strict += len( + predicted_annotations_strict.intersection(gold_annotations_strict) + ) + total_predictions_strict += len(predicted_annotations_strict) + + predicted_annotations = set( + [ + ( + triplet["subject"]["start"], + triplet["subject"]["end"], + -1, + triplet["relation"]["name"], + triplet["object"]["start"], + triplet["object"]["end"], + -1, + ) + for triplet in sample.predicted_relations + ] + ) + gold_annotations = set( + [ + ( + triplet["subject"]["start"], + triplet["subject"]["end"], + -1, + triplet["relation"]["name"], + triplet["object"]["start"], + triplet["object"]["end"], + -1, + ) + for triplet in sample.triplets + ] + ) + predicted_spans = set( + [(ss, se) for (ss, se, _) in sample.predicted_entities] + ) + gold_spans = set([(ss, se) for (ss, se, _) in sample.entities]) + total_gold_spans += len(gold_spans) + + correct_predictions_bound += len(predicted_spans.intersection(gold_spans)) + total_predictions_bound += len(predicted_spans) + + total_predictions += len(predicted_annotations) + total_gold += len(gold_annotations) + # correct relation extraction + correct_predictions += len( + predicted_annotations.intersection(gold_annotations) + ) + + span_precision, span_recall, span_f1 = compute_metrics( + correct_span_predictions, total_span_predictions, total_gold_spans + ) + bound_precision, bound_recall, bound_f1 = compute_metrics( + correct_predictions_bound, total_predictions_bound, total_gold_spans + ) + + precision, recall, f1 = compute_metrics( + correct_predictions, total_predictions, total_gold + ) + + if sample.entity_candidates: + precision_strict, recall_strict, f1_strict = compute_metrics( + correct_predictions_strict, total_predictions_strict, total_gold + ) + return { + "span-precision": span_precision, + "span-recall": span_recall, + "span-f1": span_f1, + "precision": precision, + "recall": recall, + "f1": f1, + "precision-strict": precision_strict, + "recall-strict": recall_strict, + "f1-strict": f1_strict, + } + else: + return { + "span-precision": bound_precision, + "span-recall": bound_recall, + "span-f1": bound_f1, + "precision": precision, + "recall": recall, + "f1": f1, + } + + +class REStrongMatchingCallback(Callback): + def __init__(self, dataset_path: str, dataset_conf) -> None: + super().__init__() + self.dataset_path = dataset_path + self.dataset_conf = dataset_conf + self.strong_matching_metric = StrongMatching() + + def on_validation_epoch_start(self, trainer, pl_module) -> None: + relik_reader_predictor = RelikReaderPredictor(pl_module.relik_reader_re_model) + predicted_samples = relik_reader_predictor._predict( + self.dataset_path, + None, + self.dataset_conf, + ) + predicted_samples = list(predicted_samples) + for k, v in self.strong_matching_metric(predicted_samples).items(): + pl_module.log(f"val_{k}", v) diff --git a/relik/reader/utils/save_load_utilities.py b/relik/reader/utils/save_load_utilities.py new file mode 100644 index 0000000000000000000000000000000000000000..1e635650c1f69c0e223d268f97ec9d6e0677742c --- /dev/null +++ b/relik/reader/utils/save_load_utilities.py @@ -0,0 +1,76 @@ +import argparse +import os +from typing import Tuple + +import omegaconf +import torch + +from relik.common.utils import from_cache +from relik.reader.lightning_modules.relik_reader_pl_module import RelikReaderPLModule +from relik.reader.relik_reader_core import RelikReaderCoreModel + +CKPT_FILE_NAME = "model.ckpt" +CONFIG_FILE_NAME = "cfg.yaml" + + +def convert_pl_module(pl_module_ckpt_path: str, output_dir: str) -> None: + if not os.path.exists(output_dir): + os.makedirs(output_dir) + else: + print(f"{output_dir} already exists, aborting operation") + exit(1) + + relik_pl_module: RelikReaderPLModule = RelikReaderPLModule.load_from_checkpoint( + pl_module_ckpt_path + ) + torch.save( + relik_pl_module.relik_reader_core_model, f"{output_dir}/{CKPT_FILE_NAME}" + ) + with open(f"{output_dir}/{CONFIG_FILE_NAME}", "w") as f: + omegaconf.OmegaConf.save( + omegaconf.OmegaConf.create(relik_pl_module.hparams["cfg"]), f + ) + + +def load_model_and_conf( + model_dir_path: str, +) -> Tuple[RelikReaderCoreModel, omegaconf.DictConfig]: + # TODO: quick workaround to load the model from HF hub + model_dir = from_cache( + model_dir_path, + filenames=[CKPT_FILE_NAME, CONFIG_FILE_NAME], + cache_dir=None, + force_download=False, + ) + + ckpt_path = f"{model_dir}/{CKPT_FILE_NAME}" + model = torch.load(ckpt_path, map_location=torch.device("cpu")) + + model_cfg_path = f"{model_dir}/{CONFIG_FILE_NAME}" + model_conf = omegaconf.OmegaConf.load(model_cfg_path) + return model, model_conf + + +def parse_arg() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument( + "--ckpt", + help="Path to the pytorch lightning ckpt you want to convert.", + required=True, + ) + parser.add_argument( + "--output-dir", + "-o", + help="The output dir to store the bare models and the config.", + required=True, + ) + return parser.parse_args() + + +def main(): + args = parse_arg() + convert_pl_module(args.ckpt, args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/relik/reader/utils/special_symbols.py b/relik/reader/utils/special_symbols.py new file mode 100644 index 0000000000000000000000000000000000000000..170909ad6cb2b69e1d6a8384af34cba441e60ce4 --- /dev/null +++ b/relik/reader/utils/special_symbols.py @@ -0,0 +1,11 @@ +from typing import List + +NME_SYMBOL = "--NME--" + + +def get_special_symbols(num_entities: int) -> List[str]: + return [NME_SYMBOL] + [f"[E-{i}]" for i in range(num_entities)] + + +def get_special_symbols_re(num_entities: int) -> List[str]: + return [NME_SYMBOL] + [f"[R-{i}]" for i in range(num_entities)] diff --git a/relik/reader/utils/strong_matching_eval.py b/relik/reader/utils/strong_matching_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..88ad651134907c8067c174ee5f6fcbbb5fc2cb73 --- /dev/null +++ b/relik/reader/utils/strong_matching_eval.py @@ -0,0 +1,146 @@ +from typing import Dict, List + +from lightning.pytorch.callbacks import Callback +from reader.data.relik_reader_sample import RelikReaderSample + +from relik.reader.relik_reader_predictor import RelikReaderPredictor +from relik.reader.utils.metrics import f1_measure, safe_divide +from relik.reader.utils.special_symbols import NME_SYMBOL + + +class StrongMatching: + def __call__(self, predicted_samples: List[RelikReaderSample]) -> Dict: + # accumulators + correct_predictions = 0 + correct_predictions_at_k = 0 + total_predictions = 0 + total_gold = 0 + correct_span_predictions = 0 + miss_due_to_candidates = 0 + + # prediction index stats + avg_correct_predicted_index = [] + avg_wrong_predicted_index = [] + less_index_predictions = [] + + # collect data from samples + for sample in predicted_samples: + predicted_annotations = sample.predicted_window_labels_chars + predicted_annotations_probabilities = sample.probs_window_labels_chars + gold_annotations = { + (ss, se, entity) + for ss, se, entity in sample.window_labels + if entity != NME_SYMBOL + } + total_predictions += len(predicted_annotations) + total_gold += len(gold_annotations) + + # correct named entity detection + predicted_spans = {(s, e) for s, e, _ in predicted_annotations} + gold_spans = {(s, e) for s, e, _ in gold_annotations} + correct_span_predictions += len(predicted_spans.intersection(gold_spans)) + + # correct entity linking + correct_predictions += len( + predicted_annotations.intersection(gold_annotations) + ) + + for ss, se, ge in gold_annotations.difference(predicted_annotations): + if ge not in sample.window_candidates: + miss_due_to_candidates += 1 + if ge in predicted_annotations_probabilities.get((ss, se), set()): + correct_predictions_at_k += 1 + + # indices metrics + predicted_spans_index = { + (ss, se): ent for ss, se, ent in predicted_annotations + } + gold_spans_index = {(ss, se): ent for ss, se, ent in gold_annotations} + + for pred_span, pred_ent in predicted_spans_index.items(): + gold_ent = gold_spans_index.get(pred_span) + + if pred_span not in gold_spans_index: + continue + + # missing candidate + if gold_ent not in sample.window_candidates: + continue + + gold_idx = sample.window_candidates.index(gold_ent) + if gold_idx is None: + continue + pred_idx = sample.window_candidates.index(pred_ent) + + if gold_ent != pred_ent: + avg_wrong_predicted_index.append(pred_idx) + + if gold_idx is not None: + if pred_idx > gold_idx: + less_index_predictions.append(0) + else: + less_index_predictions.append(1) + + else: + avg_correct_predicted_index.append(pred_idx) + + # compute NED metrics + span_precision = safe_divide(correct_span_predictions, total_predictions) + span_recall = safe_divide(correct_span_predictions, total_gold) + span_f1 = f1_measure(span_precision, span_recall) + + # compute EL metrics + precision = safe_divide(correct_predictions, total_predictions) + recall = safe_divide(correct_predictions, total_gold) + recall_at_k = safe_divide( + (correct_predictions + correct_predictions_at_k), total_gold + ) + + f1 = f1_measure(precision, recall) + + wrong_for_candidates = safe_divide(miss_due_to_candidates, total_gold) + + out_dict = { + "span_precision": span_precision, + "span_recall": span_recall, + "span_f1": span_f1, + "core_precision": precision, + "core_recall": recall, + "core_recall-at-k": recall_at_k, + "core_f1": round(f1, 4), + "wrong-for-candidates": wrong_for_candidates, + "index_errors_avg-index": safe_divide( + sum(avg_wrong_predicted_index), len(avg_wrong_predicted_index) + ), + "index_correct_avg-index": safe_divide( + sum(avg_correct_predicted_index), len(avg_correct_predicted_index) + ), + "index_avg-index": safe_divide( + sum(avg_correct_predicted_index + avg_wrong_predicted_index), + len(avg_correct_predicted_index + avg_wrong_predicted_index), + ), + "index_percentage-favoured-smaller-idx": safe_divide( + sum(less_index_predictions), len(less_index_predictions) + ), + } + + return {k: round(v, 5) for k, v in out_dict.items()} + + +class ELStrongMatchingCallback(Callback): + def __init__(self, dataset_path: str, dataset_conf) -> None: + super().__init__() + self.dataset_path = dataset_path + self.dataset_conf = dataset_conf + self.strong_matching_metric = StrongMatching() + + def on_validation_epoch_start(self, trainer, pl_module) -> None: + relik_reader_predictor = RelikReaderPredictor(pl_module.relik_reader_core_model) + predicted_samples = relik_reader_predictor.predict( + self.dataset_path, + samples=None, + dataset_conf=self.dataset_conf, + ) + predicted_samples = list(predicted_samples) + for k, v in self.strong_matching_metric(predicted_samples).items(): + pl_module.log(f"val_{k}", v) diff --git a/relik/retriever/__init__.py b/relik/retriever/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/relik/retriever/callbacks/__init__.py b/relik/retriever/callbacks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/relik/retriever/callbacks/base.py b/relik/retriever/callbacks/base.py new file mode 100644 index 0000000000000000000000000000000000000000..43042c94bfc93ac32fb60b344ca644cd1c79c1f3 --- /dev/null +++ b/relik/retriever/callbacks/base.py @@ -0,0 +1,168 @@ +from functools import partial +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union + +import hydra +import lightning as pl +import torch +from lightning.pytorch.trainer.states import RunningStage +from omegaconf import DictConfig +from torch.utils.data import DataLoader, Dataset + +from relik.common.log import get_logger +from relik.retriever.data.base.datasets import BaseDataset + +logger = get_logger() + + +STAGES_COMPATIBILITY_MAP = { + "train": RunningStage.TRAINING, + "val": RunningStage.VALIDATING, + "test": RunningStage.TESTING, +} + +DEFAULT_STAGES = { + RunningStage.VALIDATING, + RunningStage.TESTING, + RunningStage.SANITY_CHECKING, + RunningStage.PREDICTING, +} + + +class PredictionCallback(pl.Callback): + def __init__( + self, + batch_size: int = 32, + stages: Optional[Set[Union[str, RunningStage]]] = None, + other_callbacks: Optional[ + Union[List[DictConfig], List["NLPTemplateCallback"]] + ] = None, + datasets: Optional[Union[DictConfig, BaseDataset]] = None, + dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + *args, + **kwargs, + ): + super().__init__() + # parameters + self.batch_size = batch_size + self.datasets = datasets + self.dataloaders = dataloaders + + # callback initialization + if stages is None: + stages = DEFAULT_STAGES + + # compatibily stuff + stages = {STAGES_COMPATIBILITY_MAP.get(stage, stage) for stage in stages} + self.stages = [RunningStage(stage) for stage in stages] + self.other_callbacks = other_callbacks or [] + for i, callback in enumerate(self.other_callbacks): + if isinstance(callback, DictConfig): + self.other_callbacks[i] = hydra.utils.instantiate( + callback, _recursive_=False + ) + + @torch.no_grad() + def __call__( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + *args, + **kwargs, + ) -> Any: + # it should return the predictions + raise NotImplementedError + + def on_validation_epoch_end( + self, trainer: pl.Trainer, pl_module: pl.LightningModule + ): + predictions = self(trainer, pl_module) + for callback in self.other_callbacks: + callback( + trainer=trainer, + pl_module=pl_module, + callback=self, + predictions=predictions, + ) + + def on_test_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): + predictions = self(trainer, pl_module) + for callback in self.other_callbacks: + callback( + trainer=trainer, + pl_module=pl_module, + callback=self, + predictions=predictions, + ) + + @staticmethod + def _get_datasets_and_dataloaders( + dataset: Optional[Union[Dataset, DictConfig]], + dataloader: Optional[DataLoader], + trainer: pl.Trainer, + dataloader_kwargs: Optional[Dict[str, Any]] = None, + collate_fn: Optional[Callable] = None, + collate_fn_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[List[Dataset], List[DataLoader]]: + """ + Get the datasets and dataloaders from the datamodule or from the dataset provided. + + Args: + dataset (`Optional[Union[Dataset, DictConfig]]`): + The dataset to use. If `None`, the datamodule is used. + dataloader (`Optional[DataLoader]`): + The dataloader to use. If `None`, the datamodule is used. + trainer (`pl.Trainer`): + The trainer that contains the datamodule. + dataloader_kwargs (`Optional[Dict[str, Any]]`): + The kwargs to pass to the dataloader. + collate_fn (`Optional[Callable]`): + The collate function to use. + collate_fn_kwargs (`Optional[Dict[str, Any]]`): + The kwargs to pass to the collate function. + + Returns: + `Tuple[List[Dataset], List[DataLoader]]`: The datasets and dataloaders. + """ + # if a dataset is provided, use it + if dataset is not None: + dataloader_kwargs = dataloader_kwargs or {} + # get dataset + if isinstance(dataset, DictConfig): + dataset = hydra.utils.instantiate(dataset, _recursive_=False) + datasets = [dataset] if not isinstance(dataset, list) else dataset + if dataloader is not None: + dataloaders = ( + [dataloader] if isinstance(dataloader, DataLoader) else dataloader + ) + else: + collate_fn = collate_fn or partial( + datasets[0].collate_fn, **collate_fn_kwargs + ) + dataloader_kwargs["collate_fn"] = collate_fn + dataloaders = [DataLoader(datasets[0], **dataloader_kwargs)] + else: + # get the dataloaders and datasets from the datamodule + datasets = ( + trainer.datamodule.test_datasets + if trainer.state.stage == RunningStage.TESTING + else trainer.datamodule.val_datasets + ) + dataloaders = ( + trainer.test_dataloaders + if trainer.state.stage == RunningStage.TESTING + else trainer.val_dataloaders + ) + return datasets, dataloaders + + +class NLPTemplateCallback: + def __call__( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + callback: PredictionCallback, + predictions: Dict[str, Any], + *args, + **kwargs, + ) -> Any: + raise NotImplementedError diff --git a/relik/retriever/callbacks/evaluation_callbacks.py b/relik/retriever/callbacks/evaluation_callbacks.py new file mode 100644 index 0000000000000000000000000000000000000000..72d8fcb306f9e0e0fb58d1dbdb8ea6cfcea8c7e3 --- /dev/null +++ b/relik/retriever/callbacks/evaluation_callbacks.py @@ -0,0 +1,276 @@ +import logging +from typing import Dict, List, Optional + +import lightning as pl +import torch +from lightning.pytorch.trainer.states import RunningStage +from sklearn.metrics import label_ranking_average_precision_score + +from relik.common.log import get_console_logger, get_logger +from relik.retriever.callbacks.base import DEFAULT_STAGES, NLPTemplateCallback + +console_logger = get_console_logger() +logger = get_logger(__name__, level=logging.INFO) + + +class RecallAtKEvaluationCallback(NLPTemplateCallback): + """ + Computes the recall at k for the predictions. Recall at k is computed as the number of + correct predictions in the top k predictions divided by the total number of correct + predictions. + + Args: + k (`int`): + The number of predictions to consider. + prefix (`str`, `optional`): + The prefix to add to the metrics. + verbose (`bool`, `optional`, defaults to `False`): + Whether to log the metrics. + prog_bar (`bool`, `optional`, defaults to `True`): + Whether to log the metrics to the progress bar. + """ + + def __init__( + self, + k: int = 100, + prefix: Optional[str] = None, + verbose: bool = False, + prog_bar: bool = True, + *args, + **kwargs, + ): + super().__init__() + self.k = k + self.prefix = prefix + self.verbose = verbose + self.prog_bar = prog_bar + + @torch.no_grad() + def __call__( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + predictions: Dict, + *args, + **kwargs, + ) -> dict: + """ + Computes the recall at k for the predictions. + + Args: + trainer (:obj:`lightning.trainer.trainer.Trainer`): + The trainer object. + pl_module (:obj:`lightning.core.lightning.LightningModule`): + The lightning module. + predictions (:obj:`Dict`): + The predictions. + + Returns: + :obj:`Dict`: The computed metrics. + """ + if self.verbose: + logger.info(f"Computing recall@{self.k}") + + # metrics to return + metrics = {} + + stage = trainer.state.stage + if stage not in DEFAULT_STAGES: + raise ValueError( + f"Stage {stage} not supported, only `validate` and `test` are supported." + ) + + for dataloader_idx, samples in predictions.items(): + hits, total = 0, 0 + for sample in samples: + # compute the recall at k + # cut the predictions to the first k elements + predictions = sample["predictions"][: self.k] + hits += len(set(predictions) & set(sample["gold"])) + total += len(set(sample["gold"])) + + # compute the mean recall at k + recall_at_k = hits / total + metrics[f"recall@{self.k}_{dataloader_idx}"] = recall_at_k + metrics[f"recall@{self.k}"] = sum(metrics.values()) / len(metrics) + + if self.prefix is not None: + metrics = {f"{self.prefix}_{k}": v for k, v in metrics.items()} + else: + metrics = {f"{stage.value}_{k}": v for k, v in metrics.items()} + pl_module.log_dict( + metrics, on_step=False, on_epoch=True, prog_bar=self.prog_bar + ) + + if self.verbose: + logger.info( + f"Recall@{self.k} on {stage.value}: {metrics[f'{stage.value}_recall@{self.k}']}" + ) + + return metrics + + +class AvgRankingEvaluationCallback(NLPTemplateCallback): + """ + Computes the average ranking of the gold label in the predictions. Average ranking is + computed as the average of the rank of the gold label in the predictions. + + Args: + k (`int`): + The number of predictions to consider. + prefix (`str`, `optional`): + The prefix to add to the metrics. + stages (`List[str]`, `optional`): + The stages to compute the metrics on. Defaults to `["validate", "test"]`. + verbose (`bool`, `optional`, defaults to `False`): + Whether to log the metrics. + """ + + def __init__( + self, + k: int, + prefix: Optional[str] = None, + stages: Optional[List[str]] = None, + verbose: bool = True, + *args, + **kwargs, + ): + super().__init__() + self.k = k + self.prefix = prefix + self.verbose = verbose + self.stages = ( + [RunningStage(stage) for stage in stages] if stages else DEFAULT_STAGES + ) + + @torch.no_grad() + def __call__( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + predictions: Dict, + *args, + **kwargs, + ) -> dict: + """ + Computes the average ranking of the gold label in the predictions. + + Args: + trainer (:obj:`lightning.trainer.trainer.Trainer`): + The trainer object. + pl_module (:obj:`lightning.core.lightning.LightningModule`): + The lightning module. + predictions (:obj:`Dict`): + The predictions. + + Returns: + :obj:`Dict`: The computed metrics. + """ + if not predictions: + logger.warning("No predictions to compute the AVG Ranking metrics.") + return {} + + if self.verbose: + logger.info(f"Computing AVG Ranking@{self.k}") + + # metrics to return + metrics = {} + + stage = trainer.state.stage + if stage not in self.stages: + raise ValueError( + f"Stage `{stage}` not supported, only `validate` and `test` are supported." + ) + + for dataloader_idx, samples in predictions.items(): + rankings = [] + for sample in samples: + window_candidates = sample["predictions"][: self.k] + window_labels = sample["gold"] + for wl in window_labels: + if wl in window_candidates: + rankings.append(window_candidates.index(wl) + 1) + + avg_ranking = sum(rankings) / len(rankings) if len(rankings) > 0 else 0 + metrics[f"avg_ranking@{self.k}_{dataloader_idx}"] = avg_ranking + if len(metrics) == 0: + metrics[f"avg_ranking@{self.k}"] = 0 + else: + metrics[f"avg_ranking@{self.k}"] = sum(metrics.values()) / len(metrics) + + prefix = self.prefix or stage.value + metrics = { + f"{prefix}_{k}": torch.as_tensor(v, dtype=torch.float32) + for k, v in metrics.items() + } + pl_module.log_dict(metrics, on_step=False, on_epoch=True, prog_bar=False) + + if self.verbose: + logger.info( + f"AVG Ranking@{self.k} on {prefix}: {metrics[f'{prefix}_avg_ranking@{self.k}']}" + ) + + return metrics + + +class LRAPEvaluationCallback(NLPTemplateCallback): + def __init__( + self, + k: int = 100, + prefix: Optional[str] = None, + verbose: bool = False, + prog_bar: bool = True, + *args, + **kwargs, + ): + super().__init__() + self.k = k + self.prefix = prefix + self.verbose = verbose + self.prog_bar = prog_bar + + @torch.no_grad() + def __call__( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + predictions: Dict, + *args, + **kwargs, + ) -> dict: + if self.verbose: + logger.info(f"Computing recall@{self.k}") + + # metrics to return + metrics = {} + + stage = trainer.state.stage + if stage not in DEFAULT_STAGES: + raise ValueError( + f"Stage {stage} not supported, only `validate` and `test` are supported." + ) + + for dataloader_idx, samples in predictions.items(): + scores = [sample["scores"][: self.k] for sample in samples] + golds = [sample["gold"] for sample in samples] + + # compute the mean recall at k + lrap = label_ranking_average_precision_score(golds, scores) + metrics[f"lrap@{self.k}_{dataloader_idx}"] = lrap + metrics[f"lrap@{self.k}"] = sum(metrics.values()) / len(metrics) + + prefix = self.prefix or stage.value + metrics = { + f"{prefix}_{k}": torch.as_tensor(v, dtype=torch.float32) + for k, v in metrics.items() + } + pl_module.log_dict( + metrics, on_step=False, on_epoch=True, prog_bar=self.prog_bar + ) + + if self.verbose: + logger.info( + f"Recall@{self.k} on {stage.value}: {metrics[f'{stage.value}_recall@{self.k}']}" + ) + + return metrics diff --git a/relik/retriever/callbacks/prediction_callbacks.py b/relik/retriever/callbacks/prediction_callbacks.py new file mode 100644 index 0000000000000000000000000000000000000000..f8a051ad396d07872dfac05998d1ec550724677a --- /dev/null +++ b/relik/retriever/callbacks/prediction_callbacks.py @@ -0,0 +1,432 @@ +import logging +import random +import time +from copy import deepcopy +from pathlib import Path +from typing import List, Optional, Set, Union + +import lightning as pl +import torch +from lightning.pytorch.trainer.states import RunningStage +from omegaconf import DictConfig +from torch.utils.data import DataLoader +from tqdm import tqdm + +from relik.common.log import get_console_logger, get_logger +from relik.retriever.callbacks.base import PredictionCallback +from relik.retriever.common.model_inputs import ModelInputs +from relik.retriever.data.base.datasets import BaseDataset +from relik.retriever.data.datasets import GoldenRetrieverDataset +from relik.retriever.data.utils import HardNegativesManager +from relik.retriever.indexers.base import BaseDocumentIndex +from relik.retriever.pytorch_modules.model import GoldenRetriever + +console_logger = get_console_logger() +logger = get_logger(__name__, level=logging.INFO) + + +class GoldenRetrieverPredictionCallback(PredictionCallback): + def __init__( + self, + k: Optional[int] = None, + batch_size: int = 32, + num_workers: int = 8, + document_index: Optional[BaseDocumentIndex] = None, + precision: Union[str, int] = 32, + force_reindex: bool = True, + retriever_dir: Optional[Path] = None, + stages: Optional[Set[Union[str, RunningStage]]] = None, + other_callbacks: Optional[List[DictConfig]] = None, + dataset: Optional[Union[DictConfig, BaseDataset]] = None, + dataloader: Optional[DataLoader] = None, + *args, + **kwargs, + ): + super().__init__(batch_size, stages, other_callbacks, dataset, dataloader) + self.k = k + self.num_workers = num_workers + self.document_index = document_index + self.precision = precision + self.force_reindex = force_reindex + self.retriever_dir = retriever_dir + + @torch.no_grad() + def __call__( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + datasets: Optional[ + Union[DictConfig, BaseDataset, List[DictConfig], List[BaseDataset]] + ] = None, + dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + *args, + **kwargs, + ) -> dict: + stage = trainer.state.stage + logger.info(f"Computing predictions for stage {stage.value}") + if stage not in self.stages: + raise ValueError( + f"Stage `{stage}` not supported, only {self.stages} are supported" + ) + + self.datasets, self.dataloaders = self._get_datasets_and_dataloaders( + datasets, + dataloaders, + trainer, + dataloader_kwargs=dict( + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=True, + shuffle=False, + ), + ) + + # set the model to eval mode + pl_module.eval() + # get the retriever + retriever: GoldenRetriever = pl_module.model + + # here we will store the samples with predictions for each dataloader + dataloader_predictions = {} + # compute the passage embeddings index for each dataloader + for dataloader_idx, dataloader in enumerate(self.dataloaders): + current_dataset: GoldenRetrieverDataset = self.datasets[dataloader_idx] + logger.info( + f"Computing passage embeddings for dataset {current_dataset.name}" + ) + # passages = self._get_passages_dataloader(current_dataset, trainer) + + tokenizer = current_dataset.tokenizer + + def collate_fn(x): + return ModelInputs( + tokenizer( + x, + truncation=True, + padding=True, + max_length=current_dataset.max_passage_length, + return_tensors="pt", + ) + ) + + # check if we need to reindex the passages and + # also if we need to load the retriever from disk + if (self.retriever_dir is not None and trainer.current_epoch == 0) or ( + self.retriever_dir is not None and stage == RunningStage.TESTING + ): + force_reindex = False + else: + force_reindex = self.force_reindex + + if ( + not force_reindex + and self.retriever_dir is not None + and stage == RunningStage.TESTING + ): + retriever = retriever.from_pretrained(self.retriever_dir) + # set the retriever to eval mode if we are loading it from disk + + # you never know :) + retriever.eval() + + retriever.index( + batch_size=self.batch_size, + num_workers=self.num_workers, + max_length=current_dataset.max_passage_length, + collate_fn=collate_fn, + precision=self.precision, + compute_on_cpu=False, + force_reindex=force_reindex, + ) + + # pl_module_original_device = pl_module.device + # if ( + # and pl_module.device.type == "cuda" + # ): + # pl_module.to("cpu") + + # now compute the question embeddings and compute the top-k accuracy + predictions = [] + start = time.time() + for batch in tqdm( + dataloader, + desc=f"Computing predictions for dataset {current_dataset.name}", + ): + batch = batch.to(pl_module.device) + # get the top-k indices + retriever_output = retriever.retrieve( + **batch.questions, k=self.k, precision=self.precision + ) + # compute recall at k + for batch_idx, retrieved_samples in enumerate(retriever_output): + # get the positive passages + gold_passages = batch["positives"][batch_idx] + # get the index of the gold passages in the retrieved passages + gold_passage_indices = [ + retriever.get_index_from_passage(passage) + for passage in gold_passages + ] + retrieved_indices = [r.index for r in retrieved_samples] + retrieved_passages = [r.label for r in retrieved_samples] + retrieved_scores = [r.score for r in retrieved_samples] + # correct predictions are the passages that are in the top-k and are gold + correct_indices = set(gold_passage_indices) & set(retrieved_indices) + # wrong predictions are the passages that are in the top-k and are not gold + wrong_indices = set(retrieved_indices) - set(gold_passage_indices) + # add the predictions to the list + prediction_output = dict( + sample_idx=batch.sample_idx[batch_idx], + gold=gold_passages, + predictions=retrieved_passages, + scores=retrieved_scores, + correct=[ + retriever.get_passage_from_index(i) for i in correct_indices + ], + wrong=[ + retriever.get_passage_from_index(i) for i in wrong_indices + ], + ) + predictions.append(prediction_output) + end = time.time() + logger.info(f"Time to retrieve: {str(end - start)}") + + dataloader_predictions[dataloader_idx] = predictions + + # if pl_module_original_device != pl_module.device: + # pl_module.to(pl_module_original_device) + + # return the predictions + return dataloader_predictions + + # @staticmethod + # def _get_passages_dataloader( + # indexer: Optional[BaseIndexer] = None, + # dataset: Optional[GoldenRetrieverDataset] = None, + # trainer: Optional[pl.Trainer] = None, + # ): + # if indexer is None: + # logger.info( + # f"Indexer is None, creating indexer from passages not found in dataset {dataset.name}, computing them from the dataloaders" + # ) + # # get the passages from the all the dataloader passage ids + # passages = set() # set to avoid duplicates + # for batch in trainer.train_dataloader: + # passages.update( + # [ + # " ".join(map(str, [c for c in passage_ids.tolist() if c != 0])) + # for passage_ids in batch["passages"]["input_ids"] + # ] + # ) + # for d in trainer.val_dataloaders: + # for batch in d: + # passages.update( + # [ + # " ".join( + # map(str, [c for c in passage_ids.tolist() if c != 0]) + # ) + # for passage_ids in batch["passages"]["input_ids"] + # ] + # ) + # for d in trainer.test_dataloaders: + # for batch in d: + # passages.update( + # [ + # " ".join( + # map(str, [c for c in passage_ids.tolist() if c != 0]) + # ) + # for passage_ids in batch["passages"]["input_ids"] + # ] + # ) + # passages = list(passages) + # else: + # passages = dataset.passages + # return passages + + +class NegativeAugmentationCallback(GoldenRetrieverPredictionCallback): + """ + Callback that computes the predictions of a retriever model on a dataset and computes the + negative examples for the training set. + + Args: + k (:obj:`int`, `optional`, defaults to 100): + The number of top-k retrieved passages to + consider for the evaluation. + batch_size (:obj:`int`, `optional`, defaults to 32): + The batch size to use for the evaluation. + num_workers (:obj:`int`, `optional`, defaults to 0): + The number of workers to use for the evaluation. + force_reindex (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether to force the reindexing of the dataset. + retriever_dir (:obj:`Path`, `optional`): + The path to the retriever directory. If not specified, the retriever will be + initialized from scratch. + stages (:obj:`Set[str]`, `optional`): + The stages to run the callback on. If not specified, the callback will be run on + train, validation and test. + other_callbacks (:obj:`List[DictConfig]`, `optional`): + A list of other callbacks to run on the same stages. + dataset (:obj:`Union[DictConfig, BaseDataset]`, `optional`): + The dataset to use for the evaluation. If not specified, the dataset will be + initialized from scratch. + metrics_to_monitor (:obj:`List[str]`, `optional`): + The metrics to monitor for the evaluation. + threshold (:obj:`float`, `optional`, defaults to 0.8): + The threshold to consider. If the recall score of the retriever is above the + threshold, the negative examples will be added to the training set. + max_negatives (:obj:`int`, `optional`, defaults to 5): + The maximum number of negative examples to add to the training set. + add_with_probability (:obj:`float`, `optional`, defaults to 1.0): + The probability with which to add the negative examples to the training set. + refresh_every_n_epochs (:obj:`int`, `optional`, defaults to 1): + The number of epochs after which to refresh the index. + """ + + def __init__( + self, + k: int = 100, + batch_size: int = 32, + num_workers: int = 0, + force_reindex: bool = False, + retriever_dir: Optional[Path] = None, + stages: Set[Union[str, RunningStage]] = None, + other_callbacks: Optional[List[DictConfig]] = None, + dataset: Optional[Union[DictConfig, BaseDataset]] = None, + metrics_to_monitor: List[str] = None, + threshold: float = 0.8, + max_negatives: int = 5, + add_with_probability: float = 1.0, + refresh_every_n_epochs: int = 1, + *args, + **kwargs, + ): + super().__init__( + k=k, + batch_size=batch_size, + num_workers=num_workers, + force_reindex=force_reindex, + retriever_dir=retriever_dir, + stages=stages, + other_callbacks=other_callbacks, + dataset=dataset, + *args, + **kwargs, + ) + if metrics_to_monitor is None: + metrics_to_monitor = ["val_loss"] + self.metrics_to_monitor = metrics_to_monitor + self.threshold = threshold + self.max_negatives = max_negatives + self.add_with_probability = add_with_probability + self.refresh_every_n_epochs = refresh_every_n_epochs + + @torch.no_grad() + def __call__( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + *args, + **kwargs, + ) -> dict: + """ + Computes the predictions of a retriever model on a dataset and computes the negative + examples for the training set. + + Args: + trainer (:obj:`pl.Trainer`): + The trainer object. + pl_module (:obj:`pl.LightningModule`): + The lightning module. + + Returns: + A dictionary containing the negative examples. + """ + stage = trainer.state.stage + if stage not in self.stages: + return {} + + if self.metrics_to_monitor not in trainer.logged_metrics: + raise ValueError( + f"Metric `{self.metrics_to_monitor}` not found in trainer.logged_metrics" + f"Available metrics: {trainer.logged_metrics.keys()}" + ) + if trainer.logged_metrics[self.metrics_to_monitor] < self.threshold: + return {} + + if trainer.current_epoch % self.refresh_every_n_epochs != 0: + return {} + + # if all( + # [ + # trainer.logged_metrics.get(metric) is None + # for metric in self.metrics_to_monitor + # ] + # ): + # raise ValueError( + # f"No metric from {self.metrics_to_monitor} not found in trainer.logged_metrics" + # f"Available metrics: {trainer.logged_metrics.keys()}" + # ) + + # if all( + # [ + # trainer.logged_metrics.get(metric) < self.threshold + # for metric in self.metrics_to_monitor + # if trainer.logged_metrics.get(metric) is not None + # ] + # ): + # return {} + + if trainer.current_epoch % self.refresh_every_n_epochs != 0: + return {} + + logger.info( + f"At least one metric from {self.metrics_to_monitor} is above threshold " + f"{self.threshold}. Computing hard negatives." + ) + + # make a copy of the dataset to avoid modifying the original one + trainer.datamodule.train_dataset.hn_manager = None + dataset_copy = deepcopy(trainer.datamodule.train_dataset) + predictions = super().__call__( + trainer, + pl_module, + datasets=dataset_copy, + dataloaders=DataLoader( + dataset_copy.to_torch_dataset(), + shuffle=False, + batch_size=None, + num_workers=self.num_workers, + pin_memory=True, + collate_fn=lambda x: x, + ), + *args, + **kwargs, + ) + logger.info(f"Computing hard negatives for epoch {trainer.current_epoch}") + # predictions is a dict with the dataloader index as key and the predictions as value + # since we only have one dataloader, we can get the predictions directly + predictions = list(predictions.values())[0] + # store the predictions in a dictionary for faster access based on the sample index + hard_negatives_list = {} + for prediction in tqdm(predictions, desc="Collecting hard negatives"): + if random.random() < 1 - self.add_with_probability: + continue + top_k_passages = prediction["predictions"] + gold_passages = prediction["gold"] + # get the ids of the max_negatives wrong passages with the highest similarity + wrong_passages = [ + passage_id + for passage_id in top_k_passages + if passage_id not in gold_passages + ][: self.max_negatives] + hard_negatives_list[prediction["sample_idx"]] = wrong_passages + + trainer.datamodule.train_dataset.hn_manager = HardNegativesManager( + tokenizer=trainer.datamodule.train_dataset.tokenizer, + max_length=trainer.datamodule.train_dataset.max_passage_length, + data=hard_negatives_list, + ) + + # normalize predictions as in the original GoldenRetrieverPredictionCallback + predictions = {0: predictions} + return predictions diff --git a/relik/retriever/callbacks/utils_callbacks.py b/relik/retriever/callbacks/utils_callbacks.py new file mode 100644 index 0000000000000000000000000000000000000000..ba73e0d9ee02d9e1424611551befc002bdaaecf3 --- /dev/null +++ b/relik/retriever/callbacks/utils_callbacks.py @@ -0,0 +1,287 @@ +import json +import logging +import os +from pathlib import Path +from typing import Any, Dict, Optional, Union + +import lightning as pl +import torch +from lightning.pytorch.trainer.states import RunningStage + +from relik.common.log import get_console_logger, get_logger +from relik.retriever.callbacks.base import NLPTemplateCallback, PredictionCallback +from relik.retriever.pytorch_modules.hf import GoldenRetrieverModel + +console_logger = get_console_logger() +logger = get_logger(__name__, level=logging.INFO) + + +class SavePredictionsCallback(NLPTemplateCallback): + def __init__( + self, + saving_dir: Optional[Union[str, os.PathLike]] = None, + verbose: bool = False, + *args, + **kwargs, + ): + super().__init__() + self.saving_dir = saving_dir + self.verbose = verbose + + @torch.no_grad() + def __call__( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + predictions: Dict, + callback: PredictionCallback, + *args, + **kwargs, + ) -> dict: + # write the predictions to a file inside the experiment folder + if self.saving_dir is None and trainer.logger is None: + logger.info( + "You need to specify an output directory (`saving_dir`) or a logger to save the predictions.\n" + "Skipping saving predictions." + ) + return + datasets = callback.datasets + for dataloader_idx, dataloader_predictions in predictions.items(): + # save to file + if self.saving_dir is not None: + prediction_folder = Path(self.saving_dir) + else: + try: + prediction_folder = ( + Path(trainer.logger.experiment.dir) / "predictions" + ) + except Exception: + logger.info( + "You need to specify an output directory (`saving_dir`) or a logger to save the predictions.\n" + "Skipping saving predictions." + ) + return + prediction_folder.mkdir(exist_ok=True) + predictions_path = ( + prediction_folder + / f"{datasets[dataloader_idx].name}_{dataloader_idx}.json" + ) + if self.verbose: + logger.info(f"Saving predictions to {predictions_path}") + with open(predictions_path, "w") as f: + for prediction in dataloader_predictions: + for k, v in prediction.items(): + if isinstance(v, set): + prediction[k] = list(v) + f.write(json.dumps(prediction) + "\n") + + +class ResetModelCallback(pl.Callback): + def __init__( + self, + question_encoder: str, + passage_encoder: Optional[str] = None, + verbose: bool = True, + ) -> None: + super().__init__() + self.question_encoder = question_encoder + self.passage_encoder = passage_encoder + self.verbose = verbose + + def on_train_epoch_start( + self, trainer: pl.Trainer, pl_module: pl.LightningModule, *args, **kwargs + ) -> None: + if trainer.current_epoch == 0: + if self.verbose: + logger.info("Current epoch is 0, skipping resetting model") + return + + if self.verbose: + logger.info("Resetting model, optimizer and lr scheduler") + # reload model from scratch + previous_device = pl_module.device + trainer.model.model.question_encoder = GoldenRetrieverModel.from_pretrained( + self.question_encoder + ) + trainer.model.model.question_encoder.to(previous_device) + if self.passage_encoder is not None: + trainer.model.model.passage_encoder = GoldenRetrieverModel.from_pretrained( + self.passage_encoder + ) + trainer.model.model.passage_encoder.to(previous_device) + + trainer.strategy.setup_optimizers(trainer) + + +class FreeUpIndexerVRAMCallback(pl.Callback): + def __call__( + self, + pl_module: pl.LightningModule, + *args, + **kwargs, + ) -> Any: + logger.info("Freeing up GPU memory") + + # remove the index from the GPU memory + # remove the embeddings from the GPU memory first + if pl_module.model.document_index is not None: + if pl_module.model.document_index.embeddings is not None: + pl_module.model.document_index.embeddings.cpu() + pl_module.model.document_index.embeddings = None + + import gc + + gc.collect() + torch.cuda.empty_cache() + + def on_train_epoch_start( + self, trainer: pl.Trainer, pl_module: pl.LightningModule, *args, **kwargs + ) -> None: + return self(pl_module) + + def on_test_epoch_start( + self, trainer: pl.Trainer, pl_module: pl.LightningModule, *args, **kwargs + ) -> None: + return self(pl_module) + + +class ShuffleTrainDatasetCallback(pl.Callback): + def __init__(self, seed: int = 42, verbose: bool = True) -> None: + super().__init__() + self.seed = seed + self.verbose = verbose + self.previous_epoch = -1 + + def on_validation_epoch_end(self, trainer: pl.Trainer, *args, **kwargs): + if self.verbose: + if trainer.current_epoch != self.previous_epoch: + logger.info(f"Shuffling train dataset at epoch {trainer.current_epoch}") + + # logger.info(f"Shuffling train dataset at epoch {trainer.current_epoch}") + if trainer.current_epoch != self.previous_epoch: + trainer.datamodule.train_dataset.shuffle_data( + seed=self.seed + trainer.current_epoch + 1 + ) + self.previous_epoch = trainer.current_epoch + + +class PrefetchTrainDatasetCallback(pl.Callback): + def __init__(self, verbose: bool = True) -> None: + super().__init__() + self.verbose = verbose + # self.previous_epoch = -1 + + def on_validation_epoch_end(self, trainer: pl.Trainer, *args, **kwargs): + if trainer.datamodule.train_dataset.prefetch_batches: + if self.verbose: + # if trainer.current_epoch != self.previous_epoch: + logger.info( + f"Prefetching train dataset at epoch {trainer.current_epoch}" + ) + # if trainer.current_epoch != self.previous_epoch: + trainer.datamodule.train_dataset.prefetch() + self.previous_epoch = trainer.current_epoch + + +class SubsampleTrainDatasetCallback(pl.Callback): + def __init__(self, seed: int = 43, verbose: bool = True) -> None: + super().__init__() + self.seed = seed + self.verbose = verbose + + def on_validation_epoch_end(self, trainer: pl.Trainer, *args, **kwargs): + if self.verbose: + logger.info(f"Subsampling train dataset at epoch {trainer.current_epoch}") + trainer.datamodule.train_dataset.random_subsample( + seed=self.seed + trainer.current_epoch + 1 + ) + + +class SaveRetrieverCallback(pl.Callback): + def __init__( + self, + saving_dir: Optional[Union[str, os.PathLike]] = None, + verbose: bool = True, + *args, + **kwargs, + ): + super().__init__() + self.saving_dir = saving_dir + self.verbose = verbose + self.free_up_indexer_callback = FreeUpIndexerVRAMCallback() + + @torch.no_grad() + def __call__( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + *args, + **kwargs, + ): + if self.saving_dir is None and trainer.logger is None: + logger.info( + "You need to specify an output directory (`saving_dir`) or a logger to save the retriever.\n" + "Skipping saving retriever." + ) + return + if self.saving_dir is not None: + retriever_folder = Path(self.saving_dir) + else: + try: + retriever_folder = Path(trainer.logger.experiment.dir) / "retriever" + except Exception: + logger.info( + "You need to specify an output directory (`saving_dir`) or a logger to save the retriever.\n" + "Skipping saving retriever." + ) + return + retriever_folder.mkdir(exist_ok=True, parents=True) + if self.verbose: + logger.info(f"Saving retriever to {retriever_folder}") + pl_module.model.save_pretrained(retriever_folder) + + def on_save_checkpoint( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + checkpoint: Dict[str, Any], + ): + self(trainer, pl_module) + # self.free_up_indexer_callback(pl_module) + + +class SampleNegativesDatasetCallback(pl.Callback): + def __init__(self, seed: int = 42, verbose: bool = True) -> None: + super().__init__() + self.seed = seed + self.verbose = verbose + + def on_validation_epoch_end(self, trainer: pl.Trainer, *args, **kwargs): + if self.verbose: + f"Sampling negatives for train dataset at epoch {trainer.current_epoch}" + trainer.datamodule.train_dataset.sample_dataset_negatives( + seed=self.seed + trainer.current_epoch + ) + + +class SubsampleDataCallback(pl.Callback): + def __init__(self, seed: int = 42, verbose: bool = True) -> None: + super().__init__() + self.seed = seed + self.verbose = verbose + + def on_validation_epoch_start(self, trainer: pl.Trainer, *args, **kwargs): + if self.verbose: + f"Subsampling data for train dataset at epoch {trainer.current_epoch}" + if trainer.state.stage == RunningStage.SANITY_CHECKING: + return + trainer.datamodule.train_dataset.subsample_data( + seed=self.seed + trainer.current_epoch + ) + + def on_fit_start(self, trainer: pl.Trainer, *args, **kwargs): + if self.verbose: + f"Subsampling data for train dataset at epoch {trainer.current_epoch}" + trainer.datamodule.train_dataset.subsample_data( + seed=self.seed + trainer.current_epoch + ) diff --git a/relik/retriever/common/__init__.py b/relik/retriever/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/relik/retriever/common/model_inputs.py b/relik/retriever/common/model_inputs.py new file mode 100644 index 0000000000000000000000000000000000000000..06b28f74e9594df32f98f8a32a0b46177db54062 --- /dev/null +++ b/relik/retriever/common/model_inputs.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +from collections import UserDict +from typing import Any, Union + +import torch +from lightning.fabric.utilities import move_data_to_device + +from relik.common.log import get_console_logger + +logger = get_console_logger() + + +class ModelInputs(UserDict): + """Model input dictionary wrapper.""" + + def __getattr__(self, item: str): + try: + return self.data[item] + except KeyError: + raise AttributeError(f"`ModelInputs` has no attribute `{item}`") + + def __getitem__(self, item: str) -> Any: + return self.data[item] + + def __getstate__(self): + return {"data": self.data} + + def __setstate__(self, state): + if "data" in state: + self.data = state["data"] + + def keys(self): + """A set-like object providing a view on D's keys.""" + return self.data.keys() + + def values(self): + """An object providing a view on D's values.""" + return self.data.values() + + def items(self): + """A set-like object providing a view on D's items.""" + return self.data.items() + + def to(self, device: Union[str, torch.device]) -> ModelInputs: + """ + Send all tensors values to device. + Args: + device (`str` or `torch.device`): The device to put the tensors on. + Returns: + :class:`tokenizers.ModelInputs`: The same instance of :class:`~tokenizers.ModelInputs` + after modification. + """ + self.data = move_data_to_device(self.data, device) + return self diff --git a/relik/retriever/common/sampler.py b/relik/retriever/common/sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..024c57b23da6db71dd76929226b005f75b9e98f5 --- /dev/null +++ b/relik/retriever/common/sampler.py @@ -0,0 +1,108 @@ +import math + +from torch.utils.data.sampler import BatchSampler, Sampler, SubsetRandomSampler + + +def identity(x): + return x + + +class SortedSampler(Sampler): + """ + Samples elements sequentially, always in the same order. + + Args: + data (`obj`: `Iterable`): + Iterable data. + sort_key (`obj`: `Callable`): + Specifies a function of one argument that is used to + extract a numerical comparison key from each list element. + + Example: + >>> list(SortedSampler(range(10), sort_key=lambda i: -i)) + [9, 8, 7, 6, 5, 4, 3, 2, 1, 0] + + """ + + def __init__(self, data, sort_key=identity): + super().__init__(data) + self.data = data + self.sort_key = sort_key + zip_ = [(i, self.sort_key(row)) for i, row in enumerate(self.data)] + zip_ = sorted(zip_, key=lambda r: r[1]) + self.sorted_indexes = [item[0] for item in zip_] + + def __iter__(self): + return iter(self.sorted_indexes) + + def __len__(self): + return len(self.data) + + +class BucketBatchSampler(BatchSampler): + """ + `BucketBatchSampler` toggles between `sampler` batches and sorted batches. + Typically, the `sampler` will be a `RandomSampler` allowing the user to toggle between + random batches and sorted batches. A larger `bucket_size_multiplier` is more sorted and vice + versa. + Background: + ``BucketBatchSampler`` is similar to a ``BucketIterator`` found in popular libraries like + ``AllenNLP`` and ``torchtext``. A ``BucketIterator`` pools together examples with a similar + size length to reduce the padding required for each batch while maintaining some noise + through bucketing. + **AllenNLP Implementation:** + https://github.com/allenai/allennlp/blob/master/allennlp/data/iterators/bucket_iterator.py + **torchtext Implementation:** + https://github.com/pytorch/text/blob/master/torchtext/data/iterator.py#L225 + + Args: + sampler (`obj`: `torch.data.utils.sampler.Sampler): + batch_size (`int`): + Size of mini-batch. + drop_last (`bool`, optional, defaults to `False`): + If `True` the sampler will drop the last batch if its size would be less than `batch_size`. + sort_key (`obj`: `Callable`, optional, defaults to `identity`): + Callable to specify a comparison key for sorting. + bucket_size_multiplier (`int`, optional, defaults to `100`): + Buckets are of size `batch_size * bucket_size_multiplier`. + Example: + >>> from torchnlp.random import set_seed + >>> set_seed(123) + >>> + >>> from torch.utils.data.sampler import SequentialSampler + >>> sampler = SequentialSampler(list(range(10))) + >>> list(BucketBatchSampler(sampler, batch_size=3, drop_last=False)) + [[6, 7, 8], [0, 1, 2], [3, 4, 5], [9]] + >>> list(BucketBatchSampler(sampler, batch_size=3, drop_last=True)) + [[0, 1, 2], [3, 4, 5], [6, 7, 8]] + + """ + + def __init__( + self, + sampler, + batch_size, + drop_last: bool = False, + sort_key=identity, + bucket_size_multiplier=100, + ): + super().__init__(sampler, batch_size, drop_last) + self.sort_key = sort_key + _bucket_size = batch_size * bucket_size_multiplier + if hasattr(sampler, "__len__"): + _bucket_size = min(_bucket_size, len(sampler)) + self.bucket_sampler = BatchSampler(sampler, _bucket_size, False) + + def __iter__(self): + for bucket in self.bucket_sampler: + sorted_sampler = SortedSampler(bucket, self.sort_key) + for batch in SubsetRandomSampler( + list(BatchSampler(sorted_sampler, self.batch_size, self.drop_last)) + ): + yield [bucket[i] for i in batch] + + def __len__(self): + if self.drop_last: + return len(self.sampler) // self.batch_size + else: + return math.ceil(len(self.sampler) / self.batch_size) diff --git a/relik/retriever/conf/data/aida_dataset.yaml b/relik/retriever/conf/data/aida_dataset.yaml new file mode 100644 index 0000000000000000000000000000000000000000..22fcc6458a2dc757b569baef37846751dd3c1c7a --- /dev/null +++ b/relik/retriever/conf/data/aida_dataset.yaml @@ -0,0 +1,47 @@ +shared_params: + passages_path: null + max_passage_length: 64 + passage_batch_size: 64 + question_batch_size: 64 + use_topics: False + +datamodule: + _target_: relik.retriever.lightning_modules.pl_data_modules.GoldenRetrieverPLDataModule + datasets: + train: + _target_: relik.retriever.data.datasets.AidaInBatchNegativesDataset + name: "train" + path: null + tokenizer: ${model.language_model} + max_passage_length: ${data.shared_params.max_passage_length} + question_batch_size: ${data.shared_params.question_batch_size} + passage_batch_size: ${data.shared_params.passage_batch_size} + subsample_strategy: null + subsample_portion: 0.1 + shuffle: True + use_topics: ${data.shared_params.use_topics} + + val: + - _target_: relik.retriever.data.datasets.AidaInBatchNegativesDataset + name: "val" + path: null + tokenizer: ${model.language_model} + max_passage_length: ${data.shared_params.max_passage_length} + question_batch_size: ${data.shared_params.question_batch_size} + passage_batch_size: ${data.shared_params.passage_batch_size} + use_topics: ${data.shared_params.use_topics} + + test: + - _target_: relik.retriever.data.datasets.AidaInBatchNegativesDataset + name: "test" + path: null + tokenizer: ${model.language_model} + max_passage_length: ${data.shared_params.max_passage_length} + question_batch_size: ${data.shared_params.question_batch_size} + passage_batch_size: ${data.shared_params.passage_batch_size} + use_topics: ${data.shared_params.use_topics} + + num_workers: + train: 4 + val: 4 + test: 4 diff --git a/relik/retriever/conf/data/dataset_v2.yaml b/relik/retriever/conf/data/dataset_v2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6040616d1f92182c2d002c065a7b89dc240f96b3 --- /dev/null +++ b/relik/retriever/conf/data/dataset_v2.yaml @@ -0,0 +1,43 @@ +shared_params: + passages_path: null + max_passage_length: 64 + passage_batch_size: 64 + question_batch_size: 64 + +datamodule: + _target_: relik.retriever.lightning_modules.pl_data_modules.GoldenRetrieverPLDataModule + datasets: + train: + _target_: relik.retriever.data.datasets.InBatchNegativesDataset + name: "train" + path: null + tokenizer: ${model.language_model} + max_passage_length: ${data.shared_params.max_passage_length} + question_batch_size: ${data.shared_params.question_batch_size} + passage_batch_size: ${data.shared_params.passage_batch_size} + subsample_strategy: null + subsample_portion: 0.1 + shuffle: True + + val: + - _target_: relik.retriever.data.datasets.InBatchNegativesDataset + name: "val" + path: null + tokenizer: ${model.language_model} + max_passage_length: ${data.shared_params.max_passage_length} + question_batch_size: ${data.shared_params.question_batch_size} + passage_batch_size: ${data.shared_params.passage_batch_size} + + test: + - _target_: relik.retriever.data.datasets.InBatchNegativesDataset + name: "test" + path: null + tokenizer: ${model.language_model} + max_passage_length: ${data.shared_params.max_passage_length} + question_batch_size: ${data.shared_params.question_batch_size} + passage_batch_size: ${data.shared_params.passage_batch_size} + + num_workers: + train: 0 + val: 0 + test: 0 diff --git a/relik/retriever/conf/data/dpr_like.yaml b/relik/retriever/conf/data/dpr_like.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d316b9868b0e64ea3e58725fc56549a2b31436be --- /dev/null +++ b/relik/retriever/conf/data/dpr_like.yaml @@ -0,0 +1,31 @@ +datamodule: + _target_: relik.retriever.goldenretriever.lightning_modules.pl_data_modules.PLDataModule + tokenizer: ${model.language_model} + datasets: + train: + _target_: relik.retriever.data.dpr.datasets.DPRDataset + name: "train" + passages_path: ${data_overrides.passages_path} + path: ${data_overrides.train_path} + + val: + - _target_: relik.retriever.data.dpr.datasets.DPRDataset + name: "val" + passages_path: ${data_overrides.passages_path} + path: ${data_overrides.val_path} + + test: + - _target_: relik.retriever.data.dpr.datasets.DPRDataset + name: "test" + passages_path: ${data_overrides.passages_path} + path: ${data_overrides.test_path} + + batch_sizes: + train: 32 + val: 64 + test: 64 + + num_workers: + train: 4 + val: 4 + test: 4 diff --git a/relik/retriever/conf/data/in_batch_negatives.yaml b/relik/retriever/conf/data/in_batch_negatives.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1a90cb4619604e9d93df47f490f5948ebdfbf312 --- /dev/null +++ b/relik/retriever/conf/data/in_batch_negatives.yaml @@ -0,0 +1,48 @@ +shared_params: + passages_path: null + max_passage_length: 64 + prefetch_batches: True + use_topics: False + +datamodule: + _target_: goldenretriever.lightning_modules.pl_data_modules.PLDataModule + tokenizer: ${model.language_model} + datasets: + train: + _target_: goldenretriever.data.dpr.datasets.InBatchNegativesDPRDataset + name: "train" + path: null + passages_path: ${data.shared_params.passages_path} + max_passage_length: ${data.shared_params.max_passage_length} + prefetch_batches: ${data.shared_params.prefetch_batches} + subsample: null + shuffle: True + use_topics: ${data.shared_params.use_topics} + + val: + - _target_: goldenretriever.data.dpr.datasets.InBatchNegativesDPRDataset + name: "val" + path: null + passages_path: ${data.shared_params.passages_path} + max_passage_length: ${data.shared_params.max_passage_length} + prefetch_batches: ${data.shared_params.prefetch_batches} + use_topics: ${data.shared_params.use_topics} + + test: + - _target_: goldenretriever.data.dpr.datasets.InBatchNegativesDPRDataset + name: "test" + path: null + passages_path: ${data.shared_params.passages_path} + max_passage_length: ${data.shared_params.max_passage_length} + prefetch_batches: ${data.shared_params.prefetch_batches} + use_topics: ${data.shared_params.use_topics} + + batch_sizes: + train: 64 + val: 64 + test: 64 + + num_workers: + train: 4 + val: 4 + test: 4 diff --git a/relik/retriever/conf/data/iterable_in_batch_negatives.yaml b/relik/retriever/conf/data/iterable_in_batch_negatives.yaml new file mode 100644 index 0000000000000000000000000000000000000000..397fda31c4b42ef5cd22c10805200f4d19cd590d --- /dev/null +++ b/relik/retriever/conf/data/iterable_in_batch_negatives.yaml @@ -0,0 +1,52 @@ +shared_params: + passages_path: null + max_passage_length: 64 + max_passages_per_batch: 64 + max_questions_per_batch: 64 + prefetch_batches: True + use_topics: False + +datamodule: + _target_: relik.retriever.lightning_modules.pl_data_modules.PLDataModule + tokenizer: ${model.language_model} + datasets: + train: + _target_: relik.retriever.data.dpr.datasets.InBatchNegativesDPRIterableDataset + name: "train" + path: null + passages_path: ${data.shared_params.passages_path} + max_passage_length: ${data.shared_params.max_passage_length} + max_questions_per_batch: ${data.shared_params.max_questions_per_batch} + max_passages_per_batch: ${data.shared_params.max_passages_per_batch} + prefetch_batches: ${data.shared_params.prefetch_batches} + subsample: null + random_subsample: False + shuffle: True + use_topics: ${data.shared_params.use_topics} + + val: + - _target_: relik.retriever.data.dpr.datasets.InBatchNegativesDPRIterableDataset + name: "val" + path: null + passages_path: ${data.shared_params.passages_path} + max_passage_length: ${data.shared_params.max_passage_length} + max_questions_per_batch: ${data.shared_params.max_questions_per_batch} + max_passages_per_batch: ${data.shared_params.max_passages_per_batch} + prefetch_batches: ${data.shared_params.prefetch_batches} + use_topics: ${data.shared_params.use_topics} + + test: + - _target_: relik.retriever.data.dpr.datasets.InBatchNegativesDPRIterableDataset + name: "test" + path: null + passages_path: ${data.shared_params.passages_path} + max_passage_length: ${data.shared_params.max_passage_length} + max_questions_per_batch: ${data.shared_params.max_questions_per_batch} + max_passages_per_batch: ${data.shared_params.max_passages_per_batch} + prefetch_batches: ${data.shared_params.prefetch_batches} + use_topics: ${data.shared_params.use_topics} + + num_workers: + train: 0 + val: 0 + test: 0 diff --git a/relik/retriever/conf/data/sampled_negatives.yaml b/relik/retriever/conf/data/sampled_negatives.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4f581a4a6488a7ad516aca08a5d0a0fba5906fe1 --- /dev/null +++ b/relik/retriever/conf/data/sampled_negatives.yaml @@ -0,0 +1,39 @@ +max_passages: 64 + +datamodule: + _target_: relik.retriever.lightning_modules.pl_data_modules.PLDataModule + tokenizer: ${model.language_model} + datasets: + train: + _target_: relik.retriever.data.dpr.datasets.SampledNegativesDPRDataset + name: "train" + passages_path: ${data_overrides.passages_path} + max_passage_length: 64 + max_passages: ${data.max_passages} + path: ${data_overrides.train_path} + + val: + - _target_: relik.retriever.data.dpr.datasets.SampledNegativesDPRDataset + name: "val" + passages_path: ${data_overrides.passages_path} + max_passage_length: 64 + max_passages: ${data.max_passages} + path: ${data_overrides.val_path} + + test: + - _target_: relik.retriever.data.dpr.datasets.SampledNegativesDPRDataset + name: "test" + passages_path: ${data_overrides.passages_path} + max_passage_length: 64 + max_passages: ${data.max_passages} + path: ${data_overrides.test_path} + + batch_sizes: + train: 4 + val: 64 + test: 64 + + num_workers: + train: 4 + val: 4 + test: 4 diff --git a/relik/retriever/conf/finetune_iterable_in_batch.yaml b/relik/retriever/conf/finetune_iterable_in_batch.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7bcaa034ca3dca3e82d6efa6d6b674839f2ec880 --- /dev/null +++ b/relik/retriever/conf/finetune_iterable_in_batch.yaml @@ -0,0 +1,117 @@ +# Required to make the "experiments" dir the default one for the output of the models +hydra: + run: + dir: ./experiments/${model_name}/${now:%Y-%m-%d}/${now:%H-%M-%S} + +model_name: ${model.language_model} # used to name the model in wandb +project_name: relik-retriever # used to name the project in wandb + +defaults: + - _self_ + - model: golden_retriever + - index: inmemory + - loss: nce_loss + - optimizer: radamw + - scheduler: linear_scheduler + - data: dataset_v2 # iterable_in_batch_negatives #dataset_v2 + - logging: wandb_logging + - override hydra/job_logging: colorlog + - override hydra/hydra_logging: colorlog + +train: + # reproducibility + seed: 42 + set_determinism_the_old_way: False + # torch parameters + float32_matmul_precision: "medium" + # if true, only test the model + only_test: False + # if provided, initialize the model with the weights from the checkpoint + pretrain_ckpt_path: null + # if provided, start training from the checkpoint + checkpoint_path: null + + # task specific parameter + top_k: 100 + + # pl_trainer + pl_trainer: + _target_: lightning.Trainer + accelerator: gpu + devices: 1 + num_nodes: 1 + strategy: auto + accumulate_grad_batches: 1 + gradient_clip_val: 1.0 + val_check_interval: 1.0 # you can specify an int "n" here => validation every "n" steps + check_val_every_n_epoch: 1 + max_epochs: 0 + max_steps: 25_000 + deterministic: True + fast_dev_run: False + precision: 16 + reload_dataloaders_every_n_epochs: 1 + + early_stopping_callback: + # null + _target_: lightning.callbacks.EarlyStopping + monitor: validate_recall@${train.top_k} + mode: max + patience: 3 + + model_checkpoint_callback: + _target_: lightning.callbacks.ModelCheckpoint + monitor: validate_recall@${train.top_k} + mode: max + verbose: True + save_top_k: 1 + save_last: False + filename: "checkpoint-validate_recall@${train.top_k}_{validate_recall@${train.top_k}:.4f}-epoch_{epoch:02d}" + auto_insert_metric_name: False + + callbacks: + prediction_callback: + _target_: relik.retriever.callbacks.prediction_callbacks.GoldenRetrieverPredictionCallback + k: ${train.top_k} + batch_size: 64 + precision: 16 + index_precision: 16 + other_callbacks: + - _target_: relik.retriever.callbacks.evaluation_callbacks.RecallAtKEvaluationCallback + k: ${train.top_k} + verbose: True + - _target_: relik.retriever.callbacks.evaluation_callbacks.RecallAtKEvaluationCallback + k: 50 + verbose: True + prog_bar: False + - _target_: relik.retriever.callbacks.evaluation_callbacks.AvgRankingEvaluationCallback + k: ${train.top_k} + verbose: True + - _target_: relik.retriever.callbacks.utils_callbacks.SavePredictionsCallback + + hard_negatives_callback: + _target_: relik.retriever.callbacks.prediction_callbacks.NegativeAugmentationCallback + k: ${train.top_k} + batch_size: 64 + precision: 16 + index_precision: 16 + stages: [validate] #[validate, sanity_check] + metrics_to_monitor: + validate_recall@${train.top_k} + # - sanity_check_recall@${train.top_k} + threshold: 0.0 + max_negatives: 20 + add_with_probability: 1.0 + refresh_every_n_epochs: 1 + other_callbacks: + - _target_: relik.retriever.callbacks.evaluation_callbacks.AvgRankingEvaluationCallback + k: ${train.top_k} + verbose: True + prefix: "train" + + utils_callbacks: + - _target_: relik.retriever.callbacks.utils_callbacks.SaveRetrieverCallback + - _target_: relik.retriever.callbacks.utils_callbacks.FreeUpIndexerVRAMCallback + # - _target_: relik.retriever.callbacks.utils_callbacks.ResetModelCallback + # question_encoder: ${model.pl_module.model.question_encoder} + # passage_encoder: ${model.pl_module.model.passage_encoder} diff --git a/relik/retriever/conf/index/inmemory.yaml b/relik/retriever/conf/index/inmemory.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d77f5b79946a82384cb4da0c0f60b7b2700e9280 --- /dev/null +++ b/relik/retriever/conf/index/inmemory.yaml @@ -0,0 +1,4 @@ +_target_: relik.retriever.indexers.inmemory.InMemoryDocumentIndex +documents: ${data.shared_params.passages_path} +device: cuda +precision: 16 diff --git a/relik/retriever/conf/logging/wandb_logging.yaml b/relik/retriever/conf/logging/wandb_logging.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1908d7e09789ca0b8e4973ec7f3ca5d47d460af3 --- /dev/null +++ b/relik/retriever/conf/logging/wandb_logging.yaml @@ -0,0 +1,16 @@ +# don't forget loggers.login() for the first usage. + +log: True # set to False to avoid the logging + +wandb_arg: + _target_: lightning.loggers.WandbLogger + name: ${model_name} + project: ${project_name} + save_dir: ./ + log_model: True + mode: "online" + entity: null + +watch: + log: "all" + log_freq: 100 diff --git a/relik/retriever/conf/loss/nce_loss.yaml b/relik/retriever/conf/loss/nce_loss.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fe9246b88027fd0d499b9bc3b4beaa7937b23586 --- /dev/null +++ b/relik/retriever/conf/loss/nce_loss.yaml @@ -0,0 +1 @@ +_target_: relik.retriever.pythorch_modules.losses.MultiLabelNCELoss diff --git a/relik/retriever/conf/loss/nll_loss.yaml b/relik/retriever/conf/loss/nll_loss.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1e0a5010025a4a6e9da382e5408af4a58cda6185 --- /dev/null +++ b/relik/retriever/conf/loss/nll_loss.yaml @@ -0,0 +1 @@ +_target_: torch.nn.NLLLoss diff --git a/relik/retriever/conf/optimizer/adamw.yaml b/relik/retriever/conf/optimizer/adamw.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ff0f84e15ebd6c60e3e6e411c30e88fb910fdb29 --- /dev/null +++ b/relik/retriever/conf/optimizer/adamw.yaml @@ -0,0 +1,4 @@ +_target_: torch.optim.AdamW +lr: 1e-5 +weight_decay: 0.01 +fused: False diff --git a/relik/retriever/conf/optimizer/radam.yaml b/relik/retriever/conf/optimizer/radam.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b5d2a4ecf468327bda98fd205bf483b6828cf653 --- /dev/null +++ b/relik/retriever/conf/optimizer/radam.yaml @@ -0,0 +1,3 @@ +_target_: torch.optim.RAdam +lr: 1e-5 +weight_decay: 0 diff --git a/relik/retriever/conf/optimizer/radamw.yaml b/relik/retriever/conf/optimizer/radamw.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6f1fc8c4696bf793180366a64baf107829cf7752 --- /dev/null +++ b/relik/retriever/conf/optimizer/radamw.yaml @@ -0,0 +1,3 @@ +_target_: relik.retriever.pytorch_modules.optim.RAdamW +lr: 1e-5 +weight_decay: 0.01 diff --git a/relik/retriever/conf/pretrain_iterable_in_batch.yaml b/relik/retriever/conf/pretrain_iterable_in_batch.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c003d612cf718f4b4b84e866009a2944a5c22d9a --- /dev/null +++ b/relik/retriever/conf/pretrain_iterable_in_batch.yaml @@ -0,0 +1,114 @@ +# Required to make the "experiments" dir the default one for the output of the models +hydra: + run: + dir: ./experiments/${model_name}/${now:%Y-%m-%d}/${now:%H-%M-%S} + +model_name: ${model.language_model} # used to name the model in wandb +project_name: relik-retriever # used to name the project in wandb + +defaults: + - _self_ + - model: golden_retriever + - index: inmemory + - loss: nce_loss + - optimizer: radamw + - scheduler: linear_scheduler + - data: dataset_v2 + - logging: wandb_logging + - override hydra/job_logging: colorlog + - override hydra/hydra_logging: colorlog + +train: + # reproducibility + seed: 42 + set_determinism_the_old_way: False + # torch parameters + float32_matmul_precision: "medium" + # if true, only test the model + only_test: False + # if provided, initialize the model with the weights from the checkpoint + pretrain_ckpt_path: null + # if provided, start training from the checkpoint + checkpoint_path: null + + # task specific parameter + top_k: 100 + + # pl_trainer + pl_trainer: + _target_: lightning.Trainer + accelerator: gpu + devices: 1 + num_nodes: 1 + strategy: auto + accumulate_grad_batches: 1 + gradient_clip_val: 1.0 + val_check_interval: 1.0 # you can specify an int "n" here => validation every "n" steps + check_val_every_n_epoch: 1 + max_epochs: 0 + max_steps: 220_000 + deterministic: True + fast_dev_run: False + precision: 16 + reload_dataloaders_every_n_epochs: 1 + + early_stopping_callback: + null + # _target_: lightning.callbacks.EarlyStopping + # monitor: validate_recall@${train.top_k} + # mode: max + # patience: 15 + + model_checkpoint_callback: + _target_: lightning.callbacks.ModelCheckpoint + monitor: validate_recall@${train.top_k} + mode: max + verbose: True + save_top_k: 1 + save_last: True + filename: "checkpoint-validate_recall@${train.top_k}_{validate_recall@${train.top_k}:.4f}-epoch_{epoch:02d}" + auto_insert_metric_name: False + + callbacks: + prediction_callback: + _target_: relik.retriever.callbacks.prediction_callbacks.GoldenRetrieverPredictionCallback + k: ${train.top_k} + batch_size: 128 + precision: 16 + index_precision: 16 + other_callbacks: + - _target_: relik.retriever.callbacks.evaluation_callbacks.RecallAtKEvaluationCallback + k: ${train.top_k} + verbose: True + - _target_: relik.retriever.callbacks.evaluation_callbacks.RecallAtKEvaluationCallback + k: 50 + verbose: True + prog_bar: False + - _target_: relik.retriever.callbacks.evaluation_callbacks.AvgRankingEvaluationCallback + k: ${train.top_k} + verbose: True + - _target_: relik.retriever.callbacks.utils_callbacks.SavePredictionsCallback + + hard_negatives_callback: + _target_: relik.retriever.callbacks.prediction_callbacks.NegativeAugmentationCallback + k: ${train.top_k} + batch_size: 128 + precision: 16 + index_precision: 16 + stages: [validate] #[validate, sanity_check] + metrics_to_monitor: + validate_recall@${train.top_k} + # - sanity_check_recall@${train.top_k} + threshold: 0.0 + max_negatives: 15 + add_with_probability: 0.2 + refresh_every_n_epochs: 1 + other_callbacks: + - _target_: relik.retriever.callbacks.evaluation_callbacks.AvgRankingEvaluationCallback + k: ${train.top_k} + verbose: True + prefix: "train" + + utils_callbacks: + - _target_: relik.retriever.callbacks.utils_callbacks.SaveRetrieverCallback + - _target_: relik.retriever.callbacks.utils_callbacks.FreeUpIndexerVRAMCallback diff --git a/relik/retriever/conf/scheduler/linear_scheduler.yaml b/relik/retriever/conf/scheduler/linear_scheduler.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d1896bff5d01ee2543639e4e379674d60682a0f6 --- /dev/null +++ b/relik/retriever/conf/scheduler/linear_scheduler.yaml @@ -0,0 +1,3 @@ +_target_: transformers.get_linear_schedule_with_warmup +num_warmup_steps: 0 +num_training_steps: ${train.pl_trainer.max_steps} diff --git a/relik/retriever/conf/scheduler/linear_scheduler_with_warmup.yaml b/relik/retriever/conf/scheduler/linear_scheduler_with_warmup.yaml new file mode 100644 index 0000000000000000000000000000000000000000..417857489486469032b7e8b19d509a1e45da043c --- /dev/null +++ b/relik/retriever/conf/scheduler/linear_scheduler_with_warmup.yaml @@ -0,0 +1,3 @@ +_target_: transformers.get_linear_schedule_with_warmup +num_warmup_steps: 5_000 +num_training_steps: ${train.pl_trainer.max_steps} diff --git a/relik/retriever/conf/scheduler/none.yaml b/relik/retriever/conf/scheduler/none.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ec747fa47ddb81e9bf2d282011ed32aa4c59f932 --- /dev/null +++ b/relik/retriever/conf/scheduler/none.yaml @@ -0,0 +1 @@ +null \ No newline at end of file diff --git a/relik/retriever/data/__init__.py b/relik/retriever/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/relik/retriever/data/base/__init__.py b/relik/retriever/data/base/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/relik/retriever/data/base/datasets.py b/relik/retriever/data/base/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..8a011402a7f6ead5914eb0808a62e9e967c8c12d --- /dev/null +++ b/relik/retriever/data/base/datasets.py @@ -0,0 +1,89 @@ +import os +from pathlib import Path +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union + +import torch +from torch.utils.data import Dataset, IterableDataset + +from relik.common.log import get_logger + +logger = get_logger() + + +class BaseDataset(Dataset): + def __init__( + self, + name: str, + path: Optional[Union[str, os.PathLike, List[str], List[os.PathLike]]] = None, + data: Any = None, + **kwargs, + ): + super().__init__() + self.name = name + if path is None and data is None: + raise ValueError("Either `path` or `data` must be provided") + self.path = path + self.project_folder = Path(__file__).parent.parent.parent + self.data = data + + def __len__(self) -> int: + return len(self.data) + + def __getitem__( + self, index + ) -> Union[Dict[str, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: + return self.data[index] + + def __repr__(self) -> str: + return f"Dataset({self.name=}, {self.path=})" + + def load( + self, + paths: Union[str, os.PathLike, List[str], List[os.PathLike]], + *args, + **kwargs, + ) -> Any: + # load data from single or multiple paths in one single dataset + raise NotImplementedError + + @staticmethod + def collate_fn(batch: Any, *args, **kwargs) -> Any: + raise NotImplementedError + + +class IterableBaseDataset(IterableDataset): + def __init__( + self, + name: str, + path: Optional[Union[str, Path, List[str], List[Path]]] = None, + data: Any = None, + *args, + **kwargs, + ): + super().__init__() + self.name = name + if path is None and data is None: + raise ValueError("Either `path` or `data` must be provided") + self.path = path + self.project_folder = Path(__file__).parent.parent.parent + self.data = data + + def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: + for sample in self.data: + yield sample + + def __repr__(self) -> str: + return f"Dataset({self.name=}, {self.path=})" + + def load( + self, + paths: Union[str, os.PathLike, List[str], List[os.PathLike]], + *args, + **kwargs, + ) -> Any: + # load data from single or multiple paths in one single dataset + raise NotImplementedError + + @staticmethod + def collate_fn(batch: Any, *args, **kwargs) -> Any: + raise NotImplementedError diff --git a/relik/retriever/data/datasets.py b/relik/retriever/data/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..6cf1897a9955d9eb902d74dd08b7dbcd09fd68e4 --- /dev/null +++ b/relik/retriever/data/datasets.py @@ -0,0 +1,726 @@ +import os +from copy import deepcopy +from enum import Enum +from functools import partial +from pathlib import Path +from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union + +import datasets +import psutil +import torch +import transformers as tr +from datasets import load_dataset +from torch.utils.data import Dataset +from tqdm import tqdm + +from relik.common.log import get_console_logger, get_logger +from relik.retriever.common.model_inputs import ModelInputs +from relik.retriever.data.base.datasets import BaseDataset, IterableBaseDataset +from relik.retriever.data.utils import HardNegativesManager + +console_logger = get_console_logger() + +logger = get_logger(__name__) + + +class SubsampleStrategyEnum(Enum): + NONE = "none" + RANDOM = "random" + IN_ORDER = "in_order" + + +class GoldenRetrieverDataset: + def __init__( + self, + name: str, + path: Union[str, os.PathLike, List[str], List[os.PathLike]] = None, + data: Any = None, + tokenizer: Optional[Union[str, tr.PreTrainedTokenizer]] = None, + # passages: Union[str, os.PathLike, List[str]] = None, + passage_batch_size: int = 32, + question_batch_size: int = 32, + max_positives: int = -1, + max_negatives: int = 0, + max_hard_negatives: int = 0, + max_question_length: int = 256, + max_passage_length: int = 64, + shuffle: bool = False, + subsample_strategy: Optional[str] = SubsampleStrategyEnum.NONE, + subsample_portion: float = 0.1, + num_proc: Optional[int] = None, + load_from_cache_file: bool = True, + keep_in_memory: bool = False, + prefetch: bool = True, + load_fn_kwargs: Optional[Dict[str, Any]] = None, + batch_fn_kwargs: Optional[Dict[str, Any]] = None, + collate_fn_kwargs: Optional[Dict[str, Any]] = None, + ): + if path is None and data is None: + raise ValueError("Either `path` or `data` must be provided") + + if tokenizer is None: + raise ValueError("A tokenizer must be provided") + + # dataset parameters + self.name = name + self.path = Path(path) or path + if path is not None and not isinstance(self.path, Sequence): + self.path = [self.path] + # self.project_folder = Path(__file__).parent.parent.parent + self.data = data + + # hyper-parameters + self.passage_batch_size = passage_batch_size + self.question_batch_size = question_batch_size + self.max_positives = max_positives + self.max_negatives = max_negatives + self.max_hard_negatives = max_hard_negatives + self.max_question_length = max_question_length + self.max_passage_length = max_passage_length + self.shuffle = shuffle + self.num_proc = num_proc + self.load_from_cache_file = load_from_cache_file + self.keep_in_memory = keep_in_memory + self.prefetch = prefetch + + self.tokenizer = tokenizer + if isinstance(self.tokenizer, str): + self.tokenizer = tr.AutoTokenizer.from_pretrained(self.tokenizer) + + self.padding_ops = { + "input_ids": partial( + self.pad_sequence, + value=self.tokenizer.pad_token_id, + ), + "attention_mask": partial(self.pad_sequence, value=0), + "token_type_ids": partial( + self.pad_sequence, + value=self.tokenizer.pad_token_type_id, + ), + } + + # check if subsample strategy is valid + if subsample_strategy is not None: + # subsample_strategy can be a string or a SubsampleStrategy + if isinstance(subsample_strategy, str): + try: + subsample_strategy = SubsampleStrategyEnum(subsample_strategy) + except ValueError: + raise ValueError( + f"Subsample strategy {subsample_strategy} is not valid. " + f"Valid strategies are: {SubsampleStrategyEnum.__members__}" + ) + if not isinstance(subsample_strategy, SubsampleStrategyEnum): + raise ValueError( + f"Subsample strategy {subsample_strategy} is not valid. " + f"Valid strategies are: {SubsampleStrategyEnum.__members__}" + ) + self.subsample_strategy = subsample_strategy + self.subsample_portion = subsample_portion + + # load the dataset + if data is None: + self.data: Dataset = self.load( + self.path, + tokenizer=self.tokenizer, + load_from_cache_file=load_from_cache_file, + load_fn_kwargs=load_fn_kwargs, + num_proc=num_proc, + shuffle=shuffle, + keep_in_memory=keep_in_memory, + max_positives=max_positives, + max_negatives=max_negatives, + max_hard_negatives=max_hard_negatives, + max_question_length=max_question_length, + max_passage_length=max_passage_length, + ) + else: + self.data: Dataset = data + + self.hn_manager: Optional[HardNegativesManager] = None + + # keep track of how many times the dataset has been iterated over + self.number_of_complete_iterations = 0 + + def __repr__(self) -> str: + return f"GoldenRetrieverDataset({self.name=}, {self.path=})" + + def __len__(self) -> int: + raise NotImplementedError + + def __getitem__( + self, index + ) -> Union[Dict[str, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: + raise NotImplementedError + + def to_torch_dataset(self, *args, **kwargs) -> torch.utils.data.Dataset: + raise NotImplementedError + + def load( + self, + paths: Union[str, os.PathLike, List[str], List[os.PathLike]], + tokenizer: tr.PreTrainedTokenizer = None, + load_fn_kwargs: Dict = None, + load_from_cache_file: bool = True, + num_proc: Optional[int] = None, + shuffle: bool = False, + keep_in_memory: bool = True, + max_positives: int = -1, + max_negatives: int = -1, + max_hard_negatives: int = -1, + max_passages: int = -1, + max_question_length: int = 256, + max_passage_length: int = 64, + *args, + **kwargs, + ) -> Any: + # if isinstance(paths, Sequence): + # paths = [self.project_folder / path for path in paths] + # else: + # paths = [self.project_folder / paths] + + # read the data and put it in a placeholder list + for path in paths: + if not path.exists(): + raise ValueError(f"{path} does not exist") + + fn_kwargs = dict( + tokenizer=tokenizer, + max_positives=max_positives, + max_negatives=max_negatives, + max_hard_negatives=max_hard_negatives, + max_passages=max_passages, + max_question_length=max_question_length, + max_passage_length=max_passage_length, + ) + if load_fn_kwargs is not None: + fn_kwargs.update(load_fn_kwargs) + + if num_proc is None: + num_proc = psutil.cpu_count(logical=False) + + # The data is a list of dictionaries, each dictionary is a sample + # Each sample has the following keys: + # - "question": the question + # - "answers": a list of answers + # - "positive_ctxs": a list of positive passages + # - "negative_ctxs": a list of negative passages + # - "hard_negative_ctxs": a list of hard negative passages + # use the huggingface dataset library to load the data, by default it will load the + # data in a dict with the key being "train". + logger.info(f"Loading data for dataset {self.name}") + data = load_dataset( + "json", + data_files=[str(p) for p in paths], # datasets needs str paths and not Path + split="train", + streaming=False, # TODO maybe we can make streaming work + keep_in_memory=keep_in_memory, + ) + # add id if not present + if isinstance(data, datasets.Dataset): + data = data.add_column("sample_idx", range(len(data))) + else: + data = data.map( + lambda x, idx: x.update({"sample_idx": idx}), with_indices=True + ) + + map_kwargs = dict( + function=self.load_fn, + fn_kwargs=fn_kwargs, + ) + if isinstance(data, datasets.Dataset): + map_kwargs.update( + dict( + load_from_cache_file=load_from_cache_file, + keep_in_memory=keep_in_memory, + num_proc=num_proc, + desc="Loading data", + ) + ) + # preprocess the data + data = data.map(**map_kwargs) + + # shuffle the data + if shuffle: + data.shuffle(seed=42) + + return data + + @staticmethod + def create_batches( + data: Dataset, + batch_fn: Callable, + batch_fn_kwargs: Optional[Dict[str, Any]] = None, + prefetch: bool = True, + *args, + **kwargs, + ) -> Union[Iterable, List]: + if not prefetch: + # if we are streaming, we don't need to create batches right now + # we will create them on the fly when we need them + batched_data = ( + batch + for batch in batch_fn( + data, **(batch_fn_kwargs if batch_fn_kwargs is not None else {}) + ) + ) + else: + batched_data = [ + batch + for batch in tqdm( + batch_fn( + data, **(batch_fn_kwargs if batch_fn_kwargs is not None else {}) + ), + desc="Creating batches", + ) + ] + return batched_data + + @staticmethod + def collate_batches( + batched_data: Union[Iterable, List], + collate_fn: Callable, + collate_fn_kwargs: Optional[Dict[str, Any]] = None, + prefetch: bool = True, + *args, + **kwargs, + ) -> Union[Iterable, List]: + if not prefetch: + collated_data = ( + collate_fn(batch, **(collate_fn_kwargs if collate_fn_kwargs else {})) + for batch in batched_data + ) + else: + collated_data = [ + collate_fn(batch, **(collate_fn_kwargs if collate_fn_kwargs else {})) + for batch in tqdm(batched_data, desc="Collating batches") + ] + return collated_data + + @staticmethod + def load_fn(sample: Dict, *args, **kwargs) -> Dict: + raise NotImplementedError + + @staticmethod + def batch_fn(data: Dataset, *args, **kwargs) -> Any: + raise NotImplementedError + + @staticmethod + def collate_fn(batch: Any, *args, **kwargs) -> Any: + raise NotImplementedError + + @staticmethod + def pad_sequence( + sequence: Union[List, torch.Tensor], + length: int, + value: Any = None, + pad_to_left: bool = False, + ) -> Union[List, torch.Tensor]: + """ + Pad the input to the specified length with the given value. + + Args: + sequence (:obj:`List`, :obj:`torch.Tensor`): + Element to pad, it can be either a :obj:`List` or a :obj:`torch.Tensor`. + length (:obj:`int`, :obj:`str`, optional, defaults to :obj:`subtoken`): + Length after pad. + value (:obj:`Any`, optional): + Value to use as padding. + pad_to_left (:obj:`bool`, optional, defaults to :obj:`False`): + If :obj:`True`, pads to the left, right otherwise. + + Returns: + :obj:`List`, :obj:`torch.Tensor`: The padded sequence. + + """ + padding = [value] * abs(length - len(sequence)) + if isinstance(sequence, torch.Tensor): + if len(sequence.shape) > 1: + raise ValueError( + f"Sequence tensor must be 1D. Current shape is `{len(sequence.shape)}`" + ) + padding = torch.as_tensor(padding) + if pad_to_left: + if isinstance(sequence, torch.Tensor): + return torch.cat((padding, sequence), -1) + return padding + sequence + if isinstance(sequence, torch.Tensor): + return torch.cat((sequence, padding), -1) + return sequence + padding + + def convert_to_batch( + self, samples: Any, *args, **kwargs + ) -> Dict[str, torch.Tensor]: + """ + Convert the list of samples to a batch. + + Args: + samples (:obj:`List`): + List of samples to convert to a batch. + + Returns: + :obj:`Dict[str, torch.Tensor]`: The batch. + """ + # invert questions from list of dict to dict of list + samples = {k: [d[k] for d in samples] for k in samples[0]} + # get max length of questions + max_len = max(len(x) for x in samples["input_ids"]) + # pad the questions + for key in samples: + if key in self.padding_ops: + samples[key] = torch.as_tensor( + [self.padding_ops[key](b, max_len) for b in samples[key]] + ) + return samples + + def shuffle_data(self, seed: int = 42): + self.data = self.data.shuffle(seed=seed) + + +class InBatchNegativesDataset(GoldenRetrieverDataset): + def __len__(self) -> int: + if isinstance(self.data, datasets.Dataset): + return len(self.data) + + def __getitem__( + self, index + ) -> Union[Dict[str, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: + return self.data[index] + + def to_torch_dataset(self) -> torch.utils.data.Dataset: + shuffle_this_time = self.shuffle + + if ( + self.subsample_strategy + and self.subsample_strategy != SubsampleStrategyEnum.NONE + ): + number_of_samples = int(len(self.data) * self.subsample_portion) + if self.subsample_strategy == SubsampleStrategyEnum.RANDOM: + logger.info( + f"Random subsampling {number_of_samples} samples from {len(self.data)}" + ) + data = ( + deepcopy(self.data) + .shuffle(seed=42 + self.number_of_complete_iterations) + .select(range(0, number_of_samples)) + ) + elif self.subsample_strategy == SubsampleStrategyEnum.IN_ORDER: + # number_of_samples = int(len(self.data) * self.subsample_portion) + already_selected = ( + number_of_samples * self.number_of_complete_iterations + ) + logger.info( + f"Subsampling {number_of_samples} samples out of {len(self.data)}" + ) + to_select = min(already_selected + number_of_samples, len(self.data)) + logger.info( + f"Portion of data selected: {already_selected} " f"to {to_select}" + ) + data = deepcopy(self.data).select(range(already_selected, to_select)) + + # don't shuffle the data if we are subsampling, and we have still not completed + # one full iteration over the dataset + if self.number_of_complete_iterations > 0: + shuffle_this_time = False + + # reset the number of complete iterations + if to_select >= len(self.data): + # reset the number of complete iterations, + # we have completed one full iteration over the dataset + # the value is -1 because we want to start from 0 at the next iteration + self.number_of_complete_iterations = -1 + else: + raise ValueError( + f"Subsample strategy `{self.subsample_strategy}` is not valid. " + f"Valid strategies are: {SubsampleStrategyEnum.__members__}" + ) + + else: + data = data = self.data + + # do we need to shuffle the data? + if self.shuffle and shuffle_this_time: + logger.info("Shuffling the data") + data = data.shuffle(seed=42 + self.number_of_complete_iterations) + + batch_fn_kwargs = { + "passage_batch_size": self.passage_batch_size, + "question_batch_size": self.question_batch_size, + "hard_negatives_manager": self.hn_manager, + } + batched_data = self.create_batches( + data, + batch_fn=self.batch_fn, + batch_fn_kwargs=batch_fn_kwargs, + prefetch=self.prefetch, + ) + + batched_data = self.collate_batches( + batched_data, self.collate_fn, prefetch=self.prefetch + ) + + # increment the number of complete iterations + self.number_of_complete_iterations += 1 + + if self.prefetch: + return BaseDataset(name=self.name, data=batched_data) + else: + return IterableBaseDataset(name=self.name, data=batched_data) + + @staticmethod + def load_fn( + sample: Dict, + tokenizer: tr.PreTrainedTokenizer, + max_positives: int, + max_negatives: int, + max_hard_negatives: int, + max_passages: int = -1, + max_question_length: int = 256, + max_passage_length: int = 128, + *args, + **kwargs, + ) -> Dict: + # remove duplicates and limit the number of passages + positives = list(set([p["text"].strip() for p in sample["positive_ctxs"]])) + if max_positives != -1: + positives = positives[:max_positives] + negatives = list(set([n["text"].strip() for n in sample["negative_ctxs"]])) + if max_negatives != -1: + negatives = negatives[:max_negatives] + hard_negatives = list( + set([h["text"].strip() for h in sample["hard_negative_ctxs"]]) + ) + if max_hard_negatives != -1: + hard_negatives = hard_negatives[:max_hard_negatives] + + question = tokenizer( + sample["question"], max_length=max_question_length, truncation=True + ) + + passage = positives + negatives + hard_negatives + if max_passages != -1: + passage = passage[:max_passages] + + passage = tokenizer(passage, max_length=max_passage_length, truncation=True) + + # invert the passage data structure from a dict of lists to a list of dicts + passage = [dict(zip(passage, t)) for t in zip(*passage.values())] + + output = dict( + question=question, + passage=passage, + positives=positives, + positive_pssgs=passage[: len(positives)], + ) + return output + + @staticmethod + def batch_fn( + data: Dataset, + passage_batch_size: int, + question_batch_size: int, + hard_negatives_manager: Optional[HardNegativesManager] = None, + *args, + **kwargs, + ) -> Dict[str, List[Dict[str, Any]]]: + def split_batch( + batch: Union[Dict[str, Any], ModelInputs], question_batch_size: int + ) -> List[ModelInputs]: + """ + Split a batch into multiple batches of size `question_batch_size` while keeping + the same number of passages. + """ + + def split_fn(x): + return [ + x[i : i + question_batch_size] + for i in range(0, len(x), question_batch_size) + ] + + # split the sample_idx + sample_idx = split_fn(batch["sample_idx"]) + # split the questions + questions = split_fn(batch["questions"]) + # split the positives + positives = split_fn(batch["positives"]) + # split the positives_pssgs + positives_pssgs = split_fn(batch["positives_pssgs"]) + + # collect the new batches + batches = [] + for i in range(len(questions)): + batches.append( + ModelInputs( + dict( + sample_idx=sample_idx[i], + questions=questions[i], + passages=batch["passages"], + positives=positives[i], + positives_pssgs=positives_pssgs[i], + ) + ) + ) + return batches + + batch = [] + passages_in_batch = {} + + for sample in data: + if len(passages_in_batch) >= passage_batch_size: + # create the batch dict + batch_dict = ModelInputs( + dict( + sample_idx=[s["sample_idx"] for s in batch], + questions=[s["question"] for s in batch], + passages=list(passages_in_batch.values()), + positives_pssgs=[s["positive_pssgs"] for s in batch], + positives=[s["positives"] for s in batch], + ) + ) + # split the batch if needed + if len(batch) > question_batch_size: + for splited_batch in split_batch(batch_dict, question_batch_size): + yield splited_batch + else: + yield batch_dict + + # reset batch + batch = [] + passages_in_batch = {} + + batch.append(sample) + # yes it's a bit ugly but it works :) + # count the number of passages in the batch and stop if we reach the limit + # we use a set to avoid counting the same passage twice + # we use a tuple because set doesn't support lists + # we use input_ids as discriminator + passages_in_batch.update( + {tuple(passage["input_ids"]): passage for passage in sample["passage"]} + ) + # check for hard negatives and add with a probability of 0.1 + if hard_negatives_manager is not None: + if sample["sample_idx"] in hard_negatives_manager: + passages_in_batch.update( + { + tuple(passage["input_ids"]): passage + for passage in hard_negatives_manager.get( + sample["sample_idx"] + ) + } + ) + + # left over + if len(batch) > 0: + # create the batch dict + batch_dict = ModelInputs( + dict( + sample_idx=[s["sample_idx"] for s in batch], + questions=[s["question"] for s in batch], + passages=list(passages_in_batch.values()), + positives_pssgs=[s["positive_pssgs"] for s in batch], + positives=[s["positives"] for s in batch], + ) + ) + # split the batch if needed + if len(batch) > question_batch_size: + for splited_batch in split_batch(batch_dict, question_batch_size): + yield splited_batch + else: + yield batch_dict + + def collate_fn(self, batch: Any, *args, **kwargs) -> Any: + # convert questions and passages to a batch + questions = self.convert_to_batch(batch.questions) + passages = self.convert_to_batch(batch.passages) + + # build an index to map the position of the passage in the batch + passage_index = {tuple(c["input_ids"]): i for i, c in enumerate(batch.passages)} + + # now we can create the labels + labels = torch.zeros( + questions["input_ids"].shape[0], passages["input_ids"].shape[0] + ) + # iterate over the questions and set the labels to 1 if the passage is positive + for sample_idx in range(len(questions["input_ids"])): + for pssg in batch["positives_pssgs"][sample_idx]: + # get the index of the positive passage + index = passage_index[tuple(pssg["input_ids"])] + # set the label to 1 + labels[sample_idx, index] = 1 + + model_inputs = ModelInputs( + { + "questions": questions, + "passages": passages, + "labels": labels, + "positives": batch["positives"], + "sample_idx": batch["sample_idx"], + } + ) + return model_inputs + + +class AidaInBatchNegativesDataset(InBatchNegativesDataset): + def __init__(self, use_topics: bool = False, *args, **kwargs): + if "load_fn_kwargs" not in kwargs: + kwargs["load_fn_kwargs"] = {} + kwargs["load_fn_kwargs"]["use_topics"] = use_topics + super().__init__(*args, **kwargs) + + @staticmethod + def load_fn( + sample: Dict, + tokenizer: tr.PreTrainedTokenizer, + max_positives: int, + max_negatives: int, + max_hard_negatives: int, + max_passages: int = -1, + max_question_length: int = 256, + max_passage_length: int = 128, + use_topics: bool = False, + *args, + **kwargs, + ) -> Dict: + # remove duplicates and limit the number of passages + positives = list(set([p["text"].strip() for p in sample["positive_ctxs"]])) + if max_positives != -1: + positives = positives[:max_positives] + negatives = list(set([n["text"].strip() for n in sample["negative_ctxs"]])) + if max_negatives != -1: + negatives = negatives[:max_negatives] + hard_negatives = list( + set([h["text"].strip() for h in sample["hard_negative_ctxs"]]) + ) + if max_hard_negatives != -1: + hard_negatives = hard_negatives[:max_hard_negatives] + + question = sample["question"] + + if "doc_topic" in sample and use_topics: + question = tokenizer( + question, + sample["doc_topic"], + max_length=max_question_length, + truncation=True, + ) + else: + question = tokenizer( + question, max_length=max_question_length, truncation=True + ) + + passage = positives + negatives + hard_negatives + if max_passages != -1: + passage = passage[:max_passages] + + passage = tokenizer(passage, max_length=max_passage_length, truncation=True) + + # invert the passage data structure from a dict of lists to a list of dicts + passage = [dict(zip(passage, t)) for t in zip(*passage.values())] + + output = dict( + question=question, + passage=passage, + positives=positives, + positive_pssgs=passage[: len(positives)], + ) + return output diff --git a/relik/retriever/data/labels.py b/relik/retriever/data/labels.py new file mode 100644 index 0000000000000000000000000000000000000000..de8f87a2186a413b47d686f92b4d3039536a5988 --- /dev/null +++ b/relik/retriever/data/labels.py @@ -0,0 +1,338 @@ +import json +from pathlib import Path +from typing import Dict, List, Optional, Set, Union + +import transformers as tr + + +class Labels: + """ + Class that contains the labels for a model. + + Args: + _labels_to_index (:obj:`Dict[str, Dict[str, int]]`): + A dictionary from :obj:`str` to :obj:`int`. + _index_to_labels (:obj:`Dict[str, Dict[int, str]]`): + A dictionary from :obj:`int` to :obj:`str`. + """ + + def __init__( + self, + _labels_to_index: Dict[str, Dict[str, int]] = None, + _index_to_labels: Dict[str, Dict[int, str]] = None, + **kwargs, + ): + self._labels_to_index = _labels_to_index or {"labels": {}} + self._index_to_labels = _index_to_labels or {"labels": {}} + # if _labels_to_index is not empty and _index_to_labels is not provided + # to the constructor, build the inverted label dictionary + if not _index_to_labels and _labels_to_index: + for namespace in self._labels_to_index: + self._index_to_labels[namespace] = { + v: k for k, v in self._labels_to_index[namespace].items() + } + + def get_index_from_label(self, label: str, namespace: str = "labels") -> int: + """ + Returns the index of a literal label. + + Args: + label (:obj:`str`): + The string representation of the label. + namespace (:obj:`str`, optional, defaults to ``labels``): + The namespace where the label belongs, e.g. ``roles`` for a SRL task. + + Returns: + :obj:`int`: The index of the label. + """ + if namespace not in self._labels_to_index: + raise ValueError( + f"Provided namespace `{namespace}` is not in the label dictionary." + ) + + if label not in self._labels_to_index[namespace]: + raise ValueError(f"Provided label {label} is not in the label dictionary.") + + return self._labels_to_index[namespace][label] + + def get_label_from_index(self, index: int, namespace: str = "labels") -> str: + """ + Returns the string representation of the label index. + + Args: + index (:obj:`int`): + The index of the label. + namespace (:obj:`str`, optional, defaults to ``labels``): + The namespace where the label belongs, e.g. ``roles`` for a SRL task. + + Returns: + :obj:`str`: The string representation of the label. + """ + if namespace not in self._index_to_labels: + raise ValueError( + f"Provided namespace `{namespace}` is not in the label dictionary." + ) + + if index not in self._index_to_labels[namespace]: + raise ValueError( + f"Provided label `{index}` is not in the label dictionary." + ) + + return self._index_to_labels[namespace][index] + + def add_labels( + self, + labels: Union[str, List[str], Set[str], Dict[str, int]], + namespace: str = "labels", + ) -> List[int]: + """ + Adds the labels in input in the label dictionary. + + Args: + labels (:obj:`str`, :obj:`List[str]`, :obj:`Set[str]`): + The labels (single label, list of labels or set of labels) to add to the dictionary. + namespace (:obj:`str`, optional, defaults to ``labels``): + Namespace where the labels belongs. + + Returns: + :obj:`List[int]`: The index of the labels just inserted. + """ + if isinstance(labels, dict): + self._labels_to_index[namespace] = labels + self._index_to_labels[namespace] = { + v: k for k, v in self._labels_to_index[namespace].items() + } + # normalize input + if isinstance(labels, (str, list)): + labels = set(labels) + # if new namespace, add to the dictionaries + if namespace not in self._labels_to_index: + self._labels_to_index[namespace] = {} + self._index_to_labels[namespace] = {} + # returns the new indices + return [self._add_label(label, namespace) for label in labels] + + def _add_label(self, label: str, namespace: str = "labels") -> int: + """ + Adds the label in input in the label dictionary. + + Args: + label (:obj:`str`): + The label to add to the dictionary. + namespace (:obj:`str`, optional, defaults to ``labels``): + Namespace where the label belongs. + + Returns: + :obj:`List[int]`: The index of the label just inserted. + """ + if label not in self._labels_to_index[namespace]: + index = len(self._labels_to_index[namespace]) + self._labels_to_index[namespace][label] = index + self._index_to_labels[namespace][index] = label + return index + else: + return self._labels_to_index[namespace][label] + + def get_labels(self, namespace: str = "labels") -> Dict[str, int]: + """ + Returns all the labels that belongs to the input namespace. + + Args: + namespace (:obj:`str`, optional, defaults to ``labels``): + Labels namespace to retrieve. + + Returns: + :obj:`Dict[str, int]`: The label dictionary, from ``str`` to ``int``. + """ + if namespace not in self._labels_to_index: + raise ValueError( + f"Provided namespace `{namespace}` is not in the label dictionary." + ) + return self._labels_to_index[namespace] + + def get_label_size(self, namespace: str = "labels") -> int: + """ + Returns the number of the labels in the namespace dictionary. + + Args: + namespace (:obj:`str`, optional, defaults to ``labels``): + Labels namespace to retrieve. + + Returns: + :obj:`int`: Number of labels. + """ + if namespace not in self._labels_to_index: + raise ValueError( + f"Provided namespace `{namespace}` is not in the label dictionary." + ) + return len(self._labels_to_index[namespace]) + + def get_namespaces(self) -> List[str]: + """ + Returns all the namespaces in the label dictionary. + + Returns: + :obj:`List[str]`: The namespaces in the label dictionary. + """ + return list(self._labels_to_index.keys()) + + @classmethod + def from_file(cls, file_path: Union[str, Path, dict], **kwargs): + with open(file_path, "r") as f: + labels_to_index = json.load(f) + return cls(labels_to_index, **kwargs) + + def save(self, file_path: Union[str, Path, dict], **kwargs): + with open(file_path, "w") as f: + json.dump(self._labels_to_index, f, indent=2) + + +class PassageManager: + def __init__( + self, + tokenizer: Optional[tr.PreTrainedTokenizer] = None, + passages: Optional[Union[Dict[str, Dict[str, int]], Labels, List[str]]] = None, + lazy: bool = True, + **kwargs, + ): + if passages is None: + self.passages = Labels() + elif isinstance(passages, Labels): + self.passages = passages + elif isinstance(passages, dict): + self.passages = Labels(passages) + elif isinstance(passages, list): + self.passages = Labels() + self.passages.add_labels(passages) + else: + raise ValueError( + "`passages` should be either a Labels object or a dictionary." + ) + + self.tokenizer = tokenizer + self.lazy = lazy + + self._tokenized_passages = {} + + if not self.lazy: + self._tokenize_passages(self.passages) + + def __len__(self) -> int: + return self.passages.get_label_size() + + def get_index_from_passage(self, passage: str) -> int: + """ + Returns the index of the passage in input. + + Args: + passage (:obj:`str`): + The passage to get the index from. + + Returns: + :obj:`int`: The index of the passage. + """ + return self.passages.get_index_from_label(passage) + + def get_passage_from_index(self, index: int) -> str: + """ " + Returns the passage from the index in input. + + Args: + index (:obj:`int`): + The index to get the passage from. + + Returns: + :obj:`str`: The passage. + """ + return self.passages.get_label_from_index(index) + + def add_passages( + self, + passages: Union[str, List[str], Set[str], Dict[str, int]], + lazy: Optional[bool] = None, + ) -> List[int]: + """ + Adds the passages in input in the passage dictionary. + + Args: + passages (:obj:`str`, :obj:`List[str]`, :obj:`Set[str]`, :obj:`Dict[str, int]`): + The passages (single passage, list of passages, set of passages or dictionary of passages) to add to the dictionary. + lazy (:obj:`bool`, optional, defaults to ``None``): + Whether to tokenize the passages right away or not. + + Returns: + :obj:`List[int]`: The index of the passages just inserted. + """ + + return self.passages.add_labels(passages) + + def get_passages(self) -> Dict[str, int]: + """ + Returns all the passages in the passage dictionary. + + Returns: + :obj:`Dict[str, int]`: The passage dictionary, from ``str`` to ``int``. + """ + return self.passages.get_labels() + + def get_tokenized_passage( + self, passage: Union[str, int], force_tokenize: bool = False, **kwargs + ) -> Dict: + """ + Returns the tokenized passage in input. + + Args: + passage (:obj:`Union[str, int]`): + The passage to tokenize. + force_tokenize (:obj:`bool`, optional, defaults to ``False``): + Whether to force the tokenization of the passage or not. + kwargs: + Additional keyword arguments to pass to the tokenizer. + + Returns: + :obj:`Dict`: The tokenized passage. + """ + passage_index: Optional[int] = None + passage_str: Optional[str] = None + + if isinstance(passage, str): + passage_index = self.passages.get_index_from_label(passage) + passage_str = passage + elif isinstance(passage, int): + passage_index = passage + passage_str = self.passages.get_label_from_index(passage) + else: + raise ValueError( + f"`passage` should be either a `str` or an `int`. Provided type: {type(passage)}." + ) + + if passage_index not in self._tokenized_passages or force_tokenize: + self._tokenized_passages[passage_index] = self.tokenizer( + passage_str, **kwargs + ) + + return self._tokenized_passages[passage_index] + + def _tokenize_passages(self, **kwargs): + for passage in self.passages.get_labels(): + self.get_tokenized_passage(passage, **kwargs) + + def tokenize(self, text: Union[str, List[str]], **kwargs): + """ + Tokenizes the text in input using the tokenizer. + + Args: + text (:obj:`str`, :obj:`List[str]`): + The text to tokenize. + **kwargs: + Additional keyword arguments to pass to the tokenizer. + + Returns: + :obj:`List[str]`: The tokenized text. + + """ + if self.tokenizer is None: + raise ValueError( + "No tokenizer was provided. Please provide a tokenizer to the passageManager." + ) + return self.tokenizer(text, **kwargs) diff --git a/relik/retriever/data/utils.py b/relik/retriever/data/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..928dfb833919ce2e30c9e90fed539c46f4bd3dec --- /dev/null +++ b/relik/retriever/data/utils.py @@ -0,0 +1,176 @@ +import json +import os +from collections import defaultdict +from typing import Any, Dict, Iterable, List, Optional, Union + +import numpy as np +import transformers as tr +from tqdm import tqdm + + +class HardNegativesManager: + def __init__( + self, + tokenizer: tr.PreTrainedTokenizer, + data: Union[List[Dict], os.PathLike, Dict[int, List]] = None, + max_length: int = 64, + batch_size: int = 1000, + lazy: bool = False, + ) -> None: + self._db: dict = None + self.tokenizer = tokenizer + + if data is None: + self._db = {} + else: + if isinstance(data, Dict): + self._db = data + elif isinstance(data, os.PathLike): + with open(data) as f: + self._db = json.load(f) + else: + raise ValueError( + f"Data type {type(data)} not supported, only Dict and os.PathLike are supported." + ) + # add the tokenizer to the class for future use + self.tokenizer = tokenizer + + # invert the db to have a passage -> sample_idx mapping + self._passage_db = defaultdict(set) + for sample_idx, passages in self._db.items(): + for passage in passages: + self._passage_db[passage].add(sample_idx) + + self._passage_hard_negatives = {} + if not lazy: + # create a dictionary of passage -> hard_negative mapping + batch_size = min(batch_size, len(self._passage_db)) + unique_passages = list(self._passage_db.keys()) + for i in tqdm( + range(0, len(unique_passages), batch_size), + desc="Tokenizing Hard Negatives", + ): + batch = unique_passages[i : i + batch_size] + tokenized_passages = self.tokenizer( + batch, + max_length=max_length, + truncation=True, + ) + for i, passage in enumerate(batch): + self._passage_hard_negatives[passage] = { + k: tokenized_passages[k][i] for k in tokenized_passages.keys() + } + + def __len__(self) -> int: + return len(self._db) + + def __getitem__(self, idx: int) -> Dict: + return self._db[idx] + + def __iter__(self): + for sample in self._db: + yield sample + + def __contains__(self, idx: int) -> bool: + return idx in self._db + + def get(self, idx: int) -> List[str]: + """Get the hard negatives for a given sample index.""" + if idx not in self._db: + raise ValueError(f"Sample index {idx} not in the database.") + + passages = self._db[idx] + + output = [] + for passage in passages: + if passage not in self._passage_hard_negatives: + self._passage_hard_negatives[passage] = self._tokenize(passage) + output.append(self._passage_hard_negatives[passage]) + + return output + + def _tokenize(self, passage: str) -> Dict: + return self.tokenizer(passage, max_length=self.max_length, truncation=True) + + +class NegativeSampler: + def __init__( + self, num_elements: int, probabilities: Optional[Union[List, np.ndarray]] = None + ): + if not isinstance(probabilities, np.ndarray): + probabilities = np.array(probabilities) + + if probabilities is None: + # probabilities should sum to 1 + probabilities = np.random.random(num_elements) + probabilities /= np.sum(probabilities) + self.probabilities = probabilities + + def __call__( + self, + sample_size: int, + num_samples: int = 1, + probabilities: np.array = None, + exclude: List[int] = None, + ) -> np.array: + """ + Fast sampling of `sample_size` elements from `num_elements` elements. + The sampling is done by randomly shifting the probabilities and then + finding the smallest of the negative numbers. This is much faster than + sampling from a multinomial distribution. + + Args: + sample_size (`int`): + number of elements to sample + num_samples (`int`, optional): + number of samples to draw. Defaults to 1. + probabilities (`np.array`, optional): + probabilities of each element. Defaults to None. + exclude (`List[int]`, optional): + indices of elements to exclude. Defaults to None. + + Returns: + `np.array`: array of sampled indices + """ + if probabilities is None: + probabilities = self.probabilities + + if exclude is not None: + probabilities[exclude] = 0 + # re-normalize? + # probabilities /= np.sum(probabilities) + + # replicate probabilities as many times as `num_samples` + replicated_probabilities = np.tile(probabilities, (num_samples, 1)) + # get random shifting numbers & scale them correctly + random_shifts = np.random.random(replicated_probabilities.shape) + random_shifts /= random_shifts.sum(axis=1)[:, np.newaxis] + # shift by numbers & find largest (by finding the smallest of the negative) + shifted_probabilities = random_shifts - replicated_probabilities + sampled_indices = np.argpartition(shifted_probabilities, sample_size, axis=1)[ + :, :sample_size + ] + return sampled_indices + + +def batch_generator(samples: Iterable[Any], batch_size: int) -> Iterable[Any]: + """ + Generate batches from samples. + + Args: + samples (`Iterable[Any]`): Iterable of samples. + batch_size (`int`): Batch size. + + Returns: + `Iterable[Any]`: Iterable of batches. + """ + batch = [] + for sample in samples: + batch.append(sample) + if len(batch) == batch_size: + yield batch + batch = [] + + # leftover batch + if len(batch) > 0: + yield batch diff --git a/relik/retriever/indexers/__init__.py b/relik/retriever/indexers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/relik/retriever/indexers/base.py b/relik/retriever/indexers/base.py new file mode 100644 index 0000000000000000000000000000000000000000..27ba1eeae4d8f9a757daedb1558ed03788a6b8bd --- /dev/null +++ b/relik/retriever/indexers/base.py @@ -0,0 +1,319 @@ +import os +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +import hydra +import numpy +import torch +from omegaconf import OmegaConf +from rich.pretty import pprint + +from relik.common import upload +from relik.common.log import get_console_logger, get_logger +from relik.common.utils import ( + from_cache, + is_remote_url, + is_str_a_path, + relative_to_absolute_path, + sapienzanlp_model_urls, +) +from relik.retriever.data.labels import Labels + +# from relik.retriever.models.model import GoldenRetriever, RetrievedSample + + +logger = get_logger(__name__) +console_logger = get_console_logger() + + +@dataclass +class IndexerOutput: + indices: Union[torch.Tensor, numpy.ndarray] + distances: Union[torch.Tensor, numpy.ndarray] + + +class BaseDocumentIndex: + CONFIG_NAME = "config.yaml" + DOCUMENTS_FILE_NAME = "documents.json" + EMBEDDINGS_FILE_NAME = "embeddings.pt" + + def __init__( + self, + documents: Union[str, List[str], Labels, os.PathLike, List[os.PathLike]] = None, + embeddings: Optional[torch.Tensor] = None, + name_or_dir: Optional[Union[str, os.PathLike]] = None, + ) -> None: + if documents is not None: + if isinstance(documents, Labels): + self.documents = documents + else: + documents_are_paths = False + + # normalize the documents to list if not already + if not isinstance(documents, list): + documents = [documents] + + # now check if the documents are a list of paths (either str or os.PathLike) + if isinstance(documents[0], str) or isinstance( + documents[0], os.PathLike + ): + # check if the str is a path + documents_are_paths = is_str_a_path(documents[0]) + + # if the documents are a list of paths, then we load them + if documents_are_paths: + logger.info("Loading documents from paths") + _documents = [] + for doc in documents: + with open(relative_to_absolute_path(doc)) as f: + _documents += [line.strip() for line in f.readlines()] + # remove duplicates + documents = list(set(_documents)) + + self.documents = Labels() + self.documents.add_labels(documents) + else: + self.documents = Labels() + + self.embeddings = embeddings + self.name_or_dir = name_or_dir + + @property + def config(self) -> Dict[str, Any]: + """ + The configuration of the document index. + + Returns: + `Dict[str, Any]`: The configuration of the retriever. + """ + + def obj_to_dict(obj): + match obj: + case dict(): + data = {} + for k, v in obj.items(): + data[k] = obj_to_dict(v) + return data + + case list() | tuple(): + return [obj_to_dict(x) for x in obj] + + case object(__dict__=_): + data = { + "_target_": f"{obj.__class__.__module__}.{obj.__class__.__name__}", + } + for k, v in obj.__dict__.items(): + if not k.startswith("_"): + data[k] = obj_to_dict(v) + return data + + case _: + return obj + + return obj_to_dict(self) + + def index( + self, + retriever, + *args, + **kwargs, + ) -> "BaseDocumentIndex": + raise NotImplementedError + + def search(self, query: Any, k: int = 1, *args, **kwargs) -> List: + raise NotImplementedError + + def get_index_from_passage(self, document: str) -> int: + """ + Get the index of the passage. + + Args: + document (`str`): + The document to get the index for. + + Returns: + `int`: The index of the document. + """ + return self.documents.get_index_from_label(document) + + def get_passage_from_index(self, index: int) -> str: + """ + Get the document from the index. + + Args: + index (`int`): + The index of the document. + + Returns: + `str`: The document. + """ + return self.documents.get_label_from_index(index) + + def get_embeddings_from_index(self, index: int) -> torch.Tensor: + """ + Get the document vector from the index. + + Args: + index (`int`): + The index of the document. + + Returns: + `torch.Tensor`: The document vector. + """ + if self.embeddings is None: + raise ValueError( + "The documents must be indexed before they can be retrieved." + ) + if index >= self.embeddings.shape[0]: + raise ValueError( + f"The index {index} is out of bounds. The maximum index is {len(self.embeddings) - 1}." + ) + return self.embeddings[index] + + def get_embeddings_from_passage(self, document: str) -> torch.Tensor: + """ + Get the document vector from the document label. + + Args: + document (`str`): + The document to get the vector for. + + Returns: + `torch.Tensor`: The document vector. + """ + if self.embeddings is None: + raise ValueError( + "The documents must be indexed before they can be retrieved." + ) + return self.get_embeddings_from_index(self.get_index_from_passage(document)) + + def save_pretrained( + self, + output_dir: Union[str, os.PathLike], + config: Optional[Dict[str, Any]] = None, + config_file_name: Optional[str] = None, + document_file_name: Optional[str] = None, + embedding_file_name: Optional[str] = None, + push_to_hub: bool = False, + **kwargs, + ): + """ + Save the retriever to a directory. + + Args: + output_dir (`str`): + The directory to save the retriever to. + config (`Optional[Dict[str, Any]]`, `optional`): + The configuration to save. If `None`, the current configuration of the retriever will be + saved. Defaults to `None`. + config_file_name (`Optional[str]`, `optional`): + The name of the configuration file. Defaults to `config.yaml`. + document_file_name (`Optional[str]`, `optional`): + The name of the document file. Defaults to `documents.json`. + embedding_file_name (`Optional[str]`, `optional`): + The name of the embedding file. Defaults to `embeddings.pt`. + push_to_hub (`bool`, `optional`): + Whether to push the saved retriever to the hub. Defaults to `False`. + """ + if config is None: + # create a default config + config = self.config + + config_file_name = config_file_name or self.CONFIG_NAME + document_file_name = document_file_name or self.DOCUMENTS_FILE_NAME + embedding_file_name = embedding_file_name or self.EMBEDDINGS_FILE_NAME + + # create the output directory + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + logger.info(f"Saving retriever to {output_dir}") + logger.info(f"Saving config to {output_dir / config_file_name}") + # pretty print the config + pprint(config, console=console_logger, expand_all=True) + OmegaConf.save(config, output_dir / config_file_name) + + # save the current state of the retriever + embedding_path = output_dir / embedding_file_name + logger.info(f"Saving retriever state to {output_dir / embedding_path}") + torch.save(self.embeddings, embedding_path) + + # save the passage index + documents_path = output_dir / document_file_name + logger.info(f"Saving passage index to {documents_path}") + self.documents.save(documents_path) + + logger.info("Saving document index to disk done.") + + if push_to_hub: + # push to hub + logger.info(f"Pushing to hub") + model_id = model_id or output_dir.name + upload(output_dir, model_id, **kwargs) + + @classmethod + def from_pretrained( + cls, + name_or_dir: Union[str, os.PathLike], + device: str = "cpu", + precision: Optional[str] = None, + config_file_name: Optional[str] = None, + document_file_name: Optional[str] = None, + embedding_file_name: Optional[str] = None, + *args, + **kwargs, + ) -> "BaseDocumentIndex": + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + + config_file_name = config_file_name or cls.CONFIG_NAME + document_file_name = document_file_name or cls.DOCUMENTS_FILE_NAME + embedding_file_name = embedding_file_name or cls.EMBEDDINGS_FILE_NAME + + model_dir = from_cache( + name_or_dir, + filenames=[config_file_name, document_file_name, embedding_file_name], + cache_dir=cache_dir, + force_download=force_download, + ) + + config_path = model_dir / config_file_name + if not config_path.exists(): + raise FileNotFoundError( + f"Model configuration file not found at {config_path}." + ) + + config = OmegaConf.load(config_path) + pprint(OmegaConf.to_container(config), console=console_logger, expand_all=True) + + # load the documents + documents_path = model_dir / document_file_name + + if not documents_path.exists(): + raise ValueError(f"Document file `{documents_path}` does not exist.") + logger.info(f"Loading documents from {documents_path}") + documents = Labels.from_file(documents_path) + + # load the passage embeddings + embedding_path = model_dir / embedding_file_name + # run some checks + embeddings = None + if embedding_path.exists(): + logger.info(f"Loading embeddings from {embedding_path}") + embeddings = torch.load(embedding_path, map_location="cpu") + else: + logger.warning(f"Embedding file `{embedding_path}` does not exist.") + + document_index = hydra.utils.instantiate( + config, + documents=documents, + embeddings=embeddings, + device=device, + precision=precision, + name_or_dir=name_or_dir, + *args, + **kwargs, + ) + + return document_index diff --git a/relik/retriever/indexers/faiss.py b/relik/retriever/indexers/faiss.py new file mode 100644 index 0000000000000000000000000000000000000000..49b752db9f04833706f600c149ee18d16511ff55 --- /dev/null +++ b/relik/retriever/indexers/faiss.py @@ -0,0 +1,399 @@ +import contextlib +import logging +import math +import os +from dataclasses import dataclass +from typing import Callable, List, Optional, Union + +import numpy +import torch +from pytorch_modules import RetrievedSample +from torch.utils.data import DataLoader +from tqdm import tqdm + +from relik.common.log import get_logger +from relik.common.utils import is_package_available +from relik.retriever.common.model_inputs import ModelInputs +from relik.retriever.data.base.datasets import BaseDataset +from relik.retriever.data.labels import Labels +from relik.retriever.indexers.base import BaseDocumentIndex +from relik.retriever.pytorch_modules import PRECISION_MAP +from relik.retriever.pytorch_modules.model import GoldenRetriever + +if is_package_available("faiss"): + import faiss + import faiss.contrib.torch_utils + +logger = get_logger(__name__, level=logging.INFO) + + +@dataclass +class FaissOutput: + indices: Union[torch.Tensor, numpy.ndarray] + distances: Union[torch.Tensor, numpy.ndarray] + + +class FaissDocumentIndex(BaseDocumentIndex): + DOCUMENTS_FILE_NAME = "documents.json" + EMBEDDINGS_FILE_NAME = "embeddings.pt" + INDEX_FILE_NAME = "index.faiss" + + def __init__( + self, + documents: Union[List[str], Labels], + embeddings: Optional[Union[torch.Tensor, numpy.ndarray]] = None, + index=None, + index_type: str = "Flat", + metric: int = faiss.METRIC_INNER_PRODUCT, + normalize: bool = False, + device: str = "cpu", + name_or_dir: Optional[Union[str, os.PathLike]] = None, + *args, + **kwargs, + ) -> None: + super().__init__(documents, embeddings, name_or_dir) + + if embeddings is not None and documents is not None: + logger.info("Both documents and embeddings are provided.") + if documents.get_label_size() != embeddings.shape[0]: + raise ValueError( + "The number of documents and embeddings must be the same." + ) + + # device to store the embeddings + self.device = device + + # params + self.index_type = index_type + self.metric = metric + self.normalize = normalize + + if index is not None: + self.embeddings = index + if self.device == "cuda": + # use a single GPU + faiss_resource = faiss.StandardGpuResources() + self.embeddings = faiss.index_cpu_to_gpu( + faiss_resource, 0, self.embeddings + ) + else: + if embeddings is not None: + # build the faiss index + logger.info("Building the index from the embeddings.") + self.embeddings = self._build_faiss_index( + embeddings=embeddings, + index_type=index_type, + normalize=normalize, + metric=metric, + ) + + def _build_faiss_index( + self, + embeddings: Optional[Union[torch.Tensor, numpy.ndarray]], + index_type: str, + normalize: bool, + metric: int, + ): + # build the faiss index + self.normalize = ( + normalize + and metric == faiss.METRIC_INNER_PRODUCT + and not isinstance(embeddings, torch.Tensor) + ) + if self.normalize: + index_type = f"L2norm,{index_type}" + faiss_vector_size = embeddings.shape[1] + if self.device == "cpu": + index_type = index_type.replace("x,", "x_HNSW32,") + index_type = index_type.replace( + "x", str(math.ceil(math.sqrt(faiss_vector_size)) * 4) + ) + self.embeddings = faiss.index_factory(faiss_vector_size, index_type, metric) + + # convert to GPU + if self.device == "cuda": + # use a single GPU + faiss_resource = faiss.StandardGpuResources() + self.embeddings = faiss.index_cpu_to_gpu(faiss_resource, 0, self.embeddings) + else: + # move to CPU if embeddings is a torch.Tensor + embeddings = ( + embeddings.cpu() if isinstance(embeddings, torch.Tensor) else embeddings + ) + + # convert to float32 if embeddings is a torch.Tensor and is float16 + if isinstance(embeddings, torch.Tensor) and embeddings.dtype == torch.float16: + embeddings = embeddings.float() + + self.embeddings.add(embeddings) + + # save parameters for saving/loading + self.index_type = index_type + self.metric = metric + + # clear the embeddings to free up memory + embeddings = None + + return self.embeddings + + @torch.no_grad() + @torch.inference_mode() + def index( + self, + retriever: GoldenRetriever, + documents: Optional[List[str]] = None, + batch_size: int = 32, + num_workers: int = 4, + max_length: Optional[int] = None, + collate_fn: Optional[Callable] = None, + encoder_precision: Optional[Union[str, int]] = None, + compute_on_cpu: bool = False, + force_reindex: bool = False, + *args, + **kwargs, + ) -> "FaissDocumentIndex": + """ + Index the documents using the encoder. + + Args: + retriever (:obj:`torch.nn.Module`): + The encoder to be used for indexing. + documents (:obj:`List[str]`, `optional`, defaults to None): + The documents to be indexed. + batch_size (:obj:`int`, `optional`, defaults to 32): + The batch size to be used for indexing. + num_workers (:obj:`int`, `optional`, defaults to 4): + The number of workers to be used for indexing. + max_length (:obj:`int`, `optional`, defaults to None): + The maximum length of the input to the encoder. + collate_fn (:obj:`Callable`, `optional`, defaults to None): + The collate function to be used for batching. + encoder_precision (:obj:`Union[str, int]`, `optional`, defaults to None): + The precision to be used for the encoder. + compute_on_cpu (:obj:`bool`, `optional`, defaults to False): + Whether to compute the embeddings on CPU. + force_reindex (:obj:`bool`, `optional`, defaults to False): + Whether to force reindexing. + + Returns: + :obj:`InMemoryIndexer`: The indexer object. + """ + + if self.embeddings is not None and not force_reindex: + logger.log( + "Embeddings are already present and `force_reindex` is `False`. Skipping indexing." + ) + if documents is None: + return self + + # release the memory + if collate_fn is None: + tokenizer = retriever.passage_tokenizer + + def collate_fn(x): + return ModelInputs( + tokenizer( + x, + padding=True, + return_tensors="pt", + truncation=True, + max_length=max_length or tokenizer.model_max_length, + ) + ) + + if force_reindex: + if documents is not None: + self.documents.add_labels(documents) + data = [k for k in self.documents.get_labels()] + + else: + if documents is not None: + data = [k for k in Labels(documents).get_labels()] + else: + return self + + dataloader = DataLoader( + BaseDataset(name="passage", data=data), + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + pin_memory=False, + collate_fn=collate_fn, + ) + + encoder = retriever.passage_encoder + + # Create empty lists to store the passage embeddings and passage index + passage_embeddings: List[torch.Tensor] = [] + + encoder_device = "cpu" if compute_on_cpu else self.device + + # fucking autocast only wants pure strings like 'cpu' or 'cuda' + # we need to convert the model device to that + device_type_for_autocast = str(encoder_device).split(":")[0] + # autocast doesn't work with CPU and stuff different from bfloat16 + autocast_pssg_mngr = ( + contextlib.nullcontext() + if device_type_for_autocast == "cpu" + else ( + torch.autocast( + device_type=device_type_for_autocast, + dtype=PRECISION_MAP[encoder_precision], + ) + ) + ) + with autocast_pssg_mngr: + # Iterate through each batch in the dataloader + for batch in tqdm(dataloader, desc="Indexing"): + # Move the batch to the device + batch: ModelInputs = batch.to(encoder_device) + # Compute the passage embeddings + passage_outs = encoder(**batch) + # Append the passage embeddings to the list + if self.device == "cpu": + passage_embeddings.extend([c.detach().cpu() for c in passage_outs]) + else: + passage_embeddings.extend([c for c in passage_outs]) + + # move the passage embeddings to the CPU if not already done + passage_embeddings = [c.detach().cpu() for c in passage_embeddings] + # stack it + passage_embeddings: torch.Tensor = torch.stack(passage_embeddings, dim=0) + # convert to float32 for faiss + passage_embeddings.to(PRECISION_MAP["float32"]) + + # index the embeddings + self.embeddings = self._build_faiss_index( + embeddings=passage_embeddings, + index_type=self.index_type, + normalize=self.normalize, + metric=self.metric, + ) + # free up memory from the unused variable + del passage_embeddings + + return self + + @torch.no_grad() + @torch.inference_mode() + def search(self, query: torch.Tensor, k: int = 1) -> list[list[RetrievedSample]]: + k = min(k, self.embeddings.ntotal) + + if self.normalize: + faiss.normalize_L2(query) + if isinstance(query, torch.Tensor) and self.device == "cpu": + query = query.detach().cpu() + # Retrieve the indices of the top k passage embeddings + retriever_out = self.embeddings.search(query, k) + + # get int values (second element of the tuple) + batch_top_k: List[List[int]] = retriever_out[1].detach().cpu().tolist() + # get float values (first element of the tuple) + batch_scores: List[List[float]] = retriever_out[0].detach().cpu().tolist() + # Retrieve the passages corresponding to the indices + batch_passages = [ + [self.documents.get_label_from_index(i) for i in indices] + for indices in batch_top_k + ] + # build the output object + batch_retrieved_samples = [ + [ + RetrievedSample(label=passage, index=index, score=score) + for passage, index, score in zip(passages, indices, scores) + ] + for passages, indices, scores in zip( + batch_passages, batch_top_k, batch_scores + ) + ] + return batch_retrieved_samples + + # def save(self, saving_dir: Union[str, os.PathLike]): + # """ + # Save the indexer to the disk. + + # Args: + # saving_dir (:obj:`Union[str, os.PathLike]`): + # The directory where the indexer will be saved. + # """ + # saving_dir = Path(saving_dir) + # # save the passage embeddings + # index_path = saving_dir / self.INDEX_FILE_NAME + # logger.info(f"Saving passage embeddings to {index_path}") + # faiss.write_index(self.embeddings, str(index_path)) + # # save the passage index + # documents_path = saving_dir / self.DOCUMENTS_FILE_NAME + # logger.info(f"Saving passage index to {documents_path}") + # self.documents.save(documents_path) + + # @classmethod + # def load( + # cls, + # loading_dir: Union[str, os.PathLike], + # device: str = "cpu", + # document_file_name: Optional[str] = None, + # embedding_file_name: Optional[str] = None, + # index_file_name: Optional[str] = None, + # **kwargs, + # ) -> "FaissDocumentIndex": + # loading_dir = Path(loading_dir) + + # document_file_name = document_file_name or cls.DOCUMENTS_FILE_NAME + # embedding_file_name = embedding_file_name or cls.EMBEDDINGS_FILE_NAME + # index_file_name = index_file_name or cls.INDEX_FILE_NAME + + # # load the documents + # documents_path = loading_dir / document_file_name + + # if not documents_path.exists(): + # raise ValueError(f"Document file `{documents_path}` does not exist.") + # logger.info(f"Loading documents from {documents_path}") + # documents = Labels.from_file(documents_path) + + # index = None + # embeddings = None + # # try to load the index directly + # index_path = loading_dir / index_file_name + # if not index_path.exists(): + # # try to load the embeddings + # embedding_path = loading_dir / embedding_file_name + # # run some checks + # if embedding_path.exists(): + # logger.info(f"Loading embeddings from {embedding_path}") + # embeddings = torch.load(embedding_path, map_location="cpu") + # logger.warning( + # f"Index file `{index_path}` and embedding file `{embedding_path}` do not exist." + # ) + # else: + # logger.info(f"Loading index from {index_path}") + # index = faiss.read_index(str(embedding_path)) + + # return cls( + # documents=documents, + # embeddings=embeddings, + # index=index, + # device=device, + # **kwargs, + # ) + + def get_embeddings_from_index( + self, index: int + ) -> Union[torch.Tensor, numpy.ndarray]: + """ + Get the document vector from the index. + + Args: + index (`int`): + The index of the document. + + Returns: + `torch.Tensor`: The document vector. + """ + if self.embeddings is None: + raise ValueError( + "The documents must be indexed before they can be retrieved." + ) + if index >= self.embeddings.ntotal: + raise ValueError( + f"The index {index} is out of bounds. The maximum index is {self.embeddings.ntotal}." + ) + return self.embeddings.reconstruct(index) diff --git a/relik/retriever/indexers/inmemory.py b/relik/retriever/indexers/inmemory.py new file mode 100644 index 0000000000000000000000000000000000000000..42926991df7fecd3451dbf18c6b52b67983070b9 --- /dev/null +++ b/relik/retriever/indexers/inmemory.py @@ -0,0 +1,275 @@ +import contextlib +import logging +import os +from typing import Callable, List, Optional, Tuple, Union + +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm + +from relik.common.log import get_logger +from relik.retriever.common.model_inputs import ModelInputs +from relik.retriever.data.base.datasets import BaseDataset +from relik.retriever.data.labels import Labels +from relik.retriever.indexers.base import BaseDocumentIndex +from relik.retriever.pytorch_modules import PRECISION_MAP, RetrievedSample + +logger = get_logger(__name__, level=logging.INFO) + + +class InMemoryDocumentIndex(BaseDocumentIndex): + DOCUMENTS_FILE_NAME = "documents.json" + EMBEDDINGS_FILE_NAME = "embeddings.pt" + + def __init__( + self, + documents: Union[str, List[str], Labels, os.PathLike, List[os.PathLike]] = None, + embeddings: Optional[torch.Tensor] = None, + device: str = "cpu", + precision: Optional[str] = None, + name_or_dir: Optional[Union[str, os.PathLike]] = None, + *args, + **kwargs, + ) -> None: + """ + An in-memory indexer. + + Args: + documents (:obj:`Union[List[str], PassageManager]`): + The documents to be indexed. + embeddings (:obj:`Optional[torch.Tensor]`, `optional`, defaults to :obj:`None`): + The embeddings of the documents. + device (:obj:`str`, `optional`, defaults to "cpu"): + The device to be used for storing the embeddings. + """ + + super().__init__(documents, embeddings, name_or_dir) + + if embeddings is not None and documents is not None: + logger.info("Both documents and embeddings are provided.") + if documents.get_label_size() != embeddings.shape[0]: + raise ValueError( + "The number of documents and embeddings must be the same." + ) + + # embeddings of the documents + self.embeddings = embeddings + # does this do anything? + del embeddings + # convert the embeddings to the desired precision + if precision is not None: + if ( + self.embeddings is not None + and self.embeddings.dtype != PRECISION_MAP[precision] + ): + logger.info( + f"Index vectors are of type {self.embeddings.dtype}. " + f"Converting to {PRECISION_MAP[precision]}." + ) + self.embeddings = self.embeddings.to(PRECISION_MAP[precision]) + # move the embeddings to the desired device + if self.embeddings is not None and not self.embeddings.device == device: + self.embeddings = self.embeddings.to(device) + + # device to store the embeddings + self.device = device + # precision to be used for the embeddings + self.precision = precision + + @torch.no_grad() + @torch.inference_mode() + def index( + self, + retriever, + documents: Optional[List[str]] = None, + batch_size: int = 32, + num_workers: int = 4, + max_length: Optional[int] = None, + collate_fn: Optional[Callable] = None, + encoder_precision: Optional[Union[str, int]] = None, + compute_on_cpu: bool = False, + force_reindex: bool = False, + add_to_existing_index: bool = False, + ) -> "InMemoryDocumentIndex": + """ + Index the documents using the encoder. + + Args: + retriever (:obj:`torch.nn.Module`): + The encoder to be used for indexing. + documents (:obj:`List[str]`, `optional`, defaults to :obj:`None`): + The documents to be indexed. + batch_size (:obj:`int`, `optional`, defaults to 32): + The batch size to be used for indexing. + num_workers (:obj:`int`, `optional`, defaults to 4): + The number of workers to be used for indexing. + max_length (:obj:`int`, `optional`, defaults to None): + The maximum length of the input to the encoder. + collate_fn (:obj:`Callable`, `optional`, defaults to None): + The collate function to be used for batching. + encoder_precision (:obj:`Union[str, int]`, `optional`, defaults to None): + The precision to be used for the encoder. + compute_on_cpu (:obj:`bool`, `optional`, defaults to False): + Whether to compute the embeddings on CPU. + force_reindex (:obj:`bool`, `optional`, defaults to False): + Whether to force reindexing. + add_to_existing_index (:obj:`bool`, `optional`, defaults to False): + Whether to add the new documents to the existing index. + + Returns: + :obj:`InMemoryIndexer`: The indexer object. + """ + + if documents is None and self.documents is None: + raise ValueError("Documents must be provided.") + + if self.embeddings is not None and not force_reindex: + logger.info( + "Embeddings are already present and `force_reindex` is `False`. Skipping indexing." + ) + if documents is None: + return self + + if collate_fn is None: + tokenizer = retriever.passage_tokenizer + + def collate_fn(x): + return ModelInputs( + tokenizer( + x, + padding=True, + return_tensors="pt", + truncation=True, + max_length=max_length or tokenizer.model_max_length, + ) + ) + + if force_reindex: + if documents is not None: + self.documents.add_labels(documents) + data = [k for k in self.documents.get_labels()] + + else: + if documents is not None: + data = [k for k in Labels(documents).get_labels()] + else: + return self + + # if force_reindex: + # data = [k for k in self.documents.get_labels()] + + dataloader = DataLoader( + BaseDataset(name="passage", data=data), + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + pin_memory=False, + collate_fn=collate_fn, + ) + + encoder = retriever.passage_encoder + + # Create empty lists to store the passage embeddings and passage index + passage_embeddings: List[torch.Tensor] = [] + + encoder_device = "cpu" if compute_on_cpu else self.device + + # fucking autocast only wants pure strings like 'cpu' or 'cuda' + # we need to convert the model device to that + device_type_for_autocast = str(encoder_device).split(":")[0] + # autocast doesn't work with CPU and stuff different from bfloat16 + autocast_pssg_mngr = ( + contextlib.nullcontext() + if device_type_for_autocast == "cpu" + else ( + torch.autocast( + device_type=device_type_for_autocast, + dtype=PRECISION_MAP[encoder_precision], + ) + ) + ) + with autocast_pssg_mngr: + # Iterate through each batch in the dataloader + for batch in tqdm(dataloader, desc="Indexing"): + # Move the batch to the device + batch: ModelInputs = batch.to(encoder_device) + # Compute the passage embeddings + passage_outs = encoder(**batch).pooler_output + # Append the passage embeddings to the list + if self.device == "cpu": + passage_embeddings.extend([c.detach().cpu() for c in passage_outs]) + else: + passage_embeddings.extend([c for c in passage_outs]) + + # move the passage embeddings to the CPU if not already done + # the move to cpu and then to gpu is needed to avoid OOM when using mixed precision + if not self.device == "cpu": # this if is to avoid unnecessary moves + passage_embeddings = [c.detach().cpu() for c in passage_embeddings] + # stack it + passage_embeddings: torch.Tensor = torch.stack(passage_embeddings, dim=0) + # move the passage embeddings to the gpu if needed + if not self.device == "cpu": + passage_embeddings = passage_embeddings.to(PRECISION_MAP[self.precision]) + passage_embeddings = passage_embeddings.to(self.device) + self.embeddings = passage_embeddings + + # free up memory from the unused variable + del passage_embeddings + + return self + + @torch.no_grad() + @torch.inference_mode() + def search(self, query: torch.Tensor, k: int = 1) -> list[list[RetrievedSample]]: + """ + Search the documents using the query. + + Args: + query (:obj:`torch.Tensor`): + The query to be used for searching. + k (:obj:`int`, `optional`, defaults to 1): + The number of documents to be retrieved. + + Returns: + :obj:`List[RetrievedSample]`: The retrieved documents. + """ + # fucking autocast only wants pure strings like 'cpu' or 'cuda' + # we need to convert the model device to that + device_type_for_autocast = str(self.device).split(":")[0] + # autocast doesn't work with CPU and stuff different from bfloat16 + autocast_pssg_mngr = ( + contextlib.nullcontext() + if device_type_for_autocast == "cpu" + else ( + torch.autocast( + device_type=device_type_for_autocast, + dtype=self.embeddings.dtype, + ) + ) + ) + with autocast_pssg_mngr: + similarity = torch.matmul(query, self.embeddings.T) + # Retrieve the indices of the top k passage embeddings + retriever_out: Tuple = torch.topk( + similarity, k=min(k, similarity.shape[-1]), dim=1 + ) + # get int values + batch_top_k: List[List[int]] = retriever_out.indices.detach().cpu().tolist() + # get float values + batch_scores: List[List[float]] = retriever_out.values.detach().cpu().tolist() + # Retrieve the passages corresponding to the indices + batch_passages = [ + [self.documents.get_label_from_index(i) for i in indices] + for indices in batch_top_k + ] + # build the output object + batch_retrieved_samples = [ + [ + RetrievedSample(label=passage, index=index, score=score) + for passage, index, score in zip(passages, indices, scores) + ] + for passages, indices, scores in zip( + batch_passages, batch_top_k, batch_scores + ) + ] + return batch_retrieved_samples diff --git a/relik/retriever/lightning_modules/__init__.py b/relik/retriever/lightning_modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/relik/retriever/lightning_modules/pl_data_modules.py b/relik/retriever/lightning_modules/pl_data_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..8c69f5f291789cb7b473dd8466f7373b1a010e8a --- /dev/null +++ b/relik/retriever/lightning_modules/pl_data_modules.py @@ -0,0 +1,121 @@ +from typing import Any, List, Optional, Sequence, Union + +import hydra +import lightning as pl +import torch +from lightning.pytorch.utilities.types import EVAL_DATALOADERS +from omegaconf import DictConfig +from torch.utils.data import DataLoader + +from relik.common.log import get_logger +from relik.retriever.data.datasets import GoldenRetrieverDataset + +logger = get_logger() + + +class GoldenRetrieverPLDataModule(pl.LightningDataModule): + def __init__( + self, + train_dataset: Optional[GoldenRetrieverDataset] = None, + val_datasets: Optional[Sequence[GoldenRetrieverDataset]] = None, + test_datasets: Optional[Sequence[GoldenRetrieverDataset]] = None, + num_workers: Optional[Union[DictConfig, int]] = None, + datasets: Optional[DictConfig] = None, + *args, + **kwargs, + ): + super().__init__() + self.datasets = datasets + if num_workers is None: + num_workers = 0 + if isinstance(num_workers, int): + num_workers = DictConfig( + {"train": num_workers, "val": num_workers, "test": num_workers} + ) + self.num_workers = num_workers + # data + self.train_dataset: Optional[GoldenRetrieverDataset] = train_dataset + self.val_datasets: Optional[Sequence[GoldenRetrieverDataset]] = val_datasets + self.test_datasets: Optional[Sequence[GoldenRetrieverDataset]] = test_datasets + + def prepare_data(self, *args, **kwargs): + """ + Method for preparing the data before the training. This method is called only once. + It is used to download the data, tokenize the data, etc. + """ + pass + + def setup(self, stage: Optional[str] = None): + if stage == "fit" or stage is None: + # usually there is only one dataset for train + # if you need more train loader, you can follow + # the same logic as val and test datasets + if self.train_dataset is None: + self.train_dataset = hydra.utils.instantiate(self.datasets.train) + self.val_datasets = [ + hydra.utils.instantiate(dataset_cfg) + for dataset_cfg in self.datasets.val + ] + if stage == "test": + if self.test_datasets is None: + self.test_datasets = [ + hydra.utils.instantiate(dataset_cfg) + for dataset_cfg in self.datasets.test + ] + + def train_dataloader(self, *args, **kwargs) -> DataLoader: + torch_dataset = self.train_dataset.to_torch_dataset() + return DataLoader( + # self.train_dataset.to_torch_dataset(), + torch_dataset, + shuffle=False, + batch_size=None, + num_workers=self.num_workers.train, + pin_memory=True, + collate_fn=lambda x: x, + ) + + def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: + dataloaders = [] + for dataset in self.val_datasets: + torch_dataset = dataset.to_torch_dataset() + dataloaders.append( + DataLoader( + torch_dataset, + shuffle=False, + batch_size=None, + num_workers=self.num_workers.val, + pin_memory=True, + collate_fn=lambda x: x, + ) + ) + return dataloaders + + def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: + dataloaders = [] + for dataset in self.test_datasets: + torch_dataset = dataset.to_torch_dataset() + dataloaders.append( + DataLoader( + torch_dataset, + shuffle=False, + batch_size=None, + num_workers=self.num_workers.test, + pin_memory=True, + collate_fn=lambda x: x, + ) + ) + return dataloaders + + def predict_dataloader(self) -> EVAL_DATALOADERS: + raise NotImplementedError + + def transfer_batch_to_device( + self, batch: Any, device: torch.device, dataloader_idx: int + ) -> Any: + return super().transfer_batch_to_device(batch, device, dataloader_idx) + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(" f"{self.datasets=}, " f"{self.num_workers=}, " + ) diff --git a/relik/retriever/lightning_modules/pl_modules.py b/relik/retriever/lightning_modules/pl_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..1ef6ace6eb8ac95d92e16b1e28dc50d2492b33f4 --- /dev/null +++ b/relik/retriever/lightning_modules/pl_modules.py @@ -0,0 +1,123 @@ +from typing import Any, Union + +import hydra +import lightning as pl +import torch +from omegaconf import DictConfig + +from relik.retriever.common.model_inputs import ModelInputs + + +class GoldenRetrieverPLModule(pl.LightningModule): + def __init__( + self, + model: Union[torch.nn.Module, DictConfig], + optimizer: Union[torch.optim.Optimizer, DictConfig], + lr_scheduler: Union[torch.optim.lr_scheduler.LRScheduler, DictConfig] = None, + *args, + **kwargs, + ) -> None: + super().__init__() + self.save_hyperparameters(ignore=["model"]) + if isinstance(model, DictConfig): + self.model = hydra.utils.instantiate(model) + else: + self.model = model + + self.optimizer_config = optimizer + self.lr_scheduler_config = lr_scheduler + + def forward(self, **kwargs) -> dict: + """ + Method for the forward pass. + 'training_step', 'validation_step' and 'test_step' should call + this method in order to compute the output predictions and the loss. + + Returns: + output_dict: forward output containing the predictions (output logits ecc...) and the loss if any. + + """ + return self.model(**kwargs) + + def training_step(self, batch: ModelInputs, batch_idx: int) -> torch.Tensor: + forward_output = self.forward(**batch, return_loss=True) + self.log( + "loss", + forward_output["loss"], + batch_size=batch["questions"]["input_ids"].size(0), + prog_bar=True, + ) + return forward_output["loss"] + + def validation_step(self, batch: ModelInputs, batch_idx: int) -> None: + forward_output = self.forward(**batch, return_loss=True) + self.log( + "val_loss", + forward_output["loss"], + batch_size=batch["questions"]["input_ids"].size(0), + ) + + def test_step(self, batch: ModelInputs, batch_idx: int) -> Any: + forward_output = self.forward(**batch, return_loss=True) + self.log( + "test_loss", + forward_output["loss"], + batch_size=batch["questions"]["input_ids"].size(0), + ) + + def configure_optimizers(self): + if isinstance(self.optimizer_config, DictConfig): + param_optimizer = list(self.named_parameters()) + no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + { + "params": [ + p for n, p in param_optimizer if "layer_norm_layer" in n + ], + "weight_decay": self.hparams.optimizer.weight_decay, + "lr": 1e-4, + }, + { + "params": [ + p + for n, p in param_optimizer + if all(nd not in n for nd in no_decay) + and "layer_norm_layer" not in n + ], + "weight_decay": self.hparams.optimizer.weight_decay, + }, + { + "params": [ + p + for n, p in param_optimizer + if "layer_norm_layer" not in n + and any(nd in n for nd in no_decay) + ], + "weight_decay": 0.0, + }, + ] + optimizer = hydra.utils.instantiate( + self.optimizer_config, + # params=self.parameters(), + params=optimizer_grouped_parameters, + _convert_="partial", + ) + else: + optimizer = self.optimizer_config + + if self.lr_scheduler_config is None: + return optimizer + + if isinstance(self.lr_scheduler_config, DictConfig): + lr_scheduler = hydra.utils.instantiate( + self.lr_scheduler_config, optimizer=optimizer + ) + else: + lr_scheduler = self.lr_scheduler_config + + lr_scheduler_config = { + "scheduler": lr_scheduler, + "interval": "step", + "frequency": 1, + } + return [optimizer], [lr_scheduler_config] diff --git a/relik/retriever/pytorch_modules/__init__.py b/relik/retriever/pytorch_modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..01752b8aa79367a7bcdc2d18438ae87bebdd87f2 --- /dev/null +++ b/relik/retriever/pytorch_modules/__init__.py @@ -0,0 +1,28 @@ +from dataclasses import dataclass + +import torch + +PRECISION_MAP = { + None: torch.float32, + 16: torch.float16, + 32: torch.float32, + "float16": torch.float16, + "float32": torch.float32, + "half": torch.float16, + "float": torch.float32, + "16": torch.float16, + "32": torch.float32, + "fp16": torch.float16, + "fp32": torch.float32, +} + + +@dataclass +class RetrievedSample: + """ + Dataclass for the output of the GoldenRetriever model. + """ + + score: float + index: int + label: str diff --git a/relik/retriever/pytorch_modules/hf.py b/relik/retriever/pytorch_modules/hf.py new file mode 100644 index 0000000000000000000000000000000000000000..b5868d67ce5ed97a1e66c6d1d3a606350e75267a --- /dev/null +++ b/relik/retriever/pytorch_modules/hf.py @@ -0,0 +1,88 @@ +from typing import Tuple, Union + +import torch +from transformers import PretrainedConfig +from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions +from transformers.models.bert.modeling_bert import BertModel + + +class GoldenRetrieverConfig(PretrainedConfig): + model_type = "bert" + + def __init__( + self, + vocab_size=30522, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + layer_norm_eps=1e-12, + pad_token_id=0, + position_embedding_type="absolute", + use_cache=True, + classifier_dropout=None, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.position_embedding_type = position_embedding_type + self.use_cache = use_cache + self.classifier_dropout = classifier_dropout + + +class GoldenRetrieverModel(BertModel): + config_class = GoldenRetrieverConfig + + def __init__(self, config, *args, **kwargs): + super().__init__(config) + self.layer_norm_layer = torch.nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps + ) + + def forward( + self, **kwargs + ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: + attention_mask = kwargs.get("attention_mask", None) + model_outputs = super().forward(**kwargs) + if attention_mask is None: + pooler_output = model_outputs.pooler_output + else: + token_embeddings = model_outputs.last_hidden_state + input_mask_expanded = ( + attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + ) + pooler_output = torch.sum( + token_embeddings * input_mask_expanded, 1 + ) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) + + pooler_output = self.layer_norm_layer(pooler_output) + + if not kwargs.get("return_dict", True): + return (model_outputs[0], pooler_output) + model_outputs[2:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=model_outputs.last_hidden_state, + pooler_output=pooler_output, + past_key_values=model_outputs.past_key_values, + hidden_states=model_outputs.hidden_states, + attentions=model_outputs.attentions, + cross_attentions=model_outputs.cross_attentions, + ) diff --git a/relik/retriever/pytorch_modules/loss.py b/relik/retriever/pytorch_modules/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..643d3a486ca73ca38486094553b357f5e7c28adb --- /dev/null +++ b/relik/retriever/pytorch_modules/loss.py @@ -0,0 +1,34 @@ +from typing import Optional + +import torch +from torch.nn.modules.loss import _WeightedLoss + + +class MultiLabelNCELoss(_WeightedLoss): + __constants__ = ["reduction"] + + def __init__( + self, + weight: Optional[torch.Tensor] = None, + size_average=None, + reduction: Optional[str] = "mean", + ) -> None: + super(MultiLabelNCELoss, self).__init__(weight, size_average, None, reduction) + + def forward( + self, input: torch.Tensor, target: torch.Tensor, ignore_index: int = -100 + ) -> torch.Tensor: + gold_scores = input.masked_fill(~(target.bool()), 0) + gold_scores_sum = gold_scores.sum(-1) # B x C + neg_logits = input.masked_fill(target.bool(), float("-inf")) # B x C x L + neg_log_sum_exp = torch.logsumexp(neg_logits, -1, keepdim=True) # B x C x 1 + norm_term = ( + torch.logaddexp(input, neg_log_sum_exp) + .masked_fill(~(target.bool()), 0) + .sum(-1) + ) + gold_log_probs = gold_scores_sum - norm_term + loss = -gold_log_probs.sum() + if self.reduction == "mean": + loss /= input.size(0) + return loss diff --git a/relik/retriever/pytorch_modules/model.py b/relik/retriever/pytorch_modules/model.py new file mode 100644 index 0000000000000000000000000000000000000000..f02aedd5b43cbd789b87ae4afd417c919b72b129 --- /dev/null +++ b/relik/retriever/pytorch_modules/model.py @@ -0,0 +1,533 @@ +import contextlib +import logging +import os +from dataclasses import dataclass +from pathlib import Path +from typing import Callable, Dict, List, Optional, Union + +import torch +import torch.nn.functional as F +import transformers as tr + +from relik.common.log import get_console_logger, get_logger +from relik.retriever.common.model_inputs import ModelInputs +from relik.retriever.data.labels import Labels +from relik.retriever.indexers.base import BaseDocumentIndex +from relik.retriever.indexers.inmemory import InMemoryDocumentIndex +from relik.retriever.pytorch_modules import PRECISION_MAP, RetrievedSample +from relik.retriever.pytorch_modules.hf import GoldenRetrieverModel + +console_logger = get_console_logger() +logger = get_logger(__name__, level=logging.INFO) + + +@dataclass +class GoldenRetrieverOutput(tr.file_utils.ModelOutput): + """Class for model's outputs.""" + + logits: Optional[torch.FloatTensor] = None + loss: Optional[torch.FloatTensor] = None + question_encodings: Optional[torch.FloatTensor] = None + passages_encodings: Optional[torch.FloatTensor] = None + + +class GoldenRetriever(torch.nn.Module): + def __init__( + self, + question_encoder: Union[str, tr.PreTrainedModel], + loss_type: Optional[torch.nn.Module] = None, + passage_encoder: Optional[Union[str, tr.PreTrainedModel]] = None, + document_index: Optional[Union[str, BaseDocumentIndex]] = None, + question_tokenizer: Optional[Union[str, tr.PreTrainedTokenizer]] = None, + passage_tokenizer: Optional[Union[str, tr.PreTrainedTokenizer]] = None, + device: Optional[Union[str, torch.device]] = None, + precision: Optional[Union[str, int]] = None, + index_precision: Optional[Union[str, int]] = 32, + index_device: Optional[Union[str, torch.device]] = "cpu", + *args, + **kwargs, + ): + super().__init__() + + self.passage_encoder_is_question_encoder = False + # question encoder model + if isinstance(question_encoder, str): + question_encoder = GoldenRetrieverModel.from_pretrained( + question_encoder, **kwargs + ) + self.question_encoder = question_encoder + if passage_encoder is None: + # if no passage encoder is provided, + # share the weights of the question encoder + passage_encoder = question_encoder + # keep track of the fact that the passage encoder is the same as the question encoder + self.passage_encoder_is_question_encoder = True + if isinstance(passage_encoder, str): + passage_encoder = GoldenRetrieverModel.from_pretrained( + passage_encoder, **kwargs + ) + # passage encoder model + self.passage_encoder = passage_encoder + + # loss function + self.loss_type = loss_type + + # indexer stuff + if document_index is None: + # if no indexer is provided, create a new one + document_index = InMemoryDocumentIndex( + device=index_device, precision=index_precision, **kwargs + ) + if isinstance(document_index, str): + document_index = BaseDocumentIndex.from_pretrained( + document_index, device=index_device, precision=index_precision, **kwargs + ) + self.document_index = document_index + + # lazy load the tokenizer for inference + self._question_tokenizer = question_tokenizer + self._passage_tokenizer = passage_tokenizer + + # move the model to the device + self.to(device or torch.device("cpu")) + + # set the precision + self.precision = precision + + def forward( + self, + questions: Optional[Dict[str, torch.Tensor]] = None, + passages: Optional[Dict[str, torch.Tensor]] = None, + labels: Optional[torch.Tensor] = None, + question_encodings: Optional[torch.Tensor] = None, + passages_encodings: Optional[torch.Tensor] = None, + passages_per_question: Optional[List[int]] = None, + return_loss: bool = False, + return_encodings: bool = False, + *args, + **kwargs, + ) -> GoldenRetrieverOutput: + """ + Forward pass of the model. + + Args: + questions (`Dict[str, torch.Tensor]`): + The questions to encode. + passages (`Dict[str, torch.Tensor]`): + The passages to encode. + labels (`torch.Tensor`): + The labels of the sentences. + return_loss (`bool`): + Whether to compute the predictions. + question_encodings (`torch.Tensor`): + The encodings of the questions. + passages_encodings (`torch.Tensor`): + The encodings of the passages. + passages_per_question (`List[int]`): + The number of passages per question. + return_loss (`bool`): + Whether to compute the loss. + return_encodings (`bool`): + Whether to return the encodings. + + Returns: + obj:`torch.Tensor`: The outputs of the model. + """ + if questions is None and question_encodings is None: + raise ValueError( + "Either `questions` or `question_encodings` must be provided" + ) + if passages is None and passages_encodings is None: + raise ValueError( + "Either `passages` or `passages_encodings` must be provided" + ) + + if question_encodings is None: + question_encodings = self.question_encoder(**questions).pooler_output + if passages_encodings is None: + passages_encodings = self.passage_encoder(**passages).pooler_output + + if passages_per_question is not None: + # multiply each question encoding with a passages_per_question encodings + concatenated_passages = torch.stack( + torch.split(passages_encodings, passages_per_question) + ).transpose(1, 2) + if isinstance(self.loss_type, torch.nn.BCEWithLogitsLoss): + # normalize the encodings for cosine similarity + concatenated_passages = F.normalize(concatenated_passages, p=2, dim=2) + question_encodings = F.normalize(question_encodings, p=2, dim=1) + logits = torch.bmm( + question_encodings.unsqueeze(1), concatenated_passages + ).view(question_encodings.shape[0], -1) + else: + if isinstance(self.loss_type, torch.nn.BCEWithLogitsLoss): + # normalize the encodings for cosine similarity + question_encodings = F.normalize(question_encodings, p=2, dim=1) + passages_encodings = F.normalize(passages_encodings, p=2, dim=1) + + logits = torch.matmul(question_encodings, passages_encodings.T) + + output = dict(logits=logits) + + if return_loss and labels is not None: + if self.loss_type is None: + raise ValueError( + "If `return_loss` is set to `True`, `loss_type` must be provided" + ) + if isinstance(self.loss_type, torch.nn.NLLLoss): + labels = labels.argmax(dim=1) + logits = F.log_softmax(logits, dim=1) + if len(question_encodings.size()) > 1: + logits = logits.view(question_encodings.size(0), -1) + + output["loss"] = self.loss_type(logits, labels) + + if return_encodings: + output["question_encodings"] = question_encodings + output["passages_encodings"] = passages_encodings + + return GoldenRetrieverOutput(**output) + + @torch.no_grad() + @torch.inference_mode() + def index( + self, + batch_size: int = 32, + num_workers: int = 4, + max_length: Optional[int] = None, + collate_fn: Optional[Callable] = None, + force_reindex: bool = False, + compute_on_cpu: bool = False, + precision: Optional[Union[str, int]] = None, + ): + """ + Index the passages for later retrieval. + + Args: + batch_size (`int`): + The batch size to use for the indexing. + num_workers (`int`): + The number of workers to use for the indexing. + max_length (`Optional[int]`): + The maximum length of the passages. + collate_fn (`Callable`): + The collate function to use for the indexing. + force_reindex (`bool`): + Whether to force reindexing even if the passages are already indexed. + compute_on_cpu (`bool`): + Whether to move the index to the CPU after the indexing. + precision (`Optional[Union[str, int]]`): + The precision to use for the model. + """ + if self.document_index is None: + raise ValueError( + "The retriever must be initialized with an indexer to index " + "the passages within the retriever." + ) + return self.document_index.index( + retriever=self, + batch_size=batch_size, + num_workers=num_workers, + max_length=max_length, + collate_fn=collate_fn, + encoder_precision=precision or self.precision, + compute_on_cpu=compute_on_cpu, + force_reindex=force_reindex, + ) + + @torch.no_grad() + @torch.inference_mode() + def retrieve( + self, + text: Optional[Union[str, List[str]]] = None, + text_pair: Optional[Union[str, List[str]]] = None, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + k: Optional[int] = None, + max_length: Optional[int] = None, + precision: Optional[Union[str, int]] = None, + ) -> List[List[RetrievedSample]]: + """ + Retrieve the passages for the questions. + + Args: + text (`Optional[Union[str, List[str]]]`): + The questions to retrieve the passages for. + text_pair (`Optional[Union[str, List[str]]]`): + The questions to retrieve the passages for. + input_ids (`torch.Tensor`): + The input ids of the questions. + attention_mask (`torch.Tensor`): + The attention mask of the questions. + token_type_ids (`torch.Tensor`): + The token type ids of the questions. + k (`int`): + The number of top passages to retrieve. + max_length (`Optional[int]`): + The maximum length of the questions. + precision (`Optional[Union[str, int]]`): + The precision to use for the model. + + Returns: + `List[List[RetrievedSample]]`: The retrieved passages and their indices. + """ + if self.document_index is None: + raise ValueError( + "The indexer must be indexed before it can be used within the retriever." + ) + if text is None and input_ids is None: + raise ValueError( + "Either `text` or `input_ids` must be provided to retrieve the passages." + ) + + if text is not None: + if isinstance(text, str): + text = [text] + if text_pair is not None and isinstance(text_pair, str): + text_pair = [text_pair] + tokenizer = self.question_tokenizer + model_inputs = ModelInputs( + tokenizer( + text, + text_pair=text_pair, + padding=True, + return_tensors="pt", + truncation=True, + max_length=max_length or tokenizer.model_max_length, + ) + ) + else: + model_inputs = ModelInputs(dict(input_ids=input_ids)) + if attention_mask is not None: + model_inputs["attention_mask"] = attention_mask + if token_type_ids is not None: + model_inputs["token_type_ids"] = token_type_ids + + model_inputs.to(self.device) + + # fucking autocast only wants pure strings like 'cpu' or 'cuda' + # we need to convert the model device to that + device_type_for_autocast = str(self.device).split(":")[0] + # autocast doesn't work with CPU and stuff different from bfloat16 + autocast_pssg_mngr = ( + contextlib.nullcontext() + if device_type_for_autocast == "cpu" + else ( + torch.autocast( + device_type=device_type_for_autocast, + dtype=PRECISION_MAP[precision], + ) + ) + ) + with autocast_pssg_mngr: + question_encodings = self.question_encoder(**model_inputs).pooler_output + + # TODO: fix if encoder and index are on different device + return self.document_index.search(question_encodings, k) + + def get_index_from_passage(self, passage: str) -> int: + """ + Get the index of the passage. + + Args: + passage (`str`): + The passage to get the index for. + + Returns: + `int`: The index of the passage. + """ + if self.document_index is None: + raise ValueError( + "The passages must be indexed before they can be retrieved." + ) + return self.document_index.get_index_from_passage(passage) + + def get_passage_from_index(self, index: int) -> str: + """ + Get the passage from the index. + + Args: + index (`int`): + The index of the passage. + + Returns: + `str`: The passage. + """ + if self.document_index is None: + raise ValueError( + "The passages must be indexed before they can be retrieved." + ) + return self.document_index.get_passage_from_index(index) + + def get_vector_from_index(self, index: int) -> torch.Tensor: + """ + Get the passage vector from the index. + + Args: + index (`int`): + The index of the passage. + + Returns: + `torch.Tensor`: The passage vector. + """ + if self.document_index is None: + raise ValueError( + "The passages must be indexed before they can be retrieved." + ) + return self.document_index.get_embeddings_from_index(index) + + def get_vector_from_passage(self, passage: str) -> torch.Tensor: + """ + Get the passage vector from the passage. + + Args: + passage (`str`): + The passage. + + Returns: + `torch.Tensor`: The passage vector. + """ + if self.document_index is None: + raise ValueError( + "The passages must be indexed before they can be retrieved." + ) + return self.document_index.get_embeddings_from_passage(passage) + + @property + def passage_embeddings(self) -> torch.Tensor: + """ + The passage embeddings. + """ + return self._passage_embeddings + + @property + def passage_index(self) -> Labels: + """ + The passage index. + """ + return self._passage_index + + @property + def device(self) -> torch.device: + """ + The device of the model. + """ + return next(self.parameters()).device + + @property + def question_tokenizer(self) -> tr.PreTrainedTokenizer: + """ + The question tokenizer. + """ + if self._question_tokenizer: + return self._question_tokenizer + + if ( + self.question_encoder.config.name_or_path + == self.question_encoder.config.name_or_path + ): + if not self._question_tokenizer: + self._question_tokenizer = tr.AutoTokenizer.from_pretrained( + self.question_encoder.config.name_or_path + ) + self._passage_tokenizer = self._question_tokenizer + return self._question_tokenizer + + if not self._question_tokenizer: + self._question_tokenizer = tr.AutoTokenizer.from_pretrained( + self.question_encoder.config.name_or_path + ) + return self._question_tokenizer + + @property + def passage_tokenizer(self) -> tr.PreTrainedTokenizer: + """ + The passage tokenizer. + """ + if self._passage_tokenizer: + return self._passage_tokenizer + + if ( + self.question_encoder.config.name_or_path + == self.passage_encoder.config.name_or_path + ): + if not self._question_tokenizer: + self._question_tokenizer = tr.AutoTokenizer.from_pretrained( + self.question_encoder.config.name_or_path + ) + self._passage_tokenizer = self._question_tokenizer + return self._passage_tokenizer + + if not self._passage_tokenizer: + self._passage_tokenizer = tr.AutoTokenizer.from_pretrained( + self.passage_encoder.config.name_or_path + ) + return self._passage_tokenizer + + def save_pretrained( + self, + output_dir: Union[str, os.PathLike], + question_encoder_name: Optional[str] = None, + passage_encoder_name: Optional[str] = None, + document_index_name: Optional[str] = None, + push_to_hub: bool = False, + **kwargs, + ): + """ + Save the retriever to a directory. + + Args: + output_dir (`str`): + The directory to save the retriever to. + question_encoder_name (`Optional[str]`): + The name of the question encoder. + passage_encoder_name (`Optional[str]`): + The name of the passage encoder. + document_index_name (`Optional[str]`): + The name of the document index. + push_to_hub (`bool`): + Whether to push the model to the hub. + """ + + # create the output directory + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + logger.info(f"Saving retriever to {output_dir}") + + question_encoder_name = question_encoder_name or "question_encoder" + passage_encoder_name = passage_encoder_name or "passage_encoder" + document_index_name = document_index_name or "document_index" + + logger.info( + f"Saving question encoder state to {output_dir / question_encoder_name}" + ) + # self.question_encoder.config._name_or_path = question_encoder_name + self.question_encoder.register_for_auto_class() + self.question_encoder.save_pretrained( + output_dir / question_encoder_name, push_to_hub=push_to_hub, **kwargs + ) + self.question_tokenizer.save_pretrained( + output_dir / question_encoder_name, push_to_hub=push_to_hub, **kwargs + ) + if not self.passage_encoder_is_question_encoder: + logger.info( + f"Saving passage encoder state to {output_dir / passage_encoder_name}" + ) + # self.passage_encoder.config._name_or_path = passage_encoder_name + self.passage_encoder.register_for_auto_class() + self.passage_encoder.save_pretrained( + output_dir / passage_encoder_name, push_to_hub=push_to_hub, **kwargs + ) + self.passage_tokenizer.save_pretrained( + output_dir / passage_encoder_name, push_to_hub=push_to_hub, **kwargs + ) + + if self.document_index is not None: + # save the indexer + self.document_index.save_pretrained( + output_dir / document_index_name, push_to_hub=push_to_hub, **kwargs + ) + + logger.info("Saving retriever to disk done.") diff --git a/relik/retriever/pytorch_modules/optim.py b/relik/retriever/pytorch_modules/optim.py new file mode 100644 index 0000000000000000000000000000000000000000..e815acf44948a4835d7b093c74e6ca8cf8539dd1 --- /dev/null +++ b/relik/retriever/pytorch_modules/optim.py @@ -0,0 +1,213 @@ +import math + +import torch +from torch.optim import Optimizer + + +class RAdamW(Optimizer): + r"""Implements RAdamW algorithm. + + RAdam from `On the Variance of the Adaptive Learning Rate and Beyond + `_ + + * `Adam: A Method for Stochastic Optimization + `_ + * `Decoupled Weight Decay Regularization + `_ + * `On the Convergence of Adam and Beyond + `_ + * `On the Variance of the Adaptive Learning Rate and Beyond + `_ + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + """ + + def __init__( + self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2 + ): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + super(RAdamW, self).__init__(params, defaults) + + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + + # Perform optimization step + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError( + "Adam does not support sparse gradients, please consider SparseAdam instead" + ) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state["step"] = 0 + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(p.data) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like(p.data) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + beta1, beta2 = group["betas"] + eps = group["eps"] + lr = group["lr"] + if "rho_inf" not in group: + group["rho_inf"] = 2 / (1 - beta2) - 1 + rho_inf = group["rho_inf"] + + state["step"] += 1 + t = state["step"] + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + rho_t = rho_inf - ((2 * t * (beta2**t)) / (1 - beta2**t)) + + # Perform stepweight decay + p.data.mul_(1 - lr * group["weight_decay"]) + + if rho_t >= 5: + var = exp_avg_sq.sqrt().add_(eps) + r = math.sqrt( + (1 - beta2**t) + * ((rho_t - 4) * (rho_t - 2) * rho_inf) + / ((rho_inf - 4) * (rho_inf - 2) * rho_t) + ) + + p.data.addcdiv_(exp_avg, var, value=-lr * r / (1 - beta1**t)) + else: + p.data.add_(exp_avg, alpha=-lr / (1 - beta1**t)) + + return loss + + +# from typing import List +# import collections + +# import torch +# import transformers +# from classy.optim.factories import Factory +# from transformers import AdamW + + +# class ElectraOptimizer(Factory): +# def __init__( +# self, +# lr: float, +# warmup_steps: int, +# total_steps: int, +# weight_decay: float, +# lr_decay: float, +# no_decay_params: List[str], +# ): +# self.lr = lr +# self.warmup_steps = warmup_steps +# self.total_steps = total_steps +# self.weight_decay = weight_decay +# self.lr_decay = lr_decay +# self.no_decay_params = no_decay_params + +# def group_layers(self, module) -> dict: +# grouped_layers = collections.defaultdict(list) +# module_named_parameters = list(module.named_parameters()) +# for ln, lp in module_named_parameters: +# if "embeddings" in ln: +# grouped_layers["embeddings"].append((ln, lp)) +# elif "encoder.layer" in ln: +# layer_num = ln.replace("transformer_model.encoder.layer.", "") +# layer_num = layer_num[0 : layer_num.index(".")] +# grouped_layers[layer_num].append((ln, lp)) +# else: +# grouped_layers["head"].append((ln, lp)) + +# depth = len(grouped_layers) - 1 +# final_dict = dict() +# for key, value in grouped_layers.items(): +# if key == "head": +# final_dict[0] = value +# elif key == "embeddings": +# final_dict[depth] = value +# else: +# # -1 because layer number starts from zero +# final_dict[depth - int(key) - 1] = value + +# assert len(module_named_parameters) == sum( +# len(v) for _, v in final_dict.items() +# ) + +# return final_dict + +# def group_params(self, module) -> list: +# optimizer_grouped_params = [] +# for inverse_depth, layer in self.group_layers(module).items(): +# layer_lr = self.lr * (self.lr_decay**inverse_depth) +# layer_wd_params = { +# "params": [ +# lp +# for ln, lp in layer +# if not any(nd in ln for nd in self.no_decay_params) +# ], +# "weight_decay": self.weight_decay, +# "lr": layer_lr, +# } +# layer_no_wd_params = { +# "params": [ +# lp +# for ln, lp in layer +# if any(nd in ln for nd in self.no_decay_params) +# ], +# "weight_decay": 0, +# "lr": layer_lr, +# } + +# if len(layer_wd_params) != 0: +# optimizer_grouped_params.append(layer_wd_params) +# if len(layer_no_wd_params) != 0: +# optimizer_grouped_params.append(layer_no_wd_params) + +# return optimizer_grouped_params + +# def __call__(self, module: torch.nn.Module): +# optimizer_grouped_parameters = self.group_params(module) +# optimizer = AdamW(optimizer_grouped_parameters, lr=self.lr) +# scheduler = transformers.get_linear_schedule_with_warmup( +# optimizer, self.warmup_steps, self.total_steps +# ) +# return { +# "optimizer": optimizer, +# "lr_scheduler": { +# "scheduler": scheduler, +# "interval": "step", +# "frequency": 1, +# }, +# } diff --git a/relik/retriever/pytorch_modules/scheduler.py b/relik/retriever/pytorch_modules/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..5edd2433612b27c1a91fe189e57c4e2d41c462b2 --- /dev/null +++ b/relik/retriever/pytorch_modules/scheduler.py @@ -0,0 +1,54 @@ +import torch +from torch.optim.lr_scheduler import LRScheduler + + +class LinearSchedulerWithWarmup(LRScheduler): + def __init__( + self, + optimizer: torch.optim.Optimizer, + num_warmup_steps: int, + num_training_steps: int, + last_epoch: int = -1, + verbose: bool = False, + **kwargs, + ): + self.num_warmup_steps = num_warmup_steps + self.num_training_steps = num_training_steps + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + def scheduler_fn(current_step): + if current_step < self.num_warmup_steps: + return current_step / max(1, self.num_warmup_steps) + return max( + 0.0, + float(self.num_training_steps - current_step) + / float(max(1, self.num_training_steps - self.num_warmup_steps)), + ) + + return [base_lr * scheduler_fn(self.last_epoch) for base_lr in self.base_lrs] + + +class LinearScheduler(LRScheduler): + def __init__( + self, + optimizer: torch.optim.Optimizer, + num_training_steps: int, + last_epoch: int = -1, + verbose: bool = False, + **kwargs, + ): + self.num_training_steps = num_training_steps + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + def scheduler_fn(current_step): + # if current_step < self.num_warmup_steps: + # return current_step / max(1, self.num_warmup_steps) + return max( + 0.0, + float(self.num_training_steps - current_step) + / float(max(1, self.num_training_steps)), + ) + + return [base_lr * scheduler_fn(self.last_epoch) for base_lr in self.base_lrs] diff --git a/relik/retriever/trainer/__init__.py b/relik/retriever/trainer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f1b18bf79091418217ae2bb782c3796dfa8b5b56 --- /dev/null +++ b/relik/retriever/trainer/__init__.py @@ -0,0 +1 @@ +from relik.retriever.trainer.train import RetrieverTrainer diff --git a/relik/retriever/trainer/train.py b/relik/retriever/trainer/train.py new file mode 100644 index 0000000000000000000000000000000000000000..dc55298c178e7f519cf2f724e9a0881a168c9158 --- /dev/null +++ b/relik/retriever/trainer/train.py @@ -0,0 +1,667 @@ +import os +from pathlib import Path +from typing import List, Optional, Union + +import hydra +import lightning as pl +import omegaconf +import torch +from lightning import Trainer +from lightning.pytorch.callbacks import ( + EarlyStopping, + LearningRateMonitor, + ModelCheckpoint, + ModelSummary, +) +from lightning.pytorch.loggers import WandbLogger +from omegaconf import OmegaConf +from rich.pretty import pprint + +from relik.common.log import get_console_logger +from relik.retriever.callbacks.evaluation_callbacks import ( + AvgRankingEvaluationCallback, + RecallAtKEvaluationCallback, +) +from relik.retriever.callbacks.prediction_callbacks import ( + GoldenRetrieverPredictionCallback, + NegativeAugmentationCallback, +) +from relik.retriever.callbacks.utils_callbacks import ( + FreeUpIndexerVRAMCallback, + SavePredictionsCallback, + SaveRetrieverCallback, +) +from relik.retriever.data.datasets import GoldenRetrieverDataset +from relik.retriever.indexers.base import BaseDocumentIndex +from relik.retriever.lightning_modules.pl_data_modules import ( + GoldenRetrieverPLDataModule, +) +from relik.retriever.lightning_modules.pl_modules import GoldenRetrieverPLModule +from relik.retriever.pytorch_modules.loss import MultiLabelNCELoss +from relik.retriever.pytorch_modules.model import GoldenRetriever +from relik.retriever.pytorch_modules.optim import RAdamW +from relik.retriever.pytorch_modules.scheduler import ( + LinearScheduler, + LinearSchedulerWithWarmup, +) + +logger = get_console_logger() + + +class RetrieverTrainer: + def __init__( + self, + retriever: GoldenRetriever, + train_dataset: GoldenRetrieverDataset, + val_dataset: Union[GoldenRetrieverDataset, list[GoldenRetrieverDataset]], + test_dataset: Optional[ + Union[GoldenRetrieverDataset, list[GoldenRetrieverDataset]] + ] = None, + num_workers: int = 4, + optimizer: torch.optim.Optimizer = RAdamW, + lr: float = 1e-5, + weight_decay: float = 0.01, + lr_scheduler: torch.optim.lr_scheduler.LRScheduler = LinearScheduler, + num_warmup_steps: int = 0, + loss: torch.nn.Module = MultiLabelNCELoss, + callbacks: Optional[list] = None, + accelerator: str = "auto", + devices: int = 1, + num_nodes: int = 1, + strategy: str = "auto", + accumulate_grad_batches: int = 1, + gradient_clip_val: float = 1.0, + val_check_interval: float = 1.0, + check_val_every_n_epoch: int = 1, + max_steps: Optional[int] = None, + max_epochs: Optional[int] = None, + # checkpoint_path: Optional[Union[str, os.PathLike]] = None, + deterministic: bool = True, + fast_dev_run: bool = False, + precision: int = 16, + reload_dataloaders_every_n_epochs: int = 1, + top_ks: Union[int, List[int]] = 100, + # early stopping parameters + early_stopping: bool = True, + early_stopping_patience: int = 10, + # wandb logger parameters + log_to_wandb: bool = True, + wandb_entity: Optional[str] = None, + wandb_experiment_name: Optional[str] = None, + wandb_project_name: Optional[str] = None, + wandb_save_dir: Optional[Union[str, os.PathLike]] = None, + wandb_log_model: bool = True, + wandb_offline_mode: bool = False, + wandb_watch: str = "all", + # checkpoint parameters + model_checkpointing: bool = True, + chekpoint_dir: Optional[Union[str, os.PathLike]] = None, + checkpoint_filename: Optional[Union[str, os.PathLike]] = None, + save_top_k: int = 1, + save_last: bool = False, + # prediction callback parameters + prediction_batch_size: int = 128, + # hard negatives callback parameters + max_hard_negatives_to_mine: int = 15, + hard_negatives_threshold: float = 0.0, + metrics_to_monitor_for_hard_negatives: Optional[str] = None, + mine_hard_negatives_with_probability: float = 1.0, + # other parameters + seed: int = 42, + float32_matmul_precision: str = "medium", + **kwargs, + ): + # put all the parameters in the class + self.retriever = retriever + # datasets + self.train_dataset = train_dataset + self.val_dataset = val_dataset + self.test_dataset = test_dataset + self.num_workers = num_workers + # trainer parameters + self.optimizer = optimizer + self.lr = lr + self.weight_decay = weight_decay + self.lr_scheduler = lr_scheduler + self.num_warmup_steps = num_warmup_steps + self.loss = loss + self.callbacks = callbacks + self.accelerator = accelerator + self.devices = devices + self.num_nodes = num_nodes + self.strategy = strategy + self.accumulate_grad_batches = accumulate_grad_batches + self.gradient_clip_val = gradient_clip_val + self.val_check_interval = val_check_interval + self.check_val_every_n_epoch = check_val_every_n_epoch + self.max_steps = max_steps + self.max_epochs = max_epochs + # self.checkpoint_path = checkpoint_path + self.deterministic = deterministic + self.fast_dev_run = fast_dev_run + self.precision = precision + self.reload_dataloaders_every_n_epochs = reload_dataloaders_every_n_epochs + self.top_ks = top_ks + # early stopping parameters + self.early_stopping = early_stopping + self.early_stopping_patience = early_stopping_patience + # wandb logger parameters + self.log_to_wandb = log_to_wandb + self.wandb_entity = wandb_entity + self.wandb_experiment_name = wandb_experiment_name + self.wandb_project_name = wandb_project_name + self.wandb_save_dir = wandb_save_dir + self.wandb_log_model = wandb_log_model + self.wandb_offline_mode = wandb_offline_mode + self.wandb_watch = wandb_watch + # checkpoint parameters + self.model_checkpointing = model_checkpointing + self.chekpoint_dir = chekpoint_dir + self.checkpoint_filename = checkpoint_filename + self.save_top_k = save_top_k + self.save_last = save_last + # prediction callback parameters + self.prediction_batch_size = prediction_batch_size + # hard negatives callback parameters + self.max_hard_negatives_to_mine = max_hard_negatives_to_mine + self.hard_negatives_threshold = hard_negatives_threshold + self.metrics_to_monitor_for_hard_negatives = ( + metrics_to_monitor_for_hard_negatives + ) + self.mine_hard_negatives_with_probability = mine_hard_negatives_with_probability + # other parameters + self.seed = seed + self.float32_matmul_precision = float32_matmul_precision + + if self.max_epochs is None and self.max_steps is None: + raise ValueError( + "Either `max_epochs` or `max_steps` should be specified in the trainer configuration" + ) + + if self.max_epochs is not None and self.max_steps is not None: + logger.log( + "Both `max_epochs` and `max_steps` are specified in the trainer configuration. " + "Will use `max_epochs` for the number of training steps" + ) + self.max_steps = None + + # reproducibility + pl.seed_everything(self.seed) + # set the precision of matmul operations + torch.set_float32_matmul_precision(self.float32_matmul_precision) + + # lightning data module declaration + self.lightining_datamodule = self.configure_lightning_datamodule() + + if self.max_epochs is not None: + logger.log(f"Number of training epochs: {self.max_epochs}") + self.max_steps = ( + len(self.lightining_datamodule.train_dataloader()) * self.max_epochs + ) + + # optimizer declaration + self.optimizer, self.lr_scheduler = self.configure_optimizers() + + # lightning module declaration + self.lightining_module = self.configure_lightning_module() + + # callbacks declaration + self.callbacks_store: List[pl.Callback] = self.configure_callbacks() + + logger.log("Instantiating the Trainer") + self.trainer = pl.Trainer( + accelerator=self.accelerator, + devices=self.devices, + num_nodes=self.num_nodes, + strategy=self.strategy, + accumulate_grad_batches=self.accumulate_grad_batches, + max_epochs=self.max_epochs, + max_steps=self.max_steps, + gradient_clip_val=self.gradient_clip_val, + val_check_interval=self.val_check_interval, + check_val_every_n_epoch=self.check_val_every_n_epoch, + deterministic=self.deterministic, + fast_dev_run=self.fast_dev_run, + precision=self.precision, + reload_dataloaders_every_n_epochs=self.reload_dataloaders_every_n_epochs, + callbacks=self.callbacks_store, + logger=self.wandb_logger, + ) + + def configure_lightning_datamodule(self, *args, **kwargs): + # lightning data module declaration + if isinstance(self.val_dataset, GoldenRetrieverDataset): + self.val_dataset = [self.val_dataset] + if self.test_dataset is not None and isinstance( + self.test_dataset, GoldenRetrieverDataset + ): + self.test_dataset = [self.test_dataset] + + self.lightining_datamodule = GoldenRetrieverPLDataModule( + train_dataset=self.train_dataset, + val_datasets=self.val_dataset, + test_datasets=self.test_dataset, + num_workers=self.num_workers, + *args, + **kwargs, + ) + return self.lightining_datamodule + + def configure_lightning_module(self, *args, **kwargs): + # add loss object to the retriever + if self.retriever.loss_type is None: + self.retriever.loss_type = self.loss() + + # lightning module declaration + self.lightining_module = GoldenRetrieverPLModule( + model=self.retriever, + optimizer=self.optimizer, + lr_scheduler=self.lr_scheduler, + *args, + **kwargs, + ) + + return self.lightining_module + + def configure_optimizers(self, *args, **kwargs): + # check if it is the class or the instance + if isinstance(self.optimizer, type): + self.optimizer = self.optimizer( + params=self.retriever.parameters(), + lr=self.lr, + weight_decay=self.weight_decay, + ) + else: + self.optimizer = self.optimizer + + # LR Scheduler declaration + # check if it is the class, the instance or a function + if self.lr_scheduler is not None: + if isinstance(self.lr_scheduler, type): + self.lr_scheduler = self.lr_scheduler( + optimizer=self.optimizer, + num_warmup_steps=self.num_warmup_steps, + num_training_steps=self.max_steps, + ) + + return self.optimizer, self.lr_scheduler + + def configure_callbacks(self, *args, **kwargs): + # callbacks declaration + self.callbacks_store = self.callbacks or [] + self.callbacks_store.append(ModelSummary(max_depth=2)) + + # metric to monitor + if isinstance(self.top_ks, int): + self.top_ks = [self.top_ks] + # order the top_ks in descending order + self.top_ks = sorted(self.top_ks, reverse=True) + # get the max top_k to monitor + self.top_k = self.top_ks[0] + self.metric_to_monitor = f"validate_recall@{self.top_k}" + self.monitor_mode = "max" + + # early stopping callback if specified + self.early_stopping_callback: Optional[EarlyStopping] = None + if self.early_stopping: + logger.log( + f"Eanbling Early Stopping, patience: {self.early_stopping_patience}" + ) + self.early_stopping_callback = EarlyStopping( + monitor=self.metric_to_monitor, + mode=self.monitor_mode, + patience=self.early_stopping_patience, + ) + self.callbacks_store.append(self.early_stopping_callback) + + # wandb logger if specified + self.wandb_logger: Optional[WandbLogger] = None + self.experiment_path: Optional[Path] = None + if self.log_to_wandb: + # define some default values for the wandb logger + if self.wandb_project_name is None: + self.wandb_project_name = "relik-retriever" + if self.wandb_save_dir is None: + self.wandb_save_dir = "./" + logger.log("Instantiating Wandb Logger") + self.wandb_logger = WandbLogger( + entity=self.wandb_entity, + project=self.wandb_project_name, + name=self.wandb_experiment_name, + save_dir=self.wandb_save_dir, + log_model=self.wandb_log_model, + mode="offline" if self.wandb_offline_mode else "online", + ) + self.wandb_logger.watch(self.lightining_module, log=self.wandb_watch) + self.experiment_path = Path(self.wandb_logger.experiment.dir) + # Store the YaML config separately into the wandb dir + # yaml_conf: str = OmegaConf.to_yaml(cfg=conf) + # (experiment_path / "hparams.yaml").write_text(yaml_conf) + # Add a Learning Rate Monitor callback to log the learning rate + self.callbacks_store.append(LearningRateMonitor(logging_interval="step")) + + # model checkpoint callback if specified + self.model_checkpoint_callback: Optional[ModelCheckpoint] = None + if self.model_checkpointing: + logger.log("Enabling Model Checkpointing") + if self.chekpoint_dir is None: + self.chekpoint_dir = ( + self.experiment_path / "checkpoints" + if self.experiment_path + else None + ) + if self.checkpoint_filename is None: + self.checkpoint_filename = ( + "checkpoint-validate_recall@" + + str(self.top_k) + + "_{validate_recall@" + + str(self.top_k) + + ":.4f}-epoch_{epoch:02d}" + ) + self.model_checkpoint_callback = ModelCheckpoint( + monitor=self.metric_to_monitor, + mode=self.monitor_mode, + verbose=True, + save_top_k=self.save_top_k, + save_last=self.save_last, + filename=self.checkpoint_filename, + dirpath=self.chekpoint_dir, + auto_insert_metric_name=False, + ) + self.callbacks_store.append(self.model_checkpoint_callback) + + # prediction callback + self.other_callbacks_for_prediction = [ + RecallAtKEvaluationCallback(k) for k in self.top_ks + ] + self.other_callbacks_for_prediction += [ + AvgRankingEvaluationCallback(k=self.top_k, verbose=True, prefix="train"), + SavePredictionsCallback(), + ] + self.prediction_callback = GoldenRetrieverPredictionCallback( + k=self.top_k, + batch_size=self.prediction_batch_size, + precision=self.precision, + other_callbacks=self.other_callbacks_for_prediction, + ) + self.callbacks_store.append(self.prediction_callback) + + # hard negative mining callback + self.hard_negatives_callback: Optional[NegativeAugmentationCallback] = None + if self.max_hard_negatives_to_mine > 0: + self.metrics_to_monitor = ( + self.metrics_to_monitor_for_hard_negatives + or f"validate_recall@{self.top_k}" + ) + self.hard_negatives_callback = NegativeAugmentationCallback( + k=self.top_k, + batch_size=self.prediction_batch_size, + precision=self.precision, + stages=["validate"], + metrics_to_monitor=self.metrics_to_monitor, + threshold=self.hard_negatives_threshold, + max_negatives=self.max_hard_negatives_to_mine, + add_with_probability=self.mine_hard_negatives_with_probability, + refresh_every_n_epochs=1, + other_callbacks=[ + AvgRankingEvaluationCallback( + k=self.top_k, verbose=True, prefix="train" + ) + ], + ) + self.callbacks_store.append(self.hard_negatives_callback) + + # utils callback + self.callbacks_store.extend( + [SaveRetrieverCallback(), FreeUpIndexerVRAMCallback()] + ) + return self.callbacks_store + + def train(self): + self.trainer.fit(self.lightining_module, datamodule=self.lightining_datamodule) + + def test( + self, + lightining_module: Optional[GoldenRetrieverPLModule] = None, + checkpoint_path: Optional[Union[str, os.PathLike]] = None, + lightining_datamodule: Optional[GoldenRetrieverPLDataModule] = None, + ): + if lightining_module is not None: + self.lightining_module = lightining_module + else: + if self.fast_dev_run: + best_lightining_module = self.lightining_module + else: + # load best model for testing + if checkpoint_path is not None: + best_model_path = checkpoint_path + elif self.checkpoint_path: + best_model_path = self.checkpoint_path + elif self.model_checkpoint_callback: + best_model_path = self.model_checkpoint_callback.best_model_path + else: + raise ValueError( + "Either `checkpoint_path` or `model_checkpoint_callback` should " + "be provided to the trainer" + ) + logger.log(f"Loading best model from {best_model_path}") + + try: + best_lightining_module = ( + GoldenRetrieverPLModule.load_from_checkpoint(best_model_path) + ) + except Exception as e: + logger.log(f"Failed to load the model from checkpoint: {e}") + logger.log("Using last model instead") + best_lightining_module = self.lightining_module + + lightining_datamodule = lightining_datamodule or self.lightining_datamodule + # module test + self.trainer.test(best_lightining_module, datamodule=lightining_datamodule) + + +def train(conf: omegaconf.DictConfig) -> None: + # reproducibility + pl.seed_everything(conf.train.seed) + torch.set_float32_matmul_precision(conf.train.float32_matmul_precision) + + logger.log(f"Starting training for [bold cyan]{conf.model_name}[/bold cyan] model") + if conf.train.pl_trainer.fast_dev_run: + logger.log( + f"Debug mode {conf.train.pl_trainer.fast_dev_run}. Forcing debugger configuration" + ) + # Debuggers don't like GPUs nor multiprocessing + # conf.train.pl_trainer.accelerator = "cpu" + conf.train.pl_trainer.devices = 1 + conf.train.pl_trainer.strategy = "auto" + conf.train.pl_trainer.precision = 32 + if "num_workers" in conf.data.datamodule: + conf.data.datamodule.num_workers = { + k: 0 for k in conf.data.datamodule.num_workers + } + # Switch wandb to offline mode to prevent online logging + conf.logging.log = None + # remove model checkpoint callback + conf.train.model_checkpoint_callback = None + + if "print_config" in conf and conf.print_config: + pprint(OmegaConf.to_container(conf), console=logger, expand_all=True) + + # data module declaration + logger.log("Instantiating the Data Module") + pl_data_module: GoldenRetrieverPLDataModule = hydra.utils.instantiate( + conf.data.datamodule, _recursive_=False + ) + # force setup to get labels initialized for the model + pl_data_module.prepare_data() + # main module declaration + pl_module: Optional[GoldenRetrieverPLModule] = None + + if not conf.train.only_test: + pl_data_module.setup("fit") + + # count the number of training steps + if ( + "max_epochs" in conf.train.pl_trainer + and conf.train.pl_trainer.max_epochs > 0 + ): + num_training_steps = ( + len(pl_data_module.train_dataloader()) + * conf.train.pl_trainer.max_epochs + ) + if "max_steps" in conf.train.pl_trainer: + logger.log( + "Both `max_epochs` and `max_steps` are specified in the trainer configuration. " + "Will use `max_epochs` for the number of training steps" + ) + conf.train.pl_trainer.max_steps = None + elif ( + "max_steps" in conf.train.pl_trainer and conf.train.pl_trainer.max_steps > 0 + ): + num_training_steps = conf.train.pl_trainer.max_steps + conf.train.pl_trainer.max_epochs = None + else: + raise ValueError( + "Either `max_epochs` or `max_steps` should be specified in the trainer configuration" + ) + logger.log(f"Expected number of training steps: {num_training_steps}") + + if "lr_scheduler" in conf.model.pl_module and conf.model.pl_module.lr_scheduler: + # set the number of warmup steps as x% of the total number of training steps + if conf.model.pl_module.lr_scheduler.num_warmup_steps is None: + if ( + "warmup_steps_ratio" in conf.model.pl_module + and conf.model.pl_module.warmup_steps_ratio is not None + ): + conf.model.pl_module.lr_scheduler.num_warmup_steps = int( + conf.model.pl_module.lr_scheduler.num_training_steps + * conf.model.pl_module.warmup_steps_ratio + ) + else: + conf.model.pl_module.lr_scheduler.num_warmup_steps = 0 + logger.log( + f"Number of warmup steps: {conf.model.pl_module.lr_scheduler.num_warmup_steps}" + ) + + logger.log("Instantiating the Model") + pl_module: GoldenRetrieverPLModule = hydra.utils.instantiate( + conf.model.pl_module, _recursive_=False + ) + if ( + "pretrain_ckpt_path" in conf.train + and conf.train.pretrain_ckpt_path is not None + ): + logger.log( + f"Loading pretrained checkpoint from {conf.train.pretrain_ckpt_path}" + ) + pl_module.load_state_dict( + torch.load(conf.train.pretrain_ckpt_path)["state_dict"], strict=False + ) + + if "compile" in conf.model.pl_module and conf.model.pl_module.compile: + try: + pl_module = torch.compile(pl_module, backend="inductor") + except Exception: + logger.log( + "Failed to compile the model, you may need to install PyTorch 2.0" + ) + + # callbacks declaration + callbacks_store = [ModelSummary(max_depth=2)] + + experiment_logger: Optional[WandbLogger] = None + experiment_path: Optional[Path] = None + if conf.logging.log: + logger.log("Instantiating Wandb Logger") + experiment_logger = hydra.utils.instantiate(conf.logging.wandb_arg) + if pl_module is not None: + # it may happen that the model is not instantiated if we are only testing + # in that case, we don't need to watch the model + experiment_logger.watch(pl_module, **conf.logging.watch) + experiment_path = Path(experiment_logger.experiment.dir) + # Store the YaML config separately into the wandb dir + yaml_conf: str = OmegaConf.to_yaml(cfg=conf) + (experiment_path / "hparams.yaml").write_text(yaml_conf) + # Add a Learning Rate Monitor callback to log the learning rate + callbacks_store.append(LearningRateMonitor(logging_interval="step")) + + early_stopping_callback: Optional[EarlyStopping] = None + if conf.train.early_stopping_callback is not None: + early_stopping_callback = hydra.utils.instantiate( + conf.train.early_stopping_callback + ) + callbacks_store.append(early_stopping_callback) + + model_checkpoint_callback: Optional[ModelCheckpoint] = None + if conf.train.model_checkpoint_callback is not None: + model_checkpoint_callback = hydra.utils.instantiate( + conf.train.model_checkpoint_callback, + dirpath=experiment_path / "checkpoints" if experiment_path else None, + ) + callbacks_store.append(model_checkpoint_callback) + + if "callbacks" in conf.train and conf.train.callbacks is not None: + for _, callback in conf.train.callbacks.items(): + # callback can be a list of callbacks or a single callback + if isinstance(callback, omegaconf.listconfig.ListConfig): + for cb in callback: + if cb is not None: + callbacks_store.append( + hydra.utils.instantiate(cb, _recursive_=False) + ) + else: + if callback is not None: + callbacks_store.append(hydra.utils.instantiate(callback)) + + # trainer + logger.log("Instantiating the Trainer") + trainer: Trainer = hydra.utils.instantiate( + conf.train.pl_trainer, callbacks=callbacks_store, logger=experiment_logger + ) + + if not conf.train.only_test: + # module fit + trainer.fit(pl_module, datamodule=pl_data_module) + + if conf.train.pl_trainer.fast_dev_run: + best_pl_module = pl_module + else: + # load best model for testing + if conf.train.checkpoint_path: + best_model_path = conf.evaluation.checkpoint_path + elif model_checkpoint_callback: + best_model_path = model_checkpoint_callback.best_model_path + else: + raise ValueError( + "Either `checkpoint_path` or `model_checkpoint_callback` should " + "be specified in the evaluation configuration" + ) + logger.log(f"Loading best model from {best_model_path}") + + try: + best_pl_module = GoldenRetrieverPLModule.load_from_checkpoint( + best_model_path + ) + except Exception as e: + logger.log(f"Failed to load the model from checkpoint: {e}") + logger.log("Using last model instead") + best_pl_module = pl_module + if "compile" in conf.model.pl_module and conf.model.pl_module.compile: + try: + best_pl_module = torch.compile(best_pl_module, backend="inductor") + except Exception: + logger.log( + "Failed to compile the model, you may need to install PyTorch 2.0" + ) + + # module test + trainer.test(best_pl_module, datamodule=pl_data_module) + + +@hydra.main(config_path="../../conf", config_name="default", version_base="1.3") +def main(conf: omegaconf.DictConfig): + train(conf) + + +if __name__ == "__main__": + main() diff --git a/relik/version.py b/relik/version.py new file mode 100644 index 0000000000000000000000000000000000000000..bed137800c980e0e82d7c8ccdf474053baed630f --- /dev/null +++ b/relik/version.py @@ -0,0 +1,13 @@ +import os + +_MAJOR = "0" +_MINOR = "1" +# On main and in a nightly release the patch should be one ahead of the last +# released build. +_PATCH = "0" +# This is mainly for nightly builds which have the suffix ".dev$DATE". See +# https://semver.org/#is-v123-a-semantic-version for the semantics. +_SUFFIX = os.environ.get("RELIK_VERSION_SUFFIX", "") + +VERSION_SHORT = "{0}.{1}".format(_MAJOR, _MINOR) +VERSION = "{0}.{1}.{2}{3}".format(_MAJOR, _MINOR, _PATCH, _SUFFIX) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..5933df9197cc3d656a9463e4da1a1ea5680bbf93 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,36 @@ +#------- Core dependencies ------- +torch>=2.0 +transformers[sentencepiece]>=4.34,<4.35 +rich>=13.0.0,<14.0.0 +scikit-learn +overrides + +#------- Optional dependencies ------- + +# train +lightning>=2.0,<2.1 +hydra-core>=1.3,<1.4 +hydra_colorlog +wandb>=0.15,<0.16 +datasets>=2.13,<2.15 + +# faiss +faiss-cpu==1.7.4 # needed by: faiss + +# serve +fastapi>=0.103,<0.104 # needed by: serve +uvicorn[standard]==0.23.2 # needed by: serve +gunicorn==21.2.0 # needed by: serve +ray[serve]>=2.7,<=2.8 # needed by: serve +ipa-core # needed by: serve +streamlit>=1.27,<1.28 # needed by: serve +streamlit_extras>=0.3,<0.4 # needed by: serve + +# retriever + +# reader + +# dev +pre-commit +black[d] +isort diff --git a/scripts/setup.sh b/scripts/setup.sh new file mode 100755 index 0000000000000000000000000000000000000000..36f2afd2cba191cb89c2b36ffc64d51cd2274cc5 --- /dev/null +++ b/scripts/setup.sh @@ -0,0 +1,46 @@ +#!/bin/bash + +# setup conda +CONDA_BASE=$(conda info --base) +# check if conda is installed +if [ -z "$CONDA_BASE" ]; then + echo "Conda is not installed. Please install conda first." + exit 1 +fi +source "$CONDA_BASE"/etc/profile.d/conda.sh + +# create conda env +read -rp "Enter environment name or prefix: " ENV_NAME +read -rp "Enter python version (default 3.10): " PYTHON_VERSION +if [ -z "$PYTHON_VERSION" ]; then + PYTHON_VERSION="3.10" +fi + +# check if ENV_NAME is a full path +if [[ "$ENV_NAME" == /* ]]; then + CONDA_NEW_ARG="--prefix" +else + CONDA_NEW_ARG="--name" +fi + +conda create -y "$CONDA_NEW_ARG" "$ENV_NAME" python="$PYTHON_VERSION" +conda activate "$ENV_NAME" + +# replace placeholder env with $ENV_NAME in scripts/train.sh +# NEW_CONDA_LINE="source \$CONDA_BASE/bin/activate $ENV_NAME" +# sed -i.bak -e "s,.*bin/activate.*,$NEW_CONDA_LINE,g" scripts/train.sh + +# install torch +read -rp "Enter cuda version (e.g. '11.8', default no cuda support): " CUDA_VERSION +read -rp "Enter PyTorch version (e.g. '2.1', default latest): " PYTORCH_VERSION +if [ -n "$PYTORCH_VERSION" ]; then + PYTORCH_VERSION="=$PYTORCH_VERSION" +fi +if [ -z "$CUDA_VERSION" ]; then + conda install -y pytorch"$PYTORCH_VERSION" cpuonly -c pytorch +else + conda install -y pytorch"$PYTORCH_VERSION" pytorch-cuda="$CUDA_VERSION" -c pytorch -c nvidia +fi + +# install python requirements +pip install -e .[all] diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..37575622936c68b53ced7e8f63fb32a3fa701047 --- /dev/null +++ b/setup.py @@ -0,0 +1,128 @@ +""" +Simple check list from AllenNLP repo: https://github.com/allenai/allennlp/blob/master/setup.py +To create the package for pypi. +1. Run `make pre-release` (or `make pre-patch` for a patch release) then run `make fix-copies` to fix the index of the + documentation. +2. Run Tests for Amazon Sagemaker. The documentation is located in `./tests/sagemaker/README.md`, otherwise @philschmid. +3. Unpin specific versions from setup.py that use a git install. +4. Commit these changes with the message: "Release: VERSION" +5. Add a tag in git to mark the release: "git tag VERSION -m 'Adds tag VERSION for pypi' " + Push the tag to git: git push --tags origin master +6. Build both the sources and the wheel. Do not change anything in setup.py between + creating the wheel and the source distribution (obviously). + For the wheel, run: "python setup.py bdist_wheel" in the top level directory. + (this will build a wheel for the python version you use to build it). + For the sources, run: "python setup.py sdist" + You should now have a /dist directory with both .whl and .tar.gz source versions. +7. Check that everything looks correct by uploading the package to the pypi test server: + twine upload dist/* -r pypitest + (pypi suggest using twine as other methods upload files via plaintext.) + You may have to specify the repository url, use the following command then: + twine upload dist/* -r pypitest --repository-url=https://test.pypi.org/legacy/ + Check that you can install it in a virtualenv by running: + pip install -i https://testpypi.python.org/pypi transformers +8. Upload the final version to actual pypi: + twine upload dist/* -r pypi +9. Copy the release notes from RELEASE.md to the tag in github once everything is looking hunky-dory. +10. Run `make post-release` (or `make post-patch` for a patch release). +""" +from collections import defaultdict + +import setuptools + + +def parse_requirements_file( + path, allowed_extras: set = None, include_all_extra: bool = True +): + requirements = [] + extras = defaultdict(list) + find_links = [] + with open(path) as requirements_file: + import re + + def fix_url_dependencies(req: str) -> str: + """Pip and setuptools disagree about how URL dependencies should be handled.""" + m = re.match( + r"^(git\+)?(https|ssh)://(git@)?github\.com/([\w-]+)/(?P[\w-]+)\.git", + req, + ) + if m is None: + return req + else: + return f"{m.group('name')} @ {req}" + + for line in requirements_file: + line = line.strip() + if line.startswith("#") or len(line) <= 0: + continue + if ( + line.startswith("-f") + or line.startswith("--find-links") + or line.startswith("--index-url") + ): + find_links.append(line.split(" ", maxsplit=1)[-1].strip()) + continue + + req, *needed_by = line.split("# needed by:") + req = fix_url_dependencies(req.strip()) + if needed_by: + for extra in needed_by[0].strip().split(","): + extra = extra.strip() + if allowed_extras is not None and extra not in allowed_extras: + raise ValueError(f"invalid extra '{extra}' in {path}") + extras[extra].append(req) + if include_all_extra and req not in extras["all"]: + extras["all"].append(req) + else: + requirements.append(req) + return requirements, extras, find_links + + +allowed_extras = { + "onnx", + "onnx-gpu", + "serve", + "retriever", + "reader", + "all", + "faiss", + "dev", +} + +# Load requirements. +install_requirements, extras, find_links = parse_requirements_file( + "requirements.txt", allowed_extras=allowed_extras +) + +# version.py defines the VERSION and VERSION_SHORT variables. +# We use exec here, so we don't import allennlp whilst setting up. +VERSION = {} # type: ignore +with open("relik/version.py", "r") as version_file: + exec(version_file.read(), VERSION) + +with open("README.md", "r") as fh: + long_description = fh.read() + +setuptools.setup( + name="relik", + version=VERSION["VERSION"], + author="Edoardo Barba, Riccardo Orlando, Pere-Lluís Huguet Cabot", + author_email="orlandorcc@gmail.com", + description="Fast and Accurate Entity Linking and Relation Extraction on an Academic Budget", + long_description=long_description, + long_description_content_type="text/markdown", + url="https://github.com/SapienzaNLP/relik", + keywords="NLP Sapienza sapienzanlp deep learning transformer pytorch retriever entity linking relation extraction reader budget", + packages=setuptools.find_packages(), + include_package_data=True, + license="Apache", + classifiers=[ + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + ], + install_requires=install_requirements, + extras_require=extras, + python_requires=">=3.10", +) diff --git a/style.css b/style.css new file mode 100644 index 0000000000000000000000000000000000000000..31f0d182cfd9b2636d5db5cbd0e7a1339ed5d1c3 --- /dev/null +++ b/style.css @@ -0,0 +1,33 @@ +/* Sidebar */ +.eczjsme11 { + background-color: #802433; +} + +.st-emotion-cache-10oheav h2 { + color: white; +} + +.st-emotion-cache-10oheav li { + color: white; +} + +/* Main */ +a:link { + text-decoration: none; + color: white; +} + +a:visited { + text-decoration: none; + color: white; +} + +a:hover { + text-decoration: none; + color: rgba(255, 255, 255, 0.871); +} + +a:active { + text-decoration: none; + color: white; +} \ No newline at end of file