anthropic-kernel / perf_takehome.py
algorembrant's picture
Upload 39 files
f3ce0b0 verified
import collections
from collections import defaultdict, deque
import heapq
import random
import unittest
# Assumes problem.py exists in the same directory as per the original structure
from problem import (
Engine,
DebugInfo,
SLOT_LIMITS, # Note: Scheduler re-defines this, but we keep import for safety
VLEN,
N_CORES,
SCRATCH_SIZE,
Machine,
Tree,
Input,
HASH_STAGES,
reference_kernel,
build_mem_image,
reference_kernel2,
)
# --- Integrated Scheduler Code ---
# Redefining locally to ensure scheduler uses these exact limits
SCHEDULER_SLOT_LIMITS = {
"alu": 12,
"valu": 6,
"load": 2,
"store": 2,
"flow": 1,
"debug": 64,
}
class Node:
def __init__(self, id, engine, args, desc=""):
self.id = id
self.engine = engine
self.args = args # Tuple of args
self.desc = desc
self.parents = []
self.children = []
self.priority = 0
self.latency = 1 # Default latency
def add_child(self, node):
self.children.append(node)
node.parents.append(self)
class Scheduler:
def __init__(self):
self.nodes = []
self.id_counter = 0
self.scratch_reads = defaultdict(list) # addr -> [nodes reading it]
self.scratch_writes = defaultdict(list) # addr -> [nodes writing it]
def add_op(self, engine, args, desc=""):
node = Node(self.id_counter, engine, args, desc)
self.nodes.append(node)
self.id_counter += 1
# Analyze dependencies
reads, writes = self._get_rw(engine, args)
# RAW (Read After Write): Current node reads from a previous write
for r in reads:
if r in self.scratch_writes and self.scratch_writes[r]:
# Depend on the LAST writer
last_writer = self.scratch_writes[r][-1]
last_writer.add_child(node)
# WAW (Write After Write): Current node writes to same addr as previous write
for w in writes:
if w in self.scratch_writes and self.scratch_writes[w]:
last_writer = self.scratch_writes[w][-1]
last_writer.add_child(node)
# WAR (Write After Read): Current node writes to addr that was read previously
# We must not write until previous reads are done.
for w in writes:
if w in self.scratch_reads and self.scratch_reads[w]:
for reader in self.scratch_reads[w]:
if reader != node: # Don't depend on self
reader.add_child(node)
# Register Access updates
for r in reads:
self.scratch_reads[r].append(node)
for w in writes:
self.scratch_writes[w].append(node)
return node
def _get_rw(self, engine, args):
reads = []
writes = []
# Helpers
def is_addr(x): return isinstance(x, int)
if engine == "alu":
# (op, dest, a1, a2)
# Generic ALU ops usually take 3 args: dest, src1, src2
op, dest, a1, a2 = args
writes.append(dest)
reads.append(a1)
reads.append(a2)
elif engine == "valu":
# varargs
op = args[0]
if op == "vbroadcast":
# dest, src
writes.extend([args[1] + i for i in range(VLEN)])
reads.append(args[2])
elif op == "multiply_add":
# dest, a, b, c
writes.extend([args[1] + i for i in range(VLEN)])
reads.extend([args[2] + i for i in range(VLEN)])
reads.extend([args[3] + i for i in range(VLEN)])
reads.extend([args[4] + i for i in range(VLEN)])
else:
# Generic VALU op: op, dest, a1, a2
# e.g. ^, >>, +, <, &
writes.extend([args[1] + i for i in range(VLEN)])
reads.extend([args[2] + i for i in range(VLEN)])
reads.extend([args[3] + i for i in range(VLEN)])
elif engine == "load":
op = args[0]
if op == "const":
writes.append(args[1])
elif op == "load":
writes.append(args[1])
reads.append(args[2])
elif op == "vload":
writes.extend([args[1] + i for i in range(VLEN)])
reads.append(args[2]) # scalar addr
elif engine == "store":
op = args[0]
if op == "vstore":
reads.append(args[1]) # addr
reads.extend([args[2] + i for i in range(VLEN)]) # val
elif engine == "flow":
op = args[0]
if op == "vselect":
# dest, cond, a, b
writes.extend([args[1] + i for i in range(VLEN)])
reads.extend([args[2] + i for i in range(VLEN)])
reads.extend([args[3] + i for i in range(VLEN)])
reads.extend([args[4] + i for i in range(VLEN)])
elif op == "select":
# dest, cond, a, b
writes.append(args[1])
reads.append(args[2])
reads.append(args[3])
reads.append(args[4])
elif op == "add_imm":
# dest, a, imm
writes.append(args[1])
reads.append(args[2])
elif op == "cond_jump" or op == "cond_jump_rel":
# cond, dest
reads.append(args[1])
elif op == "pause":
pass
return reads, writes
def schedule(self):
# Calculate priorities (longest path)
self._calc_priorities()
ready = [] # Heap of (-priority, node)
in_degree = defaultdict(int)
for node in self.nodes:
in_degree[node] = len(node.parents)
if in_degree[node] == 0:
heapq.heappush(ready, (-node.priority, node.id, node))
instructions = []
# Main Scheduling Loop
while ready or any(count > 0 for count in in_degree.values()):
cycle_ops = defaultdict(list)
deferred = []
usage = {k:0 for k in SCHEDULER_SLOT_LIMITS}
curr_cycle_nodes = []
# Greedy allocation for this cycle
while ready:
prio, nid, node = heapq.heappop(ready)
if usage[node.engine] < SCHEDULER_SLOT_LIMITS[node.engine]:
usage[node.engine] += 1
cycle_ops[node.engine].append(node.args)
curr_cycle_nodes.append(node)
else:
deferred.append((prio, nid, node))
# Push back deferred for next cycle
for item in deferred:
heapq.heappush(ready, item)
# Check for termination or deadlock
if not curr_cycle_nodes and not ready:
if any(in_degree.values()):
raise Exception("Deadlock detected in scheduler")
break
instructions.append(dict(cycle_ops))
# Update children for NEXT cycle
for node in curr_cycle_nodes:
for child in node.children:
in_degree[child] -= 1
if in_degree[child] == 0:
heapq.heappush(ready, (-child.priority, child.id, child))
return instructions
def _calc_priorities(self):
memo = {}
def get_dist(node):
if node in memo: return memo[node]
max_d = 0
for child in node.children:
max_d = max(max_d, get_dist(child))
memo[node] = max_d + 1
return max_d + 1
for node in self.nodes:
node.priority = get_dist(node)
# --- Main Kernel Logic ---
class KernelBuilder:
def __init__(self):
self.scheduler = Scheduler()
self.scratch = {}
self.scratch_debug = {}
self.scratch_ptr = 0
self.const_map = {}
def debug_info(self):
return DebugInfo(scratch_map=self.scratch_debug)
def finalize(self):
return self.scheduler.schedule()
def add_instr(self, instr_dict):
# Compatibility wrapper
for engine, slots in instr_dict.items():
for args in slots:
self.scheduler.add_op(engine, args)
def alloc_scratch(self, name=None, length=1):
addr = self.scratch_ptr
if name is not None:
self.scratch[name] = addr
self.scratch_debug[addr] = (name, length)
self.scratch_ptr += length
assert self.scratch_ptr <= SCRATCH_SIZE, f"Out of scratch space: {self.scratch_ptr}"
return addr
def scratch_const(self, val, name=None):
if val not in self.const_map:
addr = self.alloc_scratch(name)
self.scheduler.add_op("load", ("const", addr, val))
self.const_map[val] = addr
return self.const_map[val]
def scratch_vec_const(self, val, name=None):
key = (val, "vec")
if key not in self.const_map:
addr = self.alloc_scratch(name if name else f"vconst_{val}", VLEN)
scalar_addr = self.scratch_const(val)
self.scheduler.add_op("valu", ("vbroadcast", addr, scalar_addr))
self.const_map[key] = addr
return self.const_map[key]
def add_hash_opt(self, val_vec, tmp1_vec, tmp2_vec):
"""
Adds slots for the strength-reduced hash function to scheduler.
"""
# Stage 0: MAD
c1 = self.scratch_vec_const(0x7ED55D16, "h0_c")
m1 = self.scratch_vec_const(1 + (1<<12), "h0_m")
self.scheduler.add_op("valu", ("multiply_add", val_vec, val_vec, m1, c1))
# Stage 1: Xor, Shift, Xor
c2 = self.scratch_vec_const(0xC761C23C, "h1_c")
s2 = self.scratch_vec_const(19, "h1_s")
# 1a
self.scheduler.add_op("valu", ("^", tmp1_vec, val_vec, c2))
self.scheduler.add_op("valu", (">>", tmp2_vec, val_vec, s2))
# 1b
self.scheduler.add_op("valu", ("^", val_vec, tmp1_vec, tmp2_vec))
# Stage 2: MAD
c3 = self.scratch_vec_const(0x165667B1, "h2_c")
m3 = self.scratch_vec_const(1 + (1<<5), "h2_m")
self.scheduler.add_op("valu", ("multiply_add", val_vec, val_vec, m3, c3))
# Stage 3: Add, Shift, Xor
c4 = self.scratch_vec_const(0xD3A2646C, "h3_c")
s4 = self.scratch_vec_const(9, "h3_s")
self.scheduler.add_op("valu", ("+", tmp1_vec, val_vec, c4))
self.scheduler.add_op("valu", ("<<", tmp2_vec, val_vec, s4))
self.scheduler.add_op("valu", ("^", val_vec, tmp1_vec, tmp2_vec))
# Stage 4: MAD
c5 = self.scratch_vec_const(0xFD7046C5, "h4_c")
m5 = self.scratch_vec_const(1 + (1<<3), "h4_m")
self.scheduler.add_op("valu", ("multiply_add", val_vec, val_vec, m5, c5))
# Stage 5: Xor, Shift, Xor
c6 = self.scratch_vec_const(0xB55A4F09, "h5_c")
s6 = self.scratch_vec_const(16, "h5_s")
self.scheduler.add_op("valu", ("^", tmp1_vec, val_vec, c6))
self.scheduler.add_op("valu", (">>", tmp2_vec, val_vec, s6))
self.scheduler.add_op("valu", ("^", val_vec, tmp1_vec, tmp2_vec))
def add_hash_opt_scalar(self, val_vec, tmp1_vec, tmp2_vec):
"""
Scalarized version of hash optimization.
Unrolls loop over 8 lanes and uses ALU engine.
"""
def add_alu_lanes(op, dest_vec, src1_vec, src2_vec, s2_is_const=False):
for lane in range(VLEN):
s2_addr = src2_vec if s2_is_const else src2_vec + lane
self.scheduler.add_op("alu", (op, dest_vec + lane, src1_vec + lane, s2_addr))
def add_mad_lanes(dest_vec, a_vec, b_vec, c_vec, b_is_const=False, c_is_const=False):
for lane in range(VLEN):
b_addr = b_vec if b_is_const else b_vec + lane
c_addr = c_vec if c_is_const else c_vec + lane
# dest = a*b
self.scheduler.add_op("alu", ("*", dest_vec + lane, a_vec + lane, b_addr))
# dest = dest+c
self.scheduler.add_op("alu", ("+", dest_vec + lane, dest_vec + lane, c_addr))
# Stage 0: MAD
c1 = self.scratch_const(0x7ED55D16, "h0_c")
m1 = self.scratch_const(1 + (1<<12), "h0_m")
add_mad_lanes(val_vec, val_vec, m1, c1, True, True)
# Stage 1: Xor, Shift, Xor
c2 = self.scratch_const(0xC761C23C, "h1_c")
s2 = self.scratch_const(19, "h1_s")
add_alu_lanes("^", tmp1_vec, val_vec, c2, True)
add_alu_lanes(">>", tmp2_vec, val_vec, s2, True)
add_alu_lanes("^", val_vec, tmp1_vec, tmp2_vec, False)
# Stage 2: MAD
c3 = self.scratch_const(0x165667B1, "h2_c")
m3 = self.scratch_const(1 + (1<<5), "h2_m")
add_mad_lanes(val_vec, val_vec, m3, c3, True, True)
# Stage 3: Add, Shift, Xor
c4 = self.scratch_const(0xD3A2646C, "h3_c")
s4 = self.scratch_const(9, "h3_s")
add_alu_lanes("+", tmp1_vec, val_vec, c4, True)
add_alu_lanes("<<", tmp2_vec, val_vec, s4, True)
add_alu_lanes("^", val_vec, tmp1_vec, tmp2_vec, False)
# Stage 4: MAD
c5 = self.scratch_const(0xFD7046C5, "h4_c")
m5 = self.scratch_const(1 + (1<<3), "h4_m")
add_mad_lanes(val_vec, val_vec, m5, c5, True, True)
# Stage 5: Xor, Shift, Xor
c6 = self.scratch_const(0xB55A4F09, "h5_c")
s6 = self.scratch_const(16, "h5_s")
add_alu_lanes("^", tmp1_vec, val_vec, c6, True)
add_alu_lanes(">>", tmp2_vec, val_vec, s6, True)
add_alu_lanes("^", val_vec, tmp1_vec, tmp2_vec, False)
def build_kernel(
self, forest_height: int, n_nodes: int, batch_size: int, rounds: int,
active_threshold=4, mask_skip=True, scalar_offload=2
):
result_scalar_offload = scalar_offload
# --- Memory Pointers ---
init_vars = [
"rounds", "n_nodes", "batch_size", "forest_height",
"forest_values_p", "inp_indices_p", "inp_values_p"
]
ptr_map = {}
tmp_load = self.alloc_scratch("tmp_load")
for i, v in enumerate(init_vars):
addr = self.alloc_scratch(v)
ptr_map[v] = addr
self.scheduler.add_op("load", ("const", tmp_load, i))
self.scheduler.add_op("load", ("load", addr, tmp_load))
indices_base = self.alloc_scratch("indices_cache", batch_size)
values_base = self.alloc_scratch("values_cache", batch_size)
# Memory Optimization: Reuse Scratch
block_x = self.alloc_scratch("block_x", batch_size)
block_y = self.alloc_scratch("block_y", batch_size)
num_vecs = batch_size // VLEN
tmp_addrs_base = block_x
node_vals_base = block_x
vtmp1_base = block_x
vtmp2_base = block_y
# Constants
const_0_vec = self.scratch_vec_const(0)
const_1_vec = self.scratch_vec_const(1)
global_n_nodes_vec = self.alloc_scratch("n_nodes_vec", VLEN)
self.scheduler.add_op("valu", ("vbroadcast", global_n_nodes_vec, ptr_map["n_nodes"]))
active_temp_base = self.alloc_scratch("active_temp", 200)
# --- 1. Load Input Data (Wavefront) ---
for i in range(0, batch_size, VLEN):
i_const = self.scratch_const(i)
# Indices Addr
self.scheduler.add_op("alu", ("+", tmp_load, ptr_map["inp_indices_p"], i_const))
self.scheduler.add_op("load", ("vload", indices_base + i, tmp_load))
self.scheduler.add_op("alu", ("+", tmp_load, ptr_map["inp_values_p"], i_const))
self.scheduler.add_op("load", ("vload", values_base + i, tmp_load))
# --- 2. Main Loop ---
self.scheduler.add_op("flow", ("pause",))
active_indices = []
for r in range(rounds):
# Collect register pointers for all vectors
vecs = []
for vec_i in range(num_vecs):
offset = vec_i * VLEN
vecs.append({
'idx': indices_base + offset,
'val': values_base + offset,
'node': node_vals_base + offset,
'tmp1': vtmp1_base + offset,
'tmp2': vtmp2_base + offset,
'addr': tmp_addrs_base + offset
})
if r == 0:
# Round 0: 1 Node (0)
scalar_node = self.alloc_scratch("scalar_node_r0")
self.scheduler.add_op("load", ("load", scalar_node, ptr_map["forest_values_p"]))
for vec in vecs:
self.scheduler.add_op("valu", ("vbroadcast", vec['node'], scalar_node))
active_indices = [0]
elif len(active_indices) * 2 <= 8: # Threshold for next round
# Reuse Scratch
active_dev_ptr = active_temp_base
def alloc_temp(length=1):
nonlocal active_dev_ptr
addr = active_dev_ptr
active_dev_ptr += length
assert active_dev_ptr <= active_temp_base + 512
return addr
# Update active indices
new_actives = []
for x in active_indices:
new_actives.append(2*x + 1)
new_actives.append(2*x + 2)
active_indices = new_actives
# Active Load Strategy
node_map = {}
for uidx in active_indices:
s_node = alloc_temp(1)
s_addr = alloc_temp(1)
idx_c = self.scratch_const(uidx)
# Calc Addr
self.scheduler.add_op("alu", ("+", s_addr, ptr_map["forest_values_p"], idx_c))
# Load
self.scheduler.add_op("load", ("load", s_node, s_addr))
# Broadcast
v_node = alloc_temp(VLEN)
self.scheduler.add_op("valu", ("vbroadcast", v_node, s_node))
node_map[uidx] = v_node
tree_temp_start = active_dev_ptr
# Select Tree for each vector
for vec in vecs:
active_dev_ptr = tree_temp_start
def build_tree(indices):
if len(indices) == 1:
return node_map[indices[0]]
mid = len(indices) // 2
left = indices[:mid]
right = indices[mid:]
split_val = right[0]
split_c = self.scratch_vec_const(split_val)
cond = alloc_temp(VLEN)
self.scheduler.add_op("valu", ("<", cond, vec['idx'], split_c))
l_res = build_tree(left)
r_res = build_tree(right)
res = alloc_temp(VLEN)
self.scheduler.add_op("flow", ("vselect", res, cond, l_res, r_res))
return res
final_res = build_tree(active_indices)
self.scheduler.add_op("valu", ("|", vec['node'], final_res, final_res))
else:
# Generic Wavefront Load
for vec in vecs:
for lane in range(VLEN):
self.scheduler.add_op("alu", ("+", vec['addr'] + lane, ptr_map["forest_values_p"], vec['idx'] + lane))
for vec in vecs:
for lane in range(VLEN):
self.scheduler.add_op("load", ("load", vec['node'] + lane, vec['addr'] + lane))
do_wrap = True
if mask_skip and (1<<(r+2)) < n_nodes:
do_wrap = False
use_offload = (r >= active_threshold) and (not do_wrap)
scalar_vectors = vecs[:result_scalar_offload] if use_offload else []
vector_vectors = vecs[result_scalar_offload:] if use_offload else vecs
# --- VECTORIZED VECTORS ---
# Mixed Hash
for vec in vector_vectors:
self.scheduler.add_op("valu", ("^", vec['val'], vec['val'], vec['node']))
for vec in vector_vectors:
self.add_hash_opt(vec['val'], vec['tmp1'], vec['tmp2'])
# Index Update
for vec in vector_vectors:
self.scheduler.add_op("valu", ("&", vec['tmp1'], vec['val'], const_1_vec))
self.scheduler.add_op("valu", ("+", vec['tmp1'], vec['tmp1'], const_1_vec))
self.scheduler.add_op("valu", ("+", vec['idx'], vec['idx'], vec['idx']))
self.scheduler.add_op("valu", ("+", vec['idx'], vec['idx'], vec['tmp1']))
# Wrap
if do_wrap:
for vec in vector_vectors:
self.scheduler.add_op("valu", ("<", vec['tmp1'], vec['idx'], global_n_nodes_vec))
for vec in vector_vectors:
self.scheduler.add_op("flow", ("vselect", vec['idx'], vec['tmp1'], vec['idx'], const_0_vec))
# --- SCALARIZED VECTORS ---
def alu_lanes(op, dest, s1, s2, s2_c=False):
for l in range(VLEN):
s2_Address = s2 if s2_c else s2+l
self.scheduler.add_op("alu", (op, dest+l, s1+l, s2_Address))
# Mixed Hash
for vec in scalar_vectors:
alu_lanes("^", vec['val'], vec['val'], vec['node'], False)
for vec in scalar_vectors:
self.add_hash_opt_scalar(vec['val'], vec['tmp1'], vec['tmp2'])
# Index Update
const_1 = self.scratch_const(1)
for vec in scalar_vectors:
alu_lanes("&", vec['tmp1'], vec['val'], const_1, True)
alu_lanes("+", vec['tmp1'], vec['tmp1'], const_1, True)
alu_lanes("+", vec['idx'], vec['idx'], vec['idx'], False)
alu_lanes("+", vec['idx'], vec['idx'], vec['tmp1'], False)
# Wrap
if do_wrap:
const_0 = self.scratch_const(0)
n_nodes_c = ptr_map["n_nodes"]
for vec in scalar_vectors:
alu_lanes("<", vec['tmp1'], vec['idx'], n_nodes_c, True)
for vec in scalar_vectors:
for l in range(VLEN):
self.scheduler.add_op("flow", ("select", vec['idx']+l, vec['tmp1']+l, vec['idx']+l, const_0))
# --- 3. Final Store ---
for i in range(0, batch_size, VLEN):
i_const = self.scratch_const(i)
self.scheduler.add_op("alu", ("+", tmp_load, ptr_map["inp_indices_p"], i_const))
self.scheduler.add_op("store", ("vstore", tmp_load, indices_base + i))
self.scheduler.add_op("alu", ("+", tmp_load, ptr_map["inp_values_p"], i_const))
self.scheduler.add_op("store", ("vstore", tmp_load, values_base + i))
self.scheduler.add_op("flow", ("pause",))
self.instrs = self.scheduler.schedule()
BASELINE = 147734
def do_kernel_test(
forest_height: int,
rounds: int,
batch_size: int,
seed: int = 123,
trace: bool = False,
prints: bool = False,
):
print(f"{forest_height=}, {rounds=}, {batch_size=}")
random.seed(seed)
forest = Tree.generate(forest_height)
inp = Input.generate(forest, batch_size, rounds)
mem = build_mem_image(forest, inp)
kb = KernelBuilder()
kb.build_kernel(forest.height, len(forest.values), len(inp.indices), rounds)
value_trace = {}
machine = Machine(
mem,
kb.instrs,
kb.debug_info(),
n_cores=N_CORES,
value_trace=value_trace,
trace=trace,
)
machine.prints = prints
while machine.cores[0].state.value != 3: # STOPPED
machine.run()
if machine.cores[0].state.value == 2: # PAUSED
machine.cores[0].state = machine.cores[0].state.__class__(1) # RUNNING
continue
break
# Check FINAL result
machine.enable_pause = False
for ref_mem in reference_kernel2(mem, value_trace):
pass
inp_values_p = ref_mem[6]
# DEBUG PRINT ALWAYS
print("CYCLES: ", machine.cycle)
if hasattr(machine.cores[0], 'trace_buf'):
print("TRACE BUF:", machine.cores[0].trace_buf[:64])
assert (
machine.mem[inp_values_p : inp_values_p + len(inp.values)]
== ref_mem[inp_values_p : inp_values_p + len(inp.values)]
), f"Incorrect result on final round"
return machine.cycle
class Tests(unittest.TestCase):
def test_ref_kernels(self):
random.seed(123)
for i in range(10):
f = Tree.generate(4)
inp = Input.generate(f, 10, 6)
mem = build_mem_image(f, inp)
reference_kernel(f, inp)
for _ in reference_kernel2(mem, {}):
pass
assert inp.indices == mem[mem[5] : mem[5] + len(inp.indices)]
assert inp.values == mem[mem[6] : mem[6] + len(inp.values)]
def test_kernel_trace(self):
do_kernel_test(10, 16, 256, trace=True, prints=False)
def test_kernel_cycles(self):
do_kernel_test(10, 16, 256, prints=False)
if __name__ == "__main__":
unittest.main()