Spaces:
Running
Running
| from __future__ import annotations | |
| import os | |
| import re | |
| from io import BytesIO | |
| from typing import Any | |
| import torch | |
| from PIL import Image | |
| from transformers import AutoModelForCausalLM, AutoProcessor | |
| MODEL_ID = os.getenv("MODEL_ID", "microsoft/Florence-2-base") | |
| MODEL_REVISION = os.getenv("MODEL_REVISION") | |
| DEFAULT_MAX_TOKENS = int(os.getenv("DEFAULT_MAX_TOKENS", "64")) | |
| MAX_MAX_TOKENS = int(os.getenv("MAX_MAX_TOKENS", "256")) | |
| MAX_IMAGE_SIDE = int(os.getenv("MAX_IMAGE_SIDE", "896")) | |
| RESIZE_MULTIPLE = int(os.getenv("RESIZE_MULTIPLE", "32")) | |
| NUM_BEAMS = int(os.getenv("NUM_BEAMS", "3")) | |
| DEFAULT_PROMPT = os.getenv("DEFAULT_PROMPT", "<CAPTION>") | |
| TASK_TOKEN_PATTERN = re.compile(r"^<[^>\s]+>") | |
| _model = None | |
| _processor = None | |
| _device = torch.device("cpu") | |
| _dtype = torch.float32 | |
| def _prepare_image(image_bytes: bytes) -> Image.Image: | |
| image = Image.open(BytesIO(image_bytes)).convert("RGB") | |
| width, height = image.size | |
| if width <= MAX_IMAGE_SIDE and height <= MAX_IMAGE_SIDE: | |
| return image | |
| if width >= height: | |
| # Landscape: cap width, preserve aspect ratio. | |
| ratio = MAX_IMAGE_SIDE / width | |
| else: | |
| # Portrait: cap height, preserve aspect ratio. | |
| ratio = MAX_IMAGE_SIDE / height | |
| new_w = max(1, int(width * ratio)) | |
| new_h = max(1, int(height * ratio)) | |
| # Align dimensions to improve tensor-core friendly shapes. | |
| if RESIZE_MULTIPLE > 1: | |
| new_w = max(RESIZE_MULTIPLE, (new_w // RESIZE_MULTIPLE) * RESIZE_MULTIPLE) | |
| new_h = max(RESIZE_MULTIPLE, (new_h // RESIZE_MULTIPLE) * RESIZE_MULTIPLE) | |
| new_size = (new_w, new_h) | |
| return image.resize(new_size, Image.Resampling.LANCZOS) | |
| def load_model() -> tuple[Any, Any]: | |
| global _model, _processor | |
| if _model is None or _processor is None: | |
| pretrained_kwargs: dict[str, Any] = {"trust_remote_code": True} | |
| if MODEL_REVISION: | |
| pretrained_kwargs["revision"] = MODEL_REVISION | |
| _processor = AutoProcessor.from_pretrained(MODEL_ID, **pretrained_kwargs) | |
| _model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=_dtype, | |
| attn_implementation="eager", | |
| **pretrained_kwargs, | |
| ).to(_device) | |
| _model.eval() | |
| return _model, _processor | |
| def _build_prompt(text_input: str | None) -> str: | |
| if text_input is None: | |
| return DEFAULT_PROMPT | |
| prompt = text_input.strip() | |
| if not prompt: | |
| return DEFAULT_PROMPT | |
| if not prompt.startswith("<"): | |
| raise ValueError( | |
| "Invalid prompt in `text`: expected a Florence-2 task token like " | |
| "'<CAPTION>' or '<CAPTION_TO_PHRASE_GROUNDING>phrase'." | |
| ) | |
| return prompt | |
| def _task_token_from_prompt(prompt: str) -> str: | |
| match = TASK_TOKEN_PATTERN.match(prompt) | |
| return match.group(0) if match else DEFAULT_PROMPT | |
| def generate_caption( | |
| image_bytes: bytes, | |
| text_input: str | None = None, | |
| max_tokens: int = DEFAULT_MAX_TOKENS, | |
| ) -> dict[str, Any]: | |
| model, processor = load_model() | |
| prompt = _build_prompt(text_input) | |
| safe_max_tokens = min(max(int(max_tokens), 8), MAX_MAX_TOKENS) | |
| image = _prepare_image(image_bytes) | |
| try: | |
| inputs = processor(text=prompt, images=image, return_tensors="pt") | |
| except AssertionError as exc: | |
| raise ValueError( | |
| "Invalid Florence-2 task format in `text`. For plain captioning, use only " | |
| "'<CAPTION>' with no extra words." | |
| ) from exc | |
| input_ids = inputs["input_ids"].to(_device) | |
| pixel_values = inputs["pixel_values"].to(_device, _dtype) | |
| with torch.inference_mode(): | |
| generated_ids = model.generate( | |
| input_ids=input_ids, | |
| pixel_values=pixel_values, | |
| do_sample=False, | |
| max_new_tokens=safe_max_tokens, | |
| num_beams=max(1, NUM_BEAMS), | |
| ) | |
| generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0].strip() | |
| parsed = None | |
| post_process = getattr(processor, "post_process_generation", None) | |
| if callable(post_process): | |
| try: | |
| parsed = post_process( | |
| generated_text, | |
| task=_task_token_from_prompt(prompt), | |
| image_size=(image.width, image.height), | |
| ) | |
| except Exception: | |
| parsed = None | |
| return {"text": generated_text, "parsed": parsed} if parsed else {"text": generated_text} | |