Techbite commited on
Commit
26d1a81
·
1 Parent(s): 13c7994

initial commit

Browse files
Files changed (11) hide show
  1. .gitignore +3 -0
  2. README.md +100 -12
  3. app.py +235 -0
  4. data/faq_data.csv +11 -0
  5. notes.txt +16 -0
  6. requirements.txt +10 -0
  7. src/__init__.py +2 -0
  8. src/data_processing.py +93 -0
  9. src/embedding.py +81 -0
  10. src/llm_response.py +138 -0
  11. src/utils.py +59 -0
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Ignore all files in the .venv directory
2
+ venv
3
+ __pycache__
README.md CHANGED
@@ -1,12 +1,100 @@
1
- ---
2
- title: Faq Rag Chatbot
3
- emoji: 🌖
4
- colorFrom: yellow
5
- colorTo: indigo
6
- sdk: streamlit
7
- sdk_version: 1.44.1
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FAQ Chatbot Using RAG for Customer Support - Setup Instructions
2
+
3
+ Follow these steps to set up and run the e-commerce FAQ chatbot, optimized for hardware with 16-19GB RAM and 8-11GB GPU.
4
+
5
+ ## Prerequisites
6
+
7
+ - Python 3.8 or higher
8
+ - CUDA-compatible GPU with 8-11GB VRAM
9
+ - 16-19GB RAM
10
+ - Internet connection (for downloading models and datasets)
11
+
12
+ ## Step 1: Create Project Directory Structure
13
+
14
+ ```bash
15
+ # Create the project directory
16
+ mkdir faq-rag-chatbot
17
+ cd faq-rag-chatbot
18
+
19
+ # Create the source directory
20
+ mkdir -p src data
21
+ ```
22
+
23
+ ## Step 2: Create Virtual Environment
24
+
25
+ ```bash
26
+ # Create a virtual environment
27
+ python -m venv venv
28
+
29
+ # Activate the virtual environment
30
+ # On Windows:
31
+ venv\Scripts\activate
32
+ # On macOS/Linux:
33
+ source venv/bin/activate
34
+ ```
35
+
36
+ ## Step 3: Create Project Files
37
+
38
+ Create all the required files with the optimized code provided:
39
+
40
+ 1. `requirements.txt`
41
+ 2. `src/__init__.py`
42
+ 3. `src/data_processing.py`
43
+ 4. `src/embedding.py`
44
+ 5. `src/llm_response.py`
45
+ 6. `src/utils.py`
46
+ 7. `app.py`
47
+
48
+ ## Step 4: Install Dependencies
49
+
50
+ ```bash
51
+ # Install required packages
52
+ pip install -r requirements.txt
53
+
54
+ # Additional dependency for memory monitoring
55
+ pip install psutil
56
+ ```
57
+
58
+ ## Step 5: Run the Application
59
+
60
+ ```bash
61
+ # Make sure the virtual environment is activated
62
+ # Run the Streamlit app
63
+ streamlit run app.py
64
+ ```
65
+
66
+ ## Memory Optimization Notes
67
+
68
+ This implementation includes several optimizations for systems with 16-19GB RAM and 8-11GB GPU:
69
+
70
+ 1. **Default to Smaller Models**: The app defaults to Phi-2 which works well on 8GB GPUs
71
+ 2. **4-bit Quantization**: Uses 4-bit quantization for larger models like Mistral-7B
72
+ 3. **Memory Offloading**: Offloads weights to CPU when not in use
73
+ 4. **Batch Processing**: Processes embeddings in smaller batches
74
+ 5. **Garbage Collection**: Aggressively frees memory after operations
75
+ 6. **Response Length Limits**: Generates shorter responses to save memory
76
+ 7. **CPU Embedding**: Keeps the embedding model on CPU to save GPU memory for the LLM
77
+
78
+ ## Using the Chatbot
79
+
80
+ 1. The application will automatically download the e-commerce FAQ dataset from Hugging Face
81
+ 2. Choose an appropriate model based on your available GPU memory:
82
+ - For 8GB GPU: Use Phi-2 (default)
83
+ - For 10-11GB GPU: You can try Mistral-7B with 4-bit quantization
84
+ - For limited GPU or CPU-only: Use TinyLlama-1.1B
85
+ 3. Type a question or select a sample question
86
+ 4. The system will retrieve relevant FAQs and generate a response
87
+ 5. Monitor memory usage in the sidebar
88
+
89
+ ## Troubleshooting
90
+
91
+ - **Out of Memory Errors**: If you encounter CUDA out of memory errors, switch to a smaller model like TinyLlama-1.1B
92
+ - **Slow Response Times**: First response may be slow as the model loads, subsequent responses will be faster
93
+ - **Model Loading Issues**: If Mistral-7B fails to load, the system will automatically fall back to Phi-2
94
+
95
+ ## Performance Considerations
96
+
97
+ - The embedding and retrieval components work efficiently even on limited hardware
98
+ - Response generation speed depends on the model size and available GPU memory
99
+ - For optimal performance with 8GB GPU, stick with Phi-2 model
100
+ - For faster responses with less accuracy, use TinyLlama-1.1B
app.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import time
3
+ import os
4
+ import gc
5
+ import torch
6
+ from src.data_processing import load_huggingface_faq_data, load_faq_data, preprocess_faq
7
+ from src.embedding import FAQEmbedder
8
+ from src.llm_response import ResponseGenerator
9
+ from src.utils import time_function, format_memory_stats
10
+
11
+ # Set page title and layout
12
+ st.set_page_config(
13
+ page_title="E-Commerce FAQ Chatbot",
14
+ layout="wide",
15
+ initial_sidebar_state="expanded"
16
+ )
17
+
18
+ # Memory optimization: Force garbage collection before starting
19
+ gc.collect()
20
+ if torch.cuda.is_available():
21
+ torch.cuda.empty_cache()
22
+
23
+ @time_function
24
+ def initialize_components(use_huggingface: bool = True, model_name: str = "mistralai/Mistral-7B-Instruct-v0.1"):
25
+ """Initialize all components of the RAG system with memory optimization"""
26
+
27
+ # Step 1: Load and preprocess FAQ data
28
+ if use_huggingface:
29
+ faqs = load_huggingface_faq_data("NebulaByte/E-Commerce_FAQs")
30
+ else:
31
+ data_path = os.path.join("data", "faq_data.csv")
32
+ faqs = load_faq_data(data_path)
33
+
34
+ processed_faqs = preprocess_faq(faqs)
35
+
36
+ # Step 2: Initialize and create embeddings
37
+ # Use smaller batch size for memory efficiency
38
+ embedder = FAQEmbedder()
39
+ embedder.create_embeddings(processed_faqs, batch_size=32)
40
+
41
+ # Clear memory before loading the LLM
42
+ gc.collect()
43
+ if torch.cuda.is_available():
44
+ torch.cuda.empty_cache()
45
+
46
+ # Step 3: Initialize response generator
47
+ response_generator = ResponseGenerator(model_name=model_name)
48
+
49
+ return embedder, response_generator, len(processed_faqs)
50
+
51
+ def main():
52
+ st.title("E-Commerce Customer Support FAQ Chatbot")
53
+ st.subheader("Ask any question about your orders, shipping, returns, or any other e-commerce related queries")
54
+
55
+ # Sidebar configuration
56
+ st.sidebar.title("Configuration")
57
+ use_huggingface = st.sidebar.checkbox("Use Hugging Face Dataset", value=True)
58
+
59
+ # Model options - include smaller models by default
60
+ model_options = {
61
+ "Phi-2 (Recommended for 8GB GPU)": "microsoft/phi-2",
62
+ "TinyLlama-1.1B (Smallest, fastest)": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
63
+ "Mistral-7B (Requires 4-bit quantization)": "mistralai/Mistral-7B-Instruct-v0.1"
64
+ }
65
+
66
+ # Default to Phi-2 for 8-11GB GPU
67
+ selected_model = st.sidebar.selectbox("Select LLM Model", list(model_options.keys()), index=0)
68
+ model_name = model_options[selected_model]
69
+
70
+ # Memory usage monitoring
71
+ if st.sidebar.checkbox("Show Memory Usage", value=True):
72
+ st.sidebar.subheader("Memory Usage")
73
+ memory_stats = format_memory_stats()
74
+ for key, value in memory_stats.items():
75
+ st.sidebar.text(f"{key}: {value}")
76
+
77
+ # Initialize session state for chat history if it doesn't exist
78
+ if "chat_history" not in st.session_state:
79
+ st.session_state.chat_history = []
80
+
81
+ # Initialize RAG components (only once)
82
+ if "system_initialized" not in st.session_state or st.sidebar.button("Reload System"):
83
+ with st.spinner("Initializing system components... This may take a few minutes."):
84
+ st.session_state.embedder, st.session_state.response_generator, num_faqs = initialize_components(
85
+ use_huggingface=use_huggingface,
86
+ model_name=model_name
87
+ )
88
+ st.session_state.system_initialized = True
89
+ st.sidebar.success(f"System initialized with {num_faqs} FAQs!")
90
+
91
+ # Chat interface
92
+ col1, col2 = st.columns([2, 1])
93
+
94
+ with col1:
95
+ # Display chat history
96
+ st.subheader("Conversation")
97
+ chat_container = st.container(height=400)
98
+
99
+ with chat_container:
100
+ for i, message in enumerate(st.session_state.chat_history):
101
+ if message["role"] == "user":
102
+ st.markdown(f"**You**: {message['content']}")
103
+ else:
104
+ st.markdown(f"**Bot**: {message['content']}")
105
+
106
+ if i < len(st.session_state.chat_history) - 1:
107
+ st.markdown("---")
108
+
109
+ # Chat input
110
+ with st.form(key="chat_form"):
111
+ user_query = st.text_input("Type your question:", key="user_input",
112
+ placeholder="e.g., How do I track my order?")
113
+ submit_button = st.form_submit_button("Ask")
114
+
115
+ with col2:
116
+ if st.session_state.get("system_initialized", False):
117
+ # Show FAQ metadata and information
118
+ st.subheader("Retrieved Information")
119
+ info_container = st.container(height=500)
120
+
121
+ with info_container:
122
+ if "current_faqs" in st.session_state:
123
+ for i, faq in enumerate(st.session_state.get("current_faqs", [])):
124
+ st.markdown(f"**Relevant FAQ #{i+1}**")
125
+ st.markdown(f"**Q**: {faq['question']}")
126
+ # Limit answer length to save UI memory
127
+ st.markdown(f"**A**: {faq['answer'][:150]}..." if len(faq['answer']) > 150 else f"**A**: {faq['answer']}")
128
+ st.markdown(f"*Similarity Score*: {faq['similarity']:.2f}")
129
+ if 'category' in faq and faq['category']:
130
+ st.markdown(f"*Category*: {faq['category']}")
131
+ st.markdown("---")
132
+ else:
133
+ st.markdown("Ask a question to see relevant FAQs here.")
134
+
135
+ # Performance metrics in the sidebar
136
+ if "retrieval_time" in st.session_state and "generation_time" in st.session_state:
137
+ st.sidebar.subheader("Performance Metrics")
138
+ st.sidebar.markdown(f"Retrieval time: {st.session_state.retrieval_time:.2f} seconds")
139
+ st.sidebar.markdown(f"Response generation: {st.session_state.generation_time:.2f} seconds")
140
+ st.sidebar.markdown(f"Total time: {st.session_state.retrieval_time + st.session_state.generation_time:.2f} seconds")
141
+
142
+ # Process user query
143
+ if submit_button and user_query:
144
+ # Add user query to chat history
145
+ st.session_state.chat_history.append({"role": "user", "content": user_query})
146
+
147
+ # Process query
148
+ with st.spinner("Thinking..."):
149
+ # Free memory before processing
150
+ gc.collect()
151
+ if torch.cuda.is_available():
152
+ torch.cuda.empty_cache()
153
+
154
+ # Step 1: Retrieve relevant FAQs
155
+ start_time = time.time()
156
+ relevant_faqs = st.session_state.embedder.retrieve_relevant_faqs(user_query)
157
+ retrieval_time = time.time() - start_time
158
+
159
+ # Step 2: Generate response
160
+ start_time = time.time()
161
+ response = st.session_state.response_generator.generate_response(user_query, relevant_faqs)
162
+ generation_time = time.time() - start_time
163
+
164
+ # Store metrics and retrieved FAQs
165
+ st.session_state.retrieval_time = retrieval_time
166
+ st.session_state.generation_time = generation_time
167
+ st.session_state.current_faqs = relevant_faqs
168
+
169
+ # Step 3: Add response to chat history
170
+ st.session_state.chat_history.append({"role": "assistant", "content": response})
171
+
172
+ # Free memory after processing
173
+ gc.collect()
174
+ if torch.cuda.is_available():
175
+ torch.cuda.empty_cache()
176
+
177
+ # Rerun to display the updated chat history
178
+ st.experimental_rerun()
179
+
180
+ # Add sample questions at the bottom
181
+ st.subheader("Sample Questions")
182
+ sample_questions = [
183
+ "How do I track my order?",
184
+ "What should I do if my delivery is delayed?",
185
+ "How do I return a product?",
186
+ "Can I cancel my order after placing it?",
187
+ "How quickly will my order be delivered?",
188
+ "Why can't I track my order yet?"
189
+ ]
190
+
191
+ # Use two columns instead of three to reduce memory usage
192
+ cols = st.columns(2)
193
+ for i, question in enumerate(sample_questions):
194
+ col_idx = i % 2
195
+ if cols[col_idx].button(question, key=f"sample_{i}"):
196
+ # Clear the text input and set the sample question
197
+ st.session_state.user_input = question
198
+ # Simulate form submission
199
+ st.session_state.chat_history.append({"role": "user", "content": question})
200
+
201
+ # Process query (similar to above)
202
+ with st.spinner("Thinking..."):
203
+ # Free memory before processing
204
+ gc.collect()
205
+ if torch.cuda.is_available():
206
+ torch.cuda.empty_cache()
207
+
208
+ # Step 1: Retrieve relevant FAQs
209
+ start_time = time.time()
210
+ relevant_faqs = st.session_state.embedder.retrieve_relevant_faqs(question)
211
+ retrieval_time = time.time() - start_time
212
+
213
+ # Step 2: Generate response
214
+ start_time = time.time()
215
+ response = st.session_state.response_generator.generate_response(question, relevant_faqs)
216
+ generation_time = time.time() - start_time
217
+
218
+ # Store metrics and retrieved FAQs
219
+ st.session_state.retrieval_time = retrieval_time
220
+ st.session_state.generation_time = generation_time
221
+ st.session_state.current_faqs = relevant_faqs
222
+
223
+ # Step 3: Add response to chat history
224
+ st.session_state.chat_history.append({"role": "assistant", "content": response})
225
+
226
+ # Free memory after processing
227
+ gc.collect()
228
+ if torch.cuda.is_available():
229
+ torch.cuda.empty_cache()
230
+
231
+ # Rerun to display the updated chat history
232
+ st.experimental_rerun()
233
+
234
+ if __name__ == "__main__":
235
+ main()
data/faq_data.csv ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ question,answer
2
+ "How do I reset my password?","To reset your password, click on the 'Forgot Password' link on the login page. You'll receive an email with instructions to set a new password."
3
+ "What payment methods do you accept?","We accept Visa, Mastercard, American Express, and PayPal. All payments are processed securely through our payment gateway."
4
+ "How do I track my order?","You can track your order by logging into your account and visiting the 'Order History' section. Click on the specific order to see its current status and tracking information."
5
+ "What is your return policy?","Our return policy allows returns within 30 days of purchase. Items must be in original condition with tags attached. Please fill out the return form in your order confirmation email to start the process."
6
+ "How long does shipping take?","Standard shipping typically takes 3-5 business days within the continental US. Express shipping (1-2 business days) is available for an additional fee. International shipping times vary by destination."
7
+ "Do you offer international shipping?","Yes, we ship to most countries worldwide. International shipping costs and delivery times vary based on the destination. You can see the specific shipping options during checkout."
8
+ "How do I cancel my subscription?","To cancel your subscription, log in to your account, go to 'Subscription Settings', and click on 'Cancel Subscription'. Follow the prompts to complete the cancellation process."
9
+ "Is my personal information secure?","Yes, we take data security seriously. We use industry-standard encryption protocols to protect your personal information. We never share your data with third parties without your consent."
10
+ "How do I contact customer support?","You can contact our customer support team via email at support@example.com, through the live chat on our website, or by calling our support line at 1-800-123-4567 during business hours (9 AM - 5 PM EST, Monday-Friday)."
11
+ "Do you offer discounts for bulk orders?","Yes, we offer volume discounts for bulk orders. Please contact our sales team at sales@example.com with your requirements to get a custom quote."
notes.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Create a virtual environment (recommended)
2
+ python -m venv venv
3
+ source venv/bin/activate # On Windows: venv\Scripts\activate
4
+
5
+ # Install required packages
6
+ pip install torch transformers sentence-transformers faiss-cpu pandas streamlit nltk
7
+
8
+
9
+
10
+
11
+
12
+ # Activate your virtual environment if not already active
13
+ source venv/bin/activate # On Windows: venv\Scripts\activate
14
+
15
+ # Run the Streamlit app
16
+ streamlit run app.py
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ transformers>=4.30.0
3
+ sentence-transformers>=2.2.2
4
+ faiss-cpu>=1.7.4
5
+ pandas>=1.5.0
6
+ streamlit>=1.25.0
7
+ numpy>=1.24.0
8
+ datasets>=2.10.0
9
+ bitsandbytes>=0.40.0
10
+ accelerate>=0.20.0
src/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+
2
+ # This file is intentionally left empty to mark the directory as a Python package
src/data_processing.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import json
3
+ from typing import List, Dict, Any
4
+ from datasets import load_dataset
5
+
6
+ def load_huggingface_faq_data(dataset_name: str = "NebulaByte/E-Commerce_FAQs") -> List[Dict[str, Any]]:
7
+ """
8
+ Load FAQ data from Hugging Face datasets
9
+ """
10
+ print(f"Loading dataset {dataset_name} from Hugging Face...")
11
+
12
+ try:
13
+ # Load the dataset
14
+ dataset = load_dataset(dataset_name)
15
+
16
+ # Get the train split (as seen in the screenshots)
17
+ train_data = dataset["train"]
18
+
19
+ # Convert to list of dictionaries
20
+ faqs = []
21
+ for item in train_data:
22
+ # Extract the required fields
23
+ faq = {
24
+ "question": item["question"],
25
+ "answer": item["answer"],
26
+ # Include additional metadata
27
+ "category": item.get("category", ""),
28
+ "question_id": item.get("question_id", ""),
29
+ "faq_url": item.get("faq_url", "")
30
+ }
31
+ faqs.append(faq)
32
+
33
+ print(f"Loaded {len(faqs)} FAQ entries from Hugging Face")
34
+ return faqs
35
+
36
+ except Exception as e:
37
+ print(f"Error loading dataset from Hugging Face: {e}")
38
+ print("Falling back to local data...")
39
+ return load_faq_data("data/faq_data.csv")
40
+
41
+ def load_faq_data(file_path: str) -> List[Dict[str, Any]]:
42
+ """
43
+ Load FAQ data from a local CSV or JSON file
44
+ """
45
+ print(f"Loading data from {file_path}")
46
+ try:
47
+ if file_path.endswith('.csv'):
48
+ df = pd.read_csv(file_path)
49
+ # Assume CSV has 'question' and 'answer' columns
50
+ faqs = df.to_dict('records')
51
+ elif file_path.endswith('.json'):
52
+ with open(file_path, 'r') as f:
53
+ faqs = json.load(f)
54
+ else:
55
+ raise ValueError(f"Unsupported file format: {file_path}")
56
+
57
+ print(f"Loaded {len(faqs)} FAQ entries")
58
+ return faqs
59
+ except Exception as e:
60
+ print(f"Error loading data: {e}")
61
+ # Create a minimal sample dataset as fallback
62
+ print("Creating sample dataset as fallback")
63
+ sample_faqs = [
64
+ {"question": "How do I track my order?",
65
+ "answer": "You can track your order by logging into your account and visiting the Order History section."},
66
+ {"question": "How do I reset my password?",
67
+ "answer": "To reset your password, click on the 'Forgot Password' link on the login page."}
68
+ ]
69
+ return sample_faqs
70
+
71
+ def preprocess_faq(faqs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
72
+ """
73
+ Preprocess FAQ data: clean text, handle formatting
74
+ """
75
+ processed_faqs = []
76
+ for faq in faqs:
77
+ # Basic cleaning - remove extra whitespace
78
+ if 'question' in faq and faq['question'] is not None:
79
+ faq['question'] = faq['question'].strip()
80
+ else:
81
+ faq['question'] = ""
82
+
83
+ if 'answer' in faq and faq['answer'] is not None:
84
+ faq['answer'] = faq['answer'].strip()
85
+ else:
86
+ faq['answer'] = ""
87
+
88
+ # Only include FAQs with both question and answer
89
+ if faq.get('question') and faq.get('answer'):
90
+ processed_faqs.append(faq)
91
+
92
+ print(f"After preprocessing: {len(processed_faqs)} valid FAQ entries")
93
+ return processed_faqs
src/embedding.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sentence_transformers import SentenceTransformer
2
+ import faiss
3
+ import numpy as np
4
+ from typing import List, Dict, Any
5
+ import torch
6
+ import gc
7
+
8
+ class FAQEmbedder:
9
+ def __init__(self, model_name: str = "all-MiniLM-L6-v2"):
10
+ """
11
+ Initialize the FAQ embedder with a sentence transformer model
12
+ Optimized for memory efficiency
13
+ """
14
+ print(f"Initializing FAQ Embedder with model: {model_name}")
15
+ # Use CPU for embedding model to save GPU memory for LLM
16
+ self.device = "cpu"
17
+ self.model = SentenceTransformer(model_name, device=self.device)
18
+ self.index = None
19
+ self.faqs = None
20
+ self.embeddings = None
21
+
22
+ def create_embeddings(self, faqs: List[Dict[str, Any]], batch_size: int = 32) -> None:
23
+ """
24
+ Create embeddings for all FAQs and build FAISS index
25
+ Using batching for memory efficiency
26
+ """
27
+ self.faqs = faqs
28
+ print(f"Creating embeddings for {len(faqs)} FAQs in batches of {batch_size}...")
29
+
30
+ # Extract questions for embedding
31
+ questions = [faq['question'] for faq in faqs]
32
+
33
+ # Process in batches to reduce memory usage
34
+ all_embeddings = []
35
+ for i in range(0, len(questions), batch_size):
36
+ batch = questions[i:i+batch_size]
37
+ print(f"Processing batch {i//batch_size + 1}/{(len(questions) + batch_size - 1)//batch_size}")
38
+
39
+ # Create embeddings for this batch
40
+ batch_embeddings = self.model.encode(batch, show_progress_bar=False, convert_to_numpy=True)
41
+ all_embeddings.append(batch_embeddings)
42
+
43
+ # Combine all batches
44
+ self.embeddings = np.vstack(all_embeddings).astype('float32')
45
+
46
+ # Clear memory explicitly
47
+ all_embeddings = None
48
+ gc.collect()
49
+
50
+ # Create FAISS index
51
+ dimension = self.embeddings.shape[1]
52
+ self.index = faiss.IndexFlatL2(dimension)
53
+ self.index.add(self.embeddings)
54
+
55
+ print(f"Created embeddings of shape {self.embeddings.shape}")
56
+ print(f"FAISS index contains {self.index.ntotal} vectors")
57
+
58
+ def retrieve_relevant_faqs(self, query: str, k: int = 3) -> List[Dict[str, Any]]:
59
+ """
60
+ Retrieve top-k relevant FAQs for a given query
61
+ """
62
+ if self.index is None or self.faqs is None or self.embeddings is None:
63
+ raise ValueError("Embeddings not created yet. Call create_embeddings first.")
64
+
65
+ # Embed the query
66
+ query_embedding = self.model.encode([query], convert_to_numpy=True).astype('float32')
67
+
68
+ # Search in FAISS
69
+ distances, indices = self.index.search(query_embedding, k)
70
+
71
+ # Get the relevant FAQs with their similarity scores
72
+ relevant_faqs = []
73
+ for i, idx in enumerate(indices[0]):
74
+ if idx < len(self.faqs): # Ensure index is valid
75
+ faq = self.faqs[idx].copy()
76
+ # Convert L2 distance to similarity score (higher is better)
77
+ similarity = 1.0 / (1.0 + distances[0][i])
78
+ faq['similarity'] = similarity
79
+ relevant_faqs.append(faq)
80
+
81
+ return relevant_faqs
src/llm_response.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
2
+ import torch
3
+ from typing import List, Dict, Any
4
+ import gc
5
+
6
+ class ResponseGenerator:
7
+ def __init__(self, model_name: str = "mistralai/Mistral-7B-Instruct-v0.1"):
8
+ """
9
+ Initialize the response generator with an LLM
10
+ Optimized for 8-11GB GPU
11
+ """
12
+ print(f"Loading LLM: {model_name}")
13
+ print("This may take a few minutes...")
14
+
15
+ # Load tokenizer
16
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
17
+
18
+ # Configure device and data type based on available resources
19
+ device = "cuda" if torch.cuda.is_available() else "cpu"
20
+ print(f"Using device: {device}")
21
+
22
+ # Free up memory before loading model
23
+ gc.collect()
24
+ if device == "cuda":
25
+ torch.cuda.empty_cache()
26
+
27
+ # Configure 4-bit quantization for maximum memory efficiency
28
+ try:
29
+ # Use 4-bit quantization for models that support it
30
+ quantization_config = BitsAndBytesConfig(
31
+ load_in_4bit=True,
32
+ bnb_4bit_compute_dtype=torch.float16,
33
+ bnb_4bit_use_double_quant=True,
34
+ bnb_4bit_quant_type="nf4"
35
+ )
36
+
37
+ # Load the model with quantization
38
+ self.model = AutoModelForCausalLM.from_pretrained(
39
+ model_name,
40
+ quantization_config=quantization_config,
41
+ device_map="auto",
42
+ torch_dtype=torch.float16,
43
+ # Load model in parts to avoid OOM errors
44
+ max_memory={0: "8GiB", "cpu": "16GiB"},
45
+ offload_folder="offload",
46
+ offload_state_dict=True, # Offload weights to CPU when not in use
47
+ low_cpu_mem_usage=True
48
+ )
49
+ except Exception as e:
50
+ print(f"4-bit quantization error: {e}")
51
+ print("Falling back to 8-bit quantization...")
52
+
53
+ try:
54
+ # Try 8-bit quantization
55
+ quantization_config = BitsAndBytesConfig(
56
+ load_in_8bit=True,
57
+ bnb_8bit_use_double_quant=True
58
+ )
59
+
60
+ self.model = AutoModelForCausalLM.from_pretrained(
61
+ model_name,
62
+ quantization_config=quantization_config,
63
+ device_map="auto",
64
+ torch_dtype=torch.float16,
65
+ max_memory={0: "8GiB", "cpu": "16GiB"},
66
+ offload_folder="offload",
67
+ low_cpu_mem_usage=True
68
+ )
69
+ except Exception as e2:
70
+ print(f"8-bit quantization error: {e2}")
71
+ print("Falling back to smaller model...")
72
+
73
+ # Use a much smaller model as fallback
74
+ backup_model = "microsoft/phi-2"
75
+ self.tokenizer = AutoTokenizer.from_pretrained(backup_model)
76
+ self.model = AutoModelForCausalLM.from_pretrained(
77
+ backup_model,
78
+ device_map="auto",
79
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32
80
+ )
81
+
82
+ print("LLM loaded successfully")
83
+
84
+ def generate_response(self, query: str, relevant_faqs: List[Dict[str, Any]]) -> str:
85
+ """
86
+ Generate a response using the LLM with retrieved FAQs as context
87
+ Memory-optimized version
88
+ """
89
+ # Create prompt with relevant FAQs
90
+ prompt = self._create_prompt(query, relevant_faqs)
91
+
92
+ # Generate response with memory-efficient settings
93
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
94
+
95
+ with torch.no_grad():
96
+ # Use more conservative generation parameters
97
+ outputs = self.model.generate(
98
+ **inputs,
99
+ max_new_tokens=200, # Shorter response for memory efficiency
100
+ temperature=0.7,
101
+ top_p=0.9,
102
+ do_sample=True,
103
+ pad_token_id=self.tokenizer.eos_token_id
104
+ )
105
+
106
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
107
+
108
+ # Extract just the response part (after the prompt)
109
+ response = response[len(prompt):].strip()
110
+
111
+ # Clear GPU memory after generating response
112
+ if torch.cuda.is_available():
113
+ torch.cuda.empty_cache()
114
+
115
+ return response
116
+
117
+ def _create_prompt(self, query: str, relevant_faqs: List[Dict[str, Any]]) -> str:
118
+ """
119
+ Create a prompt for the LLM with retrieved FAQs as context
120
+ """
121
+ # Format FAQs in a way that's suitable for the model
122
+ faq_context = "\n\n".join([
123
+ f"Q: {faq['question']}\nA: {faq['answer']}"
124
+ for faq in relevant_faqs
125
+ ])
126
+
127
+ # Create the prompt
128
+ prompt = f"""
129
+ Below are some relevant e-commerce customer support FAQ entries:
130
+
131
+ {faq_context}
132
+
133
+ Based on the information above, please provide a helpful, accurate, and concise response to the following customer query:
134
+ Customer Query: {query}
135
+
136
+ Response:
137
+ """
138
+ return prompt
src/utils.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import functools
3
+ from typing import Callable, Any, Dict
4
+
5
+ def time_function(func: Callable) -> Callable:
6
+ """
7
+ Decorator to time function execution
8
+ """
9
+ @functools.wraps(func)
10
+ def wrapper(*args, **kwargs):
11
+ start_time = time.time()
12
+ result = func(*args, **kwargs)
13
+ end_time = time.time()
14
+ print(f"{func.__name__} took {end_time - start_time:.2f} seconds to execute")
15
+ return result
16
+ return wrapper
17
+
18
+ def evaluate_response(generated_response: str, ground_truth: str = None) -> Dict[str, Any]:
19
+ """
20
+ Basic evaluation of generated response
21
+ """
22
+ results = {
23
+ "length": len(generated_response),
24
+ "word_count": len(generated_response.split())
25
+ }
26
+
27
+ # If ground truth is provided, we could add metrics like BLEU or ROUGE
28
+ if ground_truth:
29
+ # Simplified evaluation - word overlap
30
+ generated_words = set(generated_response.lower().split())
31
+ ground_truth_words = set(ground_truth.lower().split())
32
+ overlap = len(generated_words.intersection(ground_truth_words))
33
+ results["word_overlap"] = overlap / len(ground_truth_words) if ground_truth_words else 0
34
+
35
+ return results
36
+
37
+ def format_memory_stats():
38
+ """
39
+ Format memory usage statistics for display
40
+ """
41
+ import torch
42
+ import psutil
43
+ import os
44
+
45
+ # System memory
46
+ system_stats = {
47
+ "RAM": f"{psutil.virtual_memory().used / (1024 ** 3):.1f}GB / {psutil.virtual_memory().total / (1024 ** 3):.1f}GB",
48
+ "RAM Usage": f"{psutil.virtual_memory().percent}%",
49
+ }
50
+
51
+ # GPU memory if available
52
+ if torch.cuda.is_available():
53
+ gpu_stats = {}
54
+ for i in range(torch.cuda.device_count()):
55
+ gpu_stats[f"GPU {i}"] = f"{torch.cuda.get_device_name(i)}"
56
+ gpu_stats[f"GPU {i} Memory"] = f"{torch.cuda.memory_allocated(i) / (1024 ** 3):.1f}GB / {torch.cuda.get_device_properties(i).total_memory / (1024 ** 3):.1f}GB"
57
+ system_stats.update(gpu_stats)
58
+
59
+ return system_stats