| | |
| | |
| |
|
| | |
| | |
| |
|
| | import glob |
| | import json |
| | import os |
| |
|
| | import numpy as np |
| | import pandas as pd |
| | import torch |
| |
|
| | from PIL import Image as PILImage |
| |
|
| | try: |
| | from pycocotools import mask as mask_utils |
| | except: |
| | pass |
| |
|
| |
|
| | class JSONSegmentLoader: |
| | def __init__(self, video_json_path, ann_every=1, frames_fps=24, valid_obj_ids=None): |
| | |
| | self.ann_every = ann_every |
| | |
| | self.valid_obj_ids = valid_obj_ids |
| | with open(video_json_path, "r") as f: |
| | data = json.load(f) |
| | if isinstance(data, list): |
| | self.frame_annots = data |
| | elif isinstance(data, dict): |
| | masklet_field_name = "masklet" if "masklet" in data else "masks" |
| | self.frame_annots = data[masklet_field_name] |
| | if "fps" in data: |
| | if isinstance(data["fps"], list): |
| | annotations_fps = int(data["fps"][0]) |
| | else: |
| | annotations_fps = int(data["fps"]) |
| | assert frames_fps % annotations_fps == 0 |
| | self.ann_every = frames_fps // annotations_fps |
| | else: |
| | raise NotImplementedError |
| |
|
| | def load(self, frame_id, obj_ids=None): |
| | assert frame_id % self.ann_every == 0 |
| | rle_mask = self.frame_annots[frame_id // self.ann_every] |
| |
|
| | valid_objs_ids = set(range(len(rle_mask))) |
| | if self.valid_obj_ids is not None: |
| | |
| | valid_objs_ids &= set(self.valid_obj_ids) |
| | if obj_ids is not None: |
| | |
| | valid_objs_ids &= set(obj_ids) |
| | valid_objs_ids = sorted(list(valid_objs_ids)) |
| |
|
| | |
| | id_2_idx = {} |
| | rle_mask_filtered = [] |
| | for obj_id in valid_objs_ids: |
| | if rle_mask[obj_id] is not None: |
| | id_2_idx[obj_id] = len(rle_mask_filtered) |
| | rle_mask_filtered.append(rle_mask[obj_id]) |
| | else: |
| | id_2_idx[obj_id] = None |
| |
|
| | |
| | raw_segments = torch.from_numpy(mask_utils.decode(rle_mask_filtered)).permute( |
| | 2, 0, 1 |
| | ) |
| | segments = {} |
| | for obj_id in valid_objs_ids: |
| | if id_2_idx[obj_id] is None: |
| | segments[obj_id] = None |
| | else: |
| | idx = id_2_idx[obj_id] |
| | segments[obj_id] = raw_segments[idx] |
| | return segments |
| |
|
| | def get_valid_obj_frames_ids(self, num_frames_min=None): |
| | |
| | num_objects = len(self.frame_annots[0]) |
| |
|
| | |
| | res = {obj_id: [] for obj_id in range(num_objects)} |
| |
|
| | for annot_idx, annot in enumerate(self.frame_annots): |
| | for obj_id in range(num_objects): |
| | if annot[obj_id] is not None: |
| | res[obj_id].append(int(annot_idx * self.ann_every)) |
| |
|
| | if num_frames_min is not None: |
| | |
| | for obj_id, valid_frames in list(res.items()): |
| | if len(valid_frames) < num_frames_min: |
| | res.pop(obj_id) |
| |
|
| | return res |
| |
|
| |
|
| | class PalettisedPNGSegmentLoader: |
| | def __init__(self, video_png_root): |
| | """ |
| | SegmentLoader for datasets with masks stored as palettised PNGs. |
| | video_png_root: the folder contains all the masks stored in png |
| | """ |
| | self.video_png_root = video_png_root |
| | |
| | |
| | |
| | png_filenames = os.listdir(self.video_png_root) |
| | self.frame_id_to_png_filename = {} |
| | for filename in png_filenames: |
| | frame_id, _ = os.path.splitext(filename) |
| | self.frame_id_to_png_filename[int(frame_id)] = filename |
| |
|
| | def load(self, frame_id): |
| | """ |
| | load the single palettised mask from the disk (path: f'{self.video_png_root}/{frame_id:05d}.png') |
| | Args: |
| | frame_id: int, define the mask path |
| | Return: |
| | binary_segments: dict |
| | """ |
| | |
| | mask_path = os.path.join( |
| | self.video_png_root, self.frame_id_to_png_filename[frame_id] |
| | ) |
| |
|
| | |
| | masks = PILImage.open(mask_path).convert("P") |
| | masks = np.array(masks) |
| |
|
| | object_id = pd.unique(masks.flatten()) |
| | object_id = object_id[object_id != 0] |
| |
|
| | |
| | binary_segments = {} |
| | for i in object_id: |
| | bs = masks == i |
| | binary_segments[i] = torch.from_numpy(bs) |
| |
|
| | return binary_segments |
| |
|
| | def __len__(self): |
| | return |
| |
|
| |
|
| | class MultiplePNGSegmentLoader: |
| | def __init__(self, video_png_root, single_object_mode=False): |
| | """ |
| | video_png_root: the folder contains all the masks stored in png |
| | single_object_mode: whether to load only a single object at a time |
| | """ |
| | self.video_png_root = video_png_root |
| | self.single_object_mode = single_object_mode |
| | |
| | if self.single_object_mode: |
| | tmp_mask_path = glob.glob(os.path.join(video_png_root, "*.png"))[0] |
| | else: |
| | tmp_mask_path = glob.glob(os.path.join(video_png_root, "*", "*.png"))[0] |
| | tmp_mask = np.array(PILImage.open(tmp_mask_path)) |
| | self.H = tmp_mask.shape[0] |
| | self.W = tmp_mask.shape[1] |
| | if self.single_object_mode: |
| | self.obj_id = ( |
| | int(video_png_root.split("/")[-1]) + 1 |
| | ) |
| | else: |
| | self.obj_id = None |
| |
|
| | def load(self, frame_id): |
| | if self.single_object_mode: |
| | return self._load_single_png(frame_id) |
| | else: |
| | return self._load_multiple_pngs(frame_id) |
| |
|
| | def _load_single_png(self, frame_id): |
| | """ |
| | load single png from the disk (path: f'{self.obj_id}/{frame_id:05d}.png') |
| | Args: |
| | frame_id: int, define the mask path |
| | Return: |
| | binary_segments: dict |
| | """ |
| | mask_path = os.path.join(self.video_png_root, f"{frame_id:05d}.png") |
| | binary_segments = {} |
| |
|
| | if os.path.exists(mask_path): |
| | mask = np.array(PILImage.open(mask_path)) |
| | else: |
| | |
| | mask = np.zeros((self.H, self.W), dtype=bool) |
| | binary_segments[self.obj_id] = torch.from_numpy(mask > 0) |
| | return binary_segments |
| |
|
| | def _load_multiple_pngs(self, frame_id): |
| | """ |
| | load multiple png masks from the disk (path: f'{obj_id}/{frame_id:05d}.png') |
| | Args: |
| | frame_id: int, define the mask path |
| | Return: |
| | binary_segments: dict |
| | """ |
| | |
| | all_objects = sorted(glob.glob(os.path.join(self.video_png_root, "*"))) |
| | num_objects = len(all_objects) |
| | assert num_objects > 0 |
| |
|
| | |
| | binary_segments = {} |
| | for obj_folder in all_objects: |
| | |
| | obj_id = int(obj_folder.split("/")[-1]) |
| | obj_id = obj_id + 1 |
| | mask_path = os.path.join(obj_folder, f"{frame_id:05d}.png") |
| | if os.path.exists(mask_path): |
| | mask = np.array(PILImage.open(mask_path)) |
| | else: |
| | mask = np.zeros((self.H, self.W), dtype=bool) |
| | binary_segments[obj_id] = torch.from_numpy(mask > 0) |
| |
|
| | return binary_segments |
| |
|
| | def __len__(self): |
| | return |
| |
|
| |
|
| | class LazySegments: |
| | """ |
| | Only decodes segments that are actually used. |
| | """ |
| |
|
| | def __init__(self): |
| | self.segments = {} |
| | self.cache = {} |
| |
|
| | def __setitem__(self, key, item): |
| | self.segments[key] = item |
| |
|
| | def __getitem__(self, key): |
| | if key in self.cache: |
| | return self.cache[key] |
| | rle = self.segments[key] |
| | mask = torch.from_numpy(mask_utils.decode([rle])).permute(2, 0, 1)[0] |
| | self.cache[key] = mask |
| | return mask |
| |
|
| | def __contains__(self, key): |
| | return key in self.segments |
| |
|
| | def __len__(self): |
| | return len(self.segments) |
| |
|
| | def keys(self): |
| | return self.segments.keys() |
| |
|
| |
|
| | class SA1BSegmentLoader: |
| | def __init__( |
| | self, |
| | video_mask_path, |
| | mask_area_frac_thresh=1.1, |
| | video_frame_path=None, |
| | uncertain_iou=-1, |
| | ): |
| | with open(video_mask_path, "r") as f: |
| | self.frame_annots = json.load(f) |
| |
|
| | if mask_area_frac_thresh <= 1.0: |
| | |
| | orig_w, orig_h = PILImage.open(video_frame_path).size |
| | area = orig_w * orig_h |
| |
|
| | self.frame_annots = self.frame_annots["annotations"] |
| |
|
| | rle_masks = [] |
| | for frame_annot in self.frame_annots: |
| | if not frame_annot["area"] > 0: |
| | continue |
| | if ("uncertain_iou" in frame_annot) and ( |
| | frame_annot["uncertain_iou"] < uncertain_iou |
| | ): |
| | |
| | continue |
| | if ( |
| | mask_area_frac_thresh <= 1.0 |
| | and (frame_annot["area"] / area) >= mask_area_frac_thresh |
| | ): |
| | continue |
| | rle_masks.append(frame_annot["segmentation"]) |
| |
|
| | self.segments = LazySegments() |
| | for i, rle in enumerate(rle_masks): |
| | self.segments[i] = rle |
| |
|
| | def load(self, frame_idx): |
| | return self.segments |
| |
|