llm_knowledge_base / app 2.py
allinaigc's picture
Upload 35 files
b2e325f verified
"""
1. 完成了用Qwen通义千问作为知识库查询。
1. 总共有三个区块:知识库回答,应用来源,相关问题。
1. 在Huggingface的API上部署了一个在线BGE的模型,用于回答问题。OpenAI的Emebedding或者Langchain的Embedding都不可以用(会报错: self.d)。
"""
##TODO: 1. 建立一个upload file的模块。
# -*- coding: utf-8 -*-
import requests
import streamlit as st
import openai
import os
import numpy as np
import pandas as pd
import csv
import tempfile
from tempfile import NamedTemporaryFile
import pathlib
from pathlib import Path
import re
from re import sub
import matplotlib.pyplot as plt
from itertools import product
from tqdm import tqdm_notebook, tqdm, trange
import time
from time import sleep
from matplotlib.pyplot import style
from rich import print
import warnings
import streamlit_authenticator as stauth
# from langchain.vectorstores import FAISS
from langchain_community.vectorstores import FAISS
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_core.runnables import RunnableParallel
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
from typing import Dict, List, Optional, Tuple, Union
import requests
import json
import streamlit as st
# import rag_reponse_001
import qwen_response
import rag_reponse_002
# import chatgpt
# from st_copy_to_clipboard import st_copy_to_clipboard
import clipboard
import dashscope
# warnings.filterwarnings('ignore')
from dotenv import load_dotenv
load_dotenv()
### 设置openai的API key
os.environ["OPENAI_API_KEY"] = os.environ['user_token']
openai.api_key = os.environ['user_token']
bing_search_api_key = os.environ['bing_api_key']
dashscope.api_key = os.environ['dashscope_api_key']
### Streamlit页面设定。
st.set_page_config(layout="wide")
st.title("大语言模型智能知识库查询中心")
# st.title("大语言模型本地知识库问答系统")
# st.subheader("Large Language Model-based Knowledge Base QA System")
# st.warning("_声明:内容由人工智能生成,仅供参考。如果您本人使用或对外传播本服务生成的输出,您应当主动核查输出内容的真实性、准确性,避免传播虚假信息。_")
st.caption("_声明:内容由人工智能生成,仅供参考。您应当主动核查输出内容的真实性、准确性,避免传播虚假信息。_")
# st.caption("_声明:内容由人工智能生成,仅供参考。如果您本人使用或对外传播本服务生成的输出,您应当主动核查输出内容的真实性、准确性,避免传播虚假信息。_")
# st.info("_声明:内容由人工智能生成,仅供参考。如果您本人使用或对外传播本服务生成的输出,您应当主动核查输出内容的真实性、准确性,避免传播虚假信息。_")
# st.divider()
### 上传文件的模块
def upload_file(uploaded_file):
if uploaded_file is not None:
# filename = uploaded_file.name
# st.write(filename) # print out the whole file name to validate. not to show in the final version.
try:
# if '.pdf' in filename: ### original code here.
if '.pdf' in uploaded_file.name:
pdf_filename = uploaded_file.name ### original code here.
filename = uploaded_file.name
# print('PDF file:', pdf_filename)
# with st.status('正在为您解析新知识库...', expanded=False, state='running') as status:
spinner = st.spinner('正在为您解析新知识库...请耐心等待')
with spinner:
### 以下是langchain方案。
import langchain_KB
import save_database_info
uploaded_file_name = "File_provided"
temp_dir = tempfile.TemporaryDirectory()
# ! working.
uploaded_file_path = pathlib.Path(temp_dir.name) / uploaded_file_name
with open(pdf_filename, 'wb') as output_temporary_file:
# with open(f'./{username}_upload.pdf', 'wb') as output_temporary_file: ### original code here. 可能会造成在引用信息来源时文件名不对的问题。
# ! 必须用这种格式读入内容,然后才可以写入temporary文件夹中。
# output_temporary_file.write(uploaded_file.getvalue())
output_temporary_file.write(uploaded_file.getvalue())
langchain_KB.langchain_localKB_construct(output_temporary_file, username)
## 在屏幕上展示当前知识库的信息,包括名字和加载日期。
save_database_info.save_database_info(f'./{username}/database_name.csv', pdf_filename, str(datetime.now(pytz.timezone('Asia/Shanghai')).strftime("%Y-%m-%d %H:%M")))
st.markdown('新知识库解析成功,请务必刷新页面,然后开启对话 🔃')
return pdf_filename
else:
# if '.csv' in filename: ### original code here.
if '.csv' in uploaded_file.name:
print('start the csv file processing...')
csv_filename = uploaded_file.name
filename = uploaded_file.name
csv_file = pd.read_csv(uploaded_file)
csv_file.to_csv(f'./{username}/{username}_upload.csv', encoding='utf-8', index=False)
st.write(csv_file[:3]) # 这里只是显示文件,后面需要定位文件所在的绝对路径。
else:
xls_file = pd.read_excel(uploaded_file)
xls_file.to_csv(f'./{username}_upload.csv', index=False)
st.write(xls_file[:3])
print('end the csv file processing...')
# uploaded_file_name = "File_provided"
# temp_dir = tempfile.TemporaryDirectory()
# ! working.
# uploaded_file_path = pathlib.Path(temp_dir.name) / uploaded_file_name
# with open('./upload.csv', 'wb') as output_temporary_file:
# with open(f'./{username}_upload.csv', 'wb') as output_temporary_file:
# print(f'./{name}_upload.csv')
# ! 必须用这种格式读入内容,然后才可以写入temporary文件夹中。
# output_temporary_file.write(uploaded_file.getvalue())
# st.write(uploaded_file_path) #* 可以查看文件是否真实存在,然后是否可以
except Exception as e:
st.write(e)
## 以下代码是为了解决上传文件后,文件路径和文件名不对的问题。
# uploaded_file_name = "File_provided"
# temp_dir = tempfile.TemporaryDirectory()
# # ! working.
# uploaded_file_path = pathlib.Path(temp_dir.name) / uploaded_file_name
# # with open('./upload.csv', 'wb') as output_temporary_file:
# with open(f'./{name}_upload.csv', 'wb') as output_temporary_file:
# # print(f'./{name}_upload.csv')
# # ! 必须用这种格式读入内容,然后才可以写入temporary文件夹中。
# # output_temporary_file.write(uploaded_file.getvalue())
# output_temporary_file.write(uploaded_file.getvalue())
# # st.write(uploaded_file_path) # * 可以查看文件是否真实存在,然后是否可以
# # st.write('Now file saved successfully.')
# return pdf_filename, csv_filename
return filename
### upload file
# username = 'test'
# path = f'./{username}/faiss_index/index.faiss'
# if os.path.exists(path):
# print(f'{path} local KB exists')
# database_info = pd.read_csv(f'./{username}/database_name.csv')
# current_database_name = database_info.iloc[-1][0]
# current_database_date = database_info.iloc[-1][1]
# database_claim = f"当前知识库为:{current_database_name},创建于{current_database_date}。可以开始提问!"
# st.markdown(database_claim)
# uploaded_file = st.file_uploader(
# "选择上传一个新知识库", type=(["pdf"]))
# # 默认状态下没有上传文件,None,会报错。需要判断。
# if uploaded_file is not None:
# # uploaded_file_path = upload_file(uploaded_file)
# upload_file(uploaded_file)
# # ## 创建向量数据库
# from langchain.embeddings.openai import OpenAIEmbeddings
# embeddings = OpenAIEmbeddings(disallowed_special=()) ## 这里是联网情况下,部署在Huggingface上后使用。
# print('embeddings:', embeddings)
# embedding_model_name = 'GanymedeNil/text2vec-large-chinese'
# # embeddings = HuggingFaceEmbeddings(model_name=embedding_model_name) ## 这里是联网情况下连接huggingface后使用。
# embeddings = HuggingFaceEmbeddings(model_name='/Users/yunshi/Downloads/360Data/Data Center/Working-On Task/演讲与培训/2023ChatGPT/Coding/RAG/bge-large-zh') ## 切换成BGE的embedding。
# embeddings = HuggingFaceEmbeddings(model_name='/Users/yunshi/Downloads/360Data/Data Center/Working-On Task/演讲与培训/2023ChatGPT/RAG/bge-large-zh/') ## 切换成BGE的embedding。
# embeddings = HuggingFaceEmbeddings(model_name='/Users/yunshi/Downloads/chatGLM/My_LocalKB_Project/GanymedeNil_text2vec-large-chinese/') ## 这里会有个“No sentence-transformers model found with name“的warning,但不是error,不影响使用。
### authentication with a local yaml file.
import yaml
from yaml.loader import SafeLoader
with open('./config.yaml') as file:
config = yaml.load(file, Loader=SafeLoader)
authenticator = stauth.Authenticate(
config['credentials'],
config['cookie']['name'],
config['cookie']['key'],
config['cookie']['expiry_days'],
config['preauthorized']
)
user, authentication_status, username = authenticator.login('用户登录', 'main')
if authentication_status:
with st.sidebar:
st.markdown(
"""
<style>
[data-testid="stSidebar"][aria-expanded="true"]{
min-width: 450px;
max-width: 450px;
}
""",
unsafe_allow_html=True,
)
### siderbar的题目。
### siderbar的题目。
# st.header(f'**大语言模型专家系统工作设定区**')
st.header(f'**欢迎 **{username}** 使用本系统** ')
st.write(f'_Large Language Model Expert System Working Environment_')
# st.write(f'_Welcome and Hope U Enjoy Staying Here_')
authenticator.logout('登出', 'sidebar')
## 在sidebar上的三个分页显示,用st.tabs实现。
tab_1, tab_2, tab_3, tab_4 = st.tabs(['使用须知', '模型参数', '提示词模板', '系统角色设定'])
# with st.expander(label='**使用须知**', expanded=False):
with tab_1:
# st.markdown("#### 快速上手指南")
# with st.text(body="说明"):
# st.markdown("* 重启一轮新对话时,只需要刷新页面(按Ctrl/Command + R)即可。")
with st.text(body="说明"):
st.markdown("* 为了保护数据与隐私,所有对话均不会被保存,刷新页面立即删除。敬请放心。")
# with st.text(body="说明"):
# st.markdown("* “GPT-4”回答质量极佳,但速度缓慢,建议适当使用。")
with st.text(body="说明"):
st.markdown("* 查询知识库模式与所有的搜索引擎或者数据库检索方式一样,仅限一轮对话,将不会保持之前的会话记录。")
with st.text(body="说明"):
st.markdown("""* 系统的工作流程如下:
1. 用户输入问题。
1. 系统将问题转换为机器可理解的格式。
1. 系统使用大语言模型来生成与问题相关的候选答案。
1. 系统使用本地知识库来评估候选答案的准确性。
1. 系统返回最准确的答案。""")
## 大模型参数
# with st.expander(label='**大语言模型参数**', expanded=True):
with tab_2:
max_tokens = st.slider(label='Max_Token(生成结果时最大字数)', min_value=100, max_value=8096, value=4096,step=100)
temperature = st.slider(label='Temperature (温度)', min_value=0.0, max_value=1.0, value=0.8, step=0.1)
top_p = st.slider(label='Top_P (核采样)', min_value=0.0, max_value=1.0, value=0.6, step=0.1)
frequency_penalty = st.slider(label='Frequency Penalty (重复度惩罚因子)', min_value=-2.0, max_value=2.0, value=1.0, step=0.1)
presence_penalty = st.slider(label='Presence Penalty (控制主题的重复度)', min_value=-2.0, max_value=2.0, value=1.0, step=0.1)
with tab_3:
# st.markdown("#### Prompt提示词参考资料")
# with st.expander(label="**大语言模型基础提示词Prompt示例**", expanded=False):
st.code(
body="我是一个企业主,我需要关注哪些“存货”相关的数据资源规则?", language='plaintext')
st.code(
body="作为零售商,了解哪些关键的库存管理指标对我至关重要?", language='plaintext')
st.code(body="企业主在监控库存时,应如何确保遵守行业法规和最佳实践?",
language='plaintext')
st.code(body="在数字化时代,我应该关注哪些技术工具或平台来优化我的库存数据流程?", language='plaintext')
st.code(body="我应该如何定期审查和分析这些库存数据以提高运营效率?", language='plaintext')
st.code(body="如何设置预警系统来避免过度库存或缺货情况?", language='plaintext')
with tab_4:
st.text_area(label='系统角色设定', value='你是一个人工智能,你需要回答我提出的问题,或者完成我交代的任务。你需要使用我提问的语言(如中文、英文)来回答。', height=200, label_visibility='hidden')
elif authentication_status == False:
st.error('⛔ 用户名或密码错误!')
elif authentication_status == None:
st.warning('⬆ 请先登录!')
#### start: 主程序
## 清楚所有对话记录。
def clear_all():
st.session_state.conversation = None
st.session_state.chat_history = None
st.session_state.messages = []
message_placeholder = st.empty()
## 只用这一个就可以了。
st.rerun()
return None
if "copied" not in st.session_state:
st.session_state.copied = []
if "llm_response" not in st.session_state:
st.session_state.llm_response = ""
## copy to clipboard function with a button.
def copy_to_clipboard(text):
st.session_state.copied.append(text)
clipboard.copy(text)
def main():
# llm = ChatGLM() ## 启动一个实例。
col1, col2 = st.columns([2, 1])
# st.markdown('### 数据库查询区')
# with st.expander(label='**查询企业内部知识库**', expanded=True):
with col1:
KB_mode = True
user_input = st.text_input(label='**🧭 大模型数据库对话区**', placeholder='请输入您的问题', label_visibility='visible')
if user_input:
## 非stream输出,原始状态,不需要改变api.py中的内容。
# with st.status('检索中...', expanded=True, state='running') as status:
spinner = st.spinner('思考中...请耐心等待')
with spinner:
if KB_mode == True:
# import rag_reponse_001
# clear_all()
# response = rag_reponse_001.rag_response(user_input=user_input, k=5) ## working.
print('user_input:', user_input)
response, source = rag_reponse_002.rag_response(user_input=user_input, k=3)
print('llm response:', response)
sim_prompt = f"""你需要根据以下的问题来提出5个可能的后续问题{user_input}
"""
# sim_questions = chatgpt.chatgpt(user_prompt=sim_prompt) ## chatgpt to get similar questions.
sim_questions = qwen_response.call_with_messages(sim_prompt)
if len(user_input) != 0:
sim_prompt = f"""你需要根据以下的初始问题来提出3个相似的问题和3个后续问题。
初始问题是:{user_input}
--------------------
你回答的时候,需要使用如下格式:
**相似问题:**
**后续问题:**
"""
# sim_prompt = f"""你需要根据以下的问题来提出5个可能的后续问题{user_input}"""
### 这里用chatgpt来生成相似问题。
# sim_questions = chatgpt.chatgpt(user_prompt=sim_prompt)
### 这里用Qwen来生成相似问题。
sim_questions = qwen_response.call_with_messages(sim_prompt)
st.markdown(response)
# st_copy_to_clipboard(text=str(response), show_text=True, before_copy_label="📋", after_copy_label="✅")
## 如果这样使用,每次按button都会重新提交问题。
# st.button(label="📃", on_click=copy_to_clipboard, args=(response,))
st.divider()
st.caption(source)
st.divider()
## 初始状态下response未被定义。
try:
if response:
with col2:
with st.expander(label='## **您可能还会关注以下内容**', expanded=True):
st.info(sim_questions)
except:
pass
# st.stop()
return None
#### End: 主程序
if __name__ == '__main__':
main()