File size: 17,643 Bytes
002bd9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
import logging
from ...models.sam.processing_sam import SamProcessor
from transformers import PreTrainedTokenizer
from typing import Optional, Union, List, Dict, Any
import torch
import pycocotools.mask
from collections import defaultdict
from dataclasses import dataclass
from PIL import Image
import random
from . import detectron2_transforms as T
import numpy as np


logger = logging.getLogger(__name__)


@dataclass
class SCADataTransformedFeature:
    # num_regions = 35
    images: Image
    pixel_values: torch.Tensor  # (3, 1024, 1024)
    original_sizes: torch.Tensor  # (2,)
    reshaped_input_sizes: torch.Tensor  # (2,)
    input_ids: List[List[int]]  # (num_regions, DIFFERENT_length)
    attention_mask: List[List[int]]  # (num_regions, DIFFERENT_length)
    input_boxes: torch.Tensor  # (num_regions,, 4)
    metadata_input_boxes: torch.Tensor  # (num_regions,, 4)
    metadata_captions: List[str]  # (num_regions,)
    metadata_image_id: torch.Tensor  # (num_regions,)
    metadata_region_id: torch.Tensor  # (num_regions,)


# NOTE(xiaoke): used in collator to batchify the regions
# based on the minimum number of regions in the batch
REGION_KEYS = (
    "input_points",
    "input_labels",
    "input_boxes",
    "metadata_input_boxes",
    "metadata_image_id",
    "metadata_region_id",
)

# NOTE(xiaoke): used in `generate()` to remove unused keys
UNUSED_KEYS_IN_GENERATE = (
    "metadata_captions",
    "metadata_input_boxes",
    "metadata_image_id",
    "metadata_region_id",
    "task_type",
)


