"""PaLM embeddings.""" from typing import TYPE_CHECKING, Iterable, cast import numpy as np from tenacity import retry, stop_after_attempt, wait_random_exponential from typing_extensions import override from ..config import env from ..schema import Item, RichData from ..signals.signal import TextEmbeddingSignal from ..signals.splitters.chunk_splitter import split_text from .embedding import compute_split_embeddings if TYPE_CHECKING: import google.generativeai as palm PALM_BATCH_SIZE = 1 # PaLM API only supports batch size 1. NUM_PARALLEL_REQUESTS = 256 # Because batch size is 1, we can send many requests in parallel. EMBEDDING_MODEL = 'models/embedding-gecko-001' class PaLM(TextEmbeddingSignal): """Computes embeddings using PaLM's embedding API.
**Important**: This will send data to an external server!
To use this signal, you must get a PaLM API key from [makersuite.google.com](https://makersuite.google.com/app/apikey) and add it to your .env.local. """ name = 'palm' display_name = 'PaLM Embeddings' _model: 'palm' @override def setup(self) -> None: api_key = env('PALM_API_KEY') if not api_key: raise ValueError('`PALM_API_KEY` environment variable not set.') try: import google.generativeai as palm palm.configure(api_key=api_key) self._model = palm except ImportError: raise ImportError('Could not import the "google.generativeai" python package. ' 'Please install it with `pip install google-generativeai`.') @override def compute(self, docs: Iterable[RichData]) -> Iterable[Item]: """Compute embeddings for the given documents.""" @retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(10)) def embed_fn(texts: list[str]) -> list[np.ndarray]: assert len(texts) == 1, 'PaLM API only supports batch size 1.' response = self._model.generate_embeddings(model=EMBEDDING_MODEL, text=texts[0]) return [np.array(response['embedding'], dtype=np.float32)] docs = cast(Iterable[str], docs) split_fn = split_text if self._split else None yield from compute_split_embeddings( docs, PALM_BATCH_SIZE, embed_fn, split_fn, num_parallel_requests=NUM_PARALLEL_REQUESTS)