algorembrant's picture
Upload 39 files
f3ce0b0 verified
import os
import sys
import ray
from ray import tune
from ray.tune.search.optuna import OptunaSearch
# Add parent dir to path to import perf_takehome
current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.dirname(current_dir)
sys.path.insert(0, parent_dir)
# Add ray/python to path
ray_path = os.path.join(parent_dir, "ray", "python")
sys.path.insert(0, ray_path)
import ray
from ray import tune
def objective(config):
# Wrapper to run kernel test with params
# We need to monkeypath KernelBuilder default args?
# Or modify do_kernel_test to accept kwargs?
# do_kernel_test calls KernelBuilder().build_kernel(...)
# We can perform a hack: Subclass KernelBuilder and inject it?
# Or better: Just use the code from do_kernel_test but adapted.
try:
forest_height = 10
rounds = 16
batch_size = 256
# Setup similar to do_kernel_test
forest = Tree.generate(forest_height)
inp = Input.generate(forest, batch_size, rounds)
mem = build_mem_image(forest, inp)
kb = KernelBuilder()
# Pass tuned parameters
kb.build_kernel(
forest.height,
len(forest.values),
len(inp.indices),
rounds,
active_threshold=config["active_threshold"],
mask_skip=config["mask_skip"]
)
value_trace = {}
machine = Machine(
mem,
kb.instrs,
kb.debug_info(),
n_cores=N_CORES,
value_trace=value_trace,
trace=False,
)
machine.prints = False
# Run
while machine.cores[0].state.value != 3: # STOPPED
machine.run()
if machine.cores[0].state.value == 2: # PAUSED
machine.cores[0].state = machine.cores[0].state.__class__(1)
continue
break
machine.enable_pause = False
# Ref
for ref_mem in reference_kernel2(mem, value_trace):
pass
# Validate
inp_values_p = ref_mem[6]
if machine.mem[inp_values_p : inp_values_p + len(inp.values)] != ref_mem[inp_values_p : inp_values_p + len(inp.values)]:
return {"cycles": 999999, "correct": False}
return {"cycles": machine.cycle, "correct": True}
except Exception as e:
print(f"Error: {e}")
return {"cycles": 999999, "correct": False}
if __name__ == "__main__":
ray.init()
analysis = tune.run(
objective,
config={
"active_threshold": tune.grid_search([4, 8, 16]),
# "mask_skip": tune.grid_search([True, False]), # We know True is better? Or maybe overhead logic is buggy?
"mask_skip": True
},
mode="min",
metric="cycles",
num_samples=1,
)
print("Best config: ", analysis.get_best_config(metric="cycles", mode="min"))
print("Best cycles: ", analysis.best_result["cycles"])