githubear's picture
Update app.py
251f093 verified
Raw
History Blame Contribute Delete
8.4 kB
import os
import requests
import bcrypt
import pymysql
import streamlit as st
# vectordb
from langchain.vectorstores import OceanBase
# embeddings
# from langchain_openai import OpenAIEmbeddings
from langchain_community.embeddings import JinaEmbeddings
# PDF
from PyPDF2 import PdfReader
from langchain.text_splitter import RecursiveCharacterTextSplitter
# LLM
# from langchain_openai import ChatOpenAI
from langchain_community.llms import Tongyi
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.output_parsers import StrOutputParser
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnablePassthrough
from langchain.chains.combine_documents import create_stuff_documents_chain
# env
from dotenv import load_dotenv
load_dotenv()
if 'login_status' not in st.session_state:
st.session_state['login_status'] = False
if 'username' not in st.session_state:
st.session_state['username'] = ''
if 'user_id' not in st.session_state:
st.session_state['user_id'] = -1
# 配置数据库连接
def create_db_connection():
connection = None
try:
connection = pymysql.connect(
host=os.getenv("OB_HOST", "localhost"),
port=int(os.getenv("OB_PORT", 2881)),
db=os.getenv("OB_DATABASE", "test"),
user=os.getenv("OB_USER", "root"),
passwd=os.getenv("OB_PASSWORD", ""),
charset='utf8mb4',
cursorclass=pymysql.cursors.DictCursor
)
except pymysql.MySQLError as e:
st.error(f"The error '{e}' occurred")
return connection
# 用户注册功能
def register_user(connection, username, password):
with connection.cursor() as cursor:
try:
password_hash = bcrypt.hashpw(password.encode('utf-8'), bcrypt.gensalt())
insert_query = """
INSERT INTO chat_users (username, password)
VALUES (%s, %s)
"""
cursor.execute(insert_query, (username, password_hash))
connection.commit()
st.success("User registered successfully.")
except pymysql.MySQLError as e:
st.error(f"The error '{e}' occurred")
# 用户登录功能
def login_user(connection, username, password):
with connection.cursor() as cursor:
cursor.execute("SELECT user_id,password FROM chat_users WHERE username = %s", (username,))
record = cursor.fetchone()
if record and bcrypt.checkpw(password.encode('utf-8'), record['password'].encode('utf-8')):
return record['user_id']
return -1
def get_pdf_text(pdf_docs):
text = ""
for pdf in pdf_docs:
pdf_reader = PdfReader(pdf)
for page in pdf_reader.pages:
text += page.extract_text()
return text
def load_text_chunks(pdf_docs):
text = get_pdf_text(pdf_docs)
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=0)
return text_splitter.split_text(text)
## create vectore store
def get_oceanbase() -> OceanBase:
connection_str = OceanBase.connection_string_from_db_params(
host=os.getenv("OB_HOST", "localhost"),
port=os.getenv("OB_PORT", "2881"),
database=os.getenv("OB_DATABASE", "test"),
user=os.getenv("OB_USER", "root"),
password=os.getenv("OB_PASSWORD", ""),
)
# Embeddings: can be changed
# embeddings = OpenAIEmbeddings(api_key=os.getenv("OPENAI_API_KEY", ""), openai_proxy='https://api.chatgptid.net/v1')
embeddings = JinaEmbeddings(
jina_api_key=os.getenv("JINA_AI_API", ""), model_name="jina-embeddings-v2-base-zh"
)
# create oceanbase
collection_name = f"langchain_document{st.session_state['user_id']}"
oceanbase = OceanBase(
connection_string=connection_str,
embedding_function=embeddings,
collection_name=collection_name,
# pre_delete_collection=True, # TODO
)
return oceanbase
def get_texts_summary(texts):
prompt_text = """您是一名助理,负责总结文本以供检索。 \
这些摘要将用于embedding并用于检索原始文本。 \
现在请给出针对检索进行优化的简洁摘要。 以下是原始文本: {element} """
prompt = ChatPromptTemplate.from_template(prompt_text)
llm = Tongyi()
# llm = ChatOpenAI(api_key=os.getenv("OPENAI_API_KEY", ""), model="gpt-3.5-turbo-1106", openai_proxy='https://api.chatgptid.net/v1')
summarize_chain = {"element": lambda x: x} | prompt | llm | StrOutputParser()
# print(texts)
return summarize_chain.batch(texts, {"max_concurrency": 5})
def text_rag_chain(retriever):
llm = Tongyi()
# llm = ChatOpenAI(api_key=os.getenv("OPENAI_API_KEY", ""), model="gpt-3.5-turbo-1106", openai_proxy='https://api.chatgptid.net/v1')
SYSTEM_TEMPLATE = """
根据下面给出的上下文回答用户的问题。
如果下面的上下文中不包含与问题相关的任何信息,请不要编造内容,仅仅回复”我不知道“
<context>
{context}
</context>
"""
question_answering_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
SYSTEM_TEMPLATE,
),
MessagesPlaceholder(variable_name="messages"),
]
)
document_chain = create_stuff_documents_chain(llm, question_answering_prompt)
def parse_retriever_input(params):
return params["messages"][-1].content
retrieval_chain = RunnablePassthrough.assign(
context=parse_retriever_input | retriever,
).assign(
answer=document_chain,
)
return retrieval_chain
def run_pipeline(oceanbase, user_question):
retriever = oceanbase.as_retriever(k=5)
chain = text_rag_chain(retriever)
response = chain.invoke(
{
"messages": [
HumanMessage(content=user_question)
],
}
)
st.write(response["answer"])
def main():
if not st.session_state['login_status']:
# Streamlit 界面
st.title("Login to chat with OceanBase")
tab = st.radio("Choose a tab:", ["Login", "Register"])
# 注册界面
if tab == "Login":
login_username = st.text_input("Username", key="login_username")
login_password = st.text_input("Password", type="password", key="login_password")
login_button = st.button("Login")
if login_button:
conn = create_db_connection()
st.session_state['user_id'] = login_user(conn, login_username, login_password)
if conn and st.session_state['user_id'] != -1:
st.session_state['login_status'] = True
st.session_state['username'] = login_username
st.success(f"Welcome {login_username}!")
conn.close()
elif conn:
st.error("Incorrect username/password")
conn.close()
elif tab == "Register":
new_username = st.text_input("Username", key="register_username")
new_password = st.text_input("Password", type="password", key="register_password")
register_button = st.button("Register")
if register_button:
conn = create_db_connection()
if conn:
register_user(conn, new_username, new_password)
conn.close()
elif st.session_state['login_status']:
oceanbase = get_oceanbase()
st.set_page_config("Chat PDF")
st.header("Chat with PDF")
user_question = st.text_input("Ask a Question from the PDF Files")
if user_question:
run_pipeline(oceanbase, user_question)
with st.sidebar:
st.title("Menu:")
pdf_docs = st.file_uploader("Upload PDF Files", accept_multiple_files=True)
if st.button("Submit & Process"):
with st.spinner("Processing..."):
if pdf_docs:
texts = load_text_chunks(pdf_docs)
# summary = get_texts_summary(texts)
# texts.append(summary)
oceanbase.add_texts(texts=texts)
st.success("Done")
if __name__ == "__main__":
main()