Spaces:
Sleeping
Sleeping
initial commit
Browse files- .gitignore +3 -0
- README.md +100 -12
- app.py +235 -0
- data/faq_data.csv +11 -0
- notes.txt +16 -0
- requirements.txt +10 -0
- src/__init__.py +2 -0
- src/data_processing.py +93 -0
- src/embedding.py +81 -0
- src/llm_response.py +138 -0
- src/utils.py +59 -0
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# Ignore all files in the .venv directory
|
2 |
+
venv
|
3 |
+
__pycache__
|
README.md
CHANGED
@@ -1,12 +1,100 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# FAQ Chatbot Using RAG for Customer Support - Setup Instructions
|
2 |
+
|
3 |
+
Follow these steps to set up and run the e-commerce FAQ chatbot, optimized for hardware with 16-19GB RAM and 8-11GB GPU.
|
4 |
+
|
5 |
+
## Prerequisites
|
6 |
+
|
7 |
+
- Python 3.8 or higher
|
8 |
+
- CUDA-compatible GPU with 8-11GB VRAM
|
9 |
+
- 16-19GB RAM
|
10 |
+
- Internet connection (for downloading models and datasets)
|
11 |
+
|
12 |
+
## Step 1: Create Project Directory Structure
|
13 |
+
|
14 |
+
```bash
|
15 |
+
# Create the project directory
|
16 |
+
mkdir faq-rag-chatbot
|
17 |
+
cd faq-rag-chatbot
|
18 |
+
|
19 |
+
# Create the source directory
|
20 |
+
mkdir -p src data
|
21 |
+
```
|
22 |
+
|
23 |
+
## Step 2: Create Virtual Environment
|
24 |
+
|
25 |
+
```bash
|
26 |
+
# Create a virtual environment
|
27 |
+
python -m venv venv
|
28 |
+
|
29 |
+
# Activate the virtual environment
|
30 |
+
# On Windows:
|
31 |
+
venv\Scripts\activate
|
32 |
+
# On macOS/Linux:
|
33 |
+
source venv/bin/activate
|
34 |
+
```
|
35 |
+
|
36 |
+
## Step 3: Create Project Files
|
37 |
+
|
38 |
+
Create all the required files with the optimized code provided:
|
39 |
+
|
40 |
+
1. `requirements.txt`
|
41 |
+
2. `src/__init__.py`
|
42 |
+
3. `src/data_processing.py`
|
43 |
+
4. `src/embedding.py`
|
44 |
+
5. `src/llm_response.py`
|
45 |
+
6. `src/utils.py`
|
46 |
+
7. `app.py`
|
47 |
+
|
48 |
+
## Step 4: Install Dependencies
|
49 |
+
|
50 |
+
```bash
|
51 |
+
# Install required packages
|
52 |
+
pip install -r requirements.txt
|
53 |
+
|
54 |
+
# Additional dependency for memory monitoring
|
55 |
+
pip install psutil
|
56 |
+
```
|
57 |
+
|
58 |
+
## Step 5: Run the Application
|
59 |
+
|
60 |
+
```bash
|
61 |
+
# Make sure the virtual environment is activated
|
62 |
+
# Run the Streamlit app
|
63 |
+
streamlit run app.py
|
64 |
+
```
|
65 |
+
|
66 |
+
## Memory Optimization Notes
|
67 |
+
|
68 |
+
This implementation includes several optimizations for systems with 16-19GB RAM and 8-11GB GPU:
|
69 |
+
|
70 |
+
1. **Default to Smaller Models**: The app defaults to Phi-2 which works well on 8GB GPUs
|
71 |
+
2. **4-bit Quantization**: Uses 4-bit quantization for larger models like Mistral-7B
|
72 |
+
3. **Memory Offloading**: Offloads weights to CPU when not in use
|
73 |
+
4. **Batch Processing**: Processes embeddings in smaller batches
|
74 |
+
5. **Garbage Collection**: Aggressively frees memory after operations
|
75 |
+
6. **Response Length Limits**: Generates shorter responses to save memory
|
76 |
+
7. **CPU Embedding**: Keeps the embedding model on CPU to save GPU memory for the LLM
|
77 |
+
|
78 |
+
## Using the Chatbot
|
79 |
+
|
80 |
+
1. The application will automatically download the e-commerce FAQ dataset from Hugging Face
|
81 |
+
2. Choose an appropriate model based on your available GPU memory:
|
82 |
+
- For 8GB GPU: Use Phi-2 (default)
|
83 |
+
- For 10-11GB GPU: You can try Mistral-7B with 4-bit quantization
|
84 |
+
- For limited GPU or CPU-only: Use TinyLlama-1.1B
|
85 |
+
3. Type a question or select a sample question
|
86 |
+
4. The system will retrieve relevant FAQs and generate a response
|
87 |
+
5. Monitor memory usage in the sidebar
|
88 |
+
|
89 |
+
## Troubleshooting
|
90 |
+
|
91 |
+
- **Out of Memory Errors**: If you encounter CUDA out of memory errors, switch to a smaller model like TinyLlama-1.1B
|
92 |
+
- **Slow Response Times**: First response may be slow as the model loads, subsequent responses will be faster
|
93 |
+
- **Model Loading Issues**: If Mistral-7B fails to load, the system will automatically fall back to Phi-2
|
94 |
+
|
95 |
+
## Performance Considerations
|
96 |
+
|
97 |
+
- The embedding and retrieval components work efficiently even on limited hardware
|
98 |
+
- Response generation speed depends on the model size and available GPU memory
|
99 |
+
- For optimal performance with 8GB GPU, stick with Phi-2 model
|
100 |
+
- For faster responses with less accuracy, use TinyLlama-1.1B
|
app.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import time
|
3 |
+
import os
|
4 |
+
import gc
|
5 |
+
import torch
|
6 |
+
from src.data_processing import load_huggingface_faq_data, load_faq_data, preprocess_faq
|
7 |
+
from src.embedding import FAQEmbedder
|
8 |
+
from src.llm_response import ResponseGenerator
|
9 |
+
from src.utils import time_function, format_memory_stats
|
10 |
+
|
11 |
+
# Set page title and layout
|
12 |
+
st.set_page_config(
|
13 |
+
page_title="E-Commerce FAQ Chatbot",
|
14 |
+
layout="wide",
|
15 |
+
initial_sidebar_state="expanded"
|
16 |
+
)
|
17 |
+
|
18 |
+
# Memory optimization: Force garbage collection before starting
|
19 |
+
gc.collect()
|
20 |
+
if torch.cuda.is_available():
|
21 |
+
torch.cuda.empty_cache()
|
22 |
+
|
23 |
+
@time_function
|
24 |
+
def initialize_components(use_huggingface: bool = True, model_name: str = "mistralai/Mistral-7B-Instruct-v0.1"):
|
25 |
+
"""Initialize all components of the RAG system with memory optimization"""
|
26 |
+
|
27 |
+
# Step 1: Load and preprocess FAQ data
|
28 |
+
if use_huggingface:
|
29 |
+
faqs = load_huggingface_faq_data("NebulaByte/E-Commerce_FAQs")
|
30 |
+
else:
|
31 |
+
data_path = os.path.join("data", "faq_data.csv")
|
32 |
+
faqs = load_faq_data(data_path)
|
33 |
+
|
34 |
+
processed_faqs = preprocess_faq(faqs)
|
35 |
+
|
36 |
+
# Step 2: Initialize and create embeddings
|
37 |
+
# Use smaller batch size for memory efficiency
|
38 |
+
embedder = FAQEmbedder()
|
39 |
+
embedder.create_embeddings(processed_faqs, batch_size=32)
|
40 |
+
|
41 |
+
# Clear memory before loading the LLM
|
42 |
+
gc.collect()
|
43 |
+
if torch.cuda.is_available():
|
44 |
+
torch.cuda.empty_cache()
|
45 |
+
|
46 |
+
# Step 3: Initialize response generator
|
47 |
+
response_generator = ResponseGenerator(model_name=model_name)
|
48 |
+
|
49 |
+
return embedder, response_generator, len(processed_faqs)
|
50 |
+
|
51 |
+
def main():
|
52 |
+
st.title("E-Commerce Customer Support FAQ Chatbot")
|
53 |
+
st.subheader("Ask any question about your orders, shipping, returns, or any other e-commerce related queries")
|
54 |
+
|
55 |
+
# Sidebar configuration
|
56 |
+
st.sidebar.title("Configuration")
|
57 |
+
use_huggingface = st.sidebar.checkbox("Use Hugging Face Dataset", value=True)
|
58 |
+
|
59 |
+
# Model options - include smaller models by default
|
60 |
+
model_options = {
|
61 |
+
"Phi-2 (Recommended for 8GB GPU)": "microsoft/phi-2",
|
62 |
+
"TinyLlama-1.1B (Smallest, fastest)": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
63 |
+
"Mistral-7B (Requires 4-bit quantization)": "mistralai/Mistral-7B-Instruct-v0.1"
|
64 |
+
}
|
65 |
+
|
66 |
+
# Default to Phi-2 for 8-11GB GPU
|
67 |
+
selected_model = st.sidebar.selectbox("Select LLM Model", list(model_options.keys()), index=0)
|
68 |
+
model_name = model_options[selected_model]
|
69 |
+
|
70 |
+
# Memory usage monitoring
|
71 |
+
if st.sidebar.checkbox("Show Memory Usage", value=True):
|
72 |
+
st.sidebar.subheader("Memory Usage")
|
73 |
+
memory_stats = format_memory_stats()
|
74 |
+
for key, value in memory_stats.items():
|
75 |
+
st.sidebar.text(f"{key}: {value}")
|
76 |
+
|
77 |
+
# Initialize session state for chat history if it doesn't exist
|
78 |
+
if "chat_history" not in st.session_state:
|
79 |
+
st.session_state.chat_history = []
|
80 |
+
|
81 |
+
# Initialize RAG components (only once)
|
82 |
+
if "system_initialized" not in st.session_state or st.sidebar.button("Reload System"):
|
83 |
+
with st.spinner("Initializing system components... This may take a few minutes."):
|
84 |
+
st.session_state.embedder, st.session_state.response_generator, num_faqs = initialize_components(
|
85 |
+
use_huggingface=use_huggingface,
|
86 |
+
model_name=model_name
|
87 |
+
)
|
88 |
+
st.session_state.system_initialized = True
|
89 |
+
st.sidebar.success(f"System initialized with {num_faqs} FAQs!")
|
90 |
+
|
91 |
+
# Chat interface
|
92 |
+
col1, col2 = st.columns([2, 1])
|
93 |
+
|
94 |
+
with col1:
|
95 |
+
# Display chat history
|
96 |
+
st.subheader("Conversation")
|
97 |
+
chat_container = st.container(height=400)
|
98 |
+
|
99 |
+
with chat_container:
|
100 |
+
for i, message in enumerate(st.session_state.chat_history):
|
101 |
+
if message["role"] == "user":
|
102 |
+
st.markdown(f"**You**: {message['content']}")
|
103 |
+
else:
|
104 |
+
st.markdown(f"**Bot**: {message['content']}")
|
105 |
+
|
106 |
+
if i < len(st.session_state.chat_history) - 1:
|
107 |
+
st.markdown("---")
|
108 |
+
|
109 |
+
# Chat input
|
110 |
+
with st.form(key="chat_form"):
|
111 |
+
user_query = st.text_input("Type your question:", key="user_input",
|
112 |
+
placeholder="e.g., How do I track my order?")
|
113 |
+
submit_button = st.form_submit_button("Ask")
|
114 |
+
|
115 |
+
with col2:
|
116 |
+
if st.session_state.get("system_initialized", False):
|
117 |
+
# Show FAQ metadata and information
|
118 |
+
st.subheader("Retrieved Information")
|
119 |
+
info_container = st.container(height=500)
|
120 |
+
|
121 |
+
with info_container:
|
122 |
+
if "current_faqs" in st.session_state:
|
123 |
+
for i, faq in enumerate(st.session_state.get("current_faqs", [])):
|
124 |
+
st.markdown(f"**Relevant FAQ #{i+1}**")
|
125 |
+
st.markdown(f"**Q**: {faq['question']}")
|
126 |
+
# Limit answer length to save UI memory
|
127 |
+
st.markdown(f"**A**: {faq['answer'][:150]}..." if len(faq['answer']) > 150 else f"**A**: {faq['answer']}")
|
128 |
+
st.markdown(f"*Similarity Score*: {faq['similarity']:.2f}")
|
129 |
+
if 'category' in faq and faq['category']:
|
130 |
+
st.markdown(f"*Category*: {faq['category']}")
|
131 |
+
st.markdown("---")
|
132 |
+
else:
|
133 |
+
st.markdown("Ask a question to see relevant FAQs here.")
|
134 |
+
|
135 |
+
# Performance metrics in the sidebar
|
136 |
+
if "retrieval_time" in st.session_state and "generation_time" in st.session_state:
|
137 |
+
st.sidebar.subheader("Performance Metrics")
|
138 |
+
st.sidebar.markdown(f"Retrieval time: {st.session_state.retrieval_time:.2f} seconds")
|
139 |
+
st.sidebar.markdown(f"Response generation: {st.session_state.generation_time:.2f} seconds")
|
140 |
+
st.sidebar.markdown(f"Total time: {st.session_state.retrieval_time + st.session_state.generation_time:.2f} seconds")
|
141 |
+
|
142 |
+
# Process user query
|
143 |
+
if submit_button and user_query:
|
144 |
+
# Add user query to chat history
|
145 |
+
st.session_state.chat_history.append({"role": "user", "content": user_query})
|
146 |
+
|
147 |
+
# Process query
|
148 |
+
with st.spinner("Thinking..."):
|
149 |
+
# Free memory before processing
|
150 |
+
gc.collect()
|
151 |
+
if torch.cuda.is_available():
|
152 |
+
torch.cuda.empty_cache()
|
153 |
+
|
154 |
+
# Step 1: Retrieve relevant FAQs
|
155 |
+
start_time = time.time()
|
156 |
+
relevant_faqs = st.session_state.embedder.retrieve_relevant_faqs(user_query)
|
157 |
+
retrieval_time = time.time() - start_time
|
158 |
+
|
159 |
+
# Step 2: Generate response
|
160 |
+
start_time = time.time()
|
161 |
+
response = st.session_state.response_generator.generate_response(user_query, relevant_faqs)
|
162 |
+
generation_time = time.time() - start_time
|
163 |
+
|
164 |
+
# Store metrics and retrieved FAQs
|
165 |
+
st.session_state.retrieval_time = retrieval_time
|
166 |
+
st.session_state.generation_time = generation_time
|
167 |
+
st.session_state.current_faqs = relevant_faqs
|
168 |
+
|
169 |
+
# Step 3: Add response to chat history
|
170 |
+
st.session_state.chat_history.append({"role": "assistant", "content": response})
|
171 |
+
|
172 |
+
# Free memory after processing
|
173 |
+
gc.collect()
|
174 |
+
if torch.cuda.is_available():
|
175 |
+
torch.cuda.empty_cache()
|
176 |
+
|
177 |
+
# Rerun to display the updated chat history
|
178 |
+
st.experimental_rerun()
|
179 |
+
|
180 |
+
# Add sample questions at the bottom
|
181 |
+
st.subheader("Sample Questions")
|
182 |
+
sample_questions = [
|
183 |
+
"How do I track my order?",
|
184 |
+
"What should I do if my delivery is delayed?",
|
185 |
+
"How do I return a product?",
|
186 |
+
"Can I cancel my order after placing it?",
|
187 |
+
"How quickly will my order be delivered?",
|
188 |
+
"Why can't I track my order yet?"
|
189 |
+
]
|
190 |
+
|
191 |
+
# Use two columns instead of three to reduce memory usage
|
192 |
+
cols = st.columns(2)
|
193 |
+
for i, question in enumerate(sample_questions):
|
194 |
+
col_idx = i % 2
|
195 |
+
if cols[col_idx].button(question, key=f"sample_{i}"):
|
196 |
+
# Clear the text input and set the sample question
|
197 |
+
st.session_state.user_input = question
|
198 |
+
# Simulate form submission
|
199 |
+
st.session_state.chat_history.append({"role": "user", "content": question})
|
200 |
+
|
201 |
+
# Process query (similar to above)
|
202 |
+
with st.spinner("Thinking..."):
|
203 |
+
# Free memory before processing
|
204 |
+
gc.collect()
|
205 |
+
if torch.cuda.is_available():
|
206 |
+
torch.cuda.empty_cache()
|
207 |
+
|
208 |
+
# Step 1: Retrieve relevant FAQs
|
209 |
+
start_time = time.time()
|
210 |
+
relevant_faqs = st.session_state.embedder.retrieve_relevant_faqs(question)
|
211 |
+
retrieval_time = time.time() - start_time
|
212 |
+
|
213 |
+
# Step 2: Generate response
|
214 |
+
start_time = time.time()
|
215 |
+
response = st.session_state.response_generator.generate_response(question, relevant_faqs)
|
216 |
+
generation_time = time.time() - start_time
|
217 |
+
|
218 |
+
# Store metrics and retrieved FAQs
|
219 |
+
st.session_state.retrieval_time = retrieval_time
|
220 |
+
st.session_state.generation_time = generation_time
|
221 |
+
st.session_state.current_faqs = relevant_faqs
|
222 |
+
|
223 |
+
# Step 3: Add response to chat history
|
224 |
+
st.session_state.chat_history.append({"role": "assistant", "content": response})
|
225 |
+
|
226 |
+
# Free memory after processing
|
227 |
+
gc.collect()
|
228 |
+
if torch.cuda.is_available():
|
229 |
+
torch.cuda.empty_cache()
|
230 |
+
|
231 |
+
# Rerun to display the updated chat history
|
232 |
+
st.experimental_rerun()
|
233 |
+
|
234 |
+
if __name__ == "__main__":
|
235 |
+
main()
|
data/faq_data.csv
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
question,answer
|
2 |
+
"How do I reset my password?","To reset your password, click on the 'Forgot Password' link on the login page. You'll receive an email with instructions to set a new password."
|
3 |
+
"What payment methods do you accept?","We accept Visa, Mastercard, American Express, and PayPal. All payments are processed securely through our payment gateway."
|
4 |
+
"How do I track my order?","You can track your order by logging into your account and visiting the 'Order History' section. Click on the specific order to see its current status and tracking information."
|
5 |
+
"What is your return policy?","Our return policy allows returns within 30 days of purchase. Items must be in original condition with tags attached. Please fill out the return form in your order confirmation email to start the process."
|
6 |
+
"How long does shipping take?","Standard shipping typically takes 3-5 business days within the continental US. Express shipping (1-2 business days) is available for an additional fee. International shipping times vary by destination."
|
7 |
+
"Do you offer international shipping?","Yes, we ship to most countries worldwide. International shipping costs and delivery times vary based on the destination. You can see the specific shipping options during checkout."
|
8 |
+
"How do I cancel my subscription?","To cancel your subscription, log in to your account, go to 'Subscription Settings', and click on 'Cancel Subscription'. Follow the prompts to complete the cancellation process."
|
9 |
+
"Is my personal information secure?","Yes, we take data security seriously. We use industry-standard encryption protocols to protect your personal information. We never share your data with third parties without your consent."
|
10 |
+
"How do I contact customer support?","You can contact our customer support team via email at support@example.com, through the live chat on our website, or by calling our support line at 1-800-123-4567 during business hours (9 AM - 5 PM EST, Monday-Friday)."
|
11 |
+
"Do you offer discounts for bulk orders?","Yes, we offer volume discounts for bulk orders. Please contact our sales team at sales@example.com with your requirements to get a custom quote."
|
notes.txt
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Create a virtual environment (recommended)
|
2 |
+
python -m venv venv
|
3 |
+
source venv/bin/activate # On Windows: venv\Scripts\activate
|
4 |
+
|
5 |
+
# Install required packages
|
6 |
+
pip install torch transformers sentence-transformers faiss-cpu pandas streamlit nltk
|
7 |
+
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
# Activate your virtual environment if not already active
|
13 |
+
source venv/bin/activate # On Windows: venv\Scripts\activate
|
14 |
+
|
15 |
+
# Run the Streamlit app
|
16 |
+
streamlit run app.py
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch>=2.0.0
|
2 |
+
transformers>=4.30.0
|
3 |
+
sentence-transformers>=2.2.2
|
4 |
+
faiss-cpu>=1.7.4
|
5 |
+
pandas>=1.5.0
|
6 |
+
streamlit>=1.25.0
|
7 |
+
numpy>=1.24.0
|
8 |
+
datasets>=2.10.0
|
9 |
+
bitsandbytes>=0.40.0
|
10 |
+
accelerate>=0.20.0
|
src/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# This file is intentionally left empty to mark the directory as a Python package
|
src/data_processing.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import json
|
3 |
+
from typing import List, Dict, Any
|
4 |
+
from datasets import load_dataset
|
5 |
+
|
6 |
+
def load_huggingface_faq_data(dataset_name: str = "NebulaByte/E-Commerce_FAQs") -> List[Dict[str, Any]]:
|
7 |
+
"""
|
8 |
+
Load FAQ data from Hugging Face datasets
|
9 |
+
"""
|
10 |
+
print(f"Loading dataset {dataset_name} from Hugging Face...")
|
11 |
+
|
12 |
+
try:
|
13 |
+
# Load the dataset
|
14 |
+
dataset = load_dataset(dataset_name)
|
15 |
+
|
16 |
+
# Get the train split (as seen in the screenshots)
|
17 |
+
train_data = dataset["train"]
|
18 |
+
|
19 |
+
# Convert to list of dictionaries
|
20 |
+
faqs = []
|
21 |
+
for item in train_data:
|
22 |
+
# Extract the required fields
|
23 |
+
faq = {
|
24 |
+
"question": item["question"],
|
25 |
+
"answer": item["answer"],
|
26 |
+
# Include additional metadata
|
27 |
+
"category": item.get("category", ""),
|
28 |
+
"question_id": item.get("question_id", ""),
|
29 |
+
"faq_url": item.get("faq_url", "")
|
30 |
+
}
|
31 |
+
faqs.append(faq)
|
32 |
+
|
33 |
+
print(f"Loaded {len(faqs)} FAQ entries from Hugging Face")
|
34 |
+
return faqs
|
35 |
+
|
36 |
+
except Exception as e:
|
37 |
+
print(f"Error loading dataset from Hugging Face: {e}")
|
38 |
+
print("Falling back to local data...")
|
39 |
+
return load_faq_data("data/faq_data.csv")
|
40 |
+
|
41 |
+
def load_faq_data(file_path: str) -> List[Dict[str, Any]]:
|
42 |
+
"""
|
43 |
+
Load FAQ data from a local CSV or JSON file
|
44 |
+
"""
|
45 |
+
print(f"Loading data from {file_path}")
|
46 |
+
try:
|
47 |
+
if file_path.endswith('.csv'):
|
48 |
+
df = pd.read_csv(file_path)
|
49 |
+
# Assume CSV has 'question' and 'answer' columns
|
50 |
+
faqs = df.to_dict('records')
|
51 |
+
elif file_path.endswith('.json'):
|
52 |
+
with open(file_path, 'r') as f:
|
53 |
+
faqs = json.load(f)
|
54 |
+
else:
|
55 |
+
raise ValueError(f"Unsupported file format: {file_path}")
|
56 |
+
|
57 |
+
print(f"Loaded {len(faqs)} FAQ entries")
|
58 |
+
return faqs
|
59 |
+
except Exception as e:
|
60 |
+
print(f"Error loading data: {e}")
|
61 |
+
# Create a minimal sample dataset as fallback
|
62 |
+
print("Creating sample dataset as fallback")
|
63 |
+
sample_faqs = [
|
64 |
+
{"question": "How do I track my order?",
|
65 |
+
"answer": "You can track your order by logging into your account and visiting the Order History section."},
|
66 |
+
{"question": "How do I reset my password?",
|
67 |
+
"answer": "To reset your password, click on the 'Forgot Password' link on the login page."}
|
68 |
+
]
|
69 |
+
return sample_faqs
|
70 |
+
|
71 |
+
def preprocess_faq(faqs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
72 |
+
"""
|
73 |
+
Preprocess FAQ data: clean text, handle formatting
|
74 |
+
"""
|
75 |
+
processed_faqs = []
|
76 |
+
for faq in faqs:
|
77 |
+
# Basic cleaning - remove extra whitespace
|
78 |
+
if 'question' in faq and faq['question'] is not None:
|
79 |
+
faq['question'] = faq['question'].strip()
|
80 |
+
else:
|
81 |
+
faq['question'] = ""
|
82 |
+
|
83 |
+
if 'answer' in faq and faq['answer'] is not None:
|
84 |
+
faq['answer'] = faq['answer'].strip()
|
85 |
+
else:
|
86 |
+
faq['answer'] = ""
|
87 |
+
|
88 |
+
# Only include FAQs with both question and answer
|
89 |
+
if faq.get('question') and faq.get('answer'):
|
90 |
+
processed_faqs.append(faq)
|
91 |
+
|
92 |
+
print(f"After preprocessing: {len(processed_faqs)} valid FAQ entries")
|
93 |
+
return processed_faqs
|
src/embedding.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sentence_transformers import SentenceTransformer
|
2 |
+
import faiss
|
3 |
+
import numpy as np
|
4 |
+
from typing import List, Dict, Any
|
5 |
+
import torch
|
6 |
+
import gc
|
7 |
+
|
8 |
+
class FAQEmbedder:
|
9 |
+
def __init__(self, model_name: str = "all-MiniLM-L6-v2"):
|
10 |
+
"""
|
11 |
+
Initialize the FAQ embedder with a sentence transformer model
|
12 |
+
Optimized for memory efficiency
|
13 |
+
"""
|
14 |
+
print(f"Initializing FAQ Embedder with model: {model_name}")
|
15 |
+
# Use CPU for embedding model to save GPU memory for LLM
|
16 |
+
self.device = "cpu"
|
17 |
+
self.model = SentenceTransformer(model_name, device=self.device)
|
18 |
+
self.index = None
|
19 |
+
self.faqs = None
|
20 |
+
self.embeddings = None
|
21 |
+
|
22 |
+
def create_embeddings(self, faqs: List[Dict[str, Any]], batch_size: int = 32) -> None:
|
23 |
+
"""
|
24 |
+
Create embeddings for all FAQs and build FAISS index
|
25 |
+
Using batching for memory efficiency
|
26 |
+
"""
|
27 |
+
self.faqs = faqs
|
28 |
+
print(f"Creating embeddings for {len(faqs)} FAQs in batches of {batch_size}...")
|
29 |
+
|
30 |
+
# Extract questions for embedding
|
31 |
+
questions = [faq['question'] for faq in faqs]
|
32 |
+
|
33 |
+
# Process in batches to reduce memory usage
|
34 |
+
all_embeddings = []
|
35 |
+
for i in range(0, len(questions), batch_size):
|
36 |
+
batch = questions[i:i+batch_size]
|
37 |
+
print(f"Processing batch {i//batch_size + 1}/{(len(questions) + batch_size - 1)//batch_size}")
|
38 |
+
|
39 |
+
# Create embeddings for this batch
|
40 |
+
batch_embeddings = self.model.encode(batch, show_progress_bar=False, convert_to_numpy=True)
|
41 |
+
all_embeddings.append(batch_embeddings)
|
42 |
+
|
43 |
+
# Combine all batches
|
44 |
+
self.embeddings = np.vstack(all_embeddings).astype('float32')
|
45 |
+
|
46 |
+
# Clear memory explicitly
|
47 |
+
all_embeddings = None
|
48 |
+
gc.collect()
|
49 |
+
|
50 |
+
# Create FAISS index
|
51 |
+
dimension = self.embeddings.shape[1]
|
52 |
+
self.index = faiss.IndexFlatL2(dimension)
|
53 |
+
self.index.add(self.embeddings)
|
54 |
+
|
55 |
+
print(f"Created embeddings of shape {self.embeddings.shape}")
|
56 |
+
print(f"FAISS index contains {self.index.ntotal} vectors")
|
57 |
+
|
58 |
+
def retrieve_relevant_faqs(self, query: str, k: int = 3) -> List[Dict[str, Any]]:
|
59 |
+
"""
|
60 |
+
Retrieve top-k relevant FAQs for a given query
|
61 |
+
"""
|
62 |
+
if self.index is None or self.faqs is None or self.embeddings is None:
|
63 |
+
raise ValueError("Embeddings not created yet. Call create_embeddings first.")
|
64 |
+
|
65 |
+
# Embed the query
|
66 |
+
query_embedding = self.model.encode([query], convert_to_numpy=True).astype('float32')
|
67 |
+
|
68 |
+
# Search in FAISS
|
69 |
+
distances, indices = self.index.search(query_embedding, k)
|
70 |
+
|
71 |
+
# Get the relevant FAQs with their similarity scores
|
72 |
+
relevant_faqs = []
|
73 |
+
for i, idx in enumerate(indices[0]):
|
74 |
+
if idx < len(self.faqs): # Ensure index is valid
|
75 |
+
faq = self.faqs[idx].copy()
|
76 |
+
# Convert L2 distance to similarity score (higher is better)
|
77 |
+
similarity = 1.0 / (1.0 + distances[0][i])
|
78 |
+
faq['similarity'] = similarity
|
79 |
+
relevant_faqs.append(faq)
|
80 |
+
|
81 |
+
return relevant_faqs
|
src/llm_response.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
2 |
+
import torch
|
3 |
+
from typing import List, Dict, Any
|
4 |
+
import gc
|
5 |
+
|
6 |
+
class ResponseGenerator:
|
7 |
+
def __init__(self, model_name: str = "mistralai/Mistral-7B-Instruct-v0.1"):
|
8 |
+
"""
|
9 |
+
Initialize the response generator with an LLM
|
10 |
+
Optimized for 8-11GB GPU
|
11 |
+
"""
|
12 |
+
print(f"Loading LLM: {model_name}")
|
13 |
+
print("This may take a few minutes...")
|
14 |
+
|
15 |
+
# Load tokenizer
|
16 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
17 |
+
|
18 |
+
# Configure device and data type based on available resources
|
19 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
20 |
+
print(f"Using device: {device}")
|
21 |
+
|
22 |
+
# Free up memory before loading model
|
23 |
+
gc.collect()
|
24 |
+
if device == "cuda":
|
25 |
+
torch.cuda.empty_cache()
|
26 |
+
|
27 |
+
# Configure 4-bit quantization for maximum memory efficiency
|
28 |
+
try:
|
29 |
+
# Use 4-bit quantization for models that support it
|
30 |
+
quantization_config = BitsAndBytesConfig(
|
31 |
+
load_in_4bit=True,
|
32 |
+
bnb_4bit_compute_dtype=torch.float16,
|
33 |
+
bnb_4bit_use_double_quant=True,
|
34 |
+
bnb_4bit_quant_type="nf4"
|
35 |
+
)
|
36 |
+
|
37 |
+
# Load the model with quantization
|
38 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
39 |
+
model_name,
|
40 |
+
quantization_config=quantization_config,
|
41 |
+
device_map="auto",
|
42 |
+
torch_dtype=torch.float16,
|
43 |
+
# Load model in parts to avoid OOM errors
|
44 |
+
max_memory={0: "8GiB", "cpu": "16GiB"},
|
45 |
+
offload_folder="offload",
|
46 |
+
offload_state_dict=True, # Offload weights to CPU when not in use
|
47 |
+
low_cpu_mem_usage=True
|
48 |
+
)
|
49 |
+
except Exception as e:
|
50 |
+
print(f"4-bit quantization error: {e}")
|
51 |
+
print("Falling back to 8-bit quantization...")
|
52 |
+
|
53 |
+
try:
|
54 |
+
# Try 8-bit quantization
|
55 |
+
quantization_config = BitsAndBytesConfig(
|
56 |
+
load_in_8bit=True,
|
57 |
+
bnb_8bit_use_double_quant=True
|
58 |
+
)
|
59 |
+
|
60 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
61 |
+
model_name,
|
62 |
+
quantization_config=quantization_config,
|
63 |
+
device_map="auto",
|
64 |
+
torch_dtype=torch.float16,
|
65 |
+
max_memory={0: "8GiB", "cpu": "16GiB"},
|
66 |
+
offload_folder="offload",
|
67 |
+
low_cpu_mem_usage=True
|
68 |
+
)
|
69 |
+
except Exception as e2:
|
70 |
+
print(f"8-bit quantization error: {e2}")
|
71 |
+
print("Falling back to smaller model...")
|
72 |
+
|
73 |
+
# Use a much smaller model as fallback
|
74 |
+
backup_model = "microsoft/phi-2"
|
75 |
+
self.tokenizer = AutoTokenizer.from_pretrained(backup_model)
|
76 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
77 |
+
backup_model,
|
78 |
+
device_map="auto",
|
79 |
+
torch_dtype=torch.float16 if device == "cuda" else torch.float32
|
80 |
+
)
|
81 |
+
|
82 |
+
print("LLM loaded successfully")
|
83 |
+
|
84 |
+
def generate_response(self, query: str, relevant_faqs: List[Dict[str, Any]]) -> str:
|
85 |
+
"""
|
86 |
+
Generate a response using the LLM with retrieved FAQs as context
|
87 |
+
Memory-optimized version
|
88 |
+
"""
|
89 |
+
# Create prompt with relevant FAQs
|
90 |
+
prompt = self._create_prompt(query, relevant_faqs)
|
91 |
+
|
92 |
+
# Generate response with memory-efficient settings
|
93 |
+
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
|
94 |
+
|
95 |
+
with torch.no_grad():
|
96 |
+
# Use more conservative generation parameters
|
97 |
+
outputs = self.model.generate(
|
98 |
+
**inputs,
|
99 |
+
max_new_tokens=200, # Shorter response for memory efficiency
|
100 |
+
temperature=0.7,
|
101 |
+
top_p=0.9,
|
102 |
+
do_sample=True,
|
103 |
+
pad_token_id=self.tokenizer.eos_token_id
|
104 |
+
)
|
105 |
+
|
106 |
+
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
107 |
+
|
108 |
+
# Extract just the response part (after the prompt)
|
109 |
+
response = response[len(prompt):].strip()
|
110 |
+
|
111 |
+
# Clear GPU memory after generating response
|
112 |
+
if torch.cuda.is_available():
|
113 |
+
torch.cuda.empty_cache()
|
114 |
+
|
115 |
+
return response
|
116 |
+
|
117 |
+
def _create_prompt(self, query: str, relevant_faqs: List[Dict[str, Any]]) -> str:
|
118 |
+
"""
|
119 |
+
Create a prompt for the LLM with retrieved FAQs as context
|
120 |
+
"""
|
121 |
+
# Format FAQs in a way that's suitable for the model
|
122 |
+
faq_context = "\n\n".join([
|
123 |
+
f"Q: {faq['question']}\nA: {faq['answer']}"
|
124 |
+
for faq in relevant_faqs
|
125 |
+
])
|
126 |
+
|
127 |
+
# Create the prompt
|
128 |
+
prompt = f"""
|
129 |
+
Below are some relevant e-commerce customer support FAQ entries:
|
130 |
+
|
131 |
+
{faq_context}
|
132 |
+
|
133 |
+
Based on the information above, please provide a helpful, accurate, and concise response to the following customer query:
|
134 |
+
Customer Query: {query}
|
135 |
+
|
136 |
+
Response:
|
137 |
+
"""
|
138 |
+
return prompt
|
src/utils.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import functools
|
3 |
+
from typing import Callable, Any, Dict
|
4 |
+
|
5 |
+
def time_function(func: Callable) -> Callable:
|
6 |
+
"""
|
7 |
+
Decorator to time function execution
|
8 |
+
"""
|
9 |
+
@functools.wraps(func)
|
10 |
+
def wrapper(*args, **kwargs):
|
11 |
+
start_time = time.time()
|
12 |
+
result = func(*args, **kwargs)
|
13 |
+
end_time = time.time()
|
14 |
+
print(f"{func.__name__} took {end_time - start_time:.2f} seconds to execute")
|
15 |
+
return result
|
16 |
+
return wrapper
|
17 |
+
|
18 |
+
def evaluate_response(generated_response: str, ground_truth: str = None) -> Dict[str, Any]:
|
19 |
+
"""
|
20 |
+
Basic evaluation of generated response
|
21 |
+
"""
|
22 |
+
results = {
|
23 |
+
"length": len(generated_response),
|
24 |
+
"word_count": len(generated_response.split())
|
25 |
+
}
|
26 |
+
|
27 |
+
# If ground truth is provided, we could add metrics like BLEU or ROUGE
|
28 |
+
if ground_truth:
|
29 |
+
# Simplified evaluation - word overlap
|
30 |
+
generated_words = set(generated_response.lower().split())
|
31 |
+
ground_truth_words = set(ground_truth.lower().split())
|
32 |
+
overlap = len(generated_words.intersection(ground_truth_words))
|
33 |
+
results["word_overlap"] = overlap / len(ground_truth_words) if ground_truth_words else 0
|
34 |
+
|
35 |
+
return results
|
36 |
+
|
37 |
+
def format_memory_stats():
|
38 |
+
"""
|
39 |
+
Format memory usage statistics for display
|
40 |
+
"""
|
41 |
+
import torch
|
42 |
+
import psutil
|
43 |
+
import os
|
44 |
+
|
45 |
+
# System memory
|
46 |
+
system_stats = {
|
47 |
+
"RAM": f"{psutil.virtual_memory().used / (1024 ** 3):.1f}GB / {psutil.virtual_memory().total / (1024 ** 3):.1f}GB",
|
48 |
+
"RAM Usage": f"{psutil.virtual_memory().percent}%",
|
49 |
+
}
|
50 |
+
|
51 |
+
# GPU memory if available
|
52 |
+
if torch.cuda.is_available():
|
53 |
+
gpu_stats = {}
|
54 |
+
for i in range(torch.cuda.device_count()):
|
55 |
+
gpu_stats[f"GPU {i}"] = f"{torch.cuda.get_device_name(i)}"
|
56 |
+
gpu_stats[f"GPU {i} Memory"] = f"{torch.cuda.memory_allocated(i) / (1024 ** 3):.1f}GB / {torch.cuda.get_device_properties(i).total_memory / (1024 ** 3):.1f}GB"
|
57 |
+
system_stats.update(gpu_stats)
|
58 |
+
|
59 |
+
return system_stats
|