Spaces:
Sleeping
Sleeping
File size: 11,891 Bytes
26d1a81 f402ae8 26d1a81 f402ae8 26d1a81 f402ae8 26d1a81 f402ae8 26d1a81 f402ae8 26d1a81 f402ae8 26d1a81 f402ae8 26d1a81 f402ae8 26d1a81 f402ae8 26d1a81 f402ae8 26d1a81 f402ae8 26d1a81 f402ae8 26d1a81 f402ae8 26d1a81 f402ae8 26d1a81 f402ae8 26d1a81 f402ae8 26d1a81 f402ae8 26d1a81 f402ae8 26d1a81 f402ae8 26d1a81 f402ae8 26d1a81 f402ae8 26d1a81 f402ae8 26d1a81 f402ae8 26d1a81 f402ae8 26d1a81 f402ae8 26d1a81 f402ae8 26d1a81 f402ae8 26d1a81 f402ae8 26d1a81 f402ae8 26d1a81 f402ae8 26d1a81 f402ae8 26d1a81 f402ae8 26d1a81 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 |
import streamlit as st
import time
import os
import gc
import torch
from src.data_processing import load_huggingface_faq_data, load_faq_data, preprocess_faq, augment_faqs
from src.embedding import FAQEmbedder
from src.llm_response import ResponseGenerator
from src.utils import time_function, format_memory_stats, evaluate_response, evaluate_retrieval, baseline_keyword_search
# Suppress CUDA warning and Torch path errors
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["TORCH_NO_PATH_CHECK"] = "1"
st.set_page_config(page_title="E-Commerce FAQ Chatbot", layout="wide", initial_sidebar_state="expanded")
@time_function
def initialize_components(use_huggingface: bool = True, model_name: str = "microsoft/phi-2", enable_augmentation: bool = True):
"""
Initialize RAG system components
"""
try:
if use_huggingface:
faqs = load_huggingface_faq_data("NebulaByte/E-Commerce_FAQs")
else:
faqs = load_faq_data("data/faq_data.csv")
processed_faqs = augment_faqs(preprocess_faq(faqs), enable_augmentation=enable_augmentation)
embedder = FAQEmbedder()
if os.path.exists("embeddings"):
embedder.load("embeddings")
else:
embedder.create_embeddings(processed_faqs)
embedder.save("embeddings")
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
response_generator = ResponseGenerator(model_name=model_name)
response_generator.generate_response("Warmup query", [{"question": "Test", "answer": "Test"}])
return embedder, response_generator, len(processed_faqs)
except Exception as e:
st.error(f"Initialization failed: {e}")
raise
def main():
st.title("E-Commerce Customer Support FAQ Chatbot")
st.subheader("Ask about orders, shipping, returns, or other e-commerce queries")
st.sidebar.title("Configuration")
use_huggingface = st.sidebar.checkbox("Use Hugging Face Dataset", value=True)
enable_augmentation = st.sidebar.checkbox("Enable FAQ Augmentation", value=True, help="Generate paraphrased questions to expand dataset")
target_lang = st.sidebar.selectbox("Language", ["en", "es", "fr"], index=0)
model_options = {
"Phi-2 (Recommended for 16GB RAM)": "microsoft/phi-2",
"TinyLlama-1.1B (Fastest)": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"Mistral-7B (For 15GB+ GPU)": "mistralai/Mistral-7B-Instruct-v0.1"
}
selected_model = st.sidebar.selectbox("Select LLM Model", list(model_options.keys()), index=0)
model_name = model_options[selected_model]
if st.sidebar.checkbox("Show Memory Usage", value=True):
st.sidebar.subheader("Memory Usage")
for key, value in format_memory_stats().items():
st.sidebar.text(f"{key}: {value}")
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
if "query_cache" not in st.session_state:
st.session_state.query_cache = {}
if "feedback" not in st.session_state:
st.session_state.feedback = []
if "system_initialized" not in st.session_state or st.sidebar.button("Reload System"):
with st.spinner("Initializing system..."):
try:
st.session_state.embedder, st.session_state.response_generator, num_faqs = initialize_components(
use_huggingface=use_huggingface,
model_name=model_name,
enable_augmentation=enable_augmentation
)
st.session_state.system_initialized = True
st.sidebar.success(f"System initialized with {num_faqs} FAQs!")
except Exception as e:
st.error(f"System initialization failed: {e}")
return
col1, col2 = st.columns([2, 1])
with col1:
st.subheader("Conversation")
chat_container = st.container(height=400)
with chat_container:
for i, message in enumerate(st.session_state.chat_history):
if message["role"] == "user":
st.markdown(f"**You**: {message['content']}")
else:
st.markdown(f"**Bot**: {message['content']}")
if i < len(st.session_state.chat_history) - 1:
st.markdown("---")
with st.form(key="chat_form"):
user_query = st.text_input("Type your question:", key="user_input", placeholder="e.g., How do I track my order?")
submit_button = st.form_submit_button("Ask")
if len(st.session_state.chat_history) > 0:
with st.form(key=f"feedback_form_{len(st.session_state.chat_history)}"):
rating = st.slider("Rate this response (1-5)", 1, 5, key=f"rating_{len(st.session_state.chat_history)}")
comments = st.text_area("Comments", key=f"comments_{len(st.session_state.chat_history)}")
if st.form_submit_button("Submit Feedback"):
st.session_state.feedback.append({
"rating": rating,
"comments": comments,
"response": st.session_state.chat_history[-1]["content"]
})
with open("feedback.json", "w") as f:
json.dump(st.session_state.feedback, f)
st.success("Feedback submitted!")
with col2:
if st.session_state.get("system_initialized", False):
st.subheader("Retrieved Information")
info_container = st.container(height=500)
with info_container:
if "current_faqs" in st.session_state:
for i, faq in enumerate(st.session_state.current_faqs):
st.markdown(f"**Relevant FAQ #{i+1}**")
st.markdown(f"**Q**: {faq['question']}")
st.markdown(f"**A**: {faq['answer'][:150]}..." if len(faq['answer']) > 150 else f"**A**: {faq['answer']}")
st.markdown(f"*Similarity Score*: {faq['similarity']:.2f}")
if 'category' in faq and faq['category']:
st.markdown(f"*Category*: {faq['category']}")
st.markdown("---")
else:
st.markdown("Ask a question to see relevant FAQs.")
if "retrieval_time" in st.session_state and "generation_time" in st.session_state:
st.sidebar.subheader("Performance Metrics")
st.sidebar.markdown(f"Retrieval time: {st.session_state.retrieval_time:.2f} seconds")
st.sidebar.markdown(f"Response generation: {st.session_state.generation_time:.2f} seconds")
st.sidebar.markdown(f"Total time: {st.session_state.retrieval_time + st.session_state.generation_time:.2f} seconds")
if submit_button and user_query:
from src.data_processing import translate_faq
from googletrans import Translator
translator = Translator()
if target_lang != "en":
user_query_translated = translator.translate(user_query, dest="en").text
else:
user_query_translated = user_query
if user_query_translated in st.session_state.query_cache:
response, relevant_faqs = st.session_state.query_cache[user_query_translated]
else:
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
start_time = time.time()
relevant_faqs = st.session_state.embedder.retrieve_relevant_faqs(user_query_translated)
retrieval_time = time.time() - start_time
if target_lang != "en":
relevant_faqs = [translate_faq(faq, target_lang) for faq in relevant_faqs]
start_time = time.time()
response = st.session_state.response_generator.generate_response(user_query_translated, relevant_faqs)
generation_time = time.time() - start_time
if target_lang != "en":
response = translator.translate(response, dest=target_lang).text
st.session_state.query_cache[user_query_translated] = (response, relevant_faqs)
st.session_state.retrieval_time = retrieval_time
st.session_state.generation_time = generation_time
st.session_state.current_faqs = relevant_faqs
st.session_state.chat_history.append({"role": "user", "content": user_query})
st.session_state.chat_history.append({"role": "assistant", "content": response})
if st.button("Clear Chat History"):
st.session_state.chat_history = []
st.session_state.query_cache = {}
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
if st.session_state.get("system_initialized", False):
st.sidebar.subheader("Baseline Comparison")
baseline_faqs = baseline_keyword_search(user_query_translated if 'user_query_translated' in locals() else "", st.session_state.embedder.faqs)
st.sidebar.write(f"RAG FAQs: {[faq['question'][:50] for faq in st.session_state.get('current_faqs', [])]}")
st.sidebar.write(f"Keyword FAQs: {[faq['question'][:50] for faq in baseline_faqs]}")
st.subheader("Sample Questions")
sample_questions = [
"How do I track my order?",
"What should I do if my delivery is delayed?",
"How do I return a product?",
"Can I cancel my order after placing it?",
"How quickly will my order be delivered?"
]
cols = st.columns(2)
for i, question in enumerate(sample_questions):
col_idx = i % 2
if cols[col_idx].button(question, key=f"sample_{i}"):
st.session_state.user_input = question
st.session_state.chat_history.append({"role": "user", "content": question})
from src.data_processing import translate_faq
from googletrans import Translator
translator = Translator()
if target_lang != "en":
question_translated = translator.translate(question, dest="en").text
else:
question_translated = question
if question_translated in st.session_state.query_cache:
response, relevant_faqs = st.session_state.query_cache[question_translated]
else:
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
start_time = time.time()
relevant_faqs = st.session_state.embedder.retrieve_relevant_faqs(question_translated)
retrieval_time = time.time() - start_time
if target_lang != "en":
relevant_faqs = [translate_faq(faq, target_lang) for faq in relevant_faqs]
start_time = time.time()
response = st.session_state.response_generator.generate_response(question_translated, relevant_faqs)
generation_time = time.time() - start_time
if target_lang != "en":
response = translator.translate(response, dest=target_lang).text
st.session_state.query_cache[question_translated] = (response, relevant_faqs)
st.session_state.retrieval_time = retrieval_time
st.session_state.generation_time = generation_time
st.session_state.current_faqs = relevant_faqs
st.session_state.chat_history.append({"role": "assistant", "content": response})
if __name__ == "__main__":
main() |