arxiv_chatbot / chat /model_manage.py
tosanoob's picture
Update to use environment variables apikey
e3b816c verified
raw
history blame
8.38 kB
import chat.arxiv_bot.arxiv_bot_utils as utils
import google.generativeai as genai
import json
import os
from google.generativeai.types import content_types
from collections.abc import Iterable
from IPython import display
from IPython.display import Markdown
# ----------------------- define instructions -----------------------
system_instruction = """You are a library chatbot that help people to find relevant articles about a topic, or find a specific article with given title and authors.
Your job is to analyze the user question, generate enough parameters based on the user question and use the tools that are given to you.
Also, after the function call is done, you must post-process the results in a more conversational form, providing some explanation about the paper based on its summary to avoid recitation.
You must provide the link to its Arxiv pdf page."""
# --------------------------- define tools --------------------------
def search_for_relevant_article(keywords: list['str'], topic_description: str) -> str:
"""This tool is used to search for articles from the database which is relevant to a topic, using a list of more than 3 keywords and a long sentence topic description.
If there is not enough 3 keywords from the question, the model must generate more keywords related to the topic.
If there is no description about the topic, the model must generate a description for the function call.
\nThe result is a string describe the records found from the database: 'Record no. - Title: <title>, Author: <authors>, Link: <link to the pdf file>, Summary: <summary of the article>'. There can be many records.
\nIf the result is 'Information not found' it means some error has occured, or the database has no relevant article"""
print('Keywords: {}, description: {}'.format(keywords,topic_description))
results = utils.ArxivChroma.query_relevant(keywords=keywords, query_texts=topic_description)
# print(results)
ids = results['metadatas'][0]
if len(ids) == 0:
# go crawl some
new_records = utils.crawl_arxiv(keyword_list=keywords, max_results=10)
# print("Got new records: ",len(new_records))
if type(new_records) == str:
return "Information not found"
utils.ArxivChroma.add(new_records)
utils.ArxivSQL.add(new_records)
results = utils.ArxivChroma.query_relevant(keywords=keywords, query_texts=topic_description)
ids = results['metadatas'][0]
# print("Re-queried on chromadb, results: ",ids)
paper_id = [id['paper_id'] for id in ids]
paper_info = utils.ArxivSQL.query_id(paper_id)
# print(paper_info)
records = [] # get title (2), author (3), link (6)
result_string = ""
if paper_info:
for i in range(len(paper_info)):
result_string += "Record no.{} - Title: {}, Author: {}, Link: {}, ".format(i+1,paper_info[i][2],paper_info[i][3],paper_info[i][6])
id = paper_info[i][0]
selected_document = utils.ArxivChroma.query_exact(id)["documents"]
doc_str = "Summary:"
for doc in selected_document:
doc_str+= doc + " "
result_string += doc_str
records.append([paper_info[i][2],paper_info[i][3],paper_info[i][6]])
return result_string
else:
return "Information not found"
def search_for_specific_article(title: str, authors: list['str']) -> str:
"""This tool is used to search for a specific article from the database, with its name and authors given.
\nThe result is a string describe the records found from the database: 'Record no. - Title: <title>, Author: <authors>, Link: <link to the pdf file>, Summary: <summary of the article>'. There can be many records.
\nIf the result is 'Information not found' it means some error has occured, or the database has no relevant article"""
print('Keywords: {}, description: {}'.format(title,authors))
paper_info = utils.ArxivSQL.query(title = title,author = authors)
if paper_info:
new_records = utils.crawl_exact_paper(title=title,author=authors)
# print("Got new records: ",len(new_records))
if type(new_records) == str:
# print(new_records)
return "Information not found"
utils.ArxivChroma.add(new_records)
utils.ArxivSQL.add(new_records)
paper_info = utils.ArxivSQL.query(title = title,author = authors)
# print("Re-queried on chromadb, results: ",paper_info)
# -------------------------------------
records = [] # get title (2), author (3), link (6)
result_string = ""
if paper_info:
for i in range(len(paper_info)):
result_string += "Record no.{} - Title: {}, Author: {}, Link: {}, ".format(i+1,paper_info[i][2],paper_info[i][3],paper_info[i][6])
id = paper_info[i][0]
selected_document = utils.ArxivChroma.query_exact(id)["documents"]
doc_str = "Summary:"
for doc in selected_document:
doc_str+= doc + " "
result_string += doc_str
records.append([paper_info[i][2],paper_info[i][3],paper_info[i][6]])
# process results:
if len(result_string) == 0:
return "Information not found"
return result_string
def answer_others_questions(question: str) -> str:
"""This tool is the default option for other questions that are not related to article or paper request. The model will response the question with its own answer."""
return question
tools = [search_for_relevant_article, search_for_specific_article, answer_others_questions]
tools_name = ['search_for_relevant_article', 'search_for_specific_article', 'answer_others_questions']
# load key, prepare config ------------------------
if os.path.exist('apikey.txt'):
with open("apikey.txt","r") as apikey:
key = apikey.readline()
else:
key = os.environ.get('API_KEY')
genai.configure(api_key=key)
generation_config = {
"temperature": 1,
"top_p": 1,
"top_k": 0,
"max_output_tokens": 2048,
"response_mime_type": "text/plain",
}
safety_settings = [
{
"category": "HARM_CATEGORY_DANGEROUS",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_NONE",
},
]
# this function return a tool_config with mode 'none', 'any', 'auto'
def tool_config_from_mode(mode: str, fns: Iterable[str] = ()):
"""Create a tool config with the specified function calling mode."""
return content_types.to_tool_config(
{"function_calling_config": {"mode": mode, "allowed_function_names": fns}}
)
def init_model(mode = "auto"):
# return an instance of a model, holding its own ChatSession
# every socket session holds its own model
# this function must be called upon socket init, also start_chat() to begin chat
model = genai.GenerativeModel(model_name="gemini-1.5-flash-latest",
safety_settings=safety_settings,
generation_config=generation_config,
tools=tools,
tool_config=tool_config_from_mode(mode),
system_instruction=system_instruction)
chat_instance = model.start_chat(enable_automatic_function_calling=True)
return model, chat_instance
# handle tool call and chatsession
def full_chain_history_question(user_input, chat_instance: genai.ChatSession, mode="auto"):
try:
response = chat_instance.send_message(user_input,tool_config=tool_config_from_mode(mode)).text
return response, chat_instance.history
except Exception as e:
print(e)
return f'Error occured during call: {e}', chat_instance.history
# for printing log session
def print_history(history):
for content in history:
part = content.parts[0]
print(content.role, "->", type(part).to_dict(part))
print('-'*80)
utils.ArxivChroma.connect()
utils.ArxivSQL.connect()