File size: 4,117 Bytes
473c3a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
from __future__ import annotations

import json
import logging
import re
from importlib import import_module
from importlib.metadata import metadata
from typing import TYPE_CHECKING, Any, Protocol, cast

import safetensors
from joblib import Parallel
from tokenizers import Tokenizer
from tqdm import tqdm

if TYPE_CHECKING:
	from collections.abc import Iterator
	from pathlib import Path

	import numpy as np

logger = logging.getLogger(__name__)


class ProgressParallel(Parallel):
	"""A drop-in replacement for joblib.Parallel that shows a tqdm progress bar."""

	def __init__(self, use_tqdm: bool = True, total: int | None = None, *args: Any, **kwargs: Any) -> None:
		"""
		Initialize the ProgressParallel object.

		:param use_tqdm: Whether to show the progress bar.
		:param total: Total number of tasks (batches) you expect to process. If None,
		            it updates the total dynamically to the number of dispatched tasks.
		:param *args: Additional arguments to pass to `Parallel.__init__`.
		:param **kwargs: Additional keyword arguments to pass to `Parallel.__init__`.
		"""
		self._use_tqdm = use_tqdm
		self._total = total
		super().__init__(*args, **kwargs)

	def __call__(self, *args: Any, **kwargs: Any) -> Any:
		"""Create a tqdm context."""
		with tqdm(disable=not self._use_tqdm, total=self._total) as self._pbar:
			self._pbar = self._pbar
			return super().__call__(*args, **kwargs)

	def print_progress(self) -> None:
		"""Hook called by joblib as tasks complete. We update the tqdm bar here."""
		if self._total is None:
			# If no fixed total was given, we dynamically set the total
			self._pbar.total = self.n_dispatched_tasks
		# Move the bar to the number of completed tasks
		self._pbar.n = self.n_completed_tasks
		self._pbar.refresh()


class SafeOpenProtocol(Protocol):
	"""Protocol to fix safetensors safe open."""

	def get_tensor(self, key: str) -> np.ndarray:
		"""Get a tensor."""
		...  # pragma: no cover


_MODULE_MAP = (("scikit-learn", "sklearn"),)
_DIVIDERS = re.compile(r"[=<>!]+")


def get_package_extras(package: str, extra: str) -> Iterator[str]:
	"""Get the extras of the package."""
	try:
		message = metadata(package)
	except Exception as e:
		# For local packages without metadata, return empty iterator
		# This allows the package to work without installed metadata
		logger.debug(f"Could not retrieve metadata for package '{package}': {e}")
		return iter([])

	all_packages = message.get_all("Requires-Dist") or []
	for package in all_packages:
		name, *rest = package.split(";", maxsplit=1)
		if rest:
			# Extract and clean the extra requirement
			found_extra = rest[0].split("==")[-1].strip(" \"'")
			if found_extra == extra:
				prefix, *_ = _DIVIDERS.split(name)
				yield prefix.strip()


def importable(module: str, extra: str) -> None:
	"""Check if a module is importable."""
	module = dict(_MODULE_MAP).get(module, module)
	try:
		import_module(module)
	except ImportError:
		msg = f"`{module}`, is required. Please reinstall model2vec with the `{extra}` extra. `pip install model2vec[{extra}]`"
		raise ImportError(msg)


def setup_logging() -> None:
	"""Simple logging setup."""
	from rich.logging import RichHandler

	logging.basicConfig(
		level="INFO",
		format="%(name)s - %(message)s",
		datefmt="%Y-%m-%d %H:%M:%S",
		handlers=[RichHandler(rich_tracebacks=True)],
	)


def load_local_model(folder: Path) -> tuple[np.ndarray, Tokenizer, dict[str, str]]:
	"""Load a local model."""
	embeddings_path = folder / "model.safetensors"
	tokenizer_path = folder / "tokenizer.json"
	config_path = folder / "config.json"

	opened_tensor_file = cast("SafeOpenProtocol", safetensors.safe_open(embeddings_path, framework="numpy"))
	embeddings = opened_tensor_file.get_tensor("embeddings")

	config = json.load(open(config_path)) if config_path.exists() else {}

	tokenizer: Tokenizer = Tokenizer.from_file(str(tokenizer_path))

	if len(tokenizer.get_vocab()) != len(embeddings):
		logger.warning(
			f"Number of tokens does not match number of embeddings: `{len(tokenizer.get_vocab())}` vs `{len(embeddings)}`"
		)

	return embeddings, tokenizer, config