CarlosMalaga commited on
Commit
2f044c1
1 Parent(s): 3376207

Upload 201 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. relik/__init__.py +8 -0
  2. relik/common/__init__.py +0 -0
  3. relik/common/__pycache__/__init__.cpython-310.pyc +0 -0
  4. relik/common/__pycache__/log.cpython-310.pyc +0 -0
  5. relik/common/__pycache__/torch_utils.cpython-310.pyc +0 -0
  6. relik/common/__pycache__/upload.cpython-310.pyc +0 -0
  7. relik/common/__pycache__/utils.cpython-310.pyc +0 -0
  8. relik/common/log.py +174 -0
  9. relik/common/torch_utils.py +82 -0
  10. relik/common/upload.py +144 -0
  11. relik/common/utils.py +610 -0
  12. relik/inference/__init__.py +0 -0
  13. relik/inference/__pycache__/__init__.cpython-310.pyc +0 -0
  14. relik/inference/__pycache__/annotator.cpython-310.pyc +0 -0
  15. relik/inference/annotator.py +840 -0
  16. relik/inference/data/__init__.py +0 -0
  17. relik/inference/data/__pycache__/__init__.cpython-310.pyc +0 -0
  18. relik/inference/data/__pycache__/objects.cpython-310.pyc +0 -0
  19. relik/inference/data/objects.py +88 -0
  20. relik/inference/data/splitters/__init__.py +0 -0
  21. relik/inference/data/splitters/__pycache__/__init__.cpython-310.pyc +0 -0
  22. relik/inference/data/splitters/__pycache__/base_sentence_splitter.cpython-310.pyc +0 -0
  23. relik/inference/data/splitters/__pycache__/blank_sentence_splitter.cpython-310.pyc +0 -0
  24. relik/inference/data/splitters/__pycache__/spacy_sentence_splitter.cpython-310.pyc +0 -0
  25. relik/inference/data/splitters/__pycache__/window_based_splitter.cpython-310.pyc +0 -0
  26. relik/inference/data/splitters/base_sentence_splitter.py +55 -0
  27. relik/inference/data/splitters/blank_sentence_splitter.py +29 -0
  28. relik/inference/data/splitters/spacy_sentence_splitter.py +153 -0
  29. relik/inference/data/splitters/window_based_splitter.py +62 -0
  30. relik/inference/data/tokenizers/__init__.py +87 -0
  31. relik/inference/data/tokenizers/__pycache__/__init__.cpython-310.pyc +0 -0
  32. relik/inference/data/tokenizers/__pycache__/base_tokenizer.cpython-310.pyc +0 -0
  33. relik/inference/data/tokenizers/__pycache__/spacy_tokenizer.cpython-310.pyc +0 -0
  34. relik/inference/data/tokenizers/base_tokenizer.py +84 -0
  35. relik/inference/data/tokenizers/spacy_tokenizer.py +194 -0
  36. relik/inference/data/window/__init__.py +0 -0
  37. relik/inference/data/window/__pycache__/__init__.cpython-310.pyc +0 -0
  38. relik/inference/data/window/__pycache__/manager.cpython-310.pyc +0 -0
  39. relik/inference/data/window/manager.py +431 -0
  40. relik/inference/gerbil.py +269 -0
  41. relik/inference/serve/__init__.py +0 -0
  42. relik/inference/serve/backend/__init__.py +0 -0
  43. relik/inference/serve/backend/fastapi.py +122 -0
  44. relik/inference/serve/backend/ray.py +165 -0
  45. relik/inference/serve/backend/utils.py +38 -0
  46. relik/inference/serve/frontend/__init__.py +0 -0
  47. relik/inference/serve/frontend/relik_front.py +229 -0
  48. relik/inference/serve/frontend/relik_re_front.py +251 -0
  49. relik/inference/serve/frontend/style.css +33 -0
  50. relik/inference/serve/frontend/utils.py +132 -0
relik/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from relik.inference.annotator import Relik
2
+ from pathlib import Path
3
+
4
+ VERSION = {} # type: ignore
5
+ with open(Path(__file__).parent / "version.py", "r") as version_file:
6
+ exec(version_file.read(), VERSION)
7
+
8
+ __version__ = VERSION["VERSION"]
relik/common/__init__.py ADDED
File without changes
relik/common/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (180 Bytes). View file
 
relik/common/__pycache__/log.cpython-310.pyc ADDED
Binary file (4.39 kB). View file
 
relik/common/__pycache__/torch_utils.cpython-310.pyc ADDED
Binary file (1.11 kB). View file
 
relik/common/__pycache__/upload.cpython-310.pyc ADDED
Binary file (4.04 kB). View file
 
relik/common/__pycache__/utils.cpython-310.pyc ADDED
Binary file (14.8 kB). View file
 
