|
import logging |
|
from typing import Any |
|
|
|
import numpy as np |
|
|
|
from rag_utils.rag_utils import OpenAI, do_1_embed |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def do_sort( |
|
embed_q: np.ndarray, embed_talks: np.ndarray, list_talk_ids: list[str] |
|
) -> list[dict[str, str | float]]: |
|
""" |
|
Sort documents based on their cosine similarity to the query embedding. |
|
|
|
Args: |
|
embed_dict (dict[str, np.ndarray]): Dictionary containing document embeddings. |
|
arr_q (np.ndarray): Query embedding. |
|
|
|
Returns: |
|
pd.DataFrame: Sorted dataframe containing document IDs and similarity scores. |
|
""" |
|
|
|
|
|
cos_sims = np.dot(embed_talks, embed_q) |
|
|
|
|
|
best_match_video_ids = np.argsort(-cos_sims) |
|
|
|
|
|
sorted_vids = [ |
|
{"id0": list_talk_ids[i], "score": -cs} |
|
for i, cs in zip(best_match_video_ids, np.sort(-cos_sims)) |
|
] |
|
|
|
return sorted_vids |
|
|
|
|
|
def limit_docs( |
|
sorted_vids: list[dict[str, str | float]], talk_info: dict[str, str | int], n_results: int |
|
) -> list[dict[str, Any]]: |
|
""" |
|
Limit the retrieved documents based on a score threshold and return the top documents. |
|
|
|
Args: |
|
df_sorted (pd.DataFrame): Sorted dataframe containing document IDs and similarity scores. |
|
df_talks (pd.DataFrame): Dataframe containing talk information. |
|
n_results (int): Number of top documents to retrieve. |
|
transcript_dicts (dict[str, dict]): Dictionary containing transcript text for each document ID. |
|
|
|
Returns: |
|
dict[str, dict]: Dictionary containing the top documents with their IDs, scores, and text. |
|
""" |
|
|
|
|
|
top_vids = sorted_vids[:n_results] |
|
|
|
|
|
top_score = top_vids[0]["score"] |
|
score_thresh = max(min(0.6, top_score - 0.05), 0.2) |
|
|
|
|
|
keep_texts = [] |
|
for my_vid in top_vids: |
|
if my_vid["score"] >= score_thresh: |
|
vid_data = talk_info[my_vid["id0"]] |
|
vid_data = {**vid_data, **my_vid} |
|
keep_texts.append(vid_data) |
|
|
|
logger.info(f"{len(keep_texts)} videos kept") |
|
|
|
return keep_texts |
|
|
|
|
|
def do_retrieval( |
|
query0: str, |
|
n_results: int, |
|
api_client: OpenAI, |
|
talk_ids: list[str], |
|
embeds: np.ndarray, |
|
talk_info: dict[str, str | int], |
|
) -> list[dict[str, Any]]: |
|
""" |
|
Retrieve relevant documents based on the user's query. |
|
|
|
Args: |
|
query0 (str): The user's query. |
|
n_results (int): The number of documents to retrieve. |
|
api_client (OpenAI): The API client (OpenAI) for generating embeddings. |
|
|
|
Returns: |
|
dict[str, dict]: The retrieved documents. |
|
""" |
|
logger.info(f"Starting document retrieval for query: {query0}") |
|
try: |
|
|
|
arr_q = do_1_embed(query0, api_client) |
|
|
|
|
|
sorted_vids = do_sort(embed_q=arr_q, embed_talks=embeds, list_talk_ids=talk_ids) |
|
|
|
|
|
keep_texts = limit_docs(sorted_vids=sorted_vids, talk_info=talk_info, n_results=n_results) |
|
|
|
logger.info(f"Retrieved {len(keep_texts)} documents for query: {query0}") |
|
|
|
return keep_texts |
|
except Exception as e: |
|
logger.error(f"Error during document retrieval for query: {query0}, Error: {str(e)}") |
|
raise e |
|
|