| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from transformers import ProcessorMixin |
| from typing import List, Union, Dict, Any, Optional |
| import torch |
|
|
|
|
| class XVLAProcessor(ProcessorMixin): |
| """ |
| XVLAProcessor: Unified multimodal processor for XVLA models. |
| |
| Handles: |
| - Multi-view image inputs (e.g., from multiple cameras). |
| - Batch processing for multiple samples. |
| - Joint tokenization and image tensor preparation. |
| |
| This processor combines an image processor and a tokenizer under a single interface |
| so that users can call it directly like: |
| |
| >>> processor = XVLAProcessor.from_pretrained("path/to/xvla") |
| >>> inputs = processor(images=batch_images, language_instruction=batch_texts) |
| |
| It is fully compatible with the Hugging Face AutoProcessor API. |
| |
| Attributes |
| ---------- |
| num_views : int, default=3 |
| Expected number of image views per sample. Missing views will be padded with zeros. |
| language_max_length : int, default=50 |
| Maximum token length for text encoding. |
| attributes : list |
| Required by ProcessorMixin to know which submodules are stored and reloaded. |
| image_processor_class : str |
| The name of the associated image processor class. |
| tokenizer_class : tuple(str) |
| The names of compatible tokenizer classes. |
| """ |
|
|
| num_views: int = 3 |
| language_max_length: int = 50 |
|
|
| |
| attributes = ["image_processor", "tokenizer"] |
| image_processor_class = "AutoImageProcessor" |
| tokenizer_class = ("BartTokenizer", "BartTokenizerFast") |
|
|
| def __init__(self, image_processor=None, tokenizer=None): |
| """ |
| Initialize XVLAProcessor. |
| |
| Parameters |
| ---------- |
| image_processor : PreTrainedImageProcessor, optional |
| The image processor used to normalize/resize images. |
| tokenizer : PreTrainedTokenizer, optional |
| The tokenizer used for text tokenization. |
| """ |
| |
| super().__init__(image_processor, tokenizer) |
|
|
| |
| def encode_language(self, language_instruction: Union[str, List[str]]) -> Dict[str, torch.Tensor]: |
| """ |
| Tokenize one or more language instructions. |
| |
| Parameters |
| ---------- |
| language_instruction : str or List[str] |
| A single instruction or a batch of instructions. |
| |
| Returns |
| ------- |
| Dict[str, torch.Tensor] |
| { |
| "input_ids": tensor of shape [B, L] |
| } |
| """ |
| if isinstance(language_instruction, str): |
| language_instruction = [language_instruction] |
|
|
| inputs = self.tokenizer( |
| language_instruction, |
| return_tensors="pt", |
| padding="max_length", |
| max_length=self.language_max_length, |
| truncation=True, |
| ) |
| return {"input_ids": inputs["input_ids"]} |
|
|
| |
| def encode_image( |
| self, |
| images: Union[List, List[List]], |
| **kwargs |
| ) -> Dict[str, torch.Tensor]: |
| """ |
| Preprocess one or more sets of multi-view images. |
| |
| Parameters |
| ---------- |
| images : List or List[List] |
| Single sample: [img1, img2, ...] |
| Batch: [[img1a, img1b], [img2a, img2b, img2c], ...] |
| Each image may be a PIL.Image, NumPy array, or torch.Tensor. |
| |
| kwargs : dict |
| Extra arguments passed to the underlying image processor |
| (e.g., `do_resize=False`, `size=(224,224)`). |
| |
| Returns |
| ------- |
| Dict[str, torch.Tensor] |
| { |
| "image_input": tensor [B, num_views, C, H, W], |
| "image_mask": tensor [B, num_views] |
| } |
| """ |
| |
| if not isinstance(images[0], (list, tuple)): |
| images = [images] |
|
|
| batch_imgs, batch_masks = [], [] |
|
|
| for sample_imgs in images: |
| processed = self.image_processor(sample_imgs, return_tensors="pt", **kwargs)["pixel_values"] |
| V_exist = processed.size(0) |
|
|
| |
| if V_exist < self.num_views: |
| processed = torch.cat( |
| [processed, |
| processed.new_zeros(self.num_views - V_exist, *processed.shape[1:])], |
| dim=0, |
| ) |
|
|
| |
| image_mask = torch.zeros(self.num_views, dtype=torch.bool, device=processed.device) |
| image_mask[:V_exist] = True |
|
|
| batch_imgs.append(processed) |
| batch_masks.append(image_mask) |
|
|
| image_input = torch.stack(batch_imgs, dim=0) |
| image_mask = torch.stack(batch_masks, dim=0) |
|
|
| return {"image_input": image_input, "image_mask": image_mask} |
|
|
| |
| def __call__( |
| self, |
| images: Optional[Union[List, List[List]]] = None, |
| language_instruction: Optional[Union[str, List[str]]] = None, |
| **kwargs |
| ) -> Dict[str, torch.Tensor]: |
| """ |
| Combine image and text encoding into a unified multimodal input. |
| |
| Parameters |
| ---------- |
| images : List or List[List], optional |
| Single-sample or batched multi-view images. |
| language_instruction : str or List[str], optional |
| Corresponding text instructions. |
| kwargs : dict |
| Extra args passed to image processor. |
| |
| Returns |
| ------- |
| Dict[str, torch.Tensor] |
| { |
| "input_ids": [B, L], optional, |
| "image_input": [B, num_views, C, H, W], optional, |
| "image_mask": [B, num_views], optional |
| } |
| """ |
| outputs: Dict[str, Any] = {} |
|
|
| |
| if language_instruction is not None: |
| outputs.update(self.encode_language(language_instruction)) |
|
|
| |
| if images is not None: |
| outputs.update(self.encode_image(images, **kwargs)) |
|
|
| |
| if "input_ids" in outputs and "image_input" in outputs: |
| assert outputs["input_ids"].size(0) == outputs["image_input"].size(0), ( |
| f"Batch mismatch: text batch {outputs['input_ids'].size(0)} " |
| f"!= image batch {outputs['image_input'].size(0)}" |
| ) |
| return outputs |
|
|