File size: 11,891 Bytes
26d1a81
 
 
 
 
f402ae8
26d1a81
 
f402ae8
26d1a81
f402ae8
 
 
26d1a81
f402ae8
26d1a81
 
f402ae8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26d1a81
 
 
f402ae8
26d1a81
 
 
f402ae8
 
26d1a81
 
f402ae8
 
 
26d1a81
 
 
 
 
 
f402ae8
26d1a81
 
 
 
f402ae8
 
 
 
26d1a81
 
f402ae8
 
 
 
 
 
 
 
 
 
 
 
26d1a81
 
 
 
 
 
 
 
 
 
 
 
 
 
f402ae8
26d1a81
f402ae8
26d1a81
f402ae8
 
 
 
 
 
 
 
 
 
 
 
 
 
26d1a81
 
 
 
 
 
 
f402ae8
26d1a81
 
 
 
 
 
 
 
f402ae8
26d1a81
 
 
 
 
 
 
 
f402ae8
 
 
 
 
 
 
26d1a81
f402ae8
 
 
26d1a81
 
 
f402ae8
26d1a81
f402ae8
26d1a81
 
f402ae8
 
 
26d1a81
f402ae8
26d1a81
 
f402ae8
 
 
 
26d1a81
 
 
 
f402ae8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26d1a81
 
 
 
 
 
 
f402ae8
26d1a81
 
 
 
 
 
 
 
f402ae8
 
 
 
 
 
 
 
 
 
 
26d1a81
 
 
f402ae8
26d1a81
f402ae8
26d1a81
 
f402ae8
 
 
26d1a81
f402ae8
26d1a81
 
f402ae8
 
 
 
26d1a81
 
 
 
f402ae8
26d1a81
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
import streamlit as st
import time
import os
import gc
import torch
from src.data_processing import load_huggingface_faq_data, load_faq_data, preprocess_faq, augment_faqs
from src.embedding import FAQEmbedder
from src.llm_response import ResponseGenerator
from src.utils import time_function, format_memory_stats, evaluate_response, evaluate_retrieval, baseline_keyword_search

# Suppress CUDA warning and Torch path errors
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["TORCH_NO_PATH_CHECK"] = "1"

st.set_page_config(page_title="E-Commerce FAQ Chatbot", layout="wide", initial_sidebar_state="expanded")

@time_function
def initialize_components(use_huggingface: bool = True, model_name: str = "microsoft/phi-2", enable_augmentation: bool = True):
    """
    Initialize RAG system components
    """
    try:
        if use_huggingface:
            faqs = load_huggingface_faq_data("NebulaByte/E-Commerce_FAQs")
        else:
            faqs = load_faq_data("data/faq_data.csv")
        
        processed_faqs = augment_faqs(preprocess_faq(faqs), enable_augmentation=enable_augmentation)
        embedder = FAQEmbedder()
        
        if os.path.exists("embeddings"):
            embedder.load("embeddings")
        else:
            embedder.create_embeddings(processed_faqs)
            embedder.save("embeddings")
        
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        
        response_generator = ResponseGenerator(model_name=model_name)
        response_generator.generate_response("Warmup query", [{"question": "Test", "answer": "Test"}])
        
        return embedder, response_generator, len(processed_faqs)
    except Exception as e:
        st.error(f"Initialization failed: {e}")
        raise

