Spaces:
Sleeping
Sleeping
from __future__ import annotations | |
import logging | |
from typing import Any, Dict, List, Optional | |
from langchain.embeddings.base import Embeddings | |
from langchain.pydantic_v1 import BaseModel, root_validator | |
from langchain.utils import get_from_dict_or_env | |
logger = logging.getLogger(__name__) | |
class ZhipuAIEmbeddings(BaseModel, Embeddings): | |
"""`Zhipuai Embeddings` embedding models.""" | |
zhipuai_api_key: Optional[str] = None | |
"""Zhipuai application apikey""" | |
def validate_environment(cls, values: Dict) -> Dict: | |
""" | |
Validate whether zhipuai_api_key in the environment variables or | |
configuration file are available or not. | |
Args: | |
values: a dictionary containing configuration information, must include the | |
fields of zhipuai_api_key | |
Returns: | |
a dictionary containing configuration information. If zhipuai_api_key | |
are not provided in the environment variables or configuration | |
file, the original values will be returned; otherwise, values containing | |
zhipuai_api_key will be returned. | |
Raises: | |
ValueError: zhipuai package not found, please install it with `pip install | |
zhipuai` | |
""" | |
values["zhipuai_api_key"] = get_from_dict_or_env( | |
values, | |
"zhipuai_api_key", | |
"ZHIPUAI_API_KEY", | |
) | |
try: | |
from zhipuai import ZhipuAI | |
values["client"] = ZhipuAI(api_key=values["zhipuai_api_key"]) | |
except ImportError: | |
raise ValueError( | |
"Zhipuai package not found, please install it with " | |
"`pip install zhipuai`" | |
) | |
return values | |
def _embed(self, texts: str) -> List[float]: | |
# send request | |
try: | |
print("cal embed:", texts) | |
resp = self.client.embeddings.create( | |
model="embedding-2", | |
input=texts | |
) | |
#print("resp:", resp) | |
except Exception as e: | |
raise ValueError(f"Error raised by inference endpoint: {e}") | |
if not resp.data : | |
raise ValueError( | |
"Error raised by inference API HTTP code: %s, %s" | |
% (resp["code"], resp["msg"]) | |
) | |
embeddings = resp.data[0].embedding | |
return embeddings | |
def embed_query(self, text: str) -> List[float]: | |
""" | |
Embedding a text. | |
Args: | |
Text (str): A text to be embedded. | |
Return: | |
List [float]: An embedding list of input text, which is a list of floating-point values. | |
""" | |
resp = self.embed_documents([text]) | |
return resp[0] | |
def embed_documents(self, texts: List[str]) -> List[List[float]]: | |
""" | |
Embeds a list of text documents. | |
Args: | |
texts (List[str]): A list of text documents to embed. | |
Returns: | |
List[List[float]]: A list of embeddings for each document in the input list. | |
Each embedding is represented as a list of float values. | |
""" | |
return [self._embed(text) for text in texts] | |
async def aembed_documents(self, texts: List[str]) -> List[List[float]]: | |
"""Asynchronous Embed search docs.""" | |
raise NotImplementedError( | |
"Please use `embed_documents`. Official does not support asynchronous requests") | |
async def aembed_query(self, text: str) -> List[float]: | |
"""Asynchronous Embed query text.""" | |
raise NotImplementedError( | |
"Please use `aembed_query`. Official does not support asynchronous requests") |