| """execute_crops node β Gemini code_execution for agentic cropping (PoC 1 style)."""
|
| from __future__ import annotations
|
|
|
| import io
|
| import logging
|
| import time
|
| import uuid
|
| from collections.abc import Callable
|
| from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
|
| from google import genai
|
| from google.genai import types
|
| from PIL import Image
|
|
|
| from config import CROPPER_MODEL, GOOGLE_API_KEY
|
| from prompts.cropper import CROPPER_PROMPT_TEMPLATE
|
| from state import CropTask, DrawingReaderState, ImageRef
|
| from tools.crop_cache import CropCache
|
| from tools.image_store import ImageStore
|
| from tools.pdf_processor import get_page_image_bytes
|
|
|
| logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
| ProgressCallback = Callable[[ImageRef, CropTask, str, int, int], None]
|
|
|
|
|
| MAX_RETRIES = 3
|
| RETRY_BASE_DELAY = 2.0
|
|
|
|
|
| def _extract_last_image(response) -> Image.Image | None:
|
| """Extract the last generated image from a Gemini code_execution response."""
|
| last_image = None
|
| for part in response.candidates[0].content.parts:
|
|
|
| try:
|
| img_data = part.as_image()
|
| if img_data is not None:
|
| last_image = Image.open(io.BytesIO(img_data.image_bytes))
|
| continue
|
| except Exception:
|
| pass
|
|
|
| try:
|
| if hasattr(part, "inline_data") and part.inline_data is not None:
|
| img_bytes = part.inline_data.data
|
| last_image = Image.open(io.BytesIO(img_bytes))
|
| except Exception:
|
| pass
|
| return last_image
|
|
|
|
|
| def _execute_single_crop_sync(
|
| client: genai.Client,
|
| page_image_bytes: bytes,
|
| crop_task: CropTask,
|
| image_store: ImageStore,
|
| ) -> tuple[ImageRef, bool]:
|
| """Execute one crop via Gemini code_execution (synchronous).
|
|
|
| Includes retry logic for transient 503/429 errors.
|
|
|
| Returns
|
| -------
|
| (image_ref, is_fallback)
|
| ``is_fallback`` is True when Gemini failed to produce a crop and the
|
| full page image was returned instead. Fallbacks should NOT be cached.
|
| """
|
| prompt = CROPPER_PROMPT_TEMPLATE.format(
|
| crop_instruction=crop_task["crop_instruction"],
|
| )
|
|
|
| image_part = types.Part.from_bytes(data=page_image_bytes, mime_type="image/png")
|
|
|
|
|
| response = None
|
| for attempt in range(MAX_RETRIES):
|
| try:
|
| response = client.models.generate_content(
|
| model=CROPPER_MODEL,
|
| contents=[image_part, prompt],
|
| config=types.GenerateContentConfig(
|
| tools=[types.Tool(code_execution=types.ToolCodeExecution)]
|
| ),
|
| )
|
| break
|
| except Exception as e:
|
| err_str = str(e)
|
| if ("503" in err_str or "429" in err_str or "UNAVAILABLE" in err_str):
|
| delay = RETRY_BASE_DELAY * (2 ** attempt)
|
| logger.warning(
|
| "Crop API error (attempt %d/%d): %s β retrying in %.1fs",
|
| attempt + 1, MAX_RETRIES, err_str[:120], delay,
|
| )
|
| time.sleep(delay)
|
| else:
|
| raise
|
|
|
| is_fallback = True
|
| if response is not None:
|
| final_image = _extract_last_image(response)
|
| if final_image is not None:
|
| is_fallback = False
|
| else:
|
| final_image = Image.open(io.BytesIO(page_image_bytes))
|
| else:
|
|
|
| final_image = Image.open(io.BytesIO(page_image_bytes))
|
|
|
| crop_id = f"crop_{uuid.uuid4().hex[:6]}"
|
| ref = image_store.save_crop(
|
| page_num=crop_task["page_num"],
|
| crop_id=crop_id,
|
| image=final_image,
|
| label=crop_task["label"],
|
| )
|
| return ref, is_fallback
|
|
|
|
|
| def execute_crops(
|
| state: DrawingReaderState,
|
| image_store: ImageStore,
|
| crop_cache: CropCache | None = None,
|
| progress_callback: ProgressCallback | None = None,
|
| ) -> dict:
|
| """Execute all crop tasks concurrently, reusing cached crops when possible.
|
|
|
| Parameters
|
| ----------
|
| progress_callback
|
| Optional callback invoked on the **main thread** each time a crop
|
| completes (or is served from cache). Called with
|
| ``(image_ref, crop_task, source, completed_count, total_count)``
|
| where *source* is ``"cached"``, ``"completed"``, or ``"fallback"``.
|
| """
|
| crop_tasks = state.get("crop_tasks", [])
|
| page_image_dir = state["page_image_dir"]
|
|
|
| if not crop_tasks:
|
| return {"status_message": ["No crop tasks to execute."]}
|
|
|
| total_count = len(crop_tasks)
|
| completed_count = 0
|
|
|
|
|
| image_refs: list[ImageRef] = []
|
| tasks_to_execute: list[tuple[int, CropTask]] = []
|
| cache_hits = 0
|
|
|
| for i, ct in enumerate(crop_tasks):
|
| if crop_cache is not None:
|
| cached_ref = crop_cache.lookup(ct["page_num"], ct["crop_instruction"])
|
| if cached_ref is not None:
|
| image_refs.append(cached_ref)
|
| cache_hits += 1
|
| completed_count += 1
|
| logger.info(
|
| "Reusing cached crop for '%s' (page %d)",
|
| ct["label"], ct["page_num"],
|
| )
|
|
|
| if progress_callback is not None:
|
| progress_callback(
|
| cached_ref, ct, "cached", completed_count, total_count,
|
| )
|
| continue
|
|
|
| tasks_to_execute.append((i, ct))
|
|
|
|
|
| errors: list[str] = []
|
|
|
| if tasks_to_execute:
|
| client = genai.Client(api_key=GOOGLE_API_KEY)
|
|
|
| with ThreadPoolExecutor(max_workers=min(len(tasks_to_execute), 4)) as pool:
|
| future_to_idx: dict = {}
|
| for exec_idx, (_, ct) in enumerate(tasks_to_execute):
|
| page_bytes = get_page_image_bytes(page_image_dir, ct["page_num"])
|
| future = pool.submit(
|
| _execute_single_crop_sync, client, page_bytes, ct, image_store,
|
| )
|
| future_to_idx[future] = exec_idx
|
|
|
|
|
|
|
| for future in as_completed(future_to_idx):
|
| exec_idx = future_to_idx[future]
|
| orig_idx, ct = tasks_to_execute[exec_idx]
|
| try:
|
| ref, is_fallback = future.result()
|
| image_refs.append(ref)
|
| completed_count += 1
|
|
|
|
|
| if crop_cache is not None:
|
| crop_cache.register(
|
| page_num=ct["page_num"],
|
| crop_instruction=ct["crop_instruction"],
|
| label=ct["label"],
|
| image_ref=ref,
|
| is_fallback=is_fallback,
|
| )
|
|
|
|
|
| if progress_callback is not None:
|
| source = "fallback" if is_fallback else "completed"
|
| progress_callback(
|
| ref, ct, source, completed_count, total_count,
|
| )
|
|
|
| except Exception as e:
|
| completed_count += 1
|
| errors.append(f"Crop task {orig_idx} failed: {e}")
|
| logger.error("Crop task %d failed: %s", orig_idx, e)
|
|
|
|
|
| api_count = len(tasks_to_execute) - len(errors)
|
| parts = [f"Completed {len(image_refs)} of {total_count} crops"]
|
| if cache_hits:
|
| parts.append(f"({cache_hits} from cache, {api_count} new)")
|
| if errors:
|
| parts.append(f"Errors: {'; '.join(errors)}")
|
| status = ". ".join(parts) + "."
|
|
|
| if crop_cache is not None:
|
| logger.info(crop_cache.stats)
|
|
|
| return {
|
| "image_refs": image_refs,
|
| "status_message": [status],
|
| }
|
|
|