AlanFeder's picture
Upload folder using huggingface_hub
1c4216d verified
import logging
from typing import Any
import numpy as np
from rag_utils.rag_utils import OpenAI, do_1_embed
logger = logging.getLogger(__name__)
def do_sort(
embed_q: np.ndarray, embed_talks: np.ndarray, list_talk_ids: list[str]
) -> list[dict[str, str | float]]:
"""
Sort documents based on their cosine similarity to the query embedding.
Args:
embed_dict (dict[str, np.ndarray]): Dictionary containing document embeddings.
arr_q (np.ndarray): Query embedding.
Returns:
pd.DataFrame: Sorted dataframe containing document IDs and similarity scores.
"""
# Calculate cosine similarities between query embedding and document embeddings
cos_sims = np.dot(embed_talks, embed_q)
# Get the indices of the best matching video IDs
best_match_video_ids = np.argsort(-cos_sims)
# Get the sorted video IDs based on the best match indices
sorted_vids = [
{"id0": list_talk_ids[i], "score": -cs}
for i, cs in zip(best_match_video_ids, np.sort(-cos_sims))
]
return sorted_vids
def limit_docs(
sorted_vids: list[dict[str, str | float]], talk_info: dict[str, str | int], n_results: int
) -> list[dict[str, Any]]:
"""
Limit the retrieved documents based on a score threshold and return the top documents.
Args:
df_sorted (pd.DataFrame): Sorted dataframe containing document IDs and similarity scores.
df_talks (pd.DataFrame): Dataframe containing talk information.
n_results (int): Number of top documents to retrieve.
transcript_dicts (dict[str, dict]): Dictionary containing transcript text for each document ID.
Returns:
dict[str, dict]: Dictionary containing the top documents with their IDs, scores, and text.
"""
# Get the top n_results documents
top_vids = sorted_vids[:n_results]
# Get the top score and calculate the score threshold
top_score = top_vids[0]["score"]
score_thresh = max(min(0.6, top_score - 0.05), 0.2)
# Filter the top documents based on the score threshold
keep_texts = []
for my_vid in top_vids:
if my_vid["score"] >= score_thresh:
vid_data = talk_info[my_vid["id0"]]
vid_data = {**vid_data, **my_vid}
keep_texts.append(vid_data)
logger.info(f"{len(keep_texts)} videos kept")
return keep_texts
def do_retrieval(
query0: str,
n_results: int,
api_client: OpenAI,
talk_ids: list[str],
embeds: np.ndarray,
talk_info: dict[str, str | int],
) -> list[dict[str, Any]]:
"""
Retrieve relevant documents based on the user's query.
Args:
query0 (str): The user's query.
n_results (int): The number of documents to retrieve.
api_client (OpenAI): The API client (OpenAI) for generating embeddings.
Returns:
dict[str, dict]: The retrieved documents.
"""
logger.info(f"Starting document retrieval for query: {query0}")
try:
# Generate embeddings for the query
arr_q = do_1_embed(query0, api_client)
# Sort documents based on their cosine similarity to the query embedding
sorted_vids = do_sort(embed_q=arr_q, embed_talks=embeds, list_talk_ids=talk_ids)
# Limit the retrieved documents based on a score threshold
keep_texts = limit_docs(sorted_vids=sorted_vids, talk_info=talk_info, n_results=n_results)
logger.info(f"Retrieved {len(keep_texts)} documents for query: {query0}")
return keep_texts
except Exception as e:
logger.error(f"Error during document retrieval for query: {query0}, Error: {str(e)}")
raise e