hello-embed / app.py
guangliang.yin
优化代码
23fff3a
raw
history blame
No virus
5 kB
from typing import Callable, Optional
import gradio as gr
from langchain.vectorstores import Zilliz
from langchain.document_loaders import TextLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain.chains import RetrievalQAWithSourcesChain
from langchain.chains.llm import LLMChain
from langchain.chains import StuffDocumentsChain
from langchain_core.prompts import PromptTemplate
import hashlib
import os
from project.embeddings.local_embed import LocalEmbed
from project.llm.check_embed_llm import CheckEmbedLlm
chain: Optional[Callable] = None
db_host = os.getenv("DB_HOST")
db_user = os.getenv("DB_USER")
db_password = os.getenv("DB_PASSWORD")
zhipuai_api_key = os.getenv("ZHIPU_AI_KEY")
def generate_article_id(content):
# 使用SHA-256哈希算法
sha256 = hashlib.sha256()
# 将文章内容编码为字节流并更新哈希对象
sha256.update(content.encode('utf-8'))
# 获取哈希值的十六进制表示
article_id = sha256.hexdigest()
return article_id
def web_loader(file):
if not file:
return "please upload file"
loader = TextLoader(file)
docs = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=512, chunk_overlap=0)
docs = text_splitter.split_documents(docs)
#embeddings = OpenAIEmbeddings(model="text-embedding-ada-002", openai_api_key=openai_key)
#embeddings = ZhipuAIEmbeddings(zhipuai_api_key=zhipuai_api_key)
embeddings = LocalEmbed(zhipuai_api_key=zhipuai_api_key)
if not embeddings:
return "embeddings not"
texts = [d.page_content for d in docs]
article_ids = []
# 遍历texts列表
for text in texts:
# 使用generate_article_id函数生成文章ID,并将其添加到article_ids列表中
article_id = generate_article_id(text)
article_ids.append(article_id)
docsearch = Zilliz.from_documents(
docs,
embedding=embeddings,
ids=article_ids,
connection_args={
"uri": db_host,
"user": db_user,
"password": db_password,
"secure": True,
},
collection_name="CheckEmbedLocalEmbed"
)
if not docsearch:
return "docsearch not"
llm = CheckEmbedLlm(model="glm-3-turbo", temperature=0.1, zhipuai_api_key=zhipuai_api_key)
document_prompt = PromptTemplate(
input_variables=["page_content"],
template="{page_content}"
)
document_variable_name = "context"
# The prompt here should take as an input variable the
# `document_variable_name`
prompt = PromptTemplate.from_template(
"""查询到的文档如下:
{context}
问题: {question}
答:"""
)
llm_chain = LLMChain(llm=llm, prompt=prompt)
combine_documents_chain = StuffDocumentsChain(
llm_chain=llm_chain,
document_prompt=document_prompt,
document_variable_name=document_variable_name
)
global chain
chain = RetrievalQAWithSourcesChain(combine_documents_chain=combine_documents_chain,
retriever=docsearch.as_retriever(search_kwargs={'k': 3}))
return "success to load data"
def query(question):
global chain
# "What is milvus?"
if not chain:
return "please load the data first"
return chain(inputs={"question": question}, return_only_outputs=True).get(
"answer", "fail to get answer"
)
if __name__ == "__main__":
block = gr.Blocks()
with block as demo:
gr.Markdown(
"""
<h1><center>Langchain And Embed App</center></h1>
v.2.29.17.30
"""
)
# url_list_text = gr.Textbox(
# label="url list",
# lines=3,
# placeholder="https://milvus.io/docs/overview.md",
# )
file = gr.File(label='请上传知识库文件\n可以处理 .txt, .md, .docx, .pdf 结尾的文件',
file_types=['.txt', '.md', '.docx', '.pdf'])
#openai_key_text = gr.Textbox(label="openai api key", type="password", placeholder="sk-******")
#puzhiai_key_text = gr.Textbox(label="puzhi api key", type="password", placeholder="******")
loader_output = gr.Textbox(label="load status")
loader_btn = gr.Button("Load Data")
loader_btn.click(
fn=web_loader,
inputs=[
file,
],
outputs=loader_output,
api_name="web_load",
)
question_text = gr.Textbox(
label="question",
lines=3,
placeholder="What is milvus?",
)
query_output = gr.Textbox(label="question answer", lines=3)
query_btn = gr.Button("Generate")
query_btn.click(
fn=query,
inputs=[question_text],
outputs=query_output,
api_name="generate_answer",
)
demo.queue().launch(server_name="0.0.0.0", share=False)