Jonttup's picture
Upload models/edit_ops.py with huggingface_hub
40c6c49 verified
"""
Parametric Edit Operations Engine β€” 31 Operations
Executes edit operations on palette tensors with pointer-based addressing.
Operations preserve scope balance and support region-relative indexing.
Three levels:
Level 1 β€” Primitive (8 ops): token-level atomic edits
Level 2 β€” Structural (11 ops): scope-tree mutations
Level 3 β€” Semantic (12 ops): meaning-level transforms
Key Features:
- Parametric actions with arguments
- Pointer arithmetic (region-relative β†’ absolute)
- Cross-region operations (MOVE, COPY, EXTRACT, INLINE, MERGE, etc.)
- Scope balance verification
- Macro pattern transformations
- Stateless execution (pure functions)
"""
import torch
from dataclasses import dataclass, field
from typing import Tuple, Optional, List, Dict
from enum import IntEnum
# Import RegionMetadata from scope_pooler
from .scope_pooler import RegionMetadata
class OpCode(IntEnum):
"""Operation codes β€” 31 edit operations across 3 levels"""
# === Control (0-99) ===
NO_OP = 0
MOVE_NEXT = 1
FOCUS_PARENT = 2
DONE = 99
# === Level 1: Primitive β€” token-level atomic (100-199) ===
DELETE_RANGE = 100
INSERT_TOKEN = 101
REPLACE_TOKEN = 102
SWAP_TOKENS = 103
MOVE_RANGE = 104
COPY_RANGE = 105
WRAP_SCOPE = 106
UNWRAP_SCOPE = 107
# === Level 2: Structural β€” scope-tree mutations (200-299) ===
INDENT = 200
DEDENT = 201
EXTRACT = 202
INLINE = 203
SPLIT_REGION = 204
MERGE_REGIONS = 205
REORDER = 206
NEST_IN_BLOCK = 207
UNNEST_FROM_BLOCK = 208
HOIST = 209
SINK = 210
# === Level 3: Semantic β€” meaning-level transforms (300-399) ===
RENAME = 300
RETYPE = 301
CONVERT_CONSTRUCT = 302
SYNC_TO_ASYNC = 303
PARAMETERIZE = 304
SPECIALIZE = 305
GUARD = 306
UNGUARD = 307
SCATTER = 308
GATHER = 309
MIRROR = 310
COMPOSE = 311
# Backward compat aliases for old OpCode values
_LEGACY_OPCODES = {
150: OpCode.DELETE_RANGE,
151: OpCode.INSERT_TOKEN,
152: OpCode.REPLACE_TOKEN,
153: OpCode.SWAP_TOKENS,
499: OpCode.DONE,
}
# ---- Op metadata for classification ----
OP_LEVEL = {}
for op in OpCode:
v = op.value
if v < 100:
OP_LEVEL[op] = 'control'
elif v < 200:
OP_LEVEL[op] = 'primitive'
elif v < 300:
OP_LEVEL[op] = 'structural'
else:
OP_LEVEL[op] = 'semantic'
# Canonical list of the 31 trainable ops (excludes control)
TRAINABLE_OPS: List[OpCode] = [op for op in OpCode if OP_LEVEL[op] != 'control']
NUM_OPS = len(TRAINABLE_OPS) # 31
OP_TO_IDX: Dict[OpCode, int] = {op: i for i, op in enumerate(TRAINABLE_OPS)}
IDX_TO_OP: Dict[int, OpCode] = {i: op for i, op in enumerate(TRAINABLE_OPS)}
@dataclass
class EditAction:
"""
Parametric edit operation with arguments.
Fields:
op_id: Operation code from OpCode enum
region_id: Which semantic region to operate on [0, R)
i_start: Token index within region (relative addressing)
i_end: End token index (for range operations, -1 if unused)
payload_idx: Palette index to insert/replace (0-4095)
confidence: Model confidence in [0, 1]
target_region_id: Destination region for cross-region ops (-1 if same region)
payload_tokens: Multi-token payload for WRAP, NEST, etc.
positions: Multiple target positions for SCATTER
"""
op_id: int
region_id: int
i_start: int
i_end: int
payload_idx: int
confidence: float = 1.0
target_region_id: int = -1
payload_tokens: List[int] = field(default_factory=list)
positions: List[int] = field(default_factory=list)
def __post_init__(self):
assert self.op_id >= 0, f"Invalid op_id: {self.op_id}"
assert self.region_id >= 0, f"Invalid region_id: {self.region_id}"
assert self.i_start >= 0, f"Invalid i_start: {self.i_start}"
assert self.i_end >= -1, f"Invalid i_end: {self.i_end}"
if self.i_end != -1:
assert self.i_end >= self.i_start, f"i_end ({self.i_end}) < i_start ({self.i_start})"
assert 0 <= self.payload_idx < 4096, f"Invalid payload_idx: {self.payload_idx}"
assert 0 <= self.confidence <= 1, f"Invalid confidence: {self.confidence}"
# ---- Exceptions ----
class EditError(Exception):
"""Base class for edit errors"""
pass
class ScopeBalanceError(EditError):
"""Operation would break scope balance"""
pass
class InvalidPointerError(EditError):
"""Pointer out of bounds"""
pass
class RegionNotFoundError(EditError):
"""region_id invalid"""
pass
class PatternNotFoundError(EditError):
"""Macro pattern not found in region"""
pass
class CrossRegionError(EditError):
"""Cross-region operation failed"""
pass
# ---- Main Engine ----
class PaletteEditOps:
"""
Stateless edit operation executor β€” 31 operations.
All methods are pure functions (no internal state).
Thread-safe and deterministic.
Constants:
START_OF_SCOPE: 0
END_OF_SCOPE: 1
NOOP: 2
"""
START_OF_SCOPE = 0
END_OF_SCOPE = 1
NOOP = 2
# Macro pattern definitions for CONVERT_CONSTRUCT
MACRO_PATTERNS = {
'py_for_to_js_for': {
'pattern': [20, 220, 220],
'target': [20, 201, 220],
'name': 'Python for β†’ JavaScript for'
},
}
# ------------------------------------------------------------------ #
# Main dispatch #
# ------------------------------------------------------------------ #
@staticmethod
def apply(
palette_img: torch.Tensor,
action: EditAction,
metadata: RegionMetadata
) -> Tuple[torch.Tensor, bool]:
"""
Apply edit action to palette.
Returns (new_palette, success).
Guarantees: original unchanged; if success=False, new == original.
"""
# Resolve legacy OpCode values
op_id = _LEGACY_OPCODES.get(action.op_id, action.op_id)
if action.region_id >= len(metadata.starts):
return palette_img, False
palette = palette_img.clone()
try:
if not PaletteEditOps.verify_scope_balance(palette):
raise ScopeBalanceError("Input palette has unbalanced scopes")
# --- Control ---
if op_id == OpCode.NO_OP:
new_palette = palette
# --- Level 1: Primitive ---
elif op_id == OpCode.DELETE_RANGE:
new_palette = PaletteEditOps.delete_range(
palette, action.region_id, action.i_start, action.i_end, metadata)
elif op_id == OpCode.INSERT_TOKEN:
new_palette = PaletteEditOps.insert_token(
palette, action.region_id, action.i_start, action.payload_idx, metadata)
elif op_id == OpCode.REPLACE_TOKEN:
new_palette = PaletteEditOps.replace_token(
palette, action.region_id, action.i_start, action.payload_idx, metadata)
elif op_id == OpCode.SWAP_TOKENS:
new_palette = PaletteEditOps.swap_tokens(
palette, action.region_id, action.i_start, action.i_end, metadata)
elif op_id == OpCode.MOVE_RANGE:
new_palette = PaletteEditOps.move_range(
palette, action.region_id, action.i_start, action.i_end,
action.target_region_id, action.payload_idx, metadata)
elif op_id == OpCode.COPY_RANGE:
new_palette = PaletteEditOps.copy_range(
palette, action.region_id, action.i_start, action.i_end,
action.target_region_id, action.payload_idx, metadata)
elif op_id == OpCode.WRAP_SCOPE:
new_palette = PaletteEditOps.wrap_scope(
palette, action.region_id, action.i_start, action.i_end, metadata)
elif op_id == OpCode.UNWRAP_SCOPE:
new_palette = PaletteEditOps.unwrap_scope(
palette, action.region_id, metadata)
# --- Level 2: Structural ---
elif op_id == OpCode.INDENT:
new_palette = PaletteEditOps.indent(
palette, action.region_id, action.i_start, action.i_end, metadata)
elif op_id == OpCode.DEDENT:
new_palette = PaletteEditOps.dedent(
palette, action.region_id, action.i_start, action.i_end, metadata)
elif op_id == OpCode.EXTRACT:
new_palette = PaletteEditOps.extract(
palette, action.region_id, action.i_start, action.i_end, metadata)
elif op_id == OpCode.INLINE:
new_palette = PaletteEditOps.inline(
palette, action.region_id, action.target_region_id, metadata)
elif op_id == OpCode.SPLIT_REGION:
new_palette = PaletteEditOps.split_region(
palette, action.region_id, action.i_start, metadata)
elif op_id == OpCode.MERGE_REGIONS:
new_palette = PaletteEditOps.merge_regions(
palette, action.region_id, action.target_region_id, metadata)
elif op_id == OpCode.REORDER:
new_palette = PaletteEditOps.reorder(
palette, action.region_id, action.i_start, action.i_end, metadata)
elif op_id == OpCode.NEST_IN_BLOCK:
new_palette = PaletteEditOps.nest_in_block(
palette, action.region_id, action.i_start, action.i_end,
action.payload_idx, metadata)
elif op_id == OpCode.UNNEST_FROM_BLOCK:
new_palette = PaletteEditOps.unnest_from_block(
palette, action.region_id, metadata)
elif op_id == OpCode.HOIST:
new_palette = PaletteEditOps.hoist(
palette, action.region_id, action.i_start, action.i_end, metadata)
elif op_id == OpCode.SINK:
new_palette = PaletteEditOps.sink(
palette, action.region_id, action.i_start, action.i_end,
action.target_region_id, metadata)
# --- Level 3: Semantic ---
elif op_id == OpCode.RENAME:
new_palette = PaletteEditOps.rename(
palette, action.region_id, action.i_start, action.payload_idx, metadata)
elif op_id == OpCode.RETYPE:
new_palette = PaletteEditOps.retype(
palette, action.region_id, action.i_start, action.i_end,
action.payload_tokens, metadata)
elif op_id == OpCode.CONVERT_CONSTRUCT:
new_palette = PaletteEditOps.convert_construct(
palette, action.region_id, action.payload_tokens, metadata)
elif op_id == OpCode.SYNC_TO_ASYNC:
new_palette = PaletteEditOps.sync_to_async(
palette, action.region_id, metadata)
elif op_id == OpCode.PARAMETERIZE:
new_palette = PaletteEditOps.parameterize(
palette, action.region_id, action.i_start, action.payload_idx, metadata)
elif op_id == OpCode.SPECIALIZE:
new_palette = PaletteEditOps.specialize(
palette, action.region_id, action.i_start, action.i_end,
action.payload_tokens, metadata)
elif op_id == OpCode.GUARD:
new_palette = PaletteEditOps.guard(
palette, action.region_id, action.i_start, action.i_end,
action.payload_idx, metadata)
elif op_id == OpCode.UNGUARD:
new_palette = PaletteEditOps.unguard(
palette, action.region_id, metadata)
elif op_id == OpCode.SCATTER:
new_palette = PaletteEditOps.scatter(
palette, action.region_id, action.payload_idx,
action.positions, metadata)
elif op_id == OpCode.GATHER:
new_palette = PaletteEditOps.gather(
palette, action.positions, action.region_id,
action.i_start, metadata)
elif op_id == OpCode.MIRROR:
new_palette = PaletteEditOps.mirror(
palette, action.region_id, action.target_region_id,
action.i_start, action.i_end, action.payload_idx, metadata)
elif op_id == OpCode.COMPOSE:
new_palette = PaletteEditOps.compose(
palette, action.region_id, action.i_start, action.i_end, metadata)
else:
return palette_img, False
# Post-check balance
if not PaletteEditOps.verify_scope_balance(new_palette):
raise ScopeBalanceError("Operation broke scope balance")
return new_palette, True
except EditError:
return palette_img, False
# ================================================================== #
# LEVEL 1 β€” Primitive (token-level) #
# ================================================================== #
@staticmethod
def delete_range(
palette: torch.Tensor, region_id: int,
i_start: int, i_end: int, metadata: RegionMetadata
) -> torch.Tensor:
"""Delete tokens [i_start, i_end] within region. Shift left, pad NOOP."""
H, W = palette.shape
positions = PaletteEditOps._get_content_positions(palette, metadata, region_id)
if i_start < 0 or i_start >= len(positions):
raise InvalidPointerError(f"i_start={i_start} out of bounds (size={len(positions)})")
if i_end < i_start or i_end >= len(positions):
raise InvalidPointerError(f"i_end={i_end} out of bounds")
abs_positions = [positions[i] for i in range(i_start, i_end + 1)]
palette_flat = palette.flatten()
for pos in abs_positions:
if palette_flat[pos].item() in (PaletteEditOps.START_OF_SCOPE, PaletteEditOps.END_OF_SCOPE):
raise ScopeBalanceError("Cannot delete scope markers")
delete_mask = torch.zeros(H * W, dtype=torch.bool)
for pos in abs_positions:
delete_mask[pos] = True
kept = palette_flat[~delete_mask]
pad = torch.full((H * W - len(kept),), PaletteEditOps.NOOP, dtype=palette.dtype)
return torch.cat([kept, pad]).view(H, W)
@staticmethod
def insert_token(
palette: torch.Tensor, region_id: int,
i_start: int, payload_idx: int, metadata: RegionMetadata
) -> torch.Tensor:
"""Insert payload_idx at position i_start within region. Shift right, drop last."""
H, W = palette.shape
positions = PaletteEditOps._get_content_positions(palette, metadata, region_id)
if i_start < 0 or i_start > len(positions):
raise InvalidPointerError(f"i_start={i_start} out of bounds")
if payload_idx in (PaletteEditOps.START_OF_SCOPE, PaletteEditOps.END_OF_SCOPE):
raise ScopeBalanceError("Cannot insert unpaired scope marker")
abs_pos = positions[i_start] if i_start < len(positions) else (positions[-1] + 1 if positions else 0)
flat = palette.flatten()
new_flat = torch.zeros(H * W, dtype=palette.dtype)
new_flat[:abs_pos] = flat[:abs_pos]
new_flat[abs_pos] = payload_idx
if abs_pos < H * W - 1:
new_flat[abs_pos + 1:] = flat[abs_pos:H * W - 1]
return new_flat.view(H, W)
@staticmethod
def replace_token(
palette: torch.Tensor, region_id: int,
i_start: int, payload_idx: int, metadata: RegionMetadata
) -> torch.Tensor:
"""Replace token at i_start with payload_idx."""
H, W = palette.shape
positions = PaletteEditOps._get_content_positions(palette, metadata, region_id)
if i_start < 0 or i_start >= len(positions):
raise InvalidPointerError(f"i_start={i_start} out of bounds")
abs_pos = positions[i_start]
h, w = abs_pos // W, abs_pos % W
old_value = palette[h, w].item()
is_old_scope = old_value in (PaletteEditOps.START_OF_SCOPE, PaletteEditOps.END_OF_SCOPE)
is_new_scope = payload_idx in (PaletteEditOps.START_OF_SCOPE, PaletteEditOps.END_OF_SCOPE)
if is_old_scope and not is_new_scope:
raise ScopeBalanceError("Cannot replace scope marker with non-marker")
if is_old_scope and is_new_scope and old_value != payload_idx:
raise ScopeBalanceError("Cannot replace START with END or vice versa")
new_palette = palette.clone()
new_palette[h, w] = payload_idx
return new_palette
@staticmethod
def swap_tokens(
palette: torch.Tensor, region_id: int,
i_start: int, i_end: int, metadata: RegionMetadata
) -> torch.Tensor:
"""Swap tokens at i_start and i_end within region."""
H, W = palette.shape
positions = PaletteEditOps._get_content_positions(palette, metadata, region_id)
if i_start < 0 or i_start >= len(positions):
raise InvalidPointerError(f"i_start={i_start} out of bounds")
if i_end < 0 or i_end >= len(positions):
raise InvalidPointerError(f"i_end={i_end} out of bounds")
p1, p2 = positions[i_start], positions[i_end]
h1, w1 = p1 // W, p1 % W
h2, w2 = p2 // W, p2 % W
v1, v2 = palette[h1, w1].item(), palette[h2, w2].item()
if {v1, v2} == {PaletteEditOps.START_OF_SCOPE, PaletteEditOps.END_OF_SCOPE}:
raise ScopeBalanceError("Cannot swap START ↔ END")
new_palette = palette.clone()
new_palette[h1, w1], new_palette[h2, w2] = palette[h2, w2], palette[h1, w1]
return new_palette
@staticmethod
def move_range(
palette: torch.Tensor, src_region: int,
i_start: int, i_end: int,
dst_region: int, dst_pos: int,
metadata: RegionMetadata
) -> torch.Tensor:
"""Move tokens [i_start,i_end] from src_region to dst_pos in dst_region.
= copy + delete source. Cross-region cut-paste."""
if dst_region < 0:
dst_region = src_region
# First copy, then delete from source
result = PaletteEditOps.copy_range(
palette, src_region, i_start, i_end, dst_region, dst_pos, metadata)
# After copy, source positions shifted β€” recalculate metadata on new palette
# For correctness, we delete from original positions in the post-copy palette.
# The copy inserted (i_end - i_start + 1) tokens into dst, which may shift src positions.
# Simplification: if same region, account for shift; if different, positions unchanged.
n_copied = i_end - i_start + 1
src_positions = PaletteEditOps._get_content_positions(result, metadata, src_region)
# Find the original source tokens by value matching
orig_positions = PaletteEditOps._get_content_positions(palette, metadata, src_region)
orig_flat = palette.flatten()
src_values = [orig_flat[orig_positions[i]].item() for i in range(i_start, i_end + 1)]
# Delete from result: find matching tokens in src region
result_flat = result.flatten()
H, W = result.shape
delete_mask = torch.zeros(H * W, dtype=torch.bool)
deleted = 0
for pos in src_positions:
val = result_flat[pos].item()
if deleted < n_copied and val == src_values[deleted]:
if val not in (PaletteEditOps.START_OF_SCOPE, PaletteEditOps.END_OF_SCOPE):
delete_mask[pos] = True
deleted += 1
if deleted == 0:
raise InvalidPointerError("Could not locate source tokens for deletion")
kept = result_flat[~delete_mask]
pad = torch.full((H * W - len(kept),), PaletteEditOps.NOOP, dtype=palette.dtype)
return torch.cat([kept, pad]).view(H, W)
@staticmethod
def copy_range(
palette: torch.Tensor, src_region: int,
i_start: int, i_end: int,
dst_region: int, dst_pos: int,
metadata: RegionMetadata
) -> torch.Tensor:
"""Copy tokens [i_start,i_end] from src_region, insert at dst_pos in dst_region."""
if dst_region < 0:
dst_region = src_region
src_positions = PaletteEditOps._get_content_positions(palette, metadata, src_region)
if i_start < 0 or i_end >= len(src_positions):
raise InvalidPointerError(f"Source range [{i_start},{i_end}] out of bounds")
flat = palette.flatten()
copied_tokens = [flat[src_positions[i]].item() for i in range(i_start, i_end + 1)]
# Validate: no scope markers in copy
for t in copied_tokens:
if t in (PaletteEditOps.START_OF_SCOPE, PaletteEditOps.END_OF_SCOPE):
raise ScopeBalanceError("Cannot copy scope markers without pairing")
dst_positions = PaletteEditOps._get_content_positions(palette, metadata, dst_region)
if dst_pos < 0 or dst_pos > len(dst_positions):
raise InvalidPointerError(f"dst_pos={dst_pos} out of bounds")
abs_dst = dst_positions[dst_pos] if dst_pos < len(dst_positions) else (
dst_positions[-1] + 1 if dst_positions else 0)
H, W = palette.shape
new_flat = torch.full((H * W,), PaletteEditOps.NOOP, dtype=palette.dtype)
n = len(copied_tokens)
new_flat[:abs_dst] = flat[:abs_dst]
for i, t in enumerate(copied_tokens):
if abs_dst + i < H * W:
new_flat[abs_dst + i] = t
remaining = min(H * W - abs_dst - n, H * W - abs_dst)
if remaining > 0:
new_flat[abs_dst + n:abs_dst + n + remaining] = flat[abs_dst:abs_dst + remaining]
return new_flat.view(H, W)
@staticmethod
def wrap_scope(
palette: torch.Tensor, region_id: int,
i_start: int, i_end: int,
metadata: RegionMetadata
) -> torch.Tensor:
"""Wrap tokens [i_start,i_end] in new scope markers (START...END).
Inserts START before i_start, END after i_end. Balanced by construction."""
H, W = palette.shape
positions = PaletteEditOps._get_content_positions(palette, metadata, region_id)
if i_start < 0 or i_end >= len(positions) or i_end < i_start:
raise InvalidPointerError(f"Range [{i_start},{i_end}] invalid")
abs_start = positions[i_start]
abs_end = positions[i_end]
flat = palette.flatten()
# Insert START before abs_start, END after abs_end
# Build new sequence: [before_start] + [START] + [start..end] + [END] + [after_end]
before = flat[:abs_start].tolist()
wrapped = flat[abs_start:abs_end + 1].tolist()
after = flat[abs_end + 1:].tolist()
new_seq = before + [PaletteEditOps.START_OF_SCOPE] + wrapped + [PaletteEditOps.END_OF_SCOPE] + after
# Truncate or pad to H*W
if len(new_seq) > H * W:
new_seq = new_seq[:H * W]
elif len(new_seq) < H * W:
new_seq.extend([PaletteEditOps.NOOP] * (H * W - len(new_seq)))
return torch.tensor(new_seq, dtype=palette.dtype).view(H, W)
@staticmethod
def unwrap_scope(
palette: torch.Tensor, region_id: int,
metadata: RegionMetadata
) -> torch.Tensor:
"""Remove the outermost scope markers of a region. Content preserved, scope removed.
Removes START at region start and END at region end."""
H, W = palette.shape
flat = palette.flatten()
start_pos = metadata.starts[region_id]
end_pos = metadata.ends[region_id]
# Verify markers exist
if flat[start_pos].item() != PaletteEditOps.START_OF_SCOPE:
raise ScopeBalanceError("Region start is not START_OF_SCOPE")
if flat[end_pos].item() != PaletteEditOps.END_OF_SCOPE:
raise ScopeBalanceError("Region end is not END_OF_SCOPE")
# Remove both markers
delete_mask = torch.zeros(H * W, dtype=torch.bool)
delete_mask[start_pos] = True
delete_mask[end_pos] = True
kept = flat[~delete_mask]
pad = torch.full((H * W - len(kept),), PaletteEditOps.NOOP, dtype=palette.dtype)
return torch.cat([kept, pad]).view(H, W)
# ================================================================== #
# LEVEL 2 β€” Structural (scope-tree mutations) #
# ================================================================== #
@staticmethod
def indent(
palette: torch.Tensor, region_id: int,
i_start: int, i_end: int,
metadata: RegionMetadata
) -> torch.Tensor:
"""Increase scope depth: wrap [i_start,i_end] in new scope.
Equivalent to wrap_scope β€” increases nesting by 1."""
return PaletteEditOps.wrap_scope(palette, region_id, i_start, i_end, metadata)
@staticmethod
def dedent(
palette: torch.Tensor, region_id: int,
i_start: int, i_end: int,
metadata: RegionMetadata
) -> torch.Tensor:
"""Decrease scope depth: remove innermost scope around [i_start,i_end].
Finds the tightest enclosing scope and removes its markers."""
H, W = palette.shape
flat = palette.flatten()
positions = PaletteEditOps._get_content_positions(palette, metadata, region_id)
if i_start < 0 or i_end >= len(positions):
raise InvalidPointerError(f"Range [{i_start},{i_end}] invalid")
abs_start = positions[i_start]
abs_end = positions[i_end]
# Walk outward from abs_start to find enclosing START
enclosing_start = -1
for i in range(abs_start - 1, -1, -1):
if flat[i].item() == PaletteEditOps.START_OF_SCOPE:
enclosing_start = i
break
if enclosing_start < 0:
raise ScopeBalanceError("No enclosing scope to dedent from")
# Find matching END
depth = 0
enclosing_end = -1
for i in range(enclosing_start, H * W):
v = flat[i].item()
if v == PaletteEditOps.START_OF_SCOPE:
depth += 1
elif v == PaletteEditOps.END_OF_SCOPE:
depth -= 1
if depth == 0:
enclosing_end = i
break
if enclosing_end < 0 or enclosing_end < abs_end:
raise ScopeBalanceError("Cannot find matching END for enclosing scope")
# Remove the enclosing pair
delete_mask = torch.zeros(H * W, dtype=torch.bool)
delete_mask[enclosing_start] = True
delete_mask[enclosing_end] = True
kept = flat[~delete_mask]
pad = torch.full((H * W - len(kept),), PaletteEditOps.NOOP, dtype=palette.dtype)
return torch.cat([kept, pad]).view(H, W)
@staticmethod
def extract(
palette: torch.Tensor, region_id: int,
i_start: int, i_end: int,
metadata: RegionMetadata
) -> torch.Tensor:
"""Extract tokens [i_start,i_end] into a new scope appended after current region.
Source range replaced with a reference token (payload_idx=3 = EXTRACTED_REF).
New scope with extracted content appears after current region's END."""
EXTRACTED_REF = 3 # Sentinel: "content was extracted here"
H, W = palette.shape
flat = palette.flatten()
positions = PaletteEditOps._get_content_positions(palette, metadata, region_id)
if i_start < 0 or i_end >= len(positions) or i_end < i_start:
raise InvalidPointerError(f"Range [{i_start},{i_end}] invalid")
# Grab tokens to extract
extracted = [flat[positions[i]].item() for i in range(i_start, i_end + 1)]
for t in extracted:
if t in (PaletteEditOps.START_OF_SCOPE, PaletteEditOps.END_OF_SCOPE):
raise ScopeBalanceError("Cannot extract scope markers")
# Replace source range with single ref token
abs_start = positions[i_start]
abs_end = positions[i_end]
before = flat[:abs_start].tolist()
after = flat[abs_end + 1:].tolist()
middle = [EXTRACTED_REF]
# Find insertion point: after region's END marker
region_end = metadata.ends[region_id]
# Adjust region_end for removed tokens
n_removed = (i_end - i_start + 1)
adj_end = region_end - n_removed + 1 # +1 for ref token
new_scope = [PaletteEditOps.START_OF_SCOPE] + extracted + [PaletteEditOps.END_OF_SCOPE]
seq = before + middle + after
# Insert new scope after adjusted region end
insert_at = min(adj_end + 1, len(seq))
seq = seq[:insert_at] + new_scope + seq[insert_at:]
# Truncate/pad
if len(seq) > H * W:
seq = seq[:H * W]
else:
seq.extend([PaletteEditOps.NOOP] * (H * W - len(seq)))
return torch.tensor(seq, dtype=palette.dtype).view(H, W)
@staticmethod
def inline(
palette: torch.Tensor, src_region: int,
target_region: int, metadata: RegionMetadata
) -> torch.Tensor:
"""Inline: replace a reference in target_region with contents of src_region.
Opposite of extract. Removes src_region scope, inserts content at ref position."""
if target_region < 0:
raise CrossRegionError("target_region_id required for INLINE")
EXTRACTED_REF = 3
H, W = palette.shape
flat = palette.flatten()
# Get src region content (excluding scope markers)
src_content = PaletteEditOps._get_content_positions(palette, metadata, src_region)
content_tokens = [flat[pos].item() for pos in src_content]
# Find ref token in target region
target_positions = PaletteEditOps._get_content_positions(palette, metadata, target_region)
ref_pos = -1
for pos in target_positions:
if flat[pos].item() == EXTRACTED_REF:
ref_pos = pos
break
if ref_pos < 0:
raise PatternNotFoundError("No EXTRACTED_REF found in target region")
# Remove src region entirely (with scope markers)
src_start = metadata.starts[src_region]
src_end = metadata.ends[src_region]
seq = flat.tolist()
# Replace ref with content
ref_idx = seq.index(EXTRACTED_REF) if EXTRACTED_REF in seq else ref_pos
seq = seq[:ref_idx] + content_tokens + seq[ref_idx + 1:]
# Remove src scope markers and content
# Recalculate positions after insertion
# Simpler: remove src region range from original, then insert content at ref
# Let's rebuild from scratch
flat_list = flat.tolist()
# Step 1: mark src region for removal
remove = set(range(src_start, src_end + 1))
cleaned = [(i, v) for i, v in enumerate(flat_list) if i not in remove]
# Step 2: find ref in cleaned sequence and replace
new_seq = []
for _, v in cleaned:
if v == EXTRACTED_REF:
new_seq.extend(content_tokens)
else:
new_seq.append(v)
if len(new_seq) > H * W:
new_seq = new_seq[:H * W]
else:
new_seq.extend([PaletteEditOps.NOOP] * (H * W - len(new_seq)))
return torch.tensor(new_seq, dtype=palette.dtype).view(H, W)
@staticmethod
def split_region(
palette: torch.Tensor, region_id: int,
split_at: int, metadata: RegionMetadata
) -> torch.Tensor:
"""Split region into two at position split_at.
Inserts END + START between positions split_at-1 and split_at."""
H, W = palette.shape
positions = PaletteEditOps._get_content_positions(palette, metadata, region_id)
if split_at <= 0 or split_at >= len(positions):
raise InvalidPointerError(f"split_at={split_at} must be in (0, {len(positions)})")
abs_split = positions[split_at]
flat = palette.flatten()
before = flat[:abs_split].tolist()
after = flat[abs_split:].tolist()
# Insert END then START to create two regions
new_seq = before + [PaletteEditOps.END_OF_SCOPE, PaletteEditOps.START_OF_SCOPE] + after
if len(new_seq) > H * W:
new_seq = new_seq[:H * W]
else:
new_seq.extend([PaletteEditOps.NOOP] * (H * W - len(new_seq)))
return torch.tensor(new_seq, dtype=palette.dtype).view(H, W)
@staticmethod
def merge_regions(
palette: torch.Tensor, region_a: int,
region_b: int, metadata: RegionMetadata
) -> torch.Tensor:
"""Merge two adjacent regions by removing the END of A and START of B.
Regions must be adjacent (A's END directly before B's START)."""
if region_b < 0:
raise CrossRegionError("target_region_id required for MERGE")
H, W = palette.shape
flat = palette.flatten()
end_a = metadata.ends[region_a]
start_b = metadata.starts[region_b]
# Verify adjacency
if flat[end_a].item() != PaletteEditOps.END_OF_SCOPE:
raise ScopeBalanceError("Region A end is not END_OF_SCOPE")
if flat[start_b].item() != PaletteEditOps.START_OF_SCOPE:
raise ScopeBalanceError("Region B start is not START_OF_SCOPE")
# Remove both markers
delete_mask = torch.zeros(H * W, dtype=torch.bool)
delete_mask[end_a] = True
delete_mask[start_b] = True
kept = flat[~delete_mask]
pad = torch.full((H * W - len(kept),), PaletteEditOps.NOOP, dtype=palette.dtype)
return torch.cat([kept, pad]).view(H, W)
@staticmethod
def reorder(
palette: torch.Tensor, region_id: int,
i_start: int, i_end: int,
metadata: RegionMetadata
) -> torch.Tensor:
"""Reverse the order of tokens [i_start,i_end] within region.
Generalizable to arbitrary permutations via payload_tokens."""
H, W = palette.shape
positions = PaletteEditOps._get_content_positions(palette, metadata, region_id)
if i_start < 0 or i_end >= len(positions) or i_end < i_start:
raise InvalidPointerError(f"Range [{i_start},{i_end}] invalid")
flat = palette.flatten()
new_palette = palette.clone()
new_flat = new_palette.flatten()
# Reverse the range
vals = [flat[positions[i]].item() for i in range(i_start, i_end + 1)]
vals.reverse()
for i, val in enumerate(vals):
pos = positions[i_start + i]
new_flat[pos] = val
return new_flat.view(H, W)
@staticmethod
def nest_in_block(
palette: torch.Tensor, region_id: int,
i_start: int, i_end: int,
block_type_hue: int, metadata: RegionMetadata
) -> torch.Tensor:
"""Wrap [i_start,i_end] in a new control block (if/for/try/function).
Inserts: START + block_type_hue + [content] + END.
The block_type_hue identifies the construct type (20=for, 24=if, etc.)."""
H, W = palette.shape
positions = PaletteEditOps._get_content_positions(palette, metadata, region_id)
if i_start < 0 or i_end >= len(positions) or i_end < i_start:
raise InvalidPointerError(f"Range [{i_start},{i_end}] invalid")
abs_start = positions[i_start]
abs_end = positions[i_end]
flat = palette.flatten()
before = flat[:abs_start].tolist()
content = flat[abs_start:abs_end + 1].tolist()
after = flat[abs_end + 1:].tolist()
new_seq = (before
+ [PaletteEditOps.START_OF_SCOPE, block_type_hue]
+ content
+ [PaletteEditOps.END_OF_SCOPE]
+ after)
if len(new_seq) > H * W:
new_seq = new_seq[:H * W]
else:
new_seq.extend([PaletteEditOps.NOOP] * (H * W - len(new_seq)))
return torch.tensor(new_seq, dtype=palette.dtype).view(H, W)
@staticmethod
def unnest_from_block(
palette: torch.Tensor, region_id: int,
metadata: RegionMetadata
) -> torch.Tensor:
"""Remove control block scope: remove START, block_type token, and matching END.
Content is preserved and lifted to parent scope."""
H, W = palette.shape
flat = palette.flatten()
start_pos = metadata.starts[region_id]
end_pos = metadata.ends[region_id]
if flat[start_pos].item() != PaletteEditOps.START_OF_SCOPE:
raise ScopeBalanceError("Region start is not START_OF_SCOPE")
if flat[end_pos].item() != PaletteEditOps.END_OF_SCOPE:
raise ScopeBalanceError("Region end is not END_OF_SCOPE")
# Remove START, the token immediately after START (block type), and END
delete_mask = torch.zeros(H * W, dtype=torch.bool)
delete_mask[start_pos] = True
if start_pos + 1 < H * W:
delete_mask[start_pos + 1] = True # block type hue
delete_mask[end_pos] = True
kept = flat[~delete_mask]
pad = torch.full((H * W - len(kept),), PaletteEditOps.NOOP, dtype=palette.dtype)
return torch.cat([kept, pad]).view(H, W)
@staticmethod
def hoist(
palette: torch.Tensor, region_id: int,
i_start: int, i_end: int,
metadata: RegionMetadata
) -> torch.Tensor:
"""Hoist: move tokens [i_start,i_end] from current region to before region's START.
Declaration moves to higher scope."""
H, W = palette.shape
flat = palette.flatten()
positions = PaletteEditOps._get_content_positions(palette, metadata, region_id)
if i_start < 0 or i_end >= len(positions) or i_end < i_start:
raise InvalidPointerError(f"Range [{i_start},{i_end}] invalid")
# Extract tokens
hoisted = [flat[positions[i]].item() for i in range(i_start, i_end + 1)]
for t in hoisted:
if t in (PaletteEditOps.START_OF_SCOPE, PaletteEditOps.END_OF_SCOPE):
raise ScopeBalanceError("Cannot hoist scope markers")
# Remove from current positions
abs_positions = [positions[i] for i in range(i_start, i_end + 1)]
delete_mask = torch.zeros(H * W, dtype=torch.bool)
for pos in abs_positions:
delete_mask[pos] = True
cleaned = flat[~delete_mask].tolist()
# Insert before region's START marker
region_start = metadata.starts[region_id]
# Adjust for deletions before region_start
adj = sum(1 for p in abs_positions if p < region_start)
insert_at = region_start - adj
new_seq = cleaned[:insert_at] + hoisted + cleaned[insert_at:]
if len(new_seq) > H * W:
new_seq = new_seq[:H * W]
else:
new_seq.extend([PaletteEditOps.NOOP] * (H * W - len(new_seq)))
return torch.tensor(new_seq, dtype=palette.dtype).view(H, W)
@staticmethod
def sink(
palette: torch.Tensor, region_id: int,
i_start: int, i_end: int,
target_region: int, metadata: RegionMetadata
) -> torch.Tensor:
"""Sink: move tokens from current region into a deeper (child) region.
Opposite of hoist. Tokens move from parent scope to target child scope."""
if target_region < 0:
raise CrossRegionError("target_region_id required for SINK")
# This is a move from region_id to target_region
return PaletteEditOps.move_range(
palette, region_id, i_start, i_end,
target_region, 0, metadata)
# ================================================================== #
# LEVEL 3 β€” Semantic (meaning-level transforms) #
# ================================================================== #
@staticmethod
def rename(
palette: torch.Tensor, region_id: int,
i_start: int, new_hue: int, metadata: RegionMetadata
) -> torch.Tensor:
"""Rename: replace identifier hue at i_start with new_hue.
Same as REPLACE but semantically constrained to identifiers."""
return PaletteEditOps.replace_token(palette, region_id, i_start, new_hue, metadata)
@staticmethod
def retype(
palette: torch.Tensor, region_id: int,
i_start: int, i_end: int,
new_type_tokens: List[int], metadata: RegionMetadata
) -> torch.Tensor:
"""Retype: replace type annotation range [i_start,i_end] with new tokens.
Handles type annotations that may change length (int β†’ List[int])."""
H, W = palette.shape
positions = PaletteEditOps._get_content_positions(palette, metadata, region_id)
if i_start < 0 or i_end >= len(positions) or i_end < i_start:
raise InvalidPointerError(f"Range [{i_start},{i_end}] invalid")
abs_start = positions[i_start]
abs_end = positions[i_end]
flat = palette.flatten()
before = flat[:abs_start].tolist()
after = flat[abs_end + 1:].tolist()
new_seq = before + list(new_type_tokens) + after
if len(new_seq) > H * W:
new_seq = new_seq[:H * W]
else:
new_seq.extend([PaletteEditOps.NOOP] * (H * W - len(new_seq)))
return torch.tensor(new_seq, dtype=palette.dtype).view(H, W)
@staticmethod
def convert_construct(
palette: torch.Tensor, region_id: int,
pattern_target: List[int], metadata: RegionMetadata
) -> torch.Tensor:
"""Convert construct: pattern match and replace within region.
pattern_target = [*pattern_tokens, -1, *target_tokens] where -1 is separator.
If empty, falls back to built-in MACRO_PATTERNS."""
if not pattern_target:
# Use first built-in pattern
macro = list(PaletteEditOps.MACRO_PATTERNS.values())[0]
pattern = macro['pattern']
target = macro['target']
else:
if -1 not in pattern_target:
raise PatternNotFoundError("pattern_target must contain -1 separator")
sep = pattern_target.index(-1)
pattern = pattern_target[:sep]
target = pattern_target[sep + 1:]
positions = PaletteEditOps._get_content_positions(palette, metadata, region_id)
flat = palette.flatten()
region_tokens = [flat[pos].item() for pos in positions]
plen = len(pattern)
found = False
new_palette = palette.clone()
for i in range(len(region_tokens) - plen + 1):
if region_tokens[i:i + plen] == pattern:
# Replace with target (may differ in length)
if len(target) == plen:
# Same length: direct replacement
for j, t in enumerate(target):
pos = positions[i + j]
h, w = pos // palette.shape[1], pos % palette.shape[1]
new_palette[h, w] = t
else:
# Different length: rebuild sequence
abs_start = positions[i]
abs_end = positions[i + plen - 1]
H, W = palette.shape
flat_list = flat.tolist()
new_seq = flat_list[:abs_start] + target + flat_list[abs_end + 1:]
if len(new_seq) > H * W:
new_seq = new_seq[:H * W]
else:
new_seq.extend([PaletteEditOps.NOOP] * (H * W - len(new_seq)))
new_palette = torch.tensor(new_seq, dtype=palette.dtype).view(H, W)
found = True
break
if not found:
raise PatternNotFoundError(f"Pattern {pattern} not found in region")
return new_palette
@staticmethod
def sync_to_async(
palette: torch.Tensor, region_id: int,
metadata: RegionMetadata
) -> torch.Tensor:
"""Add async/await markers to region.
Inserts async hue (hue 46) before region's first function-def token (hue 12),
and await hue (hue 47) before call tokens (hue 60)."""
ASYNC_HUE = 46
AWAIT_HUE = 47
FUNC_DEF_HUE = 12
CALL_HUE = 60
H, W = palette.shape
positions = PaletteEditOps._get_content_positions(palette, metadata, region_id)
flat = palette.flatten()
insertions = [] # (abs_pos, hue_to_insert)
for pos in positions:
val = flat[pos].item()
if val == FUNC_DEF_HUE:
insertions.append((pos, ASYNC_HUE))
elif val == CALL_HUE:
insertions.append((pos, AWAIT_HUE))
if not insertions:
raise PatternNotFoundError("No function defs or calls found to make async")
# Build new sequence with insertions
seq = flat.tolist()
offset = 0
for abs_pos, hue in sorted(insertions):
seq.insert(abs_pos + offset, hue)
offset += 1
if len(seq) > H * W:
seq = seq[:H * W]
else:
seq.extend([PaletteEditOps.NOOP] * (H * W - len(seq)))
return torch.tensor(seq, dtype=palette.dtype).view(H, W)
@staticmethod
def parameterize(
palette: torch.Tensor, region_id: int,
i_start: int, param_hue: int, metadata: RegionMetadata
) -> torch.Tensor:
"""Replace a hardcoded literal at i_start with a parameter reference (param_hue).
The literal hue becomes a variable/parameter hue."""
return PaletteEditOps.replace_token(palette, region_id, i_start, param_hue, metadata)
@staticmethod
def specialize(
palette: torch.Tensor, region_id: int,
i_start: int, i_end: int,
concrete_tokens: List[int], metadata: RegionMetadata
) -> torch.Tensor:
"""Replace generic type tokens [i_start,i_end] with concrete specialization.
Opposite of parameterize for types: List[T] β†’ List[int]."""
return PaletteEditOps.retype(palette, region_id, i_start, i_end, concrete_tokens, metadata)
@staticmethod
def guard(
palette: torch.Tensor, region_id: int,
i_start: int, i_end: int,
guard_hue: int, metadata: RegionMetadata
) -> torch.Tensor:
"""Wrap [i_start,i_end] in a conditional guard (if/try/etc).
Like nest_in_block with a guard-specific hue."""
return PaletteEditOps.nest_in_block(
palette, region_id, i_start, i_end, guard_hue, metadata)
@staticmethod
def unguard(
palette: torch.Tensor, region_id: int,
metadata: RegionMetadata
) -> torch.Tensor:
"""Remove conditional guard from region. Content lifted to parent scope.
Like unnest_from_block."""
return PaletteEditOps.unnest_from_block(palette, region_id, metadata)
@staticmethod
def scatter(
palette: torch.Tensor, region_id: int,
new_hue: int, target_positions: List[int],
metadata: RegionMetadata
) -> torch.Tensor:
"""Replace token at multiple positions with new_hue.
Same change applied to N locations (rename-all, update-all-call-sites)."""
positions = PaletteEditOps._get_content_positions(palette, metadata, region_id)
new_palette = palette.clone()
H, W = palette.shape
for pos_idx in target_positions:
if pos_idx < 0 or pos_idx >= len(positions):
raise InvalidPointerError(f"Position {pos_idx} out of bounds")
abs_pos = positions[pos_idx]
h, w = abs_pos // W, abs_pos % W
val = new_palette[h, w].item()
if val in (PaletteEditOps.START_OF_SCOPE, PaletteEditOps.END_OF_SCOPE):
raise ScopeBalanceError("Cannot scatter over scope markers")
new_palette[h, w] = new_hue
return new_palette
@staticmethod
def gather(
palette: torch.Tensor, source_positions: List[int],
target_region: int, target_pos: int,
metadata: RegionMetadata
) -> torch.Tensor:
"""Gather: collect tokens from multiple positions into a single location.
Tokens at source_positions are removed and concatenated at target_pos in target_region.
Opposite of scatter."""
H, W = palette.shape
flat = palette.flatten()
# Collect values from source positions (these are region-relative in first region)
# source_positions are absolute flat indices for simplicity
gathered_vals = []
for pos in source_positions:
if pos < 0 or pos >= H * W:
raise InvalidPointerError(f"Source position {pos} out of bounds")
val = flat[pos].item()
if val in (PaletteEditOps.START_OF_SCOPE, PaletteEditOps.END_OF_SCOPE):
raise ScopeBalanceError("Cannot gather scope markers")
gathered_vals.append(val)
# Remove source positions
delete_mask = torch.zeros(H * W, dtype=torch.bool)
for pos in source_positions:
delete_mask[pos] = True
cleaned = flat[~delete_mask].tolist()
# Insert gathered values at target position in target region
target_positions_list = PaletteEditOps._get_content_positions(palette, metadata, target_region)
if target_pos < 0 or target_pos > len(target_positions_list):
raise InvalidPointerError(f"target_pos={target_pos} out of bounds")
# Adjust target pos for deletions before it
abs_target = target_positions_list[target_pos] if target_pos < len(target_positions_list) else (
target_positions_list[-1] + 1 if target_positions_list else 0)
adj = sum(1 for p in source_positions if p < abs_target)
abs_target -= adj
new_seq = cleaned[:abs_target] + gathered_vals + cleaned[abs_target:]
if len(new_seq) > H * W:
new_seq = new_seq[:H * W]
else:
new_seq.extend([PaletteEditOps.NOOP] * (H * W - len(new_seq)))
return torch.tensor(new_seq, dtype=palette.dtype).view(H, W)
@staticmethod
def mirror(
palette: torch.Tensor, region_a: int,
region_b: int, i_start: int, i_end: int,
new_hue: int, metadata: RegionMetadata
) -> torch.Tensor:
"""Apply symmetric change to paired regions A and B.
Replace tokens at [i_start,i_end] in BOTH regions with new_hue.
For getter/setter pairs, request/response symmetry, etc."""
if region_b < 0:
raise CrossRegionError("target_region_id required for MIRROR")
new_palette = palette.clone()
H, W = palette.shape
for rid in [region_a, region_b]:
positions = PaletteEditOps._get_content_positions(new_palette, metadata, rid)
if i_start < 0 or i_end >= len(positions):
raise InvalidPointerError(f"Range [{i_start},{i_end}] out of bounds in region {rid}")
for i in range(i_start, i_end + 1):
abs_pos = positions[i]
h, w = abs_pos // W, abs_pos % W
val = new_palette[h, w].item()
if val in (PaletteEditOps.START_OF_SCOPE, PaletteEditOps.END_OF_SCOPE):
raise ScopeBalanceError("Cannot mirror over scope markers")
new_palette[h, w] = new_hue
return new_palette
@staticmethod
def compose(
palette: torch.Tensor, region_id: int,
i_start: int, i_end: int,
metadata: RegionMetadata
) -> torch.Tensor:
"""Compose: fuse sequential statements [i_start,i_end] into a single expression.
Removes intermediate scope boundaries within the range.
Tokens are kept, internal START/END pairs are removed."""
H, W = palette.shape
flat = palette.flatten()
# Get ALL positions in region (including scope markers for this op)
mask = metadata.masks[region_id]
all_positions = mask.nonzero(as_tuple=False)
all_flat = sorted((all_positions[:, 0] * W + all_positions[:, 1]).tolist())
if i_start < 0 or i_end >= len(all_flat) or i_end < i_start:
raise InvalidPointerError(f"Range [{i_start},{i_end}] invalid")
# Within range, remove internal START/END pairs (not outermost)
range_positions = all_flat[i_start:i_end + 1]
delete_mask = torch.zeros(H * W, dtype=torch.bool)
# Find internal scope markers (not the first START or last END)
depth = 0
for pos in range_positions:
val = flat[pos].item()
if val == PaletteEditOps.START_OF_SCOPE:
depth += 1
if depth > 1: # Internal
delete_mask[pos] = True
elif val == PaletteEditOps.END_OF_SCOPE:
if depth > 1: # Internal
delete_mask[pos] = True
depth -= 1
kept = flat[~delete_mask]
pad = torch.full((H * W - len(kept),), PaletteEditOps.NOOP, dtype=palette.dtype)
return torch.cat([kept, pad]).view(H, W)
# ================================================================== #
# Helpers #
# ================================================================== #
@staticmethod
def verify_scope_balance(palette: torch.Tensor) -> bool:
"""Check START_OF_SCOPE count == END_OF_SCOPE count."""
num_starts = (palette == PaletteEditOps.START_OF_SCOPE).sum().item()
num_ends = (palette == PaletteEditOps.END_OF_SCOPE).sum().item()
return num_starts == num_ends
@staticmethod
def _get_content_positions(
palette: torch.Tensor, metadata: RegionMetadata, region_id: int
) -> List[int]:
"""Get flattened positions of content tokens in region (excluding scope markers)."""
H, W = palette.shape
mask = metadata.masks[region_id]
positions = mask.nonzero(as_tuple=False)
flat_positions = (positions[:, 0] * W + positions[:, 1]).tolist()
filtered = []
for pos in flat_positions:
h, w = pos // W, pos % W
token = palette[h, w].item()
if token not in (PaletteEditOps.START_OF_SCOPE, PaletteEditOps.END_OF_SCOPE):
filtered.append(pos)
return sorted(filtered)
@staticmethod
def _get_region_positions(mask: torch.Tensor, W: int, palette: torch.Tensor = None) -> List[int]:
"""Legacy helper β€” use _get_content_positions instead."""
positions = mask.nonzero(as_tuple=False)
flat_positions = (positions[:, 0] * W + positions[:, 1]).tolist()
if palette is not None:
filtered = []
for pos in flat_positions:
h, w = pos // W, pos % W
token = palette[h, w].item()
if token not in (PaletteEditOps.START_OF_SCOPE, PaletteEditOps.END_OF_SCOPE):
filtered.append(pos)
return sorted(filtered)
return sorted(flat_positions)