ltmarx / core /detector.ts
harelcain's picture
Upload 13 files
53bf5b7 verified
/**
* High-level watermark detector
*
* Takes a Y plane + key + config → returns detection result
*/
import type { WatermarkConfig, DetectionResult, Buffer2D } from './types.js';
import { createBuffer2D, yPlaneToBuffer, dwtForward, extractSubband } from './dwt.js';
import { dctForward8x8, extractBlock, ZIGZAG_ORDER } from './dct.js';
import { dmqimExtractSoft } from './dmqim.js';
import { crcVerify } from './crc.js';
import { BchCodec } from './bch.js';
import { generateDithers, generatePermutation } from './keygen.js';
import { computeTileGrid, recoverTileGrid, getTileOrigin, getTileBlocks, type TileGrid } from './tiling.js';
import { blockAcEnergy, computeMaskingFactors } from './masking.js';
import { bitsToPayload } from './embedder.js';
import { PRESETS } from './presets.js';
import type { PresetName } from './types.js';
/**
* Detect and extract watermark from a single Y plane
*/
export function detectWatermark(
yPlane: Uint8Array,
width: number,
height: number,
key: string,
config: WatermarkConfig
): DetectionResult {
return detectWatermarkMultiFrame([yPlane], width, height, key, config);
}
/**
* Extract per-tile soft decisions from a single Y plane.
* Returns an array of soft-bit vectors, one per tile.
*/
/** Precomputed DWT subband + tile grid info for a frame */
interface FrameDWT {
hlSubband: Buffer2D;
subbandTilePeriod: number;
}
function computeFrameDWT(
yPlane: Uint8Array,
width: number,
height: number,
config: WatermarkConfig
): FrameDWT {
const buf = yPlaneToBuffer(yPlane, width, height);
const { buf: dwtBuf, dims } = dwtForward(buf, config.dwtLevels);
const hlSubband = extractSubband(dwtBuf, dims[dims.length - 1].w, dims[dims.length - 1].h, 'HL');
const subbandTilePeriod = Math.floor(config.tilePeriod / (1 << config.dwtLevels));
return { hlSubband, subbandTilePeriod };
}
function extractSoftBitsFromSubband(
hlSubband: Buffer2D,
tileGrid: TileGrid,
key: string,
config: WatermarkConfig,
ditherOffX: number = 0,
ditherOffY: number = 0,
blocksPerSide: number = 0,
): { tileSoftBits: Float64Array[]; totalTiles: number } | null {
if (tileGrid.totalTiles === 0) return null;
const codedLength = config.bch.n;
const maxCoeffsPerTile = 1024;
const dithers = generateDithers(key, maxCoeffsPerTile, config.delta);
const tileSoftBits: Float64Array[] = [];
const blockBuf = new Float64Array(64);
// Precompute zigzag → coefficient index mapping
const zigCoeffIdx = new Int32Array(config.zigzagPositions.length);
for (let z = 0; z < config.zigzagPositions.length; z++) {
const [r, c] = ZIGZAG_ORDER[config.zigzagPositions[z]];
zigCoeffIdx[z] = r * 8 + c;
}
const hasDitherOffset = ditherOffX !== 0 || ditherOffY !== 0;
const numZig = config.zigzagPositions.length;
for (let tileIdx = 0; tileIdx < tileGrid.totalTiles; tileIdx++) {
let ditherIdx = 0; // Reset per tile — matches embedder
const origin = getTileOrigin(tileGrid, tileIdx);
const blocks = getTileBlocks(origin.x, origin.y, tileGrid.tilePeriod, hlSubband.width, hlSubband.height);
const tileOriginBlockRow = Math.floor(origin.y / 8);
const tileOriginBlockCol = Math.floor(origin.x / 8);
const softBits = new Float64Array(codedLength);
const bitCounts = new Float64Array(codedLength);
let maskingFactors: Float64Array | null = null;
if (config.perceptualMasking && blocks.length > 0) {
const energies = new Float64Array(blocks.length);
for (let bi = 0; bi < blocks.length; bi++) {
extractBlock(hlSubband.data, hlSubband.width, blocks[bi].row, blocks[bi].col, blockBuf);
dctForward8x8(blockBuf);
energies[bi] = blockAcEnergy(blockBuf);
}
maskingFactors = computeMaskingFactors(energies);
}
let bitIdx = 0;
for (let bi = 0; bi < blocks.length; bi++) {
const { row, col } = blocks[bi];
extractBlock(hlSubband.data, hlSubband.width, row, col, blockBuf);
dctForward8x8(blockBuf);
const maskFactor = maskingFactors ? maskingFactors[bi] : 1.0;
const effectiveDelta = config.delta * maskFactor;
// Compute dither index and bit index: when dither offset is active,
// remap block position to find the embedder's dither and bit assignment
// for this spatial location within the periodic tile structure.
let blockDitherBase: number;
if (hasDitherOffset && blocksPerSide > 0) {
const relBr = row - tileOriginBlockRow;
const relBc = col - tileOriginBlockCol;
const origR = (relBr + ditherOffY) % blocksPerSide;
const origC = (relBc + ditherOffX) % blocksPerSide;
blockDitherBase = (origR * blocksPerSide + origC) * numZig;
// Remap bitIdx to match the embedder's bit assignment at the original position
bitIdx = ((origR * blocksPerSide + origC) * numZig) % codedLength;
} else {
blockDitherBase = ditherIdx;
}
for (let z = 0; z < zigCoeffIdx.length; z++) {
if (bitIdx >= codedLength) bitIdx = 0;
const coeffIdx = zigCoeffIdx[z];
const dither = hasDitherOffset ? dithers[blockDitherBase + z] : dithers[ditherIdx++];
const soft = dmqimExtractSoft(blockBuf[coeffIdx], effectiveDelta, dither);
softBits[bitIdx] += soft;
bitCounts[bitIdx]++;
bitIdx++;
}
if (!hasDitherOffset) {
// ditherIdx already incremented in the loop above
} else {
ditherIdx += numZig; // keep in sync
}
}
for (let i = 0; i < codedLength; i++) {
if (bitCounts[i] > 0) softBits[i] /= bitCounts[i];
}
tileSoftBits.push(softBits);
}
return { tileSoftBits, totalTiles: tileGrid.totalTiles };
}
/** Options for crop-resilient detection */
export interface DetectOptions {
/** Enable grid-phase search for cropped content */
cropResilient?: boolean;
}
/**
* Detect watermark from multiple Y planes.
* Extracts soft decisions from each frame independently, then combines
* across frames and tiles (never averages raw pixels).
*
* When cropResilient is true, searches over:
* - 16 DWT-pad combinations (0..3 × 0..3 for dwtLevels=2)
* - N×N tile-phase offsets (block-aligned, N = tilePeriod/8)
* Signal magnitude from one frame ranks candidates cheaply, then the
* top candidates are decoded using all frames.
*/
export function detectWatermarkMultiFrame(
yPlanes: Uint8Array[],
width: number,
height: number,
key: string,
config: WatermarkConfig,
options?: DetectOptions,
): DetectionResult {
const noResult: DetectionResult = {
detected: false,
payload: null,
confidence: 0,
tilesDecoded: 0,
tilesTotal: 0,
};
if (yPlanes.length === 0) return noResult;
const codedLength = config.bch.n;
const bch = new BchCodec(config.bch);
const perm = generatePermutation(key, codedLength);
// Helper: try to detect with given frames and explicit tile grid.
// makeSubbandAndGrid can optionally transform the subband (e.g. shift it).
interface SubbandAndGrid {
subband: Buffer2D;
grid: TileGrid;
ditherOffX?: number;
ditherOffY?: number;
blocksPerSide?: number;
}
const tryDetect = (
frames: FrameDWT[],
makeSubbandAndGrid: (hlSubband: Buffer2D, stp: number) => SubbandAndGrid,
): DetectionResult | null => {
const softBits: Float64Array[] = [];
for (const { hlSubband, subbandTilePeriod } of frames) {
const { subband, grid, ditherOffX, ditherOffY, blocksPerSide: bps } = makeSubbandAndGrid(hlSubband, subbandTilePeriod);
const frameResult = extractSoftBitsFromSubband(subband, grid, key, config, ditherOffX ?? 0, ditherOffY ?? 0, bps ?? 0);
if (frameResult) softBits.push(...frameResult.tileSoftBits);
}
if (softBits.length === 0) return null;
return decodeFromSoftBits(softBits, codedLength, perm, bch, config);
};
// Fast path: zero-phase grid (uncropped frames)
const frameDWTs = yPlanes.map((yp) => computeFrameDWT(yp, width, height, config));
const fast = tryDetect(frameDWTs, (hl, stp) => ({
subband: hl,
grid: computeTileGrid(hl.width, hl.height, stp),
}));
if (fast) return fast;
if (!options?.cropResilient) return noResult;
// ── Crop-resilient: joint search over DWT-pad × pixel-shift × dither-offset ──
//
// A crop of C pixels causes three alignment problems:
// 1. DWT pixel pairing: pad by C%4 → search 0..3 per axis (16 combos)
// 2. DCT block alignment: subband shift % 8 → search 0..7 per axis (64)
// 3. Tile dither offset: which block within the tile period does the
// detector's block 0 correspond to? Search 0..blocksPerTileSide-1
// per axis.
//
// All three must be correct simultaneously for signal to emerge, so we
// search them jointly. For each (pad, shift), we compute DCT blocks once
// per scoring frame, then sweep dither offsets cheaply (DMQIM re-indexing
// only, no DCT recomputation).
//
// Scoring uses min(4, nFrames) frames for reliable ranking.
// Top candidates are decoded with ALL frames.
const subbandTilePeriod = Math.floor(config.tilePeriod / (1 << config.dwtLevels));
const effectiveTP = Math.max(8, Math.floor(subbandTilePeriod / 8) * 8);
const blocksPerSide = effectiveTP / 8;
const dwtPads = 1 << config.dwtLevels; // 4 for dwtLevels=2
// Scoring: use frame 0 only for fast candidate ranking (36K candidates)
const nScoringFrames = 1;
interface Candidate {
padTop: number;
padLeft: number;
shiftX: number;
shiftY: number;
ditherOffX: number;
ditherOffY: number;
signalMag: number;
}
const candidates: Candidate[] = [];
// Precompute DWTs for scoring frames, cached by pad
const scoringDWTCache = new Map<string, FrameDWT[]>();
const getScoringDWTs = (padTop: number, padLeft: number): FrameDWT[] => {
const cacheKey = `${padTop},${padLeft}`;
let cached = scoringDWTCache.get(cacheKey);
if (!cached) {
cached = [];
for (let fi = 0; fi < nScoringFrames; fi++) {
if (padTop === 0 && padLeft === 0) {
cached.push(frameDWTs[fi]);
} else {
const { padded, paddedW, paddedH } = padYPlane(yPlanes[fi], width, height, padLeft, padTop);
cached.push(computeFrameDWT(padded, paddedW, paddedH, config));
}
}
scoringDWTCache.set(cacheKey, cached);
}
return cached;
};
// Precompute zigzag → coefficient index mapping for scoring
const numZig = config.zigzagPositions.length;
const zigCoeffIdx = new Int32Array(numZig);
for (let z = 0; z < numZig; z++) {
const [r, c] = ZIGZAG_ORDER[config.zigzagPositions[z]];
zigCoeffIdx[z] = r * 8 + c;
}
const scoreDithers = generateDithers(key, 1024, config.delta);
const blockBuf = new Float64Array(64);
// Phase 1: score all candidates with DCT caching.
// For each (pad, shift), compute DCT once per scoring frame, then sweep
// all dither offsets using only DMQIM re-indexing (no DCT recomputation).
for (let padTop = 0; padTop < dwtPads; padTop++) {
for (let padLeft = 0; padLeft < dwtPads; padLeft++) {
const scoreDWTs = getScoringDWTs(padTop, padLeft);
for (let shiftY = 0; shiftY < 8; shiftY++) {
for (let shiftX = 0; shiftX < 8; shiftX++) {
const hl0 = scoreDWTs[0].hlSubband;
const newW = hl0.width - shiftX;
const newH = hl0.height - shiftY;
if (newW < effectiveTP || newH < effectiveTP) continue;
const grid = computeTileGrid(newW, newH, subbandTilePeriod);
if (grid.totalTiles === 0) continue;
// Use tile 0 for scoring (fast; sufficient for ranking)
const tile0Origin = getTileOrigin(grid, 0);
const tile0Blocks = getTileBlocks(
tile0Origin.x, tile0Origin.y, grid.tilePeriod, newW, newH
);
const nBlocks = tile0Blocks.length;
if (nBlocks === 0) continue;
const tile0OriginBR = Math.floor(tile0Origin.y / 8);
const tile0OriginBC = Math.floor(tile0Origin.x / 8);
const relBR = new Int32Array(nBlocks);
const relBC = new Int32Array(nBlocks);
for (let bi = 0; bi < nBlocks; bi++) {
relBR[bi] = tile0Blocks[bi].row - tile0OriginBR;
relBC[bi] = tile0Blocks[bi].col - tile0OriginBC;
}
// Precompute DCT coefficients + effective deltas per scoring frame
const frameCoeffs: Float64Array[] = [];
const frameDeltas: Float64Array[] = [];
for (let fi = 0; fi < scoreDWTs.length; fi++) {
const shifted = createShiftedSubband(scoreDWTs[fi].hlSubband, shiftX, shiftY);
const coeffs = new Float64Array(nBlocks * numZig);
const deltas = new Float64Array(nBlocks);
if (config.perceptualMasking) {
const energies = new Float64Array(nBlocks);
for (let bi = 0; bi < nBlocks; bi++) {
extractBlock(shifted.data, newW, tile0Blocks[bi].row, tile0Blocks[bi].col, blockBuf);
dctForward8x8(blockBuf);
energies[bi] = blockAcEnergy(blockBuf);
for (let z = 0; z < numZig; z++) {
coeffs[bi * numZig + z] = blockBuf[zigCoeffIdx[z]];
}
}
const factors = computeMaskingFactors(energies);
for (let bi = 0; bi < nBlocks; bi++) {
deltas[bi] = config.delta * factors[bi];
}
} else {
for (let bi = 0; bi < nBlocks; bi++) {
extractBlock(shifted.data, newW, tile0Blocks[bi].row, tile0Blocks[bi].col, blockBuf);
dctForward8x8(blockBuf);
for (let z = 0; z < numZig; z++) {
coeffs[bi * numZig + z] = blockBuf[zigCoeffIdx[z]];
}
deltas[bi] = config.delta;
}
}
frameCoeffs.push(coeffs);
frameDeltas.push(deltas);
}
// Sweep all dither offsets using cached coefficients (DMQIM only)
for (let ditherOffY = 0; ditherOffY < blocksPerSide; ditherOffY++) {
for (let ditherOffX = 0; ditherOffX < blocksPerSide; ditherOffX++) {
if (padTop === 0 && padLeft === 0 && shiftX === 0 && shiftY === 0
&& ditherOffX === 0 && ditherOffY === 0) {
continue; // Already tried in fast path
}
const avg = new Float64Array(codedLength);
let nSamples = 0;
for (let fi = 0; fi < frameCoeffs.length; fi++) {
const coeffs = frameCoeffs[fi];
const deltas = frameDeltas[fi];
const softBits = new Float64Array(codedLength);
const bitCounts = new Float64Array(codedLength);
for (let bi = 0; bi < nBlocks; bi++) {
const origR = ((relBR[bi] + ditherOffY) % blocksPerSide + blocksPerSide) % blocksPerSide;
const origC = ((relBC[bi] + ditherOffX) % blocksPerSide + blocksPerSide) % blocksPerSide;
const blockDitherBase = (origR * blocksPerSide + origC) * numZig;
const ed = deltas[bi];
// Remap bitIdx to match embedder's bit assignment at original position
let bitIdx = ((origR * blocksPerSide + origC) * numZig) % codedLength;
for (let z = 0; z < numZig; z++) {
if (bitIdx >= codedLength) bitIdx = 0;
const soft = dmqimExtractSoft(coeffs[bi * numZig + z], ed, scoreDithers[blockDitherBase + z]);
softBits[bitIdx] += soft;
bitCounts[bitIdx]++;
bitIdx++;
}
}
for (let i = 0; i < codedLength; i++) {
if (bitCounts[i] > 0) softBits[i] /= bitCounts[i];
avg[i] += softBits[i];
}
nSamples++;
}
let mag = 0;
for (let i = 0; i < codedLength; i++) {
avg[i] /= nSamples;
mag += avg[i] * avg[i];
}
candidates.push({ padTop, padLeft, shiftX, shiftY, ditherOffX, ditherOffY, signalMag: mag });
}
}
}
}
}
}
// Sort by signal magnitude and decode top candidates with all frames
candidates.sort((a, b) => b.signalMag - a.signalMag);
const MAX_DECODE = 50;
let bestResult: DetectionResult | null = null;
for (let i = 0; i < Math.min(MAX_DECODE, candidates.length); i++) {
const { padTop, padLeft, shiftX, shiftY, ditherOffX, ditherOffY } = candidates[i];
const dwts = (padTop === 0 && padLeft === 0)
? frameDWTs
: yPlanes.map((yp) => {
const { padded, paddedW, paddedH } = padYPlane(yp, width, height, padLeft, padTop);
return computeFrameDWT(padded, paddedW, paddedH, config);
});
const result = tryDetect(dwts, (hl) => {
const shifted = createShiftedSubband(hl, shiftX, shiftY);
const grid = computeTileGrid(shifted.width, shifted.height, subbandTilePeriod);
return { subband: shifted, grid, ditherOffX, ditherOffY, blocksPerSide };
});
if (result && (!bestResult || result.confidence > bestResult.confidence)) {
bestResult = result;
}
if (bestResult && bestResult.confidence >= 0.95) break;
}
return bestResult ?? noResult;
}
/**
* Pad a Y plane with edge-replicated border pixels to realign DWT pixel pairing.
*/
function padYPlane(
yPlane: Uint8Array,
width: number,
height: number,
padLeft: number,
padTop: number,
): { padded: Uint8Array; paddedW: number; paddedH: number } {
const paddedW = width + padLeft;
const paddedH = height + padTop;
const padded = new Uint8Array(paddedW * paddedH);
for (let y = 0; y < paddedH; y++) {
const srcY = Math.max(0, y - padTop);
for (let x = 0; x < paddedW; x++) {
const srcX = Math.max(0, x - padLeft);
padded[y * paddedW + x] = yPlane[srcY * width + srcX];
}
}
return { padded, paddedW, paddedH };
}
/**
* Create a shifted view of a subband (cheap array copy).
*/
function createShiftedSubband(hl: Buffer2D, shiftX: number, shiftY: number): Buffer2D {
const newW = hl.width - shiftX;
const newH = hl.height - shiftY;
const shifted = createBuffer2D(newW, newH);
for (let y = 0; y < newH; y++) {
const srcOff = (y + shiftY) * hl.width + shiftX;
const dstOff = y * newW;
for (let x = 0; x < newW; x++) {
shifted.data[dstOff + x] = hl.data[srcOff + x];
}
}
return shifted;
}
/**
* Combine soft bits from all tiles, decode, and compute confidence.
* Returns null if decoding fails or confidence is too low.
*/
function decodeFromSoftBits(
allTileSoftBits: Float64Array[],
codedLength: number,
perm: Uint32Array,
bch: BchCodec,
config: WatermarkConfig
): DetectionResult | null {
const combined = new Float64Array(codedLength);
for (const tileSoft of allTileSoftBits) {
for (let i = 0; i < codedLength; i++) {
combined[i] += tileSoft[i];
}
}
for (let i = 0; i < codedLength; i++) {
combined[i] /= allTileSoftBits.length;
}
const decoded = tryDecode(combined, perm, bch, config);
if (!decoded) return null;
// Cross-validate — count how many individual tiles agree with the combined decode
const reEncoded = bch.encode(decoded.rawMessage);
let agreeTiles = 0;
for (const tileSoft of allTileSoftBits) {
const deinterleaved = new Float64Array(codedLength);
for (let i = 0; i < codedLength; i++) {
deinterleaved[i] = tileSoft[perm[i]];
}
let matching = 0;
for (let i = 0; i < codedLength; i++) {
const hardBit = deinterleaved[i] > 0 ? 1 : 0;
if (hardBit === reEncoded[i]) matching++;
}
if (matching / codedLength > 0.65) agreeTiles++;
}
const totalTileCount = allTileSoftBits.length;
const zSingle = 0.3 * Math.sqrt(codedLength);
const pChance = Math.max(1e-10, 0.5 * Math.exp(-0.5 * zSingle * zSingle));
const expected = totalTileCount * pChance;
const stddev = Math.sqrt(totalTileCount * pChance * (1 - pChance));
const z = stddev > 0 ? (agreeTiles - expected) / stddev : agreeTiles > expected ? 100 : 0;
const statsConfidence = Math.max(0, Math.min(1.0, 1 - Math.exp(-z * 0.5)));
const confidence = Math.max(statsConfidence, decoded.softConfidence);
if (confidence < MIN_CONFIDENCE) return null;
return {
detected: true,
payload: decoded.payload,
confidence,
tilesDecoded: agreeTiles,
tilesTotal: allTileSoftBits.length,
};
}
/** Minimum confidence to report a detection (low threshold is fine —
* the statistical model already ensures noise scores near 0%) */
const MIN_CONFIDENCE = 0.75;
/**
* Try to decode soft bits into a payload
*/
function tryDecode(
softBits: Float64Array,
perm: Uint32Array,
bch: BchCodec,
config: WatermarkConfig
): { payload: Uint8Array; rawMessage: Uint8Array; softConfidence: number } | null {
const codedLength = config.bch.n;
// De-interleave
const deinterleaved = new Float64Array(codedLength);
for (let i = 0; i < codedLength; i++) {
deinterleaved[i] = softBits[perm[i]];
}
// BCH soft decode
const { message, reliable } = bch.decodeSoft(deinterleaved);
if (!message) return null;
// Extract the CRC-protected portion (first 32 + crc_bits of the BCH message)
const PAYLOAD_BITS = 32;
const crcProtectedLen = PAYLOAD_BITS + config.crc.bits;
const crcProtected = message.subarray(0, crcProtectedLen);
// CRC verify the 32-bit payload
const verified = crcVerify(crcProtected, config.crc.bits);
if (!verified) return null;
// Convert 32 payload bits to bytes
const payload = bitsToPayload(verified);
// Soft confidence = correlation between soft decisions and decoded codeword.
// Real signal: soft values are large AND agree with the codeword → high correlation.
// Noise that BCH happened to decode: soft values are small/random → low correlation.
const reEncoded = bch.encode(message);
let correlation = 0;
for (let i = 0; i < codedLength; i++) {
const sign = reEncoded[i] === 1 ? 1 : -1;
correlation += sign * deinterleaved[i];
}
correlation /= codedLength;
const softConfidence = Math.max(0, Math.min(1.0, correlation * 2));
return { payload, rawMessage: message, softConfidence };
}
/** Extended detection result that includes which preset matched */
export interface AutoDetectResult extends DetectionResult {
/** The preset that produced the detection (null if not detected) */
presetUsed: PresetName | null;
}
/**
* Auto-detect: try all presets and return the best result.
* No need for the user to know which preset was used during embedding.
*/
export function autoDetect(
yPlane: Uint8Array,
width: number,
height: number,
key: string,
options?: DetectOptions,
): AutoDetectResult {
return autoDetectMultiFrame([yPlane], width, height, key, options);
}
/**
* Auto-detect with multiple frames: try all presets, return the best result.
*/
export function autoDetectMultiFrame(
yPlanes: Uint8Array[],
width: number,
height: number,
key: string,
options?: DetectOptions,
): AutoDetectResult {
let best: AutoDetectResult = {
detected: false,
payload: null,
confidence: 0,
tilesDecoded: 0,
tilesTotal: 0,
presetUsed: null,
};
for (const [name, config] of Object.entries(PRESETS)) {
const result = detectWatermarkMultiFrame(yPlanes, width, height, key, config, options);
if (result.detected && result.confidence > best.confidence) {
best = { ...result, presetUsed: name as PresetName };
}
}
return best;
}