Artteiv tosanoob commited on
Commit
9b87900
1 Parent(s): e46a8b9

Fix model retrieval and improve answering by adding summary context (#5)

Browse files

- Fix model retrieval and improve answering by adding summary context (681a1427df5b7b3e66349fd6d4caafd23aee82ae)


Co-authored-by: Trương Tấn Cường <tosanoob@users.noreply.huggingface.co>

Files changed (1) hide show
  1. chat/model_manage.py +133 -64
chat/model_manage.py CHANGED
@@ -5,6 +5,25 @@ import json
5
 
6
  model = None
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  def create_model():
9
  with open("apikey.txt","r") as apikey:
10
  key = apikey.readline()
@@ -14,7 +33,7 @@ def create_model():
14
  print(m.name)
15
  print("He was there")
16
  config = genai.GenerationConfig(max_output_tokens=2048,
17
- temperature=0.7)
18
  safety_settings = [
19
  {
20
  "category": "HARM_CATEGORY_DANGEROUS",
@@ -37,53 +56,71 @@ def create_model():
37
  "threshold": "BLOCK_NONE",
38
  },
39
  ]
40
- global model
41
- model = genai.GenerativeModel("gemini-pro",
42
  generation_config=config,
43
  safety_settings=safety_settings)
44
- return model
 
 
 
 
 
 
 
 
45
 
46
  def get_model():
47
- global model
48
  if model is None:
49
  # Khởi tạo model ở đây
50
- model = create_model() # Giả sử create_model là hàm tạo model của bạn
51
- return model
52
 
53
  def extract_keyword_prompt(query):
54
  """A prompt that return a JSON block as arguments for querying database"""
55
 
56
- prompt = (
57
- """[INST] SYSTEM: You are an assistant that choose only one action below based on guest question.
58
- 1. If the guest question is asking for a single specific document or article with explicit title, you need to respond the information in JSON format with 2 keys "title", "author" if found any above. The authors are separated with the word 'and'.
59
- 2. If the guest question is asking for relevant informations about a topic, you need to respond the information in JSON format with 2 keys "keywords", "description", include a list of keywords represent the main academic topic, \
60
- and a description about the main topic. You may paraphrase the keywords to add more. \
61
- 3. If the guest is not asking for any informations or documents, you need to respond with a polite answer in JSON format with 1 key "answer".
62
- QUESTION: '{query}'
63
- [/INST]
64
- ANSWER:
65
- """
66
- ).format(query=query)
67
-
68
  return prompt
69
 
70
  def make_answer_prompt(input, contexts):
71
  """A prompt that return the final answer, based on the queried context"""
72
 
73
  prompt = (
74
- """[INST] You are a library assistant that help to search articles and documents based on user's question.
75
- From guest's question, you have found some records and documents that may help. Now you need to answer the guest with the information found.
76
- If no information found in the database, you may generate some other recommendation related to user's question using your own knowledge. Each article or paper must have a link to the pdf download page.
77
- You should answer in a conversational form politely.
78
- QUESTION: '{input}'
79
  INFORMATION: '{contexts}'
80
  [/INST]
81
  ANSWER:
82
  """
83
  ).format(input=input, contexts=contexts)
84
-
85
  return prompt
86
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  def response(args, db_instance):
88
  """Create response context, based on input arguments"""
89
  keys = list(dict.keys(args))
@@ -115,41 +152,48 @@ def response(args, db_instance):
115
  result_string = ""
116
  if paper_info:
117
  for i in range(len(paper_info)):
118
- result_string += "Title: {}, Author: {}, Link: {}".format(paper_info[i][2],paper_info[i][3],paper_info[i][6])
 
 
 
 
 
 
119
  records.append([paper_info[i][2],paper_info[i][3],paper_info[i][6]])
120
- return result_string, records
121
  else:
122
  return "Information not found", "Information not found"
123
  # invoke llm and return result
124
 
125
- if "title" in keys:
126
- title = args['title']
127
- authors = utils.authors_str_to_list(args['author'])
128
- paper_info = db_instance.query(title = title,author = authors)
129
- # if query not found then go crawl brh
130
- # print(paper_info)
131
 
132
- if len(paper_info) == 0:
133
- new_records = utils.crawl_exact_paper(title=title,author=authors)
134
- print("Got new records: ",len(new_records))
135
- if type(new_records) == str:
136
- # print(new_records)
137
- return "Error occured, information not found", "Information not found"
138
- utils.db.add(new_records)
139
- db_instance.add(new_records)
140
- paper_info = db_instance.query(title = title,author = authors)
141
- print("Re-queried on chromadb, results: ",paper_info)
142
- # -------------------------------------
143
- records = [] # get title (2), author (3), link (6)
144
- result_string = ""
145
- for i in range(len(paper_info)):
146
- result_string += "Title: {}, Author: {}, Link: {}".format(paper_info[i][2],paper_info[i][3],paper_info[i][6])
147
- records.append([paper_info[i][2],paper_info[i][3],paper_info[i][6]])
148
- # process results:
149
- if len(result_string) == 0:
150
- return "Information not found", "Information not found"
151
- return result_string, records
152
  # invoke llm and return result
 
153
  def full_chain_single_question(input_prompt, db_instance):
154
  try:
155
  first_prompt = extract_keyword_prompt(input_prompt)
@@ -180,23 +224,48 @@ def format_chat_history_from_web(chat_history: list):
180
  )
