Spaces:
Running
Running
| """Runtime syntax-blacklist mask for live generation. | |
| A lightweight, parse-free counterpart to tools/lilylet_blacklist_gen.py's | |
| BlacklistMonitor. It does NOT call any parser/oracle and has no Node dependency: | |
| it trusts a pre-discovered variable-length n-gram blacklist and masks the | |
| forbidden next tokens during sampling. Suitable for the Gradio live-gen path. | |
| Wire-compatible with StreamingLilyletGenerator's `monitor` hook: | |
| banned() -> ids to mask for the next draw (suffix match on the context) | |
| accept(id)->bool -> always True (we trust the mask; never re-parse) | |
| commit_forced(id) -> advance the running context for a forced token | |
| Plus mark()/rollback() so the generator's `[r:0/<measures>]` priming re-sample | |
| (a probe-then-discard draw) can rewind the context before the forced redraw. | |
| The context is the last N CONTENT token ids (whitespace dropped, `[r:x/y]` stream | |
| markers excluded), maintained INCREMENTALLY in id space — no per-token tokenizer | |
| call. Markers are detected by buffering a potential `[r:…]` run and classifying | |
| the tiny buffer text with the same regexes _clean uses, so the resulting context | |
| is identical to clean_for_parse()+tokenize but at near-zero cost. | |
| """ | |
| import os | |
| import re | |
| import json | |
| # Anchored marker regexes for the id-space state machine — applied to the tiny | |
| # marker BUFFER text (a candidate `[r:…]` run), never the whole stream: | |
| # - COMPLETE `\[r:\d+/\d*\]` -> drop the buffer entirely (a finished marker) | |
| # - viable PARTIAL `\[r(:(\d+(/\d*)?)?)?` -> keep buffering; it is excluded from | |
| # the context (mirrors discovery's _MARKER_PARTIAL_TAIL at end-of-stream) | |
| # - the lone `[` is a viable partial too, but stays VISIBLE as content (a bare | |
| # `[` is a header `[composer …]` / beam `c8[`, which _clean keeps) | |
| # - anything else -> not a marker -> flush the buffer back into the context | |
| _MARK_COMPLETE_FULL = re.compile(r'\[r:\d+/\d*\]\Z') | |
| _MARK_PARTIAL_FULL = re.compile(r'\[r(:(\d+(/\d*)?)?)?\Z') | |
| def _whitespace_ids (tokenizer): | |
| '''Token ids for space/newline/tab/CR present in the vocab — dropped from the | |
| content context so it matches the corpus index space the blacklist was keyed in.''' | |
| out = [] | |
| for ch in (' ', '\n', '\t', '\r'): | |
| enc = tokenizer.encode(ch) | |
| if len(enc) == 1: | |
| out.append(enc[0]) | |
| return out | |
| def load_blacklist (path): | |
| '''Load a blacklist JSON ({"id1,...,idn": [ids...]} under a top-level "blacklist" | |
| key, or bare) into dict[tuple(int,...) -> set[int]]. Keys are variable-length | |
| context n-grams. Missing file -> {}.''' | |
| if not path or not os.path.isfile(path): | |
| return {} | |
| data = json.load(open(path)) | |
| raw = data.get('blacklist', data) if isinstance(data, dict) else {} | |
| bl = {} | |
| for key, ids in raw.items(): | |
| ctx = tuple(int(x) for x in key.split(',')) if key else () | |
| bl[ctx] = set(int(i) for i in ids) | |
| return bl | |
| class MaskMonitor: | |
| '''Parse-free runtime mask. Construct with the generator (for tokenizer + | |
| special ids) and a variable-length-keyed blacklist dict; pass as `monitor=` to | |
| generate_stream. A key (a,b,c) masks its forbidden tokens whenever the running | |
| content context ends with the sequence a,b,c (longest/any suffix match).''' | |
| def __init__ (self, gen, blacklist): | |
| self.tk = gen.tokenizer | |
| self.blacklist = blacklist or {} | |
| self.pad_id, self.bos_id, self.eos_id = gen.pad_id, gen.bos_id, gen.eos_id | |
| self._ws = set(_whitespace_ids(self.tk)) | |
| self._key_lengths = sorted({len(k) for k in self.blacklist}, reverse=True) if self.blacklist else [] | |
| self._max_ctx = max(self._key_lengths) if self._key_lengths else 0 | |
| self._lbrack = self.tk.encode('[')[0] | |
| # Context is maintained INCREMENTALLY in id space — no per-token tokenizer | |
| # call. The only thing that needs care is excluding `[r:x/y]` stream markers, | |
| # whose chars share ids with real content. We buffer a *potential* marker | |
| # (always starting at `[`) and decide using the SAME regexes as _clean applied | |
| # to the tiny buffer text (id->char is a cheap dict lookup), so the resulting | |
| # context is provably identical to _clean()+encode of the full stream. | |
| self._ctx_ids = [] # confirmed visible content ids (whitespace dropped) | |
| self._buf = [] # token ids of an in-progress potential marker | |
| self._buf_text = '' # their concatenated text (starts with '[') | |
| self._mark = None | |
| def _is_content (self, tid): | |
| return tid not in (self.pad_id, self.bos_id, self.eos_id) | |
| def _is_ctx (self, tid): | |
| return self._is_content(tid) and int(tid) not in self._ws | |
| def _text (self, tid): | |
| tid = int(tid) | |
| return '' if not self._is_content(tid) else self.tk.text_by_id.get(tid, '') | |
| def _push (self, tid): | |
| self._ctx_ids.append(int(tid)) | |
| if len(self._ctx_ids) > self._max_ctx: | |
| del self._ctx_ids[:-self._max_ctx] | |
| def _push_content (self, tid): | |
| '''Route a non-marker token into the context: drop whitespace, keep content.''' | |
| if self._is_ctx(tid): | |
| self._push(tid) | |
| def _flush_buf (self): | |
| '''The buffered tokens are NOT a marker after all -> they are real content; | |
| replay them through the content path, then clear the buffer.''' | |
| buf = self._buf | |
| self._buf = [] | |
| self._buf_text = '' | |
| for tid in buf: | |
| self._push_content(tid) | |
| def _feed (self, tid): | |
| '''Advance the incremental context with one committed token id.''' | |
| tid = int(tid) | |
| if not self._is_content(tid): | |
| # pad/bos/eos: not content, and breaks any pending marker buffer. | |
| if self._buf: | |
| self._flush_buf() | |
| return | |
| if self._buf: | |
| # extend the potential marker and re-classify the buffer text. | |
| text = self._buf_text + self._text(tid) | |
| if _MARK_COMPLETE_FULL.match(text): | |
| # a full `[r:x/y]` -> the whole buffer contributes nothing; drop it. | |
| self._buf = [] | |
| self._buf_text = '' | |
| return | |
| if _MARK_PARTIAL_FULL.match(text): | |
| # still a viable marker prefix (`[r`, `[r:`, `[r:5`, `[r:5/`, `[r:5/3`) | |
| self._buf.append(tid) | |
| self._buf_text = text | |
| return | |
| # not a marker: flush the buffer as content, then process tid afresh below. | |
| self._flush_buf() | |
| if tid == self._lbrack: | |
| # start a potential marker (a lone `[` is itself real content until/unless | |
| # an `r` follows; that visibility is handled in banned()). | |
| self._buf = [tid] | |
| self._buf_text = '[' | |
| return | |
| self._push_content(tid) | |
| # ---- generator-facing API ---- | |
| def banned (self): | |
| '''Union forbidden sets of every stored key that is a suffix of the context. | |
| The visible context is _ctx_ids plus a pending lone `[` (which _clean keeps); | |
| a longer `[r…` partial buffer is stripped (matches _MARKER_PARTIAL_TAIL).''' | |
| if not self._key_lengths: | |
| return () | |
| ctx = self._ctx_ids | |
| if self._buf_text == '[': | |
| ctx = ctx + [self._lbrack] | |
| out = set() | |
| for n in self._key_lengths: | |
| if n <= len(ctx): | |
| hit = self.blacklist.get(tuple(ctx[-n:])) | |
| if hit: | |
| out |= hit | |
| return out | |
| def accept (self, tid): | |
| # trust the mask: a non-banned draw is always committed (no re-parse) | |
| self.commit_forced(tid) | |
| return True | |
| def commit_forced (self, tid): | |
| if self._max_ctx: | |
| self._feed(tid) | |
| # ---- priming support (probe-then-discard rewind) ---- | |
| def mark (self): | |
| '''Snapshot the context so a discarded priming-probe patch can be undone.''' | |
| self._mark = (list(self._ctx_ids), list(self._buf), self._buf_text) | |
| def rollback (self): | |
| '''Rewind to the last mark() (drops a probe patch's effect on the context).''' | |
| if self._mark is not None: | |
| self._ctx_ids = list(self._mark[0]) | |
| self._buf = list(self._mark[1]) | |
| self._buf_text = self._mark[2] | |