Spaces:
Sleeping
Sleeping
| from typing import List, Dict, Any | |
| import numpy as np | |
| import pandas as pd | |
| from annoy import AnnoyIndex | |
| from sentence_transformers import SentenceTransformer | |
| from flashrank import Ranker, RerankRequest | |
| from dataclasses import dataclass | |
| class CourseSearchResult: | |
| course_code: str | |
| course_title: str | |
| description: str | |
| credits: float | |
| category: str | |
| department: str | |
| track: str | |
| prerequisite: str | |
| remark: str | |
| score: float | |
| rerank_score: float = None | |
| class DDICourseSearch: | |
| def __init__(self, courses_df, ann_file: str = 'ddi_courses_index.ann', bi_encoder_name: str = 'all-MiniLM-L6-v2', | |
| flash_rank_name: str = "claudecc/flash-rank-reranker", | |
| embedding_dim: int = 384): | |
| self.courses_df = courses_df | |
| self.embedding_dim = embedding_dim | |
| self.bi_encoder = SentenceTransformer(bi_encoder_name) | |
| self.index = AnnoyIndex(embedding_dim, 'angular') | |
| self.index_built = False | |
| self.index.load(ann_file) | |
| # Initialize FlashRank reranker | |
| self.reranker = Ranker(max_length=128) | |
| def _create_search_text(self, row: pd.Series) -> str: | |
| """Create search text from course data.""" | |
| components = [ | |
| str(row['Course Code']), | |
| str(row['Course Title']), | |
| str(row['Course Description']), | |
| str(row['Track']), | |
| str(row['Category']) | |
| ] | |
| text = ' '.join(str(comp) for comp in components if pd.notna(comp)) | |
| return text.replace('nan', '').strip() | |
| def search(self, | |
| query: str, | |
| k: int = 5, | |
| search_type: str = 'hybrid', | |
| alpha: float = 0.7, | |
| rerank: bool = True, | |
| rerank_cutoff: int = 100) -> List[CourseSearchResult]: | |
| """ | |
| Search for courses using the specified method. | |
| Args: | |
| query: Search query | |
| k: Number of final results to return | |
| search_type: One of 'semantic', 'keyword', or 'hybrid' | |
| alpha: Weight for semantic search in hybrid mode | |
| rerank: Whether to apply FlashRank reranking | |
| rerank_cutoff: Number of initial results to rerank | |
| Returns: | |
| List of CourseSearchResult objects | |
| """ | |
| # Get initial results | |
| if search_type == 'semantic': | |
| results = self._semantic_search(query, rerank_cutoff if rerank else k) | |
| elif search_type == 'keyword': | |
| results = self._keyword_search(query, rerank_cutoff if rerank else k) | |
| else: | |
| results = self._hybrid_search(query, rerank_cutoff if rerank else k, alpha) | |
| # Apply FlashRank reranking if requested | |
| if rerank: | |
| results = self._rerank_results(query, results, k) | |
| return results[:k] | |
| def _rerank_results(self, query: str, results: List[CourseSearchResult], k: int) -> List[CourseSearchResult]: | |
| """Rerank results using FlashRank.""" | |
| # Prepare texts for reranking | |
| texts = [] | |
| for result in results: | |
| text = self._create_search_text(pd.Series({ | |
| 'Course Code': result.course_code, | |
| 'Course Title': result.course_title, | |
| 'Course Description': result.description, | |
| 'Track': result.track, | |
| 'Category': result.category | |
| })) | |
| texts.append({'text': text}) | |
| # Get reranking scores | |
| rerankrequest = RerankRequest(query=query, passages=texts) | |
| rerank_scores = self.reranker.rerank(rerankrequest) | |
| rerank_scores = [item['score'] for item in rerank_scores] | |
| for result, score in zip(results, rerank_scores): | |
| result.rerank_score = float(score) | |
| # Sort by rerank score | |
| results.sort(key=lambda x: x.rerank_score, reverse=True) | |
| return results | |
| def _semantic_search(self, query: str, k: int) -> List[CourseSearchResult]: | |
| """Perform semantic search.""" | |
| query_embedding = self.bi_encoder.encode(query) | |
| indices, distances = self.index.get_nns_by_vector( | |
| query_embedding, k, include_distances=True | |
| ) | |
| # Convert distances to similarities | |
| similarities = [1 - (distance ** 2) / 2 for distance in distances] | |
| return [self._create_result(idx, sim) for idx, sim in zip(indices, similarities)] | |
| def _keyword_search(self, query: str, k: int) -> List[CourseSearchResult]: | |
| """Perform keyword-based search.""" | |
| query_terms = set(query.lower().split()) | |
| scores = [] | |
| for idx, row in self.courses_df.iterrows(): | |
| text = self._create_search_text(row).lower() | |
| text_terms = set(text.split()) | |
| # Calculate TF score | |
| matches = len(query_terms.intersection(text_terms)) | |
| score = matches / len(query_terms) if query_terms else 0 | |
| scores.append((idx, score)) | |
| # Sort and get top k | |
| scores.sort(key=lambda x: x[1], reverse=True) | |
| return [self._create_result(idx, score) for idx, score in scores[:k]] | |
| def _hybrid_search(self, query: str, k: int, alpha: float) -> List[CourseSearchResult]: | |
| """Combine semantic and keyword search results.""" | |
| semantic_results = self._semantic_search(query, k) | |
| keyword_results = self._keyword_search(query, k) | |
| # Combine scores | |
| combined_scores = {} | |
| for result in semantic_results: | |
| combined_scores[result.course_code] = alpha * result.score | |
| for result in keyword_results: | |
| if result.course_code in combined_scores: | |
| combined_scores[result.course_code] += (1 - alpha) * result.score | |
| else: | |
| combined_scores[result.course_code] = (1 - alpha) * result.score | |
| results = [] | |
| for code, score in sorted(combined_scores.items(), key=lambda x: x[1], reverse=True)[:k]: | |
| row = self.courses_df[self.courses_df['Course Code'] == code].iloc[0] | |
| results.append(self._create_result_from_row(row, score)) | |
| return results | |
| def _create_result(self, idx: int, score: float) -> CourseSearchResult: | |
| """Create a CourseSearchResult from index and score.""" | |
| row = self.courses_df.iloc[idx] | |
| return self._create_result_from_row(row, score) | |
| def _create_result_from_row(self, row: pd.Series, score: float) -> CourseSearchResult: | |
| """Create a CourseSearchResult from row and score.""" | |
| return CourseSearchResult( | |
| course_code=row['Course Code'], | |
| course_title=row['Course Title'], | |
| description=str(row['Course Description']), | |
| credits=row['Credits'], | |
| category=str(row['Category']), | |
| department=str(row['Department']), | |
| track=str(row['Track']), | |
| prerequisite=str(row['Prerequisite']), | |
| remark=str(row['Remark']), | |
| score=score | |
| ) |