Spaces:
Running
Running
File size: 13,571 Bytes
b69de73 7f31eab b69de73 4252956 b69de73 4252956 b69de73 7f31eab b69de73 55b687e b69de73 55b687e b69de73 7f31eab b69de73 4252956 b69de73 4252956 b69de73 4252956 b69de73 4252956 b69de73 4252956 b69de73 4252956 b69de73 54a557e b69de73 4252956 b69de73 4252956 b69de73 4252956 b69de73 4252956 b69de73 4252956 b69de73 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 | """Standalone streaming Lilylet generator (int8 ONNX + two-level KV cache).
Torch-free adaptation of deep-starry's ORTGeneratorKV
(tests/bench_lilylet_int8_ort.py): the patch-level decoder runs incrementally
through `patch_kv_int8.onnx` (O(1) per step) and the token decoder inside each
patch through `token_kv_int8.onnx`, both via onnxruntime. The token embedding
table and model geometry are loaded from vendored assets (`wte.npy`,
`geometry.json`) instead of a torch model, and sampling is reimplemented in
numpy — so the only runtime deps are onnxruntime + numpy.
`generate_stream(...)` is a Python generator: it yields `(raw, pretty, done)`
after every patch, where `raw` is the accumulated decoded text (for the run
log) and `pretty = postprocess(raw)` (for the editor, segmented by measure).
"""
import os
import json
import logging
LOG = logging.getLogger('lilyscript')
import numpy as np
import onnxruntime as ort
from .postprocess import postprocess
class _StreamAborted (Exception):
'''Raised by the token loop when a syntax monitor exhausts its redraw budget,
signalling generate_stream to end the stream cleanly rather than commit a bad
token (which would corrupt the monitor's running text). Internal to the
blacklist-discovery path; never raised when monitor is None.'''
pass
def sample_next (logits, rng, temperature=1.0, top_k=0, top_p=1.0, banned_ids=None):
'''Sample one token id from a logits vector (numpy) with temperature/top-k/top-p.
banned_ids: optional iterable of token ids to forbid; their logits are set to
-inf before any other filtering, so they can never be drawn (used by the
syntax-blacklist mask in tools/lilylet_blacklist_gen.py). Default None = no-op.
'''
logits = logits.astype(np.float64)
if banned_ids:
logits = logits.copy()
logits[list(banned_ids)] = -np.inf
if temperature != 1.0:
logits = logits / max(temperature, 1e-6)
if top_k and top_k > 0:
k = min(top_k, logits.shape[-1])
kth = np.sort(logits)[-k]
logits = np.where(logits < kth, -np.inf, logits)
if top_p and top_p < 1.0:
order = np.argsort(logits)[::-1]
sorted_logits = logits[order]
probs = _softmax(sorted_logits)
cdf = np.cumsum(probs)
remove = cdf > top_p
remove[1:] = remove[:-1].copy()
remove[0] = False
sorted_logits = np.where(remove, -np.inf, sorted_logits)
logits = np.full_like(logits, -np.inf)
logits[order] = sorted_logits
probs = _softmax(logits)
return int(rng.choice(len(probs), p=probs))
def _softmax (x):
x = x - np.max(x)
e = np.exp(x)
return e / e.sum()
def _physical_cores ():
'''Best-effort physical (not logical/HT) core count via /proc/cpuinfo; None if
unavailable. ORT's intra_op default (=0) maps to this on most CPU builds.'''
try:
phys, cur = set(), {}
for line in open('/proc/cpuinfo'):
line = line.strip()
if not line:
if 'physical id' in cur and 'core id' in cur:
phys.add((cur['physical id'], cur['core id']))
cur = {}
continue
if ':' in line:
k, v = line.split(':', 1)
cur[k.strip()] = v.strip()
return len(phys) or None
except Exception:
return None
def _log_thread_info (so, sess):
'''Log host CPU capacity + the ONNX Runtime intra/inter-op thread settings that
are actually in effect. intra_op_num_threads/inter_op_num_threads == 0 means
"ORT auto" — it picks the number of physical cores for the intra-op pool.'''
logical = os.cpu_count()
affinity = len(os.sched_getaffinity(0)) if hasattr(os, 'sched_getaffinity') else logical
physical = _physical_cores()
intra = so.intra_op_num_threads
inter = so.inter_op_num_threads
effective_intra = intra if intra else (physical or affinity or logical)
LOG.info('CPU: %s logical / %s physical cores, %s available (affinity)',
logical, physical if physical is not None else '?', affinity)
LOG.info('ONNX Runtime threads: intra_op=%s (%s), inter_op=%s (%s) | execution_mode=%s',
intra, 'auto -> ~%s' % effective_intra if intra == 0 else 'explicit',
inter, 'auto' if inter == 0 else 'explicit',
getattr(so, 'execution_mode', '?'))
class StreamingLilyletGenerator:
'''Loads the int8 KV ONNX sessions + vendored assets and streams generation.'''
def __init__ (self, model_dir, asset_dir, threads=None):
from .tokenizer import LilyletTokenizer
geo = json.load(open(os.path.join(model_dir, 'geometry.json')))
self.patch_size = geo['patch_size']
self.pad_id = geo['pad_id']
self.bos_id = geo['bos_id']
self.eos_id = geo['eos_id']
# patch-level KV geometry
self.n_layers = geo['patch']['n_layers']
self.n_kv = geo['patch']['n_kv_heads']
self.head_dim = geo['patch']['head_dim']
# token-level KV geometry
self.t_layers = geo['token']['n_layers']
self.t_kv = geo['token']['n_kv_heads']
self.t_head_dim = geo['token']['head_dim']
self.tokenizer = LilyletTokenizer(os.path.join(asset_dir, 'lilylet-tokenizer.json'))
self.wte = np.load(os.path.join(model_dir, 'wte.npy')) # [vocab, hidden] — model weight, lives with the onnx
so = ort.SessionOptions()
if threads:
so.intra_op_num_threads = threads
self.patch_kv_sess = ort.InferenceSession(
os.path.join(model_dir, 'patch_kv_int8.onnx'), so, providers=['CPUExecutionProvider'])
self.token_kv_sess = ort.InferenceSession(
os.path.join(model_dir, 'token_kv_int8.onnx'), so, providers=['CPUExecutionProvider'])
self.patch_out_names = [o.name for o in self.patch_kv_sess.get_outputs()]
self.token_out_names = [o.name for o in self.token_kv_sess.get_outputs()]
_log_thread_info(so, self.patch_kv_sess)
# ---- text helpers (mirror LilyletPatchyGenerator.patch_to_text) ----
def patch_to_text (self, patch):
out = []
for tid in patch:
tid = int(tid)
if tid == self.eos_id:
break
if tid in (self.pad_id, self.bos_id):
continue
out.append(self.tokenizer.text_by_id.get(tid, ''))
return ''.join(out)
def patches_to_text (self, patches):
return ''.join(self.patch_to_text(p) for p in patches)
# ---- KV plumbing ----
def _empty_patch_past (self):
return [np.zeros((1, self.n_kv, 0, self.head_dim), dtype=np.float32) for _ in range(2 * self.n_layers)]
def _empty_token_past (self):
return [np.zeros((1, self.t_kv, 0, self.t_head_dim), dtype=np.float32) for _ in range(2 * self.t_layers)]
def _patch_kv_step (self, patch_rows, past):
'''Feed L new patches (list of patch_size-length id rows) + past KV.
Returns (last_hidden [hidden], new_past list).'''
feed = {'patches': np.asarray([patch_rows], dtype=np.int64)}
for i in range(self.n_layers):
feed[f'past_k_{i}'] = past[2 * i]
feed[f'past_v_{i}'] = past[2 * i + 1]
out = dict(zip(self.patch_out_names, self.patch_kv_sess.run(None, feed)))
new_past = []
for i in range(self.n_layers):
new_past.append(out[f'new_k_{i}'])
new_past.append(out[f'new_v_{i}'])
return out['hidden'][0, -1], new_past
def _token_kv_step (self, emb_np, past):
'''Feed L new token embeddings [1,L,hidden] + past KV. Returns (logits[-1], new_past).'''
feed = {'inputs_embeds': emb_np.astype(np.float32)}
for i in range(self.t_layers):
feed[f'past_k_{i}'] = past[2 * i]
feed[f'past_v_{i}'] = past[2 * i + 1]
out = dict(zip(self.token_out_names, self.token_kv_sess.run(None, feed)))
new_past = []
for i in range(self.t_layers):
new_past.append(out[f'new_k_{i}'])
new_past.append(out[f'new_v_{i}'])
return out['logits'][0, -1], new_past
def _generate_patch (self, last_hidden, rng, prefix_ids=None, temperature=1.0, top_k=0, top_p=1.0,
monitor=None, max_redraws=16):
'''Token-level decode for one patch through the token-KV session: feed the
patch state, then bos+prefix embeddings, then sample until the patch is full.
monitor: optional syntax-blacklist hook (default None -> unchanged behavior).
- monitor.banned() -> iterable of token ids to mask for the next draw,
- monitor.accept(id) -> bool: True to commit the drawn token, False to
redraw (the monitor records the violation + grows its ban set internally),
- monitor.commit_forced(id): notified of each forced (non-sampled) token so
its running context/stream stay correct.
Forced tokens (bos + prefix_ids) bypass accept()/redraw — only commit_forced.
'''
generated = list(prefix_ids or [])
tokens = [self.bos_id] + list(prefix_ids or [])
if monitor is not None:
for tid in prefix_ids or []:
monitor.commit_forced(tid)
past = self._empty_token_past()
enc = last_hidden.reshape(1, 1, -1).astype(np.float32)
logits, past = self._token_kv_step(enc, past)
for i in range(1, len(tokens)):
emb = self.wte[tokens[i]].reshape(1, 1, -1)
logits, past = self._token_kv_step(emb, past)
while len(generated) < self.patch_size:
if monitor is None:
nxt = sample_next(logits, rng, temperature=temperature, top_k=top_k, top_p=top_p)
else:
# draw with the current ban mask; on a rejected (illegal) token the
# monitor records it and we redraw from the SAME logits with the now-
# larger mask. The token KV cache is only advanced for the committed
# token below, so redrawing here costs nothing.
for _ in range(max_redraws):
nxt = sample_next(logits, rng, temperature=temperature, top_k=top_k, top_p=top_p,
banned_ids=monitor.banned())
if monitor.accept(nxt):
break
else:
# every redraw was rejected. Committing a bad token would corrupt
# the monitor's running text (every subsequent token would then
# parse-fail and be mis-recorded), so abort the stream cleanly
# instead — the caller (generate_stream) ends generation here.
raise _StreamAborted()
generated.append(nxt)
if len(generated) < self.patch_size:
emb = self.wte[nxt].reshape(1, 1, -1)
logits, past = self._token_kv_step(emb, past)
return generated
def generate_stream (self, prompt_text='', max_patches=256, temperature=1.0, top_k=0,
top_p=0.9, measures=None, seed=0, monitor=None):
'''Autoregressive generation, yielding after every patch.
Yields (raw, pretty, done):
raw -- accumulated decoded patch text (with `[r:x/y]` markers), for the log
pretty -- postprocess(raw), measure-segmented, for the editor
done -- True on the final yield (EOS patch or max_patches reached)
monitor: optional syntax-blacklist hook threaded into the token decode loop
(default None -> behavior identical to before). When a monitor is given the
`[r:0/<measures>]` priming re-sample is disabled: priming is a probe-then-
discard draw that would corrupt the monitor's running stream, and the
blacklist harness wants free generation anyway.
'''
rng = np.random.default_rng(seed)
bos_patch = [self.bos_id] * (self.patch_size - 1) + [self.eos_id]
patches = [bos_patch]
if prompt_text:
for line in prompt_text.splitlines():
ids = self.tokenizer.encode(line + '\n')
for i in range(0, len(ids), self.patch_size):
chunk = ids[i:i + self.patch_size]
patches.append(chunk + [self.pad_id] * (self.patch_size - len(chunk)))
out_text = self.patches_to_text(patches[1:])
# 0-based marker: `y` counts measures remaining AFTER this one (patchifier:
# y = total - i - 1), so `[r:0/{measures-1}]` yields exactly `measures` total.
prime_ids = self.tokenizer.encode(f'[r:0/{measures - 1}]') if measures is not None and measures >= 1 else None
primed = False
# seed the monitor's running context/stream from the prompt patches (if any)
if monitor is not None and len(patches) > 1:
for p in patches[1:]:
for tid in p:
monitor.commit_forced(tid)
# prefill: run all seed patches through the patch-KV decoder in one call
past = self._empty_patch_past()
last, past = self._patch_kv_step(patches, past)
yield out_text, postprocess(out_text), False
for _ in range(max_patches):
# If a monitor supports priming-rewind (mark/rollback), snapshot before the
# draw so a discarded probe patch can be undone (see priming below).
can_prime_with_monitor = monitor is not None and hasattr(monitor, 'mark')
if can_prime_with_monitor:
monitor.mark()
try:
patch_ids = self._generate_patch(last, rng, temperature=temperature, top_k=top_k, top_p=top_p,
monitor=monitor)
except _StreamAborted:
break
# first time the model emits a stream patch, re-sample with the forced
# `[r:0/<measures>]` prefix so the body starts at the requested measure count.
# With a priming-capable monitor we rewind its running text first, so the
# discarded probe patch doesn't pollute the 2-gram context. (The discovery
# harness's monitor lacks mark/rollback, so priming stays disabled there.)
if prime_ids is not None and not primed and self.patch_to_text(patch_ids).startswith('[r:') \
and (monitor is None or can_prime_with_monitor):
primed = True
if can_prime_with_monitor:
monitor.rollback()
patch_ids = self._generate_patch(last, rng, prefix_ids=prime_ids,
temperature=temperature, top_k=top_k, top_p=top_p, monitor=monitor)
# EOS patch -> done
if patch_ids[0] == self.bos_id and patch_ids[1] == self.eos_id:
break
out_text += self.patch_to_text(patch_ids)
# mask tokens after the first EOS inside the patch to PAD before caching
clean = list(patch_ids)
seen_eos = False
for j in range(len(clean)):
if seen_eos:
clean[j] = self.pad_id
if clean[j] == self.eos_id:
seen_eos = True
# advance the patch-level cache by the one new patch -> next hidden state
last, past = self._patch_kv_step([clean], past)
yield out_text, postprocess(out_text), False
yield out_text, postprocess(out_text), True
|