Techbite commited on
Commit
f402ae8
·
1 Parent(s): 3c4eeeb

changes:minor changes

Browse files
Files changed (11) hide show
  1. .gitignore +22 -3
  2. app.py +135 -120
  3. data/faq_data.csv +6 -1
  4. notes.txt +0 -16
  5. requirements.txt +8 -2
  6. src/__init__.py +1 -2
  7. src/data_processing.py +98 -44
  8. src/embedding.py +31 -23
  9. src/llm_response.py +35 -69
  10. src/utils.py +54 -13
  11. test_set.json +7 -0
.gitignore CHANGED
@@ -1,3 +1,22 @@
1
- # Ignore all files in the .venv directory
2
- venv
3
- __pycache__
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python virtual environment
2
+ venv/
3
+
4
+ # Python cache files
5
+ __pycache__/
6
+
7
+ # Data and model files
8
+ data/ecommerce_faqs.json
9
+ embeddings/
10
+ offload/
11
+ *.bin
12
+
13
+ # Feedback data
14
+ feedback.json
15
+
16
+ # NLTK data
17
+ nltk_data/
18
+
19
+ # Streamlit cache
20
+ .venv/
21
+ streamlit_cache/
22
+
app.py CHANGED
@@ -3,233 +3,248 @@ import time
3
  import os
4
  import gc
5
  import torch
6
- from src.data_processing import load_huggingface_faq_data, load_faq_data, preprocess_faq
7
  from src.embedding import FAQEmbedder
8
  from src.llm_response import ResponseGenerator
9
- from src.utils import time_function, format_memory_stats
10
 
11
- # Set page title and layout
12
- st.set_page_config(
13
- page_title="E-Commerce FAQ Chatbot",
14
- layout="wide",
15
- initial_sidebar_state="expanded"
16
- )
17
 
18
- # Memory optimization: Force garbage collection before starting
19
- gc.collect()
20
- if torch.cuda.is_available():
21
- torch.cuda.empty_cache()
22
 
23
  @time_function
24
- def initialize_components(use_huggingface: bool = True, model_name: str = "mistralai/Mistral-7B-Instruct-v0.1"):
25
- """Initialize all components of the RAG system with memory optimization"""
26
-
27
- # Step 1: Load and preprocess FAQ data
28
- if use_huggingface:
29
- faqs = load_huggingface_faq_data("NebulaByte/E-Commerce_FAQs")
30
- else:
31
- data_path = os.path.join("data", "faq_data.csv")
32
- faqs = load_faq_data(data_path)
33
-
34
- processed_faqs = preprocess_faq(faqs)
35
-
36
- # Step 2: Initialize and create embeddings
37
- # Use smaller batch size for memory efficiency
38
- embedder = FAQEmbedder()
39
- embedder.create_embeddings(processed_faqs, batch_size=32)
40
-
41
- # Clear memory before loading the LLM
42
- gc.collect()
43
- if torch.cuda.is_available():
44
- torch.cuda.empty_cache()
45
-
46
- # Step 3: Initialize response generator
47
- response_generator = ResponseGenerator(model_name=model_name)
48
-
49
- return embedder, response_generator, len(processed_faqs)
 
 
 
 
50
 
51
  def main():
52
  st.title("E-Commerce Customer Support FAQ Chatbot")
53
- st.subheader("Ask any question about your orders, shipping, returns, or any other e-commerce related queries")
54
 
55
- # Sidebar configuration
56
  st.sidebar.title("Configuration")
57
  use_huggingface = st.sidebar.checkbox("Use Hugging Face Dataset", value=True)
 
 
58
 
59
- # Model options - include smaller models by default
60
  model_options = {
61
- "Phi-2 (Recommended for 8GB GPU)": "microsoft/phi-2",
62
- "TinyLlama-1.1B (Smallest, fastest)": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
63
- "Mistral-7B (Requires 4-bit quantization)": "mistralai/Mistral-7B-Instruct-v0.1"
64
  }
65
-
66
- # Default to Phi-2 for 8-11GB GPU
67
  selected_model = st.sidebar.selectbox("Select LLM Model", list(model_options.keys()), index=0)
68
  model_name = model_options[selected_model]
69
 
70
- # Memory usage monitoring
71
  if st.sidebar.checkbox("Show Memory Usage", value=True):
72
  st.sidebar.subheader("Memory Usage")
73
- memory_stats = format_memory_stats()
74
- for key, value in memory_stats.items():
75
  st.sidebar.text(f"{key}: {value}")
76
 
77
- # Initialize session state for chat history if it doesn't exist
78
  if "chat_history" not in st.session_state:
79
  st.session_state.chat_history = []
 
 
 
 
80
 
81
- # Initialize RAG components (only once)
82
  if "system_initialized" not in st.session_state or st.sidebar.button("Reload System"):
83
- with st.spinner("Initializing system components... This may take a few minutes."):
84
- st.session_state.embedder, st.session_state.response_generator, num_faqs = initialize_components(
85
- use_huggingface=use_huggingface,
86
- model_name=model_name
87
- )
88
- st.session_state.system_initialized = True
89
- st.sidebar.success(f"System initialized with {num_faqs} FAQs!")
 
 
 
 
 
90
 
91
- # Chat interface
92
  col1, col2 = st.columns([2, 1])
93
 
94
  with col1:
95
- # Display chat history
96
  st.subheader("Conversation")
97
  chat_container = st.container(height=400)
98
-
99
  with chat_container:
100
  for i, message in enumerate(st.session_state.chat_history):
101
  if message["role"] == "user":
102
  st.markdown(f"**You**: {message['content']}")
103
  else:
104
  st.markdown(f"**Bot**: {message['content']}")
105
-
106
  if i < len(st.session_state.chat_history) - 1:
107
  st.markdown("---")
108
-
109
- # Chat input
110
  with st.form(key="chat_form"):
111
- user_query = st.text_input("Type your question:", key="user_input",
112
- placeholder="e.g., How do I track my order?")
113
  submit_button = st.form_submit_button("Ask")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
  with col2:
116
  if st.session_state.get("system_initialized", False):
117
- # Show FAQ metadata and information
118
  st.subheader("Retrieved Information")
119
  info_container = st.container(height=500)
120
-
121
  with info_container:
122
  if "current_faqs" in st.session_state:
123
- for i, faq in enumerate(st.session_state.get("current_faqs", [])):
124
  st.markdown(f"**Relevant FAQ #{i+1}**")
125
  st.markdown(f"**Q**: {faq['question']}")
126
- # Limit answer length to save UI memory
127
  st.markdown(f"**A**: {faq['answer'][:150]}..." if len(faq['answer']) > 150 else f"**A**: {faq['answer']}")
128
  st.markdown(f"*Similarity Score*: {faq['similarity']:.2f}")
129
  if 'category' in faq and faq['category']:
130
  st.markdown(f"*Category*: {faq['category']}")
131
  st.markdown("---")
132
  else:
133
- st.markdown("Ask a question to see relevant FAQs here.")
134
 
135
- # Performance metrics in the sidebar
136
  if "retrieval_time" in st.session_state and "generation_time" in st.session_state:
137
  st.sidebar.subheader("Performance Metrics")
