Spaces:
Sleeping
Sleeping
Update app/bot.py
Browse files- app/bot.py +631 -625
app/bot.py
CHANGED
@@ -1,626 +1,632 @@
|
|
1 |
-
# app/bot.py
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
from
|
10 |
-
|
11 |
-
import
|
12 |
-
|
13 |
-
import
|
14 |
-
import
|
15 |
-
import
|
16 |
-
from
|
17 |
-
|
18 |
-
|
19 |
-
import
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
pass
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
#
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
#
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
"
|
69 |
-
"
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
self.
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
'
|
115 |
-
'
|
116 |
-
'
|
117 |
-
'
|
118 |
-
'
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
df = df
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
#
|
161 |
-
self.
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
emb
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
#
|
210 |
-
|
211 |
-
|
212 |
-
#
|
213 |
-
|
214 |
-
|
215 |
-
#
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
#
|
241 |
-
for c
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
# Step
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
if
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
return [
|
428 |
-
|
429 |
-
|
430 |
-
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
|
487 |
-
|
488 |
-
|
489 |
-
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
|
505 |
-
|
506 |
-
|
507 |
-
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
|
516 |
-
|
517 |
-
|
518 |
-
|
519 |
-
|
520 |
-
|
521 |
-
|
522 |
-
|
523 |
-
|
524 |
-
|
525 |
-
|
526 |
-
"
|
527 |
-
|
528 |
-
|
529 |
-
|
530 |
-
|
531 |
-
|
532 |
-
|
533 |
-
|
534 |
-
|
535 |
-
|
536 |
-
|
537 |
-
|
538 |
-
|
539 |
-
|
540 |
-
|
541 |
-
|
542 |
-
|
543 |
-
|
544 |
-
|
545 |
-
|
546 |
-
|
547 |
-
|
548 |
-
|
549 |
-
|
550 |
-
|
551 |
-
|
552 |
-
|
553 |
-
|
554 |
-
|
555 |
-
|
556 |
-
|
557 |
-
|
558 |
-
|
559 |
-
|
560 |
-
|
561 |
-
|
562 |
-
|
563 |
-
|
564 |
-
|
565 |
-
|
566 |
-
|
567 |
-
|
568 |
-
|
569 |
-
|
570 |
-
|
571 |
-
|
572 |
-
|
573 |
-
|
574 |
-
|
575 |
-
|
576 |
-
|
577 |
-
|
578 |
-
|
579 |
-
|
580 |
-
|
581 |
-
|
582 |
-
|
583 |
-
|
584 |
-
|
585 |
-
|
586 |
-
|
587 |
-
|
588 |
-
|
589 |
-
|
590 |
-
|
591 |
-
|
592 |
-
|
593 |
-
|
594 |
-
|
595 |
-
|
596 |
-
|
597 |
-
|
598 |
-
|
599 |
-
|
600 |
-
|
601 |
-
|
602 |
-
|
603 |
-
|
604 |
-
|
605 |
-
|
606 |
-
|
607 |
-
|
608 |
-
|
609 |
-
|
610 |
-
|
611 |
-
|
612 |
-
|
613 |
-
|
614 |
-
|
615 |
-
|
616 |
-
|
617 |
-
|
618 |
-
|
619 |
-
|
620 |
-
|
621 |
-
print(
|
622 |
-
|
623 |
-
|
624 |
-
|
625 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
626 |
print(f"\nAverage Response Quality: {quality_results['average_quality_score']:.2%}")
|
|
|
1 |
+
# app/bot.py
|
2 |
+
import os
|
3 |
+
# Set cache directories before importing transformers
|
4 |
+
os.environ['HF_HOME'] = '/app/.cache'
|
5 |
+
os.environ['TRANSFORMERS_CACHE'] = '/app/.cache/transformers'
|
6 |
+
os.environ['SENTENCE_TRANSFORMERS_HOME'] = '/app/.cache/sentence_transformers'
|
7 |
+
os.environ['TORCH_HOME'] = '/app/.cache/torch'
|
8 |
+
|
9 |
+
from __future__ import annotations
|
10 |
+
|
11 |
+
import logging
|
12 |
+
import re
|
13 |
+
import unicodedata
|
14 |
+
import warnings
|
15 |
+
from pathlib import Path
|
16 |
+
from typing import Any, List, Dict, Tuple
|
17 |
+
import json
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
import pandas as pd
|
21 |
+
import torch
|
22 |
+
from sentence_transformers import SentenceTransformer, CrossEncoder
|
23 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
24 |
+
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
|
25 |
+
import nltk
|
26 |
+
|
27 |
+
# Download required NLTK data
|
28 |
+
try:
|
29 |
+
nltk.download('punkt', quiet=True)
|
30 |
+
nltk.download('stopwords', quiet=True)
|
31 |
+
except:
|
32 |
+
pass
|
33 |
+
|
34 |
+
warnings.filterwarnings("ignore")
|
35 |
+
|
36 |
+
|
37 |
+
class RequirementError(RuntimeError):
|
38 |
+
pass
|
39 |
+
|
40 |
+
|
41 |
+
class JupiterFAQBot:
|
42 |
+
# ------------------------------------------------------------------ #
|
43 |
+
# Free Models Configuration
|
44 |
+
# ------------------------------------------------------------------ #
|
45 |
+
MODELS = {
|
46 |
+
"bi": "sentence-transformers/all-MiniLM-L6-v2", # Fast semantic search
|
47 |
+
"cross": "cross-encoder/ms-marco-MiniLM-L-6-v2", # Reranking
|
48 |
+
"qa": "deepset/roberta-base-squad2", # Better QA model
|
49 |
+
"summarizer": "facebook/bart-large-cnn", # Better summarization
|
50 |
+
}
|
51 |
+
|
52 |
+
# Retrieval parameters
|
53 |
+
TOP_K = 15 # More candidates for better coverage
|
54 |
+
HIGH_SIM = 0.85 # High confidence threshold
|
55 |
+
CROSS_OK = 0.50 # Cross-encoder threshold
|
56 |
+
MIN_SIM = 0.40 # Minimum similarity to consider
|
57 |
+
|
58 |
+
# Paths
|
59 |
+
EMB_CACHE = Path("data/faq_embeddings.npy")
|
60 |
+
FAQ_PATH = Path("data/faqs.csv")
|
61 |
+
|
62 |
+
# Response templates for better UX
|
63 |
+
CONFIDENCE_LEVELS = {
|
64 |
+
"high": "This information matches your query based on our FAQs:\n\n",
|
65 |
+
"medium": "This appears to be relevant to your question:\n\n",
|
66 |
+
"low": "This may be related to your query and could be helpful:\n\n",
|
67 |
+
"none": (
|
68 |
+
"We couldn't find a direct match for your question. "
|
69 |
+
"However, we can assist with topics such as:\n"
|
70 |
+
"• Account opening and KYC\n"
|
71 |
+
"• Payments and UPI\n"
|
72 |
+
"• Rewards and cashback\n"
|
73 |
+
"• Credit cards and loans\n"
|
74 |
+
"• Investments and savings\n\n"
|
75 |
+
"Please try rephrasing your question or selecting a topic above."
|
76 |
+
)
|
77 |
+
}
|
78 |
+
|
79 |
+
# ------------------------------------------------------------------ #
|
80 |
+
def __init__(self, csv_path: str = None) -> None:
|
81 |
+
logging.basicConfig(format="%(levelname)s | %(message)s", level=logging.INFO)
|
82 |
+
|
83 |
+
# Use provided path or default
|
84 |
+
self.csv_path = csv_path or str(self.FAQ_PATH)
|
85 |
+
|
86 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
87 |
+
self.pipe_dev = 0 if self.device.type == "cuda" else -1
|
88 |
+
|
89 |
+
self._load_data(self.csv_path)
|
90 |
+
self._setup_models()
|
91 |
+
self._setup_embeddings()
|
92 |
+
|
93 |
+
logging.info("Jupiter FAQ Bot ready ✔")
|
94 |
+
|
95 |
+
# ------------------------ Text Processing ------------------------- #
|
96 |
+
@staticmethod
|
97 |
+
def _clean(text: str) -> str:
|
98 |
+
"""Clean and normalize text"""
|
99 |
+
if pd.isna(text):
|
100 |
+
return ""
|
101 |
+
text = str(text)
|
102 |
+
text = unicodedata.normalize("NFC", text)
|
103 |
+
# Remove extra whitespace but keep sentence structure
|
104 |
+
text = re.sub(r'\s+', ' ', text)
|
105 |
+
# Keep bullet points and formatting
|
106 |
+
text = re.sub(r'•\s*', '\n• ', text)
|
107 |
+
return text.strip()
|
108 |
+
|
109 |
+
@staticmethod
|
110 |
+
def _preprocess_query(query: str) -> str:
|
111 |
+
"""Preprocess user query for better matching"""
|
112 |
+
# Expand common abbreviations
|
113 |
+
abbreviations = {
|
114 |
+
'kyc': 'know your customer verification',
|
115 |
+
'upi': 'unified payments interface',
|
116 |
+
'fd': 'fixed deposit',
|
117 |
+
'sip': 'systematic investment plan',
|
118 |
+
'neft': 'national electronic funds transfer',
|
119 |
+
'rtgs': 'real time gross settlement',
|
120 |
+
'imps': 'immediate payment service',
|
121 |
+
'emi': 'equated monthly installment',
|
122 |
+
'apr': 'annual percentage rate',
|
123 |
+
'atm': 'automated teller machine',
|
124 |
+
'pin': 'personal identification number',
|
125 |
+
}
|
126 |
+
|
127 |
+
query_lower = query.lower()
|
128 |
+
for abbr, full in abbreviations.items():
|
129 |
+
if abbr in query_lower.split():
|
130 |
+
query_lower = query_lower.replace(abbr, full)
|
131 |
+
|
132 |
+
return query_lower
|
133 |
+
|
134 |
+
# ------------------------ Initialization -------------------------- #
|
135 |
+
def _load_data(self, path: str):
|
136 |
+
"""Load and preprocess FAQ data"""
|
137 |
+
if not Path(path).exists():
|
138 |
+
raise RequirementError(f"CSV not found: {path}")
|
139 |
+
|
140 |
+
df = pd.read_csv(path)
|
141 |
+
|
142 |
+
# Clean all text fields
|
143 |
+
df["question"] = df["question"].apply(self._clean)
|
144 |
+
df["answer"] = df["answer"].apply(self._clean)
|
145 |
+
df["category"] = df["category"].fillna("General")
|
146 |
+
|
147 |
+
# Create searchable text combining question and category
|
148 |
+
df["searchable"] = df["question"].str.lower() + " " + df["category"].str.lower()
|
149 |
+
|
150 |
+
# Remove duplicates
|
151 |
+
df = df.drop_duplicates(subset=["question"]).reset_index(drop=True)
|
152 |
+
|
153 |
+
self.faq = df
|
154 |
+
logging.info(f"Loaded {len(self.faq)} FAQ entries from {len(df['category'].unique())} categories")
|
155 |
+
|
156 |
+
def _setup_models(self):
|
157 |
+
"""Initialize all models"""
|
158 |
+
logging.info("Loading models...")
|
159 |
+
|
160 |
+
# Sentence transformer for embeddings
|
161 |
+
self.bi = SentenceTransformer(self.MODELS["bi"], device=self.device)
|
162 |
+
|
163 |
+
# Cross-encoder for reranking
|
164 |
+
self.cross = CrossEncoder(self.MODELS["cross"], device=self.device)
|
165 |
+
|
166 |
+
# QA model
|
167 |
+
self.qa = pipeline(
|
168 |
+
"question-answering",
|
169 |
+
model=self.MODELS["qa"],
|
170 |
+
device=self.pipe_dev,
|
171 |
+
handle_impossible_answer=True
|
172 |
+
)
|
173 |
+
|
174 |
+
# Summarization model - using BART for better quality
|
175 |
+
self.summarizer = pipeline(
|
176 |
+
"summarization",
|
177 |
+
model=self.MODELS["summarizer"],
|
178 |
+
device=self.pipe_dev,
|
179 |
+
max_length=150,
|
180 |
+
min_length=50
|
181 |
+
)
|
182 |
+
|
183 |
+
logging.info("All models loaded successfully")
|
184 |
+
|
185 |
+
def _setup_embeddings(self):
|
186 |
+
"""Create or load embeddings"""
|
187 |
+
questions = self.faq["searchable"].tolist()
|
188 |
+
|
189 |
+
if self.EMB_CACHE.exists():
|
190 |
+
emb = np.load(self.EMB_CACHE)
|
191 |
+
if len(emb) != len(questions):
|
192 |
+
logging.info("Regenerating embeddings due to data change...")
|
193 |
+
emb = self.bi.encode(questions, show_progress_bar=True, convert_to_tensor=False)
|
194 |
+
np.save(self.EMB_CACHE, emb)
|
195 |
+
else:
|
196 |
+
logging.info("Creating embeddings for the first time...")
|
197 |
+
emb = self.bi.encode(questions, show_progress_bar=True, convert_to_tensor=False)
|
198 |
+
self.EMB_CACHE.parent.mkdir(parents=True, exist_ok=True)
|
199 |
+
np.save(self.EMB_CACHE, emb)
|
200 |
+
|
201 |
+
self.embeddings = emb
|
202 |
+
|
203 |
+
# ------------------------- Retrieval ------------------------------ #
|
204 |
+
def _retrieve_candidates(self, query: str, top_k: int = None) -> List[Dict]:
|
205 |
+
"""Retrieve top candidates using semantic search"""
|
206 |
+
if top_k is None:
|
207 |
+
top_k = self.TOP_K
|
208 |
+
|
209 |
+
# Preprocess query
|
210 |
+
processed_query = self._preprocess_query(query)
|
211 |
+
|
212 |
+
# Encode query
|
213 |
+
query_emb = self.bi.encode([processed_query])
|
214 |
+
|
215 |
+
# Calculate similarities
|
216 |
+
similarities = cosine_similarity(query_emb, self.embeddings)[0]
|
217 |
+
|
218 |
+
# Get top indices
|
219 |
+
top_indices = similarities.argsort()[-top_k:][::-1]
|
220 |
+
|
221 |
+
# Filter by minimum similarity
|
222 |
+
candidates = []
|
223 |
+
for idx in top_indices:
|
224 |
+
if similarities[idx] >= self.MIN_SIM:
|
225 |
+
candidates.append({
|
226 |
+
"idx": int(idx),
|
227 |
+
"question": self.faq.iloc[idx]["question"],
|
228 |
+
"answer": self.faq.iloc[idx]["answer"],
|
229 |
+
"category": self.faq.iloc[idx]["category"],
|
230 |
+
"similarity": float(similarities[idx])
|
231 |
+
})
|
232 |
+
|
233 |
+
return candidates
|
234 |
+
|
235 |
+
def _rerank_candidates(self, query: str, candidates: List[Dict]) -> List[Dict]:
|
236 |
+
"""Rerank candidates using cross-encoder"""
|
237 |
+
if not candidates:
|
238 |
+
return []
|
239 |
+
|
240 |
+
# Prepare pairs for cross-encoder
|
241 |
+
pairs = [[query, c["question"]] for c in candidates]
|
242 |
+
|
243 |
+
# Get cross-encoder scores
|
244 |
+
scores = self.cross.predict(pairs, convert_to_numpy=True)
|
245 |
+
|
246 |
+
# Add scores to candidates
|
247 |
+
for c, score in zip(candidates, scores):
|
248 |
+
c["cross_score"] = float(score)
|
249 |
+
|
250 |
+
# Filter and sort by cross-encoder score
|
251 |
+
reranked = [c for c in candidates if c["cross_score"] >= self.CROSS_OK]
|
252 |
+
reranked.sort(key=lambda x: x["cross_score"], reverse=True)
|
253 |
+
|
254 |
+
return reranked
|
255 |
+
|
256 |
+
def _extract_answer(self, query: str, context: str) -> Dict[str, Any]:
|
257 |
+
"""Extract specific answer using QA model"""
|
258 |
+
try:
|
259 |
+
result = self.qa(question=query, context=context)
|
260 |
+
return {
|
261 |
+
"answer": result["answer"],
|
262 |
+
"score": result["score"],
|
263 |
+
"start": result.get("start", 0),
|
264 |
+
"end": result.get("end", len(result["answer"]))
|
265 |
+
}
|
266 |
+
except Exception as e:
|
267 |
+
logging.warning(f"QA extraction failed: {e}")
|
268 |
+
return {"answer": context, "score": 0.5}
|
269 |
+
|
270 |
+
def _create_friendly_response(self, answers: List[str], confidence: str = "medium") -> str:
|
271 |
+
"""Create a user-friendly response from multiple answers"""
|
272 |
+
if not answers:
|
273 |
+
return self.CONFIDENCE_LEVELS["none"]
|
274 |
+
|
275 |
+
# Remove duplicates while preserving order
|
276 |
+
unique_answers = []
|
277 |
+
seen = set()
|
278 |
+
for ans in answers:
|
279 |
+
normalized = ans.lower().strip()
|
280 |
+
if normalized not in seen:
|
281 |
+
seen.add(normalized)
|
282 |
+
unique_answers.append(ans)
|
283 |
+
|
284 |
+
if len(unique_answers) == 1:
|
285 |
+
# Single answer - return as is with confidence prefix
|
286 |
+
return self.CONFIDENCE_LEVELS[confidence] + unique_answers[0]
|
287 |
+
|
288 |
+
# Multiple answers - need to summarize
|
289 |
+
combined_text = " ".join(unique_answers)
|
290 |
+
|
291 |
+
# If text is short enough, format it nicely
|
292 |
+
if len(combined_text) < 300:
|
293 |
+
response = self.CONFIDENCE_LEVELS[confidence]
|
294 |
+
for i, answer in enumerate(unique_answers):
|
295 |
+
if "•" in answer:
|
296 |
+
# Already has bullets
|
297 |
+
response += answer + "\n\n"
|
298 |
+
else:
|
299 |
+
# Add as paragraph
|
300 |
+
response += answer + "\n\n"
|
301 |
+
return response.strip()
|
302 |
+
|
303 |
+
# Long text - summarize it
|
304 |
+
try:
|
305 |
+
# Prepare text for summarization
|
306 |
+
summary_input = f"Summarize the following information about Jupiter banking services: {combined_text}"
|
307 |
+
|
308 |
+
# Generate summary
|
309 |
+
summary = self.summarizer(summary_input, max_length=150, min_length=50, do_sample=False)
|
310 |
+
summarized_text = summary[0]['summary_text']
|
311 |
+
|
312 |
+
# Make it more conversational
|
313 |
+
response = self.CONFIDENCE_LEVELS[confidence]
|
314 |
+
response += self._make_conversational(summarized_text)
|
315 |
+
|
316 |
+
return response
|
317 |
+
|
318 |
+
except Exception as e:
|
319 |
+
logging.warning(f"Summarization failed: {e}")
|
320 |
+
# Fallback to formatted response
|
321 |
+
return self._format_multiple_answers(unique_answers, confidence)
|
322 |
+
|
323 |
+
def _make_conversational(self, text: str) -> str:
|
324 |
+
"""Make response more conversational and friendly"""
|
325 |
+
# Add appropriate punctuation if missing
|
326 |
+
if text and text[-1] not in '.!?':
|
327 |
+
text += '.'
|
328 |
+
|
329 |
+
# Replace robotic phrases
|
330 |
+
replacements = {
|
331 |
+
"The user": "You",
|
332 |
+
"the user": "you",
|
333 |
+
"It is": "It's",
|
334 |
+
"You will": "You'll",
|
335 |
+
"You can not": "You can't",
|
336 |
+
"Do not": "Don't",
|
337 |
+
}
|
338 |
+
|
339 |
+
for old, new in replacements.items():
|
340 |
+
text = text.replace(old, new)
|
341 |
+
|
342 |
+
return text
|
343 |
+
|
344 |
+
def _format_multiple_answers(self, answers: List[str], confidence: str) -> str:
|
345 |
+
"""Format multiple answers nicely"""
|
346 |
+
response = self.CONFIDENCE_LEVELS[confidence]
|
347 |
+
|
348 |
+
if len(answers) <= 3:
|
349 |
+
# Few answers - show all
|
350 |
+
for answer in answers:
|
351 |
+
if "•" in answer:
|
352 |
+
response += answer + "\n\n"
|
353 |
+
else:
|
354 |
+
response += f"• {answer}\n\n"
|
355 |
+
else:
|
356 |
+
# Many answers - group by category
|
357 |
+
response += "Here are the key points:\n\n"
|
358 |
+
for i, answer in enumerate(answers[:5]): # Limit to 5
|
359 |
+
response += f"{i+1}. {answer}\n\n"
|
360 |
+
|
361 |
+
return response.strip()
|
362 |
+
|
363 |
+
# ------------------------- Main API ------------------------------- #
|
364 |
+
def generate_response(self, query: str) -> str:
|
365 |
+
"""Generate response for user query"""
|
366 |
+
query = self._clean(query)
|
367 |
+
|
368 |
+
# Step 1: Retrieve candidates
|
369 |
+
candidates = self._retrieve_candidates(query)
|
370 |
+
|
371 |
+
if not candidates:
|
372 |
+
return self.CONFIDENCE_LEVELS["none"]
|
373 |
+
|
374 |
+
# Step 2: Check for high similarity match
|
375 |
+
if candidates[0]["similarity"] >= self.HIGH_SIM:
|
376 |
+
return self.CONFIDENCE_LEVELS["high"] + candidates[0]["answer"]
|
377 |
+
|
378 |
+
# Step 3: Rerank candidates
|
379 |
+
reranked = self._rerank_candidates(query, candidates)
|
380 |
+
|
381 |
+
if not reranked:
|
382 |
+
# Use original candidates with lower confidence
|
383 |
+
reranked = candidates[:3]
|
384 |
+
confidence = "low"
|
385 |
+
else:
|
386 |
+
confidence = "high" if reranked[0]["cross_score"] > 0.8 else "medium"
|
387 |
+
|
388 |
+
# Step 4: Extract relevant answers
|
389 |
+
relevant_answers = []
|
390 |
+
|
391 |
+
for candidate in reranked[:5]: # Top 5 reranked
|
392 |
+
# Try QA extraction for more specific answer
|
393 |
+
qa_result = self._extract_answer(query, candidate["answer"])
|
394 |
+
|
395 |
+
if qa_result["score"] > 0.3:
|
396 |
+
# Good QA match
|
397 |
+
relevant_answers.append(qa_result["answer"])
|
398 |
+
else:
|
399 |
+
# Use full answer if QA didn't find specific part
|
400 |
+
relevant_answers.append(candidate["answer"])
|
401 |
+
|
402 |
+
# Step 5: Create final response
|
403 |
+
final_response = self._create_friendly_response(relevant_answers, confidence)
|
404 |
+
|
405 |
+
return final_response
|
406 |
+
|
407 |
+
def suggest_related_queries(self, query: str) -> List[str]:
|
408 |
+
"""Suggest related queries based on similar questions"""
|
409 |
+
candidates = self._retrieve_candidates(query, top_k=10)
|
410 |
+
|
411 |
+
related = []
|
412 |
+
seen = set()
|
413 |
+
|
414 |
+
for candidate in candidates:
|
415 |
+
if candidate["similarity"] >= 0.5 and candidate["similarity"] < 0.9:
|
416 |
+
# Clean question for display
|
417 |
+
clean_q = candidate["question"].strip()
|
418 |
+
if clean_q.lower() not in seen and clean_q.lower() != query.lower():
|
419 |
+
seen.add(clean_q.lower())
|
420 |
+
related.append(clean_q)
|
421 |
+
|
422 |
+
# Return top 5 related queries
|
423 |
+
return related[:5]
|
424 |
+
|
425 |
+
def get_categories(self) -> List[str]:
|
426 |
+
"""Get all available FAQ categories"""
|
427 |
+
return sorted(self.faq["category"].unique().tolist())
|
428 |
+
|
429 |
+
def get_faqs_by_category(self, category: str) -> List[Dict[str, str]]:
|
430 |
+
"""Get all FAQs for a specific category"""
|
431 |
+
cat_faqs = self.faq[self.faq["category"].str.lower() == category.lower()]
|
432 |
+
|
433 |
+
return [
|
434 |
+
{
|
435 |
+
"question": row["question"],
|
436 |
+
"answer": row["answer"]
|
437 |
+
}
|
438 |
+
for _, row in cat_faqs.iterrows()
|
439 |
+
]
|
440 |
+
|
441 |
+
def search_faqs(self, keyword: str) -> List[Dict[str, str]]:
|
442 |
+
"""Simple keyword search in FAQs"""
|
443 |
+
keyword_lower = keyword.lower()
|
444 |
+
|
445 |
+
matches = []
|
446 |
+
for _, row in self.faq.iterrows():
|
447 |
+
if (keyword_lower in row["question"].lower() or
|
448 |
+
keyword_lower in row["answer"].lower()):
|
449 |
+
matches.append({
|
450 |
+
"question": row["question"],
|
451 |
+
"answer": row["answer"],
|
452 |
+
"category": row["category"]
|
453 |
+
})
|
454 |
+
|
455 |
+
return matches[:10] # Limit to 10 results
|
456 |
+
|
457 |
+
|
458 |
+
# Evaluation module
|
459 |
+
class BotEvaluator:
|
460 |
+
"""Evaluate bot performance"""
|
461 |
+
|
462 |
+
def __init__(self, bot: JupiterFAQBot):
|
463 |
+
self.bot = bot
|
464 |
+
|
465 |
+
def create_test_queries(self) -> List[Dict[str, str]]:
|
466 |
+
"""Create test queries based on FAQ categories"""
|
467 |
+
test_queries = [
|
468 |
+
# Account queries
|
469 |
+
{"query": "How do I open an account?", "expected_category": "Account"},
|
470 |
+
{"query": "What is Jupiter savings account?", "expected_category": "Account"},
|
471 |
+
|
472 |
+
# Payment queries
|
473 |
+
{"query": "How to make UPI payment?", "expected_category": "Payments"},
|
474 |
+
{"query": "What is the daily transaction limit?", "expected_category": "Payments"},
|
475 |
+
|
476 |
+
# Rewards queries
|
477 |
+
{"query": "How do I earn cashback?", "expected_category": "Rewards"},
|
478 |
+
{"query": "What are Jewels?", "expected_category": "Rewards"},
|
479 |
+
|
480 |
+
# Investment queries
|
481 |
+
{"query": "Can I invest in mutual funds?", "expected_category": "Investments"},
|
482 |
+
{"query": "What is Magic Spends?", "expected_category": "Magic Spends"},
|
483 |
+
|
484 |
+
# Loan queries
|
485 |
+
{"query": "How to apply for personal loan?", "expected_category": "Jupiter Loans"},
|
486 |
+
{"query": "What is the interest rate?", "expected_category": "Jupiter Loans"},
|
487 |
+
|
488 |
+
# Card queries
|
489 |
+
{"query": "How to get credit card?", "expected_category": "Edge+ Credit Card"},
|
490 |
+
{"query": "Is there any annual fee?", "expected_category": "Edge+ Credit Card"},
|
491 |
+
]
|
492 |
+
|
493 |
+
return test_queries
|
494 |
+
|
495 |
+
def evaluate_retrieval_accuracy(self) -> Dict[str, float]:
|
496 |
+
"""Evaluate how well the bot retrieves relevant information"""
|
497 |
+
test_queries = self.create_test_queries()
|
498 |
+
|
499 |
+
correct = 0
|
500 |
+
total = len(test_queries)
|
501 |
+
|
502 |
+
results = []
|
503 |
+
|
504 |
+
for test in test_queries:
|
505 |
+
response = self.bot.generate_response(test["query"])
|
506 |
+
|
507 |
+
# Check if response mentions expected category content
|
508 |
+
is_correct = test["expected_category"].lower() in response.lower()
|
509 |
+
|
510 |
+
if is_correct:
|
511 |
+
correct += 1
|
512 |
+
|
513 |
+
results.append({
|
514 |
+
"query": test["query"],
|
515 |
+
"expected_category": test["expected_category"],
|
516 |
+
"response": response[:200] + "..." if len(response) > 200 else response,
|
517 |
+
"correct": is_correct
|
518 |
+
})
|
519 |
+
|
520 |
+
accuracy = correct / total if total > 0 else 0
|
521 |
+
|
522 |
+
return {
|
523 |
+
"accuracy": accuracy,
|
524 |
+
"correct": correct,
|
525 |
+
"total": total,
|
526 |
+
"results": results
|
527 |
+
}
|
528 |
+
|
529 |
+
def evaluate_response_quality(self) -> Dict[str, Any]:
|
530 |
+
"""Evaluate the quality of responses"""
|
531 |
+
test_queries = [
|
532 |
+
"What is Jupiter?",
|
533 |
+
"How do I earn rewards?",
|
534 |
+
"Tell me about credit cards",
|
535 |
+
"Can I get a loan?",
|
536 |
+
"How to invest money?"
|
537 |
+
]
|
538 |
+
|
539 |
+
quality_metrics = []
|
540 |
+
|
541 |
+
for query in test_queries:
|
542 |
+
response = self.bot.generate_response(query)
|
543 |
+
|
544 |
+
# Check quality indicators
|
545 |
+
has_greeting = any(phrase in response for phrase in ["Based on", "Here's", "I found"])
|
546 |
+
has_structure = "\n" in response or "•" in response
|
547 |
+
appropriate_length = 50 < len(response) < 500
|
548 |
+
|
549 |
+
quality_score = sum([has_greeting, has_structure, appropriate_length]) / 3
|
550 |
+
|
551 |
+
quality_metrics.append({
|
552 |
+
"query": query,
|
553 |
+
"response_length": len(response),
|
554 |
+
"has_greeting": has_greeting,
|
555 |
+
"has_structure": has_structure,
|
556 |
+
"appropriate_length": appropriate_length,
|
557 |
+
"quality_score": quality_score
|
558 |
+
})
|
559 |
+
|
560 |
+
avg_quality = sum(m["quality_score"] for m in quality_metrics) / len(quality_metrics)
|
561 |
+
|
562 |
+
return {
|
563 |
+
"average_quality_score": avg_quality,
|
564 |
+
"metrics": quality_metrics
|
565 |
+
}
|
566 |
+
|
567 |
+
|
568 |
+
# Utility functions for data preparation
|
569 |
+
def prepare_faq_data(csv_path: str = "data/faqs.csv") -> pd.DataFrame:
|
570 |
+
"""Prepare and validate FAQ data"""
|
571 |
+
df = pd.read_csv(csv_path)
|
572 |
+
|
573 |
+
# Ensure required columns exist
|
574 |
+
required_cols = ["question", "answer", "category"]
|
575 |
+
if not all(col in df.columns for col in required_cols):
|
576 |
+
raise ValueError(f"CSV must contain columns: {required_cols}")
|
577 |
+
|
578 |
+
# Basic stats
|
579 |
+
print(f"Total FAQs: {len(df)}")
|
580 |
+
print(f"Categories: {df['category'].nunique()}")
|
581 |
+
print(f"\nCategory distribution:")
|
582 |
+
print(df['category'].value_counts())
|
583 |
+
|
584 |
+
return df
|
585 |
+
|
586 |
+
|
587 |
+
# Main execution example
|
588 |
+
if __name__ == "__main__":
|
589 |
+
# Initialize bot
|
590 |
+
bot = JupiterFAQBot()
|
591 |
+
|
592 |
+
# Test some queries
|
593 |
+
test_queries = [
|
594 |
+
"How do I open a savings account?",
|
595 |
+
"What are the cashback rates?",
|
596 |
+
"Can I get a personal loan?",
|
597 |
+
"How to use UPI?",
|
598 |
+
"Tell me about investments"
|
599 |
+
]
|
600 |
+
|
601 |
+
print("\n" + "="*50)
|
602 |
+
print("Testing Jupiter FAQ Bot")
|
603 |
+
print("="*50 + "\n")
|
604 |
+
|
605 |
+
for query in test_queries:
|
606 |
+
print(f"Q: {query}")
|
607 |
+
response = bot.generate_response(query)
|
608 |
+
print(f"A: {response}\n")
|
609 |
+
|
610 |
+
# Show related queries
|
611 |
+
related = bot.suggest_related_queries(query)
|
612 |
+
if related:
|
613 |
+
print("Related questions:")
|
614 |
+
for r in related[:3]:
|
615 |
+
print(f" - {r}")
|
616 |
+
print("\n" + "-"*50 + "\n")
|
617 |
+
|
618 |
+
# Run evaluation
|
619 |
+
print("\n" + "="*50)
|
620 |
+
print("Running Evaluation")
|
621 |
+
print("="*50 + "\n")
|
622 |
+
|
623 |
+
evaluator = BotEvaluator(bot)
|
624 |
+
|
625 |
+
# Retrieval accuracy
|
626 |
+
accuracy_results = evaluator.evaluate_retrieval_accuracy()
|
627 |
+
print(f"Retrieval Accuracy: {accuracy_results['accuracy']:.2%}")
|
628 |
+
print(f"Correct: {accuracy_results['correct']}/{accuracy_results['total']}")
|
629 |
+
|
630 |
+
# Response quality
|
631 |
+
quality_results = evaluator.evaluate_response_quality()
|
632 |
print(f"\nAverage Response Quality: {quality_results['average_quality_score']:.2%}")
|