thecoderhere commited on
Commit
0bcee26
·
verified ·
1 Parent(s): 30113b5

Update app/bot.py

Browse files
Files changed (1) hide show
  1. app/bot.py +631 -625
app/bot.py CHANGED
@@ -1,626 +1,632 @@
1
- # app/bot.py
2
- from __future__ import annotations
3
-
4
- import logging
5
- import os
6
- import re
7
- import unicodedata
8
- import warnings
9
- from pathlib import Path
10
- from typing import Any, List, Dict, Tuple
11
- import json
12
-
13
- import numpy as np
14
- import pandas as pd
15
- import torch
16
- from sentence_transformers import SentenceTransformer, CrossEncoder
17
- from sklearn.metrics.pairwise import cosine_similarity
18
- from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
19
- import nltk
20
-
21
- # Download required NLTK data
22
- try:
23
- nltk.download('punkt', quiet=True)
24
- nltk.download('stopwords', quiet=True)
25
- except:
26
- pass
27
-
28
- warnings.filterwarnings("ignore")
29
-
30
-
31
- class RequirementError(RuntimeError):
32
- pass
33
-
34
-
35
- class JupiterFAQBot:
36
- # ------------------------------------------------------------------ #
37
- # Free Models Configuration
38
- # ------------------------------------------------------------------ #
39
- MODELS = {
40
- "bi": "sentence-transformers/all-MiniLM-L6-v2", # Fast semantic search
41
- "cross": "cross-encoder/ms-marco-MiniLM-L-6-v2", # Reranking
42
- "qa": "deepset/roberta-base-squad2", # Better QA model
43
- "summarizer": "facebook/bart-large-cnn", # Better summarization
44
- }
45
-
46
- # Retrieval parameters
47
- TOP_K = 15 # More candidates for better coverage
48
- HIGH_SIM = 0.85 # High confidence threshold
49
- CROSS_OK = 0.50 # Cross-encoder threshold
50
- MIN_SIM = 0.40 # Minimum similarity to consider
51
-
52
- # Paths
53
- EMB_CACHE = Path("data/faq_embeddings.npy")
54
- FAQ_PATH = Path("data/faqs.csv")
55
-
56
- # Response templates for better UX
57
- CONFIDENCE_LEVELS = {
58
- "high": "This information matches your query based on our FAQs:\n\n",
59
- "medium": "This appears to be relevant to your question:\n\n",
60
- "low": "This may be related to your query and could be helpful:\n\n",
61
- "none": (
62
- "We couldn't find a direct match for your question. "
63
- "However, we can assist with topics such as:\n"
64
- " Account opening and KYC\n"
65
- " Payments and UPI\n"
66
- " Rewards and cashback\n"
67
- " Credit cards and loans\n"
68
- " Investments and savings\n\n"
69
- "Please try rephrasing your question or selecting a topic above."
70
- )
71
- }
72
-
73
- # ------------------------------------------------------------------ #
74
- def __init__(self, csv_path: str = None) -> None:
75
- logging.basicConfig(format="%(levelname)s | %(message)s", level=logging.INFO)
76
-
77
- # Use provided path or default
78
- self.csv_path = csv_path or str(self.FAQ_PATH)
79
-
80
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
81
- self.pipe_dev = 0 if self.device.type == "cuda" else -1
82
-
83
- self._load_data(self.csv_path)
84
- self._setup_models()
85
- self._setup_embeddings()
86
-
87
- logging.info("Jupiter FAQ Bot ready ")
88
-
89
- # ------------------------ Text Processing ------------------------- #
90
- @staticmethod
91
- def _clean(text: str) -> str:
92
- """Clean and normalize text"""
93
- if pd.isna(text):
94
- return ""
95
- text = str(text)
96
- text = unicodedata.normalize("NFC", text)
97
- # Remove extra whitespace but keep sentence structure
98
- text = re.sub(r'\s+', ' ', text)
99
- # Keep bullet points and formatting
100
- text = re.sub(r'•\s*', '\n• ', text)
101
- return text.strip()
102
-
103
- @staticmethod
104
- def _preprocess_query(query: str) -> str:
105
- """Preprocess user query for better matching"""
106
- # Expand common abbreviations
107
- abbreviations = {
108
- 'kyc': 'know your customer verification',
109
- 'upi': 'unified payments interface',
110
- 'fd': 'fixed deposit',
111
- 'sip': 'systematic investment plan',
112
- 'neft': 'national electronic funds transfer',
113
- 'rtgs': 'real time gross settlement',
114
- 'imps': 'immediate payment service',
115
- 'emi': 'equated monthly installment',
116
- 'apr': 'annual percentage rate',
117
- 'atm': 'automated teller machine',
118
- 'pin': 'personal identification number',
119
- }
120
-
121
- query_lower = query.lower()
122
- for abbr, full in abbreviations.items():
123
- if abbr in query_lower.split():
124
- query_lower = query_lower.replace(abbr, full)
125
-
126
- return query_lower
127
-
128
- # ------------------------ Initialization -------------------------- #
129
- def _load_data(self, path: str):
130
- """Load and preprocess FAQ data"""
131
- if not Path(path).exists():
132
- raise RequirementError(f"CSV not found: {path}")
133
-
134
- df = pd.read_csv(path)
135
-
136
- # Clean all text fields
137
- df["question"] = df["question"].apply(self._clean)
138
- df["answer"] = df["answer"].apply(self._clean)
139
- df["category"] = df["category"].fillna("General")
140
-
141
- # Create searchable text combining question and category
142
- df["searchable"] = df["question"].str.lower() + " " + df["category"].str.lower()
143
-
144
- # Remove duplicates
145
- df = df.drop_duplicates(subset=["question"]).reset_index(drop=True)
146
-
147
- self.faq = df
148
- logging.info(f"Loaded {len(self.faq)} FAQ entries from {len(df['category'].unique())} categories")
149
-
150
- def _setup_models(self):
151
- """Initialize all models"""
152
- logging.info("Loading models...")
153
-
154
- # Sentence transformer for embeddings
155
- self.bi = SentenceTransformer(self.MODELS["bi"], device=self.device)
156
-
157
- # Cross-encoder for reranking
158
- self.cross = CrossEncoder(self.MODELS["cross"], device=self.device)
159
-
160
- # QA model
161
- self.qa = pipeline(
162
- "question-answering",
163
- model=self.MODELS["qa"],
164
- device=self.pipe_dev,
165
- handle_impossible_answer=True
166
- )
167
-
168
- # Summarization model - using BART for better quality
169
- self.summarizer = pipeline(
170
- "summarization",
171
- model=self.MODELS["summarizer"],
172
- device=self.pipe_dev,
173
- max_length=150,
174
- min_length=50
175
- )
176
-
177
- logging.info("All models loaded successfully")
178
-
179
- def _setup_embeddings(self):
180
- """Create or load embeddings"""
181
- questions = self.faq["searchable"].tolist()
182
-
183
- if self.EMB_CACHE.exists():
184
- emb = np.load(self.EMB_CACHE)
185
- if len(emb) != len(questions):
186
- logging.info("Regenerating embeddings due to data change...")
187
- emb = self.bi.encode(questions, show_progress_bar=True, convert_to_tensor=False)
188
- np.save(self.EMB_CACHE, emb)
189
- else:
190
- logging.info("Creating embeddings for the first time...")
191
- emb = self.bi.encode(questions, show_progress_bar=True, convert_to_tensor=False)
192
- self.EMB_CACHE.parent.mkdir(parents=True, exist_ok=True)
193
- np.save(self.EMB_CACHE, emb)
194
-
195
- self.embeddings = emb
196
-
197
- # ------------------------- Retrieval ------------------------------ #
198
- def _retrieve_candidates(self, query: str, top_k: int = None) -> List[Dict]:
199
- """Retrieve top candidates using semantic search"""
200
- if top_k is None:
201
- top_k = self.TOP_K
202
-
203
- # Preprocess query
204
- processed_query = self._preprocess_query(query)
205
-
206
- # Encode query
207
- query_emb = self.bi.encode([processed_query])
208
-
209
- # Calculate similarities
210
- similarities = cosine_similarity(query_emb, self.embeddings)[0]
211
-
212
- # Get top indices
213
- top_indices = similarities.argsort()[-top_k:][::-1]
214
-
215
- # Filter by minimum similarity
216
- candidates = []
217
- for idx in top_indices:
218
- if similarities[idx] >= self.MIN_SIM:
219
- candidates.append({
220
- "idx": int(idx),
221
- "question": self.faq.iloc[idx]["question"],
222
- "answer": self.faq.iloc[idx]["answer"],
223
- "category": self.faq.iloc[idx]["category"],
224
- "similarity": float(similarities[idx])
225
- })
226
-
227
- return candidates
228
-
229
- def _rerank_candidates(self, query: str, candidates: List[Dict]) -> List[Dict]:
230
- """Rerank candidates using cross-encoder"""
231
- if not candidates:
232
- return []
233
-
234
- # Prepare pairs for cross-encoder
235
- pairs = [[query, c["question"]] for c in candidates]
236
-
237
- # Get cross-encoder scores
238
- scores = self.cross.predict(pairs, convert_to_numpy=True)
239
-
240
- # Add scores to candidates
241
- for c, score in zip(candidates, scores):
242
- c["cross_score"] = float(score)
243
-
244
- # Filter and sort by cross-encoder score
245
- reranked = [c for c in candidates if c["cross_score"] >= self.CROSS_OK]
246
- reranked.sort(key=lambda x: x["cross_score"], reverse=True)
247
-
248
- return reranked
249
-
250
- def _extract_answer(self, query: str, context: str) -> Dict[str, Any]:
251
- """Extract specific answer using QA model"""
252
- try:
253
- result = self.qa(question=query, context=context)
254
- return {
255
- "answer": result["answer"],
256
- "score": result["score"],
257
- "start": result.get("start", 0),
258
- "end": result.get("end", len(result["answer"]))
259
- }
260
- except Exception as e:
261
- logging.warning(f"QA extraction failed: {e}")
262
- return {"answer": context, "score": 0.5}
263
-
264
- def _create_friendly_response(self, answers: List[str], confidence: str = "medium") -> str:
265
- """Create a user-friendly response from multiple answers"""
266
- if not answers:
267
- return self.CONFIDENCE_LEVELS["none"]
268
-
269
- # Remove duplicates while preserving order
270
- unique_answers = []
271
- seen = set()
272
- for ans in answers:
273
- normalized = ans.lower().strip()
274
- if normalized not in seen:
275
- seen.add(normalized)
276
- unique_answers.append(ans)
277
-
278
- if len(unique_answers) == 1:
279
- # Single answer - return as is with confidence prefix
280
- return self.CONFIDENCE_LEVELS[confidence] + unique_answers[0]
281
-
282
- # Multiple answers - need to summarize
283
- combined_text = " ".join(unique_answers)
284
-
285
- # If text is short enough, format it nicely
286
- if len(combined_text) < 300:
287
- response = self.CONFIDENCE_LEVELS[confidence]
288
- for i, answer in enumerate(unique_answers):
289
- if "" in answer:
290
- # Already has bullets
291
- response += answer + "\n\n"
292
- else:
293
- # Add as paragraph
294
- response += answer + "\n\n"
295
- return response.strip()
296
-
297
- # Long text - summarize it
298
- try:
299
- # Prepare text for summarization
300
- summary_input = f"Summarize the following information about Jupiter banking services: {combined_text}"
301
-
302
- # Generate summary
303
- summary = self.summarizer(summary_input, max_length=150, min_length=50, do_sample=False)
304
- summarized_text = summary[0]['summary_text']
305
-
306
- # Make it more conversational
307
- response = self.CONFIDENCE_LEVELS[confidence]
308
- response += self._make_conversational(summarized_text)
309
-
310
- return response
311
-
312
- except Exception as e:
313
- logging.warning(f"Summarization failed: {e}")
314
- # Fallback to formatted response
315
- return self._format_multiple_answers(unique_answers, confidence)
316
-
317
- def _make_conversational(self, text: str) -> str:
318
- """Make response more conversational and friendly"""
319
- # Add appropriate punctuation if missing
320
- if text and text[-1] not in '.!?':
321
- text += '.'
322
-
323
- # Replace robotic phrases
324
- replacements = {
325
- "The user": "You",
326
- "the user": "you",
327
- "It is": "It's",
328
- "You will": "You'll",
329
- "You can not": "You can't",
330
- "Do not": "Don't",
331
- }
332
-
333
- for old, new in replacements.items():
334
- text = text.replace(old, new)
335
-
336
- return text
337
-
338
- def _format_multiple_answers(self, answers: List[str], confidence: str) -> str:
339
- """Format multiple answers nicely"""
340
- response = self.CONFIDENCE_LEVELS[confidence]
341
-
342
- if len(answers) <= 3:
343
- # Few answers - show all
344
- for answer in answers:
345
- if "" in answer:
346
- response += answer + "\n\n"
347
- else:
348
- response += f"• {answer}\n\n"
349
- else:
350
- # Many answers - group by category
351
- response += "Here are the key points:\n\n"
352
- for i, answer in enumerate(answers[:5]): # Limit to 5
353
- response += f"{i+1}. {answer}\n\n"
354
-
355
- return response.strip()
356
-
357
- # ------------------------- Main API ------------------------------- #
358
- def generate_response(self, query: str) -> str:
359
- """Generate response for user query"""
360
- query = self._clean(query)
361
-
362
- # Step 1: Retrieve candidates
363
- candidates = self._retrieve_candidates(query)
364
-
365
- if not candidates:
366
- return self.CONFIDENCE_LEVELS["none"]
367
-
368
- # Step 2: Check for high similarity match
369
- if candidates[0]["similarity"] >= self.HIGH_SIM:
370
- return self.CONFIDENCE_LEVELS["high"] + candidates[0]["answer"]
371
-
372
- # Step 3: Rerank candidates
373
- reranked = self._rerank_candidates(query, candidates)
374
-
375
- if not reranked:
376
- # Use original candidates with lower confidence
377
- reranked = candidates[:3]
378
- confidence = "low"
379
- else:
380
- confidence = "high" if reranked[0]["cross_score"] > 0.8 else "medium"
381
-
382
- # Step 4: Extract relevant answers
383
- relevant_answers = []
384
-
385
- for candidate in reranked[:5]: # Top 5 reranked
386
- # Try QA extraction for more specific answer
387
- qa_result = self._extract_answer(query, candidate["answer"])
388
-
389
- if qa_result["score"] > 0.3:
390
- # Good QA match
391
- relevant_answers.append(qa_result["answer"])
392
- else:
393
- # Use full answer if QA didn't find specific part
394
- relevant_answers.append(candidate["answer"])
395
-
396
- # Step 5: Create final response
397
- final_response = self._create_friendly_response(relevant_answers, confidence)
398
-
399
- return final_response
400
-
401
- def suggest_related_queries(self, query: str) -> List[str]:
402
- """Suggest related queries based on similar questions"""
403
- candidates = self._retrieve_candidates(query, top_k=10)
404
-
405
- related = []
406
- seen = set()
407
-
408
- for candidate in candidates:
409
- if candidate["similarity"] >= 0.5 and candidate["similarity"] < 0.9:
410
- # Clean question for display
411
- clean_q = candidate["question"].strip()
412
- if clean_q.lower() not in seen and clean_q.lower() != query.lower():
413
- seen.add(clean_q.lower())
414
- related.append(clean_q)
415
-
416
- # Return top 5 related queries
417
- return related[:5]
418
-
419
- def get_categories(self) -> List[str]:
420
- """Get all available FAQ categories"""
421
- return sorted(self.faq["category"].unique().tolist())
422
-
423
- def get_faqs_by_category(self, category: str) -> List[Dict[str, str]]:
424
- """Get all FAQs for a specific category"""
425
- cat_faqs = self.faq[self.faq["category"].str.lower() == category.lower()]
426
-
427
- return [
428
- {
429
- "question": row["question"],
430
- "answer": row["answer"]
431
- }
432
- for _, row in cat_faqs.iterrows()
433
- ]
434
-
435
- def search_faqs(self, keyword: str) -> List[Dict[str, str]]:
436
- """Simple keyword search in FAQs"""
437
- keyword_lower = keyword.lower()
438
-
439
- matches = []
440
- for _, row in self.faq.iterrows():
441
- if (keyword_lower in row["question"].lower() or
442
- keyword_lower in row["answer"].lower()):
443
- matches.append({
444
- "question": row["question"],
445
- "answer": row["answer"],
446
- "category": row["category"]
447
- })
448
-
449
- return matches[:10] # Limit to 10 results
450
-
451
-
452
- # Evaluation module
453
- class BotEvaluator:
454
- """Evaluate bot performance"""
455
-
456
- def __init__(self, bot: JupiterFAQBot):
457
- self.bot = bot
458
-
459
- def create_test_queries(self) -> List[Dict[str, str]]:
460
- """Create test queries based on FAQ categories"""
461
- test_queries = [
462
- # Account queries
463
- {"query": "How do I open an account?", "expected_category": "Account"},
464
- {"query": "What is Jupiter savings account?", "expected_category": "Account"},
465
-
466
- # Payment queries
467
- {"query": "How to make UPI payment?", "expected_category": "Payments"},
468
- {"query": "What is the daily transaction limit?", "expected_category": "Payments"},
469
-
470
- # Rewards queries
471
- {"query": "How do I earn cashback?", "expected_category": "Rewards"},
472
- {"query": "What are Jewels?", "expected_category": "Rewards"},
473
-
474
- # Investment queries
475
- {"query": "Can I invest in mutual funds?", "expected_category": "Investments"},
476
- {"query": "What is Magic Spends?", "expected_category": "Magic Spends"},
477
-
478
- # Loan queries
479
- {"query": "How to apply for personal loan?", "expected_category": "Jupiter Loans"},
480
- {"query": "What is the interest rate?", "expected_category": "Jupiter Loans"},
481
-
482
- # Card queries
483
- {"query": "How to get credit card?", "expected_category": "Edge+ Credit Card"},
484
- {"query": "Is there any annual fee?", "expected_category": "Edge+ Credit Card"},
485
- ]
486
-
487
- return test_queries
488
-
489
- def evaluate_retrieval_accuracy(self) -> Dict[str, float]:
490
- """Evaluate how well the bot retrieves relevant information"""
491
- test_queries = self.create_test_queries()
492
-
493
- correct = 0
494
- total = len(test_queries)
495
-
496
- results = []
497
-
498
- for test in test_queries:
499
- response = self.bot.generate_response(test["query"])
500
-
501
- # Check if response mentions expected category content
502
- is_correct = test["expected_category"].lower() in response.lower()
503
-
504
- if is_correct:
505
- correct += 1
506
-
507
- results.append({
508
- "query": test["query"],
509
- "expected_category": test["expected_category"],
510
- "response": response[:200] + "..." if len(response) > 200 else response,
511
- "correct": is_correct
512
- })
513
-
514
- accuracy = correct / total if total > 0 else 0
515
-
516
- return {
517
- "accuracy": accuracy,
518
- "correct": correct,
519
- "total": total,
520
- "results": results
521
- }
522
-
523
- def evaluate_response_quality(self) -> Dict[str, Any]:
524
- """Evaluate the quality of responses"""
525
- test_queries = [
526
- "What is Jupiter?",
527
- "How do I earn rewards?",
528
- "Tell me about credit cards",
529
- "Can I get a loan?",
530
- "How to invest money?"
531
- ]
532
-
533
- quality_metrics = []
534
-
535
- for query in test_queries:
536
- response = self.bot.generate_response(query)
537
-
538
- # Check quality indicators
539
- has_greeting = any(phrase in response for phrase in ["Based on", "Here's", "I found"])
540
- has_structure = "\n" in response or "•" in response
541
- appropriate_length = 50 < len(response) < 500
542
-
543
- quality_score = sum([has_greeting, has_structure, appropriate_length]) / 3
544
-
545
- quality_metrics.append({
546
- "query": query,
547
- "response_length": len(response),
548
- "has_greeting": has_greeting,
549
- "has_structure": has_structure,
550
- "appropriate_length": appropriate_length,
551
- "quality_score": quality_score
552
- })
553
-
554
- avg_quality = sum(m["quality_score"] for m in quality_metrics) / len(quality_metrics)
555
-
556
- return {
557
- "average_quality_score": avg_quality,
558
- "metrics": quality_metrics
559
- }
560
-
561
-
562
- # Utility functions for data preparation
563
- def prepare_faq_data(csv_path: str = "data/faqs.csv") -> pd.DataFrame:
564
- """Prepare and validate FAQ data"""
565
- df = pd.read_csv(csv_path)
566
-
567
- # Ensure required columns exist
568
- required_cols = ["question", "answer", "category"]
569
- if not all(col in df.columns for col in required_cols):
570
- raise ValueError(f"CSV must contain columns: {required_cols}")
571
-
572
- # Basic stats
573
- print(f"Total FAQs: {len(df)}")
574
- print(f"Categories: {df['category'].nunique()}")
575
- print(f"\nCategory distribution:")
576
- print(df['category'].value_counts())
577
-
578
- return df
579
-
580
-
581
- # Main execution example
582
- if __name__ == "__main__":
583
- # Initialize bot
584
- bot = JupiterFAQBot()
585
-
586
- # Test some queries
587
- test_queries = [
588
- "How do I open a savings account?",
589
- "What are the cashback rates?",
590
- "Can I get a personal loan?",
591
- "How to use UPI?",
592
- "Tell me about investments"
593
- ]
594
-
595
- print("\n" + "="*50)
596
- print("Testing Jupiter FAQ Bot")
597
- print("="*50 + "\n")
598
-
599
- for query in test_queries:
600
- print(f"Q: {query}")
601
- response = bot.generate_response(query)
602
- print(f"A: {response}\n")
603
-
604
- # Show related queries
605
- related = bot.suggest_related_queries(query)
606
- if related:
607
- print("Related questions:")
608
- for r in related[:3]:
609
- print(f" - {r}")
610
- print("\n" + "-"*50 + "\n")
611
-
612
- # Run evaluation
613
- print("\n" + "="*50)
614
- print("Running Evaluation")
615
- print("="*50 + "\n")
616
-
617
- evaluator = BotEvaluator(bot)
618
-
619
- # Retrieval accuracy
620
- accuracy_results = evaluator.evaluate_retrieval_accuracy()
621
- print(f"Retrieval Accuracy: {accuracy_results['accuracy']:.2%}")
622
- print(f"Correct: {accuracy_results['correct']}/{accuracy_results['total']}")
623
-
624
- # Response quality
625
- quality_results = evaluator.evaluate_response_quality()
 
 
 
 
 
 
626
  print(f"\nAverage Response Quality: {quality_results['average_quality_score']:.2%}")
 
