guangliang.yin commited on
Commit
a7b5657
1 Parent(s): 5756ae7

初始化项目

Browse files
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
- title: Hello Embed
3
- emoji: 🌍
4
- colorFrom: pink
5
- colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 4.19.2
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
1
  ---
2
+ title: Hellow LangChain
3
+ emoji: 🐨
4
+ colorFrom: purple
5
+ colorTo: green
6
  sdk: gradio
7
+ sdk_version: 4.19.1
8
  app_file: app.py
9
  pinned: false
10
  license: mit
app.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Optional
2
+
3
+ import gradio as gr
4
+ from langchain.vectorstores import Zilliz
5
+ from langchain.document_loaders import TextLoader
6
+ from langchain.text_splitter import CharacterTextSplitter
7
+ from langchain.chains import RetrievalQAWithSourcesChain
8
+ from langchain.chains.llm import LLMChain
9
+ from langchain.chains import StuffDocumentsChain
10
+ from langchain_core.prompts import PromptTemplate
11
+ import hashlib
12
+ import os
13
+ from project.embeddings.local_embed import LocalEmbed
14
+ from project.llm.check_embed_llm import CheckEmbedLlm
15
+
16
+ chain: Optional[Callable] = None
17
+
18
+ db_host = os.getenv("DB_HOST")
19
+ db_user = os.getenv("DB_USER")
20
+ db_password = os.getenv("DB_PASSWORD")
21
+ zhipuai_api_key = os.getenv("ZHIPU_AI_KEY")
22
+
23
+
24
+ def generate_article_id(content):
25
+ # 使用SHA-256哈希算法
26
+ sha256 = hashlib.sha256()
27
+
28
+ # 将文章内容编码为字节流并更新哈希对象
29
+ sha256.update(content.encode('utf-8'))
30
+
31
+ # 获取哈希值的十六进制表示
32
+ article_id = sha256.hexdigest()
33
+
34
+ return article_id
35
+
36
+
37
+ def web_loader(file):
38
+ if not file:
39
+ return "please upload file"
40
+ loader = TextLoader(file)
41
+ docs = loader.load()
42
+
43
+ text_splitter = CharacterTextSplitter(chunk_size=512, chunk_overlap=0)
44
+ docs = text_splitter.split_documents(docs)
45
+ #embeddings = OpenAIEmbeddings(model="text-embedding-ada-002", openai_api_key=openai_key)
46
+ #embeddings = ZhipuAIEmbeddings(zhipuai_api_key=zhipuai_api_key)
47
+ embeddings = LocalEmbed(zhipuai_api_key=zhipuai_api_key)
48
+
49
+
50
+ if not embeddings:
51
+ return "embeddings not"
52
+
53
+ texts = [d.page_content for d in docs]
54
+ article_ids = []
55
+ # 遍历texts列表
56
+ for text in texts:
57
+ # 使用generate_article_id函数生成文章ID,并将其添加到article_ids列表中
58
+ article_id = generate_article_id(text)
59
+ article_ids.append(article_id)
60
+
61
+ docsearch = Zilliz.from_documents(
62
+ docs,
63
+ embedding=embeddings,
64
+ ids=article_ids,
65
+ connection_args={
66
+ "uri": db_host,
67
+ "user": db_user,
68
+ "password": db_password,
69
+ "secure": True,
70
+ },
71
+ collection_name="LangChainCollectionYin"
72
+ )
73
+
74
+ if not docsearch:
75
+ return "docsearch not"
76
+
77
+ global chain
78
+
79
+ llm = CheckEmbedLlm(model="glm-3-turbo", temperature=0.1, zhipuai_api_key=zhipuai_api_key)
80
+
81
+ document_prompt = PromptTemplate(
82
+ input_variables=["page_content"],
83
+ template="{page_content}"
84
+ )
85
+ document_variable_name = "context"
86
+ # The prompt here should take as an input variable the
87
+ # `document_variable_name`
88
+ prompt = PromptTemplate.from_template(
89
+ """查询到的文档如下:
90
+ {context}
91
+
92
+ 问题: {question}
93
+ 答:"""
94
+ )
95
+ llm_chain = LLMChain(llm=llm, prompt=prompt)
96
+ combine_documents_chain = StuffDocumentsChain(
97
+ llm_chain=llm_chain,
98
+ document_prompt=document_prompt,
99
+ document_variable_name=document_variable_name
100
+ )
101
+
102
+ chain = RetrievalQAWithSourcesChain(combine_documents_chain=combine_documents_chain,
103
+ retriever=docsearch.as_retriever(search_kwargs={'k': 3}))
104
+ return "success to load data"
105
+
106
+
107
+ def query(question):
108
+ global chain
109
+ # "What is milvus?"
110
+ if not chain:
111
+ return "please load the data first"
112
+ return chain(inputs={"question": question}, return_only_outputs=True).get(
113
+ "answer", "fail to get answer"
114
+ )
115
+
116
+
117
+ if __name__ == "__main__":
118
+ block = gr.Blocks()
119
+ with block as demo:
120
+ gr.Markdown(
121
+ """
122
+ <h1><center>Langchain And Embed App</center></h1>
123
+
124
+ v.2.29.14.55
125
+
126
+ """
127
+ )
128
+ # url_list_text = gr.Textbox(
129
+ # label="url list",
130
+ # lines=3,
131
+ # placeholder="https://milvus.io/docs/overview.md",
132
+ # )
133
+ file = gr.File(label='请上传知识库文件\n可以处理 .txt, .md, .docx, .pdf 结尾的文件',
134
+ file_types=['.txt', '.md', '.docx', '.pdf'])
135
+ #openai_key_text = gr.Textbox(label="openai api key", type="password", placeholder="sk-******")
136
+ #puzhiai_key_text = gr.Textbox(label="puzhi api key", type="password", placeholder="******")
137
+
138
+ loader_output = gr.Textbox(label="load status")
139
+ loader_btn = gr.Button("Load Data")
140
+ loader_btn.click(
141
+ fn=web_loader,
142
+ inputs=[
143
+ file,
144
+ ],
145
+ outputs=loader_output,
146
+ api_name="web_load",
147
+ )
148
+
149
+ question_text = gr.Textbox(
150
+ label="question",
151
+ lines=3,
152
+ placeholder="What is milvus?",
153
+ )
154
+ query_output = gr.Textbox(label="question answer", lines=3)
155
+ query_btn = gr.Button("Generate")
156
+ query_btn.click(
157
+ fn=query,
158
+ inputs=[question_text],
159
+ outputs=query_output,
160
+ api_name="generate_answer",
161
+ )
162
+
163
+ demo.queue().launch(server_name="0.0.0.0", share=False)
project/embeddings/__init__.py ADDED
File without changes
project/embeddings/local_embed.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from typing import Any, Dict, List, Optional
5
+
6
+ from langchain.embeddings.base import Embeddings
7
+ from langchain.pydantic_v1 import BaseModel, root_validator
8
+ from langchain.utils import get_from_dict_or_env
9
+ from FlagEmbedding import LLMEmbedder
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class LocalEmbed(BaseModel, Embeddings):
15
+ """`Zhipuai Embeddings` embedding models."""
16
+
17
+ zhipuai_api_key: Optional[str] = None
18
+ """Zhipuai application apikey"""
19
+
20
+ @root_validator()
21
+ def validate_environment(cls, values: Dict) -> Dict:
22
+ """
23
+ Validate whether zhipuai_api_key in the environment variables or
24
+ configuration file are available or not.
25
+
26
+ Args:
27
+
28
+ values: a dictionary containing configuration information, must include the
29
+ fields of zhipuai_api_key
30
+ Returns:
31
+
32
+ a dictionary containing configuration information. If zhipuai_api_key
33
+ are not provided in the environment variables or configuration
34
+ file, the original values will be returned; otherwise, values containing
35
+ zhipuai_api_key will be returned.
36
+ Raises:
37
+
38
+ ValueError: zhipuai package not found, please install it with `pip install
39
+ zhipuai`
40
+ """
41
+ values["zhipuai_api_key"] = get_from_dict_or_env(
42
+ values,
43
+ "zhipuai_api_key",
44
+ "ZHIPUAI_API_KEY",
45
+ )
46
+
47
+ values["client"] = LLMEmbedder('BAAI/bge-large-zh-v1.5',
48
+ query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
49
+ use_fp16=True)
50
+ return values
51
+
52
+ def _embed(self, texts: str) -> List[float]:
53
+ print("cal embed:", texts)
54
+
55
+ embeddings = self.client.encode(texts)
56
+
57
+ return embeddings
58
+
59
+ def embed_query(self, text: str) -> List[float]:
60
+ """
61
+ Embedding a text.
62
+
63
+ Args:
64
+
65
+ Text (str): A text to be embedded.
66
+
67
+ Return:
68
+
69
+ List [float]: An embedding list of input text, which is a list of floating-point values.
70
+ """
71
+ resp = self.embed_documents([text])
72
+ return resp[0]
73
+
74
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
75
+ """
76
+ Embeds a list of text documents.
77
+
78
+ Args:
79
+ texts (List[str]): A list of text documents to embed.
80
+
81
+ Returns:
82
+ List[List[float]]: A list of embeddings for each document in the input list.
83
+ Each embedding is represented as a list of float values.
84
+ """
85
+ return [self._embed(text) for text in texts]
86
+
87
+ async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
88
+ """Asynchronous Embed search docs."""
89
+ raise NotImplementedError(
90
+ "Please use `embed_documents`. Official does not support asynchronous requests")
91
+
92
+ async def aembed_query(self, text: str) -> List[float]:
93
+ """Asynchronous Embed query text."""
94
+ raise NotImplementedError(
95
+ "Please use `aembed_query`. Official does not support asynchronous requests")
project/embeddings/zhipuai_embedding.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from typing import Any, Dict, List, Optional
5
+
6
+ from langchain.embeddings.base import Embeddings
7
+ from langchain.pydantic_v1 import BaseModel, root_validator
8
+ from langchain.utils import get_from_dict_or_env
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class ZhipuAIEmbeddings(BaseModel, Embeddings):
14
+ """`Zhipuai Embeddings` embedding models."""
15
+
16
+ zhipuai_api_key: Optional[str] = None
17
+ """Zhipuai application apikey"""
18
+
19
+ @root_validator()
20
+ def validate_environment(cls, values: Dict) -> Dict:
21
+ """
22
+ Validate whether zhipuai_api_key in the environment variables or
23
+ configuration file are available or not.
24
+
25
+ Args:
26
+
27
+ values: a dictionary containing configuration information, must include the
28
+ fields of zhipuai_api_key
29
+ Returns:
30
+
31
+ a dictionary containing configuration information. If zhipuai_api_key
32
+ are not provided in the environment variables or configuration
33
+ file, the original values will be returned; otherwise, values containing
34
+ zhipuai_api_key will be returned.
35
+ Raises:
36
+
37
+ ValueError: zhipuai package not found, please install it with `pip install
38
+ zhipuai`
39
+ """
40
+ values["zhipuai_api_key"] = get_from_dict_or_env(
41
+ values,
42
+ "zhipuai_api_key",
43
+ "ZHIPUAI_API_KEY",
44
+ )
45
+
46
+ try:
47
+ from zhipuai import ZhipuAI
48
+ values["client"] = ZhipuAI(api_key=values["zhipuai_api_key"])
49
+
50
+ except ImportError:
51
+ raise ValueError(
52
+ "Zhipuai package not found, please install it with "
53
+ "`pip install zhipuai`"
54
+ )
55
+ return values
56
+
57
+ def _embed(self, texts: str) -> List[float]:
58
+ # send request
59
+ try:
60
+ print("cal embed:", texts)
61
+ resp = self.client.embeddings.create(
62
+ model="embedding-2",
63
+ input=texts
64
+ )
65
+ #print("resp:", resp)
66
+ except Exception as e:
67
+ raise ValueError(f"Error raised by inference endpoint: {e}")
68
+
69
+ if not resp.data :
70
+ raise ValueError(
71
+ "Error raised by inference API HTTP code: %s, %s"
72
+ % (resp["code"], resp["msg"])
73
+ )
74
+ embeddings = resp.data[0].embedding
75
+ return embeddings
76
+
77
+ def embed_query(self, text: str) -> List[float]:
78
+ """
79
+ Embedding a text.
80
+
81
+ Args:
82
+
83
+ Text (str): A text to be embedded.
84
+
85
+ Return:
86
+
87
+ List [float]: An embedding list of input text, which is a list of floating-point values.
88
+ """
89
+ resp = self.embed_documents([text])
90
+ return resp[0]
91
+
92
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
93
+ """
94
+ Embeds a list of text documents.
95
+
96
+ Args:
97
+ texts (List[str]): A list of text documents to embed.
98
+
99
+ Returns:
100
+ List[List[float]]: A list of embeddings for each document in the input list.
101
+ Each embedding is represented as a list of float values.
102
+ """
103
+ return [self._embed(text) for text in texts]
104
+
105
+ async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
106
+ """Asynchronous Embed search docs."""
107
+ raise NotImplementedError(
108
+ "Please use `embed_documents`. Official does not support asynchronous requests")
109
+
110
+ async def aembed_query(self, text: str) -> List[float]:
111
+ """Asynchronous Embed query text."""
112
+ raise NotImplementedError(
113
+ "Please use `aembed_query`. Official does not support asynchronous requests")
project/llm/__init__.py ADDED
File without changes
project/llm/check_embed_llm.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+ '''
4
+ @File : zhipuai_llm.py
5
+ @Time : 2023/10/16 22:06:26
6
+ @Author : 0-yy-0
7
+ @Version : 1.0
8
+ @Contact : 310484121@qq.com
9
+ @License : (C)Copyright 2017-2018, Liugroup-NLPR-CASIA
10
+ @Desc : 基于智谱 AI 大模型自定义 LLM 类
11
+ '''
12
+
13
+ from __future__ import annotations
14
+
15
+ import logging
16
+ from typing import (
17
+ Any,
18
+ AsyncIterator,
19
+ Dict,
20
+ Iterator,
21
+ List,
22
+ Optional,
23
+ )
24
+
25
+ from langchain.callbacks.manager import (
26
+ AsyncCallbackManagerForLLMRun,
27
+ CallbackManagerForLLMRun,
28
+ )
29
+ from langchain.llms.base import LLM
30
+ from langchain.pydantic_v1 import Field, root_validator
31
+ from langchain.schema.output import GenerationChunk
32
+ from langchain.utils import get_from_dict_or_env
33
+ from project.llm.self_llm import Self_LLM
34
+ import re
35
+
36
+ logger = logging.getLogger(__name__)
37
+
38
+
39
+ class CheckEmbedLlm(Self_LLM):
40
+ """Zhipuai hosted open source or customized models.
41
+
42
+ To use, you should have the ``zhipuai`` python package installed, and
43
+ the environment variable ``zhipuai_api_key`` set with
44
+ your API key and Secret Key.
45
+
46
+ zhipuai_api_key are required parameters which you could get from
47
+ https://open.bigmodel.cn/usercenter/apikeys
48
+
49
+ Example:
50
+ .. code-block:: python
51
+
52
+ from langchain.llms import ZhipuAILLM
53
+ zhipuai_model = ZhipuAILLM(model="chatglm_std", temperature=temperature)
54
+
55
+ """
56
+
57
+ model_kwargs: Dict[str, Any] = Field(default_factory=dict)
58
+
59
+ client: Any
60
+
61
+ model: str = "chatglm_std"
62
+ """Model name in chatglm_pro, chatglm_std, chatglm_lite. """
63
+
64
+ zhipuai_api_key: Optional[str] = None
65
+
66
+ incremental: Optional[bool] = True
67
+ """Whether to incremental the results or not."""
68
+
69
+ streaming: Optional[bool] = False
70
+ """Whether to streaming the results or not."""
71
+ # streaming = -incremental
72
+
73
+ request_timeout: Optional[int] = 60
74
+ """request timeout for chat http requests"""
75
+
76
+ top_p: Optional[float] = 0.8
77
+ temperature: Optional[float] = 0.95
78
+ request_id: Optional[float] = None
79
+
80
+ @root_validator()
81
+ def validate_enviroment(cls, values: Dict) -> Dict:
82
+
83
+ values["zhipuai_api_key"] = get_from_dict_or_env(
84
+ values,
85
+ "zhipuai_api_key",
86
+ "ZHIPUAI_API_KEY",
87
+ )
88
+
89
+ params = {
90
+ "zhipuai_api_key": values["zhipuai_api_key"],
91
+ "model": values["model"],
92
+ }
93
+ try:
94
+ #import zhipuai
95
+
96
+ #zhipuai.api_key = values["zhipuai_api_key"]
97
+ #values["client"] = zhipuai()
98
+
99
+ from zhipuai import ZhipuAI
100
+
101
+ conf_api_key = values["zhipuai_api_key"]
102
+ client = ZhipuAI(api_key=conf_api_key)
103
+ values["client"] = client
104
+ except ImportError:
105
+ raise ValueError(
106
+ "zhipuai package not found, please install it with "
107
+ "`pip install zhipuai`"
108
+ )
109
+ return values
110
+
111
+ @property
112
+ def _identifying_params(self) -> Dict[str, Any]:
113
+ return {
114
+ **{"model": self.model},
115
+ **super()._identifying_params,
116
+ }
117
+
118
+ @property
119
+ def _llm_type(self) -> str:
120
+ """Return type of llm."""
121
+ return "zhipuai"
122
+
123
+ @property
124
+ def _default_params(self) -> Dict[str, Any]:
125
+ """Get the default parameters for calling OpenAI API."""
126
+ normal_params = {
127
+ "streaming": self.streaming,
128
+ "top_p": self.top_p,
129
+ "temperature": self.temperature,
130
+ "request_id": self.request_id,
131
+ }
132
+
133
+ return {**normal_params, **self.model_kwargs}
134
+
135
+ def _convert_prompt_msg_params(
136
+ self,
137
+ prompt: str,
138
+ **kwargs: Any,
139
+ ) -> dict:
140
+ return {
141
+ **{"prompt": prompt, "model": self.model},
142
+ **self._default_params,
143
+ **kwargs,
144
+ }
145
+
146
+ def _call(
147
+ self,
148
+ prompt: str,
149
+ stop: Optional[List[str]] = None,
150
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
151
+ **kwargs: Any,
152
+ ) -> str:
153
+ """Call out to an zhipuai models endpoint for each generation with a prompt.
154
+ Args:
155
+ prompt: The prompt to pass into the model.
156
+ Returns:
157
+ The string generated by the model.
158
+
159
+ Example:
160
+ .. code-block:: python
161
+ response = zhipuai_model("Tell me a joke.")
162
+ """
163
+ if self.streaming:
164
+ completion = ""
165
+ for chunk in self._stream(prompt, stop, run_manager, **kwargs):
166
+ completion += chunk.text
167
+ return completion
168
+ params = self._convert_prompt_msg_params(prompt, **kwargs)
169
+
170
+ all_word = params['prompt']
171
+
172
+ keyword = "问题"
173
+ matches = re.finditer(keyword, all_word)
174
+ indexes = [match.start() for match in matches]
175
+ last_index = indexes[len(indexes) -1]
176
+
177
+ params = {"messages": [
178
+ {"role": "system", "content": all_word[0:last_index]},
179
+ {"role": "user", "content": all_word[last_index:len(all_word)]}],
180
+ "model": self.model, "stream": False, "top_p": 0.8, "temperature": 0.01, "request_id": None}
181
+
182
+ print("params:", params)
183
+ response_payload = params
184
+
185
+ return response_payload
186
+
187
+ async def _acall(
188
+ self,
189
+ prompt: str,
190
+ stop: Optional[List[str]] = None,
191
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
192
+ **kwargs: Any,
193
+ ) -> str:
194
+ if self.streaming:
195
+ completion = ""
196
+ async for chunk in self._astream(prompt, stop, run_manager, **kwargs):
197
+ completion += chunk.text
198
+ return completion
199
+
200
+ params = self._convert_prompt_msg_params(prompt, **kwargs)
201
+
202
+ response = await self.client.async_invoke(**params)
203
+
204
+ return response_payload
205
+
206
+ def _stream(
207
+ self,
208
+ prompt: str,
209
+ stop: Optional[List[str]] = None,
210
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
211
+ **kwargs: Any,
212
+ ) -> Iterator[GenerationChunk]:
213
+ params = self._convert_prompt_msg_params(prompt, **kwargs)
214
+
215
+ for res in self.client.invoke(**params):
216
+ if res:
217
+ chunk = GenerationChunk(text=res)
218
+ yield chunk
219
+ if run_manager:
220
+ run_manager.on_llm_new_token(chunk.text)
221
+
222
+ async def _astream(
223
+
224
+ self,
225
+ prompt: str,
226
+ stop: Optional[List[str]] = None,
227
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
228
+ **kwargs: Any,
229
+ ) -> AsyncIterator[GenerationChunk]:
230
+ params = self._convert_prompt_msg_params(prompt, **kwargs)
231
+
232
+ async for res in await self.client.ado(**params):
233
+ if res:
234
+ chunk = GenerationChunk(text=res["data"]["choices"]["content"])
235
+
236
+ yield chunk
237
+ if run_manager:
238
+ await run_manager.on_llm_new_token(chunk.text)
project/llm/self_llm.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+ '''
4
+ @File : self_llm.py
5
+ @Time : 2023/10/16 18:48:08
6
+ @Author : Logan Zou
7
+ @Version : 1.0
8
+ @Contact : loganzou0421@163.com
9
+ @License : (C)Copyright 2017-2018, Liugroup-NLPR-CASIA
10
+ @Desc : 在 LangChain LLM 基础上封装的项目类,统一了 GPT、文心、讯飞、智谱多种 API 调用
11
+ '''
12
+
13
+ from langchain.llms.base import LLM
14
+ from typing import Dict, Any, Mapping
15
+ from pydantic import Field
16
+
17
+ class Self_LLM(LLM):
18
+ # 自定义 LLM
19
+ # 继承自 langchain.llms.base.LLM
20
+ # 原生接口地址
21
+ url : str = None
22
+ # 默认选用 GPT-3.5 模型,即目前一般所说的百度文心大模型
23
+ model_name: str = "gpt-3.5-turbo"
24
+ # 访问时延上限
25
+ request_timeout: float = None
26
+ # 温度系数
27
+ temperature: float = 0.1
28
+ # API_Key
29
+ api_key: str = None
30
+ # 必备的可选参数
31
+ model_kwargs: Dict[str, Any] = Field(default_factory=dict)
32
+
33
+ # 定义一个返回默认参数的方法
34
+ @property
35
+ def _default_params(self) -> Dict[str, Any]:
36
+ """获取调用默认参数。"""
37
+ normal_params = {
38
+ "temperature": self.temperature,
39
+ "request_timeout": self.request_timeout,
40
+ }
41
+ # print(type(self.model_kwargs))
42
+ return {**normal_params}
43
+
44
+ @property
45
+ def _identifying_params(self) -> Mapping[str, Any]:
46
+ """Get the identifying parameters."""
47
+ return {**{"model_name": self.model_name}, **self._default_params}
project/llm/zhipuai_llm.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+ '''
4
+ @File : zhipuai_llm.py
5
+ @Time : 2023/10/16 22:06:26
6
+ @Author : 0-yy-0
7
+ @Version : 1.0
8
+ @Contact : 310484121@qq.com
9
+ @License : (C)Copyright 2017-2018, Liugroup-NLPR-CASIA
10
+ @Desc : 基于智谱 AI 大模型自定义 LLM 类
11
+ '''
12
+
13
+ from __future__ import annotations
14
+
15
+ import logging
16
+ from typing import (
17
+ Any,
18
+ AsyncIterator,
19
+ Dict,
20
+ Iterator,
21
+ List,
22
+ Optional,
23
+ )
24
+
25
+ from langchain.callbacks.manager import (
26
+ AsyncCallbackManagerForLLMRun,
27
+ CallbackManagerForLLMRun,
28
+ )
29
+ from langchain.llms.base import LLM
30
+ from langchain.pydantic_v1 import Field, root_validator
31
+ from langchain.schema.output import GenerationChunk
32
+ from langchain.utils import get_from_dict_or_env
33
+ from project.llm.self_llm import Self_LLM
34
+ import re
35
+
36
+ logger = logging.getLogger(__name__)
37
+
38
+
39
+ class ZhipuAILLM(Self_LLM):
40
+ """Zhipuai hosted open source or customized models.
41
+
42
+ To use, you should have the ``zhipuai`` python package installed, and
43
+ the environment variable ``zhipuai_api_key`` set with
44
+ your API key and Secret Key.
45
+
46
+ zhipuai_api_key are required parameters which you could get from
47
+ https://open.bigmodel.cn/usercenter/apikeys
48
+
49
+ Example:
50
+ .. code-block:: python
51
+
52
+ from langchain.llms import ZhipuAILLM
53
+ zhipuai_model = ZhipuAILLM(model="chatglm_std", temperature=temperature)
54
+
55
+ """
56
+
57
+ model_kwargs: Dict[str, Any] = Field(default_factory=dict)
58
+
59
+ client: Any
60
+
61
+ model: str = "chatglm_std"
62
+ """Model name in chatglm_pro, chatglm_std, chatglm_lite. """
63
+
64
+ zhipuai_api_key: Optional[str] = None
65
+
66
+ incremental: Optional[bool] = True
67
+ """Whether to incremental the results or not."""
68
+
69
+ streaming: Optional[bool] = False
70
+ """Whether to streaming the results or not."""
71
+ # streaming = -incremental
72
+
73
+ request_timeout: Optional[int] = 60
74
+ """request timeout for chat http requests"""
75
+
76
+ top_p: Optional[float] = 0.8
77
+ temperature: Optional[float] = 0.95
78
+ request_id: Optional[float] = None
79
+
80
+ @root_validator()
81
+ def validate_enviroment(cls, values: Dict) -> Dict:
82
+
83
+ values["zhipuai_api_key"] = get_from_dict_or_env(
84
+ values,
85
+ "zhipuai_api_key",
86
+ "ZHIPUAI_API_KEY",
87
+ )
88
+
89
+ params = {
90
+ "zhipuai_api_key": values["zhipuai_api_key"],
91
+ "model": values["model"],
92
+ }
93
+ try:
94
+ #import zhipuai
95
+
96
+ #zhipuai.api_key = values["zhipuai_api_key"]
97
+ #values["client"] = zhipuai()
98
+
99
+ from zhipuai import ZhipuAI
100
+
101
+ conf_api_key = values["zhipuai_api_key"]
102
+ client = ZhipuAI(api_key=conf_api_key)
103
+ values["client"] = client
104
+ except ImportError:
105
+ raise ValueError(
106
+ "zhipuai package not found, please install it with "
107
+ "`pip install zhipuai`"
108
+ )
109
+ return values
110
+
111
+ @property
112
+ def _identifying_params(self) -> Dict[str, Any]:
113
+ return {
114
+ **{"model": self.model},
115
+ **super()._identifying_params,
116
+ }
117
+
118
+ @property
119
+ def _llm_type(self) -> str:
120
+ """Return type of llm."""
121
+ return "zhipuai"
122
+
123
+ @property
124
+ def _default_params(self) -> Dict[str, Any]:
125
+ """Get the default parameters for calling OpenAI API."""
126
+ normal_params = {
127
+ "streaming": self.streaming,
128
+ "top_p": self.top_p,
129
+ "temperature": self.temperature,
130
+ "request_id": self.request_id,
131
+ }
132
+
133
+ return {**normal_params, **self.model_kwargs}
134
+
135
+ def _convert_prompt_msg_params(
136
+ self,
137
+ prompt: str,
138
+ **kwargs: Any,
139
+ ) -> dict:
140
+ return {
141
+ **{"prompt": prompt, "model": self.model},
142
+ **self._default_params,
143
+ **kwargs,
144
+ }
145
+
146
+ def _call(
147
+ self,
148
+ prompt: str,
149
+ stop: Optional[List[str]] = None,
150
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
151
+ **kwargs: Any,
152
+ ) -> str:
153
+ """Call out to an zhipuai models endpoint for each generation with a prompt.
154
+ Args:
155
+ prompt: The prompt to pass into the model.
156
+ Returns:
157
+ The string generated by the model.
158
+
159
+ Example:
160
+ .. code-block:: python
161
+ response = zhipuai_model("Tell me a joke.")
162
+ """
163
+ if self.streaming:
164
+ completion = ""
165
+ for chunk in self._stream(prompt, stop, run_manager, **kwargs):
166
+ completion += chunk.text
167
+ return completion
168
+ params = self._convert_prompt_msg_params(prompt, **kwargs)
169
+
170
+ all_word = params['prompt']
171
+
172
+ keyword = "问题"
173
+ matches = re.finditer(keyword, all_word)
174
+ indexes = [match.start() for match in matches]
175
+ last_index = indexes[len(indexes) -1]
176
+
177
+ params = {"messages": [
178
+ {"role": "system", "content": all_word[0:last_index]},
179
+ {"role": "user", "content": all_word[last_index:len(all_word)]}],
180
+ "model": self.model, "stream": False, "top_p": 0.8, "temperature": 0.01, "request_id": None}
181
+
182
+ print("params:", params)
183
+ response_payload = self.client.chat.completions.create(**params)
184
+ print("response_payload", response_payload)
185
+
186
+ return response_payload.choices[0].message.content
187
+
188
+ async def _acall(
189
+ self,
190
+ prompt: str,
191
+ stop: Optional[List[str]] = None,
192
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
193
+ **kwargs: Any,
194
+ ) -> str:
195
+ if self.streaming:
196
+ completion = ""
197
+ async for chunk in self._astream(prompt, stop, run_manager, **kwargs):
198
+ completion += chunk.text
199
+ return completion
200
+
201
+ params = self._convert_prompt_msg_params(prompt, **kwargs)
202
+
203
+ response = await self.client.async_invoke(**params)
204
+
205
+ return response_payload
206
+
207
+ def _stream(
208
+ self,
209
+ prompt: str,
210
+ stop: Optional[List[str]] = None,
211
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
212
+ **kwargs: Any,
213
+ ) -> Iterator[GenerationChunk]:
214
+ params = self._convert_prompt_msg_params(prompt, **kwargs)
215
+
216
+ for res in self.client.invoke(**params):
217
+ if res:
218
+ chunk = GenerationChunk(text=res)
219
+ yield chunk
220
+ if run_manager:
221
+ run_manager.on_llm_new_token(chunk.text)
222
+
223
+ async def _astream(
224
+
225
+ self,
226
+ prompt: str,
227
+ stop: Optional[List[str]] = None,
228
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
229
+ **kwargs: Any,
230
+ ) -> AsyncIterator[GenerationChunk]:
231
+ params = self._convert_prompt_msg_params(prompt, **kwargs)
232
+
233
+ async for res in await self.client.ado(**params):
234
+ if res:
235
+ chunk = GenerationChunk(text=res["data"]["choices"]["content"])
236
+
237
+ yield chunk
238
+ if run_manager:
239
+ await run_manager.on_llm_new_token(chunk.text)
project/prompt/__init__.py ADDED
File without changes
project/prompt/answer_by_private_prompt.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.prompts import PromptTemplate
2
+
3
+ question_prompt_template = """用提供给你的文档去回答问题,不需要编造或者虚构答案,也不需要回答文档之外的内容。
4
+ 如果在文档中没有找到相关的答案,那么就直接回答'知识库中没有相关问题解答'
5
+ 请用中文回答。
6
+ 下边是我给你提供的文档,其中文档格式都是一问一答。问题答案也完全来自所提供的回答:
7
+ {context}
8
+
9
+ 问题: {question}
10
+ 答:"""
11
+ QUESTION_PROMPT = PromptTemplate(
12
+ template=question_prompt_template, input_variables=["context", "question"]
13
+ )
14
+
15
+ combine_prompt_template = """
16
+ QUESTION: {question}
17
+ =========
18
+ {summaries}
19
+ =========
20
+ FINAL ANSWER:"""
21
+ COMBINE_PROMPT = PromptTemplate(
22
+ template=combine_prompt_template, input_variables=["summaries", "question"]
23
+ )
24
+
25
+ EXAMPLE_PROMPT = PromptTemplate(
26
+ template="Content: {page_content}\nSource: {source}",
27
+ input_variables=["page_content", "source"],
28
+ )
29
+
30
+ DEFAULT_REFINE_PROMPT_TMPL = (
31
+ "The original question is as follows: {question}\n"
32
+ "We have provided an existing answer, including sources: {existing_answer}\n"
33
+ "We have the opportunity to refine the existing answer"
34
+ "(only if needed) with some more context below.\n"
35
+ "------------\n"
36
+ "{context_str}\n"
37
+ "------------\n"
38
+ "Given the new context, refine the original answer to better "
39
+ "answer the question. "
40
+ "If you do update it, please update the sources as well. "
41
+ "If the context isn't useful, return the original answer."
42
+ )
43
+ DEFAULT_REFINE_PROMPT = PromptTemplate(
44
+ input_variables=["question", "existing_answer", "context_str"],
45
+ template=DEFAULT_REFINE_PROMPT_TMPL,
46
+ )
47
+
48
+
49
+ DEFAULT_TEXT_QA_PROMPT_TMPL = (
50
+ """用提供给你的文档去回答问题,不需要编造或者虚构答案,也不需要回答文档之外的内容。
51
+ 如果在文档中没有找到相关的答案,那么就直接回答'知识库中没有相关问题解答'
52
+ 请用中文回答。
53
+ 下边是我给你提供的文档,其中文档格式都是一问一答。问题答案也完全来自所提供的回答:
54
+ ---------------------
55
+ {context_str}
56
+ ---------------------
57
+
58
+ 问题: {question}
59
+ 答:"""
60
+ )
61
+ DEFAULT_TEXT_QA_PROMPT = PromptTemplate(
62
+ input_variables=["context_str", "question"], template=DEFAULT_TEXT_QA_PROMPT_TMPL
63
+ )
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ pymilvus
2
+ langchain
3
+ openai
4
+ tiktoken
5
+ gradio
6
+ bs4
7
+ uuid
8
+ zhipuai
9
+ transformers
10
+ FlagEmbedding