ArteFact / backend /runner /patch_inference.py
samwaugh's picture
Agent attempted fix
4ac1f80
"""
patch_inference.py
──────────────────
Fast patch‑text similarity ranking on top of the existing PaintingCLIP
inference pipeline.
Public API
----------
rank_sentences_for_cell(...)
list_grid_scores(...) # optional diagnostic helper
"""
from __future__ import annotations
import math
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, List, Tuple
import torch
import torch.nn.functional as F
from PIL import Image
# Local import: reuse the heavyweight initialiser & sentence metadata
from .inference import _initialize_pipeline # same package, no circular import
# ════════════════════════════════════════════════════════════════════════════
# Internal helpers
# ════════════════════════════════════════════════════════════════════════════
def _infer_patch_hw(num_patches: int) -> Tuple[int, int]:
"""
Infer the ViT patch grid (H, W) from the flat token count.
Works for square layouts only (ViT‑B/32 β†’ 7Γ—7; ViT‑B/16 β†’ 14Γ—14).
"""
root = int(math.sqrt(num_patches))
if root * root == num_patches:
return root, root
raise ValueError(f"Unexpected non‑square patch layout: {num_patches}")
@lru_cache(maxsize=8) # cache a few recent paintings Γ— grid sizes
def _prepare_image(image_path: str, grid_size: Tuple[int, int]) -> torch.Tensor:
"""
Generate cell embeddings for the entire image.
Uses ViT patch embeddings directly for efficiency.
"""
# Load resources from main inference pipeline
processor, model, _, _, _, device = _initialize_pipeline()
# Load and process image
image = Image.open(image_path).convert("RGB")
inputs = processor(images=image, return_tensors="pt")
# Ensure inputs are on the correct device
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
# Get patch embeddings from vision model
vision_out = model.vision_model(**inputs, output_hidden_states=True)
# Exclude CLS (token‑0), keep patch tokens
patch_tokens = vision_out.last_hidden_state[:, 1:, :] # (1, N, 768)
patch_tokens = model.vision_model.post_layernorm(patch_tokens) # LayerNorm
patch_feats = model.visual_projection(patch_tokens) # (1, N, 512)
patch_feats = F.normalize(patch_feats.squeeze(0), dim=-1) # (N, 512)
# 2. Reshape β†’ (D, H, W) to pool channel‑wise
num_patches, dim = patch_feats.shape
H, W = _infer_patch_hw(num_patches)
patch_grid = (
patch_feats.view(H, W, dim)
.permute(2, 0, 1) # (D, H, W)
.unsqueeze(0) # (1, D, H, W) for pooling
)
# 3. Adaptive average‑pool down to UI grid resolution
grid_h, grid_w = grid_size
# Special case: if grid size matches patch grid, no pooling needed
if (grid_h, grid_w) == (H, W):
cell_grid = patch_grid.squeeze(0) # Just remove batch dimension
else:
cell_grid = F.adaptive_avg_pool2d(
patch_grid, output_size=(grid_h, grid_w)
).squeeze(0)
# 4. Flatten β†’ (cells, D) & L2‑normalise
cell_vecs = cell_grid.permute(1, 2, 0).reshape(-1, dim) # (gΒ², 512)
return F.normalize(cell_vecs, dim=-1).to(device)
# ════════════════════════════════════════════════════════════════════════════
# Public API
# ════════════════════════════════════════════════════════════════════════════
def rank_sentences_for_cell(
image_path: str | Path,
cell_row: int,
cell_col: int,
grid_size: Tuple[int, int] = (7, 7),
top_k: int = 25,
filter_topics: List[str] = None,
filter_creators: List[str] = None,
) -> List[Dict[str, Any]]:
"""
Retrieve the *top‑k* sentences whose text embeddings align most strongly
with a specific grid cell of the painting.
Parameters
----------
image_path : str | Path
Path (local or mounted) to the RGB image file.
cell_row, cell_col : int
Zero‑indexed row/column of the clicked grid cell.
grid_size : (int, int), default (7, 7)
Resolution of the UI grid (7x7 matches ViT-B/32 patch grid).
top_k : int, default 25
How many sentences to return.
filter_topics : List[str], optional
List of topic codes to filter results by
filter_creators : List[str], optional
List of creator names to filter results by
Returns
-------
List[dict]
Each item is the same schema as `run_inference`, facilitating front‑end reuse:
{ "sentence_id", "score", "english_original", "work", "rank" }
"""
# Shared resources
_proc, _model, sent_mat, sentence_ids, sent_meta, device = _initialize_pipeline()
sent_mat = F.normalize(sent_mat.to(device), dim=-1)
# Apply filtering if needed
if filter_topics or filter_creators:
from .filtering import get_filtered_sentence_ids
valid_sentence_ids = get_filtered_sentence_ids(filter_topics, filter_creators)
# Create mask for valid sentences
valid_indices = [
i for i, sid in enumerate(sentence_ids) if sid in valid_sentence_ids
]
if not valid_indices:
return []
# Filter embeddings and sentence_ids
sent_mat = sent_mat[valid_indices]
sentence_ids = [sentence_ids[i] for i in valid_indices]
# Validate cell indices
grid_h, grid_w = grid_size
if not (0 <= cell_row < grid_h and 0 <= cell_col < grid_w):
raise ValueError(f"Cell ({cell_row}, {cell_col}) outside grid {grid_size}")
# Cell feature vector
cell_idx = cell_row * grid_w + cell_col
cell_vecs = _prepare_image(image_path, grid_size)
cell_vec = cell_vecs[cell_idx]
# Cosine similarity and ranking
scores = torch.matmul(sent_mat, cell_vec)
k = min(top_k, scores.size(0))
top_scores, top_idx = torch.topk(scores, k)
# Assemble output
out: List[Dict[str, Any]] = []
for rank, (idx, sc) in enumerate(zip(top_idx.tolist(), top_scores.tolist()), 1):
sid = sentence_ids[idx]
meta = sent_meta.get(
sid,
{"English Original": f"[Sentence data not found for {sid}]"},
)
work_id = sid.split("_")[0]
out.append(
{
"id": sid, # Frontend expects "id", not "sentence_id"
"score": float(sc),
"english_original": meta.get("English Original", ""),
"work": work_id,
"rank": rank,
}
)
return out
# Optional helper for debugging / heat‑map pre‑computation
def list_grid_scores(
image_path: str | Path,
grid_size: Tuple[int, int] = (7, 7), # Changed default to 7x7
) -> torch.Tensor:
"""
Return the full similarity matrix of shape (sentences, cells).
Primarily for diagnostics or off‑line analysis.
"""
_p, _m, sent_mat, *_, device = _initialize_pipeline()
sent_mat = F.normalize(sent_mat.to(device), dim=-1)
cell_vecs = _prepare_image(image_path, grid_size)
return sent_mat @ cell_vecs.T # (S, gΒ²)