Spaces:
Sleeping
Sleeping
| # src.kg.knowledge_graph.py | |
| import itertools | |
| import logging | |
| import re | |
| from collections import defaultdict | |
| from itertools import combinations, product | |
| from pathlib import Path | |
| import networkx as nx | |
| from .utils import strip_and_remove_empty_strings | |
| logger = logging.getLogger(__name__) | |
| PROMPT_FILE_PATH = Path('templates/story-prompt.txt') | |
| MAX_RESPONSE_EDGE_COUNT = 15 | |
| MAX_PREDICATE_WORD_COUNT = 5 | |
| MAX_POSSESSION_WORD_COUNT = 2 | |
| MAX_MERGEABLE_NODE_EDGE_COUNT = 2 | |
| MIN_NODE_EDGE_COUNT = 1 | |
| class NamedEntity: | |
| """A knowledge graph node representing a named entity.""" | |
| def __init__(self, names): | |
| self.names = names | |
| def __repr__(self): | |
| return ' / '.join(self.names) | |
| def remove_number_prefix(text): | |
| clean_text = re.sub(r'^\d+\.\s*', '', text) | |
| return clean_text | |
| def parse_response_text(response_text, identifier, are_edges_numbered=True): | |
| """ | |
| Parse a response text from the OpenAI model into names (a list of names for | |
| each entity) and edges (relations between entities). `identifier` is a | |
| string used to identify the response text in error messages. | |
| """ | |
| lines = strip_and_remove_empty_strings(response_text.split('\n')) | |
| if 'Named entities' not in lines[0]: | |
| logger.error(f'{identifier}: First line of response text does not ' | |
| f'start with "Named entities:". ("{lines[0]}")') | |
| return [], [] | |
| mode = 'names' | |
| names = [] | |
| edges = [] | |
| for line in lines[1:]: | |
| if 'Knowledge graph edges' in line: | |
| mode = 'edges' | |
| continue | |
| if mode == 'names': | |
| if line.startswith('-'): | |
| line = line[1:] | |
| name_group = strip_and_remove_empty_strings(line.split(' / ')) | |
| name_group = [remove_number_prefix(name) for name in name_group] | |
| names.append(name_group) | |
| elif mode == 'edges': | |
| if are_edges_numbered: | |
| if not re.match(r'^\d{1,2}\. ', line): | |
| break | |
| if int(line.split('.')[0]) > MAX_RESPONSE_EDGE_COUNT: | |
| break; | |
| line = line[3:] | |
| edge_components = strip_and_remove_empty_strings(line.split(';')) | |
| if len(edge_components) not in (2, 3): | |
| continue | |
| subjects = strip_and_remove_empty_strings( | |
| edge_components[0].split(',')) | |
| predicate = edge_components[1] | |
| if len(edge_components) == 3: | |
| objects = strip_and_remove_empty_strings( | |
| edge_components[2].split(',')) | |
| else: | |
| objects = [None] | |
| for subject, object_ in product(subjects, objects): | |
| edge = (subject, predicate, object_) | |
| edges.append(edge) | |
| if not names: | |
| logger.error(f'{identifier}: No names were parsed from the response ' | |
| f'text.') | |
| if not edges: | |
| logger.error(f'{identifier}: No edges were parsed from the response ' | |
| f'text.') | |
| return names, edges | |
| def generate_names_graph(names): | |
| """ | |
| Generate a graph of names where the nodes are names and the edges indicate | |
| that two names refer to the same entity. | |
| """ | |
| names_graph = nx.Graph() | |
| for name_group in names: | |
| for name in name_group: | |
| names_graph.add_node(name) | |
| for name_pair in combinations(name_group, 2): | |
| names_graph.add_edge(*name_pair) | |
| return names_graph | |
| def expand_contracted_possessive(predicate, names): | |
| """ | |
| Check if a predicate is of the form "<owner>'s <possession>", where the | |
| owner is a named entity. If so, return a predicate of the form | |
| "<possession> of" and an object of the form "<owner>". | |
| """ | |
| match = re.search( | |
| fr'\'s\s\w+(?:\s\w+)' | |
| fr'{{0,{MAX_POSSESSION_WORD_COUNT - 1}}}$', predicate) | |
| if not match: | |
| return predicate, None | |
| apostrophe_index = match.start() | |
| owner = next( | |
| (name for name in names | |
| if predicate[:apostrophe_index].endswith(name)), None) | |
| if owner is None: | |
| return predicate, None | |
| possession = predicate[apostrophe_index + 2:].strip() | |
| predicate = (f'{predicate[:apostrophe_index - len(owner)].strip()} ' | |
| f'{possession} of') | |
| object_ = owner | |
| return predicate, object_ | |
| def does_duplicate_edge_exist(knowledge_graph, subject, predicate, object_): | |
| """ | |
| Check if an edge with a given subject, predicate, and object already exists | |
| in a knowledge graph. If it exists, return the edge data; otherwise, return None. | |
| """ | |
| for edge in knowledge_graph.edges(subject, data=True): | |
| if edge[1] == object_ and edge[2]['predicate'] == predicate: | |
| return edge | |
| return None | |
| def add_edge_to_knowledge_graph(knowledge_graph, names, edge, max_predicate_word_count, **edge_attributes): | |
| """Add an edge to a knowledge graph, updating count if the edge already exists.""" | |
| subject, predicate, object_ = edge | |
| if subject not in names: | |
| return | |
| if object_ is not None and object_ not in names: | |
| predicate += f' {object_}' | |
| object_ = None | |
| if object_ is None: | |
| object_at_end_of_predicate = next( | |
| (name for name in names if predicate.endswith(' ' + name)), None) | |
| if object_at_end_of_predicate is not None: | |
| object_ = object_at_end_of_predicate | |
| predicate = predicate[:-len(object_)].strip() | |
| else: | |
| predicate, object_ = expand_contracted_possessive(predicate, names) | |
| while predicate.endswith(('.', ',', '!', '?')): | |
| predicate = predicate[:-1] | |
| if (max_predicate_word_count and len(predicate.split()) > max_predicate_word_count): | |
| return | |
| if subject == object_: | |
| return | |
| if object_ is None: | |
| object_ = subject | |
| subject_node = next((node for node in knowledge_graph.nodes if subject in node.names), None) | |
| object_node = next((node for node in knowledge_graph.nodes if object_ in node.names), None) | |
| if subject_node is None or object_node is None: | |
| return | |
| existing_edge = does_duplicate_edge_exist(knowledge_graph, subject_node, predicate, object_node) | |
| if existing_edge: | |
| existing_edge[2]['count'] += 1 | |
| else: | |
| knowledge_graph.add_edge(subject_node, object_node, predicate=predicate, count=1, **edge_attributes) | |
| def initialize_knowledge_graph(names_graph, edges): | |
| """ | |
| Initialize a knowledge graph from a graph of names and a dictionary of | |
| edges grouped by chapter index. | |
| """ | |
| names = set(names_graph.nodes) | |
| knowledge_graph = nx.MultiDiGraph() | |
| for name in names: | |
| knowledge_graph.add_node(NamedEntity({name})) | |
| for chapter_index, chapter_edges in edges.items(): | |
| for edge in chapter_edges: | |
| add_edge_to_knowledge_graph( | |
| knowledge_graph, names, edge, | |
| max_predicate_word_count=MAX_PREDICATE_WORD_COUNT, | |
| chapter_index=chapter_index) | |
| return knowledge_graph | |
| def get_node_edge_count(knowledge_graph, node): | |
| """ | |
| Get the number of edges for a node in a knowledge graph, excluding | |
| self-loops. | |
| """ | |
| edges = (set(knowledge_graph.in_edges(node)) | |
| | set(knowledge_graph.out_edges(node))) | |
| edge_count = sum(1 for edge in edges if edge[0] is not edge[1]) | |
| return edge_count | |
| def merge_nodes(knowledge_graph, nodes_to_merge): | |
| """ | |
| Merge a list of nodes in a knowledge graph into one node, combining their | |
| sets of names and preserving their edges. | |
| """ | |
| merged_node = NamedEntity(set()) | |
| for node in nodes_to_merge: | |
| merged_node.names.update(node.names) | |
| knowledge_graph.add_node(merged_node) | |
| for node in nodes_to_merge: | |
| for edge in itertools.chain(knowledge_graph.out_edges(node, data=True), | |
| knowledge_graph.in_edges(node, data=True)): | |
| subject, object_, attributes = edge | |
| if (does_duplicate_edge_exist(knowledge_graph, merged_node, | |
| attributes['predicate'], object_) | |
| or does_duplicate_edge_exist(knowledge_graph, subject, | |
| attributes['predicate'], | |
| merged_node)): | |
| continue | |
| if subject is object_: | |
| knowledge_graph.add_edge(merged_node, merged_node, | |
| **attributes) | |
| if subject is node: | |
| knowledge_graph.add_edge(merged_node, object_, **attributes) | |
| else: | |
| knowledge_graph.add_edge(subject, merged_node, **attributes) | |
| knowledge_graph.remove_node(node) | |
| def merge_same_entity_nodes(knowledge_graph, names_graph): | |
| """ | |
| Using a graph of names, merge nodes in a knowledge graph corresponding to | |
| the same entity. | |
| """ | |
| for name_pair in names_graph.edges: | |
| first_node = next((node for node in knowledge_graph.nodes | |
| if name_pair[0] in node.names), None) | |
| if first_node is None: | |
| continue | |
| if name_pair[1] in first_node.names: | |
| continue | |
| second_node = next((node for node in knowledge_graph.nodes | |
| if name_pair[1] in node.names), None) | |
| if second_node is None: | |
| continue | |
| if knowledge_graph.has_edge(first_node, second_node): | |
| continue | |
| first_node_edge_count = get_node_edge_count(knowledge_graph, | |
| first_node) | |
| second_node_edge_count = get_node_edge_count(knowledge_graph, | |
| second_node) | |
| if (first_node_edge_count > MAX_MERGEABLE_NODE_EDGE_COUNT | |
| and second_node_edge_count > MAX_MERGEABLE_NODE_EDGE_COUNT): | |
| continue | |
| merge_nodes(knowledge_graph, [first_node, second_node]) | |
| def remove_nodes_with_few_edges(knowledge_graph): | |
| """ | |
| Remove nodes that have fewer than `MIN_NODE_EDGE_COUNT` edges (excluding | |
| self-loops) from a knowledge graph. Repeat until no more nodes are removed. | |
| """ | |
| while True: | |
| nodes_to_remove = [] | |
| for node in knowledge_graph.nodes: | |
| edge_count = get_node_edge_count(knowledge_graph, node) | |
| if edge_count < MIN_NODE_EDGE_COUNT: | |
| nodes_to_remove.append(node) | |
| if not nodes_to_remove: | |
| break | |
| knowledge_graph.remove_nodes_from(nodes_to_remove) | |
| def generate_knowledge_graph(response_texts, project_gutenberg_index): | |
| """ | |
| Use OpenAI API response texts grouped by chapter index to generate a | |
| knowledge graph for a book. | |
| """ | |
| names = [] | |
| edges = defaultdict(list) | |
| for chapter_index, chapter_response_texts in response_texts.items(): | |
| for response_text in chapter_response_texts: | |
| identifier = (f'Book {project_gutenberg_index}, chapter ' | |
| f'{chapter_index}') | |
| chapter_segment_names, chapter_segment_edges = parse_response_text( | |
| response_text, identifier) | |
| names.extend(chapter_segment_names) | |
| edges[chapter_index].extend(chapter_segment_edges) | |
| names_graph = generate_names_graph(names) | |
| knowledge_graph = initialize_knowledge_graph(names_graph, edges) | |
| merge_same_entity_nodes(knowledge_graph, names_graph) | |
| remove_nodes_with_few_edges(knowledge_graph) | |
| return knowledge_graph | |