File size: 5,004 Bytes
a7b5657
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23fff3a
a7b5657
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23fff3a
a7b5657
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23fff3a
a7b5657
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
from typing import Callable, Optional

import gradio as gr
from langchain.vectorstores import Zilliz
from langchain.document_loaders import TextLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain.chains import RetrievalQAWithSourcesChain
from langchain.chains.llm import LLMChain
from langchain.chains import StuffDocumentsChain
from langchain_core.prompts import PromptTemplate
import hashlib
import os
from project.embeddings.local_embed import LocalEmbed
from project.llm.check_embed_llm import CheckEmbedLlm

chain: Optional[Callable] = None

db_host = os.getenv("DB_HOST")
db_user = os.getenv("DB_USER")
db_password = os.getenv("DB_PASSWORD")
zhipuai_api_key = os.getenv("ZHIPU_AI_KEY")


def generate_article_id(content):
    # 使用SHA-256哈希算法
    sha256 = hashlib.sha256()

    # 将文章内容编码为字节流并更新哈希对象
    sha256.update(content.encode('utf-8'))

    # 获取哈希值的十六进制表示
    article_id = sha256.hexdigest()

    return article_id


def web_loader(file):
    if not file:
        return "please upload file"
    loader = TextLoader(file)
    docs = loader.load()

    text_splitter = CharacterTextSplitter(chunk_size=512, chunk_overlap=0)
    docs = text_splitter.split_documents(docs)
    #embeddings = OpenAIEmbeddings(model="text-embedding-ada-002", openai_api_key=openai_key)
    #embeddings = ZhipuAIEmbeddings(zhipuai_api_key=zhipuai_api_key)
    embeddings = LocalEmbed(zhipuai_api_key=zhipuai_api_key)


    if not embeddings:
        return "embeddings not"

    texts = [d.page_content for d in docs]
    article_ids = []
    # 遍历texts列表
    for text in texts:
        # 使用generate_article_id函数生成文章ID,并将其添加到article_ids列表中
        article_id = generate_article_id(text)
        article_ids.append(article_id)

    docsearch = Zilliz.from_documents(
        docs,
        embedding=embeddings,
        ids=article_ids,
        connection_args={
            "uri": db_host,
            "user": db_user,
            "password": db_password,
            "secure": True,
        },
        collection_name="CheckEmbedLocalEmbed"
    )

    if not docsearch:
        return "docsearch not"

    llm = CheckEmbedLlm(model="glm-3-turbo", temperature=0.1, zhipuai_api_key=zhipuai_api_key)

    document_prompt = PromptTemplate(
        input_variables=["page_content"],
        template="{page_content}"
    )
    document_variable_name = "context"
    # The prompt here should take as an input variable the
    # `document_variable_name`
    prompt = PromptTemplate.from_template(
        """查询到的文档如下:
        {context}

        问题: {question}
        答:"""
    )
    llm_chain = LLMChain(llm=llm, prompt=prompt)
    combine_documents_chain = StuffDocumentsChain(
        llm_chain=llm_chain,
        document_prompt=document_prompt,
        document_variable_name=document_variable_name
    )

    global chain
    chain = RetrievalQAWithSourcesChain(combine_documents_chain=combine_documents_chain,
                                        retriever=docsearch.as_retriever(search_kwargs={'k': 3}))
    return "success to load data"


def query(question):
    global chain
    # "What is milvus?"
    if not chain:
        return "please load the data first"
    return chain(inputs={"question": question}, return_only_outputs=True).get(
        "answer", "fail to get answer"
    )


if __name__ == "__main__":
    block = gr.Blocks()
    with block as demo:
        gr.Markdown(
            """
        <h1><center>Langchain And Embed App</center></h1>
        
        v.2.29.17.30
        
        """
        )
        # url_list_text = gr.Textbox(
        #     label="url list",
        #     lines=3,
        #     placeholder="https://milvus.io/docs/overview.md",
        # )
        file = gr.File(label='请上传知识库文件\n可以处理 .txt, .md, .docx, .pdf 结尾的文件',
                       file_types=['.txt', '.md', '.docx', '.pdf'])
        #openai_key_text = gr.Textbox(label="openai api key", type="password", placeholder="sk-******")
        #puzhiai_key_text = gr.Textbox(label="puzhi api key", type="password", placeholder="******")

        loader_output = gr.Textbox(label="load status")
        loader_btn = gr.Button("Load Data")
        loader_btn.click(
            fn=web_loader,
            inputs=[
                file,
            ],
            outputs=loader_output,
            api_name="web_load",
        )

        question_text = gr.Textbox(
            label="question",
            lines=3,
            placeholder="What is milvus?",
        )
        query_output = gr.Textbox(label="question answer", lines=3)
        query_btn = gr.Button("Generate")
        query_btn.click(
            fn=query,
            inputs=[question_text],
            outputs=query_output,
            api_name="generate_answer",
        )

        demo.queue().launch(server_name="0.0.0.0", share=False)