138
  st.sidebar.markdown(f"Retrieval time: {st.session_state.retrieval_time:.2f} seconds")
139
  st.sidebar.markdown(f"Response generation: {st.session_state.generation_time:.2f} seconds")
140
  st.sidebar.markdown(f"Total time: {st.session_state.retrieval_time + st.session_state.generation_time:.2f} seconds")
141
 
142
- # Process user query
143
  if submit_button and user_query:
144
- # Add user query to chat history
145
- st.session_state.chat_history.append({"role": "user", "content": user_query})
 
 
 
 
 
146
 
147
- # Process query
148
- with st.spinner("Thinking..."):
149
- # Free memory before processing
150
  gc.collect()
151
  if torch.cuda.is_available():
152
  torch.cuda.empty_cache()
153
-
154
- # Step 1: Retrieve relevant FAQs
155
  start_time = time.time()
156
- relevant_faqs = st.session_state.embedder.retrieve_relevant_faqs(user_query)
157
  retrieval_time = time.time() - start_time
158
 
159
- # Step 2: Generate response
 
 
160
  start_time = time.time()
161
- response = st.session_state.response_generator.generate_response(user_query, relevant_faqs)
162
  generation_time = time.time() - start_time
163
 
164
- # Store metrics and retrieved FAQs
 
 
 
165
  st.session_state.retrieval_time = retrieval_time
166
  st.session_state.generation_time = generation_time
167
  st.session_state.current_faqs = relevant_faqs
168
-
169
- # Step 3: Add response to chat history
170
- st.session_state.chat_history.append({"role": "assistant", "content": response})
171
-
172
- # Free memory after processing
173
- gc.collect()
174
- if torch.cuda.is_available():
175
- torch.cuda.empty_cache()
176
 
177
- # Rerun to display the updated chat history
178
- st.experimental_rerun()
 
 
 
 
 
 
 
 
 
 
 
 
 
179
 
180
- # Add sample questions at the bottom
181
  st.subheader("Sample Questions")
182
  sample_questions = [
183
  "How do I track my order?",
184
  "What should I do if my delivery is delayed?",
185
  "How do I return a product?",
186
  "Can I cancel my order after placing it?",
187
- "How quickly will my order be delivered?",
188
- "Why can't I track my order yet?"
189
  ]
190
-
191
- # Use two columns instead of three to reduce memory usage
192
  cols = st.columns(2)
193
  for i, question in enumerate(sample_questions):
194
  col_idx = i % 2
195
  if cols[col_idx].button(question, key=f"sample_{i}"):
196
- # Clear the text input and set the sample question
197
  st.session_state.user_input = question
198
- # Simulate form submission
199
  st.session_state.chat_history.append({"role": "user", "content": question})
200
 
201
- # Process query (similar to above)
202
- with st.spinner("Thinking..."):
203
- # Free memory before processing
 
 
 
 
 
 
 
 
204
  gc.collect()
205
  if torch.cuda.is_available():
206
  torch.cuda.empty_cache()
207
-
208
- # Step 1: Retrieve relevant FAQs
209
  start_time = time.time()
210
- relevant_faqs = st.session_state.embedder.retrieve_relevant_faqs(question)
211
  retrieval_time = time.time() - start_time
212
 
213
- # Step 2: Generate response
 
 
214
  start_time = time.time()
215
- response = st.session_state.response_generator.generate_response(question, relevant_faqs)
216
  generation_time = time.time() - start_time
217
 
218
- # Store metrics and retrieved FAQs
 
 
 
219
  st.session_state.retrieval_time = retrieval_time
220
  st.session_state.generation_time = generation_time
221
  st.session_state.current_faqs = relevant_faqs
222
-
223
- # Step 3: Add response to chat history
224
- st.session_state.chat_history.append({"role": "assistant", "content": response})
225
-
226
- # Free memory after processing
227
- gc.collect()
228
- if torch.cuda.is_available():
229
- torch.cuda.empty_cache()
230
 
231
- # Rerun to display the updated chat history
232
- st.experimental_rerun()
233
 
234
  if __name__ == "__main__":
235
  main()
 
3
  import os
4
  import gc
5
  import torch
6
+ from src.data_processing import load_huggingface_faq_data, load_faq_data, preprocess_faq, augment_faqs
7
  from src.embedding import FAQEmbedder
8
  from src.llm_response import ResponseGenerator
9
+ from src.utils import time_function, format_memory_stats, evaluate_response, evaluate_retrieval, baseline_keyword_search
10
 
11
+ # Suppress CUDA warning and Torch path errors
12
+ os.environ["CUDA_VISIBLE_DEVICES"] = ""
13
+ os.environ["TORCH_NO_PATH_CHECK"] = "1"
 
 
 
14
 
15
+ st.set_page_config(page_title="E-Commerce FAQ Chatbot", layout="wide", initial_sidebar_state="expanded")
 
 
 
16
 
17
  @time_function
18
+ def initialize_components(use_huggingface: bool = True, model_name: str = "microsoft/phi-2", enable_augmentation: bool = True):
19
+ """
20
+ Initialize RAG system components
21
+ """
22
+ try:
23
+ if use_huggingface:
24
+ faqs = load_huggingface_faq_data("NebulaByte/E-Commerce_FAQs")
25
+ else:
26
+ faqs = load_faq_data("data/faq_data.csv")
27
+
28
+ processed_faqs = augment_faqs(preprocess_faq(faqs), enable_augmentation=enable_augmentation)
29
+ embedder = FAQEmbedder()
30
+
31
+ if os.path.exists("embeddings"):
32
+ embedder.load("embeddings")
33
+ else:
34
+ embedder.create_embeddings(processed_faqs)
35
+ embedder.save("embeddings")
36
+
37
+ gc.collect()
38
+ if torch.cuda.is_available():
39
+ torch.cuda.empty_cache()
40
+
41
+ response_generator = ResponseGenerator(model_name=model_name)
42
+ response_generator.generate_response("Warmup query", [{"question": "Test", "answer": "Test"}])
43
+
44
+ return embedder, response_generator, len(processed_faqs)
45
+ except Exception as e:
46
+ st.error(f"Initialization failed: {e}")
47
+ raise
48
 
49
  def main():
50
  st.title("E-Commerce Customer Support FAQ Chatbot")
51
+ st.subheader("Ask about orders, shipping, returns, or other e-commerce queries")
52
 
 
53
  st.sidebar.title("Configuration")
54
  use_huggingface = st.sidebar.checkbox("Use Hugging Face Dataset", value=True)
55
+ enable_augmentation = st.sidebar.checkbox("Enable FAQ Augmentation", value=True, help="Generate paraphrased questions to expand dataset")
56
+ target_lang = st.sidebar.selectbox("Language", ["en", "es", "fr"], index=0)
57
 
 
58
  model_options = {
59
+ "Phi-2 (Recommended for 16GB RAM)": "microsoft/phi-2",
60
+ "TinyLlama-1.1B (Fastest)": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
61
+ "Mistral-7B (For 15GB+ GPU)": "mistralai/Mistral-7B-Instruct-v0.1"
62
  }
 
 
63
  selected_model = st.sidebar.selectbox("Select LLM Model", list(model_options.keys()), index=0)
64
  model_name = model_options[selected_model]
