Spaces:
Running
on
A10G
Running
on
A10G
""" | |
Copyright (c) 2022, salesforce.com, inc. | |
All rights reserved. | |
SPDX-License-Identifier: BSD-3-Clause | |
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause | |
""" | |
import json | |
from typing import Iterable | |
from torch.utils.data import Dataset, ConcatDataset | |
from torch.utils.data.dataloader import default_collate | |
class BaseDataset(Dataset): | |
def __init__( | |
self, vis_processor=None, text_processor=None, vis_root=None, ann_paths=[] | |
): | |
""" | |
vis_root (string): Root directory of images (e.g. coco/images/) | |
ann_root (string): directory to store the annotation file | |
""" | |
self.vis_root = vis_root | |
self.annotation = [] | |
for ann_path in ann_paths: | |
self.annotation.extend(json.load(open(ann_path, "r"))['annotations']) | |
self.vis_processor = vis_processor | |
self.text_processor = text_processor | |
self._add_instance_ids() | |
def __len__(self): | |
return len(self.annotation) | |
def collater(self, samples): | |
return default_collate(samples) | |
def set_processors(self, vis_processor, text_processor): | |
self.vis_processor = vis_processor | |
self.text_processor = text_processor | |
def _add_instance_ids(self, key="instance_id"): | |
for idx, ann in enumerate(self.annotation): | |
ann[key] = str(idx) | |
class ConcatDataset(ConcatDataset): | |
def __init__(self, datasets: Iterable[Dataset]) -> None: | |
super().__init__(datasets) | |
def collater(self, samples): | |
# TODO For now only supports datasets with same underlying collater implementations | |
all_keys = set() | |
for s in samples: | |
all_keys.update(s) | |
shared_keys = all_keys | |
for s in samples: | |
shared_keys = shared_keys & set(s.keys()) | |
samples_shared_keys = [] | |
for s in samples: | |
samples_shared_keys.append({k: s[k] for k in s.keys() if k in shared_keys}) | |
return self.datasets[0].collater(samples_shared_keys) | |