Spaces:
Runtime error
Runtime error
File size: 12,242 Bytes
8968917 bc6f5ad 8968917 5bde97a 8968917 7a36a89 8968917 5bde97a 8968917 bc6f5ad 8968917 b2c82cb 8968917 b2c82cb 8968917 de3d112 8968917 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 |
import json
import pydot
from sentence_transformers import SentenceTransformer, util
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import torch
import openai
from openai import OpenAI
from PIL import Image
system_text = """You are an expert AI that extracts knowledge graphs from text and outputs JSON files with the extracted knowledge, and nothing more. Here's how the JSON is broken down.
Entity dictionaries are organized in a list
Every entity mentioned in the text has its own entity dictionary, in which the name of the entity is the key, and the value is a list of relationships.
Each relationship contains a short word or two accurately describing the relationship to the other entity as the key, and then the other entity as a value.
All inverses of these relationships are represented in the relationship list of the other entities. This is REALY IMPORTANT. For example if Apple created the iPhone, it is also important to note that the iPhone was created by Apple (each entity should have this relsationship from their perspective).
Non specified relationships are also inferred (if person X is the son of person Y, and person Z is person X's sibling, person Z is also the child of person Y).
The JSON contains NO NEW LINES. All the data should be on one line.
Every entity has a "description" relationship which provides a short description of what it is in a few words. If the description references another entity, then this relationship MUST be graphed, even if it is redundant.
Relationships are only created about facts, not just any connection between two entities mentioned in the text.
Example output:
[{"Toki Pona": [{"description": "philosophical artistic constructed language"}, {"translated as": "the language of good"}, {"created by": "Sonja Lang"}, {"first published": "2001"}, {"complete form published in": "Toki Pona: The Language of Good"}, {"supplementary dictionary": "Toki Pona Dictionary"}], "Sonja Lang": [{"description": "Canadian linguist and translator"}, {"creator of": "Toki Pona"}], "Toki Pona: The Language of Good": [{"description": "book"}, {"published in": "2014"}, {"language": "Toki Pona"}], "Toki Pona Dictionary": [{"description": "dictionary"}, {"released in": "July 2021"}, {"based on": "community usage"}]}]"""
class KnowledgeGraph:
def __init__(self,api_key,kg_file=""):
openai.api_key = api_key
self.graph = pydot.Dot(graph_type="digraph")
self.entities = {}
self.fact_scores = {}
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = SentenceTransformer('all-MiniLM-L6-v2').to(self.device)
self.entity_embeddings = {}
if kg_file!="":
self.load_graph(kg_file)
def add_entity(self, name, description):
if name not in self.entities:
self.entities[name] = {"description": description}
entity_node = pydot.Node(name, label=f"{name}\n({description})")
self.graph.add_node(entity_node)
self.entity_embeddings[name] = self.model.encode(name)#+": \n"+"\n".join([key+": "+kg.entities[name][key] for key in kg.entities[name]]))
print("added embedding")
def add_relationship(self, entity1, relationship, entity2):
if entity1 in self.entities:
try:
self.entities[entity1][relationship] += ", "+entity2
except:
self.entities[entity1][relationship] = entity2
edge = pydot.Edge(entity1, entity2, label=relationship)
self.graph.add_edge(edge)
def update_graph(self, json_str,clean=True):
try:
data = json.loads(json_str)
except:
print("GPT4 failed to create a valid JSON. Input may be too long for processing.")
return
for entity_dict in data:
for entity, relationships in entity_dict.items():
try:
self.add_entity(entity, relationships[0]["description"])
except:
self.add_entity(entity, "")
for rel in relationships[1:]:
for relationship, other_entity in rel.items():
try:
self.add_relationship(entity, relationship, other_entity)
except:
for o in other_entity:
self.add_relationship(entity, relationship, o)
if clean:
for entity_dict in data:
for entity, relationships in entity_dict.items():
self.clean_graph(entity)
def display_graph(self, output_file="knowledge_graph.png"):
self.graph.write_png(output_file)
img = Image.open(output_file)
img.show()
return img
def search(self, query, n=5):
if len(self.entity_embeddings)<5:
n = len(self.entity_embeddings)
query_embedding = self.model.encode(query)
query_tensor = torch.tensor([query_embedding])
entity_tensor = torch.tensor(list(self.entity_embeddings.values()))
similarities = util.cos_sim(query_tensor, entity_tensor).numpy()
top_indices = np.argsort(similarities[0])[-n:][::-1]
results = [(list(self.entity_embeddings.keys())[index], similarities[0][index]) for index in top_indices]
return results
def related_entities(self,query, n=5):
query_embedding = self.model.encode(query)
query_tensor = torch.tensor([query_embedding])
potentities = [key+": "+self.entities[key]["description"] for key in self.entities]
entity_tensor = self.model.encode(potentities)
similarities = util.cos_sim(query_tensor, entity_tensor).numpy()
if len(similarities)<n:
n = len(similarities)
top_indices = np.argsort(similarities[0])[-n:][::-1]
results = [potentities[index] for index in top_indices]
return results
def text_to_data(self,text):
system = {"role":"system","content":system_text}
messages = [system]
try:
related = self.related_entities(text)
text = text+f"\n\nGenerate the JSON for the text above, remembering to add inverse relationships and inferences. Here are some related entities already in the graph. If you are adding information about any of them, refer to them by the names below (otherwise ignore this information):\n\n{str(related)}"
except:
pass
messages.append({"role":"user","content":text})
output = openai.ChatCompletion.create(model="gpt-4-1106-preview",messages=messages)["choices"][0]["message"].to_dict()["content"]
return output
def learn(self,text,show_output=False):
json_str = self.text_to_data(text)
if show_output:
print(json_str)
self.update_graph(json_str)
def graph_search(self,query,n=5,path="subgraph.png"):
results = self.search(query, n)
if len(results)<n:
n = len(results)
top_ents = [results[i][0] for i in range(n)]
data = [{ent:[{key:self.entities[ent][key]} for key in self.entities[ent]]} for ent in top_ents]
new = KnowledgeGraph()
json_string = json.dumps(data)
new.update_graph(str(json_string),clean=False)
new.display_graph(path)
def text_search(self,query,n=3):
results = self.search(query, n)
keys = [r[0] for r in results]
potentities = [key+": "+str(self.entities[key]) for key in keys]
for p in potentities:
print(p)
def qa_search(self,query,n=5):
results = self.search(query, n)
keys = [r[0] for r in results]
facts = [key+": "+str(rel).replace("description","is")+" "+str(self.entities[key][rel]) for key in keys for rel in self.entities[key]]
query_embedding = self.model.encode(query)
query_tensor = torch.tensor([query_embedding])
fact_tensor = self.model.encode(facts)
similarities = util.cos_sim(query_tensor, fact_tensor).numpy()
if len(similarities[0])<n:
n = len(similarities)
top_indices = np.argsort(similarities[0])[-n:][::-1]
results = [facts[index] for index in top_indices]
return results
def fact_search(self,query,facts,n=5):
query_embedding = self.model.encode(query)
query_tensor = torch.tensor([query_embedding])
fact_tensor = self.model.encode(facts)
similarities = util.cos_sim(query_tensor, fact_tensor).numpy()
if len(similarities[0])<n:
n = len(similarities)
top_indices = np.argsort(similarities[0])[-n:][::-1]
return top_indices
def branch_search(self,query,num_branches=2,window_size=5):
results = self.search(query, window_size)
keys = [r[0] for r in results]
facts = [key+": "+str(rel).replace("description","is")+" "+str(self.entities[key][rel]) for key in keys for rel in self.entities[key]]
top_indices = self.fact_search(query,facts)
next_ents = [str(self.entities[key][rel]) for key in keys for rel in self.entities[key]]
top_ents = [next_ents[index] for index in top_indices]
if num_branches == 0:
return [facts[index] for index in top_indices]
for b in range(num_branches):
next_facts = [facts[i]+". "+keys[i]+": "+str(rel).replace("description","is")+" "+str(self.entities[keys[i]][rel]) for i in range(len(keys)) for rel in self.entities[keys[i]] if facts[i].find(keys[i])!=0]
facts = next_facts
top_indices = self.fact_search(query,facts)
answers = [facts[index] for index in top_indices]
next_ents = [str(self.entities[key][rel]) for key in keys for rel in self.entities[key]]
top_ents = [next_ents[index] for index in top_indices]
return answers
def chat_qa(self,query):
results = self.branch_search(query)
system = {"role":"system","content":"You are a helpful chatbot that answers questions based on data in your fact database."}
messages = [system]
text = f"Question: {query}\n\nFact Data: \n{results}"
messages.append({"role":"user","content":text})
output = openai.ChatCompletion.create(model="gpt-3.5-turbo",messages=messages)["choices"][0]["message"].to_dict()["content"]
return output
def clean_graph(self,key):
facts = [key+": "+str(rel).replace("description","is")+" "+str(self.entities[key][rel]) for rel in self.entities[key]]
rels = [rel for rel in self.entities[key]]
fact_embs = self.model.encode(facts)
scores = util.cos_sim(fact_embs,fact_embs)
pairs = []
for i in range(len(scores)):
for j in range(len(scores[i])):
if round(scores[i][j].item(),3)!=1.0 and scores[i][j]>0.7:
if (facts[i],facts[j]) not in pairs and (facts[j],facts[i]) not in pairs:
pairs.append((facts[i],facts[j]))
for pair in pairs:
system = {"role":"system","content":"You are a helpful chatbot that only outputs YES or NO"}
messages = [system]
messages.append({"role":"user","content":f"Do these two facts in our database express the same thing?: {pair}"})
output = openai.ChatCompletion.create(model="gpt-4-1106-preview",messages=messages)["choices"][0]["message"].to_dict()["content"]
if "yes" in output.lower():
bad_index = facts.index(pair[1])
redundant = rels[bad_index]
del self.entities[key][redundant]
good_index = facts.index(pair[0])
validated = rels[bad_index]
try:
self.fact_scores[(key,validated)]+=1
except:
self.fact_scores[(key,validated)]=1
def load_graph(self,kg_file):
with open(kg_file) as f:
lines = f.readlines()
graph_data = "\n".join(lines[:-1])
ents = eval(lines[-1])
data = [{ent:[{key:ents[ent][key]} for key in ents[ent]]} for ent in ents]
json_string = json.dumps(data)
print(json_string)
self.update_graph(str(json_string))
self.graph = pydot.graph_from_dot_data(graph_data)[0]
def save_graph(self,filename="mygraph.kg"):
with open(filename,"w") as f:
f.write("")
self.graph.write_dot(filename)
with open(filename,"a") as f:
f.write("\n")
f.write(str(self.entities)) |