NewRAG-chatbot / app.py
ziphai's picture
Update app.py
d6850af verified
import gradio as gr
import os
from langchain.chains import ConversationalRetrievalChain
from langchain.text_splitter import CharacterTextSplitter
from langchain_community.document_loaders import PyPDFLoader, Docx2txtLoader, TextLoader
from langchain_community.vectorstores import Chroma
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
# Validate OpenAI API Key
api_key = os.getenv('OPENAI_API_KEY')
if not api_key:
raise ValueError("Please set the 'OPENAI_API_KEY' environment variable")
# OpenAI API key
openai_api_key = api_key
# Transform chat history for LangChain format
def transform_history_for_langchain(history):
return [(chat[0], chat[1]) for chat in history if chat[0]]
# Transform chat history for OpenAI format
def transform_history_for_openai(history):
new_history = []
for chat in history:
if chat[0]:
new_history.append({"role": "user", "content": chat[0]})
if chat[1]:
new_history.append({"role": "assistant", "content": chat[1]})
return new_history
# Load and process documents function
def load_and_process_documents(folder_path):
documents = []
for file in os.listdir(folder_path):
file_path = os.path.join(folder_path, file)
if file.endswith(".pdf"):
loader = PyPDFLoader(file_path)
documents.extend(loader.load())
elif file.endswith('.docx') or file.endswith('.doc'):
loader = Docx2txtLoader(file_path)
documents.extend(loader.load())
elif file.endswith('.txt'):
loader = TextLoader(file_path)
documents.extend(loader.load())
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=10)
documents = text_splitter.split_documents(documents)
vectordb = Chroma.from_documents(
documents,
embedding=OpenAIEmbeddings(),
persist_directory="./tmp"
)
return vectordb
# Initialize vector database as a global variable
if 'vectordb' not in globals():
vectordb = load_and_process_documents("./")
# Define query handling function for RAG
def handle_query(user_message, temperature, chat_history):
try:
if not user_message:
return chat_history # Return unchanged chat history
# Use LangChain's ConversationalRetrievalChain to handle the query
preface = """
Instruction: Answer in Traditional Chinese, within 200 characters.這是AI論壇,只回答AI相關問題
If the question is unrelated to the documents, respond with: 此事無可奉告,話說這件事須請教海虔王...
"""
query = f"{preface} Query content: {user_message}"
# Extract previous answers as context, converting them to LangChain format
previous_answers = transform_history_for_langchain(chat_history)
pdf_qa = ConversationalRetrievalChain.from_llm(
ChatOpenAI(temperature=temperature, model_name='gpt-4'),
retriever=vectordb.as_retriever(search_kwargs={'k': 6}),
return_source_documents=True,
verbose=False
)
# Invoke the model to handle the query
result = pdf_qa.invoke({"question": query, "chat_history": previous_answers})
# Ensure 'answer' is present in the result
if "answer" not in result:
return chat_history + [("System", "Sorry, an error occurred.")]
# Update the AI response in chat history
chat_history[-1] = (user_message, result["answer"]) # Update the last record, pairing user input with AI response
return chat_history
except Exception as e:
return chat_history + [("System", f"An error occurred: {str(e)}")]
# Create a custom chat interface using Gradio Blocks API
with gr.Blocks() as demo:
gr.Markdown("<h1 style='text-align: center;'>AI Assistant for AI Forum</h1>")
chatbot = gr.Chatbot()
state = gr.State([])
with gr.Row():
with gr.Column(scale=0.85):
txt = gr.Textbox(show_label=False, placeholder="Please enter your question...")
with gr.Column(scale=0.15, min_width=0):
submit_btn = gr.Button("Ask")
# Immediately show user input without response part, and clear input box
def user_input(user_message, history):
history.append((user_message, "")) # Show user message, response part as empty string
return history, "", history # Return cleared input box and updated chat history
# Handle AI response, update response part
def bot_response(history):
user_message = history[-1][0] # Get the latest user input
history = handle_query(user_message, 0.7, history) # Call the query handler
return history, history # Return updated chat history
# First show user message, then handle AI response, clear input box
submit_btn.click(user_input, [txt, state], [chatbot, txt, state], queue=False).then(
bot_response, state, [chatbot, state]
)
# Support pressing "Enter" to submit question, immediately show user input, clear input box
txt.submit(user_input, [txt, state], [chatbot, txt, state], queue=False).then(
bot_response, state, [chatbot, state]
)
# Launch Gradio app
demo.launch()