class SamCaptionerDataTransform:
    def __init__(
        self,
        sam_processor: SamProcessor,
        tokenizer: PreTrainedTokenizer,
        split: Optional[str] = None,
        num_masks_per_sample: Optional[int] = None,
        data_transforms: Optional[Any] = None,
    ):
        if split is None:
            raise ValueError("split in DataTransfom must be provided")
        if split not in ["train", "inference"]:
            raise ValueError(f"split in DataTransfom must be one of ['train', 'inference'], got {split}")

        if num_masks_per_sample is None and split == "train":
            num_masks_per_sample = 64
            logger.info(f"num_masks_per_sample not provided, defaulting to {num_masks_per_sample}")

        self.sam_processor = sam_processor
        self.tokenizer = tokenizer
        self.split = split
        self.num_masks_per_sample = num_masks_per_sample

        logger.info(f"[{self.__class__}] this is split: {split}")
        max_length = self.tokenizer.model_max_length
        logger.info(f"max_length is {max_length} for caption tokens")

        if data_transforms is None:
            logger.info(f"data_transforms not provided, defaulting to not using it.")

        self.augmentation = None  # NOTE: either split=inference or data_transforms=None.
        if data_transforms is not None and split == "train":
            logger.info("Applying data augmentations during training")
            logger.info(f"data_transforms is {data_transforms}")
            min_scale = data_transforms.min_scale
            max_scale = data_transforms.max_scale
            image_size = data_transforms.image_size

            augmentation = []
            augmentation.append(T.RandomFlip(horizontal=True))
            augmentation.append(
                T.ResizeScale(
                    min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size
                )
            )
            # NOTE: use pad_value=0 to normlized images (with imagenet mean and std)
            augmentation.append(T.FixedSizeCrop(crop_size=(image_size, image_size), pad_value=0))
            self.augmentation = augmentation
            logger.info(f"augmentation is {augmentation}")

            image_mean, image_std = (
                sam_processor.image_processor.image_mean,
                sam_processor.image_processor.image_std,
            )
            image_mean = np.array(image_mean).reshape(1, 1, 3)
            image_std = np.array(image_std).reshape(1, 1, 3)

            self._normalize_np_image = lambda x: (x / 255 - image_mean) / image_std
            self._denormalize_np_image = lambda x: ((x * image_std + image_mean) * 255).astype(np.uint8)

    def __call__(self, examples):
        """_summary_

        Args:
            examples (_type_): _description_

        Returns:
            dict: tensors
                # sam related
                - images: (batch_size, 3, H, W)
                - pixel_values: (batch_size, 3, H, W)
                - input_points
                - input_labels
                - input_boxes

                # used to crop and post-processing
                - original_sizes (batch_size, 2)
                - reshaped_input_sizes (batch_size, 2)

                # caption related
                - input_ids
                - attention_mask
                - labels

                # used for post-processing, or evaluation
                # maybe need to be removed for `.generate()`
                - metadata_input_boxes (batch_size, num_regions, 4)
                - metadata_captions: List[List[str]]
                - metadata_image_id: (batch_size,)
                - metadata_region_id: (batch_size, num_regions)
        """
        batch_image = examples["image"]
        batch_transforms = None
        # NOTE: Apply augmentations to images from detectron2
        if self.augmentation is not None:
            # NOTE: convert image to RGB, since there may be gray images.
            # transformers/image_transforms.py
            # transformers/models/sam/image_processing_sam.py
            # NOTE: convert to np array and normalize with image_mean and image_std
            np_image = [self._normalize_np_image(np.array(image_.convert("RGB"))) for image_ in batch_image]
            batch_aug_image = []
            batch_aug_transforms = []
            for image_ in np_image:
                aug_image_, aug_transforms_ = T.apply_transform_gens(self.augmentation, image_)
                batch_aug_image.append(self._denormalize_np_image(aug_image_))
                batch_aug_transforms.append(aug_transforms_)
            batch_image = batch_aug_image
            batch_transforms = batch_aug_transforms

        # NOTE: Process images with huggingface's preprocessor
        encoding_sam_processor = self.sam_processor(batch_image, return_tensors="pt")
        original_sizes = encoding_sam_processor["original_sizes"]
        # print(examples)
        encoding_region_processor = defaultdict(list)
        batch_regions = examples["regions"]
        batch_image_id = examples["image_id"]
        for idx, regions in enumerate(batch_regions):
            if batch_transforms is not None:
                transforms = batch_transforms[idx]
            else:
                transforms = None
            _encoding_region_processor = self.process_regions(
                original_sizes[idx : idx + 1], regions, batch_image_id[idx], transforms
            )
            for k, v in _encoding_region_processor.items():
                encoding_region_processor[k].append(v)

        return {
            "images": batch_image,
            **encoding_sam_processor,
            **encoding_region_processor,
            "task_type": examples["task_type"],
        }

    def process_regions(self, original_sizes, regions, image_id_to_be_checked, transforms=None):
        if self.split == "train":
            num_regions = len(regions)
            region_index_ls = torch.randperm(num_regions)
        else:
            region_index_ls = torch.arange(len(regions))

        # boxes, original_caption_ls, mask_ls = self.prepare_regions(regions, region_index_ls)
        # TODO: support mask and point. Now only support box
        original_boxes, original_caption_ls = self.load_regions_by_index_ls(regions, region_index_ls)

        # NOTE: Apply augmentations to regions from detectron2
        if transforms is not None:
            boxes = transforms.apply_box(original_boxes)
            # NOTE: boxes are transformed in np in detectron2, so we need to convert it back to tensor
            boxes = torch.from_numpy(boxes)
        else:
            boxes = original_boxes
        # print(f"boxes after transform: {boxes}")
        # NOTE: clip and clamp, de-empty the boxes
        # TODO: support mask and point. Now only support box
        # NOTE: the input shape is from the transformed images. We normalize the boxes upon the transformed images.
        # https://github.com/UX-Decoder/Semantic-SAM/blob/e3b9e758ba8832eebdf30e00bc4fe4b45963f010/datasets/dataset_mappers/coco_instance_new_baseline_dataset_mapper.py#L144
        # NOTE: `original_sizes` are the image size after transformed if `transforms` is not None.
        # NOTE: filter boxes, region_index_ls, original_boxes, original_caption_ls
        boxes, region_index_ls, original_boxes, original_caption_ls = normalize_and_filter_transformed_boxes(
            original_sizes, boxes, region_index_ls, original_boxes, original_caption_ls, mode=self.split
        )
        # print(f'boxes after normalize_and_filter_transformed_boxes: {boxes}')
        if self.split == "train":
            # NOTE: sample regions with `num_masks_per_sample`
            # NOTE: the regions are not enough, from 3 to 250+. We futher handle this in collator for not enough regions (< num_masks_per_sample).
            boxes = boxes[: self.num_masks_per_sample]
            region_index_ls = region_index_ls[: self.num_masks_per_sample]
            original_boxes = original_boxes[: self.num_masks_per_sample]
            original_caption_ls = original_caption_ls[: self.num_masks_per_sample]

        # NOTE: Process regions with huggingface's preprocessor
        input_boxes = boxes.unsqueeze(0)  # (batch_size (1), num_regions, 4)
        assert len(original_sizes) == len(input_boxes)
        # NOTE: sam's processor only convert prompts to tensors and scale it to 1024.
        # It does NOT clip, clamp, or de-empty the points, boxes, and masks.
        # NOTE: We have to make sure every sample has the same keys, or /anaconda/envs/sca-v2/lib/python3.9/site-packages/datasets/arrow_dataset.py:2799L raises error due to missing value.
        # XXX: unused keys of SAM are in None. But the empty samples are also None. Thus we need to use "input_ids" and "attention_mask" to filter the empty samples.
        encoding_prompts = {
            "input_boxes": None,
            "input_points": None,
            "input_labels": None,
        }
        try:
            encoding_prompts.update(
                self.sam_processor.process_prompts(original_sizes, input_boxes=input_boxes, return_tensors="pt")
            )
        except Exception as e:
            logger.warning(
                f"Error in processing prompts: {e} maybe due to no visual prompts."
                "As they are truncate by the normalization and filtering of boxes."
                "e.g., large scale jittering."
            )

        # NOTE(xiaoke): remove the batch dimension (=1)
        for k, v in encoding_prompts.items():
            if v is not None:
                encoding_prompts[k] = v.squeeze(0)

        # NOTE(xiaoke): first truncate based on `model_max_length`.
        # the dynamic batch padding happens in the collator.
        # NOTE: We have to make sure every sample has the same keys, or /anaconda/envs/sca-v2/lib/python3.9/site-packages/datasets/arrow_dataset.py:2799L raises error due to missing value.
        # NOTE: empty sample has both "input_ids": None, "attention_mask": None
        encoding_tokenizer = {"input_ids": None, "attention_mask": None}
        try:
            encoding_tokenizer.update(self.process_tokens(original_caption_ls))
        except Exception as e:
            logger.warning(
                f"Error in processing tokens: {e} maybe due to no captions."
                "As they are truncate by the normalization and filtering of boxes."
                "e.g., large scale jittering."
            )

        # NOTE(xiaoke): get metadata by index
        metadata_image_id = torch.tensor([regions[index]["image_id"] for index in region_index_ls])
        metadata_region_id = torch.tensor([regions[index]["region_id"] for index in region_index_ls])
        if any(image_id_to_be_checked != metadata_image_id):
            logger.warning(
                f"There are image_id in region different from the ture image_id: {image_id_to_be_checked} != {metadata_image_id}"
            )

        return {
            **encoding_tokenizer,
            **encoding_prompts,
            "metadata_input_boxes": original_boxes,
            "metadata_captions": original_caption_ls,
            "metadata_image_id": metadata_image_id,
            "metadata_region_id": metadata_region_id,
        }

    def load_regions_by_index_ls(self, regions, region_index_ls):
        # mask_ls = []
        box_ls = []
        original_caption_ls = []
        for region_index in region_index_ls:
            region = regions[region_index]

            # TODO(xiaoke): add mask to support point prompts
            # mask = torch.tensor(pycocotools.mask.decode(region["mask"]))
            box = torch.tensor(
                [
                    region["x"],
                    region["y"],
                    region["x"] + region["width"],
                    region["y"] + region["height"],
                ]
            )
            # XXX(xiaoke): How to support multiple gt captions?
            if self.split == "train":
                # NOTE(xiaoke): Now we only randomly take one caption among multiple gt captions.
                gt_caption = random.choice(region["phrases"])
            else:
                gt_caption = region["phrases"][0]

            # mask_ls.append(mask)
            box_ls.append(box)
            original_caption_ls.append(gt_caption)
        try:
            boxes = torch.stack(box_ls)  # (num_regions, 4)
        except Exception as e:
            
            logger.warning(f"regions: {regions}")
            raise e
        # print("boxes from load_regions_by_index_ls", boxes)
        # return boxes, original_caption_ls, mask_ls
        return boxes, original_caption_ls

    def process_tokens(self, texts):
        return self.tokenizer(texts, truncation=True)


