yonikremer
commited on
Commit
•
27e2360
1
Parent(s):
32e4e72
using google search api
Browse files- prompt_engeneering.py +37 -19
- requirements.txt +4 -3
prompt_engeneering.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1 |
from dataclasses import dataclass
|
2 |
from datetime import datetime
|
3 |
-
from typing import Generator
|
4 |
|
5 |
-
import
|
|
|
6 |
|
7 |
INSTRUCTIONS = "Instructions: " \
|
8 |
"Using the provided web search results, " \
|
@@ -12,30 +13,47 @@ INSTRUCTIONS = "Instructions: " \
|
|
12 |
"write separate answers for each subject."
|
13 |
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
@dataclass
|
16 |
class SearchResult:
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
|
22 |
|
23 |
def get_web_search_results(
|
24 |
-
|
25 |
num_results: int,
|
26 |
) -> Generator[SearchResult, None, None]:
|
27 |
-
"""
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
yield SearchResult(
|
36 |
title=result["title"],
|
37 |
-
body=result["
|
38 |
-
url=result["
|
39 |
)
|
40 |
|
41 |
|
@@ -52,10 +70,10 @@ def rewrite_prompt(
|
|
52 |
) -> str:
|
53 |
"""Rewrites the prompt by adding web search results to it."""
|
54 |
raw_results = get_web_search_results(
|
55 |
-
|
56 |
num_results=5,
|
57 |
)
|
58 |
-
formatted_results = "Web search results
|
59 |
formatted_date = "Current date: " + datetime.now().strftime("%d/%m/%Y")
|
60 |
formatted_prompt = f"Query: {prompt}"
|
61 |
return "\n".join([formatted_results, formatted_date, INSTRUCTIONS, formatted_prompt])
|
|
|
1 |
from dataclasses import dataclass
|
2 |
from datetime import datetime
|
3 |
+
from typing import Generator, Dict, List
|
4 |
|
5 |
+
from googleapiclient.discovery import build
|
6 |
+
from streamlit import secrets
|
7 |
|
8 |
INSTRUCTIONS = "Instructions: " \
|
9 |
"Using the provided web search results, " \
|
|
|
13 |
"write separate answers for each subject."
|
14 |
|
15 |
|
16 |
+
def get_google_api_key():
|
17 |
+
"""Returns the Google API key from streamlit's secrets"""
|
18 |
+
return secrets["google_search_api_key"]
|
19 |
+
|
20 |
+
|
21 |
+
def get_google_cse_id():
|
22 |
+
"""Returns the Google CSE ID from streamlit's secrets"""
|
23 |
+
return secrets["google_cse_id"]
|
24 |
+
|
25 |
+
|
26 |
+
def google_search(search_term, **kwargs) -> list:
|
27 |
+
service = build("customsearch", "v1", developerKey=get_google_api_key())
|
28 |
+
search_engine = service.cse()
|
29 |
+
res = search_engine.list(q=search_term, cx=get_google_cse_id(), **kwargs).execute()
|
30 |
+
return res['items']
|
31 |
+
|
32 |
+
|
33 |
@dataclass
|
34 |
class SearchResult:
|
35 |
+
__slots__ = ["title", "body", "url"]
|
36 |
+
title: str
|
37 |
+
body: str
|
38 |
+
url: str
|
39 |
|
40 |
|
41 |
def get_web_search_results(
|
42 |
+
query: str,
|
43 |
num_results: int,
|
44 |
) -> Generator[SearchResult, None, None]:
|
45 |
+
"""Gets a list of web search results using the Google search API"""
|
46 |
+
rew_results: List[Dict[str, str]] = google_search(
|
47 |
+
search_term=query,
|
48 |
+
num=num_results
|
49 |
+
)[:num_results]
|
50 |
+
for result in rew_results:
|
51 |
+
if result["snippet"].endswith("\xa0..."):
|
52 |
+
result["snippet"] = result["snippet"][:-4]
|
53 |
yield SearchResult(
|
54 |
title=result["title"],
|
55 |
+
body=result["snippet"],
|
56 |
+
url=result["link"],
|
57 |
)
|
58 |
|
59 |
|
|
|
70 |
) -> str:
|
71 |
"""Rewrites the prompt by adding web search results to it."""
|
72 |
raw_results = get_web_search_results(
|
73 |
+
query=prompt,
|
74 |
num_results=5,
|
75 |
)
|
76 |
+
formatted_results = "Web search results:\n" + format_search_result(raw_results)
|
77 |
formatted_date = "Current date: " + datetime.now().strftime("%d/%m/%Y")
|
78 |
formatted_prompt = f"Query: {prompt}"
|
79 |
return "\n".join([formatted_results, formatted_date, INSTRUCTIONS, formatted_prompt])
|
requirements.txt
CHANGED
@@ -4,9 +4,10 @@ torch>1.12.1
|
|
4 |
transformers~=4.26.0
|
5 |
hatchling
|
6 |
beautifulsoup4~=4.11.2
|
7 |
-
urllib3
|
8 |
requests~=2.28.2
|
9 |
accelerate
|
10 |
bitsandbytes
|
11 |
-
pytest
|
12 |
-
sentencepiece
|
|
|
|
4 |
transformers~=4.26.0
|
5 |
hatchling
|
6 |
beautifulsoup4~=4.11.2
|
7 |
+
urllib3~=1.26.14
|
8 |
requests~=2.28.2
|
9 |
accelerate
|
10 |
bitsandbytes
|
11 |
+
pytest~=7.1.2
|
12 |
+
sentencepiece
|
13 |
+
google-api-python-client
|