File size: 7,686 Bytes
174ae06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
# Copyright (c) 2025 NVIDIA CORPORATION.
# Licensed under the MIT license.

# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
# LICENSE is in incl_licenses directory.

from dataclasses import dataclass
from typing import Any, Dict, Sequence

import torch
from transformers import PreTrainedTokenizer

from llava.constants import IGNORE_INDEX
from llava.utils.logging import logger

__all__ = ["DataCollator"]


@dataclass
class DataCollator:
    tokenizer: PreTrainedTokenizer

    def __init__(self, tokenizer: PreTrainedTokenizer):
        super().__init__()
        self.tokenizer = tokenizer

    def __call__(self, instances: Sequence[Dict[str, Any]]) -> Dict[str, Any]:
        # Gather everything from the batch
        input_ids, labels, media, block_sizes = [], [], {name: [] for name in self.tokenizer.media_tokens}, []

        media_meta = {}

        media_meta["sound_feature_masks"] = []
        media_meta["sound_embed_masks"] = []
        media_meta["frame_times"] = []
        for instance in instances:
            if isinstance(instance["input_ids"], torch.Tensor):
                input_ids.append(instance["input_ids"])
                labels.append(instance["labels"])
                for name in media:
                    objs = instance.get(name)
                    objs = objs if objs is not None else []
                    media[name].append([obj for obj in objs])
                if instance.get("sound") is not None:
                    for name_k in media_meta:
                        if "sound" in name_k:
                            objs = instance.get(name_k)
                            media_meta[name_k].append([obj for obj in objs])
                if instance.get("video") is not None or instance.get("image") is not None:
                    for name_k in media_meta:
                        if "frame" in name_k:
                            objs = instance.get(name_k)
                            media_meta[name_k].append([obj for obj in objs])
                if "block_sizes" in instance:
                    block_sizes.append(instance["block_sizes"])
                else:
                    block_sizes.append(
                        [None for _ in range(len(instance.get("image")))] if instance.get("image") is not None else []
                    )
            else:
                input_ids.extend(instance["input_ids"])
                labels.extend(instance["labels"])
                for name in media:
                    objs = instance.get(name)
                    objs = objs if objs is not None else [[] for _ in range(len(instance["input_ids"]))]
                    media[name].extend(objs)
                if instance.get("sound") is not None:
                    for name_k in media_meta:
                        if "sound" in name_k:
                            objs = instance.get(name_k)
                            media_meta[name_k].extend(objs)
                if instance.get("video") is not None or instance.get("image") is not None:
                    for name_k in media_meta:
                        if "frame" in name_k:
                            objs = instance.get(name_k)
                            media_meta[name_k].append([obj for obj in objs])
                if "block_sizes" in instance:
                    block_sizes.extend(instance["block_sizes"])
                else:
                    block_sizes.extend(
                        [[None for _ in range(len(objs))] for objs in instance.get("image")]
                        if instance.get("image") is not None
                        else [[] for _ in range(len(instance["input_ids"]))]
                    )

        batch_size = len(input_ids)
        

        # Check if the number of media objects (or the number of block sizes) matches the number of media tokens
        for name in media:
            for k in range(batch_size):
                if name == "image" and not all([_ is None for _ in block_sizes[k]]):
                    actual = len(block_sizes[k])
                else:
                    actual = len(media[name][k])
                expected = (input_ids[k] == self.tokenizer.media_token_ids[name]).sum().item()
                if actual != expected:
                    raise ValueError(
                        f"Number mismatch between {name} objects and {name} tokens. "
                        f"There are {expected} {name} tokens but {actual} {name} objects."
                    )
        
        # Batchify the inputs
        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids,
            batch_first=True,
            padding_value=self.tokenizer.pad_token_id,
        )
        labels = torch.nn.utils.rnn.pad_sequence(
            labels,
            batch_first=True,
            padding_value=IGNORE_INDEX,
        )
        input_ids = input_ids[:, : self.tokenizer.model_max_length]
        labels = labels[:, : self.tokenizer.model_max_length]
        attention_mask = input_ids.ne(self.tokenizer.pad_token_id)

        # Truncate media objects if necessary
        for name in media:
            objects = []
            for k in range(batch_size):
                if name == "image" and not all([_ is None for _ in block_sizes[k]]):
                    actual = len(media[name][k])
                    num_large_scale_blocks = sum([x * y for x, y in block_sizes[k]])
                    num_small_scale_blocks = actual - num_large_scale_blocks
                    num_small_scale_blocks_each_img = num_small_scale_blocks // len(block_sizes[k])
                    expected_full_image = (input_ids[k] == self.tokenizer.media_token_ids[name]).sum().item()
                    expected = (
                        sum([x * y for x, y in block_sizes[k][:expected_full_image]])
                        + num_small_scale_blocks_each_img * expected_full_image
                    )
                    if actual > expected:
                        logger.warning(f"Truncating the number of {name} objects from {actual} to {expected}")
                        media[name][k] = media[name][k][:expected]
                    objects.extend(media[name][k])
                    block_sizes[k] = block_sizes[k][:expected_full_image]
                else:
                    actual = len(media[name][k])
                    expected = (input_ids[k] == self.tokenizer.media_token_ids[name]).sum().item()
                    if actual > expected:
                        logger.warning(f"Truncating the number of {name} objects from {actual} to {expected}")
                        media[name][k] = media[name][k][:expected]
                    objects.extend(media[name][k])
                    if name == "image":
                        block_sizes[k] = block_sizes[k][:expected]
            media[name] = objects

        for name in media_meta:
            objects = []
            for k in range(batch_size):
                try:
                    objects.extend(media_meta[name][k])
                except:
                    continue
            media_meta[name] = objects
     
        # Flatten block sizes from [[bls_im1_instance1, bls_im2_instance1], [bls_im1_instance2, bls_im2_instance2], ...] to [bls_im1_instance1, bls_im2_instance1, bls_im1_instance2, bls_im2_instance2, ...]
        block_sizes = sum(block_sizes, [])
        return {
            "input_ids": input_ids,
            "media": media,
            "media_config": {"image": {"block_sizes": block_sizes}, "video": {}, "speech": {}, "sound": {}},
            "labels": labels,
            "attention_mask": attention_mask,
            "media_meta": media_meta,
        }