pragneshbarik's picture
moved search_client, chat_client, utils to middleware
b615916
raw
history blame
4.49 kB
from sentence_transformers import CrossEncoder
import math
import numpy as np
from middlewares.search_client import SearchClient
import os
from dotenv import load_dotenv
load_dotenv()
GOOGLE_SEARCH_ENGINE_ID = os.getenv("GOOGLE_SEARCH_ENGINE_ID")
GOOGLE_SEARCH_API_KEY = os.getenv("GOOGLE_SEARCH_API_KEY")
BING_SEARCH_API_KEY = os.getenv("BING_SEARCH_API_KEY")
reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
googleSearchClient = SearchClient(
"google", api_key=GOOGLE_SEARCH_API_KEY, engine_id=GOOGLE_SEARCH_ENGINE_ID
)
bingSearchClient = SearchClient("bing", api_key=BING_SEARCH_API_KEY, engine_id=None)
def rerank(query, top_k, search_results, chunk_size):
"""Chunks and reranks the documents using a specified reranker,
Args:
query (str): the query for reranking
top_k (int): the number of top reranked results to return
search_results (list[dict]): a list of dictionaries containing "link" and "text" keys for each search result
chunk_size (int): the size of each chunk for reranking
Returns:
list[dict]: a list of dictionaries containing the top reranked results with "link" and "text" keys
"""
chunks = []
for result in search_results:
text = result["text"]
words = text.split()
num_chunks = math.ceil(len(words) / chunk_size)
for i in range(num_chunks):
start = i * chunk_size
end = (i + 1) * chunk_size
chunk = " ".join(words[start:end])
chunks.append((result["link"], chunk))
# Create sentence combinations with the query
sentence_combinations = [[query, chunk[1]] for chunk in chunks]
# Compute similarity scores for these combinations
similarity_scores = reranker.predict(sentence_combinations)
# Sort scores indexes in decreasing order
sim_scores_argsort = reversed(np.argsort(similarity_scores))
# Rearrange search_results based on the reranked scores
reranked_results = []
for idx in sim_scores_argsort:
link = chunks[idx][0]
chunk = chunks[idx][1]
reranked_results.append({"link": link, "text": chunk})
# Return the top K ranks
return reranked_results[:top_k]
def gen_augmented_prompt_via_websearch(
prompt,
vendor,
n_crawl,
top_k,
pre_context="",
post_context="",
pre_prompt="",
post_prompt="",
pass_prev=False,
prev_output="",
chunk_size=512,
):
"""
Generates an augmented prompt and a list of links by performing a web search and reranking the results.
Args:
prompt (str, required): The prompt for the web search.
vendor (str): The search engine to use, either 'Google' or 'Bing'.
n_crawl (int): The number of search results to retrieve.
top_k (int): The number of top reranked results to return.
pre_context (str): The pre-context to be included in the augmented prompt.
post_context (str): The post-context to be included in the augmented prompt.
pre_prompt (str, optional): The pre-prompt to be included in the augmented prompt. Defaults to "".
post_prompt (str, optional): The post-prompt to be included in the augmented prompt. Defaults to "".
pass_prev (bool, optional): Whether to include the previous output in the augmented prompt. Defaults to False.
prev_output (str, optional): The previous output to be included in the augmented prompt. Defaults to "".
chunk_size (int, optional): The size of each chunk for reranking. Defaults to 512.
Returns:
tuple: A tuple containing the augmented prompt and a list of links.
"""
search_results = []
reranked_results = []
if vendor == "Google":
search_results = googleSearchClient.search(prompt, n_crawl)
elif vendor == "Bing":
search_results = bingSearchClient.search(prompt, n_crawl)
if len(search_results) > 0:
reranked_results = rerank(prompt, top_k, search_results, chunk_size)
links = []
context = ""
for res in reranked_results:
context += res["text"] + "\n\n"
link = res["link"]
links.append(link)
# remove duplicate links
links = list(set(links))
prev_output = prev_output if pass_prev else ""
augmented_prompt = f"""
{pre_context}
{context}
{post_context}
{pre_prompt}
{prompt}
{post_prompt}
{prev_output}
"""
return augmented_prompt, links