jfeng1115's picture
init commit
58d33f0
"""Wrapper around Cohere embedding models."""
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Extra, root_validator
from langchain.embeddings.base import Embeddings
from langchain.utils import get_from_dict_or_env
class CohereEmbeddings(BaseModel, Embeddings):
"""Wrapper around Cohere embedding models.
To use, you should have the ``cohere`` python package installed, and the
environment variable ``COHERE_API_KEY`` set with your API key or pass it
as a named parameter to the constructor.
Example:
.. code-block:: python
from langchain.embeddings import CohereEmbeddings
cohere = CohereEmbeddings(model="medium", cohere_api_key="my-api-key")
"""
client: Any #: :meta private:
model: str = "large"
"""Model name to use."""
truncate: Optional[str] = None
"""Truncate embeddings that are too long from start or end ("NONE"|"START"|"END")"""
cohere_api_key: Optional[str] = None
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
cohere_api_key = get_from_dict_or_env(
values, "cohere_api_key", "COHERE_API_KEY"
)
try:
import cohere
values["client"] = cohere.Client(cohere_api_key)
except ImportError:
raise ValueError(
"Could not import cohere python package. "
"Please it install it with `pip install cohere`."
)
return values
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Call out to Cohere's embedding endpoint.
Args:
texts: The list of texts to embed.
Returns:
List of embeddings, one for each text.
"""
embeddings = self.client.embed(
model=self.model, texts=texts, truncate=self.truncate
).embeddings
return [list(map(float, e)) for e in embeddings]
def embed_query(self, text: str) -> List[float]:
"""Call out to Cohere's embedding endpoint.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
embedding = self.client.embed(
model=self.model, texts=[text], truncate=self.truncate
).embeddings[0]
return list(map(float, embedding))