Spaces:
Runtime error
Runtime error
# my_app/model_manager.py | |
import google.generativeai as genai | |
import chat.arxiv_bot.arxiv_bot_utils as utils | |
import json | |
model = None | |
def create_model(): | |
with open("apikey.txt","r") as apikey: | |
key = apikey.readline() | |
genai.configure(api_key=key) | |
for m in genai.list_models(): | |
if 'generateContent' in m.supported_generation_methods: | |
print(m.name) | |
print("He was there") | |
config = genai.GenerationConfig(max_output_tokens=2048, | |
temperature=0.7) | |
safety_settings = [ | |
{ | |
"category": "HARM_CATEGORY_DANGEROUS", | |
"threshold": "BLOCK_NONE", | |
}, | |
{ | |
"category": "HARM_CATEGORY_HARASSMENT", | |
"threshold": "BLOCK_NONE", | |
}, | |
{ | |
"category": "HARM_CATEGORY_HATE_SPEECH", | |
"threshold": "BLOCK_NONE", | |
}, | |
{ | |
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", | |
"threshold": "BLOCK_NONE", | |
}, | |
{ | |
"category": "HARM_CATEGORY_DANGEROUS_CONTENT", | |
"threshold": "BLOCK_NONE", | |
}, | |
] | |
global model | |
model = genai.GenerativeModel("gemini-pro", | |
generation_config=config, | |
safety_settings=safety_settings) | |
return model | |
def get_model(): | |
global model | |
if model is None: | |
# Khởi tạo model ở đây | |
model = create_model() # Giả sử create_model là hàm tạo model của bạn | |
return model | |
def extract_keyword_prompt(query): | |
"""A prompt that return a JSON block as arguments for querying database""" | |
prompt = ( | |
"""[INST] SYSTEM: You are an assistant that choose only one action below based on guest question. | |
1. If the guest question is asking for a single specific document or article with explicit title, you need to respond the information in JSON format with 2 keys "title", "author" if found any above. The authors are separated with the word 'and'. | |
2. If the guest question is asking for relevant informations about a topic, you need to respond the information in JSON format with 2 keys "keywords", "description", include a list of keywords represent the main academic topic, \ | |
and a description about the main topic. You may paraphrase the keywords to add more. \ | |
3. If the guest is not asking for any informations or documents, you need to respond with a polite answer in JSON format with 1 key "answer". | |
QUESTION: '{query}' | |
[/INST] | |
ANSWER: | |
""" | |
).format(query=query) | |
return prompt | |
def make_answer_prompt(input, contexts): | |
"""A prompt that return the final answer, based on the queried context""" | |
prompt = ( | |
"""[INST] You are a library assistant that help to search articles and documents based on user's question. | |
From guest's question, you have found some records and documents that may help. Now you need to answer the guest with the information found. | |
If no information found in the database, you may generate some other recommendation related to user's question using your own knowledge. Each article or paper must have a link to the pdf download page. | |
You should answer in a conversational form politely. | |
QUESTION: '{input}' | |
INFORMATION: '{contexts}' | |
[/INST] | |
ANSWER: | |
""" | |
).format(input=input, contexts=contexts) | |
return prompt | |
def response(args, db_instance): | |
"""Create response context, based on input arguments""" | |
keys = list(dict.keys(args)) | |
if "answer" in keys: | |
return args['answer'], None # trả lời trực tiếp | |
if "keywords" in keys: | |
# perform query | |
query_texts = args["description"] | |
keywords = args["keywords"] | |
results = utils.db.query_relevant(keywords=keywords, query_texts=query_texts) | |
# print(results) | |
ids = results['metadatas'][0] | |
if len(ids) == 0: | |
# go crawl some | |
new_records = utils.crawl_arxiv(keyword_list=keywords, max_results=10) | |
print("Got new records: ",len(new_records)) | |
if type(new_records) == str: | |
return "Error occured, information not found", new_records | |
utils.db.add(new_records) | |
db_instance.add(new_records) | |
results = utils.db.query_relevant(keywords=keywords, query_texts=query_texts) | |
ids = results['metadatas'][0] | |
print("Re-queried on chromadb, results: ",ids) | |
paper_id = [id['paper_id'] for id in ids] | |
paper_info = db_instance.query_id(paper_id) | |
print(paper_info) | |
records = [] # get title (2), author (3), link (6) | |
result_string = "" | |
if paper_info: | |
for i in range(len(paper_info)): | |
result_string += "Title: {}, Author: {}, Link: {}".format(paper_info[i][2],paper_info[i][3],paper_info[i][6]) | |
records.append([paper_info[i][2],paper_info[i][3],paper_info[i][6]]) | |
return result_string, records | |
else: | |
return "Information not found", "Information not found" | |
# invoke llm and return result | |
if "title" in keys: | |
title = args['title'] | |
authors = utils.authors_str_to_list(args['author']) | |
paper_info = db_instance.query(title = title,author = authors) | |
# if query not found then go crawl brh | |
# print(paper_info) | |
if len(paper_info) == 0: | |
new_records = utils.crawl_exact_paper(title=title,author=authors) | |
print("Got new records: ",len(new_records)) | |
if type(new_records) == str: | |
# print(new_records) | |
return "Error occured, information not found", "Information not found" | |
utils.db.add(new_records) | |
db_instance.add(new_records) | |
paper_info = db_instance.query(title = title,author = authors) | |
print("Re-queried on chromadb, results: ",paper_info) | |
# ------------------------------------- | |
records = [] # get title (2), author (3), link (6) | |
result_string = "" | |
for i in range(len(paper_info)): | |
result_string += "Title: {}, Author: {}, Link: {}".format(paper_info[i][2],paper_info[i][3],paper_info[i][6]) | |
records.append([paper_info[i][2],paper_info[i][3],paper_info[i][6]]) | |
# process results: | |
if len(result_string) == 0: | |
return "Information not found", "Information not found" | |
return result_string, records | |
# invoke llm and return result | |
def full_chain_single_question(input_prompt, db_instance): | |
try: | |
first_prompt = extract_keyword_prompt(input_prompt) | |
temp_answer = model.generate_content(first_prompt).text | |
args = json.loads(utils.trimming(temp_answer)) | |
contexts, results = response(args, db_instance) | |
if not results: | |
# print(contexts) | |
return "Random question, direct return", contexts | |
else: | |
output_prompt = make_answer_prompt(input_prompt,contexts) | |
answer = model.generate_content(output_prompt).text | |
return temp_answer, answer | |
except Exception as e: | |
# print(e) | |
return temp_answer, "Error occured: " + str(e) | |
def format_chat_history_from_web(chat_history: list): | |
temp_chat = [] | |
for message in chat_history: | |
temp_chat.append( | |
{ | |
"role": message["role"], | |
"parts": [message["content"]] | |
} | |
) | |
return temp_chat | |
def full_chain_history_question(chat_history: list, db_instance): | |
try: | |
temp_chat = format_chat_history_from_web(chat_history) | |
first_prompt = extract_keyword_prompt(temp_chat[-1]["parts"][0]) | |
temp_answer = model.generate_content(first_prompt).text | |
args = json.loads(utils.trimming(temp_answer)) | |
contexts, results = response(args, db_instance) | |
if not results: | |
# print(contexts) | |
return "Random question, direct return", contexts | |
else: | |
QA_Prompt = make_answer_prompt(temp_chat[-1]["parts"][0], contexts) | |
temp_chat[-1]["parts"] = QA_Prompt | |
print(temp_chat) | |
answer = model.generate_content(temp_chat).text | |
return temp_answer, answer | |
except Exception as e: | |
# print(e) | |
return temp_answer, "Error occured: " + str(e) |