relik/common/log.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import sys
4
+ import threading
5
+ from logging.config import dictConfig
6
+ from typing import Any, Dict, Optional
7
+
8
+ from art import text2art, tprint
9
+ from colorama import Fore, Style, init
10
+ from rich import get_console
11
+
12
+ _lock = threading.Lock()
13
+ _default_handler: Optional[logging.Handler] = None
14
+
15
+ _default_log_level = logging.WARNING
16
+
17
+ # fancy logger
18
+ _console = get_console()
19
+
20
+
21
+ class ColorfulFormatter(logging.Formatter):
22
+ """
23
+ Formatter to add coloring to log messages by log type
24
+ """
25
+
26
+ COLORS = {
27
+ "WARNING": Fore.YELLOW,
28
+ "ERROR": Fore.RED,
29
+ "CRITICAL": Fore.RED + Style.BRIGHT,
30
+ "DEBUG": Fore.CYAN,
31
+ # "INFO": Fore.GREEN,
32
+ }
33
+
34
+ def format(self, record):
35
+ record.rank = int(os.getenv("LOCAL_RANK", "0"))
36
+ log_message = super().format(record)
37
+ return self.COLORS.get(record.levelname, "") + log_message + Fore.RESET
38
+
39
+
40
+ DEFAULT_LOGGING_CONFIG: Dict[str, Any] = {
41
+ "version": 1,
42
+ "formatters": {
43
+ "simple": {
44
+ "format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] %(message)s",
45
+ },
46
+ "colorful": {
47
+ "()": ColorfulFormatter,
48
+ "format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] [RANK:%(rank)d] %(message)s",
49
+ },
50
+ },
51
+ "filters": {},
52
+ "handlers": {
53
+ "console": {
54
+ "class": "logging.StreamHandler",
55
+ "formatter": "simple",
56
+ "filters": [],
57
+ "stream": sys.stdout,
58
+ },
59
+ "color_console": {
60
+ "class": "logging.StreamHandler",
61
+ "formatter": "colorful",
62
+ "filters": [],
63
+ "stream": sys.stdout,
64
+ },
65
+ },
66
+ "root": {"handlers": ["console"], "level": os.getenv("LOG_LEVEL", "INFO")},
67
+ "loggers": {
68
+ "relik": {
69
+ "handlers": ["color_console"],
70
+ "level": "DEBUG",
71
+ "propagate": False,
72
+ },
73
+ },
74
+ }
75
+
76
+
77
+ def configure_logging(**kwargs):
78
+ """Configure with default logging"""
79
+ init() # Initialize colorama
80
+ # merge DEFAULT_LOGGING_CONFIG with kwargs
81
+ logger_config = DEFAULT_LOGGING_CONFIG
82
+ if kwargs:
83
+ logger_config.update(kwargs)
84
+ dictConfig(logger_config)
85
+
86
+
87
+ def _get_library_name() -> str:
88
+ return __name__.split(".")[0]
89
+
90
+
91
+ def _get_library_root_logger() -> logging.Logger:
92
+ return logging.getLogger(_get_library_name())
93
+
94
+
95
+ def _configure_library_root_logger() -> None:
96
+ global _default_handler
97
+
98
+ with _lock:
99
+ if _default_handler:
100
+ # This library has already configured the library root logger.
101
+ return
102
+ _default_handler = logging.StreamHandler() # Set sys.stderr as stream.
103
+ _default_handler.flush = sys.stderr.flush
104
+
105
+ # Apply our default configuration to the library root logger.
106
+ library_root_logger = _get_library_root_logger()
107
+ library_root_logger.addHandler(_default_handler)
108
+ library_root_logger.setLevel(_default_log_level)
109
+ library_root_logger.propagate = False
110
+
111
+
112
+ def _reset_library_root_logger() -> None:
113
+ global _default_handler
114
+
115
+ with _lock:
116
+ if not _default_handler:
117
+ return
118
+
119
+ library_root_logger = _get_library_root_logger()
120
+ library_root_logger.removeHandler(_default_handler)
121
+ library_root_logger.setLevel(logging.NOTSET)
122
+ _default_handler = None
123
+
124
+
125
+ def set_log_level(level: int, logger: logging.Logger = None) -> None:
126
+ """
127
+ Set the log level.
128
+ Args:
129
+ level (:obj:`int`):
130
+ Logging level.
131
+ logger (:obj:`logging.Logger`):
132
+ Logger to set the log level.
133
+ """
134
+ if not logger:
135
+ _configure_library_root_logger()
136
+ logger = _get_library_root_logger()
137
+ logger.setLevel(level)
138
+
139
+
140
+ def get_logger(
141
+ name: Optional[str] = None,
142
+ level: Optional[int] = None,
143
+ formatter: Optional[str] = None,
144
+ **kwargs,
145
+ ) -> logging.Logger:
146
+ """
147
+ Return a logger with the specified name.
148
+ """
149
+
150
+ configure_logging(**kwargs)
151
+
152
+ if name is None:
153
+ name = _get_library_name()
154
+
155
+ _configure_library_root_logger()
156
+
157
+ if level is not None:
158
+ set_log_level(level)
159
+
160
+ if formatter is None:
161
+ formatter = logging.Formatter(
162
+ "%(asctime)s - %(levelname)s - %(name)s - %(message)s"
163
+ )
164
+ _default_handler.setFormatter(formatter)
165
+
166
+ return logging.getLogger(name)
167
+
168
+
169
+ def get_console_logger():
170
+ return _console
171
+
172
+
173
+ def print_relik_text_art(text: str = "relik", font: str = "larry3d", **kwargs):
174
+ tprint(text, font=font, **kwargs)
relik/common/torch_utils.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import tempfile
3
+
4
+ import torch
5
+ import transformers as tr
6
+
7
+ from relik.common.utils import is_package_available
8
+
9
+ # check if ORT is available
10
+ if is_package_available("onnxruntime"):
11
+ from optimum.onnxruntime import (
12
+ ORTModel,
13
+ ORTModelForCustomTasks,
14
+ ORTModelForSequenceClassification,
15
+ ORTOptimizer,
16
+ )
17
+ from optimum.onnxruntime.configuration import AutoOptimizationConfig
18
+
19
+ # from relik.retriever.pytorch_modules import PRECISION_MAP
20
+
21
+
22
+ def get_autocast_context(
23
+ device: str | torch.device, precision: str
24
+ ) -> contextlib.AbstractContextManager:
25
+ # fucking autocast only wants pure strings like 'cpu' or 'cuda'
26
+ # we need to convert the model device to that
27
+ device_type_for_autocast = str(device).split(":")[0]
28
+
29
+ from relik.retriever.pytorch_modules import PRECISION_MAP
30
+
31
+ # autocast doesn't work with CPU and stuff different from bfloat16
32
+ autocast_manager = (
33
+ contextlib.nullcontext()
34
+ if device_type_for_autocast in ["cpu", "mps"]
35
+ and PRECISION_MAP[precision] != torch.bfloat16
36
+ else (
37
+ torch.autocast(
38
+ device_type=device_type_for_autocast,
39
+ dtype=PRECISION_MAP[precision],
40
+ )
41
+ )
42
+ )
43
+ return autocast_manager
44
+
45
+
46
+ # def load_ort_optimized_hf_model(
47
+ # hf_model: tr.PreTrainedModel,
48
+ # provider: str = "CPUExecutionProvider",
49
+ # ort_model_type: callable = "ORTModelForCustomTasks",
50
+ # ) -> ORTModel:
51
+ # """
52
+ # Load an optimized ONNX Runtime HF model.
53
+ #
54
+ # Args:
55
+ # hf_model (`tr.PreTrainedModel`):
56
+ # The HF model to optimize.
57
+ # provider (`str`, optional):
58
+ # The ONNX Runtime provider to use. Defaults to "CPUExecutionProvider".
59
+ #
60
+ # Returns:
61
+ # `ORTModel`: The optimized HF model.
62
+ # """
63
+ # if isinstance(hf_model, ORTModel):
64
+ # return hf_model
65
+ # temp_dir = tempfile.mkdtemp()
66
+ # hf_model.save_pretrained(temp_dir)
67
+ # ort_model = ort_model_type.from_pretrained(
68
+ # temp_dir, export=True, provider=provider, use_io_binding=True
69
+ # )
70
+ # if is_package_available("onnxruntime"):
71
+ # optimizer = ORTOptimizer.from_pretrained(ort_model)
72
+ # optimization_config = AutoOptimizationConfig.O4()
73
+ # optimizer.optimize(save_dir=temp_dir, optimization_config=optimization_config)
74
+ # ort_model = ort_model_type.from_pretrained(
75
+ # temp_dir,
76
+ # export=True,
77
+ # provider=provider,
78
+ # use_io_binding=bool(provider == "CUDAExecutionProvider"),
79
+ # )
80
+ # return ort_model
81
+ # else:
82
+ # raise ValueError("onnxruntime is not installed. Please install Ray with `pip install relik[serve]`.")
relik/common/upload.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import logging
4
+ import os
5
+ import tempfile
6
+ import zipfile
7
+ from datetime import datetime
8
+ from pathlib import Path
9
+ from typing import Optional, Union
10
+
11
+ import huggingface_hub
12
+
13
+ from relik.common.log import get_logger
14
+ from relik.common.utils import SAPIENZANLP_DATE_FORMAT, get_md5
15
+
16
+ logger = get_logger(__name__, level=logging.DEBUG)
17
+
18
+
19
+ def create_info_file(tmpdir: Path):
20
+ logger.debug("Computing md5 of model.zip")
21
+ md5 = get_md5(tmpdir / "model.zip")
22
+ date = datetime.now().strftime(SAPIENZANLP_DATE_FORMAT)
23
+
24
+ logger.debug("Dumping info.json file")
25
+ with (tmpdir / "info.json").open("w") as f:
26
+ json.dump(dict(md5=md5, upload_date=date), f, indent=2)
27
+
28
+
29
+ def zip_run(
30
+ dir_path: Union[str, os.PathLike],
31
+ tmpdir: Union[str, os.PathLike],
32
+ zip_name: str = "model.zip",
33
+ ) -> Path:
34
+ logger.debug(f"zipping {dir_path} to {tmpdir}")
35
+ # creates a zip version of the provided dir_path
36
+ run_dir = Path(dir_path)
37
+ zip_path = tmpdir / zip_name
38
+
39
+ with zipfile.ZipFile(zip_path, "w") as zip_file:
40
+ # fully zip the run directory maintaining its structure
41
+ for file in run_dir.rglob("*.*"):
42
+ if file.is_dir():
43
+ continue
44
+
45
+ zip_file.write(file, arcname=file.relative_to(run_dir))
46
+
47
+ return zip_path
48
+
49
+
50
+ def get_logged_in_username():
51
+ token = huggingface_hub.HfFolder.get_token()
52
+ if token is None:
53
+ raise ValueError(
54
+ "No HuggingFace token found. You need to execute `huggingface-cli login` first!"
55
+ )
56
+ api = huggingface_hub.HfApi()
57
+ user = api.whoami(token=token)
58
+ return user["name"]
59
+
60
+
61
+ def upload(
62
+ model_dir: Union[str, os.PathLike],
63
+ model_name: str,
64
+ filenames: Optional[list[str]] = None,
65
+ organization: Optional[str] = None,
66
+ repo_name: Optional[str] = None,
67
+ commit: Optional[str] = None,
68
+ archive: bool = False,
69
+ ):
70
+ token = huggingface_hub.HfFolder.get_token()
71
+ if token is None:
72
+ raise ValueError(
73
+ "No HuggingFace token found. You need to execute `huggingface-cli login` first!"
74
+ )
75
+
76
+ repo_id = repo_name or model_name
77
+ if organization is not None:
78
+ repo_id = f"{organization}/{repo_id}"
79
+ with tempfile.TemporaryDirectory() as tmpdir:
80
+ api = huggingface_hub.HfApi()
81
+ repo_url = api.create_repo(
82
+ token=token,
83
+ repo_id=repo_id,
84
+ exist_ok=True,
85
+ )
86
+ repo = huggingface_hub.Repository(
87
+ str(tmpdir), clone_from=repo_url, use_auth_token=token
88
+ )
89
+
90
+ tmp_path = Path(tmpdir)
91
+ if archive:
92
+ # otherwise we zip the model_dir
93
+ logger.debug(f"Zipping {model_dir} to {tmp_path}")
94
+ zip_run(model_dir, tmp_path)
95
+ create_info_file(tmp_path)
96
+ else:
97
+ # if the user wants to upload a transformers model, we don't need to zip it
98
+ # we just need to copy the files to the tmpdir
99
+ logger.debug(f"Copying {model_dir} to {tmpdir}")
100
+ # copy only the files that are needed
101
+ if filenames is not None:
102
+ for filename in filenames:
103
+ os.system(f"cp {model_dir}/{filename} {tmpdir}")
104
+ else:
105
+ os.system(f"cp -r {model_dir}/* {tmpdir}")
106
+
107
+ # this method automatically puts large files (>10MB) into git lfs
108
+ repo.push_to_hub(commit_message=commit or "Automatic push from sapienzanlp")
109
+
110
+
111
+ def parse_args() -> argparse.Namespace:
112
+ parser = argparse.ArgumentParser()
113
+ parser.add_argument(
114
+ "model_dir", help="The directory of the model you want to upload"
115
+ )
116
+ parser.add_argument("model_name", help="The model you want to upload")
117
+ parser.add_argument(
118
+ "--organization",
119
+ help="the name of the organization where you want to upload the model",
120
+ )
121
+ parser.add_argument(
122
+ "--repo_name",
123
+ help="Optional name to use when uploading to the HuggingFace repository",
124
+ )
125
+ parser.add_argument(
126
+ "--commit", help="Commit message to use when pushing to the HuggingFace Hub"
127
+ )
128
+ parser.add_argument(
129
+ "--archive",
130
+ action="store_true",
131
+ help="""
132
+ Whether to compress the model directory before uploading it.
133
+ If True, the model directory will be zipped and the zip file will be uploaded.
134
+ If False, the model directory will be uploaded as is.""",
135
+ )
136
+ return parser.parse_args()
137
+
138
+
139
+ def main():
140
+ upload(**vars(parse_args()))
141
+
142
+
143
+ if __name__ == "__main__":
144
+ main()
relik/common/utils.py ADDED
@@ -0,0 +1,610 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib.util
2
+ import json
3
+ import logging
4
+ import os
5
+ import shutil
6
+ import tarfile
7
+ import tempfile
8
+ from functools import partial
9
+ from hashlib import sha256
10
+ from pathlib import Path
11
+ from typing import Any, BinaryIO, Dict, List, Optional, Union
12
+ from urllib.parse import urlparse
13
+ from zipfile import ZipFile, is_zipfile
14
+
15
+ import huggingface_hub
16
+ import requests
17
+ import tqdm
18
+ from filelock import FileLock
19
+ from transformers.utils.hub import cached_file as hf_cached_file
20
+
21
+ from relik.common.log import get_logger
22
+
23
+ # name constants
24
+ WEIGHTS_NAME = "weights.pt"
25
+ ONNX_WEIGHTS_NAME = "weights.onnx"
26
+ CONFIG_NAME = "config.yaml"
27
+ LABELS_NAME = "labels.json"
28
+
29
+ # SAPIENZANLP_USER_NAME = "sapienzanlp"
30
+ SAPIENZANLP_USER_NAME = "riccorl"
31
+ SAPIENZANLP_HF_MODEL_REPO_URL = "riccorl/{model_id}"
32
+ SAPIENZANLP_HF_MODEL_REPO_ARCHIVE_URL = (
33
+ f"{SAPIENZANLP_HF_MODEL_REPO_URL}/resolve/main/model.zip"
34
+ )
35
+ # path constants
36
+ HF_CACHE_DIR = Path(os.getenv("HF_HOME", Path.home() / ".cache/huggingface/hub"))
37
+ SAPIENZANLP_CACHE_DIR = os.getenv("SAPIENZANLP_CACHE_DIR", HF_CACHE_DIR)
38
+ SAPIENZANLP_DATE_FORMAT = "%Y-%m-%d %H-%M-%S"
39
+
40
+ logger = get_logger(__name__)
41
+
42
+
43
+ def sapienzanlp_model_urls(model_id: str) -> str:
44
+ """
45
+ Returns the URL for a possible SapienzaNLP valid model.
46
+
47
+ Args:
48
+ model_id (:obj:`str`):
49
+ A SapienzaNLP model id.
50
+
51
+ Returns:
52
+ :obj:`str`: The url for the model id.
53
+ """
54
+ # check if there is already the namespace of the user
55
+ if "/" in model_id:
56
+ return model_id
57
+ return SAPIENZANLP_HF_MODEL_REPO_URL.format(model_id=model_id)
58
+
59
+
60
+ def is_package_available(package_name: str) -> bool:
61
+ """
62
+ Check if a package is available.
63
+
64
+ Args:
65
+ package_name (`str`): The name of the package to check.
66
+ """
67
+ return importlib.util.find_spec(package_name) is not None
68
+
69
+
70
+ def load_json(path: Union[str, Path]) -> Any:
71
+ """
72
+ Load a json file provided in input.
73
+
74
+ Args:
75
+ path (`Union[str, Path]`): The path to the json file to load.
76
+
77
+ Returns:
78
+ `Any`: The loaded json file.
79
+ """
80
+ with open(path, encoding="utf8") as f:
81
+ return json.load(f)
82
+
83
+
84
+ def dump_json(document: Any, path: Union[str, Path], indent: Optional[int] = None):
85
+ """
86
+ Dump input to json file.
87
+
88
+ Args:
89
+ document (`Any`): The document to dump.
90
+ path (`Union[str, Path]`): The path to dump the document to.
91
+ indent (`Optional[int]`): The indent to use for the json file.
92
+
93
+ """
94
+ with open(path, "w", encoding="utf8") as outfile:
95
+ json.dump(document, outfile, indent=indent)
96
+
97
+
98
+ def get_md5(path: Path):
99
+ """
100
+ Get the MD5 value of a path.
101
+ """
102
+ import hashlib
103
+
104
+ with path.open("rb") as fin:
105
+ data = fin.read()
106
+ return hashlib.md5(data).hexdigest()
107
+
108
+
109
+ def file_exists(path: Union[str, os.PathLike]) -> bool:
110
+ """
111
+ Check if the file at :obj:`path` exists.
112
+
113
+ Args:
114
+ path (:obj:`str`, :obj:`os.PathLike`):
115
+ Path to check.
116
+
117
+ Returns:
118
+ :obj:`bool`: :obj:`True` if the file exists.
119
+ """
120
+ return Path(path).exists()
121
+
122
+
123
+ def dir_exists(path: Union[str, os.PathLike]) -> bool:
124
+ """
125
+ Check if the directory at :obj:`path` exists.
126
+
127
+ Args:
128
+ path (:obj:`str`, :obj:`os.PathLike`):
129
+ Path to check.
130
+
131
+ Returns:
132
+ :obj:`bool`: :obj:`True` if the directory exists.
133
+ """
134
+ return Path(path).is_dir()
135
+
136
+
137
+ def is_remote_url(url_or_filename: Union[str, Path]):
138
+ """
139
+ Returns :obj:`True` if the input path is an url.
140
+
141
+ Args:
142
+ url_or_filename (:obj:`str`, :obj:`Path`):
143
+ path to check.
144
+
145
+ Returns:
146
+ :obj:`bool`: :obj:`True` if the input path is an url, :obj:`False` otherwise.
147
+
148
+ """
149
+ if isinstance(url_or_filename, Path):
150
+ url_or_filename = str(url_or_filename)
151
+ parsed = urlparse(url_or_filename)
152
+ return parsed.scheme in ("http", "https")
153
+
154
+
155
+ def url_to_filename(resource: str, etag: str = None) -> str:
156
+ """
157
+ Convert a `resource` into a hashed filename in a repeatable way.
158
+ If `etag` is specified, append its hash to the resources's, delimited
159
+ by a period.
160
+ """
161
+ resource_bytes = resource.encode("utf-8")
162
+ resource_hash = sha256(resource_bytes)
163
+ filename = resource_hash.hexdigest()
164
+
165
+ if etag:
166
+ etag_bytes = etag.encode("utf-8")
167
+ etag_hash = sha256(etag_bytes)
168
+ filename += "." + etag_hash.hexdigest()
169
+
170
+ return filename
171
+
172
+
173
+ def download_resource(
174
+ url: str,
175
+ temp_file: BinaryIO,
176
+ headers=None,
177
+ ):
178
+ """
179
+ Download remote file.
180
+ """
181
+
182
+ if headers is None:
183
+ headers = {}
184
+
185
+ r = requests.get(url, stream=True, headers=headers)
186
+ r.raise_for_status()
187
+ content_length = r.headers.get("Content-Length")
188
+ total = int(content_length) if content_length is not None else None
189
+ progress = tqdm(
190
+ unit="B",
191
+ unit_scale=True,
192
+ total=total,
193
+ desc="Downloading",
194
+ disable=logger.level in [logging.NOTSET],
195
+ )
196
+ for chunk in r.iter_content(chunk_size=1024):
197
+ if chunk: # filter out keep-alive new chunks
198
+ progress.update(len(chunk))
199
+ temp_file.write(chunk)
200
+ progress.close()
201
+
202
+
203
+ def download_and_cache(
204
+ url: Union[str, Path],
205
+ cache_dir: Union[str, Path] = None,
206
+ force_download: bool = False,
207
+ ):
208
+ if cache_dir is None:
209
+ cache_dir = SAPIENZANLP_CACHE_DIR
210
+ if isinstance(url, Path):
211
+ url = str(url)
212
+
213
+ # check if cache dir exists
214
+ Path(cache_dir).mkdir(parents=True, exist_ok=True)
215
+
216
+ # check if file is private
217
+ headers = {}
218
+ try:
219
+ r = requests.head(url, allow_redirects=False, timeout=10)
220
+ r.raise_for_status()
221
+ except requests.exceptions.HTTPError:
222
+ if r.status_code == 401:
223
+ hf_token = huggingface_hub.HfFolder.get_token()
224
+ if hf_token is None:
225
+ raise ValueError(
226
+ "You need to login to HuggingFace to download this model "
227
+ "(use the `huggingface-cli login` command)"
228
+ )
229
+ headers["Authorization"] = f"Bearer {hf_token}"
230
+
231
+ etag = None
232
+ try:
233
+ r = requests.head(url, allow_redirects=True, timeout=10, headers=headers)
234
+ r.raise_for_status()
235
+ etag = r.headers.get("X-Linked-Etag") or r.headers.get("ETag")
236
+ # We favor a custom header indicating the etag of the linked resource, and
237
+ # we fallback to the regular etag header.
238
+ # If we don't have any of those, raise an error.
239
+ if etag is None:
240
+ raise OSError(
241
+ "Distant resource does not have an ETag, we won't be able to reliably ensure reproducibility."
242
+ )
243
+ # In case of a redirect,
244
+ # save an extra redirect on the request.get call,
245
+ # and ensure we download the exact atomic version even if it changed
246
+ # between the HEAD and the GET (unlikely, but hey).
247
+ if 300 <= r.status_code <= 399:
248
+ url = r.headers["Location"]
249
+ except (requests.exceptions.SSLError, requests.exceptions.ProxyError):
250
+ # Actually raise for those subclasses of ConnectionError
251
+ raise
252
+ except (requests.exceptions.ConnectionError, requests.exceptions.Timeout):
253
+ # Otherwise, our Internet connection is down.
254
+ # etag is None
255
+ pass
256
+
257
+ # get filename from the url
258
+ filename = url_to_filename(url, etag)
259
+ # get cache path to put the file
260
+ cache_path = cache_dir / filename
261
+
262
+ # the file is already here, return it
263
+ if file_exists(cache_path) and not force_download:
264
+ logger.info(
265
+ f"{url} found in cache, set `force_download=True` to force the download"
266
+ )
267
+ return cache_path
268
+
269
+ cache_path = str(cache_path)
270
+ # Prevent parallel downloads of the same file with a lock.
271
+ lock_path = cache_path + ".lock"
272
+ with FileLock(lock_path):
273
+ # If the download just completed while the lock was activated.
274
+ if file_exists(cache_path) and not force_download:
275
+ # Even if returning early like here, the lock will be released.
276
+ return cache_path
277
+
278
+ temp_file_manager = partial(
279
+ tempfile.NamedTemporaryFile, mode="wb", dir=cache_dir, delete=False
280
+ )
281
+
282
+ # Download to temporary file, then copy to cache dir once finished.
283
+ # Otherwise, you get corrupt cache entries if the download gets interrupted.
284
+ with temp_file_manager() as temp_file:
285
+ logger.info(
286
+ f"{url} not found in cache or `force_download` set to `True`, downloading to {temp_file.name}"
287
+ )
288
+ download_resource(url, temp_file, headers)
289
+
290
+ logger.info(f"storing {url} in cache at {cache_path}")
291
+ os.replace(temp_file.name, cache_path)
292
+
293
+ # NamedTemporaryFile creates a file with hardwired 0600 perms (ignoring umask), so fixing it.
294
+ umask = os.umask(0o666)
295
+ os.umask(umask)
296
+ os.chmod(cache_path, 0o666 & ~umask)
297
+
298
+ logger.info(f"creating metadata file for {cache_path}")
299
+ meta = {"url": url} # , "etag": etag}
300
+ meta_path = cache_path + ".json"
301
+ with open(meta_path, "w") as meta_file:
302
+ json.dump(meta, meta_file)
303
+
304
+ return cache_path
305
+
306
+
307
+ def download_from_hf(
308
+ path_or_repo_id: Union[str, Path],
309
+ filenames: List[str],
310
+ cache_dir: Union[str, Path] = None,
311
+ force_download: bool = False,
312
+ resume_download: bool = False,
313
+ proxies: Optional[Dict[str, str]] = None,
314
+ use_auth_token: Optional[Union[bool, str]] = None,
315
+ revision: Optional[str] = None,
316
+ local_files_only: bool = False,
317
+ subfolder: str = "",
318
+ repo_type: str = "model",
319
+ ):
320
+ if isinstance(path_or_repo_id, Path):
321
+ path_or_repo_id = str(path_or_repo_id)
322
+
323
+ downloaded_paths = []
324
+ for filename in filenames:
325
+ downloaded_path = hf_cached_file(
326
+ path_or_repo_id,
327
+ filename,
328
+ cache_dir=cache_dir,
329
+ force_download=force_download,
330
+ proxies=proxies,
331
+ resume_download=resume_download,
332
+ use_auth_token=use_auth_token,
333
+ revision=revision,
334
+ local_files_only=local_files_only,
335
+ subfolder=subfolder,
336
+ )
337
+ downloaded_paths.append(downloaded_path)
338
+
339
+ # we want the folder where the files are downloaded
340
+ # the best guess is the parent folder of the first file
341
+ probably_the_folder = Path(downloaded_paths[0]).parent
342
+ return probably_the_folder
343
+
344
+
345
+ def model_name_or_path_resolver(model_name_or_dir: Union[str, os.PathLike]) -> str:
346
+ """
347
+ Resolve a model name or directory to a model archive name or directory.
348
+
349
+ Args:
350
+ model_name_or_dir (:obj:`str` or :obj:`os.PathLike`):
351
+ A model name or directory.
352
+
353
+ Returns:
354
+ :obj:`str`: The model archive name or directory.
355
+ """
356
+ if is_remote_url(model_name_or_dir):
357
+ # if model_name_or_dir is a URL
358
+ # download it and try to load
359
+ model_archive = model_name_or_dir
360
+ elif Path(model_name_or_dir).is_dir() or Path(model_name_or_dir).is_file():
361
+ # if model_name_or_dir is a local directory or
362
+ # an archive file try to load it
363
+ model_archive = model_name_or_dir
364
+ else:
365
+ # probably model_name_or_dir is a sapienzanlp model id
366
+ # guess the url and try to download
367
+ model_name_or_dir_ = model_name_or_dir
368
+ # raise ValueError(f"Providing a model id is not supported yet.")
369
+ model_archive = sapienzanlp_model_urls(model_name_or_dir_)
370
+
371
+ return model_archive
372
+
373
+
374
+ def from_cache(
375
+ url_or_filename: Union[str, Path],
376
+ cache_dir: Union[str, Path] = None,
377
+ force_download: bool = False,
378
+ resume_download: bool = False,
379
+ proxies: Optional[Dict[str, str]] = None,
380
+ use_auth_token: Optional[Union[bool, str]] = None,
381
+ revision: Optional[str] = None,
382
+ local_files_only: bool = False,
383
+ subfolder: str = "",
384
+ filenames: Optional[List[str]] = None,
385
+ ) -> Path:
386
+ """
387
+ Given something that could be either a local path or a URL (or a SapienzaNLP model id),
388
+ determine which one and return a path to the corresponding file.
389
+
390
+ Args:
391
+ url_or_filename (:obj:`str` or :obj:`Path`):
392
+ A path to a local file or a URL (or a SapienzaNLP model id).
393
+ cache_dir (:obj:`str` or :obj:`Path`, `optional`):
394
+ Path to a directory in which a downloaded file will be cached.
395
+ force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
396
+ Whether or not to re-download the file even if it already exists.
397
+ resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
398
+ Whether or not to delete incompletely received files. Attempts to resume the download if such a file
399
+ exists.
400
+ proxies (:obj:`Dict[str, str]`, `optional`):
401
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
402
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
403
+ use_auth_token (:obj:`Union[bool, str]`, `optional`):
404
+ Optional string or boolean to use as Bearer token for remote files. If :obj:`True`, will get token from
405
+ :obj:`~transformers.hf_api.HfApi`. If :obj:`str`, will use that string as token.
406
+ revision (:obj:`str`, `optional`):
407
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
408
+ git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
409
+ identifier allowed by git.
410
+ local_files_only (:obj:`bool`, `optional`, defaults to :obj:`False`):
411
+ Whether or not to raise an error if the file to be downloaded is local.
412
+ subfolder (:obj:`str`, `optional`):
413
+ In case the relevant file is in a subfolder of the URL, specify it here.
414
+ filenames (:obj:`List[str]`, `optional`):
415
+ List of filenames to look for in the directory structure.
416
+
417
+ Returns:
418
+ :obj:`Path`: Path to the cached file.
419
+ """
420
+
421
+ url_or_filename = model_name_or_path_resolver(url_or_filename)
422
+
423
+ if cache_dir is None:
424
+ cache_dir = SAPIENZANLP_CACHE_DIR
425
+
426
+ if file_exists(url_or_filename):
427
+ logger.info(f"{url_or_filename} is a local path or file")
428
+ output_path = url_or_filename
429
+ elif is_remote_url(url_or_filename):
430
+ # URL, so get it from the cache (downloading if necessary)
431
+ output_path = download_and_cache(
432
+ url_or_filename,
433
+ cache_dir=cache_dir,
434
+ force_download=force_download,
435
+ )
436
+ else:
437
+ if filenames is None:
438
+ filenames = [WEIGHTS_NAME, CONFIG_NAME, LABELS_NAME]
439
+ output_path = download_from_hf(
440
+ url_or_filename,
441
+ filenames,
442
+ cache_dir,
443
+ force_download,
444
+ resume_download,
445
+ proxies,
446
+ use_auth_token,
447
+ revision,
448
+ local_files_only,
449
+ subfolder,
450
+ )
451
+
452
+ # if is_hf_hub_url(url_or_filename):
453
+ # HuggingFace Hub
454
+ # output_path = hf_hub_download_url(url_or_filename)
455
+ # elif is_remote_url(url_or_filename):
456
+ # # URL, so get it from the cache (downloading if necessary)
457
+ # output_path = download_and_cache(
458
+ # url_or_filename,
459
+ # cache_dir=cache_dir,
460
+ # force_download=force_download,
461
+ # )
462
+ # elif file_exists(url_or_filename):
463
+ # logger.info(f"{url_or_filename} is a local path or file")
464
+ # # File, and it exists.
465
+ # output_path = url_or_filename
466
+ # elif urlparse(url_or_filename).scheme == "":
467
+ # # File, but it doesn't exist.
468
+ # raise EnvironmentError(f"file {url_or_filename} not found")
469
+ # else:
470
+ # # Something unknown
471
+ # raise ValueError(
472
+ # f"unable to parse {url_or_filename} as a URL or as a local path"
473
+ # )
474
+
475
+ if dir_exists(output_path) or (
476
+ not is_zipfile(output_path) and not tarfile.is_tarfile(output_path)
477
+ ):
478
+ return Path(output_path)
479
+
480
+ # Path where we extract compressed archives
481
+ # for now it will extract it in the same folder
482
+ # maybe implement extraction in the sapienzanlp folder
483
+ # when using local archive path?
484
+ logger.info("Extracting compressed archive")
485
+ output_dir, output_file = os.path.split(output_path)
486
+ output_extract_dir_name = output_file.replace(".", "-") + "-extracted"
487
+ output_path_extracted = os.path.join(output_dir, output_extract_dir_name)
488
+
489
+ # already extracted, do not extract
490
+ if (
491
+ os.path.isdir(output_path_extracted)
492
+ and os.listdir(output_path_extracted)
493
+ and not force_download
494
+ ):
495
+ return Path(output_path_extracted)
496
+
497
+ # Prevent parallel extractions
498
+ lock_path = output_path + ".lock"
499
+ with FileLock(lock_path):
500
+ shutil.rmtree(output_path_extracted, ignore_errors=True)
501
+ os.makedirs(output_path_extracted)
502
+ if is_zipfile(output_path):
503
+ with ZipFile(output_path, "r") as zip_file:
504
+ zip_file.extractall(output_path_extracted)
505
+ zip_file.close()
506
+ elif tarfile.is_tarfile(output_path):
507
+ tar_file = tarfile.open(output_path)
508
+ tar_file.extractall(output_path_extracted)
509
+ tar_file.close()
510
+ else:
511
+ raise EnvironmentError(
512
+ f"Archive format of {output_path} could not be identified"
513
+ )
514
+
515
+ # remove lock file, is it safe?
516
+ os.remove(lock_path)
517
+
518
+ return Path(output_path_extracted)
519
+
520
+
521
+ def is_str_a_path(maybe_path: str) -> bool:
522
+ """
523
+ Check if a string is a path.
524
+
525
+ Args:
526
+ maybe_path (`str`): The string to check.
527
+
528
+ Returns:
529
+ `bool`: `True` if the string is a path, `False` otherwise.
530
+ """
531
+ # first check if it is a path
532
+ if Path(maybe_path).exists():
533
+ return True
534
+ # check if it is a relative path
535
+ if Path(os.path.join(os.getcwd(), maybe_path)).exists():
536
+ return True
537
+ # otherwise it is not a path
538
+ return False
539
+
540
+
541
+ def relative_to_absolute_path(path: str) -> os.PathLike:
542
+ """
543
+ Convert a relative path to an absolute path.
544
+
545
+ Args:
546
+ path (`str`): The relative path to convert.
547
+
548
+ Returns:
549
+ `os.PathLike`: The absolute path.
550
+ """
551
+ if not is_str_a_path(path):
552
+ raise ValueError(f"{path} is not a path")
553
+ if Path(path).exists():
554
+ return Path(path).absolute()
555
+ if Path(os.path.join(os.getcwd(), path)).exists():
556
+ return Path(os.path.join(os.getcwd(), path)).absolute()
557
+ raise ValueError(f"{path} is not a path")
558
+
559
+
560
+ def to_config(object_to_save: Any) -> Dict[str, Any]:
561
+ """
562
+ Convert an object to a dictionary.
563
+
564
+ Returns:
565
+ `Dict[str, Any]`: The dictionary representation of the object.
566
+ """
567
+
568
+ def obj_to_dict(obj):
569
+ match obj:
570
+ case dict():
571
+ data = {}
572
+ for k, v in obj.items():
573
+ data[k] = obj_to_dict(v)
574
+ return data
575
+
576
+ case list() | tuple():
577
+ return [obj_to_dict(x) for x in obj]
578
+
579
+ case object(__dict__=_):
580
+ data = {
581
+ "_target_": f"{obj.__class__.__module__}.{obj.__class__.__name__}",
582
+ }
583
+ for k, v in obj.__dict__.items():
584
+ if not k.startswith("_"):
585
+ data[k] = obj_to_dict(v)
586
+ return data
587
+
588
+ case _:
589
+ return obj
590
+
591
+ return obj_to_dict(object_to_save)
592
+
593
+
594
+ def get_callable_from_string(callable_fn: str) -> Any:
595
+ """
596
+ Get a callable from a string.
597
+
598
+ Args:
599
+ callable_fn (`str`):
600
+ The string representation of the callable.
601
+
602
+ Returns:
603
+ `Any`: The callable.
604
+ """
605
+ # separate the function name from the module name
606
+ module_name, function_name = callable_fn.rsplit(".", 1)
607
+ # import the module
608
+ module = importlib.import_module(module_name)
609
+ # get the function
610
+ return getattr(module, function_name)
relik/inference/__init__.py ADDED
File without changes
relik/inference/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (183 Bytes). View file
 
relik/inference/__pycache__/annotator.cpython-310.pyc ADDED
Binary file (22.7 kB). View file
 
