parthmodi22 commited on
Commit
42b54f3
1 Parent(s): 1b8b75e
Files changed (1) hide show
  1. hogragger.py +251 -0
hogragger.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ import nltk
4
+ from nltk.tokenize import sent_tokenize
5
+ import torch
6
+ from sentence_transformers import SentenceTransformer, util
7
+ import faiss
8
+ import numpy as np
9
+ from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer
10
+ from rank_bm25 import BM25Okapi # BM25 for hybrid search
11
+ import logging
12
+
13
+
14
+ nltk.download('punkt', quiet=True)
15
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
16
+
17
+
18
+ class Hogragger:
19
+ def __init__(self, corpus_path, model_name='sentence-transformers/all-MiniLM-L12-v2', qa_model='deepset/roberta-large-squad2', classifier_model='deepset/roberta-large-squad2'):
20
+ self.corpus = self.load_corpus(corpus_path)
21
+ self.cleaned_passages = self.preprocess_corpus()
22
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
23
+ logging.info(f"Using device: {self.device}")
24
+
25
+ # Initialize embedding model and build FAISS index
26
+ self.model = SentenceTransformer(model_name).to(self.device)
27
+ self.index = self.build_faiss_index()
28
+
29
+ # Initialize BM25 for lexical matching
30
+ self.bm25 = self.build_bm25_index()
31
+
32
+ # Initialize classifier for question type prediction
33
+ self.tokenizer = AutoTokenizer.from_pretrained(classifier_model)
34
+ self.classifier = AutoModelForSequenceClassification.from_pretrained(classifier_model).to(self.device)
35
+
36
+ # QA Model
37
+ self.qa_model = pipeline('question-answering', model=qa_model, device=0 if self.device == 'cuda' else -1)
38
+
39
+ def load_corpus(self, path):
40
+ logging.info(f"Loading corpus from {path}")
41
+ with open(path, "r") as f:
42
+ corpus = json.load(f)
43
+ logging.info(f"Loaded {len(corpus)} documents")
44
+ return corpus
45
+
46
+ # def preprocess_corpus(self):
47
+ # cleaned_passages = []
48
+ # for article in self.corpus:
49
+ # body = article.get('body', '')
50
+ # clean_body = re.sub(r'<.*?>', '', body) # Clean HTML tags
51
+ # clean_body = re.sub(r'\s+', ' ', clean_body).strip() # Clean extra spaces
52
+ # sentences = sent_tokenize(clean_body)
53
+
54
+ # chunk = ""
55
+ # for sentence in sentences:
56
+ # if len(chunk.split()) + len(sentence.split()) <= 300:
57
+ # chunk += " " + sentence
58
+ # else:
59
+ # cleaned_passages.append(self.create_passage(article, chunk))
60
+ # chunk = sentence
61
+
62
+ # if chunk:
63
+ # cleaned_passages.append(self.create_passage(article, chunk))
64
+ # logging.info(f"Created {len(cleaned_passages)} passages")
65
+ # return cleaned_passages
66
+ def preprocess_corpus(self):
67
+ cleaned_passages = []
68
+ for article in self.corpus:
69
+ body = article.get('body', '')
70
+ clean_body = re.sub(r'<.*?>', '', body) # Clean HTML tags
71
+ clean_body = re.sub(r'\s+', ' ', clean_body).strip() # Clean extra spaces
72
+
73
+ # Simply take the full cleaned text as a passage without chunking or sentence splitting
74
+ cleaned_passages.append(self.create_passage(article, clean_body))
75
+
76
+ logging.info(f"Created {len(cleaned_passages)} passages")
77
+ return cleaned_passages
78
+
79
+ def create_passage(self, article, chunk):
80
+ """Creates a passage dictionary from an article and chunk of text."""
81
+ return {
82
+ "title": article['title'],
83
+ "author": article.get('author', 'Unknown'),
84
+ "published_at": article['published_at'],
85
+ "category": article['category'],
86
+ "url": article['url'],
87
+ "source": article['source'],
88
+ "passage": chunk.strip()
89
+ }
90
+
91
+ def build_faiss_index(self):
92
+ logging.info("Building FAISS index...")
93
+ embeddings = self.model.encode([p['passage'] for p in self.cleaned_passages], convert_to_tensor=True, device=self.device)
94
+ embeddings = np.array(embeddings.cpu()).astype('float32')
95
+ logging.info(f"Shape of embeddings: {embeddings.shape}")
96
+
97
+ index = faiss.IndexFlatL2(embeddings.shape[1]) # Initialize FAISS index
98
+
99
+ if self.device == 'cuda':
100
+ try:
101
+ res = faiss.StandardGpuResources()
102
+ gpu_index = faiss.index_cpu_to_gpu(res, 0, index)
103
+ gpu_index.add(embeddings)
104
+ logging.info("Successfully created GPU index")
105
+ return gpu_index
106
+ except RuntimeError as e:
107
+ logging.error(f"GPU index creation failed: {e}")
108
+ logging.info("Falling back to CPU index")
109
+
110
+ index.add(embeddings) # Add embeddings to CPU index
111
+ logging.info("Successfully created CPU index")
112
+ return index
113
+
114
+ def build_bm25_index(self):
115
+ logging.info("Building BM25 index...")
116
+ tokenized_corpus = [p['passage'].split() for p in self.cleaned_passages]
117
+ bm25 = BM25Okapi(tokenized_corpus)
118
+ logging.info("Successfully built BM25 index")
119
+ return bm25
120
+
121
+ def predict_question_type(self, query):
122
+ inputs = self.tokenizer(query, return_tensors='pt').to(self.device)
123
+ outputs = self.classifier(**inputs)
124
+ prediction = torch.argmax(outputs.logits, dim=1).item()
125
+
126
+ labels = {0: 'inference_query', 1: 'comparison_query', 2: 'null_query', 3: 'temporal_query', 4: 'fact_query'}
127
+ return labels.get(prediction, 'unknown_query')
128
+
129
+ def retrieve_passages(self, query, k=100, threshold=0.7):
130
+ try:
131
+ # FAISS retrieval
132
+ query_embedding = self.model.encode([query], convert_to_tensor=True, device=self.device)
133
+ D, I = self.index.search(np.array(query_embedding.cpu()), k)
134
+
135
+ # BM25 retrieval
136
+ tokenized_query = query.split()
137
+ bm25_scores = self.bm25.get_scores(tokenized_query)
138
+
139
+ # Combine FAISS and BM25 results
140
+ hybrid_scores = self.combine_faiss_bm25_scores(D[0], bm25_scores, I)
141
+
142
+ # Filter passages based on hybrid score
143
+ passages = [self.cleaned_passages[i] for i, score in zip(I[0], hybrid_scores) if score > threshold]
144
+
145
+ logging.info(f"Retrieved {len(passages)} passages using hybrid search for query.")
146
+ return passages
147
+ except Exception as e:
148
+ logging.error(f"Error in retrieving passages: {e}")
149
+ return []
150
+
151
+ def combine_faiss_bm25_scores(self, faiss_scores, bm25_scores, passage_indices):
152
+ # Normalize and combine FAISS and BM25 scores
153
+ bm25_scores = np.array(bm25_scores)[passage_indices]
154
+ faiss_scores = np.array(faiss_scores)
155
+
156
+ # Convert FAISS distances into similarities by inverting the scale
157
+ faiss_similarities = 1 / (faiss_scores + 1e-6) # Avoid division by zero
158
+
159
+ # Normalize scores (scale between 0 and 1)
160
+ bm25_scores = (bm25_scores - np.min(bm25_scores)) / (np.max(bm25_scores) - np.min(bm25_scores) + 1e-6)
161
+ faiss_similarities = (faiss_similarities - np.min(faiss_similarities)) / (np.max(faiss_similarities) - np.min(faiss_similarities) + 1e-6)
162
+
163
+ # Weighted combination (you can adjust weights)
164
+ combined_scores = 0.7 * faiss_similarities + 0.3 * bm25_scores
165
+ combined_scores = np.squeeze(combined_scores) # Ensure it's a single-dimensional array
166
+
167
+ return combined_scores
168
+
169
+ def filter_passages(self, query, passages):
170
+ try:
171
+ query_embedding = self.model.encode(query, convert_to_tensor=True)
172
+ passage_embeddings = self.model.encode([p['passage'] for p in passages], convert_to_tensor=True)
173
+
174
+ similarities = util.pytorch_cos_sim(query_embedding, passage_embeddings)
175
+ top_k = min(10, len(passages))
176
+ top_indices = similarities.topk(k=top_k)[1].tolist()[0]
177
+
178
+ selected_passages = []
179
+ used_titles = set()
180
+ for i in top_indices:
181
+ if passages[i]['title'] not in used_titles:
182
+ selected_passages.append(passages[i])
183
+ used_titles.add(passages[i]['title'])
184
+
185
+ return selected_passages
186
+ except Exception as e:
187
+ logging.error(f"Error in filtering passages: {e}")
188
+ return []
189
+
190
+ def generate_answer(self, query, passages):
191
+ try:
192
+ context = " ".join([p['passage'] for p in passages[:5]])
193
+ answer = self.qa_model(question=query, context=context)
194
+ logging.info(f"Generated answer: {answer['answer']}")
195
+ return answer['answer']
196
+ except Exception as e:
197
+ logging.error(f"Error in generating answer: {e}")
198
+ return "Insufficient information."
199
+
200
+ def post_process_answer(self, answer, confidence=0.2):
201
+ answer = re.sub(r'^.*\?', '', answer).strip()
202
+ answer = answer.capitalize()
203
+
204
+ if len(answer) > 100:
205
+ truncated = re.match(r'^(.*?[.!?])\s', answer)
206
+ if truncated:
207
+ answer = truncated.group(1)
208
+
209
+ if confidence < 0.2:
210
+ logging.warning(f"Answer confidence too low: {confidence}")
211
+ return "I'm unsure about this answer."
212
+
213
+ return answer
214
+
215
+ def process_query(self, query):
216
+ question_type = self.predict_question_type(query)
217
+ retrieved_passages = self.retrieve_passages(query, k=100, threshold=0.7)
218
+ if not retrieved_passages:
219
+ return {"query": query, "answer": "No relevant information found", "question_type": question_type, "evidence_list": []}
220
+
221
+ filtered_passages = self.filter_passages(query, retrieved_passages)
222
+ raw_answer = self.generate_answer(query, filtered_passages)
223
+
224
+ evidence_count = min(len(filtered_passages), 4)
225
+ evidence_list = [
226
+ {
227
+ "title": p['title'],
228
+ "author": p['author'],
229
+ "url": p['url'],
230
+ "source": p['source'],
231
+ "category": p['category'],
232
+ "published_at": p['published_at'],
233
+ "fact": self.extract_fact(p['passage'], query)
234
+ } for p in filtered_passages[:evidence_count]
235
+ ]
236
+ final_answer = self.post_process_answer(raw_answer)
237
+
238
+ return {
239
+ "query": query,
240
+ "answer": final_answer,
241
+ "question_type": question_type,
242
+ "evidence_list": evidence_list
243
+ }
244
+
245
+ def extract_fact(self, passage, query):
246
+ # Extracting most relevant sentence from passage
247
+ sentences = sent_tokenize(passage)
248
+ query_keywords = set(query.lower().split())
249
+
250
+ best_sentence = max(sentences, key=lambda s: len(set(s.lower().split()) & query_keywords), default="")
251
+ return best_sentence if best_sentence else (sentences[0] if sentences else "")