| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | """ |
| | Processor class for Cosmos-Embed1 |
| | """ |
| |
|
| | from typing import List, Optional, Tuple, Union |
| |
|
| | import numpy as np |
| | import torch |
| | import torchvision |
| | from transformers import AutoProcessor, BatchFeature |
| | from transformers.processing_utils import ProcessorMixin |
| | from transformers.utils import TensorType |
| |
|
| | from .configuration_embed1 import CosmosEmbed1Config |
| |
|
| |
|
| | class CosmosEmbed1Processor(ProcessorMixin): |
| | r""" |
| | Constructs a processor which wraps a BertTokenizer tokenizer and a fast video resize function. |
| | |
| | Args: |
| | tokenizer ([`BertTokenizerFast`], *optional*): |
| | The tokenizer is a required input for text processing. |
| | config ([`CosmosEmbed1Config`], *optional*): |
| | Needed for processing options. |
| | """ |
| |
|
| | attributes = ["tokenizer"] |
| | tokenizer_class = ("BertTokenizer", "BertTokenizerFast") |
| | config_class = CosmosEmbed1Config |
| | chat_template = None |
| |
|
| | def __init__( |
| | self, |
| | tokenizer=None, |
| | resolution: Union[int, Tuple[int, int]] = 448, |
| | num_video_frames: int = 8, |
| | max_txt_len: int = 128, |
| | **kwargs, |
| | ) -> None: |
| | super().__init__(tokenizer, **kwargs) |
| | self.resolution = resolution |
| | self.num_video_frames = num_video_frames |
| | self.max_txt_len = max_txt_len |
| |
|
| | def __call__( |
| | self, |
| | text: Optional[Union[str, List[str]]] = None, |
| | videos: Optional[Union[np.ndarray, torch.Tensor]] = None, |
| | return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, |
| | resolution: Union[int, Tuple[int, int]] = None, |
| | num_video_frames: int = None, |
| | max_txt_len: int = None, |
| | **kwargs, |
| | ) -> BatchFeature: |
| | inputs = {} |
| |
|
| | if text is not None: |
| | max_txt_len = max_txt_len if max_txt_len is not None else self.max_txt_len |
| | tokenized = self.tokenizer( |
| | text, return_tensors="pt", padding="max_length", truncation=True, max_length=max_txt_len, **kwargs |
| | ) |
| | inputs["input_ids"] = tokenized.input_ids |
| | inputs["attention_mask"] = tokenized.attention_mask.float() |
| |
|
| | if videos is not None: |
| | if isinstance(videos, np.ndarray): |
| | videos = torch.from_numpy(videos) |
| | if not isinstance(videos, torch.Tensor) or videos.ndim != 5: |
| | raise ValueError("Processor expects a numpy or torch tensor of shape BTCHW from [0-255].") |
| | resolution = resolution if resolution is not None else self.resolution |
| | if isinstance(resolution, int): |
| | resolution = (resolution, resolution) |
| | _, t, c, h, w = videos.shape |
| | if c != 3: |
| | raise ValueError(f"Expected tensor of shape BTCHW with RGB channels, got channel size {c}.") |
| | num_video_frames = num_video_frames if num_video_frames is not None else self.num_video_frames |
| | if t != num_video_frames: |
| | raise ValueError(f"Expected tensor of shape BTCHW with {num_video_frames} frames, got {t}.") |
| | if h != resolution[0] or w != resolution[1]: |
| | videos = resize_video(videos, resolution) |
| | if videos.dtype == torch.uint8: |
| | videos = videos.float() |
| | inputs["videos"] = videos / 255.0 |
| |
|
| | if not inputs: |
| | raise ValueError("Must pass either `text` or `videos` argument to __call__ function.") |
| |
|
| | return BatchFeature(inputs, tensor_type=return_tensors) |
| |
|
| |
|
| | def resize_video(video: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor: |
| | """Resize a video tensor (B, T, C, H, W) to a new height/width. |
| | |
| | Args: |
| | video (torch.Tensor): (B, T, C, H, W) uint8 or float32. |
| | size (tuple): target (H', W') size. |
| | Returns: |
| | torch.Tensor: resized video of shape (B, T, C, H', W') |
| | """ |
| | h, w = size |
| | B, T, C, H, W = video.shape |
| | video = video.view(B * T, C, H, W) |
| | resize = torchvision.transforms.Resize( |
| | (h, w), |
| | antialias=True, |
| | interpolation=torchvision.transforms.InterpolationMode.BILINEAR, |
| | ) |
| | video = resize(video) |
| | new_H, new_W = video.shape[-2:] |
| | video = video.view(B, T, C, new_H, new_W) |
| | return video |
| |
|
| |
|
| | AutoProcessor.register(CosmosEmbed1Config, CosmosEmbed1Processor) |
| |
|
| |
|
| | __all__ = ["CosmosEmbed1Processor"] |
| |
|