Spaces:
Sleeping
Sleeping
guangliang.yin
commited on
Commit
•
a7b5657
1
Parent(s):
5756ae7
初始化项目
Browse files- README.md +5 -5
- app.py +163 -0
- project/embeddings/__init__.py +0 -0
- project/embeddings/local_embed.py +95 -0
- project/embeddings/zhipuai_embedding.py +113 -0
- project/llm/__init__.py +0 -0
- project/llm/check_embed_llm.py +238 -0
- project/llm/self_llm.py +47 -0
- project/llm/zhipuai_llm.py +239 -0
- project/prompt/__init__.py +0 -0
- project/prompt/answer_by_private_prompt.py +63 -0
- requirements.txt +10 -0
README.md
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 4.19.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: mit
|
|
|
1 |
---
|
2 |
+
title: Hellow LangChain
|
3 |
+
emoji: 🐨
|
4 |
+
colorFrom: purple
|
5 |
+
colorTo: green
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 4.19.1
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: mit
|
app.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable, Optional
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
from langchain.vectorstores import Zilliz
|
5 |
+
from langchain.document_loaders import TextLoader
|
6 |
+
from langchain.text_splitter import CharacterTextSplitter
|
7 |
+
from langchain.chains import RetrievalQAWithSourcesChain
|
8 |
+
from langchain.chains.llm import LLMChain
|
9 |
+
from langchain.chains import StuffDocumentsChain
|
10 |
+
from langchain_core.prompts import PromptTemplate
|
11 |
+
import hashlib
|
12 |
+
import os
|
13 |
+
from project.embeddings.local_embed import LocalEmbed
|
14 |
+
from project.llm.check_embed_llm import CheckEmbedLlm
|
15 |
+
|
16 |
+
chain: Optional[Callable] = None
|
17 |
+
|
18 |
+
db_host = os.getenv("DB_HOST")
|
19 |
+
db_user = os.getenv("DB_USER")
|
20 |
+
db_password = os.getenv("DB_PASSWORD")
|
21 |
+
zhipuai_api_key = os.getenv("ZHIPU_AI_KEY")
|
22 |
+
|
23 |
+
|
24 |
+
def generate_article_id(content):
|
25 |
+
# 使用SHA-256哈希算法
|
26 |
+
sha256 = hashlib.sha256()
|
27 |
+
|
28 |
+
# 将文章内容编码为字节流并更新哈希对象
|
29 |
+
sha256.update(content.encode('utf-8'))
|
30 |
+
|
31 |
+
# 获取哈希值的十六进制表示
|
32 |
+
article_id = sha256.hexdigest()
|
33 |
+
|
34 |
+
return article_id
|
35 |
+
|
36 |
+
|
37 |
+
def web_loader(file):
|
38 |
+
if not file:
|
39 |
+
return "please upload file"
|
40 |
+
loader = TextLoader(file)
|
41 |
+
docs = loader.load()
|
42 |
+
|
43 |
+
text_splitter = CharacterTextSplitter(chunk_size=512, chunk_overlap=0)
|
44 |
+
docs = text_splitter.split_documents(docs)
|
45 |
+
#embeddings = OpenAIEmbeddings(model="text-embedding-ada-002", openai_api_key=openai_key)
|
46 |
+
#embeddings = ZhipuAIEmbeddings(zhipuai_api_key=zhipuai_api_key)
|
47 |
+
embeddings = LocalEmbed(zhipuai_api_key=zhipuai_api_key)
|
48 |
+
|
49 |
+
|
50 |
+
if not embeddings:
|
51 |
+
return "embeddings not"
|
52 |
+
|
53 |
+
texts = [d.page_content for d in docs]
|
54 |
+
article_ids = []
|
55 |
+
# 遍历texts列表
|
56 |
+
for text in texts:
|
57 |
+
# 使用generate_article_id函数生成文章ID,并将其添加到article_ids列表中
|
58 |
+
article_id = generate_article_id(text)
|
59 |
+
article_ids.append(article_id)
|
60 |
+
|
61 |
+
docsearch = Zilliz.from_documents(
|
62 |
+
docs,
|
63 |
+
embedding=embeddings,
|
64 |
+
ids=article_ids,
|
65 |
+
connection_args={
|
66 |
+
"uri": db_host,
|
67 |
+
"user": db_user,
|
68 |
+
"password": db_password,
|
69 |
+
"secure": True,
|
70 |
+
},
|
71 |
+
collection_name="LangChainCollectionYin"
|
72 |
+
)
|
73 |
+
|
74 |
+
if not docsearch:
|
75 |
+
return "docsearch not"
|
76 |
+
|
77 |
+
global chain
|
78 |
+
|
79 |
+
llm = CheckEmbedLlm(model="glm-3-turbo", temperature=0.1, zhipuai_api_key=zhipuai_api_key)
|
80 |
+
|
81 |
+
document_prompt = PromptTemplate(
|
82 |
+
input_variables=["page_content"],
|
83 |
+
template="{page_content}"
|
84 |
+
)
|
85 |
+
document_variable_name = "context"
|
86 |
+
# The prompt here should take as an input variable the
|
87 |
+
# `document_variable_name`
|
88 |
+
prompt = PromptTemplate.from_template(
|
89 |
+
"""查询到的文档如下:
|
90 |
+
{context}
|
91 |
+
|
92 |
+
问题: {question}
|
93 |
+
答:"""
|
94 |
+
)
|
95 |
+
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
96 |
+
combine_documents_chain = StuffDocumentsChain(
|
97 |
+
llm_chain=llm_chain,
|
98 |
+
document_prompt=document_prompt,
|
99 |
+
document_variable_name=document_variable_name
|
100 |
+
)
|
101 |
+
|
102 |
+
chain = RetrievalQAWithSourcesChain(combine_documents_chain=combine_documents_chain,
|
103 |
+
retriever=docsearch.as_retriever(search_kwargs={'k': 3}))
|
104 |
+
return "success to load data"
|
105 |
+
|
106 |
+
|
107 |
+
def query(question):
|
108 |
+
global chain
|
109 |
+
# "What is milvus?"
|
110 |
+
if not chain:
|
111 |
+
return "please load the data first"
|
112 |
+
return chain(inputs={"question": question}, return_only_outputs=True).get(
|
113 |
+
"answer", "fail to get answer"
|
114 |
+
)
|
115 |
+
|
116 |
+
|
117 |
+
if __name__ == "__main__":
|
118 |
+
block = gr.Blocks()
|
119 |
+
with block as demo:
|
120 |
+
gr.Markdown(
|
121 |
+
"""
|
122 |
+
<h1><center>Langchain And Embed App</center></h1>
|
123 |
+
|
124 |
+
v.2.29.14.55
|
125 |
+
|
126 |
+
"""
|
127 |
+
)
|
128 |
+
# url_list_text = gr.Textbox(
|
129 |
+
# label="url list",
|
130 |
+
# lines=3,
|
131 |
+
# placeholder="https://milvus.io/docs/overview.md",
|
132 |
+
# )
|
133 |
+
file = gr.File(label='请上传知识库文件\n可以处理 .txt, .md, .docx, .pdf 结尾的文件',
|
134 |
+
file_types=['.txt', '.md', '.docx', '.pdf'])
|
135 |
+
#openai_key_text = gr.Textbox(label="openai api key", type="password", placeholder="sk-******")
|
136 |
+
#puzhiai_key_text = gr.Textbox(label="puzhi api key", type="password", placeholder="******")
|
137 |
+
|
138 |
+
loader_output = gr.Textbox(label="load status")
|
139 |
+
loader_btn = gr.Button("Load Data")
|
140 |
+
loader_btn.click(
|
141 |
+
fn=web_loader,
|
142 |
+
inputs=[
|
143 |
+
file,
|
144 |
+
],
|
145 |
+
outputs=loader_output,
|
146 |
+
api_name="web_load",
|
147 |
+
)
|
148 |
+
|
149 |
+
question_text = gr.Textbox(
|
150 |
+
label="question",
|
151 |
+
lines=3,
|
152 |
+
placeholder="What is milvus?",
|
153 |
+
)
|
154 |
+
query_output = gr.Textbox(label="question answer", lines=3)
|
155 |
+
query_btn = gr.Button("Generate")
|
156 |
+
query_btn.click(
|
157 |
+
fn=query,
|
158 |
+
inputs=[question_text],
|
159 |
+
outputs=query_output,
|
160 |
+
api_name="generate_answer",
|
161 |
+
)
|
162 |
+
|
163 |
+
demo.queue().launch(server_name="0.0.0.0", share=False)
|
project/embeddings/__init__.py
ADDED
File without changes
|
project/embeddings/local_embed.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import logging
|
4 |
+
from typing import Any, Dict, List, Optional
|
5 |
+
|
6 |
+
from langchain.embeddings.base import Embeddings
|
7 |
+
from langchain.pydantic_v1 import BaseModel, root_validator
|
8 |
+
from langchain.utils import get_from_dict_or_env
|
9 |
+
from FlagEmbedding import LLMEmbedder
|
10 |
+
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
|
14 |
+
class LocalEmbed(BaseModel, Embeddings):
|
15 |
+
"""`Zhipuai Embeddings` embedding models."""
|
16 |
+
|
17 |
+
zhipuai_api_key: Optional[str] = None
|
18 |
+
"""Zhipuai application apikey"""
|
19 |
+
|
20 |
+
@root_validator()
|
21 |
+
def validate_environment(cls, values: Dict) -> Dict:
|
22 |
+
"""
|
23 |
+
Validate whether zhipuai_api_key in the environment variables or
|
24 |
+
configuration file are available or not.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
|
28 |
+
values: a dictionary containing configuration information, must include the
|
29 |
+
fields of zhipuai_api_key
|
30 |
+
Returns:
|
31 |
+
|
32 |
+
a dictionary containing configuration information. If zhipuai_api_key
|
33 |
+
are not provided in the environment variables or configuration
|
34 |
+
file, the original values will be returned; otherwise, values containing
|
35 |
+
zhipuai_api_key will be returned.
|
36 |
+
Raises:
|
37 |
+
|
38 |
+
ValueError: zhipuai package not found, please install it with `pip install
|
39 |
+
zhipuai`
|
40 |
+
"""
|
41 |
+
values["zhipuai_api_key"] = get_from_dict_or_env(
|
42 |
+
values,
|
43 |
+
"zhipuai_api_key",
|
44 |
+
"ZHIPUAI_API_KEY",
|
45 |
+
)
|
46 |
+
|
47 |
+
values["client"] = LLMEmbedder('BAAI/bge-large-zh-v1.5',
|
48 |
+
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
49 |
+
use_fp16=True)
|
50 |
+
return values
|
51 |
+
|
52 |
+
def _embed(self, texts: str) -> List[float]:
|
53 |
+
print("cal embed:", texts)
|
54 |
+
|
55 |
+
embeddings = self.client.encode(texts)
|
56 |
+
|
57 |
+
return embeddings
|
58 |
+
|
59 |
+
def embed_query(self, text: str) -> List[float]:
|
60 |
+
"""
|
61 |
+
Embedding a text.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
|
65 |
+
Text (str): A text to be embedded.
|
66 |
+
|
67 |
+
Return:
|
68 |
+
|
69 |
+
List [float]: An embedding list of input text, which is a list of floating-point values.
|
70 |
+
"""
|
71 |
+
resp = self.embed_documents([text])
|
72 |
+
return resp[0]
|
73 |
+
|
74 |
+
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
75 |
+
"""
|
76 |
+
Embeds a list of text documents.
|
77 |
+
|
78 |
+
Args:
|
79 |
+
texts (List[str]): A list of text documents to embed.
|
80 |
+
|
81 |
+
Returns:
|
82 |
+
List[List[float]]: A list of embeddings for each document in the input list.
|
83 |
+
Each embedding is represented as a list of float values.
|
84 |
+
"""
|
85 |
+
return [self._embed(text) for text in texts]
|
86 |
+
|
87 |
+
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
88 |
+
"""Asynchronous Embed search docs."""
|
89 |
+
raise NotImplementedError(
|
90 |
+
"Please use `embed_documents`. Official does not support asynchronous requests")
|
91 |
+
|
92 |
+
async def aembed_query(self, text: str) -> List[float]:
|
93 |
+
"""Asynchronous Embed query text."""
|
94 |
+
raise NotImplementedError(
|
95 |
+
"Please use `aembed_query`. Official does not support asynchronous requests")
|
project/embeddings/zhipuai_embedding.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import logging
|
4 |
+
from typing import Any, Dict, List, Optional
|
5 |
+
|
6 |
+
from langchain.embeddings.base import Embeddings
|
7 |
+
from langchain.pydantic_v1 import BaseModel, root_validator
|
8 |
+
from langchain.utils import get_from_dict_or_env
|
9 |
+
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
|
13 |
+
class ZhipuAIEmbeddings(BaseModel, Embeddings):
|
14 |
+
"""`Zhipuai Embeddings` embedding models."""
|
15 |
+
|
16 |
+
zhipuai_api_key: Optional[str] = None
|
17 |
+
"""Zhipuai application apikey"""
|
18 |
+
|
19 |
+
@root_validator()
|
20 |
+
def validate_environment(cls, values: Dict) -> Dict:
|
21 |
+
"""
|
22 |
+
Validate whether zhipuai_api_key in the environment variables or
|
23 |
+
configuration file are available or not.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
|
27 |
+
values: a dictionary containing configuration information, must include the
|
28 |
+
fields of zhipuai_api_key
|
29 |
+
Returns:
|
30 |
+
|
31 |
+
a dictionary containing configuration information. If zhipuai_api_key
|
32 |
+
are not provided in the environment variables or configuration
|
33 |
+
file, the original values will be returned; otherwise, values containing
|
34 |
+
zhipuai_api_key will be returned.
|
35 |
+
Raises:
|
36 |
+
|
37 |
+
ValueError: zhipuai package not found, please install it with `pip install
|
38 |
+
zhipuai`
|
39 |
+
"""
|
40 |
+
values["zhipuai_api_key"] = get_from_dict_or_env(
|
41 |
+
values,
|
42 |
+
"zhipuai_api_key",
|
43 |
+
"ZHIPUAI_API_KEY",
|
44 |
+
)
|
45 |
+
|
46 |
+
try:
|
47 |
+
from zhipuai import ZhipuAI
|
48 |
+
values["client"] = ZhipuAI(api_key=values["zhipuai_api_key"])
|
49 |
+
|
50 |
+
except ImportError:
|
51 |
+
raise ValueError(
|
52 |
+
"Zhipuai package not found, please install it with "
|
53 |
+
"`pip install zhipuai`"
|
54 |
+
)
|
55 |
+
return values
|
56 |
+
|
57 |
+
def _embed(self, texts: str) -> List[float]:
|
58 |
+
# send request
|
59 |
+
try:
|
60 |
+
print("cal embed:", texts)
|
61 |
+
resp = self.client.embeddings.create(
|
62 |
+
model="embedding-2",
|
63 |
+
input=texts
|
64 |
+
)
|
65 |
+
#print("resp:", resp)
|
66 |
+
except Exception as e:
|
67 |
+
raise ValueError(f"Error raised by inference endpoint: {e}")
|
68 |
+
|
69 |
+
if not resp.data :
|
70 |
+
raise ValueError(
|
71 |
+
"Error raised by inference API HTTP code: %s, %s"
|
72 |
+
% (resp["code"], resp["msg"])
|
73 |
+
)
|
74 |
+
embeddings = resp.data[0].embedding
|
75 |
+
return embeddings
|
76 |
+
|
77 |
+
def embed_query(self, text: str) -> List[float]:
|
78 |
+
"""
|
79 |
+
Embedding a text.
|
80 |
+
|
81 |
+
Args:
|
82 |
+
|
83 |
+
Text (str): A text to be embedded.
|
84 |
+
|
85 |
+
Return:
|
86 |
+
|
87 |
+
List [float]: An embedding list of input text, which is a list of floating-point values.
|
88 |
+
"""
|
89 |
+
resp = self.embed_documents([text])
|
90 |
+
return resp[0]
|
91 |
+
|
92 |
+
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
93 |
+
"""
|
94 |
+
Embeds a list of text documents.
|
95 |
+
|
96 |
+
Args:
|
97 |
+
texts (List[str]): A list of text documents to embed.
|
98 |
+
|
99 |
+
Returns:
|
100 |
+
List[List[float]]: A list of embeddings for each document in the input list.
|
101 |
+
Each embedding is represented as a list of float values.
|
102 |
+
"""
|
103 |
+
return [self._embed(text) for text in texts]
|
104 |
+
|
105 |
+
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
106 |
+
"""Asynchronous Embed search docs."""
|
107 |
+
raise NotImplementedError(
|
108 |
+
"Please use `embed_documents`. Official does not support asynchronous requests")
|
109 |
+
|
110 |
+
async def aembed_query(self, text: str) -> List[float]:
|
111 |
+
"""Asynchronous Embed query text."""
|
112 |
+
raise NotImplementedError(
|
113 |
+
"Please use `aembed_query`. Official does not support asynchronous requests")
|
project/llm/__init__.py
ADDED
File without changes
|
project/llm/check_embed_llm.py
ADDED
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
'''
|
4 |
+
@File : zhipuai_llm.py
|
5 |
+
@Time : 2023/10/16 22:06:26
|
6 |
+
@Author : 0-yy-0
|
7 |
+
@Version : 1.0
|
8 |
+
@Contact : 310484121@qq.com
|
9 |
+
@License : (C)Copyright 2017-2018, Liugroup-NLPR-CASIA
|
10 |
+
@Desc : 基于智谱 AI 大模型自定义 LLM 类
|
11 |
+
'''
|
12 |
+
|
13 |
+
from __future__ import annotations
|
14 |
+
|
15 |
+
import logging
|
16 |
+
from typing import (
|
17 |
+
Any,
|
18 |
+
AsyncIterator,
|
19 |
+
Dict,
|
20 |
+
Iterator,
|
21 |
+
List,
|
22 |
+
Optional,
|
23 |
+
)
|
24 |
+
|
25 |
+
from langchain.callbacks.manager import (
|
26 |
+
AsyncCallbackManagerForLLMRun,
|
27 |
+
CallbackManagerForLLMRun,
|
28 |
+
)
|
29 |
+
from langchain.llms.base import LLM
|
30 |
+
from langchain.pydantic_v1 import Field, root_validator
|
31 |
+
from langchain.schema.output import GenerationChunk
|
32 |
+
from langchain.utils import get_from_dict_or_env
|
33 |
+
from project.llm.self_llm import Self_LLM
|
34 |
+
import re
|
35 |
+
|
36 |
+
logger = logging.getLogger(__name__)
|
37 |
+
|
38 |
+
|
39 |
+
class CheckEmbedLlm(Self_LLM):
|
40 |
+
"""Zhipuai hosted open source or customized models.
|
41 |
+
|
42 |
+
To use, you should have the ``zhipuai`` python package installed, and
|
43 |
+
the environment variable ``zhipuai_api_key`` set with
|
44 |
+
your API key and Secret Key.
|
45 |
+
|
46 |
+
zhipuai_api_key are required parameters which you could get from
|
47 |
+
https://open.bigmodel.cn/usercenter/apikeys
|
48 |
+
|
49 |
+
Example:
|
50 |
+
.. code-block:: python
|
51 |
+
|
52 |
+
from langchain.llms import ZhipuAILLM
|
53 |
+
zhipuai_model = ZhipuAILLM(model="chatglm_std", temperature=temperature)
|
54 |
+
|
55 |
+
"""
|
56 |
+
|
57 |
+
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
58 |
+
|
59 |
+
client: Any
|
60 |
+
|
61 |
+
model: str = "chatglm_std"
|
62 |
+
"""Model name in chatglm_pro, chatglm_std, chatglm_lite. """
|
63 |
+
|
64 |
+
zhipuai_api_key: Optional[str] = None
|
65 |
+
|
66 |
+
incremental: Optional[bool] = True
|
67 |
+
"""Whether to incremental the results or not."""
|
68 |
+
|
69 |
+
streaming: Optional[bool] = False
|
70 |
+
"""Whether to streaming the results or not."""
|
71 |
+
# streaming = -incremental
|
72 |
+
|
73 |
+
request_timeout: Optional[int] = 60
|
74 |
+
"""request timeout for chat http requests"""
|
75 |
+
|
76 |
+
top_p: Optional[float] = 0.8
|
77 |
+
temperature: Optional[float] = 0.95
|
78 |
+
request_id: Optional[float] = None
|
79 |
+
|
80 |
+
@root_validator()
|
81 |
+
def validate_enviroment(cls, values: Dict) -> Dict:
|
82 |
+
|
83 |
+
values["zhipuai_api_key"] = get_from_dict_or_env(
|
84 |
+
values,
|
85 |
+
"zhipuai_api_key",
|
86 |
+
"ZHIPUAI_API_KEY",
|
87 |
+
)
|
88 |
+
|
89 |
+
params = {
|
90 |
+
"zhipuai_api_key": values["zhipuai_api_key"],
|
91 |
+
"model": values["model"],
|
92 |
+
}
|
93 |
+
try:
|
94 |
+
#import zhipuai
|
95 |
+
|
96 |
+
#zhipuai.api_key = values["zhipuai_api_key"]
|
97 |
+
#values["client"] = zhipuai()
|
98 |
+
|
99 |
+
from zhipuai import ZhipuAI
|
100 |
+
|
101 |
+
conf_api_key = values["zhipuai_api_key"]
|
102 |
+
client = ZhipuAI(api_key=conf_api_key)
|
103 |
+
values["client"] = client
|
104 |
+
except ImportError:
|
105 |
+
raise ValueError(
|
106 |
+
"zhipuai package not found, please install it with "
|
107 |
+
"`pip install zhipuai`"
|
108 |
+
)
|
109 |
+
return values
|
110 |
+
|
111 |
+
@property
|
112 |
+
def _identifying_params(self) -> Dict[str, Any]:
|
113 |
+
return {
|
114 |
+
**{"model": self.model},
|
115 |
+
**super()._identifying_params,
|
116 |
+
}
|
117 |
+
|
118 |
+
@property
|
119 |
+
def _llm_type(self) -> str:
|
120 |
+
"""Return type of llm."""
|
121 |
+
return "zhipuai"
|
122 |
+
|
123 |
+
@property
|
124 |
+
def _default_params(self) -> Dict[str, Any]:
|
125 |
+
"""Get the default parameters for calling OpenAI API."""
|
126 |
+
normal_params = {
|
127 |
+
"streaming": self.streaming,
|
128 |
+
"top_p": self.top_p,
|
129 |
+
"temperature": self.temperature,
|
130 |
+
"request_id": self.request_id,
|
131 |
+
}
|
132 |
+
|
133 |
+
return {**normal_params, **self.model_kwargs}
|
134 |
+
|
135 |
+
def _convert_prompt_msg_params(
|
136 |
+
self,
|
137 |
+
prompt: str,
|
138 |
+
**kwargs: Any,
|
139 |
+
) -> dict:
|
140 |
+
return {
|
141 |
+
**{"prompt": prompt, "model": self.model},
|
142 |
+
**self._default_params,
|
143 |
+
**kwargs,
|
144 |
+
}
|
145 |
+
|
146 |
+
def _call(
|
147 |
+
self,
|
148 |
+
prompt: str,
|
149 |
+
stop: Optional[List[str]] = None,
|
150 |
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
151 |
+
**kwargs: Any,
|
152 |
+
) -> str:
|
153 |
+
"""Call out to an zhipuai models endpoint for each generation with a prompt.
|
154 |
+
Args:
|
155 |
+
prompt: The prompt to pass into the model.
|
156 |
+
Returns:
|
157 |
+
The string generated by the model.
|
158 |
+
|
159 |
+
Example:
|
160 |
+
.. code-block:: python
|
161 |
+
response = zhipuai_model("Tell me a joke.")
|
162 |
+
"""
|
163 |
+
if self.streaming:
|
164 |
+
completion = ""
|
165 |
+
for chunk in self._stream(prompt, stop, run_manager, **kwargs):
|
166 |
+
completion += chunk.text
|
167 |
+
return completion
|
168 |
+
params = self._convert_prompt_msg_params(prompt, **kwargs)
|
169 |
+
|
170 |
+
all_word = params['prompt']
|
171 |
+
|
172 |
+
keyword = "问题"
|
173 |
+
matches = re.finditer(keyword, all_word)
|
174 |
+
indexes = [match.start() for match in matches]
|
175 |
+
last_index = indexes[len(indexes) -1]
|
176 |
+
|
177 |
+
params = {"messages": [
|
178 |
+
{"role": "system", "content": all_word[0:last_index]},
|
179 |
+
{"role": "user", "content": all_word[last_index:len(all_word)]}],
|
180 |
+
"model": self.model, "stream": False, "top_p": 0.8, "temperature": 0.01, "request_id": None}
|
181 |
+
|
182 |
+
print("params:", params)
|
183 |
+
response_payload = params
|
184 |
+
|
185 |
+
return response_payload
|
186 |
+
|
187 |
+
async def _acall(
|
188 |
+
self,
|
189 |
+
prompt: str,
|
190 |
+
stop: Optional[List[str]] = None,
|
191 |
+
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
192 |
+
**kwargs: Any,
|
193 |
+
) -> str:
|
194 |
+
if self.streaming:
|
195 |
+
completion = ""
|
196 |
+
async for chunk in self._astream(prompt, stop, run_manager, **kwargs):
|
197 |
+
completion += chunk.text
|
198 |
+
return completion
|
199 |
+
|
200 |
+
params = self._convert_prompt_msg_params(prompt, **kwargs)
|
201 |
+
|
202 |
+
response = await self.client.async_invoke(**params)
|
203 |
+
|
204 |
+
return response_payload
|
205 |
+
|
206 |
+
def _stream(
|
207 |
+
self,
|
208 |
+
prompt: str,
|
209 |
+
stop: Optional[List[str]] = None,
|
210 |
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
211 |
+
**kwargs: Any,
|
212 |
+
) -> Iterator[GenerationChunk]:
|
213 |
+
params = self._convert_prompt_msg_params(prompt, **kwargs)
|
214 |
+
|
215 |
+
for res in self.client.invoke(**params):
|
216 |
+
if res:
|
217 |
+
chunk = GenerationChunk(text=res)
|
218 |
+
yield chunk
|
219 |
+
if run_manager:
|
220 |
+
run_manager.on_llm_new_token(chunk.text)
|
221 |
+
|
222 |
+
async def _astream(
|
223 |
+
|
224 |
+
self,
|
225 |
+
prompt: str,
|
226 |
+
stop: Optional[List[str]] = None,
|
227 |
+
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
228 |
+
**kwargs: Any,
|
229 |
+
) -> AsyncIterator[GenerationChunk]:
|
230 |
+
params = self._convert_prompt_msg_params(prompt, **kwargs)
|
231 |
+
|
232 |
+
async for res in await self.client.ado(**params):
|
233 |
+
if res:
|
234 |
+
chunk = GenerationChunk(text=res["data"]["choices"]["content"])
|
235 |
+
|
236 |
+
yield chunk
|
237 |
+
if run_manager:
|
238 |
+
await run_manager.on_llm_new_token(chunk.text)
|
project/llm/self_llm.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
'''
|
4 |
+
@File : self_llm.py
|
5 |
+
@Time : 2023/10/16 18:48:08
|
6 |
+
@Author : Logan Zou
|
7 |
+
@Version : 1.0
|
8 |
+
@Contact : loganzou0421@163.com
|
9 |
+
@License : (C)Copyright 2017-2018, Liugroup-NLPR-CASIA
|
10 |
+
@Desc : 在 LangChain LLM 基础上封装的项目类,统一了 GPT、文心、讯飞、智谱多种 API 调用
|
11 |
+
'''
|
12 |
+
|
13 |
+
from langchain.llms.base import LLM
|
14 |
+
from typing import Dict, Any, Mapping
|
15 |
+
from pydantic import Field
|
16 |
+
|
17 |
+
class Self_LLM(LLM):
|
18 |
+
# 自定义 LLM
|
19 |
+
# 继承自 langchain.llms.base.LLM
|
20 |
+
# 原生接口地址
|
21 |
+
url : str = None
|
22 |
+
# 默认选用 GPT-3.5 模型,即目前一般所说的百度文心大模型
|
23 |
+
model_name: str = "gpt-3.5-turbo"
|
24 |
+
# 访问时延上限
|
25 |
+
request_timeout: float = None
|
26 |
+
# 温度系数
|
27 |
+
temperature: float = 0.1
|
28 |
+
# API_Key
|
29 |
+
api_key: str = None
|
30 |
+
# 必备的可选参数
|
31 |
+
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
32 |
+
|
33 |
+
# 定义一个返回默认参数的方法
|
34 |
+
@property
|
35 |
+
def _default_params(self) -> Dict[str, Any]:
|
36 |
+
"""获取调用默认参数。"""
|
37 |
+
normal_params = {
|
38 |
+
"temperature": self.temperature,
|
39 |
+
"request_timeout": self.request_timeout,
|
40 |
+
}
|
41 |
+
# print(type(self.model_kwargs))
|
42 |
+
return {**normal_params}
|
43 |
+
|
44 |
+
@property
|
45 |
+
def _identifying_params(self) -> Mapping[str, Any]:
|
46 |
+
"""Get the identifying parameters."""
|
47 |
+
return {**{"model_name": self.model_name}, **self._default_params}
|
project/llm/zhipuai_llm.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
'''
|
4 |
+
@File : zhipuai_llm.py
|
5 |
+
@Time : 2023/10/16 22:06:26
|
6 |
+
@Author : 0-yy-0
|
7 |
+
@Version : 1.0
|
8 |
+
@Contact : 310484121@qq.com
|
9 |
+
@License : (C)Copyright 2017-2018, Liugroup-NLPR-CASIA
|
10 |
+
@Desc : 基于智谱 AI 大模型自定义 LLM 类
|
11 |
+
'''
|
12 |
+
|
13 |
+
from __future__ import annotations
|
14 |
+
|
15 |
+
import logging
|
16 |
+
from typing import (
|
17 |
+
Any,
|
18 |
+
AsyncIterator,
|
19 |
+
Dict,
|
20 |
+
Iterator,
|
21 |
+
List,
|
22 |
+
Optional,
|
23 |
+
)
|
24 |
+
|
25 |
+
from langchain.callbacks.manager import (
|
26 |
+
AsyncCallbackManagerForLLMRun,
|
27 |
+
CallbackManagerForLLMRun,
|
28 |
+
)
|
29 |
+
from langchain.llms.base import LLM
|
30 |
+
from langchain.pydantic_v1 import Field, root_validator
|
31 |
+
from langchain.schema.output import GenerationChunk
|
32 |
+
from langchain.utils import get_from_dict_or_env
|
33 |
+
from project.llm.self_llm import Self_LLM
|
34 |
+
import re
|
35 |
+
|
36 |
+
logger = logging.getLogger(__name__)
|
37 |
+
|
38 |
+
|
39 |
+
class ZhipuAILLM(Self_LLM):
|
40 |
+
"""Zhipuai hosted open source or customized models.
|
41 |
+
|
42 |
+
To use, you should have the ``zhipuai`` python package installed, and
|
43 |
+
the environment variable ``zhipuai_api_key`` set with
|
44 |
+
your API key and Secret Key.
|
45 |
+
|
46 |
+
zhipuai_api_key are required parameters which you could get from
|
47 |
+
https://open.bigmodel.cn/usercenter/apikeys
|
48 |
+
|
49 |
+
Example:
|
50 |
+
.. code-block:: python
|
51 |
+
|
52 |
+
from langchain.llms import ZhipuAILLM
|
53 |
+
zhipuai_model = ZhipuAILLM(model="chatglm_std", temperature=temperature)
|
54 |
+
|
55 |
+
"""
|
56 |
+
|
57 |
+
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
58 |
+
|
59 |
+
client: Any
|
60 |
+
|
61 |
+
model: str = "chatglm_std"
|
62 |
+
"""Model name in chatglm_pro, chatglm_std, chatglm_lite. """
|
63 |
+
|
64 |
+
zhipuai_api_key: Optional[str] = None
|
65 |
+
|
66 |
+
incremental: Optional[bool] = True
|
67 |
+
"""Whether to incremental the results or not."""
|
68 |
+
|
69 |
+
streaming: Optional[bool] = False
|
70 |
+
"""Whether to streaming the results or not."""
|
71 |
+
# streaming = -incremental
|
72 |
+
|
73 |
+
request_timeout: Optional[int] = 60
|
74 |
+
"""request timeout for chat http requests"""
|
75 |
+
|
76 |
+
top_p: Optional[float] = 0.8
|
77 |
+
temperature: Optional[float] = 0.95
|
78 |
+
request_id: Optional[float] = None
|
79 |
+
|
80 |
+
@root_validator()
|
81 |
+
def validate_enviroment(cls, values: Dict) -> Dict:
|
82 |
+
|
83 |
+
values["zhipuai_api_key"] = get_from_dict_or_env(
|
84 |
+
values,
|
85 |
+
"zhipuai_api_key",
|
86 |
+
"ZHIPUAI_API_KEY",
|
87 |
+
)
|
88 |
+
|
89 |
+
params = {
|
90 |
+
"zhipuai_api_key": values["zhipuai_api_key"],
|
91 |
+
"model": values["model"],
|
92 |
+
}
|
93 |
+
try:
|
94 |
+
#import zhipuai
|
95 |
+
|
96 |
+
#zhipuai.api_key = values["zhipuai_api_key"]
|
97 |
+
#values["client"] = zhipuai()
|
98 |
+
|
99 |
+
from zhipuai import ZhipuAI
|
100 |
+
|
101 |
+
conf_api_key = values["zhipuai_api_key"]
|
102 |
+
client = ZhipuAI(api_key=conf_api_key)
|
103 |
+
values["client"] = client
|
104 |
+
except ImportError:
|
105 |
+
raise ValueError(
|
106 |
+
"zhipuai package not found, please install it with "
|
107 |
+
"`pip install zhipuai`"
|
108 |
+
)
|
109 |
+
return values
|
110 |
+
|
111 |
+
@property
|
112 |
+
def _identifying_params(self) -> Dict[str, Any]:
|
113 |
+
return {
|
114 |
+
**{"model": self.model},
|
115 |
+
**super()._identifying_params,
|
116 |
+
}
|
117 |
+
|
118 |
+
@property
|
119 |
+
def _llm_type(self) -> str:
|
120 |
+
"""Return type of llm."""
|
121 |
+
return "zhipuai"
|
122 |
+
|
123 |
+
@property
|
124 |
+
def _default_params(self) -> Dict[str, Any]:
|
125 |
+
"""Get the default parameters for calling OpenAI API."""
|
126 |
+
normal_params = {
|
127 |
+
"streaming": self.streaming,
|
128 |
+
"top_p": self.top_p,
|
129 |
+
"temperature": self.temperature,
|
130 |
+
"request_id": self.request_id,
|
131 |
+
}
|
132 |
+
|
133 |
+
return {**normal_params, **self.model_kwargs}
|
134 |
+
|
135 |
+
def _convert_prompt_msg_params(
|
136 |
+
self,
|
137 |
+
prompt: str,
|
138 |
+
**kwargs: Any,
|
139 |
+
) -> dict:
|
140 |
+
return {
|
141 |
+
**{"prompt": prompt, "model": self.model},
|
142 |
+
**self._default_params,
|
143 |
+
**kwargs,
|
144 |
+
}
|
145 |
+
|
146 |
+
def _call(
|
147 |
+
self,
|
148 |
+
prompt: str,
|
149 |
+
stop: Optional[List[str]] = None,
|
150 |
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
151 |
+
**kwargs: Any,
|
152 |
+
) -> str:
|
153 |
+
"""Call out to an zhipuai models endpoint for each generation with a prompt.
|
154 |
+
Args:
|
155 |
+
prompt: The prompt to pass into the model.
|
156 |
+
Returns:
|
157 |
+
The string generated by the model.
|
158 |
+
|
159 |
+
Example:
|
160 |
+
.. code-block:: python
|
161 |
+
response = zhipuai_model("Tell me a joke.")
|
162 |
+
"""
|
163 |
+
if self.streaming:
|
164 |
+
completion = ""
|
165 |
+
for chunk in self._stream(prompt, stop, run_manager, **kwargs):
|
166 |
+
completion += chunk.text
|
167 |
+
return completion
|
168 |
+
params = self._convert_prompt_msg_params(prompt, **kwargs)
|
169 |
+
|
170 |
+
all_word = params['prompt']
|
171 |
+
|
172 |
+
keyword = "问题"
|
173 |
+
matches = re.finditer(keyword, all_word)
|
174 |
+
indexes = [match.start() for match in matches]
|
175 |
+
last_index = indexes[len(indexes) -1]
|
176 |
+
|
177 |
+
params = {"messages": [
|
178 |
+
{"role": "system", "content": all_word[0:last_index]},
|
179 |
+
{"role": "user", "content": all_word[last_index:len(all_word)]}],
|
180 |
+
"model": self.model, "stream": False, "top_p": 0.8, "temperature": 0.01, "request_id": None}
|
181 |
+
|
182 |
+
print("params:", params)
|
183 |
+
response_payload = self.client.chat.completions.create(**params)
|
184 |
+
print("response_payload", response_payload)
|
185 |
+
|
186 |
+
return response_payload.choices[0].message.content
|
187 |
+
|
188 |
+
async def _acall(
|
189 |
+
self,
|
190 |
+
prompt: str,
|
191 |
+
stop: Optional[List[str]] = None,
|
192 |
+
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
193 |
+
**kwargs: Any,
|
194 |
+
) -> str:
|
195 |
+
if self.streaming:
|
196 |
+
completion = ""
|
197 |
+
async for chunk in self._astream(prompt, stop, run_manager, **kwargs):
|
198 |
+
completion += chunk.text
|
199 |
+
return completion
|
200 |
+
|
201 |
+
params = self._convert_prompt_msg_params(prompt, **kwargs)
|
202 |
+
|
203 |
+
response = await self.client.async_invoke(**params)
|
204 |
+
|
205 |
+
return response_payload
|
206 |
+
|
207 |
+
def _stream(
|
208 |
+
self,
|
209 |
+
prompt: str,
|
210 |
+
stop: Optional[List[str]] = None,
|
211 |
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
212 |
+
**kwargs: Any,
|
213 |
+
) -> Iterator[GenerationChunk]:
|
214 |
+
params = self._convert_prompt_msg_params(prompt, **kwargs)
|
215 |
+
|
216 |
+
for res in self.client.invoke(**params):
|
217 |
+
if res:
|
218 |
+
chunk = GenerationChunk(text=res)
|
219 |
+
yield chunk
|
220 |
+
if run_manager:
|
221 |
+
run_manager.on_llm_new_token(chunk.text)
|
222 |
+
|
223 |
+
async def _astream(
|
224 |
+
|
225 |
+
self,
|
226 |
+
prompt: str,
|
227 |
+
stop: Optional[List[str]] = None,
|
228 |
+
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
229 |
+
**kwargs: Any,
|
230 |
+
) -> AsyncIterator[GenerationChunk]:
|
231 |
+
params = self._convert_prompt_msg_params(prompt, **kwargs)
|
232 |
+
|
233 |
+
async for res in await self.client.ado(**params):
|
234 |
+
if res:
|
235 |
+
chunk = GenerationChunk(text=res["data"]["choices"]["content"])
|
236 |
+
|
237 |
+
yield chunk
|
238 |
+
if run_manager:
|
239 |
+
await run_manager.on_llm_new_token(chunk.text)
|
project/prompt/__init__.py
ADDED
File without changes
|
project/prompt/answer_by_private_prompt.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_core.prompts import PromptTemplate
|
2 |
+
|
3 |
+
question_prompt_template = """用提供给你的文档去回答问题,不需要编造或者虚构答案,也不需要回答文档之外的内容。
|
4 |
+
如果在文档中没有找到相关的答案,那么就直接回答'知识库中没有相关问题解答'
|
5 |
+
请用中文回答。
|
6 |
+
下边是我给你提供的文档,其中文档格式都是一问一答。问题答案也完全来自所提供的回答:
|
7 |
+
{context}
|
8 |
+
|
9 |
+
问题: {question}
|
10 |
+
答:"""
|
11 |
+
QUESTION_PROMPT = PromptTemplate(
|
12 |
+
template=question_prompt_template, input_variables=["context", "question"]
|
13 |
+
)
|
14 |
+
|
15 |
+
combine_prompt_template = """
|
16 |
+
QUESTION: {question}
|
17 |
+
=========
|
18 |
+
{summaries}
|
19 |
+
=========
|
20 |
+
FINAL ANSWER:"""
|
21 |
+
COMBINE_PROMPT = PromptTemplate(
|
22 |
+
template=combine_prompt_template, input_variables=["summaries", "question"]
|
23 |
+
)
|
24 |
+
|
25 |
+
EXAMPLE_PROMPT = PromptTemplate(
|
26 |
+
template="Content: {page_content}\nSource: {source}",
|
27 |
+
input_variables=["page_content", "source"],
|
28 |
+
)
|
29 |
+
|
30 |
+
DEFAULT_REFINE_PROMPT_TMPL = (
|
31 |
+
"The original question is as follows: {question}\n"
|
32 |
+
"We have provided an existing answer, including sources: {existing_answer}\n"
|
33 |
+
"We have the opportunity to refine the existing answer"
|
34 |
+
"(only if needed) with some more context below.\n"
|
35 |
+
"------------\n"
|
36 |
+
"{context_str}\n"
|
37 |
+
"------------\n"
|
38 |
+
"Given the new context, refine the original answer to better "
|
39 |
+
"answer the question. "
|
40 |
+
"If you do update it, please update the sources as well. "
|
41 |
+
"If the context isn't useful, return the original answer."
|
42 |
+
)
|
43 |
+
DEFAULT_REFINE_PROMPT = PromptTemplate(
|
44 |
+
input_variables=["question", "existing_answer", "context_str"],
|
45 |
+
template=DEFAULT_REFINE_PROMPT_TMPL,
|
46 |
+
)
|
47 |
+
|
48 |
+
|
49 |
+
DEFAULT_TEXT_QA_PROMPT_TMPL = (
|
50 |
+
"""用提供给你的文档去回答问题,不需要编造或者虚构答案,也不需要回答文档之外的内容。
|
51 |
+
如果在文档中没有找到相关的答案,那么就直接回答'知识库中没有相关问题解答'
|
52 |
+
请用中文回答。
|
53 |
+
下边是我给你提供的文档,其中文档格式都是一问一答。问题答案也完全来自所提供的回答:
|
54 |
+
---------------------
|
55 |
+
{context_str}
|
56 |
+
---------------------
|
57 |
+
|
58 |
+
问题: {question}
|
59 |
+
答:"""
|
60 |
+
)
|
61 |
+
DEFAULT_TEXT_QA_PROMPT = PromptTemplate(
|
62 |
+
input_variables=["context_str", "question"], template=DEFAULT_TEXT_QA_PROMPT_TMPL
|
63 |
+
)
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pymilvus
|
2 |
+
langchain
|
3 |
+
openai
|
4 |
+
tiktoken
|
5 |
+
gradio
|
6 |
+
bs4
|
7 |
+
uuid
|
8 |
+
zhipuai
|
9 |
+
transformers
|
10 |
+
FlagEmbedding
|