181
  return temp_chat
182
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  def full_chain_history_question(chat_history: list, db_instance):
184
  try:
185
  temp_chat = format_chat_history_from_web(chat_history)
186
- first_prompt = extract_keyword_prompt(temp_chat[-1]["parts"][0])
187
- temp_answer = model.generate_content(first_prompt).text
 
 
 
188
 
189
- args = json.loads(utils.trimming(temp_answer))
190
  contexts, results = response(args, db_instance)
191
  if not results:
192
- # print(contexts)
193
  return "Random question, direct return", contexts
194
  else:
195
- QA_Prompt = make_answer_prompt(temp_chat[-1]["parts"][0], contexts)
196
- temp_chat[-1]["parts"] = QA_Prompt
197
- print(temp_chat)
198
- answer = model.generate_content(temp_chat).text
199
- return temp_answer, answer
200
  except Exception as e:
201
- # print(e)
202
- return temp_answer, "Error occured: " + str(e)
 
 
 
5
 
6
  model = None
7
 
8
+ model_retrieval = None
9
+
10
+ model_answer = None
11
+
12
+ RETRIEVAL_INSTRUCT = """You are an auto chatbot that response with only one action below based on user question.
13
+ 1. If the guest question is asking about a science topic, you need to respond the information in JSON schema below:
14
+ {
15
+ "keywords": [a list of string keywords about the topic],
16
+ "description": "a paragraph describing the topic in about 50 to 100 words"
17
+ }
18
+ 2. If the guest is not asking for any informations or documents, you need to respond in JSON schema below:
19
+ {
20
+ "answer": "your answer to the user question"
21
+ }"""
22
+
23
+ ANSWER_INSTRUCT = """You are a library assistant that help answering customer question based on the information given.
24
+ You always answer in a conversational form naturally and politely.
25
+ You must introduce all the records given, each must contain title, authors and the link to the pdf file."""
26
+
27
  def create_model():
28
  with open("apikey.txt","r") as apikey:
29
  key = apikey.readline()
 
33
  print(m.name)
34
  print("He was there")
35
  config = genai.GenerationConfig(max_output_tokens=2048,
36
+ temperature=1.0)
37
  safety_settings = [
38
  {
39
  "category": "HARM_CATEGORY_DANGEROUS",
 
56
  "threshold": "BLOCK_NONE",
57
  },
58
  ]
59
+ global model, model_retrieval, model_answer
60
+ model = genai.GenerativeModel("gemini-1.5-pro-latest",
61
  generation_config=config,
62
  safety_settings=safety_settings)
63
+ model_retrieval = genai.GenerativeModel("gemini-1.5-pro-latest",
64
+ generation_config=config,
65
+ safety_settings=safety_settings,
66
+ system_instruction=RETRIEVAL_INSTRUCT)
67
+ model_answer = genai.GenerativeModel("gemini-1.5-pro-latest",
68
+ generation_config=config,
69
+ safety_settings=safety_settings,
70
+ system_instruction=ANSWER_INSTRUCT)
71
+ return model, model_answer, model_retrieval
72
 
73
  def get_model():
74
+ global model, model_answer, model_retrieval
75
  if model is None:
76
  # Khởi tạo model ở đây
77
+ model, model_answer, model_retrieval = create_model() # Giả sử create_model là hàm tạo model của bạn
78
+ return model, model_answer, model_retrieval
79
 
80
  def extract_keyword_prompt(query):
81
  """A prompt that return a JSON block as arguments for querying database"""
82
 
83
+ prompt = """[INST] SYSTEM: You are an auto chatbot that response with only one action below based on user question.
84
+ 1. If the guest question is asking about a science topic, you need to respond the information in JSON schema below:
85
+ {
86
+ "keywords": [a list of string keywords about the topic],
87
+ "description": "a paragraph describing the topic in about 50 to 100 words"
88
+ }
89
+ 2. If the guest is not asking for any informations or documents, you need to respond in JSON schema below:
90
+ {
91
+ "answer": "your answer to the user question"
92
+ }
93
+ QUESTION: """ + query + """[/INST]
94
+ ANSWER: """
95
  return prompt
96
 
97
  def make_answer_prompt(input, contexts):
98
  """A prompt that return the final answer, based on the queried context"""
99
 
100
  prompt = (
101
+ """[INST] You are a library assistant that help answering customer QUESTION based on the INFORMATION given.
102
+ You always answer in a conversational form naturally and politely.
103
+ You must introduce all the records given, each must contain title, authors and the link to the pdf file.
104
+ QUESTION: {input}
 
105
  INFORMATION: '{contexts}'
106
  [/INST]
107
  ANSWER:
108
  """
109
  ).format(input=input, contexts=contexts)
 
