Spaces:
Sleeping
Sleeping
changes:minor changes
Browse files- .gitignore +22 -3
- app.py +135 -120
- data/faq_data.csv +6 -1
- notes.txt +0 -16
- requirements.txt +8 -2
- src/__init__.py +1 -2
- src/data_processing.py +98 -44
- src/embedding.py +31 -23
- src/llm_response.py +35 -69
- src/utils.py +54 -13
- test_set.json +7 -0
.gitignore
CHANGED
@@ -1,3 +1,22 @@
|
|
1 |
-
#
|
2 |
-
venv
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Python virtual environment
|
2 |
+
venv/
|
3 |
+
|
4 |
+
# Python cache files
|
5 |
+
__pycache__/
|
6 |
+
|
7 |
+
# Data and model files
|
8 |
+
data/ecommerce_faqs.json
|
9 |
+
embeddings/
|
10 |
+
offload/
|
11 |
+
*.bin
|
12 |
+
|
13 |
+
# Feedback data
|
14 |
+
feedback.json
|
15 |
+
|
16 |
+
# NLTK data
|
17 |
+
nltk_data/
|
18 |
+
|
19 |
+
# Streamlit cache
|
20 |
+
.venv/
|
21 |
+
streamlit_cache/
|
22 |
+
|
app.py
CHANGED
@@ -3,233 +3,248 @@ import time
|
|
3 |
import os
|
4 |
import gc
|
5 |
import torch
|
6 |
-
from src.data_processing import load_huggingface_faq_data, load_faq_data, preprocess_faq
|
7 |
from src.embedding import FAQEmbedder
|
8 |
from src.llm_response import ResponseGenerator
|
9 |
-
from src.utils import time_function, format_memory_stats
|
10 |
|
11 |
-
#
|
12 |
-
|
13 |
-
|
14 |
-
layout="wide",
|
15 |
-
initial_sidebar_state="expanded"
|
16 |
-
)
|
17 |
|
18 |
-
|
19 |
-
gc.collect()
|
20 |
-
if torch.cuda.is_available():
|
21 |
-
torch.cuda.empty_cache()
|
22 |
|
23 |
@time_function
|
24 |
-
def initialize_components(use_huggingface: bool = True, model_name: str = "
|
25 |
-
"""
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
torch.cuda.
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
|
|
50 |
|
51 |
def main():
|
52 |
st.title("E-Commerce Customer Support FAQ Chatbot")
|
53 |
-
st.subheader("Ask
|
54 |
|
55 |
-
# Sidebar configuration
|
56 |
st.sidebar.title("Configuration")
|
57 |
use_huggingface = st.sidebar.checkbox("Use Hugging Face Dataset", value=True)
|
|
|
|
|
58 |
|
59 |
-
# Model options - include smaller models by default
|
60 |
model_options = {
|
61 |
-
"Phi-2 (Recommended for
|
62 |
-
"TinyLlama-1.1B (
|
63 |
-
"Mistral-7B (
|
64 |
}
|
65 |
-
|
66 |
-
# Default to Phi-2 for 8-11GB GPU
|
67 |
selected_model = st.sidebar.selectbox("Select LLM Model", list(model_options.keys()), index=0)
|
68 |
model_name = model_options[selected_model]
|
69 |
|
70 |
-
# Memory usage monitoring
|
71 |
if st.sidebar.checkbox("Show Memory Usage", value=True):
|
72 |
st.sidebar.subheader("Memory Usage")
|
73 |
-
|
74 |
-
for key, value in memory_stats.items():
|
75 |
st.sidebar.text(f"{key}: {value}")
|
76 |
|
77 |
-
# Initialize session state for chat history if it doesn't exist
|
78 |
if "chat_history" not in st.session_state:
|
79 |
st.session_state.chat_history = []
|
|
|
|
|
|
|
|
|
80 |
|
81 |
-
# Initialize RAG components (only once)
|
82 |
if "system_initialized" not in st.session_state or st.sidebar.button("Reload System"):
|
83 |
-
with st.spinner("Initializing system
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
|
|
|
|
|
|
|
|
|
|
90 |
|
91 |
-
# Chat interface
|
92 |
col1, col2 = st.columns([2, 1])
|
93 |
|
94 |
with col1:
|
95 |
-
# Display chat history
|
96 |
st.subheader("Conversation")
|
97 |
chat_container = st.container(height=400)
|
98 |
-
|
99 |
with chat_container:
|
100 |
for i, message in enumerate(st.session_state.chat_history):
|
101 |
if message["role"] == "user":
|
102 |
st.markdown(f"**You**: {message['content']}")
|
103 |
else:
|
104 |
st.markdown(f"**Bot**: {message['content']}")
|
105 |
-
|
106 |
if i < len(st.session_state.chat_history) - 1:
|
107 |
st.markdown("---")
|
108 |
-
|
109 |
-
# Chat input
|
110 |
with st.form(key="chat_form"):
|
111 |
-
user_query = st.text_input("Type your question:", key="user_input",
|
112 |
-
placeholder="e.g., How do I track my order?")
|
113 |
submit_button = st.form_submit_button("Ask")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
|
115 |
with col2:
|
116 |
if st.session_state.get("system_initialized", False):
|
117 |
-
# Show FAQ metadata and information
|
118 |
st.subheader("Retrieved Information")
|
119 |
info_container = st.container(height=500)
|
120 |
-
|
121 |
with info_container:
|
122 |
if "current_faqs" in st.session_state:
|
123 |
-
for i, faq in enumerate(st.session_state.
|
124 |
st.markdown(f"**Relevant FAQ #{i+1}**")
|
125 |
st.markdown(f"**Q**: {faq['question']}")
|
126 |
-
# Limit answer length to save UI memory
|
127 |
st.markdown(f"**A**: {faq['answer'][:150]}..." if len(faq['answer']) > 150 else f"**A**: {faq['answer']}")
|
128 |
st.markdown(f"*Similarity Score*: {faq['similarity']:.2f}")
|
129 |
if 'category' in faq and faq['category']:
|
130 |
st.markdown(f"*Category*: {faq['category']}")
|
131 |
st.markdown("---")
|
132 |
else:
|
133 |
-
st.markdown("Ask a question to see relevant FAQs
|
134 |
|
135 |
-
# Performance metrics in the sidebar
|
136 |
if "retrieval_time" in st.session_state and "generation_time" in st.session_state:
|
137 |
st.sidebar.subheader("Performance Metrics")
|
138 |
st.sidebar.markdown(f"Retrieval time: {st.session_state.retrieval_time:.2f} seconds")
|
139 |
st.sidebar.markdown(f"Response generation: {st.session_state.generation_time:.2f} seconds")
|
140 |
st.sidebar.markdown(f"Total time: {st.session_state.retrieval_time + st.session_state.generation_time:.2f} seconds")
|
141 |
|
142 |
-
# Process user query
|
143 |
if submit_button and user_query:
|
144 |
-
|
145 |
-
|
|
|
|
|
|
|
|
|
|
|
146 |
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
gc.collect()
|
151 |
if torch.cuda.is_available():
|
152 |
torch.cuda.empty_cache()
|
153 |
-
|
154 |
-
# Step 1: Retrieve relevant FAQs
|
155 |
start_time = time.time()
|
156 |
-
relevant_faqs = st.session_state.embedder.retrieve_relevant_faqs(
|
157 |
retrieval_time = time.time() - start_time
|
158 |
|
159 |
-
|
|
|
|
|
160 |
start_time = time.time()
|
161 |
-
response = st.session_state.response_generator.generate_response(
|
162 |
generation_time = time.time() - start_time
|
163 |
|
164 |
-
|
|
|
|
|
|
|
165 |
st.session_state.retrieval_time = retrieval_time
|
166 |
st.session_state.generation_time = generation_time
|
167 |
st.session_state.current_faqs = relevant_faqs
|
168 |
-
|
169 |
-
# Step 3: Add response to chat history
|
170 |
-
st.session_state.chat_history.append({"role": "assistant", "content": response})
|
171 |
-
|
172 |
-
# Free memory after processing
|
173 |
-
gc.collect()
|
174 |
-
if torch.cuda.is_available():
|
175 |
-
torch.cuda.empty_cache()
|
176 |
|
177 |
-
|
178 |
-
st.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
|
180 |
-
# Add sample questions at the bottom
|
181 |
st.subheader("Sample Questions")
|
182 |
sample_questions = [
|
183 |
"How do I track my order?",
|
184 |
"What should I do if my delivery is delayed?",
|
185 |
"How do I return a product?",
|
186 |
"Can I cancel my order after placing it?",
|
187 |
-
"How quickly will my order be delivered?"
|
188 |
-
"Why can't I track my order yet?"
|
189 |
]
|
190 |
-
|
191 |
-
# Use two columns instead of three to reduce memory usage
|
192 |
cols = st.columns(2)
|
193 |
for i, question in enumerate(sample_questions):
|
194 |
col_idx = i % 2
|
195 |
if cols[col_idx].button(question, key=f"sample_{i}"):
|
196 |
-
# Clear the text input and set the sample question
|
197 |
st.session_state.user_input = question
|
198 |
-
# Simulate form submission
|
199 |
st.session_state.chat_history.append({"role": "user", "content": question})
|
200 |
|
201 |
-
|
202 |
-
|
203 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
204 |
gc.collect()
|
205 |
if torch.cuda.is_available():
|
206 |
torch.cuda.empty_cache()
|
207 |
-
|
208 |
-
# Step 1: Retrieve relevant FAQs
|
209 |
start_time = time.time()
|
210 |
-
relevant_faqs = st.session_state.embedder.retrieve_relevant_faqs(
|
211 |
retrieval_time = time.time() - start_time
|
212 |
|
213 |
-
|
|
|
|
|
214 |
start_time = time.time()
|
215 |
-
response = st.session_state.response_generator.generate_response(
|
216 |
generation_time = time.time() - start_time
|
217 |
|
218 |
-
|
|
|
|
|
|
|
219 |
st.session_state.retrieval_time = retrieval_time
|
220 |
st.session_state.generation_time = generation_time
|
221 |
st.session_state.current_faqs = relevant_faqs
|
222 |
-
|
223 |
-
# Step 3: Add response to chat history
|
224 |
-
st.session_state.chat_history.append({"role": "assistant", "content": response})
|
225 |
-
|
226 |
-
# Free memory after processing
|
227 |
-
gc.collect()
|
228 |
-
if torch.cuda.is_available():
|
229 |
-
torch.cuda.empty_cache()
|
230 |
|
231 |
-
|
232 |
-
st.experimental_rerun()
|
233 |
|
234 |
if __name__ == "__main__":
|
235 |
main()
|
|
|
3 |
import os
|
4 |
import gc
|
5 |
import torch
|
6 |
+
from src.data_processing import load_huggingface_faq_data, load_faq_data, preprocess_faq, augment_faqs
|
7 |
from src.embedding import FAQEmbedder
|
8 |
from src.llm_response import ResponseGenerator
|
9 |
+
from src.utils import time_function, format_memory_stats, evaluate_response, evaluate_retrieval, baseline_keyword_search
|
10 |
|
11 |
+
# Suppress CUDA warning and Torch path errors
|
12 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
13 |
+
os.environ["TORCH_NO_PATH_CHECK"] = "1"
|
|
|
|
|
|
|
14 |
|
15 |
+
st.set_page_config(page_title="E-Commerce FAQ Chatbot", layout="wide", initial_sidebar_state="expanded")
|
|
|
|
|
|
|
16 |
|
17 |
@time_function
|
18 |
+
def initialize_components(use_huggingface: bool = True, model_name: str = "microsoft/phi-2", enable_augmentation: bool = True):
|
19 |
+
"""
|
20 |
+
Initialize RAG system components
|
21 |
+
"""
|
22 |
+
try:
|
23 |
+
if use_huggingface:
|
24 |
+
faqs = load_huggingface_faq_data("NebulaByte/E-Commerce_FAQs")
|
25 |
+
else:
|
26 |
+
faqs = load_faq_data("data/faq_data.csv")
|
27 |
+
|
28 |
+
processed_faqs = augment_faqs(preprocess_faq(faqs), enable_augmentation=enable_augmentation)
|
29 |
+
embedder = FAQEmbedder()
|
30 |
+
|
31 |
+
if os.path.exists("embeddings"):
|
32 |
+
embedder.load("embeddings")
|
33 |
+
else:
|
34 |
+
embedder.create_embeddings(processed_faqs)
|
35 |
+
embedder.save("embeddings")
|
36 |
+
|
37 |
+
gc.collect()
|
38 |
+
if torch.cuda.is_available():
|
39 |
+
torch.cuda.empty_cache()
|
40 |
+
|
41 |
+
response_generator = ResponseGenerator(model_name=model_name)
|
42 |
+
response_generator.generate_response("Warmup query", [{"question": "Test", "answer": "Test"}])
|
43 |
+
|
44 |
+
return embedder, response_generator, len(processed_faqs)
|
45 |
+
except Exception as e:
|
46 |
+
st.error(f"Initialization failed: {e}")
|
47 |
+
raise
|
48 |
|
49 |
def main():
|
50 |
st.title("E-Commerce Customer Support FAQ Chatbot")
|
51 |
+
st.subheader("Ask about orders, shipping, returns, or other e-commerce queries")
|
52 |
|
|
|
53 |
st.sidebar.title("Configuration")
|
54 |
use_huggingface = st.sidebar.checkbox("Use Hugging Face Dataset", value=True)
|
55 |
+
enable_augmentation = st.sidebar.checkbox("Enable FAQ Augmentation", value=True, help="Generate paraphrased questions to expand dataset")
|
56 |
+
target_lang = st.sidebar.selectbox("Language", ["en", "es", "fr"], index=0)
|
57 |
|
|
|
58 |
model_options = {
|
59 |
+
"Phi-2 (Recommended for 16GB RAM)": "microsoft/phi-2",
|
60 |
+
"TinyLlama-1.1B (Fastest)": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
61 |
+
"Mistral-7B (For 15GB+ GPU)": "mistralai/Mistral-7B-Instruct-v0.1"
|
62 |
}
|
|
|
|
|
63 |
selected_model = st.sidebar.selectbox("Select LLM Model", list(model_options.keys()), index=0)
|
64 |
model_name = model_options[selected_model]
|
65 |
|
|
|
66 |
if st.sidebar.checkbox("Show Memory Usage", value=True):
|
67 |
st.sidebar.subheader("Memory Usage")
|
68 |
+
for key, value in format_memory_stats().items():
|
|
|
69 |
st.sidebar.text(f"{key}: {value}")
|
70 |
|
|
|
71 |
if "chat_history" not in st.session_state:
|
72 |
st.session_state.chat_history = []
|
73 |
+
if "query_cache" not in st.session_state:
|
74 |
+
st.session_state.query_cache = {}
|
75 |
+
if "feedback" not in st.session_state:
|
76 |
+
st.session_state.feedback = []
|
77 |
|
|
|
78 |
if "system_initialized" not in st.session_state or st.sidebar.button("Reload System"):
|
79 |
+
with st.spinner("Initializing system..."):
|
80 |
+
try:
|
81 |
+
st.session_state.embedder, st.session_state.response_generator, num_faqs = initialize_components(
|
82 |
+
use_huggingface=use_huggingface,
|
83 |
+
model_name=model_name,
|
84 |
+
enable_augmentation=enable_augmentation
|
85 |
+
)
|
86 |
+
st.session_state.system_initialized = True
|
87 |
+
st.sidebar.success(f"System initialized with {num_faqs} FAQs!")
|
88 |
+
except Exception as e:
|
89 |
+
st.error(f"System initialization failed: {e}")
|
90 |
+
return
|
91 |
|
|
|
92 |
col1, col2 = st.columns([2, 1])
|
93 |
|
94 |
with col1:
|
|
|
95 |
st.subheader("Conversation")
|
96 |
chat_container = st.container(height=400)
|
|
|
97 |
with chat_container:
|
98 |
for i, message in enumerate(st.session_state.chat_history):
|
99 |
if message["role"] == "user":
|
100 |
st.markdown(f"**You**: {message['content']}")
|
101 |
else:
|
102 |
st.markdown(f"**Bot**: {message['content']}")
|
|
|
103 |
if i < len(st.session_state.chat_history) - 1:
|
104 |
st.markdown("---")
|
105 |
+
|
|
|
106 |
with st.form(key="chat_form"):
|
107 |
+
user_query = st.text_input("Type your question:", key="user_input", placeholder="e.g., How do I track my order?")
|
|
|
108 |
submit_button = st.form_submit_button("Ask")
|
109 |
+
|
110 |
+
if len(st.session_state.chat_history) > 0:
|
111 |
+
with st.form(key=f"feedback_form_{len(st.session_state.chat_history)}"):
|
112 |
+
rating = st.slider("Rate this response (1-5)", 1, 5, key=f"rating_{len(st.session_state.chat_history)}")
|
113 |
+
comments = st.text_area("Comments", key=f"comments_{len(st.session_state.chat_history)}")
|
114 |
+
if st.form_submit_button("Submit Feedback"):
|
115 |
+
st.session_state.feedback.append({
|
116 |
+
"rating": rating,
|
117 |
+
"comments": comments,
|
118 |
+
"response": st.session_state.chat_history[-1]["content"]
|
119 |
+
})
|
120 |
+
with open("feedback.json", "w") as f:
|
121 |
+
json.dump(st.session_state.feedback, f)
|
122 |
+
st.success("Feedback submitted!")
|
123 |
|
124 |
with col2:
|
125 |
if st.session_state.get("system_initialized", False):
|
|
|
126 |
st.subheader("Retrieved Information")
|
127 |
info_container = st.container(height=500)
|
|
|
128 |
with info_container:
|
129 |
if "current_faqs" in st.session_state:
|
130 |
+
for i, faq in enumerate(st.session_state.current_faqs):
|
131 |
st.markdown(f"**Relevant FAQ #{i+1}**")
|
132 |
st.markdown(f"**Q**: {faq['question']}")
|
|
|
133 |
st.markdown(f"**A**: {faq['answer'][:150]}..." if len(faq['answer']) > 150 else f"**A**: {faq['answer']}")
|
134 |
st.markdown(f"*Similarity Score*: {faq['similarity']:.2f}")
|
135 |
if 'category' in faq and faq['category']:
|
136 |
st.markdown(f"*Category*: {faq['category']}")
|
137 |
st.markdown("---")
|
138 |
else:
|
139 |
+
st.markdown("Ask a question to see relevant FAQs.")
|
140 |
|
|
|
141 |
if "retrieval_time" in st.session_state and "generation_time" in st.session_state:
|
142 |
st.sidebar.subheader("Performance Metrics")
|
143 |
st.sidebar.markdown(f"Retrieval time: {st.session_state.retrieval_time:.2f} seconds")
|
144 |
st.sidebar.markdown(f"Response generation: {st.session_state.generation_time:.2f} seconds")
|
145 |
st.sidebar.markdown(f"Total time: {st.session_state.retrieval_time + st.session_state.generation_time:.2f} seconds")
|
146 |
|
|
|
147 |
if submit_button and user_query:
|
148 |
+
from src.data_processing import translate_faq
|
149 |
+
from googletrans import Translator
|
150 |
+
translator = Translator()
|
151 |
+
if target_lang != "en":
|
152 |
+
user_query_translated = translator.translate(user_query, dest="en").text
|
153 |
+
else:
|
154 |
+
user_query_translated = user_query
|
155 |
|
156 |
+
if user_query_translated in st.session_state.query_cache:
|
157 |
+
response, relevant_faqs = st.session_state.query_cache[user_query_translated]
|
158 |
+
else:
|
159 |
gc.collect()
|
160 |
if torch.cuda.is_available():
|
161 |
torch.cuda.empty_cache()
|
162 |
+
|
|
|
163 |
start_time = time.time()
|
164 |
+
relevant_faqs = st.session_state.embedder.retrieve_relevant_faqs(user_query_translated)
|
165 |
retrieval_time = time.time() - start_time
|
166 |
|
167 |
+
if target_lang != "en":
|
168 |
+
relevant_faqs = [translate_faq(faq, target_lang) for faq in relevant_faqs]
|
169 |
+
|
170 |
start_time = time.time()
|
171 |
+
response = st.session_state.response_generator.generate_response(user_query_translated, relevant_faqs)
|
172 |
generation_time = time.time() - start_time
|
173 |
|
174 |
+
if target_lang != "en":
|
175 |
+
response = translator.translate(response, dest=target_lang).text
|
176 |
+
|
177 |
+
st.session_state.query_cache[user_query_translated] = (response, relevant_faqs)
|
178 |
st.session_state.retrieval_time = retrieval_time
|
179 |
st.session_state.generation_time = generation_time
|
180 |
st.session_state.current_faqs = relevant_faqs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
|
182 |
+
st.session_state.chat_history.append({"role": "user", "content": user_query})
|
183 |
+
st.session_state.chat_history.append({"role": "assistant", "content": response})
|
184 |
+
|
185 |
+
if st.button("Clear Chat History"):
|
186 |
+
st.session_state.chat_history = []
|
187 |
+
st.session_state.query_cache = {}
|
188 |
+
gc.collect()
|
189 |
+
if torch.cuda.is_available():
|
190 |
+
torch.cuda.empty_cache()
|
191 |
+
|
192 |
+
if st.session_state.get("system_initialized", False):
|
193 |
+
st.sidebar.subheader("Baseline Comparison")
|
194 |
+
baseline_faqs = baseline_keyword_search(user_query_translated if 'user_query_translated' in locals() else "", st.session_state.embedder.faqs)
|
195 |
+
st.sidebar.write(f"RAG FAQs: {[faq['question'][:50] for faq in st.session_state.get('current_faqs', [])]}")
|
196 |
+
st.sidebar.write(f"Keyword FAQs: {[faq['question'][:50] for faq in baseline_faqs]}")
|
197 |
|
|
|
198 |
st.subheader("Sample Questions")
|
199 |
sample_questions = [
|
200 |
"How do I track my order?",
|
201 |
"What should I do if my delivery is delayed?",
|
202 |
"How do I return a product?",
|
203 |
"Can I cancel my order after placing it?",
|
204 |
+
"How quickly will my order be delivered?"
|
|
|
205 |
]
|
|
|
|
|
206 |
cols = st.columns(2)
|
207 |
for i, question in enumerate(sample_questions):
|
208 |
col_idx = i % 2
|
209 |
if cols[col_idx].button(question, key=f"sample_{i}"):
|
|
|
210 |
st.session_state.user_input = question
|
|
|
211 |
st.session_state.chat_history.append({"role": "user", "content": question})
|
212 |
|
213 |
+
from src.data_processing import translate_faq
|
214 |
+
from googletrans import Translator
|
215 |
+
translator = Translator()
|
216 |
+
if target_lang != "en":
|
217 |
+
question_translated = translator.translate(question, dest="en").text
|
218 |
+
else:
|
219 |
+
question_translated = question
|
220 |
+
|
221 |
+
if question_translated in st.session_state.query_cache:
|
222 |
+
response, relevant_faqs = st.session_state.query_cache[question_translated]
|
223 |
+
else:
|
224 |
gc.collect()
|
225 |
if torch.cuda.is_available():
|
226 |
torch.cuda.empty_cache()
|
227 |
+
|
|
|
228 |
start_time = time.time()
|
229 |
+
relevant_faqs = st.session_state.embedder.retrieve_relevant_faqs(question_translated)
|
230 |
retrieval_time = time.time() - start_time
|
231 |
|
232 |
+
if target_lang != "en":
|
233 |
+
relevant_faqs = [translate_faq(faq, target_lang) for faq in relevant_faqs]
|
234 |
+
|
235 |
start_time = time.time()
|
236 |
+
response = st.session_state.response_generator.generate_response(question_translated, relevant_faqs)
|
237 |
generation_time = time.time() - start_time
|
238 |
|
239 |
+
if target_lang != "en":
|
240 |
+
response = translator.translate(response, dest=target_lang).text
|
241 |
+
|
242 |
+
st.session_state.query_cache[question_translated] = (response, relevant_faqs)
|
243 |
st.session_state.retrieval_time = retrieval_time
|
244 |
st.session_state.generation_time = generation_time
|
245 |
st.session_state.current_faqs = relevant_faqs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
246 |
|
247 |
+
st.session_state.chat_history.append({"role": "assistant", "content": response})
|
|
|
248 |
|
249 |
if __name__ == "__main__":
|
250 |
main()
|
data/faq_data.csv
CHANGED
@@ -8,4 +8,9 @@ question,answer
|
|
8 |
"How do I cancel my subscription?","To cancel your subscription, log in to your account, go to 'Subscription Settings', and click on 'Cancel Subscription'. Follow the prompts to complete the cancellation process."
|
9 |
"Is my personal information secure?","Yes, we take data security seriously. We use industry-standard encryption protocols to protect your personal information. We never share your data with third parties without your consent."
|
10 |
"How do I contact customer support?","You can contact our customer support team via email at support@example.com, through the live chat on our website, or by calling our support line at 1-800-123-4567 during business hours (9 AM - 5 PM EST, Monday-Friday)."
|
11 |
-
"Do you offer discounts for bulk orders?","Yes, we offer volume discounts for bulk orders. Please contact our sales team at sales@example.com with your requirements to get a custom quote."
|
|
|
|
|
|
|
|
|
|
|
|
8 |
"How do I cancel my subscription?","To cancel your subscription, log in to your account, go to 'Subscription Settings', and click on 'Cancel Subscription'. Follow the prompts to complete the cancellation process."
|
9 |
"Is my personal information secure?","Yes, we take data security seriously. We use industry-standard encryption protocols to protect your personal information. We never share your data with third parties without your consent."
|
10 |
"How do I contact customer support?","You can contact our customer support team via email at support@example.com, through the live chat on our website, or by calling our support line at 1-800-123-4567 during business hours (9 AM - 5 PM EST, Monday-Friday)."
|
11 |
+
"Do you offer discounts for bulk orders?","Yes, we offer volume discounts for bulk orders. Please contact our sales team at sales@example.com with your requirements to get a custom quote."
|
12 |
+
"Can I change my shipping address?","You can change your shipping address before the order is processed by contacting support at support@example.com."
|
13 |
+
"What happens if my order is damaged?","If your order arrives damaged, please contact us within 7 days with photos of the damage to initiate a replacement or refund."
|
14 |
+
"How do I apply a discount code?","Enter your discount code at checkout in the 'Promo Code' field. The discount will be applied to eligible items in your cart."
|
15 |
+
"What if I receive the wrong item?","If you receive the wrong item, contact support within 7 days with your order number and photos of the item received. We'll arrange a replacement or refund."
|
16 |
+
"Do you offer gift cards?","Yes, we offer digital gift cards in various denominations. Purchase them on our website and send them via email to the recipient."
|
notes.txt
DELETED
@@ -1,16 +0,0 @@
|
|
1 |
-
# Create a virtual environment (recommended)
|
2 |
-
python -m venv venv
|
3 |
-
source venv/bin/activate # On Windows: venv\Scripts\activate
|
4 |
-
|
5 |
-
# Install required packages
|
6 |
-
pip install torch transformers sentence-transformers faiss-cpu pandas streamlit nltk
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
# Activate your virtual environment if not already active
|
13 |
-
source venv/bin/activate # On Windows: venv\Scripts\activate
|
14 |
-
|
15 |
-
# Run the Streamlit app
|
16 |
-
streamlit run app.py
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
@@ -3,8 +3,14 @@ transformers>=4.30.0
|
|
3 |
sentence-transformers>=2.2.2
|
4 |
faiss-cpu>=1.7.4
|
5 |
pandas>=1.5.0
|
6 |
-
streamlit>=1.
|
7 |
numpy>=1.24.0
|
8 |
datasets>=2.10.0
|
9 |
bitsandbytes>=0.40.0
|
10 |
-
accelerate>=0.20.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
sentence-transformers>=2.2.2
|
4 |
faiss-cpu>=1.7.4
|
5 |
pandas>=1.5.0
|
6 |
+
streamlit>=1.36.0
|
7 |
numpy>=1.24.0
|
8 |
datasets>=2.10.0
|
9 |
bitsandbytes>=0.40.0
|
10 |
+
accelerate>=0.20.0
|
11 |
+
evaluate>=0.4.0
|
12 |
+
scikit-learn>=1.2.0
|
13 |
+
nlpaug>=1.1.0
|
14 |
+
googletrans==4.0.0-rc1
|
15 |
+
psutil>=5.9.0
|
16 |
+
nltk>=3.8.0
|
src/__init__.py
CHANGED
@@ -1,2 +1 @@
|
|
1 |
-
|
2 |
-
# This file is intentionally left empty to mark the directory as a Python package
|
|
|
1 |
+
# This file marks the src directory as a Python package
|
|
src/data_processing.py
CHANGED
@@ -1,40 +1,56 @@
|
|
1 |
import pandas as pd
|
2 |
import json
|
|
|
|
|
3 |
from typing import List, Dict, Any
|
4 |
from datasets import load_dataset
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
def load_huggingface_faq_data(dataset_name: str = "NebulaByte/E-Commerce_FAQs") -> List[Dict[str, Any]]:
|
7 |
"""
|
8 |
-
Load FAQ data from Hugging Face datasets
|
9 |
"""
|
10 |
-
|
|
|
|
|
|
|
|
|
11 |
|
|
|
12 |
try:
|
13 |
-
# Load the dataset
|
14 |
dataset = load_dataset(dataset_name)
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
for item in
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
"answer": item["answer"],
|
26 |
-
# Include additional metadata
|
27 |
-
"category": item.get("category", ""),
|
28 |
-
"question_id": item.get("question_id", ""),
|
29 |
-
"faq_url": item.get("faq_url", "")
|
30 |
-
}
|
31 |
-
faqs.append(faq)
|
32 |
-
|
33 |
-
print(f"Loaded {len(faqs)} FAQ entries from Hugging Face")
|
34 |
return faqs
|
35 |
-
|
36 |
except Exception as e:
|
37 |
-
print(f"Error loading dataset
|
38 |
print("Falling back to local data...")
|
39 |
return load_faq_data("data/faq_data.csv")
|
40 |
|
@@ -46,48 +62,86 @@ def load_faq_data(file_path: str) -> List[Dict[str, Any]]:
|
|
46 |
try:
|
47 |
if file_path.endswith('.csv'):
|
48 |
df = pd.read_csv(file_path)
|
49 |
-
# Assume CSV has 'question' and 'answer' columns
|
50 |
faqs = df.to_dict('records')
|
51 |
elif file_path.endswith('.json'):
|
52 |
with open(file_path, 'r') as f:
|
53 |
faqs = json.load(f)
|
54 |
else:
|
55 |
raise ValueError(f"Unsupported file format: {file_path}")
|
56 |
-
|
57 |
print(f"Loaded {len(faqs)} FAQ entries")
|
58 |
return faqs
|
59 |
except Exception as e:
|
60 |
print(f"Error loading data: {e}")
|
61 |
-
# Create a minimal sample dataset as fallback
|
62 |
print("Creating sample dataset as fallback")
|
63 |
sample_faqs = [
|
64 |
-
{"question": "How do I track my order?",
|
65 |
-
|
66 |
-
{"question": "How do I reset my password?",
|
67 |
-
"answer": "To reset your password, click on the 'Forgot Password' link on the login page."}
|
68 |
]
|
69 |
return sample_faqs
|
70 |
|
71 |
def preprocess_faq(faqs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
72 |
"""
|
73 |
-
Preprocess FAQ data: clean text, handle formatting
|
74 |
"""
|
75 |
processed_faqs = []
|
76 |
for faq in faqs:
|
77 |
-
#
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
|
|
87 |
|
88 |
# Only include FAQs with both question and answer
|
89 |
-
if
|
90 |
processed_faqs.append(faq)
|
|
|
|
|
91 |
|
92 |
print(f"After preprocessing: {len(processed_faqs)} valid FAQ entries")
|
93 |
-
return processed_faqs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import pandas as pd
|
2 |
import json
|
3 |
+
import os
|
4 |
+
import nltk
|
5 |
from typing import List, Dict, Any
|
6 |
from datasets import load_dataset
|
7 |
+
import nlpaug.augmenter.word as naw
|
8 |
+
from googletrans import Translator
|
9 |
+
|
10 |
+
# Configure NLTK data path and download required resources
|
11 |
+
NLTK_DATA_PATH = os.path.join(os.path.dirname(__file__), "../nltk_data")
|
12 |
+
os.makedirs(NLTK_DATA_PATH, exist_ok=True)
|
13 |
+
nltk.data.path.append(NLTK_DATA_PATH)
|
14 |
+
|
15 |
+
def ensure_nltk_resources():
|
16 |
+
"""
|
17 |
+
Ensure NLTK resources are downloaded and available
|
18 |
+
"""
|
19 |
+
try:
|
20 |
+
nltk.download('averaged_perceptron_tagger', download_dir=NLTK_DATA_PATH)
|
21 |
+
nltk.download('punkt', download_dir=NLTK_DATA_PATH)
|
22 |
+
print(f"NLTK resources downloaded to {NLTK_DATA_PATH}")
|
23 |
+
return True
|
24 |
+
except Exception as e:
|
25 |
+
print(f"Failed to download NLTK resources: {e}")
|
26 |
+
return False
|
27 |
|
28 |
def load_huggingface_faq_data(dataset_name: str = "NebulaByte/E-Commerce_FAQs") -> List[Dict[str, Any]]:
|
29 |
"""
|
30 |
+
Load FAQ data from Hugging Face datasets, cache locally
|
31 |
"""
|
32 |
+
local_path = "data/ecommerce_faqs.json"
|
33 |
+
if os.path.exists(local_path):
|
34 |
+
print(f"Loading cached dataset from {local_path}")
|
35 |
+
with open(local_path, 'r') as f:
|
36 |
+
return json.load(f)
|
37 |
|
38 |
+
print(f"Loading dataset {dataset_name} from Hugging Face...")
|
39 |
try:
|
|
|
40 |
dataset = load_dataset(dataset_name)
|
41 |
+
faqs = [{
|
42 |
+
"question": item["question"],
|
43 |
+
"answer": item["answer"],
|
44 |
+
"category": item.get("category", ""),
|
45 |
+
"question_id": item.get("question_id", ""),
|
46 |
+
"faq_url": item.get("faq_url", "")
|
47 |
+
} for item in dataset["train"]]
|
48 |
+
with open(local_path, 'w') as f:
|
49 |
+
json.dump(faqs, f)
|
50 |
+
print(f"Saved dataset to {local_path}, loaded {len(faqs)} FAQs")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
return faqs
|
|
|
52 |
except Exception as e:
|
53 |
+
print(f"Error loading dataset: {e}")
|
54 |
print("Falling back to local data...")
|
55 |
return load_faq_data("data/faq_data.csv")
|
56 |
|
|
|
62 |
try:
|
63 |
if file_path.endswith('.csv'):
|
64 |
df = pd.read_csv(file_path)
|
|
|
65 |
faqs = df.to_dict('records')
|
66 |
elif file_path.endswith('.json'):
|
67 |
with open(file_path, 'r') as f:
|
68 |
faqs = json.load(f)
|
69 |
else:
|
70 |
raise ValueError(f"Unsupported file format: {file_path}")
|
|
|
71 |
print(f"Loaded {len(faqs)} FAQ entries")
|
72 |
return faqs
|
73 |
except Exception as e:
|
74 |
print(f"Error loading data: {e}")
|
|
|
75 |
print("Creating sample dataset as fallback")
|
76 |
sample_faqs = [
|
77 |
+
{"question": "How do I track my order?", "answer": "You can track your order by logging into your account and visiting the Order History section."},
|
78 |
+
{"question": "How do I reset my password?", "answer": "To reset your password, click on the 'Forgot Password' link on the login page."}
|
|
|
|
|
79 |
]
|
80 |
return sample_faqs
|
81 |
|
82 |
def preprocess_faq(faqs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
83 |
"""
|
84 |
+
Preprocess FAQ data: clean text, handle formatting, and filter invalid entries
|
85 |
"""
|
86 |
processed_faqs = []
|
87 |
for faq in faqs:
|
88 |
+
# Safely handle question and answer fields
|
89 |
+
question = faq.get('question')
|
90 |
+
answer = faq.get('answer')
|
91 |
+
|
92 |
+
# Convert to string and strip, handling None values
|
93 |
+
question = str(question).strip() if question is not None else ""
|
94 |
+
answer = str(answer).strip() if answer is not None else ""
|
95 |
+
|
96 |
+
# Update FAQ dictionary
|
97 |
+
faq['question'] = question
|
98 |
+
faq['answer'] = answer
|
99 |
|
100 |
# Only include FAQs with both question and answer
|
101 |
+
if question and answer:
|
102 |
processed_faqs.append(faq)
|
103 |
+
else:
|
104 |
+
print(f"Skipping invalid FAQ: question='{question}', answer='{answer}'")
|
105 |
|
106 |
print(f"After preprocessing: {len(processed_faqs)} valid FAQ entries")
|
107 |
+
return processed_faqs
|
108 |
+
|
109 |
+
def augment_faqs(faqs: List[Dict[str, Any]], max_faqs: int = 1000, enable_augmentation: bool = True) -> List[Dict[str, Any]]:
|
110 |
+
"""
|
111 |
+
Augment FAQs with paraphrased questions if enabled
|
112 |
+
"""
|
113 |
+
if not enable_augmentation:
|
114 |
+
print("Augmentation disabled; returning original FAQs")
|
115 |
+
return faqs
|
116 |
+
|
117 |
+
if not ensure_nltk_resources():
|
118 |
+
print("NLTK resources unavailable; skipping augmentation")
|
119 |
+
return faqs
|
120 |
+
|
121 |
+
aug = naw.SynonymAug()
|
122 |
+
augmented = []
|
123 |
+
for faq in faqs:
|
124 |
+
augmented.append(faq)
|
125 |
+
if len(augmented) < max_faqs:
|
126 |
+
try:
|
127 |
+
aug_question = aug.augment(faq['question'])[0]
|
128 |
+
augmented.append({"question": aug_question, "answer": faq['answer'], "category": faq.get("category", "")})
|
129 |
+
except Exception as e:
|
130 |
+
print(f"Augmentation error for question '{faq['question'][:50]}...': {e}")
|
131 |
+
print(f"Augmented to {len(augmented)} FAQs")
|
132 |
+
return augmented
|
133 |
+
|
134 |
+
def translate_faq(faq: Dict[str, Any], target_lang: str = "es") -> Dict[str, Any]:
|
135 |
+
"""
|
136 |
+
Translate FAQ to a target language
|
137 |
+
"""
|
138 |
+
try:
|
139 |
+
translator = Translator()
|
140 |
+
translated = faq.copy()
|
141 |
+
translated["question"] = translator.translate(faq["question"], dest=target_lang).text
|
142 |
+
translated["answer"] = translator.translate(faq["answer"], dest=target_lang).text
|
143 |
+
translated["language"] = target_lang
|
144 |
+
return translated
|
145 |
+
except Exception as e:
|
146 |
+
print(f"Translation error: {e}")
|
147 |
+
return faq
|
src/embedding.py
CHANGED
@@ -4,50 +4,43 @@ import numpy as np
|
|
4 |
from typing import List, Dict, Any
|
5 |
import torch
|
6 |
import gc
|
|
|
|
|
|
|
7 |
|
8 |
class FAQEmbedder:
|
9 |
def __init__(self, model_name: str = "all-MiniLM-L6-v2"):
|
10 |
"""
|
11 |
Initialize the FAQ embedder with a sentence transformer model
|
12 |
-
Optimized for memory efficiency
|
13 |
"""
|
14 |
-
|
15 |
-
|
16 |
-
self.device = "cpu"
|
17 |
self.model = SentenceTransformer(model_name, device=self.device)
|
18 |
self.index = None
|
19 |
self.faqs = None
|
20 |
self.embeddings = None
|
21 |
|
22 |
-
def create_embeddings(self, faqs: List[Dict[str, Any]], batch_size: int =
|
23 |
"""
|
24 |
Create embeddings for all FAQs and build FAISS index
|
25 |
-
Using batching for memory efficiency
|
26 |
"""
|
27 |
self.faqs = faqs
|
|
|
|
|
28 |
print(f"Creating embeddings for {len(faqs)} FAQs in batches of {batch_size}...")
|
29 |
|
30 |
-
# Extract questions for embedding
|
31 |
questions = [faq['question'] for faq in faqs]
|
32 |
-
|
33 |
-
# Process in batches to reduce memory usage
|
34 |
all_embeddings = []
|
35 |
for i in range(0, len(questions), batch_size):
|
36 |
batch = questions[i:i+batch_size]
|
37 |
print(f"Processing batch {i//batch_size + 1}/{(len(questions) + batch_size - 1)//batch_size}")
|
38 |
-
|
39 |
-
# Create embeddings for this batch
|
40 |
batch_embeddings = self.model.encode(batch, show_progress_bar=False, convert_to_numpy=True)
|
41 |
all_embeddings.append(batch_embeddings)
|
42 |
-
|
43 |
-
# Combine all batches
|
44 |
-
self.embeddings = np.vstack(all_embeddings).astype('float32')
|
45 |
|
46 |
-
|
47 |
all_embeddings = None
|
48 |
gc.collect()
|
49 |
|
50 |
-
# Create FAISS index
|
51 |
dimension = self.embeddings.shape[1]
|
52 |
self.index = faiss.IndexFlatL2(dimension)
|
53 |
self.index.add(self.embeddings)
|
@@ -62,20 +55,35 @@ class FAQEmbedder:
|
|
62 |
if self.index is None or self.faqs is None or self.embeddings is None:
|
63 |
raise ValueError("Embeddings not created yet. Call create_embeddings first.")
|
64 |
|
65 |
-
# Embed the query
|
66 |
query_embedding = self.model.encode([query], convert_to_numpy=True).astype('float32')
|
67 |
-
|
68 |
-
# Search in FAISS
|
69 |
distances, indices = self.index.search(query_embedding, k)
|
70 |
|
71 |
-
# Get the relevant FAQs with their similarity scores
|
72 |
relevant_faqs = []
|
73 |
for i, idx in enumerate(indices[0]):
|
74 |
-
if idx < len(self.faqs):
|
75 |
faq = self.faqs[idx].copy()
|
76 |
-
# Convert L2 distance to similarity score (higher is better)
|
77 |
similarity = 1.0 / (1.0 + distances[0][i])
|
78 |
faq['similarity'] = similarity
|
79 |
relevant_faqs.append(faq)
|
80 |
|
81 |
-
return relevant_faqs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
from typing import List, Dict, Any
|
5 |
import torch
|
6 |
import gc
|
7 |
+
import os
|
8 |
+
import psutil
|
9 |
+
import json
|
10 |
|
11 |
class FAQEmbedder:
|
12 |
def __init__(self, model_name: str = "all-MiniLM-L6-v2"):
|
13 |
"""
|
14 |
Initialize the FAQ embedder with a sentence transformer model
|
|
|
15 |
"""
|
16 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
17 |
+
print(f"Embedding model using device: {self.device}")
|
|
|
18 |
self.model = SentenceTransformer(model_name, device=self.device)
|
19 |
self.index = None
|
20 |
self.faqs = None
|
21 |
self.embeddings = None
|
22 |
|
23 |
+
def create_embeddings(self, faqs: List[Dict[str, Any]], batch_size: int = None) -> None:
|
24 |
"""
|
25 |
Create embeddings for all FAQs and build FAISS index
|
|
|
26 |
"""
|
27 |
self.faqs = faqs
|
28 |
+
available_memory = psutil.virtual_memory().available / (1024 ** 3) # GB
|
29 |
+
batch_size = batch_size or min(64, int(available_memory * 4))
|
30 |
print(f"Creating embeddings for {len(faqs)} FAQs in batches of {batch_size}...")
|
31 |
|
|
|
32 |
questions = [faq['question'] for faq in faqs]
|
|
|
|
|
33 |
all_embeddings = []
|
34 |
for i in range(0, len(questions), batch_size):
|
35 |
batch = questions[i:i+batch_size]
|
36 |
print(f"Processing batch {i//batch_size + 1}/{(len(questions) + batch_size - 1)//batch_size}")
|
|
|
|
|
37 |
batch_embeddings = self.model.encode(batch, show_progress_bar=False, convert_to_numpy=True)
|
38 |
all_embeddings.append(batch_embeddings)
|
|
|
|
|
|
|
39 |
|
40 |
+
self.embeddings = np.vstack(all_embeddings).astype('float32')
|
41 |
all_embeddings = None
|
42 |
gc.collect()
|
43 |
|
|
|
44 |
dimension = self.embeddings.shape[1]
|
45 |
self.index = faiss.IndexFlatL2(dimension)
|
46 |
self.index.add(self.embeddings)
|
|
|
55 |
if self.index is None or self.faqs is None or self.embeddings is None:
|
56 |
raise ValueError("Embeddings not created yet. Call create_embeddings first.")
|
57 |
|
|
|
58 |
query_embedding = self.model.encode([query], convert_to_numpy=True).astype('float32')
|
|
|
|
|
59 |
distances, indices = self.index.search(query_embedding, k)
|
60 |
|
|
|
61 |
relevant_faqs = []
|
62 |
for i, idx in enumerate(indices[0]):
|
63 |
+
if idx < len(self.faqs):
|
64 |
faq = self.faqs[idx].copy()
|
|
|
65 |
similarity = 1.0 / (1.0 + distances[0][i])
|
66 |
faq['similarity'] = similarity
|
67 |
relevant_faqs.append(faq)
|
68 |
|
69 |
+
return relevant_faqs
|
70 |
+
|
71 |
+
def save(self, path: str):
|
72 |
+
"""
|
73 |
+
Save embeddings and FAQs to disk
|
74 |
+
"""
|
75 |
+
os.makedirs(path, exist_ok=True)
|
76 |
+
self.model.save(path)
|
77 |
+
faiss.write_index(self.index, f"{path}/index.bin")
|
78 |
+
with open(f"{path}/faqs.json", "w") as f:
|
79 |
+
json.dump(self.faqs, f)
|
80 |
+
|
81 |
+
def load(self, path: str):
|
82 |
+
"""
|
83 |
+
Load embeddings and FAQs from disk
|
84 |
+
"""
|
85 |
+
self.model = SentenceTransformer(path)
|
86 |
+
self.index = faiss.read_index(f"{path}/index.bin")
|
87 |
+
with open(f"{path}/faqs.json", "r") as f:
|
88 |
+
self.faqs = json.load(f)
|
89 |
+
self.embeddings = np.array([self.model.encode(faq["question"]) for faq in self.faqs]).astype('float32')
|
src/llm_response.py
CHANGED
@@ -2,101 +2,76 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
|
2 |
import torch
|
3 |
from typing import List, Dict, Any
|
4 |
import gc
|
|
|
5 |
|
6 |
class ResponseGenerator:
|
7 |
-
def __init__(self, model_name: str = "
|
8 |
"""
|
9 |
Initialize the response generator with an LLM
|
10 |
-
Optimized for 8-11GB GPU
|
11 |
"""
|
12 |
print(f"Loading LLM: {model_name}")
|
13 |
-
|
|
|
14 |
|
15 |
-
# Load tokenizer
|
16 |
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
17 |
-
|
18 |
-
# Configure device and data type based on available resources
|
19 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
20 |
-
print(f"Using device: {device}")
|
21 |
-
|
22 |
-
# Free up memory before loading model
|
23 |
gc.collect()
|
24 |
-
if device == "cuda":
|
25 |
torch.cuda.empty_cache()
|
26 |
|
27 |
-
# Configure 4-bit quantization for maximum memory efficiency
|
28 |
try:
|
29 |
-
|
30 |
-
quantization_config = BitsAndBytesConfig(
|
31 |
-
load_in_4bit=True,
|
32 |
-
bnb_4bit_compute_dtype=torch.float16,
|
33 |
-
bnb_4bit_use_double_quant=True,
|
34 |
-
bnb_4bit_quant_type="nf4"
|
35 |
-
)
|
36 |
-
|
37 |
-
# Load the model with quantization
|
38 |
-
self.model = AutoModelForCausalLM.from_pretrained(
|
39 |
-
model_name,
|
40 |
-
quantization_config=quantization_config,
|
41 |
-
device_map="auto",
|
42 |
-
torch_dtype=torch.float16,
|
43 |
-
# Load model in parts to avoid OOM errors
|
44 |
-
max_memory={0: "8GiB", "cpu": "16GiB"},
|
45 |
-
offload_folder="offload",
|
46 |
-
offload_state_dict=True, # Offload weights to CPU when not in use
|
47 |
-
low_cpu_mem_usage=True
|
48 |
-
)
|
49 |
-
except Exception as e:
|
50 |
-
print(f"4-bit quantization error: {e}")
|
51 |
-
print("Falling back to 8-bit quantization...")
|
52 |
-
|
53 |
-
try:
|
54 |
-
# Try 8-bit quantization
|
55 |
quantization_config = BitsAndBytesConfig(
|
56 |
-
|
57 |
-
|
|
|
|
|
58 |
)
|
|
|
|
|
|
|
|
|
59 |
|
60 |
self.model = AutoModelForCausalLM.from_pretrained(
|
61 |
model_name,
|
62 |
quantization_config=quantization_config,
|
63 |
device_map="auto",
|
64 |
torch_dtype=torch.float16,
|
65 |
-
max_memory=
|
66 |
offload_folder="offload",
|
|
|
67 |
low_cpu_mem_usage=True
|
68 |
)
|
69 |
-
|
70 |
-
print(f"8-bit quantization error: {e2}")
|
71 |
-
print("Falling back to smaller model...")
|
72 |
-
|
73 |
-
# Use a much smaller model as fallback
|
74 |
-
backup_model = "microsoft/phi-2"
|
75 |
-
self.tokenizer = AutoTokenizer.from_pretrained(backup_model)
|
76 |
self.model = AutoModelForCausalLM.from_pretrained(
|
77 |
-
|
78 |
-
device_map="
|
79 |
-
torch_dtype=torch.
|
|
|
80 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
|
82 |
print("LLM loaded successfully")
|
83 |
|
84 |
def generate_response(self, query: str, relevant_faqs: List[Dict[str, Any]]) -> str:
|
85 |
"""
|
86 |
Generate a response using the LLM with retrieved FAQs as context
|
87 |
-
Memory-optimized version
|
88 |
"""
|
89 |
-
# Create prompt with relevant FAQs
|
90 |
prompt = self._create_prompt(query, relevant_faqs)
|
91 |
-
|
92 |
-
# Generate response with memory-efficient settings
|
93 |
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
|
94 |
|
95 |
with torch.no_grad():
|
96 |
-
# Use more conservative generation parameters
|
97 |
outputs = self.model.generate(
|
98 |
**inputs,
|
99 |
-
max_new_tokens=
|
100 |
temperature=0.7,
|
101 |
top_p=0.9,
|
102 |
do_sample=True,
|
@@ -104,12 +79,9 @@ class ResponseGenerator:
|
|
104 |
)
|
105 |
|
106 |
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
107 |
-
|
108 |
-
# Extract just the response part (after the prompt)
|
109 |
response = response[len(prompt):].strip()
|
110 |
|
111 |
-
|
112 |
-
if torch.cuda.is_available():
|
113 |
torch.cuda.empty_cache()
|
114 |
|
115 |
return response
|
@@ -118,19 +90,13 @@ class ResponseGenerator:
|
|
118 |
"""
|
119 |
Create a prompt for the LLM with retrieved FAQs as context
|
120 |
"""
|
121 |
-
|
122 |
-
faq_context = "\n\n".join([
|
123 |
-
f"Q: {faq['question']}\nA: {faq['answer']}"
|
124 |
-
for faq in relevant_faqs
|
125 |
-
])
|
126 |
-
|
127 |
-
# Create the prompt
|
128 |
prompt = f"""
|
129 |
Below are some relevant e-commerce customer support FAQ entries:
|
130 |
|
131 |
{faq_context}
|
132 |
|
133 |
-
Based on the information above,
|
134 |
Customer Query: {query}
|
135 |
|
136 |
Response:
|
|
|
2 |
import torch
|
3 |
from typing import List, Dict, Any
|
4 |
import gc
|
5 |
+
import psutil
|
6 |
|
7 |
class ResponseGenerator:
|
8 |
+
def __init__(self, model_name: str = "microsoft/phi-2"):
|
9 |
"""
|
10 |
Initialize the response generator with an LLM
|
|
|
11 |
"""
|
12 |
print(f"Loading LLM: {model_name}")
|
13 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
14 |
+
print(f"Using device: {self.device}")
|
15 |
|
|
|
16 |
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
gc.collect()
|
18 |
+
if self.device == "cuda":
|
19 |
torch.cuda.empty_cache()
|
20 |
|
|
|
21 |
try:
|
22 |
+
if self.device == "cuda":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
quantization_config = BitsAndBytesConfig(
|
24 |
+
load_in_4bit=True,
|
25 |
+
bnb_4bit_compute_dtype=torch.float16,
|
26 |
+
bnb_4bit_use_double_quant=True,
|
27 |
+
bnb_4bit_quant_type="nf4"
|
28 |
)
|
29 |
+
available_memory = psutil.virtual_memory().total / (1024 ** 3)
|
30 |
+
gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024 ** 3)
|
31 |
+
max_memory = {0: f"{min(gpu_memory, 15)}GiB", "cpu": f"{min(available_memory, 30)}GiB"}
|
32 |
+
print(f"Setting max_memory: {max_memory}")
|
33 |
|
34 |
self.model = AutoModelForCausalLM.from_pretrained(
|
35 |
model_name,
|
36 |
quantization_config=quantization_config,
|
37 |
device_map="auto",
|
38 |
torch_dtype=torch.float16,
|
39 |
+
max_memory=max_memory,
|
40 |
offload_folder="offload",
|
41 |
+
offload_state_dict=True,
|
42 |
low_cpu_mem_usage=True
|
43 |
)
|
44 |
+
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
self.model = AutoModelForCausalLM.from_pretrained(
|
46 |
+
model_name,
|
47 |
+
device_map={"": "cpu"},
|
48 |
+
torch_dtype=torch.float32,
|
49 |
+
low_cpu_mem_usage=True
|
50 |
)
|
51 |
+
except Exception as e:
|
52 |
+
print(f"Model loading error: {e}")
|
53 |
+
print("Falling back to TinyLlama...")
|
54 |
+
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
55 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
56 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
57 |
+
model_name,
|
58 |
+
device_map={"": self.device},
|
59 |
+
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
|
60 |
+
)
|
61 |
|
62 |
print("LLM loaded successfully")
|
63 |
|
64 |
def generate_response(self, query: str, relevant_faqs: List[Dict[str, Any]]) -> str:
|
65 |
"""
|
66 |
Generate a response using the LLM with retrieved FAQs as context
|
|
|
67 |
"""
|
|
|
68 |
prompt = self._create_prompt(query, relevant_faqs)
|
|
|
|
|
69 |
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
|
70 |
|
71 |
with torch.no_grad():
|
|
|
72 |
outputs = self.model.generate(
|
73 |
**inputs,
|
74 |
+
max_new_tokens=150,
|
75 |
temperature=0.7,
|
76 |
top_p=0.9,
|
77 |
do_sample=True,
|
|
|
79 |
)
|
80 |
|
81 |
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
|
|
82 |
response = response[len(prompt):].strip()
|
83 |
|
84 |
+
if self.device == "cuda":
|
|
|
85 |
torch.cuda.empty_cache()
|
86 |
|
87 |
return response
|
|
|
90 |
"""
|
91 |
Create a prompt for the LLM with retrieved FAQs as context
|
92 |
"""
|
93 |
+
faq_context = "\n\n".join([f"Q: {faq['question']}\nA: {faq['answer']}" for faq in relevant_faqs])
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
prompt = f"""
|
95 |
Below are some relevant e-commerce customer support FAQ entries:
|
96 |
|
97 |
{faq_context}
|
98 |
|
99 |
+
Based on the information above, provide a helpful, accurate, and concise response to the following customer query:
|
100 |
Customer Query: {query}
|
101 |
|
102 |
Response:
|
src/utils.py
CHANGED
@@ -1,6 +1,12 @@
|
|
1 |
import time
|
2 |
import functools
|
3 |
-
from typing import Callable, Any, Dict
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
def time_function(func: Callable) -> Callable:
|
6 |
"""
|
@@ -17,38 +23,73 @@ def time_function(func: Callable) -> Callable:
|
|
17 |
|
18 |
def evaluate_response(generated_response: str, ground_truth: str = None) -> Dict[str, Any]:
|
19 |
"""
|
20 |
-
|
21 |
"""
|
22 |
results = {
|
23 |
"length": len(generated_response),
|
24 |
"word_count": len(generated_response.split())
|
25 |
}
|
26 |
|
27 |
-
# If ground truth is provided, we could add metrics like BLEU or ROUGE
|
28 |
if ground_truth:
|
29 |
-
|
|
|
|
|
|
|
30 |
generated_words = set(generated_response.lower().split())
|
31 |
ground_truth_words = set(ground_truth.lower().split())
|
32 |
overlap = len(generated_words.intersection(ground_truth_words))
|
33 |
-
results
|
|
|
|
|
|
|
|
|
34 |
|
35 |
return results
|
36 |
|
37 |
-
def
|
38 |
"""
|
39 |
-
|
40 |
"""
|
41 |
-
|
42 |
-
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
system_stats = {
|
47 |
"RAM": f"{psutil.virtual_memory().used / (1024 ** 3):.1f}GB / {psutil.virtual_memory().total / (1024 ** 3):.1f}GB",
|
48 |
-
"RAM Usage": f"{psutil.virtual_memory().percent}%"
|
49 |
}
|
50 |
|
51 |
-
# GPU memory if available
|
52 |
if torch.cuda.is_available():
|
53 |
gpu_stats = {}
|
54 |
for i in range(torch.cuda.device_count()):
|
|
|
1 |
import time
|
2 |
import functools
|
3 |
+
from typing import Callable, Any, Dict, List
|
4 |
+
import torch
|
5 |
+
import psutil
|
6 |
+
import json
|
7 |
+
from evaluate import load
|
8 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
9 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
10 |
|
11 |
def time_function(func: Callable) -> Callable:
|
12 |
"""
|
|
|
23 |
|
24 |
def evaluate_response(generated_response: str, ground_truth: str = None) -> Dict[str, Any]:
|
25 |
"""
|
26 |
+
Evaluate generated response with BLEU, ROUGE, and word overlap
|
27 |
"""
|
28 |
results = {
|
29 |
"length": len(generated_response),
|
30 |
"word_count": len(generated_response.split())
|
31 |
}
|
32 |
|
|
|
33 |
if ground_truth:
|
34 |
+
bleu = load("bleu")
|
35 |
+
rouge = load("rouge")
|
36 |
+
bleu_score = bleu.compute(predictions=[generated_response], references=[[ground_truth]])
|
37 |
+
rouge_score = rouge.compute(predictions=[generated_response], references=[ground_truth])
|
38 |
generated_words = set(generated_response.lower().split())
|
39 |
ground_truth_words = set(ground_truth.lower().split())
|
40 |
overlap = len(generated_words.intersection(ground_truth_words))
|
41 |
+
results.update({
|
42 |
+
"bleu": bleu_score["bleu"],
|
43 |
+
"rouge": rouge_score["rougeL"],
|
44 |
+
"word_overlap": overlap / len(ground_truth_words) if ground_truth_words else 0
|
45 |
+
})
|
46 |
|
47 |
return results
|
48 |
|
49 |
+
def evaluate_retrieval(embedder, test_set_path: str, k: int = 3) -> Dict[str, float]:
|
50 |
"""
|
51 |
+
Evaluate retrieval quality with Precision@k and Recall@k
|
52 |
"""
|
53 |
+
with open(test_set_path, 'r') as f:
|
54 |
+
test_set = json.load(f)
|
55 |
+
|
56 |
+
precision, recall = [], []
|
57 |
+
for item in test_set:
|
58 |
+
query = item['query']
|
59 |
+
true_ids = set(item['relevant_ids'])
|
60 |
+
retrieved_faqs = embedder.retrieve_relevant_faqs(query, k)
|
61 |
+
retrieved_ids = set(range(len(retrieved_faqs)))
|
62 |
+
|
63 |
+
true_positives = len(true_ids & retrieved_ids)
|
64 |
+
precision.append(true_positives / k if k > 0 else 0)
|
65 |
+
recall.append(true_positives / len(true_ids) if true_ids else 0)
|
66 |
|
67 |
+
return {
|
68 |
+
"Precision@k": sum(precision) / len(precision) if precision else 0,
|
69 |
+
"Recall@k": sum(recall) / len(recall) if recall else 0
|
70 |
+
}
|
71 |
+
|
72 |
+
def baseline_keyword_search(query: str, faqs: List[Dict[str, Any]], k: int = 3) -> List[Dict[str, Any]]:
|
73 |
+
"""
|
74 |
+
Keyword-based search baseline using TF-IDF
|
75 |
+
"""
|
76 |
+
questions = [faq['question'] for faq in faqs]
|
77 |
+
vectorizer = TfidfVectorizer()
|
78 |
+
question_vectors = vectorizer.fit_transform(questions)
|
79 |
+
query_vector = vectorizer.transform([query])
|
80 |
+
similarities = cosine_similarity(query_vector, question_vectors).flatten()
|
81 |
+
top_k_indices = similarities.argsort()[-k:][::-1]
|
82 |
+
return [faqs[i] for i in top_k_indices]
|
83 |
+
|
84 |
+
def format_memory_stats():
|
85 |
+
"""
|
86 |
+
Format memory usage statistics
|
87 |
+
"""
|
88 |
system_stats = {
|
89 |
"RAM": f"{psutil.virtual_memory().used / (1024 ** 3):.1f}GB / {psutil.virtual_memory().total / (1024 ** 3):.1f}GB",
|
90 |
+
"RAM Usage": f"{psutil.virtual_memory().percent}%"
|
91 |
}
|
92 |
|
|
|
93 |
if torch.cuda.is_available():
|
94 |
gpu_stats = {}
|
95 |
for i in range(torch.cuda.device_count()):
|
test_set.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
{"query": "How do I track my order?", "relevant_ids": [2]},
|
3 |
+
{"query": "How to reset password?", "relevant_ids": [0]},
|
4 |
+
{"query": "What is the return policy?", "relevant_ids": [3]},
|
5 |
+
{"query": "Can I change my shipping address?", "relevant_ids": [10]},
|
6 |
+
{"query": "Do you offer gift cards?", "relevant_ids": [14]}
|
7 |
+
]
|