Commit
·
3efe7a4
1
Parent(s):
ccbeb57
Add Dockerfile for Streamlit deployment
Browse files- .gitignore +34 -0
- Data_Cleaning.py +99 -0
- Dockerfile +21 -8
- Embeddings.py +319 -0
- Logger.py +109 -0
- README.md +37 -0
- app.py +60 -0
- app_colabcode.ipynb +805 -0
- config.json +25 -0
- evaluation.py +161 -0
- requirements.txt +18 -3
.gitignore
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python cache
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.pyc
|
| 4 |
+
*.pyo
|
| 5 |
+
*.pyd
|
| 6 |
+
|
| 7 |
+
# Logs
|
| 8 |
+
logs/
|
| 9 |
+
*.log
|
| 10 |
+
|
| 11 |
+
# Checkpoints and outputs
|
| 12 |
+
*.ckpt
|
| 13 |
+
*.idx
|
| 14 |
+
*.pkl
|
| 15 |
+
*.jsonl
|
| 16 |
+
|
| 17 |
+
# Environment files
|
| 18 |
+
.env
|
| 19 |
+
*.env
|
| 20 |
+
*.bak
|
| 21 |
+
|
| 22 |
+
# Jupyter/Colab
|
| 23 |
+
.ipynb_checkpoints/
|
| 24 |
+
|
| 25 |
+
# System files
|
| 26 |
+
.DS_Store
|
| 27 |
+
Thumbs.db
|
| 28 |
+
|
| 29 |
+
# Project Files
|
| 30 |
+
eval_dataset.json
|
| 31 |
+
test_questions.txt
|
| 32 |
+
experiment.py
|
| 33 |
+
Retrieval_Summarization.py
|
| 34 |
+
run_evalution.py
|
Data_Cleaning.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pdfplumber
|
| 2 |
+
import os
|
| 3 |
+
import multiprocessing
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
|
| 6 |
+
from Logger import GetLogger
|
| 7 |
+
|
| 8 |
+
class GetDataCleaning():
|
| 9 |
+
def __init__(self, root_folder, excluding_folder=[], logger=None):
|
| 10 |
+
|
| 11 |
+
if not logger:
|
| 12 |
+
obj = GetLogger()
|
| 13 |
+
logger = obj.get_logger()
|
| 14 |
+
self.logger = logger
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
self.root_folder = root_folder
|
| 18 |
+
self.excluding_folder = excluding_folder
|
| 19 |
+
|
| 20 |
+
self.folder_list = [item for item in os.listdir(self.root_folder) if (("txt" not in item.split("_")) and (item not in excluding_folder))]
|
| 21 |
+
self.logger.info("all the folder list is generated sucessfully")
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def pdf_to_txt(self, pdf_path, txt_path):
|
| 25 |
+
text = ""
|
| 26 |
+
with pdfplumber.open(pdf_path) as pdf:
|
| 27 |
+
for page in pdf.pages:
|
| 28 |
+
page_text = page.extract_text()
|
| 29 |
+
if page_text:
|
| 30 |
+
text += page_text + "\n"
|
| 31 |
+
|
| 32 |
+
with open(txt_path, "w", encoding="utf-8") as f:
|
| 33 |
+
f.write(text)
|
| 34 |
+
|
| 35 |
+
def clean_txt(self, text):
|
| 36 |
+
lines = text.split("\n")
|
| 37 |
+
cleaned = []
|
| 38 |
+
|
| 39 |
+
for line in lines:
|
| 40 |
+
line = line.strip()
|
| 41 |
+
if not line:
|
| 42 |
+
continue
|
| 43 |
+
if line.isdigit():
|
| 44 |
+
continue
|
| 45 |
+
if line in ["Infosys", "ICICI Bank"]:
|
| 46 |
+
continue
|
| 47 |
+
cleaned.append(line)
|
| 48 |
+
|
| 49 |
+
return " ".join(cleaned)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def process_file(self, folder, file, logger):
|
| 53 |
+
"""Single file processing pipeline"""
|
| 54 |
+
|
| 55 |
+
input_pdf = os.path.join(self.root_folder, folder, file)
|
| 56 |
+
output_txt = os.path.join(self.root_folder, folder + "_txt", file.replace(".pdf", ".txt"))
|
| 57 |
+
output_cleaned = os.path.join(self.root_folder, folder + "_cleaned_txt", file.replace(".pdf", ".txt"))
|
| 58 |
+
|
| 59 |
+
# Convert PDF → TXT
|
| 60 |
+
self.pdf_to_txt(input_pdf, output_txt)
|
| 61 |
+
|
| 62 |
+
# Clean text
|
| 63 |
+
raw_text = open(output_txt, encoding="utf-8").read()
|
| 64 |
+
cleaned_text = self.clean_txt(raw_text)
|
| 65 |
+
|
| 66 |
+
with open(output_cleaned, "w", encoding="utf-8") as f:
|
| 67 |
+
f.write(cleaned_text)
|
| 68 |
+
|
| 69 |
+
logger.info(f"✅ Processed: {folder}/{file}")
|
| 70 |
+
|
| 71 |
+
def run(self, workers=4):
|
| 72 |
+
try:
|
| 73 |
+
self.logger.info("🚀 Starting Cleaning Process")
|
| 74 |
+
for folder in self.folder_list:
|
| 75 |
+
|
| 76 |
+
os.makedirs(os.path.join(self.root_folder, folder + "_txt"), exist_ok=True)
|
| 77 |
+
os.makedirs(os.path.join(self.root_folder, folder + "_cleaned_txt"), exist_ok=True)
|
| 78 |
+
|
| 79 |
+
pdf_files = [
|
| 80 |
+
f for f in os.listdir(os.path.join(self.root_folder, folder))
|
| 81 |
+
if f.endswith(".pdf")
|
| 82 |
+
]
|
| 83 |
+
|
| 84 |
+
# Run parallel processing
|
| 85 |
+
with multiprocessing.Pool(processes=workers) as pool:
|
| 86 |
+
pool.starmap(self.process_file, [(folder, f, self.logger) for f in pdf_files])
|
| 87 |
+
pool.close()
|
| 88 |
+
pool.join()
|
| 89 |
+
|
| 90 |
+
self.logger.info(f"Data Cleaning completed for folder:{folder}")
|
| 91 |
+
except Exception as e:
|
| 92 |
+
self.logger.error(f"Got Error: {e}")
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
# if __name__ == "__main__":
|
| 97 |
+
# obj = Cleaning(root_folder="financial_reports", excluding_folder=["ICICI"])
|
| 98 |
+
# obj.run()
|
| 99 |
+
# obj.process_file("ICICI", "icici-bank-23.pdf") # for experiment only
|
Dockerfile
CHANGED
|
@@ -1,20 +1,33 @@
|
|
| 1 |
-
FROM python:3.13.5-slim
|
| 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
WORKDIR /app
|
| 4 |
|
|
|
|
| 5 |
RUN apt-get update && apt-get install -y \
|
| 6 |
build-essential \
|
| 7 |
-
curl \
|
| 8 |
git \
|
|
|
|
| 9 |
&& rm -rf /var/lib/apt/lists/*
|
| 10 |
|
| 11 |
-
|
| 12 |
-
COPY
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
-
|
|
|
|
| 15 |
|
| 16 |
-
|
|
|
|
| 17 |
|
| 18 |
-
|
|
|
|
| 19 |
|
| 20 |
-
ENTRYPOINT ["streamlit", "run", "src/streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"]
|
|
|
|
|
|
|
| 1 |
|
| 2 |
+
# Base image with Python
|
| 3 |
+
FROM python:3.11-slim
|
| 4 |
+
|
| 5 |
+
# Prevent Python from writing .pyc files and using output buffer
|
| 6 |
+
ENV PYTHONDONTWRITEBYTECODE=1
|
| 7 |
+
ENV PYTHONUNBUFFERED=1
|
| 8 |
+
|
| 9 |
+
# Set working directory
|
| 10 |
WORKDIR /app
|
| 11 |
|
| 12 |
+
# Install system dependencies (for faiss, etc.)
|
| 13 |
RUN apt-get update && apt-get install -y \
|
| 14 |
build-essential \
|
|
|
|
| 15 |
git \
|
| 16 |
+
curl \
|
| 17 |
&& rm -rf /var/lib/apt/lists/*
|
| 18 |
|
| 19 |
+
# Copy requirements first (better cache usage)
|
| 20 |
+
COPY requirements.txt .
|
| 21 |
+
|
| 22 |
+
# Install Python dependencies
|
| 23 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 24 |
|
| 25 |
+
# Copy project files
|
| 26 |
+
COPY . .
|
| 27 |
|
| 28 |
+
# Expose Streamlit default port
|
| 29 |
+
EXPOSE 7860
|
| 30 |
|
| 31 |
+
# Run Streamlit app
|
| 32 |
+
CMD ["streamlit", "run", "app.py", "--server.port=7860", "--server.address=0.0.0.0"]
|
| 33 |
|
|
|
Embeddings.py
ADDED
|
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import glob
|
| 3 |
+
import pickle, json
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
# Try imports with friendly errors
|
| 8 |
+
try:
|
| 9 |
+
import faiss
|
| 10 |
+
except Exception as e:
|
| 11 |
+
raise ImportError("faiss is required. Install cpu version: `pip install faiss-cpu` or install via conda for GPU (faiss-gpu).") from e
|
| 12 |
+
|
| 13 |
+
try:
|
| 14 |
+
from sentence_transformers import SentenceTransformer
|
| 15 |
+
except Exception as e:
|
| 16 |
+
raise ImportError("sentence-transformers is required. `pip install sentence-transformers`") from e
|
| 17 |
+
|
| 18 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
| 19 |
+
import torch
|
| 20 |
+
from dotenv import load_dotenv
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
from Data_Cleaning import GetDataCleaning
|
| 24 |
+
from Logger import GetLogger
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class GetEmbeddings:
|
| 28 |
+
"""
|
| 29 |
+
Embedding pipeline for cleaned text files.
|
| 30 |
+
Generates embeddings using SentenceTransformers, builds a FAISS index,
|
| 31 |
+
and allows searching queries against the vector database.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __init__(self, config_path="config.json", logger=None):
|
| 35 |
+
|
| 36 |
+
with open(config_path, "r") as f:
|
| 37 |
+
self.config = json.load(f)
|
| 38 |
+
|
| 39 |
+
cfg_paths = self.config["paths"]
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
cfg_emb = self.config["embedding"]
|
| 43 |
+
|
| 44 |
+
self.root = cfg_paths["root"]
|
| 45 |
+
self.cleaned_suffix = "_cleaned_txt"
|
| 46 |
+
self.chunk_words = cfg_emb["chunk_words"]
|
| 47 |
+
self.batch_size = cfg_emb["batch_size"]
|
| 48 |
+
self.faiss_index_path = cfg_paths["faiss_index"]
|
| 49 |
+
self.metadata_path = cfg_paths["metadata"]
|
| 50 |
+
self.embedding_model = cfg_emb["model"]
|
| 51 |
+
|
| 52 |
+
if not logger:
|
| 53 |
+
obj = GetLogger()
|
| 54 |
+
logger = obj.get_logger()
|
| 55 |
+
self.logger = logger
|
| 56 |
+
self.logger.info("Initializing Embedding Pipeline...")
|
| 57 |
+
|
| 58 |
+
# Device
|
| 59 |
+
self.device = "cuda" if self.check_cuda() and cfg_emb["use_gpu"] else "cpu"
|
| 60 |
+
load_dotenv()
|
| 61 |
+
self.hf_token = os.getenv("HF_TOKEN")
|
| 62 |
+
|
| 63 |
+
def check_cuda(self):
|
| 64 |
+
"""Return True if CUDA is available and usable."""
|
| 65 |
+
try:
|
| 66 |
+
if torch.cuda.is_available():
|
| 67 |
+
_ = torch.cuda.current_device()
|
| 68 |
+
self.logger.info(f"✅ CUDA available. Device: {torch.cuda.get_device_name(0)}")
|
| 69 |
+
return True
|
| 70 |
+
self.logger.info("⚠️ CUDA not available. Using CPU.")
|
| 71 |
+
return False
|
| 72 |
+
except Exception as e:
|
| 73 |
+
self.logger.error(f"Error checking CUDA, defaulting to CPU. Error: {e}")
|
| 74 |
+
return False
|
| 75 |
+
|
| 76 |
+
def list_cleaned_files(self):
|
| 77 |
+
"""Return sorted list of cleaned text files under root/*{cleaned_suffix}/*.txt"""
|
| 78 |
+
pattern = os.path.join(self.root, f"*{self.cleaned_suffix}", "*.txt")
|
| 79 |
+
files = glob.glob(pattern)
|
| 80 |
+
files.sort()
|
| 81 |
+
return files
|
| 82 |
+
|
| 83 |
+
def read_text_file(self, path):
|
| 84 |
+
"""Read a text file and return string content."""
|
| 85 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 86 |
+
return f.read()
|
| 87 |
+
|
| 88 |
+
def chunk_text_words(self, text):
|
| 89 |
+
"""
|
| 90 |
+
Simple word-based chunking.
|
| 91 |
+
Returns list of text chunks.
|
| 92 |
+
"""
|
| 93 |
+
words = text.split()
|
| 94 |
+
if not words:
|
| 95 |
+
return []
|
| 96 |
+
return [" ".join(words[i:i + self.chunk_words]) for i in range(0, len(words), self.chunk_words)]
|
| 97 |
+
|
| 98 |
+
def save_index_and_metadata(self):
|
| 99 |
+
"""Save FAISS index and metadata to disk."""
|
| 100 |
+
os.makedirs(os.path.dirname(self.faiss_index_path), exist_ok=True)
|
| 101 |
+
faiss.write_index(self.index, self.faiss_index_path)
|
| 102 |
+
with open(self.metadata_path, "wb") as f:
|
| 103 |
+
pickle.dump(self.metadata, f)
|
| 104 |
+
self.logger.info(f"💾 Saved FAISS index to {self.faiss_index_path}")
|
| 105 |
+
self.logger.info(f"💾 Saved metadata to {self.metadata_path}")
|
| 106 |
+
|
| 107 |
+
def load_index_and_metadata(self):
|
| 108 |
+
"""Load FAISS index and metadata if they exist."""
|
| 109 |
+
if os.path.exists(self.faiss_index_path) and os.path.exists(self.metadata_path):
|
| 110 |
+
try:
|
| 111 |
+
self.index = faiss.read_index(self.faiss_index_path)
|
| 112 |
+
with open(self.metadata_path, "rb") as f:
|
| 113 |
+
self.metadata = pickle.load(f)
|
| 114 |
+
self.logger.info(f"✅ Loaded existing FAISS index + metadata from disk.")
|
| 115 |
+
return True
|
| 116 |
+
except Exception as e:
|
| 117 |
+
self.logger.warning(f"⚠️ Failed to load FAISS index/metadata, will rebuild. Error: {e}")
|
| 118 |
+
return False
|
| 119 |
+
return False
|
| 120 |
+
|
| 121 |
+
def load_encoder(self):
|
| 122 |
+
"""Loading Encoder"""
|
| 123 |
+
self.encoder = SentenceTransformer(self.embedding_model, device=self.device)
|
| 124 |
+
self.logger.info(f"Loaded embedding model '{self.embedding_model}' on {self.device}")
|
| 125 |
+
return self.encoder
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def building_embeddings_index(self, files):
|
| 129 |
+
"""Build embeddings for all text chunks and return FAISS index + metadata."""
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
all_embeddings, metadata = [], []
|
| 133 |
+
next_id = 0
|
| 134 |
+
# Iterate files and chunks
|
| 135 |
+
for fp in tqdm(files, desc="Files", unit="file"):
|
| 136 |
+
text = self.read_text_file(fp)
|
| 137 |
+
|
| 138 |
+
if not text.strip():
|
| 139 |
+
continue
|
| 140 |
+
|
| 141 |
+
# metadata: infer company and file from path
|
| 142 |
+
# e.g., financial_reports/Infosys_cleaned_txt/Infosys_2023_AR.txt
|
| 143 |
+
rel = os.path.relpath(fp, self.root)
|
| 144 |
+
folder = rel.split(os.sep)[0]
|
| 145 |
+
filename = os.path.basename(fp)
|
| 146 |
+
|
| 147 |
+
chunks = self.chunk_text_words(text)
|
| 148 |
+
if not chunks:
|
| 149 |
+
continue
|
| 150 |
+
|
| 151 |
+
for i in range(0, len(chunks), self.batch_size):
|
| 152 |
+
batch = chunks[i:i + self.batch_size]
|
| 153 |
+
embs = self.encoder.encode(batch, show_progress_bar=False, convert_to_numpy=True)
|
| 154 |
+
embs = embs.astype(np.float32)
|
| 155 |
+
|
| 156 |
+
for j, vec in enumerate(embs):
|
| 157 |
+
all_embeddings.append(vec)
|
| 158 |
+
metadata.append({
|
| 159 |
+
"id": next_id,
|
| 160 |
+
"source_folder": folder,
|
| 161 |
+
"file": filename,
|
| 162 |
+
"chunk_id": i + j,
|
| 163 |
+
"text": batch[j] # store chunk text for retrieval
|
| 164 |
+
})
|
| 165 |
+
next_id += 1
|
| 166 |
+
|
| 167 |
+
if not all_embeddings:
|
| 168 |
+
raise RuntimeError("No embeddings were produced. Check cleaned files and chunking.")
|
| 169 |
+
|
| 170 |
+
emb_matrix = np.vstack(all_embeddings).astype(np.float32)
|
| 171 |
+
faiss.normalize_L2(emb_matrix)
|
| 172 |
+
|
| 173 |
+
# Build FAISS index (IndexFlatIP over normalized vectors = cosine similarity)
|
| 174 |
+
dim = emb_matrix.shape[1]
|
| 175 |
+
self.index = faiss.IndexFlatIP(dim)
|
| 176 |
+
self.index.add(emb_matrix)
|
| 177 |
+
self.metadata = metadata
|
| 178 |
+
self.logger.info(f"✅ Built FAISS index with {self.index.ntotal} vectors, dim={dim}")
|
| 179 |
+
|
| 180 |
+
return self.index, self.metadata
|
| 181 |
+
|
| 182 |
+
def run(self):
|
| 183 |
+
"""Main entry: load or build embeddings + FAISS index."""
|
| 184 |
+
if self.load_index_and_metadata():
|
| 185 |
+
return
|
| 186 |
+
|
| 187 |
+
files = self.list_cleaned_files()
|
| 188 |
+
if not files:
|
| 189 |
+
self.logger.error("❌ No cleaned text files found.")
|
| 190 |
+
raise SystemExit(1)
|
| 191 |
+
self.load_encoder()
|
| 192 |
+
self.building_embeddings_index(files)
|
| 193 |
+
self.save_index_and_metadata()
|
| 194 |
+
|
| 195 |
+
def load_summarizer(self, model_name="google/gemma-2b"):
|
| 196 |
+
"""
|
| 197 |
+
Load summarizer LLM once.
|
| 198 |
+
If already loaded, skip.
|
| 199 |
+
"""
|
| 200 |
+
if hasattr(self, "summarizer_pipeline"):
|
| 201 |
+
self.logger.info("ℹ️ Summarizer already loaded, skipping reload.")
|
| 202 |
+
return
|
| 203 |
+
|
| 204 |
+
try:
|
| 205 |
+
self.logger.info(f"⏳ Loading summarizer model '{model_name}'...")
|
| 206 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name, token=self.hf_token)
|
| 207 |
+
self.summarizer_model = AutoModelForCausalLM.from_pretrained(
|
| 208 |
+
model_name,
|
| 209 |
+
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
|
| 210 |
+
device_map=self.device,
|
| 211 |
+
token=self.hf_token
|
| 212 |
+
)
|
| 213 |
+
self.summarizer_pipeline = pipeline(
|
| 214 |
+
"text-generation",
|
| 215 |
+
model=self.summarizer_model,
|
| 216 |
+
tokenizer=self.tokenizer
|
| 217 |
+
)
|
| 218 |
+
self.logger.info(f"✅ Summarizer model '{model_name}' loaded successfully.")
|
| 219 |
+
|
| 220 |
+
except RuntimeError as e:
|
| 221 |
+
if "CUDA out of memory" in str(e):
|
| 222 |
+
self.logger.warning("⚠️ CUDA OOM while loading summarizer. Retrying on CPU...")
|
| 223 |
+
self.device = "cpu"
|
| 224 |
+
torch.cuda.empty_cache()
|
| 225 |
+
return self.load_summarizer(model_name=model_name)
|
| 226 |
+
else:
|
| 227 |
+
self.logger.error(f"❌ Failed to load summarizer: {e}")
|
| 228 |
+
raise
|
| 229 |
+
|
| 230 |
+
def summarize_chunks(self, chunks, max_content_tokens=2048, max_output_tokens=256):
|
| 231 |
+
"""
|
| 232 |
+
Summarize list of text chunks using LLM.
|
| 233 |
+
- Chunks are joined until they fit into max_context_tokens
|
| 234 |
+
- Generates a concise summary.
|
| 235 |
+
"""
|
| 236 |
+
|
| 237 |
+
if not hasattr(self, "summarizer_pipeline"):
|
| 238 |
+
self.load_summarizer()
|
| 239 |
+
self.logger.info("Summarizer not initialized. Called load_summarizer(). pipeline will work with default parameters.")
|
| 240 |
+
|
| 241 |
+
# Join chunks into one context, respecting token budget
|
| 242 |
+
context = " ".join(chunks)
|
| 243 |
+
input_tokens = len(self.tokenizer.encode(context))
|
| 244 |
+
|
| 245 |
+
if input_tokens > max_content_tokens:
|
| 246 |
+
# Trim to fit context window
|
| 247 |
+
context = " ".join(context.split()[:max_content_tokens])
|
| 248 |
+
self.logger.warning("⚠️ Context truncated to fit within model token limit.")
|
| 249 |
+
|
| 250 |
+
# Build summarization prompt
|
| 251 |
+
prompt = f"""
|
| 252 |
+
Summarize the following financial report excerpts into a concise answer.
|
| 253 |
+
Keep it factual, short, and grounded in the text.
|
| 254 |
+
|
| 255 |
+
Excerpts:
|
| 256 |
+
{context}
|
| 257 |
+
|
| 258 |
+
Summary:
|
| 259 |
+
"""
|
| 260 |
+
|
| 261 |
+
try:
|
| 262 |
+
output = self.summarizer_pipeline(
|
| 263 |
+
prompt,
|
| 264 |
+
max_new_tokens=max_output_tokens,
|
| 265 |
+
do_sample=False
|
| 266 |
+
)[0]["generated_text"]
|
| 267 |
+
|
| 268 |
+
if "Summary:" in output:
|
| 269 |
+
summary = output.split("Summary:")[-1].strip()
|
| 270 |
+
else:
|
| 271 |
+
summary = output.strip()
|
| 272 |
+
|
| 273 |
+
return summary
|
| 274 |
+
|
| 275 |
+
except RuntimeError as e:
|
| 276 |
+
if "CUDA out of memory" in str(e):
|
| 277 |
+
self.logger.warning("⚠️ CUDA OOM during summarization. Retrying on CPU...")
|
| 278 |
+
self.device = "cpu"
|
| 279 |
+
torch.cuda.empty_cache()
|
| 280 |
+
return self.summarize_chunks(chunks, max_content_tokens, max_output_tokens)
|
| 281 |
+
else:
|
| 282 |
+
self.logger.error(f"❌ Summarizer failed: {e}. Falling back to raw chunks.")
|
| 283 |
+
return " ".join(chunks[:2]) # fallback: return first 2 chunks
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def answer_query(self, query, top_k=3):
|
| 287 |
+
"""
|
| 288 |
+
End-to-end QA:
|
| 289 |
+
- Retrieve relevant chunks from FAISS
|
| 290 |
+
- Summarize into a final answer.
|
| 291 |
+
"""
|
| 292 |
+
try:
|
| 293 |
+
#step 1: Retrieve
|
| 294 |
+
self.logger.info(f"🔍 searching vector DB for query: {query}")
|
| 295 |
+
q_emb = self.encoder.encode(query, show_progress_bar=False, convert_to_numpy=True).reshape(1, -1)
|
| 296 |
+
faiss.normalize_L2(q_emb)
|
| 297 |
+
|
| 298 |
+
scores, idxs = self.index.search(q_emb, k=top_k)
|
| 299 |
+
chunks = [self.metadata[idx]["text"] for idx in idxs[0]]
|
| 300 |
+
|
| 301 |
+
# Step 2: Summarize
|
| 302 |
+
summary = self.summarize_chunks(chunks)
|
| 303 |
+
|
| 304 |
+
# Log results
|
| 305 |
+
self.logger.info(f"✅ Final Answer: {summary}")
|
| 306 |
+
return summary
|
| 307 |
+
|
| 308 |
+
except Exception as e:
|
| 309 |
+
self.logger.error(f"Error in answer_query: {e}")
|
| 310 |
+
return None
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
# Example
|
| 314 |
+
ge = GetEmbeddings()
|
| 315 |
+
# ge.run()
|
| 316 |
+
# # NEW STEP
|
| 317 |
+
# ge.load_summarizer("google/gemma-2b")
|
| 318 |
+
# answer = ge.answer_query("What are the key highlights from Q2 financial report?")
|
| 319 |
+
# print(answer)
|
Logger.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os, json
|
| 3 |
+
from logging.handlers import TimedRotatingFileHandler
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
|
| 6 |
+
class GetLogger:
|
| 7 |
+
def __init__(self, logging_level="INFO", log_to_console=True, log_dir="logs"):
|
| 8 |
+
"""
|
| 9 |
+
Advanced Logger
|
| 10 |
+
- Logs to both file (rotating) and console
|
| 11 |
+
- Default rotation: daily, keep last 7 logs
|
| 12 |
+
- Safe filename (no ':' in timestamp)
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
self.logger = logging.getLogger(__name__)
|
| 16 |
+
self.logger.setLevel(logging_level.upper())
|
| 17 |
+
|
| 18 |
+
# Avoid duplicate handlers
|
| 19 |
+
if self.logger.hasHandlers():
|
| 20 |
+
self.logger.handlers.clear()
|
| 21 |
+
|
| 22 |
+
# Ensure log directory exists
|
| 23 |
+
os.makedirs(log_dir, exist_ok=True)
|
| 24 |
+
|
| 25 |
+
# File handler (rotates daily, keep 7 backups)
|
| 26 |
+
file_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + ".log"
|
| 27 |
+
log_path = os.path.join(log_dir, file_name)
|
| 28 |
+
file_handler = TimedRotatingFileHandler(
|
| 29 |
+
filename=log_path, when="D", interval=1, backupCount=3, encoding="utf-8"
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
formatter = logging.Formatter(
|
| 33 |
+
"%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(funcName)s() - %(message)s",
|
| 34 |
+
datefmt="%Y-%m-%d %H:%M:%S"
|
| 35 |
+
)
|
| 36 |
+
file_handler.setFormatter(formatter)
|
| 37 |
+
self.logger.addHandler(file_handler)
|
| 38 |
+
|
| 39 |
+
# Console handler (optional)
|
| 40 |
+
if log_to_console:
|
| 41 |
+
console_handler = logging.StreamHandler()
|
| 42 |
+
console_handler.setFormatter(formatter)
|
| 43 |
+
self.logger.addHandler(console_handler)
|
| 44 |
+
|
| 45 |
+
def get_logger(self):
|
| 46 |
+
return self.logger
|
| 47 |
+
|
| 48 |
+
def delete_logger(self):
|
| 49 |
+
"""Remove all handlers and delete logger."""
|
| 50 |
+
handlers = self.logger.handlers[:]
|
| 51 |
+
for handler in handlers:
|
| 52 |
+
self.logger.removeHandler(handler)
|
| 53 |
+
handler.close()
|
| 54 |
+
del self.logger
|
| 55 |
+
|
| 56 |
+
class MetricsLogger:
|
| 57 |
+
"""
|
| 58 |
+
Collects evaluation metrics and saves aggregated statistics.
|
| 59 |
+
"""
|
| 60 |
+
def __init__(self, save_path="logs/metrics_summary.json", logger=None):
|
| 61 |
+
self.save_path = save_path
|
| 62 |
+
self.metrics = [] # store per-query metrics
|
| 63 |
+
self.logger = logger or logging.getLogger(__name__)
|
| 64 |
+
|
| 65 |
+
def log_query_metrics(self, query, result_dict):
|
| 66 |
+
"""
|
| 67 |
+
Log metrics for a single query.
|
| 68 |
+
Example: result_dict = {"latency_sec": 0.5, "rougeL": 0.7, ...}
|
| 69 |
+
"""
|
| 70 |
+
record = {"query": query}
|
| 71 |
+
record.update(result_dict)
|
| 72 |
+
self.metrics.append(record)
|
| 73 |
+
self.logger.info(f"📊 Metrics logged for query: {query[:50]}...")
|
| 74 |
+
|
| 75 |
+
def summarize(self):
|
| 76 |
+
"""Aggregate metrics (mean values)."""
|
| 77 |
+
if not self.metrics:
|
| 78 |
+
return {}
|
| 79 |
+
|
| 80 |
+
summary = {}
|
| 81 |
+
keys = [k for k in self.metrics[0].keys() if k != "query"]
|
| 82 |
+
for key in keys:
|
| 83 |
+
values = [m[key] for m in self.metrics if key in m and isinstance(m[key], (int, float))]
|
| 84 |
+
if values:
|
| 85 |
+
summary[f"avg_{key}"] = float(sum(values) / len(values))
|
| 86 |
+
|
| 87 |
+
return summary
|
| 88 |
+
|
| 89 |
+
def save(self):
|
| 90 |
+
"""Save all metrics + summary to JSON."""
|
| 91 |
+
os.makedirs(os.path.dirname(self.save_path), exist_ok=True)
|
| 92 |
+
data = {
|
| 93 |
+
"per_query": self.metrics,
|
| 94 |
+
"summary": self.summarize()
|
| 95 |
+
}
|
| 96 |
+
with open(self.save_path, "w", encoding="utf-8") as f:
|
| 97 |
+
json.dump(data, f, indent=2)
|
| 98 |
+
self.logger.info(f"✅ Metrics saved to {self.save_path}")
|
| 99 |
+
return data
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
# Example
|
| 104 |
+
if __name__ == "__main__":
|
| 105 |
+
obj = GetLogger()
|
| 106 |
+
logger = obj.get_logger()
|
| 107 |
+
logger.info("✅ Logger initialized successfully")
|
| 108 |
+
logger.warning("⚠️ This is a warning")
|
| 109 |
+
logger.error("❌ This is an error")
|
README.md
CHANGED
|
@@ -18,3 +18,40 @@ Edit `/src/streamlit_app.py` to customize this app to your heart's desire. :hear
|
|
| 18 |
|
| 19 |
If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
|
| 20 |
forums](https://discuss.streamlit.io).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
|
| 20 |
forums](https://discuss.streamlit.io).
|
| 21 |
+
|
| 22 |
+
# 📊 Financial QA Agent
|
| 23 |
+
|
| 24 |
+
An AI-powered financial report assistant built with **RAG (Retrieval-Augmented Generation)**.
|
| 25 |
+
This app lets you upload financial reports, search them with semantic embeddings, and get concise answers/summaries using an open-source LLM.
|
| 26 |
+
|
| 27 |
+
## 🚀 Features
|
| 28 |
+
- Cleans financial report text files automatically
|
| 29 |
+
- Generates vector embeddings with FAISS for efficient retrieval
|
| 30 |
+
- Summarizes answers using `google/gemma-2b` (or lightweight models for deployment)
|
| 31 |
+
- Streamlit UI for easy interaction
|
| 32 |
+
- Evaluation pipeline with ROUGE, BLEU, and BERTScore
|
| 33 |
+
|
| 34 |
+
## 🛠️ Tech Stack
|
| 35 |
+
- **Streamlit** for UI
|
| 36 |
+
- **FAISS** for vector search
|
| 37 |
+
- **Sentence-Transformers** for embeddings
|
| 38 |
+
- **Transformers** (Gemma/LLMs) for summarization
|
| 39 |
+
- **Scikit-learn, NLTK, BERTScore** for evaluation metrics
|
| 40 |
+
|
| 41 |
+
## 📂 Project Structure
|
| 42 |
+
├── app.py # Main Streamlit app (entrypoint)
|
| 43 |
+
├── Embeddings.py # Embedding + FAISS pipeline
|
| 44 |
+
├── Data_Cleaning.py # Data cleaning utility
|
| 45 |
+
├── Logger.py # Logging utility
|
| 46 |
+
├── evaluation.py # Evaluation pipeline
|
| 47 |
+
├── config.json # Configurations
|
| 48 |
+
├── eval_dataset.json # Sample evaluation dataset
|
| 49 |
+
├── requirements.txt # Dependencies
|
| 50 |
+
├── README.md # Project documentation
|
| 51 |
+
└── .gitignore # Ignore unnecessary files
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
## ⚡ Running Locally
|
| 55 |
+
```bash
|
| 56 |
+
pip install -r requirements.txt
|
| 57 |
+
streamlit run app.py
|
app.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
from Embeddings import GetEmbeddings
|
| 3 |
+
import json
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
# Load Agent once and cache it
|
| 7 |
+
@st.cache_resource
|
| 8 |
+
def load_agent():
|
| 9 |
+
agent = GetEmbeddings(config_path="config.json")
|
| 10 |
+
agent.run() # Build/load FAISS
|
| 11 |
+
agent.load_summarizer() # Load summarizer model
|
| 12 |
+
encoder = agent.load_encoder()
|
| 13 |
+
return agent, encoder
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def main():
|
| 17 |
+
st.set_page_config(page_title="📊 Financial QA Agent", layout="wide")
|
| 18 |
+
|
| 19 |
+
st.title("📊 Financial QA Agent")
|
| 20 |
+
st.markdown(
|
| 21 |
+
"""
|
| 22 |
+
Ask questions about financial reports.
|
| 23 |
+
The system retrieves relevant sections from company reports and summarizes them into concise answers.
|
| 24 |
+
"""
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
# Sidebar
|
| 28 |
+
st.sidebar.header("⚙️ Settings")
|
| 29 |
+
show_debug = st.sidebar.checkbox("Show retrieved chunks", value=False)
|
| 30 |
+
|
| 31 |
+
# Load Agent
|
| 32 |
+
agent, encoder = load_agent()
|
| 33 |
+
|
| 34 |
+
# User Input
|
| 35 |
+
query = st.text_area("Enter your financial question:", height=100)
|
| 36 |
+
|
| 37 |
+
if st.button("Get Answer"):
|
| 38 |
+
if query.strip() == "":
|
| 39 |
+
st.warning("⚠️ Please enter a query.")
|
| 40 |
+
else:
|
| 41 |
+
with st.spinner("🔎 Searching and generating answer..."):
|
| 42 |
+
answer = agent.answer_query(query, top_k=3)
|
| 43 |
+
|
| 44 |
+
st.subheader("✅ Answer")
|
| 45 |
+
st.write(answer)
|
| 46 |
+
|
| 47 |
+
if show_debug:
|
| 48 |
+
st.subheader("📂 Retrieved Chunks (Debug)")
|
| 49 |
+
# Show top chunks used
|
| 50 |
+
q_emb = encoder.encode(query, convert_to_numpy=True).reshape(1, -1)
|
| 51 |
+
import faiss
|
| 52 |
+
faiss.normalize_L2(q_emb)
|
| 53 |
+
scores, idxs = agent.index.search(q_emb, k=3)
|
| 54 |
+
for score, idx in zip(scores[0], idxs[0]):
|
| 55 |
+
st.markdown(f"**Score:** {score:.4f}")
|
| 56 |
+
st.write(agent.metadata[idx]["text"][:500] + "...")
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
if __name__ == "__main__":
|
| 60 |
+
main()
|
app_colabcode.ipynb
ADDED
|
@@ -0,0 +1,805 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": null,
|
| 6 |
+
"metadata": {
|
| 7 |
+
"colab": {
|
| 8 |
+
"base_uri": "https://localhost:8080/"
|
| 9 |
+
},
|
| 10 |
+
"id": "rkRTHAoXbOJC",
|
| 11 |
+
"outputId": "ab776f45-7c6c-4b1c-87bc-7410dc1955fe"
|
| 12 |
+
},
|
| 13 |
+
"outputs": [
|
| 14 |
+
{
|
| 15 |
+
"name": "stdout",
|
| 16 |
+
"output_type": "stream",
|
| 17 |
+
"text": [
|
| 18 |
+
"Selecting previously unselected package cloudflared.\n",
|
| 19 |
+
"(Reading database ... 126441 files and directories currently installed.)\n",
|
| 20 |
+
"Preparing to unpack cloudflared-linux-amd64.deb ...\n",
|
| 21 |
+
"Unpacking cloudflared (2025.9.1) ...\n",
|
| 22 |
+
"Setting up cloudflared (2025.9.1) ...\n",
|
| 23 |
+
"Processing triggers for man-db (2.10.2-1) ...\n",
|
| 24 |
+
"cloudflared version 2025.9.1 (built 2025-09-22-13:28 UTC)\n"
|
| 25 |
+
]
|
| 26 |
+
}
|
| 27 |
+
],
|
| 28 |
+
"source": [
|
| 29 |
+
"!pip install -r requirements.txt -q\n",
|
| 30 |
+
"!pip install streamlit cloudflared -q\n",
|
| 31 |
+
"!wget -q https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-linux-amd64.deb\n",
|
| 32 |
+
"!dpkg -i cloudflared-linux-amd64.deb\n",
|
| 33 |
+
"\n",
|
| 34 |
+
"!cloudflared --version\n"
|
| 35 |
+
]
|
| 36 |
+
},
|
| 37 |
+
{
|
| 38 |
+
"cell_type": "code",
|
| 39 |
+
"execution_count": null,
|
| 40 |
+
"metadata": {
|
| 41 |
+
"id": "UpQo5rPBkvT4"
|
| 42 |
+
},
|
| 43 |
+
"outputs": [],
|
| 44 |
+
"source": []
|
| 45 |
+
},
|
| 46 |
+
{
|
| 47 |
+
"cell_type": "code",
|
| 48 |
+
"execution_count": null,
|
| 49 |
+
"metadata": {
|
| 50 |
+
"colab": {
|
| 51 |
+
"base_uri": "https://localhost:8080/"
|
| 52 |
+
},
|
| 53 |
+
"id": "l08lsc3SbUy2",
|
| 54 |
+
"outputId": "e7c5db50-4944-4fad-bad6-fae2ec7439aa"
|
| 55 |
+
},
|
| 56 |
+
"outputs": [
|
| 57 |
+
{
|
| 58 |
+
"name": "stdout",
|
| 59 |
+
"output_type": "stream",
|
| 60 |
+
"text": [
|
| 61 |
+
"✅ CUDA is available. Using GPU: Tesla T4\n"
|
| 62 |
+
]
|
| 63 |
+
}
|
| 64 |
+
],
|
| 65 |
+
"source": [
|
| 66 |
+
"import torch\n",
|
| 67 |
+
"\n",
|
| 68 |
+
"if torch.cuda.is_available():\n",
|
| 69 |
+
" print(f\"✅ CUDA is available. Using GPU: {torch.cuda.get_device_name(0)}\")\n",
|
| 70 |
+
" # return True\n",
|
| 71 |
+
"else:\n",
|
| 72 |
+
" print(\"⚠️ CUDA not available. Falling back to CPU.\")\n",
|
| 73 |
+
" # return False\n",
|
| 74 |
+
"\n",
|
| 75 |
+
"\n",
|
| 76 |
+
"# # Load the allocator\n",
|
| 77 |
+
"# new_alloc = torch.cuda.memory.CUDAPluggableAllocator(\n",
|
| 78 |
+
"# 'alloc.so', 'my_malloc', 'my_free')\n",
|
| 79 |
+
"# # Swap the current allocator\n",
|
| 80 |
+
"# torch.cuda.memory.change_current_allocator(new_alloc)"
|
| 81 |
+
]
|
| 82 |
+
},
|
| 83 |
+
{
|
| 84 |
+
"cell_type": "code",
|
| 85 |
+
"execution_count": null,
|
| 86 |
+
"metadata": {
|
| 87 |
+
"colab": {
|
| 88 |
+
"base_uri": "https://localhost:8080/"
|
| 89 |
+
},
|
| 90 |
+
"id": "LHHSaPwNbZXW",
|
| 91 |
+
"outputId": "a2939de4-7a06-4a35-cf6f-190ea3fec13a"
|
| 92 |
+
},
|
| 93 |
+
"outputs": [
|
| 94 |
+
{
|
| 95 |
+
"name": "stdout",
|
| 96 |
+
"output_type": "stream",
|
| 97 |
+
"text": [
|
| 98 |
+
"Overwriting Embeddings.py\n"
|
| 99 |
+
]
|
| 100 |
+
}
|
| 101 |
+
],
|
| 102 |
+
"source": [
|
| 103 |
+
"%%writefile Embeddings.py\n",
|
| 104 |
+
"\n",
|
| 105 |
+
"import os\n",
|
| 106 |
+
"import glob\n",
|
| 107 |
+
"import pickle, json\n",
|
| 108 |
+
"from tqdm import tqdm\n",
|
| 109 |
+
"import numpy as np\n",
|
| 110 |
+
"\n",
|
| 111 |
+
"# Try imports with friendly errors\n",
|
| 112 |
+
"try:\n",
|
| 113 |
+
" import faiss\n",
|
| 114 |
+
"except Exception as e:\n",
|
| 115 |
+
" raise ImportError(\"faiss is required. Install cpu version: `pip install faiss-cpu` or install via conda for GPU (faiss-gpu).\") from e\n",
|
| 116 |
+
"\n",
|
| 117 |
+
"try:\n",
|
| 118 |
+
" from sentence_transformers import SentenceTransformer\n",
|
| 119 |
+
"except Exception as e:\n",
|
| 120 |
+
" raise ImportError(\"sentence-transformers is required. `pip install sentence-transformers`\") from e\n",
|
| 121 |
+
"\n",
|
| 122 |
+
"from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline\n",
|
| 123 |
+
"import torch\n",
|
| 124 |
+
"from google.colab import userdata\n",
|
| 125 |
+
"\n",
|
| 126 |
+
"\n",
|
| 127 |
+
"\n",
|
| 128 |
+
"# from Data_Cleaning import GetDataCleaning\n",
|
| 129 |
+
"# from Logger import GetLogger\n",
|
| 130 |
+
"\n",
|
| 131 |
+
"\n",
|
| 132 |
+
"class GetEmbeddings:\n",
|
| 133 |
+
" \"\"\"\n",
|
| 134 |
+
" Embedding pipeline for cleaned text files.\n",
|
| 135 |
+
" Generates embeddings using SentenceTransformers, builds a FAISS index,\n",
|
| 136 |
+
" and allows searching queries against the vector database.\n",
|
| 137 |
+
" \"\"\"\n",
|
| 138 |
+
"\n",
|
| 139 |
+
" def __init__(self, config_path=\"config.json\", logger=None):\n",
|
| 140 |
+
"\n",
|
| 141 |
+
" with open(config_path, \"r\") as f:\n",
|
| 142 |
+
" self.config = json.load(f)\n",
|
| 143 |
+
"\n",
|
| 144 |
+
" cfg_paths = self.config[\"paths\"]\n",
|
| 145 |
+
"\n",
|
| 146 |
+
"\n",
|
| 147 |
+
" cfg_emb = self.config[\"embedding\"]\n",
|
| 148 |
+
"\n",
|
| 149 |
+
" self.root = cfg_paths[\"root\"]\n",
|
| 150 |
+
" self.cleaned_suffix = \"_cleaned_txt\"\n",
|
| 151 |
+
" self.chunk_words = cfg_emb[\"chunk_words\"]\n",
|
| 152 |
+
" self.batch_size = cfg_emb[\"batch_size\"]\n",
|
| 153 |
+
" self.faiss_index_path = cfg_paths[\"faiss_index\"]\n",
|
| 154 |
+
" self.metadata_path = cfg_paths[\"metadata\"]\n",
|
| 155 |
+
" self.embedding_model = cfg_emb[\"model\"]\n",
|
| 156 |
+
"\n",
|
| 157 |
+
" # if not logger:\n",
|
| 158 |
+
" # obj = GetLogger()\n",
|
| 159 |
+
" # logger = obj.get_logger()\n",
|
| 160 |
+
" # self.logger = logger\n",
|
| 161 |
+
" # print(\"Initializing Embedding Pipeline...\")\n",
|
| 162 |
+
"\n",
|
| 163 |
+
" # Device\n",
|
| 164 |
+
" self.device = \"cuda\" if self.check_cuda() and cfg_emb[\"use_gpu\"] else \"cpu\"\n",
|
| 165 |
+
" self.hf_token = \"your_token\"\n",
|
| 166 |
+
"\n",
|
| 167 |
+
" def check_cuda(self):\n",
|
| 168 |
+
" \"\"\"Return True if CUDA is available and usable.\"\"\"\n",
|
| 169 |
+
" try:\n",
|
| 170 |
+
" if torch.cuda.is_available():\n",
|
| 171 |
+
" _ = torch.cuda.current_device()\n",
|
| 172 |
+
" print(f\"✅ CUDA available. Device: {torch.cuda.get_device_name(0)}\")\n",
|
| 173 |
+
" return True\n",
|
| 174 |
+
" print(\"⚠️ CUDA not available. Using CPU.\")\n",
|
| 175 |
+
" return False\n",
|
| 176 |
+
" except Exception as e:\n",
|
| 177 |
+
" print(f\"Error checking CUDA, defaulting to CPU. Error: {e}\")\n",
|
| 178 |
+
" return False\n",
|
| 179 |
+
"\n",
|
| 180 |
+
" def list_cleaned_files(self):\n",
|
| 181 |
+
" \"\"\"Return sorted list of cleaned text files under root/*{cleaned_suffix}/*.txt\"\"\"\n",
|
| 182 |
+
" pattern = os.path.join(self.root, f\"*{self.cleaned_suffix}\", \"*.txt\")\n",
|
| 183 |
+
" files = glob.glob(pattern)\n",
|
| 184 |
+
" files.sort()\n",
|
| 185 |
+
" return files\n",
|
| 186 |
+
"\n",
|
| 187 |
+
" def read_text_file(self, path):\n",
|
| 188 |
+
" \"\"\"Read a text file and return string content.\"\"\"\n",
|
| 189 |
+
" with open(path, \"r\", encoding=\"utf-8\") as f:\n",
|
| 190 |
+
" return f.read()\n",
|
| 191 |
+
"\n",
|
| 192 |
+
" def chunk_text_words(self, text):\n",
|
| 193 |
+
" \"\"\"\n",
|
| 194 |
+
" Simple word-based chunking.\n",
|
| 195 |
+
" Returns list of text chunks.\n",
|
| 196 |
+
" \"\"\"\n",
|
| 197 |
+
" words = text.split()\n",
|
| 198 |
+
" if not words:\n",
|
| 199 |
+
" return []\n",
|
| 200 |
+
" return [\" \".join(words[i:i + self.chunk_words]) for i in range(0, len(words), self.chunk_words)]\n",
|
| 201 |
+
"\n",
|
| 202 |
+
" def save_index_and_metadata(self):\n",
|
| 203 |
+
" \"\"\"Save FAISS index and metadata to disk.\"\"\"\n",
|
| 204 |
+
" os.makedirs(os.path.dirname(self.faiss_index_path), exist_ok=True)\n",
|
| 205 |
+
" faiss.write_index(self.index, self.faiss_index_path)\n",
|
| 206 |
+
" with open(self.metadata_path, \"wb\") as f:\n",
|
| 207 |
+
" pickle.dump(self.metadata, f)\n",
|
| 208 |
+
" print(f\"💾 Saved FAISS index to {self.faiss_index_path}\")\n",
|
| 209 |
+
" print(f\"💾 Saved metadata to {self.metadata_path}\")\n",
|
| 210 |
+
"\n",
|
| 211 |
+
" def load_index_and_metadata(self):\n",
|
| 212 |
+
" \"\"\"Load FAISS index and metadata if they exist.\"\"\"\n",
|
| 213 |
+
" if os.path.exists(self.faiss_index_path) and os.path.exists(self.metadata_path):\n",
|
| 214 |
+
" try:\n",
|
| 215 |
+
" self.index = faiss.read_index(self.faiss_index_path)\n",
|
| 216 |
+
" with open(self.metadata_path, \"rb\") as f:\n",
|
| 217 |
+
" self.metadata = pickle.load(f)\n",
|
| 218 |
+
" print(f\"✅ Loaded existing FAISS index + metadata from disk.\")\n",
|
| 219 |
+
" return True\n",
|
| 220 |
+
" except Exception as e:\n",
|
| 221 |
+
" print(f\"⚠️ Failed to load FAISS index/metadata, will rebuild. Error: {e}\")\n",
|
| 222 |
+
" return False\n",
|
| 223 |
+
" return False\n",
|
| 224 |
+
"\n",
|
| 225 |
+
" def load_encoder(self):\n",
|
| 226 |
+
" \"\"\"Loading Encoder\"\"\"\n",
|
| 227 |
+
" self.encoder = SentenceTransformer(self.embedding_model, device=self.device)\n",
|
| 228 |
+
" print(f\"Loaded embedding model '{self.embedding_model}' on {self.device}\")\n",
|
| 229 |
+
" return self.encoder\n",
|
| 230 |
+
"\n",
|
| 231 |
+
"\n",
|
| 232 |
+
" def building_embeddings_index(self, files):\n",
|
| 233 |
+
" \"\"\"Build embeddings for all text chunks and return FAISS index + metadata.\"\"\"\n",
|
| 234 |
+
"\n",
|
| 235 |
+
"\n",
|
| 236 |
+
" all_embeddings, metadata = [], []\n",
|
| 237 |
+
" next_id = 0\n",
|
| 238 |
+
" # Iterate files and chunks\n",
|
| 239 |
+
" for fp in tqdm(files, desc=\"Files\", unit=\"file\"):\n",
|
| 240 |
+
" text = self.read_text_file(fp)\n",
|
| 241 |
+
"\n",
|
| 242 |
+
" if not text.strip():\n",
|
| 243 |
+
" continue\n",
|
| 244 |
+
"\n",
|
| 245 |
+
" # metadata: infer company and file from path\n",
|
| 246 |
+
" # e.g., financial_reports/Infosys_cleaned_txt/Infosys_2023_AR.txt\n",
|
| 247 |
+
" rel = os.path.relpath(fp, self.root)\n",
|
| 248 |
+
" folder = rel.split(os.sep)[0]\n",
|
| 249 |
+
" filename = os.path.basename(fp)\n",
|
| 250 |
+
"\n",
|
| 251 |
+
" chunks = self.chunk_text_words(text)\n",
|
| 252 |
+
" if not chunks:\n",
|
| 253 |
+
" continue\n",
|
| 254 |
+
"\n",
|
| 255 |
+
" for i in range(0, len(chunks), self.batch_size):\n",
|
| 256 |
+
" batch = chunks[i:i + self.batch_size]\n",
|
| 257 |
+
" embs = self.encoder.encode(batch, show_progress_bar=False, convert_to_numpy=True)\n",
|
| 258 |
+
" embs = embs.astype(np.float32)\n",
|
| 259 |
+
"\n",
|
| 260 |
+
" for j, vec in enumerate(embs):\n",
|
| 261 |
+
" all_embeddings.append(vec)\n",
|
| 262 |
+
" metadata.append({\n",
|
| 263 |
+
" \"id\": next_id,\n",
|
| 264 |
+
" \"source_folder\": folder,\n",
|
| 265 |
+
" \"file\": filename,\n",
|
| 266 |
+
" \"chunk_id\": i + j,\n",
|
| 267 |
+
" \"text\": batch[j] # store chunk text for retrieval\n",
|
| 268 |
+
" })\n",
|
| 269 |
+
" next_id += 1\n",
|
| 270 |
+
"\n",
|
| 271 |
+
" if not all_embeddings:\n",
|
| 272 |
+
" raise RuntimeError(\"No embeddings were produced. Check cleaned files and chunking.\")\n",
|
| 273 |
+
"\n",
|
| 274 |
+
" emb_matrix = np.vstack(all_embeddings).astype(np.float32)\n",
|
| 275 |
+
" faiss.normalize_L2(emb_matrix)\n",
|
| 276 |
+
"\n",
|
| 277 |
+
" # Build FAISS index (IndexFlatIP over normalized vectors = cosine similarity)\n",
|
| 278 |
+
" dim = emb_matrix.shape[1]\n",
|
| 279 |
+
" self.index = faiss.IndexFlatIP(dim)\n",
|
| 280 |
+
" self.index.add(emb_matrix)\n",
|
| 281 |
+
" self.metadata = metadata\n",
|
| 282 |
+
" print(f\"✅ Built FAISS index with {self.index.ntotal} vectors, dim={dim}\")\n",
|
| 283 |
+
"\n",
|
| 284 |
+
" return self.index, self.metadata\n",
|
| 285 |
+
"\n",
|
| 286 |
+
" def run(self):\n",
|
| 287 |
+
" \"\"\"Main entry: load or build embeddings + FAISS index.\"\"\"\n",
|
| 288 |
+
" if self.load_index_and_metadata():\n",
|
| 289 |
+
" return\n",
|
| 290 |
+
"\n",
|
| 291 |
+
" files = self.list_cleaned_files()\n",
|
| 292 |
+
" if not files:\n",
|
| 293 |
+
" print(\"❌ No cleaned text files found.\")\n",
|
| 294 |
+
" raise SystemExit(1)\n",
|
| 295 |
+
" self.load_encoder()\n",
|
| 296 |
+
" self.building_embeddings_index(files)\n",
|
| 297 |
+
" self.save_index_and_metadata()\n",
|
| 298 |
+
"\n",
|
| 299 |
+
" def load_summarizer(self, model_name=\"google/gemma-2b\"):\n",
|
| 300 |
+
" \"\"\"\n",
|
| 301 |
+
" Load summarizer LLM once.\n",
|
| 302 |
+
" If already loaded, skip.\n",
|
| 303 |
+
" \"\"\"\n",
|
| 304 |
+
" if hasattr(self, \"summarizer_pipeline\"):\n",
|
| 305 |
+
" print(\"ℹ️ Summarizer already loaded, skipping reload.\")\n",
|
| 306 |
+
" return\n",
|
| 307 |
+
"\n",
|
| 308 |
+
" try:\n",
|
| 309 |
+
" print(f\"⏳ Loading summarizer model '{model_name}'...\")\n",
|
| 310 |
+
" self.tokenizer = AutoTokenizer.from_pretrained(model_name, token=self.hf_token)\n",
|
| 311 |
+
" self.summarizer_model = AutoModelForCausalLM.from_pretrained(\n",
|
| 312 |
+
" model_name,\n",
|
| 313 |
+
" torch_dtype=torch.float16 if self.device == \"cuda\" else torch.float32,\n",
|
| 314 |
+
" device_map=self.device,\n",
|
| 315 |
+
" token=self.hf_token\n",
|
| 316 |
+
" )\n",
|
| 317 |
+
" self.summarizer_pipeline = pipeline(\n",
|
| 318 |
+
" \"text-generation\",\n",
|
| 319 |
+
" model=self.summarizer_model,\n",
|
| 320 |
+
" tokenizer=self.tokenizer\n",
|
| 321 |
+
" )\n",
|
| 322 |
+
" print(f\"✅ Summarizer model '{model_name}' loaded successfully.\")\n",
|
| 323 |
+
"\n",
|
| 324 |
+
" except RuntimeError as e:\n",
|
| 325 |
+
" if \"CUDA out of memory\" in str(e):\n",
|
| 326 |
+
" print(\"⚠️ CUDA OOM while loading summarizer. Retrying on CPU...\")\n",
|
| 327 |
+
" self.device = \"cpu\"\n",
|
| 328 |
+
" torch.cuda.empty_cache()\n",
|
| 329 |
+
" return self.load_summarizer(model_name=model_name)\n",
|
| 330 |
+
" else:\n",
|
| 331 |
+
" print(f\"❌ Failed to load summarizer: {e}\")\n",
|
| 332 |
+
" raise\n",
|
| 333 |
+
"\n",
|
| 334 |
+
" def summarize_chunks(self, chunks, max_content_tokens=2048, max_output_tokens=256):\n",
|
| 335 |
+
" \"\"\"\n",
|
| 336 |
+
" Summarize list of text chunks using LLM.\n",
|
| 337 |
+
" - Chunks are joined until they fit into max_context_tokens\n",
|
| 338 |
+
" - Generates a concise summary.\n",
|
| 339 |
+
" \"\"\"\n",
|
| 340 |
+
"\n",
|
| 341 |
+
" if not hasattr(self, \"summarizer_pipeline\"):\n",
|
| 342 |
+
" self.load_summarizer()\n",
|
| 343 |
+
" print(\"Summarizer not initialized. Called load_summarizer(). pipeline will work with default parameters.\")\n",
|
| 344 |
+
"\n",
|
| 345 |
+
" # Join chunks into one context, respecting token budget\n",
|
| 346 |
+
" context = \" \".join(chunks)\n",
|
| 347 |
+
" input_tokens = len(self.tokenizer.encode(context))\n",
|
| 348 |
+
"\n",
|
| 349 |
+
" if input_tokens > max_content_tokens:\n",
|
| 350 |
+
" # Trim to fit context window\n",
|
| 351 |
+
" context = \" \".join(context.split()[:max_content_tokens])\n",
|
| 352 |
+
" print(\"⚠️ Context truncated to fit within model token limit.\")\n",
|
| 353 |
+
"\n",
|
| 354 |
+
" # Build summarization prompt\n",
|
| 355 |
+
" prompt = f\"\"\"\n",
|
| 356 |
+
" Summarize the following financial report excerpts into a concise answer.\n",
|
| 357 |
+
" Keep it factual, short, and grounded in the text.\n",
|
| 358 |
+
"\n",
|
| 359 |
+
" Excerpts:\n",
|
| 360 |
+
" {context}\n",
|
| 361 |
+
"\n",
|
| 362 |
+
" Summary:\n",
|
| 363 |
+
" \"\"\"\n",
|
| 364 |
+
"\n",
|
| 365 |
+
" try:\n",
|
| 366 |
+
" output = self.summarizer_pipeline(\n",
|
| 367 |
+
" prompt,\n",
|
| 368 |
+
" max_new_tokens=max_output_tokens,\n",
|
| 369 |
+
" do_sample=False\n",
|
| 370 |
+
" )[0][\"generated_text\"]\n",
|
| 371 |
+
"\n",
|
| 372 |
+
" if \"Summary:\" in output:\n",
|
| 373 |
+
" summary = output.split(\"Summary:\")[-1].strip()\n",
|
| 374 |
+
" else:\n",
|
| 375 |
+
" summary = output.strip()\n",
|
| 376 |
+
"\n",
|
| 377 |
+
" return summary\n",
|
| 378 |
+
"\n",
|
| 379 |
+
" except RuntimeError as e:\n",
|
| 380 |
+
" if \"CUDA out of memory\" in str(e):\n",
|
| 381 |
+
" print(\"⚠️ CUDA OOM during summarization. Retrying on CPU...\")\n",
|
| 382 |
+
" self.device = \"cpu\"\n",
|
| 383 |
+
" torch.cuda.empty_cache()\n",
|
| 384 |
+
" return self.summarize_chunks(chunks, max_content_tokens, max_output_tokens)\n",
|
| 385 |
+
" else:\n",
|
| 386 |
+
" print(f\"❌ Summarizer failed: {e}. Falling back to raw chunks.\")\n",
|
| 387 |
+
" return \" \".join(chunks[:2]) # fallback: return first 2 chunks\n",
|
| 388 |
+
"\n",
|
| 389 |
+
"\n",
|
| 390 |
+
" def answer_query(self, query, top_k=3):\n",
|
| 391 |
+
" \"\"\"\n",
|
| 392 |
+
" End-to-end QA:\n",
|
| 393 |
+
" - Retrieve relevant chunks from FAISS\n",
|
| 394 |
+
" - Summarize into a final answer.\n",
|
| 395 |
+
" \"\"\"\n",
|
| 396 |
+
" try:\n",
|
| 397 |
+
" #step 1: Retrieve\n",
|
| 398 |
+
" print(f\"🔍 searching vector DB for query: {query}\")\n",
|
| 399 |
+
" q_emb = self.encoder.encode(query, show_progress_bar=False, convert_to_numpy=True).reshape(1, -1)\n",
|
| 400 |
+
" faiss.normalize_L2(q_emb)\n",
|
| 401 |
+
"\n",
|
| 402 |
+
" scores, idxs = self.index.search(q_emb, k=top_k)\n",
|
| 403 |
+
" chunks = [self.metadata[idx][\"text\"] for idx in idxs[0]]\n",
|
| 404 |
+
"\n",
|
| 405 |
+
" # Step 2: Summarize\n",
|
| 406 |
+
" summary = self.summarize_chunks(chunks)\n",
|
| 407 |
+
"\n",
|
| 408 |
+
" # Log results\n",
|
| 409 |
+
" print(f\"✅ Final Answer: {summary}\")\n",
|
| 410 |
+
" return summary\n",
|
| 411 |
+
"\n",
|
| 412 |
+
" except Exception as e:\n",
|
| 413 |
+
" print(f\"Error in answer_query: {e}\")\n",
|
| 414 |
+
" return None\n",
|
| 415 |
+
"\n",
|
| 416 |
+
"\n",
|
| 417 |
+
"# Example\n",
|
| 418 |
+
"# ge = GetEmbeddings()\n",
|
| 419 |
+
"# ge.run()\n",
|
| 420 |
+
"# # NEW STEP\n",
|
| 421 |
+
"# ge.load_summarizer(\"google/gemma-2b\")\n",
|
| 422 |
+
"# answer = ge.answer_query(\"What are the key highlights from Q2 financial report?\")\n",
|
| 423 |
+
"# print(answer)"
|
| 424 |
+
]
|
| 425 |
+
},
|
| 426 |
+
{
|
| 427 |
+
"cell_type": "code",
|
| 428 |
+
"execution_count": null,
|
| 429 |
+
"metadata": {
|
| 430 |
+
"colab": {
|
| 431 |
+
"base_uri": "https://localhost:8080/"
|
| 432 |
+
},
|
| 433 |
+
"id": "SrZwOeGPba8Q",
|
| 434 |
+
"outputId": "b14f3d67-54d7-4db1-c030-702ab670bc90"
|
| 435 |
+
},
|
| 436 |
+
"outputs": [
|
| 437 |
+
{
|
| 438 |
+
"name": "stdout",
|
| 439 |
+
"output_type": "stream",
|
| 440 |
+
"text": [
|
| 441 |
+
"Writing Evaluator.py\n"
|
| 442 |
+
]
|
| 443 |
+
}
|
| 444 |
+
],
|
| 445 |
+
"source": [
|
| 446 |
+
"%%writefile Evaluator.py\n",
|
| 447 |
+
"import os\n",
|
| 448 |
+
"import json\n",
|
| 449 |
+
"import time\n",
|
| 450 |
+
"import numpy as np\n",
|
| 451 |
+
"from tqdm import tqdm\n",
|
| 452 |
+
"\n",
|
| 453 |
+
"# from Logger import GetLogger, MetricsLogger\n",
|
| 454 |
+
"# from Embeddings import GetEmbeddings\n",
|
| 455 |
+
"\n",
|
| 456 |
+
"# Metrics\n",
|
| 457 |
+
"from sklearn.metrics.pairwise import cosine_similarity\n",
|
| 458 |
+
"from rouge_score import rouge_scorer\n",
|
| 459 |
+
"from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction\n",
|
| 460 |
+
"from bert_score import score as bert_score\n",
|
| 461 |
+
"\n",
|
| 462 |
+
"class Evaluator:\n",
|
| 463 |
+
" \"\"\"\n",
|
| 464 |
+
" Evaluation pipeline for financial QA Agent.\n",
|
| 465 |
+
" Uses eval_dataset.json to run queries, collect answers, and compute metrics.\n",
|
| 466 |
+
" \"\"\"\n",
|
| 467 |
+
" def __init__(self, config_path=\"config.json\", logger=None):\n",
|
| 468 |
+
" with open(config_path, \"r\") as f:\n",
|
| 469 |
+
" self.config = json.load(f)\n",
|
| 470 |
+
" self.paths = self.config[\"paths\"]\n",
|
| 471 |
+
"\n",
|
| 472 |
+
"\n",
|
| 473 |
+
" # if not logger:\n",
|
| 474 |
+
" # obj = GetLogger()\n",
|
| 475 |
+
" # logger = obj.get_logger()\n",
|
| 476 |
+
" # self.logger = logger\n",
|
| 477 |
+
"\n",
|
| 478 |
+
"\t\t# # Metrics logger\n",
|
| 479 |
+
" # self.metrics_logger = MetricsLogger(logger=self.logger)\n",
|
| 480 |
+
"\n",
|
| 481 |
+
" # Initialize Agent\n",
|
| 482 |
+
" self.agent = GetEmbeddings(config_path=config_path, logger=None)\n",
|
| 483 |
+
" self.agent.run() # Load or rebuild FAISS + embeddings\n",
|
| 484 |
+
" self.agent.load_summarizer() # Load summarizer\n",
|
| 485 |
+
" self.encoder = self.agent.load_encoder()\n",
|
| 486 |
+
"\n",
|
| 487 |
+
" # Load Dataset\n",
|
| 488 |
+
" self.dataset = self.load_dataset()\n",
|
| 489 |
+
" self.results = []\n",
|
| 490 |
+
" self.failed_queries = []\n",
|
| 491 |
+
"\n",
|
| 492 |
+
" def load_dataset(self):\n",
|
| 493 |
+
" path = self.paths[\"eval_dataset\"]\n",
|
| 494 |
+
" if not os.path.exists(path):\n",
|
| 495 |
+
" raise FileNotFoundError(f\"Dataset not found: {path}\")\n",
|
| 496 |
+
" with open(path, \"r\", encoding=\"utf-8\") as f:\n",
|
| 497 |
+
" return json.load(f)\n",
|
| 498 |
+
"\n",
|
| 499 |
+
" def measure_latency(self, func, *args, **kwargs):\n",
|
| 500 |
+
" \"\"\"Helper: measure time taken by a function call.\"\"\"\n",
|
| 501 |
+
" start = time.time()\n",
|
| 502 |
+
" result = func(*args, **kwargs)\n",
|
| 503 |
+
" latency = time.time() - start\n",
|
| 504 |
+
" return result, latency\n",
|
| 505 |
+
"\n",
|
| 506 |
+
" def evaluate_query(self, query, reference):\n",
|
| 507 |
+
" \"\"\"Run one query, compare answer vs. reference, compute metrics.\"\"\"\n",
|
| 508 |
+
" # try:\n",
|
| 509 |
+
" # Run pipeline\n",
|
| 510 |
+
" system_answer, latency = self.measure_latency(self.agent.answer_query, query)\n",
|
| 511 |
+
"\n",
|
| 512 |
+
" # 1. Embedding similarity (proxy retrieval quality)\n",
|
| 513 |
+
" ref_emb = self.encoder.encode([reference], convert_to_numpy=True)\n",
|
| 514 |
+
" ans_emb = self.encoder.encode([system_answer], convert_to_numpy=True)\n",
|
| 515 |
+
" retrieval_quality = float(cosine_similarity(ref_emb, ans_emb)[0][0])\n",
|
| 516 |
+
"\n",
|
| 517 |
+
" # 2. ROUGE-L\n",
|
| 518 |
+
" scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)\n",
|
| 519 |
+
" rouge_score = scorer.score(reference, system_answer)['rougeL'].fmeasure\n",
|
| 520 |
+
"\n",
|
| 521 |
+
" # 3. BLEU (with smoothing for short texts)\n",
|
| 522 |
+
" smoothie = SmoothingFunction().method4\n",
|
| 523 |
+
" bleu = sentence_bleu([reference.split()], system_answer.split(), smoothing_function=smoothie)\n",
|
| 524 |
+
"\n",
|
| 525 |
+
" # 4. BERTScore (semantic similarity)\n",
|
| 526 |
+
" P, R, F1 = bert_score([system_answer], [reference], lang=\"en\")\n",
|
| 527 |
+
" bert_f1 = float(F1.mean())\n",
|
| 528 |
+
"\n",
|
| 529 |
+
" metrics = {\n",
|
| 530 |
+
" \"query\": query,\n",
|
| 531 |
+
" \"reference\": reference,\n",
|
| 532 |
+
" \"system_answer\": system_answer,\n",
|
| 533 |
+
" \"retrieval_quality\": retrieval_quality,\n",
|
| 534 |
+
" \"rougeL\": rouge_score,\n",
|
| 535 |
+
" \"bleu\": bleu,\n",
|
| 536 |
+
" \"bertscore_f1\": bert_f1,\n",
|
| 537 |
+
" \"latency_sec\": latency\n",
|
| 538 |
+
" }\n",
|
| 539 |
+
"\n",
|
| 540 |
+
" # Log into metrics logger\n",
|
| 541 |
+
" # self.metrics_logger.log_query_metrics(query, metrics)\n",
|
| 542 |
+
"\n",
|
| 543 |
+
" return metrics\n",
|
| 544 |
+
"\n",
|
| 545 |
+
" # except Exception as e:\n",
|
| 546 |
+
" # print(f\"Error evaluating query '{query}': {e}\")\n",
|
| 547 |
+
" # return None\n",
|
| 548 |
+
"\n",
|
| 549 |
+
"\n",
|
| 550 |
+
" def run(self):\n",
|
| 551 |
+
" \"\"\"Run evaluation on entire dataset.\"\"\"\n",
|
| 552 |
+
" print(\"Starting Evaluation...\")\n",
|
| 553 |
+
"\n",
|
| 554 |
+
" for item in tqdm(self.dataset, desc=\"Queries\"):\n",
|
| 555 |
+
" query = item[\"query\"]\n",
|
| 556 |
+
" reference = item[\"reference\"]\n",
|
| 557 |
+
" result = self.evaluate_query(query, reference)\n",
|
| 558 |
+
" if result:\n",
|
| 559 |
+
" self.results.append(result)\n",
|
| 560 |
+
"\n",
|
| 561 |
+
"\n",
|
| 562 |
+
" # Save result\n",
|
| 563 |
+
" with open(self.paths[\"eval_results\"], \"w\", encoding=\"utf-8\") as f:\n",
|
| 564 |
+
" json.dump(self.results, f, indent=2)\n",
|
| 565 |
+
"\n",
|
| 566 |
+
" if self.failed_queries:\n",
|
| 567 |
+
" with open(self.paths[\"failed_queries\"], \"w\", encoding=\"utf-8\") as f:\n",
|
| 568 |
+
" json.dump(self.failed_queries, f, indent=2)\n",
|
| 569 |
+
"\n",
|
| 570 |
+
"\n",
|
| 571 |
+
" # Save metrics summary\n",
|
| 572 |
+
" # summary = self.metrics_logger.save()\n",
|
| 573 |
+
" summary = None\n",
|
| 574 |
+
" print(f\"Evaluation Complete.\")\n",
|
| 575 |
+
" print(f\"📊 Evaluation summary: {summary}\")\n",
|
| 576 |
+
"\n",
|
| 577 |
+
" return self.results, summary\n",
|
| 578 |
+
"\n",
|
| 579 |
+
"\n",
|
| 580 |
+
"if __name__ == \"__main__\":\n",
|
| 581 |
+
" evaluator = Evaluator()\n",
|
| 582 |
+
" results, summary = evaluator.run()\n",
|
| 583 |
+
"\n",
|
| 584 |
+
" print(\"\\n=== Sample Results ===\")\n",
|
| 585 |
+
" print(json.dumps(results[:2], indent=2))\n",
|
| 586 |
+
" print(\"\\n=== Summary ===\")\n",
|
| 587 |
+
" print(json.dumps(summary, indent=2))\n"
|
| 588 |
+
]
|
| 589 |
+
},
|
| 590 |
+
{
|
| 591 |
+
"cell_type": "code",
|
| 592 |
+
"execution_count": null,
|
| 593 |
+
"metadata": {
|
| 594 |
+
"colab": {
|
| 595 |
+
"base_uri": "https://localhost:8080/"
|
| 596 |
+
},
|
| 597 |
+
"id": "_SgMUhSJbdcu",
|
| 598 |
+
"outputId": "c79fe42b-517f-40b7-cc2b-71ddaae05084"
|
| 599 |
+
},
|
| 600 |
+
"outputs": [
|
| 601 |
+
{
|
| 602 |
+
"name": "stdout",
|
| 603 |
+
"output_type": "stream",
|
| 604 |
+
"text": [
|
| 605 |
+
"Overwriting app.py\n"
|
| 606 |
+
]
|
| 607 |
+
}
|
| 608 |
+
],
|
| 609 |
+
"source": [
|
| 610 |
+
"%%writefile app.py\n",
|
| 611 |
+
"import streamlit as st\n",
|
| 612 |
+
"import json\n",
|
| 613 |
+
"import faiss\n",
|
| 614 |
+
"import numpy as np\n",
|
| 615 |
+
"import re\n",
|
| 616 |
+
"from Embeddings import GetEmbeddings\n",
|
| 617 |
+
"from Logger import GetLogger\n",
|
| 618 |
+
"\n",
|
| 619 |
+
"# ================================\n",
|
| 620 |
+
"# Load Config\n",
|
| 621 |
+
"# ================================\n",
|
| 622 |
+
"with open(\"config.json\", \"r\") as f:\n",
|
| 623 |
+
" config = json.load(f)\n",
|
| 624 |
+
"\n",
|
| 625 |
+
"# Initialize Logger\n",
|
| 626 |
+
"log_obj = GetLogger()\n",
|
| 627 |
+
"logger = log_obj.get_logger()\n",
|
| 628 |
+
"\n",
|
| 629 |
+
"# Initialize QA Agent\n",
|
| 630 |
+
"@st.cache_resource\n",
|
| 631 |
+
"def load_agent():\n",
|
| 632 |
+
" agent = GetEmbeddings(config_path=\"config.json\", logger=logger)\n",
|
| 633 |
+
" agent.run() # load or build FAISS index\n",
|
| 634 |
+
" encoder = agent.load_encoder()\n",
|
| 635 |
+
" agent.load_summarizer()\n",
|
| 636 |
+
" return agent, encoder\n",
|
| 637 |
+
"\n",
|
| 638 |
+
"agent, encoder = load_agent()\n",
|
| 639 |
+
"\n",
|
| 640 |
+
"# ================================\n",
|
| 641 |
+
"# Streamlit UI\n",
|
| 642 |
+
"# ================================\n",
|
| 643 |
+
"st.set_page_config(page_title=\"Financial QA Agent\", layout=\"wide\")\n",
|
| 644 |
+
"\n",
|
| 645 |
+
"# --- Header ---\n",
|
| 646 |
+
"st.title(\"💹 Financial Report QA Agent\")\n",
|
| 647 |
+
"st.markdown(\n",
|
| 648 |
+
" \"\"\"\n",
|
| 649 |
+
" Welcome!\n",
|
| 650 |
+
" This tool lets you **query annual financial reports** (Infosys, ICICI Bank, etc.)\n",
|
| 651 |
+
" and get **summarized answers** with supporting evidence from the text.\n",
|
| 652 |
+
" \"\"\"\n",
|
| 653 |
+
")\n",
|
| 654 |
+
"\n",
|
| 655 |
+
"# Sidebar - Settings\n",
|
| 656 |
+
"st.sidebar.header(\"⚙️ Settings\")\n",
|
| 657 |
+
"top_k = st.sidebar.slider(\"Top K Chunks\", 1, 10, 3)\n",
|
| 658 |
+
"max_output_tokens = st.sidebar.slider(\"Max Summary Tokens\", 64, 512, 256)\n",
|
| 659 |
+
"\n",
|
| 660 |
+
"# --- Keyword highlighting ---\n",
|
| 661 |
+
"def highlight_keywords(text, keywords=[\"risk\", \"revenue\", \"profit\", \"growth\", \"loss\"]):\n",
|
| 662 |
+
" pattern = re.compile(r\"\\b(\" + \"|\".join(keywords) + r\")\\b\", re.IGNORECASE)\n",
|
| 663 |
+
" return pattern.sub(lambda m: f\"**{m.group(0)}**\", text)\n",
|
| 664 |
+
"\n",
|
| 665 |
+
"# --- Session State for Query History ---\n",
|
| 666 |
+
"if \"history\" not in st.session_state:\n",
|
| 667 |
+
" st.session_state[\"history\"] = []\n",
|
| 668 |
+
"\n",
|
| 669 |
+
"# --- Query input ---\n",
|
| 670 |
+
"query = st.text_input(\"🔍 Enter your question:\", placeholder=\"e.g., What are the main risk factors in 2023?\")\n",
|
| 671 |
+
"\n",
|
| 672 |
+
"if st.button(\"Get Answer\"):\n",
|
| 673 |
+
" if query.strip() == \"\":\n",
|
| 674 |
+
" st.warning(\"Please enter a query.\")\n",
|
| 675 |
+
" else:\n",
|
| 676 |
+
" with st.spinner(\"Searching reports...\"):\n",
|
| 677 |
+
" try:\n",
|
| 678 |
+
" # Retrieve + summarize\n",
|
| 679 |
+
" answer = agent.answer_query(query, top_k=top_k)\n",
|
| 680 |
+
"\n",
|
| 681 |
+
" # --- Display final answer ---\n",
|
| 682 |
+
" st.subheader(\"📌 Answer\")\n",
|
| 683 |
+
" st.success(answer)\n",
|
| 684 |
+
"\n",
|
| 685 |
+
" # --- Show supporting chunks ---\n",
|
| 686 |
+
" st.subheader(\"📂 Supporting Chunks\")\n",
|
| 687 |
+
" q_emb = encoder.encode(query, convert_to_numpy=True).reshape(1, -1)\n",
|
| 688 |
+
" faiss.normalize_L2(q_emb)\n",
|
| 689 |
+
" scores, idxs = agent.index.search(q_emb.astype(np.float32), k=top_k)\n",
|
| 690 |
+
"\n",
|
| 691 |
+
" for score, idx in zip(scores[0], idxs[0]):\n",
|
| 692 |
+
" meta = agent.metadata[idx]\n",
|
| 693 |
+
" with st.expander(f\"📄 {meta['file']} | Chunk {meta['chunk_id']} | Score: {score:.4f}\"):\n",
|
| 694 |
+
" chunk_text = highlight_keywords(meta['text'][:1000])\n",
|
| 695 |
+
" st.markdown(chunk_text)\n",
|
| 696 |
+
"\n",
|
| 697 |
+
" # --- Save Query & Answer to History ---\n",
|
| 698 |
+
" st.session_state[\"history\"].append({\"query\": query, \"answer\": answer})\n",
|
| 699 |
+
"\n",
|
| 700 |
+
" # --- Log query + answer ---\n",
|
| 701 |
+
" logger.info(f\"User Query: {query}\")\n",
|
| 702 |
+
" logger.info(f\"System Answer: {answer}\")\n",
|
| 703 |
+
"\n",
|
| 704 |
+
" # --- Save persistent history JSON ---\n",
|
| 705 |
+
" with open(\"ui_query_history.json\", \"w\", encoding=\"utf-8\") as f:\n",
|
| 706 |
+
" json.dump(st.session_state[\"history\"], f, indent=2)\n",
|
| 707 |
+
"\n",
|
| 708 |
+
" except Exception as e:\n",
|
| 709 |
+
" st.error(f\"Error: {e}\")\n",
|
| 710 |
+
" logger.error(f\"Streamlit UI error: {e}\")\n",
|
| 711 |
+
"\n",
|
| 712 |
+
"# --- Show History in Sidebar ---\n",
|
| 713 |
+
"if st.session_state[\"history\"]:\n",
|
| 714 |
+
" st.sidebar.subheader(\"🕘 Query History\")\n",
|
| 715 |
+
" for item in st.session_state[\"history\"][-5:]: # show last 5 queries\n",
|
| 716 |
+
" st.sidebar.write(f\"**Q:** {item['query']}\")\n",
|
| 717 |
+
" st.sidebar.write(f\"**A:** {item['answer'][:100]}...\")\n",
|
| 718 |
+
" st.sidebar.markdown(\"---\")\n"
|
| 719 |
+
]
|
| 720 |
+
},
|
| 721 |
+
{
|
| 722 |
+
"cell_type": "code",
|
| 723 |
+
"execution_count": null,
|
| 724 |
+
"metadata": {
|
| 725 |
+
"colab": {
|
| 726 |
+
"base_uri": "https://localhost:8080/"
|
| 727 |
+
},
|
| 728 |
+
"id": "6UAnlclVckzM",
|
| 729 |
+
"outputId": "bb65eead-5953-4a4f-f838-14fadc1469dd"
|
| 730 |
+
},
|
| 731 |
+
"outputs": [
|
| 732 |
+
{
|
| 733 |
+
"name": "stdout",
|
| 734 |
+
"output_type": "stream",
|
| 735 |
+
"text": [
|
| 736 |
+
"\u001b[90m2025-09-29T13:35:21Z\u001b[0m \u001b[32mINF\u001b[0m Thank you for trying Cloudflare Tunnel. Doing so, without a Cloudflare account, is a quick way to experiment and try it out. However, be aware that these account-less Tunnels have no uptime guarantee, are subject to the Cloudflare Online Services Terms of Use (https://www.cloudflare.com/website-terms/), and Cloudflare reserves the right to investigate your use of Tunnels for violations of such terms. If you intend to use Tunnels in production you should use a pre-created named tunnel by following: https://developers.cloudflare.com/cloudflare-one/connections/connect-apps\n",
|
| 737 |
+
"\u001b[90m2025-09-29T13:35:21Z\u001b[0m \u001b[32mINF\u001b[0m Requesting new quick Tunnel on trycloudflare.com...\n",
|
| 738 |
+
"\u001b[90m2025-09-29T13:35:25Z\u001b[0m \u001b[32mINF\u001b[0m +--------------------------------------------------------------------------------------------+\n",
|
| 739 |
+
"\u001b[90m2025-09-29T13:35:25Z\u001b[0m \u001b[32mINF\u001b[0m | Your quick Tunnel has been created! Visit it at (it may take some time to be reachable): |\n",
|
| 740 |
+
"\u001b[90m2025-09-29T13:35:25Z\u001b[0m \u001b[32mINF\u001b[0m | https://ease-library-cases-gibraltar.trycloudflare.com |\n",
|
| 741 |
+
"\u001b[90m2025-09-29T13:35:25Z\u001b[0m \u001b[32mINF\u001b[0m +--------------------------------------------------------------------------------------------+\n",
|
| 742 |
+
"\u001b[90m2025-09-29T13:35:25Z\u001b[0m \u001b[32mINF\u001b[0m Cannot determine default configuration path. No file [config.yml config.yaml] in [~/.cloudflared ~/.cloudflare-warp ~/cloudflare-warp /etc/cloudflared /usr/local/etc/cloudflared]\n",
|
| 743 |
+
"\u001b[90m2025-09-29T13:35:25Z\u001b[0m \u001b[32mINF\u001b[0m Version 2025.9.1 (Checksum 3dc1dc4252eae3c691861f926e2b8640063a2ce534b07b7a3f4ec2de439ecfe3)\n",
|
| 744 |
+
"\u001b[90m2025-09-29T13:35:25Z\u001b[0m \u001b[32mINF\u001b[0m GOOS: linux, GOVersion: go1.24.4, GoArch: amd64\n",
|
| 745 |
+
"\u001b[90m2025-09-29T13:35:25Z\u001b[0m \u001b[32mINF\u001b[0m Settings: map[ha-connections:1 no-autoupdate:true protocol:quic url:http://localhost:8501]\n",
|
| 746 |
+
"\u001b[90m2025-09-29T13:35:25Z\u001b[0m \u001b[32mINF\u001b[0m cloudflared will not automatically update if installed by a package manager.\n",
|
| 747 |
+
"\u001b[90m2025-09-29T13:35:25Z\u001b[0m \u001b[32mINF\u001b[0m Generated Connector ID: b7e0104f-71af-4b1e-a366-b3b15b2c86d9\n",
|
| 748 |
+
"\u001b[90m2025-09-29T13:35:25Z\u001b[0m \u001b[32mINF\u001b[0m Initial protocol quic\n",
|
| 749 |
+
"\u001b[90m2025-09-29T13:35:25Z\u001b[0m \u001b[32mINF\u001b[0m ICMP proxy will use 172.28.0.12 as source for IPv4\n",
|
| 750 |
+
"\u001b[90m2025-09-29T13:35:25Z\u001b[0m \u001b[32mINF\u001b[0m ICMP proxy will use :: as source for IPv6\n",
|
| 751 |
+
"\u001b[90m2025-09-29T13:35:25Z\u001b[0m \u001b[1m\u001b[31mERR\u001b[0m\u001b[0m Cannot determine default origin certificate path. No file cert.pem in [~/.cloudflared ~/.cloudflare-warp ~/cloudflare-warp /etc/cloudflared /usr/local/etc/cloudflared]. You need to specify the origin certificate path by specifying the origincert option in the configuration file, or set TUNNEL_ORIGIN_CERT environment variable \u001b[36moriginCertPath=\u001b[0m\n",
|
| 752 |
+
"\u001b[90m2025-09-29T13:35:25Z\u001b[0m \u001b[32mINF\u001b[0m ICMP proxy will use 172.28.0.12 as source for IPv4\n",
|
| 753 |
+
"\u001b[90m2025-09-29T13:35:25Z\u001b[0m \u001b[32mINF\u001b[0m ICMP proxy will use :: as source for IPv6\n",
|
| 754 |
+
"\u001b[90m2025-09-29T13:35:25Z\u001b[0m \u001b[32mINF\u001b[0m Starting metrics server on 127.0.0.1:20241/metrics\n",
|
| 755 |
+
"\u001b[90m2025-09-29T13:35:25Z\u001b[0m \u001b[32mINF\u001b[0m Tunnel connection curve preferences: [X25519MLKEM768 CurveP256] \u001b[36mconnIndex=\u001b[0m0 \u001b[36mevent=\u001b[0m0 \u001b[36mip=\u001b[0m198.41.200.113\n",
|
| 756 |
+
"2025/09/29 13:35:25 failed to sufficiently increase receive buffer size (was: 208 kiB, wanted: 7168 kiB, got: 416 kiB). See https://github.com/quic-go/quic-go/wiki/UDP-Buffer-Sizes for details.\n",
|
| 757 |
+
"\u001b[90m2025-09-29T13:35:26Z\u001b[0m \u001b[32mINF\u001b[0m Registered tunnel connection \u001b[36mconnIndex=\u001b[0m0 \u001b[36mconnection=\u001b[0mc535a197-93c0-4941-a9ab-b32533b50549 \u001b[36mevent=\u001b[0m0 \u001b[36mip=\u001b[0m198.41.200.113 \u001b[36mlocation=\u001b[0msin02 \u001b[36mprotocol=\u001b[0mquic\n",
|
| 758 |
+
"\u001b[90m2025-09-29T13:38:58Z\u001b[0m \u001b[32mINF\u001b[0m Initiating graceful shutdown due to signal interrupt ...\n",
|
| 759 |
+
"\u001b[90m2025-09-29T13:38:58Z\u001b[0m \u001b[1m\u001b[31mERR\u001b[0m\u001b[0m failed to run the datagram handler \u001b[31merror=\u001b[0m\u001b[31m\"context canceled\"\u001b[0m \u001b[36mconnIndex=\u001b[0m0 \u001b[36mevent=\u001b[0m0 \u001b[36mip=\u001b[0m198.41.200.113\n",
|
| 760 |
+
"\u001b[90m2025-09-29T13:38:58Z\u001b[0m \u001b[1m\u001b[31mERR\u001b[0m\u001b[0m failed to serve tunnel connection \u001b[31merror=\u001b[0m\u001b[31m\"accept stream listener encountered a failure while serving\"\u001b[0m \u001b[36mconnIndex=\u001b[0m0 \u001b[36mevent=\u001b[0m0 \u001b[36mip=\u001b[0m198.41.200.113\n",
|
| 761 |
+
"\u001b[90m2025-09-29T13:38:58Z\u001b[0m \u001b[1m\u001b[31mERR\u001b[0m\u001b[0m Serve tunnel error \u001b[31merror=\u001b[0m\u001b[31m\"accept stream listener encountered a failure while serving\"\u001b[0m \u001b[36mconnIndex=\u001b[0m0 \u001b[36mevent=\u001b[0m0 \u001b[36mip=\u001b[0m198.41.200.113\n",
|
| 762 |
+
"\u001b[90m2025-09-29T13:38:58Z\u001b[0m \u001b[32mINF\u001b[0m Retrying connection in up to 1s \u001b[36mconnIndex=\u001b[0m0 \u001b[36mevent=\u001b[0m0 \u001b[36mip=\u001b[0m198.41.200.113\n",
|
| 763 |
+
"\u001b[90m2025-09-29T13:38:58Z\u001b[0m \u001b[1m\u001b[31mERR\u001b[0m\u001b[0m Connection terminated \u001b[36mconnIndex=\u001b[0m0\n",
|
| 764 |
+
"\u001b[90m2025-09-29T13:38:58Z\u001b[0m \u001b[1m\u001b[31mERR\u001b[0m\u001b[0m no more connections active and exiting\n",
|
| 765 |
+
"\u001b[90m2025-09-29T13:38:58Z\u001b[0m \u001b[32mINF\u001b[0m Tunnel server stopped\n",
|
| 766 |
+
"\u001b[90m2025-09-29T13:38:58Z\u001b[0m \u001b[32mINF\u001b[0m Metrics server stopped\n"
|
| 767 |
+
]
|
| 768 |
+
}
|
| 769 |
+
],
|
| 770 |
+
"source": [
|
| 771 |
+
"import threading, os\n",
|
| 772 |
+
"\n",
|
| 773 |
+
"# Kill anything on port 8501 (just in case)\n",
|
| 774 |
+
"os.system(\"kill -9 $(lsof -t -i:8501) 2>/dev/null\")\n",
|
| 775 |
+
"\n",
|
| 776 |
+
"# Run Streamlit in background\n",
|
| 777 |
+
"def run_app():\n",
|
| 778 |
+
" os.system(\"streamlit run app.py --server.port 8501\")\n",
|
| 779 |
+
"\n",
|
| 780 |
+
"thread = threading.Thread(target=run_app)\n",
|
| 781 |
+
"thread.start()\n",
|
| 782 |
+
"\n",
|
| 783 |
+
"# Start cloudflared tunnel\n",
|
| 784 |
+
"!cloudflared tunnel --url http://localhost:8501 --no-autoupdate\n"
|
| 785 |
+
]
|
| 786 |
+
}
|
| 787 |
+
],
|
| 788 |
+
"metadata": {
|
| 789 |
+
"accelerator": "GPU",
|
| 790 |
+
"colab": {
|
| 791 |
+
"gpuType": "T4",
|
| 792 |
+
"machine_shape": "hm",
|
| 793 |
+
"provenance": []
|
| 794 |
+
},
|
| 795 |
+
"kernelspec": {
|
| 796 |
+
"display_name": "Python 3",
|
| 797 |
+
"name": "python3"
|
| 798 |
+
},
|
| 799 |
+
"language_info": {
|
| 800 |
+
"name": "python"
|
| 801 |
+
}
|
| 802 |
+
},
|
| 803 |
+
"nbformat": 4,
|
| 804 |
+
"nbformat_minor": 0
|
| 805 |
+
}
|
config.json
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"paths": {
|
| 3 |
+
"root": "financial_reports",
|
| 4 |
+
"faiss_index": "financial_reports/faiss_index.idx",
|
| 5 |
+
"metadata": "financial_reports/faiss_metadata.pkl",
|
| 6 |
+
"eval_dataset": "eval_dataset.json",
|
| 7 |
+
"eval_results": "eval_results.json",
|
| 8 |
+
"failed_queries": "failed_queries.json"
|
| 9 |
+
},
|
| 10 |
+
"embedding": {
|
| 11 |
+
"model": "sentence-transformers/all-MiniLM-L6-v2",
|
| 12 |
+
"chunk_words": 600,
|
| 13 |
+
"batch_size": 64,
|
| 14 |
+
"use_gpu": true
|
| 15 |
+
},
|
| 16 |
+
"summarizer": {
|
| 17 |
+
"model": "google/gemma-2b",
|
| 18 |
+
"max_content_tokens": 2048,
|
| 19 |
+
"max_output_tokens": 256
|
| 20 |
+
},
|
| 21 |
+
"logging": {
|
| 22 |
+
"level": "INFO",
|
| 23 |
+
"log_dir": "logs"
|
| 24 |
+
}
|
| 25 |
+
}
|
evaluation.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import time
|
| 4 |
+
import numpy as np
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
import nltk
|
| 7 |
+
|
| 8 |
+
from Logger import GetLogger, MetricsLogger
|
| 9 |
+
from Embeddings import GetEmbeddings
|
| 10 |
+
|
| 11 |
+
# Metrics
|
| 12 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
| 13 |
+
from rouge_score import rouge_scorer
|
| 14 |
+
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
|
| 15 |
+
from bert_score import score as bert_score
|
| 16 |
+
|
| 17 |
+
class Evaluator:
|
| 18 |
+
"""
|
| 19 |
+
Evaluation pipeline for financial QA Agent.
|
| 20 |
+
Uses eval_dataset.json to run queries, collect answers, and compute metrics.
|
| 21 |
+
"""
|
| 22 |
+
def __init__(self, config_path="config.json", logger=None):
|
| 23 |
+
with open(config_path, "r") as f:
|
| 24 |
+
self.config = json.load(f)
|
| 25 |
+
self.paths = self.config["paths"]
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
if not logger:
|
| 29 |
+
obj = GetLogger()
|
| 30 |
+
logger = obj.get_logger()
|
| 31 |
+
self.logger = logger
|
| 32 |
+
|
| 33 |
+
# Metrics logger
|
| 34 |
+
self.metrics_logger = MetricsLogger(logger=self.logger)
|
| 35 |
+
|
| 36 |
+
# Initialize Agent
|
| 37 |
+
self.agent = GetEmbeddings(config_path=config_path, logger=self.logger)
|
| 38 |
+
self.agent.run() # Load or rebuild FAISS + embeddings
|
| 39 |
+
self.agent.load_summarizer() # Load summarizer
|
| 40 |
+
self.encoder = self.agent.load_encoder()
|
| 41 |
+
|
| 42 |
+
# Load Dataset
|
| 43 |
+
self.dataset = self.load_dataset()
|
| 44 |
+
self.results = []
|
| 45 |
+
self.failed_queries = []
|
| 46 |
+
|
| 47 |
+
nltk.download('punkt', quiet=True)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def load_dataset(self):
|
| 51 |
+
path = self.paths["eval_dataset"]
|
| 52 |
+
if not os.path.exists(path):
|
| 53 |
+
raise FileNotFoundError(f"Dataset not found: {path}")
|
| 54 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 55 |
+
return json.load(f)
|
| 56 |
+
|
| 57 |
+
def measure_latency(self, func, *args, **kwargs):
|
| 58 |
+
"""Helper: measure time taken by a function call."""
|
| 59 |
+
start = time.time()
|
| 60 |
+
result = func(*args, **kwargs)
|
| 61 |
+
latency = time.time() - start
|
| 62 |
+
return result, latency
|
| 63 |
+
|
| 64 |
+
def evaluate_query(self, query, reference):
|
| 65 |
+
"""Run one query, compare answer vs. reference, compute metrics."""
|
| 66 |
+
try:
|
| 67 |
+
# Run pipeline
|
| 68 |
+
system_answer, latency = self.measure_latency(self.agent.answer_query, query)
|
| 69 |
+
|
| 70 |
+
# 1. Embedding similarity (proxy retrieval quality)
|
| 71 |
+
ref_emb = self.encoder.encode([reference], convert_to_numpy=True)
|
| 72 |
+
ans_emb = self.encoder.encode([system_answer], convert_to_numpy=True)
|
| 73 |
+
retrieval_quality = float(cosine_similarity(ref_emb, ans_emb)[0][0])
|
| 74 |
+
|
| 75 |
+
# 2. ROUGE-L
|
| 76 |
+
scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
|
| 77 |
+
rouge_score = scorer.score(reference, system_answer)['rougeL'].fmeasure
|
| 78 |
+
|
| 79 |
+
# 3. BLEU (with smoothing for short texts)
|
| 80 |
+
smoothie = SmoothingFunction().method4
|
| 81 |
+
bleu = sentence_bleu([reference.split()], system_answer.split(), smoothing_function=smoothie)
|
| 82 |
+
|
| 83 |
+
# 4. BERTScore (semantic similarity)
|
| 84 |
+
P, R, F1 = bert_score([system_answer], [reference], lang="en")
|
| 85 |
+
bert_f1 = float(F1.mean())
|
| 86 |
+
|
| 87 |
+
metrics = {
|
| 88 |
+
"query": query,
|
| 89 |
+
"reference": reference,
|
| 90 |
+
"system_answer": system_answer,
|
| 91 |
+
"retrieval_quality": retrieval_quality,
|
| 92 |
+
"rougeL": rouge_score,
|
| 93 |
+
"bleu": bleu,
|
| 94 |
+
"bertscore_f1": bert_f1,
|
| 95 |
+
"latency_sec": latency
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
# Log into metrics logger
|
| 99 |
+
self.metrics_logger.log_query_metrics(query, metrics)
|
| 100 |
+
|
| 101 |
+
return metrics
|
| 102 |
+
|
| 103 |
+
except Exception as e:
|
| 104 |
+
self.logger.error(f"Error evaluating query '{query}': {e}")
|
| 105 |
+
return None
|
| 106 |
+
|
| 107 |
+
def aggregate_summary(self):
|
| 108 |
+
"""Aggregate metrics across all queries for global averages."""
|
| 109 |
+
if not self.results:
|
| 110 |
+
return {}
|
| 111 |
+
|
| 112 |
+
summary = {
|
| 113 |
+
"avg_retrieval_quality": float(np.mean([r["retrieval_quality"] for r in self.results])),
|
| 114 |
+
"avg_rougeL": float(np.mean([r["rougeL"] for r in self.results])),
|
| 115 |
+
"avg_bleu": float(np.mean([r["bleu"] for r in self.results])),
|
| 116 |
+
"avg_bertscore_f1": float(np.mean([r["bertscore_f1"] for r in self.results])),
|
| 117 |
+
"avg_latency_sec": float(np.mean([r["latency_sec"] for r in self.results])),
|
| 118 |
+
"num_queries": len(self.results)
|
| 119 |
+
}
|
| 120 |
+
return summary
|
| 121 |
+
|
| 122 |
+
def run(self):
|
| 123 |
+
"""Run evaluation on entire dataset."""
|
| 124 |
+
self.logger.info("Starting Evaluation...")
|
| 125 |
+
|
| 126 |
+
for item in tqdm(self.dataset, desc="Queries"):
|
| 127 |
+
query = item["query"]
|
| 128 |
+
reference = item["reference"]
|
| 129 |
+
result = self.evaluate_query(query, reference)
|
| 130 |
+
if result:
|
| 131 |
+
self.results.append(result)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
# Save result
|
| 135 |
+
with open(self.paths["eval_results"], "w", encoding="utf-8") as f:
|
| 136 |
+
json.dump(self.results, f, indent=2)
|
| 137 |
+
|
| 138 |
+
if self.failed_queries:
|
| 139 |
+
with open(self.paths["failed_queries"], "w", encoding="utf-8") as f:
|
| 140 |
+
json.dump(self.failed_queries, f, indent=2)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
# Save metrics summary
|
| 144 |
+
summary = self.aggregate_summary() # NEW: aggregated averages
|
| 145 |
+
self.logger.info(f"📊 Evaluation summary: {summary}")
|
| 146 |
+
|
| 147 |
+
# Also save aggregated summary separately
|
| 148 |
+
with open(self.paths.get("eval_summary", "eval_summary.json"), "w", encoding="utf-8") as f:
|
| 149 |
+
json.dump(summary, f, indent=2)
|
| 150 |
+
|
| 151 |
+
return self.results, summary
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
if __name__ == "__main__":
|
| 155 |
+
evaluator = Evaluator()
|
| 156 |
+
results, summary = evaluator.run()
|
| 157 |
+
|
| 158 |
+
print("\n=== Sample Results ===")
|
| 159 |
+
print(json.dumps(results[:2], indent=2))
|
| 160 |
+
print("\n=== Summary ===")
|
| 161 |
+
print(json.dumps(summary, indent=2))
|
requirements.txt
CHANGED
|
@@ -1,3 +1,18 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pdfplumber
|
| 2 |
+
tqdm
|
| 3 |
+
transformers
|
| 4 |
+
sentence-transformers
|
| 5 |
+
numpy
|
| 6 |
+
faiss-cpu
|
| 7 |
+
python-dotenv
|
| 8 |
+
accelerate
|
| 9 |
+
protobuf
|
| 10 |
+
tiktoken
|
| 11 |
+
SentencePiece
|
| 12 |
+
bitsandbytes
|
| 13 |
+
nltk
|
| 14 |
+
rouge-score
|
| 15 |
+
bert-score
|
| 16 |
+
streamlit
|
| 17 |
+
python-dateutil
|
| 18 |
+
protobuf<4.0.0
|