Spaces:
Runtime error
Runtime error
File size: 2,280 Bytes
e4f9cbe 55dc3dd e4f9cbe d9bd6c3 e4f9cbe d9bd6c3 55dc3dd e4f9cbe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
"""Sentence-BERT embeddings. Open-source models, designed to run on device."""
import functools
import os
from typing import Iterable, Optional, cast
import torch
from sentence_transformers import SentenceTransformer
from typing_extensions import override
from ..config import data_path
from ..schema import Item, RichData
from ..signals.signal import TextEmbeddingSignal
from ..signals.splitters.chunk_splitter import split_text
from ..utils import log
from .embedding import compute_split_embeddings
# The `all-mpnet-base-v2` model provides the best quality, while `all-MiniLM-L6-v2`` is 5 times
# faster and still offers good quality. See https://www.sbert.net/docs/pretrained_models.html#sentence-embedding-models/
MINI_LM_MODEL = 'all-MiniLM-L6-v2'
SBERT_DEFAULT_BATCH_SIZE = 64
# Maps a tuple of model name and device to the optimal batch size, found empirically.
SBERT_OPTIMAL_BATCH_SIZE: dict[tuple[str, str], int] = {
(MINI_LM_MODEL, 'mps'): 256,
}
MODEL_NAME = MINI_LM_MODEL
@functools.cache
def _sbert() -> tuple[Optional[str], SentenceTransformer]:
preferred_device: Optional[str] = None
if torch.backends.mps.is_available():
preferred_device = 'mps'
elif not torch.backends.mps.is_built():
log('MPS not available because the current PyTorch install was not built with MPS enabled.')
return preferred_device, SentenceTransformer(
MODEL_NAME, device=preferred_device, cache_folder=os.path.join(data_path(), '.cache'))
def _optimal_batch_size(preferred_device: Optional[str]) -> int:
model_device = (MODEL_NAME, str(preferred_device))
if model_device in SBERT_OPTIMAL_BATCH_SIZE:
return SBERT_OPTIMAL_BATCH_SIZE[model_device]
return SBERT_DEFAULT_BATCH_SIZE
class SBERT(TextEmbeddingSignal):
"""Computes embeddings using Sentence-BERT library."""
name = 'sbert'
display_name = 'SBERT Embeddings'
@override
def compute(self, docs: Iterable[RichData]) -> Iterable[Item]:
"""Call the embedding function."""
preferred_device, model = _sbert()
batch_size = _optimal_batch_size(preferred_device)
embed_fn = model.encode
split_fn = split_text if self._split else None
docs = cast(Iterable[str], docs)
yield from compute_split_embeddings(docs, batch_size, embed_fn=embed_fn, split_fn=split_fn)
|