Spaces:
Runtime error
Runtime error
from abc import ABC, abstractmethod | |
from dataclasses import dataclass | |
import pandas as pd | |
from openai.embeddings_utils import cosine_similarity | |
ALL_SOURCES = "All" | |
class Retriever(ABC): | |
def get_documents(self, source: str) -> pd.DataFrame: | |
"""Get all current documents from a given source.""" | |
... | |
def get_source_display_name(self, source: str) -> str: | |
"""Get the display name of a source.""" | |
... | |
def retrieve(self, query_embedding: list[float], top_k: int, source: str = None) -> pd.DataFrame: | |
documents = self.get_documents(source) | |
documents["similarity"] = documents.embedding.apply(lambda x: cosine_similarity(x, query_embedding)) | |
# sort the matched_documents by score | |
matched_documents = documents.sort_values("similarity", ascending=False) | |
# limit search to top_k matched_documents. | |
top_k = len(matched_documents) if top_k == -1 else top_k | |
matched_documents = matched_documents.head(top_k) | |
return matched_documents | |