Spaces:
Sleeping
Sleeping
from typing import List, Dict, Any | |
import numpy as np | |
import pandas as pd | |
from annoy import AnnoyIndex | |
from sentence_transformers import SentenceTransformer | |
from flashrank import Ranker, RerankRequest | |
from dataclasses import dataclass | |
class CourseSearchResult: | |
course_code: str | |
course_title: str | |
description: str | |
credits: float | |
category: str | |
department: str | |
track: str | |
prerequisite: str | |
remark: str | |
score: float | |
rerank_score: float = None | |
class DDICourseSearch: | |
def __init__(self, courses_df, ann_file: str = 'ddi_courses_index.ann', bi_encoder_name: str = 'all-MiniLM-L6-v2', | |
flash_rank_name: str = "claudecc/flash-rank-reranker", | |
embedding_dim: int = 384): | |
self.courses_df = courses_df | |
self.embedding_dim = embedding_dim | |
self.bi_encoder = SentenceTransformer(bi_encoder_name) | |
self.index = AnnoyIndex(embedding_dim, 'angular') | |
self.index_built = False | |
self.index.load(ann_file) | |
# Initialize FlashRank reranker | |
self.reranker = Ranker(max_length=128) | |
def _create_search_text(self, row: pd.Series) -> str: | |
"""Create search text from course data.""" | |
components = [ | |
str(row['Course Code']), | |
str(row['Course Title']), | |
str(row['Course Description']), | |
str(row['Track']), | |
str(row['Category']) | |
] | |
text = ' '.join(str(comp) for comp in components if pd.notna(comp)) | |
return text.replace('nan', '').strip() | |
def search(self, | |
query: str, | |
k: int = 5, | |
search_type: str = 'hybrid', | |
alpha: float = 0.7, | |
rerank: bool = True, | |
rerank_cutoff: int = 100) -> List[CourseSearchResult]: | |
""" | |
Search for courses using the specified method. | |
Args: | |
query: Search query | |
k: Number of final results to return | |
search_type: One of 'semantic', 'keyword', or 'hybrid' | |
alpha: Weight for semantic search in hybrid mode | |
rerank: Whether to apply FlashRank reranking | |
rerank_cutoff: Number of initial results to rerank | |
Returns: | |
List of CourseSearchResult objects | |
""" | |
# Get initial results | |
if search_type == 'semantic': | |
results = self._semantic_search(query, rerank_cutoff if rerank else k) | |
elif search_type == 'keyword': | |
results = self._keyword_search(query, rerank_cutoff if rerank else k) | |
else: | |
results = self._hybrid_search(query, rerank_cutoff if rerank else k, alpha) | |
# Apply FlashRank reranking if requested | |
if rerank: | |
results = self._rerank_results(query, results, k) | |
return results[:k] | |
def _rerank_results(self, query: str, results: List[CourseSearchResult], k: int) -> List[CourseSearchResult]: | |
"""Rerank results using FlashRank.""" | |
# Prepare texts for reranking | |
texts = [] | |
for result in results: | |
text = self._create_search_text(pd.Series({ | |
'Course Code': result.course_code, | |
'Course Title': result.course_title, | |
'Course Description': result.description, | |
'Track': result.track, | |
'Category': result.category | |
})) | |
texts.append({'text': text}) | |
# Get reranking scores | |
rerankrequest = RerankRequest(query=query, passages=texts) | |
rerank_scores = self.reranker.rerank(rerankrequest) | |
rerank_scores = [item['score'] for item in rerank_scores] | |
for result, score in zip(results, rerank_scores): | |
result.rerank_score = float(score) | |
# Sort by rerank score | |
results.sort(key=lambda x: x.rerank_score, reverse=True) | |
return results | |
def _semantic_search(self, query: str, k: int) -> List[CourseSearchResult]: | |
"""Perform semantic search.""" | |
query_embedding = self.bi_encoder.encode(query) | |
indices, distances = self.index.get_nns_by_vector( | |
query_embedding, k, include_distances=True | |
) | |
# Convert distances to similarities | |
similarities = [1 - (distance ** 2) / 2 for distance in distances] | |
return [self._create_result(idx, sim) for idx, sim in zip(indices, similarities)] | |
def _keyword_search(self, query: str, k: int) -> List[CourseSearchResult]: | |
"""Perform keyword-based search.""" | |
query_terms = set(query.lower().split()) | |
scores = [] | |
for idx, row in self.courses_df.iterrows(): | |
text = self._create_search_text(row).lower() | |
text_terms = set(text.split()) | |
# Calculate TF score | |
matches = len(query_terms.intersection(text_terms)) | |
score = matches / len(query_terms) if query_terms else 0 | |
scores.append((idx, score)) | |
# Sort and get top k | |
scores.sort(key=lambda x: x[1], reverse=True) | |
return [self._create_result(idx, score) for idx, score in scores[:k]] | |
def _hybrid_search(self, query: str, k: int, alpha: float) -> List[CourseSearchResult]: | |
"""Combine semantic and keyword search results.""" | |
semantic_results = self._semantic_search(query, k) | |
keyword_results = self._keyword_search(query, k) | |
# Combine scores | |
combined_scores = {} | |
for result in semantic_results: | |
combined_scores[result.course_code] = alpha * result.score | |
for result in keyword_results: | |
if result.course_code in combined_scores: | |
combined_scores[result.course_code] += (1 - alpha) * result.score | |
else: | |
combined_scores[result.course_code] = (1 - alpha) * result.score | |
results = [] | |
for code, score in sorted(combined_scores.items(), key=lambda x: x[1], reverse=True)[:k]: | |
row = self.courses_df[self.courses_df['Course Code'] == code].iloc[0] | |
results.append(self._create_result_from_row(row, score)) | |
return results | |
def _create_result(self, idx: int, score: float) -> CourseSearchResult: | |
"""Create a CourseSearchResult from index and score.""" | |
row = self.courses_df.iloc[idx] | |
return self._create_result_from_row(row, score) | |
def _create_result_from_row(self, row: pd.Series, score: float) -> CourseSearchResult: | |
"""Create a CourseSearchResult from row and score.""" | |
return CourseSearchResult( | |
course_code=row['Course Code'], | |
course_title=row['Course Title'], | |
description=str(row['Course Description']), | |
credits=row['Credits'], | |
category=str(row['Category']), | |
department=str(row['Department']), | |
track=str(row['Track']), | |
prerequisite=str(row['Prerequisite']), | |
remark=str(row['Remark']), | |
score=score | |
) |