| from typing import Union, List, Dict |
|
|
| import torch |
| from matanyone2.inference.object_info import ObjectInfo |
|
|
|
|
| class ObjectManager: |
| """ |
| Object IDs are immutable. The same ID always represent the same object. |
| Temporary IDs are the positions of each object in the tensor. It changes as objects get removed. |
| Temporary IDs start from 1. |
| """ |
|
|
| def __init__(self): |
| self.obj_to_tmp_id: Dict[ObjectInfo, int] = {} |
| self.tmp_id_to_obj: Dict[int, ObjectInfo] = {} |
| self.obj_id_to_obj: Dict[int, ObjectInfo] = {} |
|
|
| self.all_historical_object_ids: List[int] = [] |
|
|
| def _recompute_obj_id_to_obj_mapping(self) -> None: |
| self.obj_id_to_obj = {obj.id: obj for obj in self.obj_to_tmp_id} |
|
|
| def add_new_objects( |
| self, objects: Union[List[ObjectInfo], ObjectInfo, |
| List[int]]) -> (List[int], List[int]): |
| if not isinstance(objects, list): |
| objects = [objects] |
|
|
| corresponding_tmp_ids = [] |
| corresponding_obj_ids = [] |
| for obj in objects: |
| if isinstance(obj, int): |
| obj = ObjectInfo(id=obj) |
|
|
| if obj in self.obj_to_tmp_id: |
| |
| corresponding_tmp_ids.append(self.obj_to_tmp_id[obj]) |
| corresponding_obj_ids.append(obj.id) |
| else: |
| |
| new_obj = ObjectInfo(id=obj.id) |
|
|
| |
| new_tmp_id = len(self.obj_to_tmp_id) + 1 |
| self.obj_to_tmp_id[new_obj] = new_tmp_id |
| self.tmp_id_to_obj[new_tmp_id] = new_obj |
| self.all_historical_object_ids.append(new_obj.id) |
| corresponding_tmp_ids.append(new_tmp_id) |
| corresponding_obj_ids.append(new_obj.id) |
|
|
| self._recompute_obj_id_to_obj_mapping() |
| assert corresponding_tmp_ids == sorted(corresponding_tmp_ids) |
| return corresponding_tmp_ids, corresponding_obj_ids |
|
|
| def delete_objects(self, obj_ids_to_remove: Union[int, List[int]]) -> None: |
| |
| |
| if isinstance(obj_ids_to_remove, int): |
| obj_ids_to_remove = [obj_ids_to_remove] |
|
|
| new_tmp_id = 1 |
| total_num_id = len(self.obj_to_tmp_id) |
|
|
| local_obj_to_tmp_id = {} |
| local_tmp_to_obj_id = {} |
|
|
| for tmp_iter in range(1, total_num_id + 1): |
| obj = self.tmp_id_to_obj[tmp_iter] |
| if obj.id not in obj_ids_to_remove: |
| local_obj_to_tmp_id[obj] = new_tmp_id |
| local_tmp_to_obj_id[new_tmp_id] = obj |
| new_tmp_id += 1 |
|
|
| self.obj_to_tmp_id = local_obj_to_tmp_id |
| self.tmp_id_to_obj = local_tmp_to_obj_id |
| self._recompute_obj_id_to_obj_mapping() |
|
|
| def purge_inactive_objects(self, |
| max_missed_detection_count: int) -> (bool, List[int], List[int]): |
| |
| obj_id_to_be_deleted = [] |
| tmp_id_to_be_deleted = [] |
| tmp_id_to_keep = [] |
| obj_id_to_keep = [] |
|
|
| for obj in self.obj_to_tmp_id: |
| if obj.poke_count > max_missed_detection_count: |
| obj_id_to_be_deleted.append(obj.id) |
| tmp_id_to_be_deleted.append(self.obj_to_tmp_id[obj]) |
| else: |
| tmp_id_to_keep.append(self.obj_to_tmp_id[obj]) |
| obj_id_to_keep.append(obj.id) |
|
|
| purge_activated = len(obj_id_to_be_deleted) > 0 |
| if purge_activated: |
| self.delete_objects(obj_id_to_be_deleted) |
| return purge_activated, tmp_id_to_keep, obj_id_to_keep |
|
|
| def tmp_to_obj_cls(self, mask) -> torch.Tensor: |
| |
| new_mask = torch.zeros_like(mask) |
| for tmp_id, obj in self.tmp_id_to_obj.items(): |
| new_mask[mask == tmp_id] = obj.id |
| return new_mask |
|
|
| def get_tmp_to_obj_mapping(self) -> Dict[int, ObjectInfo]: |
| |
| return {obj.id: tmp_id for obj, tmp_id in self.tmp_id_to_obj.items()} |
|
|
| def realize_dict(self, obj_dict, dim=1) -> torch.Tensor: |
| |
| output = [] |
| for _, obj in self.tmp_id_to_obj.items(): |
| if obj.id not in obj_dict: |
| raise NotImplementedError |
| output.append(obj_dict[obj.id]) |
| output = torch.stack(output, dim=dim) |
| return output |
|
|
| def make_one_hot(self, cls_mask) -> torch.Tensor: |
| output = [] |
| for _, obj in self.tmp_id_to_obj.items(): |
| output.append(cls_mask == obj.id) |
| if len(output) == 0: |
| output = torch.zeros((0, *cls_mask.shape), dtype=torch.bool, device=cls_mask.device) |
| else: |
| output = torch.stack(output, dim=0) |
| return output |
|
|
| @property |
| def all_obj_ids(self) -> List[int]: |
| return [k.id for k in self.obj_to_tmp_id] |
|
|
| @property |
| def num_obj(self) -> int: |
| return len(self.obj_to_tmp_id) |
|
|
| def has_all(self, objects: List[int]) -> bool: |
| for obj in objects: |
| if obj not in self.obj_to_tmp_id: |
| return False |
| return True |
|
|
| def find_object_by_id(self, obj_id) -> ObjectInfo: |
| return self.obj_id_to_obj[obj_id] |
|
|
| def find_tmp_by_id(self, obj_id) -> int: |
| return self.obj_to_tmp_id[self.obj_id_to_obj[obj_id]] |
|
|