| import os |
| import requests |
| import bcrypt |
| import pymysql |
| import streamlit as st |
| |
| from langchain.vectorstores import OceanBase |
| |
| |
| from langchain_community.embeddings import JinaEmbeddings |
| |
| from PyPDF2 import PdfReader |
| from langchain.text_splitter import RecursiveCharacterTextSplitter |
| |
| |
| from langchain_community.llms import Tongyi |
| from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder |
| from langchain_core.output_parsers import StrOutputParser |
| from langchain_core.messages import HumanMessage |
| from langchain_core.runnables import RunnablePassthrough |
| from langchain.chains.combine_documents import create_stuff_documents_chain |
| |
| from dotenv import load_dotenv |
|
|
| load_dotenv() |
| if 'login_status' not in st.session_state: |
| st.session_state['login_status'] = False |
| if 'username' not in st.session_state: |
| st.session_state['username'] = '' |
| if 'user_id' not in st.session_state: |
| st.session_state['user_id'] = -1 |
|
|
| |
| def create_db_connection(): |
| connection = None |
| try: |
| connection = pymysql.connect( |
| host=os.getenv("OB_HOST", "localhost"), |
| port=int(os.getenv("OB_PORT", 2881)), |
| db=os.getenv("OB_DATABASE", "test"), |
| user=os.getenv("OB_USER", "root"), |
| passwd=os.getenv("OB_PASSWORD", ""), |
| charset='utf8mb4', |
| cursorclass=pymysql.cursors.DictCursor |
| ) |
| except pymysql.MySQLError as e: |
| st.error(f"The error '{e}' occurred") |
| return connection |
|
|
| |
| def register_user(connection, username, password): |
| with connection.cursor() as cursor: |
| try: |
| password_hash = bcrypt.hashpw(password.encode('utf-8'), bcrypt.gensalt()) |
| insert_query = """ |
| INSERT INTO chat_users (username, password) |
| VALUES (%s, %s) |
| """ |
| cursor.execute(insert_query, (username, password_hash)) |
| connection.commit() |
| st.success("User registered successfully.") |
| except pymysql.MySQLError as e: |
| st.error(f"The error '{e}' occurred") |
|
|
| |
| def login_user(connection, username, password): |
| with connection.cursor() as cursor: |
| cursor.execute("SELECT user_id,password FROM chat_users WHERE username = %s", (username,)) |
| record = cursor.fetchone() |
| if record and bcrypt.checkpw(password.encode('utf-8'), record['password'].encode('utf-8')): |
| return record['user_id'] |
| return -1 |
|
|
| def get_pdf_text(pdf_docs): |
| text = "" |
| for pdf in pdf_docs: |
| pdf_reader = PdfReader(pdf) |
| for page in pdf_reader.pages: |
| text += page.extract_text() |
| return text |
|
|
| def load_text_chunks(pdf_docs): |
| text = get_pdf_text(pdf_docs) |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=0) |
| return text_splitter.split_text(text) |
|
|
| |
| def get_oceanbase() -> OceanBase: |
| connection_str = OceanBase.connection_string_from_db_params( |
| host=os.getenv("OB_HOST", "localhost"), |
| port=os.getenv("OB_PORT", "2881"), |
| database=os.getenv("OB_DATABASE", "test"), |
| user=os.getenv("OB_USER", "root"), |
| password=os.getenv("OB_PASSWORD", ""), |
| ) |
| |
| |
| embeddings = JinaEmbeddings( |
| jina_api_key=os.getenv("JINA_AI_API", ""), model_name="jina-embeddings-v2-base-zh" |
| ) |
| |
| collection_name = f"langchain_document{st.session_state['user_id']}" |
| oceanbase = OceanBase( |
| connection_string=connection_str, |
| embedding_function=embeddings, |
| collection_name=collection_name, |
| |
| ) |
| return oceanbase |
|
|
| def get_texts_summary(texts): |
| prompt_text = """您是一名助理,负责总结文本以供检索。 \ |
| 这些摘要将用于embedding并用于检索原始文本。 \ |
| 现在请给出针对检索进行优化的简洁摘要。 以下是原始文本: {element} """ |
| prompt = ChatPromptTemplate.from_template(prompt_text) |
| |
| llm = Tongyi() |
| |
| summarize_chain = {"element": lambda x: x} | prompt | llm | StrOutputParser() |
| |
| return summarize_chain.batch(texts, {"max_concurrency": 5}) |
|
|
| def text_rag_chain(retriever): |
| llm = Tongyi() |
| |
| SYSTEM_TEMPLATE = """ |
| 根据下面给出的上下文回答用户的问题。 |
| 如果下面的上下文中不包含与问题相关的任何信息,请不要编造内容,仅仅回复”我不知道“ |
| |
| <context> |
| {context} |
| </context> |
| """ |
| question_answering_prompt = ChatPromptTemplate.from_messages( |
| [ |
| ( |
| "system", |
| SYSTEM_TEMPLATE, |
| ), |
| MessagesPlaceholder(variable_name="messages"), |
| ] |
| ) |
| document_chain = create_stuff_documents_chain(llm, question_answering_prompt) |
|
|
| def parse_retriever_input(params): |
| return params["messages"][-1].content |
| retrieval_chain = RunnablePassthrough.assign( |
| context=parse_retriever_input | retriever, |
| ).assign( |
| answer=document_chain, |
| ) |
| return retrieval_chain |
|
|
| def run_pipeline(oceanbase, user_question): |
| retriever = oceanbase.as_retriever(k=5) |
| chain = text_rag_chain(retriever) |
| response = chain.invoke( |
| { |
| "messages": [ |
| HumanMessage(content=user_question) |
| ], |
| } |
| ) |
| st.write(response["answer"]) |
|
|
| def main(): |
| if not st.session_state['login_status']: |
| |
| st.title("Login to chat with OceanBase") |
| tab = st.radio("Choose a tab:", ["Login", "Register"]) |
|
|
| |
| if tab == "Login": |
| login_username = st.text_input("Username", key="login_username") |
| login_password = st.text_input("Password", type="password", key="login_password") |
| login_button = st.button("Login") |
| |
| if login_button: |
| conn = create_db_connection() |
| st.session_state['user_id'] = login_user(conn, login_username, login_password) |
| if conn and st.session_state['user_id'] != -1: |
| st.session_state['login_status'] = True |
| st.session_state['username'] = login_username |
| st.success(f"Welcome {login_username}!") |
| conn.close() |
| elif conn: |
| st.error("Incorrect username/password") |
| conn.close() |
| elif tab == "Register": |
| new_username = st.text_input("Username", key="register_username") |
| new_password = st.text_input("Password", type="password", key="register_password") |
| register_button = st.button("Register") |
| |
| if register_button: |
| conn = create_db_connection() |
| if conn: |
| register_user(conn, new_username, new_password) |
| conn.close() |
| |
| elif st.session_state['login_status']: |
| oceanbase = get_oceanbase() |
| st.set_page_config("Chat PDF") |
| st.header("Chat with PDF") |
|
|
| user_question = st.text_input("Ask a Question from the PDF Files") |
|
|
| if user_question: |
| run_pipeline(oceanbase, user_question) |
|
|
| with st.sidebar: |
| st.title("Menu:") |
| pdf_docs = st.file_uploader("Upload PDF Files", accept_multiple_files=True) |
| if st.button("Submit & Process"): |
| with st.spinner("Processing..."): |
| if pdf_docs: |
| texts = load_text_chunks(pdf_docs) |
| |
| |
| oceanbase.add_texts(texts=texts) |
| st.success("Done") |
|
|
| if __name__ == "__main__": |
| main() |