Spaces:
Sleeping
Sleeping
from typing import Callable, Optional | |
import gradio as gr | |
from langchain.vectorstores import Zilliz | |
from langchain.document_loaders import TextLoader | |
from langchain.text_splitter import CharacterTextSplitter | |
from langchain.chains import RetrievalQAWithSourcesChain | |
from langchain.chains.llm import LLMChain | |
from langchain.chains import StuffDocumentsChain | |
from langchain_core.prompts import PromptTemplate | |
import hashlib | |
import os | |
from project.embeddings.local_embed import LocalEmbed | |
from project.llm.check_embed_llm import CheckEmbedLlm | |
chain: Optional[Callable] = None | |
db_host = os.getenv("DB_HOST") | |
db_user = os.getenv("DB_USER") | |
db_password = os.getenv("DB_PASSWORD") | |
zhipuai_api_key = os.getenv("ZHIPU_AI_KEY") | |
def generate_article_id(content): | |
# 使用SHA-256哈希算法 | |
sha256 = hashlib.sha256() | |
# 将文章内容编码为字节流并更新哈希对象 | |
sha256.update(content.encode('utf-8')) | |
# 获取哈希值的十六进制表示 | |
article_id = sha256.hexdigest() | |
return article_id | |
def web_loader(file): | |
if not file: | |
return "please upload file" | |
loader = TextLoader(file) | |
docs = loader.load() | |
text_splitter = CharacterTextSplitter(chunk_size=512, chunk_overlap=0) | |
docs = text_splitter.split_documents(docs) | |
#embeddings = OpenAIEmbeddings(model="text-embedding-ada-002", openai_api_key=openai_key) | |
#embeddings = ZhipuAIEmbeddings(zhipuai_api_key=zhipuai_api_key) | |
embeddings = LocalEmbed(zhipuai_api_key=zhipuai_api_key) | |
if not embeddings: | |
return "embeddings not" | |
texts = [d.page_content for d in docs] | |
article_ids = [] | |
# 遍历texts列表 | |
for text in texts: | |
# 使用generate_article_id函数生成文章ID,并将其添加到article_ids列表中 | |
article_id = generate_article_id(text) | |
article_ids.append(article_id) | |
docsearch = Zilliz.from_documents( | |
docs, | |
embedding=embeddings, | |
ids=article_ids, | |
connection_args={ | |
"uri": db_host, | |
"user": db_user, | |
"password": db_password, | |
"secure": True, | |
}, | |
collection_name="CheckEmbedLocalEmbed" | |
) | |
if not docsearch: | |
return "docsearch not" | |
llm = CheckEmbedLlm(model="glm-3-turbo", temperature=0.1, zhipuai_api_key=zhipuai_api_key) | |
document_prompt = PromptTemplate( | |
input_variables=["page_content"], | |
template="{page_content}" | |
) | |
document_variable_name = "context" | |
# The prompt here should take as an input variable the | |
# `document_variable_name` | |
prompt = PromptTemplate.from_template( | |
"""查询到的文档如下: | |
{context} | |
问题: {question} | |
答:""" | |
) | |
llm_chain = LLMChain(llm=llm, prompt=prompt) | |
combine_documents_chain = StuffDocumentsChain( | |
llm_chain=llm_chain, | |
document_prompt=document_prompt, | |
document_variable_name=document_variable_name | |
) | |
global chain | |
chain = RetrievalQAWithSourcesChain(combine_documents_chain=combine_documents_chain, | |
retriever=docsearch.as_retriever(search_kwargs={'k': 3})) | |
return "success to load data" | |
def query(question): | |
global chain | |
# "What is milvus?" | |
if not chain: | |
return "please load the data first" | |
return chain(inputs={"question": question}, return_only_outputs=True).get( | |
"answer", "fail to get answer" | |
) | |
if __name__ == "__main__": | |
block = gr.Blocks() | |
with block as demo: | |
gr.Markdown( | |
""" | |
<h1><center>Langchain And Embed App</center></h1> | |
v.2.29.17.30 | |
""" | |
) | |
# url_list_text = gr.Textbox( | |
# label="url list", | |
# lines=3, | |
# placeholder="https://milvus.io/docs/overview.md", | |
# ) | |
file = gr.File(label='请上传知识库文件\n可以处理 .txt, .md, .docx, .pdf 结尾的文件', | |
file_types=['.txt', '.md', '.docx', '.pdf']) | |
#openai_key_text = gr.Textbox(label="openai api key", type="password", placeholder="sk-******") | |
#puzhiai_key_text = gr.Textbox(label="puzhi api key", type="password", placeholder="******") | |
loader_output = gr.Textbox(label="load status") | |
loader_btn = gr.Button("Load Data") | |
loader_btn.click( | |
fn=web_loader, | |
inputs=[ | |
file, | |
], | |
outputs=loader_output, | |
api_name="web_load", | |
) | |
question_text = gr.Textbox( | |
label="question", | |
lines=3, | |
placeholder="What is milvus?", | |
) | |
query_output = gr.Textbox(label="question answer", lines=3) | |
query_btn = gr.Button("Generate") | |
query_btn.click( | |
fn=query, | |
inputs=[question_text], | |
outputs=query_output, | |
api_name="generate_answer", | |
) | |
demo.queue().launch(server_name="0.0.0.0", share=False) |