Spaces:
Running
on
A100
Running
on
A100
| # Copyright (c) 2025 NVIDIA CORPORATION. | |
| # Licensed under the MIT license. | |
| # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. | |
| # LICENSE is in incl_licenses directory. | |
| from dataclasses import dataclass | |
| from typing import Any, Dict, Sequence | |
| import torch | |
| from transformers import PreTrainedTokenizer | |
| from llava.constants import IGNORE_INDEX | |
| from llava.utils.logging import logger | |
| __all__ = ["DataCollator"] | |
| class DataCollator: | |
| tokenizer: PreTrainedTokenizer | |
| def __init__(self, tokenizer: PreTrainedTokenizer): | |
| super().__init__() | |
| self.tokenizer = tokenizer | |
| def __call__(self, instances: Sequence[Dict[str, Any]]) -> Dict[str, Any]: | |
| # Gather everything from the batch | |
| input_ids, labels, media, block_sizes = [], [], {name: [] for name in self.tokenizer.media_tokens}, [] | |
| media_meta = {} | |
| media_meta["sound_feature_masks"] = [] | |
| media_meta["sound_embed_masks"] = [] | |
| media_meta["frame_times"] = [] | |
| for instance in instances: | |
| if isinstance(instance["input_ids"], torch.Tensor): | |
| input_ids.append(instance["input_ids"]) | |
| labels.append(instance["labels"]) | |
| for name in media: | |
| objs = instance.get(name) | |
| objs = objs if objs is not None else [] | |
| media[name].append([obj for obj in objs]) | |
| if instance.get("sound") is not None: | |
| for name_k in media_meta: | |
| if "sound" in name_k: | |
| objs = instance.get(name_k) | |
| media_meta[name_k].append([obj for obj in objs]) | |
| if instance.get("video") is not None or instance.get("image") is not None: | |
| for name_k in media_meta: | |
| if "frame" in name_k: | |
| objs = instance.get(name_k) | |
| media_meta[name_k].append([obj for obj in objs]) | |
| if "block_sizes" in instance: | |
| block_sizes.append(instance["block_sizes"]) | |
| else: | |
| block_sizes.append( | |
| [None for _ in range(len(instance.get("image")))] if instance.get("image") is not None else [] | |
| ) | |
| else: | |
| input_ids.extend(instance["input_ids"]) | |
| labels.extend(instance["labels"]) | |
| for name in media: | |
| objs = instance.get(name) | |
| objs = objs if objs is not None else [[] for _ in range(len(instance["input_ids"]))] | |
| media[name].extend(objs) | |
| if instance.get("sound") is not None: | |
| for name_k in media_meta: | |
| if "sound" in name_k: | |
| objs = instance.get(name_k) | |
| media_meta[name_k].extend(objs) | |
| if instance.get("video") is not None or instance.get("image") is not None: | |
| for name_k in media_meta: | |
| if "frame" in name_k: | |
| objs = instance.get(name_k) | |
| media_meta[name_k].append([obj for obj in objs]) | |
| if "block_sizes" in instance: | |
| block_sizes.extend(instance["block_sizes"]) | |
| else: | |
| block_sizes.extend( | |
| [[None for _ in range(len(objs))] for objs in instance.get("image")] | |
| if instance.get("image") is not None | |
| else [[] for _ in range(len(instance["input_ids"]))] | |
| ) | |
| batch_size = len(input_ids) | |
| # Check if the number of media objects (or the number of block sizes) matches the number of media tokens | |
| for name in media: | |
| for k in range(batch_size): | |
| if name == "image" and not all([_ is None for _ in block_sizes[k]]): | |
| actual = len(block_sizes[k]) | |
| else: | |
| actual = len(media[name][k]) | |
| expected = (input_ids[k] == self.tokenizer.media_token_ids[name]).sum().item() | |
| if actual != expected: | |
| raise ValueError( | |
| f"Number mismatch between {name} objects and {name} tokens. " | |
| f"There are {expected} {name} tokens but {actual} {name} objects." | |
| ) | |
| # Batchify the inputs | |
| input_ids = torch.nn.utils.rnn.pad_sequence( | |
| input_ids, | |
| batch_first=True, | |
| padding_value=self.tokenizer.pad_token_id, | |
| ) | |
| labels = torch.nn.utils.rnn.pad_sequence( | |
| labels, | |
| batch_first=True, | |
| padding_value=IGNORE_INDEX, | |
| ) | |
| input_ids = input_ids[:, : self.tokenizer.model_max_length] | |
| labels = labels[:, : self.tokenizer.model_max_length] | |
| attention_mask = input_ids.ne(self.tokenizer.pad_token_id) | |
| # Truncate media objects if necessary | |
| for name in media: | |
| objects = [] | |
| for k in range(batch_size): | |
| if name == "image" and not all([_ is None for _ in block_sizes[k]]): | |
| actual = len(media[name][k]) | |
| num_large_scale_blocks = sum([x * y for x, y in block_sizes[k]]) | |
| num_small_scale_blocks = actual - num_large_scale_blocks | |
| num_small_scale_blocks_each_img = num_small_scale_blocks // len(block_sizes[k]) | |
| expected_full_image = (input_ids[k] == self.tokenizer.media_token_ids[name]).sum().item() | |
| expected = ( | |
| sum([x * y for x, y in block_sizes[k][:expected_full_image]]) | |
| + num_small_scale_blocks_each_img * expected_full_image | |
| ) | |
| if actual > expected: | |
| logger.warning(f"Truncating the number of {name} objects from {actual} to {expected}") | |
| media[name][k] = media[name][k][:expected] | |
| objects.extend(media[name][k]) | |
| block_sizes[k] = block_sizes[k][:expected_full_image] | |
| else: | |
| actual = len(media[name][k]) | |
| expected = (input_ids[k] == self.tokenizer.media_token_ids[name]).sum().item() | |
| if actual > expected: | |
| logger.warning(f"Truncating the number of {name} objects from {actual} to {expected}") | |
| media[name][k] = media[name][k][:expected] | |
| objects.extend(media[name][k]) | |
| if name == "image": | |
| block_sizes[k] = block_sizes[k][:expected] | |
| media[name] = objects | |
| for name in media_meta: | |
| objects = [] | |
| for k in range(batch_size): | |
| try: | |
| objects.extend(media_meta[name][k]) | |
| except: | |
| continue | |
| media_meta[name] = objects | |
| # Flatten block sizes from [[bls_im1_instance1, bls_im2_instance1], [bls_im1_instance2, bls_im2_instance2], ...] to [bls_im1_instance1, bls_im2_instance1, bls_im1_instance2, bls_im2_instance2, ...] | |
| block_sizes = sum(block_sizes, []) | |
| return { | |
| "input_ids": input_ids, | |
| "media": media, | |
| "media_config": {"image": {"block_sizes": block_sizes}, "video": {}, "speech": {}, "sound": {}}, | |
| "labels": labels, | |
| "attention_mask": attention_mask, | |
| "media_meta": media_meta, | |
| } | |