| from __future__ import annotations |
|
|
| import ctypes |
| import json |
| from pathlib import Path |
|
|
|
|
| ROOT = Path(__file__).resolve().parents[3] |
| DLL = ROOT / "runtime_rust_dll" / "target" / "release" / "acestep_runtime.dll" |
|
|
| lib = ctypes.CDLL(str(DLL)) |
|
|
| lib.ace_create_context.argtypes = [ctypes.c_char_p] |
| lib.ace_create_context.restype = ctypes.c_void_p |
| lib.ace_free_context.argtypes = [ctypes.c_void_p] |
| lib.ace_string_free.argtypes = [ctypes.c_void_p] |
|
|
| lib.ace_prepare_step_inputs.argtypes = [ |
| ctypes.c_void_p, |
| ctypes.c_char_p, |
| ctypes.POINTER(ctypes.c_float), |
| ctypes.c_size_t, |
| ctypes.POINTER(ctypes.c_void_p), |
| ] |
| lib.ace_prepare_step_inputs.restype = ctypes.c_int32 |
|
|
| lib.ace_scheduler_step.argtypes = [ |
| ctypes.c_void_p, |
| ctypes.POINTER(ctypes.c_float), |
| ctypes.POINTER(ctypes.c_float), |
| ctypes.c_size_t, |
| ctypes.c_float, |
| ctypes.POINTER(ctypes.c_float), |
| ] |
| lib.ace_scheduler_step.restype = ctypes.c_int32 |
|
|
|
|
| def cstr(ptr: int) -> str: |
| return ctypes.cast(ptr, ctypes.c_char_p).value.decode("utf-8") |
|
|
|
|
| def main() -> int: |
| cfg = {"seed": 42, "blocked_token_ids": [0]} |
| ctx = lib.ace_create_context(json.dumps(cfg).encode("utf-8")) |
| if not ctx: |
| raise RuntimeError("failed to create context") |
|
|
| try: |
| state = {"shift": 3.0, "inference_steps": 8, "current_step": 0} |
| in_buf = (ctypes.c_float * 4)(1.0, 2.0, 3.0, 4.0) |
| out_json = ctypes.c_void_p() |
| rc = lib.ace_prepare_step_inputs( |
| ctx, |
| json.dumps(state).encode("utf-8"), |
| in_buf, |
| 4, |
| ctypes.byref(out_json), |
| ) |
| if rc != 0: |
| raise RuntimeError("ace_prepare_step_inputs failed") |
| print("prepare:", cstr(out_json.value)) |
| lib.ace_string_free(out_json) |
|
|
| xt = (ctypes.c_float * 4)(1.0, 1.0, 1.0, 1.0) |
| vt = (ctypes.c_float * 4)(0.1, 0.2, 0.3, 0.4) |
| out = (ctypes.c_float * 4)() |
| rc = lib.ace_scheduler_step(ctx, xt, vt, 4, ctypes.c_float(0.5), out) |
| if rc != 0: |
| raise RuntimeError("ace_scheduler_step failed") |
| print("scheduler:", list(out)) |
| finally: |
| lib.ace_free_context(ctx) |
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| raise SystemExit(main()) |
|
|