| """ |
| Parametric Edit Operations Engine β 31 Operations |
| |
| Executes edit operations on palette tensors with pointer-based addressing. |
| Operations preserve scope balance and support region-relative indexing. |
| |
| Three levels: |
| Level 1 β Primitive (8 ops): token-level atomic edits |
| Level 2 β Structural (11 ops): scope-tree mutations |
| Level 3 β Semantic (12 ops): meaning-level transforms |
| |
| Key Features: |
| - Parametric actions with arguments |
| - Pointer arithmetic (region-relative β absolute) |
| - Cross-region operations (MOVE, COPY, EXTRACT, INLINE, MERGE, etc.) |
| - Scope balance verification |
| - Macro pattern transformations |
| - Stateless execution (pure functions) |
| """ |
|
|
| import torch |
| from dataclasses import dataclass, field |
| from typing import Tuple, Optional, List, Dict |
| from enum import IntEnum |
|
|
| |
| from .scope_pooler import RegionMetadata |
|
|
|
|
| class OpCode(IntEnum): |
| """Operation codes β 31 edit operations across 3 levels""" |
|
|
| |
| NO_OP = 0 |
| MOVE_NEXT = 1 |
| FOCUS_PARENT = 2 |
| DONE = 99 |
|
|
| |
| DELETE_RANGE = 100 |
| INSERT_TOKEN = 101 |
| REPLACE_TOKEN = 102 |
| SWAP_TOKENS = 103 |
| MOVE_RANGE = 104 |
| COPY_RANGE = 105 |
| WRAP_SCOPE = 106 |
| UNWRAP_SCOPE = 107 |
|
|
| |
| INDENT = 200 |
| DEDENT = 201 |
| EXTRACT = 202 |
| INLINE = 203 |
| SPLIT_REGION = 204 |
| MERGE_REGIONS = 205 |
| REORDER = 206 |
| NEST_IN_BLOCK = 207 |
| UNNEST_FROM_BLOCK = 208 |
| HOIST = 209 |
| SINK = 210 |
|
|
| |
| RENAME = 300 |
| RETYPE = 301 |
| CONVERT_CONSTRUCT = 302 |
| SYNC_TO_ASYNC = 303 |
| PARAMETERIZE = 304 |
| SPECIALIZE = 305 |
| GUARD = 306 |
| UNGUARD = 307 |
| SCATTER = 308 |
| GATHER = 309 |
| MIRROR = 310 |
| COMPOSE = 311 |
|
|
|
|
| |
| _LEGACY_OPCODES = { |
| 150: OpCode.DELETE_RANGE, |
| 151: OpCode.INSERT_TOKEN, |
| 152: OpCode.REPLACE_TOKEN, |
| 153: OpCode.SWAP_TOKENS, |
| 499: OpCode.DONE, |
| } |
|
|
| |
| OP_LEVEL = {} |
| for op in OpCode: |
| v = op.value |
| if v < 100: |
| OP_LEVEL[op] = 'control' |
| elif v < 200: |
| OP_LEVEL[op] = 'primitive' |
| elif v < 300: |
| OP_LEVEL[op] = 'structural' |
| else: |
| OP_LEVEL[op] = 'semantic' |
|
|
| |
| TRAINABLE_OPS: List[OpCode] = [op for op in OpCode if OP_LEVEL[op] != 'control'] |
| NUM_OPS = len(TRAINABLE_OPS) |
| OP_TO_IDX: Dict[OpCode, int] = {op: i for i, op in enumerate(TRAINABLE_OPS)} |
| IDX_TO_OP: Dict[int, OpCode] = {i: op for i, op in enumerate(TRAINABLE_OPS)} |
|
|
|
|
| @dataclass |
| class EditAction: |
| """ |
| Parametric edit operation with arguments. |
| |
| Fields: |
| op_id: Operation code from OpCode enum |
| region_id: Which semantic region to operate on [0, R) |
| i_start: Token index within region (relative addressing) |
| i_end: End token index (for range operations, -1 if unused) |
| payload_idx: Palette index to insert/replace (0-4095) |
| confidence: Model confidence in [0, 1] |
| target_region_id: Destination region for cross-region ops (-1 if same region) |
| payload_tokens: Multi-token payload for WRAP, NEST, etc. |
| positions: Multiple target positions for SCATTER |
| """ |
| op_id: int |
| region_id: int |
| i_start: int |
| i_end: int |
| payload_idx: int |
| confidence: float = 1.0 |
| target_region_id: int = -1 |
| payload_tokens: List[int] = field(default_factory=list) |
| positions: List[int] = field(default_factory=list) |
|
|
| def __post_init__(self): |
| assert self.op_id >= 0, f"Invalid op_id: {self.op_id}" |
| assert self.region_id >= 0, f"Invalid region_id: {self.region_id}" |
| assert self.i_start >= 0, f"Invalid i_start: {self.i_start}" |
| assert self.i_end >= -1, f"Invalid i_end: {self.i_end}" |
| if self.i_end != -1: |
| assert self.i_end >= self.i_start, f"i_end ({self.i_end}) < i_start ({self.i_start})" |
| assert 0 <= self.payload_idx < 4096, f"Invalid payload_idx: {self.payload_idx}" |
| assert 0 <= self.confidence <= 1, f"Invalid confidence: {self.confidence}" |
|
|
|
|
| |
|
|
| class EditError(Exception): |
| """Base class for edit errors""" |
| pass |
|
|
| class ScopeBalanceError(EditError): |
| """Operation would break scope balance""" |
| pass |
|
|
| class InvalidPointerError(EditError): |
| """Pointer out of bounds""" |
| pass |
|
|
| class RegionNotFoundError(EditError): |
| """region_id invalid""" |
| pass |
|
|
| class PatternNotFoundError(EditError): |
| """Macro pattern not found in region""" |
| pass |
|
|
| class CrossRegionError(EditError): |
| """Cross-region operation failed""" |
| pass |
|
|
|
|
| |
|
|
| class PaletteEditOps: |
| """ |
| Stateless edit operation executor β 31 operations. |
| |
| All methods are pure functions (no internal state). |
| Thread-safe and deterministic. |
| |
| Constants: |
| START_OF_SCOPE: 0 |
| END_OF_SCOPE: 1 |
| NOOP: 2 |
| """ |
|
|
| START_OF_SCOPE = 0 |
| END_OF_SCOPE = 1 |
| NOOP = 2 |
|
|
| |
| MACRO_PATTERNS = { |
| 'py_for_to_js_for': { |
| 'pattern': [20, 220, 220], |
| 'target': [20, 201, 220], |
| 'name': 'Python for β JavaScript for' |
| }, |
| } |
|
|
| |
| |
| |
|
|
| @staticmethod |
| def apply( |
| palette_img: torch.Tensor, |
| action: EditAction, |
| metadata: RegionMetadata |
| ) -> Tuple[torch.Tensor, bool]: |
| """ |
| Apply edit action to palette. |
| |
| Returns (new_palette, success). |
| Guarantees: original unchanged; if success=False, new == original. |
| """ |
| |
| op_id = _LEGACY_OPCODES.get(action.op_id, action.op_id) |
|
|
| if action.region_id >= len(metadata.starts): |
| return palette_img, False |
|
|
| palette = palette_img.clone() |
|
|
| try: |
| if not PaletteEditOps.verify_scope_balance(palette): |
| raise ScopeBalanceError("Input palette has unbalanced scopes") |
|
|
| |
| if op_id == OpCode.NO_OP: |
| new_palette = palette |
|
|
| |
| elif op_id == OpCode.DELETE_RANGE: |
| new_palette = PaletteEditOps.delete_range( |
| palette, action.region_id, action.i_start, action.i_end, metadata) |
|
|
| elif op_id == OpCode.INSERT_TOKEN: |
| new_palette = PaletteEditOps.insert_token( |
| palette, action.region_id, action.i_start, action.payload_idx, metadata) |
|
|
| elif op_id == OpCode.REPLACE_TOKEN: |
| new_palette = PaletteEditOps.replace_token( |
| palette, action.region_id, action.i_start, action.payload_idx, metadata) |
|
|
| elif op_id == OpCode.SWAP_TOKENS: |
| new_palette = PaletteEditOps.swap_tokens( |
| palette, action.region_id, action.i_start, action.i_end, metadata) |
|
|
| elif op_id == OpCode.MOVE_RANGE: |
| new_palette = PaletteEditOps.move_range( |
| palette, action.region_id, action.i_start, action.i_end, |
| action.target_region_id, action.payload_idx, metadata) |
|
|
| elif op_id == OpCode.COPY_RANGE: |
| new_palette = PaletteEditOps.copy_range( |
| palette, action.region_id, action.i_start, action.i_end, |
| action.target_region_id, action.payload_idx, metadata) |
|
|
| elif op_id == OpCode.WRAP_SCOPE: |
| new_palette = PaletteEditOps.wrap_scope( |
| palette, action.region_id, action.i_start, action.i_end, metadata) |
|
|
| elif op_id == OpCode.UNWRAP_SCOPE: |
| new_palette = PaletteEditOps.unwrap_scope( |
| palette, action.region_id, metadata) |
|
|
| |
| elif op_id == OpCode.INDENT: |
| new_palette = PaletteEditOps.indent( |
| palette, action.region_id, action.i_start, action.i_end, metadata) |
|
|
| elif op_id == OpCode.DEDENT: |
| new_palette = PaletteEditOps.dedent( |
| palette, action.region_id, action.i_start, action.i_end, metadata) |
|
|
| elif op_id == OpCode.EXTRACT: |
| new_palette = PaletteEditOps.extract( |
| palette, action.region_id, action.i_start, action.i_end, metadata) |
|
|
| elif op_id == OpCode.INLINE: |
| new_palette = PaletteEditOps.inline( |
| palette, action.region_id, action.target_region_id, metadata) |
|
|
| elif op_id == OpCode.SPLIT_REGION: |
| new_palette = PaletteEditOps.split_region( |
| palette, action.region_id, action.i_start, metadata) |
|
|
| elif op_id == OpCode.MERGE_REGIONS: |
| new_palette = PaletteEditOps.merge_regions( |
| palette, action.region_id, action.target_region_id, metadata) |
|
|
| elif op_id == OpCode.REORDER: |
| new_palette = PaletteEditOps.reorder( |
| palette, action.region_id, action.i_start, action.i_end, metadata) |
|
|
| elif op_id == OpCode.NEST_IN_BLOCK: |
| new_palette = PaletteEditOps.nest_in_block( |
| palette, action.region_id, action.i_start, action.i_end, |
| action.payload_idx, metadata) |
|
|
| elif op_id == OpCode.UNNEST_FROM_BLOCK: |
| new_palette = PaletteEditOps.unnest_from_block( |
| palette, action.region_id, metadata) |
|
|
| elif op_id == OpCode.HOIST: |
| new_palette = PaletteEditOps.hoist( |
| palette, action.region_id, action.i_start, action.i_end, metadata) |
|
|
| elif op_id == OpCode.SINK: |
| new_palette = PaletteEditOps.sink( |
| palette, action.region_id, action.i_start, action.i_end, |
| action.target_region_id, metadata) |
|
|
| |
| elif op_id == OpCode.RENAME: |
| new_palette = PaletteEditOps.rename( |
| palette, action.region_id, action.i_start, action.payload_idx, metadata) |
|
|
| elif op_id == OpCode.RETYPE: |
| new_palette = PaletteEditOps.retype( |
| palette, action.region_id, action.i_start, action.i_end, |
| action.payload_tokens, metadata) |
|
|
| elif op_id == OpCode.CONVERT_CONSTRUCT: |
| new_palette = PaletteEditOps.convert_construct( |
| palette, action.region_id, action.payload_tokens, metadata) |
|
|
| elif op_id == OpCode.SYNC_TO_ASYNC: |
| new_palette = PaletteEditOps.sync_to_async( |
| palette, action.region_id, metadata) |
|
|
| elif op_id == OpCode.PARAMETERIZE: |
| new_palette = PaletteEditOps.parameterize( |
| palette, action.region_id, action.i_start, action.payload_idx, metadata) |
|
|
| elif op_id == OpCode.SPECIALIZE: |
| new_palette = PaletteEditOps.specialize( |
| palette, action.region_id, action.i_start, action.i_end, |
| action.payload_tokens, metadata) |
|
|
| elif op_id == OpCode.GUARD: |
| new_palette = PaletteEditOps.guard( |
| palette, action.region_id, action.i_start, action.i_end, |
| action.payload_idx, metadata) |
|
|
| elif op_id == OpCode.UNGUARD: |
| new_palette = PaletteEditOps.unguard( |
| palette, action.region_id, metadata) |
|
|
| elif op_id == OpCode.SCATTER: |
| new_palette = PaletteEditOps.scatter( |
| palette, action.region_id, action.payload_idx, |
| action.positions, metadata) |
|
|
| elif op_id == OpCode.GATHER: |
| new_palette = PaletteEditOps.gather( |
| palette, action.positions, action.region_id, |
| action.i_start, metadata) |
|
|
| elif op_id == OpCode.MIRROR: |
| new_palette = PaletteEditOps.mirror( |
| palette, action.region_id, action.target_region_id, |
| action.i_start, action.i_end, action.payload_idx, metadata) |
|
|
| elif op_id == OpCode.COMPOSE: |
| new_palette = PaletteEditOps.compose( |
| palette, action.region_id, action.i_start, action.i_end, metadata) |
|
|
| else: |
| return palette_img, False |
|
|
| |
| if not PaletteEditOps.verify_scope_balance(new_palette): |
| raise ScopeBalanceError("Operation broke scope balance") |
|
|
| return new_palette, True |
|
|
| except EditError: |
| return palette_img, False |
|
|
| |
| |
| |
|
|
| @staticmethod |
| def delete_range( |
| palette: torch.Tensor, region_id: int, |
| i_start: int, i_end: int, metadata: RegionMetadata |
| ) -> torch.Tensor: |
| """Delete tokens [i_start, i_end] within region. Shift left, pad NOOP.""" |
| H, W = palette.shape |
| positions = PaletteEditOps._get_content_positions(palette, metadata, region_id) |
|
|
| if i_start < 0 or i_start >= len(positions): |
| raise InvalidPointerError(f"i_start={i_start} out of bounds (size={len(positions)})") |
| if i_end < i_start or i_end >= len(positions): |
| raise InvalidPointerError(f"i_end={i_end} out of bounds") |
|
|
| abs_positions = [positions[i] for i in range(i_start, i_end + 1)] |
| palette_flat = palette.flatten() |
|
|
| for pos in abs_positions: |
| if palette_flat[pos].item() in (PaletteEditOps.START_OF_SCOPE, PaletteEditOps.END_OF_SCOPE): |
| raise ScopeBalanceError("Cannot delete scope markers") |
|
|
| delete_mask = torch.zeros(H * W, dtype=torch.bool) |
| for pos in abs_positions: |
| delete_mask[pos] = True |
|
|
| kept = palette_flat[~delete_mask] |
| pad = torch.full((H * W - len(kept),), PaletteEditOps.NOOP, dtype=palette.dtype) |
| return torch.cat([kept, pad]).view(H, W) |
|
|
| @staticmethod |
| def insert_token( |
| palette: torch.Tensor, region_id: int, |
| i_start: int, payload_idx: int, metadata: RegionMetadata |
| ) -> torch.Tensor: |
| """Insert payload_idx at position i_start within region. Shift right, drop last.""" |
| H, W = palette.shape |
| positions = PaletteEditOps._get_content_positions(palette, metadata, region_id) |
|
|
| if i_start < 0 or i_start > len(positions): |
| raise InvalidPointerError(f"i_start={i_start} out of bounds") |
| if payload_idx in (PaletteEditOps.START_OF_SCOPE, PaletteEditOps.END_OF_SCOPE): |
| raise ScopeBalanceError("Cannot insert unpaired scope marker") |
|
|
| abs_pos = positions[i_start] if i_start < len(positions) else (positions[-1] + 1 if positions else 0) |
| flat = palette.flatten() |
| new_flat = torch.zeros(H * W, dtype=palette.dtype) |
| new_flat[:abs_pos] = flat[:abs_pos] |
| new_flat[abs_pos] = payload_idx |
| if abs_pos < H * W - 1: |
| new_flat[abs_pos + 1:] = flat[abs_pos:H * W - 1] |
| return new_flat.view(H, W) |
|
|
| @staticmethod |
| def replace_token( |
| palette: torch.Tensor, region_id: int, |
| i_start: int, payload_idx: int, metadata: RegionMetadata |
| ) -> torch.Tensor: |
| """Replace token at i_start with payload_idx.""" |
| H, W = palette.shape |
| positions = PaletteEditOps._get_content_positions(palette, metadata, region_id) |
|
|
| if i_start < 0 or i_start >= len(positions): |
| raise InvalidPointerError(f"i_start={i_start} out of bounds") |
|
|
| abs_pos = positions[i_start] |
| h, w = abs_pos // W, abs_pos % W |
| old_value = palette[h, w].item() |
|
|
| is_old_scope = old_value in (PaletteEditOps.START_OF_SCOPE, PaletteEditOps.END_OF_SCOPE) |
| is_new_scope = payload_idx in (PaletteEditOps.START_OF_SCOPE, PaletteEditOps.END_OF_SCOPE) |
|
|
| if is_old_scope and not is_new_scope: |
| raise ScopeBalanceError("Cannot replace scope marker with non-marker") |
| if is_old_scope and is_new_scope and old_value != payload_idx: |
| raise ScopeBalanceError("Cannot replace START with END or vice versa") |
|
|
| new_palette = palette.clone() |
| new_palette[h, w] = payload_idx |
| return new_palette |
|
|
| @staticmethod |
| def swap_tokens( |
| palette: torch.Tensor, region_id: int, |
| i_start: int, i_end: int, metadata: RegionMetadata |
| ) -> torch.Tensor: |
| """Swap tokens at i_start and i_end within region.""" |
| H, W = palette.shape |
| positions = PaletteEditOps._get_content_positions(palette, metadata, region_id) |
|
|
| if i_start < 0 or i_start >= len(positions): |
| raise InvalidPointerError(f"i_start={i_start} out of bounds") |
| if i_end < 0 or i_end >= len(positions): |
| raise InvalidPointerError(f"i_end={i_end} out of bounds") |
|
|
| p1, p2 = positions[i_start], positions[i_end] |
| h1, w1 = p1 // W, p1 % W |
| h2, w2 = p2 // W, p2 % W |
| v1, v2 = palette[h1, w1].item(), palette[h2, w2].item() |
|
|
| if {v1, v2} == {PaletteEditOps.START_OF_SCOPE, PaletteEditOps.END_OF_SCOPE}: |
| raise ScopeBalanceError("Cannot swap START β END") |
|
|
| new_palette = palette.clone() |
| new_palette[h1, w1], new_palette[h2, w2] = palette[h2, w2], palette[h1, w1] |
| return new_palette |
|
|
| @staticmethod |
| def move_range( |
| palette: torch.Tensor, src_region: int, |
| i_start: int, i_end: int, |
| dst_region: int, dst_pos: int, |
| metadata: RegionMetadata |
| ) -> torch.Tensor: |
| """Move tokens [i_start,i_end] from src_region to dst_pos in dst_region. |
| = copy + delete source. Cross-region cut-paste.""" |
| if dst_region < 0: |
| dst_region = src_region |
|
|
| |
| result = PaletteEditOps.copy_range( |
| palette, src_region, i_start, i_end, dst_region, dst_pos, metadata) |
|
|
| |
| |
| |
| |
| n_copied = i_end - i_start + 1 |
| src_positions = PaletteEditOps._get_content_positions(result, metadata, src_region) |
|
|
| |
| orig_positions = PaletteEditOps._get_content_positions(palette, metadata, src_region) |
| orig_flat = palette.flatten() |
| src_values = [orig_flat[orig_positions[i]].item() for i in range(i_start, i_end + 1)] |
|
|
| |
| result_flat = result.flatten() |
| H, W = result.shape |
| delete_mask = torch.zeros(H * W, dtype=torch.bool) |
| deleted = 0 |
| for pos in src_positions: |
| val = result_flat[pos].item() |
| if deleted < n_copied and val == src_values[deleted]: |
| if val not in (PaletteEditOps.START_OF_SCOPE, PaletteEditOps.END_OF_SCOPE): |
| delete_mask[pos] = True |
| deleted += 1 |
|
|
| if deleted == 0: |
| raise InvalidPointerError("Could not locate source tokens for deletion") |
|
|
| kept = result_flat[~delete_mask] |
| pad = torch.full((H * W - len(kept),), PaletteEditOps.NOOP, dtype=palette.dtype) |
| return torch.cat([kept, pad]).view(H, W) |
|
|
| @staticmethod |
| def copy_range( |
| palette: torch.Tensor, src_region: int, |
| i_start: int, i_end: int, |
| dst_region: int, dst_pos: int, |
| metadata: RegionMetadata |
| ) -> torch.Tensor: |
| """Copy tokens [i_start,i_end] from src_region, insert at dst_pos in dst_region.""" |
| if dst_region < 0: |
| dst_region = src_region |
|
|
| src_positions = PaletteEditOps._get_content_positions(palette, metadata, src_region) |
| if i_start < 0 or i_end >= len(src_positions): |
| raise InvalidPointerError(f"Source range [{i_start},{i_end}] out of bounds") |
|
|
| flat = palette.flatten() |
| copied_tokens = [flat[src_positions[i]].item() for i in range(i_start, i_end + 1)] |
|
|
| |
| for t in copied_tokens: |
| if t in (PaletteEditOps.START_OF_SCOPE, PaletteEditOps.END_OF_SCOPE): |
| raise ScopeBalanceError("Cannot copy scope markers without pairing") |
|
|
| dst_positions = PaletteEditOps._get_content_positions(palette, metadata, dst_region) |
| if dst_pos < 0 or dst_pos > len(dst_positions): |
| raise InvalidPointerError(f"dst_pos={dst_pos} out of bounds") |
|
|
| abs_dst = dst_positions[dst_pos] if dst_pos < len(dst_positions) else ( |
| dst_positions[-1] + 1 if dst_positions else 0) |
|
|
| H, W = palette.shape |
| new_flat = torch.full((H * W,), PaletteEditOps.NOOP, dtype=palette.dtype) |
| n = len(copied_tokens) |
|
|
| new_flat[:abs_dst] = flat[:abs_dst] |
| for i, t in enumerate(copied_tokens): |
| if abs_dst + i < H * W: |
| new_flat[abs_dst + i] = t |
| remaining = min(H * W - abs_dst - n, H * W - abs_dst) |
| if remaining > 0: |
| new_flat[abs_dst + n:abs_dst + n + remaining] = flat[abs_dst:abs_dst + remaining] |
|
|
| return new_flat.view(H, W) |
|
|
| @staticmethod |
| def wrap_scope( |
| palette: torch.Tensor, region_id: int, |
| i_start: int, i_end: int, |
| metadata: RegionMetadata |
| ) -> torch.Tensor: |
| """Wrap tokens [i_start,i_end] in new scope markers (START...END). |
| Inserts START before i_start, END after i_end. Balanced by construction.""" |
| H, W = palette.shape |
| positions = PaletteEditOps._get_content_positions(palette, metadata, region_id) |
|
|
| if i_start < 0 or i_end >= len(positions) or i_end < i_start: |
| raise InvalidPointerError(f"Range [{i_start},{i_end}] invalid") |
|
|
| abs_start = positions[i_start] |
| abs_end = positions[i_end] |
| flat = palette.flatten() |
|
|
| |
| |
| before = flat[:abs_start].tolist() |
| wrapped = flat[abs_start:abs_end + 1].tolist() |
| after = flat[abs_end + 1:].tolist() |
|
|
| new_seq = before + [PaletteEditOps.START_OF_SCOPE] + wrapped + [PaletteEditOps.END_OF_SCOPE] + after |
| |
| if len(new_seq) > H * W: |
| new_seq = new_seq[:H * W] |
| elif len(new_seq) < H * W: |
| new_seq.extend([PaletteEditOps.NOOP] * (H * W - len(new_seq))) |
|
|
| return torch.tensor(new_seq, dtype=palette.dtype).view(H, W) |
|
|
| @staticmethod |
| def unwrap_scope( |
| palette: torch.Tensor, region_id: int, |
| metadata: RegionMetadata |
| ) -> torch.Tensor: |
| """Remove the outermost scope markers of a region. Content preserved, scope removed. |
| Removes START at region start and END at region end.""" |
| H, W = palette.shape |
| flat = palette.flatten() |
|
|
| start_pos = metadata.starts[region_id] |
| end_pos = metadata.ends[region_id] |
|
|
| |
| if flat[start_pos].item() != PaletteEditOps.START_OF_SCOPE: |
| raise ScopeBalanceError("Region start is not START_OF_SCOPE") |
| if flat[end_pos].item() != PaletteEditOps.END_OF_SCOPE: |
| raise ScopeBalanceError("Region end is not END_OF_SCOPE") |
|
|
| |
| delete_mask = torch.zeros(H * W, dtype=torch.bool) |
| delete_mask[start_pos] = True |
| delete_mask[end_pos] = True |
|
|
| kept = flat[~delete_mask] |
| pad = torch.full((H * W - len(kept),), PaletteEditOps.NOOP, dtype=palette.dtype) |
| return torch.cat([kept, pad]).view(H, W) |
|
|
| |
| |
| |
|
|
| @staticmethod |
| def indent( |
| palette: torch.Tensor, region_id: int, |
| i_start: int, i_end: int, |
| metadata: RegionMetadata |
| ) -> torch.Tensor: |
| """Increase scope depth: wrap [i_start,i_end] in new scope. |
| Equivalent to wrap_scope β increases nesting by 1.""" |
| return PaletteEditOps.wrap_scope(palette, region_id, i_start, i_end, metadata) |
|
|
| @staticmethod |
| def dedent( |
| palette: torch.Tensor, region_id: int, |
| i_start: int, i_end: int, |
| metadata: RegionMetadata |
| ) -> torch.Tensor: |
| """Decrease scope depth: remove innermost scope around [i_start,i_end]. |
| Finds the tightest enclosing scope and removes its markers.""" |
| H, W = palette.shape |
| flat = palette.flatten() |
| positions = PaletteEditOps._get_content_positions(palette, metadata, region_id) |
|
|
| if i_start < 0 or i_end >= len(positions): |
| raise InvalidPointerError(f"Range [{i_start},{i_end}] invalid") |
|
|
| abs_start = positions[i_start] |
| abs_end = positions[i_end] |
|
|
| |
| enclosing_start = -1 |
| for i in range(abs_start - 1, -1, -1): |
| if flat[i].item() == PaletteEditOps.START_OF_SCOPE: |
| enclosing_start = i |
| break |
|
|
| if enclosing_start < 0: |
| raise ScopeBalanceError("No enclosing scope to dedent from") |
|
|
| |
| depth = 0 |
| enclosing_end = -1 |
| for i in range(enclosing_start, H * W): |
| v = flat[i].item() |
| if v == PaletteEditOps.START_OF_SCOPE: |
| depth += 1 |
| elif v == PaletteEditOps.END_OF_SCOPE: |
| depth -= 1 |
| if depth == 0: |
| enclosing_end = i |
| break |
|
|
| if enclosing_end < 0 or enclosing_end < abs_end: |
| raise ScopeBalanceError("Cannot find matching END for enclosing scope") |
|
|
| |
| delete_mask = torch.zeros(H * W, dtype=torch.bool) |
| delete_mask[enclosing_start] = True |
| delete_mask[enclosing_end] = True |
|
|
| kept = flat[~delete_mask] |
| pad = torch.full((H * W - len(kept),), PaletteEditOps.NOOP, dtype=palette.dtype) |
| return torch.cat([kept, pad]).view(H, W) |
|
|
| @staticmethod |
| def extract( |
| palette: torch.Tensor, region_id: int, |
| i_start: int, i_end: int, |
| metadata: RegionMetadata |
| ) -> torch.Tensor: |
| """Extract tokens [i_start,i_end] into a new scope appended after current region. |
| Source range replaced with a reference token (payload_idx=3 = EXTRACTED_REF). |
| New scope with extracted content appears after current region's END.""" |
| EXTRACTED_REF = 3 |
| H, W = palette.shape |
| flat = palette.flatten() |
| positions = PaletteEditOps._get_content_positions(palette, metadata, region_id) |
|
|
| if i_start < 0 or i_end >= len(positions) or i_end < i_start: |
| raise InvalidPointerError(f"Range [{i_start},{i_end}] invalid") |
|
|
| |
| extracted = [flat[positions[i]].item() for i in range(i_start, i_end + 1)] |
| for t in extracted: |
| if t in (PaletteEditOps.START_OF_SCOPE, PaletteEditOps.END_OF_SCOPE): |
| raise ScopeBalanceError("Cannot extract scope markers") |
|
|
| |
| abs_start = positions[i_start] |
| abs_end = positions[i_end] |
|
|
| before = flat[:abs_start].tolist() |
| after = flat[abs_end + 1:].tolist() |
| middle = [EXTRACTED_REF] |
|
|
| |
| region_end = metadata.ends[region_id] |
| |
| n_removed = (i_end - i_start + 1) |
| adj_end = region_end - n_removed + 1 |
|
|
| new_scope = [PaletteEditOps.START_OF_SCOPE] + extracted + [PaletteEditOps.END_OF_SCOPE] |
|
|
| seq = before + middle + after |
| |
| insert_at = min(adj_end + 1, len(seq)) |
| seq = seq[:insert_at] + new_scope + seq[insert_at:] |
|
|
| |
| if len(seq) > H * W: |
| seq = seq[:H * W] |
| else: |
| seq.extend([PaletteEditOps.NOOP] * (H * W - len(seq))) |
|
|
| return torch.tensor(seq, dtype=palette.dtype).view(H, W) |
|
|
| @staticmethod |
| def inline( |
| palette: torch.Tensor, src_region: int, |
| target_region: int, metadata: RegionMetadata |
| ) -> torch.Tensor: |
| """Inline: replace a reference in target_region with contents of src_region. |
| Opposite of extract. Removes src_region scope, inserts content at ref position.""" |
| if target_region < 0: |
| raise CrossRegionError("target_region_id required for INLINE") |
|
|
| EXTRACTED_REF = 3 |
| H, W = palette.shape |
| flat = palette.flatten() |
|
|
| |
| src_content = PaletteEditOps._get_content_positions(palette, metadata, src_region) |
| content_tokens = [flat[pos].item() for pos in src_content] |
|
|
| |
| target_positions = PaletteEditOps._get_content_positions(palette, metadata, target_region) |
| ref_pos = -1 |
| for pos in target_positions: |
| if flat[pos].item() == EXTRACTED_REF: |
| ref_pos = pos |
| break |
|
|
| if ref_pos < 0: |
| raise PatternNotFoundError("No EXTRACTED_REF found in target region") |
|
|
| |
| src_start = metadata.starts[src_region] |
| src_end = metadata.ends[src_region] |
|
|
| seq = flat.tolist() |
| |
| ref_idx = seq.index(EXTRACTED_REF) if EXTRACTED_REF in seq else ref_pos |
| seq = seq[:ref_idx] + content_tokens + seq[ref_idx + 1:] |
|
|
| |
| |
| |
| |
| flat_list = flat.tolist() |
|
|
| |
| remove = set(range(src_start, src_end + 1)) |
| cleaned = [(i, v) for i, v in enumerate(flat_list) if i not in remove] |
|
|
| |
| new_seq = [] |
| for _, v in cleaned: |
| if v == EXTRACTED_REF: |
| new_seq.extend(content_tokens) |
| else: |
| new_seq.append(v) |
|
|
| if len(new_seq) > H * W: |
| new_seq = new_seq[:H * W] |
| else: |
| new_seq.extend([PaletteEditOps.NOOP] * (H * W - len(new_seq))) |
|
|
| return torch.tensor(new_seq, dtype=palette.dtype).view(H, W) |
|
|
| @staticmethod |
| def split_region( |
| palette: torch.Tensor, region_id: int, |
| split_at: int, metadata: RegionMetadata |
| ) -> torch.Tensor: |
| """Split region into two at position split_at. |
| Inserts END + START between positions split_at-1 and split_at.""" |
| H, W = palette.shape |
| positions = PaletteEditOps._get_content_positions(palette, metadata, region_id) |
|
|
| if split_at <= 0 or split_at >= len(positions): |
| raise InvalidPointerError(f"split_at={split_at} must be in (0, {len(positions)})") |
|
|
| abs_split = positions[split_at] |
| flat = palette.flatten() |
|
|
| before = flat[:abs_split].tolist() |
| after = flat[abs_split:].tolist() |
| |
| new_seq = before + [PaletteEditOps.END_OF_SCOPE, PaletteEditOps.START_OF_SCOPE] + after |
|
|
| if len(new_seq) > H * W: |
| new_seq = new_seq[:H * W] |
| else: |
| new_seq.extend([PaletteEditOps.NOOP] * (H * W - len(new_seq))) |
|
|
| return torch.tensor(new_seq, dtype=palette.dtype).view(H, W) |
|
|
| @staticmethod |
| def merge_regions( |
| palette: torch.Tensor, region_a: int, |
| region_b: int, metadata: RegionMetadata |
| ) -> torch.Tensor: |
| """Merge two adjacent regions by removing the END of A and START of B. |
| Regions must be adjacent (A's END directly before B's START).""" |
| if region_b < 0: |
| raise CrossRegionError("target_region_id required for MERGE") |
|
|
| H, W = palette.shape |
| flat = palette.flatten() |
|
|
| end_a = metadata.ends[region_a] |
| start_b = metadata.starts[region_b] |
|
|
| |
| if flat[end_a].item() != PaletteEditOps.END_OF_SCOPE: |
| raise ScopeBalanceError("Region A end is not END_OF_SCOPE") |
| if flat[start_b].item() != PaletteEditOps.START_OF_SCOPE: |
| raise ScopeBalanceError("Region B start is not START_OF_SCOPE") |
|
|
| |
| delete_mask = torch.zeros(H * W, dtype=torch.bool) |
| delete_mask[end_a] = True |
| delete_mask[start_b] = True |
|
|
| kept = flat[~delete_mask] |
| pad = torch.full((H * W - len(kept),), PaletteEditOps.NOOP, dtype=palette.dtype) |
| return torch.cat([kept, pad]).view(H, W) |
|
|
| @staticmethod |
| def reorder( |
| palette: torch.Tensor, region_id: int, |
| i_start: int, i_end: int, |
| metadata: RegionMetadata |
| ) -> torch.Tensor: |
| """Reverse the order of tokens [i_start,i_end] within region. |
| Generalizable to arbitrary permutations via payload_tokens.""" |
| H, W = palette.shape |
| positions = PaletteEditOps._get_content_positions(palette, metadata, region_id) |
|
|
| if i_start < 0 or i_end >= len(positions) or i_end < i_start: |
| raise InvalidPointerError(f"Range [{i_start},{i_end}] invalid") |
|
|
| flat = palette.flatten() |
| new_palette = palette.clone() |
| new_flat = new_palette.flatten() |
|
|
| |
| vals = [flat[positions[i]].item() for i in range(i_start, i_end + 1)] |
| vals.reverse() |
|
|
| for i, val in enumerate(vals): |
| pos = positions[i_start + i] |
| new_flat[pos] = val |
|
|
| return new_flat.view(H, W) |
|
|
| @staticmethod |
| def nest_in_block( |
| palette: torch.Tensor, region_id: int, |
| i_start: int, i_end: int, |
| block_type_hue: int, metadata: RegionMetadata |
| ) -> torch.Tensor: |
| """Wrap [i_start,i_end] in a new control block (if/for/try/function). |
| Inserts: START + block_type_hue + [content] + END. |
| The block_type_hue identifies the construct type (20=for, 24=if, etc.).""" |
| H, W = palette.shape |
| positions = PaletteEditOps._get_content_positions(palette, metadata, region_id) |
|
|
| if i_start < 0 or i_end >= len(positions) or i_end < i_start: |
| raise InvalidPointerError(f"Range [{i_start},{i_end}] invalid") |
|
|
| abs_start = positions[i_start] |
| abs_end = positions[i_end] |
| flat = palette.flatten() |
|
|
| before = flat[:abs_start].tolist() |
| content = flat[abs_start:abs_end + 1].tolist() |
| after = flat[abs_end + 1:].tolist() |
|
|
| new_seq = (before |
| + [PaletteEditOps.START_OF_SCOPE, block_type_hue] |
| + content |
| + [PaletteEditOps.END_OF_SCOPE] |
| + after) |
|
|
| if len(new_seq) > H * W: |
| new_seq = new_seq[:H * W] |
| else: |
| new_seq.extend([PaletteEditOps.NOOP] * (H * W - len(new_seq))) |
|
|
| return torch.tensor(new_seq, dtype=palette.dtype).view(H, W) |
|
|
| @staticmethod |
| def unnest_from_block( |
| palette: torch.Tensor, region_id: int, |
| metadata: RegionMetadata |
| ) -> torch.Tensor: |
| """Remove control block scope: remove START, block_type token, and matching END. |
| Content is preserved and lifted to parent scope.""" |
| H, W = palette.shape |
| flat = palette.flatten() |
|
|
| start_pos = metadata.starts[region_id] |
| end_pos = metadata.ends[region_id] |
|
|
| if flat[start_pos].item() != PaletteEditOps.START_OF_SCOPE: |
| raise ScopeBalanceError("Region start is not START_OF_SCOPE") |
| if flat[end_pos].item() != PaletteEditOps.END_OF_SCOPE: |
| raise ScopeBalanceError("Region end is not END_OF_SCOPE") |
|
|
| |
| delete_mask = torch.zeros(H * W, dtype=torch.bool) |
| delete_mask[start_pos] = True |
| if start_pos + 1 < H * W: |
| delete_mask[start_pos + 1] = True |
| delete_mask[end_pos] = True |
|
|
| kept = flat[~delete_mask] |
| pad = torch.full((H * W - len(kept),), PaletteEditOps.NOOP, dtype=palette.dtype) |
| return torch.cat([kept, pad]).view(H, W) |
|
|
| @staticmethod |
| def hoist( |
| palette: torch.Tensor, region_id: int, |
| i_start: int, i_end: int, |
| metadata: RegionMetadata |
| ) -> torch.Tensor: |
| """Hoist: move tokens [i_start,i_end] from current region to before region's START. |
| Declaration moves to higher scope.""" |
| H, W = palette.shape |
| flat = palette.flatten() |
| positions = PaletteEditOps._get_content_positions(palette, metadata, region_id) |
|
|
| if i_start < 0 or i_end >= len(positions) or i_end < i_start: |
| raise InvalidPointerError(f"Range [{i_start},{i_end}] invalid") |
|
|
| |
| hoisted = [flat[positions[i]].item() for i in range(i_start, i_end + 1)] |
| for t in hoisted: |
| if t in (PaletteEditOps.START_OF_SCOPE, PaletteEditOps.END_OF_SCOPE): |
| raise ScopeBalanceError("Cannot hoist scope markers") |
|
|
| |
| abs_positions = [positions[i] for i in range(i_start, i_end + 1)] |
| delete_mask = torch.zeros(H * W, dtype=torch.bool) |
| for pos in abs_positions: |
| delete_mask[pos] = True |
|
|
| cleaned = flat[~delete_mask].tolist() |
|
|
| |
| region_start = metadata.starts[region_id] |
| |
| adj = sum(1 for p in abs_positions if p < region_start) |
| insert_at = region_start - adj |
|
|
| new_seq = cleaned[:insert_at] + hoisted + cleaned[insert_at:] |
|
|
| if len(new_seq) > H * W: |
| new_seq = new_seq[:H * W] |
| else: |
| new_seq.extend([PaletteEditOps.NOOP] * (H * W - len(new_seq))) |
|
|
| return torch.tensor(new_seq, dtype=palette.dtype).view(H, W) |
|
|
| @staticmethod |
| def sink( |
| palette: torch.Tensor, region_id: int, |
| i_start: int, i_end: int, |
| target_region: int, metadata: RegionMetadata |
| ) -> torch.Tensor: |
| """Sink: move tokens from current region into a deeper (child) region. |
| Opposite of hoist. Tokens move from parent scope to target child scope.""" |
| if target_region < 0: |
| raise CrossRegionError("target_region_id required for SINK") |
|
|
| |
| return PaletteEditOps.move_range( |
| palette, region_id, i_start, i_end, |
| target_region, 0, metadata) |
|
|
| |
| |
| |
|
|
| @staticmethod |
| def rename( |
| palette: torch.Tensor, region_id: int, |
| i_start: int, new_hue: int, metadata: RegionMetadata |
| ) -> torch.Tensor: |
| """Rename: replace identifier hue at i_start with new_hue. |
| Same as REPLACE but semantically constrained to identifiers.""" |
| return PaletteEditOps.replace_token(palette, region_id, i_start, new_hue, metadata) |
|
|
| @staticmethod |
| def retype( |
| palette: torch.Tensor, region_id: int, |
| i_start: int, i_end: int, |
| new_type_tokens: List[int], metadata: RegionMetadata |
| ) -> torch.Tensor: |
| """Retype: replace type annotation range [i_start,i_end] with new tokens. |
| Handles type annotations that may change length (int β List[int]).""" |
| H, W = palette.shape |
| positions = PaletteEditOps._get_content_positions(palette, metadata, region_id) |
|
|
| if i_start < 0 or i_end >= len(positions) or i_end < i_start: |
| raise InvalidPointerError(f"Range [{i_start},{i_end}] invalid") |
|
|
| abs_start = positions[i_start] |
| abs_end = positions[i_end] |
| flat = palette.flatten() |
|
|
| before = flat[:abs_start].tolist() |
| after = flat[abs_end + 1:].tolist() |
| new_seq = before + list(new_type_tokens) + after |
|
|
| if len(new_seq) > H * W: |
| new_seq = new_seq[:H * W] |
| else: |
| new_seq.extend([PaletteEditOps.NOOP] * (H * W - len(new_seq))) |
|
|
| return torch.tensor(new_seq, dtype=palette.dtype).view(H, W) |
|
|
| @staticmethod |
| def convert_construct( |
| palette: torch.Tensor, region_id: int, |
| pattern_target: List[int], metadata: RegionMetadata |
| ) -> torch.Tensor: |
| """Convert construct: pattern match and replace within region. |
| pattern_target = [*pattern_tokens, -1, *target_tokens] where -1 is separator. |
| If empty, falls back to built-in MACRO_PATTERNS.""" |
| if not pattern_target: |
| |
| macro = list(PaletteEditOps.MACRO_PATTERNS.values())[0] |
| pattern = macro['pattern'] |
| target = macro['target'] |
| else: |
| if -1 not in pattern_target: |
| raise PatternNotFoundError("pattern_target must contain -1 separator") |
| sep = pattern_target.index(-1) |
| pattern = pattern_target[:sep] |
| target = pattern_target[sep + 1:] |
|
|
| positions = PaletteEditOps._get_content_positions(palette, metadata, region_id) |
| flat = palette.flatten() |
| region_tokens = [flat[pos].item() for pos in positions] |
|
|
| plen = len(pattern) |
| found = False |
| new_palette = palette.clone() |
|
|
| for i in range(len(region_tokens) - plen + 1): |
| if region_tokens[i:i + plen] == pattern: |
| |
| if len(target) == plen: |
| |
| for j, t in enumerate(target): |
| pos = positions[i + j] |
| h, w = pos // palette.shape[1], pos % palette.shape[1] |
| new_palette[h, w] = t |
| else: |
| |
| abs_start = positions[i] |
| abs_end = positions[i + plen - 1] |
| H, W = palette.shape |
| flat_list = flat.tolist() |
| new_seq = flat_list[:abs_start] + target + flat_list[abs_end + 1:] |
| if len(new_seq) > H * W: |
| new_seq = new_seq[:H * W] |
| else: |
| new_seq.extend([PaletteEditOps.NOOP] * (H * W - len(new_seq))) |
| new_palette = torch.tensor(new_seq, dtype=palette.dtype).view(H, W) |
| found = True |
| break |
|
|
| if not found: |
| raise PatternNotFoundError(f"Pattern {pattern} not found in region") |
|
|
| return new_palette |
|
|
| @staticmethod |
| def sync_to_async( |
| palette: torch.Tensor, region_id: int, |
| metadata: RegionMetadata |
| ) -> torch.Tensor: |
| """Add async/await markers to region. |
| Inserts async hue (hue 46) before region's first function-def token (hue 12), |
| and await hue (hue 47) before call tokens (hue 60).""" |
| ASYNC_HUE = 46 |
| AWAIT_HUE = 47 |
| FUNC_DEF_HUE = 12 |
| CALL_HUE = 60 |
|
|
| H, W = palette.shape |
| positions = PaletteEditOps._get_content_positions(palette, metadata, region_id) |
| flat = palette.flatten() |
|
|
| insertions = [] |
| for pos in positions: |
| val = flat[pos].item() |
| if val == FUNC_DEF_HUE: |
| insertions.append((pos, ASYNC_HUE)) |
| elif val == CALL_HUE: |
| insertions.append((pos, AWAIT_HUE)) |
|
|
| if not insertions: |
| raise PatternNotFoundError("No function defs or calls found to make async") |
|
|
| |
| seq = flat.tolist() |
| offset = 0 |
| for abs_pos, hue in sorted(insertions): |
| seq.insert(abs_pos + offset, hue) |
| offset += 1 |
|
|
| if len(seq) > H * W: |
| seq = seq[:H * W] |
| else: |
| seq.extend([PaletteEditOps.NOOP] * (H * W - len(seq))) |
|
|
| return torch.tensor(seq, dtype=palette.dtype).view(H, W) |
|
|
| @staticmethod |
| def parameterize( |
| palette: torch.Tensor, region_id: int, |
| i_start: int, param_hue: int, metadata: RegionMetadata |
| ) -> torch.Tensor: |
| """Replace a hardcoded literal at i_start with a parameter reference (param_hue). |
| The literal hue becomes a variable/parameter hue.""" |
| return PaletteEditOps.replace_token(palette, region_id, i_start, param_hue, metadata) |
|
|
| @staticmethod |
| def specialize( |
| palette: torch.Tensor, region_id: int, |
| i_start: int, i_end: int, |
| concrete_tokens: List[int], metadata: RegionMetadata |
| ) -> torch.Tensor: |
| """Replace generic type tokens [i_start,i_end] with concrete specialization. |
| Opposite of parameterize for types: List[T] β List[int].""" |
| return PaletteEditOps.retype(palette, region_id, i_start, i_end, concrete_tokens, metadata) |
|
|
| @staticmethod |
| def guard( |
| palette: torch.Tensor, region_id: int, |
| i_start: int, i_end: int, |
| guard_hue: int, metadata: RegionMetadata |
| ) -> torch.Tensor: |
| """Wrap [i_start,i_end] in a conditional guard (if/try/etc). |
| Like nest_in_block with a guard-specific hue.""" |
| return PaletteEditOps.nest_in_block( |
| palette, region_id, i_start, i_end, guard_hue, metadata) |
|
|
| @staticmethod |
| def unguard( |
| palette: torch.Tensor, region_id: int, |
| metadata: RegionMetadata |
| ) -> torch.Tensor: |
| """Remove conditional guard from region. Content lifted to parent scope. |
| Like unnest_from_block.""" |
| return PaletteEditOps.unnest_from_block(palette, region_id, metadata) |
|
|
| @staticmethod |
| def scatter( |
| palette: torch.Tensor, region_id: int, |
| new_hue: int, target_positions: List[int], |
| metadata: RegionMetadata |
| ) -> torch.Tensor: |
| """Replace token at multiple positions with new_hue. |
| Same change applied to N locations (rename-all, update-all-call-sites).""" |
| positions = PaletteEditOps._get_content_positions(palette, metadata, region_id) |
| new_palette = palette.clone() |
| H, W = palette.shape |
|
|
| for pos_idx in target_positions: |
| if pos_idx < 0 or pos_idx >= len(positions): |
| raise InvalidPointerError(f"Position {pos_idx} out of bounds") |
| abs_pos = positions[pos_idx] |
| h, w = abs_pos // W, abs_pos % W |
| val = new_palette[h, w].item() |
| if val in (PaletteEditOps.START_OF_SCOPE, PaletteEditOps.END_OF_SCOPE): |
| raise ScopeBalanceError("Cannot scatter over scope markers") |
| new_palette[h, w] = new_hue |
|
|
| return new_palette |
|
|
| @staticmethod |
| def gather( |
| palette: torch.Tensor, source_positions: List[int], |
| target_region: int, target_pos: int, |
| metadata: RegionMetadata |
| ) -> torch.Tensor: |
| """Gather: collect tokens from multiple positions into a single location. |
| Tokens at source_positions are removed and concatenated at target_pos in target_region. |
| Opposite of scatter.""" |
| H, W = palette.shape |
| flat = palette.flatten() |
|
|
| |
| |
| gathered_vals = [] |
| for pos in source_positions: |
| if pos < 0 or pos >= H * W: |
| raise InvalidPointerError(f"Source position {pos} out of bounds") |
| val = flat[pos].item() |
| if val in (PaletteEditOps.START_OF_SCOPE, PaletteEditOps.END_OF_SCOPE): |
| raise ScopeBalanceError("Cannot gather scope markers") |
| gathered_vals.append(val) |
|
|
| |
| delete_mask = torch.zeros(H * W, dtype=torch.bool) |
| for pos in source_positions: |
| delete_mask[pos] = True |
| cleaned = flat[~delete_mask].tolist() |
|
|
| |
| target_positions_list = PaletteEditOps._get_content_positions(palette, metadata, target_region) |
| if target_pos < 0 or target_pos > len(target_positions_list): |
| raise InvalidPointerError(f"target_pos={target_pos} out of bounds") |
|
|
| |
| abs_target = target_positions_list[target_pos] if target_pos < len(target_positions_list) else ( |
| target_positions_list[-1] + 1 if target_positions_list else 0) |
| adj = sum(1 for p in source_positions if p < abs_target) |
| abs_target -= adj |
|
|
| new_seq = cleaned[:abs_target] + gathered_vals + cleaned[abs_target:] |
|
|
| if len(new_seq) > H * W: |
| new_seq = new_seq[:H * W] |
| else: |
| new_seq.extend([PaletteEditOps.NOOP] * (H * W - len(new_seq))) |
|
|
| return torch.tensor(new_seq, dtype=palette.dtype).view(H, W) |
|
|
| @staticmethod |
| def mirror( |
| palette: torch.Tensor, region_a: int, |
| region_b: int, i_start: int, i_end: int, |
| new_hue: int, metadata: RegionMetadata |
| ) -> torch.Tensor: |
| """Apply symmetric change to paired regions A and B. |
| Replace tokens at [i_start,i_end] in BOTH regions with new_hue. |
| For getter/setter pairs, request/response symmetry, etc.""" |
| if region_b < 0: |
| raise CrossRegionError("target_region_id required for MIRROR") |
|
|
| new_palette = palette.clone() |
| H, W = palette.shape |
|
|
| for rid in [region_a, region_b]: |
| positions = PaletteEditOps._get_content_positions(new_palette, metadata, rid) |
| if i_start < 0 or i_end >= len(positions): |
| raise InvalidPointerError(f"Range [{i_start},{i_end}] out of bounds in region {rid}") |
| for i in range(i_start, i_end + 1): |
| abs_pos = positions[i] |
| h, w = abs_pos // W, abs_pos % W |
| val = new_palette[h, w].item() |
| if val in (PaletteEditOps.START_OF_SCOPE, PaletteEditOps.END_OF_SCOPE): |
| raise ScopeBalanceError("Cannot mirror over scope markers") |
| new_palette[h, w] = new_hue |
|
|
| return new_palette |
|
|
| @staticmethod |
| def compose( |
| palette: torch.Tensor, region_id: int, |
| i_start: int, i_end: int, |
| metadata: RegionMetadata |
| ) -> torch.Tensor: |
| """Compose: fuse sequential statements [i_start,i_end] into a single expression. |
| Removes intermediate scope boundaries within the range. |
| Tokens are kept, internal START/END pairs are removed.""" |
| H, W = palette.shape |
| flat = palette.flatten() |
|
|
| |
| mask = metadata.masks[region_id] |
| all_positions = mask.nonzero(as_tuple=False) |
| all_flat = sorted((all_positions[:, 0] * W + all_positions[:, 1]).tolist()) |
|
|
| if i_start < 0 or i_end >= len(all_flat) or i_end < i_start: |
| raise InvalidPointerError(f"Range [{i_start},{i_end}] invalid") |
|
|
| |
| range_positions = all_flat[i_start:i_end + 1] |
| delete_mask = torch.zeros(H * W, dtype=torch.bool) |
|
|
| |
| depth = 0 |
| for pos in range_positions: |
| val = flat[pos].item() |
| if val == PaletteEditOps.START_OF_SCOPE: |
| depth += 1 |
| if depth > 1: |
| delete_mask[pos] = True |
| elif val == PaletteEditOps.END_OF_SCOPE: |
| if depth > 1: |
| delete_mask[pos] = True |
| depth -= 1 |
|
|
| kept = flat[~delete_mask] |
| pad = torch.full((H * W - len(kept),), PaletteEditOps.NOOP, dtype=palette.dtype) |
| return torch.cat([kept, pad]).view(H, W) |
|
|
| |
| |
| |
|
|
| @staticmethod |
| def verify_scope_balance(palette: torch.Tensor) -> bool: |
| """Check START_OF_SCOPE count == END_OF_SCOPE count.""" |
| num_starts = (palette == PaletteEditOps.START_OF_SCOPE).sum().item() |
| num_ends = (palette == PaletteEditOps.END_OF_SCOPE).sum().item() |
| return num_starts == num_ends |
|
|
| @staticmethod |
| def _get_content_positions( |
| palette: torch.Tensor, metadata: RegionMetadata, region_id: int |
| ) -> List[int]: |
| """Get flattened positions of content tokens in region (excluding scope markers).""" |
| H, W = palette.shape |
| mask = metadata.masks[region_id] |
| positions = mask.nonzero(as_tuple=False) |
| flat_positions = (positions[:, 0] * W + positions[:, 1]).tolist() |
|
|
| filtered = [] |
| for pos in flat_positions: |
| h, w = pos // W, pos % W |
| token = palette[h, w].item() |
| if token not in (PaletteEditOps.START_OF_SCOPE, PaletteEditOps.END_OF_SCOPE): |
| filtered.append(pos) |
| return sorted(filtered) |
|
|
| @staticmethod |
| def _get_region_positions(mask: torch.Tensor, W: int, palette: torch.Tensor = None) -> List[int]: |
| """Legacy helper β use _get_content_positions instead.""" |
| positions = mask.nonzero(as_tuple=False) |
| flat_positions = (positions[:, 0] * W + positions[:, 1]).tolist() |
| if palette is not None: |
| filtered = [] |
| for pos in flat_positions: |
| h, w = pos // W, pos % W |
| token = palette[h, w].item() |
| if token not in (PaletteEditOps.START_OF_SCOPE, PaletteEditOps.END_OF_SCOPE): |
| filtered.append(pos) |
| return sorted(filtered) |
| return sorted(flat_positions) |
|
|