Spaces:
Runtime error
Runtime error
File size: 1,865 Bytes
bfc0ec6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
"""Gegeral Text Embeddings (GTE) model. Open-source model, designed to run on device."""
from typing import TYPE_CHECKING, Iterable, cast
from typing_extensions import override
from ..schema import Item, RichData
from ..signal import TextEmbeddingSignal
from ..splitters.chunk_splitter import split_text
from .embedding import compute_split_embeddings
from .transformer_utils import get_model
if TYPE_CHECKING:
pass
# See https://huggingface.co/spaces/mteb/leaderboard for leaderboard of models.
GTE_SMALL = 'thenlper/gte-small'
GTE_BASE = 'thenlper/gte-base'
# Maps a tuple of model name and device to the optimal batch size, found empirically.
_OPTIMAL_BATCH_SIZES: dict[str, dict[str, int]] = {
GTE_SMALL: {
'': 64, # Default batch size.
'mps': 256,
},
GTE_BASE: {
'': 64, # Default batch size.
'mps': 128,
}
}
class GTESmall(TextEmbeddingSignal):
"""Computes Gegeral Text Embeddings (GTE).
<br>This embedding runs on-device. See the [model card](https://huggingface.co/thenlper/gte-small)
for details.
"""
name = 'gte-small'
display_name = 'Gegeral Text Embeddings (small)'
_model_name = GTE_SMALL
@override
def compute(self, docs: Iterable[RichData]) -> Iterable[Item]:
"""Call the embedding function."""
batch_size, model = get_model(self._model_name, _OPTIMAL_BATCH_SIZES[self._model_name])
embed_fn = model.encode
split_fn = split_text if self._split else None
docs = cast(Iterable[str], docs)
yield from compute_split_embeddings(docs, batch_size, embed_fn=embed_fn, split_fn=split_fn)
class GTEBase(GTESmall):
"""Computes Gegeral Text Embeddings (GTE).
<br>This embedding runs on-device. See the [model card](https://huggingface.co/thenlper/gte-base)
for details.
"""
name = 'gte-base'
display_name = 'Gegeral Text Embeddings (base)'
_model_name = GTE_BASE
|