File size: 2,871 Bytes
831e906
67b695a
592fd3e
831e906
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
592fd3e
 
831e906
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
592fd3e
c54ec59
831e906
 
 
20b4ac9
831e906
 
 
 
 
 
 
 
 
 
 
c54ec59
 
831e906
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from sentence_transformers import CrossEncoder

import json
import math
import numpy as np
from middlewares.search_client import SearchClient
import os
from dotenv import load_dotenv


load_dotenv()


GOOGLE_SEARCH_ENGINE_ID = os.getenv("GOOGLE_SEARCH_ENGINE_ID")
GOOGLE_SEARCH_API_KEY = os.getenv("GOOGLE_SEARCH_API_KEY")
BING_SEARCH_API_KEY = os.getenv("BING_SEARCH_API_KEY")

reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")

googleSearchClient = SearchClient(
    "google", api_key=GOOGLE_SEARCH_API_KEY, engine_id=GOOGLE_SEARCH_ENGINE_ID
)
bingSearchClient = SearchClient("bing", api_key=BING_SEARCH_API_KEY, engine_id=None)




def rerank(query, top_k, search_results, chunk_size=512):
    chunks = []
    for result in search_results:
        text = result["text"]
        words = text.split()
        num_chunks = math.ceil(len(words) / chunk_size)
        for i in range(num_chunks):
            start = i * chunk_size
            end = (i + 1) * chunk_size
            chunk = " ".join(words[start:end])
            chunks.append((result["link"], chunk))

    # Create sentence combinations with the query
    sentence_combinations = [[query, chunk[1]] for chunk in chunks]

    # Compute similarity scores for these combinations
    similarity_scores = reranker.predict(sentence_combinations)

    # Sort scores indexes in decreasing order
    sim_scores_argsort = reversed(np.argsort(similarity_scores))

    # Rearrange search_results based on the reranked scores
    reranked_results = []
    for idx in sim_scores_argsort:
        link = chunks[idx][0]
        chunk = chunks[idx][1]
        reranked_results.append({"link": link, "text": chunk})

    # Return the top K ranks
    return reranked_results[:top_k]


def gen_augmented_prompt_via_websearch(
    prompt,
    search_vendor,
    n_crawl,
    top_k,
    pre_context="",
    post_context="",
    pre_prompt="",
    post_prompt="",
    pass_prev=False,
    prev_output="",
    chunk_size=512,
):
   

    search_results = []
    reranked_results = []
    if search_vendor == "Google":
        search_results = googleSearchClient.search(prompt, n_crawl)
    elif search_vendor == "Bing":
        search_results = bingSearchClient.search(prompt, n_crawl)

    if len(search_results) > 0:
        reranked_results = rerank(prompt, top_k, search_results, chunk_size)

    links = []
    context = ""
    for res in reranked_results:
        context += res["text"] + "\n\n"
        link = res["link"]
        links.append(link)

    # remove duplicate links
    links = list(set(links))

    prev_output = prev_output if pass_prev else ""

    augmented_prompt = f"""
    
    {pre_context}

    {context}

    
    {post_context}
    
    {pre_prompt} 
    
    {prompt} 
    
    {post_prompt}

    {prev_output}

    """

    print(augmented_prompt)
    return augmented_prompt, links