helliun commited on
Commit
b2c82cb
1 Parent(s): 7a36a89

Update GPT4KG.py

Browse files
Files changed (1) hide show
  1. GPT4KG.py +29 -1
GPT4KG.py CHANGED
@@ -150,9 +150,37 @@ class KnowledgeGraph:
150
  top_indices = np.argsort(similarities[0])[-n:][::-1]
151
  results = [facts[index] for index in top_indices]
152
  return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
  def chat_qa(self,query):
155
- results = self.qa_search(query)
156
  system = {"role":"system","content":"You are a helpful chatbot that answers questions based on data in your fact database."}
157
  messages = [system]
158
  text = f"Question: {query}\n\nFact Data: \n{results}"
 
150
  top_indices = np.argsort(similarities[0])[-n:][::-1]
151
  results = [facts[index] for index in top_indices]
152
  return results
153
+
154
+ def fact_search(self,query,facts,n=5):
155
+ query_embedding = self.model.encode(query)
156
+ query_tensor = torch.tensor([query_embedding])
157
+ fact_tensor = self.model.encode(facts)
158
+ similarities = util.cos_sim(query_tensor, fact_tensor).numpy()
159
+ if len(similarities[0])<n:
160
+ n = len(similarities)
161
+ top_indices = np.argsort(similarities[0])[-n:][::-1]
162
+ return top_indices
163
+
164
+ def branch_search(self,query,num_branches=2,window_size=5):
165
+ results = self.search(query, window_size)
166
+ keys = [r[0] for r in results]
167
+ facts = [key+": "+str(rel).replace("description","is")+" "+str(self.entities[key][rel]) for key in keys for rel in self.entities[key]]
168
+ top_indices = self.fact_search(query,facts)
169
+ next_ents = [str(self.entities[key][rel]) for key in keys for rel in self.entities[key]]
170
+ top_ents = [next_ents[index] for index in top_indices]
171
+ if num_branches == 0:
172
+ return [facts[index] for index in top_indices]
173
+ for b in range(num_branches):
174
+ next_facts = [facts[i]+". "+keys[i]+": "+str(rel).replace("description","is")+" "+str(self.entities[keys[i]][rel]) for i in range(len(keys)) for rel in self.entities[keys[i]] if facts[i].find(keys[i])!=0]
175
+ facts = next_facts
176
+ top_indices = self.fact_search(query,facts)
177
+ answers = [facts[index] for index in top_indices]
178
+ next_ents = [str(self.entities[key][rel]) for key in keys for rel in self.entities[key]]
179
+ top_ents = [next_ents[index] for index in top_indices]
180
+ return answers
181
 
182
  def chat_qa(self,query):
183
+ results = self.branch_search(query)
184
  system = {"role":"system","content":"You are a helpful chatbot that answers questions based on data in your fact database."}
185
  messages = [system]
186
  text = f"Question: {query}\n\nFact Data: \n{results}"