relik/inference/annotator.py ADDED
@@ -0,0 +1,840 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import time
4
+ from pathlib import Path
5
+ from typing import Any, Dict, List, Optional, Union
6
+
7
+ import hydra
8
+ import torch
9
+ from omegaconf import DictConfig, OmegaConf
10
+ from pprintpp import pformat
11
+
12
+ from relik.inference.data.splitters.blank_sentence_splitter import BlankSentenceSplitter
13
+ from relik.common.log import get_logger
14
+ from relik.common.upload import get_logged_in_username, upload
15
+ from relik.common.utils import CONFIG_NAME, from_cache
16
+ from relik.inference.data.objects import (
17
+ AnnotationType,
18
+ RelikOutput,
19
+ Span,
20
+ TaskType,
21
+ Triples,
22
+ )
23
+ from relik.inference.data.splitters.base_sentence_splitter import BaseSentenceSplitter
24
+ from relik.inference.data.splitters.spacy_sentence_splitter import SpacySentenceSplitter
25
+ from relik.inference.data.splitters.window_based_splitter import WindowSentenceSplitter
26
+ from relik.inference.data.tokenizers.spacy_tokenizer import SpacyTokenizer
27
+ from relik.inference.data.window.manager import WindowManager
28
+ from relik.reader.data.relik_reader_sample import RelikReaderSample
29
+ from relik.reader.pytorch_modules.base import RelikReaderBase
30
+ from relik.reader.pytorch_modules.span import RelikReaderForSpanExtraction
31
+ from relik.reader.pytorch_modules.triplet import RelikReaderForTripletExtraction
32
+ from relik.retriever.indexers.base import BaseDocumentIndex
33
+ from relik.retriever.indexers.document import Document
34
+ from relik.retriever.pytorch_modules import PRECISION_MAP
35
+ from relik.retriever.pytorch_modules.model import GoldenRetriever
36
+
37
+ # set tokenizers parallelism to False
38
+
39
+ os.environ["TOKENIZERS_PARALLELISM"] = os.getenv("TOKENIZERS_PARALLELISM", "false")
40
+
41
+ LOG_QUERY = os.getenv("RELIK_LOG_QUERY_ON_FILE", "false").lower() == "true"
42
+
43
+ logger = get_logger(__name__, level=logging.INFO)
44
+ file_logger = None
45
+ if LOG_QUERY:
46
+ RELIK_LOG_PATH = Path(__file__).parent.parent.parent / "relik.log"
47
+ # create file handler which logs even debug messages
48
+ fh = logging.FileHandler(RELIK_LOG_PATH)
49
+ fh.setLevel(logging.INFO)
50
+ file_logger = get_logger("relik", level=logging.INFO)
51
+ file_logger.addHandler(fh)
52
+
53
+
54
+ class Relik:
55
+ """
56
+ Relik main class. It is a wrapper around a retriever and a reader.
57
+
58
+ Args:
59
+ retriever (:obj:`GoldenRetriever`):
60
+ The retriever to use.
61
+ reader (:obj:`RelikReaderBase`):
62
+ The reader to use.
63
+ document_index (:obj:`BaseDocumentIndex`, `optional`):
64
+ The document index to use. If `None`, the retriever's document index will be used.
65
+ device (`str`, `optional`, defaults to `cpu`):
66
+ The device to use for both the retriever and the reader.
67
+ retriever_device (`str`, `optional`, defaults to `None`):
68
+ The device to use for the retriever. If `None`, the `device` argument will be used.
69
+ document_index_device (`str`, `optional`, defaults to `None`):
70
+ The device to use for the document index. If `None`, the `device` argument will be used.
71
+ reader_device (`str`, `optional`, defaults to `None`):
72
+ The device to use for the reader. If `None`, the `device` argument will be used.
73
+ precision (`int`, `str` or `torch.dtype`, `optional`, defaults to `32`):
74
+ The precision to use for both the retriever and the reader.
75
+ retriever_precision (`int`, `str` or `torch.dtype`, `optional`, defaults to `None`):
76
+ The precision to use for the retriever. If `None`, the `precision` argument will be used.
77
+ document_index_precision (`int`, `str` or `torch.dtype`, `optional`, defaults to `None`):
78
+ The precision to use for the document index. If `None`, the `precision` argument will be used.
79
+ reader_precision (`int`, `str` or `torch.dtype`, `optional`, defaults to `None`):
80
+ The precision to use for the reader. If `None`, the `precision` argument will be used.
81
+ metadata_fields (`list[str]`, `optional`, defaults to `None`):
82
+ The fields to add to the candidates for the reader.
83
+ top_k (`int`, `optional`, defaults to `None`):
84
+ The number of candidates to retrieve for each window.
85
+ window_size (`int`, `optional`, defaults to `None`):
86
+ The size of the window. If `None`, the whole text will be annotated.
87
+ window_stride (`int`, `optional`, defaults to `None`):
88
+ The stride of the window. If `None`, there will be no overlap between windows.
89
+ **kwargs:
90
+ Additional keyword arguments to pass to the retriever and the reader.
91
+ """
92
+
93
+ def __init__(
94
+ self,
95
+ retriever: GoldenRetriever | DictConfig | Dict | None = None,
96
+ reader: RelikReaderBase | DictConfig | None = None,
97
+ device: str | None = None,
98
+ retriever_device: str | None = None,
99
+ document_index_device: str | None = None,
100
+ reader_device: str | None = None,
101
+ precision: int | str | torch.dtype | None = None,
102
+ retriever_precision: int | str | torch.dtype | None = None,
103
+ document_index_precision: int | str | torch.dtype | None = None,
104
+ reader_precision: int | str | torch.dtype | None = None,
105
+ task: TaskType | str = TaskType.SPAN,
106
+ metadata_fields: list[str] | None = None,
107
+ top_k: int | None = None,
108
+ window_size: int | str | None = None,
109
+ window_stride: int | None = None,
110
+ retriever_kwargs: Dict[str, Any] | None = None,
111
+ reader_kwargs: Dict[str, Any] | None = None,
112
+ **kwargs,
113
+ ) -> None:
114
+ # parse task into a TaskType
115
+ if isinstance(task, str):
116
+ try:
117
+ task = TaskType(task.lower())
118
+ except ValueError:
119
+ raise ValueError(
120
+ f"Task `{task}` not recognized. "
121
+ f"Please choose one of {list(TaskType)}."
122
+ )
123
+ self.task = task
124
+
125
+ # organize devices
126
+ if device is not None:
127
+ if retriever_device is None:
128
+ retriever_device = device
129
+ if document_index_device is None:
130
+ document_index_device = device
131
+ if reader_device is None:
132
+ reader_device = device
133
+
134
+ # organize precision
135
+ if precision is not None:
136
+ if retriever_precision is None:
137
+ retriever_precision = precision
138
+ if document_index_precision is None:
139
+ document_index_precision = precision
140
+ if reader_precision is None:
141
+ reader_precision = precision
142
+
143
+ # retriever
144
+ self.retriever: Dict[TaskType, GoldenRetriever] = {
145
+ TaskType.SPAN: None,
146
+ TaskType.TRIPLET: None,
147
+ }
148
+
149
+ if retriever:
150
+ # check retriever type, it can be a GoldenRetriever, a DictConfig or a Dict
151
+ if not isinstance(retriever, (GoldenRetriever, DictConfig, Dict)):
152
+ raise ValueError(
153
+ f"`retriever` must be a `GoldenRetriever`, a `DictConfig` or "
154
+ f"a `Dict`, got `{type(retriever)}`."
155
+ )
156
+
157
+ # we need to check weather the DictConfig is a DictConfig for an instance of GoldenRetriever
158
+ # or a primitive Dict
159
+ if isinstance(retriever, DictConfig):
160
+ # then it is probably a primitive Dict
161
+ if "_target_" not in retriever:
162
+ retriever = OmegaConf.to_container(retriever, resolve=True)
163
+ # convert the key to TaskType
164
+ try:
165
+ retriever = {
166
+ TaskType(k.lower()): v for k, v in retriever.items()
167
+ }
168
+ except ValueError as e:
169
+ raise ValueError(
170
+ f"Please choose a valid task type (one of {list(TaskType)}) for each retriever."
171
+ ) from e
172
+
173
+ if isinstance(retriever, Dict):
174
+ # convert the key to TaskType
175
+ retriever = {TaskType(k): v for k, v in retriever.items()}
176
+ else:
177
+ retriever = {task: retriever}
178
+
179
+ # instantiate each retriever
180
+ if self.task in [TaskType.SPAN, TaskType.BOTH]:
181
+ self.retriever[TaskType.SPAN] = self._instantiate_retriever(
182
+ retriever[TaskType.SPAN],
183
+ retriever_device,
184
+ retriever_precision,
185
+ None,
186
+ document_index_device,
187
+ document_index_precision,
188
+ )
189
+ if self.task in [TaskType.TRIPLET, TaskType.BOTH]:
190
+ self.retriever[TaskType.TRIPLET] = self._instantiate_retriever(
191
+ retriever[TaskType.TRIPLET],
192
+ retriever_device,
193
+ retriever_precision,
194
+ None,
195
+ document_index_device,
196
+ document_index_precision,
197
+ )
198
+
199
+ # clean up None retrievers from the dictionary
200
+ self.retriever = {
201
+ task_type: r for task_type, r in self.retriever.items() if r is not None
202
+ }
203
+ # torch compile
204
+ # self.retriever = {task_type: torch.compile(r, backend="onnxrt") for task_type, r in self.retriever.items()}
205
+
206
+ # reader
207
+ self.reader: RelikReaderBase | None = None
208
+ if reader:
209
+ reader = (
210
+ hydra.utils.instantiate(
211
+ reader,
212
+ device=reader_device,
213
+ precision=reader_precision,
214
+ )
215
+ if isinstance(reader, DictConfig)
216
+ else reader
217
+ )
218
+ reader.training = False
219
+ reader.eval()
220
+ if reader_device is not None:
221
+ logger.info(f"Moving reader to `{reader_device}`.")
222
+ reader.to(reader_device)
223
+ if reader_precision is not None and reader.precision != PRECISION_MAP[reader_precision]:
224
+ logger.info(
225
+ f"Setting precision of reader to `{PRECISION_MAP[reader_precision]}`."
226
+ )
227
+ reader.to(PRECISION_MAP[reader_precision])
228
+ self.reader = reader
229
+ # self.reader = torch.compile(self.reader, backend="tvm")
230
+
231
+ # windowization stuff
232
+ self.tokenizer = SpacyTokenizer(language="en") # TODO: parametrize?
233
+ self.sentence_splitter: BaseSentenceSplitter | None = None
234
+ self.window_manager: WindowManager | None = None
235
+
236
+ if metadata_fields is None:
237
+ metadata_fields = []
238
+ self.metadata_fields = metadata_fields
239
+
240
+ # inference params
241
+ self.top_k = top_k
242
+ self.window_size = window_size
243
+ self.window_stride = window_stride
244
+
245
+ @staticmethod
246
+ def _instantiate_retriever(
247
+ retriever,
248
+ retriever_device,
249
+ retriever_precision,
250
+ document_index,
251
+ document_index_device,
252
+ document_index_precision,
253
+ ):
254
+ if not isinstance(retriever, GoldenRetriever):
255
+ # convert to DictConfig
256
+ retriever = hydra.utils.instantiate(
257
+ OmegaConf.create(retriever),
258
+ device=retriever_device,
259
+ precision=retriever_precision,
260
+ index_device=document_index_device,
261
+ index_precision=document_index_precision,
262
+ )
263
+ retriever.training = False
264
+ retriever.eval()
265
+ if document_index is not None:
266
+ if retriever.document_index is not None:
267
+ logger.info(
268
+ "The Retriever already has a document index, replacing it with the provided one."
269
+ "If you want to keep using the old one, please do not provide a document index."
270
+ )
271
+ retriever.document_index = document_index
272
+ # we override the device and the precision of the document index if provided
273
+ if document_index_device is not None:
274
+ logger.info(f"Moving document index to `{document_index_device}`.")
275
+ retriever.document_index.to(document_index_device)
276
+ if document_index_precision is not None:
277
+ logger.info(
278
+ f"Setting precision of document index to `{PRECISION_MAP[document_index_precision]}`."
279
+ )
280
+ retriever.document_index.to(PRECISION_MAP[document_index_precision])
281
+ # retriever.document_index = document_index
282
+ # now we can move the retriever to the right device and set the precision
283
+ if retriever_device is not None:
284
+ logger.info(f"Moving retriever to `{retriever_device}`.")
285
+ retriever.to(retriever_device)
286
+ if retriever_precision is not None:
287
+ logger.info(
288
+ f"Setting precision of retriever to `{PRECISION_MAP[retriever_precision]}`."
289
+ )
290
+ retriever.to(PRECISION_MAP[retriever_precision])
291
+ return retriever
292
+
293
+ def __call__(
294
+ self,
295
+ text: str | List[str] | None = None,
296
+ windows: List[RelikReaderSample] | None = None,
297
+ candidates: List[str]
298
+ | List[Document]
299
+ | Dict[TaskType, List[Document]]
300
+ | None = None,
301
+ mentions: List[List[int]] | List[List[List[int]]] | None = None,
302
+ top_k: int | None = None,
303
+ window_size: int | None = None,
304
+ window_stride: int | None = None,
305
+ is_split_into_words: bool = False,
306
+ retriever_batch_size: int | None = 32,
307
+ reader_batch_size: int | None = 32,
308
+ return_also_windows: bool = False,
309
+ annotation_type: str | AnnotationType = AnnotationType.CHAR,
310
+ progress_bar: bool = False,
311
+ **kwargs,
312
+ ) -> Union[RelikOutput, list[RelikOutput]]:
313
+ """
314
+ Annotate a text with entities.
315
+
316
+ Args:
317
+ text (`str` or `list`):
318
+ The text to annotate. If a list is provided, each element of the list
319
+ will be annotated separately.
320
+ candidates (`list[str]`, `list[Document]`, `optional`, defaults to `None`):
321
+ The candidates to use for the reader. If `None`, the candidates will be
322
+ retrieved from the retriever.
323
+ mentions (`list[list[int]]` or `list[list[list[int]]]`, `optional`, defaults to `None`):
324
+ The mentions to use for the reader. If `None`, the mentions will be
325
+ predicted by the reader.
326
+ top_k (`int`, `optional`, defaults to `None`):
327
+ The number of candidates to retrieve for each window.
328
+ window_size (`int`, `optional`, defaults to `None`):
329
+ The size of the window. If `None`, the whole text will be annotated.
330
+ window_stride (`int`, `optional`, defaults to `None`):
331
+ The stride of the window. If `None`, there will be no overlap between windows.
332
+ retriever_batch_size (`int`, `optional`, defaults to `None`):
333
+ The batch size to use for the retriever. The whole input is the batch for the retriever.
334
+ reader_batch_size (`int`, `optional`, defaults to `None`):
335
+ The batch size to use for the reader. The whole input is the batch for the reader.
336
+ return_also_windows (`bool`, `optional`, defaults to `False`):
337
+ Whether to return the windows in the output.
338
+ annotation_type (`str` or `AnnotationType`, `optional`, defaults to `char`):
339
+ The type of annotation to return. If `char`, the spans will be in terms of
340
+ character offsets. If `word`, the spans will be in terms of word offsets.
341
+ **kwargs:
342
+ Additional keyword arguments to pass to the retriever and the reader.
343
+
344
+ Returns:
345
+ `RelikOutput` or `list[RelikOutput]`:
346
+ The annotated text. If a list was provided as input, a list of
347
+ `RelikOutput` objects will be returned.
348
+ """
349
+
350
+ if text is None and windows is None:
351
+ raise ValueError(
352
+ "Either `text` or `windows` must be provided. Both are `None`."
353
+ )
354
+
355
+ if isinstance(annotation_type, str):
356
+ try:
357
+ annotation_type = AnnotationType(annotation_type)
358
+ except ValueError:
359
+ raise ValueError(
360
+ f"Annotation type {annotation_type} not recognized. "
361
+ f"Please choose one of {list(AnnotationType)}."
362
+ )
363
+
364
+ if top_k is None:
365
+ top_k = self.top_k or 100
366
+ if window_size is None:
367
+ window_size = self.window_size
368
+ if window_stride is None:
369
+ window_stride = self.window_stride
370
+
371
+ if text:
372
+ if isinstance(text, str):
373
+ text = [text]
374
+ if mentions is not None:
375
+ mentions = [mentions]
376
+ if file_logger is not None:
377
+ file_logger.info("Annotating the following text:")
378
+ for t in text:
379
+ file_logger.info(f" {t}")
380
+
381
+ if self.window_manager is None:
382
+ if window_size == "none":
383
+ self.sentence_splitter = BlankSentenceSplitter()
384
+ elif window_size == "sentence":
385
+ self.sentence_splitter = SpacySentenceSplitter()
386
+ else:
387
+ self.sentence_splitter = WindowSentenceSplitter(
388
+ window_size=window_size, window_stride=window_stride
389
+ )
390
+ self.window_manager = WindowManager(
391
+ self.tokenizer, self.sentence_splitter
392
+ )
393
+
394
+ if (
395
+ window_size not in ["sentence", "none"]
396
+ and window_stride is not None
397
+ and window_size < window_stride
398
+ ):
399
+ raise ValueError(
400
+ f"Window size ({window_size}) must be greater than window stride ({window_stride})"
401
+ )
402
+
403
+ if windows is None:
404
+ # windows were provided, use them
405
+ windows, blank_windows = self.window_manager.create_windows(
406
+ text,
407
+ window_size,
408
+ window_stride,
409
+ is_split_into_words=is_split_into_words,
410
+ mentions=mentions
411
+ )
412
+ else:
413
+ blank_windows = []
414
+ text = {w.doc_id: w.text for w in windows}
415
+
416
+ if candidates is not None and any(
417
+ r is not None for r in self.retriever.values()
418
+ ):
419
+ logger.info(
420
+ "Both candidates and a retriever were provided. "
421
+ "Retriever will be ignored."
422
+ )
423
+
424
+ windows_candidates = {TaskType.SPAN: None, TaskType.TRIPLET: None}
425
+ if candidates is not None:
426
+ # again, check if candidates is a dict
427
+ if isinstance(candidates, Dict):
428
+ if self.task not in candidates:
429
+ raise ValueError(
430
+ f"Task `{self.task}` not found in `candidates`."
431
+ f"Please choose one of {list(TaskType)}."
432
+ )
433
+ else:
434
+ candidates = {self.task: candidates}
435
+
436
+ for task_type, _candidates in candidates.items():
437
+ if isinstance(_candidates, list):
438
+ _candidates = [
439
+ [
440
+ c if isinstance(c, Document) else Document(c)
441
+ for c in _candidates[w.doc_id]
442
+ ]
443
+ for w in windows
444
+ ]
445
+ windows_candidates[task_type] = _candidates
446
+
447
+ else:
448
+ # retrieve candidates first
449
+ if self.retriever is None:
450
+ raise ValueError(
451
+ "No retriever was provided, please provide a retriever or candidates."
452
+ )
453
+ start_retr = time.time()
454
+ for task_type, retriever in self.retriever.items():
455
+ retriever_out = retriever.retrieve(
456
+ [w.text for w in windows],
457
+ text_pair=[w.doc_topic.text if w.doc_topic is not None else None for w in windows],
458
+ k=top_k,
459
+ batch_size=retriever_batch_size,
460
+ progress_bar=progress_bar,
461
+ **kwargs,
462
+ )
463
+ windows_candidates[task_type] = [
464
+ [p.document for p in predictions] for predictions in retriever_out
465
+ ]
466
+ end_retr = time.time()
467
+ logger.info(f"Retrieval took {end_retr - start_retr} seconds.")
468
+
469
+ # clean up None's
470
+ windows_candidates = {
471
+ t: c for t, c in windows_candidates.items() if c is not None
472
+ }
473
+
474
+ # add passage to the windows
475
+ for task_type, task_candidates in windows_candidates.items():
476
+ for window, candidates in zip(windows, task_candidates):
477
+ # construct the candidates for the reader
478
+ formatted_candidates = []
479
+ for candidate in candidates:
480
+ window_candidate_text = candidate.text
481
+ for field in self.metadata_fields:
482
+ window_candidate_text += f"{candidate.metadata.get(field, '')}"
483
+ formatted_candidates.append(window_candidate_text)
484
+ # create a member for the windows that is named like the task
485
+ setattr(window, f"{task_type.value}_candidates", formatted_candidates)
486
+
487
+ for task_type, task_candidates in windows_candidates.items():
488
+ for window in blank_windows:
489
+ setattr(window, f"{task_type.value}_candidates", [])
490
+ setattr(window, "predicted_spans", [])
491
+ setattr(window, "predicted_triples", [])
492
+ if self.reader is not None:
493
+ start_read = time.time()
494
+ windows = self.reader.read(
495
+ samples=windows,
496
+ max_batch_size=reader_batch_size,
497
+ annotation_type=annotation_type,
498
+ progress_bar=progress_bar,
499
+ **kwargs,
500
+ )
501
+ end_read = time.time()
502
+ logger.info(f"Reading took {end_read - start_read} seconds.")
503
+ # TODO: check merging behavior without a reader
504
+ # do we want to merge windows if there is no reader?
505
+
506
+ if self.window_size is not None and self.window_size not in ["sentence", "none"]:
507
+ start_w = time.time()
508
+ windows = windows + blank_windows
509
+ windows.sort(key=lambda x: (x.doc_id, x.offset))
510
+ merged_windows = self.window_manager.merge_windows(windows)
511
+ end_w = time.time()
512
+ logger.info(f"Merging took {end_w - start_w} seconds.")
513
+ else:
514
+ merged_windows = windows
515
+ else:
516
+ windows = windows + blank_windows
517
+ windows.sort(key=lambda x: (x.doc_id, x.offset))
518
+ merged_windows = windows
519
+
520
+ # transform predictions into RelikOutput objects
521
+ output = []
522
+ for w in merged_windows:
523
+ span_labels = []
524
+ triples_labels = []
525
+ # span extraction should always be present
526
+ if getattr(w, "predicted_spans", None) is not None:
527
+ span_labels = sorted(
528
+ [
529
+ Span(start=ss, end=se, label=sl, text=text[w.doc_id][ss:se])
530
+ if annotation_type == AnnotationType.CHAR
531
+ else Span(start=ss, end=se, label=sl, text=w.words[ss:se])
532
+ for ss, se, sl in w.predicted_spans
533
+ ],
534
+ key=lambda x: x.start,
535
+ )
536
+ # triple extraction is optional, if here add it
537
+ if getattr(w, "predicted_triples", None) is not None:
538
+ triples_labels = [
539
+ Triples(
540
+ subject=span_labels[subj],
541
+ label=label,
542
+ object=span_labels[obj],
543
+ confidence=conf,
544
+ )
545
+ for subj, label, obj, conf in w.predicted_triples
546
+ ]
547
+ # create the output
548
+ sample_output = RelikOutput(
549
+ text=text[w.doc_id],
550
+ tokens=w.words,
551
+ spans=span_labels,
552
+ triples=triples_labels,
553
+ candidates={
554
+ task_type: [
555
+ r.document_index.documents.get_document_from_text(c)
556
+ for c in getattr(w, f"{task_type.value}_candidates", [])
557
+ if r.document_index.documents.get_document_from_text(c) is not None
558
+ ]
559
+ for task_type, r in self.retriever.items()
560
+ },
561
+ )
562
+ output.append(sample_output)
563
+
564
+ # add windows to the output if requested
565
+ # do we want to force windows to be returned if there is no reader?
566
+ if return_also_windows:
567
+ for i, sample_output in enumerate(output):
568
+ sample_output.windows = [w for w in windows if w.doc_id == i]
569
+
570
+ # if only one text was provided, return a single RelikOutput object
571
+ if len(output) == 1:
572
+ return output[0]
573
+
574
+ return output
575
+
576
+ @classmethod
577
+ def from_pretrained(
578
+ cls,
579
+ model_name_or_dir: Union[str, os.PathLike],
580
+ config_file_name: str = CONFIG_NAME,
581
+ *args,
582
+ **kwargs,
583
+ ) -> "Relik":
584
+ """
585
+ Instantiate a `Relik` from a pretrained model.
586
+
587
+ Args:
588
+ model_name_or_dir (`str` or `os.PathLike`):
589
+ The name or path of the model to load.
590
+ config_file_name (`str`, `optional`, defaults to `config.yaml`):
591
+ The name of the configuration file to load.
592
+ *args:
593
+ Additional positional arguments to pass to `OmegaConf.merge`.
594
+ **kwargs:
595
+ Additional keyword arguments to pass to `OmegaConf.merge`.
596
+
597
+ Returns:
598
+ `Relik`:
599
+ The instantiated `Relik`.
600
+
601
+ """
602
+ cache_dir = kwargs.pop("cache_dir", None)
603
+ force_download = kwargs.pop("force_download", False)
604
+
605
+ model_dir = from_cache(
606
+ model_name_or_dir,
607
+ filenames=[config_file_name],
608
+ cache_dir=cache_dir,
609
+ force_download=force_download,
610
+ )
611
+
612
+ config_path = model_dir / config_file_name
613
+ if not config_path.exists():
614
+ raise FileNotFoundError(
615
+ f"Model configuration file not found at {config_path}."
616
+ )
617
+
618
+ # overwrite config with config_kwargs
619
+ config = OmegaConf.load(config_path)
620
+ # if kwargs is not None:
621
+ config = OmegaConf.merge(config, OmegaConf.create(kwargs))
622
+ # do we want to print the config? I like it
623
+ logger.info(f"Loading Relik from {model_name_or_dir}")
624
+ logger.info(pformat(OmegaConf.to_container(config)))
625
+
626
+ # load relik from config
627
+ relik = hydra.utils.instantiate(config, _recursive_=False, *args)
628
+
629
+ return relik
630
+
631
+ def save_pretrained(
632
+ self,
633
+ output_dir: Union[str, os.PathLike],
634
+ config: Optional[Dict[str, Any]] = None,
635
+ config_file_name: Optional[str] = None,
636
+ save_weights: bool = False,
637
+ push_to_hub: bool = False,
638
+ model_id: Optional[str] = None,
639
+ organization: Optional[str] = None,
640
+ repo_name: Optional[str] = None,
641
+ retriever_model_id: Optional[str] = None,
642
+ reader_model_id: Optional[str] = None,
643
+ **kwargs,
644
+ ):
645
+ """
646
+ Save the configuration of Relik to the specified directory as a YAML file.
647
+
648
+ Args:
649
+ output_dir (`str`):
650
+ The directory to save the configuration file to.
651
+ config (`Optional[Dict[str, Any]]`, `optional`):
652
+ The configuration to save. If `None`, the current configuration will be
653
+ saved. Defaults to `None`.
654
+ config_file_name (`Optional[str]`, `optional`):
655
+ The name of the configuration file. Defaults to `config.yaml`.
656
+ save_weights (`bool`, `optional`):
657
+ Whether to save the weights of the model. Defaults to `False`.
658
+ push_to_hub (`bool`, `optional`):
659
+ Whether to push the saved model to the hub. Defaults to `False`.
660
+ model_id (`Optional[str]`, `optional`):
661
+ The id of the model to push to the hub. If `None`, the name of the
662
+ directory will be used. Defaults to `None`.
663
+ organization (`Optional[str]`, `optional`):
664
+ The organization to push the model to. Defaults to `None`.
665
+ repo_name (`Optional[str]`, `optional`):
666
+ The name of the repository to push the model to. Defaults to `None`.
667
+ retriever_model_id (`Optional[str]`, `optional`):
668
+ The id of the retriever model to push to the hub. If `None`, the name of the
669
+ directory will be used. Defaults to `None`.
670
+ reader_model_id (`Optional[str]`, `optional`):
671
+ The id of the reader model to push to the hub. If `None`, the name of the
672
+ directory will be used. Defaults to `None`.
673
+ **kwargs:
674
+ Additional keyword arguments to pass to `OmegaConf.save`.
675
+ """
676
+ # create the output directory
677
+ output_dir = Path(output_dir)
678
+ output_dir.mkdir(parents=True, exist_ok=True)
679
+
680
+ retrievers_names: Dict[TaskType, Dict | None] = {
681
+ TaskType.SPAN: {
682
+ "question_encoder_name": None,
683
+ "passage_encoder_name": None,
684
+ "document_index_name": None,
685
+ },
686
+ TaskType.TRIPLET: {
687
+ "question_encoder_name": None,
688
+ "passage_encoder_name": None,
689
+ "document_index_name": None,
690
+ },
691
+ }
692
+
693
+ if save_weights:
694
+ # save weights
695
+ # retriever
696
+ model_id = model_id or output_dir.name
697
+ retriever_model_id = retriever_model_id or f"retriever-{model_id}"
698
+ for task_type, retriever in self.retriever.items():
699
+ if retriever is None:
700
+ continue
701
+ task_retriever_model_id = f"{retriever_model_id}-{task_type.value}"
702
+ question_encoder_name = f"{task_retriever_model_id}-question-encoder"
703
+ passage_encoder_name = f"{task_retriever_model_id}-passage-encoder"
704
+ document_index_name = f"{task_retriever_model_id}-index"
705
+ logger.info(
706
+ f"Saving retriever to {output_dir / task_retriever_model_id}"
707
+ )
708
+ retriever.save_pretrained(
709
+ output_dir / task_retriever_model_id,
710
+ question_encoder_name=question_encoder_name,
711
+ passage_encoder_name=passage_encoder_name,
712
+ document_index_name=document_index_name,
713
+ push_to_hub=push_to_hub,
714
+ organization=organization,
715
+ **kwargs,
716
+ )
717
+ retrievers_names[task_type] = {
718
+ "reader_model_id": task_retriever_model_id,
719
+ "question_encoder_name": question_encoder_name,
720
+ "passage_encoder_name": passage_encoder_name,
721
+ "document_index_name": document_index_name,
722
+ }
723
+
724
+ # reader
725
+ reader_model_id = reader_model_id or f"reader-{model_id}"
726
+ logger.info(f"Saving reader to {output_dir / reader_model_id}")
727
+ self.reader.save_pretrained(
728
+ output_dir / reader_model_id,
729
+ push_to_hub=push_to_hub,
730
+ organization=organization,
731
+ **kwargs,
732
+ )
733
+
734
+ if push_to_hub:
735
+ user = organization or get_logged_in_username()
736
+ # we need to update the config with the model ids that will
737
+ # result from the push to hub
738
+ for task_type, retriever_names in retrievers_names.items():
739
+ retriever_names[
740
+ "question_encoder_name"
741
+ ] = f"{user}/{retriever_names['question_encoder_name']}"
742
+ retriever_names[
743
+ "passage_encoder_name"
744
+ ] = f"{user}/{retriever_names['passage_encoder_name']}"
745
+ retriever_names[
746
+ "document_index_name"
747
+ ] = f"{user}/{retriever_names['document_index_name']}"
748
+ # question_encoder_name = f"{user}/{question_encoder_name}"
749
+ # passage_encoder_name = f"{user}/{passage_encoder_name}"
750
+ # document_index_name = f"{user}/{document_index_name}"
751
+ reader_model_id = f"{user}/{reader_model_id}"
752
+ else:
753
+ for task_type, retriever_names in retrievers_names.items():
754
+ retriever_names["question_encoder_name"] = (
755
+ output_dir / retriever_names["question_encoder_name"]
756
+ )
757
+ retriever_names["passage_encoder_name"] = (
758
+ output_dir / retriever_names["passage_encoder_name"]
759
+ )
760
+ retriever_names["document_index_name"] = (
761
+ output_dir / retriever_names["document_index_name"]
762
+ )
763
+ reader_model_id = output_dir / reader_model_id
764
+ else:
765
+ # save config only
766
+ for task_type, retriever_names in retrievers_names.items():
767
+ retriever = self.retriever.get(task_type, None)
768
+ if retriever is None:
769
+ continue
770
+ retriever_names[
771
+ "question_encoder_name"
772
+ ] = retriever.question_encoder.name_or_path
773
+ retriever_names[
774
+ "passage_encoder_name"
775
+ ] = retriever.passage_encoder.name_or_path
776
+ retriever_names[
777
+ "document_index_name"
778
+ ] = retriever.document_index.name_or_path
779
+
780
+ reader_model_id = self.reader.name_or_path
781
+
782
+ if config is None:
783
+ # create a default config
784
+ config = {
785
+ "_target_": f"{self.__class__.__module__}.{self.__class__.__name__}"
786
+ }
787
+ if self.retriever is not None:
788
+ config["retriever"] = {}
789
+ for task_type, retriever in self.retriever.items():
790
+ if retriever is None:
791
+ continue
792
+ config["retriever"][task_type.value] = {
793
+ "_target_": f"{retriever.__class__.__module__}.{retriever.__class__.__name__}",
794
+ }
795
+ if retriever.question_encoder is not None:
796
+ config["retriever"][task_type.value][
797
+ "question_encoder"
798
+ ] = retrievers_names[task_type]["question_encoder_name"]
799
+ if (
800
+ retriever.passage_encoder is not None
801
+ and not retriever.passage_encoder_is_question_encoder
802
+ ):
803
+ config["retriever"][task_type.value][
804
+ "passage_encoder"
805
+ ] = retrievers_names[task_type]["passage_encoder_name"]
806
+ if retriever.document_index is not None:
807
+ config["retriever"][task_type.value][
808
+ "document_index"
809
+ ] = retrievers_names[task_type]["document_index_name"]
810
+ if self.reader is not None:
811
+ config["reader"] = {
812
+ "_target_": f"{self.reader.__class__.__module__}.{self.reader.__class__.__name__}",
813
+ "transformer_model": reader_model_id,
814
+ }
815
+
816
+ # these are model-specific and should be saved
817
+ config["task"] = self.task
818
+ config["metadata_fields"] = self.metadata_fields
819
+ config["top_k"] = self.top_k
820
+ config["window_size"] = self.window_size
821
+ config["window_stride"] = self.window_stride
822
+
823
+ config_file_name = config_file_name or CONFIG_NAME
824
+
825
+ logger.info(f"Saving relik config to {output_dir / config_file_name}")
826
+ # pretty print the config
827
+ logger.info(pformat(config))
828
+ OmegaConf.save(config, output_dir / config_file_name)
829
+
830
+ if push_to_hub:
831
+ # push to hub
832
+ logger.info("Pushing to hub")
833
+ model_id = model_id or output_dir.name
834
+ upload(
835
+ output_dir,
836
+ model_id,
837
+ filenames=[config_file_name],
838
+ organization=organization,
839
+ repo_name=repo_name,
840
+ )
relik/inference/data/__init__.py ADDED
File without changes
relik/inference/data/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (188 Bytes). View file
 
