File size: 2,344 Bytes
55dc3dd
815971e
55dc3dd
 
 
 
 
c14732f
55dc3dd
 
 
 
 
815971e
 
 
55dc3dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
815971e
 
3600417
 
c14732f
3600417
 
815971e
 
 
 
 
 
 
3600417
55dc3dd
 
 
 
 
 
e9a1c18
 
 
 
 
815971e
55dc3dd
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
"""OpenAI embeddings."""
from typing import TYPE_CHECKING, Iterable, cast

import numpy as np
from tenacity import retry, stop_after_attempt, wait_random_exponential
from typing_extensions import override

from ..config import env
from ..schema import Item, RichData
from ..signals.signal import TextEmbeddingSignal
from ..signals.splitters.chunk_splitter import split_text
from .embedding import compute_split_embeddings

if TYPE_CHECKING:
  import openai

NUM_PARALLEL_REQUESTS = 10
OPENAI_BATCH_SIZE = 128
EMBEDDING_MODEL = 'text-embedding-ada-002'


class OpenAI(TextEmbeddingSignal):
  """Computes embeddings using OpenAI's embedding API.

  <br>**Important**: This will send data to an external server!

  <br>To use this signal, you must get an OpenAI API key from
  [platform.openai.com](https://platform.openai.com/) and add it to your .env.local.

  <br>For details on pricing, see: https://openai.com/pricing.
  """

  name = 'openai'
  display_name = 'OpenAI Embeddings'

  _model: 'openai.Embedding'

  @override
  def setup(self) -> None:
    api_key = env('OPENAI_API_KEY')
    if not api_key:
      raise ValueError('`OPENAI_API_KEY` environment variable not set.')
    try:
      import openai
      openai.api_key = api_key
      self._model = openai.Embedding
    except ImportError:
      raise ImportError('Could not import the "openai" python package. '
                        'Please install it with `pip install openai`.')

  @override
  def compute(self, docs: Iterable[RichData]) -> Iterable[Item]:
    """Compute embeddings for the given documents."""

    @retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(10))
    def embed_fn(texts: list[str]) -> list[np.ndarray]:

      # Replace newlines, which can negatively affect performance.
      # See https://github.com/search?q=repo%3Aopenai%2Fopenai-python+replace+newlines&type=code
      texts = [text.replace('\n', ' ') for text in texts]

      response = self._model.create(input=texts, model=EMBEDDING_MODEL)
      return [np.array(embedding['embedding'], dtype=np.float32) for embedding in response['data']]

    docs = cast(Iterable[str], docs)
    split_fn = split_text if self._split else None
    yield from compute_split_embeddings(
      docs, OPENAI_BATCH_SIZE, embed_fn, split_fn, num_parallel_requests=NUM_PARALLEL_REQUESTS)