def normalize_and_filter_transformed_boxes(
    original_sizes, transformed_boxes, region_index_ls, original_boxes, original_caption_ls, mode
):
    # print(f"input boxes: {transformed_boxes}")
    # print(f"original sizes: {original_sizes}")
    # print(f"region index ls: {region_index_ls}")
    # print(f"original boxes: {original_boxes}")
    
    # NOTE: filter empty regions
    transformed_boxes = transformed_boxes.clamp(min=0)
    # NOTE: original sizes: (y, x), while transformed boxes: (x, y, x, y)
    # Flip as in `transform_instance_annotations` https://github.com/facebookresearch/detectron2/blob/2409af0bf0d4bdcc685feb6d2c7fd659828acac4/detectron2/data/detection_utils.py#L263
    transformed_boxes = torch.minimum(transformed_boxes, original_sizes.expand(2, 2).flatten().flip(0))
    # NOTE(yunfei): adjust boxes that exceed original sizes
    max_sizes = original_sizes.flip(1).repeat(1, 2)
    
    # Adjust boxes that exceed original size
    transformed_boxes = torch.min(transformed_boxes, max_sizes - 1)    
    if mode == "train":
        keep_indices = _box_nonempty(transformed_boxes)

        transformed_boxes = transformed_boxes[keep_indices]
        region_index_ls = region_index_ls[keep_indices]
        original_boxes = original_boxes[keep_indices]
        # NOTE: the implicit cast from tensor to numpy array may cause error when its size is 1. It cast [True] to [1], leading to index out of range.
        # Thus we need to explicit cast tensor to numpy array.
        original_caption_ls = np.array(original_caption_ls)[keep_indices.numpy()].tolist()
    return transformed_boxes, region_index_ls, original_boxes, original_caption_ls


