helliun's picture
Update GPT4KG.py
bc6f5ad
raw
history blame
12.2 kB
import json
import pydot
from sentence_transformers import SentenceTransformer, util
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import torch
import openai
from openai import OpenAI
from PIL import Image
system_text = """You are an expert AI that extracts knowledge graphs from text and outputs JSON files with the extracted knowledge, and nothing more. Here's how the JSON is broken down.
Entity dictionaries are organized in a list
Every entity mentioned in the text has its own entity dictionary, in which the name of the entity is the key, and the value is a list of relationships.
Each relationship contains a short word or two accurately describing the relationship to the other entity as the key, and then the other entity as a value.
All inverses of these relationships are represented in the relationship list of the other entities. This is REALY IMPORTANT. For example if Apple created the iPhone, it is also important to note that the iPhone was created by Apple (each entity should have this relsationship from their perspective).
Non specified relationships are also inferred (if person X is the son of person Y, and person Z is person X's sibling, person Z is also the child of person Y).
The JSON contains NO NEW LINES. All the data should be on one line.
Every entity has a "description" relationship which provides a short description of what it is in a few words. If the description references another entity, then this relationship MUST be graphed, even if it is redundant.
Relationships are only created about facts, not just any connection between two entities mentioned in the text.
Example output:
[{"Toki Pona": [{"description": "philosophical artistic constructed language"}, {"translated as": "the language of good"}, {"created by": "Sonja Lang"}, {"first published": "2001"}, {"complete form published in": "Toki Pona: The Language of Good"}, {"supplementary dictionary": "Toki Pona Dictionary"}], "Sonja Lang": [{"description": "Canadian linguist and translator"}, {"creator of": "Toki Pona"}], "Toki Pona: The Language of Good": [{"description": "book"}, {"published in": "2014"}, {"language": "Toki Pona"}], "Toki Pona Dictionary": [{"description": "dictionary"}, {"released in": "July 2021"}, {"based on": "community usage"}]}]"""
class KnowledgeGraph:
def __init__(self,api_key,kg_file=""):
openai.api_key = api_key
self.graph = pydot.Dot(graph_type="digraph")
self.entities = {}
self.fact_scores = {}
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = SentenceTransformer('all-MiniLM-L6-v2').to(self.device)
self.entity_embeddings = {}
if kg_file!="":
self.load_graph(kg_file)
def add_entity(self, name, description):
if name not in self.entities:
self.entities[name] = {"description": description}
entity_node = pydot.Node(name, label=f"{name}\n({description})")
self.graph.add_node(entity_node)
self.entity_embeddings[name] = self.model.encode(name)#+": \n"+"\n".join([key+": "+kg.entities[name][key] for key in kg.entities[name]]))
print("added embedding")
def add_relationship(self, entity1, relationship, entity2):
if entity1 in self.entities:
try:
self.entities[entity1][relationship] += ", "+entity2
except:
self.entities[entity1][relationship] = entity2
edge = pydot.Edge(entity1, entity2, label=relationship)
self.graph.add_edge(edge)
def update_graph(self, json_str,clean=True):
try:
data = json.loads(json_str)
except:
print("GPT4 failed to create a valid JSON. Input may be too long for processing.")
return
for entity_dict in data:
for entity, relationships in entity_dict.items():
try:
self.add_entity(entity, relationships[0]["description"])
except:
self.add_entity(entity, "")
for rel in relationships[1:]:
for relationship, other_entity in rel.items():
try:
self.add_relationship(entity, relationship, other_entity)
except:
for o in other_entity:
self.add_relationship(entity, relationship, o)
if clean:
for entity_dict in data:
for entity, relationships in entity_dict.items():
self.clean_graph(entity)
def display_graph(self, output_file="knowledge_graph.png"):
self.graph.write_png(output_file)
img = Image.open(output_file)
img.show()
return img
def search(self, query, n=5):
if len(self.entity_embeddings)<5:
n = len(self.entity_embeddings)
query_embedding = self.model.encode(query)
query_tensor = torch.tensor([query_embedding])
entity_tensor = torch.tensor(list(self.entity_embeddings.values()))
similarities = util.cos_sim(query_tensor, entity_tensor).numpy()
top_indices = np.argsort(similarities[0])[-n:][::-1]
results = [(list(self.entity_embeddings.keys())[index], similarities[0][index]) for index in top_indices]
return results
def related_entities(self,query, n=5):
query_embedding = self.model.encode(query)
query_tensor = torch.tensor([query_embedding])
potentities = [key+": "+self.entities[key]["description"] for key in self.entities]
entity_tensor = self.model.encode(potentities)
similarities = util.cos_sim(query_tensor, entity_tensor).numpy()
if len(similarities)<n:
n = len(similarities)
top_indices = np.argsort(similarities[0])[-n:][::-1]
results = [potentities[index] for index in top_indices]
return results
def text_to_data(self,text):
system = {"role":"system","content":system_text}
messages = [system]
try:
related = self.related_entities(text)
text = text+f"\n\nGenerate the JSON for the text above, remembering to add inverse relationships and inferences. Here are some related entities already in the graph. If you are adding information about any of them, refer to them by the names below (otherwise ignore this information):\n\n{str(related)}"
except:
pass
messages.append({"role":"user","content":text})
output = openai.ChatCompletion.create(model="gpt-4-1106-preview",messages=messages)["choices"][0]["message"].to_dict()["content"]
return output
def learn(self,text,show_output=False):
json_str = self.text_to_data(text)
if show_output:
print(json_str)
self.update_graph(json_str)
def graph_search(self,query,n=5,path="subgraph.png"):
results = self.search(query, n)
if len(results)<n:
n = len(results)
top_ents = [results[i][0] for i in range(n)]
data = [{ent:[{key:self.entities[ent][key]} for key in self.entities[ent]]} for ent in top_ents]
new = KnowledgeGraph()
json_string = json.dumps(data)
new.update_graph(str(json_string),clean=False)
new.display_graph(path)
def text_search(self,query,n=3):
results = self.search(query, n)
keys = [r[0] for r in results]
potentities = [key+": "+str(self.entities[key]) for key in keys]
for p in potentities:
print(p)
def qa_search(self,query,n=5):
results = self.search(query, n)
keys = [r[0] for r in results]
facts = [key+": "+str(rel).replace("description","is")+" "+str(self.entities[key][rel]) for key in keys for rel in self.entities[key]]
query_embedding = self.model.encode(query)
query_tensor = torch.tensor([query_embedding])
fact_tensor = self.model.encode(facts)
similarities = util.cos_sim(query_tensor, fact_tensor).numpy()
if len(similarities[0])<n:
n = len(similarities)
top_indices = np.argsort(similarities[0])[-n:][::-1]
results = [facts[index] for index in top_indices]
return results
def fact_search(self,query,facts,n=5):
query_embedding = self.model.encode(query)
query_tensor = torch.tensor([query_embedding])
fact_tensor = self.model.encode(facts)
similarities = util.cos_sim(query_tensor, fact_tensor).numpy()
if len(similarities[0])<n:
n = len(similarities)
top_indices = np.argsort(similarities[0])[-n:][::-1]
return top_indices
def branch_search(self,query,num_branches=2,window_size=5):
results = self.search(query, window_size)
keys = [r[0] for r in results]
facts = [key+": "+str(rel).replace("description","is")+" "+str(self.entities[key][rel]) for key in keys for rel in self.entities[key]]
top_indices = self.fact_search(query,facts)
next_ents = [str(self.entities[key][rel]) for key in keys for rel in self.entities[key]]
top_ents = [next_ents[index] for index in top_indices]
if num_branches == 0:
return [facts[index] for index in top_indices]
for b in range(num_branches):
next_facts = [facts[i]+". "+keys[i]+": "+str(rel).replace("description","is")+" "+str(self.entities[keys[i]][rel]) for i in range(len(keys)) for rel in self.entities[keys[i]] if facts[i].find(keys[i])!=0]
facts = next_facts
top_indices = self.fact_search(query,facts)
answers = [facts[index] for index in top_indices]
next_ents = [str(self.entities[key][rel]) for key in keys for rel in self.entities[key]]
top_ents = [next_ents[index] for index in top_indices]
return answers
def chat_qa(self,query):
results = self.branch_search(query)
system = {"role":"system","content":"You are a helpful chatbot that answers questions based on data in your fact database."}
messages = [system]
text = f"Question: {query}\n\nFact Data: \n{results}"
messages.append({"role":"user","content":text})
output = openai.ChatCompletion.create(model="gpt-3.5-turbo",messages=messages)["choices"][0]["message"].to_dict()["content"]
return output
def clean_graph(self,key):
facts = [key+": "+str(rel).replace("description","is")+" "+str(self.entities[key][rel]) for rel in self.entities[key]]
rels = [rel for rel in self.entities[key]]
fact_embs = self.model.encode(facts)
scores = util.cos_sim(fact_embs,fact_embs)
pairs = []
for i in range(len(scores)):
for j in range(len(scores[i])):
if round(scores[i][j].item(),3)!=1.0 and scores[i][j]>0.7:
if (facts[i],facts[j]) not in pairs and (facts[j],facts[i]) not in pairs:
pairs.append((facts[i],facts[j]))
for pair in pairs:
system = {"role":"system","content":"You are a helpful chatbot that only outputs YES or NO"}
messages = [system]
messages.append({"role":"user","content":f"Do these two facts in our database express the same thing?: {pair}"})
output = openai.ChatCompletion.create(model="gpt-4-1106-preview",messages=messages)["choices"][0]["message"].to_dict()["content"]
if "yes" in output.lower():
bad_index = facts.index(pair[1])
redundant = rels[bad_index]
del self.entities[key][redundant]
good_index = facts.index(pair[0])
validated = rels[bad_index]
try:
self.fact_scores[(key,validated)]+=1
except:
self.fact_scores[(key,validated)]=1
def load_graph(self,kg_file):
with open(kg_file) as f:
lines = f.readlines()
graph_data = "\n".join(lines[:-1])
ents = eval(lines[-1])
data = [{ent:[{key:ents[ent][key]} for key in ents[ent]]} for ent in ents]
json_string = json.dumps(data)
print(json_string)
self.update_graph(str(json_string))
self.graph = pydot.graph_from_dot_data(graph_data)[0]
def save_graph(self,filename="mygraph.kg"):
with open(filename,"w") as f:
f.write("")
self.graph.write_dot(filename)
with open(filename,"a") as f:
f.write("\n")
f.write(str(self.entities))