File size: 7,519 Bytes
4252956
 
 
 
a8785dd
 
4252956
 
a8785dd
4252956
a8785dd
4252956
a8785dd
4252956
a8785dd
 
 
 
 
4252956
 
 
 
 
 
a8785dd
 
 
 
 
 
 
 
 
 
4252956
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a8785dd
 
 
 
 
 
 
 
 
 
4252956
 
 
 
 
 
 
 
 
 
 
 
a8785dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4252956
 
 
 
a8785dd
 
 
 
4252956
 
 
a8785dd
 
4252956
 
 
 
 
 
 
 
 
 
 
 
 
 
a8785dd
 
4252956
 
 
 
a8785dd
 
4252956
 
 
 
a8785dd
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
"""Runtime syntax-blacklist mask for live generation.

A lightweight, parse-free counterpart to tools/lilylet_blacklist_gen.py's
BlacklistMonitor. It does NOT call any parser/oracle and has no Node dependency:
it trusts a pre-discovered variable-length n-gram blacklist and masks the
forbidden next tokens during sampling. Suitable for the Gradio live-gen path.

Wire-compatible with StreamingLilyletGenerator's `monitor` hook:
  banned()          -> ids to mask for the next draw (suffix match on the context)
  accept(id)->bool  -> always True (we trust the mask; never re-parse)
  commit_forced(id) -> advance the running context for a forced token
Plus mark()/rollback() so the generator's `[r:0/<measures>]` priming re-sample
(a probe-then-discard draw) can rewind the context before the forced redraw.

The context is the last N CONTENT token ids (whitespace dropped, `[r:x/y]` stream
markers excluded), maintained INCREMENTALLY in id space — no per-token tokenizer
call. Markers are detected by buffering a potential `[r:…]` run and classifying
the tiny buffer text with the same regexes _clean uses, so the resulting context
is identical to clean_for_parse()+tokenize but at near-zero cost.
"""

import os
import re
import json

# Anchored marker regexes for the id-space state machine — applied to the tiny
# marker BUFFER text (a candidate `[r:…]` run), never the whole stream:
#   - COMPLETE `\[r:\d+/\d*\]` -> drop the buffer entirely (a finished marker)
#   - viable PARTIAL `\[r(:(\d+(/\d*)?)?)?` -> keep buffering; it is excluded from
#       the context (mirrors discovery's _MARKER_PARTIAL_TAIL at end-of-stream)
#   - the lone `[` is a viable partial too, but stays VISIBLE as content (a bare
#       `[` is a header `[composer …]` / beam `c8[`, which _clean keeps)
#   - anything else -> not a marker -> flush the buffer back into the context
_MARK_COMPLETE_FULL = re.compile(r'\[r:\d+/\d*\]\Z')
_MARK_PARTIAL_FULL = re.compile(r'\[r(:(\d+(/\d*)?)?)?\Z')


def _whitespace_ids (tokenizer):
	'''Token ids for space/newline/tab/CR present in the vocab — dropped from the
	content context so it matches the corpus index space the blacklist was keyed in.'''
	out = []
	for ch in (' ', '\n', '\t', '\r'):
		enc = tokenizer.encode(ch)
		if len(enc) == 1:
			out.append(enc[0])
	return out


def load_blacklist (path):
	'''Load a blacklist JSON ({"id1,...,idn": [ids...]} under a top-level "blacklist"
	key, or bare) into dict[tuple(int,...) -> set[int]]. Keys are variable-length
	context n-grams. Missing file -> {}.'''
	if not path or not os.path.isfile(path):
		return {}
	data = json.load(open(path))
	raw = data.get('blacklist', data) if isinstance(data, dict) else {}
	bl = {}
	for key, ids in raw.items():
		ctx = tuple(int(x) for x in key.split(',')) if key else ()
		bl[ctx] = set(int(i) for i in ids)
	return bl