relik/inference/data/__pycache__/objects.cpython-310.pyc ADDED
Binary file (3.24 kB). View file
 
relik/inference/data/objects.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Dict, List, NamedTuple, Optional
5
+
6
+ from relik.reader.pytorch_modules.hf.modeling_relik import RelikReaderSample
7
+ from relik.retriever.indexers.document import Document
8
+
9
+
10
+ @dataclass
11
+ class Word:
12
+ """
13
+ A word representation that includes text, index in the sentence, POS tag, lemma,
14
+ dependency relation, and similar information.
15
+
16
+ # Parameters
17
+ text : `str`, optional
18
+ The text representation.
19
+ index : `int`, optional
20
+ The word offset in the sentence.
21
+ lemma : `str`, optional
22
+ The lemma of this word.
23
+ pos : `str`, optional
24
+ The coarse-grained part of speech of this word.
25
+ dep : `str`, optional
26
+ The dependency relation for this word.
27
+
28
+ input_id : `int`, optional
29
+ Integer representation of the word, used to pass it to a model.
30
+ token_type_id : `int`, optional
31
+ Token type id used by some transformers.
32
+ attention_mask: `int`, optional
33
+ Attention mask used by transformers, indicates to the model which tokens should
34
+ be attended to, and which should not.
35
+ """
36
+
37
+ text: str
38
+ i: int
39
+ idx: Optional[int] = None
40
+ idx_end: Optional[int] = None
41
+ # preprocessing fields
42
+ lemma: Optional[str] = None
43
+ pos: Optional[str] = None
44
+ dep: Optional[str] = None
45
+ head: Optional[int] = None
46
+
47
+ def __str__(self):
48
+ return self.text
49
+
50
+ def __repr__(self):
51
+ return self.__str__()
52
+
53
+
54
+ class Span(NamedTuple):
55
+ start: int
56
+ end: int
57
+ label: str
58
+ text: str
59
+
60
+
61
+ class Triples(NamedTuple):
62
+ subject: Span
63
+ label: str
64
+ object: Span
65
+ confidence: float
66
+
67
+ @dataclass
68
+ class RelikOutput:
69
+ text: str
70
+ tokens: List[str]
71
+ spans: List[Span]
72
+ triples: List[Triples]
73
+ candidates: Dict[TaskType, List[Document]]
74
+ windows: Optional[List[RelikReaderSample]] = None
75
+
76
+
77
+ from enum import Enum
78
+
79
+
80
+ class AnnotationType(Enum):
81
+ CHAR = "char"
82
+ WORD = "word"
83
+
84
+
85
+ class TaskType(Enum):
86
+ SPAN = "span"
87
+ TRIPLET = "triplet"
88
+ BOTH = "both"
relik/inference/data/splitters/__init__.py ADDED
File without changes
relik/inference/data/splitters/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (198 Bytes). View file
 
relik/inference/data/splitters/__pycache__/base_sentence_splitter.cpython-310.pyc ADDED
Binary file (2.38 kB). View file
 
relik/inference/data/splitters/__pycache__/blank_sentence_splitter.cpython-310.pyc ADDED
Binary file (1.6 kB). View file
 
relik/inference/data/splitters/__pycache__/spacy_sentence_splitter.cpython-310.pyc ADDED
Binary file (5.31 kB). View file
 
relik/inference/data/splitters/__pycache__/window_based_splitter.cpython-310.pyc ADDED
Binary file (2.49 kB). View file
 