65
 
 
66
  if st.sidebar.checkbox("Show Memory Usage", value=True):
67
  st.sidebar.subheader("Memory Usage")
68
+ for key, value in format_memory_stats().items():
 
69
  st.sidebar.text(f"{key}: {value}")
70
 
 
71
  if "chat_history" not in st.session_state:
72
  st.session_state.chat_history = []
73
+ if "query_cache" not in st.session_state:
74
+ st.session_state.query_cache = {}
75
+ if "feedback" not in st.session_state:
76
+ st.session_state.feedback = []
77
 
 
78
  if "system_initialized" not in st.session_state or st.sidebar.button("Reload System"):
79
+ with st.spinner("Initializing system..."):
80
+ try:
81
+ st.session_state.embedder, st.session_state.response_generator, num_faqs = initialize_components(
82
+ use_huggingface=use_huggingface,
83
+ model_name=model_name,
84
+ enable_augmentation=enable_augmentation
85
+ )
86
+ st.session_state.system_initialized = True
87
+ st.sidebar.success(f"System initialized with {num_faqs} FAQs!")
88
+ except Exception as e:
89
+ st.error(f"System initialization failed: {e}")
90
+ return
91
 
 
92
  col1, col2 = st.columns([2, 1])
93
 
94
  with col1:
 
95
  st.subheader("Conversation")
96
  chat_container = st.container(height=400)
 
97
  with chat_container:
98
  for i, message in enumerate(st.session_state.chat_history):
99
  if message["role"] == "user":
100
  st.markdown(f"**You**: {message['content']}")
101
  else:
102
  st.markdown(f"**Bot**: {message['content']}")
 
103
  if i < len(st.session_state.chat_history) - 1:
104
  st.markdown("---")
105
+
 
106
  with st.form(key="chat_form"):
107
+ user_query = st.text_input("Type your question:", key="user_input", placeholder="e.g., How do I track my order?")
 
108
  submit_button = st.form_submit_button("Ask")
109
+
110
+ if len(st.session_state.chat_history) > 0:
111
+ with st.form(key=f"feedback_form_{len(st.session_state.chat_history)}"):
112
+ rating = st.slider("Rate this response (1-5)", 1, 5, key=f"rating_{len(st.session_state.chat_history)}")
113
+ comments = st.text_area("Comments", key=f"comments_{len(st.session_state.chat_history)}")
114
+ if st.form_submit_button("Submit Feedback"):
115
+ st.session_state.feedback.append({
116
+ "rating": rating,
117
+ "comments": comments,
118
+ "response": st.session_state.chat_history[-1]["content"]
119
+ })
120
+ with open("feedback.json", "w") as f:
121
+ json.dump(st.session_state.feedback, f)
122
+ st.success("Feedback submitted!")
123
 
124
  with col2:
125
  if st.session_state.get("system_initialized", False):
 
126
  st.subheader("Retrieved Information")
127
  info_container = st.container(height=500)
 
128
  with info_container:
129
  if "current_faqs" in st.session_state:
130
+ for i, faq in enumerate(st.session_state.current_faqs):
131
  st.markdown(f"**Relevant FAQ #{i+1}**")
132
  st.markdown(f"**Q**: {faq['question']}")
 
133
  st.markdown(f"**A**: {faq['answer'][:150]}..." if len(faq['answer']) > 150 else f"**A**: {faq['answer']}")
134
  st.markdown(f"*Similarity Score*: {faq['similarity']:.2f}")
135
  if 'category' in faq and faq['category']:
136
  st.markdown(f"*Category*: {faq['category']}")
137
  st.markdown("---")
138
  else:
139
+ st.markdown("Ask a question to see relevant FAQs.")
140
 
 
141
  if "retrieval_time" in st.session_state and "generation_time" in st.session_state:
142
  st.sidebar.subheader("Performance Metrics")
143
  st.sidebar.markdown(f"Retrieval time: {st.session_state.retrieval_time:.2f} seconds")
144
  st.sidebar.markdown(f"Response generation: {st.session_state.generation_time:.2f} seconds")
145
  st.sidebar.markdown(f"Total time: {st.session_state.retrieval_time + st.session_state.generation_time:.2f} seconds")
146
 
 
147
  if submit_button and user_query:
148
+ from src.data_processing import translate_faq
149
+ from googletrans import Translator
150
+ translator = Translator()
151
+ if target_lang != "en":
152
+ user_query_translated = translator.translate(user_query, dest="en").text
153
+ else:
154
+ user_query_translated = user_query
155
 
156
+ if user_query_translated in st.session_state.query_cache:
157
+ response, relevant_faqs = st.session_state.query_cache[user_query_translated]
158
+ else:
159
  gc.collect()
160
  if torch.cuda.is_available():
161
  torch.cuda.empty_cache()
162
+
 
163
  start_time = time.time()
164
+ relevant_faqs = st.session_state.embedder.retrieve_relevant_faqs(user_query_translated)
165
  retrieval_time = time.time() - start_time
166
 
167
+ if target_lang != "en":
168
+ relevant_faqs = [translate_faq(faq, target_lang) for faq in relevant_faqs]
169
+
170
  start_time = time.time()
171
+ response = st.session_state.response_generator.generate_response(user_query_translated, relevant_faqs)
172
  generation_time = time.time() - start_time
173
 
174
+ if target_lang != "en":
175
+ response = translator.translate(response, dest=target_lang).text
176
+
177
+ st.session_state.query_cache[user_query_translated] = (response, relevant_faqs)
178
  st.session_state.retrieval_time = retrieval_time
179
  st.session_state.generation_time = generation_time
180
  st.session_state.current_faqs = relevant_faqs
 
 
 
 
 
 
 
 
181
 
182
+ st.session_state.chat_history.append({"role": "user", "content": user_query})
183
+ st.session_state.chat_history.append({"role": "assistant", "content": response})
184
+
185
+ if st.button("Clear Chat History"):
186
+ st.session_state.chat_history = []
187
+ st.session_state.query_cache = {}
188
+ gc.collect()
189
+ if torch.cuda.is_available():
190
+ torch.cuda.empty_cache()
191
+
192
+ if st.session_state.get("system_initialized", False):
193
+ st.sidebar.subheader("Baseline Comparison")
194
+ baseline_faqs = baseline_keyword_search(user_query_translated if 'user_query_translated' in locals() else "", st.session_state.embedder.faqs)
195
+ st.sidebar.write(f"RAG FAQs: {[faq['question'][:50] for faq in st.session_state.get('current_faqs', [])]}")
196
+ st.sidebar.write(f"Keyword FAQs: {[faq['question'][:50] for faq in baseline_faqs]}")
197
 
 
198
  st.subheader("Sample Questions")
199
  sample_questions = [
200
  "How do I track my order?",
201
  "What should I do if my delivery is delayed?",
202
  "How do I return a product?",
203
  "Can I cancel my order after placing it?",
204
+ "How quickly will my order be delivered?"
 
205
  ]
 
 
206
  cols = st.columns(2)
207
  for i, question in enumerate(sample_questions):
208
  col_idx = i % 2
209
  if cols[col_idx].button(question, key=f"sample_{i}"):
 
210
  st.session_state.user_input = question
 
211
  st.session_state.chat_history.append({"role": "user", "content": question})
