Spaces:
Runtime error
Runtime error
Update GPT4KG.py
Browse files
GPT4KG.py
CHANGED
@@ -150,9 +150,37 @@ class KnowledgeGraph:
|
|
150 |
top_indices = np.argsort(similarities[0])[-n:][::-1]
|
151 |
results = [facts[index] for index in top_indices]
|
152 |
return results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
|
154 |
def chat_qa(self,query):
|
155 |
-
results = self.
|
156 |
system = {"role":"system","content":"You are a helpful chatbot that answers questions based on data in your fact database."}
|
157 |
messages = [system]
|
158 |
text = f"Question: {query}\n\nFact Data: \n{results}"
|
|
|
150 |
top_indices = np.argsort(similarities[0])[-n:][::-1]
|
151 |
results = [facts[index] for index in top_indices]
|
152 |
return results
|
153 |
+
|
154 |
+
def fact_search(self,query,facts,n=5):
|
155 |
+
query_embedding = self.model.encode(query)
|
156 |
+
query_tensor = torch.tensor([query_embedding])
|
157 |
+
fact_tensor = self.model.encode(facts)
|
158 |
+
similarities = util.cos_sim(query_tensor, fact_tensor).numpy()
|
159 |
+
if len(similarities[0])<n:
|
160 |
+
n = len(similarities)
|
161 |
+
top_indices = np.argsort(similarities[0])[-n:][::-1]
|
162 |
+
return top_indices
|
163 |
+
|
164 |
+
def branch_search(self,query,num_branches=2,window_size=5):
|
165 |
+
results = self.search(query, window_size)
|
166 |
+
keys = [r[0] for r in results]
|
167 |
+
facts = [key+": "+str(rel).replace("description","is")+" "+str(self.entities[key][rel]) for key in keys for rel in self.entities[key]]
|
168 |
+
top_indices = self.fact_search(query,facts)
|
169 |
+
next_ents = [str(self.entities[key][rel]) for key in keys for rel in self.entities[key]]
|
170 |
+
top_ents = [next_ents[index] for index in top_indices]
|
171 |
+
if num_branches == 0:
|
172 |
+
return [facts[index] for index in top_indices]
|
173 |
+
for b in range(num_branches):
|
174 |
+
next_facts = [facts[i]+". "+keys[i]+": "+str(rel).replace("description","is")+" "+str(self.entities[keys[i]][rel]) for i in range(len(keys)) for rel in self.entities[keys[i]] if facts[i].find(keys[i])!=0]
|
175 |
+
facts = next_facts
|
176 |
+
top_indices = self.fact_search(query,facts)
|
177 |
+
answers = [facts[index] for index in top_indices]
|
178 |
+
next_ents = [str(self.entities[key][rel]) for key in keys for rel in self.entities[key]]
|
179 |
+
top_ents = [next_ents[index] for index in top_indices]
|
180 |
+
return answers
|
181 |
|
182 |
def chat_qa(self,query):
|
183 |
+
results = self.branch_search(query)
|
184 |
system = {"role":"system","content":"You are a helpful chatbot that answers questions based on data in your fact database."}
|
185 |
messages = [system]
|
186 |
text = f"Question: {query}\n\nFact Data: \n{results}"
|