File size: 11,531 Bytes
eaa3d8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
# src.kg.knowledge_graph.py
import itertools
import logging
import re
from collections import defaultdict
from itertools import combinations, product
from pathlib import Path

import networkx as nx

from .utils import strip_and_remove_empty_strings

logger = logging.getLogger(__name__)

PROMPT_FILE_PATH = Path('templates/story-prompt.txt')
MAX_RESPONSE_EDGE_COUNT = 15
MAX_PREDICATE_WORD_COUNT = 5
MAX_POSSESSION_WORD_COUNT = 2
MAX_MERGEABLE_NODE_EDGE_COUNT = 2
MIN_NODE_EDGE_COUNT = 1


class NamedEntity:
    """A knowledge graph node representing a named entity."""

    def __init__(self, names):
        self.names = names

    def __repr__(self):
        return ' / '.join(self.names)

def remove_number_prefix(text):

    clean_text = re.sub(r'^\d+\.\s*', '', text)
    return clean_text

def parse_response_text(response_text, identifier, are_edges_numbered=True):
    """
    Parse a response text from the OpenAI model into names (a list of names for
    each entity) and edges (relations between entities). `identifier` is a
    string used to identify the response text in error messages.
    """

    lines = strip_and_remove_empty_strings(response_text.split('\n'))

    if 'Named entities' not in lines[0]:
        logger.error(f'{identifier}: First line of response text does not '
                     f'start with "Named entities:". ("{lines[0]}")')
        return [], []
    mode = 'names'
    names = []
    edges = []
    for line in lines[1:]:
        if 'Knowledge graph edges' in line:
            mode = 'edges'
            continue
        if mode == 'names':
            if line.startswith('-'):
                line = line[1:]
            name_group = strip_and_remove_empty_strings(line.split(' / '))
            name_group = [remove_number_prefix(name) for name in name_group]
            names.append(name_group)
        elif mode == 'edges':
            if are_edges_numbered:
                if not re.match(r'^\d{1,2}\. ', line):
                    break
                if int(line.split('.')[0]) > MAX_RESPONSE_EDGE_COUNT:
                    break;
                line = line[3:]
            edge_components = strip_and_remove_empty_strings(line.split(';'))
            if len(edge_components) not in (2, 3):
                continue
            subjects = strip_and_remove_empty_strings(
                edge_components[0].split(','))
            predicate = edge_components[1]
            if len(edge_components) == 3:
                objects = strip_and_remove_empty_strings(
                    edge_components[2].split(','))
            else:
                objects = [None]
            for subject, object_ in product(subjects, objects):
                edge = (subject, predicate, object_)
                edges.append(edge)
    if not names:
        logger.error(f'{identifier}: No names were parsed from the response '
                     f'text.')
    if not edges:
        logger.error(f'{identifier}: No edges were parsed from the response '
                     f'text.')

    return names, edges


def generate_names_graph(names):
    """
    Generate a graph of names where the nodes are names and the edges indicate
    that two names refer to the same entity.
    """
    names_graph = nx.Graph()
    for name_group in names:
        for name in name_group:
            names_graph.add_node(name)
        for name_pair in combinations(name_group, 2):
            names_graph.add_edge(*name_pair)
    return names_graph


def expand_contracted_possessive(predicate, names):
    """
    Check if a predicate is of the form "<owner>'s <possession>", where the
    owner is a named entity. If so, return a predicate of the form
    "<possession> of" and an object of the form "<owner>".
    """
    match = re.search(
        fr'\'s\s\w+(?:\s\w+)'
        fr'{{0,{MAX_POSSESSION_WORD_COUNT - 1}}}$', predicate)
    if not match:
        return predicate, None
    apostrophe_index = match.start()
    owner = next(
        (name for name in names
         if predicate[:apostrophe_index].endswith(name)), None)
    if owner is None:
        return predicate, None
    possession = predicate[apostrophe_index + 2:].strip()
    predicate = (f'{predicate[:apostrophe_index - len(owner)].strip()} '
                 f'{possession} of')
    object_ = owner
    return predicate, object_


def does_duplicate_edge_exist(knowledge_graph, subject, predicate, object_):
    """
    Check if an edge with a given subject, predicate, and object already exists
    in a knowledge graph. If it exists, return the edge data; otherwise, return None.
    """
    for edge in knowledge_graph.edges(subject, data=True):
        if edge[1] == object_ and edge[2]['predicate'] == predicate:
            return edge
    return None


def add_edge_to_knowledge_graph(knowledge_graph, names, edge, max_predicate_word_count, **edge_attributes):
    """Add an edge to a knowledge graph, updating count if the edge already exists."""
    subject, predicate, object_ = edge
    if subject not in names:
        return
    if object_ is not None and object_ not in names:
        predicate += f' {object_}'
        object_ = None
    if object_ is None:
        object_at_end_of_predicate = next(
            (name for name in names if predicate.endswith(' ' + name)), None)
        if object_at_end_of_predicate is not None:
            object_ = object_at_end_of_predicate
            predicate = predicate[:-len(object_)].strip()
        else:
            predicate, object_ = expand_contracted_possessive(predicate, names)
    while predicate.endswith(('.', ',', '!', '?')):
        predicate = predicate[:-1]
    if (max_predicate_word_count and len(predicate.split()) > max_predicate_word_count):
        return
    if subject == object_:
        return
    if object_ is None:
        object_ = subject
    subject_node = next((node for node in knowledge_graph.nodes if subject in node.names), None)
    object_node = next((node for node in knowledge_graph.nodes if object_ in node.names), None)

    if subject_node is None or object_node is None:
        return

    existing_edge = does_duplicate_edge_exist(knowledge_graph, subject_node, predicate, object_node)
    if existing_edge:
        existing_edge[2]['count'] += 1
    else:
        knowledge_graph.add_edge(subject_node, object_node, predicate=predicate, count=1, **edge_attributes)


