|
""" |
|
Contains tests and a prototype implementation for the fanout algorithm in |
|
the LLVM refprune pass. |
|
""" |
|
|
|
try: |
|
from graphviz import Digraph |
|
except ImportError: |
|
pass |
|
from collections import defaultdict |
|
|
|
|
|
ENTRY = "A" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def case1(): |
|
edges = { |
|
"A": ["B"], |
|
"B": ["C", "D"], |
|
"C": [], |
|
"D": ["E", "F"], |
|
"E": ["G"], |
|
"F": [], |
|
"G": ["H", "I"], |
|
"I": ["G", "F"], |
|
"H": ["J", "K"], |
|
"J": ["L", "M"], |
|
"K": [], |
|
"L": ["Z"], |
|
"M": ["Z", "O", "P"], |
|
"O": ["Z"], |
|
"P": ["Z"], |
|
"Z": [], |
|
} |
|
nodes = defaultdict(list) |
|
nodes["D"] = ["incref"] |
|
nodes["H"] = ["decref"] |
|
nodes["F"] = ["decref", "decref"] |
|
expected = {"D": {"H", "F"}} |
|
return nodes, edges, expected |
|
|
|
|
|
def case2(): |
|
edges = { |
|
"A": ["B", "C"], |
|
"B": ["C"], |
|
"C": [], |
|
} |
|
nodes = defaultdict(list) |
|
nodes["A"] = ["incref"] |
|
nodes["B"] = ["decref"] |
|
nodes["C"] = ["decref"] |
|
expected = {"A": None} |
|
return nodes, edges, expected |
|
|
|
|
|
def case3(): |
|
nodes, edges, _ = case1() |
|
|
|
edges["H"].append("F") |
|
expected = {"D": None} |
|
return nodes, edges, expected |
|
|
|
|
|
def case4(): |
|
nodes, edges, _ = case1() |
|
|
|
edges["H"].append("E") |
|
expected = {"D": None} |
|
return nodes, edges, expected |
|
|
|
|
|
def case5(): |
|
nodes, edges, _ = case1() |
|
|
|
edges["B"].append("I") |
|
expected = {"D": None} |
|
return nodes, edges, expected |
|
|
|
|
|
def case6(): |
|
nodes, edges, _ = case1() |
|
|
|
edges["I"].append("B") |
|
expected = {"D": None} |
|
return nodes, edges, expected |
|
|
|
|
|
def case7(): |
|
nodes, edges, _ = case1() |
|
|
|
edges["I"].append("M") |
|
expected = {"D": None} |
|
return nodes, edges, expected |
|
|
|
|
|
def case8(): |
|
edges = { |
|
"A": ["B", "C"], |
|
"B": ["C"], |
|
"C": [], |
|
} |
|
nodes = defaultdict(list) |
|
nodes["A"] = ["incref"] |
|
nodes["C"] = ["decref"] |
|
expected = {"A": {"C"}} |
|
return nodes, edges, expected |
|
|
|
|
|
def case9(): |
|
nodes, edges, _ = case8() |
|
|
|
edges["C"].append("B") |
|
expected = {"A": None} |
|
return nodes, edges, expected |
|
|
|
|
|
def case10(): |
|
nodes, edges, _ = case8() |
|
|
|
edges["C"].append("A") |
|
expected = {"A": {"C"}} |
|
return nodes, edges, expected |
|
|
|
|
|
def case11(): |
|
nodes, edges, _ = case8() |
|
edges["C"].append("D") |
|
edges["D"] = [] |
|
expected = {"A": {"C"}} |
|
return nodes, edges, expected |
|
|
|
|
|
def case12(): |
|
nodes, edges, _ = case8() |
|
edges["C"].append("D") |
|
edges["D"] = ["A"] |
|
expected = {"A": {"C"}} |
|
return nodes, edges, expected |
|
|
|
|
|
def case13(): |
|
nodes, edges, _ = case8() |
|
edges["C"].append("D") |
|
edges["D"] = ["B"] |
|
expected = {"A": None} |
|
return nodes, edges, expected |
|
|
|
|
|
def make_predecessor_map(edges): |
|
d = defaultdict(set) |
|
for src, outgoings in edges.items(): |
|
for dst in outgoings: |
|
d[dst].add(src) |
|
return d |
|
|
|
|
|
class FanoutAlgorithm: |
|
def __init__(self, nodes, edges, verbose=False): |
|
self.nodes = nodes |
|
self.edges = edges |
|
self.rev_edges = make_predecessor_map(edges) |
|
self.print = print if verbose else self._null_print |
|
|
|
def run(self): |
|
return self.find_fanout_in_function() |
|
|
|
def _null_print(self, *args, **kwargs): |
|
pass |
|
|
|
def find_fanout_in_function(self): |
|
got = {} |
|
for cur_node in self.edges: |
|
for incref in (x for x in self.nodes[cur_node] if x == "incref"): |
|
decref_blocks = self.find_fanout(cur_node) |
|
self.print(">>", cur_node, "===", decref_blocks) |
|
got[cur_node] = decref_blocks |
|
return got |
|
|
|
def find_fanout(self, head_node): |
|
decref_blocks = self.find_decref_candidates(head_node) |
|
self.print("candidates", decref_blocks) |
|
if not decref_blocks: |
|
return None |
|
if not self.verify_non_overlapping( |
|
head_node, decref_blocks, entry=ENTRY |
|
): |
|
return None |
|
return set(decref_blocks) |
|
|
|
def verify_non_overlapping(self, head_node, decref_blocks, entry): |
|
self.print("verify_non_overlapping".center(80, "-")) |
|
|
|
|
|
todo = list(decref_blocks) |
|
while todo: |
|
cur_node = todo.pop() |
|
visited = set() |
|
|
|
workstack = [cur_node] |
|
del cur_node |
|
while workstack: |
|
cur_node = workstack.pop() |
|
self.print("cur_node", cur_node, "|", workstack) |
|
if cur_node in visited: |
|
continue |
|
if cur_node == entry: |
|
|
|
self.print( |
|
"!! failed because we arrived at entry", cur_node |
|
) |
|
return False |
|
visited.add(cur_node) |
|
|
|
self.print( |
|
f" {cur_node} preds {self.get_predecessors(cur_node)}" |
|
) |
|
for pred in self.get_predecessors(cur_node): |
|
if pred in decref_blocks: |
|
|
|
self.print( |
|
"!! reject because predecessor in decref_blocks" |
|
) |
|
return False |
|
if pred != head_node: |
|
|
|
workstack.append(pred) |
|
|
|
return True |
|
|
|
def get_successors(self, node): |
|
return tuple(self.edges[node]) |
|
|
|
def get_predecessors(self, node): |
|
return tuple(self.rev_edges[node]) |
|
|
|
def has_decref(self, node): |
|
return "decref" in self.nodes[node] |
|
|
|
def walk_child_for_decref( |
|
self, cur_node, path_stack, decref_blocks, depth=10 |
|
): |
|
indent = " " * len(path_stack) |
|
self.print(indent, "walk", path_stack, cur_node) |
|
if depth <= 0: |
|
return False |
|
if cur_node in path_stack: |
|
if cur_node == path_stack[0]: |
|
return False |
|
return True |
|
if self.has_decref(cur_node): |
|
decref_blocks.add(cur_node) |
|
self.print(indent, "found decref") |
|
return True |
|
|
|
depth -= 1 |
|
path_stack += (cur_node,) |
|
found = False |
|
for child in self.get_successors(cur_node): |
|
if not self.walk_child_for_decref( |
|
child, path_stack, decref_blocks |
|
): |
|
found = False |
|
break |
|
else: |
|
found = True |
|
|
|
self.print(indent, f"ret {found}") |
|
return found |
|
|
|
def find_decref_candidates(self, cur_node): |
|
|
|
self.print("find_decref_candidates".center(80, "-")) |
|
path_stack = (cur_node,) |
|
found = False |
|
decref_blocks = set() |
|
for child in self.get_successors(cur_node): |
|
if not self.walk_child_for_decref( |
|
child, path_stack, decref_blocks |
|
): |
|
found = False |
|
break |
|
else: |
|
found = True |
|
if not found: |
|
return set() |
|
else: |
|
return decref_blocks |
|
|
|
|
|
def check_once(): |
|
nodes, edges, expected = case13() |
|
|
|
|
|
G = Digraph() |
|
for node in edges: |
|
G.node(node, shape="rect", label=f"{node}\n" + r"\l".join(nodes[node])) |
|
for node, children in edges.items(): |
|
for child in children: |
|
G.edge(node, child) |
|
|
|
G.view() |
|
|
|
algo = FanoutAlgorithm(nodes, edges, verbose=True) |
|
got = algo.run() |
|
assert expected == got |
|
|
|
|
|
def check_all(): |
|
for k, fn in list(globals().items()): |
|
if k.startswith("case"): |
|
print(f"{fn}".center(80, "-")) |
|
nodes, edges, expected = fn() |
|
algo = FanoutAlgorithm(nodes, edges) |
|
got = algo.run() |
|
assert expected == got |
|
print("ALL PASSED") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
check_all() |
|
|