110
  return prompt
111
 
112
+ def retrieval_chat_template(question):
113
+ return {
114
+ "role":"user",
115
+ "parts":[f"QUESTION: {question} \n ANSWER:"]
116
+ }
117
+
118
+ def answer_chat_template(question, contexts):
119
+ return {
120
+ "role":"user",
121
+ "parts":[f"QUESTION: {question} \n INFORMATION: {contexts} \n ANSWER:"]
122
+ }
123
+
124
  def response(args, db_instance):
125
  """Create response context, based on input arguments"""
126
  keys = list(dict.keys(args))
 
152
  result_string = ""
153
  if paper_info:
154
  for i in range(len(paper_info)):
155
+ result_string += "Record no.{} - Title: {}, Author: {}, Link: {}, ".format(i+1,paper_info[i][2],paper_info[i][3],paper_info[i][6])
156
+ id = paper_info[i][0]
157
+ selected_document = utils.db.query_exact(id)["documents"]
158
+ doc_str = "Summary:"
159
+ for doc in selected_document:
160
+ doc_str+= doc + " "
161
+ result_string += doc_str
162
  records.append([paper_info[i][2],paper_info[i][3],paper_info[i][6]])
163
+ return result_string, records
164
  else:
165
  return "Information not found", "Information not found"
166
  # invoke llm and return result
167
 
168
+ # if "title" in keys:
169
+ # title = args['title']
170
+ # authors = utils.authors_str_to_list(args['author'])
171
+ # paper_info = db_instance.query(title = title,author = authors)
172
+ # # if query not found then go crawl brh
173
+ # # print(paper_info)
174
 
175
+ # if len(paper_info) == 0:
176
+ # new_records = utils.crawl_exact_paper(title=title,author=authors)
177
+ # print("Got new records: ",len(new_records))
178
+ # if type(new_records) == str:
179
+ # # print(new_records)
180
+ # return "Error occured, information not found", "Information not found"
181
+ # utils.db.add(new_records)
182
+ # db_instance.add(new_records)
183
+ # paper_info = db_instance.query(title = title,author = authors)
184
+ # print("Re-queried on chromadb, results: ",paper_info)
185
+ # # -------------------------------------
186
+ # records = [] # get title (2), author (3), link (6)
187
+ # result_string = ""
188
+ # for i in range(len(paper_info)):
189
+ # result_string += "Title: {}, Author: {}, Link: {}".format(paper_info[i][2],paper_info[i][3],paper_info[i][6])
190
+ # records.append([paper_info[i][2],paper_info[i][3],paper_info[i][6]])
191
+ # # process results:
192
+ # if len(result_string) == 0:
193
+ # return "Information not found", "Information not found"
194
+ # return result_string, records
195
  # invoke llm and return result
196
+
197
  def full_chain_single_question(input_prompt, db_instance):
198
  try:
199
  first_prompt = extract_keyword_prompt(input_prompt)
 
224
  )
225
  return temp_chat
226
 
227
+ # def full_chain_history_question(chat_history: list, db_instance):
228
+ # try:
229
+ # temp_chat = format_chat_history_from_web(chat_history)
230
+ # print('Extracted temp chat: ',temp_chat)
231
+ # first_prompt = extract_keyword_prompt(temp_chat[-1]["parts"][0])
232
+ # temp_answer = model.generate_content(first_prompt).text
233
+
234
+ # args = json.loads(utils.trimming(temp_answer))
235
+ # contexts, results = response(args, db_instance)
236
+ # print('Context extracted: ',contexts)
237
+ # if not results:
238
+ # return "Random question, direct return", contexts
239
+ # else:
240
+ # QA_Prompt = make_answer_prompt(temp_chat[-1]["parts"][0], contexts)
241
+ # temp_chat[-1]["parts"] = QA_Prompt
242
+ # print(temp_chat)
243
+ # answer = model.generate_content(temp_chat).text
244
+ # return temp_answer, answer
245
+ # except Exception as e:
246
+ # # print(e)
247
+ # return temp_answer, "Error occured: " + str(e)
248
+
249
  def full_chain_history_question(chat_history: list, db_instance):
250
  try:
251
  temp_chat = format_chat_history_from_web(chat_history)
252
+ question = temp_chat[-1]['parts'][0]
253
+ first_answer = model_retrieval.generate_content(temp_chat).text
254
+
255
+ print(first_answer)
256
+ args = json.loads(utils.trimming(first_answer))
257
 
 
258
  contexts, results = response(args, db_instance)
259
  if not results:
 
260
  return "Random question, direct return", contexts
261
  else:
262
+ print('Context to answers: ',contexts)
263
+ answer_chat = answer_chat_template(question, contexts)
264
+ temp_chat[-1] = answer_chat
265
+ answer = model_answer.generate_content(temp_chat).text
266
+ return first_answer, answer
267
  except Exception as e:
268
+ if first_answer:
269
+ return first_answer, "Error occured: " + str(e)
270
+ else:
271
+ return "No answer", "Error occured: " + str(e)