Spaces:
Sleeping
Sleeping
from __future__ import annotations | |
import logging | |
from typing import Any, Dict, List, Optional | |
from langchain.embeddings.base import Embeddings | |
from langchain.pydantic_v1 import BaseModel, root_validator | |
from langchain.utils import get_from_dict_or_env | |
from FlagEmbedding import FlagModel | |
logger = logging.getLogger(__name__) | |
class LocalEmbed(BaseModel, Embeddings): | |
"""`Zhipuai Embeddings` embedding models.""" | |
zhipuai_api_key: Optional[str] = None | |
"""Zhipuai application apikey""" | |
def validate_environment(cls, values: Dict) -> Dict: | |
""" | |
Validate whether zhipuai_api_key in the environment variables or | |
configuration file are available or not. | |
Args: | |
values: a dictionary containing configuration information, must include the | |
fields of zhipuai_api_key | |
Returns: | |
a dictionary containing configuration information. If zhipuai_api_key | |
are not provided in the environment variables or configuration | |
file, the original values will be returned; otherwise, values containing | |
zhipuai_api_key will be returned. | |
Raises: | |
ValueError: zhipuai package not found, please install it with `pip install | |
zhipuai` | |
""" | |
values["zhipuai_api_key"] = get_from_dict_or_env( | |
values, | |
"zhipuai_api_key", | |
"ZHIPUAI_API_KEY", | |
) | |
values["client"] = FlagModel('BAAI/bge-large-zh-v1.5', | |
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", | |
use_fp16=True) | |
return values | |
def _embed(self, texts: str) -> List[float]: | |
print("cal embed:", texts) | |
embeddings = self.client.encode(texts) | |
return embeddings | |
def embed_query(self, text: str) -> List[float]: | |
""" | |
Embedding a text. | |
Args: | |
Text (str): A text to be embedded. | |
Return: | |
List [float]: An embedding list of input text, which is a list of floating-point values. | |
""" | |
resp = self.embed_documents([text]) | |
return resp[0] | |
def embed_documents(self, texts: List[str]) -> List[List[float]]: | |
""" | |
Embeds a list of text documents. | |
Args: | |
texts (List[str]): A list of text documents to embed. | |
Returns: | |
List[List[float]]: A list of embeddings for each document in the input list. | |
Each embedding is represented as a list of float values. | |
""" | |
return [self._embed(text) for text in texts] | |
async def aembed_documents(self, texts: List[str]) -> List[List[float]]: | |
"""Asynchronous Embed search docs.""" | |
raise NotImplementedError( | |
"Please use `embed_documents`. Official does not support asynchronous requests") | |
async def aembed_query(self, text: str) -> List[float]: | |
"""Asynchronous Embed query text.""" | |
raise NotImplementedError( | |
"Please use `aembed_query`. Official does not support asynchronous requests") |