from sentence_transformers import CrossEncoder import math import numpy as np from middlewares.search_client import SearchClient import os from dotenv import load_dotenv load_dotenv() GOOGLE_SEARCH_ENGINE_ID = os.getenv("GOOGLE_SEARCH_ENGINE_ID") GOOGLE_SEARCH_API_KEY = os.getenv("GOOGLE_SEARCH_API_KEY") BING_SEARCH_API_KEY = os.getenv("BING_SEARCH_API_KEY") reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2") googleSearchClient = SearchClient( "google", api_key=GOOGLE_SEARCH_API_KEY, engine_id=GOOGLE_SEARCH_ENGINE_ID ) bingSearchClient = SearchClient("bing", api_key=BING_SEARCH_API_KEY, engine_id=None) def rerank(query, top_k, search_results, chunk_size): """Chunks and reranks the documents using a specified reranker, Args: query (str): the query for reranking top_k (int): the number of top reranked results to return search_results (list[dict]): a list of dictionaries containing "link" and "text" keys for each search result chunk_size (int): the size of each chunk for reranking Returns: list[dict]: a list of dictionaries containing the top reranked results with "link" and "text" keys """ chunks = [] for result in search_results: text = result["text"] words = text.split() num_chunks = math.ceil(len(words) / chunk_size) for i in range(num_chunks): start = i * chunk_size end = (i + 1) * chunk_size chunk = " ".join(words[start:end]) chunks.append((result["link"], chunk)) # Create sentence combinations with the query sentence_combinations = [[query, chunk[1]] for chunk in chunks] # Compute similarity scores for these combinations similarity_scores = reranker.predict(sentence_combinations) # Sort scores indexes in decreasing order sim_scores_argsort = reversed(np.argsort(similarity_scores)) # Rearrange search_results based on the reranked scores reranked_results = [] for idx in sim_scores_argsort: link = chunks[idx][0] chunk = chunks[idx][1] reranked_results.append({"link": link, "text": chunk}) # Return the top K ranks return reranked_results[:top_k] def gen_augmented_prompt_via_websearch( prompt, vendor, n_crawl, top_k, pre_context="", post_context="", pre_prompt="", post_prompt="", pass_prev=False, prev_output="", chunk_size=512, ): """ Generates an augmented prompt and a list of links by performing a web search and reranking the results. Args: prompt (str, required): The prompt for the web search. vendor (str): The search engine to use, either 'Google' or 'Bing'. n_crawl (int): The number of search results to retrieve. top_k (int): The number of top reranked results to return. pre_context (str): The pre-context to be included in the augmented prompt. post_context (str): The post-context to be included in the augmented prompt. pre_prompt (str, optional): The pre-prompt to be included in the augmented prompt. Defaults to "". post_prompt (str, optional): The post-prompt to be included in the augmented prompt. Defaults to "". pass_prev (bool, optional): Whether to include the previous output in the augmented prompt. Defaults to False. prev_output (str, optional): The previous output to be included in the augmented prompt. Defaults to "". chunk_size (int, optional): The size of each chunk for reranking. Defaults to 512. Returns: tuple: A tuple containing the augmented prompt and a list of links. """ search_results = [] reranked_results = [] if vendor == "Google": search_results = googleSearchClient.search(prompt, n_crawl) elif vendor == "Bing": search_results = bingSearchClient.search(prompt, n_crawl) if len(search_results) > 0: reranked_results = rerank(prompt, top_k, search_results, chunk_size) links = [] context = "" for res in reranked_results: context += res["text"] + "\n\n" link = res["link"] links.append(link) # remove duplicate links links = list(set(links)) prev_output = prev_output if pass_prev else "" augmented_prompt = f""" {pre_context} {context} {post_context} {pre_prompt} {prompt} {post_prompt} {prev_output} """ return augmented_prompt, links