guangliang.yin
优化提示词
2c97b5a
raw
history blame
6.37 kB
from typing import Callable, Optional
import gradio as gr
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import Zilliz
from langchain.document_loaders import TextLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain.chains import RetrievalQAWithSourcesChain
import uuid
from project.llm.zhipuai_llm import ZhipuAILLM
from project.prompt.answer_by_private_prompt import (
COMBINE_PROMPT,
EXAMPLE_PROMPT,
QUESTION_PROMPT,
DEFAULT_TEXT_QA_PROMPT,
DEFAULT_REFINE_PROMPT
)
from langchain.chains.combine_documents.refine import RefineDocumentsChain
from langchain.chains.llm import LLMChain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains import StuffDocumentsChain
from langchain_core.prompts import PromptTemplate
chain: Optional[Callable] = None
def web_loader(file, openai_key, puzhiai_key, zilliz_uri, user, password):
if not file:
return "please upload file"
loader = TextLoader(file)
docs = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=1024, chunk_overlap=0)
docs = text_splitter.split_documents(docs)
embeddings = OpenAIEmbeddings(model="text-embedding-ada-002", openai_api_key=openai_key)
if not embeddings:
return "embeddings not"
texts = [d.page_content for d in docs]
docsearch = Zilliz.from_documents(
docs,
embedding=embeddings,
ids=[str(uuid.uuid4()) for _ in range(len(texts))],
connection_args={
"uri": zilliz_uri,
"user": user,
"password": password,
"secure": True,
},
)
if not docsearch:
return "docsearch not"
global chain
#chain = RetrievalQAWithSourcesChain.from_chain_type(
# ZhipuAILLM(model="glm-3-turbo", temperature=0.1, zhipuai_api_key=puzhiai_key),
# chain_type="refine",
# retriever=docsearch.as_retriever(),
#)
#chain = RetrievalQAWithSourcesChain.from_llm(
# ZhipuAILLM(model="glm-3-turbo", temperature=0.1, zhipuai_api_key=puzhiai_key),
# EXAMPLE_PROMPT,
# QUESTION_PROMPT,
# COMBINE_PROMPT,
# retriever=docsearch.as_retriever(),
#)
llm = ZhipuAILLM(model="glm-3-turbo", temperature=0.1, zhipuai_api_key=puzhiai_key)
#initial_chain = LLMChain(llm=llm, prompt=DEFAULT_TEXT_QA_PROMPT)
#refine_chain = LLMChain(llm=llm, prompt=DEFAULT_REFINE_PROMPT)
#combine_documents_chain = RefineDocumentsChain(
# initial_llm_chain=initial_chain,
# refine_llm_chain=refine_chain,
# document_variable_name="context_str",
# initial_response_name="existing_answer",
# document_prompt=EXAMPLE_PROMPT,
#)
document_prompt = PromptTemplate(
input_variables=["page_content"],
template="{page_content}"
)
document_variable_name = "context"
# The prompt here should take as an input variable the
# `document_variable_name`
prompt = PromptTemplate.from_template(
"""用提供给你的文档去回答问题,不需要编造或者虚构答案,也不需要回答文档之外的内容。
如果在文档中没有找到相关的答案,那么就直接回答'知识库中没有相关问题解答'
请用中文回答。
下边是我给你提供的文档,其中文档格式都是一问一答。问题答案也完全来自所提供的回答:
---------------------
{context}
---------------------
问题: {question}
答:"""
)
llm_chain = LLMChain(llm=llm, prompt=prompt)
combine_documents_chain = StuffDocumentsChain(
llm_chain=llm_chain,
document_prompt=document_prompt,
document_variable_name=document_variable_name
)
chain = RetrievalQAWithSourcesChain(combine_documents_chain=combine_documents_chain,
retriever=docsearch.as_retriever())
return "success to load data"
def query(question):
global chain
# "What is milvus?"
if not chain:
return "please load the data first"
return chain(inputs={"question": question}, return_only_outputs=True).get(
"answer", "fail to get answer"
)
if __name__ == "__main__":
block = gr.Blocks()
with block as demo:
gr.Markdown(
"""
<h1><center>Langchain And Zilliz App</center></h1>
v.2.27.15.27
"""
)
# url_list_text = gr.Textbox(
# label="url list",
# lines=3,
# placeholder="https://milvus.io/docs/overview.md",
# )
file = gr.File(label='请上传知识库文件',
file_types=['.txt', '.md', '.docx', '.pdf'])
openai_key_text = gr.Textbox(label="openai api key", type="password", placeholder="sk-******")
puzhiai_key_text = gr.Textbox(label="puzhi api key", type="password", placeholder="******")
with gr.Row():
zilliz_uri_text = gr.Textbox(
label="zilliz cloud uri",
placeholder="https://<instance-id>.<cloud-region-id>.vectordb.zillizcloud.com:<port>",
)
user_text = gr.Textbox(label="username", placeholder="db_admin")
password_text = gr.Textbox(
label="password", type="password", placeholder="******"
)
loader_output = gr.Textbox(label="load status")
loader_btn = gr.Button("Load Data")
loader_btn.click(
fn=web_loader,
inputs=[
file,
openai_key_text,
puzhiai_key_text,
zilliz_uri_text,
user_text,
password_text,
],
outputs=loader_output,
api_name="web_load",
)
question_text = gr.Textbox(
label="question",
lines=3,
placeholder="What is milvus?",
)
query_output = gr.Textbox(label="question answer", lines=3)
query_btn = gr.Button("Generate")
query_btn.click(
fn=query,
inputs=[question_text],
outputs=query_output,
api_name="generate_answer",
)
demo.queue().launch(server_name="0.0.0.0", share=False)