relik/inference/data/splitters/base_sentence_splitter.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Union
2
+
3
+
4
+ class BaseSentenceSplitter:
5
+ """
6
+ A `BaseSentenceSplitter` splits strings into sentences.
7
+ """
8
+
9
+ def __call__(self, *args, **kwargs):
10
+ """
11
+ Calls :meth:`split_sentences`.
12
+ """
13
+ return self.split_sentences(*args, **kwargs)
14
+
15
+ def split_sentences(
16
+ self, text: str, max_len: int = 0, *args, **kwargs
17
+ ) -> List[str]:
18
+ """
19
+ Splits a `text` :class:`str` paragraph into a list of :class:`str`, where each is a sentence.
20
+ """
21
+ raise NotImplementedError
22
+
23
+ def split_sentences_batch(
24
+ self, texts: List[str], *args, **kwargs
25
+ ) -> List[List[str]]:
26
+ """
27
+ Default implementation is to just iterate over the texts and call `split_sentences`.
28
+ """
29
+ return [self.split_sentences(text) for text in texts]
30
+
31
+ @staticmethod
32
+ def check_is_batched(
33
+ texts: Union[str, List[str], List[List[str]]], is_split_into_words: bool
34
+ ):
35
+ """
36
+ Check if input is batched or a single sample.
37
+
38
+ Args:
39
+ texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`):
40
+ Text to check.
41
+ is_split_into_words (:obj:`bool`):
42
+ If :obj:`True` and the input is a string, the input is split on spaces.
43
+
44
+ Returns:
45
+ :obj:`bool`: ``True`` if ``texts`` is batched, ``False`` otherwise.
46
+ """
47
+ return bool(
48
+ (not is_split_into_words and isinstance(texts, (list, tuple)))
49
+ or (
50
+ is_split_into_words
51
+ and isinstance(texts, (list, tuple))
52
+ and texts
53
+ and isinstance(texts[0], (list, tuple))
54
+ )
55
+ )
relik/inference/data/splitters/blank_sentence_splitter.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Union
2
+
3
+
4
+ class BlankSentenceSplitter:
5
+ """
6
+ A `BlankSentenceSplitter` splits strings into sentences.
7
+ """
8
+
9
+ def __call__(self, *args, **kwargs):
10
+ """
11
+ Calls :meth:`split_sentences`.
12
+ """
13
+ return self.split_sentences(*args, **kwargs)
14
+
15
+ def split_sentences(
16
+ self, text: str, max_len: int = 0, *args, **kwargs
17
+ ) -> List[str]:
18
+ """
19
+ Splits a `text` :class:`str` paragraph into a list of :class:`str`, where each is a sentence.
20
+ """
21
+ return [text]
22
+
23
+ def split_sentences_batch(
24
+ self, texts: List[str], *args, **kwargs
25
+ ) -> List[List[str]]:
26
+ """
27
+ Default implementation is to just iterate over the texts and call `split_sentences`.
28
+ """
29
+ return [self.split_sentences(text) for text in texts]
relik/inference/data/splitters/spacy_sentence_splitter.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Iterable, List, Optional, Union
2
+
3
+ import spacy
4
+
5
+ from relik.inference.data.objects import Word
6
+ from relik.inference.data.splitters.base_sentence_splitter import BaseSentenceSplitter
7
+ from relik.inference.data.tokenizers.spacy_tokenizer import load_spacy
8
+
9
+ SPACY_LANGUAGE_MAPPER = {
10
+ "cs": "xx_sent_ud_sm",
11
+ "da": "xx_sent_ud_sm",
12
+ "de": "xx_sent_ud_sm",
13
+ "fa": "xx_sent_ud_sm",
14
+ "fi": "xx_sent_ud_sm",
15
+ "fr": "xx_sent_ud_sm",
16
+ "el": "el_core_news_sm",
17
+ "en": "xx_sent_ud_sm",
18
+ "es": "xx_sent_ud_sm",
19
+ "ga": "xx_sent_ud_sm",
20
+ "hr": "xx_sent_ud_sm",
21
+ "id": "xx_sent_ud_sm",
22
+ "it": "xx_sent_ud_sm",
23
+ "ja": "ja_core_news_sm",
24
+ "lv": "xx_sent_ud_sm",
25
+ "lt": "xx_sent_ud_sm",
26
+ "mr": "xx_sent_ud_sm",
27
+ "nb": "xx_sent_ud_sm",
28
+ "nl": "xx_sent_ud_sm",
29
+ "no": "xx_sent_ud_sm",
30
+ "pl": "pl_core_news_sm",
31
+ "pt": "xx_sent_ud_sm",
32
+ "ro": "xx_sent_ud_sm",
33
+ "ru": "xx_sent_ud_sm",
34
+ "sk": "xx_sent_ud_sm",
35
+ "sr": "xx_sent_ud_sm",
36
+ "sv": "xx_sent_ud_sm",
37
+ "te": "xx_sent_ud_sm",
38
+ "vi": "xx_sent_ud_sm",
39
+ "zh": "zh_core_web_sm",
40
+ }
41
+
42
+
43
+ class SpacySentenceSplitter(BaseSentenceSplitter):
44
+ """
45
+ A :obj:`SentenceSplitter` that uses spaCy's built-in sentence boundary detection.
46
+
47
+ Args:
48
+ language (:obj:`str`, optional, defaults to :obj:`en`):
49
+ Language of the text to tokenize.
50
+ model_type (:obj:`str`, optional, defaults to :obj:`statistical`):
51
+ Three different type of sentence splitter:
52
+ - ``dependency``: sentence splitter uses a dependency parse to detect sentence boundaries,
53
+ slow, but accurate.
54
+ - ``statistical``:
55
+ - ``rule_based``: It's fast and has a small memory footprint, since it uses punctuation to detect
56
+ sentence boundaries.
57
+ """
58
+
59
+ def __init__(self, language: str = "en", model_type: str = "statistical") -> None:
60
+ # we need spacy's dependency parser if we're not using rule-based sentence boundary detection.
61
+ # self.spacy = get_spacy_model(language, parse=not rule_based, ner=False)
62
+ dep = bool(model_type == "dependency")
63
+ if language in SPACY_LANGUAGE_MAPPER:
64
+ self.spacy = load_spacy(SPACY_LANGUAGE_MAPPER[language], parse=dep)
65
+ else:
66
+ self.spacy = spacy.blank(language)
67
+ # force type to rule_based since there is no pre-trained model
68
+ model_type = "rule_based"
69
+ if model_type == "dependency":
70
+ # dependency type must declared at model init
71
+ pass
72
+ elif model_type == "statistical":
73
+ if not self.spacy.has_pipe("senter"):
74
+ self.spacy.enable_pipe("senter")
75
+ elif model_type == "rule_based":
76
+ # we use `sentencizer`, a built-in spacy module for rule-based sentence boundary detection.
77
+ # depending on the spacy version, it could be called 'sentencizer' or 'sbd'
78
+ if not self.spacy.has_pipe("sentencizer"):
79
+ self.spacy.add_pipe("sentencizer")
80
+ else:
81
+ raise ValueError(
82
+ f"type {model_type} not supported. Choose between `dependency`, `statistical` or `rule_based`"
83
+ )
84
+
85
+ def __call__(
86
+ self,
87
+ texts: Union[str, List[str], List[List[str]]],
88
+ max_length: Optional[int] = None,
89
+ is_split_into_words: bool = False,
90
+ **kwargs,
91
+ ) -> Union[List[str], List[List[str]]]:
92
+ """
93
+ Tokenize the input into single words using SpaCy models.
94
+
95
+ Args:
96
+ texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`):
97
+ Text to tag. It can be a single string, a batch of string and pre-tokenized strings.
98
+ max_len (:obj:`int`, optional, defaults to :obj:`0`):
99
+ Maximum length of a single text. If the text is longer than `max_len`, it will be split
100
+ into multiple sentences.
101
+
102
+ Returns:
103
+ :obj:`List[List[str]]`: The input doc split into sentences.
104
+ """
105
+ # check if input is batched or a single sample
106
+ is_batched = self.check_is_batched(texts, is_split_into_words)
107
+
108
+ if is_batched:
109
+ sents = self.split_sentences_batch(texts)
110
+ else:
111
+ sents = self.split_sentences(texts, max_length)
112
+ return sents
113
+
114
+ @staticmethod
115
+ def chunked(iterable, n: int) -> Iterable[List[Any]]:
116
+ """
117
+ Chunks a list into n sized chunks.
118
+
119
+ Args:
120
+ iterable (:obj:`List[Any]`):
121
+ List to chunk.
122
+ n (:obj:`int`):
123
+ Size of the chunks.
124
+
125
+ Returns:
126
+ :obj:`Iterable[List[Any]]`: The input list chunked into n sized chunks.
127
+ """
128
+ return [iterable[i : i + n] for i in range(0, len(iterable), n)]
129
+
130
+ def split_sentences(
131
+ self, text: str | List[Word], max_length: Optional[int] = None, *args, **kwargs
132
+ ) -> List[str]:
133
+ """
134
+ Splits a `text` into smaller sentences.
135
+
136
+ Args:
137
+ text (:obj:`str`):
138
+ Text to split.
139
+ max_length (:obj:`int`, optional, defaults to :obj:`0`):
140
+ Maximum length of a single sentence. If the text is longer than `max_len`, it will be split
141
+ into multiple sentences.
142
+
143
+ Returns:
144
+ :obj:`List[str]`: The input text split into sentences.
145
+ """
146
+ sentences = [sent for sent in self.spacy(text).sents]
147
+ if max_length is not None and max_length > 0:
148
+ sentences = [
149
+ chunk
150
+ for sentence in sentences
151
+ for chunk in self.chunked(sentence, max_length)
152
+ ]
153
+ return sentences
relik/inference/data/splitters/window_based_splitter.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Union
2
+
3
+ from relik.inference.data.splitters.base_sentence_splitter import BaseSentenceSplitter
4
+
5
+
6
+ class WindowSentenceSplitter(BaseSentenceSplitter):
7
+ """
8
+ A :obj:`WindowSentenceSplitter` that splits a text into windows of a given size.
9
+ """
10
+
11
+ def __init__(self, window_size: int, window_stride: int, *args, **kwargs) -> None:
12
+ super(WindowSentenceSplitter, self).__init__()
13
+ self.window_size = window_size
14
+ self.window_stride = window_stride
15
+
16
+ def __call__(
17
+ self,
18
+ texts: Union[str, List[str], List[List[str]]],
19
+ is_split_into_words: bool = False,
20
+ **kwargs,
21
+ ) -> Union[List[str], List[List[str]]]:
22
+ """
23
+ Tokenize the input into single words using SpaCy models.
24
+
25
+ Args:
26
+ texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`):
27
+ Text to tag. It can be a single string, a batch of string and pre-tokenized strings.
28
+
29
+ Returns:
30
+ :obj:`List[List[str]]`: The input doc split into sentences.
31
+ """
32
+ return self.split_sentences(texts)
33
+
34
+ def split_sentences(self, text: str | List, *args, **kwargs) -> List[List]:
35
+ """
36
+ Splits a `text` into sentences.
37
+
38
+ Args:
39
+ text (:obj:`str`):
40
+ Text to split.
41
+
42
+ Returns:
43
+ :obj:`List[str]`: The input text split into sentences.
44
+ """
45
+
46
+ if isinstance(text, str):
47
+ text = text.split()
48
+ sentences = []
49
+ for i in range(0, len(text), self.window_stride):
50
+ # if the last stride is smaller than the window size, then we can
51
+ # include more tokens form the previous window.
52
+ if i != 0 and i + self.window_size > len(text):
53
+ overflowing_tokens = i + self.window_size - len(text)
54
+ if overflowing_tokens >= self.window_stride:
55
+ break
56
+ i -= overflowing_tokens
57
+ involved_token_indices = list(
58
+ range(i, min(i + self.window_size, len(text)))
59
+ )
60
+ window_tokens = [text[j] for j in involved_token_indices]
61
+ sentences.append(window_tokens)
62
+ return sentences
relik/inference/data/tokenizers/__init__.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ SPACY_LANGUAGE_MAPPER = {
2
+ "ca": "ca_core_news_sm",
3
+ "da": "da_core_news_sm",
4
+ "de": "de_core_news_sm",
5
+ "el": "el_core_news_sm",
6
+ "en": "en_core_web_sm",
7
+ "es": "es_core_news_sm",
8
+ "fr": "fr_core_news_sm",
9
+ "it": "it_core_news_sm",
10
+ "ja": "ja_core_news_sm",
11
+ "lt": "lt_core_news_sm",
12
+ "mk": "mk_core_news_sm",
13
+ "nb": "nb_core_news_sm",
14
+ "nl": "nl_core_news_sm",
15
+ "pl": "pl_core_news_sm",
16
+ "pt": "pt_core_news_sm",
17
+ "ro": "ro_core_news_sm",
18
+ "ru": "ru_core_news_sm",
19
+ "xx": "xx_sent_ud_sm",
20
+ "zh": "zh_core_web_sm",
21
+ "ca_core_news_sm": "ca_core_news_sm",
22
+ "ca_core_news_md": "ca_core_news_md",
23
+ "ca_core_news_lg": "ca_core_news_lg",
24
+ "ca_core_news_trf": "ca_core_news_trf",
25
+ "da_core_news_sm": "da_core_news_sm",
26
+ "da_core_news_md": "da_core_news_md",
27
+ "da_core_news_lg": "da_core_news_lg",
28
+ "da_core_news_trf": "da_core_news_trf",
29
+ "de_core_news_sm": "de_core_news_sm",
30
+ "de_core_news_md": "de_core_news_md",
31
+ "de_core_news_lg": "de_core_news_lg",
32
+ "de_dep_news_trf": "de_dep_news_trf",
33
+ "el_core_news_sm": "el_core_news_sm",
34
+ "el_core_news_md": "el_core_news_md",
35
+ "el_core_news_lg": "el_core_news_lg",
36
+ "en_core_web_sm": "en_core_web_sm",
37
+ "en_core_web_md": "en_core_web_md",
38
+ "en_core_web_lg": "en_core_web_lg",
39
+ "en_core_web_trf": "en_core_web_trf",
40
+ "es_core_news_sm": "es_core_news_sm",
41
+ "es_core_news_md": "es_core_news_md",
42
+ "es_core_news_lg": "es_core_news_lg",
43
+ "es_dep_news_trf": "es_dep_news_trf",
44
+ "fr_core_news_sm": "fr_core_news_sm",
45
+ "fr_core_news_md": "fr_core_news_md",
46
+ "fr_core_news_lg": "fr_core_news_lg",
47
+ "fr_dep_news_trf": "fr_dep_news_trf",
48
+ "it_core_news_sm": "it_core_news_sm",
49
+ "it_core_news_md": "it_core_news_md",
50
+ "it_core_news_lg": "it_core_news_lg",
51
+ "ja_core_news_sm": "ja_core_news_sm",
52
+ "ja_core_news_md": "ja_core_news_md",
53
+ "ja_core_news_lg": "ja_core_news_lg",
54
+ "ja_dep_news_trf": "ja_dep_news_trf",
55
+ "lt_core_news_sm": "lt_core_news_sm",
56
+ "lt_core_news_md": "lt_core_news_md",
57
+ "lt_core_news_lg": "lt_core_news_lg",
58
+ "mk_core_news_sm": "mk_core_news_sm",
59
+ "mk_core_news_md": "mk_core_news_md",
60
+ "mk_core_news_lg": "mk_core_news_lg",
61
+ "nb_core_news_sm": "nb_core_news_sm",
62
+ "nb_core_news_md": "nb_core_news_md",
63
+ "nb_core_news_lg": "nb_core_news_lg",
64
+ "nl_core_news_sm": "nl_core_news_sm",
65
+ "nl_core_news_md": "nl_core_news_md",
66
+ "nl_core_news_lg": "nl_core_news_lg",
67
+ "pl_core_news_sm": "pl_core_news_sm",
68
+ "pl_core_news_md": "pl_core_news_md",
69
+ "pl_core_news_lg": "pl_core_news_lg",
70
+ "pt_core_news_sm": "pt_core_news_sm",
71
+ "pt_core_news_md": "pt_core_news_md",
72
+ "pt_core_news_lg": "pt_core_news_lg",
73
+ "ro_core_news_sm": "ro_core_news_sm",
74
+ "ro_core_news_md": "ro_core_news_md",
75
+ "ro_core_news_lg": "ro_core_news_lg",
76
+ "ru_core_news_sm": "ru_core_news_sm",
77
+ "ru_core_news_md": "ru_core_news_md",
78
+ "ru_core_news_lg": "ru_core_news_lg",
79
+ "xx_ent_wiki_sm": "xx_ent_wiki_sm",
80
+ "xx_sent_ud_sm": "xx_sent_ud_sm",
81
+ "zh_core_web_sm": "zh_core_web_sm",
82
+ "zh_core_web_md": "zh_core_web_md",
83
+ "zh_core_web_lg": "zh_core_web_lg",
84
+ "zh_core_web_trf": "zh_core_web_trf",
85
+ }
86
+
87
+ from relik.inference.data.tokenizers.spacy_tokenizer import SpacyTokenizer
relik/inference/data/tokenizers/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (2.31 kB). View file
 
relik/inference/data/tokenizers/__pycache__/base_tokenizer.cpython-310.pyc ADDED
Binary file (3.13 kB). View file
 
relik/inference/data/tokenizers/__pycache__/spacy_tokenizer.cpython-310.pyc ADDED
Binary file (6.55 kB). View file
 
relik/inference/data/tokenizers/base_tokenizer.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Union
2
+
3
+ from relik.inference.data.objects import Word
4
+
5
+
6
+ class BaseTokenizer:
7
+ """
8
+ A :obj:`Tokenizer` splits strings of text into single words, optionally adds
9
+ pos tags and perform lemmatization.
10
+ """
11
+
12
+ def __call__(
13
+ self,
14
+ texts: Union[str, List[str], List[List[str]]],
15
+ is_split_into_words: bool = False,
16
+ **kwargs
17
+ ) -> List[List[Word]]:
18
+ """
19
+ Tokenize the input into single words.
20
+
21
+ Args:
22
+ texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`):
23
+ Text to tag. It can be a single string, a batch of string and pre-tokenized strings.
24
+ is_split_into_words (:obj:`bool`, optional, defaults to :obj:`False`):
25
+ If :obj:`True` and the input is a string, the input is split on spaces.
26
+
27
+ Returns:
28
+ :obj:`List[List[Word]]`: The input text tokenized in single words.
29
+ """
30
+ raise NotImplementedError
31
+
32
+ def tokenize(self, text: str) -> List[Word]:
33
+ """
34
+ Implements splitting words into tokens.
35
+
36
+ Args:
37
+ text (:obj:`str`):
38
+ Text to tokenize.
39
+
40
+ Returns:
41
+ :obj:`List[Word]`: The input text tokenized in single words.
42
+
43
+ """
44
+ raise NotImplementedError
45
+
46
+ def tokenize_batch(self, texts: List[str]) -> List[List[Word]]:
47
+ """
48
+ Implements batch splitting words into tokens.
49
+
50
+ Args:
51
+ texts (:obj:`List[str]`):
52
+ Batch of text to tokenize.
53
+
54
+ Returns:
55
+ :obj:`List[List[Word]]`: The input batch tokenized in single words.
56
+
57
+ """
58
+ return [self.tokenize(text) for text in texts]
59
+
60
+ @staticmethod
61
+ def check_is_batched(
62
+ texts: Union[str, List[str], List[List[str]]], is_split_into_words: bool
63
+ ):
64
+ """
65
+ Check if input is batched or a single sample.
66
+
67
+ Args:
68
+ texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`):
69
+ Text to check.
70
+ is_split_into_words (:obj:`bool`):
71
+ If :obj:`True` and the input is a string, the input is split on spaces.
72
+
73
+ Returns:
74
+ :obj:`bool`: ``True`` if ``texts`` is batched, ``False`` otherwise.
75
+ """
76
+ return bool(
77
+ (not is_split_into_words and isinstance(texts, (list, tuple)))
78
+ or (
79
+ is_split_into_words
80
+ and isinstance(texts, (list, tuple))
81
+ and texts
82
+ and isinstance(texts[0], (list, tuple))
83
+ )
84
+ )
relik/inference/data/tokenizers/spacy_tokenizer.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from copy import deepcopy
3
+ from typing import Dict, List, Tuple, Union, Any
4
+
5
+ import spacy
6
+
7
+ # from ipa.common.utils import load_spacy
8
+ from spacy.cli.download import download as spacy_download
9
+ from spacy.tokens import Doc
10
+
11
+ from relik.common.log import get_logger
12
+ from relik.inference.data.objects import Word
13
+ from relik.inference.data.tokenizers import SPACY_LANGUAGE_MAPPER
14
+ from relik.inference.data.tokenizers.base_tokenizer import BaseTokenizer
15
+
16
+ logger = get_logger(level=logging.DEBUG)
17
+
18
+ # Spacy and Stanza stuff
19
+
20
+ LOADED_SPACY_MODELS: Dict[Tuple[str, bool, bool, bool, bool], spacy.Language] = {}
21
+
22
+
23
+ def load_spacy(
24
+ language: str,
25
+ pos_tags: bool = False,
26
+ lemma: bool = False,
27
+ parse: bool = False,
28
+ split_on_spaces: bool = False,
29
+ ) -> spacy.Language:
30
+ """
31
+ Download and load spacy model.
32
+
33
+ Args:
34
+ language (:obj:`str`, defaults to :obj:`en`):
35
+ Language of the text to tokenize.
36
+ pos_tags (:obj:`bool`, optional, defaults to :obj:`False`):
37
+ If :obj:`True`, performs POS tagging with spacy model.
38
+ lemma (:obj:`bool`, optional, defaults to :obj:`False`):
39
+ If :obj:`True`, performs lemmatization with spacy model.
40
+ parse (:obj:`bool`, optional, defaults to :obj:`False`):
41
+ If :obj:`True`, performs dependency parsing with spacy model.
42
+ split_on_spaces (:obj:`bool`, optional, defaults to :obj:`False`):
43
+ If :obj:`True`, will split by spaces without performing tokenization.
44
+
45
+ Returns:
46
+ :obj:`spacy.Language`: The spacy model loaded.
47
+ """
48
+ exclude = ["vectors", "textcat", "ner"]
49
+ if not pos_tags:
50
+ exclude.append("tagger")
51
+ if not lemma:
52
+ exclude.append("lemmatizer")
53
+ if not parse:
54
+ exclude.append("parser")
55
+
56
+ # check if the model is already loaded
57
+ # if so, there is no need to reload it
58
+ spacy_params = (language, pos_tags, lemma, parse, split_on_spaces)
59
+ if spacy_params not in LOADED_SPACY_MODELS:
60
+ try:
61
+ spacy_tagger = spacy.load(language, exclude=exclude)
62
+ except OSError:
63
+ logger.warning(
64
+ "Spacy model '%s' not found. Downloading and installing.", language
65
+ )
66
+ spacy_download(language)
67
+ spacy_tagger = spacy.load(language, exclude=exclude)
68
+
69
+ # if everything is disabled, return only the tokenizer
70
+ # for faster tokenization
71
+ # TODO: is it really faster?
72
+ # TODO: check split_on_spaces behaviour if we don't do this if
73
+ if len(exclude) >= 6 and split_on_spaces:
74
+ spacy_tagger = spacy_tagger.tokenizer
75
+ LOADED_SPACY_MODELS[spacy_params] = spacy_tagger
76
+
77
+ return LOADED_SPACY_MODELS[spacy_params]
78
+
79
+
80
+ class SpacyTokenizer(BaseTokenizer):
81
+ """
82
+ A :obj:`Tokenizer` that uses SpaCy to tokenizer and preprocess the text. It returns :obj:`Word` objects.
83
+
84
+ Args:
85
+ language (:obj:`str`, optional, defaults to :obj:`en`):
86
+ Language of the text to tokenize.
87
+ return_pos_tags (:obj:`bool`, optional, defaults to :obj:`False`):
88
+ If :obj:`True`, performs POS tagging with spacy model.
89
+ return_lemmas (:obj:`bool`, optional, defaults to :obj:`False`):
90
+ If :obj:`True`, performs lemmatization with spacy model.
91
+ return_deps (:obj:`bool`, optional, defaults to :obj:`False`):
92
+ If :obj:`True`, performs dependency parsing with spacy model.
93
+ use_gpu (:obj:`bool`, optional, defaults to :obj:`False`):
94
+ If :obj:`True`, will load the Stanza model on GPU.
95
+ """
96
+
97
+ def __init__(
98
+ self,
99
+ language: str = "en",
100
+ return_pos_tags: bool = False,
101
+ return_lemmas: bool = False,
102
+ return_deps: bool = False,
103
+ use_gpu: bool = False,
104
+ ):
105
+ super().__init__()
106
+ if language not in SPACY_LANGUAGE_MAPPER:
107
+ raise ValueError(
108
+ f"`{language}` language not supported. The supported "
109
+ f"languages are: {list(SPACY_LANGUAGE_MAPPER.keys())}."
110
+ )
111
+ if use_gpu:
112
+ # load the model on GPU
113
+ # if the GPU is not available or not correctly configured,
114
+ # it will rise an error
115
+ spacy.require_gpu()
116
+ self.spacy = load_spacy(
117
+ SPACY_LANGUAGE_MAPPER[language],
118
+ return_pos_tags,
119
+ return_lemmas,
120
+ return_deps,
121
+ )
122
+
123
+ def __call__(
124
+ self,
125
+ texts: Union[str, List[str], List[List[str]]],
126
+ is_split_into_words: bool = False,
127
+ **kwargs,
128
+ ) -> Union[List[Word], List[List[Word]]]:
129
+ """
130
+ Tokenize the input into single words using SpaCy models.
131
+
132
+ Args:
133
+ texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`):
134
+ Text to tag. It can be a single string, a batch of string and pre-tokenized strings.
135
+ is_split_into_words (:obj:`bool`, optional, defaults to :obj:`False`):
136
+ If :obj:`True` and the input is a string, the input is split on spaces.
137
+
138
+ Returns:
139
+ :obj:`List[List[Word]]`: The input text tokenized in single words.
140
+
141
+ Example::
142
+
143
+ >>> from relik.inference.data.tokenizers.spacy_tokenizer import SpacyTokenizer
144
+
145
+ >>> spacy_tokenizer = SpacyTokenizer(language="en", pos_tags=True, lemma=True)
146
+ >>> spacy_tokenizer("Mary sold the car to John.")
147
+
148
+ """
149
+ # check if input is batched or a single sample
150
+ is_batched = self.check_is_batched(texts, is_split_into_words)
151
+
152
+ if is_batched:
153
+ tokenized = self.tokenize_batch(texts, is_split_into_words)
154
+ else:
155
+ tokenized = self.tokenize(texts, is_split_into_words)
156
+
157
+ return tokenized
158
+
159
+ def tokenize(self, text: Union[str, List[str]], is_split_into_words: bool) -> Doc:
160
+ if is_split_into_words:
161
+ if isinstance(text, str):
162
+ text = text.split(" ")
163
+ elif isinstance(text, list):
164
+ text = text
165
+ else:
166
+ raise ValueError(
167
+ f"text must be either `str` or `list`, found: `{type(text)}`"
168
+ )
169
+ spaces = [True] * len(text)
170
+ return self.spacy(Doc(self.spacy.vocab, words=text, spaces=spaces))
171
+ return self.spacy(text)
172
+
173
+ def tokenize_batch(
174
+ self, texts: Union[List[str], List[List[str]]], is_split_into_words: bool
175
+ ) -> list[Any] | list[Doc]:
176
+ try:
177
+ if is_split_into_words:
178
+ if isinstance(texts[0], str):
179
+ texts = [text.split(" ") for text in texts]
180
+ elif isinstance(texts[0], list):
181
+ texts = texts
182
+ else:
183
+ raise ValueError(
184
+ f"text must be either `str` or `list`, found: `{type(texts[0])}`"
185
+ )
186
+ spaces = [[True] * len(text) for text in texts]
187
+ texts = [
188
+ Doc(self.spacy.vocab, words=text, spaces=space)
189
+ for text, space in zip(texts, spaces)
190
+ ]
191
+ return list(self.spacy.pipe(texts))
192
+ except AttributeError:
193
+ # a WhitespaceSpacyTokenizer has no `pipe()` method, we use simple for loop
194
+ return [self.spacy(tokens) for tokens in texts]
relik/inference/data/window/__init__.py ADDED
File without changes
relik/inference/data/window/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (195 Bytes). View file
 
relik/inference/data/window/__pycache__/manager.cpython-310.pyc ADDED
Binary file (11.2 kB). View file
 
