Spaces:
Sleeping
Sleeping
| import json | |
| import re | |
| import nltk | |
| from nltk.tokenize import sent_tokenize | |
| import torch | |
| from sentence_transformers import SentenceTransformer, util | |
| import faiss | |
| import numpy as np | |
| from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer | |
| from rank_bm25 import BM25Okapi # BM25 for hybrid search | |
| import logging | |
| nltk.download('punkt', quiet=True) | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| class Hogragger: | |
| def __init__(self, corpus_path, model_name='sentence-transformers/all-MiniLM-L12-v2', qa_model='deepset/roberta-large-squad2', classifier_model='deepset/roberta-large-squad2'): | |
| self.corpus = self.load_corpus(corpus_path) | |
| self.cleaned_passages = self.preprocess_corpus() | |
| self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| logging.info(f"Using device: {self.device}") | |
| # Initialize embedding model and build FAISS index | |
| self.model = SentenceTransformer(model_name).to(self.device) | |
| self.index = self.build_faiss_index() | |
| # Initialize BM25 for lexical matching | |
| self.bm25 = self.build_bm25_index() | |
| # Initialize classifier for question type prediction | |
| self.tokenizer = AutoTokenizer.from_pretrained(classifier_model) | |
| self.classifier = AutoModelForSequenceClassification.from_pretrained(classifier_model).to(self.device) | |
| # QA Model | |
| self.qa_model = pipeline('question-answering', model=qa_model, device=0 if self.device == 'cuda' else -1) | |
| def load_corpus(self, path): | |
| logging.info(f"Loading corpus from {path}") | |
| with open(path, "r") as f: | |
| corpus = json.load(f) | |
| logging.info(f"Loaded {len(corpus)} documents") | |
| return corpus | |
| # def preprocess_corpus(self): | |
| # cleaned_passages = [] | |
| # for article in self.corpus: | |
| # body = article.get('body', '') | |
| # clean_body = re.sub(r'<.*?>', '', body) # Clean HTML tags | |
| # clean_body = re.sub(r'\s+', ' ', clean_body).strip() # Clean extra spaces | |
| # sentences = sent_tokenize(clean_body) | |
| # chunk = "" | |
| # for sentence in sentences: | |
| # if len(chunk.split()) + len(sentence.split()) <= 300: | |
| # chunk += " " + sentence | |
| # else: | |
| # cleaned_passages.append(self.create_passage(article, chunk)) | |
| # chunk = sentence | |
| # if chunk: | |
| # cleaned_passages.append(self.create_passage(article, chunk)) | |
| # logging.info(f"Created {len(cleaned_passages)} passages") | |
| # return cleaned_passages | |
| def preprocess_corpus(self): | |
| cleaned_passages = [] | |
| for article in self.corpus: | |
| body = article.get('body', '') | |
| clean_body = re.sub(r'<.*?>', '', body) # Clean HTML tags | |
| clean_body = re.sub(r'\s+', ' ', clean_body).strip() # Clean extra spaces | |
| # Simply take the full cleaned text as a passage without chunking or sentence splitting | |
| cleaned_passages.append(self.create_passage(article, clean_body)) | |
| logging.info(f"Created {len(cleaned_passages)} passages") | |
| return cleaned_passages | |
| def create_passage(self, article, chunk): | |
| """Creates a passage dictionary from an article and chunk of text.""" | |
| return { | |
| "title": article['title'], | |
| "author": article.get('author', 'Unknown'), | |
| "published_at": article['published_at'], | |
| "category": article['category'], | |
| "url": article['url'], | |
| "source": article['source'], | |
| "passage": chunk.strip() | |
| } | |
| def build_faiss_index(self): | |
| logging.info("Building FAISS index...") | |
| embeddings = self.model.encode([p['passage'] for p in self.cleaned_passages], convert_to_tensor=True, device=self.device) | |
| embeddings = np.array(embeddings.cpu()).astype('float32') | |
| logging.info(f"Shape of embeddings: {embeddings.shape}") | |
| index = faiss.IndexFlatL2(embeddings.shape[1]) # Initialize FAISS index | |
| if self.device == 'cuda': | |
| try: | |
| res = faiss.StandardGpuResources() | |
| gpu_index = faiss.index_cpu_to_gpu(res, 0, index) | |
| gpu_index.add(embeddings) | |
| logging.info("Successfully created GPU index") | |
| return gpu_index | |
| except RuntimeError as e: | |
| logging.error(f"GPU index creation failed: {e}") | |
| logging.info("Falling back to CPU index") | |
| index.add(embeddings) # Add embeddings to CPU index | |
| logging.info("Successfully created CPU index") | |
| return index | |
| def build_bm25_index(self): | |
| logging.info("Building BM25 index...") | |
| tokenized_corpus = [p['passage'].split() for p in self.cleaned_passages] | |
| bm25 = BM25Okapi(tokenized_corpus) | |
| logging.info("Successfully built BM25 index") | |
| return bm25 | |
| def predict_question_type(self, query): | |
| inputs = self.tokenizer(query, return_tensors='pt').to(self.device) | |
| outputs = self.classifier(**inputs) | |
| prediction = torch.argmax(outputs.logits, dim=1).item() | |
| labels = {0: 'inference_query', 1: 'comparison_query', 2: 'null_query', 3: 'temporal_query', 4: 'fact_query'} | |
| return labels.get(prediction, 'unknown_query') | |
| def retrieve_passages(self, query, k=100, threshold=0.7): | |
| try: | |
| # FAISS retrieval | |
| query_embedding = self.model.encode([query], convert_to_tensor=True, device=self.device) | |
| D, I = self.index.search(np.array(query_embedding.cpu()), k) | |
| # BM25 retrieval | |
| tokenized_query = query.split() | |
| bm25_scores = self.bm25.get_scores(tokenized_query) | |
| # Combine FAISS and BM25 results | |
| hybrid_scores = self.combine_faiss_bm25_scores(D[0], bm25_scores, I) | |
| # Filter passages based on hybrid score | |
| passages = [self.cleaned_passages[i] for i, score in zip(I[0], hybrid_scores) if score > threshold] | |
| logging.info(f"Retrieved {len(passages)} passages using hybrid search for query.") | |
| return passages | |
| except Exception as e: | |
| logging.error(f"Error in retrieving passages: {e}") | |
| return [] | |
| def combine_faiss_bm25_scores(self, faiss_scores, bm25_scores, passage_indices): | |
| # Normalize and combine FAISS and BM25 scores | |
| bm25_scores = np.array(bm25_scores)[passage_indices] | |
| faiss_scores = np.array(faiss_scores) | |
| # Convert FAISS distances into similarities by inverting the scale | |
| faiss_similarities = 1 / (faiss_scores + 1e-6) # Avoid division by zero | |
| # Normalize scores (scale between 0 and 1) | |
| bm25_scores = (bm25_scores - np.min(bm25_scores)) / (np.max(bm25_scores) - np.min(bm25_scores) + 1e-6) | |
| faiss_similarities = (faiss_similarities - np.min(faiss_similarities)) / (np.max(faiss_similarities) - np.min(faiss_similarities) + 1e-6) | |
| # Weighted combination (you can adjust weights) | |
| combined_scores = 0.7 * faiss_similarities + 0.3 * bm25_scores | |
| combined_scores = np.squeeze(combined_scores) # Ensure it's a single-dimensional array | |
| return combined_scores | |
| def filter_passages(self, query, passages): | |
| try: | |
| query_embedding = self.model.encode(query, convert_to_tensor=True) | |
| passage_embeddings = self.model.encode([p['passage'] for p in passages], convert_to_tensor=True) | |
| similarities = util.pytorch_cos_sim(query_embedding, passage_embeddings) | |
| top_k = min(10, len(passages)) | |
| top_indices = similarities.topk(k=top_k)[1].tolist()[0] | |
| selected_passages = [] | |
| used_titles = set() | |
| for i in top_indices: | |
| if passages[i]['title'] not in used_titles: | |
| selected_passages.append(passages[i]) | |
| used_titles.add(passages[i]['title']) | |
| return selected_passages | |
| except Exception as e: | |
| logging.error(f"Error in filtering passages: {e}") | |
| return [] | |
| def generate_answer(self, query, passages): | |
| try: | |
| context = " ".join([p['passage'] for p in passages[:5]]) | |
| answer = self.qa_model(question=query, context=context) | |
| logging.info(f"Generated answer: {answer['answer']}") | |
| return answer['answer'] | |
| except Exception as e: | |
| logging.error(f"Error in generating answer: {e}") | |
| return "Insufficient information." | |
| def post_process_answer(self, answer, confidence=0.2): | |
| answer = re.sub(r'^.*\?', '', answer).strip() | |
| answer = answer.capitalize() | |
| if len(answer) > 100: | |
| truncated = re.match(r'^(.*?[.!?])\s', answer) | |
| if truncated: | |
| answer = truncated.group(1) | |
| if confidence < 0.2: | |
| logging.warning(f"Answer confidence too low: {confidence}") | |
| return "I'm unsure about this answer." | |
| return answer | |
| def process_query(self, query): | |
| question_type = self.predict_question_type(query) | |
| retrieved_passages = self.retrieve_passages(query, k=100, threshold=0.7) | |
| if not retrieved_passages: | |
| return {"query": query, "answer": "No relevant information found", "question_type": question_type, "evidence_list": []} | |
| filtered_passages = self.filter_passages(query, retrieved_passages) | |
| raw_answer = self.generate_answer(query, filtered_passages) | |
| evidence_count = min(len(filtered_passages), 4) | |
| evidence_list = [ | |
| { | |
| "title": p['title'], | |
| "author": p['author'], | |
| "url": p['url'], | |
| "source": p['source'], | |
| "category": p['category'], | |
| "published_at": p['published_at'], | |
| "fact": self.extract_fact(p['passage'], query) | |
| } for p in filtered_passages[:evidence_count] | |
| ] | |
| final_answer = self.post_process_answer(raw_answer) | |
| return { | |
| "query": query, | |
| "answer": final_answer, | |
| "question_type": question_type, | |
| "evidence_list": evidence_list | |
| } | |
| def extract_fact(self, passage, query): | |
| # Extracting most relevant sentence from passage | |
| sentences = sent_tokenize(passage) | |
| query_keywords = set(query.lower().split()) | |
| best_sentence = max(sentences, key=lambda s: len(set(s.lower().split()) & query_keywords), default="") | |
| return best_sentence if best_sentence else (sentences[0] if sentences else "") |