|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
import random |
|
from copy import deepcopy |
|
|
|
import numpy as np |
|
|
|
import torch |
|
from iopath.common.file_io import g_pathmgr |
|
from PIL import Image as PILImage |
|
from torchvision.datasets.vision import VisionDataset |
|
|
|
from training.dataset.vos_raw_dataset import VOSRawDataset |
|
from training.dataset.vos_sampler import VOSSampler |
|
from training.dataset.vos_segment_loader import JSONSegmentLoader |
|
|
|
from training.utils.data_utils import Frame, Object, VideoDatapoint |
|
|
|
MAX_RETRIES = 100 |
|
|
|
|
|
class VOSDataset(VisionDataset): |
|
def __init__( |
|
self, |
|
transforms, |
|
training: bool, |
|
video_dataset: VOSRawDataset, |
|
sampler: VOSSampler, |
|
multiplier: int, |
|
always_target=True, |
|
target_segments_available=True, |
|
): |
|
self._transforms = transforms |
|
self.training = training |
|
self.video_dataset = video_dataset |
|
self.sampler = sampler |
|
|
|
self.repeat_factors = torch.ones(len(self.video_dataset), dtype=torch.float32) |
|
self.repeat_factors *= multiplier |
|
print(f"Raw dataset length = {len(self.video_dataset)}") |
|
|
|
self.curr_epoch = 0 |
|
self.always_target = always_target |
|
self.target_segments_available = target_segments_available |
|
|
|
def _get_datapoint(self, idx): |
|
|
|
for retry in range(MAX_RETRIES): |
|
try: |
|
if isinstance(idx, torch.Tensor): |
|
idx = idx.item() |
|
|
|
video, segment_loader = self.video_dataset.get_video(idx) |
|
|
|
sampled_frms_and_objs = self.sampler.sample( |
|
video, segment_loader, epoch=self.curr_epoch |
|
) |
|
break |
|
except Exception as e: |
|
if self.training: |
|
logging.warning( |
|
f"Loading failed (id={idx}); Retry {retry} with exception: {e}" |
|
) |
|
idx = random.randrange(0, len(self.video_dataset)) |
|
else: |
|
|
|
raise e |
|
|
|
datapoint = self.construct(video, sampled_frms_and_objs, segment_loader) |
|
for transform in self._transforms: |
|
datapoint = transform(datapoint, epoch=self.curr_epoch) |
|
return datapoint |
|
|
|
def construct(self, video, sampled_frms_and_objs, segment_loader): |
|
""" |
|
Constructs a VideoDatapoint sample to pass to transforms |
|
""" |
|
sampled_frames = sampled_frms_and_objs.frames |
|
sampled_object_ids = sampled_frms_and_objs.object_ids |
|
|
|
images = [] |
|
rgb_images = load_images(sampled_frames) |
|
|
|
for frame_idx, frame in enumerate(sampled_frames): |
|
w, h = rgb_images[frame_idx].size |
|
images.append( |
|
Frame( |
|
data=rgb_images[frame_idx], |
|
objects=[], |
|
) |
|
) |
|
|
|
if isinstance(segment_loader, JSONSegmentLoader): |
|
segments = segment_loader.load( |
|
frame.frame_idx, obj_ids=sampled_object_ids |
|
) |
|
else: |
|
segments = segment_loader.load(frame.frame_idx) |
|
for obj_id in sampled_object_ids: |
|
|
|
if obj_id in segments: |
|
assert ( |
|
segments[obj_id] is not None |
|
), "None targets are not supported" |
|
|
|
segment = segments[obj_id].to(torch.uint8) |
|
else: |
|
|
|
if not self.always_target: |
|
continue |
|
segment = torch.zeros(h, w, dtype=torch.uint8) |
|
|
|
images[frame_idx].objects.append( |
|
Object( |
|
object_id=obj_id, |
|
frame_index=frame.frame_idx, |
|
segment=segment, |
|
) |
|
) |
|
return VideoDatapoint( |
|
frames=images, |
|
video_id=video.video_id, |
|
size=(h, w), |
|
) |
|
|
|
def __getitem__(self, idx): |
|
return self._get_datapoint(idx) |
|
|
|
def __len__(self): |
|
return len(self.video_dataset) |
|
|
|
|
|
def load_images(frames): |
|
all_images = [] |
|
cache = {} |
|
for frame in frames: |
|
if frame.data is None: |
|
|
|
path = frame.image_path |
|
if path in cache: |
|
all_images.append(deepcopy(all_images[cache[path]])) |
|
continue |
|
with g_pathmgr.open(path, "rb") as fopen: |
|
all_images.append(PILImage.open(fopen).convert("RGB")) |
|
cache[path] = len(all_images) - 1 |
|
else: |
|
|
|
|
|
all_images.append(tensor_2_PIL(frame.data)) |
|
|
|
return all_images |
|
|
|
|
|
def tensor_2_PIL(data: torch.Tensor) -> PILImage.Image: |
|
data = data.cpu().numpy().transpose((1, 2, 0)) * 255.0 |
|
data = data.astype(np.uint8) |
|
return PILImage.fromarray(data) |
|
|