from typing import Callable, Optional import gradio as gr from langchain.vectorstores import Zilliz from langchain.document_loaders import TextLoader from langchain.text_splitter import CharacterTextSplitter from langchain.chains import RetrievalQAWithSourcesChain from langchain.chains.llm import LLMChain from langchain.chains import StuffDocumentsChain from langchain_core.prompts import PromptTemplate import hashlib import os from project.embeddings.local_embed import LocalEmbed from project.llm.check_embed_llm import CheckEmbedLlm chain: Optional[Callable] = None db_host = os.getenv("DB_HOST") db_user = os.getenv("DB_USER") db_password = os.getenv("DB_PASSWORD") zhipuai_api_key = os.getenv("ZHIPU_AI_KEY") def generate_article_id(content): # 使用SHA-256哈希算法 sha256 = hashlib.sha256() # 将文章内容编码为字节流并更新哈希对象 sha256.update(content.encode('utf-8')) # 获取哈希值的十六进制表示 article_id = sha256.hexdigest() return article_id def web_loader(file): if not file: return "please upload file" loader = TextLoader(file) docs = loader.load() text_splitter = CharacterTextSplitter(chunk_size=512, chunk_overlap=0) docs = text_splitter.split_documents(docs) #embeddings = OpenAIEmbeddings(model="text-embedding-ada-002", openai_api_key=openai_key) #embeddings = ZhipuAIEmbeddings(zhipuai_api_key=zhipuai_api_key) embeddings = LocalEmbed(zhipuai_api_key=zhipuai_api_key) if not embeddings: return "embeddings not" texts = [d.page_content for d in docs] article_ids = [] # 遍历texts列表 for text in texts: # 使用generate_article_id函数生成文章ID,并将其添加到article_ids列表中 article_id = generate_article_id(text) article_ids.append(article_id) docsearch = Zilliz.from_documents( docs, embedding=embeddings, ids=article_ids, connection_args={ "uri": db_host, "user": db_user, "password": db_password, "secure": True, }, collection_name="CheckEmbedLocalEmbed" ) if not docsearch: return "docsearch not" llm = CheckEmbedLlm(model="glm-3-turbo", temperature=0.1, zhipuai_api_key=zhipuai_api_key) document_prompt = PromptTemplate( input_variables=["page_content"], template="{page_content}" ) document_variable_name = "context" # The prompt here should take as an input variable the # `document_variable_name` prompt = PromptTemplate.from_template( """查询到的文档如下: {context} 问题: {question} 答:""" ) llm_chain = LLMChain(llm=llm, prompt=prompt) combine_documents_chain = StuffDocumentsChain( llm_chain=llm_chain, document_prompt=document_prompt, document_variable_name=document_variable_name ) global chain chain = RetrievalQAWithSourcesChain(combine_documents_chain=combine_documents_chain, retriever=docsearch.as_retriever(search_kwargs={'k': 3})) return "success to load data" def query(question): global chain # "What is milvus?" if not chain: return "please load the data first" return chain(inputs={"question": question}, return_only_outputs=True).get( "answer", "fail to get answer" ) if __name__ == "__main__": block = gr.Blocks() with block as demo: gr.Markdown( """

Langchain And Embed App

v.2.29.17.30 """ ) # url_list_text = gr.Textbox( # label="url list", # lines=3, # placeholder="https://milvus.io/docs/overview.md", # ) file = gr.File(label='请上传知识库文件\n可以处理 .txt, .md, .docx, .pdf 结尾的文件', file_types=['.txt', '.md', '.docx', '.pdf']) #openai_key_text = gr.Textbox(label="openai api key", type="password", placeholder="sk-******") #puzhiai_key_text = gr.Textbox(label="puzhi api key", type="password", placeholder="******") loader_output = gr.Textbox(label="load status") loader_btn = gr.Button("Load Data") loader_btn.click( fn=web_loader, inputs=[ file, ], outputs=loader_output, api_name="web_load", ) question_text = gr.Textbox( label="question", lines=3, placeholder="What is milvus?", ) query_output = gr.Textbox(label="question answer", lines=3) query_btn = gr.Button("Generate") query_btn.click( fn=query, inputs=[question_text], outputs=query_output, api_name="generate_answer", ) demo.queue().launch(server_name="0.0.0.0", share=False)