File size: 1,865 Bytes
bfc0ec6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
"""Gegeral Text Embeddings (GTE) model. Open-source model, designed to run on device."""
from typing import TYPE_CHECKING, Iterable, cast

from typing_extensions import override

from ..schema import Item, RichData
from ..signal import TextEmbeddingSignal
from ..splitters.chunk_splitter import split_text
from .embedding import compute_split_embeddings
from .transformer_utils import get_model

if TYPE_CHECKING:
  pass

# See https://huggingface.co/spaces/mteb/leaderboard for leaderboard of models.
GTE_SMALL = 'thenlper/gte-small'
GTE_BASE = 'thenlper/gte-base'

# Maps a tuple of model name and device to the optimal batch size, found empirically.
_OPTIMAL_BATCH_SIZES: dict[str, dict[str, int]] = {
  GTE_SMALL: {
    '': 64,  # Default batch size.
    'mps': 256,
  },
  GTE_BASE: {
    '': 64,  # Default batch size.
    'mps': 128,
  }
}


class GTESmall(TextEmbeddingSignal):
  """Computes Gegeral Text Embeddings (GTE).

  <br>This embedding runs on-device. See the [model card](https://huggingface.co/thenlper/gte-small)
  for details.
  """

  name = 'gte-small'
  display_name = 'Gegeral Text Embeddings (small)'

  _model_name = GTE_SMALL

  @override
  def compute(self, docs: Iterable[RichData]) -> Iterable[Item]:
    """Call the embedding function."""
    batch_size, model = get_model(self._model_name, _OPTIMAL_BATCH_SIZES[self._model_name])
    embed_fn = model.encode
    split_fn = split_text if self._split else None
    docs = cast(Iterable[str], docs)
    yield from compute_split_embeddings(docs, batch_size, embed_fn=embed_fn, split_fn=split_fn)


class GTEBase(GTESmall):
  """Computes Gegeral Text Embeddings (GTE).

  <br>This embedding runs on-device. See the [model card](https://huggingface.co/thenlper/gte-base)
  for details.
  """
  name = 'gte-base'
  display_name = 'Gegeral Text Embeddings (base)'

  _model_name = GTE_BASE