Spaces:
Starting
on
T4
Starting
on
T4
import os | |
from glob import glob | |
import random | |
import numpy as np | |
from PIL import Image | |
import cv2 | |
import itertools | |
import torch | |
import copy | |
from torch.utils.data import Dataset | |
import torchvision.datasets.folder | |
import torchvision.transforms as transforms | |
from einops import rearrange | |
def compute_distance_transform(mask): | |
mask_dt = [] | |
for m in mask: | |
dt = torch.FloatTensor(cv2.distanceTransform(np.uint8(m[0]), cv2.DIST_L2, cv2.DIST_MASK_PRECISE)) | |
inv_dt = torch.FloatTensor(cv2.distanceTransform(np.uint8(1 - m[0]), cv2.DIST_L2, cv2.DIST_MASK_PRECISE)) | |
mask_dt += [torch.stack([dt, inv_dt], 0)] | |
return torch.stack(mask_dt, 0) # Bx2xHxW | |
def crop_image(image, boxs, size): | |
crops = [] | |
for box in boxs: | |
crop_x0, crop_y0, crop_w, crop_h = box | |
crop = transforms.functional.resized_crop(image, crop_y0, crop_x0, crop_h, crop_w, size) | |
crop = transforms.functional.to_tensor(crop) | |
crops += [crop] | |
return torch.stack(crops, 0) | |
def box_loader(fpath): | |
box = np.loadtxt(fpath, 'str') | |
box[0] = box[0].split('_')[0] | |
return box.astype(np.float32) | |
def read_feat_from_img(path, n_channels): | |
feat = np.array(Image.open(path)) | |
return dencode_feat_from_img(feat, n_channels) | |
def dencode_feat_from_img(img, n_channels): | |
n_addon_channels = int(np.ceil(n_channels / 3) * 3) - n_channels | |
n_tiles = int((n_channels + n_addon_channels) / 3) | |
feat = rearrange(img, 'h (t w) c -> h w (t c)', t=n_tiles, c=3) | |
if n_addon_channels != 0: | |
feat = feat[:, :, :-n_addon_channels] | |
feat = feat.astype('float32') / 255 | |
return feat.transpose(2, 0, 1) | |
def dino_loader(fpath, n_channels): | |
dino_map = read_feat_from_img(fpath, n_channels) | |
return dino_map | |
def get_valid_mask(boxs, image_size): | |
valid_masks = [] | |
for box in boxs: | |
crop_x0, crop_y0, crop_w, crop_h, full_w, full_h = box[1:7].int().numpy() | |
margin_w = int(crop_w * 0.02) | |
margin_h = int(crop_h * 0.02) | |
mask_full = torch.ones(full_h-margin_h*2, full_w-margin_w*2) | |
mask_full_pad = torch.nn.functional.pad(mask_full, (crop_w+margin_w, crop_w+margin_w, crop_h+margin_h, crop_h+margin_h), mode='constant', value=0.0) | |
mask_full_crop = mask_full_pad[(crop_y0+crop_h):crop_y0+(crop_h*2), (crop_x0+crop_w):crop_x0+(crop_w*2)] | |
mask_crop = torch.nn.functional.interpolate(mask_full_crop[None, None, :, :], image_size, mode='nearest')[0,0] | |
valid_masks += [mask_crop] | |
return torch.stack(valid_masks, 0) # NxHxW | |
def horizontal_flip_box(box): | |
frame_id, crop_x0, crop_y0, crop_w, crop_h, full_w, full_h, sharpness, label = box.unbind(1) | |
box[:,1] = full_w - crop_x0 - crop_w # x0 | |
return box | |
def horizontal_flip_all(images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features=None, dino_clusters=None): | |
images = images.flip(3) # NxCxHxW | |
masks = masks.flip(3) # NxCxHxW | |
mask_dt = mask_dt.flip(3) # NxCxHxW | |
mask_valid = mask_valid.flip(2) # NxHxW | |
if flows.dim() > 1: | |
flows = flows.flip(3) # (N-1)x(x,y)xHxW | |
flows[:,0] *= -1 # invert delta x | |
bboxs = horizontal_flip_box(bboxs) # NxK | |
bg_images = bg_images.flip(3) # NxCxHxW | |
if dino_features.dim() > 1: | |
dino_features = dino_features.flip(3) | |
if dino_clusters.dim() > 1: | |
dino_clusters = dino_clusters.flip(3) | |
return images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters | |
def none_to_nan(x): | |
return torch.FloatTensor([float('nan')]) if x is None else x | |
class BaseSequenceDataset(Dataset): | |
def __init__(self, root, skip_beginning=4, skip_end=4, min_seq_len=10, debug_seq=False): | |
super().__init__() | |
self.skip_beginning = skip_beginning | |
self.skip_end = skip_end | |
self.min_seq_len = min_seq_len | |
# self.pattern = "{:07d}_{}" | |
self.sequences = self._make_sequences(root) | |
if debug_seq: | |
# self.sequences = [self.sequences[0][20:160]] * 100 | |
seq_len = 0 | |
while seq_len < min_seq_len: | |
i = np.random.randint(len(self.sequences)) | |
rand_seq = self.sequences[i] | |
seq_len = len(rand_seq) | |
self.sequences = [rand_seq] | |
self.samples = [] | |
def _make_sequences(self, path): | |
result = [] | |
for d in sorted(os.scandir(path), key=lambda e: e.name): | |
if d.is_dir(): | |
files = self._parse_folder(d) | |
if len(files) >= self.min_seq_len: | |
result.append(files) | |
return result | |
def _parse_folder(self, path): | |
result = sorted(glob(os.path.join(path, '*'+self.image_loaders[0][0]))) | |
result = [p.replace(self.image_loaders[0][0], '{}') for p in result] | |
if len(result) <= self.skip_beginning + self.skip_end: | |
return [] | |
if self.skip_end == 0: | |
return result[self.skip_beginning:] | |
return result[self.skip_beginning:-self.skip_end] | |
def _load_ids(self, path_patterns, loaders, transform=None): | |
result = [] | |
for loader in loaders: | |
for p in path_patterns: | |
x = loader[1](p.format(loader[0]), *loader[2:]) | |
if transform: | |
x = transform(x) | |
result.append(x) | |
return tuple(result) | |
def __len__(self): | |
return len(self.samples) | |
def __getitem__(self, index): | |
raise NotImplemented("This is a base class and should not be used directly") | |
class NFrameSequenceDataset(BaseSequenceDataset): | |
def __init__(self, root, cat_name=None, num_sample_frames=2, skip_beginning=4, skip_end=4, min_seq_len=10, in_image_size=256, out_image_size=256, debug_seq=False, random_sample=False, shuffle=False, dense_sample=True, color_jitter=None, load_background=False, random_flip=False, rgb_suffix='.png', load_dino_feature=False, load_dino_cluster=False, dino_feature_dim=64, flow_bool=False, **kwargs): | |
self.cat_name = cat_name | |
self.flow_bool=flow_bool | |
self.image_loaders = [("rgb"+rgb_suffix, torchvision.datasets.folder.default_loader)] | |
self.mask_loaders = [("mask.png", torchvision.datasets.folder.default_loader)] | |
self.bbox_loaders = [("box.txt", box_loader)] | |
super().__init__(root, skip_beginning, skip_end, min_seq_len, debug_seq) | |
# from IPython import embed; embed() | |
if flow_bool and num_sample_frames > 1: | |
self.flow_loaders = [("flow.png", cv2.imread, cv2.IMREAD_UNCHANGED)] | |
else: | |
self.flow_loaders = None | |
self.num_sample_frames = num_sample_frames | |
self.random_sample = random_sample | |
if self.random_sample: | |
if shuffle: | |
random.shuffle(self.sequences) | |
self.samples = self.sequences | |
else: | |
for i, s in enumerate(self.sequences): | |
stride = 1 if dense_sample else self.num_sample_frames | |
self.samples += [(i, k) for k in range(0, len(s), stride)] | |
if shuffle: | |
random.shuffle(self.samples) | |
self.in_image_size = in_image_size | |
self.out_image_size = out_image_size | |
self.load_background = load_background | |
self.color_jitter = color_jitter | |
self.image_transform = transforms.Compose([transforms.Resize(self.in_image_size), transforms.ToTensor()]) | |
self.mask_transform = transforms.Compose([transforms.Resize(self.out_image_size, interpolation=Image.NEAREST), transforms.ToTensor()]) | |
if self.flow_loaders is not None: | |
self.flow_transform = lambda x: (torch.FloatTensor(x.astype(np.float32)).flip(2)[:,:,:2] / 65535. ) *2 -1 | |
self.random_flip = random_flip | |
self.load_dino_feature = load_dino_feature | |
if load_dino_feature: | |
self.dino_feature_loaders = [(f"feat{dino_feature_dim}.png", dino_loader, dino_feature_dim)] | |
self.load_dino_cluster = load_dino_cluster | |
if load_dino_cluster: | |
self.dino_cluster_loaders = [("clusters.png", torchvision.datasets.folder.default_loader)] | |
def __getitem__(self, index): | |
if self.random_sample: | |
seq_idx = index % len(self.sequences) | |
seq = self.sequences[seq_idx] | |
if len(seq) < self.num_sample_frames: | |
start_frame_idx = 0 | |
else: | |
start_frame_idx = np.random.randint(len(seq)-self.num_sample_frames+1) | |
paths = seq[start_frame_idx:start_frame_idx+self.num_sample_frames] | |
else: | |
seq_idx, start_frame_idx = self.samples[index % len(self.samples)] | |
seq = self.sequences[seq_idx] | |
# Handle edge case: when only last frame is left, sample last two frames, except if the sequence only has one frame | |
if len(seq) <= start_frame_idx +1: | |
start_frame_idx = max(0, start_frame_idx-1) | |
paths = seq[start_frame_idx:start_frame_idx+self.num_sample_frames] | |
masks = torch.stack(self._load_ids(paths, self.mask_loaders, transform=self.mask_transform), 0) # load all images | |
mask_dt = compute_distance_transform(masks) | |
jitter = False | |
if self.color_jitter is not None: | |
prob, b, h = self.color_jitter | |
if np.random.rand() < prob: | |
jitter = True | |
color_jitter_tsf_fg = transforms.ColorJitter.get_params(brightness=(1-b, 1+b), contrast=None, saturation=None, hue=(-h, h)) | |
image_transform_fg = transforms.Compose([transforms.Resize(self.in_image_size), color_jitter_tsf_fg, transforms.ToTensor()]) | |
color_jitter_tsf_bg = transforms.ColorJitter.get_params(brightness=(1-b, 1+b), contrast=None, saturation=None, hue=(-h, h)) | |
image_transform_bg = transforms.Compose([transforms.Resize(self.in_image_size), color_jitter_tsf_bg, transforms.ToTensor()]) | |
if jitter: | |
images_fg = torch.stack(self._load_ids(paths, self.image_loaders, transform=image_transform_fg), 0) # load all images | |
images_bg = torch.stack(self._load_ids(paths, self.image_loaders, transform=image_transform_bg), 0) # load all images | |
images = images_fg * masks + images_bg * (1-masks) | |
else: | |
images = torch.stack(self._load_ids(paths, self.image_loaders, transform=self.image_transform), 0) # load all images | |
if self.flow_bool==True and len(paths) > 1: | |
flows = torch.stack(self._load_ids(paths[:-1], self.flow_loaders, transform=self.flow_transform), 0).permute(0,3,1,2) # load flow for first image, (N-1)x(x,y)xHxW, -1~1 | |
flows = torch.nn.functional.interpolate(flows, size=self.out_image_size, mode="bilinear") | |
else: | |
flows = torch.zeros(1) | |
bboxs = torch.stack(self._load_ids(paths, self.bbox_loaders, transform=torch.FloatTensor), 0) # load bounding boxes for all images | |
mask_valid = get_valid_mask(bboxs, (self.out_image_size, self.out_image_size)) # exclude pixels cropped outside the original image | |
if self.load_background: | |
bg_image = torchvision.datasets.folder.default_loader(os.path.join(os.path.dirname(paths[0]), 'background_frame.jpg')) | |
if jitter: | |
bg_image = color_jitter_tsf_bg(bg_image) | |
bg_images = crop_image(bg_image, bboxs[:, 1:5].int().numpy(), (self.in_image_size, self.in_image_size)) | |
else: | |
bg_images = torch.zeros_like(images) | |
if self.load_dino_feature: | |
dino_paths = [ | |
x.replace( | |
"/viscam/projects/articulated/dor/AnimalsMotionDataset/splitted_data/Combine_data/dinov2_new", | |
"/viscam/projects/articulated/zzli/data_dino_5000/7_cat" | |
) | |
for x in paths | |
] | |
dino_features = torch.stack(self._load_ids(dino_paths, self.dino_feature_loaders, transform=torch.FloatTensor), 0) | |
# dino_features = torch.stack(self._load_ids(paths, self.dino_feature_loaders, transform=torch.FloatTensor), 0) # BxFx64x224x224 | |
else: | |
dino_features = torch.zeros(1) | |
if self.load_dino_cluster: | |
dino_clusters = torch.stack(self._load_ids(paths, self.dino_cluster_loaders, transform=transforms.ToTensor()), 0) # BxFx3x55x55 | |
else: | |
dino_clusters = torch.zeros(1) | |
seq_idx = torch.LongTensor([seq_idx]) | |
frame_idx = torch.arange(start_frame_idx, start_frame_idx+len(paths)).long() | |
if self.random_flip and np.random.rand() < 0.5: | |
images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters = horizontal_flip_all(images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters) | |
## pad shorter sequence | |
if len(paths) < self.num_sample_frames: | |
num_pad = self.num_sample_frames - len(paths) | |
images = torch.cat([images[:1]] *num_pad + [images], 0) | |
masks = torch.cat([masks[:1]] *num_pad + [masks], 0) | |
mask_dt = torch.cat([mask_dt[:1]] *num_pad + [mask_dt], 0) | |
mask_valid = torch.cat([mask_valid[:1]] *num_pad + [mask_valid], 0) | |
if flows.dim() > 1: | |
flows = torch.cat([flows[:1]*0] *num_pad + [flows], 0) | |
bboxs = torch.cat([bboxs[:1]] * num_pad + [bboxs], 0) | |
bg_images = torch.cat([bg_images[:1]] *num_pad + [bg_images], 0) | |
if dino_features.dim() > 1: | |
dino_features = torch.cat([dino_features[:1]] *num_pad + [dino_features], 0) | |
if dino_clusters.dim() > 1: | |
dino_clusters = torch.cat([dino_clusters[:1]] *num_pad + [dino_clusters], 0) | |
frame_idx = torch.cat([frame_idx[:1]] *num_pad + [frame_idx], 0) | |
out = (*map(none_to_nan, (images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters, seq_idx, frame_idx, self.cat_name)), ) | |
return out | |
# return images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters, seq_idx, frame_idx, self.cat_name | |
def few_shot_box_loader(fpath): | |
box = np.loadtxt(fpath, 'str') | |
# box[0] = box[0].split('_')[0] | |
return box.astype(np.float32) | |
class FewShotImageDataset(Dataset): | |
def __init__(self, root, cat_name=None, cat_num=0, num_sample_frames=2, in_image_size=256, out_image_size=256, shuffle=False, color_jitter=None, load_background=False, random_flip=False, rgb_suffix='.png', load_dino_feature=False, dino_feature_dim=64, flow_bool=False, **kwargs): | |
super().__init__() | |
self.cat_name = cat_name | |
self.cat_num = cat_num # this is actually useless | |
self.flow_bool=flow_bool | |
self.image_loaders = [("rgb"+rgb_suffix, torchvision.datasets.folder.default_loader)] | |
self.mask_loaders = [("mask.png", torchvision.datasets.folder.default_loader)] | |
self.bbox_loaders = [("box.txt", few_shot_box_loader)] | |
self.flow_loaders = None | |
# get all the valid paths, since it's just image-wise, in get_item, we will make it like a len=1 sequence | |
result = sorted(glob(os.path.join(root, '*'+self.image_loaders[0][0]))) | |
result = [p.replace(self.image_loaders[0][0], '{}') for p in result] | |
self.sequences = result | |
self.num_sample_frames = num_sample_frames | |
if shuffle: | |
random.shuffle(self.sequences) | |
self.samples = self.sequences | |
self.in_image_size = in_image_size | |
self.out_image_size = out_image_size | |
self.load_background = load_background | |
self.color_jitter = color_jitter | |
self.image_transform = transforms.Compose([transforms.Resize(self.in_image_size), transforms.ToTensor()]) | |
self.mask_transform = transforms.Compose([transforms.Resize(self.out_image_size, interpolation=Image.NEAREST), transforms.ToTensor()]) | |
self.random_flip = random_flip | |
self.load_dino_feature = load_dino_feature | |
if load_dino_feature: | |
self.dino_feature_loaders = [(f"feat{dino_feature_dim}.png", dino_loader, dino_feature_dim)] | |
def _load_ids(self, path_patterns, loaders, transform=None): | |
result = [] | |
for loader in loaders: | |
for p in path_patterns: | |
x = loader[1](p.format(loader[0]), *loader[2:]) | |
if transform: | |
x = transform(x) | |
result.append(x) | |
return tuple(result) | |
def __len__(self): | |
return len(self.samples) | |
def __getitem__(self, index): | |
paths = [self.samples[index]] # len 1 sequence | |
masks = torch.stack(self._load_ids(paths, self.mask_loaders, transform=self.mask_transform), 0) # load all images | |
mask_dt = compute_distance_transform(masks) | |
jitter = False | |
if self.color_jitter is not None: | |
prob, b, h = self.color_jitter | |
if np.random.rand() < prob: | |
jitter = True | |
color_jitter_tsf_fg = transforms.ColorJitter.get_params(brightness=(1-b, 1+b), contrast=None, saturation=None, hue=(-h, h)) | |
image_transform_fg = transforms.Compose([transforms.Resize(self.in_image_size), color_jitter_tsf_fg, transforms.ToTensor()]) | |
color_jitter_tsf_bg = transforms.ColorJitter.get_params(brightness=(1-b, 1+b), contrast=None, saturation=None, hue=(-h, h)) | |
image_transform_bg = transforms.Compose([transforms.Resize(self.in_image_size), color_jitter_tsf_bg, transforms.ToTensor()]) | |
if jitter: | |
images_fg = torch.stack(self._load_ids(paths, self.image_loaders, transform=image_transform_fg), 0) # load all images | |
images_bg = torch.stack(self._load_ids(paths, self.image_loaders, transform=image_transform_bg), 0) # load all images | |
images = images_fg * masks + images_bg * (1-masks) | |
else: | |
images = torch.stack(self._load_ids(paths, self.image_loaders, transform=self.image_transform), 0) # load all images | |
flows = torch.zeros(1) | |
bboxs = torch.stack(self._load_ids(paths, self.bbox_loaders, transform=torch.FloatTensor), 0) # load bounding boxes for all images | |
bboxs=torch.cat([bboxs, torch.Tensor([[self.cat_num]]).float()],dim=-1) # pad a label number | |
mask_valid = get_valid_mask(bboxs, (self.out_image_size, self.out_image_size)) # exclude pixels cropped outside the original image | |
if self.load_background: | |
bg_image = torchvision.datasets.folder.default_loader(os.path.join(os.path.dirname(paths[0]), 'background_frame.jpg')) | |
if jitter: | |
bg_image = color_jitter_tsf_bg(bg_image) | |
bg_images = crop_image(bg_image, bboxs[:, 1:5].int().numpy(), (self.in_image_size, self.in_image_size)) | |
else: | |
bg_images = torch.zeros_like(images) | |
if self.load_dino_feature: | |
dino_features = torch.stack(self._load_ids(paths, self.dino_feature_loaders, transform=torch.FloatTensor), 0) # BxFx64x224x224 | |
else: | |
dino_features = torch.zeros(1) | |
dino_clusters = torch.zeros(1) | |
# These are actually no use | |
seq_idx = 0 | |
seq_idx = torch.LongTensor([seq_idx]) | |
frame_idx = torch.arange(0, 1).long() | |
if self.random_flip and np.random.rand() < 0.5: | |
images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters = horizontal_flip_all(images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters) | |
## pad shorter sequence | |
if len(paths) < self.num_sample_frames: | |
num_pad = self.num_sample_frames - len(paths) | |
images = torch.cat([images[:1]] *num_pad + [images], 0) | |
masks = torch.cat([masks[:1]] *num_pad + [masks], 0) | |
mask_dt = torch.cat([mask_dt[:1]] *num_pad + [mask_dt], 0) | |
mask_valid = torch.cat([mask_valid[:1]] *num_pad + [mask_valid], 0) | |
if flows.dim() > 1: | |
flows = torch.cat([flows[:1]*0] *num_pad + [flows], 0) | |
bboxs = torch.cat([bboxs[:1]] * num_pad + [bboxs], 0) | |
bg_images = torch.cat([bg_images[:1]] *num_pad + [bg_images], 0) | |
if dino_features.dim() > 1: | |
dino_features = torch.cat([dino_features[:1]] *num_pad + [dino_features], 0) | |
if dino_clusters.dim() > 1: | |
dino_clusters = torch.cat([dino_clusters[:1]] *num_pad + [dino_clusters], 0) | |
frame_idx = torch.cat([frame_idx[:1]] *num_pad + [frame_idx], 0) | |
out = (*map(none_to_nan, (images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters, seq_idx, frame_idx, self.cat_name)), ) | |
return out | |
# return images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters, seq_idx, frame_idx, self.cat_name | |
class Quadrupeds_Image_Dataset(Dataset): | |
def __init__(self, original_data_dirs, few_shot_data_dirs, original_num=7, few_shot_num=93, num_sample_frames=2, | |
in_image_size=256, out_image_size=256, is_validation=False, val_image_num=5, shuffle=False, color_jitter=None, | |
load_background=False, random_flip=False, rgb_suffix='.png', load_dino_feature=False, dino_feature_dim=64, | |
flow_bool=False, disable_fewshot=False, dataset_split_num=-1, **kwargs): | |
self.original_data_dirs = original_data_dirs | |
self.few_shot_data_dirs = few_shot_data_dirs | |
self.original_num = original_num | |
self.few_shot_num = few_shot_num | |
self.image_loaders = [("rgb"+rgb_suffix, torchvision.datasets.folder.default_loader)] | |
self.mask_loaders = [("mask.png", torchvision.datasets.folder.default_loader)] | |
self.original_bbox_loaders = [("box.txt", box_loader)] | |
self.few_shot_bbox_loaders = [("box.txt", few_shot_box_loader)] | |
assert len(self.original_data_dirs.keys()) == self.original_num | |
assert len(self.few_shot_data_dirs.keys()) == self.few_shot_num | |
self.num_sample_frames = num_sample_frames | |
self.batch_size = kwargs['batch_size'] # a hack way here | |
# for debug, just use some categories | |
if "override_categories" in kwargs: | |
self.override_categories = kwargs["override_categories"] | |
else: | |
self.override_categories = None | |
# original dataset | |
original_data_paths = {} | |
for k,v in self.original_data_dirs.items(): | |
# categories override | |
if self.override_categories is not None: | |
if k not in self.override_categories: | |
continue | |
sequences = self._make_sequences(v) | |
samples = [] | |
for seq in sequences: | |
samples += seq | |
if shuffle: | |
random.shuffle(samples) | |
original_data_paths.update({k: samples}) | |
# few-shot dataset | |
enhance_back_view = kwargs['enhance_back_view'] | |
if enhance_back_view: | |
enhance_back_view_path = kwargs['enhance_back_view_path'] | |
few_shot_data_paths = {} | |
for k,v in self.few_shot_data_dirs.items(): | |
# categories override | |
if self.override_categories is not None: | |
if k not in self.override_categories: | |
continue | |
if k.startswith('_'): | |
# a boundary here for dealing with when in new data, we have same categories as in 7-cat | |
v = v.replace(k, k[1:]) | |
if isinstance(v, str): | |
result = sorted(glob(os.path.join(v, '*'+self.image_loaders[0][0]))) | |
elif isinstance(v, list): | |
result = [] | |
for _v in v: | |
result = result + sorted(glob(os.path.join(_v, '*'+self.image_loaders[0][0]))) | |
else: | |
raise NotImplementedError | |
# result = sorted(glob(os.path.join(v, '*'+self.image_loaders[0][0]))) | |
result = [p.replace(self.image_loaders[0][0], '{}') for p in result] | |
sequences = result | |
# the original 7 categories are using pre-defined paths to separate train and test | |
# here the few-shot we use is_validation to decide if this dataset is train or test | |
# if use enhanced back view, we first pad the multiplied back view image paths at the front of seq | |
# i.e., we don't use back view images for validation | |
if enhance_back_view: | |
back_view_dir = os.path.join(enhance_back_view_path, k, 'train') | |
back_view_result = sorted(glob(os.path.join(back_view_dir, '*'+self.image_loaders[0][0]))) | |
back_view_result = [p.replace(self.image_loaders[0][0], '{}') for p in back_view_result] | |
mul_bv_sequences = self._more_back_views(back_view_result, result) | |
sequences = mul_bv_sequences + sequences | |
if is_validation: | |
# sequences = sequences[-2:] | |
sequences = sequences[-val_image_num:] | |
else: | |
# sequences = sequences[:-2] | |
sequences = sequences[:-val_image_num] | |
if shuffle: | |
random.shuffle(sequences) | |
few_shot_data_paths.update({k: sequences}) | |
# for visualization purpose | |
self.pure_ori_data_path = original_data_paths | |
self.pure_fs_data_path = few_shot_data_paths | |
self.few_shot_data_length = self._get_data_length(few_shot_data_paths) # get the original length of each few-shot category | |
if disable_fewshot: | |
few_shot_data_paths = {} | |
self.dataset_split_num = dataset_split_num # if -1 then pad to longest, otherwise follow this number to pad and split | |
if is_validation: | |
self.dataset_split_num = -1 # validation we don't split dataset | |
if self.dataset_split_num == -1: | |
self.all_data_paths, self.one_category_num = self._pad_paths(original_data_paths, few_shot_data_paths) | |
self.all_category_num = len(self.all_data_paths.keys()) | |
self.all_category_names = list(self.all_data_paths.keys()) | |
self.original_category_names = list(self.original_data_dirs.keys()) | |
elif self.dataset_split_num > 0: | |
self.all_data_paths, self.one_category_num, self.original_category_names = self._pad_paths_withnum(original_data_paths, few_shot_data_paths, self.dataset_split_num) | |
self.all_category_num = len(self.all_data_paths.keys()) | |
self.all_category_names = list(self.all_data_paths.keys()) | |
else: | |
raise NotImplementedError | |
self.in_image_size = in_image_size | |
self.out_image_size = out_image_size | |
self.load_background = load_background | |
self.color_jitter = color_jitter | |
self.image_transform = transforms.Compose([transforms.Resize(self.in_image_size), transforms.ToTensor()]) | |
self.mask_transform = transforms.Compose([transforms.Resize(self.out_image_size, interpolation=Image.NEAREST), transforms.ToTensor()]) | |
self.random_flip = random_flip | |
self.load_dino_feature = load_dino_feature | |
if load_dino_feature: | |
self.dino_feature_loaders = [(f"feat{dino_feature_dim}.png", dino_loader, dino_feature_dim)] | |
def _more_back_views(self, back_view_seq, seq): | |
if len(back_view_seq) == 0: | |
# for category without back views | |
return [] | |
factor = 5 | |
# length = (len(seq) // factor) * factor | |
length = (len(seq) // factor) * (factor - 1) | |
mul_f = length // len(back_view_seq) | |
pad_f = length % len(back_view_seq) | |
new_seq = mul_f * back_view_seq + back_view_seq[:pad_f] | |
return new_seq | |
def _get_data_length(self, paths): | |
data_length = {} | |
for k,v in paths.items(): | |
length = len(v) | |
data_length.update({k: length}) | |
return data_length | |
def _make_sequences(self, path): | |
result = [] | |
for d in sorted(os.scandir(path), key=lambda e: e.name): | |
if d.is_dir(): | |
files = self._parse_folder(d) | |
if len(files) >= 1: | |
result.append(files) | |
return result | |
def _parse_folder(self, path): | |
result = sorted(glob(os.path.join(path, '*'+self.image_loaders[0][0]))) | |
result = [p.replace(self.image_loaders[0][0], '{}') for p in result] | |
if len(result) <= 0: | |
return [] | |
return result | |
def _pad_paths(self, ori_paths, fs_paths): | |
img_nums = [] | |
all_paths = copy.deepcopy(ori_paths) | |
all_paths.update(fs_paths) | |
for _, v in all_paths.items(): | |
img_nums.append(len(v)) | |
img_num = max(img_nums) | |
img_num = (img_num // self.batch_size) * self.batch_size | |
for k,v in all_paths.items(): | |
if len(v) < img_num: | |
mul_time = img_num // len(v) | |
pad_time = img_num % len(v) | |
# for each v, shuffle it | |
shuffle_v = copy.deepcopy(v) | |
new_v = [] | |
for i in range(mul_time): | |
new_v = new_v + shuffle_v | |
random.shuffle(shuffle_v) | |
del shuffle_v | |
new_v = new_v + v[0:pad_time] | |
# new_v = mul_time * v + v[0:pad_time] | |
all_paths[k] = new_v | |
elif len(v) > img_num: | |
all_paths[k] = v[:img_num] | |
else: | |
continue | |
return all_paths, img_num | |
def _pad_paths_withnum(self, ori_paths, fs_paths, split_num=1000): | |
img_num = (split_num // self.batch_size) * self.batch_size | |
all_paths = {} | |
orig_cat_names = [] | |
for k, v in ori_paths.items(): | |
total_num = ((len(v) // img_num) + 1) * img_num | |
pad_num = total_num - len(v) | |
split_num = total_num // img_num | |
new_v = copy.deepcopy(v) | |
random.shuffle(new_v) | |
all_v = v + new_v[:pad_num] | |
del new_v | |
for sn in range(split_num): | |
split_cat_name = f'{k}_' + '%03d' % sn | |
all_paths.update({ | |
split_cat_name: all_v[sn*img_num: (sn+1)*img_num] | |
}) | |
orig_cat_names.append(split_cat_name) | |
for k, v in fs_paths.items(): | |
if len(v) < img_num: | |
mul_time = img_num // len(v) | |
pad_time = img_num % len(v) | |
# for each v, shuffle it | |
shuffle_v = copy.deepcopy(v) | |
new_v = [] | |
for i in range(mul_time): | |
new_v = new_v + shuffle_v | |
random.shuffle(shuffle_v) | |
del shuffle_v | |
new_v = new_v + v[0:pad_time] | |
# new_v = mul_time * v + v[0:pad_time] | |
all_paths.update({ | |
k: new_v | |
}) | |
elif len(v) > img_num: | |
all_paths.update({ | |
k: v[:img_num] | |
}) | |
else: | |
continue | |
return all_paths, img_num, orig_cat_names | |
def _load_ids(self, path_patterns, loaders, transform=None): | |
result = [] | |
for loader in loaders: | |
for p in path_patterns: | |
x = loader[1](p.format(loader[0]), *loader[2:]) | |
if transform: | |
x = transform(x) | |
result.append(x) | |
return tuple(result) | |
def _shuffle_all(self): | |
for k,v in self.all_data_paths.items(): | |
new_v = copy.deepcopy(v) | |
random.shuffle(new_v) | |
self.all_data_paths[k] = new_v | |
return None | |
def __len__(self): | |
return self.all_category_num * self.one_category_num | |
def __getitem__(self, index): | |
''' | |
This dataset must have non-shuffled index!! | |
''' | |
category_idx = (index % (self.batch_size * self.all_category_num)) // self.batch_size | |
path_idx = (index // (self.batch_size * self.all_category_num)) * self.batch_size + (index % (self.batch_size * self.all_category_num)) - category_idx * self.batch_size | |
category_name = self.all_category_names[category_idx] | |
paths = [self.all_data_paths[category_name][path_idx]] # len 1 sequence | |
if category_name in self.original_category_names: | |
bbox_loaders = self.original_bbox_loaders | |
use_original_bbox = True | |
else: | |
bbox_loaders = self.few_shot_bbox_loaders | |
use_original_bbox = False | |
masks = torch.stack(self._load_ids(paths, self.mask_loaders, transform=self.mask_transform), 0) # load all images | |
mask_dt = compute_distance_transform(masks) | |
jitter = False | |
if self.color_jitter is not None: | |
prob, b, h = self.color_jitter | |
if np.random.rand() < prob: | |
jitter = True | |
color_jitter_tsf_fg = transforms.ColorJitter.get_params(brightness=(1-b, 1+b), contrast=None, saturation=None, hue=(-h, h)) | |
image_transform_fg = transforms.Compose([transforms.Resize(self.in_image_size), color_jitter_tsf_fg, transforms.ToTensor()]) | |
color_jitter_tsf_bg = transforms.ColorJitter.get_params(brightness=(1-b, 1+b), contrast=None, saturation=None, hue=(-h, h)) | |
image_transform_bg = transforms.Compose([transforms.Resize(self.in_image_size), color_jitter_tsf_bg, transforms.ToTensor()]) | |
if jitter: | |
images_fg = torch.stack(self._load_ids(paths, self.image_loaders, transform=image_transform_fg), 0) # load all images | |
images_bg = torch.stack(self._load_ids(paths, self.image_loaders, transform=image_transform_bg), 0) # load all images | |
images = images_fg * masks + images_bg * (1-masks) | |
else: | |
images = torch.stack(self._load_ids(paths, self.image_loaders, transform=self.image_transform), 0) # load all images | |
flows = torch.zeros(1) | |
bboxs = torch.stack(self._load_ids(paths, bbox_loaders, transform=torch.FloatTensor), 0) # load bounding boxes for all images | |
if not use_original_bbox: | |
bboxs=torch.cat([bboxs, torch.Tensor([[category_idx]]).float()],dim=-1) # pad a label number | |
mask_valid = get_valid_mask(bboxs, (self.out_image_size, self.out_image_size)) # exclude pixels cropped outside the original image | |
if self.load_background: | |
bg_image = torchvision.datasets.folder.default_loader(os.path.join(os.path.dirname(paths[0]), 'background_frame.jpg')) | |
if jitter: | |
bg_image = color_jitter_tsf_bg(bg_image) | |
bg_images = crop_image(bg_image, bboxs[:, 1:5].int().numpy(), (self.in_image_size, self.in_image_size)) | |
else: | |
bg_images = torch.zeros_like(images) | |
if self.load_dino_feature: | |
# print(paths) | |
new_dino_data_name = "data_dino_5000" | |
new_dino_data_path = os.path.join("/viscam/projects/articulated/dor/combine_all_data_for_ablation_magicpony", new_dino_data_name) | |
# TODO: use another version of DINO here by changing the path | |
if paths[0].startswith("/viscam/projects/articulated/dor/AnimalsMotionDataset/splitted_data/Combine_data/dinov2_new"): | |
# 7 cat data | |
new_dino_path = paths[0].replace( | |
"/viscam/projects/articulated/dor/AnimalsMotionDataset/splitted_data/Combine_data/dinov2_new", | |
"/viscam/projects/articulated/zzli/data_dino_5000/7_cat" | |
) | |
dino_paths = [new_dino_path] | |
elif paths[0].startswith("/viscam/u/zzli/workspace/Animal-Data-Engine/data/data_resize_update/few_shot_data_all"): | |
# 100 cat | |
dino_path = paths[0].replace( | |
"/viscam/u/zzli/workspace/Animal-Data-Engine/data/data_resize_update/few_shot_data_all", | |
os.path.join(new_dino_data_path, "100_cat") | |
) | |
dino_path_list = dino_path.split("/") | |
new_dino_path = dino_path_list[:-2] + dino_path_list[-1:] # remove "/train/" | |
new_dino_path = '/'.join(new_dino_path) | |
dino_paths = [new_dino_path] | |
elif paths[0].startswith("/viscam/projects/articulated/zzli/fs_data/data_resize_update/few_shot_data_all"): | |
# 100 cat | |
dino_path = paths[0].replace( | |
"/viscam/projects/articulated/zzli/fs_data/data_resize_update/few_shot_data_all", | |
os.path.join(new_dino_data_path, "100_cat") | |
) | |
dino_path_list = dino_path.split("/") | |
new_dino_path = dino_path_list[:-2] + dino_path_list[-1:] # remove "/train/" | |
new_dino_path = '/'.join(new_dino_path) | |
dino_paths = [new_dino_path] | |
elif paths[0].startswith("/viscam/u/zzli/workspace/Animal-Data-Engine/data/data_resize_update/segmented_back_view_data"): | |
# back 100 cat | |
dino_path = paths[0].replace( | |
"/viscam/u/zzli/workspace/Animal-Data-Engine/data/data_resize_update/segmented_back_view_data", | |
os.path.join(new_dino_data_path, "back_100_cat") | |
) | |
dino_path_list = dino_path.split("/") | |
new_dino_path = dino_path_list[:-2] + dino_path_list[-1:] # remove "/train/" | |
new_dino_path = '/'.join(new_dino_path) | |
dino_paths = [new_dino_path] | |
elif paths[0].startswith("/viscam/projects/articulated/dor/Animal-Data-Engine/data/data_resize_update/train_with_classes_filtered"): | |
# animal3d | |
dino_path = paths[0].replace( | |
"/viscam/projects/articulated/dor/Animal-Data-Engine/data/data_resize_update/train_with_classes_filtered", | |
os.path.join(new_dino_data_path, "animal3D") | |
) | |
dino_path_list = dino_path.split("/") | |
new_dino_path = dino_path_list[:-2] + dino_path_list[-1:] # remove "/train/" | |
new_dino_path = '/'.join(new_dino_path) | |
dino_paths = [new_dino_path] | |
else: | |
raise NotImplementedError | |
dino_features = torch.stack(self._load_ids(dino_paths, self.dino_feature_loaders, transform=torch.FloatTensor), 0) | |
# dino_features = torch.stack(self._load_ids(paths, self.dino_feature_loaders, transform=torch.FloatTensor), 0) # BxFx64x224x224 | |
else: | |
dino_features = torch.zeros(1) | |
dino_clusters = torch.zeros(1) | |
# These are actually no use | |
seq_idx = 0 | |
seq_idx = torch.LongTensor([seq_idx]) | |
frame_idx = torch.arange(0, 1).long() | |
if self.random_flip and np.random.rand() < 0.5: | |
images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters = horizontal_flip_all(images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters) | |
## pad shorter sequence | |
if len(paths) < self.num_sample_frames: | |
num_pad = self.num_sample_frames - len(paths) | |
images = torch.cat([images[:1]] *num_pad + [images], 0) | |
masks = torch.cat([masks[:1]] *num_pad + [masks], 0) | |
mask_dt = torch.cat([mask_dt[:1]] *num_pad + [mask_dt], 0) | |
mask_valid = torch.cat([mask_valid[:1]] *num_pad + [mask_valid], 0) | |
if flows.dim() > 1: | |
flows = torch.cat([flows[:1]*0] *num_pad + [flows], 0) | |
bboxs = torch.cat([bboxs[:1]] * num_pad + [bboxs], 0) | |
bg_images = torch.cat([bg_images[:1]] *num_pad + [bg_images], 0) | |
if dino_features.dim() > 1: | |
dino_features = torch.cat([dino_features[:1]] *num_pad + [dino_features], 0) | |
if dino_clusters.dim() > 1: | |
dino_clusters = torch.cat([dino_clusters[:1]] *num_pad + [dino_clusters], 0) | |
frame_idx = torch.cat([frame_idx[:1]] *num_pad + [frame_idx], 0) | |
out = (*map(none_to_nan, (images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters, seq_idx, frame_idx, category_name)), ) | |
return out | |
# return images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters, seq_idx, frame_idx, category_name | |
def get_sequence_loader_quadrupeds(original_data_dirs, few_shot_data_dirs, original_num, few_shot_num, rank, world_size, **kwargs): | |
dataset = Quadrupeds_Image_Dataset(original_data_dirs, few_shot_data_dirs, original_num, few_shot_num, **kwargs) | |
sampler = torch.utils.data.distributed.DistributedSampler( | |
dataset, | |
num_replicas=world_size, | |
rank=rank, | |
shuffle=False | |
) | |
loaders = [] | |
loaders += [torch.utils.data.DataLoader(dataset, sampler=sampler, batch_size=kwargs['batch_size'], shuffle=False, drop_last=True, num_workers=kwargs['num_workers'], pin_memory=True)] | |
return loaders | |
class Quadrupeds_Image_Test_Dataset(Dataset): | |
def __init__(self, test_data_dirs, num_sample_frames=2, in_image_size=256, out_image_size=256, shuffle=False, color_jitter=None, load_background=False, random_flip=False, rgb_suffix='.png', load_dino_feature=False, dino_feature_dim=64, flow_bool=False, **kwargs): | |
self.few_shot_data_dirs = test_data_dirs | |
self.image_loaders = [("rgb"+rgb_suffix, torchvision.datasets.folder.default_loader)] | |
self.mask_loaders = [("mask.png", torchvision.datasets.folder.default_loader)] | |
self.original_bbox_loaders = [("box.txt", box_loader)] | |
self.few_shot_bbox_loaders = [("box.txt", few_shot_box_loader)] | |
self.num_sample_frames = num_sample_frames | |
self.batch_size = kwargs['batch_size'] # a hack way here | |
few_shot_data_paths = {} | |
for k,v in self.few_shot_data_dirs.items(): | |
if k.startswith('_'): | |
# a boundary here for dealing with when in new data, we have same categories as in 7-cat | |
v = v.replace(k, k[1:]) | |
if isinstance(v, str): | |
result = sorted(glob(os.path.join(v, '*'+self.image_loaders[0][0]))) | |
elif isinstance(v, list): | |
result = [] | |
for _v in v: | |
result = result + sorted(glob(os.path.join(_v, '*'+self.image_loaders[0][0]))) | |
else: | |
raise NotImplementedError | |
# result = sorted(glob(os.path.join(v, '*'+self.image_loaders[0][0]))) | |
result = [p.replace(self.image_loaders[0][0], '{}') for p in result] | |
sequences = result | |
if shuffle: | |
random.shuffle(sequences) | |
few_shot_data_paths.update({k: sequences}) | |
# for visualization purpose | |
self.pure_fs_data_path = few_shot_data_paths | |
self.all_data_paths, self.one_category_num = self._pad_paths(few_shot_data_paths) | |
self.all_category_num = len(self.all_data_paths.keys()) | |
self.all_category_names = list(self.all_data_paths.keys()) | |
self.in_image_size = in_image_size | |
self.out_image_size = out_image_size | |
self.load_background = load_background | |
self.color_jitter = color_jitter | |
self.image_transform = transforms.Compose([transforms.Resize(self.in_image_size), transforms.ToTensor()]) | |
self.mask_transform = transforms.Compose([transforms.Resize(self.out_image_size, interpolation=Image.NEAREST), transforms.ToTensor()]) | |
self.random_flip = random_flip | |
self.load_dino_feature = load_dino_feature | |
if load_dino_feature: | |
self.dino_feature_loaders = [(f"feat{dino_feature_dim}.png", dino_loader, dino_feature_dim)] | |
def _pad_paths(self, fs_paths): | |
img_nums = [] | |
all_paths = copy.deepcopy(fs_paths) | |
for _, v in all_paths.items(): | |
img_nums.append(len(v)) | |
img_num = max(img_nums) | |
img_num = (img_num // self.batch_size) * self.batch_size | |
for k,v in all_paths.items(): | |
if len(v) < img_num: | |
mul_time = img_num // len(v) | |
pad_time = img_num % len(v) | |
# for each v, shuffle it | |
shuffle_v = copy.deepcopy(v) | |
new_v = [] | |
for i in range(mul_time): | |
new_v = new_v + shuffle_v | |
random.shuffle(shuffle_v) | |
del shuffle_v | |
new_v = new_v + v[0:pad_time] | |
# new_v = mul_time * v + v[0:pad_time] | |
all_paths[k] = new_v | |
elif len(v) > img_num: | |
all_paths[k] = v[:img_num] | |
else: | |
continue | |
return all_paths, img_num | |
def _load_ids(self, path_patterns, loaders, transform=None): | |
result = [] | |
for loader in loaders: | |
for p in path_patterns: | |
x = loader[1](p.format(loader[0]), *loader[2:]) | |
if transform: | |
x = transform(x) | |
result.append(x) | |
return tuple(result) | |
def _shuffle_all(self): | |
for k,v in self.all_data_paths.items(): | |
new_v = copy.deepcopy(v) | |
random.shuffle(new_v) | |
self.all_data_paths[k] = new_v | |
return None | |
def __len__(self): | |
return self.all_category_num * self.one_category_num | |
def __getitem__(self, index): | |
''' | |
This dataset must have non-shuffled index!! | |
''' | |
category_idx = (index % (self.batch_size * self.all_category_num)) // self.batch_size | |
path_idx = (index // (self.batch_size * self.all_category_num)) * self.batch_size + (index % (self.batch_size * self.all_category_num)) - category_idx * self.batch_size | |
category_name = self.all_category_names[category_idx] | |
paths = [self.all_data_paths[category_name][path_idx]] # len 1 sequence | |
# if category_name in self.original_category_names: | |
# bbox_loaders = self.original_bbox_loaders | |
# use_original_bbox = True | |
# else: | |
bbox_loaders = self.few_shot_bbox_loaders | |
use_original_bbox = False | |
masks = torch.stack(self._load_ids(paths, self.mask_loaders, transform=self.mask_transform), 0) # load all images | |
mask_dt = compute_distance_transform(masks) | |
jitter = False | |
if self.color_jitter is not None: | |
prob, b, h = self.color_jitter | |
if np.random.rand() < prob: | |
jitter = True | |
color_jitter_tsf_fg = transforms.ColorJitter.get_params(brightness=(1-b, 1+b), contrast=None, saturation=None, hue=(-h, h)) | |
image_transform_fg = transforms.Compose([transforms.Resize(self.in_image_size), color_jitter_tsf_fg, transforms.ToTensor()]) | |
color_jitter_tsf_bg = transforms.ColorJitter.get_params(brightness=(1-b, 1+b), contrast=None, saturation=None, hue=(-h, h)) | |
image_transform_bg = transforms.Compose([transforms.Resize(self.in_image_size), color_jitter_tsf_bg, transforms.ToTensor()]) | |
if jitter: | |
images_fg = torch.stack(self._load_ids(paths, self.image_loaders, transform=image_transform_fg), 0) # load all images | |
images_bg = torch.stack(self._load_ids(paths, self.image_loaders, transform=image_transform_bg), 0) # load all images | |
images = images_fg * masks + images_bg * (1-masks) | |
else: | |
images = torch.stack(self._load_ids(paths, self.image_loaders, transform=self.image_transform), 0) # load all images | |
flows = torch.zeros(1) | |
bboxs = torch.stack(self._load_ids(paths, bbox_loaders, transform=torch.FloatTensor), 0) # load bounding boxes for all images | |
if not use_original_bbox: | |
bboxs=torch.cat([bboxs, torch.Tensor([[category_idx]]).float()],dim=-1) # pad a label number | |
mask_valid = get_valid_mask(bboxs, (self.out_image_size, self.out_image_size)) # exclude pixels cropped outside the original image | |
if self.load_background: | |
bg_image = torchvision.datasets.folder.default_loader(os.path.join(os.path.dirname(paths[0]), 'background_frame.jpg')) | |
if jitter: | |
bg_image = color_jitter_tsf_bg(bg_image) | |
bg_images = crop_image(bg_image, bboxs[:, 1:5].int().numpy(), (self.in_image_size, self.in_image_size)) | |
else: | |
bg_images = torch.zeros_like(images) | |
if self.load_dino_feature: | |
dino_features = torch.stack(self._load_ids(paths, self.dino_feature_loaders, transform=torch.FloatTensor), 0) # BxFx64x224x224 | |
else: | |
dino_features = torch.zeros(1) | |
dino_clusters = torch.zeros(1) | |
# These are actually no use | |
seq_idx = 0 | |
seq_idx = torch.LongTensor([seq_idx]) | |
frame_idx = torch.arange(0, 1).long() | |
if self.random_flip and np.random.rand() < 0.5: | |
images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters = horizontal_flip_all(images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters) | |
## pad shorter sequence | |
if len(paths) < self.num_sample_frames: | |
num_pad = self.num_sample_frames - len(paths) | |
images = torch.cat([images[:1]] *num_pad + [images], 0) | |
masks = torch.cat([masks[:1]] *num_pad + [masks], 0) | |
mask_dt = torch.cat([mask_dt[:1]] *num_pad + [mask_dt], 0) | |
mask_valid = torch.cat([mask_valid[:1]] *num_pad + [mask_valid], 0) | |
if flows.dim() > 1: | |
flows = torch.cat([flows[:1]*0] *num_pad + [flows], 0) | |
bboxs = torch.cat([bboxs[:1]] * num_pad + [bboxs], 0) | |
bg_images = torch.cat([bg_images[:1]] *num_pad + [bg_images], 0) | |
if dino_features.dim() > 1: | |
dino_features = torch.cat([dino_features[:1]] *num_pad + [dino_features], 0) | |
if dino_clusters.dim() > 1: | |
dino_clusters = torch.cat([dino_clusters[:1]] *num_pad + [dino_clusters], 0) | |
frame_idx = torch.cat([frame_idx[:1]] *num_pad + [frame_idx], 0) | |
out = (*map(none_to_nan, (images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters, seq_idx, frame_idx, category_name)), ) | |
return out | |
# return images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters, seq_idx, frame_idx, category_name | |
def get_test_loader_quadrupeds(test_data_dirs, rank, world_size, **kwargs): | |
dataset = Quadrupeds_Image_Test_Dataset(test_data_dirs, **kwargs) | |
sampler = torch.utils.data.distributed.DistributedSampler( | |
dataset, | |
num_replicas=world_size, | |
rank=rank, | |
shuffle=False | |
) | |
loaders = [] | |
loaders += [torch.utils.data.DataLoader(dataset, sampler=sampler, batch_size=kwargs['batch_size'], shuffle=False, drop_last=True, num_workers=kwargs['num_workers'], pin_memory=True)] | |
return loaders | |
def get_sequence_loader(data_dir, **kwargs): | |
if isinstance(data_dir, dict): | |
loaders = [] | |
for k, v in data_dir.items(): | |
dataset= NFrameSequenceDataset(v, cat_name=k, **kwargs) | |
loader = torch.utils.data.DataLoader(dataset, batch_size=kwargs['batch_size'], shuffle=kwargs['shuffle'], num_workers=kwargs['num_workers'], pin_memory=True) | |
loaders += [loader] | |
return loaders | |
else: | |
return [get_sequence_loader_single(data_dir, **kwargs)] | |
def get_sequence_loader_single(data_dir, mode='all_frame', is_validation=False, batch_size=256, num_workers=4, in_image_size=256, out_image_size=256, debug_seq=False, num_sample_frames=2, skip_beginning=4, skip_end=4, min_seq_len=10, max_seq_len=256, random_sample=False, shuffle=False, dense_sample=True, color_jitter=None, load_background=False, random_flip=False, rgb_suffix='.jpg', load_dino_feature=False, load_dino_cluster=False, dino_feature_dim=64): | |
if mode == 'n_frame': | |
dataset = NFrameSequenceDataset(data_dir, num_sample_frames=num_sample_frames, skip_beginning=skip_beginning, skip_end=skip_end, min_seq_len=min_seq_len, in_image_size=in_image_size, out_image_size=out_image_size, debug_seq=debug_seq, random_sample=random_sample, shuffle=shuffle, dense_sample=dense_sample, color_jitter=color_jitter, load_background=load_background, random_flip=random_flip, rgb_suffix=rgb_suffix, load_dino_feature=load_dino_feature, load_dino_cluster=load_dino_cluster, dino_feature_dim=dino_feature_dim) | |
else: | |
raise NotImplementedError | |
loader = torch.utils.data.DataLoader( | |
dataset, | |
batch_size=batch_size, | |
shuffle=not is_validation, | |
num_workers=num_workers, | |
pin_memory=True | |
) | |
return loader | |
def get_sequence_loader_ddp(data_dir, world_size, rank, use_few_shot=False, **kwargs): | |
original_classes_num = 0 | |
use_few_shot = use_few_shot | |
if isinstance(data_dir, list) and len(data_dir) == 2 and isinstance(data_dir[-1], dict): | |
# a hack way for few shot experiment | |
original_classes_num = data_dir[0] | |
data_dir = data_dir[-1] | |
if isinstance(data_dir, dict): | |
loaders = [] | |
cnt = original_classes_num | |
for k, v in data_dir.items(): | |
if use_few_shot: | |
dataset = FewShotImageDataset(v, cat_name=k, cat_num=cnt, **kwargs) | |
cnt += 1 | |
else: | |
dataset = NFrameSequenceDataset(v, cat_name=k, **kwargs) | |
sampler = torch.utils.data.distributed.DistributedSampler( | |
dataset, | |
num_replicas=world_size, | |
rank=rank, | |
) | |
loaders += [torch.utils.data.DataLoader(dataset, sampler=sampler, batch_size=kwargs['batch_size'], shuffle=False, drop_last=True, num_workers=kwargs['num_workers'], pin_memory=True)] | |
return loaders | |
else: | |
return [get_sequence_loader_single_ddp(data_dir, world_size, rank, **kwargs)] | |
def get_sequence_loader_single_ddp(data_dir, world_size, rank, mode='all_frame', is_validation=False, batch_size=256, num_workers=4, in_image_size=256, out_image_size=256, debug_seq=False, num_sample_frames=2, skip_beginning=4, skip_end=4, min_seq_len=10, max_seq_len=256, random_sample=False, shuffle=False, dense_sample=True, color_jitter=None, load_background=False, random_flip=False, rgb_suffix='.jpg', load_dino_feature=False, load_dino_cluster=False, dino_feature_dim=64, flow_bool=False): | |
if mode == 'n_frame': | |
dataset = NFrameSequenceDataset(data_dir, num_sample_frames=num_sample_frames, skip_beginning=skip_beginning, skip_end=skip_end, min_seq_len=min_seq_len, in_image_size=in_image_size, out_image_size=out_image_size, debug_seq=debug_seq, random_sample=random_sample, shuffle=shuffle, dense_sample=dense_sample, color_jitter=color_jitter, load_background=load_background, random_flip=random_flip, rgb_suffix=rgb_suffix, load_dino_feature=load_dino_feature, load_dino_cluster=load_dino_cluster, dino_feature_dim=dino_feature_dim, flow_bool=flow_bool) | |
else: | |
raise NotImplementedError | |
sampler = torch.utils.data.distributed.DistributedSampler( | |
dataset, | |
num_replicas=world_size, | |
rank=rank, | |
) | |
loader = torch.utils.data.DataLoader( | |
dataset, | |
sampler=sampler, | |
batch_size=batch_size, | |
shuffle=False, | |
drop_last=True, | |
num_workers=num_workers, | |
pin_memory=True | |
) | |
return loader | |
class ImageDataset(Dataset): | |
def __init__(self, root, is_validation=False, image_size=256, color_jitter=None): | |
super().__init__() | |
self.image_loader = ("rgb.jpg", torchvision.datasets.folder.default_loader) | |
self.mask_loader = ("mask.png", torchvision.datasets.folder.default_loader) | |
self.bbox_loader = ("box.txt", np.loadtxt, 'str') | |
self.samples = self._parse_folder(root) | |
self.image_size = image_size | |
self.color_jitter = color_jitter | |
self.image_transform = transforms.Compose([transforms.Resize(self.image_size), transforms.ToTensor()]) | |
self.mask_transform = transforms.Compose([transforms.Resize(self.image_size, interpolation=Image.NEAREST), transforms.ToTensor()]) | |
def _parse_folder(self, path): | |
result = sorted(glob(os.path.join(path, '**/*'+self.image_loader[0]), recursive=True)) | |
result = [p.replace(self.image_loader[0], '{}') for p in result] | |
return result | |
def _load_ids(self, path, loader, transform=None): | |
x = loader[1](path.format(loader[0]), *loader[2:]) | |
if transform: | |
x = transform(x) | |
return x | |
def __len__(self): | |
return len(self.samples) | |
def __getitem__(self, index): | |
path = self.samples[index % len(self.samples)] | |
masks = self._load_ids(path, self.mask_loader, transform=self.mask_transform).unsqueeze(0) | |
mask_dt = compute_distance_transform(masks) | |
jitter = False | |
if self.color_jitter is not None: | |
prob, b, h = self.color_jitter | |
if np.random.rand() < prob: | |
jitter = True | |
color_jitter_tsf_fg = transforms.ColorJitter.get_params(brightness=(1-b, 1+b), contrast=None, saturation=None, hue=(-h, h)) | |
image_transform_fg = transforms.Compose([transforms.Resize(self.image_size), color_jitter_tsf_fg, transforms.ToTensor()]) | |
color_jitter_tsf_bg = transforms.ColorJitter.get_params(brightness=(1-b, 1+b), contrast=None, saturation=None, hue=(-h, h)) | |
image_transform_bg = transforms.Compose([transforms.Resize(self.image_size), color_jitter_tsf_bg, transforms.ToTensor()]) | |
if jitter: | |
images_fg = self._load_ids(path, self.image_loader, transform=image_transform_fg).unsqueeze(0) | |
images_bg = self._load_ids(path, self.image_loader, transform=image_transform_bg).unsqueeze(0) | |
images = images_fg * masks + images_bg * (1-masks) | |
else: | |
images = self._load_ids(path, self.image_loader, transform=self.image_transform).unsqueeze(0) | |
flows = torch.zeros(1) | |
bboxs = self._load_ids(path, self.bbox_loader, transform=None) | |
bboxs[0] = '0' | |
bboxs = torch.FloatTensor(bboxs.astype('float')).unsqueeze(0) | |
bg_fpath = os.path.join(os.path.dirname(path), 'background_frame.jpg') | |
if os.path.isfile(bg_fpath): | |
bg_image = torchvision.datasets.folder.default_loader(bg_fpath) | |
if jitter: | |
bg_image = color_jitter_tsf_bg(bg_image) | |
bg_image = transforms.ToTensor()(bg_image) | |
else: | |
bg_image = images[0] | |
seq_idx = torch.LongTensor([index]) | |
frame_idx = torch.LongTensor([0]) | |
return images, masks, mask_dt, flows, bboxs, bg_image, seq_idx, frame_idx | |
def get_image_loader(data_dir, is_validation=False, batch_size=256, num_workers=4, image_size=256, color_jitter=None): | |
dataset = ImageDataset(data_dir, is_validation=is_validation, image_size=image_size, color_jitter=color_jitter) | |
loader = torch.utils.data.DataLoader( | |
dataset, | |
batch_size=batch_size, | |
shuffle=False, | |
num_workers=num_workers, | |
pin_memory=True | |
) | |
return loader | |
def get_image_loader_ddp(data_dir, world_size, rank, is_validation=False, batch_size=256, num_workers=4, image_size=256, color_jitter=None): | |
dataset = ImageDataset(data_dir, is_validation=is_validation, image_size=image_size, color_jitter=color_jitter) | |
sampler = torch.utils.data.distributed.DistributedSampler( | |
dataset, | |
num_replicas=world_size, | |
rank=rank, | |
) | |
loader = torch.utils.data.DataLoader( | |
dataset, | |
sampler=sampler, | |
batch_size=batch_size, | |
shuffle=False, | |
drop_last=True, | |
num_workers=num_workers, | |
pin_memory=True | |
) | |
return loader | |