nsthorat's picture
Push
56cce61
raw
history blame
No virus
2.58 kB
"""Sentence-BERT embeddings. Open-source models, designed to run on device."""
import functools
import os
from typing import TYPE_CHECKING, Iterable, Optional, cast
from typing_extensions import override
from ..config import data_path
from ..schema import Item, RichData
from ..signals.signal import TextEmbeddingSignal
from ..signals.splitters.chunk_splitter import split_text
from ..utils import log
from .embedding import compute_split_embeddings
if TYPE_CHECKING:
from sentence_transformers import SentenceTransformer
# The `all-mpnet-base-v2` model provides the best quality, while `all-MiniLM-L6-v2`` is 5 times
# faster and still offers good quality. See https://www.sbert.net/docs/pretrained_models.html#sentence-embedding-models/
MINI_LM_MODEL = 'all-MiniLM-L6-v2'
SBERT_DEFAULT_BATCH_SIZE = 64
# Maps a tuple of model name and device to the optimal batch size, found empirically.
SBERT_OPTIMAL_BATCH_SIZE: dict[tuple[str, str], int] = {
(MINI_LM_MODEL, 'mps'): 256,
}
MODEL_NAME = MINI_LM_MODEL
@functools.cache
def _sbert() -> tuple[Optional[str], 'SentenceTransformer']:
try:
import torch
from sentence_transformers import SentenceTransformer
except ImportError:
raise ImportError('Could not import the "sentence_transformers" python package. '
'Please install it with `pip install sentence-transformers`.')
preferred_device: Optional[str] = None
if torch.backends.mps.is_available():
preferred_device = 'mps'
elif not torch.backends.mps.is_built():
log('MPS not available because the current PyTorch install was not built with MPS enabled.')
return preferred_device, SentenceTransformer(
MODEL_NAME, device=preferred_device, cache_folder=os.path.join(data_path(), '.cache'))
def _optimal_batch_size(preferred_device: Optional[str]) -> int:
model_device = (MODEL_NAME, str(preferred_device))
if model_device in SBERT_OPTIMAL_BATCH_SIZE:
return SBERT_OPTIMAL_BATCH_SIZE[model_device]
return SBERT_DEFAULT_BATCH_SIZE
class SBERT(TextEmbeddingSignal):
"""Computes embeddings using Sentence-BERT library."""
name = 'sbert'
display_name = 'SBERT Embeddings'
@override
def compute(self, docs: Iterable[RichData]) -> Iterable[Item]:
"""Call the embedding function."""
preferred_device, model = _sbert()
batch_size = _optimal_batch_size(preferred_device)
embed_fn = model.encode
split_fn = split_text if self._split else None
docs = cast(Iterable[str], docs)
yield from compute_split_embeddings(docs, batch_size, embed_fn=embed_fn, split_fn=split_fn)