1
+ # app/bot.py
2
+ import os
3
+ # Set cache directories before importing transformers
4
+ os.environ['HF_HOME'] = '/app/.cache'
5
+ os.environ['TRANSFORMERS_CACHE'] = '/app/.cache/transformers'
6
+ os.environ['SENTENCE_TRANSFORMERS_HOME'] = '/app/.cache/sentence_transformers'
7
+ os.environ['TORCH_HOME'] = '/app/.cache/torch'
8
+
9
+ from __future__ import annotations
10
+
11
+ import logging
12
+ import re
13
+ import unicodedata
14
+ import warnings
15
+ from pathlib import Path
16
+ from typing import Any, List, Dict, Tuple
17
+ import json
18
+
19
+ import numpy as np
20
+ import pandas as pd
21
+ import torch
22
+ from sentence_transformers import SentenceTransformer, CrossEncoder
23
+ from sklearn.metrics.pairwise import cosine_similarity
24
+ from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
25
+ import nltk
26
+
27
+ # Download required NLTK data
28
+ try:
29
+ nltk.download('punkt', quiet=True)
30
+ nltk.download('stopwords', quiet=True)
31
+ except:
32
+ pass
33
+
34
+ warnings.filterwarnings("ignore")
35
+
36
+
37
+ class RequirementError(RuntimeError):
38
+ pass
39
+
40
+
41
+ class JupiterFAQBot:
42
+ # ------------------------------------------------------------------ #
43
+ # Free Models Configuration
44
+ # ------------------------------------------------------------------ #
45
+ MODELS = {
46
+ "bi": "sentence-transformers/all-MiniLM-L6-v2", # Fast semantic search
47
+ "cross": "cross-encoder/ms-marco-MiniLM-L-6-v2", # Reranking
48
+ "qa": "deepset/roberta-base-squad2", # Better QA model
49
+ "summarizer": "facebook/bart-large-cnn", # Better summarization
50
+ }
51
+
52
+ # Retrieval parameters
53
+ TOP_K = 15 # More candidates for better coverage
54
+ HIGH_SIM = 0.85 # High confidence threshold
55
+ CROSS_OK = 0.50 # Cross-encoder threshold
56
+ MIN_SIM = 0.40 # Minimum similarity to consider
57
+
58
+ # Paths
59
+ EMB_CACHE = Path("data/faq_embeddings.npy")
60
+ FAQ_PATH = Path("data/faqs.csv")
61
+
62
+ # Response templates for better UX
63
+ CONFIDENCE_LEVELS = {
64
+ "high": "This information matches your query based on our FAQs:\n\n",
65
+ "medium": "This appears to be relevant to your question:\n\n",
66
+ "low": "This may be related to your query and could be helpful:\n\n",
67
+ "none": (
68
+ "We couldn't find a direct match for your question. "
69
+ "However, we can assist with topics such as:\n"
70
+ "• Account opening and KYC\n"
71
+ "• Payments and UPI\n"
72
+ "• Rewards and cashback\n"
73
+ "• Credit cards and loans\n"
74
+ "• Investments and savings\n\n"
75
+ "Please try rephrasing your question or selecting a topic above."
76
+ )
77
+ }
78
+
79
+ # ------------------------------------------------------------------ #
80
+ def __init__(self, csv_path: str = None) -> None:
81
+ logging.basicConfig(format="%(levelname)s | %(message)s", level=logging.INFO)
82
+
83
+ # Use provided path or default
84
+ self.csv_path = csv_path or str(self.FAQ_PATH)
85
+
86
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
87
+ self.pipe_dev = 0 if self.device.type == "cuda" else -1
88
+
89
+ self._load_data(self.csv_path)
90
+ self._setup_models()
91
+ self._setup_embeddings()
92
+
93
+ logging.info("Jupiter FAQ Bot ready ✔")
94
+
95
+ # ------------------------ Text Processing ------------------------- #
96
+ @staticmethod
97
+ def _clean(text: str) -> str:
98
+ """Clean and normalize text"""
99
+ if pd.isna(text):
100
+ return ""
101
+ text = str(text)
102
+ text = unicodedata.normalize("NFC", text)
103
+ # Remove extra whitespace but keep sentence structure
104
+ text = re.sub(r'\s+', ' ', text)
105
+ # Keep bullet points and formatting
106
+ text = re.sub(r'•\s*', '\n• ', text)
107
+ return text.strip()
108
+
109
+ @staticmethod
110
+ def _preprocess_query(query: str) -> str:
111
+ """Preprocess user query for better matching"""
112
+ # Expand common abbreviations
113
+ abbreviations = {
114
+ 'kyc': 'know your customer verification',
115
+ 'upi': 'unified payments interface',
116
+ 'fd': 'fixed deposit',
117
+ 'sip': 'systematic investment plan',
118
+ 'neft': 'national electronic funds transfer',
119
+ 'rtgs': 'real time gross settlement',
120
+ 'imps': 'immediate payment service',
121
+ 'emi': 'equated monthly installment',
122
+ 'apr': 'annual percentage rate',
123
+ 'atm': 'automated teller machine',
124
+ 'pin': 'personal identification number',
125
+ }
126
+
127
+ query_lower = query.lower()
128
+ for abbr, full in abbreviations.items():
129
+ if abbr in query_lower.split():
130
+ query_lower = query_lower.replace(abbr, full)
131
+
132
+ return query_lower
133
+
134
+ # ------------------------ Initialization -------------------------- #
135
+ def _load_data(self, path: str):
136
+ """Load and preprocess FAQ data"""
137
+ if not Path(path).exists():
138
+ raise RequirementError(f"CSV not found: {path}")
139
+
140
+ df = pd.read_csv(path)
141
+
142
+ # Clean all text fields
143
+ df["question"] = df["question"].apply(self._clean)
144
+ df["answer"] = df["answer"].apply(self._clean)
145
+ df["category"] = df["category"].fillna("General")
146
+
147
+ # Create searchable text combining question and category
148
+ df["searchable"] = df["question"].str.lower() + " " + df["category"].str.lower()
149
+
150
+ # Remove duplicates
151
+ df = df.drop_duplicates(subset=["question"]).reset_index(drop=True)
152
+
153
+ self.faq = df
154
+ logging.info(f"Loaded {len(self.faq)} FAQ entries from {len(df['category'].unique())} categories")
155
+
156
+ def _setup_models(self):
157
+ """Initialize all models"""
158
+ logging.info("Loading models...")
159
+
160
+ # Sentence transformer for embeddings
161
+ self.bi = SentenceTransformer(self.MODELS["bi"], device=self.device)
162
+
163
+ # Cross-encoder for reranking
164
+ self.cross = CrossEncoder(self.MODELS["cross"], device=self.device)
165
+
166
+ # QA model
167
+ self.qa = pipeline(
168
+ "question-answering",
169
+ model=self.MODELS["qa"],
170
+ device=self.pipe_dev,
171
+ handle_impossible_answer=True
172
+ )
173
+
174
+ # Summarization model - using BART for better quality
175
+ self.summarizer = pipeline(
176
+ "summarization",
177
+ model=self.MODELS["summarizer"],
178
+ device=self.pipe_dev,
179
+ max_length=150,
180
+ min_length=50
181
+ )
182
+
183
+ logging.info("All models loaded successfully")
184
+
185
+ def _setup_embeddings(self):
186
+ """Create or load embeddings"""
187
+ questions = self.faq["searchable"].tolist()
188
+
189
+ if self.EMB_CACHE.exists():
190
+ emb = np.load(self.EMB_CACHE)
191
+ if len(emb) != len(questions):
192
+ logging.info("Regenerating embeddings due to data change...")
193
+ emb = self.bi.encode(questions, show_progress_bar=True, convert_to_tensor=False)
194
+ np.save(self.EMB_CACHE, emb)
195
+ else:
196
+ logging.info("Creating embeddings for the first time...")
197
+ emb = self.bi.encode(questions, show_progress_bar=True, convert_to_tensor=False)
198
+ self.EMB_CACHE.parent.mkdir(parents=True, exist_ok=True)
199
+ np.save(self.EMB_CACHE, emb)
200
+
201
+ self.embeddings = emb
202
+
203
+ # ------------------------- Retrieval ------------------------------ #
204
+ def _retrieve_candidates(self, query: str, top_k: int = None) -> List[Dict]:
205
+ """Retrieve top candidates using semantic search"""
206
+ if top_k is None:
207
+ top_k = self.TOP_K
208
+
209
+ # Preprocess query
210
+ processed_query = self._preprocess_query(query)
211
+
212
+ # Encode query
213
+ query_emb = self.bi.encode([processed_query])
214
+
215
+ # Calculate similarities
216
+ similarities = cosine_similarity(query_emb, self.embeddings)[0]
217
+
218
+ # Get top indices
219
+ top_indices = similarities.argsort()[-top_k:][::-1]
220
+
221
+ # Filter by minimum similarity
222
+ candidates = []
223
+ for idx in top_indices:
224
+ if similarities[idx] >= self.MIN_SIM:
225
+ candidates.append({
226
+ "idx": int(idx),
227
+ "question": self.faq.iloc[idx]["question"],
228
+ "answer": self.faq.iloc[idx]["answer"],
229
+ "category": self.faq.iloc[idx]["category"],
230
+ "similarity": float(similarities[idx])
231
+ })
232
+
233
+ return candidates
234
+
235
+ def _rerank_candidates(self, query: str, candidates: List[Dict]) -> List[Dict]:
236
+ """Rerank candidates using cross-encoder"""
237
+ if not candidates:
238
+ return []
239
+
240
+ # Prepare pairs for cross-encoder
241
+ pairs = [[query, c["question"]] for c in candidates]
242
+
243
+ # Get cross-encoder scores
244
+ scores = self.cross.predict(pairs, convert_to_numpy=True)
245
+
246
+ # Add scores to candidates
247
+ for c, score in zip(candidates, scores):
248
+ c["cross_score"] = float(score)
249
+
250
+ # Filter and sort by cross-encoder score
251
+ reranked = [c for c in candidates if c["cross_score"] >= self.CROSS_OK]
252
+ reranked.sort(key=lambda x: x["cross_score"], reverse=True)
253
+
254
+ return reranked
255
+
256
+ def _extract_answer(self, query: str, context: str) -> Dict[str, Any]:
257
+ """Extract specific answer using QA model"""
258
+ try:
259
+ result = self.qa(question=query, context=context)
260
+ return {
261
+ "answer": result["answer"],
262
+ "score": result["score"],
263
+ "start": result.get("start", 0),
264
+ "end": result.get("end", len(result["answer"]))
265
+ }
266
+ except Exception as e:
267
+ logging.warning(f"QA extraction failed: {e}")
268
+ return {"answer": context, "score": 0.5}
269
+
270
+ def _create_friendly_response(self, answers: List[str], confidence: str = "medium") -> str:
271
+ """Create a user-friendly response from multiple answers"""
272
+ if not answers:
273
+ return self.CONFIDENCE_LEVELS["none"]
274
+
275
+ # Remove duplicates while preserving order
276
+ unique_answers = []
277
+ seen = set()
278
+ for ans in answers:
279
+ normalized = ans.lower().strip()
280
+ if normalized not in seen:
281
+ seen.add(normalized)
282
+ unique_answers.append(ans)
283
+
284
+ if len(unique_answers) == 1:
285
+ # Single answer - return as is with confidence prefix
286
+ return self.CONFIDENCE_LEVELS[confidence] + unique_answers[0]
287
+
288
+ # Multiple answers - need to summarize
289
+ combined_text = " ".join(unique_answers)
290
+
291
+ # If text is short enough, format it nicely
292
+ if len(combined_text) < 300:
293
+ response = self.CONFIDENCE_LEVELS[confidence]
294
+ for i, answer in enumerate(unique_answers):
295
+ if "•" in answer:
296
+ # Already has bullets
297
+ response += answer + "\n\n"
298
+ else:
299
+ # Add as paragraph
300
+ response += answer + "\n\n"
301
+ return response.strip()
302
+
303
+ # Long text - summarize it
304
+ try:
305
+ # Prepare text for summarization
306
+ summary_input = f"Summarize the following information about Jupiter banking services: {combined_text}"
307
+
308
+ # Generate summary
309
+ summary = self.summarizer(summary_input, max_length=150, min_length=50, do_sample=False)
310
+ summarized_text = summary[0]['summary_text']
311
+
312
+ # Make it more conversational
313
+ response = self.CONFIDENCE_LEVELS[confidence]
314
+ response += self._make_conversational(summarized_text)
315
+
316
+ return response
317
+
318
+ except Exception as e:
319
+ logging.warning(f"Summarization failed: {e}")
320
+ # Fallback to formatted response
321
+ return self._format_multiple_answers(unique_answers, confidence)
322
+
323
+ def _make_conversational(self, text: str) -> str:
324
+ """Make response more conversational and friendly"""
325
+ # Add appropriate punctuation if missing
326
+ if text and text[-1] not in '.!?':
327
+ text += '.'
328
+
329
+ # Replace robotic phrases
330
+ replacements = {
331
+ "The user": "You",
332
+ "the user": "you",
333
+ "It is": "It's",
334
+ "You will": "You'll",
335
+ "You can not": "You can't",
336
+ "Do not": "Don't",
337
+ }
338
+
339
+ for old, new in replacements.items():
340
+ text = text.replace(old, new)
341
+
342
+ return text
343
+
344
+ def _format_multiple_answers(self, answers: List[str], confidence: str) -> str:
345
+ """Format multiple answers nicely"""
346
+ response = self.CONFIDENCE_LEVELS[confidence]
347
+
348
+ if len(answers) <= 3:
349
+ # Few answers - show all
350
+ for answer in answers:
351
+ if "•" in answer:
352
+ response += answer + "\n\n"
353
+ else:
354
+ response += f"• {answer}\n\n"
355
+ else:
356
+ # Many answers - group by category
357
+ response += "Here are the key points:\n\n"
358
+ for i, answer in enumerate(answers[:5]): # Limit to 5
359
+ response += f"{i+1}. {answer}\n\n"
360
+
361
+ return response.strip()
362
+
363
+ # ------------------------- Main API ------------------------------- #
364
+ def generate_response(self, query: str) -> str:
365
+ """Generate response for user query"""
366
+ query = self._clean(query)
367
+
368
+ # Step 1: Retrieve candidates
369
+ candidates = self._retrieve_candidates(query)
370
+
371
+ if not candidates:
372
+ return self.CONFIDENCE_LEVELS["none"]
373
+
374
+ # Step 2: Check for high similarity match
375
+ if candidates[0]["similarity"] >= self.HIGH_SIM:
376
+ return self.CONFIDENCE_LEVELS["high"] + candidates[0]["answer"]
377
+
378
+ # Step 3: Rerank candidates
379
+ reranked = self._rerank_candidates(query, candidates)
380
+
381
+ if not reranked:
382
+ # Use original candidates with lower confidence
383
+ reranked = candidates[:3]
384
+ confidence = "low"
385
+ else:
386
+ confidence = "high" if reranked[0]["cross_score"] > 0.8 else "medium"
387
+
388
+ # Step 4: Extract relevant answers
389
+ relevant_answers = []
390
+
391
+ for candidate in reranked[:5]: # Top 5 reranked
392
+ # Try QA extraction for more specific answer
393
+ qa_result = self._extract_answer(query, candidate["answer"])
394
+
395
+ if qa_result["score"] > 0.3:
396
+ # Good QA match
397
+ relevant_answers.append(qa_result["answer"])
398
+ else:
399
+ # Use full answer if QA didn't find specific part
400
+ relevant_answers.append(candidate["answer"])
401
+
402
+ # Step 5: Create final response
403
+ final_response = self._create_friendly_response(relevant_answers, confidence)
404
+
405
+ return final_response
406
+
407
+ def suggest_related_queries(self, query: str) -> List[str]:
408
+ """Suggest related queries based on similar questions"""
409
+ candidates = self._retrieve_candidates(query, top_k=10)
410
+
411
+ related = []
412
+ seen = set()
413
+
414
+ for candidate in candidates:
415
+ if candidate["similarity"] >= 0.5 and candidate["similarity"] < 0.9:
416
+ # Clean question for display
417
+ clean_q = candidate["question"].strip()
418
+ if clean_q.lower() not in seen and clean_q.lower() != query.lower():
419
+ seen.add(clean_q.lower())
420
+ related.append(clean_q)
421
+
422
+ # Return top 5 related queries
423
+ return related[:5]
424
+
425
+ def get_categories(self) -> List[str]:
426
+ """Get all available FAQ categories"""
427
+ return sorted(self.faq["category"].unique().tolist())
428
+
429
+ def get_faqs_by_category(self, category: str) -> List[Dict[str, str]]:
430
+ """Get all FAQs for a specific category"""
431
+ cat_faqs = self.faq[self.faq["category"].str.lower() == category.lower()]
432
+
433
+ return [
434
+ {
435
+ "question": row["question"],
436
+ "answer": row["answer"]
437
+ }
438
+ for _, row in cat_faqs.iterrows()
439
+ ]
440
+
441
+ def search_faqs(self, keyword: str) -> List[Dict[str, str]]:
442
+ """Simple keyword search in FAQs"""
443
+ keyword_lower = keyword.lower()
444
+
445
+ matches = []
446
+ for _, row in self.faq.iterrows():
447
+ if (keyword_lower in row["question"].lower() or
448
+ keyword_lower in row["answer"].lower()):
449
+ matches.append({
450
+ "question": row["question"],
451
+ "answer": row["answer"],
452
+ "category": row["category"]
453
+ })
454
+
455
+ return matches[:10] # Limit to 10 results
456
+
457
+
458
+ # Evaluation module
459
+ class BotEvaluator:
460
+ """Evaluate bot performance"""
461
+
462
+ def __init__(self, bot: JupiterFAQBot):
463
+ self.bot = bot
464
+
465
+ def create_test_queries(self) -> List[Dict[str, str]]:
466
+ """Create test queries based on FAQ categories"""
467
+ test_queries = [
468
+ # Account queries
469
+ {"query": "How do I open an account?", "expected_category": "Account"},
470
+ {"query": "What is Jupiter savings account?", "expected_category": "Account"},
471
+
472
+ # Payment queries
473
+ {"query": "How to make UPI payment?", "expected_category": "Payments"},
474
+ {"query": "What is the daily transaction limit?", "expected_category": "Payments"},
475
+
476
+ # Rewards queries
477
+ {"query": "How do I earn cashback?", "expected_category": "Rewards"},
478
+ {"query": "What are Jewels?", "expected_category": "Rewards"},
479
+
480
+ # Investment queries
481
+ {"query": "Can I invest in mutual funds?", "expected_category": "Investments"},
482
+ {"query": "What is Magic Spends?", "expected_category": "Magic Spends"},
483
+
484
+ # Loan queries
485
+ {"query": "How to apply for personal loan?", "expected_category": "Jupiter Loans"},
486
+ {"query": "What is the interest rate?", "expected_category": "Jupiter Loans"},
487
+
488
+ # Card queries
489
+ {"query": "How to get credit card?", "expected_category": "Edge+ Credit Card"},
490
+ {"query": "Is there any annual fee?", "expected_category": "Edge+ Credit Card"},
491
+ ]
492
+
493
+ return test_queries
494
+
495
+ def evaluate_retrieval_accuracy(self) -> Dict[str, float]:
496
+ """Evaluate how well the bot retrieves relevant information"""
497
+ test_queries = self.create_test_queries()
498
+
499
+ correct = 0
500
+ total = len(test_queries)
501
+
502
+ results = []
503
+
504
+ for test in test_queries:
505
+ response = self.bot.generate_response(test["query"])
506
+
507
+ # Check if response mentions expected category content
508
+ is_correct = test["expected_category"].lower() in response.lower()
509
+
510
+ if is_correct:
511
+ correct += 1
512
+
513
+ results.append({
514
+ "query": test["query"],
515
+ "expected_category": test["expected_category"],
516
+ "response": response[:200] + "..." if len(response) > 200 else response,
517
+ "correct": is_correct
518
+ })
519
+
520
+ accuracy = correct / total if total > 0 else 0
521
+
522
+ return {
523
+ "accuracy": accuracy,
524
+ "correct": correct,
525
+ "total": total,
526
+ "results": results
527
+ }
528
+
529
+ def evaluate_response_quality(self) -> Dict[str, Any]:
530
+ """Evaluate the quality of responses"""
531
+ test_queries = [
532
+ "What is Jupiter?",
533
+ "How do I earn rewards?",
534
+ "Tell me about credit cards",
535
+ "Can I get a loan?",
536
+ "How to invest money?"
537
+ ]
538
+
539
+ quality_metrics = []
540
+
541
+ for query in test_queries:
542
+ response = self.bot.generate_response(query)
543
+
544
+ # Check quality indicators
545
+ has_greeting = any(phrase in response for phrase in ["Based on", "Here's", "I found"])
546
+ has_structure = "\n" in response or "•" in response
547
+ appropriate_length = 50 < len(response) < 500
548
+
549
+ quality_score = sum([has_greeting, has_structure, appropriate_length]) / 3
550
+
551
+ quality_metrics.append({
552
+ "query": query,
553
+ "response_length": len(response),
554
+ "has_greeting": has_greeting,
555
+ "has_structure": has_structure,
556
+ "appropriate_length": appropriate_length,
557
+ "quality_score": quality_score
558
+ })
559
+
560
+ avg_quality = sum(m["quality_score"] for m in quality_metrics) / len(quality_metrics)
561
+
562
+ return {
563
+ "average_quality_score": avg_quality,
564
+ "metrics": quality_metrics
565
+ }
566
+
567
+
568
+ # Utility functions for data preparation
569
+ def prepare_faq_data(csv_path: str = "data/faqs.csv") -> pd.DataFrame:
570
+ """Prepare and validate FAQ data"""
571
+ df = pd.read_csv(csv_path)
572
+
573
+ # Ensure required columns exist
574
+ required_cols = ["question", "answer", "category"]
575
+ if not all(col in df.columns for col in required_cols):
576
+ raise ValueError(f"CSV must contain columns: {required_cols}")
577
+
578
+ # Basic stats
579
+ print(f"Total FAQs: {len(df)}")
580
+ print(f"Categories: {df['category'].nunique()}")
581
+ print(f"\nCategory distribution:")
582
+ print(df['category'].value_counts())
583
+
584
+ return df
585
+
586
+
587
+ # Main execution example
588
+ if __name__ == "__main__":
589
+ # Initialize bot
590
+ bot = JupiterFAQBot()
591
+
592
+ # Test some queries
593
+ test_queries = [
594
+ "How do I open a savings account?",
595
+ "What are the cashback rates?",
596
+ "Can I get a personal loan?",
597
+ "How to use UPI?",
598
+ "Tell me about investments"
599
+ ]
600
+
601
+ print("\n" + "="*50)
602
+ print("Testing Jupiter FAQ Bot")
603
+ print("="*50 + "\n")
604
+
605
+ for query in test_queries:
606
+ print(f"Q: {query}")
607
+ response = bot.generate_response(query)
608
+ print(f"A: {response}\n")
609
+
610
+ # Show related queries
611
+ related = bot.suggest_related_queries(query)
612
+ if related:
613
+ print("Related questions:")
614
+ for r in related[:3]:
615
+ print(f" - {r}")
616
+ print("\n" + "-"*50 + "\n")
617
+
618
+ # Run evaluation
619
+ print("\n" + "="*50)
620
+ print("Running Evaluation")
621
+ print("="*50 + "\n")
622
+
623
+ evaluator = BotEvaluator(bot)
624
+
625
+ # Retrieval accuracy
626
+ accuracy_results = evaluator.evaluate_retrieval_accuracy()
627
+ print(f"Retrieval Accuracy: {accuracy_results['accuracy']:.2%}")
628
+ print(f"Correct: {accuracy_results['correct']}/{accuracy_results['total']}")
629
+
630
+ # Response quality
631
+ quality_results = evaluator.evaluate_response_quality()
632
  print(f"\nAverage Response Quality: {quality_results['average_quality_score']:.2%}")