212
 
213
+ from src.data_processing import translate_faq
214
+ from googletrans import Translator
215
+ translator = Translator()
216
+ if target_lang != "en":
217
+ question_translated = translator.translate(question, dest="en").text
218
+ else:
219
+ question_translated = question
220
+
221
+ if question_translated in st.session_state.query_cache:
222
+ response, relevant_faqs = st.session_state.query_cache[question_translated]
223
+ else:
224
  gc.collect()
225
  if torch.cuda.is_available():
226
  torch.cuda.empty_cache()
227
+
 
228
  start_time = time.time()
229
+ relevant_faqs = st.session_state.embedder.retrieve_relevant_faqs(question_translated)
230
  retrieval_time = time.time() - start_time
231
 
232
+ if target_lang != "en":
233
+ relevant_faqs = [translate_faq(faq, target_lang) for faq in relevant_faqs]
234
+
235
  start_time = time.time()
236
+ response = st.session_state.response_generator.generate_response(question_translated, relevant_faqs)
237
  generation_time = time.time() - start_time
238
 
239
+ if target_lang != "en":
240
+ response = translator.translate(response, dest=target_lang).text
241
+
242
+ st.session_state.query_cache[question_translated] = (response, relevant_faqs)
243
  st.session_state.retrieval_time = retrieval_time
244
  st.session_state.generation_time = generation_time
245
  st.session_state.current_faqs = relevant_faqs
 
 
 
 
 
 
 
 
246
 
247
+ st.session_state.chat_history.append({"role": "assistant", "content": response})
 
248
 
249
  if __name__ == "__main__":
250
  main()
data/faq_data.csv CHANGED
@@ -8,4 +8,9 @@ question,answer
8
  "How do I cancel my subscription?","To cancel your subscription, log in to your account, go to 'Subscription Settings', and click on 'Cancel Subscription'. Follow the prompts to complete the cancellation process."
9
  "Is my personal information secure?","Yes, we take data security seriously. We use industry-standard encryption protocols to protect your personal information. We never share your data with third parties without your consent."
10
  "How do I contact customer support?","You can contact our customer support team via email at support@example.com, through the live chat on our website, or by calling our support line at 1-800-123-4567 during business hours (9 AM - 5 PM EST, Monday-Friday)."
11
- "Do you offer discounts for bulk orders?","Yes, we offer volume discounts for bulk orders. Please contact our sales team at sales@example.com with your requirements to get a custom quote."
 
 
 
 
 
 
8
  "How do I cancel my subscription?","To cancel your subscription, log in to your account, go to 'Subscription Settings', and click on 'Cancel Subscription'. Follow the prompts to complete the cancellation process."
9
  "Is my personal information secure?","Yes, we take data security seriously. We use industry-standard encryption protocols to protect your personal information. We never share your data with third parties without your consent."
10
  "How do I contact customer support?","You can contact our customer support team via email at support@example.com, through the live chat on our website, or by calling our support line at 1-800-123-4567 during business hours (9 AM - 5 PM EST, Monday-Friday)."
11
+ "Do you offer discounts for bulk orders?","Yes, we offer volume discounts for bulk orders. Please contact our sales team at sales@example.com with your requirements to get a custom quote."
12
+ "Can I change my shipping address?","You can change your shipping address before the order is processed by contacting support at support@example.com."
13
+ "What happens if my order is damaged?","If your order arrives damaged, please contact us within 7 days with photos of the damage to initiate a replacement or refund."
14
+ "How do I apply a discount code?","Enter your discount code at checkout in the 'Promo Code' field. The discount will be applied to eligible items in your cart."
15
+ "What if I receive the wrong item?","If you receive the wrong item, contact support within 7 days with your order number and photos of the item received. We'll arrange a replacement or refund."
16
+ "Do you offer gift cards?","Yes, we offer digital gift cards in various denominations. Purchase them on our website and send them via email to the recipient."
notes.txt DELETED
@@ -1,16 +0,0 @@
1
- # Create a virtual environment (recommended)
2
- python -m venv venv
3
- source venv/bin/activate # On Windows: venv\Scripts\activate
4
-
5
- # Install required packages
6
- pip install torch transformers sentence-transformers faiss-cpu pandas streamlit nltk
7
-
8
-
9
-
10
-
11
-
12
- # Activate your virtual environment if not already active
13
- source venv/bin/activate # On Windows: venv\Scripts\activate
14
-
15
- # Run the Streamlit app
16
- streamlit run app.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -3,8 +3,14 @@ transformers>=4.30.0
3
  sentence-transformers>=2.2.2
4
  faiss-cpu>=1.7.4
5
  pandas>=1.5.0
6
- streamlit>=1.25.0
7
  numpy>=1.24.0
8
  datasets>=2.10.0
9
  bitsandbytes>=0.40.0
10
- accelerate>=0.20.0
 
 
 
 
 
 
 
3
  sentence-transformers>=2.2.2
4
  faiss-cpu>=1.7.4
5
  pandas>=1.5.0
6
+ streamlit>=1.36.0
7
  numpy>=1.24.0
8
  datasets>=2.10.0
9
  bitsandbytes>=0.40.0
10
+ accelerate>=0.20.0
11
+ evaluate>=0.4.0
12
+ scikit-learn>=1.2.0
13
+ nlpaug>=1.1.0
14
+ googletrans==4.0.0-rc1
15
+ psutil>=5.9.0
16
+ nltk>=3.8.0
src/__init__.py CHANGED
@@ -1,2 +1 @@
1
-
2
- # This file is intentionally left empty to mark the directory as a Python package
 
1
+ # This file marks the src directory as a Python package
 
src/data_processing.py CHANGED
@@ -1,40 +1,56 @@
1
  import pandas as pd
2
  import json
 
 
3
  from typing import List, Dict, Any
4
  from datasets import load_dataset
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  def load_huggingface_faq_data(dataset_name: str = "NebulaByte/E-Commerce_FAQs") -> List[Dict[str, Any]]:
7
  """
8
- Load FAQ data from Hugging Face datasets
9
  """
10
- print(f"Loading dataset {dataset_name} from Hugging Face...")
 
 
 
 
11
 
 
12
  try:
13
- # Load the dataset
14
  dataset = load_dataset(dataset_name)
15
-
16
- # Get the train split (as seen in the screenshots)
17
- train_data = dataset["train"]
18
-
19
- # Convert to list of dictionaries
20
- faqs = []
21
- for item in train_data:
22
- # Extract the required fields
23
- faq = {
24
- "question": item["question"],
25
- "answer": item["answer"],
26
- # Include additional metadata
27
- "category": item.get("category", ""),
28
- "question_id": item.get("question_id", ""),
29
- "faq_url": item.get("faq_url", "")
30
- }
31
- faqs.append(faq)
32
-
33
- print(f"Loaded {len(faqs)} FAQ entries from Hugging Face")
34
  return faqs
35
-
36
  except Exception as e:
37
- print(f"Error loading dataset from Hugging Face: {e}")
38
  print("Falling back to local data...")
39
  return load_faq_data("data/faq_data.csv")
40
 
@@ -46,48 +62,86 @@ def load_faq_data(file_path: str) -> List[Dict[str, Any]]:
46
  try:
