Spaces:
Sleeping
Sleeping
DenysPetro
commited on
Commit
•
2fa7106
1
Parent(s):
0619696
changed structure
Browse files- app.py +22 -503
- assistant.py +97 -0
- citation.py +99 -0
- sherlock.pdf → data/sherlock.pdf +0 -0
- embeddings.py +58 -0
- reranker.py +32 -0
- retrievers.py +174 -0
- text_processing.py +33 -0
app.py
CHANGED
@@ -1,520 +1,37 @@
|
|
1 |
-
import numpy as np
|
2 |
-
import pandas as pd
|
3 |
-
import faiss
|
4 |
-
import fitz
|
5 |
import nltk
|
6 |
-
import re
|
7 |
-
import os
|
8 |
-
from abc import ABC, abstractmethod
|
9 |
-
from sentence_transformers import SentenceTransformer
|
10 |
-
from sentence_transformers import CrossEncoder
|
11 |
-
from rank_bm25 import BM25Okapi
|
12 |
-
from sklearn.preprocessing import normalize
|
13 |
-
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
14 |
-
from FlagEmbedding import FlagReranker
|
15 |
import gradio as gr
|
16 |
-
from litellm import completion
|
17 |
-
import textwrap
|
18 |
-
|
19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
nltk.download('punkt', quiet=True)
|
22 |
nltk.download('punkt')
|
23 |
nltk.download('punkt_tab')
|
24 |
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
for page in doc:
|
29 |
-
text += page.get_text()
|
30 |
-
return text
|
31 |
-
|
32 |
-
|
33 |
-
def clean_text(text):
|
34 |
-
text = re.sub(r'(\s?\.\s?)+', '', text)
|
35 |
-
text = re.sub(r'\s+', ' ', text)
|
36 |
-
text = re.sub(r'[^\x00-\x7F]+', '', text)
|
37 |
-
return text.strip()
|
38 |
-
|
39 |
-
|
40 |
-
def chunk_text(text, chunk_size=1000, chunk_overlap=150):
|
41 |
-
text_splitter = RecursiveCharacterTextSplitter(
|
42 |
-
chunk_size=chunk_size,
|
43 |
-
chunk_overlap=chunk_overlap,
|
44 |
-
separators=["\n\n", "\n", " "]
|
45 |
-
)
|
46 |
-
return text_splitter.split_text(text)
|
47 |
-
|
48 |
-
|
49 |
-
def generate_embeddings(chunks, model_name='all-MiniLM-L6-v2'):
|
50 |
-
model = SentenceTransformer(model_name)
|
51 |
-
embeddings = model.encode(chunks, convert_to_numpy=True)
|
52 |
-
return embeddings
|
53 |
-
|
54 |
-
|
55 |
-
def process_pdf_for_rag(pdf_path, chunk_size=500, model_name='all-MiniLM-L6-v2'):
|
56 |
-
print("Extracting text from PDF...")
|
57 |
-
raw_text = extract_text_from_pdf(pdf_path)
|
58 |
-
print("Cleaning text...")
|
59 |
-
clean_text_content = clean_text(raw_text)
|
60 |
-
print("Chunking text...")
|
61 |
-
chunks = chunk_text(clean_text_content, chunk_size)
|
62 |
-
print("Generating embeddings...")
|
63 |
-
embeddings = generate_embeddings(chunks, model_name)
|
64 |
-
print("Processing complete!")
|
65 |
-
return chunks
|
66 |
-
|
67 |
-
class Citation:
|
68 |
-
def __init__(self, cleaned_text):
|
69 |
-
"""
|
70 |
-
Initialize with cleaned text.
|
71 |
-
:param cleaned_text: The cleaned text from the PDF.
|
72 |
-
"""
|
73 |
-
self.cleaned_text = cleaned_text
|
74 |
-
self.chapters = self.extract_chapter_names()
|
75 |
-
self.chapter_data = self.extract_chapter_data()
|
76 |
-
self.df = self.to_dataframe()
|
77 |
-
|
78 |
-
def extract_chapter_names(self):
|
79 |
-
"""
|
80 |
-
Extract chapter names using the provided `extract_chapter_names` function.
|
81 |
-
:return: List of chapter names.
|
82 |
-
"""
|
83 |
-
toc_match = re.search(r"Table of contents(.*?)Table of contents", self.cleaned_text, re.DOTALL)
|
84 |
-
if not toc_match:
|
85 |
-
raise ValueError("Table of Contents section not found.")
|
86 |
-
|
87 |
-
toc_section = toc_match.group(1)
|
88 |
-
chapter_names = re.findall(r"(.*?)\d+", toc_section)
|
89 |
-
return [name.strip() for name in chapter_names if name.strip()]
|
90 |
-
|
91 |
-
def extract_chapter_data(self):
|
92 |
-
"""
|
93 |
-
Extract full text for each chapter from the second occurrence of the chapter name
|
94 |
-
to the second occurrence of the next chapter name, or to the end if it's the last chapter.
|
95 |
-
:return: A list of tuples (chapter_name, plain_text).
|
96 |
-
Second occurence because first in text is in table of contents :)
|
97 |
-
"""
|
98 |
-
chapter_data = []
|
99 |
-
|
100 |
-
for i, chapter in enumerate(self.chapters):
|
101 |
-
current_chapter_pattern = re.escape(chapter)
|
102 |
-
next_chapter_pattern = re.escape(self.chapters[i + 1]) if i + 1 < len(self.chapters) else None
|
103 |
-
|
104 |
-
if next_chapter_pattern:
|
105 |
-
matches = list(re.finditer(rf"{current_chapter_pattern}(.*?){next_chapter_pattern}", self.cleaned_text, re.DOTALL))
|
106 |
-
else:
|
107 |
-
matches = list(re.finditer(rf"{current_chapter_pattern}(.*)", self.cleaned_text, re.DOTALL))
|
108 |
-
|
109 |
-
if len(matches) >= 2:
|
110 |
-
start = matches[1].start(1)
|
111 |
-
end = matches[1].end(1)
|
112 |
-
|
113 |
-
if next_chapter_pattern:
|
114 |
-
next_match = re.search(rf"{next_chapter_pattern}", self.cleaned_text[end:])
|
115 |
-
if next_match:
|
116 |
-
end += next_match.start()
|
117 |
-
|
118 |
-
chapter_content = self.cleaned_text[start:end].strip()
|
119 |
-
else:
|
120 |
-
chapter_content = ""
|
121 |
-
|
122 |
-
if i == len(self.chapters) - 1:
|
123 |
-
last_chapter_matches = list(re.finditer(rf"{current_chapter_pattern}", self.cleaned_text))
|
124 |
-
|
125 |
-
if len(last_chapter_matches) >= 2:
|
126 |
-
start = last_chapter_matches[1].start()
|
127 |
-
chapter_content = self.cleaned_text[start:].strip()
|
128 |
-
|
129 |
-
chapter_data.append((chapter, chapter_content))
|
130 |
-
|
131 |
-
return chapter_data
|
132 |
-
|
133 |
-
def to_dataframe(self):
|
134 |
-
"""
|
135 |
-
Converts the chapter data into a Pandas DataFrame.
|
136 |
-
:return: DataFrame with 'Chapter' and 'Content' columns.
|
137 |
-
"""
|
138 |
-
df = pd.DataFrame(self.chapter_data, columns=['Chapter', 'Content'])
|
139 |
-
return df
|
140 |
-
|
141 |
-
def search_citate(self, retrieved_chunks):
|
142 |
-
"""
|
143 |
-
Finds the chapter name for each chunk in the content and returns a formatted string.
|
144 |
-
|
145 |
-
Args:
|
146 |
-
retrieved_chunks (list): List of text chunks to search in the chapters' content.
|
147 |
-
|
148 |
-
Returns:
|
149 |
-
str: A formatted string of citations, each chunk starting on a new line.
|
150 |
-
"""
|
151 |
-
citations = []
|
152 |
-
|
153 |
-
for idx, chunk in enumerate(retrieved_chunks, start=1):
|
154 |
-
for _, row in self.df.iterrows():
|
155 |
-
chapter_name = row['Chapter']
|
156 |
-
chapter_content = row['Content']
|
157 |
-
|
158 |
-
if chunk in chapter_content:
|
159 |
-
citations.append(f"chunk [{idx}] from chapter: {chapter_name}")
|
160 |
-
break
|
161 |
-
|
162 |
-
return "\n".join(citations)
|
163 |
-
|
164 |
-
|
165 |
-
# pdf_text = extract_text_from_pdf("sherlock.pdf")
|
166 |
-
# cleaned_text = clean_text(pdf_text)
|
167 |
-
# citation = Citation(cleaned_text)
|
168 |
-
|
169 |
-
# df = citation.to_dataframe()
|
170 |
-
# df.to_csv('chapters_text.csv', index=True)
|
171 |
-
|
172 |
-
|
173 |
-
# # chapter_name, chapter_content = citation.chapter_data[0]
|
174 |
-
# # print(f"Chapter: {chapter_name}\nContent:\n{chapter_content}")
|
175 |
-
|
176 |
-
# # Assuming 'retrieved_chunks' is the list of chunks you want to search for
|
177 |
-
# retrieved_chunks = ['peculiarities of the typewriter', 'Irene Adler']
|
178 |
-
|
179 |
-
# # Create a Citation object
|
180 |
-
# citation = Citation(cleaned_text)
|
181 |
-
|
182 |
-
# # Get the citations for the retrieved chunks
|
183 |
-
# citations = citation.search_citate(retrieved_chunks)
|
184 |
-
|
185 |
-
# # Print the results
|
186 |
-
# print(citations)
|
187 |
-
|
188 |
-
class Retriever(ABC):
|
189 |
-
"""
|
190 |
-
Abstract base class for retrievers.
|
191 |
-
"""
|
192 |
-
|
193 |
-
@abstractmethod
|
194 |
-
def retrieve(self, query, top_k=5):
|
195 |
-
"""
|
196 |
-
Retrieve top-k relevant chunks for a query.
|
197 |
-
|
198 |
-
Args:
|
199 |
-
query (str): The search query.
|
200 |
-
top_k (int): Number of top results to retrieve.
|
201 |
-
|
202 |
-
Returns:
|
203 |
-
list: List of (chunk, score) tuples sorted by relevance.
|
204 |
-
"""
|
205 |
-
pass
|
206 |
-
|
207 |
-
|
208 |
-
class KeyWordRetriever(Retriever):
|
209 |
-
"""
|
210 |
-
Keyword-based retriever using BM25.
|
211 |
-
"""
|
212 |
-
|
213 |
-
def __init__(self, chunks):
|
214 |
-
"""
|
215 |
-
Initialize the BM25 retriever with pre-tokenized chunks.
|
216 |
-
|
217 |
-
Args:
|
218 |
-
chunks (list): List of text chunks to index.
|
219 |
-
"""
|
220 |
-
self.tokenized_chunks = [nltk.word_tokenize(chunk) for chunk in chunks]
|
221 |
-
self.bm25 = BM25Okapi(self.tokenized_chunks)
|
222 |
-
self.chunks = chunks
|
223 |
-
|
224 |
-
def retrieve(self, query, top_k=5):
|
225 |
-
"""
|
226 |
-
Retrieve top-k chunks based on BM25 scores.
|
227 |
-
|
228 |
-
Args:
|
229 |
-
query (str): The search query.
|
230 |
-
top_k (int): Number of top results to retrieve.
|
231 |
-
|
232 |
-
Returns:
|
233 |
-
list: List of (chunk, score) tuples sorted by relevance.
|
234 |
-
"""
|
235 |
-
query_tokens = nltk.word_tokenize(query)
|
236 |
-
scores = self.bm25.get_scores(query_tokens)
|
237 |
-
ranked_indices = np.argsort(scores)[::-1][:top_k]
|
238 |
-
return [(self.chunks[i], scores[i]) for i in ranked_indices]
|
239 |
-
|
240 |
-
|
241 |
-
class SemanticRetriever(Retriever):
|
242 |
-
"""
|
243 |
-
Semantic retriever using SentenceTransformers and FAISS.
|
244 |
-
"""
|
245 |
-
|
246 |
-
def __init__(self, chunks, model_name='all-MiniLM-L6-v2', index_path="faiss_index"):
|
247 |
-
"""
|
248 |
-
Initialize the semantic retriever with SentenceTransformers and FAISS.
|
249 |
-
|
250 |
-
Args:
|
251 |
-
chunks (list): List of text chunks.
|
252 |
-
model_name (str): Model name for SentenceTransformers.
|
253 |
-
index_path (str): Path to save/load the FAISS index.
|
254 |
-
"""
|
255 |
-
self.chunks = chunks
|
256 |
-
self.model = SentenceTransformer(model_name)
|
257 |
-
self.index_path = index_path
|
258 |
-
self.index = self._create_faiss_index(chunks)
|
259 |
-
|
260 |
-
def _create_faiss_index(self, chunks):
|
261 |
-
"""
|
262 |
-
Create a FAISS index from text chunks.
|
263 |
-
|
264 |
-
Args:
|
265 |
-
chunks (list): List of text chunks.
|
266 |
-
|
267 |
-
Returns:
|
268 |
-
faiss.Index: Trained FAISS index.
|
269 |
-
"""
|
270 |
-
|
271 |
-
embeddings = generate_embeddings(chunks)
|
272 |
-
dimension = embeddings.shape[1]
|
273 |
-
index = faiss.IndexFlatL2(dimension)
|
274 |
-
index.add(embeddings)
|
275 |
-
|
276 |
-
faiss.write_index(index, self.index_path)
|
277 |
-
print(f"FAISS index saved to {self.index_path}")
|
278 |
-
return index
|
279 |
-
|
280 |
-
def retrieve(self, query, top_k=5):
|
281 |
-
"""
|
282 |
-
Retrieve top-k chunks based on semantic similarity.
|
283 |
-
|
284 |
-
Args:
|
285 |
-
query (str): The search query.
|
286 |
-
top_k (int): Number of top results to retrieve.
|
287 |
-
|
288 |
-
Returns:
|
289 |
-
list: List of (chunk, score) tuples sorted by relevance.
|
290 |
-
"""
|
291 |
-
query_embedding = self.model.encode([query], convert_to_numpy=True).astype("float32")
|
292 |
-
faiss.normalize_L2(query_embedding)
|
293 |
-
distances, indices = self.index.search(query_embedding, top_k)
|
294 |
-
results = [(self.chunks[idx], 1 / (1 + distances[0][i])) for i, idx in enumerate(indices[0])]
|
295 |
-
return results
|
296 |
-
|
297 |
-
class HybridRetriever(Retriever):
|
298 |
-
"""
|
299 |
-
Hybrid retriever that combines results from keyword-based and semantic retrievers.
|
300 |
-
"""
|
301 |
-
|
302 |
-
def __init__(self, keyword_retriever, semantic_retriever):
|
303 |
-
"""
|
304 |
-
Initialize the HybridRetriever.
|
305 |
-
|
306 |
-
Args:
|
307 |
-
keyword_retriever (KeyWordRetriever): An instance of KeyWordRetriever.
|
308 |
-
semantic_retriever (SemanticRetriever): An instance of SemanticRetriever.
|
309 |
-
"""
|
310 |
-
self.keyword_retriever = keyword_retriever
|
311 |
-
self.semantic_retriever = semantic_retriever
|
312 |
-
|
313 |
-
def normalize_scores(self, scores):
|
314 |
-
"""
|
315 |
-
Normalize a list of scores to a [0, 1] range.
|
316 |
-
|
317 |
-
Args:
|
318 |
-
scores (list): List of scores.
|
319 |
-
|
320 |
-
Returns:
|
321 |
-
list: Normalized scores.
|
322 |
-
"""
|
323 |
-
min_score = min(scores)
|
324 |
-
max_score = max(scores)
|
325 |
-
if max_score == min_score:
|
326 |
-
return [0.5] * len(scores) # Avoid division by zero if all scores are the same
|
327 |
-
return [(score - min_score) / (max_score - min_score) for score in scores]
|
328 |
-
|
329 |
-
def retrieve(self, query, top_k=5):
|
330 |
-
"""
|
331 |
-
Retrieve top-k chunks by combining keyword and semantic relevance.
|
332 |
-
|
333 |
-
Args:
|
334 |
-
query (str): The search query.
|
335 |
-
top_k (int): Number of top results to retrieve.
|
336 |
-
|
337 |
-
Returns:
|
338 |
-
list: List of (chunk, combined_score) tuples sorted by combined relevance.
|
339 |
-
"""
|
340 |
-
# Retrieve results from both retrievers
|
341 |
-
keyword_results = self.keyword_retriever.retrieve(query, top_k=top_k)
|
342 |
-
semantic_results = self.semantic_retriever.retrieve(query, top_k=top_k)
|
343 |
-
|
344 |
-
# Extract chunks and scores from both retrievers
|
345 |
-
keyword_chunks, keyword_scores = zip(*keyword_results) if keyword_results else ([], [])
|
346 |
-
semantic_chunks, semantic_scores = zip(*semantic_results) if semantic_results else ([], [])
|
347 |
-
|
348 |
-
# Normalize scores for both retrievers
|
349 |
-
normalized_keyword_scores = self.normalize_scores(keyword_scores) if keyword_scores else []
|
350 |
-
normalized_semantic_scores = self.normalize_scores(semantic_scores) if semantic_scores else []
|
351 |
-
|
352 |
-
# Combine results by creating a mapping of chunk -> combined score
|
353 |
-
score_map = {}
|
354 |
-
|
355 |
-
# Add keyword scores to the map
|
356 |
-
for chunk, score in zip(keyword_chunks, normalized_keyword_scores):
|
357 |
-
score_map[chunk] = score_map.get(chunk, 0) + score
|
358 |
-
|
359 |
-
# Add semantic scores to the map
|
360 |
-
for chunk, score in zip(semantic_chunks, normalized_semantic_scores):
|
361 |
-
score_map[chunk] = score_map.get(chunk, 0) + score
|
362 |
-
|
363 |
-
# Sort the results by combined score
|
364 |
-
sorted_results = sorted(score_map.items(), key=lambda x: x[1], reverse=True)
|
365 |
-
|
366 |
-
# Return top-k results
|
367 |
-
return sorted_results[:top_k]
|
368 |
-
|
369 |
-
class Reranker:
|
370 |
-
def __init__(self, model_name='BAAI/bge-reranker-large'):
|
371 |
-
self.model = FlagReranker(model_name, use_fp16=True)
|
372 |
-
|
373 |
-
def rerank(self, query, retrieved_results, filter_num=1):
|
374 |
-
"""
|
375 |
-
Rerank the retrieved results based on query relevance.
|
376 |
-
|
377 |
-
Args:
|
378 |
-
query (str): Query string.
|
379 |
-
retrieved_results (list): List of chunks in text format.
|
380 |
-
|
381 |
-
Returns:
|
382 |
-
list: Reranked results as a list of chunks (sorted by relevance).
|
383 |
-
"""
|
384 |
-
if not retrieved_results:
|
385 |
-
return []
|
386 |
-
|
387 |
-
input_pairs = [(query, chunk) for chunk in retrieved_results]
|
388 |
-
scores = self.model.compute_score(input_pairs)
|
389 |
-
|
390 |
-
reranked_results = sorted(
|
391 |
-
zip(retrieved_results, scores),
|
392 |
-
key=lambda x: x[1],
|
393 |
-
reverse=True
|
394 |
-
)
|
395 |
-
|
396 |
-
reranked_chunks = [chunk for chunk, _ in reranked_results]
|
397 |
-
reranked_chunks = reranked_chunks[:filter_num]
|
398 |
-
return reranked_chunks
|
399 |
-
|
400 |
-
class Assistant:
|
401 |
-
def __init__(self, retriever, reranker=None, model="groq/llama3-8b-8192", citation=None):
|
402 |
-
"""
|
403 |
-
Initialize the Assistant.
|
404 |
-
|
405 |
-
Args:
|
406 |
-
retriever (Retriever): An instance of a Retriever class (KeyWordRetriever or SemanticRetriever).
|
407 |
-
model (str): The name of the LLM model to use (default is "groq/llama3-8b-8192").
|
408 |
-
"""
|
409 |
-
self.retriever = retriever
|
410 |
-
self.model = model
|
411 |
-
self.reranker = reranker
|
412 |
-
self.citation = citation
|
413 |
-
|
414 |
-
def simulate_llm_response(self, prompt, context, api_key):
|
415 |
-
"""
|
416 |
-
Simulate an LLM response for demonstration purposes.
|
417 |
-
|
418 |
-
Args:
|
419 |
-
prompt (str): The prompt to send to the simulated LLM.
|
420 |
-
context (str): The context to include in the prompt.
|
421 |
-
api_key (str): The API key for Groq.
|
422 |
-
|
423 |
-
Returns:
|
424 |
-
str: The generated completion text.
|
425 |
-
"""
|
426 |
-
os.environ['GROQ_API_KEY'] = api_key
|
427 |
-
instruction = """
|
428 |
-
|
429 |
-
Contextual AI Assistant
|
430 |
-
|
431 |
-
You are an AI assistant designed to provide concise, accurate, and clear responses. Always adhere to the following principles:
|
432 |
-
|
433 |
-
Core Principles:
|
434 |
-
|
435 |
-
- Truthfulness: Prioritize accuracy. If unsure, acknowledge the limitation without guessing.
|
436 |
-
- Contextual Understanding: Analyze the conversation history to understand the user's intent.
|
437 |
-
- Clarity and Conciseness: Provide brief, direct answers without unnecessary elaboration.
|
438 |
-
- Helpful Guidance: Offer practical suggestions when relevant, but keep it concise.
|
439 |
-
- Error Handling: Acknowledge limitations and suggest alternatives when unable to answer.
|
440 |
-
Important! Maximum length of your answer can be of 3-4 sentences.
|
441 |
-
"""
|
442 |
-
|
443 |
-
|
444 |
-
response = completion(
|
445 |
-
model=self.model,
|
446 |
-
messages=[
|
447 |
-
{"role": "system", "content": instruction},
|
448 |
-
{"role": "system", "content": context},
|
449 |
-
{"role": "user", "content": prompt}
|
450 |
-
],
|
451 |
-
stream=True
|
452 |
-
)
|
453 |
-
|
454 |
-
generated_text = ""
|
455 |
-
for chunk in response:
|
456 |
-
generated_text += str(chunk["choices"][0]['delta']['content'])
|
457 |
-
generated_text = generated_text[:-4]
|
458 |
-
|
459 |
-
# max_line_length = 160
|
460 |
-
# final_text = textwrap.fill(generated_text, width=max_line_length)
|
461 |
-
return generated_text
|
462 |
-
|
463 |
-
def handle_query(self, query, api_key, retriever_type="semantic", top_k=5, use_reranker=False):
|
464 |
-
"""
|
465 |
-
Handle the user's query by retrieving relevant chunks and generating a simulated LLM response.
|
466 |
-
|
467 |
-
Args:
|
468 |
-
query (str): The user's query.
|
469 |
-
retriever_type (str): Type of retriever to use ("semantic" or "keyword").
|
470 |
-
top_k (int): Number of top results to retrieve.
|
471 |
-
|
472 |
-
Returns:
|
473 |
-
str: The generated response from the simulated LLM.
|
474 |
-
"""
|
475 |
-
if retriever_type.lower() == "keyword":
|
476 |
-
retrieved_chunks = [chunk for chunk, _ in self.retriever.retrieve(query, top_k=top_k)]
|
477 |
-
elif retriever_type.lower() == "semantic":
|
478 |
-
retrieved_chunks = [chunk for chunk, _ in self.retriever.retrieve(query, top_k=top_k)]
|
479 |
-
elif retriever_type.lower() == "hybrid":
|
480 |
-
retrieved_chunks = [chunk for chunk, _ in self.retriever.retrieve(query, top_k=top_k)]
|
481 |
-
else:
|
482 |
-
raise ValueError(f"Unknown retriever type: {retriever_type}")
|
483 |
-
|
484 |
-
if use_reranker and self.reranker:
|
485 |
-
reranked_results = self.reranker.rerank(query, retrieved_chunks)
|
486 |
-
citations = self.citation.search_citate(reranked_results)
|
487 |
-
reranked_chunks = " ".join(reranked_results)
|
488 |
-
return self.simulate_llm_response(query, reranked_chunks, api_key), reranked_chunks, citations
|
489 |
-
|
490 |
-
citations = self.citation.search_citate(retrieved_chunks)
|
491 |
-
retrieved_chunks_string = " ".join(retrieved_chunks)
|
492 |
-
print(retrieved_chunks_string)
|
493 |
-
return self.simulate_llm_response(query, retrieved_chunks_string, api_key), retrieved_chunks_string, citations
|
494 |
-
|
495 |
-
pdf_path = "sherlock.pdf"
|
496 |
|
497 |
chunks = process_pdf_for_rag(pdf_path, chunk_size=500)
|
498 |
-
pdf_text = extract_text_from_pdf(
|
499 |
|
|
|
|
|
|
|
|
|
500 |
keyword_retriever = KeyWordRetriever(chunks)
|
501 |
semantic_retriever = SemanticRetriever(chunks)
|
502 |
hybrid_retriever = HybridRetriever(keyword_retriever, semantic_retriever)
|
503 |
|
|
|
504 |
reranker = Reranker()
|
505 |
-
|
506 |
-
cleaned_text = clean_text(pdf_text)
|
507 |
-
citation = Citation(cleaned_text)
|
508 |
-
query = 'Who is Irene?'
|
509 |
-
|
510 |
-
key = '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'
|
511 |
assistant = Assistant(hybrid_retriever, reranker, citation=citation)
|
512 |
-
response = assistant.handle_query(query, api_key=key, retriever_type="hybrid", top_k=5, use_reranker=False)
|
513 |
-
|
514 |
-
|
515 |
-
print("Assistant's Response:")
|
516 |
-
print(response)
|
517 |
|
|
|
518 |
def run_rag_ui(api_key, query, retriever_type, top_k, use_reranker):
|
519 |
if retriever_type.lower() == "keyword":
|
520 |
retriever = keyword_retriever
|
@@ -530,11 +47,15 @@ def run_rag_ui(api_key, query, retriever_type, top_k, use_reranker):
|
|
530 |
cleaned_text = clean_text(pdf_text)
|
531 |
citation = Citation(cleaned_text)
|
532 |
assistant = Assistant(retriever, reranker, citation=citation)
|
533 |
-
response, retrieved_chunks, citations = assistant.handle_query(
|
|
|
|
|
|
|
|
|
|
|
534 |
|
535 |
return response, citations, retrieved_chunks
|
536 |
|
537 |
-
|
538 |
iface = gr.Interface(
|
539 |
fn=run_rag_ui,
|
540 |
inputs=[
|
@@ -553,6 +74,4 @@ iface = gr.Interface(
|
|
553 |
description="Enter your query, select the retrieval method, and get retrieved chunks along with LLM responses."
|
554 |
)
|
555 |
|
556 |
-
|
557 |
iface.launch(share=True)
|
558 |
-
|
|
|
|
|
|
|
|
|
|
|
1 |
import nltk
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import gradio as gr
|
|
|
|
|
|
|
3 |
|
4 |
+
from assistant import Assistant
|
5 |
+
from citation import Citation
|
6 |
+
from retrievers import KeyWordRetriever, SemanticRetriever, HybridRetriever
|
7 |
+
from text_processing import extract_text_from_pdf, clean_text
|
8 |
+
from embeddings import process_pdf_for_rag
|
9 |
+
from reranker import Reranker
|
10 |
|
11 |
nltk.download('punkt', quiet=True)
|
12 |
nltk.download('punkt')
|
13 |
nltk.download('punkt_tab')
|
14 |
|
15 |
+
# Load and preprocess PDF
|
16 |
+
data_path = "data"
|
17 |
+
pdf_path = data_path + "/sherlock.pdf"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
chunks = process_pdf_for_rag(pdf_path, chunk_size=500)
|
20 |
+
pdf_text = extract_text_from_pdf(pdf_path)
|
21 |
|
22 |
+
cleaned_text = clean_text(pdf_text)
|
23 |
+
citation = Citation(cleaned_text)
|
24 |
+
|
25 |
+
# Initialize retrievers
|
26 |
keyword_retriever = KeyWordRetriever(chunks)
|
27 |
semantic_retriever = SemanticRetriever(chunks)
|
28 |
hybrid_retriever = HybridRetriever(keyword_retriever, semantic_retriever)
|
29 |
|
30 |
+
# Initialize assistant
|
31 |
reranker = Reranker()
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
assistant = Assistant(hybrid_retriever, reranker, citation=citation)
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
+
# Gradio UI
|
35 |
def run_rag_ui(api_key, query, retriever_type, top_k, use_reranker):
|
36 |
if retriever_type.lower() == "keyword":
|
37 |
retriever = keyword_retriever
|
|
|
47 |
cleaned_text = clean_text(pdf_text)
|
48 |
citation = Citation(cleaned_text)
|
49 |
assistant = Assistant(retriever, reranker, citation=citation)
|
50 |
+
response, retrieved_chunks, citations = assistant.handle_query(
|
51 |
+
query, api_key,
|
52 |
+
retriever_type=retriever_type,
|
53 |
+
top_k=top_k,
|
54 |
+
use_reranker=use_reranker
|
55 |
+
)
|
56 |
|
57 |
return response, citations, retrieved_chunks
|
58 |
|
|
|
59 |
iface = gr.Interface(
|
60 |
fn=run_rag_ui,
|
61 |
inputs=[
|
|
|
74 |
description="Enter your query, select the retrieval method, and get retrieved chunks along with LLM responses."
|
75 |
)
|
76 |
|
|
|
77 |
iface.launch(share=True)
|
|
assistant.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from litellm import completion
|
3 |
+
|
4 |
+
class Assistant:
|
5 |
+
def __init__(self, retriever, reranker=None, model="groq/llama3-8b-8192", citation=None):
|
6 |
+
"""
|
7 |
+
Initialize the Assistant.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
retriever (Retriever): An instance of a Retriever class (KeyWordRetriever or SemanticRetriever).
|
11 |
+
model (str): The name of the LLM model to use (default is "groq/llama3-8b-8192").
|
12 |
+
"""
|
13 |
+
self.retriever = retriever
|
14 |
+
self.model = model
|
15 |
+
self.reranker = reranker
|
16 |
+
self.citation = citation
|
17 |
+
|
18 |
+
def simulate_llm_response(self, prompt, context, api_key):
|
19 |
+
"""
|
20 |
+
Simulate an LLM response for demonstration purposes.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
prompt (str): The prompt to send to the simulated LLM.
|
24 |
+
context (str): The context to include in the prompt.
|
25 |
+
api_key (str): The API key for Groq.
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
str: The generated completion text.
|
29 |
+
"""
|
30 |
+
os.environ['GROQ_API_KEY'] = api_key
|
31 |
+
instruction = """
|
32 |
+
|
33 |
+
Contextual AI Assistant
|
34 |
+
|
35 |
+
You are an AI assistant designed to provide concise, accurate, and clear responses. Always adhere to the following principles:
|
36 |
+
|
37 |
+
Core Principles:
|
38 |
+
|
39 |
+
- Truthfulness: Prioritize accuracy. If unsure, acknowledge the limitation without guessing.
|
40 |
+
- Contextual Understanding: Analyze the conversation history to understand the user's intent.
|
41 |
+
- Clarity and Conciseness: Provide brief, direct answers without unnecessary elaboration.
|
42 |
+
- Helpful Guidance: Offer practical suggestions when relevant, but keep it concise.
|
43 |
+
- Error Handling: Acknowledge limitations and suggest alternatives when unable to answer.
|
44 |
+
Important! Maximum length of your answer can be of 3-4 sentences.
|
45 |
+
"""
|
46 |
+
|
47 |
+
|
48 |
+
response = completion(
|
49 |
+
model=self.model,
|
50 |
+
messages=[
|
51 |
+
{"role": "system", "content": instruction},
|
52 |
+
{"role": "system", "content": context},
|
53 |
+
{"role": "user", "content": prompt}
|
54 |
+
],
|
55 |
+
stream=True
|
56 |
+
)
|
57 |
+
|
58 |
+
generated_text = ""
|
59 |
+
for chunk in response:
|
60 |
+
generated_text += str(chunk["choices"][0]['delta']['content'])
|
61 |
+
generated_text = generated_text[:-4]
|
62 |
+
|
63 |
+
# max_line_length = 160
|
64 |
+
# final_text = textwrap.fill(generated_text, width=max_line_length)
|
65 |
+
return generated_text
|
66 |
+
|
67 |
+
def handle_query(self, query, api_key, retriever_type="semantic", top_k=5, use_reranker=False):
|
68 |
+
"""
|
69 |
+
Handle the user's query by retrieving relevant chunks and generating a simulated LLM response.
|
70 |
+
|
71 |
+
Args:
|
72 |
+
query (str): The user's query.
|
73 |
+
retriever_type (str): Type of retriever to use ("semantic" or "keyword").
|
74 |
+
top_k (int): Number of top results to retrieve.
|
75 |
+
|
76 |
+
Returns:
|
77 |
+
str: The generated response from the simulated LLM.
|
78 |
+
"""
|
79 |
+
if retriever_type.lower() == "keyword":
|
80 |
+
retrieved_chunks = [chunk for chunk, _ in self.retriever.retrieve(query, top_k=top_k)]
|
81 |
+
elif retriever_type.lower() == "semantic":
|
82 |
+
retrieved_chunks = [chunk for chunk, _ in self.retriever.retrieve(query, top_k=top_k)]
|
83 |
+
elif retriever_type.lower() == "hybrid":
|
84 |
+
retrieved_chunks = [chunk for chunk, _ in self.retriever.retrieve(query, top_k=top_k)]
|
85 |
+
else:
|
86 |
+
raise ValueError(f"Unknown retriever type: {retriever_type}")
|
87 |
+
|
88 |
+
if use_reranker and self.reranker:
|
89 |
+
reranked_results = self.reranker.rerank(query, retrieved_chunks)
|
90 |
+
citations = self.citation.search_citate(reranked_results)
|
91 |
+
reranked_chunks = " ".join(reranked_results)
|
92 |
+
return self.simulate_llm_response(query, reranked_chunks, api_key), reranked_chunks, citations
|
93 |
+
|
94 |
+
citations = self.citation.search_citate(retrieved_chunks)
|
95 |
+
retrieved_chunks_string = " ".join(retrieved_chunks)
|
96 |
+
print(retrieved_chunks_string)
|
97 |
+
return self.simulate_llm_response(query, retrieved_chunks_string, api_key), retrieved_chunks_string, citations
|
citation.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import pandas as pd
|
3 |
+
|
4 |
+
class Citation:
|
5 |
+
def __init__(self, cleaned_text):
|
6 |
+
"""
|
7 |
+
Initialize with cleaned text.
|
8 |
+
:param cleaned_text: The cleaned text from the PDF.
|
9 |
+
"""
|
10 |
+
self.cleaned_text = cleaned_text
|
11 |
+
self.chapters = self.extract_chapter_names()
|
12 |
+
self.chapter_data = self.extract_chapter_data()
|
13 |
+
self.df = self.to_dataframe()
|
14 |
+
|
15 |
+
def extract_chapter_names(self):
|
16 |
+
"""
|
17 |
+
Extract chapter names using the provided `extract_chapter_names` function.
|
18 |
+
:return: List of chapter names.
|
19 |
+
"""
|
20 |
+
toc_match = re.search(r"Table of contents(.*?)Table of contents", self.cleaned_text, re.DOTALL)
|
21 |
+
if not toc_match:
|
22 |
+
raise ValueError("Table of Contents section not found.")
|
23 |
+
|
24 |
+
toc_section = toc_match.group(1)
|
25 |
+
chapter_names = re.findall(r"(.*?)\d+", toc_section)
|
26 |
+
return [name.strip() for name in chapter_names if name.strip()]
|
27 |
+
|
28 |
+
def extract_chapter_data(self):
|
29 |
+
"""
|
30 |
+
Extract full text for each chapter from the second occurrence of the chapter name
|
31 |
+
to the second occurrence of the next chapter name, or to the end if it's the last chapter.
|
32 |
+
:return: A list of tuples (chapter_name, plain_text).
|
33 |
+
Second occurence because first in text is in table of contents :)
|
34 |
+
"""
|
35 |
+
chapter_data = []
|
36 |
+
|
37 |
+
for i, chapter in enumerate(self.chapters):
|
38 |
+
current_chapter_pattern = re.escape(chapter)
|
39 |
+
next_chapter_pattern = re.escape(self.chapters[i + 1]) if i + 1 < len(self.chapters) else None
|
40 |
+
|
41 |
+
if next_chapter_pattern:
|
42 |
+
matches = list(re.finditer(rf"{current_chapter_pattern}(.*?){next_chapter_pattern}", self.cleaned_text, re.DOTALL))
|
43 |
+
else:
|
44 |
+
matches = list(re.finditer(rf"{current_chapter_pattern}(.*)", self.cleaned_text, re.DOTALL))
|
45 |
+
|
46 |
+
if len(matches) >= 2:
|
47 |
+
start = matches[1].start(1)
|
48 |
+
end = matches[1].end(1)
|
49 |
+
|
50 |
+
if next_chapter_pattern:
|
51 |
+
next_match = re.search(rf"{next_chapter_pattern}", self.cleaned_text[end:])
|
52 |
+
if next_match:
|
53 |
+
end += next_match.start()
|
54 |
+
|
55 |
+
chapter_content = self.cleaned_text[start:end].strip()
|
56 |
+
else:
|
57 |
+
chapter_content = ""
|
58 |
+
|
59 |
+
if i == len(self.chapters) - 1:
|
60 |
+
last_chapter_matches = list(re.finditer(rf"{current_chapter_pattern}", self.cleaned_text))
|
61 |
+
|
62 |
+
if len(last_chapter_matches) >= 2:
|
63 |
+
start = last_chapter_matches[1].start()
|
64 |
+
chapter_content = self.cleaned_text[start:].strip()
|
65 |
+
|
66 |
+
chapter_data.append((chapter, chapter_content))
|
67 |
+
|
68 |
+
return chapter_data
|
69 |
+
|
70 |
+
def to_dataframe(self):
|
71 |
+
"""
|
72 |
+
Converts the chapter data into a Pandas DataFrame.
|
73 |
+
:return: DataFrame with 'Chapter' and 'Content' columns.
|
74 |
+
"""
|
75 |
+
df = pd.DataFrame(self.chapter_data, columns=['Chapter', 'Content'])
|
76 |
+
return df
|
77 |
+
|
78 |
+
def search_citate(self, retrieved_chunks):
|
79 |
+
"""
|
80 |
+
Finds the chapter name for each chunk in the content and returns a formatted string.
|
81 |
+
|
82 |
+
Args:
|
83 |
+
retrieved_chunks (list): List of text chunks to search in the chapters' content.
|
84 |
+
|
85 |
+
Returns:
|
86 |
+
str: A formatted string of citations, each chunk starting on a new line.
|
87 |
+
"""
|
88 |
+
citations = []
|
89 |
+
|
90 |
+
for idx, chunk in enumerate(retrieved_chunks, start=1):
|
91 |
+
for _, row in self.df.iterrows():
|
92 |
+
chapter_name = row['Chapter']
|
93 |
+
chapter_content = row['Content']
|
94 |
+
|
95 |
+
if chunk in chapter_content:
|
96 |
+
citations.append(f"chunk [{idx}] from chapter: {chapter_name}")
|
97 |
+
break
|
98 |
+
|
99 |
+
return "\n".join(citations)
|
sherlock.pdf → data/sherlock.pdf
RENAMED
File without changes
|
embeddings.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
2 |
+
from sentence_transformers import SentenceTransformer
|
3 |
+
from text_processing import extract_text_from_pdf, clean_text
|
4 |
+
|
5 |
+
def chunk_text(text, chunk_size=1000, chunk_overlap=150):
|
6 |
+
"""
|
7 |
+
Split text into overlapping chunks.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
text (str): Input text.
|
11 |
+
chunk_size (int): Size of each chunk.
|
12 |
+
chunk_overlap (int): Overlap between chunks.
|
13 |
+
|
14 |
+
Returns:
|
15 |
+
list: List of text chunks.
|
16 |
+
"""
|
17 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
18 |
+
chunk_size=chunk_size,
|
19 |
+
chunk_overlap=chunk_overlap,
|
20 |
+
separators=["\n\n", "\n", " "]
|
21 |
+
)
|
22 |
+
return text_splitter.split_text(text)
|
23 |
+
|
24 |
+
def generate_embeddings(chunks, model_name='all-MiniLM-L6-v2'):
|
25 |
+
"""
|
26 |
+
Generate embeddings for text chunks.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
chunks (list): List of text chunks.
|
30 |
+
model_name (str): SentenceTransformer model name.
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
np.ndarray: Array of embeddings.
|
34 |
+
"""
|
35 |
+
model = SentenceTransformer(model_name)
|
36 |
+
return model.encode(chunks, convert_to_numpy=True)
|
37 |
+
|
38 |
+
def process_pdf_for_rag(pdf_path, chunk_size=500):
|
39 |
+
"""
|
40 |
+
Process a PDF for RAG by extracting, cleaning, and chunking.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
pdf_path (str): Path to the PDF file.
|
44 |
+
chunk_size (int): Size of each chunk.
|
45 |
+
|
46 |
+
Returns:
|
47 |
+
list: List of text chunks.
|
48 |
+
"""
|
49 |
+
|
50 |
+
print("Extracting text from PDF...")
|
51 |
+
raw_text = extract_text_from_pdf(pdf_path)
|
52 |
+
print("Cleaning text...")
|
53 |
+
clean_text_content = clean_text(raw_text)
|
54 |
+
print("Chunking text...")
|
55 |
+
chunks = chunk_text(clean_text_content, chunk_size)
|
56 |
+
print("Processing complete!")
|
57 |
+
|
58 |
+
return chunks
|
reranker.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from FlagEmbedding import FlagReranker
|
2 |
+
|
3 |
+
class Reranker:
|
4 |
+
def __init__(self, model_name='BAAI/bge-reranker-large'):
|
5 |
+
self.model = FlagReranker(model_name, use_fp16=True)
|
6 |
+
|
7 |
+
def rerank(self, query, retrieved_results, filter_num=1):
|
8 |
+
"""
|
9 |
+
Rerank the retrieved results based on query relevance.
|
10 |
+
|
11 |
+
Args:
|
12 |
+
query (str): Query string.
|
13 |
+
retrieved_results (list): List of chunks in text format.
|
14 |
+
|
15 |
+
Returns:
|
16 |
+
list: Reranked results as a list of chunks (sorted by relevance).
|
17 |
+
"""
|
18 |
+
if not retrieved_results:
|
19 |
+
return []
|
20 |
+
|
21 |
+
input_pairs = [(query, chunk) for chunk in retrieved_results]
|
22 |
+
scores = self.model.compute_score(input_pairs)
|
23 |
+
|
24 |
+
reranked_results = sorted(
|
25 |
+
zip(retrieved_results, scores),
|
26 |
+
key=lambda x: x[1],
|
27 |
+
reverse=True
|
28 |
+
)
|
29 |
+
|
30 |
+
reranked_chunks = [chunk for chunk, _ in reranked_results]
|
31 |
+
reranked_chunks = reranked_chunks[:filter_num]
|
32 |
+
return reranked_chunks
|
retrievers.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import faiss
|
3 |
+
from rank_bm25 import BM25Okapi
|
4 |
+
from nltk.tokenize import word_tokenize
|
5 |
+
from sentence_transformers import SentenceTransformer
|
6 |
+
from embeddings import generate_embeddings
|
7 |
+
|
8 |
+
class Retriever:
|
9 |
+
"""
|
10 |
+
Abstract base class for retrievers.
|
11 |
+
"""
|
12 |
+
def retrieve(self, query, top_k=5):
|
13 |
+
raise NotImplementedError
|
14 |
+
|
15 |
+
class KeyWordRetriever(Retriever):
|
16 |
+
"""
|
17 |
+
Keyword-based retriever using BM25.
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(self, chunks):
|
21 |
+
"""
|
22 |
+
Initialize the BM25 retriever with pre-tokenized chunks.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
chunks (list): List of text chunks to index.
|
26 |
+
"""
|
27 |
+
self.tokenized_chunks = [word_tokenize(chunk) for chunk in chunks]
|
28 |
+
self.bm25 = BM25Okapi(self.tokenized_chunks)
|
29 |
+
self.chunks = chunks
|
30 |
+
|
31 |
+
def retrieve(self, query, top_k=5):
|
32 |
+
"""
|
33 |
+
Retrieve top-k chunks based on BM25 scores.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
query (str): The search query.
|
37 |
+
top_k (int): Number of top results to retrieve.
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
list: List of (chunk, score) tuples sorted by relevance.
|
41 |
+
"""
|
42 |
+
query_tokens = word_tokenize(query)
|
43 |
+
scores = self.bm25.get_scores(query_tokens)
|
44 |
+
ranked_indices = np.argsort(scores)[::-1][:top_k]
|
45 |
+
return [(self.chunks[i], scores[i]) for i in ranked_indices]
|
46 |
+
|
47 |
+
|
48 |
+
class SemanticRetriever(Retriever):
|
49 |
+
"""
|
50 |
+
Semantic retriever using SentenceTransformers and FAISS.
|
51 |
+
"""
|
52 |
+
|
53 |
+
def __init__(self, chunks, model_name='all-MiniLM-L6-v2', index_path="faiss_index"):
|
54 |
+
"""
|
55 |
+
Initialize the semantic retriever with SentenceTransformers and FAISS.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
chunks (list): List of text chunks.
|
59 |
+
model_name (str): Model name for SentenceTransformers.
|
60 |
+
index_path (str): Path to save/load the FAISS index.
|
61 |
+
"""
|
62 |
+
self.chunks = chunks
|
63 |
+
self.model = SentenceTransformer(model_name)
|
64 |
+
self.index_path = index_path
|
65 |
+
self.index = self._create_faiss_index(chunks)
|
66 |
+
|
67 |
+
def _create_faiss_index(self, chunks):
|
68 |
+
"""
|
69 |
+
Create a FAISS index from text chunks.
|
70 |
+
|
71 |
+
Args:
|
72 |
+
chunks (list): List of text chunks.
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
faiss.Index: Trained FAISS index.
|
76 |
+
"""
|
77 |
+
|
78 |
+
embeddings = generate_embeddings(chunks)
|
79 |
+
dimension = embeddings.shape[1]
|
80 |
+
index = faiss.IndexFlatL2(dimension)
|
81 |
+
index.add(embeddings)
|
82 |
+
|
83 |
+
faiss.write_index(index, self.index_path)
|
84 |
+
print(f"FAISS index saved to {self.index_path}")
|
85 |
+
return index
|
86 |
+
|
87 |
+
def retrieve(self, query, top_k=5):
|
88 |
+
"""
|
89 |
+
Retrieve top-k chunks based on semantic similarity.
|
90 |
+
|
91 |
+
Args:
|
92 |
+
query (str): The search query.
|
93 |
+
top_k (int): Number of top results to retrieve.
|
94 |
+
|
95 |
+
Returns:
|
96 |
+
list: List of (chunk, score) tuples sorted by relevance.
|
97 |
+
"""
|
98 |
+
query_embedding = self.model.encode([query], convert_to_numpy=True).astype("float32")
|
99 |
+
faiss.normalize_L2(query_embedding)
|
100 |
+
distances, indices = self.index.search(query_embedding, top_k)
|
101 |
+
results = [(self.chunks[idx], 1 / (1 + distances[0][i])) for i, idx in enumerate(indices[0])]
|
102 |
+
return results
|
103 |
+
|
104 |
+
class HybridRetriever(Retriever):
|
105 |
+
"""
|
106 |
+
Hybrid retriever that combines results from keyword-based and semantic retrievers.
|
107 |
+
"""
|
108 |
+
|
109 |
+
def __init__(self, keyword_retriever, semantic_retriever):
|
110 |
+
"""
|
111 |
+
Initialize the HybridRetriever.
|
112 |
+
|
113 |
+
Args:
|
114 |
+
keyword_retriever (KeyWordRetriever): An instance of KeyWordRetriever.
|
115 |
+
semantic_retriever (SemanticRetriever): An instance of SemanticRetriever.
|
116 |
+
"""
|
117 |
+
self.keyword_retriever = keyword_retriever
|
118 |
+
self.semantic_retriever = semantic_retriever
|
119 |
+
|
120 |
+
def normalize_scores(self, scores):
|
121 |
+
"""
|
122 |
+
Normalize a list of scores to a [0, 1] range.
|
123 |
+
|
124 |
+
Args:
|
125 |
+
scores (list): List of scores.
|
126 |
+
|
127 |
+
Returns:
|
128 |
+
list: Normalized scores.
|
129 |
+
"""
|
130 |
+
min_score = min(scores)
|
131 |
+
max_score = max(scores)
|
132 |
+
if max_score == min_score:
|
133 |
+
return [0.5] * len(scores) # Avoid division by zero if all scores are the same
|
134 |
+
return [(score - min_score) / (max_score - min_score) for score in scores]
|
135 |
+
|
136 |
+
def retrieve(self, query, top_k=5):
|
137 |
+
"""
|
138 |
+
Retrieve top-k chunks by combining keyword and semantic relevance.
|
139 |
+
|
140 |
+
Args:
|
141 |
+
query (str): The search query.
|
142 |
+
top_k (int): Number of top results to retrieve.
|
143 |
+
|
144 |
+
Returns:
|
145 |
+
list: List of (chunk, combined_score) tuples sorted by combined relevance.
|
146 |
+
"""
|
147 |
+
# Retrieve results from both retrievers
|
148 |
+
keyword_results = self.keyword_retriever.retrieve(query, top_k=top_k)
|
149 |
+
semantic_results = self.semantic_retriever.retrieve(query, top_k=top_k)
|
150 |
+
|
151 |
+
# Extract chunks and scores from both retrievers
|
152 |
+
keyword_chunks, keyword_scores = zip(*keyword_results) if keyword_results else ([], [])
|
153 |
+
semantic_chunks, semantic_scores = zip(*semantic_results) if semantic_results else ([], [])
|
154 |
+
|
155 |
+
# Normalize scores for both retrievers
|
156 |
+
normalized_keyword_scores = self.normalize_scores(keyword_scores) if keyword_scores else []
|
157 |
+
normalized_semantic_scores = self.normalize_scores(semantic_scores) if semantic_scores else []
|
158 |
+
|
159 |
+
# Combine results by creating a mapping of chunk -> combined score
|
160 |
+
score_map = {}
|
161 |
+
|
162 |
+
# Add keyword scores to the map
|
163 |
+
for chunk, score in zip(keyword_chunks, normalized_keyword_scores):
|
164 |
+
score_map[chunk] = score_map.get(chunk, 0) + score
|
165 |
+
|
166 |
+
# Add semantic scores to the map
|
167 |
+
for chunk, score in zip(semantic_chunks, normalized_semantic_scores):
|
168 |
+
score_map[chunk] = score_map.get(chunk, 0) + score
|
169 |
+
|
170 |
+
# Sort the results by combined score
|
171 |
+
sorted_results = sorted(score_map.items(), key=lambda x: x[1], reverse=True)
|
172 |
+
|
173 |
+
# Return top-k results
|
174 |
+
return sorted_results[:top_k]
|
text_processing.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import fitz
|
2 |
+
import re
|
3 |
+
|
4 |
+
def extract_text_from_pdf(pdf_path):
|
5 |
+
"""
|
6 |
+
Extract text from a PDF file.
|
7 |
+
|
8 |
+
Args:
|
9 |
+
pdf_path (str): Path to the PDF file.
|
10 |
+
|
11 |
+
Returns:
|
12 |
+
str: Extracted text from the PDF.
|
13 |
+
"""
|
14 |
+
doc = fitz.open(pdf_path)
|
15 |
+
text = ""
|
16 |
+
for page in doc:
|
17 |
+
text += page.get_text()
|
18 |
+
return text
|
19 |
+
|
20 |
+
def clean_text(text):
|
21 |
+
"""
|
22 |
+
Clean and normalize text.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
text (str): Raw text.
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
str: Cleaned text.
|
29 |
+
"""
|
30 |
+
text = re.sub(r'(\s?\.\s?)+', '', text)
|
31 |
+
text = re.sub(r'\s+', ' ', text)
|
32 |
+
text = re.sub(r'[^\x00-\x7F]+', '', text)
|
33 |
+
return text.strip()
|