guangliang.yin commited on
Commit
1bf8b22
1 Parent(s): b0c2444

切换为zhipu 向量计算 -1

Browse files
app.py CHANGED
@@ -21,6 +21,7 @@ from langchain.chains.combine_documents import create_stuff_documents_chain
21
  from langchain.chains import StuffDocumentsChain
22
  from langchain_core.prompts import PromptTemplate
23
  import hashlib
 
24
 
25
  chain: Optional[Callable] = None
26
 
@@ -46,7 +47,8 @@ def web_loader(file, openai_key, puzhiai_key, zilliz_uri, user, password):
46
 
47
  text_splitter = CharacterTextSplitter(chunk_size=1024, chunk_overlap=0)
48
  docs = text_splitter.split_documents(docs)
49
- embeddings = OpenAIEmbeddings(model="text-embedding-ada-002", openai_api_key=openai_key)
 
50
 
51
  if not embeddings:
52
  return "embeddings not"
@@ -69,6 +71,7 @@ def web_loader(file, openai_key, puzhiai_key, zilliz_uri, user, password):
69
  "password": password,
70
  "secure": True,
71
  },
 
72
  )
73
 
74
  if not docsearch:
@@ -149,7 +152,7 @@ if __name__ == "__main__":
149
  """
150
  <h1><center>Langchain And Zilliz App</center></h1>
151
 
152
- v.2.27.16.14
153
 
154
  """
155
  )
 
21
  from langchain.chains import StuffDocumentsChain
22
  from langchain_core.prompts import PromptTemplate
23
  import hashlib
24
+ from project.embeddings.zhipuai_embedding import ZhipuAIEmbeddings
25
 
26
  chain: Optional[Callable] = None
27
 
 
47
 
48
  text_splitter = CharacterTextSplitter(chunk_size=1024, chunk_overlap=0)
49
  docs = text_splitter.split_documents(docs)
50
+ #embeddings = OpenAIEmbeddings(model="text-embedding-ada-002", openai_api_key=openai_key)
51
+ embeddings = ZhipuAIEmbeddings(zhipuai_api_key=puzhiai_key)
52
 
53
  if not embeddings:
54
  return "embeddings not"
 
71
  "password": password,
72
  "secure": True,
73
  },
74
+ collection_name="LangChainCollectionYin"
75
  )
76
 
77
  if not docsearch:
 
152
  """
153
  <h1><center>Langchain And Zilliz App</center></h1>
154
 
155
+ v.2.27.17.18
156
 
157
  """
158
  )
project/embeddings/__init__.py ADDED
File without changes
project/embeddings/zhipuai_embedding.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from typing import Any, Dict, List, Optional
5
+
6
+ from langchain.embeddings.base import Embeddings
7
+ from langchain.pydantic_v1 import BaseModel, root_validator
8
+ from langchain.utils import get_from_dict_or_env
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class ZhipuAIEmbeddings(BaseModel, Embeddings):
14
+ """`Zhipuai Embeddings` embedding models."""
15
+
16
+ zhipuai_api_key: Optional[str] = None
17
+ """Zhipuai application apikey"""
18
+
19
+ @root_validator()
20
+ def validate_environment(cls, values: Dict) -> Dict:
21
+ """
22
+ Validate whether zhipuai_api_key in the environment variables or
23
+ configuration file are available or not.
24
+
25
+ Args:
26
+
27
+ values: a dictionary containing configuration information, must include the
28
+ fields of zhipuai_api_key
29
+ Returns:
30
+
31
+ a dictionary containing configuration information. If zhipuai_api_key
32
+ are not provided in the environment variables or configuration
33
+ file, the original values will be returned; otherwise, values containing
34
+ zhipuai_api_key will be returned.
35
+ Raises:
36
+
37
+ ValueError: zhipuai package not found, please install it with `pip install
38
+ zhipuai`
39
+ """
40
+ values["zhipuai_api_key"] = get_from_dict_or_env(
41
+ values,
42
+ "zhipuai_api_key",
43
+ "ZHIPUAI_API_KEY",
44
+ )
45
+
46
+ try:
47
+ from zhipuai import ZhipuAI
48
+ values["client"] = ZhipuAI(api_key=values["zhipuai_api_key"])
49
+
50
+ except ImportError:
51
+ raise ValueError(
52
+ "Zhipuai package not found, please install it with "
53
+ "`pip install zhipuai`"
54
+ )
55
+ return values
56
+
57
+ def _embed(self, texts: str) -> List[float]:
58
+ # send request
59
+ try:
60
+ resp = self.client.embeddings.create(
61
+ model="embedding-2",
62
+ input=texts
63
+ )
64
+ except Exception as e:
65
+ raise ValueError(f"Error raised by inference endpoint: {e}")
66
+
67
+ if resp["code"] != 200:
68
+ raise ValueError(
69
+ "Error raised by inference API HTTP code: %s, %s"
70
+ % (resp["code"], resp["msg"])
71
+ )
72
+ embeddings = resp["data"]["embedding"]
73
+ return embeddings
74
+
75
+ def embed_query(self, text: str) -> List[float]:
76
+ """
77
+ Embedding a text.
78
+
79
+ Args:
80
+
81
+ Text (str): A text to be embedded.
82
+
83
+ Return:
84
+
85
+ List [float]: An embedding list of input text, which is a list of floating-point values.
86
+ """
87
+ resp = self.embed_documents([text])
88
+ return resp[0]
89
+
90
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
91
+ """
92
+ Embeds a list of text documents.
93
+
94
+ Args:
95
+ texts (List[str]): A list of text documents to embed.
96
+
97
+ Returns:
98
+ List[List[float]]: A list of embeddings for each document in the input list.
99
+ Each embedding is represented as a list of float values.
100
+ """
101
+ return [self._embed(text) for text in texts]
102
+
103
+ async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
104
+ """Asynchronous Embed search docs."""
105
+ raise NotImplementedError(
106
+ "Please use `embed_documents`. Official does not support asynchronous requests")
107
+
108
+ async def aembed_query(self, text: str) -> List[float]:
109
+ """Asynchronous Embed query text."""
110
+ raise NotImplementedError(
111
+ "Please use `aembed_query`. Official does not support asynchronous requests")