47
  if file_path.endswith('.csv'):
48
  df = pd.read_csv(file_path)
49
- # Assume CSV has 'question' and 'answer' columns
50
  faqs = df.to_dict('records')
51
  elif file_path.endswith('.json'):
52
  with open(file_path, 'r') as f:
53
  faqs = json.load(f)
54
  else:
55
  raise ValueError(f"Unsupported file format: {file_path}")
56
-
57
  print(f"Loaded {len(faqs)} FAQ entries")
58
  return faqs
59
  except Exception as e:
60
  print(f"Error loading data: {e}")
61
- # Create a minimal sample dataset as fallback
62
  print("Creating sample dataset as fallback")
63
  sample_faqs = [
64
- {"question": "How do I track my order?",
65
- "answer": "You can track your order by logging into your account and visiting the Order History section."},
66
- {"question": "How do I reset my password?",
67
- "answer": "To reset your password, click on the 'Forgot Password' link on the login page."}
68
  ]
69
  return sample_faqs
70
 
71
  def preprocess_faq(faqs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
72
  """
73
- Preprocess FAQ data: clean text, handle formatting
74
  """
75
  processed_faqs = []
76
  for faq in faqs:
77
- # Basic cleaning - remove extra whitespace
78
- if 'question' in faq and faq['question'] is not None:
79
- faq['question'] = faq['question'].strip()
80
- else:
81
- faq['question'] = ""
82
-
83
- if 'answer' in faq and faq['answer'] is not None:
84
- faq['answer'] = faq['answer'].strip()
85
- else:
86
- faq['answer'] = ""
 
87
 
88
  # Only include FAQs with both question and answer
89
- if faq.get('question') and faq.get('answer'):
90
  processed_faqs.append(faq)
 
 
91
 
92
  print(f"After preprocessing: {len(processed_faqs)} valid FAQ entries")
93
- return processed_faqs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import pandas as pd
2
  import json
3
+ import os
4
+ import nltk
5
  from typing import List, Dict, Any
6
  from datasets import load_dataset
7
+ import nlpaug.augmenter.word as naw
8
+ from googletrans import Translator
9
+
10
+ # Configure NLTK data path and download required resources
11
+ NLTK_DATA_PATH = os.path.join(os.path.dirname(__file__), "../nltk_data")
12
+ os.makedirs(NLTK_DATA_PATH, exist_ok=True)
13
+ nltk.data.path.append(NLTK_DATA_PATH)
14
+
15
+ def ensure_nltk_resources():
16
+ """
17
+ Ensure NLTK resources are downloaded and available
18
+ """
19
+ try:
20
+ nltk.download('averaged_perceptron_tagger', download_dir=NLTK_DATA_PATH)
21
+ nltk.download('punkt', download_dir=NLTK_DATA_PATH)
22
+ print(f"NLTK resources downloaded to {NLTK_DATA_PATH}")
23
+ return True
24
+ except Exception as e:
25
+ print(f"Failed to download NLTK resources: {e}")
26
+ return False
27
 
28
  def load_huggingface_faq_data(dataset_name: str = "NebulaByte/E-Commerce_FAQs") -> List[Dict[str, Any]]:
29
  """
30
+ Load FAQ data from Hugging Face datasets, cache locally
31
  """
32
+ local_path = "data/ecommerce_faqs.json"
33
+ if os.path.exists(local_path):
34
+ print(f"Loading cached dataset from {local_path}")
35
+ with open(local_path, 'r') as f:
36
+ return json.load(f)
37
 
38
+ print(f"Loading dataset {dataset_name} from Hugging Face...")
39
  try:
 
40
  dataset = load_dataset(dataset_name)
41
+ faqs = [{
42
+ "question": item["question"],
43
+ "answer": item["answer"],
44
+ "category": item.get("category", ""),
45
+ "question_id": item.get("question_id", ""),
46
+ "faq_url": item.get("faq_url", "")
47
+ } for item in dataset["train"]]
48
+ with open(local_path, 'w') as f:
49
+ json.dump(faqs, f)
50
+ print(f"Saved dataset to {local_path}, loaded {len(faqs)} FAQs")
 
 
 
 
 
 
 
 
 
51
  return faqs
 
52
  except Exception as e:
53
+ print(f"Error loading dataset: {e}")
54
  print("Falling back to local data...")
55
  return load_faq_data("data/faq_data.csv")
56
 
 
62
  try:
63
  if file_path.endswith('.csv'):
64
  df = pd.read_csv(file_path)
 
65
  faqs = df.to_dict('records')
66
  elif file_path.endswith('.json'):
67
  with open(file_path, 'r') as f:
68
  faqs = json.load(f)
69
  else:
70
  raise ValueError(f"Unsupported file format: {file_path}")
 
71
  print(f"Loaded {len(faqs)} FAQ entries")
72
  return faqs
73
  except Exception as e:
74
  print(f"Error loading data: {e}")
 
75
  print("Creating sample dataset as fallback")
76
  sample_faqs = [
77
+ {"question": "How do I track my order?", "answer": "You can track your order by logging into your account and visiting the Order History section."},
78
+ {"question": "How do I reset my password?", "answer": "To reset your password, click on the 'Forgot Password' link on the login page."}
 
 
79
  ]
80
  return sample_faqs
81
 
82
  def preprocess_faq(faqs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
83
  """
84
+ Preprocess FAQ data: clean text, handle formatting, and filter invalid entries
85
  """
86
  processed_faqs = []
87
  for faq in faqs:
88
+ # Safely handle question and answer fields
89
+ question = faq.get('question')
90
+ answer = faq.get('answer')
91
+
92
+ # Convert to string and strip, handling None values
93
+ question = str(question).strip() if question is not None else ""
94
+ answer = str(answer).strip() if answer is not None else ""
95
+
96
+ # Update FAQ dictionary
97
+ faq['question'] = question
98
+ faq['answer'] = answer
99
 
100
  # Only include FAQs with both question and answer
101
+ if question and answer:
102
  processed_faqs.append(faq)
103
+ else:
104
+ print(f"Skipping invalid FAQ: question='{question}', answer='{answer}'")
105
 
106
  print(f"After preprocessing: {len(processed_faqs)} valid FAQ entries")
107
+ return processed_faqs
108
+
109
+ def augment_faqs(faqs: List[Dict[str, Any]], max_faqs: int = 1000, enable_augmentation: bool = True) -> List[Dict[str, Any]]:
110
+ """
111
+ Augment FAQs with paraphrased questions if enabled
112
+ """
113
+ if not enable_augmentation:
114
+ print("Augmentation disabled; returning original FAQs")
115
+ return faqs
116
+
117
+ if not ensure_nltk_resources():
118
+ print("NLTK resources unavailable; skipping augmentation")
119
+ return faqs
120
+
121
+ aug = naw.SynonymAug()
122
+ augmented = []
123
+ for faq in faqs:
124
+ augmented.append(faq)
125
+ if len(augmented) < max_faqs:
126
+ try:
127
+ aug_question = aug.augment(faq['question'])[0]
128
+ augmented.append({"question": aug_question, "answer": faq['answer'], "category": faq.get("category", "")})
129
+ except Exception as e:
130
+ print(f"Augmentation error for question '{faq['question'][:50]}...': {e}")
131
+ print(f"Augmented to {len(augmented)} FAQs")
132
+ return augmented
133
+
134
+ def translate_faq(faq: Dict[str, Any], target_lang: str = "es") -> Dict[str, Any]:
135
+ """
136
+ Translate FAQ to a target language
137
+ """
138
+ try:
139
+ translator = Translator()
140
+ translated = faq.copy()
141
+ translated["question"] = translator.translate(faq["question"], dest=target_lang).text
142
+ translated["answer"] = translator.translate(faq["answer"], dest=target_lang).text
143
+ translated["language"] = target_lang
144
+ return translated
145
+ except Exception as e:
146
+ print(f"Translation error: {e}")
147
+ return faq
src/embedding.py CHANGED
@@ -4,50 +4,43 @@ import numpy as np
4
  from typing import List, Dict, Any
5
  import torch
6
  import gc
 
 
 
7
 
8
  class FAQEmbedder:
9
  def __init__(self, model_name: str = "all-MiniLM-L6-v2"):
10
  """
11
  Initialize the FAQ embedder with a sentence transformer model
12
- Optimized for memory efficiency
13
  """
14
- print(f"Initializing FAQ Embedder with model: {model_name}")
15
- # Use CPU for embedding model to save GPU memory for LLM
16
- self.device = "cpu"
17
  self.model = SentenceTransformer(model_name, device=self.device)
18
  self.index = None
19
  self.faqs = None
20
  self.embeddings = None
21
 
22
- def create_embeddings(self, faqs: List[Dict[str, Any]], batch_size: int = 32) -> None:
23
  """
24
  Create embeddings for all FAQs and build FAISS index
25
- Using batching for memory efficiency
26
  """
27
  self.faqs = faqs
 
 
28
  print(f"Creating embeddings for {len(faqs)} FAQs in batches of {batch_size}...")
29
 
30
- # Extract questions for embedding
31
  questions = [faq['question'] for faq in faqs]
32
-
33
- # Process in batches to reduce memory usage
34
  all_embeddings = []
35
  for i in range(0, len(questions), batch_size):
36
  batch = questions[i:i+batch_size]
37
  print(f"Processing batch {i//batch_size + 1}/{(len(questions) + batch_size - 1)//batch_size}")
38
-
39
- # Create embeddings for this batch
40
  batch_embeddings = self.model.encode(batch, show_progress_bar=False, convert_to_numpy=True)
41
  all_embeddings.append(batch_embeddings)
42
-
43
- # Combine all batches
44
- self.embeddings = np.vstack(all_embeddings).astype('float32')
45
 
46
- # Clear memory explicitly
47
  all_embeddings = None
48
  gc.collect()
49
 
50
- # Create FAISS index
51
  dimension = self.embeddings.shape[1]
52
  self.index = faiss.IndexFlatL2(dimension)
53
  self.index.add(self.embeddings)
@@ -62,20 +55,35 @@ class FAQEmbedder:
62
  if self.index is None or self.faqs is None or self.embeddings is None:
63
  raise ValueError("Embeddings not created yet. Call create_embeddings first.")
64
 
65
- # Embed the query
66
  query_embedding = self.model.encode([query], convert_to_numpy=True).astype('float32')
67
-
68
- # Search in FAISS
69
  distances, indices = self.index.search(query_embedding, k)
70
 
71
- # Get the relevant FAQs with their similarity scores
72
  relevant_faqs = []
73
  for i, idx in enumerate(indices[0]):
74
- if idx < len(self.faqs): # Ensure index is valid
75
  faq = self.faqs[idx].copy()
76
- # Convert L2 distance to similarity score (higher is better)
77
  similarity = 1.0 / (1.0 + distances[0][i])
78
  faq['similarity'] = similarity
79
  relevant_faqs.append(faq)
80
 
81
- return relevant_faqs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from typing import List, Dict, Any
5
  import torch
6
  import gc
7
+ import os
8
+ import psutil
9
+ import json
10
 
11
  class FAQEmbedder:
12
  def __init__(self, model_name: str = "all-MiniLM-L6-v2"):
13
  """
14
  Initialize the FAQ embedder with a sentence transformer model
 
15
  """
16
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
17
+ print(f"Embedding model using device: {self.device}")
 
18
  self.model = SentenceTransformer(model_name, device=self.device)
19
  self.index = None
20
  self.faqs = None
21
  self.embeddings = None
22
 
23
+ def create_embeddings(self, faqs: List[Dict[str, Any]], batch_size: int = None) -> None:
24
  """
25
  Create embeddings for all FAQs and build FAISS index
 
26
  """
27
  self.faqs = faqs
28
+ available_memory = psutil.virtual_memory().available / (1024 ** 3) # GB
29
+ batch_size = batch_size or min(64, int(available_memory * 4))
30
  print(f"Creating embeddings for {len(faqs)} FAQs in batches of {batch_size}...")
31
 
 
32
  questions = [faq['question'] for faq in faqs]
 
 
33
  all_embeddings = []
34
  for i in range(0, len(questions), batch_size):
35
  batch = questions[i:i+batch_size]
36
  print(f"Processing batch {i//batch_size + 1}/{(len(questions) + batch_size - 1)//batch_size}")
 
 
37
  batch_embeddings = self.model.encode(batch, show_progress_bar=False, convert_to_numpy=True)
38
  all_embeddings.append(batch_embeddings)
 
 
 
39
 
40
+ self.embeddings = np.vstack(all_embeddings).astype('float32')
41
  all_embeddings = None
42
  gc.collect()
43
 
 
44
  dimension = self.embeddings.shape[1]
45
  self.index = faiss.IndexFlatL2(dimension)
46
  self.index.add(self.embeddings)
 
55
  if self.index is None or self.faqs is None or self.embeddings is None:
56
  raise ValueError("Embeddings not created yet. Call create_embeddings first.")
57
 
 
58
  query_embedding = self.model.encode([query], convert_to_numpy=True).astype('float32')
 
 
59
  distances, indices = self.index.search(query_embedding, k)
60
 
 
61
  relevant_faqs = []
62
  for i, idx in enumerate(indices[0]):
63
+ if idx < len(self.faqs):
64
  faq = self.faqs[idx].copy()
 
65
  similarity = 1.0 / (1.0 + distances[0][i])
66
  faq['similarity'] = similarity
67
  relevant_faqs.append(faq)
68
 
69
+ return relevant_faqs
70
+
71
+ def save(self, path: str):
72
+ """
73
+ Save embeddings and FAQs to disk
74
+ """
75
+ os.makedirs(path, exist_ok=True)
76
+ self.model.save(path)
77
+ faiss.write_index(self.index, f"{path}/index.bin")
78
+ with open(f"{path}/faqs.json", "w") as f:
79
+ json.dump(self.faqs, f)
80
+
81
+ def load(self, path: str):
82
+ """
83
+ Load embeddings and FAQs from disk
84
+ """
85
+ self.model = SentenceTransformer(path)
86
+ self.index = faiss.read_index(f"{path}/index.bin")
87
+ with open(f"{path}/faqs.json", "r") as f:
88
+ self.faqs = json.load(f)
89
+ self.embeddings = np.array([self.model.encode(faq["question"]) for faq in self.faqs]).astype('float32')
src/llm_response.py CHANGED
@@ -2,101 +2,76 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
2
  import torch
3
  from typing import List, Dict, Any
4
  import gc
 
5
 
6
  class ResponseGenerator:
7
- def __init__(self, model_name: str = "mistralai/Mistral-7B-Instruct-v0.1"):
8
  """
9
  Initialize the response generator with an LLM
10
- Optimized for 8-11GB GPU
11
  """
12
  print(f"Loading LLM: {model_name}")
13
- print("This may take a few minutes...")
 
14
 
15
- # Load tokenizer
16
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
17
-
18
- # Configure device and data type based on available resources
19
- device = "cuda" if torch.cuda.is_available() else "cpu"
20
- print(f"Using device: {device}")
21
-
22
- # Free up memory before loading model
23
  gc.collect()
24
- if device == "cuda":
25
  torch.cuda.empty_cache()
26
 
27
- # Configure 4-bit quantization for maximum memory efficiency
28
  try:
29
- # Use 4-bit quantization for models that support it
30
- quantization_config = BitsAndBytesConfig(
31
- load_in_4bit=True,
32
- bnb_4bit_compute_dtype=torch.float16,
33
- bnb_4bit_use_double_quant=True,
34
- bnb_4bit_quant_type="nf4"
35
- )
36
-
37
- # Load the model with quantization
38
- self.model = AutoModelForCausalLM.from_pretrained(
39
- model_name,
40
- quantization_config=quantization_config,
41
- device_map="auto",
42
- torch_dtype=torch.float16,
43
- # Load model in parts to avoid OOM errors
44
- max_memory={0: "8GiB", "cpu": "16GiB"},
45
- offload_folder="offload",
46
- offload_state_dict=True, # Offload weights to CPU when not in use
47
- low_cpu_mem_usage=True
48
- )
49
- except Exception as e:
50
- print(f"4-bit quantization error: {e}")
51
- print("Falling back to 8-bit quantization...")
52
-
53
- try:
54
- # Try 8-bit quantization
55
  quantization_config = BitsAndBytesConfig(
56
- load_in_8bit=True,
57
- bnb_8bit_use_double_quant=True
 
 
58
  )
 
 
 
 
59
 
60
  self.model = AutoModelForCausalLM.from_pretrained(
61
  model_name,
62
  quantization_config=quantization_config,
63
  device_map="auto",
64
  torch_dtype=torch.float16,
65
- max_memory={0: "8GiB", "cpu": "16GiB"},
66
  offload_folder="offload",
 
67
  low_cpu_mem_usage=True
68
  )
69
- except Exception as e2:
70
- print(f"8-bit quantization error: {e2}")
71
- print("Falling back to smaller model...")
72
-
73
- # Use a much smaller model as fallback
74
- backup_model = "microsoft/phi-2"
75
- self.tokenizer = AutoTokenizer.from_pretrained(backup_model)
76
  self.model = AutoModelForCausalLM.from_pretrained(
77
- backup_model,
78
- device_map="auto",
79
- torch_dtype=torch.float16 if device == "cuda" else torch.float32
 
80
  )
 
 
 
 
 
 
 
 
 
 
81
 
82
  print("LLM loaded successfully")
83
 
84
  def generate_response(self, query: str, relevant_faqs: List[Dict[str, Any]]) -> str:
85
  """
86
  Generate a response using the LLM with retrieved FAQs as context
87
- Memory-optimized version
88
  """
89
- # Create prompt with relevant FAQs
90
  prompt = self._create_prompt(query, relevant_faqs)
91
-
92
- # Generate response with memory-efficient settings
93
  inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
94
 
95
  with torch.no_grad():
96
- # Use more conservative generation parameters
97
  outputs = self.model.generate(
98
  **inputs,
99
- max_new_tokens=200, # Shorter response for memory efficiency
100
  temperature=0.7,
101
  top_p=0.9,
102
  do_sample=True,
@@ -104,12 +79,9 @@ class ResponseGenerator:
104
  )
105
 
106
  response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
107
-
108
- # Extract just the response part (after the prompt)
109
  response = response[len(prompt):].strip()
110
 
111
- # Clear GPU memory after generating response
112
- if torch.cuda.is_available():
113
  torch.cuda.empty_cache()
114
 
115
  return response
@@ -118,19 +90,13 @@ class ResponseGenerator:
118
  """
119
  Create a prompt for the LLM with retrieved FAQs as context
120
  """
121
- # Format FAQs in a way that's suitable for the model
122
- faq_context = "\n\n".join([
123
- f"Q: {faq['question']}\nA: {faq['answer']}"
124
- for faq in relevant_faqs
125
- ])
126
-
127
- # Create the prompt
128
  prompt = f"""
129
  Below are some relevant e-commerce customer support FAQ entries:
130
 
131
  {faq_context}
132
 
133
- Based on the information above, please provide a helpful, accurate, and concise response to the following customer query:
134
  Customer Query: {query}
135
 
136
  Response:
 
2
  import torch
3
  from typing import List, Dict, Any
4
  import gc
5
+ import psutil
6
 
7
  class ResponseGenerator:
8
+ def __init__(self, model_name: str = "microsoft/phi-2"):
9
  """
10
  Initialize the response generator with an LLM
 
11
  """
12
  print(f"Loading LLM: {model_name}")
13
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
14
+ print(f"Using device: {self.device}")
15
 
 
16
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
 
 
 
 
17
  gc.collect()
18
+ if self.device == "cuda":
19
  torch.cuda.empty_cache()
20
 
 
21
  try:
22
+ if self.device == "cuda":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  quantization_config = BitsAndBytesConfig(
24
+ load_in_4bit=True,
25
+ bnb_4bit_compute_dtype=torch.float16,
26
+ bnb_4bit_use_double_quant=True,
27
+ bnb_4bit_quant_type="nf4"
28
  )
29
+ available_memory = psutil.virtual_memory().total / (1024 ** 3)
30
+ gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024 ** 3)
31
+ max_memory = {0: f"{min(gpu_memory, 15)}GiB", "cpu": f"{min(available_memory, 30)}GiB"}
32
+ print(f"Setting max_memory: {max_memory}")
33
 
34
  self.model = AutoModelForCausalLM.from_pretrained(
35
  model_name,
36
  quantization_config=quantization_config,
37
  device_map="auto",
38
  torch_dtype=torch.float16,
39
+ max_memory=max_memory,
40
  offload_folder="offload",
41
+ offload_state_dict=True,
42
  low_cpu_mem_usage=True
43
  )
44
+ else:
 
 
 
 
 
 
45
  self.model = AutoModelForCausalLM.from_pretrained(
46
+ model_name,
47
+ device_map={"": "cpu"},
48
+ torch_dtype=torch.float32,
49
+ low_cpu_mem_usage=True
50
  )
51
+ except Exception as e:
52
+ print(f"Model loading error: {e}")
53
+ print("Falling back to TinyLlama...")
54
+ model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
55
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
56
+ self.model = AutoModelForCausalLM.from_pretrained(
57
+ model_name,
58
+ device_map={"": self.device},
59
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
60
+ )
61
 
62
  print("LLM loaded successfully")
63
 
64
  def generate_response(self, query: str, relevant_faqs: List[Dict[str, Any]]) -> str:
65
  """
66
  Generate a response using the LLM with retrieved FAQs as context
 
67
  """
 
68
  prompt = self._create_prompt(query, relevant_faqs)
 
 
69
  inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
70
 
71
  with torch.no_grad():
 
72
  outputs = self.model.generate(
73
  **inputs,
74
+ max_new_tokens=150,
75
  temperature=0.7,
76
  top_p=0.9,
77
  do_sample=True,
 
79
  )
80
 
81
  response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
82
  response = response[len(prompt):].strip()
83
 
84
+ if self.device == "cuda":
 
85
  torch.cuda.empty_cache()
86
 
87
  return response
 
90
  """
91
  Create a prompt for the LLM with retrieved FAQs as context
92
  """
93
+ faq_context = "\n\n".join([f"Q: {faq['question']}\nA: {faq['answer']}" for faq in relevant_faqs])
 
 
 
 
 
 
94
  prompt = f"""
95
  Below are some relevant e-commerce customer support FAQ entries:
96
 
97
  {faq_context}
98
 
99
+ Based on the information above, provide a helpful, accurate, and concise response to the following customer query:
100
  Customer Query: {query}
101
 
102
  Response:
src/utils.py CHANGED
@@ -1,6 +1,12 @@
1
  import time
2
  import functools
3
- from typing import Callable, Any, Dict
 
 
 
 
 
 
4
 
5
  def time_function(func: Callable) -> Callable:
6
  """
@@ -17,38 +23,73 @@ def time_function(func: Callable) -> Callable:
17
 
18
  def evaluate_response(generated_response: str, ground_truth: str = None) -> Dict[str, Any]:
19
  """
20
- Basic evaluation of generated response
21
  """
22
  results = {
23
  "length": len(generated_response),
24
  "word_count": len(generated_response.split())
25
  }
26
 
27
- # If ground truth is provided, we could add metrics like BLEU or ROUGE
28
  if ground_truth:
29
- # Simplified evaluation - word overlap
 
 
 
30
  generated_words = set(generated_response.lower().split())
31
  ground_truth_words = set(ground_truth.lower().split())
32
  overlap = len(generated_words.intersection(ground_truth_words))
33
- results["word_overlap"] = overlap / len(ground_truth_words) if ground_truth_words else 0
 
 
 
 
34
 
35
  return results
36
 
37
- def format_memory_stats():
38
  """
39
- Format memory usage statistics for display
40
  """
41
- import torch
42
- import psutil
43
- import os
 
 
 
 
 
 
 
 
 
 
44
 
45
- # System memory
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  system_stats = {
47
  "RAM": f"{psutil.virtual_memory().used / (1024 ** 3):.1f}GB / {psutil.virtual_memory().total / (1024 ** 3):.1f}GB",
48
- "RAM Usage": f"{psutil.virtual_memory().percent}%",
49
  }
50
 
51
- # GPU memory if available
52
  if torch.cuda.is_available():
53
  gpu_stats = {}
54
  for i in range(torch.cuda.device_count()):
 
1
  import time
2
  import functools
3
+ from typing import Callable, Any, Dict, List
4
+ import torch
5
+ import psutil
6
+ import json
7
+ from evaluate import load
8
+ from sklearn.feature_extraction.text import TfidfVectorizer
9
+ from sklearn.metrics.pairwise import cosine_similarity
10
 
11
  def time_function(func: Callable) -> Callable:
12
  """
 
23
 
24
  def evaluate_response(generated_response: str, ground_truth: str = None) -> Dict[str, Any]:
25
  """
26
+ Evaluate generated response with BLEU, ROUGE, and word overlap
27
  """
28
  results = {
29
  "length": len(generated_response),
30
  "word_count": len(generated_response.split())
31
  }
32
 
 
33
  if ground_truth:
34
+ bleu = load("bleu")
35
+ rouge = load("rouge")
36
+ bleu_score = bleu.compute(predictions=[generated_response], references=[[ground_truth]])
37
+ rouge_score = rouge.compute(predictions=[generated_response], references=[ground_truth])
38
  generated_words = set(generated_response.lower().split())
39
  ground_truth_words = set(ground_truth.lower().split())
40
  overlap = len(generated_words.intersection(ground_truth_words))
41
+ results.update({
42
+ "bleu": bleu_score["bleu"],
43
+ "rouge": rouge_score["rougeL"],
44
+ "word_overlap": overlap / len(ground_truth_words) if ground_truth_words else 0
45
+ })
46
 
47
  return results
48
 
49
+ def evaluate_retrieval(embedder, test_set_path: str, k: int = 3) -> Dict[str, float]:
50
  """
51
+ Evaluate retrieval quality with Precision@k and Recall@k
52
  """
53
+ with open(test_set_path, 'r') as f:
54
+ test_set = json.load(f)
55
+
56
+ precision, recall = [], []
57
+ for item in test_set:
58
+ query = item['query']
59
+ true_ids = set(item['relevant_ids'])
60
+ retrieved_faqs = embedder.retrieve_relevant_faqs(query, k)
61
+ retrieved_ids = set(range(len(retrieved_faqs)))
62
+
63
+ true_positives = len(true_ids & retrieved_ids)
64
+ precision.append(true_positives / k if k > 0 else 0)
65
+ recall.append(true_positives / len(true_ids) if true_ids else 0)
66
 
67
+ return {
68
+ "Precision@k": sum(precision) / len(precision) if precision else 0,
69
+ "Recall@k": sum(recall) / len(recall) if recall else 0
70
+ }
71
+
72
+ def baseline_keyword_search(query: str, faqs: List[Dict[str, Any]], k: int = 3) -> List[Dict[str, Any]]:
73
+ """
74
+ Keyword-based search baseline using TF-IDF
75
+ """
76
+ questions = [faq['question'] for faq in faqs]
77
+ vectorizer = TfidfVectorizer()
78
+ question_vectors = vectorizer.fit_transform(questions)
79
+ query_vector = vectorizer.transform([query])
80
+ similarities = cosine_similarity(query_vector, question_vectors).flatten()
81
+ top_k_indices = similarities.argsort()[-k:][::-1]
82
+ return [faqs[i] for i in top_k_indices]
83
+
84
+ def format_memory_stats():
85
+ """
86
+ Format memory usage statistics
87
+ """
88
  system_stats = {
89
  "RAM": f"{psutil.virtual_memory().used / (1024 ** 3):.1f}GB / {psutil.virtual_memory().total / (1024 ** 3):.1f}GB",
90
+ "RAM Usage": f"{psutil.virtual_memory().percent}%"
91
  }
92
 
 
93
  if torch.cuda.is_available():
94
  gpu_stats = {}
95
  for i in range(torch.cuda.device_count()):
test_set.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ [
2
+ {"query": "How do I track my order?", "relevant_ids": [2]},
3
+ {"query": "How to reset password?", "relevant_ids": [0]},
4
+ {"query": "What is the return policy?", "relevant_ids": [3]},
5
+ {"query": "Can I change my shipping address?", "relevant_ids": [10]},
6
+ {"query": "Do you offer gift cards?", "relevant_ids": [14]}
7
+ ]