File size: 4,117 Bytes
473c3a0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
from __future__ import annotations
import json
import logging
import re
from importlib import import_module
from importlib.metadata import metadata
from typing import TYPE_CHECKING, Any, Protocol, cast
import safetensors
from joblib import Parallel
from tokenizers import Tokenizer
from tqdm import tqdm
if TYPE_CHECKING:
from collections.abc import Iterator
from pathlib import Path
import numpy as np
logger = logging.getLogger(__name__)
class ProgressParallel(Parallel):
"""A drop-in replacement for joblib.Parallel that shows a tqdm progress bar."""
def __init__(self, use_tqdm: bool = True, total: int | None = None, *args: Any, **kwargs: Any) -> None:
"""
Initialize the ProgressParallel object.
:param use_tqdm: Whether to show the progress bar.
:param total: Total number of tasks (batches) you expect to process. If None,
it updates the total dynamically to the number of dispatched tasks.
:param *args: Additional arguments to pass to `Parallel.__init__`.
:param **kwargs: Additional keyword arguments to pass to `Parallel.__init__`.
"""
self._use_tqdm = use_tqdm
self._total = total
super().__init__(*args, **kwargs)
def __call__(self, *args: Any, **kwargs: Any) -> Any:
"""Create a tqdm context."""
with tqdm(disable=not self._use_tqdm, total=self._total) as self._pbar:
self._pbar = self._pbar
return super().__call__(*args, **kwargs)
def print_progress(self) -> None:
"""Hook called by joblib as tasks complete. We update the tqdm bar here."""
if self._total is None:
# If no fixed total was given, we dynamically set the total
self._pbar.total = self.n_dispatched_tasks
# Move the bar to the number of completed tasks
self._pbar.n = self.n_completed_tasks
self._pbar.refresh()
class SafeOpenProtocol(Protocol):
"""Protocol to fix safetensors safe open."""
def get_tensor(self, key: str) -> np.ndarray:
"""Get a tensor."""
... # pragma: no cover
_MODULE_MAP = (("scikit-learn", "sklearn"),)
_DIVIDERS = re.compile(r"[=<>!]+")
def get_package_extras(package: str, extra: str) -> Iterator[str]:
"""Get the extras of the package."""
try:
message = metadata(package)
except Exception as e:
# For local packages without metadata, return empty iterator
# This allows the package to work without installed metadata
logger.debug(f"Could not retrieve metadata for package '{package}': {e}")
return iter([])
all_packages = message.get_all("Requires-Dist") or []
for package in all_packages:
name, *rest = package.split(";", maxsplit=1)
if rest:
# Extract and clean the extra requirement
found_extra = rest[0].split("==")[-1].strip(" \"'")
if found_extra == extra:
prefix, *_ = _DIVIDERS.split(name)
yield prefix.strip()
def importable(module: str, extra: str) -> None:
"""Check if a module is importable."""
module = dict(_MODULE_MAP).get(module, module)
try:
import_module(module)
except ImportError:
msg = f"`{module}`, is required. Please reinstall model2vec with the `{extra}` extra. `pip install model2vec[{extra}]`"
raise ImportError(msg)
def setup_logging() -> None:
"""Simple logging setup."""
from rich.logging import RichHandler
logging.basicConfig(
level="INFO",
format="%(name)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
handlers=[RichHandler(rich_tracebacks=True)],
)
def load_local_model(folder: Path) -> tuple[np.ndarray, Tokenizer, dict[str, str]]:
"""Load a local model."""
embeddings_path = folder / "model.safetensors"
tokenizer_path = folder / "tokenizer.json"
config_path = folder / "config.json"
opened_tensor_file = cast("SafeOpenProtocol", safetensors.safe_open(embeddings_path, framework="numpy"))
embeddings = opened_tensor_file.get_tensor("embeddings")
config = json.load(open(config_path)) if config_path.exists() else {}
tokenizer: Tokenizer = Tokenizer.from_file(str(tokenizer_path))
if len(tokenizer.get_vocab()) != len(embeddings):
logger.warning(
f"Number of tokens does not match number of embeddings: `{len(tokenizer.get_vocab())}` vs `{len(embeddings)}`"
)
return embeddings, tokenizer, config
|