from __future__ import annotations import logging from typing import Any, Dict, List, Optional from langchain.embeddings.base import Embeddings from langchain.pydantic_v1 import BaseModel, root_validator from langchain.utils import get_from_dict_or_env from FlagEmbedding import LLMEmbedder logger = logging.getLogger(__name__) class LocalEmbed(BaseModel, Embeddings): """`Zhipuai Embeddings` embedding models.""" zhipuai_api_key: Optional[str] = None """Zhipuai application apikey""" @root_validator() def validate_environment(cls, values: Dict) -> Dict: """ Validate whether zhipuai_api_key in the environment variables or configuration file are available or not. Args: values: a dictionary containing configuration information, must include the fields of zhipuai_api_key Returns: a dictionary containing configuration information. If zhipuai_api_key are not provided in the environment variables or configuration file, the original values will be returned; otherwise, values containing zhipuai_api_key will be returned. Raises: ValueError: zhipuai package not found, please install it with `pip install zhipuai` """ values["zhipuai_api_key"] = get_from_dict_or_env( values, "zhipuai_api_key", "ZHIPUAI_API_KEY", ) values["client"] = LLMEmbedder('BAAI/bge-large-zh-v1.5', query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", use_fp16=True) return values def _embed(self, texts: str) -> List[float]: print("cal embed:", texts) embeddings = self.client.encode(texts) return embeddings def embed_query(self, text: str) -> List[float]: """ Embedding a text. Args: Text (str): A text to be embedded. Return: List [float]: An embedding list of input text, which is a list of floating-point values. """ resp = self.embed_documents([text]) return resp[0] def embed_documents(self, texts: List[str]) -> List[List[float]]: """ Embeds a list of text documents. Args: texts (List[str]): A list of text documents to embed. Returns: List[List[float]]: A list of embeddings for each document in the input list. Each embedding is represented as a list of float values. """ return [self._embed(text) for text in texts] async def aembed_documents(self, texts: List[str]) -> List[List[float]]: """Asynchronous Embed search docs.""" raise NotImplementedError( "Please use `embed_documents`. Official does not support asynchronous requests") async def aembed_query(self, text: str) -> List[float]: """Asynchronous Embed query text.""" raise NotImplementedError( "Please use `aembed_query`. Official does not support asynchronous requests")