Spaces:
Paused
Paused
##################################################### | |
### DOCUMENT PROCESSOR [MERGER] | |
##################################################### | |
# Jonathan Wang | |
# ABOUT: | |
# This project creates an app to chat with PDFs. | |
# This is the MERGER | |
# which defines how two lists with scores | |
# should be merged together into one list. | |
# (Useful for fusing things like keywords or textnodes) | |
##################################################### | |
## TODOS: | |
# We're looping through A/B more than necessary. | |
##################################################### | |
## IMPORTS: | |
from __future__ import annotations | |
from typing import TYPE_CHECKING, Sequence, TypeVar, Union | |
import numpy as np | |
if TYPE_CHECKING: | |
from numpy.typing import NDArray | |
##################################################### | |
## CODE: | |
GenericType = TypeVar("GenericType") | |
### TODO(Jonathan Wang): Implement Maximum Marginal Relevance (MMR) | |
# https://en.wikipedia.org/wiki/Maximum_marginal_relevance | |
# def mmr(documents, query, scores, lambda_param=0.5): | |
# """ | |
# Calculate Maximum Marginal Relevance (MMR) for a list of documents. | |
# Parameters: | |
# documents (list of np.array): List of document vectors. | |
# query (np.array): Query vector. | |
# scores (list of float): Relevance scores for each document. | |
# lambda_param (float): Trade-off parameter between relevance and diversity. | |
# Returns: | |
# list of int: Indices of selected documents in order of selection. | |
# """ | |
# selected = [] | |
# remaining = list(range(len(documents))) | |
# while remaining: | |
# if not selected: | |
# # Select the document with the highest relevance score | |
# idx = np.argmax(scores) | |
# else: | |
# # Calculate MMR for remaining documents | |
# mmr_scores = [] | |
# for i in remaining: | |
# relevance = scores[i] | |
# diversity = max([np.dot(documents[i], documents[j]) for j in selected]) | |
# mmr_score = lambda_param * relevance - (1 - lambda_param) * diversity | |
# mmr_scores.append(mmr_score) | |
# idx = remaining[np.argmax(mmr_scores)] | |
# selected.append(idx) | |
# remaining.remove(idx) | |
# return selected | |
def _merge_on_scores( | |
a_list: Sequence[GenericType], | |
b_list: Sequence[GenericType], | |
a_scores_input: Sequence[float | np.float64 | None], | |
b_scores_input: Sequence[float | np.float64 | None], | |
use_distribution: bool = True, | |
a_weight: float = 0.5, | |
top_k: int = 5, | |
) -> Sequence[GenericType]: | |
""" | |
Given two lists of elements with scores, fuse them together using "Distribution-Based Score Fusion". | |
Elements which have high scores in both lists are given even higher ranking here. | |
Inputs: | |
a_list: list of elements for A | |
a_scores: list of scores for each element in A. Assume higher is better. Share the same index. | |
b_list: list of elements for B | |
b_scores: list of scores for each element in B. Assume higher is better. Share the same index. | |
use_distribution: Whether to fuse using Min-Max Scaling (FALSE) or Distribution Based Score Fusion (TRUE) | |
Outputs: | |
List: List of elements that passed the merge. | |
""" | |
# Guard Clauses | |
if ((len(a_list) != len(a_scores_input)) or (len(b_list) != len(b_scores_input))): | |
msg = ( | |
f"""_merge_on_scores: Differing number of elements and scores! | |
a_list: {a_list} | |
a_scores: {a_scores_input} | |
b_list: {b_list} | |
b_scores: {b_scores_input} | |
""" | |
) | |
raise ValueError(msg) | |
if (a_weight > 1 or a_weight < 0): | |
msg = "_merge_on_scores: weight for the A list should be between 0 and 1." | |
raise ValueError(msg) | |
if (top_k < 0): # or top_k > : | |
# TODO(Jonathan Wang): Find a nice way to get the number of unique elements in a list | |
# where those elements are potentially unhashable AND unorderable. | |
# I know about the n^2 solution with two lists and (if not in x), but it's a bit annoying. | |
msg = "_merge_on_scores: top_k must be between 0 and the total number of elements." | |
raise ValueError(msg) | |
# 0. Convert to numpy arrays | |
# NOTE: When using a SubQuestionQueryEngine, the subanswers are saved as NodesWithScores, but their score is None. | |
# We want to filter these out, so we get citations when the two texts are very similar. | |
a_scores: NDArray[np.float64] = np.array(a_scores_input, dtype=np.float64) | |
b_scores: NDArray[np.float64] = np.array(b_scores_input, dtype=np.float64) | |
# 1. Calculate mean of scores. | |
a_mean = np.nanmean(a_scores) # np.nan if empty | |
b_mean = np.nanmean(b_scores) | |
# 2. Calculate standard deviations | |
a_stdev = np.nanstd(a_scores) | |
b_stdev = np.nanstd(b_scores) | |
# 3. Get minimum and maximum bands as 3std from mean | |
# alternatively, use actual min-max scaling | |
a_min = a_mean - 3 * a_stdev if use_distribution else np.nanmin(a_scores) | |
a_max = a_mean + 3 * a_stdev if use_distribution else np.nanmax(a_scores) | |
b_min = b_mean - 3 * b_stdev if use_distribution else np.nanmin(b_scores) | |
b_max = b_mean + 3 * b_stdev if use_distribution else np.nanmax(b_scores) | |
# 4. Rescale the distributions | |
if (a_max > a_min): | |
a_scores = np.array([ | |
((x - a_min) / (a_max - a_min)) | |
for x in a_scores | |
], dtype=np.float64) | |
if (b_max > b_min): | |
b_scores = np.array([ | |
(x - b_min) / (b_max - b_min) | |
for x in b_scores | |
], dtype=np.float64) | |
# 5. Fuse the scores together | |
full_dict: list[tuple[GenericType, float]] = [] | |
for index, element in enumerate(a_list): | |
a_score = a_scores[index] | |
if (element in b_list): | |
# In both A and B. Fuse score. | |
b_score = b_scores[b_list.index(element)] | |
fused_score = a_weight * a_score + (1-a_weight) * b_score | |
full_dict.append((element, fused_score)) | |
else: | |
# Only in A. | |
full_dict.append((element, a_weight * a_score)) | |
for index, element in enumerate(b_list): | |
if (element not in a_list): | |
b_score = b_scores[index] | |
full_dict.append((element, (1-a_weight) * b_score)) | |
full_dict = sorted(full_dict, key=lambda item: item[1], reverse=True) | |
output_list = [item[0] for item in full_dict] | |
if (top_k >= len(full_dict)): | |
return output_list | |
# create final response object | |
return output_list[:top_k] | |