| import logging |
| import os |
| import random |
| from dataclasses import dataclass |
| from multiprocessing import Value |
| import numpy as np |
| from training.utils import mask2box |
| import torch |
| from PIL import Image |
| from torch.utils.data import Dataset, DataLoader |
| from torch.utils.data.distributed import DistributedSampler |
| from open_clip.transform import get_scale |
| from pycocotools.coco import COCO |
| from training.coco_api import COCOPanoptic |
| from panopticapi import utils |
| import io |
| |
| try: |
| from petrel_client.client import Client |
| except: |
| Client = None |
| from open_clip.transform import ResizeLongest |
|
|
| |
| from torchvision.transforms import RandomHorizontalFlip, Compose |
| from training.custom_transforms import CustomRandomResize, CustomRandomCrop |
|
|
|
|
| class ProposalDistillDataset(Dataset): |
| def __init__(self, input_filename, transforms, image_root, |
| crop_size=224, |
| tokenizer=None, args=None): |
| logging.debug(f'Loading coco style data from {input_filename}.') |
| self.coco = COCO(input_filename) |
| logging.debug('Done loading data.') |
| self.transforms = transforms |
| self.tokenize = tokenizer |
| self.image_root = image_root |
| self.image_ids = list(self.coco.imgs.keys()) |
| self.max_anns = 20 |
| if not isinstance(crop_size, (tuple, list)): |
| crop_size = [crop_size, crop_size] |
| self.crop_size = crop_size |
| self.args = args |
|
|
| self.min_size = args.min_size |
| self.max_size = args.max_size |
|
|
| self.ceph_root = args.train_ceph_root |
| self.use_ceph = (self.ceph_root != "") |
| self.FILE_CLIENT = None |
|
|
| def read_image(self, image_name): |
| if self.use_ceph: |
| image_path = os.path.join(self.ceph_root, image_name) |
| if self.FILE_CLIENT is None: |
| self.FILE_CLIENT = Client() |
| try: |
| img_bytes = self.FILE_CLIENT.get(image_path) |
| buff = io.BytesIO(img_bytes) |
| image = Image.open(buff) |
| except: |
| print(f"Cannot load {image_path}", flush=True) |
| return None |
| else: |
| image_path = os.path.join(self.image_root, image_name) |
| try: |
| image = Image.open(image_path) |
| except: |
| print(f"Cannot load {image_path}", flush=True) |
| return None |
|
|
| width, height = image.size |
| if width < 10 or height < 10: |
| print(f"Invalid image, size {image.size}", flush=True) |
| return None |
|
|
| return image |
|
|
| def __len__(self): |
| return len(self.image_ids) |
|
|
| def __getitem__(self, idx): |
| image_id = self.image_ids[idx] |
| image_info = self.coco.imgs[image_id] |
| if 'file_name' in image_info: |
| image_name = image_info['file_name'] |
| else: |
| assert 'coco_url' in image_info |
| coco_url = image_info['coco_url'].split('/') |
| image_name = os.path.join(coco_url[-2], coco_url[-1]) |
|
|
| old_image = self.read_image(image_name) |
| if old_image is None: |
| next_id = random.choice(range(self.__len__())) |
| return self.__getitem__(next_id) |
| img_w, img_h = old_image.width, old_image.height |
| new_image = self.transforms[0](old_image) |
|
|
| scale = get_scale(old_image, new_image) |
| anns = self.coco.imgToAnns[image_id] |
| boxes_template = torch.zeros(self.max_anns, 4 + 1) |
| image_crops = torch.zeros(self.max_anns, 3, *self.crop_size) |
|
|
| indices = list(range(len(anns))) |
| random.shuffle(indices) |
| num_valid_boxes = 0 |
| for i, ann_id in enumerate(indices[:self.max_anns]): |
| ann = anns[ann_id] |
| x, y, w, h = ann['bbox'] |
| if w*h < (self.min_size ** 2) or w*h > (self.max_size ** 2): |
| continue |
| num_valid_boxes += 1 |
| cx, cy = x + w*0.5, y + h*0.5 |
| x0, y0, x1, y1 = \ |
| max(cx - w*0.75, 0), max(cy - h*0.75, 0), min(cx + w*0.75, img_w), min(cy + h*0.75, img_h) |
| image_crops[i] = self.transforms[1](old_image.crop((x0, y0, x1, y1))) |
| box_info = torch.tensor([x, y, x + w, y + h, 1.0]) |
| boxes_template[i] = box_info |
|
|
| if num_valid_boxes == 0: |
| boxes_template[0] = torch.tensor([0, 0, img_w / 4, img_h / 4, 1.0]) |
| image_crops[0] = self.transforms[1](old_image.crop((0, 0, img_w // 4, img_h // 4))) |
|
|
| _, h, w = new_image.shape |
|
|
| boxes_template[:, :4] *= scale |
| boxes_template[:, [0, 2]] /= w |
| boxes_template[:, [1, 3]] /= h |
|
|
| return new_image, boxes_template, image_crops |
|
|
|
|
| class GridDistillDataset(Dataset): |
| def __init__(self, |
| input_filename, transforms, image_root, |
| max_split=16, |
| crop_size=224, |
| pre_transforms=False, |
| ceph_root="", args=None): |
| self._init_choices(max_split) |
| logging.debug(f'Loading coco caption style data from {input_filename}.') |
| self.coco = COCO(input_filename) |
| logging.debug('Done loading data.') |
| self.transforms = transforms |
| self.image_root = image_root |
| self.args = args |
| image_ids = list(self.coco.imgs.keys()) |
| train_ratio = args.train_ratio |
| if train_ratio < 1.0: |
| num_images = int(len(image_ids) * train_ratio) |
| random.shuffle(image_ids) |
| image_ids = image_ids[:num_images] |
| self.image_ids = image_ids |
| self.max_anns = args.max_boxes |
| if not isinstance(crop_size, (tuple, list)): |
| crop_size = [crop_size, crop_size] |
| self.crop_size = crop_size |
| self._init_boxes() |
| self.ceph_root = ceph_root |
| self.use_ceph = (ceph_root != "") |
| self.FILE_CLIENT = None |
| if pre_transforms: |
| self.pre_transforms = Compose([ |
| CustomRandomResize(scale=(0.5, 2.0)), |
| CustomRandomCrop(size=self.transforms[0].transforms[0].max_size), |
| RandomHorizontalFlip()]) |
| else: |
| self.pre_transforms = None |
|
|
| def read_image(self, image_name): |
| if self.use_ceph: |
| image_path = os.path.join(self.ceph_root, image_name) |
| if self.FILE_CLIENT is None: |
| self.FILE_CLIENT = Client() |
| try: |
| img_bytes = self.FILE_CLIENT.get(image_path) |
| buff = io.BytesIO(img_bytes) |
| image = Image.open(buff) |
| except: |
| print(f"Cannot load {image_path}", flush=True) |
| return None |
| else: |
| image_path = os.path.join(self.image_root, image_name) |
| try: |
| image = Image.open(image_path) |
| except: |
| print(f"Cannot load {image_path}", flush=True) |
| return None |
|
|
| width, height = image.size |
| if width < 10 or height < 10: |
| print(f"Invalid image, size {image.size}", flush=True) |
| return None |
|
|
| return image |
|
|
|
|
| def _init_choices(self, M=16): |
| choices = [] |
| for m in range(2, M+1): |
| for n in range((m + 1)//2+1, min(m*2 + 1, M+1)): |
| choices.append((m, n)) |
| self.choices = choices |
|
|
| def __len__(self): |
| return len(self.image_ids) |
|
|
| def _init_boxes(self, ): |
| box_templates = {} |
| for choice in self.choices: |
| M, N = choice |
| grid_x, grid_y = torch.meshgrid(torch.linspace(0, 1, N + 1), torch.linspace(0, 1, M + 1), |
| indexing='xy') |
| x0y0s = torch.stack([grid_x[:M, :N], grid_y[:M, :N]], dim=-1) |
| x1y1s = torch.stack([grid_x[1:, 1:], grid_y[1:, 1:]], dim=-1) |
| pseudo_boxes = torch.cat([x0y0s, x1y1s], |
| dim=-1).view(-1, 4) |
|
|
| assert pseudo_boxes.shape[0] == M*N |
| box_templates[choice] = pseudo_boxes |
|
|
| self.box_templates = box_templates |
|
|
| def _obtain_image_crops(self, image, choice): |
| image_crops = [] |
| img_w, img_h = image.size |
| normed_boxes = self.box_templates[choice] |
| indices = list(range(len(normed_boxes))) |
| random.shuffle(indices) |
| indices = indices[:self.max_anns] |
| boxes = normed_boxes * torch.tensor([img_w, img_h, img_w, img_h]) |
| for idx in indices: |
| box = boxes[idx] |
| x0, y0, x1, y1 = box.tolist() |
| if self.args.crop_scale > 1.0: |
| box_w, box_h = x1 - x0, y1 - y0 |
| cx, cy = (x1 + x0)/2, (y1 + y0)/2 |
| delta_factor = 0.5 * self.args.crop_scale |
| x0, y0, x1, y1 = max(cx - box_w * delta_factor, 0), max(cy - box_h * delta_factor, 0), \ |
| min(cx + box_w * delta_factor, img_w), min(cy + box_h * delta_factor, img_h) |
| image_crops.append(self.transforms[1](image.crop((x0, y0, x1, y1)))) |
|
|
| return torch.stack(image_crops), boxes[indices] |
|
|
| def __getitem__(self, idx): |
| image_id = self.image_ids[idx] |
| image_info = self.coco.imgs[image_id] |
| if 'file_name' in image_info: |
| image_name = image_info['file_name'] |
| else: |
| assert 'coco_url' in image_info |
| coco_url = image_info['coco_url'].split('/') |
| image_name = os.path.join(coco_url[-2], coco_url[-1]) |
| |
| |
| old_image = self.read_image(image_name) |
| if old_image is None: |
| next_id = random.choice(range(self.__len__())) |
| return self.__getitem__(next_id) |
| new_image = self.transforms[0](old_image) |
|
|
| scale = get_scale(old_image, new_image) |
| boxes_template = torch.zeros(self.max_anns, 4 + 1) |
| image_crops_template = torch.zeros(self.max_anns, 3, *self.crop_size) |
| image_crops, boxes = self._obtain_image_crops(old_image, |
| random.choice(self.choices)) |
| assert image_crops.shape[0] == boxes.shape[0] |
| _, h, w = new_image.shape |
|
|
| boxes[:, :4] *= scale |
| boxes[:, [0, 2]] /= w |
| boxes[:, [1, 3]] /= h |
|
|
| boxes_template[:boxes.shape[0], :4] = boxes |
| boxes_template[:boxes.shape[0], 4] = 1.0 |
|
|
| image_crops_template[:boxes.shape[0]] = image_crops |
|
|
| return new_image, boxes_template, image_crops_template |
|
|
|
|
| class COCOPanopticDataset(Dataset): |
| def __init__(self, input_filename, transforms, image_root, embed_path, |
| segm_root, |
| crop_size=224, |
| tokenizer=None, |
| downsample_factor=16, |
| min_size=8, max_size=1024): |
| logging.debug(f'Loading coco caption style data from {input_filename}.') |
| self.coco = COCOPanoptic(input_filename) |
| logging.debug('Done loading data.') |
| self.transforms = transforms |
| self.tokenize = tokenizer |
| self.image_root = image_root |
| self.embeddings = np.load(embed_path) |
| self.image_ids = list(self.coco.imgs.keys()) |
| num_annos = [len(anns) for anns in self.coco.imgToAnns.values()] |
| self.max_anns = min(max(num_annos), 100) |
| if not isinstance(crop_size, (tuple, list)): |
| crop_size = [crop_size, crop_size] |
| self.crop_size = crop_size |
| self.min_size = 8 |
| self.max_size = 1024 |
| self.segm_root = segm_root |
| self.downsample_factor = downsample_factor |
| self.segm_transform = ResizeLongest(max_size=self.transforms[0].transforms[0].max_size // downsample_factor, |
| fill=0) |
|
|
| cat_ids = sorted([cat['id'] for cat in self.coco.cats.values()]) |
|
|
| self.cat_id2label = {cat_id: label for label, cat_id in enumerate(cat_ids)} |
|
|
| def __len__(self): |
| return len(self.image_ids) |
|
|
| @staticmethod |
| def _load_segm(segm_path): |
| segmentation = np.array( |
| Image.open(segm_path), |
| dtype=np.uint8 |
| ) |
| |
| |
| |
| segm_map = utils.rgb2id(segmentation) |
|
|
| return segm_map |
|
|
| def __getitem__(self, idx): |
| image_id = self.image_ids[idx] |
| image_info = self.coco.imgs[image_id] |
| image_name = image_info['file_name'] |
| segm_file = image_info['segm_file'] |
| image_path = os.path.join(self.image_root, image_name) |
| segm_path = os.path.join(self.segm_root, segm_file) |
| segm_map = self._load_segm(segm_path) |
|
|
| old_image = Image.open(image_path) |
| img_w, img_h = old_image.width, old_image.height |
| new_image = self.transforms[0](old_image) |
|
|
| scale = get_scale(old_image, new_image) |
| anns = self.coco.imgToAnns[image_id] |
| boxes_template = torch.zeros(self.max_anns, 4 + 2 + 1 + 1) |
| image_crops = torch.zeros(self.max_anns, 3, *self.crop_size) |
| gt_masks = torch.zeros(self.max_anns, self.segm_transform.max_size, |
| self.segm_transform.max_size) |
| masked_image_crops = torch.zeros(self.max_anns, 3, *self.crop_size) |
|
|
| for i, ann in enumerate(anns): |
| if i == self.max_anns: |
| break |
| cat_id = ann['category_id'] |
| is_thing = self.coco.cats[cat_id]['isthing'] |
| if is_thing > 0: |
| x, y, w, h = ann['bbox'] |
| cx, cy = x + w*0.5, y + h*0.5 |
| x0, y0, x1, y1 = \ |
| max(cx - w*0.75, 0), max(cy - h*0.75, 0), min(cx + w*0.75, img_w), min(cy + h*0.75, img_h) |
| else: |
| x0, y0, x1, y1 = mask2box(segm_map == ann['id']) |
| x, y, w, h = x0, y0, x1 - x0, y1 - y0 |
| if w * h < (self.min_size ** 2) or w * h > (self.max_size ** 2): |
| continue |
| image_crops[i] = self.transforms[1](old_image.crop((x0, y0, x1, y1))) |
| |
| np_old_image = np.asarray(old_image.copy()).copy() |
| np_old_image[segm_map != ann['id']] = 114 |
|
|
| masked_old_image = Image.fromarray(np_old_image) |
| masked_image_crops[i] = self.transforms[1](masked_old_image.crop((x0, y0, x1, y1))) |
|
|
| gt_mask = torch.from_numpy(segm_map == ann['id']).float() |
| gt_mask = self.segm_transform(gt_mask[None]) > 0.0 |
| cls_label = self.cat_id2label[cat_id] |
| box_info = torch.tensor([x, y, x + w, y + h, cls_label, 1.0, w * h, is_thing]) |
| boxes_template[i] = box_info |
| gt_masks[i] = gt_mask[0] |
|
|
| _, h, w = new_image.shape |
|
|
| boxes_template[:, :4] *= scale |
| boxes_template[:, [0, 2]] /= w |
| boxes_template[:, [1, 3]] /= h |
|
|
| return new_image, boxes_template, image_crops, gt_masks, masked_image_crops |
|
|
| class ADEPanopticDataset(Dataset): |
| def __init__(self, input_filename, transforms, image_root, embed_path, |
| segm_root, |
| crop_size=224, |
| tokenizer=None, |
| downsample_factor=16, |
| min_size=8, max_size=1024): |
| logging.debug(f'Loading coco caption style data from {input_filename}.') |
| self.coco = COCOPanoptic(input_filename) |
| logging.debug('Done loading data.') |
| self.transforms = transforms |
| self.tokenize = tokenizer |
| self.image_root = image_root |
| self.embeddings = np.load(embed_path) |
| self.image_ids = list(self.coco.imgs.keys()) |
|
|
| num_annos = [len(anns) for anns in self.coco.imgToAnns.values()] |
| self.max_anns = min(max(num_annos), 100) |
| if not isinstance(crop_size, (tuple, list)): |
| crop_size = [crop_size, crop_size] |
| self.crop_size = crop_size |
| self.min_size = 8 |
| self.max_size = 1024 |
| self.segm_root = segm_root |
|
|
| self.downsample_factor = downsample_factor |
| self.segm_transform = ResizeLongest(max_size=self.transforms[0].transforms[0].max_size // downsample_factor, |
| fill=0) |
|
|
| cat_ids = sorted([cat['id'] for cat in self.coco.cats.values()]) |
|
|
| self.cat_id2label = {cat_id: label for label, cat_id in enumerate(cat_ids)} |
|
|
| def __len__(self): |
| return len(self.image_ids) |
|
|
| @staticmethod |
| def _load_segm(segm_path): |
| segmentation = np.array( |
| Image.open(segm_path), |
| dtype=np.uint8 |
| ) |
| |
| |
| |
| segm_map = utils.rgb2id(segmentation) |
|
|
| return segm_map |
|
|
| def __getitem__(self, idx): |
| image_id = self.image_ids[idx] |
| image_info = self.coco.imgs[image_id] |
| image_name = image_info['file_name'] |
| segm_file = image_info['segm_file'] |
| image_path = os.path.join(self.image_root, image_name) |
| segm_path = os.path.join(self.segm_root, segm_file) |
| |
| segm_map = self._load_segm(segm_path) |
| |
| old_image = Image.open(image_path) |
| img_w, img_h = old_image.width, old_image.height |
| new_image = self.transforms[0](old_image) |
|
|
| scale = get_scale(old_image, new_image) |
| anns = self.coco.imgToAnns[image_id] |
| boxes_template = torch.zeros(self.max_anns, 4 + 2 + 1 + 1) |
| image_crops = torch.zeros(self.max_anns, 3, *self.crop_size) |
| gt_masks = torch.zeros(self.max_anns, self.segm_transform.max_size, |
| self.segm_transform.max_size) |
| masked_image_crops = torch.zeros(self.max_anns, 3, *self.crop_size) |
|
|
| for i, ann in enumerate(anns): |
| if i == self.max_anns: |
| break |
| cat_id = ann['category_id'] |
| is_thing = self.coco.cats[cat_id]['isthing'] |
| if is_thing > 0: |
| x, y, w, h = ann['bbox'] |
| cx, cy = x + w*0.5, y + h*0.5 |
| x0, y0, x1, y1 = \ |
| max(cx - w*0.75, 0), max(cy - h*0.75, 0), min(cx + w*0.75, img_w), min(cy + h*0.75, img_h) |
| else: |
|
|
| x0, y0, x1, y1 = mask2box(segm_map == ann['id']) |
| x, y, w, h = x0, y0, x1 - x0, y1 - y0 |
| if w * h < (self.min_size ** 2) or w * h > (self.max_size ** 2): |
| continue |
| image_crops[i] = self.transforms[1](old_image.crop((x0, y0, x1, y1))) |
| |
| np_old_image = np.asarray(old_image.copy()) |
| np_old_image = np_old_image.copy() |
|
|
| np_old_image[segm_map != ann['id']] = 114 |
| masked_old_image = Image.fromarray(np_old_image) |
| masked_image_crops[i] = self.transforms[1](masked_old_image.crop((x0, y0, x1, y1))) |
|
|
| gt_mask = torch.from_numpy(segm_map == ann['id']).float() |
| gt_mask = self.segm_transform(gt_mask[None]) > 0.0 |
| cls_label = self.cat_id2label[cat_id] |
| box_info = torch.tensor([x, y, x + w, y + h, cls_label, 1.0, w * h, is_thing]) |
| boxes_template[i] = box_info |
| gt_masks[i] = gt_mask[0] |
|
|
| _, h, w = new_image.shape |
|
|
| boxes_template[:, :4] *= scale |
| boxes_template[:, [0, 2]] /= w |
| boxes_template[:, [1, 3]] /= h |
| |
| return new_image, boxes_template, image_crops, gt_masks, masked_image_crops |
|
|
|
|
| class COCORegionCLIPDataset(Dataset): |
| def __init__(self, input_filename, transforms, image_root, args): |
| logging.debug(f'Loading coco caption style data from {input_filename}.') |
| self.coco = COCO(input_filename) |
| logging.debug('Done loading data.') |
| self.transforms = transforms |
| self.image_root = image_root |
| image_ids = list(self.coco.imgToAnns.keys()) |
| train_ratio = args.train_ratio |
| if train_ratio < 1.0: |
| num_images = int(len(image_ids) * train_ratio) |
| random.shuffle(image_ids) |
| image_ids = image_ids[:num_images] |
| self.image_ids = image_ids |
|
|
| num_annos = [len(anns) for anns in self.coco.imgToAnns.values()] |
| self.max_anns = min(max(num_annos), 20) |
| self.args = args |
| self.ceph_root = args.train_ceph_root |
| self.use_ceph = (self.ceph_root != "") |
| self.FILE_CLIENT = None |
| cat_ids = sorted([cat['id'] for cat in self.coco.cats.values()]) |
|
|
| self.cat_id2label = {cat_id: label for label, cat_id in enumerate(cat_ids)} |
|
|
| def __len__(self): |
| return len(self.image_ids) |
|
|
| def read_image(self, image_name): |
| if self.use_ceph: |
| image_path = os.path.join(self.ceph_root, image_name) |
| if self.FILE_CLIENT is None: |
| self.FILE_CLIENT = Client() |
| img_bytes = self.FILE_CLIENT.get(image_path) |
| buff = io.BytesIO(img_bytes) |
| image = Image.open(buff) |
| else: |
| image_path = os.path.join(self.image_root, image_name) |
| image = Image.open(image_path) |
| return image |
|
|
| def __getitem__(self, idx): |
| image_id = self.image_ids[idx] |
| image_info = self.coco.imgs[image_id] |
| image_name = image_info['file_name'] |
| |
| |
| old_image = self.read_image(image_name) |
| new_image = self.transforms[0](old_image) |
|
|
| scale = get_scale(old_image, new_image) |
| anns = self.coco.imgToAnns[image_id] |
| boxes_template = torch.zeros(self.max_anns, 4 + 2) |
|
|
| for i, ann in enumerate(anns): |
| if i == self.max_anns: |
| break |
| cat_id = ann['category_id'] |
| x, y, w, h = ann['bbox'] |
| cls_label = self.cat_id2label[cat_id] |
| box_info = torch.tensor([x, y, x + w, y + h, cls_label, 1.0]) |
| boxes_template[i] = box_info |
|
|
| _, h, w = new_image.shape |
|
|
| boxes_template[:, :4] *= scale |
| boxes_template[:, [0, 2]] /= w |
| boxes_template[:, [1, 3]] /= h |
|
|
| return new_image, boxes_template |
|
|
|
|
| def get_coco_panoptic_dataset(args, preprocess_fn, is_train, epoch=0, tokenizer=None): |
| input_filename = args.train_data if is_train else args.val_data |
| assert input_filename |
| dataset = COCOPanopticDataset( |
| input_filename, |
| preprocess_fn, |
| segm_root=args.val_segm_root, |
| image_root=args.val_image_root, |
| embed_path=args.embed_path, |
| tokenizer=tokenizer, |
| crop_size=args.input_size, |
| min_size=args.min_size, |
| max_size=args.max_size, |
| downsample_factor=args.downsample_factor |
| ) |
| num_samples = len(dataset) |
| |
| sampler = DistributedSampler(dataset) if args.distributed else None |
| shuffle = is_train and sampler is None |
| if is_train: |
| batch_size = args.batch_size |
| else: |
| batch_size = min(args.batch_size, 1) |
| dataloader = DataLoader( |
| dataset, |
| batch_size=batch_size, |
| shuffle=shuffle, |
| num_workers=args.workers, |
| pin_memory=True, |
| sampler=sampler, |
| drop_last=is_train, |
| ) |
| dataloader.num_samples = num_samples |
| dataloader.num_batches = len(dataloader) |
|
|
| return DataInfo(dataloader, sampler) |
|
|
| def get_ade_panoptic_dataset(args, preprocess_fn, is_train, epoch=0, tokenizer=None): |
| input_filename = args.train_data if is_train else args.val_data |
| assert input_filename |
| dataset = ADEPanopticDataset( |
| input_filename, |
| preprocess_fn, |
| segm_root=args.val_segm_root, |
| image_root=args.val_image_root, |
| embed_path=args.embed_path, |
| tokenizer=tokenizer, |
| crop_size=args.input_size, |
| min_size=args.min_size, |
| max_size=args.max_size, |
| downsample_factor=args.downsample_factor |
| ) |
| num_samples = len(dataset) |
| |
| sampler = DistributedSampler(dataset) if args.distributed else None |
| shuffle = is_train and sampler is None |
| if is_train: |
| batch_size = args.batch_size |
| else: |
| batch_size = min(args.batch_size, 1) |
| dataloader = DataLoader( |
| dataset, |
| batch_size=batch_size, |
| shuffle=shuffle, |
| num_workers=args.workers, |
| pin_memory=True, |
| sampler=sampler, |
| drop_last=is_train, |
| ) |
| dataloader.num_samples = num_samples |
| dataloader.num_batches = len(dataloader) |
|
|
| return DataInfo(dataloader, sampler) |
|
|
|
|
| def get_proposal_distill_dataset(args, preprocess_fn, is_train, epoch=0, tokenizer=None): |
| assert is_train |
| input_filename = args.train_data |
| assert input_filename |
| dataset = ProposalDistillDataset( |
| input_filename, |
| preprocess_fn, |
| image_root=args.train_image_root, |
| tokenizer=tokenizer, |
| crop_size=args.input_size, |
| args=args |
| ) |
| num_samples = len(dataset) |
| |
| sampler = DistributedSampler(dataset) if args.distributed else None |
| shuffle = is_train and sampler is None |
| batch_size = args.batch_size |
| dataloader = DataLoader( |
| dataset, |
| batch_size=batch_size, |
| shuffle=shuffle, |
| num_workers=args.workers, |
| pin_memory=True, |
| sampler=sampler, |
| drop_last=is_train, |
| ) |
| dataloader.num_samples = num_samples |
| dataloader.num_batches = len(dataloader) |
|
|
| return DataInfo(dataloader, sampler) |
|
|
|
|
| def get_grid_distill_dataset(args, preprocess_fn, is_train, epoch=0, tokenizer=None): |
| assert is_train |
| input_filename = args.train_data |
| assert input_filename |
| dataset = GridDistillDataset( |
| input_filename=input_filename, |
| transforms=preprocess_fn, |
| image_root=args.train_image_root, |
| crop_size=args.input_size, |
| max_split=args.max_split, |
| ceph_root=args.train_ceph_root, |
| pre_transforms=args.pre_transforms, |
| args=args |
| ) |
| num_samples = len(dataset) |
| |
| sampler = DistributedSampler(dataset) if args.distributed else None |
| shuffle = is_train and sampler is None |
| batch_size = args.batch_size |
| dataloader = DataLoader( |
| dataset, |
| batch_size=batch_size, |
| shuffle=shuffle, |
| num_workers=args.workers, |
| pin_memory=True, |
| sampler=sampler, |
| drop_last=is_train, |
| ) |
| dataloader.num_samples = num_samples |
| dataloader.num_batches = len(dataloader) |
|
|
| return DataInfo(dataloader, sampler) |
|
|
|
|
| def get_region_clip_dataset(args, preprocess_fn, is_train, epoch=0, tokenizer=None): |
| assert is_train |
| input_filename = args.train_data |
| assert input_filename |
| dataset = COCORegionCLIPDataset( |
| input_filename=input_filename, |
| transforms=preprocess_fn, |
| image_root=args.train_image_root, |
| args=args, |
| ) |
| num_samples = len(dataset) |
| |
| sampler = DistributedSampler(dataset) if args.distributed else None |
| shuffle = is_train and sampler is None |
| batch_size = args.batch_size |
| dataloader = DataLoader( |
| dataset, |
| batch_size=batch_size, |
| shuffle=shuffle, |
| num_workers=args.workers, |
| pin_memory=True, |
| sampler=sampler, |
| drop_last=is_train, |
| ) |
| dataloader.num_samples = num_samples |
| dataloader.num_batches = len(dataloader) |
|
|
| return DataInfo(dataloader, sampler) |
|
|
|
|
|
|
| class SharedEpoch: |
| def __init__(self, epoch: int = 0): |
| self.shared_epoch = Value('i', epoch) |
|
|
| def set_value(self, epoch): |
| self.shared_epoch.value = epoch |
|
|
| def get_value(self): |
| return self.shared_epoch.value |
|
|
|
|
| @dataclass |
| class DataInfo: |
| dataloader: DataLoader |
| sampler: DistributedSampler = None |
| shared_epoch: SharedEpoch = None |
|
|
| def set_epoch(self, epoch): |
| if self.shared_epoch is not None: |
| self.shared_epoch.set_value(epoch) |
| if self.sampler is not None and isinstance(self.sampler, DistributedSampler): |
| self.sampler.set_epoch(epoch) |
|
|
|
|
| def get_dataset_fn(data_path, dataset_type): |
| if dataset_type == 'coco_panoptic': |
| return get_coco_panoptic_dataset |
| elif dataset_type == 'ade_panoptic': |
| return get_ade_panoptic_dataset |
| elif dataset_type == 'proposals_distill': |
| return get_proposal_distill_dataset |
| elif dataset_type == 'grid_distill': |
| return get_grid_distill_dataset |
| elif dataset_type == 'region_clip': |
| return get_region_clip_dataset |
| else: |
| raise ValueError(f"Unsupported dataset type: {dataset_type}") |
|
|
|
|
| def get_data(args, preprocess_fns, epoch=0, tokenizer=None): |
| preprocess_train, preprocess_val = preprocess_fns |
| data = {} |
|
|
| if args.train_data: |
| data["train"] = get_dataset_fn(args.train_data, args.dataset_type)( |
| args, preprocess_train, is_train=True, epoch=epoch, tokenizer=tokenizer) |
|
|
| if args.val_data: |
| data["val"] = get_dataset_fn(args.val_data, dataset_type=args.test_type)( |
| args, preprocess_val, is_train=False, tokenizer=tokenizer) |
|
|
| return data |
|
|