semantic-search / embeddings.py
davila7's picture
Add application file
918336a
raw
history blame contribute delete
No virus
4.26 kB
"""Wrapper around OpenAI embedding models."""
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Extra, root_validator
from langchain.embeddings.base import Embeddings
from langchain.utils import get_from_dict_or_env
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from openai.error import Timeout, APIError, APIConnectionError, RateLimitError
class OpenAIEmbeddings(BaseModel, Embeddings):
"""Wrapper around OpenAI embedding models.
To use, you should have the ``openai`` python package installed, and the
environment variable ``OPENAI_API_KEY`` set with your API key or pass it
as a named parameter to the constructor.
Example:
.. code-block:: python
from langchain.embeddings import OpenAIEmbeddings
openai = OpenAIEmbeddings(openai_api_key="my-api-key")
"""
client: Any #: :meta private:
document_model_name: str = "text-embedding-ada-002"
query_model_name: str = "text-embedding-ada-002"
openai_api_key: Optional[str] = None
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
# TODO: deprecate this
@root_validator(pre=True, allow_reuse=True)
def get_model_names(cls, values: Dict) -> Dict:
"""Get model names from just old model name."""
if "model_name" in values:
if "document_model_name" in values:
raise ValueError(
"Both `model_name` and `document_model_name` were provided, "
"but only one should be."
)
if "query_model_name" in values:
raise ValueError(
"Both `model_name` and `query_model_name` were provided, "
"but only one should be."
)
model_name = values.pop("model_name")
values["document_model_name"] = f"text-search-{model_name}-doc-001"
values["query_model_name"] = f"text-search-{model_name}-query-001"
return values
@root_validator(allow_reuse=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
openai_api_key = get_from_dict_or_env(
values, "openai_api_key", "OPENAI_API_KEY"
)
try:
import openai
openai.api_key = openai_api_key
values["client"] = openai.Embedding
except ImportError:
raise ValueError(
"Could not import openai python package. "
"Please it install it with `pip install openai`."
)
return values
@retry(
reraise=True,
stop=stop_after_attempt(100),
wait=wait_exponential(multiplier=1, min=10, max=60),
retry=(
retry_if_exception_type(Timeout)
| retry_if_exception_type(APIError)
| retry_if_exception_type(APIConnectionError)
| retry_if_exception_type(RateLimitError)
),
)
def _embedding_func(self, text: str, *, engine: str) -> List[float]:
"""Call out to OpenAI's embedding endpoint with exponential backoff."""
# replace newlines, which can negatively affect performance.
text = text.replace("\n", " ")
return self.client.create(input=[text], engine=engine)["data"][0]["embedding"]
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Call out to OpenAI's embedding endpoint for embedding search docs.
Args:
texts: The list of texts to embed.
Returns:
List of embeddings, one for each text.
"""
responses = [
self._embedding_func(text, engine=self.document_model_name)
for text in texts
]
return responses
def embed_query(self, text: str) -> List[float]:
"""Call out to OpenAI's embedding endpoint for embedding query text.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
embedding = self._embedding_func(text, engine=self.query_model_name)
return embedding