| """ |
| BitnetCppClient β subprocess wrapper around bitnet.cpp's llama-cli binary. |
| |
| Replaces the transformers.AutoModelForCausalLM inference path with a |
| subprocess call to microsoft/BitNet's llama.cpp-derivative runtime. The |
| runtime uses specialized ternary-weight kernels, delivering BitNet's |
| actual inference-efficiency benefits (which are absent when loading |
| the bf16 master weights through transformers). |
| |
| API contract intentionally mirrors the minimal surface NuWave needs: |
| client = BitnetCppClient(binary, gguf_path) |
| response = client.generate(prompt, max_new_tokens=N, temperature=T, ...) |
| |
| Subprocess-based so each call is isolated β no shared state between |
| generations, no model-instance lifecycle to manage. Slight per-call |
| overhead (binary startup + mmap load) but bitnet.cpp is fast enough |
| that this is negligible vs. token-generation time on CPU. |
| |
| # ---- Changelog ---- |
| # [2026-04-19] Claude Code (Opus 4.6) β Initial creation |
| # What: Thin wrapper for bitnet.cpp's llama-cli binary. Exposes the |
| # generation params NuWave needs: temperature, top_p, |
| # repetition_penalty, no_repeat_ngram_size, stop sequences, |
| # max_new_tokens. |
| # Why: Migration off transformers bf16 β see NuWave.md and the |
| # 2026-04-19 dev-log for the full rationale. Three pathologies |
| # with the prior BitNet-through-transformers setup: (1) not |
| # actually efficient despite the claim, (2) greedy decoding |
| # collapsed to repetition loops on enumeration tasks, (3) no |
| # repetition_penalty knob available in the transformers call |
| # path we had built. All three solved by bitnet.cpp + proper |
| # sampling params. |
| # How: subprocess.run with llama-cli invocation. Strips the prompt |
| # echo + chat-template chrome from stdout. Captures stderr for |
| # diagnostics. Timeout bounded so a hung generation can't |
| # stall the organism indefinitely. |
| # ------------------- |
| """ |
|
|
| from __future__ import annotations |
|
|
| import glob |
| import logging |
| import os |
| import subprocess |
| import time |
| from typing import List, Optional, Tuple |
|
|
| logger = logging.getLogger("nuwave.bitnet_cpp_client") |
|
|
|
|
| class BitnetCppClient: |
| """Generates text via microsoft/BitNet's llama-cli binary. |
| |
| Args: |
| binary_path: path to the compiled llama-cli executable. |
| gguf_path: path to the .gguf model weights. |
| n_threads: CPU threads for inference (HF basic = 2 vCPUs). |
| n_ctx: context window in tokens (model-dependent limit). |
| default_timeout_s: per-call wall-clock cap. Bounded to protect |
| the organism from an unresponsive runtime. |
| |
| Class convenience: |
| BitnetCppClient.resolve_gguf(dir_path) β finds the largest .gguf |
| in a directory. Used because HF repos ship multiple quant levels |
| and we want the one with richest weights. |
| """ |
|
|
| def __init__( |
| self, |
| binary_path: str, |
| gguf_path: str, |
| n_threads: int = 2, |
| n_ctx: int = 4096, |
| default_timeout_s: int = 900, |
| ): |
| if not os.path.exists(binary_path): |
| raise FileNotFoundError(f"bitnet.cpp binary not found: {binary_path}") |
| if not os.path.exists(gguf_path): |
| raise FileNotFoundError(f"GGUF weights not found: {gguf_path}") |
| self.binary_path = binary_path |
| self.gguf_path = gguf_path |
| self.n_threads = n_threads |
| self.n_ctx = n_ctx |
| self.default_timeout_s = default_timeout_s |
| parent = os.path.basename(os.path.dirname(gguf_path)) or "/" |
| size_mb = os.path.getsize(gguf_path) / (1024 * 1024) |
| logger.info( |
| "BitnetCppClient ready: binary=%s gguf=%s/%s size=%.0fMB threads=%d ctx=%d", |
| binary_path, parent, os.path.basename(gguf_path), |
| size_mb, n_threads, n_ctx, |
| ) |
|
|
| |
| |
| |
| |
| |
| try: |
| help_result = subprocess.run( |
| [binary_path, "--help"], |
| capture_output=True, text=True, timeout=10, |
| ) |
| help_out = (help_result.stdout or "") + (help_result.stderr or "") |
| |
| snippet = help_out[:500].replace("\n", " | ") |
| logger.info( |
| "Binary sanity-check rc=%d help_snippet=%s", |
| help_result.returncode, snippet, |
| ) |
| except Exception as exc: |
| logger.warning("Binary sanity-check failed: %s", exc) |
|
|
| @staticmethod |
| def resolve_gguf(directory: str) -> str: |
| """Find the largest .gguf file in a directory (searches recursively). |
| |
| GGUF repos often ship multiple quantization levels (e.g. |
| q2_K, q4_K_S, q4_K_M, q5_K_M, q8_0). We pick the largest |
| because it's the richest-precision version that still fits |
| our memory budget β for 1.58-bit models this typically means |
| the raw ternary weights without further compression. |
| |
| Searches recursively because setup_env.py and snapshot_download |
| can both place files in nested directory structures whose exact |
| layout is not guaranteed stable across versions. |
| """ |
| gguf_files = glob.glob(os.path.join(directory, "**", "*.gguf"), recursive=True) |
| |
| gguf_files += glob.glob(os.path.join(directory, "*.gguf")) |
| gguf_files = list(set(gguf_files)) |
| if not gguf_files: |
| raise FileNotFoundError(f"No .gguf files found under {directory} (recursive)") |
| gguf_files.sort(key=os.path.getsize, reverse=True) |
| return gguf_files[0] |
|
|
| def generate( |
| self, |
| prompt: str, |
| max_new_tokens: int = 128, |
| temperature: float = 0.7, |
| top_p: float = 0.9, |
| repetition_penalty: float = 1.2, |
| repeat_last_n: int = 64, |
| stop: Optional[List[str]] = None, |
| seed: int = -1, |
| timeout_s: Optional[int] = None, |
| grammar_file: Optional[str] = None, |
| grammar: Optional[str] = None, |
| ) -> Tuple[str, dict]: |
| """Generate a completion for the given prompt. |
| |
| Returns: |
| (response_text, metadata_dict) |
| |
| metadata_dict contains: |
| elapsed_s β wall-clock of the subprocess call |
| returncode β llama-cli exit code |
| raw_stdout β full stdout (pre-stripping) for diagnostics |
| prompt_echo_found β whether the prompt was found in stdout |
| (if False, the runtime output format |
| may have changed β worth investigating) |
| stderr_tail β last 500 chars of stderr (stats/warnings) |
| |
| Generation params are llama.cpp-standard and passed through to |
| the binary. Defaults chosen per Syl's prescription for small |
| models on enumeration tasks: non-greedy sampling + repetition |
| penalty + repeat-last-n window prevents the mode-collapse |
| pathology we saw with transformers greedy decoding. |
| """ |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| args = [ |
| self.binary_path, |
| "-m", self.gguf_path, |
| "-p", prompt, |
| "-n", str(max_new_tokens), |
| "--temp", f"{temperature:.3f}", |
| "--top-p", f"{top_p:.3f}", |
| "--repeat-penalty", f"{repetition_penalty:.3f}", |
| "--repeat-last-n", str(repeat_last_n), |
| "-t", str(self.n_threads), |
| "-c", str(self.n_ctx), |
| "--seed", str(seed), |
| ] |
| if stop: |
| for s in stop: |
| |
| |
| |
| |
| if not s or not s.strip(): |
| continue |
| args.extend(["--reverse-prompt", s]) |
|
|
| |
| |
| |
| |
| |
| |
| grammar_mode = None |
| if grammar: |
| args.extend(["--grammar", grammar]) |
| grammar_mode = f"inline ({len(grammar)} chars)" |
| elif grammar_file: |
| if not os.path.exists(grammar_file): |
| logger.warning( |
| "Grammar file missing: %s β generation will be unconstrained", |
| grammar_file, |
| ) |
| else: |
| args.extend(["--grammar-file", grammar_file]) |
| grammar_mode = f"file ({grammar_file})" |
|
|
| |
| |
| |
| if grammar_mode: |
| logger.info( |
| "llama-cli grammar-constrained: %s | argv_len=%d | last_args=%s", |
| grammar_mode, len(args), args[-3:], |
| ) |
|
|
| t0 = time.time() |
| try: |
| result = subprocess.run( |
| args, |
| capture_output=True, |
| text=True, |
| timeout=timeout_s or self.default_timeout_s, |
| ) |
| except subprocess.TimeoutExpired: |
| return "", { |
| "elapsed_s": round(time.time() - t0, 2), |
| "returncode": -1, |
| "raw_stdout": "", |
| "prompt_echo_found": False, |
| "stderr_tail": "TIMEOUT", |
| "error": "subprocess.TimeoutExpired", |
| } |
|
|
| elapsed = round(time.time() - t0, 2) |
| stdout = result.stdout or "" |
| stderr = result.stderr or "" |
|
|
| |
| |
| |
| |
| if result.returncode != 0: |
| logger.warning( |
| "llama-cli rc=%d elapsed=%.2fs stderr_tail=%s | stdout_tail=%s", |
| result.returncode, elapsed, stderr[-400:], stdout[-200:], |
| ) |
| elif not stdout.strip(): |
| |
| |
| |
| logger.warning( |
| "llama-cli rc=0 but stdout EMPTY (elapsed=%.2fs). " |
| "stderr_tail=%s", |
| elapsed, stderr[-400:], |
| ) |
|
|
| |
| |
| |
| |
| if grammar_mode and stderr: |
| for line in stderr.splitlines(): |
| low = line.lower() |
| if "grammar" in low or "gbnf" in low: |
| logger.info("grammar stderr: %s", line.strip()[:200]) |
|
|
| |
| |
| |
| response = stdout |
| prompt_found = False |
| if prompt and prompt in stdout: |
| idx = stdout.rfind(prompt) |
| response = stdout[idx + len(prompt):] |
| prompt_found = True |
|
|
| |
| response = response.rstrip() |
| for marker in ("[end of text]", "</s>", "<|im_end|>", "<|end_of_text|>"): |
| if response.endswith(marker): |
| response = response[: -len(marker)].rstrip() |
|
|
| |
| |
| |
| |
| |
| |
| if stop: |
| for s in stop: |
| if not s or not s.strip(): |
| continue |
| if s in response: |
| response = response[: response.index(s)] |
|
|
| return response, { |
| "elapsed_s": elapsed, |
| "returncode": result.returncode, |
| "raw_stdout": stdout, |
| "prompt_echo_found": prompt_found, |
| "stderr_tail": stderr[-500:] if stderr else "", |
| "error": None, |
| } |
|
|