faq-rag-chatbot / app.py
Techbite's picture
changed to deep translator
43ebacc
import streamlit as st
import time
import os
import gc
import torch
from src.data_processing import load_huggingface_faq_data, load_faq_data, preprocess_faq, augment_faqs
from src.embedding import FAQEmbedder
from src.llm_response import ResponseGenerator
from src.utils import time_function, format_memory_stats, evaluate_response, evaluate_retrieval, baseline_keyword_search
from deep_translator import GoogleTranslator # Updated import
# Suppress CUDA warning and Torch path errors
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["TORCH_NO_PATH_CHECK"] = "1"
st.set_page_config(page_title="E-Commerce FAQ Chatbot", layout="wide", initial_sidebar_state="expanded")
@time_function
def initialize_components(use_huggingface: bool = True, model_name: str = "microsoft/phi-2", enable_augmentation: bool = True):
"""
Initialize RAG system components
"""
try:
if use_huggingface:
faqs = load_huggingface_faq_data("NebulaByte/E-Commerce_FAQs")
else:
faqs = load_faq_data("data/faq_data.csv")
processed_faqs = augment_faqs(preprocess_faq(faqs), enable_augmentation=enable_augmentation)
embedder = FAQEmbedder()
if os.path.exists("embeddings"):
embedder.load("embeddings")
else:
embedder.create_embeddings(processed_faqs)
embedder.save("embeddings")
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
response_generator = ResponseGenerator(model_name=model_name)
response_generator.generate_response("Warmup query", [{"question": "Test", "answer": "Test"}])
return embedder, response_generator, len(processed_faqs)
except Exception as e:
st.error(f"Initialization failed: {e}")
raise
def main():
st.title("E-Commerce Customer Support FAQ Chatbot")
st.subheader("Ask about orders, shipping, returns, or other e-commerce queries")
st.sidebar.title("Configuration")
use_huggingface = st.sidebar.checkbox("Use Hugging Face Dataset", value=True)
enable_augmentation = st.sidebar.checkbox("Enable FAQ Augmentation", value=True, help="Generate paraphrased questions to expand dataset")
target_lang = st.sidebar.selectbox("Language", ["en", "es", "fr"], index=0)
model_options = {
"Phi-2 (Recommended for 16GB RAM)": "microsoft/phi-2",
"TinyLlama-1.1B (Fastest)": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"Mistral-7B (For 15GB+ GPU)": "mistralai/Mistral-7B-Instruct-v0.1"
}
selected_model = st.sidebar.selectbox("Select LLM Model", list(model_options.keys()), index=0)
model_name = model_options[selected_model]
if st.sidebar.checkbox("Show Memory Usage", value=True):
st.sidebar.subheader("Memory Usage")
for key, value in format_memory_stats().items():
st.sidebar.text(f"{key}: {value}")
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
if "query_cache" not in st.session_state:
st.session_state.query_cache = {}
if "feedback" not in st.session_state:
st.session_state.feedback = []
if "system_initialized" not in st.session_state or st.sidebar.button("Reload System"):
with st.spinner("Initializing system..."):
try:
st.session_state.embedder, st.session_state.response_generator, num_faqs = initialize_components(
use_huggingface=use_huggingface,
model_name=model_name,
enable_augmentation=enable_augmentation
)
st.session_state.system_initialized = True
st.sidebar.success(f"System initialized with {num_faqs} FAQs!")
except Exception as e:
st.error(f"System initialization failed: {e}")
return
col1, col2 = st.columns([2, 1])
with col1:
st.subheader("Conversation")
chat_container = st.container(height=400)
with chat_container:
for i, message in enumerate(st.session_state.chat_history):
if message["role"] == "user":
st.markdown(f"**You**: {message['content']}")
else:
st.markdown(f"**Bot**: {message['content']}")
if i < len(st.session_state.chat_history) - 1:
st.markdown("---")
with st.form(key="chat_form"):
user_query = st.text_input("Type your question:", key="user_input", placeholder="e.g., How do I track my order?")
submit_button = st.form_submit_button("Ask")
if len(st.session_state.chat_history) > 0:
with st.form(key=f"feedback_form_{len(st.session_state.chat_history)}"):
rating = st.slider("Rate this response (1-5)", 1, 5, key=f"rating_{len(st.session_state.chat_history)}")
comments = st.text_area("Comments", key=f"comments_{len(st.session_state.chat_history)}")
if st.form_submit_button("Submit Feedback"):
st.session_state.feedback.append({
"rating": rating,
"comments": comments,
"response": st.session_state.chat_history[-1]["content"]
})
with open("feedback.json", "w") as f:
json.dump(st.session_state.feedback, f)
st.success("Feedback submitted!")
with col2:
if st.session_state.get("system_initialized", False):
st.subheader("Retrieved Information")
info_container = st.container(height=500)
with info_container:
if "current_faqs" in st.session_state:
for i, faq in enumerate(st.session_state.current_faqs):
st.markdown(f"**Relevant FAQ #{i+1}**")
st.markdown(f"**Q**: {faq['question']}")
st.markdown(f"**A**: {faq['answer'][:150]}..." if len(faq['answer']) > 150 else f"**A**: {faq['answer']}")
st.markdown(f"*Similarity Score*: {faq['similarity']:.2f}")
if 'category' in faq and faq['category']:
st.markdown(f"*Category*: {faq['category']}")
st.markdown("---")
else:
st.markdown("Ask a question to see relevant FAQs.")
if "retrieval_time" in st.session_state and "generation_time" in st.session_state:
st.sidebar.subheader("Performance Metrics")
st.sidebar.markdown(f"Retrieval time: {st.session_state.retrieval_time:.2f} seconds")
st.sidebar.markdown(f"Response generation: {st.session_state.generation_time:.2f} seconds")
st.sidebar.markdown(f"Total time: {st.session_state.retrieval_time + st.session_state.generation_time:.2f} seconds")
if submit_button and user_query:
from src.data_processing import translate_faq
translator = GoogleTranslator(source='auto', target='en') # Updated translator
if target_lang != "en":
user_query_translated = translator.translate(user_query)
else:
user_query_translated = user_query
if user_query_translated in st.session_state.query_cache:
response, relevant_faqs = st.session_state.query_cache[user_query_translated]
else:
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
start_time = time.time()
relevant_faqs = st.session_state.embedder.retrieve_relevant_faqs(user_query_translated)
retrieval_time = time.time() - start_time
if target_lang != "en":
relevant_faqs = [translate_faq(faq, target_lang) for faq in relevant_faqs]
start_time = time.time()
response = st.session_state.response_generator.generate_response(user_query_translated, relevant_faqs)
generation_time = time.time() - start_time
if target_lang != "en":
response = translator.translate(response, target=target_lang)
st.session_state.query_cache[user_query_translated] = (response, relevant_faqs)
st.session_state.retrieval_time = retrieval_time
st.session_state.generation_time = generation_time
st.session_state.current_faqs = relevant_faqs
st.session_state.chat_history.append({"role": "user", "content": user_query})
st.session_state.chat_history.append({"role": "assistant", "content": response})
if st.button("Clear Chat History"):
st.session_state.chat_history = []
st.session_state.query_cache = {}
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
if st.session_state.get("system_initialized", False):
st.sidebar.subheader("Baseline Comparison")
baseline_faqs = baseline_keyword_search(user_query_translated if 'user_query_translated' in locals() else "", st.session_state.embedder.faqs)
st.sidebar.write(f"RAG FAQs: {[faq['question'][:50] for faq in st.session_state.get('current_faqs', [])]}")
st.sidebar.write(f"Keyword FAQs: {[faq['question'][:50] for faq in baseline_faqs]}")
st.subheader("Sample Questions")
sample_questions = [
"How do I track my order?",
"What should I do if my delivery is delayed?",
"How do I return a product?",
"Can I cancel my order after placing it?",
"How quickly will my order be delivered?"
]
cols = st.columns(2)
for i, question in enumerate(sample_questions):
col_idx = i % 2
if cols[col_idx].button(question, key=f"sample_{i}"):
st.session_state.user_input = question
st.session_state.chat_history.append({"role": "user", "content": question})
translator = GoogleTranslator(source='auto', target='en') # Updated translator
if target_lang != "en":
question_translated = translator.translate(question)
else:
question_translated = question
if question_translated in st.session_state.query_cache:
response, relevant_faqs = st.session_state.query_cache[question_translated]
else:
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
start_time = time.time()
relevant_faqs = st.session_state.embedder.retrieve_relevant_faqs(question_translated)
retrieval_time = time.time() - start_time
if target_lang != "en":
relevant_faqs = [translate_faq(faq, target_lang) for faq in relevant_faqs]
start_time = time.time()
response = st.session_state.response_generator.generate_response(question_translated, relevant_faqs)
generation_time = time.time() - start_time
if target_lang != "en":
response = translator.translate(response, target=target_lang)
st.session_state.query_cache[question_translated] = (response, relevant_faqs)
st.session_state.retrieval_time = retrieval_time
st.session_state.generation_time = generation_time
st.session_state.current_faqs = relevant_faqs
st.session_state.chat_history.append({"role": "assistant", "content": response})
if __name__ == "__main__":
main()
# import streamlit as st
# import time
# import os
# import gc
# import torch
# from src.data_processing import load_huggingface_faq_data, load_faq_data, preprocess_faq, augment_faqs
# from src.embedding import FAQEmbedder
# from src.llm_response import ResponseGenerator
# from src.utils import time_function, format_memory_stats, evaluate_response, evaluate_retrieval, baseline_keyword_search
# # Suppress CUDA warning and Torch path errors
# os.environ["CUDA_VISIBLE_DEVICES"] = ""
# os.environ["TORCH_NO_PATH_CHECK"] = "1"
# st.set_page_config(page_title="E-Commerce FAQ Chatbot", layout="wide", initial_sidebar_state="expanded")
# @time_function
# def initialize_components(use_huggingface: bool = True, model_name: str = "microsoft/phi-2", enable_augmentation: bool = True):
# """
# Initialize RAG system components
# """
# try:
# if use_huggingface:
# faqs = load_huggingface_faq_data("NebulaByte/E-Commerce_FAQs")
# else:
# faqs = load_faq_data("data/faq_data.csv")
# processed_faqs = augment_faqs(preprocess_faq(faqs), enable_augmentation=enable_augmentation)
# embedder = FAQEmbedder()
# if os.path.exists("embeddings"):
# embedder.load("embeddings")
# else:
# embedder.create_embeddings(processed_faqs)
# embedder.save("embeddings")
# gc.collect()
# if torch.cuda.is_available():
# torch.cuda.empty_cache()
# response_generator = ResponseGenerator(model_name=model_name)
# response_generator.generate_response("Warmup query", [{"question": "Test", "answer": "Test"}])
# return embedder, response_generator, len(processed_faqs)
# except Exception as e:
# st.error(f"Initialization failed: {e}")
# raise
# def main():
# st.title("E-Commerce Customer Support FAQ Chatbot")
# st.subheader("Ask about orders, shipping, returns, or other e-commerce queries")
# st.sidebar.title("Configuration")
# use_huggingface = st.sidebar.checkbox("Use Hugging Face Dataset", value=True)
# enable_augmentation = st.sidebar.checkbox("Enable FAQ Augmentation", value=True, help="Generate paraphrased questions to expand dataset")
# target_lang = st.sidebar.selectbox("Language", ["en", "es", "fr"], index=0)
# model_options = {
# "Phi-2 (Recommended for 16GB RAM)": "microsoft/phi-2",
# "TinyLlama-1.1B (Fastest)": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
# "Mistral-7B (For 15GB+ GPU)": "mistralai/Mistral-7B-Instruct-v0.1"
# }
# selected_model = st.sidebar.selectbox("Select LLM Model", list(model_options.keys()), index=0)
# model_name = model_options[selected_model]
# if st.sidebar.checkbox("Show Memory Usage", value=True):
# st.sidebar.subheader("Memory Usage")
# for key, value in format_memory_stats().items():
# st.sidebar.text(f"{key}: {value}")
# if "chat_history" not in st.session_state:
# st.session_state.chat_history = []
# if "query_cache" not in st.session_state:
# st.session_state.query_cache = {}
# if "feedback" not in st.session_state:
# st.session_state.feedback = []
# if "system_initialized" not in st.session_state or st.sidebar.button("Reload System"):
# with st.spinner("Initializing system..."):
# try:
# st.session_state.embedder, st.session_state.response_generator, num_faqs = initialize_components(
# use_huggingface=use_huggingface,
# model_name=model_name,
# enable_augmentation=enable_augmentation
# )
# st.session_state.system_initialized = True
# st.sidebar.success(f"System initialized with {num_faqs} FAQs!")
# except Exception as e:
# st.error(f"System initialization failed: {e}")
# return
# col1, col2 = st.columns([2, 1])
# with col1:
# st.subheader("Conversation")
# chat_container = st.container(height=400)
# with chat_container:
# for i, message in enumerate(st.session_state.chat_history):
# if message["role"] == "user":
# st.markdown(f"**You**: {message['content']}")
# else:
# st.markdown(f"**Bot**: {message['content']}")
# if i < len(st.session_state.chat_history) - 1:
# st.markdown("---")
# with st.form(key="chat_form"):
# user_query = st.text_input("Type your question:", key="user_input", placeholder="e.g., How do I track my order?")
# submit_button = st.form_submit_button("Ask")
# if len(st.session_state.chat_history) > 0:
# with st.form(key=f"feedback_form_{len(st.session_state.chat_history)}"):
# rating = st.slider("Rate this response (1-5)", 1, 5, key=f"rating_{len(st.session_state.chat_history)}")
# comments = st.text_area("Comments", key=f"comments_{len(st.session_state.chat_history)}")
# if st.form_submit_button("Submit Feedback"):
# st.session_state.feedback.append({
# "rating": rating,
# "comments": comments,
# "response": st.session_state.chat_history[-1]["content"]
# })
# with open("feedback.json", "w") as f:
# json.dump(st.session_state.feedback, f)
# st.success("Feedback submitted!")
# with col2:
# if st.session_state.get("system_initialized", False):
# st.subheader("Retrieved Information")
# info_container = st.container(height=500)
# with info_container:
# if "current_faqs" in st.session_state:
# for i, faq in enumerate(st.session_state.current_faqs):
# st.markdown(f"**Relevant FAQ #{i+1}**")
# st.markdown(f"**Q**: {faq['question']}")
# st.markdown(f"**A**: {faq['answer'][:150]}..." if len(faq['answer']) > 150 else f"**A**: {faq['answer']}")
# st.markdown(f"*Similarity Score*: {faq['similarity']:.2f}")
# if 'category' in faq and faq['category']:
# st.markdown(f"*Category*: {faq['category']}")
# st.markdown("---")
# else:
# st.markdown("Ask a question to see relevant FAQs.")
# if "retrieval_time" in st.session_state and "generation_time" in st.session_state:
# st.sidebar.subheader("Performance Metrics")
# st.sidebar.markdown(f"Retrieval time: {st.session_state.retrieval_time:.2f} seconds")
# st.sidebar.markdown(f"Response generation: {st.session_state.generation_time:.2f} seconds")
# st.sidebar.markdown(f"Total time: {st.session_state.retrieval_time + st.session_state.generation_time:.2f} seconds")
# if submit_button and user_query:
# from src.data_processing import translate_faq
# from googletrans import Translator
# translator = Translator()
# if target_lang != "en":
# user_query_translated = translator.translate(user_query, dest="en").text
# else:
# user_query_translated = user_query
# if user_query_translated in st.session_state.query_cache:
# response, relevant_faqs = st.session_state.query_cache[user_query_translated]
# else:
# gc.collect()
# if torch.cuda.is_available():
# torch.cuda.empty_cache()
# start_time = time.time()
# relevant_faqs = st.session_state.embedder.retrieve_relevant_faqs(user_query_translated)
# retrieval_time = time.time() - start_time
# if target_lang != "en":
# relevant_faqs = [translate_faq(faq, target_lang) for faq in relevant_faqs]
# start_time = time.time()
# response = st.session_state.response_generator.generate_response(user_query_translated, relevant_faqs)
# generation_time = time.time() - start_time
# if target_lang != "en":
# response = translator.translate(response, dest=target_lang).text
# st.session_state.query_cache[user_query_translated] = (response, relevant_faqs)
# st.session_state.retrieval_time = retrieval_time
# st.session_state.generation_time = generation_time
# st.session_state.current_faqs = relevant_faqs
# st.session_state.chat_history.append({"role": "user", "content": user_query})
# st.session_state.chat_history.append({"role": "assistant", "content": response})
# if st.button("Clear Chat History"):
# st.session_state.chat_history = []
# st.session_state.query_cache = {}
# gc.collect()
# if torch.cuda.is_available():
# torch.cuda.empty_cache()
# if st.session_state.get("system_initialized", False):
# st.sidebar.subheader("Baseline Comparison")
# baseline_faqs = baseline_keyword_search(user_query_translated if 'user_query_translated' in locals() else "", st.session_state.embedder.faqs)
# st.sidebar.write(f"RAG FAQs: {[faq['question'][:50] for faq in st.session_state.get('current_faqs', [])]}")
# st.sidebar.write(f"Keyword FAQs: {[faq['question'][:50] for faq in baseline_faqs]}")
# st.subheader("Sample Questions")
# sample_questions = [
# "How do I track my order?",
# "What should I do if my delivery is delayed?",
# "How do I return a product?",
# "Can I cancel my order after placing it?",
# "How quickly will my order be delivered?"
# ]
# cols = st.columns(2)
# for i, question in enumerate(sample_questions):
# col_idx = i % 2
# if cols[col_idx].button(question, key=f"sample_{i}"):
# st.session_state.user_input = question
# st.session_state.chat_history.append({"role": "user", "content": question})
# from src.data_processing import translate_faq
# from googletrans import Translator
# translator = Translator()
# if target_lang != "en":
# question_translated = translator.translate(question, dest="en").text
# else:
# question_translated = question
# if question_translated in st.session_state.query_cache:
# response, relevant_faqs = st.session_state.query_cache[question_translated]
# else:
# gc.collect()
# if torch.cuda.is_available():
# torch.cuda.empty_cache()
# start_time = time.time()
# relevant_faqs = st.session_state.embedder.retrieve_relevant_faqs(question_translated)
# retrieval_time = time.time() - start_time
# if target_lang != "en":
# relevant_faqs = [translate_faq(faq, target_lang) for faq in relevant_faqs]
# start_time = time.time()
# response = st.session_state.response_generator.generate_response(question_translated, relevant_faqs)
# generation_time = time.time() - start_time
# if target_lang != "en":
# response = translator.translate(response, dest=target_lang).text
# st.session_state.query_cache[question_translated] = (response, relevant_faqs)
# st.session_state.retrieval_time = retrieval_time
# st.session_state.generation_time = generation_time
# st.session_state.current_faqs = relevant_faqs
# st.session_state.chat_history.append({"role": "assistant", "content": response})
# if __name__ == "__main__":
# main()