File size: 1,552 Bytes
c65f48d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
from llvmlite.ir import CallInstr
class Visitor(object):
def visit(self, module):
self._module = module
for func in module.functions:
self.visit_Function(func)
def visit_Function(self, func):
self._function = func
for bb in func.blocks:
self.visit_BasicBlock(bb)
def visit_BasicBlock(self, bb):
self._basic_block = bb
for instr in bb.instructions:
self.visit_Instruction(instr)
def visit_Instruction(self, instr):
raise NotImplementedError
@property
def module(self):
return self._module
@property
def function(self):
return self._function
@property
def basic_block(self):
return self._basic_block
class CallVisitor(Visitor):
def visit_Instruction(self, instr):
if isinstance(instr, CallInstr):
self.visit_Call(instr)
def visit_Call(self, instr):
raise NotImplementedError
class ReplaceCalls(CallVisitor):
def __init__(self, orig, repl):
super(ReplaceCalls, self).__init__()
self.orig = orig
self.repl = repl
self.calls = []
def visit_Call(self, instr):
if instr.callee == self.orig:
instr.replace_callee(self.repl)
self.calls.append(instr)
def replace_all_calls(mod, orig, repl):
"""Replace all calls to `orig` to `repl` in module `mod`.
Returns the references to the returned calls
"""
rc = ReplaceCalls(orig, repl)
rc.visit(mod)
return rc.calls
|