Spaces:
Running
Running
from sentence_transformers import CrossEncoder | |
import math | |
import numpy as np | |
from middlewares.search_client import SearchClient | |
import os | |
from dotenv import load_dotenv | |
load_dotenv() | |
GOOGLE_SEARCH_ENGINE_ID = os.getenv("GOOGLE_SEARCH_ENGINE_ID") | |
GOOGLE_SEARCH_API_KEY = os.getenv("GOOGLE_SEARCH_API_KEY") | |
BING_SEARCH_API_KEY = os.getenv("BING_SEARCH_API_KEY") | |
reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2") | |
googleSearchClient = SearchClient( | |
"google", api_key=GOOGLE_SEARCH_API_KEY, engine_id=GOOGLE_SEARCH_ENGINE_ID | |
) | |
bingSearchClient = SearchClient("bing", api_key=BING_SEARCH_API_KEY, engine_id=None) | |
def rerank(query, top_k, search_results, chunk_size): | |
"""Chunks and reranks the documents using a specified reranker, | |
Args: | |
query (str): the query for reranking | |
top_k (int): the number of top reranked results to return | |
search_results (list[dict]): a list of dictionaries containing "link" and "text" keys for each search result | |
chunk_size (int): the size of each chunk for reranking | |
Returns: | |
list[dict]: a list of dictionaries containing the top reranked results with "link" and "text" keys | |
""" | |
chunks = [] | |
for result in search_results: | |
text = result["text"] | |
words = text.split() | |
num_chunks = math.ceil(len(words) / chunk_size) | |
for i in range(num_chunks): | |
start = i * chunk_size | |
end = (i + 1) * chunk_size | |
chunk = " ".join(words[start:end]) | |
chunks.append((result["link"], chunk)) | |
# Create sentence combinations with the query | |
sentence_combinations = [[query, chunk[1]] for chunk in chunks] | |
# Compute similarity scores for these combinations | |
similarity_scores = reranker.predict(sentence_combinations) | |
# Sort scores indexes in decreasing order | |
sim_scores_argsort = reversed(np.argsort(similarity_scores)) | |
# Rearrange search_results based on the reranked scores | |
reranked_results = [] | |
for idx in sim_scores_argsort: | |
link = chunks[idx][0] | |
chunk = chunks[idx][1] | |
reranked_results.append({"link": link, "text": chunk}) | |
# Return the top K ranks | |
return reranked_results[:top_k] | |
def gen_augmented_prompt_via_websearch( | |
prompt, | |
vendor, | |
n_crawl, | |
top_k, | |
pre_context="", | |
post_context="", | |
pre_prompt="", | |
post_prompt="", | |
pass_prev=False, | |
prev_output="", | |
chunk_size=512, | |
): | |
""" | |
Generates an augmented prompt and a list of links by performing a web search and reranking the results. | |
Args: | |
prompt (str, required): The prompt for the web search. | |
vendor (str): The search engine to use, either 'Google' or 'Bing'. | |
n_crawl (int): The number of search results to retrieve. | |
top_k (int): The number of top reranked results to return. | |
pre_context (str): The pre-context to be included in the augmented prompt. | |
post_context (str): The post-context to be included in the augmented prompt. | |
pre_prompt (str, optional): The pre-prompt to be included in the augmented prompt. Defaults to "". | |
post_prompt (str, optional): The post-prompt to be included in the augmented prompt. Defaults to "". | |
pass_prev (bool, optional): Whether to include the previous output in the augmented prompt. Defaults to False. | |
prev_output (str, optional): The previous output to be included in the augmented prompt. Defaults to "". | |
chunk_size (int, optional): The size of each chunk for reranking. Defaults to 512. | |
Returns: | |
tuple: A tuple containing the augmented prompt and a list of links. | |
""" | |
search_results = [] | |
reranked_results = [] | |
if vendor == "Google": | |
search_results = googleSearchClient.search(prompt, n_crawl) | |
elif vendor == "Bing": | |
search_results = bingSearchClient.search(prompt, n_crawl) | |
if len(search_results) > 0: | |
reranked_results = rerank(prompt, top_k, search_results, chunk_size) | |
links = [] | |
context = "" | |
for res in reranked_results: | |
context += res["text"] + "\n\n" | |
link = res["link"] | |
links.append(link) | |
# remove duplicate links | |
links = list(set(links)) | |
prev_output = prev_output if pass_prev else "" | |
augmented_prompt = f""" | |
{pre_context} | |
{context} | |
{post_context} | |
{pre_prompt} | |
{prompt} | |
{post_prompt} | |
{prev_output} | |
""" | |
return augmented_prompt, links | |