File size: 5,502 Bytes
831e906
20b4ac9
592fd3e
 
831e906
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
592fd3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df8ae06
 
 
 
 
 
 
 
f08b7db
df8ae06
 
 
 
 
 
 
 
 
 
 
 
f08b7db
df8ae06
 
 
 
 
 
 
 
 
 
 
 
 
e5d9c7e
 
 
df8ae06
592fd3e
 
 
 
 
831e906
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
592fd3e
831e906
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
592fd3e
c54ec59
831e906
 
 
592fd3e
20b4ac9
831e906
 
 
 
 
 
 
 
 
 
 
c54ec59
 
831e906
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
from sentence_transformers import CrossEncoder
from pprint import  pformat
from notion_client import Client
import json
import math
import numpy as np
from middlewares.search_client import SearchClient
import os
from dotenv import load_dotenv


load_dotenv()


GOOGLE_SEARCH_ENGINE_ID = os.getenv("GOOGLE_SEARCH_ENGINE_ID")
GOOGLE_SEARCH_API_KEY = os.getenv("GOOGLE_SEARCH_API_KEY")
BING_SEARCH_API_KEY = os.getenv("BING_SEARCH_API_KEY")

reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")

googleSearchClient = SearchClient(
    "google", api_key=GOOGLE_SEARCH_API_KEY, engine_id=GOOGLE_SEARCH_ENGINE_ID
)
bingSearchClient = SearchClient("bing", api_key=BING_SEARCH_API_KEY, engine_id=None)

def safe_get(data, dot_chained_keys):
    '''
        {'a': {'b': [{'c': 1}]}}
        safe_get(data, 'a.b.0.c') -> 1
    '''
    keys = dot_chained_keys.split('.')
    for key in keys:
        try:
            if isinstance(data, list):
                data = data[int(key)]
            else:
                data = data[key]
        except (KeyError, TypeError, IndexError):
            return None
    return data

def get_notion_data() :
    integration_token = "secret_lTOe0q9dqqKQLRRb2KJwi7QFSl0vqoztroRFHW6MeQE"
    notion_database_id = "6c0d877b823a4e3699016fa7083f3006"

    client = Client(auth=integration_token)

    first_db_rows = client.databases.query(notion_database_id)
    rows = []


    for row in first_db_rows['results']:
        price = safe_get(row, 'properties.($) Per Unit.number')
        store_link = safe_get(row,  'properties.Store Link.url')
        supplier_email = safe_get(row,  'properties.Supplier Email.email')
        exp_del = safe_get(row,  'properties.Expected Delivery.date')
        
        collections = safe_get(row,  'properties.Collection.multi_select')
        collection_names = []
        for collection in collections :
            collection_names.append(collection['name'])

        status = safe_get(row,  'properties.Status.select.name')
        sup_phone = safe_get(row, 'properties.Supplier Phone.phone_number')
        stock_alert = safe_get(row, 'properties.Stock Alert.status.name') 
        prod_name = safe_get(row, 'properties.Product .title.0.text.content')
        sku = safe_get(row, 'properties.SKU.number')
        shipped_date = safe_get(row, 'properties.Shipped On.date')
        on_order = safe_get(row, 'properties.On Order.number')
        on_hand = safe_get(row, 'properties.On Hand.number')
        size_names = []
        sizes = safe_get(row, 'properties.Size.multi_select')
        for size in sizes :
            size_names.append(size['name'])
        
        rows.append({
        'Price Per unit': price,
        'Store Link' : store_link,
        'Supplier Email' : supplier_email,
        'Expected Delivery' : exp_del,
        'Collection' : collection_names,
        'Status' : status,
        'Supplier Phone' : sup_phone,
        'Stock Alert' : stock_alert,
        'Product Name' : prod_name,
        'SKU' : sku,
        'Sizes' : size_names,
        'Shipped Date' : shipped_date, 
        'On Order' : on_order,
        "On Hand" : on_hand,
        })

    notion_data_string = pformat(rows)
    return notion_data_string



def rerank(query, top_k, search_results, chunk_size=512):
    chunks = []
    for result in search_results:
        text = result["text"]
        words = text.split()
        num_chunks = math.ceil(len(words) / chunk_size)
        for i in range(num_chunks):
            start = i * chunk_size
            end = (i + 1) * chunk_size
            chunk = " ".join(words[start:end])
            chunks.append((result["link"], chunk))

    # Create sentence combinations with the query
    sentence_combinations = [[query, chunk[1]] for chunk in chunks]

    # Compute similarity scores for these combinations
    similarity_scores = reranker.predict(sentence_combinations)

    # Sort scores indexes in decreasing order
    sim_scores_argsort = reversed(np.argsort(similarity_scores))

    # Rearrange search_results based on the reranked scores
    reranked_results = []
    for idx in sim_scores_argsort:
        link = chunks[idx][0]
        chunk = chunks[idx][1]
        reranked_results.append({"link": link, "text": chunk})

    # Return the top K ranks
    return reranked_results[:top_k]


def gen_augmented_prompt_via_websearch(
    prompt,
    search_vendor,
    n_crawl,
    top_k,
    pre_context="",
    post_context="",
    pre_prompt="",
    post_prompt="",
    pass_prev=False,
    prev_output="",
    chunk_size=512,
):
   
    notion_data= get_notion_data()

    search_results = []
    reranked_results = []
    if search_vendor == "Google":
        search_results = googleSearchClient.search(prompt, n_crawl)
    elif search_vendor == "Bing":
        search_results = bingSearchClient.search(prompt, n_crawl)

    if len(search_results) > 0:
        reranked_results = rerank(prompt, top_k, search_results, chunk_size)

    links = []
    context = ""
    for res in reranked_results:
        context += res["text"] + "\n\n"
        link = res["link"]
        links.append(link)

    # remove duplicate links
    links = list(set(links))

    prev_output = prev_output if pass_prev else ""

    augmented_prompt = f"""
    
    {pre_context}

    {context}

    {notion_data}
    
    {post_context}
    
    {pre_prompt} 
    
    {prompt} 
    
    {post_prompt}

    {prev_output}

    """

    print(augmented_prompt)
    return augmented_prompt, links