from sentence_transformers import CrossEncoder import json import math import numpy as np from middlewares.search_client import SearchClient import os from dotenv import load_dotenv load_dotenv() GOOGLE_SEARCH_ENGINE_ID = os.getenv("GOOGLE_SEARCH_ENGINE_ID") GOOGLE_SEARCH_API_KEY = os.getenv("GOOGLE_SEARCH_API_KEY") BING_SEARCH_API_KEY = os.getenv("BING_SEARCH_API_KEY") reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2") googleSearchClient = SearchClient( "google", api_key=GOOGLE_SEARCH_API_KEY, engine_id=GOOGLE_SEARCH_ENGINE_ID ) bingSearchClient = SearchClient("bing", api_key=BING_SEARCH_API_KEY, engine_id=None) def rerank(query, top_k, search_results, chunk_size=512): chunks = [] for result in search_results: text = result["text"] words = text.split() num_chunks = math.ceil(len(words) / chunk_size) for i in range(num_chunks): start = i * chunk_size end = (i + 1) * chunk_size chunk = " ".join(words[start:end]) chunks.append((result["link"], chunk)) # Create sentence combinations with the query sentence_combinations = [[query, chunk[1]] for chunk in chunks] # Compute similarity scores for these combinations similarity_scores = reranker.predict(sentence_combinations) # Sort scores indexes in decreasing order sim_scores_argsort = reversed(np.argsort(similarity_scores)) # Rearrange search_results based on the reranked scores reranked_results = [] for idx in sim_scores_argsort: link = chunks[idx][0] chunk = chunks[idx][1] reranked_results.append({"link": link, "text": chunk}) # Return the top K ranks return reranked_results[:top_k] def gen_augmented_prompt_via_websearch( prompt, search_vendor, n_crawl, top_k, pre_context="", post_context="", pre_prompt="", post_prompt="", pass_prev=False, prev_output="", chunk_size=512, ): search_results = [] reranked_results = [] if search_vendor == "Google": search_results = googleSearchClient.search(prompt, n_crawl) elif search_vendor == "Bing": search_results = bingSearchClient.search(prompt, n_crawl) if len(search_results) > 0: reranked_results = rerank(prompt, top_k, search_results, chunk_size) links = [] context = "" for res in reranked_results: context += res["text"] + "\n\n" link = res["link"] links.append(link) # remove duplicate links links = list(set(links)) prev_output = prev_output if pass_prev else "" augmented_prompt = f""" {pre_context} {context} {post_context} {pre_prompt} {prompt} {post_prompt} {prev_output} """ print(augmented_prompt) return augmented_prompt, links