Spaces:
Runtime error
Runtime error
import re | |
from collections import namedtuple | |
from typing import Any, Dict, List, Optional, Tuple | |
Schema = namedtuple("Schema", ["left_node", "relation", "right_node"]) | |
class CypherQueryCorrector: | |
""" | |
Used to correct relationship direction in generated Cypher statements. | |
This code is copied from the winner's submission to the Cypher competition: | |
https://github.com/sakusaku-rich/cypher-direction-competition | |
""" | |
property_pattern = re.compile(r"\{.+?\}") | |
node_pattern = re.compile(r"\(.+?\)") | |
path_pattern = re.compile( | |
r"(\([^\,\(\)]*?(\{.+\})?[^\,\(\)]*?\))(<?-)(\[.*?\])?(->?)(\([^\,\(\)]*?(\{.+\})?[^\,\(\)]*?\))" | |
) | |
node_relation_node_pattern = re.compile( | |
r"(\()+(?P<left_node>[^()]*?)\)(?P<relation>.*?)\((?P<right_node>[^()]*?)(\))+" | |
) | |
relation_type_pattern = re.compile(r":(?P<relation_type>.+?)?(\{.+\})?]") | |
def __init__(self, schemas: List[Schema]): | |
""" | |
Args: | |
schemas: list of schemas | |
""" | |
self.schemas = schemas | |
def clean_node(self, node: str) -> str: | |
""" | |
Args: | |
node: node in string format | |
""" | |
node = re.sub(self.property_pattern, "", node) | |
node = node.replace("(", "") | |
node = node.replace(")", "") | |
node = node.strip() | |
return node | |
def detect_node_variables(self, query: str) -> Dict[str, List[str]]: | |
""" | |
Args: | |
query: cypher query | |
""" | |
nodes = re.findall(self.node_pattern, query) | |
nodes = [self.clean_node(node) for node in nodes] | |
res: Dict[str, Any] = {} | |
for node in nodes: | |
parts = node.split(":") | |
if parts == "": | |
continue | |
variable = parts[0] | |
if variable not in res: | |
res[variable] = [] | |
res[variable] += parts[1:] | |
return res | |
def extract_paths(self, query: str) -> "List[str]": | |
""" | |
Args: | |
query: cypher query | |
""" | |
paths = [] | |
idx = 0 | |
while matched := self.path_pattern.findall(query[idx:]): | |
matched = matched[0] | |
matched = [ | |
m for i, m in enumerate(matched) if i not in [1, len(matched) - 1] | |
] | |
path = "".join(matched) | |
idx = query.find(path) + len(path) - len(matched[-1]) | |
paths.append(path) | |
return paths | |
def judge_direction(self, relation: str) -> str: | |
""" | |
Args: | |
relation: relation in string format | |
""" | |
direction = "BIDIRECTIONAL" | |
if relation[0] == "<": | |
direction = "INCOMING" | |
if relation[-1] == ">": | |
direction = "OUTGOING" | |
return direction | |
def extract_node_variable(self, part: str) -> Optional[str]: | |
""" | |
Args: | |
part: node in string format | |
""" | |
part = part.lstrip("(").rstrip(")") | |
idx = part.find(":") | |
if idx != -1: | |
part = part[:idx] | |
return None if part == "" else part | |
def detect_labels( | |
self, str_node: str, node_variable_dict: Dict[str, Any] | |
) -> List[str]: | |
""" | |
Args: | |
str_node: node in string format | |
node_variable_dict: dictionary of node variables | |
""" | |
splitted_node = str_node.split(":") | |
variable = splitted_node[0] | |
labels = [] | |
if variable in node_variable_dict: | |
labels = node_variable_dict[variable] | |
elif variable == "" and len(splitted_node) > 1: | |
labels = splitted_node[1:] | |
return labels | |
def verify_schema( | |
self, | |
from_node_labels: List[str], | |
relation_types: List[str], | |
to_node_labels: List[str], | |
) -> bool: | |
""" | |
Args: | |
from_node_labels: labels of the from node | |
relation_type: type of the relation | |
to_node_labels: labels of the to node | |
""" | |
valid_schemas = self.schemas | |
if from_node_labels != []: | |
from_node_labels = [label.strip("`") for label in from_node_labels] | |
valid_schemas = [ | |
schema for schema in valid_schemas if schema[0] in from_node_labels | |
] | |
if to_node_labels != []: | |
to_node_labels = [label.strip("`") for label in to_node_labels] | |
valid_schemas = [ | |
schema for schema in valid_schemas if schema[2] in to_node_labels | |
] | |
if relation_types != []: | |
relation_types = [type.strip("`") for type in relation_types] | |
valid_schemas = [ | |
schema for schema in valid_schemas if schema[1] in relation_types | |
] | |
return valid_schemas != [] | |
def detect_relation_types(self, str_relation: str) -> Tuple[str, List[str]]: | |
""" | |
Args: | |
str_relation: relation in string format | |
""" | |
relation_direction = self.judge_direction(str_relation) | |
relation_type = self.relation_type_pattern.search(str_relation) | |
if relation_type is None or relation_type.group("relation_type") is None: | |
return relation_direction, [] | |
relation_types = [ | |
t.strip().strip("!") | |
for t in relation_type.group("relation_type").split("|") | |
] | |
return relation_direction, relation_types | |
def correct_query(self, query: str) -> str: | |
""" | |
Args: | |
query: cypher query | |
""" | |
node_variable_dict = self.detect_node_variables(query) | |
paths = self.extract_paths(query) | |
for path in paths: | |
original_path = path | |
start_idx = 0 | |
while start_idx < len(path): | |
match_res = re.match(self.node_relation_node_pattern, path[start_idx:]) | |
if match_res is None: | |
break | |
start_idx += match_res.start() | |
match_dict = match_res.groupdict() | |
left_node_labels = self.detect_labels( | |
match_dict["left_node"], node_variable_dict | |
) | |
right_node_labels = self.detect_labels( | |
match_dict["right_node"], node_variable_dict | |
) | |
end_idx = ( | |
start_idx | |
+ 4 | |
+ len(match_dict["left_node"]) | |
+ len(match_dict["relation"]) | |
+ len(match_dict["right_node"]) | |
) | |
original_partial_path = original_path[start_idx : end_idx + 1] | |
relation_direction, relation_types = self.detect_relation_types( | |
match_dict["relation"] | |
) | |
if relation_types != [] and "".join(relation_types).find("*") != -1: | |
start_idx += ( | |
len(match_dict["left_node"]) + len(match_dict["relation"]) + 2 | |
) | |
continue | |
if relation_direction == "OUTGOING": | |
is_legal = self.verify_schema( | |
left_node_labels, relation_types, right_node_labels | |
) | |
if not is_legal: | |
is_legal = self.verify_schema( | |
right_node_labels, relation_types, left_node_labels | |
) | |
if is_legal: | |
corrected_relation = "<" + match_dict["relation"][:-1] | |
corrected_partial_path = original_partial_path.replace( | |
match_dict["relation"], corrected_relation | |
) | |
query = query.replace( | |
original_partial_path, corrected_partial_path | |
) | |
else: | |
return "" | |
elif relation_direction == "INCOMING": | |
is_legal = self.verify_schema( | |
right_node_labels, relation_types, left_node_labels | |
) | |
if not is_legal: | |
is_legal = self.verify_schema( | |
left_node_labels, relation_types, right_node_labels | |
) | |
if is_legal: | |
corrected_relation = match_dict["relation"][1:] + ">" | |
corrected_partial_path = original_partial_path.replace( | |
match_dict["relation"], corrected_relation | |
) | |
query = query.replace( | |
original_partial_path, corrected_partial_path | |
) | |
else: | |
return "" | |
else: | |
is_legal = self.verify_schema( | |
left_node_labels, relation_types, right_node_labels | |
) | |
is_legal |= self.verify_schema( | |
right_node_labels, relation_types, left_node_labels | |
) | |
if not is_legal: | |
return "" | |
start_idx += ( | |
len(match_dict["left_node"]) + len(match_dict["relation"]) + 2 | |
) | |
return query | |
def __call__(self, query: str) -> str: | |
"""Correct the query to make it valid. If | |
Args: | |
query: cypher query | |
""" | |
return self.correct_query(query) | |