File size: 8,375 Bytes
e3b816c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f3de9b9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import chat.arxiv_bot.arxiv_bot_utils as utils
import google.generativeai as genai
import json
import os
from google.generativeai.types import content_types
from collections.abc import Iterable
from IPython import display
from IPython.display import Markdown

# ----------------------- define instructions -----------------------
system_instruction = """You are a library chatbot that help people to find relevant articles about a topic, or find a specific article with given title and authors.
Your job is to analyze the user question, generate enough parameters based on the user question and use the tools that are given to you.
Also, after the function call is done, you must post-process the results in a more conversational form, providing some explanation about the paper based on its summary to avoid recitation.
You must provide the link to its Arxiv pdf page."""

# --------------------------- define tools --------------------------
def search_for_relevant_article(keywords: list['str'], topic_description: str) -> str:
    """This tool is used to search for articles from the database which is relevant to a topic, using a list of more than 3 keywords and a long sentence topic description.
    If there is not enough 3 keywords from the question, the model must generate more keywords related to the topic.
    If there is no description about the topic, the model must generate a description for the function call.
    \nThe result is a string describe the records found from the database: 'Record no. - Title: <title>, Author: <authors>, Link: <link to the pdf file>, Summary: <summary of the article>'. There can be many records.
    \nIf the result is 'Information not found' it means some error has occured, or the database has no relevant article"""

    print('Keywords: {}, description: {}'.format(keywords,topic_description))
    
    results = utils.ArxivChroma.query_relevant(keywords=keywords, query_texts=topic_description)
    # print(results)
    ids = results['metadatas'][0]
    if len(ids) == 0:
        # go crawl some
        new_records = utils.crawl_arxiv(keyword_list=keywords, max_results=10)
        # print("Got new records: ",len(new_records))
        if type(new_records) == str:
            return "Information not found"
        
        utils.ArxivChroma.add(new_records)
        utils.ArxivSQL.add(new_records)
        results = utils.ArxivChroma.query_relevant(keywords=keywords, query_texts=topic_description)
        ids = results['metadatas'][0]
        # print("Re-queried on chromadb, results: ",ids)
        
    paper_id = [id['paper_id'] for id in ids]
    paper_info = utils.ArxivSQL.query_id(paper_id)
    # print(paper_info)
    records = [] # get title (2), author (3), link (6)
    result_string = ""
    if paper_info:
        for i in range(len(paper_info)):
            result_string += "Record no.{} - Title: {}, Author: {}, Link: {}, ".format(i+1,paper_info[i][2],paper_info[i][3],paper_info[i][6])
            id = paper_info[i][0]
            selected_document = utils.ArxivChroma.query_exact(id)["documents"]
            doc_str = "Summary:"
            for doc in selected_document:
                doc_str+= doc + " "
            result_string += doc_str
            records.append([paper_info[i][2],paper_info[i][3],paper_info[i][6]])
        return result_string
    else:
        return "Information not found"

def search_for_specific_article(title: str, authors: list['str']) -> str:
    """This tool is used to search for a specific article from the database, with its name and authors given.
    \nThe result is a string describe the records found from the database: 'Record no. - Title: <title>, Author: <authors>, Link: <link to the pdf file>, Summary: <summary of the article>'. There can be many records.
    \nIf the result is 'Information not found' it means some error has occured, or the database has no relevant article"""

    print('Keywords: {}, description: {}'.format(title,authors))

    paper_info = utils.ArxivSQL.query(title = title,author = authors)
    if paper_info:
        new_records = utils.crawl_exact_paper(title=title,author=authors)
        # print("Got new records: ",len(new_records))
        if type(new_records) == str:
            # print(new_records)
            return "Information not found"
        utils.ArxivChroma.add(new_records)
        utils.ArxivSQL.add(new_records)
        paper_info = utils.ArxivSQL.query(title = title,author = authors)
        # print("Re-queried on chromadb, results: ",paper_info)
    # -------------------------------------
    records = [] # get title (2), author (3), link (6)
    result_string = ""
    if paper_info:
        for i in range(len(paper_info)):
            result_string += "Record no.{} - Title: {}, Author: {}, Link: {}, ".format(i+1,paper_info[i][2],paper_info[i][3],paper_info[i][6])
            id = paper_info[i][0]
            selected_document = utils.ArxivChroma.query_exact(id)["documents"]
            doc_str = "Summary:"
            for doc in selected_document:
                doc_str+= doc + " "
            result_string += doc_str
            records.append([paper_info[i][2],paper_info[i][3],paper_info[i][6]])
    # process results:
    if len(result_string) == 0:
        return "Information not found"
    return result_string

def answer_others_questions(question: str) -> str:
    """This tool is the default option for other questions that are not related to article or paper request. The model will response the question with its own answer."""
    return question

tools = [search_for_relevant_article, search_for_specific_article, answer_others_questions]
tools_name = ['search_for_relevant_article', 'search_for_specific_article', 'answer_others_questions']

# load key, prepare config ------------------------
if os.path.exist('apikey.txt'):
    with open("apikey.txt","r") as apikey:
        key = apikey.readline()
else:
    key = os.environ.get('API_KEY')
genai.configure(api_key=key)
generation_config = {
  "temperature": 1,
  "top_p": 1,
  "top_k": 0,
  "max_output_tokens": 2048,
  "response_mime_type": "text/plain",
}
safety_settings = [
    {
        "category": "HARM_CATEGORY_DANGEROUS",
        "threshold": "BLOCK_NONE",
    },
    {
        "category": "HARM_CATEGORY_HARASSMENT",
        "threshold": "BLOCK_NONE",
    },
    {
        "category": "HARM_CATEGORY_HATE_SPEECH",
        "threshold": "BLOCK_NONE",
    },
    {
        "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
        "threshold": "BLOCK_NONE",
    },
    {
        "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
        "threshold": "BLOCK_NONE",
    },
]
# this function return a tool_config with mode 'none', 'any', 'auto'
def tool_config_from_mode(mode: str, fns: Iterable[str] = ()):
    """Create a tool config with the specified function calling mode."""
    return content_types.to_tool_config(
        {"function_calling_config": {"mode": mode, "allowed_function_names": fns}}
    )

def init_model(mode = "auto"):
    # return an instance of a model, holding its own ChatSession
    # every socket session holds its own model
    # this function must be called upon socket init, also start_chat() to begin chat
    model = genai.GenerativeModel(model_name="gemini-1.5-flash-latest",
                                 safety_settings=safety_settings,
                                 generation_config=generation_config,
                                 tools=tools,
                                 tool_config=tool_config_from_mode(mode),
                                 system_instruction=system_instruction)
    chat_instance = model.start_chat(enable_automatic_function_calling=True)
    return model, chat_instance

# handle tool call and chatsession
def full_chain_history_question(user_input, chat_instance: genai.ChatSession, mode="auto"):
    try:            
        response = chat_instance.send_message(user_input,tool_config=tool_config_from_mode(mode)).text
        return response, chat_instance.history
    except Exception as e:
        print(e)
        return f'Error occured during call: {e}', chat_instance.history

# for printing log session
def print_history(history):
    for content in history:
        part = content.parts[0]
        print(content.role, "->", type(part).to_dict(part))
        print('-'*80)

utils.ArxivChroma.connect()
utils.ArxivSQL.connect()