def initialize_knowledge_graph(names_graph, edges):
    """
    Initialize a knowledge graph from a graph of names and a dictionary of
    edges grouped by chapter index.
    """
    names = set(names_graph.nodes)
    knowledge_graph = nx.MultiDiGraph()
    for name in names:
        knowledge_graph.add_node(NamedEntity({name}))
    for chapter_index, chapter_edges in edges.items():
        for edge in chapter_edges:
            add_edge_to_knowledge_graph(
                knowledge_graph, names, edge,
                max_predicate_word_count=MAX_PREDICATE_WORD_COUNT,
                chapter_index=chapter_index)
    return knowledge_graph


def get_node_edge_count(knowledge_graph, node):
    """
    Get the number of edges for a node in a knowledge graph, excluding
    self-loops.
    """
    edges = (set(knowledge_graph.in_edges(node))
             | set(knowledge_graph.out_edges(node)))
    edge_count = sum(1 for edge in edges if edge[0] is not edge[1])
    return edge_count


def merge_nodes(knowledge_graph, nodes_to_merge):
    """
    Merge a list of nodes in a knowledge graph into one node, combining their
    sets of names and preserving their edges.
    """
    merged_node = NamedEntity(set())
    for node in nodes_to_merge:
        merged_node.names.update(node.names)
    knowledge_graph.add_node(merged_node)
    for node in nodes_to_merge:
        for edge in itertools.chain(knowledge_graph.out_edges(node, data=True),
                                    knowledge_graph.in_edges(node, data=True)):
            subject, object_, attributes = edge
            if (does_duplicate_edge_exist(knowledge_graph, merged_node,
                                          attributes['predicate'], object_)
                    or does_duplicate_edge_exist(knowledge_graph, subject,
                                                 attributes['predicate'],
                                                 merged_node)):
                continue
            if subject is object_:
                knowledge_graph.add_edge(merged_node, merged_node,
                                         **attributes)
            if subject is node:
                knowledge_graph.add_edge(merged_node, object_, **attributes)
            else:
                knowledge_graph.add_edge(subject, merged_node, **attributes)
        knowledge_graph.remove_node(node)

def merge_same_entity_nodes(knowledge_graph, names_graph):
    """
    Using a graph of names, merge nodes in a knowledge graph corresponding to
    the same entity.
    """
    for name_pair in names_graph.edges:
        first_node = next((node for node in knowledge_graph.nodes
                           if name_pair[0] in node.names), None)
        if first_node is None:
            continue
        if name_pair[1] in first_node.names:
            continue
        second_node = next((node for node in knowledge_graph.nodes
                            if name_pair[1] in node.names), None)
        if second_node is None:
            continue
        if knowledge_graph.has_edge(first_node, second_node):
            continue
        first_node_edge_count = get_node_edge_count(knowledge_graph,
                                                    first_node)
        second_node_edge_count = get_node_edge_count(knowledge_graph,
                                                     second_node)
        if (first_node_edge_count > MAX_MERGEABLE_NODE_EDGE_COUNT
                and second_node_edge_count > MAX_MERGEABLE_NODE_EDGE_COUNT):
            continue
        merge_nodes(knowledge_graph, [first_node, second_node])



def remove_nodes_with_few_edges(knowledge_graph):
    """
    Remove nodes that have fewer than `MIN_NODE_EDGE_COUNT` edges (excluding
    self-loops) from a knowledge graph. Repeat until no more nodes are removed.
    """
    while True:
        nodes_to_remove = []
        for node in knowledge_graph.nodes:
            edge_count = get_node_edge_count(knowledge_graph, node)
            if edge_count < MIN_NODE_EDGE_COUNT:
                nodes_to_remove.append(node)
        if not nodes_to_remove:
            break
        knowledge_graph.remove_nodes_from(nodes_to_remove)


def generate_knowledge_graph(response_texts, project_gutenberg_index):
    """
    Use OpenAI API response texts grouped by chapter index to generate a
    knowledge graph for a book.
    """
    names = []
    edges = defaultdict(list)
    for chapter_index, chapter_response_texts in response_texts.items():
        for response_text in chapter_response_texts:
            identifier = (f'Book {project_gutenberg_index}, chapter '
                          f'{chapter_index}')
            chapter_segment_names, chapter_segment_edges = parse_response_text(
                response_text, identifier)
            names.extend(chapter_segment_names)
            edges[chapter_index].extend(chapter_segment_edges)
    names_graph = generate_names_graph(names)
    knowledge_graph = initialize_knowledge_graph(names_graph, edges)
    merge_same_entity_nodes(knowledge_graph, names_graph)
    remove_nodes_with_few_edges(knowledge_graph)
    return knowledge_graph