| """Processor for Yasa2 that unifies text + media preprocessing.""" |
|
|
| from __future__ import annotations |
|
|
| import urllib.request |
| from enum import Enum |
| from typing import Any, Dict, List, Literal, Optional, Tuple, Union |
|
|
| import numpy as np |
| import torch |
| from PIL import Image |
| from transformers import AutoTokenizer, ProcessorMixin |
| from transformers.processing_utils import MultiModalData |
|
|
| from .image_processing_yasa2 import ( |
| Yasa2ImageProcessor, |
| estimate_num_tiles_llava_next, |
| estimate_num_tiles_llava_uhd, |
| image_rgb_decoder_pil, |
| image_rgb_decoder_pil_tiling, |
| process_anyres_image, |
| process_anyres_image_uhd, |
| ) |
| from .video_processing_yasa2 import ( |
| Yasa2VideoProcessor, |
| video_rgb_decoder_factory, |
| ) |
|
|
|
|
| class MediaType(str, Enum): |
| IMAGE = "image" |
| VIDEO = "video" |
|
|
|
|
| REKA_IMG_TOKEN = "<REKA_IMG_TOKEN>" |
| IMAGE_START = "<image>" |
| IMAGE_END = "</image>" |
| VIDEO_START = "<video>" |
| VIDEO_END = "</video>" |
| SEP_TOKEN = "<sep>" |
|
|
| PAD_ID = 100257 |
|
|
|
|
| def _read_bytes_from_uri(uri: str) -> bytes: |
| """Read bytes from a local path or HTTP(S) URL. |
| |
| Args: |
| uri: Local file path or HTTP(S) URL. |
| |
| Returns: |
| Raw bytes content. |
| """ |
| if uri.startswith("http://") or uri.startswith("https://"): |
| with urllib.request.urlopen(uri) as response: |
| return response.read() |
| with open(uri, "rb") as f: |
| return f.read() |
|
|
|
|
| def _decode_image_payload( |
| payload: Union[str, bytes], |
| img_tiling: bool, |
| tiling_method: str, |
| tiling_size: int, |
| grid_pinpoints: List[Tuple[int, int]], |
| max_tiles_num: int, |
| patch_size: int, |
| ) -> Dict[str, Any]: |
| """Decode image payload bytes or path into a normalized pixel dict. |
| |
| Args: |
| payload: Image path/URL or raw bytes. |
| img_tiling: Whether to enable tiling. |
| tiling_method: Tiling method identifier. |
| tiling_size: Base tile size. |
| grid_pinpoints: Candidate grid pinpoints. |
| max_tiles_num: Maximum tile count for UHD tiling. |
| patch_size: Patch size for UHD tiling. |
| |
| Returns: |
| Dict with decoded image data and tiling metadata. |
| """ |
| if isinstance(payload, str): |
| payload = _read_bytes_from_uri(payload) |
| if img_tiling: |
| return image_rgb_decoder_pil_tiling( |
| payload, |
| size=tiling_size, |
| grid_pinpoints=grid_pinpoints, |
| max_tiles_num=max_tiles_num, |
| patch_size=patch_size, |
| tiling_method=tiling_method, |
| ) |
| return image_rgb_decoder_pil(payload) |
|
|
|
|
| def _decode_video_payload( |
| payload: Union[str, bytes], |
| num_frames: int, |
| sampling: str, |
| ) -> Dict[str, Any]: |
| """Decode video payload bytes or path into sampled frames. |
| |
| Args: |
| payload: Video path/URL or raw bytes. |
| num_frames: Number of frames to sample. |
| sampling: Sampling strategy. |
| |
| Returns: |
| Dict with sampled frames and metadata. |
| """ |
| if isinstance(payload, str): |
| payload = _read_bytes_from_uri(payload) |
| decoder = video_rgb_decoder_factory( |
| num_frames=num_frames, sampling=sampling |
| ) |
| return decoder(payload) |
|
|
|
|
| class Yasa2Processor(ProcessorMixin): |
| """Processor that applies the Yasa2 dialog formatting and media decoding.""" |
|
|
| attributes = ["tokenizer", "image_processor", "video_processor"] |
| tokenizer_class = "AutoTokenizer" |
| image_processor_class = "AutoImageProcessor" |
| video_processor_class = "AutoVideoProcessor" |
|
|
| def __init__( |
| self, |
| tokenizer: AutoTokenizer | None = None, |
| image_processor: Yasa2ImageProcessor | None = None, |
| video_processor: Yasa2VideoProcessor | None = None, |
| num_img_tokens: int = 64, |
| image_token_id: int = 100278, |
| num_video_frames: int = 6, |
| video_sampling: str = "chunk", |
| max_tokens: int = 8192, |
| **kwargs, |
| ) -> None: |
| """Initialize the processor with tokenizer and media processors. |
| |
| Args: |
| tokenizer: Tokenizer for text encoding. |
| image_processor: Image processor for ConvNeXt inputs. |
| video_processor: Video processor for sampled frames. |
| num_img_tokens: Number of image content tokens per image. |
| image_token_id: Token ID for image content tokens. |
| num_video_frames: Number of frames to sample per video. |
| video_sampling: Video sampling strategy. |
| max_tokens: Maximum text token budget. |
| **kwargs: Passed to ProcessorMixin. |
| """ |
| if image_processor is None: |
| image_processor = Yasa2ImageProcessor() |
| if video_processor is None: |
| video_processor = Yasa2VideoProcessor( |
| num_frames=num_video_frames, |
| frame_sample_mode=video_sampling, |
| max_num_frames=num_video_frames, |
| ) |
| super().__init__( |
| tokenizer=tokenizer, |
| image_processor=image_processor, |
| video_processor=video_processor, |
| **kwargs, |
| ) |
| self.num_img_tokens = num_img_tokens |
| self.num_video_frames = num_video_frames |
| self.video_sampling = video_sampling |
| self.max_tokens = max_tokens |
| self.image_token_id = image_token_id |
|
|
| def _build_prompt_and_media( |
| self, |
| messages: List[Dict[str, Any]], |
| num_img_tokens: int, |
| num_video_frames: int, |
| video_sampling: str, |
| img_tiling: bool, |
| tiling_method: str, |
| tiling_size: int, |
| grid_pinpoints: List[Tuple[int, int]], |
| max_tiles_num: int, |
| patch_size: int, |
| add_generation_prompt: bool, |
| tools: Optional[List[Dict[str, Any]]] = None, |
| enable_thinking: Optional[bool] = None, |
| ) -> Tuple[str, List[Tuple[MediaType, Dict[str, Any]]]]: |
| """Build Yasa2 prompt text and decode media payloads in prompt order. |
| |
| Prompt formatting is delegated to the tokenizer's shared chat template. |
| |
| Args: |
| messages: Conversation messages in HF format. |
| num_img_tokens: Content tokens per image. |
| num_video_frames: Frames to sample per video. |
| video_sampling: Sampling strategy for videos. |
| img_tiling: Whether to enable tiling. |
| tiling_method: Tiling method identifier. |
| tiling_size: Base tile size. |
| grid_pinpoints: Candidate grid pinpoints. |
| max_tiles_num: Maximum tile count for UHD tiling. |
| patch_size: Patch size for UHD tiling. |
| add_generation_prompt: Whether to append an assistant prefix. |
| tools: Optional tool schema list for system prompt injection. |
| enable_thinking: Unused compatibility flag. |
| Returns: |
| Tuple of prompt string and list of decoded media items. |
| """ |
| media_items: List[Tuple[MediaType, Dict[str, Any]]] = [] |
|
|
| def image_builder(item: Dict[str, Any]) -> List[str]: |
| """Serialize an image placeholder sequence for the chat prompt. |
| |
| Args: |
| item: Raw message dict with image metadata. |
| |
| Returns: |
| List[str]: Tokens that represent the image placeholder. |
| """ |
| payload = item.get("image") or item.get("image_url") |
| if payload is None: |
| raise ValueError("Image content requires an 'image' field.") |
| image_datum = _decode_image_payload( |
| payload, |
| img_tiling=img_tiling, |
| tiling_method=tiling_method, |
| tiling_size=tiling_size, |
| grid_pinpoints=grid_pinpoints, |
| max_tiles_num=max_tiles_num, |
| patch_size=patch_size, |
| ) |
| num_tiles = image_datum.get("num_tiles", 1) |
| repeat_tokens = num_img_tokens * num_tiles |
| media_items.append((MediaType.IMAGE, image_datum)) |
| return ( |
| [IMAGE_START] + [REKA_IMG_TOKEN] * repeat_tokens + [IMAGE_END] |
| ) |
|
|
| def video_builder(item: Dict[str, Any]) -> List[str]: |
| """Serialize a video placeholder sequence for the chat prompt. |
| |
| Args: |
| item: Raw message dict with video metadata. |
| |
| Returns: |
| List[str]: Tokens that represent the video placeholder. |
| """ |
| payload = item.get("video") or item.get("video_url") |
| if payload is None: |
| raise ValueError("Video content requires a 'video' field.") |
| video_datum = _decode_video_payload( |
| payload, |
| num_frames=num_video_frames, |
| sampling=video_sampling, |
| ) |
| repeat_tokens = num_img_tokens * video_datum.get( |
| "num_frames", num_video_frames |
| ) |
| media_items.append((MediaType.VIDEO, video_datum)) |
| return ( |
| [VIDEO_START] + [REKA_IMG_TOKEN] * repeat_tokens + [VIDEO_END] |
| ) |
|
|
| if self.tokenizer is None: |
| raise ValueError( |
| "Yasa2Processor requires a tokenizer to build prompts." |
| ) |
| prompt = self.tokenizer.build_chat_prompt( |
| messages, |
| add_generation_prompt=add_generation_prompt, |
| continue_final_message=False, |
| tools=tools, |
| image_token_builder=image_builder, |
| video_token_builder=video_builder, |
| enable_thinking=enable_thinking, |
| ) |
| return prompt, media_items |
|
|
| def apply_chat_template( |
| self, |
| messages: List[Dict[str, Any]], |
| tokenize: bool = False, |
| add_generation_prompt: bool = True, |
| tools: Optional[List[Dict[str, Any]]] = None, |
| return_tensors: Optional[str] = None, |
| return_dict: bool = False, |
| max_length: Optional[int] = None, |
| padding: Union[bool, Literal["longest", "max_length"]] = False, |
| num_img_tokens: Optional[int] = None, |
| num_video_frames: Optional[int] = None, |
| video_sampling: Optional[str] = None, |
| enable_thinking: Optional[bool] = None, |
| img_tiling: bool = True, |
| tiling_method: str = "llava-uhd", |
| tiling_size: int = 512, |
| grid_pinpoints: Optional[List[Tuple[int, int]]] = None, |
| max_tiles_num: int = 4, |
| patch_size: int = 14, |
| return_prompt: bool = False, |
| **kwargs, |
| ) -> Union[str, Dict[str, Any]]: |
| """Apply the Yasa2 dialog template and optionally tokenize + decode media. |
| |
| The chat template is produced via the tokenizer for consistency with |
| text-only formatting. |
| |
| Args: |
| messages: Conversation messages in HF format. |
| tokenize: Whether to tokenize and return tensors. |
| add_generation_prompt: Whether to append an assistant prefix. |
| tools: Optional tool schema list for system prompt injection. |
| return_tensors: Tensor type for outputs (e.g., "pt"). |
| return_dict: Whether to return a dict payload. |
| max_length: Optional max token length. |
| padding: Padding strategy (False/True/"longest"/"max_length"). |
| num_img_tokens: Override for image content tokens. |
| num_video_frames: Override for video frame count. |
| video_sampling: Override for video sampling strategy. |
| enable_thinking: Unused compatibility flag. |
| img_tiling: Whether to enable tiling for images. |
| tiling_method: Tiling method identifier. |
| tiling_size: Base tile size. |
| grid_pinpoints: Candidate grid pinpoints. |
| max_tiles_num: Maximum tile count for UHD tiling. |
| patch_size: Patch size for UHD tiling. |
| return_prompt: Whether to include the prompt string in output. |
| **kwargs: Unused extra arguments for compatibility. |
| |
| Returns: |
| Prompt string if tokenize is False, otherwise a dict of tensors. |
| """ |
| if grid_pinpoints is None: |
| grid_pinpoints = [ |
| (2, 2), |
| (1, 2), |
| (2, 1), |
| (1, 3), |
| (3, 1), |
| (1, 4), |
| (4, 1), |
| ] |
| num_img_tokens = num_img_tokens or self.num_img_tokens |
| num_video_frames = num_video_frames or self.num_video_frames |
| video_sampling = video_sampling or self.video_sampling |
| user_max_length = max_length |
| max_tokens = user_max_length or self.max_tokens |
|
|
| prompt, media_items = self._build_prompt_and_media( |
| messages=messages, |
| num_img_tokens=num_img_tokens, |
| num_video_frames=num_video_frames, |
| video_sampling=video_sampling, |
| img_tiling=img_tiling, |
| tiling_method=tiling_method, |
| tiling_size=tiling_size, |
| grid_pinpoints=grid_pinpoints, |
| max_tiles_num=max_tiles_num, |
| patch_size=patch_size, |
| add_generation_prompt=add_generation_prompt, |
| tools=tools, |
| enable_thinking=enable_thinking, |
| ) |
|
|
| if not tokenize: |
| return prompt |
|
|
| expected_img_tokens = 0 |
| for media_type, media_datum in media_items: |
| if media_type == MediaType.IMAGE: |
| expected_img_tokens += num_img_tokens * media_datum.get( |
| "num_tiles", 1 |
| ) |
| elif media_type == MediaType.VIDEO: |
| expected_img_tokens += num_img_tokens * media_datum.get( |
| "num_frames", num_video_frames |
| ) |
|
|
| input_ids = self.tokenizer.tiktoken.encode( |
| prompt, allowed_special="all" |
| ) |
| input_ids = input_ids[:max_tokens] |
| if expected_img_tokens: |
| actual_img_tokens = sum( |
| 1 for token_id in input_ids if token_id == self.image_token_id |
| ) |
| |
| if actual_img_tokens != expected_img_tokens: |
| raise ValueError( |
| "Prompt truncation dropped image placeholder tokens. " |
| "Increase max_length/max_tokens or reduce media inputs." |
| ) |
|
|
| attention_mask = [1] * len(input_ids) |
| token_type_ids, mm_token_type_ids = self._build_mm_token_type_ids( |
| input_ids |
| ) |
|
|
| if padding not in (False, True, "longest", "max_length"): |
| raise ValueError(f"Unsupported padding value: {padding}") |
| if padding in (True, "longest", "max_length"): |
| pad_to_length = ( |
| max_tokens |
| if (padding == "max_length" or user_max_length) |
| else len(input_ids) |
| ) |
| pad_len = pad_to_length - len(input_ids) |
| if pad_len > 0: |
| |
| |
| input_ids = [PAD_ID] * pad_len + input_ids |
| attention_mask = [0] * pad_len + attention_mask |
| token_type_ids = [0] * pad_len + token_type_ids |
| mm_token_type_ids = [0] * pad_len + mm_token_type_ids |
|
|
| pixel_values_list = [] |
| patch_attention_list = [] |
| for media_type, media_datum in media_items: |
| if media_type == MediaType.IMAGE: |
| image_outputs = self.image_processor( |
| images=media_datum["pixel_values"], return_tensors="pt" |
| ) |
| pixel_values_list.append(image_outputs["pixel_values"]) |
| if "patch_attention_mask" in image_outputs: |
| patch_attention_list.append( |
| image_outputs["patch_attention_mask"] |
| ) |
| elif media_type == MediaType.VIDEO: |
| video_outputs = self.video_processor.preprocess( |
| videos=media_datum["pixel_values"], return_tensors="pt" |
| ) |
| pixel_values_list.append(video_outputs["pixel_values"]) |
| patch_attention_list.append( |
| video_outputs["patch_attention_mask"] |
| ) |
| else: |
| raise ValueError(f"Unsupported media type: {media_type}") |
|
|
| if pixel_values_list: |
| pixel_values = torch.cat(pixel_values_list, dim=0) |
| else: |
| pixel_values = torch.tensor([]) |
| if patch_attention_list: |
| patch_attention_mask = torch.cat(patch_attention_list, dim=0) |
| else: |
| patch_attention_mask = torch.tensor([]) |
|
|
| if return_tensors == "pt": |
| input_ids = torch.tensor(input_ids, dtype=torch.long) |
| attention_mask = torch.tensor(attention_mask, dtype=torch.long) |
| token_type_ids = torch.tensor(token_type_ids, dtype=torch.long) |
| mm_token_type_ids = torch.tensor( |
| mm_token_type_ids, dtype=torch.long |
| ) |
| if input_ids.dim() == 1: |
| input_ids = input_ids.unsqueeze(0) |
| if attention_mask.dim() == 1: |
| attention_mask = attention_mask.unsqueeze(0) |
| if token_type_ids.dim() == 1: |
| token_type_ids = token_type_ids.unsqueeze(0) |
| if mm_token_type_ids.dim() == 1: |
| mm_token_type_ids = mm_token_type_ids.unsqueeze(0) |
|
|
| output = { |
| "input_ids": input_ids, |
| "attention_mask": attention_mask, |
| "token_type_ids": token_type_ids, |
| "mm_token_type_ids": mm_token_type_ids, |
| "pixel_values": pixel_values, |
| "patch_attention_mask": patch_attention_mask, |
| } |
| if return_prompt: |
| output["prompt"] = prompt |
|
|
| return output if return_dict else output |
|
|
| def __call__( |
| self, |
| images: Optional[Any] = None, |
| text: Optional[Union[str, List[str]]] = None, |
| videos: Optional[Any] = None, |
| audio: Optional[Any] = None, |
| **kwargs: Any, |
| ) -> Any: |
| """Run the processor and ensure multimodal token identifiers are present. |
| |
| Args: |
| images: Optional image inputs. |
| text: Optional textual inputs. |
| videos: Optional video inputs. |
| audio: Optional audio inputs. |
| **kwargs: Additional keyword arguments forwarded to the base processor. |
| |
| Returns: |
| Any: Processor outputs augmented with token type ids when needed. |
| """ |
| kwargs.pop("return_mm_token_type_ids", None) |
| image_processor = getattr(self, "image_processor", None) |
| img_tiling = kwargs.get("img_tiling", True) |
| tiling_method = kwargs.get( |
| "tiling_method", |
| getattr(image_processor, "tiling_method", "llava-uhd"), |
| ) |
| tiling_size = kwargs.get("tiling_size") |
| if tiling_size is None and image_processor is not None: |
| size = getattr(image_processor, "size", None) |
| if isinstance(size, dict) and "shortest_edge" in size: |
| tiling_size = int(size["shortest_edge"]) |
| elif isinstance(size, int): |
| tiling_size = size |
| tiling_size = tiling_size or 512 |
| grid_pinpoints = kwargs.get("grid_pinpoints") |
| if grid_pinpoints is None: |
| grid_pinpoints = [ |
| (2, 2), |
| (1, 2), |
| (2, 1), |
| (1, 3), |
| (3, 1), |
| (1, 4), |
| (4, 1), |
| ] |
| max_tiles_num = kwargs.get( |
| "max_tiles_num", getattr(image_processor, "max_tiles_num", 4) |
| ) |
| patch_size = kwargs.get( |
| "patch_size", getattr(image_processor, "patch_size", 14) |
| ) |
|
|
| |
| |
| if isinstance(text, str) and ( |
| images is not None or videos is not None |
| ): |
| if ( |
| REKA_IMG_TOKEN not in text |
| and IMAGE_START not in text |
| and VIDEO_START not in text |
| ): |
| text = self._prepend_mm_placeholders( |
| text=text, images=images, videos=videos, **kwargs |
| ) |
| else: |
| text = self._expand_image_placeholders( |
| text=text, images=images, **kwargs |
| ) |
|
|
| |
| if images is not None and img_tiling: |
| images = self._tile_images( |
| images=images, |
| tiling_method=tiling_method, |
| tiling_size=tiling_size, |
| grid_pinpoints=grid_pinpoints, |
| max_tiles_num=max_tiles_num, |
| patch_size=patch_size, |
| ) |
|
|
| |
| if isinstance(text, str) and isinstance(images, list): |
| text = [text] |
| images = [images] |
| outputs = super().__call__( |
| images=images, text=text, videos=videos, audio=audio, **kwargs |
| ) |
| if "input_ids" in outputs and "token_type_ids" not in outputs: |
| token_type_ids, mm_token_type_ids = self._build_mm_token_type_ids( |
| outputs["input_ids"] |
| ) |
| outputs["token_type_ids"] = token_type_ids |
| outputs["mm_token_type_ids"] = mm_token_type_ids |
| return outputs |
|
|
| def _expand_image_placeholders( |
| self, |
| text: str, |
| images: Optional[Any], |
| **kwargs: Any, |
| ) -> str: |
| if images is None or IMAGE_START not in text or IMAGE_END not in text: |
| return text |
| image_list = ( |
| list(images) if isinstance(images, (list, tuple)) else [images] |
| ) |
| image_processor = getattr(self, "image_processor", None) |
| img_tiling = kwargs.get("img_tiling", True) |
| tiling_method = kwargs.get( |
| "tiling_method", |
| getattr(image_processor, "tiling_method", "llava-uhd"), |
| ) |
| tiling_size = kwargs.get("tiling_size") |
| if tiling_size is None and image_processor is not None: |
| size = getattr(image_processor, "size", None) |
| if isinstance(size, dict) and "shortest_edge" in size: |
| tiling_size = int(size["shortest_edge"]) |
| elif isinstance(size, int): |
| tiling_size = size |
| tiling_size = tiling_size or 512 |
| grid_pinpoints = kwargs.get("grid_pinpoints") |
| if grid_pinpoints is None: |
| grid_pinpoints = [ |
| (2, 2), |
| (1, 2), |
| (2, 1), |
| (1, 3), |
| (3, 1), |
| (1, 4), |
| (4, 1), |
| ] |
| max_tiles_num = kwargs.get( |
| "max_tiles_num", getattr(image_processor, "max_tiles_num", 4) |
| ) |
| patch_size = kwargs.get( |
| "patch_size", getattr(image_processor, "patch_size", 14) |
| ) |
|
|
| expected_tokens = [] |
| for image in image_list: |
| width = height = 0 |
| if hasattr(image, "size"): |
| width, height = image.size |
| elif isinstance(image, (list, tuple)) and len(image) >= 2: |
| height, width = int(image[0]), int(image[1]) |
| if img_tiling and width > 0 and height > 0: |
| if str(tiling_method).lower() == "llava-next": |
| tiles = estimate_num_tiles_llava_next( |
| (width, height), |
| size=tiling_size, |
| grid_pinpoints=grid_pinpoints, |
| ) |
| else: |
| tiles = estimate_num_tiles_llava_uhd( |
| (width, height), |
| max_tiles_num=max_tiles_num, |
| scale_resolution=tiling_size, |
| patch_size=patch_size, |
| never_split=False, |
| ) |
| else: |
| tiles = 1 |
| expected_tokens.append(self.num_img_tokens * tiles) |
|
|
| parts = [] |
| remaining = text |
| for tokens in expected_tokens: |
| start = remaining.find(IMAGE_START) |
| end = remaining.find(IMAGE_END, start + len(IMAGE_START)) |
| if start == -1 or end == -1: |
| return text |
| parts.append(remaining[:start]) |
| parts.append(IMAGE_START + (REKA_IMG_TOKEN * tokens) + IMAGE_END) |
| remaining = remaining[end + len(IMAGE_END) :] |
| parts.append(remaining) |
| new_text = "".join(parts) |
| return new_text |
|
|
| def _tile_images( |
| self, |
| images: Any, |
| tiling_method: str, |
| tiling_size: int, |
| grid_pinpoints: List[Tuple[int, int]], |
| max_tiles_num: int, |
| patch_size: int, |
| ) -> Any: |
| |
| image_list = ( |
| list(images) if isinstance(images, (list, tuple)) else [images] |
| ) |
| tiled_images: List[Any] = [] |
| for image in image_list: |
| if image is None: |
| continue |
| if isinstance(image, torch.Tensor): |
| tiled_images.append(image) |
| continue |
| if isinstance(image, np.ndarray): |
| image = Image.fromarray(image) |
| if isinstance(image, Image.Image): |
| |
| if str(tiling_method).lower() == "llava-next": |
| tiles = process_anyres_image( |
| image, size=tiling_size, grid_pinpoints=grid_pinpoints |
| ) |
| else: |
| tiles = process_anyres_image_uhd( |
| image, |
| max_tiles_num=max_tiles_num, |
| scale_resolution=tiling_size, |
| patch_size=patch_size, |
| never_split=False, |
| ) |
| tiled_images.extend(tiles) |
| continue |
| tiled_images.append(image) |
| return ( |
| tiled_images |
| if isinstance(images, (list, tuple)) |
| else tiled_images[0] |
| ) |
|
|
| def _prepend_mm_placeholders( |
| self, |
| text: str, |
| images: Optional[Any], |
| videos: Optional[Any], |
| **kwargs: Any, |
| ) -> str: |
| """Prepend placeholder tokens when media is provided without markers.""" |
| |
| image_list = ( |
| list(images) |
| if isinstance(images, (list, tuple)) |
| else ([images] if images is not None else []) |
| ) |
| num_images = len(image_list) |
| num_videos = self._count_media_items(videos) |
| if num_images == 0 and num_videos == 0: |
| return text |
|
|
| image_processor = getattr(self, "image_processor", None) |
| img_tiling = kwargs.get("img_tiling", True) |
| tiling_method = kwargs.get( |
| "tiling_method", |
| getattr(image_processor, "tiling_method", "llava-uhd"), |
| ) |
| tiling_size = kwargs.get("tiling_size") |
| if tiling_size is None and image_processor is not None: |
| size = getattr(image_processor, "size", None) |
| if isinstance(size, dict) and "shortest_edge" in size: |
| tiling_size = int(size["shortest_edge"]) |
| elif isinstance(size, int): |
| tiling_size = size |
| tiling_size = tiling_size or 512 |
| grid_pinpoints = kwargs.get("grid_pinpoints") |
| if grid_pinpoints is None: |
| grid_pinpoints = [ |
| (2, 2), |
| (1, 2), |
| (2, 1), |
| (1, 3), |
| (3, 1), |
| (1, 4), |
| (4, 1), |
| ] |
| max_tiles_num = kwargs.get( |
| "max_tiles_num", getattr(image_processor, "max_tiles_num", 4) |
| ) |
| patch_size = kwargs.get( |
| "patch_size", getattr(image_processor, "patch_size", 14) |
| ) |
|
|
| def _get_image_size(image: Any) -> Tuple[int, int]: |
| if hasattr(image, "size"): |
| size = image.size |
| if isinstance(size, (list, tuple)) and len(size) >= 2: |
| return int(size[0]), int(size[1]) |
| if hasattr(image, "shape"): |
| shape = image.shape |
| if isinstance(shape, (list, tuple)) and len(shape) >= 2: |
| return int(shape[1]), int(shape[0]) |
| if isinstance(image, (list, tuple)) and len(image) >= 2: |
| return int(image[1]), int(image[0]) |
| return 0, 0 |
|
|
| placeholder = "" |
| for image in image_list: |
| tiles = 1 |
| if img_tiling: |
| width, height = _get_image_size(image) |
| if width > 0 and height > 0: |
| if str(tiling_method).lower() == "llava-next": |
| tiles = estimate_num_tiles_llava_next( |
| (width, height), |
| size=tiling_size, |
| grid_pinpoints=grid_pinpoints, |
| ) |
| else: |
| tiles = estimate_num_tiles_llava_uhd( |
| (width, height), |
| max_tiles_num=max_tiles_num, |
| scale_resolution=tiling_size, |
| patch_size=patch_size, |
| never_split=False, |
| ) |
| placeholder += IMAGE_START |
| placeholder += REKA_IMG_TOKEN * (self.num_img_tokens * tiles) |
| placeholder += IMAGE_END |
| for _ in range(num_videos): |
| placeholder += VIDEO_START |
| placeholder += REKA_IMG_TOKEN * ( |
| self.num_img_tokens * self.num_video_frames |
| ) |
| placeholder += VIDEO_END |
| return f"{placeholder}{text}" |
|
|
| @staticmethod |
| def _count_media_items(payload: Optional[Any]) -> int: |
| """Best-effort count of media items for placeholder insertion.""" |
| if payload is None: |
| return 0 |
| if isinstance(payload, (list, tuple)): |
| return len(payload) |
| return 1 |
|
|
| def _build_mm_token_type_ids(self, input_ids: Any) -> Tuple[Any, Any]: |
| """Compute token_type_ids that mark multimodal placeholders. |
| |
| Args: |
| input_ids: Input IDs or sequences containing tokenizer ids. |
| |
| Returns: |
| Tuple[Any, Any]: Regular and multimodal token type ids detected from placeholders. |
| """ |
| if self.tokenizer is None: |
| return input_ids, input_ids |
| img_token_id = self.image_token_id |
|
|
| if isinstance(input_ids, torch.Tensor): |
| mm_token_type_ids = (input_ids == img_token_id).long() |
| token_type_ids = mm_token_type_ids.clone() |
| return token_type_ids, mm_token_type_ids |
|
|
| if isinstance(input_ids, (list, tuple)): |
| if input_ids and isinstance(input_ids[0], (list, tuple)): |
| mm_token_type_ids = [ |
| [1 if token_id == img_token_id else 0 for token_id in seq] |
| for seq in input_ids |
| ] |
| else: |
| mm_token_type_ids = [ |
| 1 if token_id == img_token_id else 0 |
| for token_id in input_ids |
| ] |
| token_type_ids = list(mm_token_type_ids) |
| return token_type_ids, mm_token_type_ids |
|
|
| if hasattr(input_ids, "tolist"): |
| ids = input_ids.tolist() |
| token_type_ids, mm_token_type_ids = self._build_mm_token_type_ids( |
| ids |
| ) |
| return token_type_ids, mm_token_type_ids |
|
|
| return input_ids, input_ids |
|
|
| def _get_num_multimodal_tokens( |
| self, |
| image_sizes: Optional[List[List[int]]] = None, |
| video_sizes: Optional[List[List[int]]] = None, |
| **kwargs: Any, |
| ) -> MultiModalData: |
| """Estimate the count of multimodal tokens from provided media sizes. |
| |
| Args: |
| image_sizes: Per-image sizes as (height, width) tuples. |
| video_sizes: Per-video sizes as (num_frames, height, width) tuples. |
| **kwargs: Ignored compatibility arguments accepted by parent helpers. |
| |
| Returns: |
| MultiModalData: Token counts for the vision modalities. |
| """ |
| vision_data: Dict[str, List[int]] = {} |
| if image_sizes is not None: |
| image_processor = getattr(self, "image_processor", None) |
| img_tiling = kwargs.get("img_tiling", True) |
| tiling_method = kwargs.get( |
| "tiling_method", |
| getattr(image_processor, "tiling_method", "llava-uhd"), |
| ) |
| tiling_size = kwargs.get("tiling_size") |
| if tiling_size is None and image_processor is not None: |
| size = getattr(image_processor, "size", None) |
| if isinstance(size, dict) and "shortest_edge" in size: |
| tiling_size = int(size["shortest_edge"]) |
| elif isinstance(size, int): |
| tiling_size = size |
| tiling_size = tiling_size or 512 |
| grid_pinpoints = kwargs.get("grid_pinpoints") |
| if grid_pinpoints is None: |
| grid_pinpoints = [ |
| (2, 2), |
| (1, 2), |
| (2, 1), |
| (1, 3), |
| (3, 1), |
| (1, 4), |
| (4, 1), |
| ] |
| max_tiles_num = kwargs.get( |
| "max_tiles_num", getattr(image_processor, "max_tiles_num", 4) |
| ) |
| patch_size = kwargs.get( |
| "patch_size", getattr(image_processor, "patch_size", 14) |
| ) |
|
|
| |
| num_image_tokens: List[int] = [] |
| num_image_patches: List[int] = [] |
| for image_size in image_sizes: |
| height = width = 0 |
| if image_size and len(image_size) >= 2: |
| height, width = int(image_size[0]), int(image_size[1]) |
| tiles = 1 |
| if img_tiling and width > 0 and height > 0: |
| if str(tiling_method).lower() == "llava-next": |
| tiles = estimate_num_tiles_llava_next( |
| (width, height), |
| size=tiling_size, |
| grid_pinpoints=grid_pinpoints, |
| ) |
| else: |
| tiles = estimate_num_tiles_llava_uhd( |
| (width, height), |
| max_tiles_num=max_tiles_num, |
| scale_resolution=tiling_size, |
| patch_size=patch_size, |
| never_split=False, |
| ) |
| num_image_tokens.append(self.num_img_tokens * tiles) |
| num_image_patches.append(tiles) |
|
|
| vision_data["num_image_tokens"] = num_image_tokens |
| vision_data["num_image_patches"] = num_image_patches |
| else: |
| vision_data["num_image_tokens"] = [] |
| vision_data["num_image_patches"] = [] |
| if video_sizes is not None: |
| video_tokens: List[int] = [] |
| for video_size in video_sizes: |
| num_frames = video_size[0] if video_size else 0 |
| num_frames = min( |
| num_frames or self.num_video_frames, self.num_video_frames |
| ) |
| video_tokens.append(self.num_img_tokens * num_frames) |
| vision_data["num_video_tokens"] = video_tokens |
|
|
| return MultiModalData(**vision_data) |
|
|
|
|
| Yasa2Processor.register_for_auto_class() |
|
|