"""Standalone streaming Lilylet generator (int8 ONNX + two-level KV cache). Torch-free adaptation of deep-starry's ORTGeneratorKV (tests/bench_lilylet_int8_ort.py): the patch-level decoder runs incrementally through `patch_kv_int8.onnx` (O(1) per step) and the token decoder inside each patch through `token_kv_int8.onnx`, both via onnxruntime. The token embedding table and model geometry are loaded from vendored assets (`wte.npy`, `geometry.json`) instead of a torch model, and sampling is reimplemented in numpy — so the only runtime deps are onnxruntime + numpy. `generate_stream(...)` is a Python generator: it yields `(raw, pretty, done)` after every patch, where `raw` is the accumulated decoded text (for the run log) and `pretty = postprocess(raw)` (for the editor, segmented by measure). """ import os import json import logging LOG = logging.getLogger('lilyscript') import numpy as np import onnxruntime as ort from .postprocess import postprocess class _StreamAborted (Exception): '''Raised by the token loop when a syntax monitor exhausts its redraw budget, signalling generate_stream to end the stream cleanly rather than commit a bad token (which would corrupt the monitor's running text). Internal to the blacklist-discovery path; never raised when monitor is None.''' pass def sample_next (logits, rng, temperature=1.0, top_k=0, top_p=1.0, banned_ids=None): '''Sample one token id from a logits vector (numpy) with temperature/top-k/top-p. banned_ids: optional iterable of token ids to forbid; their logits are set to -inf before any other filtering, so they can never be drawn (used by the syntax-blacklist mask in tools/lilylet_blacklist_gen.py). Default None = no-op. ''' logits = logits.astype(np.float64) if banned_ids: logits = logits.copy() logits[list(banned_ids)] = -np.inf if temperature != 1.0: logits = logits / max(temperature, 1e-6) if top_k and top_k > 0: k = min(top_k, logits.shape[-1]) kth = np.sort(logits)[-k] logits = np.where(logits < kth, -np.inf, logits) if top_p and top_p < 1.0: order = np.argsort(logits)[::-1] sorted_logits = logits[order] probs = _softmax(sorted_logits) cdf = np.cumsum(probs) remove = cdf > top_p remove[1:] = remove[:-1].copy() remove[0] = False sorted_logits = np.where(remove, -np.inf, sorted_logits) logits = np.full_like(logits, -np.inf) logits[order] = sorted_logits probs = _softmax(logits) return int(rng.choice(len(probs), p=probs)) def _softmax (x): x = x - np.max(x) e = np.exp(x) return e / e.sum() def _physical_cores (): '''Best-effort physical (not logical/HT) core count via /proc/cpuinfo; None if unavailable. ORT's intra_op default (=0) maps to this on most CPU builds.''' try: phys, cur = set(), {} for line in open('/proc/cpuinfo'): line = line.strip() if not line: if 'physical id' in cur and 'core id' in cur: phys.add((cur['physical id'], cur['core id'])) cur = {} continue if ':' in line: k, v = line.split(':', 1) cur[k.strip()] = v.strip() return len(phys) or None except Exception: return None def _log_thread_info (so, sess): '''Log host CPU capacity + the ONNX Runtime intra/inter-op thread settings that are actually in effect. intra_op_num_threads/inter_op_num_threads == 0 means "ORT auto" — it picks the number of physical cores for the intra-op pool.''' logical = os.cpu_count() affinity = len(os.sched_getaffinity(0)) if hasattr(os, 'sched_getaffinity') else logical physical = _physical_cores() intra = so.intra_op_num_threads inter = so.inter_op_num_threads effective_intra = intra if intra else (physical or affinity or logical) LOG.info('CPU: %s logical / %s physical cores, %s available (affinity)', logical, physical if physical is not None else '?', affinity) LOG.info('ONNX Runtime threads: intra_op=%s (%s), inter_op=%s (%s) | execution_mode=%s', intra, 'auto -> ~%s' % effective_intra if intra == 0 else 'explicit', inter, 'auto' if inter == 0 else 'explicit', getattr(so, 'execution_mode', '?')) class StreamingLilyletGenerator: '''Loads the int8 KV ONNX sessions + vendored assets and streams generation.''' def __init__ (self, model_dir, asset_dir, threads=None): from .tokenizer import LilyletTokenizer geo = json.load(open(os.path.join(model_dir, 'geometry.json'))) self.patch_size = geo['patch_size'] self.pad_id = geo['pad_id'] self.bos_id = geo['bos_id'] self.eos_id = geo['eos_id'] # patch-level KV geometry self.n_layers = geo['patch']['n_layers'] self.n_kv = geo['patch']['n_kv_heads'] self.head_dim = geo['patch']['head_dim'] # token-level KV geometry self.t_layers = geo['token']['n_layers'] self.t_kv = geo['token']['n_kv_heads'] self.t_head_dim = geo['token']['head_dim'] self.tokenizer = LilyletTokenizer(os.path.join(asset_dir, 'lilylet-tokenizer.json')) self.wte = np.load(os.path.join(model_dir, 'wte.npy')) # [vocab, hidden] — model weight, lives with the onnx so = ort.SessionOptions() if threads: so.intra_op_num_threads = threads self.patch_kv_sess = ort.InferenceSession( os.path.join(model_dir, 'patch_kv_int8.onnx'), so, providers=['CPUExecutionProvider']) self.token_kv_sess = ort.InferenceSession( os.path.join(model_dir, 'token_kv_int8.onnx'), so, providers=['CPUExecutionProvider']) self.patch_out_names = [o.name for o in self.patch_kv_sess.get_outputs()] self.token_out_names = [o.name for o in self.token_kv_sess.get_outputs()] _log_thread_info(so, self.patch_kv_sess) # ---- text helpers (mirror LilyletPatchyGenerator.patch_to_text) ---- def patch_to_text (self, patch): out = [] for tid in patch: tid = int(tid) if tid == self.eos_id: break if tid in (self.pad_id, self.bos_id): continue out.append(self.tokenizer.text_by_id.get(tid, '')) return ''.join(out) def patches_to_text (self, patches): return ''.join(self.patch_to_text(p) for p in patches) # ---- KV plumbing ---- def _empty_patch_past (self): return [np.zeros((1, self.n_kv, 0, self.head_dim), dtype=np.float32) for _ in range(2 * self.n_layers)] def _empty_token_past (self): return [np.zeros((1, self.t_kv, 0, self.t_head_dim), dtype=np.float32) for _ in range(2 * self.t_layers)] def _patch_kv_step (self, patch_rows, past): '''Feed L new patches (list of patch_size-length id rows) + past KV. Returns (last_hidden [hidden], new_past list).''' feed = {'patches': np.asarray([patch_rows], dtype=np.int64)} for i in range(self.n_layers): feed[f'past_k_{i}'] = past[2 * i] feed[f'past_v_{i}'] = past[2 * i + 1] out = dict(zip(self.patch_out_names, self.patch_kv_sess.run(None, feed))) new_past = [] for i in range(self.n_layers): new_past.append(out[f'new_k_{i}']) new_past.append(out[f'new_v_{i}']) return out['hidden'][0, -1], new_past def _token_kv_step (self, emb_np, past): '''Feed L new token embeddings [1,L,hidden] + past KV. Returns (logits[-1], new_past).''' feed = {'inputs_embeds': emb_np.astype(np.float32)} for i in range(self.t_layers): feed[f'past_k_{i}'] = past[2 * i] feed[f'past_v_{i}'] = past[2 * i + 1] out = dict(zip(self.token_out_names, self.token_kv_sess.run(None, feed))) new_past = [] for i in range(self.t_layers): new_past.append(out[f'new_k_{i}']) new_past.append(out[f'new_v_{i}']) return out['logits'][0, -1], new_past def _generate_patch (self, last_hidden, rng, prefix_ids=None, temperature=1.0, top_k=0, top_p=1.0, monitor=None, max_redraws=16): '''Token-level decode for one patch through the token-KV session: feed the patch state, then bos+prefix embeddings, then sample until the patch is full. monitor: optional syntax-blacklist hook (default None -> unchanged behavior). - monitor.banned() -> iterable of token ids to mask for the next draw, - monitor.accept(id) -> bool: True to commit the drawn token, False to redraw (the monitor records the violation + grows its ban set internally), - monitor.commit_forced(id): notified of each forced (non-sampled) token so its running context/stream stay correct. Forced tokens (bos + prefix_ids) bypass accept()/redraw — only commit_forced. ''' generated = list(prefix_ids or []) tokens = [self.bos_id] + list(prefix_ids or []) if monitor is not None: for tid in prefix_ids or []: monitor.commit_forced(tid) past = self._empty_token_past() enc = last_hidden.reshape(1, 1, -1).astype(np.float32) logits, past = self._token_kv_step(enc, past) for i in range(1, len(tokens)): emb = self.wte[tokens[i]].reshape(1, 1, -1) logits, past = self._token_kv_step(emb, past) while len(generated) < self.patch_size: if monitor is None: nxt = sample_next(logits, rng, temperature=temperature, top_k=top_k, top_p=top_p) else: # draw with the current ban mask; on a rejected (illegal) token the # monitor records it and we redraw from the SAME logits with the now- # larger mask. The token KV cache is only advanced for the committed # token below, so redrawing here costs nothing. for _ in range(max_redraws): nxt = sample_next(logits, rng, temperature=temperature, top_k=top_k, top_p=top_p, banned_ids=monitor.banned()) if monitor.accept(nxt): break else: # every redraw was rejected. Committing a bad token would corrupt # the monitor's running text (every subsequent token would then # parse-fail and be mis-recorded), so abort the stream cleanly # instead — the caller (generate_stream) ends generation here. raise _StreamAborted() generated.append(nxt) if len(generated) < self.patch_size: emb = self.wte[nxt].reshape(1, 1, -1) logits, past = self._token_kv_step(emb, past) return generated def generate_stream (self, prompt_text='', max_patches=256, temperature=1.0, top_k=0, top_p=0.9, measures=None, seed=0, monitor=None): '''Autoregressive generation, yielding after every patch. Yields (raw, pretty, done): raw -- accumulated decoded patch text (with `[r:x/y]` markers), for the log pretty -- postprocess(raw), measure-segmented, for the editor done -- True on the final yield (EOS patch or max_patches reached) monitor: optional syntax-blacklist hook threaded into the token decode loop (default None -> behavior identical to before). When a monitor is given the `[r:0/]` priming re-sample is disabled: priming is a probe-then- discard draw that would corrupt the monitor's running stream, and the blacklist harness wants free generation anyway. ''' rng = np.random.default_rng(seed) bos_patch = [self.bos_id] * (self.patch_size - 1) + [self.eos_id] patches = [bos_patch] if prompt_text: for line in prompt_text.splitlines(): ids = self.tokenizer.encode(line + '\n') for i in range(0, len(ids), self.patch_size): chunk = ids[i:i + self.patch_size] patches.append(chunk + [self.pad_id] * (self.patch_size - len(chunk))) out_text = self.patches_to_text(patches[1:]) prime_ids = self.tokenizer.encode(f'[r:0/{measures}]') if measures is not None else None primed = False # seed the monitor's running context/stream from the prompt patches (if any) if monitor is not None and len(patches) > 1: for p in patches[1:]: for tid in p: monitor.commit_forced(tid) # prefill: run all seed patches through the patch-KV decoder in one call past = self._empty_patch_past() last, past = self._patch_kv_step(patches, past) yield out_text, postprocess(out_text), False for _ in range(max_patches): # If a monitor supports priming-rewind (mark/rollback), snapshot before the # draw so a discarded probe patch can be undone (see priming below). can_prime_with_monitor = monitor is not None and hasattr(monitor, 'mark') if can_prime_with_monitor: monitor.mark() try: patch_ids = self._generate_patch(last, rng, temperature=temperature, top_k=top_k, top_p=top_p, monitor=monitor) except _StreamAborted: break # first time the model emits a stream patch, re-sample with the forced # `[r:0/]` prefix so the body starts at the requested measure count. # With a priming-capable monitor we rewind its running text first, so the # discarded probe patch doesn't pollute the 2-gram context. (The discovery # harness's monitor lacks mark/rollback, so priming stays disabled there.) if prime_ids is not None and not primed and self.patch_to_text(patch_ids).startswith('[r:') \ and (monitor is None or can_prime_with_monitor): primed = True if can_prime_with_monitor: monitor.rollback() patch_ids = self._generate_patch(last, rng, prefix_ids=prime_ids, temperature=temperature, top_k=top_k, top_p=top_p, monitor=monitor) # EOS patch -> done if patch_ids[0] == self.bos_id and patch_ids[1] == self.eos_id: break out_text += self.patch_to_text(patch_ids) # mask tokens after the first EOS inside the patch to PAD before caching clean = list(patch_ids) seen_eos = False for j in range(len(clean)): if seen_eos: clean[j] = self.pad_id if clean[j] == self.eos_id: seen_eos = True # advance the patch-level cache by the one new patch -> next hidden state last, past = self._patch_kv_step([clean], past) yield out_text, postprocess(out_text), False yield out_text, postprocess(out_text), True