import textwrap import os import re import argparse import requests import google.generativeai as genai from IPython.display import Markdown import gradio as gr # # Used to securely store your API key # from google.colab import userdata gemini_api_key = os.environ.get('GEMINI_API_KEY', '-1') genai.configure(api_key=gemini_api_key) S2_API_KEY = os.getenv('S2_API_KEY') initial_result_limit = 10 final_result_limit = 5 # Select relevant fields to pull fields = 'title,url,abstract,citationCount,journal,isOpenAccess,fieldsOfStudy,year,journal' def raw_to_markdown(text): text = text.replace('•', ' *') return Markdown(textwrap.indent(text, '> ', predicate=lambda _: True)) def markdown_to_raw(markdown_text): """ This function converts basic markdown text to raw text. Args: markdown_text: The markdown text string to be converted. Returns: A string containing the raw text equivalent of the markdown text. """ # Remove headers text = re.sub(r'#+ ?', '', markdown_text) # Remove bold and italics (can be adjusted based on needs) text = re.sub(r'\*\*(.+?)\*\*', r'\1', text) # Bold text = re.sub(r'_(.+?)_', r'\1', text) # Italics # Remove code blocks text = re.sub(r'`(.*?)`', '', text, flags=re.DOTALL) # Remove lists text = re.sub(r'\*+ (.*?)$', r'\1', text, flags=re.MULTILINE) # Unordered lists text.strip() # Remove extra whitespace return text def find_basis_papers(query): papers = None if not query: print('No query given') return None rsp = requests.get('https://api.semanticscholar.org/graph/v1/paper/search', headers={'X-API-KEY': S2_API_KEY}, params={'query': query, 'limit': initial_result_limit, 'fields': fields}) rsp.raise_for_status() results = rsp.json() total = results["total"] if not total: print('No matches found. Please try another query.') return None print(f'Found {total} initial results. Showing up to {initial_result_limit}.') papers = results['data'] # print("INITIAL RESULTS") # print_papers(papers) # Filter paper results filtered_papers = list(filter(isValidPaper, papers)) # print("FILTERED RESULTS") # print_papers(filtered_papers) # rank paper results ranked_papers = sorted(filtered_papers, key=lambda x: (x['year'], x['citationCount']), reverse=True) # print("RANKED RESULTS") # print_papers(ranked_papers) # return 5 best papers return ranked_papers[0:5] # def print_papers(papers): # for idx, paper in enumerate(papers): # print(f"PAPER {idx}") # for key, value in paper.items(): # if key != 'abstract': # print(f"\t{key}: '{value}'") def isValidPaper(paper): if paper['isOpenAccess'] and paper['abstract']: return True else: return False # def filter_papers(papers): # filtered_papers = [] # for paper in papers: # if paper['isOpenAccess'] and paper['abstract']: # # paper is acceptable # filtered_papers.append(paper) # return filtered_papers def GEMINI_optimize_query(initial_query: str): # initialize gemini LLM model = genai.GenerativeModel('gemini-pro') chat = model.start_chat(history=[]) prompt = f"""Given a search query, return an optimized version of the query to find related academic papers QUERY: {initial_query}. Only return the optimized query""" response = chat.send_message(prompt) optimized_query = markdown_to_raw(response.text) return optimized_query def GEMINI_summarize_abstracts(initial_query: str, papers: str): # initialize gemini LLM model = genai.GenerativeModel('gemini-pro') chat = model.start_chat(history=[]) prompt = f"""Given the following academic papers, return a review of related literature for the search query: {initial_query}. Ignore papers without abstracts. Here are the papers {papers} """ response = chat.send_message(prompt) abstract_summary = markdown_to_raw(response.text) return abstract_summary def create_gemini_model(): # initialize gemini LLM model = genai.GenerativeModel('gemini-pro') chat = model.start_chat(history=[]) return model, chat # instantiate models summarizer_model, summarizer_chat = create_gemini_model() query_optimizer_model, query_optimizer_chat = create_gemini_model() # def get_paper_links(papers): # urls = [] # for paper in papers: # urls = paper['url'] # return urls def predict(message, history): if history == []: query = message print(f"INITIAL QUERY: {query}") if optimize_query: optimizer_prompt = f"""Given a search query, return an optimized version of the query to find related academic papers QUERY: {query}. Only return the optimized query""" response = query_optimizer_chat.send_message(optimizer_prompt) query = markdown_to_raw(response.text) print(f"OPTIMIZED QUERY: {query}") # optimized query used to search semantic scholar papers = find_basis_papers(query) summarizer_prompt = f"""Given the following academic papers, return a review of related literature for the search query: {query}. Focus on data/key factors and methodologies considered. Here are the papers {papers} Include the paper urls at the end of the review of related literature. """ response = summarizer_chat.send_message(summarizer_prompt) abstract_summary = markdown_to_raw(response.text) return abstract_summary response = summarizer_chat.send_message(message) response_text = markdown_to_raw(response.text) return response_text def main(): # GEMINI optimizes query gr.ChatInterface( predict, title="LLM Research Helper", description="""Start by inputting a brief description/title of your research and our assistant will return a review of related literature ex. Finding optimal site locations for solar farms""", examples=['Finding optimal site locations for solar farms', 'Wildfire prediction', 'Fish yield prediction'] ).launch() if __name__ == '__main__': parser = argparse.ArgumentParser(description="Literature review chatbot") parser.add_argument("-o", "--optimize_query", help="Use query optimization (True, False)", default=False) args = parser.parse_args() optimize_query = args.optimize_query if args.optimize_query in [True, False] else False main()