class MaskMonitor:
	'''Parse-free runtime mask. Construct with the generator (for tokenizer +
	special ids) and a variable-length-keyed blacklist dict; pass as `monitor=` to
	generate_stream. A key (a,b,c) masks its forbidden tokens whenever the running
	content context ends with the sequence a,b,c (longest/any suffix match).'''

	def __init__ (self, gen, blacklist):
		self.tk = gen.tokenizer
		self.blacklist = blacklist or {}
		self.pad_id, self.bos_id, self.eos_id = gen.pad_id, gen.bos_id, gen.eos_id
		self._ws = set(_whitespace_ids(self.tk))
		self._key_lengths = sorted({len(k) for k in self.blacklist}, reverse=True) if self.blacklist else []
		self._max_ctx = max(self._key_lengths) if self._key_lengths else 0
		self._lbrack = self.tk.encode('[')[0]
		# Context is maintained INCREMENTALLY in id space — no per-token tokenizer
		# call. The only thing that needs care is excluding `[r:x/y]` stream markers,
		# whose chars share ids with real content. We buffer a *potential* marker
		# (always starting at `[`) and decide using the SAME regexes as _clean applied
		# to the tiny buffer text (id->char is a cheap dict lookup), so the resulting
		# context is provably identical to _clean()+encode of the full stream.
		self._ctx_ids = []        # confirmed visible content ids (whitespace dropped)
		self._buf = []            # token ids of an in-progress potential marker
		self._buf_text = ''       # their concatenated text (starts with '[')
		self._mark = None

	def _is_content (self, tid):
		return tid not in (self.pad_id, self.bos_id, self.eos_id)

	def _is_ctx (self, tid):
		return self._is_content(tid) and int(tid) not in self._ws

	def _text (self, tid):
		tid = int(tid)
		return '' if not self._is_content(tid) else self.tk.text_by_id.get(tid, '')

	def _push (self, tid):
		self._ctx_ids.append(int(tid))
		if len(self._ctx_ids) > self._max_ctx:
			del self._ctx_ids[:-self._max_ctx]

	def _push_content (self, tid):
		'''Route a non-marker token into the context: drop whitespace, keep content.'''
		if self._is_ctx(tid):
			self._push(tid)

	def _flush_buf (self):
		'''The buffered tokens are NOT a marker after all -> they are real content;
		replay them through the content path, then clear the buffer.'''
		buf = self._buf
		self._buf = []
		self._buf_text = ''
		for tid in buf:
			self._push_content(tid)

	def _feed (self, tid):
		'''Advance the incremental context with one committed token id.'''
		tid = int(tid)
		if not self._is_content(tid):
			# pad/bos/eos: not content, and breaks any pending marker buffer.
			if self._buf:
				self._flush_buf()
			return

		if self._buf:
			# extend the potential marker and re-classify the buffer text.
			text = self._buf_text + self._text(tid)
			if _MARK_COMPLETE_FULL.match(text):
				# a full `[r:x/y]` -> the whole buffer contributes nothing; drop it.
				self._buf = []
				self._buf_text = ''
				return
			if _MARK_PARTIAL_FULL.match(text):
				# still a viable marker prefix (`[r`, `[r:`, `[r:5`, `[r:5/`, `[r:5/3`)
				self._buf.append(tid)
				self._buf_text = text
				return
			# not a marker: flush the buffer as content, then process tid afresh below.
			self._flush_buf()

		if tid == self._lbrack:
			# start a potential marker (a lone `[` is itself real content until/unless
			# an `r` follows; that visibility is handled in banned()).
			self._buf = [tid]
			self._buf_text = '['
			return

		self._push_content(tid)

	# ---- generator-facing API ----

	def banned (self):
		'''Union forbidden sets of every stored key that is a suffix of the context.

		The visible context is _ctx_ids plus a pending lone `[` (which _clean keeps);
		a longer `[r…` partial buffer is stripped (matches _MARKER_PARTIAL_TAIL).'''
		if not self._key_lengths:
			return ()
		ctx = self._ctx_ids
		if self._buf_text == '[':
			ctx = ctx + [self._lbrack]
		out = set()
		for n in self._key_lengths:
			if n <= len(ctx):
				hit = self.blacklist.get(tuple(ctx[-n:]))
				if hit:
					out |= hit
		return out

	def accept (self, tid):
		# trust the mask: a non-banned draw is always committed (no re-parse)
		self.commit_forced(tid)
		return True

	def commit_forced (self, tid):
		if self._max_ctx:
			self._feed(tid)

	# ---- priming support (probe-then-discard rewind) ----

	def mark (self):
		'''Snapshot the context so a discarded priming-probe patch can be undone.'''
		self._mark = (list(self._ctx_ids), list(self._buf), self._buf_text)

	def rollback (self):
		'''Rewind to the last mark() (drops a probe patch's effect on the context).'''
		if self._mark is not None:
			self._ctx_ids = list(self._mark[0])
			self._buf = list(self._mark[1])
			self._buf_text = self._mark[2]