from typing import Callable, Optional import gradio as gr from langchain.embeddings.openai import OpenAIEmbeddings from langchain.vectorstores import Zilliz from langchain.document_loaders import TextLoader from langchain.text_splitter import CharacterTextSplitter from langchain.chains import RetrievalQAWithSourcesChain import uuid from project.llm.zhipuai_llm import ZhipuAILLM from project.prompt.answer_by_private_prompt import ( COMBINE_PROMPT, EXAMPLE_PROMPT, QUESTION_PROMPT, DEFAULT_TEXT_QA_PROMPT, DEFAULT_REFINE_PROMPT ) from langchain.chains.combine_documents.refine import RefineDocumentsChain from langchain.chains.llm import LLMChain from langchain.chains.combine_documents import create_stuff_documents_chain from langchain.chains import StuffDocumentsChain from langchain_core.prompts import PromptTemplate chain: Optional[Callable] = None def web_loader(file, openai_key, puzhiai_key, zilliz_uri, user, password): if not file: return "please upload file" loader = TextLoader(file) docs = loader.load() text_splitter = CharacterTextSplitter(chunk_size=1024, chunk_overlap=0) docs = text_splitter.split_documents(docs) embeddings = OpenAIEmbeddings(model="text-embedding-ada-002", openai_api_key=openai_key) if not embeddings: return "embeddings not" texts = [d.page_content for d in docs] docsearch = Zilliz.from_documents( docs, embedding=embeddings, ids=[str(uuid.uuid4()) for _ in range(len(texts))], connection_args={ "uri": zilliz_uri, "user": user, "password": password, "secure": True, }, ) if not docsearch: return "docsearch not" global chain #chain = RetrievalQAWithSourcesChain.from_chain_type( # ZhipuAILLM(model="glm-3-turbo", temperature=0.1, zhipuai_api_key=puzhiai_key), # chain_type="refine", # retriever=docsearch.as_retriever(), #) #chain = RetrievalQAWithSourcesChain.from_llm( # ZhipuAILLM(model="glm-3-turbo", temperature=0.1, zhipuai_api_key=puzhiai_key), # EXAMPLE_PROMPT, # QUESTION_PROMPT, # COMBINE_PROMPT, # retriever=docsearch.as_retriever(), #) llm = ZhipuAILLM(model="glm-3-turbo", temperature=0.1, zhipuai_api_key=puzhiai_key) #initial_chain = LLMChain(llm=llm, prompt=DEFAULT_TEXT_QA_PROMPT) #refine_chain = LLMChain(llm=llm, prompt=DEFAULT_REFINE_PROMPT) #combine_documents_chain = RefineDocumentsChain( # initial_llm_chain=initial_chain, # refine_llm_chain=refine_chain, # document_variable_name="context_str", # initial_response_name="existing_answer", # document_prompt=EXAMPLE_PROMPT, #) document_prompt = PromptTemplate( input_variables=["page_content"], template="{page_content}" ) document_variable_name = "context" # The prompt here should take as an input variable the # `document_variable_name` prompt = PromptTemplate.from_template( """用提供给你的文档去回答问题,不需要编造或者虚构答案,也不需要回答文档之外的内容。 如果在文档中没有找到相关的答案,那么就直接回答'知识库中没有相关问题解答' 请用中文回答。 下边是我给你提供的文档,其中文档格式都是一问一答。问题答案也完全来自所提供的回答: --------------------- {context} --------------------- 问题: {question} 答:""" ) llm_chain = LLMChain(llm=llm, prompt=prompt) combine_documents_chain = StuffDocumentsChain( llm_chain=llm_chain, document_prompt=document_prompt, document_variable_name=document_variable_name ) chain = RetrievalQAWithSourcesChain(combine_documents_chain=combine_documents_chain, retriever=docsearch.as_retriever()) return "success to load data" def query(question): global chain # "What is milvus?" if not chain: return "please load the data first" return chain(inputs={"question": question}, return_only_outputs=True).get( "answer", "fail to get answer" ) if __name__ == "__main__": block = gr.Blocks() with block as demo: gr.Markdown( """