def main():
    st.title("E-Commerce Customer Support FAQ Chatbot")
    st.subheader("Ask about orders, shipping, returns, or other e-commerce queries")
    
    st.sidebar.title("Configuration")
    use_huggingface = st.sidebar.checkbox("Use Hugging Face Dataset", value=True)
    enable_augmentation = st.sidebar.checkbox("Enable FAQ Augmentation", value=True, help="Generate paraphrased questions to expand dataset")
    target_lang = st.sidebar.selectbox("Language", ["en", "es", "fr"], index=0)
    
    model_options = {
        "Phi-2 (Recommended for 16GB RAM)": "microsoft/phi-2",
        "TinyLlama-1.1B (Fastest)": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
        "Mistral-7B (For 15GB+ GPU)": "mistralai/Mistral-7B-Instruct-v0.1"
    }
    selected_model = st.sidebar.selectbox("Select LLM Model", list(model_options.keys()), index=0)
    model_name = model_options[selected_model]
    
    if st.sidebar.checkbox("Show Memory Usage", value=True):
        st.sidebar.subheader("Memory Usage")
        for key, value in format_memory_stats().items():
            st.sidebar.text(f"{key}: {value}")
    
    if "chat_history" not in st.session_state:
        st.session_state.chat_history = []
    if "query_cache" not in st.session_state:
        st.session_state.query_cache = {}
    if "feedback" not in st.session_state:
        st.session_state.feedback = []
    
    if "system_initialized" not in st.session_state or st.sidebar.button("Reload System"):
        with st.spinner("Initializing system..."):
            try:
                st.session_state.embedder, st.session_state.response_generator, num_faqs = initialize_components(
                    use_huggingface=use_huggingface,
                    model_name=model_name,
                    enable_augmentation=enable_augmentation
                )
                st.session_state.system_initialized = True
                st.sidebar.success(f"System initialized with {num_faqs} FAQs!")
            except Exception as e:
                st.error(f"System initialization failed: {e}")
                return
    
    col1, col2 = st.columns([2, 1])
    
    with col1:
        st.subheader("Conversation")
        chat_container = st.container(height=400)
        with chat_container:
            for i, message in enumerate(st.session_state.chat_history):
                if message["role"] == "user":
                    st.markdown(f"**You**: {message['content']}")
                else:
                    st.markdown(f"**Bot**: {message['content']}")
                if i < len(st.session_state.chat_history) - 1:
                    st.markdown("---")
        
        with st.form(key="chat_form"):
            user_query = st.text_input("Type your question:", key="user_input", placeholder="e.g., How do I track my order?")
            submit_button = st.form_submit_button("Ask")
        
        if len(st.session_state.chat_history) > 0:
            with st.form(key=f"feedback_form_{len(st.session_state.chat_history)}"):
                rating = st.slider("Rate this response (1-5)", 1, 5, key=f"rating_{len(st.session_state.chat_history)}")
                comments = st.text_area("Comments", key=f"comments_{len(st.session_state.chat_history)}")
                if st.form_submit_button("Submit Feedback"):
                    st.session_state.feedback.append({
                        "rating": rating,
                        "comments": comments,
                        "response": st.session_state.chat_history[-1]["content"]
                    })
                    with open("feedback.json", "w") as f:
                        json.dump(st.session_state.feedback, f)
                    st.success("Feedback submitted!")
    
    with col2:
        if st.session_state.get("system_initialized", False):
            st.subheader("Retrieved Information")
            info_container = st.container(height=500)
            with info_container:
                if "current_faqs" in st.session_state:
                    for i, faq in enumerate(st.session_state.current_faqs):
                        st.markdown(f"**Relevant FAQ #{i+1}**")
                        st.markdown(f"**Q**: {faq['question']}")
                        st.markdown(f"**A**: {faq['answer'][:150]}..." if len(faq['answer']) > 150 else f"**A**: {faq['answer']}")
                        st.markdown(f"*Similarity Score*: {faq['similarity']:.2f}")
                        if 'category' in faq and faq['category']:
                            st.markdown(f"*Category*: {faq['category']}")
                        st.markdown("---")
                else:
                    st.markdown("Ask a question to see relevant FAQs.")
    
    if "retrieval_time" in st.session_state and "generation_time" in st.session_state:
        st.sidebar.subheader("Performance Metrics")
        st.sidebar.markdown(f"Retrieval time: {st.session_state.retrieval_time:.2f} seconds")
        st.sidebar.markdown(f"Response generation: {st.session_state.generation_time:.2f} seconds")
        st.sidebar.markdown(f"Total time: {st.session_state.retrieval_time + st.session_state.generation_time:.2f} seconds")
    
    if submit_button and user_query:
        from src.data_processing import translate_faq
        from googletrans import Translator
        translator = Translator()
        if target_lang != "en":
            user_query_translated = translator.translate(user_query, dest="en").text
        else:
            user_query_translated = user_query
        
        if user_query_translated in st.session_state.query_cache:
            response, relevant_faqs = st.session_state.query_cache[user_query_translated]
        else:
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            
            start_time = time.time()
            relevant_faqs = st.session_state.embedder.retrieve_relevant_faqs(user_query_translated)
            retrieval_time = time.time() - start_time
            
            if target_lang != "en":
                relevant_faqs = [translate_faq(faq, target_lang) for faq in relevant_faqs]
            
            start_time = time.time()
            response = st.session_state.response_generator.generate_response(user_query_translated, relevant_faqs)
            generation_time = time.time() - start_time
            
            if target_lang != "en":
                response = translator.translate(response, dest=target_lang).text
            
            st.session_state.query_cache[user_query_translated] = (response, relevant_faqs)
            st.session_state.retrieval_time = retrieval_time
            st.session_state.generation_time = generation_time
            st.session_state.current_faqs = relevant_faqs
        
        st.session_state.chat_history.append({"role": "user", "content": user_query})
        st.session_state.chat_history.append({"role": "assistant", "content": response})
    
    if st.button("Clear Chat History"):
        st.session_state.chat_history = []
        st.session_state.query_cache = {}
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    
    if st.session_state.get("system_initialized", False):
        st.sidebar.subheader("Baseline Comparison")
        baseline_faqs = baseline_keyword_search(user_query_translated if 'user_query_translated' in locals() else "", st.session_state.embedder.faqs)
        st.sidebar.write(f"RAG FAQs: {[faq['question'][:50] for faq in st.session_state.get('current_faqs', [])]}")
        st.sidebar.write(f"Keyword FAQs: {[faq['question'][:50] for faq in baseline_faqs]}")
    
    st.subheader("Sample Questions")
    sample_questions = [
        "How do I track my order?",
        "What should I do if my delivery is delayed?",
        "How do I return a product?",
        "Can I cancel my order after placing it?",
        "How quickly will my order be delivered?"
    ]
    cols = st.columns(2)
    for i, question in enumerate(sample_questions):
        col_idx = i % 2
        if cols[col_idx].button(question, key=f"sample_{i}"):
            st.session_state.user_input = question
            st.session_state.chat_history.append({"role": "user", "content": question})
            
            from src.data_processing import translate_faq
            from googletrans import Translator
            translator = Translator()
            if target_lang != "en":
                question_translated = translator.translate(question, dest="en").text
            else:
                question_translated = question
            
            if question_translated in st.session_state.query_cache:
                response, relevant_faqs = st.session_state.query_cache[question_translated]
            else:
                gc.collect()
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                
                start_time = time.time()
                relevant_faqs = st.session_state.embedder.retrieve_relevant_faqs(question_translated)
                retrieval_time = time.time() - start_time
                
                if target_lang != "en":
                    relevant_faqs = [translate_faq(faq, target_lang) for faq in relevant_faqs]
                
                start_time = time.time()
                response = st.session_state.response_generator.generate_response(question_translated, relevant_faqs)
                generation_time = time.time() - start_time
                
                if target_lang != "en":
                    response = translator.translate(response, dest=target_lang).text
                
                st.session_state.query_cache[question_translated] = (response, relevant_faqs)
                st.session_state.retrieval_time = retrieval_time
                st.session_state.generation_time = generation_time
                st.session_state.current_faqs = relevant_faqs
            
            st.session_state.chat_history.append({"role": "assistant", "content": response})

if __name__ == "__main__":
    main()