Spaces:
Sleeping
Sleeping
import streamlit as st | |
import time | |
import os | |
import gc | |
import torch | |
from src.data_processing import load_huggingface_faq_data, load_faq_data, preprocess_faq, augment_faqs | |
from src.embedding import FAQEmbedder | |
from src.llm_response import ResponseGenerator | |
from src.utils import time_function, format_memory_stats, evaluate_response, evaluate_retrieval, baseline_keyword_search | |
from deep_translator import GoogleTranslator # Updated import | |
# Suppress CUDA warning and Torch path errors | |
os.environ["CUDA_VISIBLE_DEVICES"] = "" | |
os.environ["TORCH_NO_PATH_CHECK"] = "1" | |
st.set_page_config(page_title="E-Commerce FAQ Chatbot", layout="wide", initial_sidebar_state="expanded") | |
def initialize_components(use_huggingface: bool = True, model_name: str = "microsoft/phi-2", enable_augmentation: bool = True): | |
""" | |
Initialize RAG system components | |
""" | |
try: | |
if use_huggingface: | |
faqs = load_huggingface_faq_data("NebulaByte/E-Commerce_FAQs") | |
else: | |
faqs = load_faq_data("data/faq_data.csv") | |
processed_faqs = augment_faqs(preprocess_faq(faqs), enable_augmentation=enable_augmentation) | |
embedder = FAQEmbedder() | |
if os.path.exists("embeddings"): | |
embedder.load("embeddings") | |
else: | |
embedder.create_embeddings(processed_faqs) | |
embedder.save("embeddings") | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
response_generator = ResponseGenerator(model_name=model_name) | |
response_generator.generate_response("Warmup query", [{"question": "Test", "answer": "Test"}]) | |
return embedder, response_generator, len(processed_faqs) | |
except Exception as e: | |
st.error(f"Initialization failed: {e}") | |
raise | |
def main(): | |
st.title("E-Commerce Customer Support FAQ Chatbot") | |
st.subheader("Ask about orders, shipping, returns, or other e-commerce queries") | |
st.sidebar.title("Configuration") | |
use_huggingface = st.sidebar.checkbox("Use Hugging Face Dataset", value=True) | |
enable_augmentation = st.sidebar.checkbox("Enable FAQ Augmentation", value=True, help="Generate paraphrased questions to expand dataset") | |
target_lang = st.sidebar.selectbox("Language", ["en", "es", "fr"], index=0) | |
model_options = { | |
"Phi-2 (Recommended for 16GB RAM)": "microsoft/phi-2", | |
"TinyLlama-1.1B (Fastest)": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", | |
"Mistral-7B (For 15GB+ GPU)": "mistralai/Mistral-7B-Instruct-v0.1" | |
} | |
selected_model = st.sidebar.selectbox("Select LLM Model", list(model_options.keys()), index=0) | |
model_name = model_options[selected_model] | |
if st.sidebar.checkbox("Show Memory Usage", value=True): | |
st.sidebar.subheader("Memory Usage") | |
for key, value in format_memory_stats().items(): | |
st.sidebar.text(f"{key}: {value}") | |
if "chat_history" not in st.session_state: | |
st.session_state.chat_history = [] | |
if "query_cache" not in st.session_state: | |
st.session_state.query_cache = {} | |
if "feedback" not in st.session_state: | |
st.session_state.feedback = [] | |
if "system_initialized" not in st.session_state or st.sidebar.button("Reload System"): | |
with st.spinner("Initializing system..."): | |
try: | |
st.session_state.embedder, st.session_state.response_generator, num_faqs = initialize_components( | |
use_huggingface=use_huggingface, | |
model_name=model_name, | |
enable_augmentation=enable_augmentation | |
) | |
st.session_state.system_initialized = True | |
st.sidebar.success(f"System initialized with {num_faqs} FAQs!") | |
except Exception as e: | |
st.error(f"System initialization failed: {e}") | |
return | |
col1, col2 = st.columns([2, 1]) | |
with col1: | |
st.subheader("Conversation") | |
chat_container = st.container(height=400) | |
with chat_container: | |
for i, message in enumerate(st.session_state.chat_history): | |
if message["role"] == "user": | |
st.markdown(f"**You**: {message['content']}") | |
else: | |
st.markdown(f"**Bot**: {message['content']}") | |
if i < len(st.session_state.chat_history) - 1: | |
st.markdown("---") | |
with st.form(key="chat_form"): | |
user_query = st.text_input("Type your question:", key="user_input", placeholder="e.g., How do I track my order?") | |
submit_button = st.form_submit_button("Ask") | |
if len(st.session_state.chat_history) > 0: | |
with st.form(key=f"feedback_form_{len(st.session_state.chat_history)}"): | |
rating = st.slider("Rate this response (1-5)", 1, 5, key=f"rating_{len(st.session_state.chat_history)}") | |
comments = st.text_area("Comments", key=f"comments_{len(st.session_state.chat_history)}") | |
if st.form_submit_button("Submit Feedback"): | |
st.session_state.feedback.append({ | |
"rating": rating, | |
"comments": comments, | |
"response": st.session_state.chat_history[-1]["content"] | |
}) | |
with open("feedback.json", "w") as f: | |
json.dump(st.session_state.feedback, f) | |
st.success("Feedback submitted!") | |
with col2: | |
if st.session_state.get("system_initialized", False): | |
st.subheader("Retrieved Information") | |
info_container = st.container(height=500) | |
with info_container: | |
if "current_faqs" in st.session_state: | |
for i, faq in enumerate(st.session_state.current_faqs): | |
st.markdown(f"**Relevant FAQ #{i+1}**") | |
st.markdown(f"**Q**: {faq['question']}") | |
st.markdown(f"**A**: {faq['answer'][:150]}..." if len(faq['answer']) > 150 else f"**A**: {faq['answer']}") | |
st.markdown(f"*Similarity Score*: {faq['similarity']:.2f}") | |
if 'category' in faq and faq['category']: | |
st.markdown(f"*Category*: {faq['category']}") | |
st.markdown("---") | |
else: | |
st.markdown("Ask a question to see relevant FAQs.") | |
if "retrieval_time" in st.session_state and "generation_time" in st.session_state: | |
st.sidebar.subheader("Performance Metrics") | |
st.sidebar.markdown(f"Retrieval time: {st.session_state.retrieval_time:.2f} seconds") | |
st.sidebar.markdown(f"Response generation: {st.session_state.generation_time:.2f} seconds") | |
st.sidebar.markdown(f"Total time: {st.session_state.retrieval_time + st.session_state.generation_time:.2f} seconds") | |
if submit_button and user_query: | |
from src.data_processing import translate_faq | |
translator = GoogleTranslator(source='auto', target='en') # Updated translator | |
if target_lang != "en": | |
user_query_translated = translator.translate(user_query) | |
else: | |
user_query_translated = user_query | |
if user_query_translated in st.session_state.query_cache: | |
response, relevant_faqs = st.session_state.query_cache[user_query_translated] | |
else: | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
start_time = time.time() | |
relevant_faqs = st.session_state.embedder.retrieve_relevant_faqs(user_query_translated) | |
retrieval_time = time.time() - start_time | |
if target_lang != "en": | |
relevant_faqs = [translate_faq(faq, target_lang) for faq in relevant_faqs] | |
start_time = time.time() | |
response = st.session_state.response_generator.generate_response(user_query_translated, relevant_faqs) | |
generation_time = time.time() - start_time | |
if target_lang != "en": | |
response = translator.translate(response, target=target_lang) | |
st.session_state.query_cache[user_query_translated] = (response, relevant_faqs) | |
st.session_state.retrieval_time = retrieval_time | |
st.session_state.generation_time = generation_time | |
st.session_state.current_faqs = relevant_faqs | |
st.session_state.chat_history.append({"role": "user", "content": user_query}) | |
st.session_state.chat_history.append({"role": "assistant", "content": response}) | |
if st.button("Clear Chat History"): | |
st.session_state.chat_history = [] | |
st.session_state.query_cache = {} | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
if st.session_state.get("system_initialized", False): | |
st.sidebar.subheader("Baseline Comparison") | |
baseline_faqs = baseline_keyword_search(user_query_translated if 'user_query_translated' in locals() else "", st.session_state.embedder.faqs) | |
st.sidebar.write(f"RAG FAQs: {[faq['question'][:50] for faq in st.session_state.get('current_faqs', [])]}") | |
st.sidebar.write(f"Keyword FAQs: {[faq['question'][:50] for faq in baseline_faqs]}") | |
st.subheader("Sample Questions") | |
sample_questions = [ | |
"How do I track my order?", | |
"What should I do if my delivery is delayed?", | |
"How do I return a product?", | |
"Can I cancel my order after placing it?", | |
"How quickly will my order be delivered?" | |
] | |
cols = st.columns(2) | |
for i, question in enumerate(sample_questions): | |
col_idx = i % 2 | |
if cols[col_idx].button(question, key=f"sample_{i}"): | |
st.session_state.user_input = question | |
st.session_state.chat_history.append({"role": "user", "content": question}) | |
translator = GoogleTranslator(source='auto', target='en') # Updated translator | |
if target_lang != "en": | |
question_translated = translator.translate(question) | |
else: | |
question_translated = question | |
if question_translated in st.session_state.query_cache: | |
response, relevant_faqs = st.session_state.query_cache[question_translated] | |
else: | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
start_time = time.time() | |
relevant_faqs = st.session_state.embedder.retrieve_relevant_faqs(question_translated) | |
retrieval_time = time.time() - start_time | |
if target_lang != "en": | |
relevant_faqs = [translate_faq(faq, target_lang) for faq in relevant_faqs] | |
start_time = time.time() | |
response = st.session_state.response_generator.generate_response(question_translated, relevant_faqs) | |
generation_time = time.time() - start_time | |
if target_lang != "en": | |
response = translator.translate(response, target=target_lang) | |
st.session_state.query_cache[question_translated] = (response, relevant_faqs) | |
st.session_state.retrieval_time = retrieval_time | |
st.session_state.generation_time = generation_time | |
st.session_state.current_faqs = relevant_faqs | |
st.session_state.chat_history.append({"role": "assistant", "content": response}) | |
if __name__ == "__main__": | |
main() | |
# import streamlit as st | |
# import time | |
# import os | |
# import gc | |
# import torch | |
# from src.data_processing import load_huggingface_faq_data, load_faq_data, preprocess_faq, augment_faqs | |
# from src.embedding import FAQEmbedder | |
# from src.llm_response import ResponseGenerator | |
# from src.utils import time_function, format_memory_stats, evaluate_response, evaluate_retrieval, baseline_keyword_search | |
# # Suppress CUDA warning and Torch path errors | |
# os.environ["CUDA_VISIBLE_DEVICES"] = "" | |
# os.environ["TORCH_NO_PATH_CHECK"] = "1" | |
# st.set_page_config(page_title="E-Commerce FAQ Chatbot", layout="wide", initial_sidebar_state="expanded") | |
# @time_function | |
# def initialize_components(use_huggingface: bool = True, model_name: str = "microsoft/phi-2", enable_augmentation: bool = True): | |
# """ | |
# Initialize RAG system components | |
# """ | |
# try: | |
# if use_huggingface: | |
# faqs = load_huggingface_faq_data("NebulaByte/E-Commerce_FAQs") | |
# else: | |
# faqs = load_faq_data("data/faq_data.csv") | |
# processed_faqs = augment_faqs(preprocess_faq(faqs), enable_augmentation=enable_augmentation) | |
# embedder = FAQEmbedder() | |
# if os.path.exists("embeddings"): | |
# embedder.load("embeddings") | |
# else: | |
# embedder.create_embeddings(processed_faqs) | |
# embedder.save("embeddings") | |
# gc.collect() | |
# if torch.cuda.is_available(): | |
# torch.cuda.empty_cache() | |
# response_generator = ResponseGenerator(model_name=model_name) | |
# response_generator.generate_response("Warmup query", [{"question": "Test", "answer": "Test"}]) | |
# return embedder, response_generator, len(processed_faqs) | |
# except Exception as e: | |
# st.error(f"Initialization failed: {e}") | |
# raise | |
# def main(): | |
# st.title("E-Commerce Customer Support FAQ Chatbot") | |
# st.subheader("Ask about orders, shipping, returns, or other e-commerce queries") | |
# st.sidebar.title("Configuration") | |
# use_huggingface = st.sidebar.checkbox("Use Hugging Face Dataset", value=True) | |
# enable_augmentation = st.sidebar.checkbox("Enable FAQ Augmentation", value=True, help="Generate paraphrased questions to expand dataset") | |
# target_lang = st.sidebar.selectbox("Language", ["en", "es", "fr"], index=0) | |
# model_options = { | |
# "Phi-2 (Recommended for 16GB RAM)": "microsoft/phi-2", | |
# "TinyLlama-1.1B (Fastest)": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", | |
# "Mistral-7B (For 15GB+ GPU)": "mistralai/Mistral-7B-Instruct-v0.1" | |
# } | |
# selected_model = st.sidebar.selectbox("Select LLM Model", list(model_options.keys()), index=0) | |
# model_name = model_options[selected_model] | |
# if st.sidebar.checkbox("Show Memory Usage", value=True): | |
# st.sidebar.subheader("Memory Usage") | |
# for key, value in format_memory_stats().items(): | |
# st.sidebar.text(f"{key}: {value}") | |
# if "chat_history" not in st.session_state: | |
# st.session_state.chat_history = [] | |
# if "query_cache" not in st.session_state: | |
# st.session_state.query_cache = {} | |
# if "feedback" not in st.session_state: | |
# st.session_state.feedback = [] | |
# if "system_initialized" not in st.session_state or st.sidebar.button("Reload System"): | |
# with st.spinner("Initializing system..."): | |
# try: | |
# st.session_state.embedder, st.session_state.response_generator, num_faqs = initialize_components( | |
# use_huggingface=use_huggingface, | |
# model_name=model_name, | |
# enable_augmentation=enable_augmentation | |
# ) | |
# st.session_state.system_initialized = True | |
# st.sidebar.success(f"System initialized with {num_faqs} FAQs!") | |
# except Exception as e: | |
# st.error(f"System initialization failed: {e}") | |
# return | |
# col1, col2 = st.columns([2, 1]) | |
# with col1: | |
# st.subheader("Conversation") | |
# chat_container = st.container(height=400) | |
# with chat_container: | |
# for i, message in enumerate(st.session_state.chat_history): | |
# if message["role"] == "user": | |
# st.markdown(f"**You**: {message['content']}") | |
# else: | |
# st.markdown(f"**Bot**: {message['content']}") | |
# if i < len(st.session_state.chat_history) - 1: | |
# st.markdown("---") | |
# with st.form(key="chat_form"): | |
# user_query = st.text_input("Type your question:", key="user_input", placeholder="e.g., How do I track my order?") | |
# submit_button = st.form_submit_button("Ask") | |
# if len(st.session_state.chat_history) > 0: | |
# with st.form(key=f"feedback_form_{len(st.session_state.chat_history)}"): | |
# rating = st.slider("Rate this response (1-5)", 1, 5, key=f"rating_{len(st.session_state.chat_history)}") | |
# comments = st.text_area("Comments", key=f"comments_{len(st.session_state.chat_history)}") | |
# if st.form_submit_button("Submit Feedback"): | |
# st.session_state.feedback.append({ | |
# "rating": rating, | |
# "comments": comments, | |
# "response": st.session_state.chat_history[-1]["content"] | |
# }) | |
# with open("feedback.json", "w") as f: | |
# json.dump(st.session_state.feedback, f) | |
# st.success("Feedback submitted!") | |
# with col2: | |
# if st.session_state.get("system_initialized", False): | |
# st.subheader("Retrieved Information") | |
# info_container = st.container(height=500) | |
# with info_container: | |
# if "current_faqs" in st.session_state: | |
# for i, faq in enumerate(st.session_state.current_faqs): | |
# st.markdown(f"**Relevant FAQ #{i+1}**") | |
# st.markdown(f"**Q**: {faq['question']}") | |
# st.markdown(f"**A**: {faq['answer'][:150]}..." if len(faq['answer']) > 150 else f"**A**: {faq['answer']}") | |
# st.markdown(f"*Similarity Score*: {faq['similarity']:.2f}") | |
# if 'category' in faq and faq['category']: | |
# st.markdown(f"*Category*: {faq['category']}") | |
# st.markdown("---") | |
# else: | |
# st.markdown("Ask a question to see relevant FAQs.") | |
# if "retrieval_time" in st.session_state and "generation_time" in st.session_state: | |
# st.sidebar.subheader("Performance Metrics") | |
# st.sidebar.markdown(f"Retrieval time: {st.session_state.retrieval_time:.2f} seconds") | |
# st.sidebar.markdown(f"Response generation: {st.session_state.generation_time:.2f} seconds") | |
# st.sidebar.markdown(f"Total time: {st.session_state.retrieval_time + st.session_state.generation_time:.2f} seconds") | |
# if submit_button and user_query: | |
# from src.data_processing import translate_faq | |
# from googletrans import Translator | |
# translator = Translator() | |
# if target_lang != "en": | |
# user_query_translated = translator.translate(user_query, dest="en").text | |
# else: | |
# user_query_translated = user_query | |
# if user_query_translated in st.session_state.query_cache: | |
# response, relevant_faqs = st.session_state.query_cache[user_query_translated] | |
# else: | |
# gc.collect() | |
# if torch.cuda.is_available(): | |
# torch.cuda.empty_cache() | |
# start_time = time.time() | |
# relevant_faqs = st.session_state.embedder.retrieve_relevant_faqs(user_query_translated) | |
# retrieval_time = time.time() - start_time | |
# if target_lang != "en": | |
# relevant_faqs = [translate_faq(faq, target_lang) for faq in relevant_faqs] | |
# start_time = time.time() | |
# response = st.session_state.response_generator.generate_response(user_query_translated, relevant_faqs) | |
# generation_time = time.time() - start_time | |
# if target_lang != "en": | |
# response = translator.translate(response, dest=target_lang).text | |
# st.session_state.query_cache[user_query_translated] = (response, relevant_faqs) | |
# st.session_state.retrieval_time = retrieval_time | |
# st.session_state.generation_time = generation_time | |
# st.session_state.current_faqs = relevant_faqs | |
# st.session_state.chat_history.append({"role": "user", "content": user_query}) | |
# st.session_state.chat_history.append({"role": "assistant", "content": response}) | |
# if st.button("Clear Chat History"): | |
# st.session_state.chat_history = [] | |
# st.session_state.query_cache = {} | |
# gc.collect() | |
# if torch.cuda.is_available(): | |
# torch.cuda.empty_cache() | |
# if st.session_state.get("system_initialized", False): | |
# st.sidebar.subheader("Baseline Comparison") | |
# baseline_faqs = baseline_keyword_search(user_query_translated if 'user_query_translated' in locals() else "", st.session_state.embedder.faqs) | |
# st.sidebar.write(f"RAG FAQs: {[faq['question'][:50] for faq in st.session_state.get('current_faqs', [])]}") | |
# st.sidebar.write(f"Keyword FAQs: {[faq['question'][:50] for faq in baseline_faqs]}") | |
# st.subheader("Sample Questions") | |
# sample_questions = [ | |
# "How do I track my order?", | |
# "What should I do if my delivery is delayed?", | |
# "How do I return a product?", | |
# "Can I cancel my order after placing it?", | |
# "How quickly will my order be delivered?" | |
# ] | |
# cols = st.columns(2) | |
# for i, question in enumerate(sample_questions): | |
# col_idx = i % 2 | |
# if cols[col_idx].button(question, key=f"sample_{i}"): | |
# st.session_state.user_input = question | |
# st.session_state.chat_history.append({"role": "user", "content": question}) | |
# from src.data_processing import translate_faq | |
# from googletrans import Translator | |
# translator = Translator() | |
# if target_lang != "en": | |
# question_translated = translator.translate(question, dest="en").text | |
# else: | |
# question_translated = question | |
# if question_translated in st.session_state.query_cache: | |
# response, relevant_faqs = st.session_state.query_cache[question_translated] | |
# else: | |
# gc.collect() | |
# if torch.cuda.is_available(): | |
# torch.cuda.empty_cache() | |
# start_time = time.time() | |
# relevant_faqs = st.session_state.embedder.retrieve_relevant_faqs(question_translated) | |
# retrieval_time = time.time() - start_time | |
# if target_lang != "en": | |
# relevant_faqs = [translate_faq(faq, target_lang) for faq in relevant_faqs] | |
# start_time = time.time() | |
# response = st.session_state.response_generator.generate_response(question_translated, relevant_faqs) | |
# generation_time = time.time() - start_time | |
# if target_lang != "en": | |
# response = translator.translate(response, dest=target_lang).text | |
# st.session_state.query_cache[question_translated] = (response, relevant_faqs) | |
# st.session_state.retrieval_time = retrieval_time | |
# st.session_state.generation_time = generation_time | |
# st.session_state.current_faqs = relevant_faqs | |
# st.session_state.chat_history.append({"role": "assistant", "content": response}) | |
# if __name__ == "__main__": | |
# main() |