Image2Caption / src /model.py
khushalcodiste's picture
fix: added
49f8ccd
from __future__ import annotations
import os
import re
from io import BytesIO
from typing import Any
import torch
from PIL import Image
from transformers import AutoModelForCausalLM, AutoProcessor
MODEL_ID = os.getenv("MODEL_ID", "microsoft/Florence-2-base")
MODEL_REVISION = os.getenv("MODEL_REVISION")
DEFAULT_MAX_TOKENS = int(os.getenv("DEFAULT_MAX_TOKENS", "64"))
MAX_MAX_TOKENS = int(os.getenv("MAX_MAX_TOKENS", "256"))
MAX_IMAGE_SIDE = int(os.getenv("MAX_IMAGE_SIDE", "896"))
RESIZE_MULTIPLE = int(os.getenv("RESIZE_MULTIPLE", "32"))
NUM_BEAMS = int(os.getenv("NUM_BEAMS", "3"))
DEFAULT_PROMPT = os.getenv("DEFAULT_PROMPT", "<CAPTION>")
TASK_TOKEN_PATTERN = re.compile(r"^<[^>\s]+>")
_model = None
_processor = None
_device = torch.device("cpu")
_dtype = torch.float32
def _prepare_image(image_bytes: bytes) -> Image.Image:
image = Image.open(BytesIO(image_bytes)).convert("RGB")
width, height = image.size
if width <= MAX_IMAGE_SIDE and height <= MAX_IMAGE_SIDE:
return image
if width >= height:
# Landscape: cap width, preserve aspect ratio.
ratio = MAX_IMAGE_SIDE / width
else:
# Portrait: cap height, preserve aspect ratio.
ratio = MAX_IMAGE_SIDE / height
new_w = max(1, int(width * ratio))
new_h = max(1, int(height * ratio))
# Align dimensions to improve tensor-core friendly shapes.
if RESIZE_MULTIPLE > 1:
new_w = max(RESIZE_MULTIPLE, (new_w // RESIZE_MULTIPLE) * RESIZE_MULTIPLE)
new_h = max(RESIZE_MULTIPLE, (new_h // RESIZE_MULTIPLE) * RESIZE_MULTIPLE)
new_size = (new_w, new_h)
return image.resize(new_size, Image.Resampling.LANCZOS)
def load_model() -> tuple[Any, Any]:
global _model, _processor
if _model is None or _processor is None:
pretrained_kwargs: dict[str, Any] = {"trust_remote_code": True}
if MODEL_REVISION:
pretrained_kwargs["revision"] = MODEL_REVISION
_processor = AutoProcessor.from_pretrained(MODEL_ID, **pretrained_kwargs)
_model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=_dtype,
attn_implementation="eager",
**pretrained_kwargs,
).to(_device)
_model.eval()
return _model, _processor
def _build_prompt(text_input: str | None) -> str:
if text_input is None:
return DEFAULT_PROMPT
prompt = text_input.strip()
if not prompt:
return DEFAULT_PROMPT
if not prompt.startswith("<"):
raise ValueError(
"Invalid prompt in `text`: expected a Florence-2 task token like "
"'<CAPTION>' or '<CAPTION_TO_PHRASE_GROUNDING>phrase'."
)
return prompt
def _task_token_from_prompt(prompt: str) -> str:
match = TASK_TOKEN_PATTERN.match(prompt)
return match.group(0) if match else DEFAULT_PROMPT
def generate_caption(
image_bytes: bytes,
text_input: str | None = None,
max_tokens: int = DEFAULT_MAX_TOKENS,
) -> dict[str, Any]:
model, processor = load_model()
prompt = _build_prompt(text_input)
safe_max_tokens = min(max(int(max_tokens), 8), MAX_MAX_TOKENS)
image = _prepare_image(image_bytes)
try:
inputs = processor(text=prompt, images=image, return_tensors="pt")
except AssertionError as exc:
raise ValueError(
"Invalid Florence-2 task format in `text`. For plain captioning, use only "
"'<CAPTION>' with no extra words."
) from exc
input_ids = inputs["input_ids"].to(_device)
pixel_values = inputs["pixel_values"].to(_device, _dtype)
with torch.inference_mode():
generated_ids = model.generate(
input_ids=input_ids,
pixel_values=pixel_values,
do_sample=False,
max_new_tokens=safe_max_tokens,
num_beams=max(1, NUM_BEAMS),
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0].strip()
parsed = None
post_process = getattr(processor, "post_process_generation", None)
if callable(post_process):
try:
parsed = post_process(
generated_text,
task=_task_token_from_prompt(prompt),
image_size=(image.width, image.height),
)
except Exception:
parsed = None
return {"text": generated_text, "parsed": parsed} if parsed else {"text": generated_text}