File size: 12,242 Bytes
8968917
 
 
 
 
 
 
bc6f5ad
8968917
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5bde97a
8968917
 
 
 
 
 
 
7a36a89
8968917
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5bde97a
8968917
 
 
 
 
 
 
bc6f5ad
8968917
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2c82cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8968917
 
b2c82cb
8968917
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de3d112
8968917
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
import json
import pydot
from sentence_transformers import SentenceTransformer, util
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import torch
import openai
from openai import OpenAI
from PIL import Image

system_text = """You are an expert AI that extracts knowledge graphs from text and outputs JSON files with the extracted knowledge, and nothing more. Here's how the JSON is broken down.
Entity dictionaries are organized in a list
Every entity mentioned in the text has its own entity dictionary, in which the name of the entity is the key, and the value is a list of relationships.
Each relationship contains a short word or two accurately describing the relationship to the other entity as the key, and then the other entity as a value.
All inverses of these relationships are represented in the relationship list of the other entities. This is REALY IMPORTANT. For example if Apple created the iPhone, it is also important to note that the iPhone was created by Apple (each entity should have this relsationship from their perspective).
Non specified relationships are also inferred (if person X is the son of person Y, and person Z is person X's sibling, person Z is also the child of person Y).
The JSON contains NO NEW LINES. All the data should be on one line.
Every entity has a "description" relationship which provides a short description of what it is in a few words. If the description references another entity, then this relationship MUST be graphed, even if it is redundant.
Relationships are only created about facts, not just any connection between two entities mentioned in the text.
Example output:
[{"Toki Pona": [{"description": "philosophical artistic constructed language"}, {"translated as": "the language of good"}, {"created by": "Sonja Lang"}, {"first published": "2001"}, {"complete form published in": "Toki Pona: The Language of Good"}, {"supplementary dictionary": "Toki Pona Dictionary"}], "Sonja Lang": [{"description": "Canadian linguist and translator"}, {"creator of": "Toki Pona"}], "Toki Pona: The Language of Good": [{"description": "book"}, {"published in": "2014"}, {"language": "Toki Pona"}], "Toki Pona Dictionary": [{"description": "dictionary"}, {"released in": "July 2021"}, {"based on": "community usage"}]}]"""

