import os from dataclasses import dataclass from datetime import datetime from typing import Generator, Dict, List from googleapiclient.discovery import build from streamlit import secrets INSTRUCTIONS = "Instructions: " \ "Using the provided web search results, " \ "write a comprehensive reply to the given query. " \ "Make sure to cite results using [[number](URL)] notation after the reference. " \ "If the provided search results refer to multiple subjects with the same name, " \ "write separate answers for each subject." def get_google_api_key(): """Returns the Google API key from streamlit's secrets""" try: return secrets["google_search_api_key"] except (FileNotFoundError, IsADirectoryError): return os.environ["google_search_api_key"] def get_google_cse_id(): """Returns the Google CSE ID from streamlit's secrets""" try: return secrets["google_cse_id"] except (FileNotFoundError, IsADirectoryError): return os.environ["google_cse_id"] def google_search(search_term, **kwargs) -> list: service = build("customsearch", "v1", developerKey=get_google_api_key()) search_engine = service.cse() res = search_engine.list(q=search_term, cx=get_google_cse_id(), **kwargs).execute() return res['items'] @dataclass class SearchResult: __slots__ = ["title", "body", "url"] title: str body: str url: str def get_web_search_results( query: str, num_results: int, ) -> Generator[SearchResult, None, None]: """Gets a list of web search results using the Google search API""" rew_results: List[Dict[str, str]] = google_search( search_term=query, num=num_results )[:num_results] for result in rew_results: if result["snippet"].endswith("\xa0..."): result["snippet"] = result["snippet"][:-4] yield SearchResult( title=result["title"], body=result["snippet"], url=result["link"], ) def format_search_result(search_result: Generator[SearchResult, None, None]) -> str: """Formats a search result to be added to the prompt.""" ans = "" for i, result in enumerate(search_result): ans += f"[{i}] {result.body}\nURL: {result.url}\n\n" return ans def rewrite_prompt( prompt: str, ) -> str: """Rewrites the prompt by adding web search results to it.""" raw_results = get_web_search_results( query=prompt, num_results=5, ) formatted_results = "Web search results:\n" + format_search_result(raw_results) formatted_date = "Current date: " + datetime.now().strftime("%d/%m/%Y") formatted_prompt = f"Query: {prompt}" return "\n".join([formatted_results, formatted_date, INSTRUCTIONS, formatted_prompt])