|
from __future__ import annotations |
|
|
|
import logging |
|
from typing import Any, Dict, List, Optional |
|
|
|
from langchain.embeddings.base import Embeddings |
|
from langchain.pydantic_v1 import BaseModel, root_validator |
|
from langchain.utils import get_from_dict_or_env |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class ZhipuAIEmbeddings(BaseModel, Embeddings): |
|
"""`Zhipuai Embeddings` embedding models.""" |
|
|
|
zhipuai_api_key: Optional[str] = None |
|
"""Zhipuai application apikey""" |
|
|
|
@root_validator() |
|
def validate_environment(cls, values: Dict) -> Dict: |
|
""" |
|
Validate whether zhipuai_api_key in the environment variables or |
|
configuration file are available or not. |
|
|
|
Args: |
|
|
|
values: a dictionary containing configuration information, must include the |
|
fields of zhipuai_api_key |
|
Returns: |
|
|
|
a dictionary containing configuration information. If zhipuai_api_key |
|
are not provided in the environment variables or configuration |
|
file, the original values will be returned; otherwise, values containing |
|
zhipuai_api_key will be returned. |
|
Raises: |
|
|
|
ValueError: zhipuai package not found, please install it with `pip install |
|
zhipuai` |
|
""" |
|
values["zhipuai_api_key"] = get_from_dict_or_env( |
|
values, |
|
"zhipuai_api_key", |
|
"ZHIPUAI_API_KEY", |
|
) |
|
|
|
try: |
|
import zhipuai |
|
zhipuai.api_key = values["zhipuai_api_key"] |
|
values["client"] = zhipuai.model_api |
|
|
|
except ImportError: |
|
raise ValueError( |
|
"Zhipuai package not found, please install it with " |
|
"`pip install zhipuai`" |
|
) |
|
return values |
|
|
|
def _embed(self, texts: str) -> List[float]: |
|
|
|
try: |
|
resp = self.client.invoke( |
|
model="text_embedding", |
|
prompt=texts |
|
) |
|
except Exception as e: |
|
raise ValueError(f"Error raised by inference endpoint: {e}") |
|
|
|
if resp["code"] != 200: |
|
raise ValueError( |
|
"Error raised by inference API HTTP code: %s, %s" |
|
% (resp["code"], resp["msg"]) |
|
) |
|
embeddings = resp["data"]["embedding"] |
|
return embeddings |
|
|
|
def embed_query(self, text: str) -> List[float]: |
|
""" |
|
Embedding a text. |
|
|
|
Args: |
|
|
|
Text (str): A text to be embedded. |
|
|
|
Return: |
|
|
|
List [float]: An embedding list of input text, which is a list of floating-point values. |
|
""" |
|
resp = self.embed_documents([text]) |
|
return resp[0] |
|
|
|
def embed_documents(self, texts: List[str]) -> List[List[float]]: |
|
""" |
|
Embeds a list of text documents. |
|
|
|
Args: |
|
texts (List[str]): A list of text documents to embed. |
|
|
|
Returns: |
|
List[List[float]]: A list of embeddings for each document in the input list. |
|
Each embedding is represented as a list of float values. |
|
""" |
|
return [self._embed(text) for text in texts] |
|
|
|
async def aembed_documents(self, texts: List[str]) -> List[List[float]]: |
|
"""Asynchronous Embed search docs.""" |
|
raise NotImplementedError( |
|
"Please use `embed_documents`. Official does not support asynchronous requests") |
|
|
|
async def aembed_query(self, text: str) -> List[float]: |
|
"""Asynchronous Embed query text.""" |
|
raise NotImplementedError( |
|
"Please use `aembed_query`. Official does not support asynchronous requests") |
|
|