reach-vb's picture
reach-vb HF staff
544099ee60f85f9b5462f64b19d52cabecc6d7f3e7dc75f037db0030e18c7d16
a7e4fab
raw
history blame
8.68 kB
"""
Contains tests and a prototype implementation for the fanout algorithm in
the LLVM refprune pass.
"""
try:
from graphviz import Digraph
except ImportError:
pass
from collections import defaultdict
# The entry block. It's always the same.
ENTRY = "A"
# The following caseNN() functions returns a 3-tuple of
# (nodes, edges, expected).
# `nodes` maps BB nodes to incref/decref inside the block.
# `edges` maps BB nodes to their successor BB.
# `expected` maps BB-node with incref to a set of BB-nodes with the decrefs, or
# the value can be None, indicating invalid prune.
def case1():
edges = {
"A": ["B"],
"B": ["C", "D"],
"C": [],
"D": ["E", "F"],
"E": ["G"],
"F": [],
"G": ["H", "I"],
"I": ["G", "F"],
"H": ["J", "K"],
"J": ["L", "M"],
"K": [],
"L": ["Z"],
"M": ["Z", "O", "P"],
"O": ["Z"],
"P": ["Z"],
"Z": [],
}
nodes = defaultdict(list)
nodes["D"] = ["incref"]
nodes["H"] = ["decref"]
nodes["F"] = ["decref", "decref"]
expected = {"D": {"H", "F"}}
return nodes, edges, expected
def case2():
edges = {
"A": ["B", "C"],
"B": ["C"],
"C": [],
}
nodes = defaultdict(list)
nodes["A"] = ["incref"]
nodes["B"] = ["decref"]
nodes["C"] = ["decref"]
expected = {"A": None}
return nodes, edges, expected
def case3():
nodes, edges, _ = case1()
# adds an invalid edge
edges["H"].append("F")
expected = {"D": None}
return nodes, edges, expected
def case4():
nodes, edges, _ = case1()
# adds an invalid edge
edges["H"].append("E")
expected = {"D": None}
return nodes, edges, expected
def case5():
nodes, edges, _ = case1()
# adds backedge to go before incref
edges["B"].append("I")
expected = {"D": None}
return nodes, edges, expected
def case6():
nodes, edges, _ = case1()
# adds backedge to go before incref
edges["I"].append("B")
expected = {"D": None}
return nodes, edges, expected
def case7():
nodes, edges, _ = case1()
# adds forward jump outside
edges["I"].append("M")
expected = {"D": None}
return nodes, edges, expected
def case8():
edges = {
"A": ["B", "C"],
"B": ["C"],
"C": [],
}
nodes = defaultdict(list)
nodes["A"] = ["incref"]
nodes["C"] = ["decref"]
expected = {"A": {"C"}}
return nodes, edges, expected
def case9():
nodes, edges, _ = case8()
# adds back edge
edges["C"].append("B")
expected = {"A": None}
return nodes, edges, expected
def case10():
nodes, edges, _ = case8()
# adds back edge to A
edges["C"].append("A")
expected = {"A": {"C"}}
return nodes, edges, expected
def case11():
nodes, edges, _ = case8()
edges["C"].append("D")
edges["D"] = []
expected = {"A": {"C"}}
return nodes, edges, expected
def case12():
nodes, edges, _ = case8()
edges["C"].append("D")
edges["D"] = ["A"]
expected = {"A": {"C"}}
return nodes, edges, expected
def case13():
nodes, edges, _ = case8()
edges["C"].append("D")
edges["D"] = ["B"]
expected = {"A": None}
return nodes, edges, expected
def make_predecessor_map(edges):
d = defaultdict(set)
for src, outgoings in edges.items():
for dst in outgoings:
d[dst].add(src)
return d
class FanoutAlgorithm:
def __init__(self, nodes, edges, verbose=False):
self.nodes = nodes
self.edges = edges
self.rev_edges = make_predecessor_map(edges)
self.print = print if verbose else self._null_print
def run(self):
return self.find_fanout_in_function()
def _null_print(self, *args, **kwargs):
pass
def find_fanout_in_function(self):
got = {}
for cur_node in self.edges:
for incref in (x for x in self.nodes[cur_node] if x == "incref"):
decref_blocks = self.find_fanout(cur_node)
self.print(">>", cur_node, "===", decref_blocks)
got[cur_node] = decref_blocks
return got
def find_fanout(self, head_node):
decref_blocks = self.find_decref_candidates(head_node)
self.print("candidates", decref_blocks)
if not decref_blocks:
return None
if not self.verify_non_overlapping(
head_node, decref_blocks, entry=ENTRY
):
return None
return set(decref_blocks)
def verify_non_overlapping(self, head_node, decref_blocks, entry):
self.print("verify_non_overlapping".center(80, "-"))
# reverse walk for each decref_blocks
# they should end at head_node
todo = list(decref_blocks)
while todo:
cur_node = todo.pop()
visited = set()
workstack = [cur_node]
del cur_node
while workstack:
cur_node = workstack.pop()
self.print("cur_node", cur_node, "|", workstack)
if cur_node in visited:
continue # skip
if cur_node == entry:
# Entry node
self.print(
"!! failed because we arrived at entry", cur_node
)
return False
visited.add(cur_node)
# check all predecessors
self.print(
f" {cur_node} preds {self.get_predecessors(cur_node)}"
)
for pred in self.get_predecessors(cur_node):
if pred in decref_blocks:
# reject because there's a predecessor in decref_blocks
self.print(
"!! reject because predecessor in decref_blocks"
)
return False
if pred != head_node:
workstack.append(pred)
return True
def get_successors(self, node):
return tuple(self.edges[node])
def get_predecessors(self, node):
return tuple(self.rev_edges[node])
def has_decref(self, node):
return "decref" in self.nodes[node]
def walk_child_for_decref(
self, cur_node, path_stack, decref_blocks, depth=10
):
indent = " " * len(path_stack)
self.print(indent, "walk", path_stack, cur_node)
if depth <= 0:
return False # missing
if cur_node in path_stack:
if cur_node == path_stack[0]:
return False # reject interior node backedge
return True # skip
if self.has_decref(cur_node):
decref_blocks.add(cur_node)
self.print(indent, "found decref")
return True
depth -= 1
path_stack += (cur_node,)
found = False
for child in self.get_successors(cur_node):
if not self.walk_child_for_decref(
child, path_stack, decref_blocks
):
found = False
break
else:
found = True
self.print(indent, f"ret {found}")
return found
def find_decref_candidates(self, cur_node):
# Forward pass
self.print("find_decref_candidates".center(80, "-"))
path_stack = (cur_node,)
found = False
decref_blocks = set()
for child in self.get_successors(cur_node):
if not self.walk_child_for_decref(
child, path_stack, decref_blocks
):
found = False
break
else:
found = True
if not found:
return set()
else:
return decref_blocks
def check_once():
nodes, edges, expected = case13()
# Render graph
G = Digraph()
for node in edges:
G.node(node, shape="rect", label=f"{node}\n" + r"\l".join(nodes[node]))
for node, children in edges.items():
for child in children:
G.edge(node, child)
G.view()
algo = FanoutAlgorithm(nodes, edges, verbose=True)
got = algo.run()
assert expected == got
def check_all():
for k, fn in list(globals().items()):
if k.startswith("case"):
print(f"{fn}".center(80, "-"))
nodes, edges, expected = fn()
algo = FanoutAlgorithm(nodes, edges)
got = algo.run()
assert expected == got
print("ALL PASSED")
if __name__ == "__main__":
# check_once()
check_all()