| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """ |
| Processor class for Cosmos-Embed1 |
| """ |
|
|
| from typing import List, Optional, Tuple, Union |
|
|
| import numpy as np |
| import torch |
| import torchvision |
| from transformers import AutoProcessor, BatchFeature |
| from transformers.processing_utils import ProcessorMixin |
| from transformers.utils import TensorType |
|
|
| from .configuration_embed1 import CosmosEmbed1Config |
|
|
|
|
| class CosmosEmbed1Processor(ProcessorMixin): |
| r""" |
| Constructs a processor which wraps a BertTokenizer tokenizer and a fast video resize function. |
| |
| Args: |
| tokenizer ([`BertTokenizerFast`], *optional*): |
| The tokenizer is a required input for text processing. |
| config ([`CosmosEmbed1Config`], *optional*): |
| Needed for processing options. |
| """ |
|
|
| attributes = ["tokenizer"] |
| tokenizer_class = ("BertTokenizer", "BertTokenizerFast") |
| config_class = CosmosEmbed1Config |
| chat_template = None |
|
|
| def __init__( |
| self, |
| tokenizer=None, |
| resolution: Union[int, Tuple[int, int]] = 448, |
| num_video_frames: int = 8, |
| max_txt_len: int = 128, |
| **kwargs, |
| ) -> None: |
| super().__init__(tokenizer, **kwargs) |
| self.resolution = resolution |
| self.num_video_frames = num_video_frames |
| self.max_txt_len = max_txt_len |
|
|
| def __call__( |
| self, |
| text: Optional[Union[str, List[str]]] = None, |
| videos: Optional[Union[np.ndarray, torch.Tensor]] = None, |
| return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, |
| resolution: Union[int, Tuple[int, int]] = None, |
| num_video_frames: int = None, |
| max_txt_len: int = None, |
| **kwargs, |
| ) -> BatchFeature: |
| inputs = {} |
|
|
| if text is not None: |
| max_txt_len = max_txt_len if max_txt_len is not None else self.max_txt_len |
| tokenized = self.tokenizer( |
| text, return_tensors="pt", padding="max_length", truncation=True, max_length=max_txt_len, **kwargs |
| ) |
| inputs["input_ids"] = tokenized.input_ids |
| inputs["attention_mask"] = tokenized.attention_mask.float() |
|
|
| if videos is not None: |
| if isinstance(videos, np.ndarray): |
| videos = torch.from_numpy(videos) |
| if not isinstance(videos, torch.Tensor) or videos.ndim != 5: |
| raise ValueError("Processor expects a numpy or torch tensor of shape BTCHW from [0-255].") |
| resolution = resolution if resolution is not None else self.resolution |
| if isinstance(resolution, int): |
| resolution = (resolution, resolution) |
| _, t, c, h, w = videos.shape |
| if c != 3: |
| raise ValueError(f"Expected tensor of shape BTCHW with RGB channels, got channel size {c}.") |
| num_video_frames = num_video_frames if num_video_frames is not None else self.num_video_frames |
| if t != num_video_frames: |
| raise ValueError(f"Expected tensor of shape BTCHW with {num_video_frames} frames, got {t}.") |
| if h != resolution[0] or w != resolution[1]: |
| videos = resize_video(videos, resolution) |
| if videos.dtype == torch.uint8: |
| videos = videos.float() |
| inputs["videos"] = videos / 255.0 |
|
|
| if not inputs: |
| raise ValueError("Must pass either `text` or `videos` argument to __call__ function.") |
|
|
| return BatchFeature(inputs, tensor_type=return_tensors) |
|
|
|
|
| def resize_video(video: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor: |
| """Resize a video tensor (B, T, C, H, W) to a new height/width. |
| |
| Args: |
| video (torch.Tensor): (B, T, C, H, W) uint8 or float32. |
| size (tuple): target (H', W') size. |
| Returns: |
| torch.Tensor: resized video of shape (B, T, C, H', W') |
| """ |
| h, w = size |
| B, T, C, H, W = video.shape |
| video = video.view(B * T, C, H, W) |
| resize = torchvision.transforms.Resize( |
| (h, w), |
| antialias=True, |
| interpolation=torchvision.transforms.InterpolationMode.BILINEAR, |
| ) |
| video = resize(video) |
| new_H, new_W = video.shape[-2:] |
| video = video.view(B, T, C, new_H, new_W) |
| return video |
|
|
|
|
| AutoProcessor.register(CosmosEmbed1Config, CosmosEmbed1Processor) |
|
|
|
|
| __all__ = ["CosmosEmbed1Processor"] |
|
|