relik/inference/data/window/manager.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import itertools
3
+ from typing import Dict, List, Optional, Set, Tuple
4
+
5
+ from relik.inference.data.splitters.blank_sentence_splitter import BlankSentenceSplitter
6
+ from relik.inference.data.splitters.base_sentence_splitter import BaseSentenceSplitter
7
+ from relik.inference.data.tokenizers.base_tokenizer import BaseTokenizer
8
+ from relik.reader.data.relik_reader_sample import RelikReaderSample
9
+
10
+
11
+ class WindowManager:
12
+ def __init__(
13
+ self, tokenizer: BaseTokenizer, splitter: BaseSentenceSplitter | None = None
14
+ ) -> None:
15
+ self.tokenizer = tokenizer
16
+ self.splitter = splitter or BlankSentenceSplitter()
17
+
18
+ def create_windows(
19
+ self,
20
+ documents: str | List[str],
21
+ window_size: int | None = None,
22
+ stride: int | None = None,
23
+ max_length: int | None = None,
24
+ doc_id: str | int | None = None,
25
+ doc_topic: str | None = None,
26
+ is_split_into_words: bool = False,
27
+ mentions: List[List[List[int]]] = None,
28
+ ) -> Tuple[List[RelikReaderSample], List[RelikReaderSample]]:
29
+ """
30
+ Create windows from a list of documents.
31
+
32
+ Args:
33
+ documents (:obj:`str` or :obj:`List[str]`):
34
+ The document(s) to split in windows.
35
+ window_size (:obj:`int`):
36
+ The size of the window.
37
+ stride (:obj:`int`):
38
+ The stride between two windows.
39
+ max_length (:obj:`int`, `optional`):
40
+ The maximum length of a window.
41
+ doc_id (:obj:`str` or :obj:`int`, `optional`):
42
+ The id of the document(s).
43
+ doc_topic (:obj:`str`, `optional`):
44
+ The topic of the document(s).
45
+ is_split_into_words (:obj:`bool`, `optional`, defaults to :obj:`False`):
46
+ Whether the input is already pre-tokenized (e.g., split into words). If :obj:`False`, the
47
+ input will first be tokenized using the tokenizer, then the tokens will be split into words.
48
+ mentions (:obj:`List[List[List[int]]]`, `optional`):
49
+ The mentions of the document(s).
50
+
51
+ Returns:
52
+ :obj:`List[RelikReaderSample]`: The windows created from the documents.
53
+ """
54
+ # normalize input
55
+ if isinstance(documents, str) or is_split_into_words:
56
+ documents = [documents]
57
+
58
+ # batch tokenize
59
+ documents_tokens = self.tokenizer(
60
+ documents, is_split_into_words=is_split_into_words
61
+ )
62
+
63
+ # set splitter params
64
+ if hasattr(self.splitter, "window_size"):
65
+ self.splitter.window_size = window_size or self.splitter.window_size
66
+ if hasattr(self.splitter, "window_stride"):
67
+ self.splitter.window_stride = stride or self.splitter.window_stride
68
+
69
+ windowed_documents, windowed_blank_documents = [], []
70
+
71
+ if mentions is not None:
72
+ assert len(documents) == len(
73
+ mentions
74
+ ), f"documents and mentions should have the same length, got {len(documents)} and {len(mentions)}"
75
+ doc_iter = zip(documents, documents_tokens, mentions)
76
+ else:
77
+ doc_iter = zip(documents, documents_tokens, itertools.repeat([]))
78
+
79
+ for infered_doc_id, (document, document_tokens, document_mentions) in enumerate(
80
+ doc_iter
81
+ ):
82
+ if doc_topic is None:
83
+ doc_topic = document_tokens[0] if len(document_tokens) > 0 else ""
84
+
85
+ if doc_id is None:
86
+ doc_id = infered_doc_id
87
+
88
+ splitted_document = self.splitter(document_tokens, max_length=max_length)
89
+
90
+ document_windows = []
91
+ for window_id, window in enumerate(splitted_document):
92
+ window_text_start = window[0].idx
93
+ window_text_end = window[-1].idx + len(window[-1].text)
94
+ if isinstance(document, str):
95
+ text = document[window_text_start:window_text_end]
96
+ else:
97
+ # window_text_start = window[0].idx
98
+ # window_text_end = window[-1].i
99
+ text = " ".join([w.text for w in window])
100
+ sample = RelikReaderSample(
101
+ doc_id=doc_id,
102
+ window_id=window_id,
103
+ text=text,
104
+ tokens=[w.text for w in window],
105
+ words=[w.text for w in window],
106
+ doc_topic=doc_topic,
107
+ offset=window_text_start,
108
+ spans=[
109
+ [m[0], m[1]] for m in document_mentions
110
+ if window_text_end > m[0] >= window_text_start and window_text_end >= m[1] >= window_text_start
111
+ ],
112
+ token2char_start={str(i): w.idx for i, w in enumerate(window)},
113
+ token2char_end={
114
+ str(i): w.idx + len(w.text) for i, w in enumerate(window)
115
+ },
116
+ char2token_start={
117
+ str(w.idx): w.i for i, w in enumerate(window)
118
+ },
119
+ char2token_end={
120
+ str(w.idx + len(w.text)): w.i for i, w in enumerate(window)
121
+ },
122
+ )
123
+ if mentions is not None and len(sample.spans) == 0:
124
+ windowed_blank_documents.append(sample)
125
+ else:
126
+ document_windows.append(sample)
127
+
128
+ windowed_documents.extend(document_windows)
129
+ if mentions is not None:
130
+ return windowed_documents, windowed_blank_documents
131
+ else:
132
+ return windowed_documents, windowed_blank_documents
133
+
134
+ def merge_windows(
135
+ self, windows: List[RelikReaderSample]
136
+ ) -> List[RelikReaderSample]:
137
+ windows_by_doc_id = collections.defaultdict(list)
138
+ for window in windows:
139
+ windows_by_doc_id[window.doc_id].append(window)
140
+
141
+ merged_window_by_doc = {
142
+ doc_id: self._merge_doc_windows(doc_windows)
143
+ for doc_id, doc_windows in windows_by_doc_id.items()
144
+ }
145
+
146
+ return list(merged_window_by_doc.values())
147
+
148
+ def _merge_doc_windows(self, windows: List[RelikReaderSample]) -> RelikReaderSample:
149
+ if len(windows) == 1:
150
+ return windows[0]
151
+
152
+ if len(windows) > 0 and getattr(windows[0], "offset", None) is not None:
153
+ windows = sorted(windows, key=(lambda x: x.offset))
154
+
155
+ window_accumulator = windows[0]
156
+
157
+ for next_window in windows[1:]:
158
+ window_accumulator = self._merge_window_pair(
159
+ window_accumulator, next_window
160
+ )
161
+
162
+ return window_accumulator
163
+
164
+ @staticmethod
165
+ def _merge_tokens(
166
+ window1: RelikReaderSample, window2: RelikReaderSample
167
+ ) -> Tuple[list, dict, dict]:
168
+ w1_tokens = window1.tokens[1:-1]
169
+ w2_tokens = window2.tokens[1:-1]
170
+
171
+ # find intersection if any
172
+ tokens_intersection = 0
173
+ for k in reversed(range(1, len(w1_tokens))):
174
+ if w1_tokens[-k:] == w2_tokens[:k]:
175
+ tokens_intersection = k
176
+ break
177
+
178
+ final_tokens = (
179
+ [window1.tokens[0]] # CLS
180
+ + w1_tokens
181
+ + w2_tokens[tokens_intersection:]
182
+ + [window1.tokens[-1]] # SEP
183
+ )
184
+
185
+ w2_starting_offset = len(w1_tokens) - tokens_intersection
186
+
187
+ def merge_char_mapping(t2c1: dict, t2c2: dict) -> dict:
188
+ final_t2c = dict()
189
+ final_t2c.update(t2c1)
190
+ for t, c in t2c2.items():
191
+ t = int(t)
192
+ if t < tokens_intersection:
193
+ continue
194
+ final_t2c[str(t + w2_starting_offset)] = c
195
+ return final_t2c
196
+
197
+ return (
198
+ final_tokens,
199
+ merge_char_mapping(window1.token2char_start, window2.token2char_start),
200
+ merge_char_mapping(window1.token2char_end, window2.token2char_end),
201
+ )
202
+
203
+ @staticmethod
204
+ def _merge_words(
205
+ window1: RelikReaderSample, window2: RelikReaderSample
206
+ ) -> Tuple[list, dict, dict]:
207
+ w1_words = window1.words
208
+ w2_words = window2.words
209
+
210
+ # find intersection if any
211
+ words_intersection = 0
212
+ for k in reversed(range(1, len(w1_words))):
213
+ if w1_words[-k:] == w2_words[:k]:
214
+ words_intersection = k
215
+ break
216
+
217
+ final_words = w1_words + w2_words[words_intersection:]
218
+
219
+ w2_starting_offset = len(w1_words) - words_intersection
220
+
221
+ def merge_word_mapping(t2c1: dict, t2c2: dict) -> dict:
222
+ final_t2c = dict()
223
+ if t2c1 is None:
224
+ t2c1 = dict()
225
+ if t2c2 is None:
226
+ t2c2 = dict()
227
+ final_t2c.update(t2c1)
228
+ for t, c in t2c2.items():
229
+ t = int(t)
230
+ if t < words_intersection:
231
+ continue
232
+ final_t2c[str(t + w2_starting_offset)] = c
233
+ return final_t2c
234
+
235
+ return (
236
+ final_words,
237
+ merge_word_mapping(window1.token2word_start, window2.token2word_start),
238
+ merge_word_mapping(window1.token2word_end, window2.token2word_end),
239
+ )
240
+
241
+ @staticmethod
242
+ def _merge_span_annotation(
243
+ span_annotation1: List[list], span_annotation2: List[list]
244
+ ) -> List[list]:
245
+ uniq_store = set()
246
+ final_span_annotation_store = []
247
+ for span_annotation in itertools.chain(span_annotation1, span_annotation2):
248
+ span_annotation_id = tuple(span_annotation)
249
+ if span_annotation_id not in uniq_store:
250
+ uniq_store.add(span_annotation_id)
251
+ final_span_annotation_store.append(span_annotation)
252
+ return sorted(final_span_annotation_store, key=lambda x: x[0])
253
+
254
+ @staticmethod
255
+ def _merge_predictions(
256
+ window1: RelikReaderSample, window2: RelikReaderSample
257
+ ) -> Tuple[Set[Tuple[int, int, str]], dict]:
258
+ # a RelikReaderSample should have a filed called `predicted_spans`
259
+ # that stores the span-level predictions, or a filed called
260
+ # `predicted_triples` that stores the triple-level predictions
261
+
262
+ # span predictions
263
+ merged_span_predictions: Set = set()
264
+ merged_span_probabilities = dict()
265
+ # triple predictions
266
+ merged_triplet_predictions: Set = set()
267
+ merged_triplet_probs: Dict = dict()
268
+
269
+ if (
270
+ getattr(window1, "predicted_spans", None) is not None
271
+ and getattr(window2, "predicted_spans", None) is not None
272
+ ):
273
+ merged_span_predictions = set(window1.predicted_spans).union(
274
+ set(window2.predicted_spans)
275
+ )
276
+ merged_span_predictions = sorted(merged_span_predictions)
277
+ # probabilities
278
+ for span_prediction, predicted_probs in itertools.chain(
279
+ window1.probs_window_labels_chars.items()
280
+ if window1.probs_window_labels_chars is not None
281
+ else [],
282
+ window2.probs_window_labels_chars.items()
283
+ if window2.probs_window_labels_chars is not None
284
+ else [],
285
+ ):
286
+ if span_prediction not in merged_span_probabilities:
287
+ merged_span_probabilities[span_prediction] = predicted_probs
288
+
289
+ if (
290
+ getattr(window1, "predicted_triples", None) is not None
291
+ and getattr(window2, "predicted_triples", None) is not None
292
+ ):
293
+ # try to merge the triples predictions
294
+ # add offset to the second window
295
+ window1_triplets = [
296
+ (
297
+ merged_span_predictions.index(window1.predicted_spans[t[0]]),
298
+ t[1],
299
+ merged_span_predictions.index(window1.predicted_spans[t[2]]),
300
+ t[3]
301
+ )
302
+ for t in window1.predicted_triples
303
+ ]
304
+ window2_triplets = [
305
+ (
306
+ merged_span_predictions.index(window2.predicted_spans[t[0]]),
307
+ t[1],
308
+ merged_span_predictions.index(window2.predicted_spans[t[2]]),
309
+ t[3]
310
+ )
311
+ for t in window2.predicted_triples
312
+ ]
313
+ merged_triplet_predictions = set(window1_triplets).union(
314
+ set(window2_triplets)
315
+ )
316
+ merged_triplet_predictions = sorted(merged_triplet_predictions)
317
+ # for now no triplet probs, we don't need them for the moment
318
+
319
+ return (
320
+ merged_span_predictions,
321
+ merged_span_probabilities,
322
+ merged_triplet_predictions,
323
+ merged_triplet_probs,
324
+ )
325
+
326
+ @staticmethod
327
+ def _merge_candidates(window1: RelikReaderSample, window2: RelikReaderSample):
328
+ candidates = []
329
+ windows_candidates = []
330
+
331
+ # TODO: retro-compatibility
332
+ if getattr(window1, "candidates", None) is not None:
333
+ candidates = window1.candidates
334
+ if getattr(window2, "candidates", None) is not None:
335
+ candidates += window2.candidates
336
+
337
+ # TODO: retro-compatibility
338
+ if getattr(window1, "windows_candidates", None) is not None:
339
+ windows_candidates = window1.windows_candidates
340
+ if getattr(window2, "windows_candidates", None) is not None:
341
+ windows_candidates += window2.windows_candidates
342
+
343
+ # TODO: add programmatically
344
+ span_candidates = []
345
+ if getattr(window1, "span_candidates", None) is not None:
346
+ span_candidates = window1.span_candidates
347
+ if getattr(window2, "span_candidates", None) is not None:
348
+ span_candidates += window2.span_candidates
349
+
350
+ triplet_candidates = []
351
+ if getattr(window1, "triplet_candidates", None) is not None:
352
+ triplet_candidates = window1.triplet_candidates
353
+ if getattr(window2, "triplet_candidates", None) is not None:
354
+ triplet_candidates += window2.triplet_candidates
355
+
356
+ # make them unique
357
+ candidates = list(set(candidates))
358
+ windows_candidates = list(set(windows_candidates))
359
+
360
+ span_candidates = list(set(span_candidates))
361
+ triplet_candidates = list(set(triplet_candidates))
362
+
363
+ return candidates, windows_candidates, span_candidates, triplet_candidates
364
+
365
+ def _merge_window_pair(
366
+ self,
367
+ window1: RelikReaderSample,
368
+ window2: RelikReaderSample,
369
+ ) -> RelikReaderSample:
370
+ merging_output = dict()
371
+
372
+ if getattr(window1, "doc_id", None) is not None:
373
+ assert window1.doc_id == window2.doc_id
374
+
375
+ if getattr(window1, "offset", None) is not None:
376
+ assert (
377
+ window1.offset < window2.offset
378
+ ), f"window 2 offset ({window2.offset}) is smaller that window 1 offset({window1.offset})"
379
+
380
+ merging_output["doc_id"] = window1.doc_id
381
+ merging_output["offset"] = window2.offset
382
+
383
+ m_tokens, m_token2char_start, m_token2char_end = self._merge_tokens(
384
+ window1, window2
385
+ )
386
+
387
+ m_words, m_token2word_start, m_token2word_end = self._merge_words(
388
+ window1, window2
389
+ )
390
+
391
+ (
392
+ m_candidates,
393
+ m_windows_candidates,
394
+ m_span_candidates,
395
+ m_triplet_candidates,
396
+ ) = self._merge_candidates(window1, window2)
397
+
398
+ window_labels = None
399
+ if getattr(window1, "window_labels", None) is not None:
400
+ window_labels = self._merge_span_annotation(
401
+ window1.window_labels, window2.window_labels
402
+ )
403
+
404
+ (
405
+ predicted_spans,
406
+ predicted_spans_probs,
407
+ predicted_triples,
408
+ predicted_triples_probs,
409
+ ) = self._merge_predictions(window1, window2)
410
+
411
+ merging_output.update(
412
+ dict(
413
+ tokens=m_tokens,
414
+ words=m_words,
415
+ token2char_start=m_token2char_start,
416
+ token2char_end=m_token2char_end,
417
+ token2word_start=m_token2word_start,
418
+ token2word_end=m_token2word_end,
419
+ window_labels=window_labels,
420
+ candidates=m_candidates,
421
+ span_candidates=m_span_candidates,
422
+ triplet_candidates=m_triplet_candidates,
423
+ windows_candidates=m_windows_candidates,
424
+ predicted_spans=predicted_spans,
425
+ predicted_spans_probs=predicted_spans_probs,
426
+ predicted_triples=predicted_triples,
427
+ predicted_triples_probs=predicted_triples_probs,
428
+ )
429
+ )
430
+
431
+ return RelikReaderSample(**merging_output)
relik/inference/gerbil.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import logging
4
+ import os
5
+ from pathlib import Path
6
+ import re
7
+ import sys
8
+ from http.server import BaseHTTPRequestHandler, HTTPServer
9
+ from typing import Iterator, List, Optional, Tuple
10
+ from urllib import parse
11
+
12
+ from relik.inference.annotator import Relik
13
+ from relik.inference.data.objects import RelikOutput
14
+
15
+ # sys.path += ['../']
16
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
17
+
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class GerbilAlbyManager:
23
+ def __init__(
24
+ self,
25
+ annotator: Optional[Relik] = None,
26
+ response_logger_dir: Optional[str] = None,
27
+ ) -> None:
28
+ self.annotator = annotator
29
+ self.response_logger_dir = response_logger_dir
30
+ self.predictions_counter = 0
31
+ self.labels_mapping = None
32
+
33
+ def annotate(self, document: str):
34
+ relik_output: RelikOutput = self.annotator(
35
+ document, retriever_batch_size=2, reader_batch_size=1
36
+ )
37
+ annotations = [(ss, se, l) for ss, se, l, _ in relik_output.spans]
38
+ if self.labels_mapping is not None:
39
+ return [
40
+ (ss, se, self.labels_mapping.get(l, l)) for ss, se, l in annotations
41
+ ]
42
+ return annotations
43
+
44
+ def set_mapping_file(self, mapping_file_path: str):
45
+ with open(mapping_file_path) as f:
46
+ labels_mapping = json.load(f)
47
+ self.labels_mapping = {v: k for k, v in labels_mapping.items()}
48
+
49
+ def write_response_bundle(
50
+ self,
51
+ document: str,
52
+ new_document: str,
53
+ annotations: list,
54
+ mapped_annotations: list,
55
+ ) -> None:
56
+ if self.response_logger_dir is None:
57
+ return
58
+
59
+ if not os.path.isdir(self.response_logger_dir):
60
+ os.mkdir(self.response_logger_dir)
61
+
62
+ with open(
63
+ f"{self.response_logger_dir}/{self.predictions_counter}.json", "w"
64
+ ) as f:
65
+ out_json_obj = dict(
66
+ document=document,
67
+ new_document=new_document,
68
+ annotations=annotations,
69
+ mapped_annotations=mapped_annotations,
70
+ )
71
+
72
+ out_json_obj["span_annotations"] = [
73
+ (ss, se, document[ss:se], label) for (ss, se, label) in annotations
74
+ ]
75
+
76
+ out_json_obj["span_mapped_annotations"] = [
77
+ (ss, se, new_document[ss:se], label)
78
+ for (ss, se, label) in mapped_annotations
79
+ ]
80
+
81
+ json.dump(out_json_obj, f, indent=2)
82
+
83
+ self.predictions_counter += 1
84
+
85
+
86
+ manager = GerbilAlbyManager()
87
+
88
+
89
+ def preprocess_document(document: str) -> Tuple[str, List[Tuple[int, int]]]:
90
+ pattern_subs = {
91
+ "-LPR- ": " (",
92
+ "-RPR-": ")",
93
+ "\n\n": "\n",
94
+ "-LRB-": "(",
95
+ "-RRB-": ")",
96
+ '","': ",",
97
+ }
98
+
99
+ document_acc = document
100
+ curr_offset = 0
101
+ char2offset = []
102
+
103
+ matchings = re.finditer("({})".format("|".join(pattern_subs)), document)
104
+ for span_matching in sorted(matchings, key=lambda x: x.span()[0]):
105
+ span_start, span_end = span_matching.span()
106
+ span_start -= curr_offset
107
+ span_end -= curr_offset
108
+
109
+ span_text = document_acc[span_start:span_end]
110
+ span_sub = pattern_subs[span_text]
111
+ document_acc = document_acc[:span_start] + span_sub + document_acc[span_end:]
112
+
113
+ offset = len(span_text) - len(span_sub)
114
+ curr_offset += offset
115
+
116
+ char2offset.append((span_start + len(span_sub), curr_offset))
117
+
118
+ return document_acc, char2offset
119
+
120
+
121
+ def map_back_annotations(
122
+ annotations: List[Tuple[int, int, str]], char_mapping: List[Tuple[int, int]]
123
+ ) -> Iterator[Tuple[int, int, str]]:
124
+ def map_char(char_idx: int) -> int:
125
+ current_offset = 0
126
+ for offset_idx, offset_value in char_mapping:
127
+ if char_idx >= offset_idx:
128
+ current_offset = offset_value
129
+ else:
130
+ break
131
+ return char_idx + current_offset
132
+
133
+ for ss, se, label in annotations:
134
+ yield map_char(ss), map_char(se), label
135
+
136
+
137
+ def annotate(document: str) -> List[Tuple[int, int, str]]:
138
+ new_document, mapping = preprocess_document(document)
139
+ logger.info("Mapping: " + str(mapping))
140
+ logger.info("Document: " + str(document))
141
+ annotations = [
142
+ (cs, ce, label.replace(" ", "_"))
143
+ for cs, ce, label in manager.annotate(new_document)
144
+ ]
145
+ logger.info("New document: " + str(new_document))
146
+ mapped_annotations = (
147
+ list(map_back_annotations(annotations, mapping))
148
+ if len(mapping) > 0
149
+ else annotations
150
+ )
151
+
152
+ logger.info(
153
+ "Annotations: "
154
+ + str([(ss, se, document[ss:se], ann) for ss, se, ann in mapped_annotations])
155
+ )
156
+
157
+ manager.write_response_bundle(
158
+ document, new_document, mapped_annotations, annotations
159
+ )
160
+
161
+ if not all(
162
+ [
163
+ new_document[ss:se] == document[mss:mse]
164
+ for (mss, mse, _), (ss, se, _) in zip(mapped_annotations, annotations)
165
+ ]
166
+ ):
167
+ diff_mappings = [
168
+ (new_document[ss:se], document[mss:mse])
169
+ for (mss, mse, _), (ss, se, _) in zip(mapped_annotations, annotations)
170
+ ]
171
+ return None
172
+ assert all(
173
+ [
174
+ document[mss:mse] == new_document[ss:se]
175
+ for (mss, mse, _), (ss, se, _) in zip(mapped_annotations, annotations)
176
+ ]
177
+ ), (mapped_annotations, annotations)
178
+
179
+ return [(cs, ce - cs, label) for cs, ce, label in mapped_annotations]
180
+
181
+
182
+ class GetHandler(BaseHTTPRequestHandler):
183
+ def do_POST(self):
184
+ content_length = int(self.headers["Content-Length"])
185
+ post_data = self.rfile.read(content_length)
186
+ self.send_response(200)
187
+ self.end_headers()
188
+ doc_text = read_json(post_data)
189
+ # try:
190
+ response = annotate(doc_text)
191
+
192
+ self.wfile.write(bytes(json.dumps(response), "utf-8"))
193
+ return
194
+
195
+
196
+ def read_json(post_data):
197
+ data = json.loads(post_data.decode("utf-8"))
198
+ # logger.info("received data:", data)
199
+ text = data["text"]
200
+ # spans = [(int(j["start"]), int(j["length"])) for j in data["spans"]]
201
+ return text
202
+
203
+
204
+ def parse_args() -> argparse.Namespace:
205
+ parser = argparse.ArgumentParser()
206
+ parser.add_argument("--relik-model-name", required=True)
207
+ parser.add_argument("--responses-log-dir")
208
+ parser.add_argument("--log-file", default="experiments/logging.txt")
209
+ parser.add_argument("--mapping-file")
210
+ return parser.parse_args()
211
+
212
+
213
+ def main():
214
+ args = parse_args()
215
+
216
+ responses_log_dir = Path(args.responses_log_dir)
217
+ responses_log_dir.mkdir(parents=True, exist_ok=True)
218
+
219
+ # init manager
220
+ manager.response_logger_dir = args.responses_log_dir
221
+ manager.annotator = Relik.from_pretrained(
222
+ args.relik_model_name,
223
+ device="cuda",
224
+ # document_index_device="cpu",
225
+ # document_index_precision="fp32",
226
+ # reader_device="cpu",
227
+ precision="fp16", # , reader_device="cpu", reader_precision="fp32"
228
+ dataset_kwargs={"use_nme": True}
229
+ )
230
+
231
+ # print("Debugging, not using you relik model but an hardcoded one.")
232
+ # manager.annotator = Relik(
233
+ # question_encoder="riccorl/relik-retriever-aida-blink-pretrain-omniencoder",
234
+ # document_index="riccorl/index-relik-retriever-aida-blink-pretrain-omniencoder",
235
+ # reader="relik/reader/models/relik-reader-deberta-base-new-data",
236
+ # window_size=32,
237
+ # window_stride=16,
238
+ # candidates_preprocessing_fn=(lambda x: x.split("<def>")[0].strip()),
239
+ # )
240
+
241
+ if args.mapping_file is not None:
242
+ manager.set_mapping_file(args.mapping_file)
243
+
244
+ # port = 6654
245
+ port = 5555
246
+ server = HTTPServer(("localhost", port), GetHandler)
247
+ logger.info(f"Starting server at http://localhost:{port}")
248
+
249
+ # Create a file handler and set its level
250
+ file_handler = logging.FileHandler(args.log_file)
251
+ file_handler.setLevel(logging.DEBUG)
252
+
253
+ # Create a log formatter and set it on the handler
254
+ formatter = logging.Formatter(
255
+ "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
256
+ )
257
+ file_handler.setFormatter(formatter)
258
+
259
+ # Add the file handler to the logger
260
+ logger.addHandler(file_handler)
261
+
262
+ try:
263
+ server.serve_forever()
264
+ except KeyboardInterrupt:
265
+ exit(0)
266
+
267
+
268
+ if __name__ == "__main__":
269
+ main()
relik/inference/serve/__init__.py ADDED
File without changes
relik/inference/serve/backend/__init__.py ADDED
File without changes
relik/inference/serve/backend/fastapi.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from pathlib import Path
4
+ from typing import List, Union
5
+ import psutil
6
+
7
+ import torch
8
+
9
+ from relik.common.utils import is_package_available
10
+ from relik.inference.annotator import Relik
11
+
12
+ if not is_package_available("fastapi"):
13
+ raise ImportError(
14
+ "FastAPI is not installed. Please install FastAPI with `pip install relik[serve]`."
15
+ )
16
+ from fastapi import FastAPI, HTTPException, APIRouter
17
+
18
+
19
+ from relik.common.log import get_logger
20
+ from relik.inference.serve.backend.utils import (
21
+ RayParameterManager,
22
+ ServerParameterManager,
23
+ )
24
+
25
+ logger = get_logger(__name__, level=logging.INFO)
26
+
27
+ VERSION = {} # type: ignore
28
+ with open(
29
+ Path(__file__).parent.parent.parent.parent / "version.py", "r"
30
+ ) as version_file:
31
+ exec(version_file.read(), VERSION)
32
+
33
+ # Env variables for server
34
+ SERVER_MANAGER = ServerParameterManager()
35
+ RAY_MANAGER = RayParameterManager()
36
+
37
+
38
+ class RelikServer:
39
+ def __init__(
40
+ self,
41
+ relik_pretrained: str | None = None,
42
+ device: str = "cpu",
43
+ retriever_device: str | None = None,
44
+ document_index_device: str | None = None,
45
+ reader_device: str | None = None,
46
+ precision: str | int | torch.dtype = 32,
47
+ retriever_precision: str | int | torch.dtype | None = None,
48
+ document_index_precision: str | int | torch.dtype | None = None,
49
+ reader_precision: str | int | torch.dtype | None = None,
50
+ annotation_type: str = "char",
51
+ **kwargs,
52
+ ):
53
+ num_threads = os.getenv("TORCH_NUM_THREADS", psutil.cpu_count(logical=False))
54
+ torch.set_num_threads(num_threads)
55
+ logger.info(f"Torch is running on {num_threads} threads.")
56
+ # parameters
57
+ logger.info(f"RELIK_PRETRAINED: {relik_pretrained}")
58
+ self.relik_pretrained = relik_pretrained
59
+ logger.info(f"DEVICE: {device}")
60
+ self.device = device
61
+ if retriever_device is not None:
62
+ logger.info(f"RETRIEVER_DEVICE: {retriever_device}")
63
+ self.retriever_device = retriever_device or device
64
+ if document_index_device is not None:
65
+ logger.info(f"INDEX_DEVICE: {document_index_device}")
66
+ self.document_index_device = document_index_device or retriever_device
67
+ if reader_device is not None:
68
+ logger.info(f"READER_DEVICE: {reader_device}")
69
+ self.reader_device = reader_device
70
+ logger.info(f"PRECISION: {precision}")
71
+ self.precision = precision
72
+ if retriever_precision is not None:
73
+ logger.info(f"RETRIEVER_PRECISION: {retriever_precision}")
74
+ self.retriever_precision = retriever_precision or precision
75
+ if document_index_precision is not None:
76
+ logger.info(f"INDEX_PRECISION: {document_index_precision}")
77
+ self.document_index_precision = document_index_precision or precision
78
+ if reader_precision is not None:
79
+ logger.info(f"READER_PRECISION: {reader_precision}")
80
+ self.reader_precision = reader_precision or precision
81
+ logger.info(f"ANNOTATION_TYPE: {annotation_type}")
82
+ self.annotation_type = annotation_type
83
+
84
+ self.relik = Relik.from_pretrained(
85
+ self.relik_pretrained,
86
+ device=self.device,
87
+ retriever_device=self.retriever_device,
88
+ document_index_device=self.document_index_device,
89
+ reader_device=self.reader_device,
90
+ precision=self.precision,
91
+ retriever_precision=self.retriever_precision,
92
+ document_index_precision=self.document_index_precision,
93
+ reader_precision=self.reader_precision,
94
+ )
95
+
96
+ self.router = APIRouter()
97
+ self.router.add_api_route("/api/relik", self.relik_endpoint, methods=["POST"])
98
+
99
+ logger.info("RelikServer initialized.")
100
+
101
+ # @serve.batch()
102
+ async def __call__(self, text: List[str]) -> List:
103
+ return self.relik(text, annotation_type=self.annotation_type)
104
+
105
+ # @app.post("/api/relik")
106
+ async def relik_endpoint(self, text: Union[str, List[str]]):
107
+ try:
108
+ # get predictions for the retriever
109
+ return await self(text)
110
+ except Exception as e:
111
+ # log the entire stack trace
112
+ logger.exception(e)
113
+ raise HTTPException(status_code=500, detail=f"Server Error: {e}")
114
+
115
+
116
+ app = FastAPI(
117
+ title="ReLiK",
118
+ version=VERSION["VERSION"],
119
+ description="ReLiK REST API",
120
+ )
121
+ server = RelikServer(**vars(SERVER_MANAGER))
122
+ app.include_router(server.router)
relik/inference/serve/backend/ray.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from pathlib import Path
4
+ from typing import List, Union
5
+ import psutil
6
+
7
+ import torch
8
+
9
+ from relik.common.utils import is_package_available
10
+ from relik.inference.annotator import Relik
11
+
12
+ if not is_package_available("fastapi"):
13
+ raise ImportError(
14
+ "FastAPI is not installed. Please install FastAPI with `pip install relik[serve]`."
15
+ )
16
+ from fastapi import FastAPI, HTTPException
17
+
18
+ if not is_package_available("ray"):
19
+ raise ImportError(
20
+ "Ray is not installed. Please install Ray with `pip install relik[serve]`."
21
+ )
22
+ from ray import serve
23
+
24
+ from relik.common.log import get_logger
25
+ from relik.inference.serve.backend.utils import (
26
+ RayParameterManager,
27
+ ServerParameterManager,
28
+ )
29
+
30
+ logger = get_logger(__name__, level=logging.INFO)
31
+
32
+ VERSION = {} # type: ignore
33
+ with open(
34
+ Path(__file__).parent.parent.parent.parent / "version.py", "r"
35
+ ) as version_file:
36
+ exec(version_file.read(), VERSION)
37
+
38
+ # Env variables for server
39
+ SERVER_MANAGER = ServerParameterManager()
40
+ RAY_MANAGER = RayParameterManager()
41
+
42
+ app = FastAPI(
43
+ title="ReLiK",
44
+ version=VERSION["VERSION"],
45
+ description="ReLiK REST API",
46
+ )
47
+
48
+
49
+ @serve.deployment(
50
+ ray_actor_options={
51
+ "num_gpus": RAY_MANAGER.num_gpus
52
+ if (
53
+ SERVER_MANAGER.device == "cuda"
54
+ or SERVER_MANAGER.retriever_device == "cuda"
55
+ or SERVER_MANAGER.reader_device == "cuda"
56
+ )
57
+ else 0
58
+ },
59
+ autoscaling_config={
60
+ "min_replicas": RAY_MANAGER.min_replicas,
61
+ "max_replicas": RAY_MANAGER.max_replicas,
62
+ },
63
+ )
64
+ @serve.ingress(app)
65
+ class RelikServer:
66
+ def __init__(
67
+ self,
68
+ relik_pretrained: str | None = None,
69
+ device: str = "cpu",
70
+ retriever_device: str | None = None,
71
+ document_index_device: str | None = None,
72
+ reader_device: str | None = None,
73
+ precision: str | int | torch.dtype = 32,
74
+ retriever_precision: str | int | torch.dtype | None = None,
75
+ document_index_precision: str | int | torch.dtype | None = None,
76
+ reader_precision: str | int | torch.dtype | None = None,
77
+ annotation_type: str = "char",
78
+ retriever_batch_size: int = 32,
79
+ reader_batch_size: int = 32,
80
+ relik_config_override: dict | None = None,
81
+ **kwargs,
82
+ ):
83
+ num_threads = os.getenv("TORCH_NUM_THREADS", psutil.cpu_count(logical=False))
84
+ torch.set_num_threads(num_threads)
85
+ logger.info(f"Torch is running on {num_threads} threads.")
86
+
87
+ # parameters
88
+ logger.info(f"RELIK_PRETRAINED: {relik_pretrained}")
89
+ self.relik_pretrained = relik_pretrained
90
+
91
+ if relik_config_override is None:
92
+ relik_config_override = {}
93
+ logger.info(f"RELIK_CONFIG_OVERRIDE: {relik_config_override}")
94
+ self.relik_config_override = relik_config_override
95
+
96
+ logger.info(f"DEVICE: {device}")
97
+ self.device = device
98
+
99
+ if retriever_device is not None:
100
+ logger.info(f"RETRIEVER_DEVICE: {retriever_device}")
101
+ self.retriever_device = retriever_device or device
102
+
103
+ if document_index_device is not None:
104
+ logger.info(f"INDEX_DEVICE: {document_index_device}")
105
+ self.document_index_device = document_index_device or retriever_device
106
+
107
+ if reader_device is not None:
108
+ logger.info(f"READER_DEVICE: {reader_device}")
109
+ self.reader_device = reader_device
110
+
111
+ logger.info(f"PRECISION: {precision}")
112
+ self.precision = precision
113
+
114
+ if retriever_precision is not None:
115
+ logger.info(f"RETRIEVER_PRECISION: {retriever_precision}")
116
+ self.retriever_precision = retriever_precision or precision
117
+
118
+ if document_index_precision is not None:
119
+ logger.info(f"INDEX_PRECISION: {document_index_precision}")
120
+ self.document_index_precision = document_index_precision or precision
121
+
122
+ if reader_precision is not None:
123
+ logger.info(f"READER_PRECISION: {reader_precision}")
124
+ self.reader_precision = reader_precision or precision
125
+
126
+ logger.info(f"ANNOTATION_TYPE: {annotation_type}")
127
+ self.annotation_type = annotation_type
128
+
129
+ self.relik = Relik.from_pretrained(
130
+ self.relik_pretrained,
131
+ device=self.device,
132
+ retriever_device=self.retriever_device,
133
+ document_index_device=self.document_index_device,
134
+ reader_device=self.reader_device,
135
+ precision=self.precision,
136
+ retriever_precision=self.retriever_precision,
137
+ document_index_precision=self.document_index_precision,
138
+ reader_precision=self.reader_precision,
139
+ **self.relik_config_override,
140
+ )
141
+
142
+ self.retriever_batch_size = retriever_batch_size
143
+ self.reader_batch_size = reader_batch_size
144
+
145
+ # @serve.batch()
146
+ async def handle_batch(self, text: List[str]) -> List:
147
+ return self.relik(
148
+ text,
149
+ annotation_type=self.annotation_type,
150
+ retriever_batch_size=self.retriever_batch_size,
151
+ reader_batch_size=self.reader_batch_size,
152
+ )
153
+
154
+ @app.post("/api/relik")
155
+ async def relik_endpoint(self, text: Union[str, List[str]]):
156
+ try:
157
+ # get predictions for the retriever
158
+ return await self.handle_batch(text)
159
+ except Exception as e:
160
+ # log the entire stack trace
161
+ logger.exception(e)
162
+ raise HTTPException(status_code=500, detail=f"Server Error: {e}")
163
+
164
+
165
+ server = RelikServer.bind(**vars(SERVER_MANAGER))
relik/inference/serve/backend/utils.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import os
3
+ from dataclasses import dataclass
4
+
5
+
6
+ @dataclass
7
+ class ServerParameterManager:
8
+ relik_pretrained: str = os.environ.get("RELIK_PRETRAINED", None)
9
+ device: str = os.environ.get("DEVICE", "cpu")
10
+ retriever_device: str | None = os.environ.get("RETRIEVER_DEVICE", None)
11
+ document_index_device: str | None = os.environ.get("INDEX_DEVICE", None)
12
+ reader_device: str | None = os.environ.get("READER_DEVICE", None)
13
+ precision: int | str | None = os.environ.get("PRECISION", "fp32")
14
+ retriever_precision: int | str | None = os.environ.get("RETRIEVER_PRECISION", None)
15
+ document_index_precision: int | str | None = os.environ.get("INDEX_PRECISION", None)
16
+ reader_precision: int | str | None = os.environ.get("READER_PRECISION", None)
17
+ annotation_type: str = os.environ.get("ANNOTATION_TYPE", "char")
18
+ question_encoder: str = os.environ.get("QUESTION_ENCODER", None)
19
+ passage_encoder: str = os.environ.get("PASSAGE_ENCODER", None)
20
+ document_index: str = os.environ.get("DOCUMENT_INDEX", None)
21
+ reader_encoder: str = os.environ.get("READER_ENCODER", None)
22
+ top_k: int = int(os.environ.get("TOP_K", 100))
23
+ use_faiss: bool = os.environ.get("USE_FAISS", False)
24
+ retriever_batch_size: int = int(os.environ.get("RETRIEVER_BATCH_SIZE", 32))
25
+ reader_batch_size: int = int(os.environ.get("READER_BATCH_SIZE", 32))
26
+ window_size: int = int(os.environ.get("WINDOW_SIZE", 32))
27
+ window_stride: int = int(os.environ.get("WINDOW_SIZE", 16))
28
+ split_on_spaces: bool = os.environ.get("SPLIT_ON_SPACES", False)
29
+ # relik_config_override: dict = ast.literal_eval(
30
+ # os.environ.get("RELIK_CONFIG_OVERRIDE", None)
31
+ # )
32
+
33
+
34
+ class RayParameterManager:
35
+ def __init__(self) -> None:
36
+ self.num_gpus = int(os.environ.get("NUM_GPUS", 1))
37
+ self.min_replicas = int(os.environ.get("MIN_REPLICAS", 1))
38
+ self.max_replicas = int(os.environ.get("MAX_REPLICAS", 1))
relik/inference/serve/frontend/__init__.py ADDED
File without changes
relik/inference/serve/frontend/relik_front.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+
4
+ import requests
5
+ import streamlit as st
6
+ from spacy import displacy
7
+ from streamlit_extras.badges import badge
8
+ from streamlit_extras.stylable_container import stylable_container
9
+
10
+ RELIK = os.getenv("RELIK", "localhost:8000/api/entities")
11
+
12
+ import random
13
+
14
+
15
+ def get_random_color(ents):
16
+ colors = {}
17
+ random_colors = generate_pastel_colors(len(ents))
18
+ for ent in ents:
19
+ colors[ent] = random_colors.pop(random.randint(0, len(random_colors) - 1))
20
+ return colors
21
+
22
+
23
+ def floatrange(start, stop, steps):
24
+ if int(steps) == 1:
25
+ return [stop]
26
+ return [
27
+ start + float(i) * (stop - start) / (float(steps) - 1) for i in range(steps)
28
+ ]
29
+
30
+
31
+ def hsl_to_rgb(h, s, l):
32
+ def hue_2_rgb(v1, v2, v_h):
33
+ while v_h < 0.0:
34
+ v_h += 1.0
35
+ while v_h > 1.0:
36
+ v_h -= 1.0
37
+ if 6 * v_h < 1.0:
38
+ return v1 + (v2 - v1) * 6.0 * v_h
39
+ if 2 * v_h < 1.0:
40
+ return v2
41
+ if 3 * v_h < 2.0:
42
+ return v1 + (v2 - v1) * ((2.0 / 3.0) - v_h) * 6.0
43
+ return v1
44
+
45
+ # if not (0 <= s <= 1): raise ValueError, "s (saturation) parameter must be between 0 and 1."
46
+ # if not (0 <= l <= 1): raise ValueError, "l (lightness) parameter must be between 0 and 1."
47
+
48
+ r, b, g = (l * 255,) * 3
49
+ if s != 0.0:
50
+ if l < 0.5:
51
+ var_2 = l * (1.0 + s)
52
+ else:
53
+ var_2 = (l + s) - (s * l)
54
+ var_1 = 2.0 * l - var_2
55
+ r = 255 * hue_2_rgb(var_1, var_2, h + (1.0 / 3.0))
56
+ g = 255 * hue_2_rgb(var_1, var_2, h)
57
+ b = 255 * hue_2_rgb(var_1, var_2, h - (1.0 / 3.0))
58
+
59
+ return int(round(r)), int(round(g)), int(round(b))
60
+
61
+
62
+ def generate_pastel_colors(n):
63
+ """Return different pastel colours.
64
+
65
+ Input:
66
+ n (integer) : The number of colors to return
67
+
68
+ Output:
69
+ A list of colors in HTML notation (eg.['#cce0ff', '#ffcccc', '#ccffe0', '#f5ccff', '#f5ffcc'])
70
+
71
+ Example:
72
+ >>> print generate_pastel_colors(5)
73
+ ['#cce0ff', '#f5ccff', '#ffcccc', '#f5ffcc', '#ccffe0']
74
+ """
75
+ if n == 0:
76
+ return []
77
+
78
+ # To generate colors, we use the HSL colorspace (see http://en.wikipedia.org/wiki/HSL_color_space)
79
+ start_hue = 0.6 # 0=red 1/3=0.333=green 2/3=0.666=blue
80
+ saturation = 1.0
81
+ lightness = 0.8
82
+ # We take points around the chromatic circle (hue):
83
+ # (Note: we generate n+1 colors, then drop the last one ([:-1]) because
84
+ # it equals the first one (hue 0 = hue 1))
85
+ return [
86
+ "#%02x%02x%02x" % hsl_to_rgb(hue, saturation, lightness)
87
+ for hue in floatrange(start_hue, start_hue + 1, n + 1)
88
+ ][:-1]
89
+
90
+
91
+ def set_sidebar(css):
92
+ white_link_wrapper = "<link rel='stylesheet' href='https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.2/css/all.min.css'><a href='{}'>{}</a>"
93
+ with st.sidebar:
94
+ st.markdown(f"<style>{css}</style>", unsafe_allow_html=True)
95
+ st.image(
96
+ "http://nlp.uniroma1.it/static/website/sapienza-nlp-logo-wh.svg",
97
+ use_column_width=True,
98
+ )
99
+ st.markdown("## ReLiK")
100
+ st.write(
101
+ f"""
102
+ - {white_link_wrapper.format("#", "<i class='fa-solid fa-file'></i>&nbsp; Paper")}
103
+ - {white_link_wrapper.format("https://github.com/SapienzaNLP/relik", "<i class='fa-brands fa-github'></i>&nbsp; GitHub")}
104
+ - {white_link_wrapper.format("https://hub.docker.com/repository/docker/sapienzanlp/relik", "<i class='fa-brands fa-docker'></i>&nbsp; Docker Hub")}
105
+ """,
106
+ unsafe_allow_html=True,
107
+ )
108
+ st.markdown("## Sapienza NLP")
109
+ st.write(
110
+ f"""
111
+ - {white_link_wrapper.format("https://nlp.uniroma1.it", "<i class='fa-solid fa-globe'></i>&nbsp; Webpage")}
112
+ - {white_link_wrapper.format("https://github.com/SapienzaNLP", "<i class='fa-brands fa-github'></i>&nbsp; GitHub")}
113
+ - {white_link_wrapper.format("https://twitter.com/SapienzaNLP", "<i class='fa-brands fa-twitter'></i>&nbsp; Twitter")}
114
+ - {white_link_wrapper.format("https://www.linkedin.com/company/79434450", "<i class='fa-brands fa-linkedin'></i>&nbsp; LinkedIn")}
115
+ """,
116
+ unsafe_allow_html=True,
117
+ )
118
+
119
+
120
+ def get_el_annotations(response):
121
+ # swap labels key with ents
122
+ response["ents"] = response.pop("labels")
123
+ label_in_text = set(l["label"] for l in response["ents"])
124
+ options = {"ents": label_in_text, "colors": get_random_color(label_in_text)}
125
+ return response, options
126
+
127
+
128
+ def set_intro(css):
129
+ # intro
130
+ st.markdown("# ReLik")
131
+ st.markdown(
132
+ "### Retrieve, Read and LinK: Fast and Accurate Entity Linking and Relation Extraction on an Academic Budget"
133
+ )
134
+ # st.markdown(
135
+ # "This is a front-end for the paper [Universal Semantic Annotator: the First Unified API "
136
+ # "for WSD, SRL and Semantic Parsing](https://www.researchgate.net/publication/360671045_Universal_Semantic_Annotator_the_First_Unified_API_for_WSD_SRL_and_Semantic_Parsing), which will be presented at LREC 2022 by "
137
+ # "[Riccardo Orlando](https://riccorl.github.io), [Simone Conia](https://c-simone.github.io/), "
138
+ # "[Stefano Faralli](https://corsidilaurea.uniroma1.it/it/users/stefanofaralliuniroma1it), and [Roberto Navigli](https://www.diag.uniroma1.it/navigli/)."
139
+ # )
140
+ badge(type="github", name="sapienzanlp/relik")
141
+ badge(type="pypi", name="relik")
142
+
143
+
144
+ def run_client():
145
+ with open(Path(__file__).parent / "style.css") as f:
146
+ css = f.read()
147
+
148
+ st.set_page_config(
149
+ page_title="ReLik",
150
+ page_icon="🦮",
151
+ layout="wide",
152
+ )
153
+ set_sidebar(css)
154
+ set_intro(css)
155
+
156
+ # text input
157
+ text = st.text_area(
158
+ "Enter Text Below:",
159
+ value="Obama went to Rome for a quick vacation.",
160
+ height=200,
161
+ max_chars=500,
162
+ )
163
+
164
+ with stylable_container(
165
+ key="annotate_button",
166
+ css_styles="""
167
+ button {
168
+ background-color: #802433;
169
+ color: white;
170
+ border-radius: 25px;
171
+ }
172
+ """,
173
+ ):
174
+ submit = st.button("Annotate")
175
+ # submit = st.button("Run")
176
+
177
+ # ReLik API call
178
+ if submit:
179
+ text = text.strip()
180
+ if text:
181
+ st.markdown("####")
182
+ st.markdown("#### Entity Linking")
183
+ with st.spinner(text="In progress"):
184
+ response = requests.post(RELIK, json=text)
185
+ if response.status_code != 200:
186
+ st.error("Error: {}".format(response.status_code))
187
+ else:
188
+ response = response.json()
189
+
190
+ # Entity Linking
191
+ # with stylable_container(
192
+ # key="container_with_border",
193
+ # css_styles="""
194
+ # {
195
+ # border: 1px solid rgba(49, 51, 63, 0.2);
196
+ # border-radius: 0.5rem;
197
+ # padding: 0.5rem;
198
+ # padding-bottom: 2rem;
199
+ # }
200
+ # """,
201
+ # ):
202
+ # st.markdown("##")
203
+ dict_of_ents, options = get_el_annotations(response=response)
204
+ display = displacy.render(
205
+ dict_of_ents, manual=True, style="ent", options=options
206
+ )
207
+ display = display.replace("\n", " ")
208
+ # wsd_display = re.sub(
209
+ # r"(wiki::\d+\w)",
210
+ # r"<a href='https://babelnet.org/synset?id=\g<1>&orig=\g<1>&lang={}'>\g<1></a>".format(
211
+ # language.upper()
212
+ # ),
213
+ # wsd_display,
214
+ # )
215
+ with st.container():
216
+ st.write(display, unsafe_allow_html=True)
217
+
218
+ st.markdown("####")
219
+ st.markdown("#### Relation Extraction")
220
+
221
+ with st.container():
222
+ st.write("Coming :)", unsafe_allow_html=True)
223
+
224
+ else:
225
+ st.error("Please enter some text.")
226
+
227
+
228
+ if __name__ == "__main__":
229
+ run_client()
relik/inference/serve/frontend/relik_re_front.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from datetime import datetime as dt
3
+ from pathlib import Path
4
+
5
+ import requests
6
+ import spacy
7
+ import streamlit as st
8
+ import streamlit.components.v1 as components
9
+ from pyvis.network import Network
10
+ from spacy import displacy
11
+ from spacy.tokens import Doc
12
+ from streamlit_extras.badges import badge
13
+ from streamlit_extras.stylable_container import stylable_container
14
+ from utils import get_random_color, visualize_parser
15
+
16
+ from relik import Relik
17
+
18
+ # RELIK = os.getenv("RELIK", "localhost:8000/api/relik")
19
+
20
+ state_variables = {"has_run_free": False, "html_free": ""}
21
+
22
+
23
+ def init_state_variables():
24
+ for k, v in state_variables.items():
25
+ if k not in st.session_state:
26
+ st.session_state[k] = v
27
+
28
+
29
+ def free_reset_session():
30
+ for k in state_variables:
31
+ del st.session_state[k]
32
+
33
+
34
+ def generate_graph(dict_ents, response, filename, options):
35
+ g = Network(
36
+ width="720px",
37
+ height="600px",
38
+ directed=True,
39
+ notebook=False,
40
+ bgcolor="#222222",
41
+ font_color="white",
42
+ )
43
+ g.barnes_hut(
44
+ gravity=-3000,
45
+ central_gravity=0.3,
46
+ spring_length=50,
47
+ spring_strength=0.001,
48
+ damping=0.09,
49
+ overlap=0,
50
+ )
51
+ for ent in dict_ents:
52
+ g.add_node(
53
+ dict_ents[ent][0],
54
+ label=dict_ents[ent][1],
55
+ color=options["colors"][dict_ents[ent][0]],
56
+ title=dict_ents[ent][0],
57
+ size=15,
58
+ labelHighlightBold=True,
59
+ )
60
+
61
+ for rel in response.triples:
62
+ g.add_edge(
63
+ dict_ents[(rel.subject.start, rel.subject.end)][0],
64
+ dict_ents[(rel.object.start, rel.object.end)][0],
65
+ label=rel.label,
66
+ title=rel.label,
67
+ )
68
+ g.show(filename, notebook=False)
69
+
70
+
71
+ def set_sidebar(css):
72
+ white_link_wrapper = (
73
+ "<link rel='stylesheet' "
74
+ "href='https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.2/css/all.min.css'><a href='{}'>{}</a>"
75
+ )
76
+ with st.sidebar:
77
+ st.markdown(f"<style>{css}</style>", unsafe_allow_html=True)
78
+ st.image(
79
+ "http://nlp.uniroma1.it/static/website/sapienza-nlp-logo-wh.svg",
80
+ use_column_width=True,
81
+ )
82
+ st.markdown("## ReLiK")
83
+ st.write(
84
+ f"""
85
+ - {white_link_wrapper.format("#", "<i class='fa-solid fa-file'></i>&nbsp; Paper")}
86
+ - {white_link_wrapper.format("https://github.com/SapienzaNLP/relik", "<i class='fa-brands fa-github'></i>&nbsp; GitHub")}
87
+ - {white_link_wrapper.format("https://hub.docker.com/repository/docker/sapienzanlp/relik", "<i class='fa-brands fa-docker'></i>&nbsp; Docker Hub")}
88
+ """,
89
+ unsafe_allow_html=True,
90
+ )
91
+ st.markdown("## Sapienza NLP")
92
+ st.write(
93
+ f"""
94
+ - {white_link_wrapper.format("https://nlp.uniroma1.it", "<i class='fa-solid fa-globe'></i>&nbsp; Webpage")}
95
+ - {white_link_wrapper.format("https://github.com/SapienzaNLP", "<i class='fa-brands fa-github'></i>&nbsp; GitHub")}
96
+ - {white_link_wrapper.format("https://twitter.com/SapienzaNLP", "<i class='fa-brands fa-twitter'></i>&nbsp; Twitter")}
97
+ - {white_link_wrapper.format("https://www.linkedin.com/company/79434450", "<i class='fa-brands fa-linkedin'></i>&nbsp; LinkedIn")}
98
+ """,
99
+ unsafe_allow_html=True,
100
+ )
101
+
102
+
103
+ def get_span_annotations(response):
104
+ el_link_wrapper = (
105
+ "<link rel='stylesheet' "
106
+ "href='https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.2/css/all.min.css'>"
107
+ "<a href='https://en.wikipedia.org/wiki/{}' style='color: #414141'><i class='fa-brands"
108
+ " fa-wikipedia-w fa-xs'></i> <span style='font-size: 1.0em; font-family: monospace'> "
109
+ "{}</span></a>"
110
+ )
111
+ tokens = response.tokens
112
+ labels = ["O"] * len(tokens)
113
+ dict_ents = {}
114
+ # make BIO labels
115
+ for idx, span in enumerate(response.spans):
116
+ labels[span.start] = (
117
+ "B-" + span.label + str(idx)
118
+ if span.label == "NME"
119
+ else "B-" + el_link_wrapper.format(span.label.replace(" ", "_"), span.label)
120
+ )
121
+ for i in range(span.start + 1, span.end):
122
+ labels[i] = (
123
+ "I-" + span.label + str(idx)
124
+ if span.label == "NME"
125
+ else "I-"
126
+ + el_link_wrapper.format(span.label.replace(" ", "_"), span.label)
127
+ )
128
+ dict_ents[(span.start, span.end)] = (
129
+ span.label + str(idx),
130
+ " ".join(tokens[span.start : span.end]),
131
+ )
132
+ unique_labels = set(w[2:] for w in labels if w != "O")
133
+ options = {"ents": unique_labels, "colors": get_random_color(unique_labels)}
134
+ return tokens, labels, options, dict_ents
135
+
136
+
137
+ @st.cache_resource()
138
+ def load_model():
139
+ return Relik.from_pretrained("riccorl/relik-relation-extraction-nyt-small")
140
+
141
+
142
+ def set_intro(css):
143
+ # intro
144
+ st.markdown("# ReLik")
145
+ st.markdown(
146
+ "### Retrieve, Read and LinK: Fast and Accurate Entity Linking "
147
+ "and Relation Extraction on an Academic Budget"
148
+ )
149
+ # st.markdown(
150
+ # "This is a front-end for the paper [Universal Semantic Annotator: the First Unified API "
151
+ # "for WSD, SRL and Semantic Parsing](https://www.researchgate.net/publication/360671045_Universal
152
+ # _Semantic_Annotator_the_First_Unified_API_for_WSD_SRL_and_Semantic_Parsing),
153
+ # which will be presented at LREC 2022 by "
154
+ # "[Riccardo Orlando](https://riccorl.github.io), [Simone Conia](https://c-simone.github.io/), "
155
+ # "[Stefano Faralli](https://corsidilaurea.uniroma1.it/it/users/stefanofaralliuniroma1it),
156
+ # and [Roberto Navigli](https://www.diag.uniroma1.it/navigli/)."
157
+ # )
158
+ badge(type="github", name="sapienzanlp/relik")
159
+ badge(type="pypi", name="relik")
160
+
161
+
162
+ def run_client():
163
+ with open(Path(__file__).parent / "style.css") as f:
164
+ css = f.read()
165
+
166
+ st.set_page_config(
167
+ page_title="ReLik",
168
+ page_icon="🦮",
169
+ layout="wide",
170
+ )
171
+ set_sidebar(css)
172
+ set_intro(css)
173
+
174
+ # text input
175
+ text = st.text_area(
176
+ "Enter Text Below:",
177
+ value="Michael Jordan was one of the best players in the NBA.",
178
+ height=200,
179
+ max_chars=1500,
180
+ )
181
+
182
+ with stylable_container(
183
+ key="annotate_button",
184
+ css_styles="""
185
+ button {
186
+ background-color: #802433;
187
+ color: white;
188
+ border-radius: 25px;
189
+ }
190
+ """,
191
+ ):
192
+ submit = st.button("Annotate")
193
+
194
+ if "relik_model" not in st.session_state.keys():
195
+ st.session_state["relik_model"] = load_model()
196
+ relik_model = st.session_state["relik_model"]
197
+ init_state_variables()
198
+ # ReLik API call
199
+
200
+ # spacy for span visualization
201
+ nlp = spacy.blank("xx")
202
+
203
+ if submit:
204
+ text = text.strip()
205
+ if text:
206
+ st.session_state["filename"] = str(dt.now().timestamp() * 1000) + ".html"
207
+
208
+ with st.spinner(text="In progress"):
209
+ response = relik_model(text, annotation_type="word", num_workers=0)
210
+ # response = requests.post(RELIK, json=text)
211
+ # if response.status_code != 200:
212
+ # st.error("Error: {}".format(response.status_code))
213
+ # else:
214
+ # response = response.json()
215
+
216
+ # EL
217
+ st.markdown("####")
218
+ st.markdown("#### Entities")
219
+ tokens, labels, options, dict_ents = get_span_annotations(
220
+ response=response
221
+ )
222
+ doc = Doc(nlp.vocab, words=tokens, ents=labels)
223
+ display_el = displacy.render(doc, style="ent", options=options)
224
+ display_el = display_el.replace("\n", " ")
225
+ # heuristic, prevents split of annotation decorations
226
+ display_el = display_el.replace(
227
+ "border-radius: 0.35em;",
228
+ "border-radius: 0.35em; white-space: nowrap;",
229
+ )
230
+ with st.container():
231
+ st.write(display_el, unsafe_allow_html=True)
232
+
233
+ # RE
234
+ generate_graph(
235
+ dict_ents, response, st.session_state["filename"], options
236
+ )
237
+ HtmlFile = open(st.session_state["filename"], "r", encoding="utf-8")
238
+ source_code = HtmlFile.read()
239
+ st.session_state["html_free"] = source_code
240
+ os.remove(st.session_state["filename"])
241
+ st.session_state["has_run_free"] = True
242
+ else:
243
+ st.error("Please enter some text.")
244
+
245
+ if st.session_state["has_run_free"]:
246
+ st.markdown("#### Relations")
247
+ components.html(st.session_state["html_free"], width=720, height=600)
248
+
249
+
250
+ if __name__ == "__main__":
251
+ run_client()
relik/inference/serve/frontend/style.css ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Sidebar */
2
+ .eczjsme11 {
3
+ background-color: #802433;
4
+ }
5
+
6
+ .st-emotion-cache-10oheav h2 {
7
+ color: white;
8
+ }
9
+
10
+ .st-emotion-cache-10oheav li {
11
+ color: white;
12
+ }
13
+
14
+ /* Main */
15
+ a:link {
16
+ text-decoration: none;
17
+ color: white;
18
+ }
19
+
20
+ a:visited {
21
+ text-decoration: none;
22
+ color: white;
23
+ }
24
+
25
+ a:hover {
26
+ text-decoration: none;
27
+ color: rgba(255, 255, 255, 0.871);
28
+ }
29
+
30
+ a:active {
31
+ text-decoration: none;
32
+ color: white;
33
+ }
relik/inference/serve/frontend/utils.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import random
3
+ from typing import Dict, List, Optional, Union
4
+
5
+ import spacy
6
+ import streamlit as st
7
+ from spacy import displacy
8
+
9
+
10
+ def get_html(html: str):
11
+ """Convert HTML so it can be rendered."""
12
+ WRAPPER = """<div style="overflow-x: auto; border: 1px solid #e6e9ef; border-radius: 0.25rem; padding: 1rem; margin-bottom: 2.5rem">{}</div>"""
13
+ # Newlines seem to mess with the rendering
14
+ html = html.replace("\n", " ")
15
+ return WRAPPER.format(html)
16
+
17
+
18
+ def get_svg(svg: str, style: str = "", wrap: bool = True):
19
+ """Convert an SVG to a base64-encoded image."""
20
+ b64 = base64.b64encode(svg.encode("utf-8")).decode("utf-8")
21
+ html = f'<img src="data:image/svg+xml;base64,{b64}" style="{style}"/>'
22
+ return get_html(html) if wrap else html
23
+
24
+
25
+ def visualize_parser(
26
+ doc: Union[spacy.tokens.Doc, List[Dict[str, str]]],
27
+ *,
28
+ title: Optional[str] = None,
29
+ key: Optional[str] = None,
30
+ manual: bool = False,
31
+ displacy_options: Optional[Dict] = None,
32
+ ) -> None:
33
+ """Visualizer for dependency parses.
34
+
35
+ doc (Doc, List): The document to visualize.
36
+ key (str): Key used for the streamlit component for selecting labels.
37
+ title (str): The title displayed at the top of the parser visualization.
38
+ manual (bool): Flag signifying whether the doc argument is a Doc object or a List of Dicts containing parse information.
39
+ displacy_options (Dict): Dictionary of options to be passed to the displacy render method for generating the HTML to be rendered.
40
+ See: https://spacy.io/api/top-level#options-dep
41
+ """
42
+ if displacy_options is None:
43
+ displacy_options = dict()
44
+ if title:
45
+ st.header(title)
46
+ docs = [doc]
47
+ # add selected options to options provided by user
48
+ # `options` from `displacy_options` are overwritten by user provided
49
+ # options from the checkboxes
50
+ for sent in docs:
51
+ html = displacy.render(
52
+ sent, options=displacy_options, style="dep", manual=manual
53
+ )
54
+ # Double newlines seem to mess with the rendering
55
+ html = html.replace("\n\n", "\n")
56
+ st.write(get_svg(html), unsafe_allow_html=True)
57
+
58
+
59
+ def get_random_color(ents):
60
+ colors = {}
61
+ random_colors = generate_pastel_colors(len(ents))
62
+ for ent in ents:
63
+ colors[ent] = random_colors.pop(random.randint(0, len(random_colors) - 1))
64
+ return colors
65
+
66
+
67
+ def floatrange(start, stop, steps):
68
+ if int(steps) == 1:
69
+ return [stop]
70
+ return [
71
+ start + float(i) * (stop - start) / (float(steps) - 1) for i in range(steps)
72
+ ]
73
+
74
+
75
+ def hsl_to_rgb(h, s, l):
76
+ def hue_2_rgb(v1, v2, v_h):
77
+ while v_h < 0.0:
78
+ v_h += 1.0
79
+ while v_h > 1.0:
80
+ v_h -= 1.0
81
+ if 6 * v_h < 1.0:
82
+ return v1 + (v2 - v1) * 6.0 * v_h
83
+ if 2 * v_h < 1.0:
84
+ return v2
85
+ if 3 * v_h < 2.0:
86
+ return v1 + (v2 - v1) * ((2.0 / 3.0) - v_h) * 6.0
87
+ return v1
88
+
89
+ # if not (0 <= s <= 1): raise ValueError, "s (saturation) parameter must be between 0 and 1."
90
+ # if not (0 <= l <= 1): raise ValueError, "l (lightness) parameter must be between 0 and 1."
91
+
92
+ r, b, g = (l * 255,) * 3
93
+ if s != 0.0:
94
+ if l < 0.5:
95
+ var_2 = l * (1.0 + s)
96
+ else:
97
+ var_2 = (l + s) - (s * l)
98
+ var_1 = 2.0 * l - var_2
99
+ r = 255 * hue_2_rgb(var_1, var_2, h + (1.0 / 3.0))
100
+ g = 255 * hue_2_rgb(var_1, var_2, h)
101
+ b = 255 * hue_2_rgb(var_1, var_2, h - (1.0 / 3.0))
102
+
103
+ return int(round(r)), int(round(g)), int(round(b))
104
+
105
+
106
+ def generate_pastel_colors(n):
107
+ """Return different pastel colours.
108
+
109
+ Input:
110
+ n (integer) : The number of colors to return
111
+
112
+ Output:
113
+ A list of colors in HTML notation (eg.['#cce0ff', '#ffcccc', '#ccffe0', '#f5ccff', '#f5ffcc'])
114
+
115
+ Example:
116
+ >>> print generate_pastel_colors(5)
117
+ ['#cce0ff', '#f5ccff', '#ffcccc', '#f5ffcc', '#ccffe0']
118
+ """
119
+ if n == 0:
120
+ return []
121
+
122
+ # To generate colors, we use the HSL colorspace (see http://en.wikipedia.org/wiki/HSL_color_space)
123
+ start_hue = 0.0 # 0=red 1/3=0.333=green 2/3=0.666=blue
124
+ saturation = 1.0
125
+ lightness = 0.9
126
+ # We take points around the chromatic circle (hue):
127
+ # (Note: we generate n+1 colors, then drop the last one ([:-1]) because
128
+ # it equals the first one (hue 0 = hue 1))
129
+ return [
130
+ "#%02x%02x%02x" % hsl_to_rgb(hue, saturation, lightness)
131
+ for hue in floatrange(start_hue, start_hue + 1, n + 1)
132
+ ][:-1]