File size: 4,489 Bytes
0676ee0
 
 
b615916
0676ee0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac9e241
0676ee0
 
ac9e241
 
 
 
0676ee0
 
ac9e241
0676ee0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d466514
 
0676ee0
 
 
 
 
 
 
d466514
0676ee0
 
d466514
 
 
 
 
 
 
 
 
 
 
0676ee0
 
d466514
0676ee0
d466514
0676ee0
e51667a
0676ee0
 
 
 
 
e51667a
 
0676ee0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
from sentence_transformers import CrossEncoder
import math
import numpy as np
from middlewares.search_client import SearchClient
import os
from dotenv import load_dotenv


load_dotenv()


GOOGLE_SEARCH_ENGINE_ID = os.getenv("GOOGLE_SEARCH_ENGINE_ID")
GOOGLE_SEARCH_API_KEY = os.getenv("GOOGLE_SEARCH_API_KEY")
BING_SEARCH_API_KEY = os.getenv("BING_SEARCH_API_KEY")

reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")

googleSearchClient = SearchClient(
    "google", api_key=GOOGLE_SEARCH_API_KEY, engine_id=GOOGLE_SEARCH_ENGINE_ID
)
bingSearchClient = SearchClient("bing", api_key=BING_SEARCH_API_KEY, engine_id=None)


def rerank(query, top_k, search_results, chunk_size):
    """Chunks and reranks the documents using a specified reranker,

    Args:
        query (str): the query for reranking
        top_k (int): the number of top reranked results to return
        search_results (list[dict]): a list of dictionaries containing "link" and "text" keys for each search result
        chunk_size (int): the size of each chunk for reranking

    Returns:
        list[dict]: a list of dictionaries containing the top reranked results with "link" and "text" keys
    """
    chunks = []
    for result in search_results:
        text = result["text"]
        words = text.split()
        num_chunks = math.ceil(len(words) / chunk_size)
        for i in range(num_chunks):
            start = i * chunk_size
            end = (i + 1) * chunk_size
            chunk = " ".join(words[start:end])
            chunks.append((result["link"], chunk))

    # Create sentence combinations with the query
    sentence_combinations = [[query, chunk[1]] for chunk in chunks]

    # Compute similarity scores for these combinations
    similarity_scores = reranker.predict(sentence_combinations)

    # Sort scores indexes in decreasing order
    sim_scores_argsort = reversed(np.argsort(similarity_scores))

    # Rearrange search_results based on the reranked scores
    reranked_results = []
    for idx in sim_scores_argsort:
        link = chunks[idx][0]
        chunk = chunks[idx][1]
        reranked_results.append({"link": link, "text": chunk})

    # Return the top K ranks
    return reranked_results[:top_k]


def gen_augmented_prompt_via_websearch(
    prompt,
    vendor,
    n_crawl,
    top_k,
    pre_context="",
    post_context="",
    pre_prompt="",
    post_prompt="",
    pass_prev=False,
    prev_output="",
    chunk_size=512,
):
    """
    Generates an augmented prompt and a list of links by performing a web search and reranking the results.

    Args:
        prompt (str, required): The prompt for the web search.
        vendor (str): The search engine to use, either 'Google' or 'Bing'.
        n_crawl (int): The number of search results to retrieve.
        top_k (int): The number of top reranked results to return.
        pre_context (str): The pre-context to be included in the augmented prompt.
        post_context (str): The post-context to be included in the augmented prompt.
        pre_prompt (str, optional): The pre-prompt to be included in the augmented prompt. Defaults to "".
        post_prompt (str, optional): The post-prompt to be included in the augmented prompt. Defaults to "".
        pass_prev (bool, optional): Whether to include the previous output in the augmented prompt. Defaults to False.
        prev_output (str, optional): The previous output to be included in the augmented prompt. Defaults to "".
        chunk_size (int, optional): The size of each chunk for reranking. Defaults to 512.

    Returns:
        tuple: A tuple containing the augmented prompt and a list of links.
    """

    search_results = []
    reranked_results = []
    if vendor == "Google":
        search_results = googleSearchClient.search(prompt, n_crawl)
    elif vendor == "Bing":
        search_results = bingSearchClient.search(prompt, n_crawl)

    if len(search_results) > 0:
        reranked_results = rerank(prompt, top_k, search_results, chunk_size)

    links = []
    context = ""
    for res in reranked_results:
        context += res["text"] + "\n\n"
        link = res["link"]
        links.append(link)

    # remove duplicate links
    links = list(set(links))

    prev_output = prev_output if pass_prev else ""

    augmented_prompt = f"""
    {pre_context}

    {context}

    {post_context}
    
    {pre_prompt} 
    
    {prompt} 
    
    {post_prompt}

    {prev_output}

    """
    return augmented_prompt, links