class KnowledgeGraph:
    def __init__(self,api_key,kg_file=""):
        openai.api_key = api_key 
        self.graph = pydot.Dot(graph_type="digraph")
        self.entities = {}
        self.fact_scores = {}
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model = SentenceTransformer('all-MiniLM-L6-v2').to(self.device)
        self.entity_embeddings = {}
        if kg_file!="":
          self.load_graph(kg_file)

    def add_entity(self, name, description):
        if name not in self.entities:
            self.entities[name] = {"description": description}
            entity_node = pydot.Node(name, label=f"{name}\n({description})")
            self.graph.add_node(entity_node)
            self.entity_embeddings[name] = self.model.encode(name)#+": \n"+"\n".join([key+": "+kg.entities[name][key] for key in kg.entities[name]]))
            print("added embedding")
            
    def add_relationship(self, entity1, relationship, entity2):
        if entity1 in self.entities:
            try: 
              self.entities[entity1][relationship] += ", "+entity2
            except:
              self.entities[entity1][relationship] = entity2
            edge = pydot.Edge(entity1, entity2, label=relationship)
            self.graph.add_edge(edge)

    def update_graph(self, json_str,clean=True):
        try:
          data = json.loads(json_str)
        except:
          print("GPT4 failed to create a valid JSON. Input may be too long for processing.")
          return
        for entity_dict in data:
            for entity, relationships in entity_dict.items():
                try:
                  self.add_entity(entity, relationships[0]["description"])
                except:
                  self.add_entity(entity, "")
                for rel in relationships[1:]:
                    for relationship, other_entity in rel.items():
                        try:
                          self.add_relationship(entity, relationship, other_entity)
                        except:
                          for o in other_entity:
                            self.add_relationship(entity, relationship, o)
        if clean:
          for entity_dict in data:
              for entity, relationships in entity_dict.items():
                self.clean_graph(entity)

    def display_graph(self, output_file="knowledge_graph.png"):
        self.graph.write_png(output_file)
        img = Image.open(output_file)
        img.show()
        return img

    def search(self, query, n=5):
      if len(self.entity_embeddings)<5:
        n = len(self.entity_embeddings)
      query_embedding = self.model.encode(query)
      query_tensor = torch.tensor([query_embedding])
      entity_tensor = torch.tensor(list(self.entity_embeddings.values()))
      similarities = util.cos_sim(query_tensor, entity_tensor).numpy()
      top_indices = np.argsort(similarities[0])[-n:][::-1]
      results = [(list(self.entity_embeddings.keys())[index], similarities[0][index]) for index in top_indices]
      return results

    def related_entities(self,query, n=5):
      query_embedding = self.model.encode(query)
      query_tensor = torch.tensor([query_embedding])
      potentities = [key+": "+self.entities[key]["description"] for key in self.entities]
      entity_tensor = self.model.encode(potentities)
      similarities = util.cos_sim(query_tensor, entity_tensor).numpy()
      if len(similarities)<n:
        n = len(similarities)
      top_indices = np.argsort(similarities[0])[-n:][::-1]
      results = [potentities[index] for index in top_indices]
      return results

    def text_to_data(self,text):
      system = {"role":"system","content":system_text}
      messages = [system]
      try:
        related = self.related_entities(text)
        text = text+f"\n\nGenerate the JSON for the text above, remembering to add inverse relationships and inferences. Here are some related entities already in the graph. If you are adding information about any of them, refer to them by the names below (otherwise ignore this information):\n\n{str(related)}"
      except:
        pass
      messages.append({"role":"user","content":text})
      output = openai.ChatCompletion.create(model="gpt-4-1106-preview",messages=messages)["choices"][0]["message"].to_dict()["content"]
      return output

    def learn(self,text,show_output=False):
      json_str = self.text_to_data(text)
      if show_output:
        print(json_str)
      self.update_graph(json_str)

    def graph_search(self,query,n=5,path="subgraph.png"):
      results = self.search(query, n)
      if len(results)<n:
        n = len(results)
      top_ents = [results[i][0] for i in range(n)]
      data = [{ent:[{key:self.entities[ent][key]} for key in self.entities[ent]]} for ent in top_ents]
      new = KnowledgeGraph()
      json_string = json.dumps(data) 
      new.update_graph(str(json_string),clean=False) 
      new.display_graph(path)
    
    def text_search(self,query,n=3):
      results = self.search(query, n)
      keys = [r[0] for r in results]
      potentities = [key+": "+str(self.entities[key]) for key in keys]
      for p in potentities:
        print(p)
    
    def qa_search(self,query,n=5):
      results = self.search(query, n)
      keys = [r[0] for r in results]
      facts = [key+": "+str(rel).replace("description","is")+" "+str(self.entities[key][rel]) for key in keys for rel in self.entities[key]]
      query_embedding = self.model.encode(query)
      query_tensor = torch.tensor([query_embedding])
      fact_tensor = self.model.encode(facts)
      similarities = util.cos_sim(query_tensor, fact_tensor).numpy()
      if len(similarities[0])<n:
        n = len(similarities)
      top_indices = np.argsort(similarities[0])[-n:][::-1]
      results = [facts[index] for index in top_indices]
      return results

    def fact_search(self,query,facts,n=5):
      query_embedding = self.model.encode(query)
      query_tensor = torch.tensor([query_embedding])
      fact_tensor = self.model.encode(facts)
      similarities = util.cos_sim(query_tensor, fact_tensor).numpy()
      if len(similarities[0])<n:
        n = len(similarities)
      top_indices = np.argsort(similarities[0])[-n:][::-1]
      return top_indices

    def branch_search(self,query,num_branches=2,window_size=5):
      results = self.search(query, window_size)
      keys = [r[0] for r in results]
      facts = [key+": "+str(rel).replace("description","is")+" "+str(self.entities[key][rel]) for key in keys for rel in self.entities[key]]
      top_indices = self.fact_search(query,facts)
      next_ents = [str(self.entities[key][rel]) for key in keys for rel in self.entities[key]]
      top_ents = [next_ents[index] for index in top_indices]
      if num_branches == 0:
        return [facts[index] for index in top_indices]
      for b in range(num_branches):
        next_facts = [facts[i]+". "+keys[i]+": "+str(rel).replace("description","is")+" "+str(self.entities[keys[i]][rel]) for i in range(len(keys)) for rel in self.entities[keys[i]] if facts[i].find(keys[i])!=0]
        facts = next_facts
        top_indices = self.fact_search(query,facts)
        answers = [facts[index] for index in top_indices]
        next_ents = [str(self.entities[key][rel]) for key in keys for rel in self.entities[key]]
        top_ents = [next_ents[index] for index in top_indices]
      return answers
    
    def chat_qa(self,query):
      results = self.branch_search(query)
      system = {"role":"system","content":"You are a helpful chatbot that answers questions based on data in your fact database."}
      messages = [system]
      text = f"Question: {query}\n\nFact Data: \n{results}"
      messages.append({"role":"user","content":text})
      output = openai.ChatCompletion.create(model="gpt-3.5-turbo",messages=messages)["choices"][0]["message"].to_dict()["content"]
      return output
    
    def clean_graph(self,key):
      facts = [key+": "+str(rel).replace("description","is")+" "+str(self.entities[key][rel]) for rel in self.entities[key]]
      rels = [rel for rel in self.entities[key]]
      fact_embs = self.model.encode(facts)
      scores = util.cos_sim(fact_embs,fact_embs)
      pairs = []
      for i in range(len(scores)):
        for j in range(len(scores[i])):
          if round(scores[i][j].item(),3)!=1.0 and scores[i][j]>0.7:
            if (facts[i],facts[j]) not in pairs and (facts[j],facts[i]) not in pairs:
              pairs.append((facts[i],facts[j]))
      for pair in pairs:
        system = {"role":"system","content":"You are a helpful chatbot that only outputs YES or NO"}
        messages = [system]
        messages.append({"role":"user","content":f"Do these two facts in our database express the same thing?: {pair}"})
        output = openai.ChatCompletion.create(model="gpt-4-1106-preview",messages=messages)["choices"][0]["message"].to_dict()["content"]
        if "yes" in output.lower():
          bad_index = facts.index(pair[1])
          redundant = rels[bad_index]
          del self.entities[key][redundant]
          good_index = facts.index(pair[0])
          validated = rels[bad_index]
          try:
            self.fact_scores[(key,validated)]+=1
          except:
            self.fact_scores[(key,validated)]=1

    def load_graph(self,kg_file):
      with open(kg_file) as f:
        lines = f.readlines()
        graph_data = "\n".join(lines[:-1])
        ents = eval(lines[-1])
      data = [{ent:[{key:ents[ent][key]} for key in ents[ent]]} for ent in ents]
      json_string = json.dumps(data)
      print(json_string)
      self.update_graph(str(json_string)) 
      self.graph = pydot.graph_from_dot_data(graph_data)[0]

    def save_graph(self,filename="mygraph.kg"):
      with open(filename,"w") as f:
        f.write("")
      self.graph.write_dot(filename)
      with open(filename,"a") as f:
        f.write("\n")
        f.write(str(self.entities))