LilyScript / lilyscript /generator.py
k-l-lambda's picture
print log about runtime cocurrency
7f31eab
"""Standalone streaming Lilylet generator (int8 ONNX + two-level KV cache).
Torch-free adaptation of deep-starry's ORTGeneratorKV
(tests/bench_lilylet_int8_ort.py): the patch-level decoder runs incrementally
through `patch_kv_int8.onnx` (O(1) per step) and the token decoder inside each
patch through `token_kv_int8.onnx`, both via onnxruntime. The token embedding
table and model geometry are loaded from vendored assets (`wte.npy`,
`geometry.json`) instead of a torch model, and sampling is reimplemented in
numpy — so the only runtime deps are onnxruntime + numpy.
`generate_stream(...)` is a Python generator: it yields `(raw, pretty, done)`
after every patch, where `raw` is the accumulated decoded text (for the run
log) and `pretty = postprocess(raw)` (for the editor, segmented by measure).
"""
import os
import json
import logging
LOG = logging.getLogger('lilyscript')
import numpy as np
import onnxruntime as ort
from .postprocess import postprocess
class _StreamAborted (Exception):
'''Raised by the token loop when a syntax monitor exhausts its redraw budget,
signalling generate_stream to end the stream cleanly rather than commit a bad
token (which would corrupt the monitor's running text). Internal to the
blacklist-discovery path; never raised when monitor is None.'''
pass
def sample_next (logits, rng, temperature=1.0, top_k=0, top_p=1.0, banned_ids=None):
'''Sample one token id from a logits vector (numpy) with temperature/top-k/top-p.
banned_ids: optional iterable of token ids to forbid; their logits are set to
-inf before any other filtering, so they can never be drawn (used by the
syntax-blacklist mask in tools/lilylet_blacklist_gen.py). Default None = no-op.
'''
logits = logits.astype(np.float64)
if banned_ids:
logits = logits.copy()
logits[list(banned_ids)] = -np.inf
if temperature != 1.0:
logits = logits / max(temperature, 1e-6)
if top_k and top_k > 0:
k = min(top_k, logits.shape[-1])
kth = np.sort(logits)[-k]
logits = np.where(logits < kth, -np.inf, logits)
if top_p and top_p < 1.0:
order = np.argsort(logits)[::-1]
sorted_logits = logits[order]
probs = _softmax(sorted_logits)
cdf = np.cumsum(probs)
remove = cdf > top_p
remove[1:] = remove[:-1].copy()
remove[0] = False
sorted_logits = np.where(remove, -np.inf, sorted_logits)
logits = np.full_like(logits, -np.inf)
logits[order] = sorted_logits
probs = _softmax(logits)
return int(rng.choice(len(probs), p=probs))
def _softmax (x):
x = x - np.max(x)
e = np.exp(x)
return e / e.sum()
def _physical_cores ():
'''Best-effort physical (not logical/HT) core count via /proc/cpuinfo; None if
unavailable. ORT's intra_op default (=0) maps to this on most CPU builds.'''
try:
phys, cur = set(), {}
for line in open('/proc/cpuinfo'):
line = line.strip()
if not line:
if 'physical id' in cur and 'core id' in cur:
phys.add((cur['physical id'], cur['core id']))
cur = {}
continue
if ':' in line:
k, v = line.split(':', 1)
cur[k.strip()] = v.strip()
return len(phys) or None
except Exception:
return None
def _log_thread_info (so, sess):
'''Log host CPU capacity + the ONNX Runtime intra/inter-op thread settings that
are actually in effect. intra_op_num_threads/inter_op_num_threads == 0 means
"ORT auto" — it picks the number of physical cores for the intra-op pool.'''
logical = os.cpu_count()
affinity = len(os.sched_getaffinity(0)) if hasattr(os, 'sched_getaffinity') else logical
physical = _physical_cores()
intra = so.intra_op_num_threads
inter = so.inter_op_num_threads
effective_intra = intra if intra else (physical or affinity or logical)
LOG.info('CPU: %s logical / %s physical cores, %s available (affinity)',
logical, physical if physical is not None else '?', affinity)
LOG.info('ONNX Runtime threads: intra_op=%s (%s), inter_op=%s (%s) | execution_mode=%s',
intra, 'auto -> ~%s' % effective_intra if intra == 0 else 'explicit',
inter, 'auto' if inter == 0 else 'explicit',
getattr(so, 'execution_mode', '?'))
class StreamingLilyletGenerator:
'''Loads the int8 KV ONNX sessions + vendored assets and streams generation.'''
def __init__ (self, model_dir, asset_dir, threads=None):
from .tokenizer import LilyletTokenizer
geo = json.load(open(os.path.join(model_dir, 'geometry.json')))
self.patch_size = geo['patch_size']
self.pad_id = geo['pad_id']
self.bos_id = geo['bos_id']
self.eos_id = geo['eos_id']
# patch-level KV geometry
self.n_layers = geo['patch']['n_layers']
self.n_kv = geo['patch']['n_kv_heads']
self.head_dim = geo['patch']['head_dim']
# token-level KV geometry
self.t_layers = geo['token']['n_layers']
self.t_kv = geo['token']['n_kv_heads']
self.t_head_dim = geo['token']['head_dim']
self.tokenizer = LilyletTokenizer(os.path.join(asset_dir, 'lilylet-tokenizer.json'))
self.wte = np.load(os.path.join(model_dir, 'wte.npy')) # [vocab, hidden] — model weight, lives with the onnx
so = ort.SessionOptions()
if threads:
so.intra_op_num_threads = threads
self.patch_kv_sess = ort.InferenceSession(
os.path.join(model_dir, 'patch_kv_int8.onnx'), so, providers=['CPUExecutionProvider'])
self.token_kv_sess = ort.InferenceSession(
os.path.join(model_dir, 'token_kv_int8.onnx'), so, providers=['CPUExecutionProvider'])
self.patch_out_names = [o.name for o in self.patch_kv_sess.get_outputs()]
self.token_out_names = [o.name for o in self.token_kv_sess.get_outputs()]
_log_thread_info(so, self.patch_kv_sess)
# ---- text helpers (mirror LilyletPatchyGenerator.patch_to_text) ----
def patch_to_text (self, patch):
out = []
for tid in patch:
tid = int(tid)
if tid == self.eos_id:
break
if tid in (self.pad_id, self.bos_id):
continue
out.append(self.tokenizer.text_by_id.get(tid, ''))
return ''.join(out)
def patches_to_text (self, patches):
return ''.join(self.patch_to_text(p) for p in patches)
# ---- KV plumbing ----
def _empty_patch_past (self):
return [np.zeros((1, self.n_kv, 0, self.head_dim), dtype=np.float32) for _ in range(2 * self.n_layers)]
def _empty_token_past (self):
return [np.zeros((1, self.t_kv, 0, self.t_head_dim), dtype=np.float32) for _ in range(2 * self.t_layers)]
def _patch_kv_step (self, patch_rows, past):
'''Feed L new patches (list of patch_size-length id rows) + past KV.
Returns (last_hidden [hidden], new_past list).'''
feed = {'patches': np.asarray([patch_rows], dtype=np.int64)}
for i in range(self.n_layers):
feed[f'past_k_{i}'] = past[2 * i]
feed[f'past_v_{i}'] = past[2 * i + 1]
out = dict(zip(self.patch_out_names, self.patch_kv_sess.run(None, feed)))
new_past = []
for i in range(self.n_layers):
new_past.append(out[f'new_k_{i}'])
new_past.append(out[f'new_v_{i}'])
return out['hidden'][0, -1], new_past
def _token_kv_step (self, emb_np, past):
'''Feed L new token embeddings [1,L,hidden] + past KV. Returns (logits[-1], new_past).'''
feed = {'inputs_embeds': emb_np.astype(np.float32)}
for i in range(self.t_layers):
feed[f'past_k_{i}'] = past[2 * i]
feed[f'past_v_{i}'] = past[2 * i + 1]
out = dict(zip(self.token_out_names, self.token_kv_sess.run(None, feed)))
new_past = []
for i in range(self.t_layers):
new_past.append(out[f'new_k_{i}'])
new_past.append(out[f'new_v_{i}'])
return out['logits'][0, -1], new_past
def _generate_patch (self, last_hidden, rng, prefix_ids=None, temperature=1.0, top_k=0, top_p=1.0,
monitor=None, max_redraws=16):
'''Token-level decode for one patch through the token-KV session: feed the
patch state, then bos+prefix embeddings, then sample until the patch is full.
monitor: optional syntax-blacklist hook (default None -> unchanged behavior).
- monitor.banned() -> iterable of token ids to mask for the next draw,
- monitor.accept(id) -> bool: True to commit the drawn token, False to
redraw (the monitor records the violation + grows its ban set internally),
- monitor.commit_forced(id): notified of each forced (non-sampled) token so
its running context/stream stay correct.
Forced tokens (bos + prefix_ids) bypass accept()/redraw — only commit_forced.
'''
generated = list(prefix_ids or [])
tokens = [self.bos_id] + list(prefix_ids or [])
if monitor is not None:
for tid in prefix_ids or []:
monitor.commit_forced(tid)
past = self._empty_token_past()
enc = last_hidden.reshape(1, 1, -1).astype(np.float32)
logits, past = self._token_kv_step(enc, past)
for i in range(1, len(tokens)):
emb = self.wte[tokens[i]].reshape(1, 1, -1)
logits, past = self._token_kv_step(emb, past)
while len(generated) < self.patch_size:
if monitor is None:
nxt = sample_next(logits, rng, temperature=temperature, top_k=top_k, top_p=top_p)
else:
# draw with the current ban mask; on a rejected (illegal) token the
# monitor records it and we redraw from the SAME logits with the now-
# larger mask. The token KV cache is only advanced for the committed
# token below, so redrawing here costs nothing.
for _ in range(max_redraws):
nxt = sample_next(logits, rng, temperature=temperature, top_k=top_k, top_p=top_p,
banned_ids=monitor.banned())
if monitor.accept(nxt):
break
else:
# every redraw was rejected. Committing a bad token would corrupt
# the monitor's running text (every subsequent token would then
# parse-fail and be mis-recorded), so abort the stream cleanly
# instead — the caller (generate_stream) ends generation here.
raise _StreamAborted()
generated.append(nxt)
if len(generated) < self.patch_size:
emb = self.wte[nxt].reshape(1, 1, -1)
logits, past = self._token_kv_step(emb, past)
return generated
def generate_stream (self, prompt_text='', max_patches=256, temperature=1.0, top_k=0,
top_p=0.9, measures=None, seed=0, monitor=None):
'''Autoregressive generation, yielding after every patch.
Yields (raw, pretty, done):
raw -- accumulated decoded patch text (with `[r:x/y]` markers), for the log
pretty -- postprocess(raw), measure-segmented, for the editor
done -- True on the final yield (EOS patch or max_patches reached)
monitor: optional syntax-blacklist hook threaded into the token decode loop
(default None -> behavior identical to before). When a monitor is given the
`[r:0/<measures>]` priming re-sample is disabled: priming is a probe-then-
discard draw that would corrupt the monitor's running stream, and the
blacklist harness wants free generation anyway.
'''
rng = np.random.default_rng(seed)
bos_patch = [self.bos_id] * (self.patch_size - 1) + [self.eos_id]
patches = [bos_patch]
if prompt_text:
for line in prompt_text.splitlines():
ids = self.tokenizer.encode(line + '\n')
for i in range(0, len(ids), self.patch_size):
chunk = ids[i:i + self.patch_size]
patches.append(chunk + [self.pad_id] * (self.patch_size - len(chunk)))
out_text = self.patches_to_text(patches[1:])
prime_ids = self.tokenizer.encode(f'[r:0/{measures}]') if measures is not None else None
primed = False
# seed the monitor's running context/stream from the prompt patches (if any)
if monitor is not None and len(patches) > 1:
for p in patches[1:]:
for tid in p:
monitor.commit_forced(tid)
# prefill: run all seed patches through the patch-KV decoder in one call
past = self._empty_patch_past()
last, past = self._patch_kv_step(patches, past)
yield out_text, postprocess(out_text), False
for _ in range(max_patches):
# If a monitor supports priming-rewind (mark/rollback), snapshot before the
# draw so a discarded probe patch can be undone (see priming below).
can_prime_with_monitor = monitor is not None and hasattr(monitor, 'mark')
if can_prime_with_monitor:
monitor.mark()
try:
patch_ids = self._generate_patch(last, rng, temperature=temperature, top_k=top_k, top_p=top_p,
monitor=monitor)
except _StreamAborted:
break
# first time the model emits a stream patch, re-sample with the forced
# `[r:0/<measures>]` prefix so the body starts at the requested measure count.
# With a priming-capable monitor we rewind its running text first, so the
# discarded probe patch doesn't pollute the 2-gram context. (The discovery
# harness's monitor lacks mark/rollback, so priming stays disabled there.)
if prime_ids is not None and not primed and self.patch_to_text(patch_ids).startswith('[r:') \
and (monitor is None or can_prime_with_monitor):
primed = True
if can_prime_with_monitor:
monitor.rollback()
patch_ids = self._generate_patch(last, rng, prefix_ids=prime_ids,
temperature=temperature, top_k=top_k, top_p=top_p, monitor=monitor)
# EOS patch -> done
if patch_ids[0] == self.bos_id and patch_ids[1] == self.eos_id:
break
out_text += self.patch_to_text(patch_ids)
# mask tokens after the first EOS inside the patch to PAD before caching
clean = list(patch_ids)
seen_eos = False
for j in range(len(clean)):
if seen_eos:
clean[j] = self.pad_id
if clean[j] == self.eos_id:
seen_eos = True
# advance the patch-level cache by the one new patch -> next hidden state
last, past = self._patch_kv_step([clean], past)
yield out_text, postprocess(out_text), False
yield out_text, postprocess(out_text), True