File size: 11,531 Bytes
eaa3d8a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 |
# src.kg.knowledge_graph.py
import itertools
import logging
import re
from collections import defaultdict
from itertools import combinations, product
from pathlib import Path
import networkx as nx
from .utils import strip_and_remove_empty_strings
logger = logging.getLogger(__name__)
PROMPT_FILE_PATH = Path('templates/story-prompt.txt')
MAX_RESPONSE_EDGE_COUNT = 15
MAX_PREDICATE_WORD_COUNT = 5
MAX_POSSESSION_WORD_COUNT = 2
MAX_MERGEABLE_NODE_EDGE_COUNT = 2
MIN_NODE_EDGE_COUNT = 1
class NamedEntity:
"""A knowledge graph node representing a named entity."""
def __init__(self, names):
self.names = names
def __repr__(self):
return ' / '.join(self.names)
def remove_number_prefix(text):
clean_text = re.sub(r'^\d+\.\s*', '', text)
return clean_text
def parse_response_text(response_text, identifier, are_edges_numbered=True):
"""
Parse a response text from the OpenAI model into names (a list of names for
each entity) and edges (relations between entities). `identifier` is a
string used to identify the response text in error messages.
"""
lines = strip_and_remove_empty_strings(response_text.split('\n'))
if 'Named entities' not in lines[0]:
logger.error(f'{identifier}: First line of response text does not '
f'start with "Named entities:". ("{lines[0]}")')
return [], []
mode = 'names'
names = []
edges = []
for line in lines[1:]:
if 'Knowledge graph edges' in line:
mode = 'edges'
continue
if mode == 'names':
if line.startswith('-'):
line = line[1:]
name_group = strip_and_remove_empty_strings(line.split(' / '))
name_group = [remove_number_prefix(name) for name in name_group]
names.append(name_group)
elif mode == 'edges':
if are_edges_numbered:
if not re.match(r'^\d{1,2}\. ', line):
break
if int(line.split('.')[0]) > MAX_RESPONSE_EDGE_COUNT:
break;
line = line[3:]
edge_components = strip_and_remove_empty_strings(line.split(';'))
if len(edge_components) not in (2, 3):
continue
subjects = strip_and_remove_empty_strings(
edge_components[0].split(','))
predicate = edge_components[1]
if len(edge_components) == 3:
objects = strip_and_remove_empty_strings(
edge_components[2].split(','))
else:
objects = [None]
for subject, object_ in product(subjects, objects):
edge = (subject, predicate, object_)
edges.append(edge)
if not names:
logger.error(f'{identifier}: No names were parsed from the response '
f'text.')
if not edges:
logger.error(f'{identifier}: No edges were parsed from the response '
f'text.')
return names, edges
def generate_names_graph(names):
"""
Generate a graph of names where the nodes are names and the edges indicate
that two names refer to the same entity.
"""
names_graph = nx.Graph()
for name_group in names:
for name in name_group:
names_graph.add_node(name)
for name_pair in combinations(name_group, 2):
names_graph.add_edge(*name_pair)
return names_graph
def expand_contracted_possessive(predicate, names):
"""
Check if a predicate is of the form "<owner>'s <possession>", where the
owner is a named entity. If so, return a predicate of the form
"<possession> of" and an object of the form "<owner>".
"""
match = re.search(
fr'\'s\s\w+(?:\s\w+)'
fr'{{0,{MAX_POSSESSION_WORD_COUNT - 1}}}$', predicate)
if not match:
return predicate, None
apostrophe_index = match.start()
owner = next(
(name for name in names
if predicate[:apostrophe_index].endswith(name)), None)
if owner is None:
return predicate, None
possession = predicate[apostrophe_index + 2:].strip()
predicate = (f'{predicate[:apostrophe_index - len(owner)].strip()} '
f'{possession} of')
object_ = owner
return predicate, object_
def does_duplicate_edge_exist(knowledge_graph, subject, predicate, object_):
"""
Check if an edge with a given subject, predicate, and object already exists
in a knowledge graph. If it exists, return the edge data; otherwise, return None.
"""
for edge in knowledge_graph.edges(subject, data=True):
if edge[1] == object_ and edge[2]['predicate'] == predicate:
return edge
return None
def add_edge_to_knowledge_graph(knowledge_graph, names, edge, max_predicate_word_count, **edge_attributes):
"""Add an edge to a knowledge graph, updating count if the edge already exists."""
subject, predicate, object_ = edge
if subject not in names:
return
if object_ is not None and object_ not in names:
predicate += f' {object_}'
object_ = None
if object_ is None:
object_at_end_of_predicate = next(
(name for name in names if predicate.endswith(' ' + name)), None)
if object_at_end_of_predicate is not None:
object_ = object_at_end_of_predicate
predicate = predicate[:-len(object_)].strip()
else:
predicate, object_ = expand_contracted_possessive(predicate, names)
while predicate.endswith(('.', ',', '!', '?')):
predicate = predicate[:-1]
if (max_predicate_word_count and len(predicate.split()) > max_predicate_word_count):
return
if subject == object_:
return
if object_ is None:
object_ = subject
subject_node = next((node for node in knowledge_graph.nodes if subject in node.names), None)
object_node = next((node for node in knowledge_graph.nodes if object_ in node.names), None)
if subject_node is None or object_node is None:
return
existing_edge = does_duplicate_edge_exist(knowledge_graph, subject_node, predicate, object_node)
if existing_edge:
existing_edge[2]['count'] += 1
else:
knowledge_graph.add_edge(subject_node, object_node, predicate=predicate, count=1, **edge_attributes)
def initialize_knowledge_graph(names_graph, edges):
"""
Initialize a knowledge graph from a graph of names and a dictionary of
edges grouped by chapter index.
"""
names = set(names_graph.nodes)
knowledge_graph = nx.MultiDiGraph()
for name in names:
knowledge_graph.add_node(NamedEntity({name}))
for chapter_index, chapter_edges in edges.items():
for edge in chapter_edges:
add_edge_to_knowledge_graph(
knowledge_graph, names, edge,
max_predicate_word_count=MAX_PREDICATE_WORD_COUNT,
chapter_index=chapter_index)
return knowledge_graph
def get_node_edge_count(knowledge_graph, node):
"""
Get the number of edges for a node in a knowledge graph, excluding
self-loops.
"""
edges = (set(knowledge_graph.in_edges(node))
| set(knowledge_graph.out_edges(node)))
edge_count = sum(1 for edge in edges if edge[0] is not edge[1])
return edge_count
def merge_nodes(knowledge_graph, nodes_to_merge):
"""
Merge a list of nodes in a knowledge graph into one node, combining their
sets of names and preserving their edges.
"""
merged_node = NamedEntity(set())
for node in nodes_to_merge:
merged_node.names.update(node.names)
knowledge_graph.add_node(merged_node)
for node in nodes_to_merge:
for edge in itertools.chain(knowledge_graph.out_edges(node, data=True),
knowledge_graph.in_edges(node, data=True)):
subject, object_, attributes = edge
if (does_duplicate_edge_exist(knowledge_graph, merged_node,
attributes['predicate'], object_)
or does_duplicate_edge_exist(knowledge_graph, subject,
attributes['predicate'],
merged_node)):
continue
if subject is object_:
knowledge_graph.add_edge(merged_node, merged_node,
**attributes)
if subject is node:
knowledge_graph.add_edge(merged_node, object_, **attributes)
else:
knowledge_graph.add_edge(subject, merged_node, **attributes)
knowledge_graph.remove_node(node)
def merge_same_entity_nodes(knowledge_graph, names_graph):
"""
Using a graph of names, merge nodes in a knowledge graph corresponding to
the same entity.
"""
for name_pair in names_graph.edges:
first_node = next((node for node in knowledge_graph.nodes
if name_pair[0] in node.names), None)
if first_node is None:
continue
if name_pair[1] in first_node.names:
continue
second_node = next((node for node in knowledge_graph.nodes
if name_pair[1] in node.names), None)
if second_node is None:
continue
if knowledge_graph.has_edge(first_node, second_node):
continue
first_node_edge_count = get_node_edge_count(knowledge_graph,
first_node)
second_node_edge_count = get_node_edge_count(knowledge_graph,
second_node)
if (first_node_edge_count > MAX_MERGEABLE_NODE_EDGE_COUNT
and second_node_edge_count > MAX_MERGEABLE_NODE_EDGE_COUNT):
continue
merge_nodes(knowledge_graph, [first_node, second_node])
def remove_nodes_with_few_edges(knowledge_graph):
"""
Remove nodes that have fewer than `MIN_NODE_EDGE_COUNT` edges (excluding
self-loops) from a knowledge graph. Repeat until no more nodes are removed.
"""
while True:
nodes_to_remove = []
for node in knowledge_graph.nodes:
edge_count = get_node_edge_count(knowledge_graph, node)
if edge_count < MIN_NODE_EDGE_COUNT:
nodes_to_remove.append(node)
if not nodes_to_remove:
break
knowledge_graph.remove_nodes_from(nodes_to_remove)
def generate_knowledge_graph(response_texts, project_gutenberg_index):
"""
Use OpenAI API response texts grouped by chapter index to generate a
knowledge graph for a book.
"""
names = []
edges = defaultdict(list)
for chapter_index, chapter_response_texts in response_texts.items():
for response_text in chapter_response_texts:
identifier = (f'Book {project_gutenberg_index}, chapter '
f'{chapter_index}')
chapter_segment_names, chapter_segment_edges = parse_response_text(
response_text, identifier)
names.extend(chapter_segment_names)
edges[chapter_index].extend(chapter_segment_edges)
names_graph = generate_names_graph(names)
knowledge_graph = initialize_knowledge_graph(names_graph, edges)
merge_same_entity_nodes(knowledge_graph, names_graph)
remove_nodes_with_few_edges(knowledge_graph)
return knowledge_graph
|