def _box_nonempty(box, threshold=1e-5):
    # Copy from: https://github.com/facebookresearch/detectron2/blob/2bd05b42983468c50a1c80d5e7dc3952980e1cd4/detectron2/structures/boxes.py#L199
    # The threshold: https://github.com/facebookresearch/detectron2/blob/2409af0bf0d4bdcc685feb6d2c7fd659828acac4/detectron2/data/detection_utils.py#L498
    widths = box[..., 2] - box[..., 0]
    heights = box[..., 3] - box[..., 1]
    keep = (widths > threshold) & (heights > threshold)

    return keep


class SCADataTransform(SamCaptionerDataTransform):
    def process_tokens(self, texts):
        # XXX(xiaoke): We trim two tokens for the "virtual" <BOS> and true <EOS>.
        outputs = self.tokenizer(texts, truncation=True, max_length=self.tokenizer.model_max_length - 2)
        # NOTE(xiaoke): add eos token
        # XXX(xiaoke): since we do not add the <BOS> token in both training and inference stage.
        # Therefore, we need to shift the labels to the right by one. However, this leads to the situation where labels have one more token than the input_ids,
        # and the max_length of the labels is larger by 1 than tokeinizer.model_max_length, which cause error in trainer eval when compute eval loss due to mismatched last dim during cross-batch-region paddding.
        # Our solution is to trim the input_ids and attention_mask by 1.
        # Check `src/data/collator.py:SCADataCollator:prepare_labels` for details.
        for i in range(len(outputs["input_ids"])):
            outputs["input_ids"][i] += [self.tokenizer.eos_token_id]
            outputs["attention_mask"][i] += [1]
        return outputs