File size: 1,564 Bytes
e4f9cbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3600417
 
 
 
 
 
 
 
 
e4f9cbe
 
 
 
 
3600417
e4f9cbe
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
"""Cohere embeddings."""
from typing import Iterable, cast

import cohere
import numpy as np
from typing_extensions import override

from ..config import CONFIG
from ..schema import Item, RichData
from ..signals.signal import TextEmbeddingSignal
from ..signals.splitters.chunk_splitter import split_text
from .embedding import compute_split_embeddings

NUM_PARALLEL_REQUESTS = 10
COHERE_BATCH_SIZE = 96


class Cohere(TextEmbeddingSignal):
  """Computes embeddings using Cohere's embedding API.

  <br>**Important**: This will send data to an external server!

  <br>To use this signal, you must get a Cohere API key from
  [cohere.com/embed](https://cohere.com/embed) and add it to your .env.local.

  <br>For details on pricing, see: https://cohere.com/pricing.
  """

  name = 'cohere'
  display_name = 'Cohere Embeddings'

  _model: cohere.Client

  @override
  def setup(self) -> None:
    api_key = CONFIG.get('COHERE_API_KEY')
    if not api_key:
      raise ValueError('`COHERE_API_KEY` environment variable not set.')
    self._model = cohere.Client(api_key, max_retries=10)

  @override
  def compute(self, docs: Iterable[RichData]) -> Iterable[Item]:
    """Compute embeddings for the given documents."""

    def embed_fn(texts: list[str]) -> list[np.ndarray]:
      return self._model.embed(texts, truncate='END').embeddings

    docs = cast(Iterable[str], docs)
    split_fn = split_text if self._split else None
    yield from compute_split_embeddings(
      docs, COHERE_BATCH_SIZE, embed_fn, split_fn, num_parallel_requests=NUM_PARALLEL_REQUESTS)