Spaces:
Runtime error
Runtime error
| import torch | |
| from typing import List | |
| class KeyValueMemoryStore: | |
| """ | |
| Works for key/value pairs type storage | |
| e.g., working and long-term memory | |
| """ | |
| """ | |
| An object group is created when new objects enter the video | |
| Objects in the same group share the same temporal extent | |
| i.e., objects initialized in the same frame are in the same group | |
| For DAVIS/interactive, there is only one object group | |
| For YouTubeVOS, there can be multiple object groups | |
| """ | |
| def __init__(self, count_usage: bool): | |
| self.count_usage = count_usage | |
| # keys are stored in a single tensor and are shared between groups/objects | |
| # values are stored as a list indexed by object groups | |
| self.k = None | |
| self.v = [] | |
| self.obj_groups = [] | |
| # for debugging only | |
| self.all_objects = [] | |
| # shrinkage and selection are also single tensors | |
| self.s = self.e = None | |
| # usage | |
| if self.count_usage: | |
| self.use_count = self.life_count = None | |
| def add(self, key, value, shrinkage, selection, objects: List[int]): | |
| new_count = torch.zeros((key.shape[0], 1, key.shape[2]), device=key.device, dtype=torch.float32) | |
| new_life = torch.zeros((key.shape[0], 1, key.shape[2]), device=key.device, dtype=torch.float32) + 1e-7 | |
| # add the key | |
| if self.k is None: | |
| self.k = key | |
| self.s = shrinkage | |
| self.e = selection | |
| if self.count_usage: | |
| self.use_count = new_count | |
| self.life_count = new_life | |
| else: | |
| self.k = torch.cat([self.k, key], -1) | |
| if shrinkage is not None: | |
| self.s = torch.cat([self.s, shrinkage], -1) | |
| if selection is not None: | |
| self.e = torch.cat([self.e, selection], -1) | |
| if self.count_usage: | |
| self.use_count = torch.cat([self.use_count, new_count], -1) | |
| self.life_count = torch.cat([self.life_count, new_life], -1) | |
| # add the value | |
| if objects is not None: | |
| # When objects is given, v is a tensor; used in working memory | |
| assert isinstance(value, torch.Tensor) | |
| # First consume objects that are already in the memory bank | |
| # cannot use set here because we need to preserve order | |
| # shift by one as background is not part of value | |
| remaining_objects = [obj-1 for obj in objects] | |
| for gi, group in enumerate(self.obj_groups): | |
| for obj in group: | |
| # should properly raise an error if there are overlaps in obj_groups | |
| remaining_objects.remove(obj) | |
| self.v[gi] = torch.cat([self.v[gi], value[group]], -1) | |
| # If there are remaining objects, add them as a new group | |
| if len(remaining_objects) > 0: | |
| new_group = list(remaining_objects) | |
| self.v.append(value[new_group]) | |
| self.obj_groups.append(new_group) | |
| self.all_objects.extend(new_group) | |
| assert sorted(self.all_objects) == self.all_objects, 'Objects MUST be inserted in sorted order ' | |
| else: | |
| # When objects is not given, v is a list that already has the object groups sorted | |
| # used in long-term memory | |
| assert isinstance(value, list) | |
| for gi, gv in enumerate(value): | |
| if gv is None: | |
| continue | |
| if gi < self.num_groups: | |
| self.v[gi] = torch.cat([self.v[gi], gv], -1) | |
| else: | |
| self.v.append(gv) | |
| def update_usage(self, usage): | |
| # increase all life count by 1 | |
| # increase use of indexed elements | |
| if not self.count_usage: | |
| return | |
| self.use_count += usage.view_as(self.use_count) | |
| self.life_count += 1 | |
| def sieve_by_range(self, start: int, end: int, min_size: int): | |
| # keep only the elements *outside* of this range (with some boundary conditions) | |
| # i.e., concat (a[:start], a[end:]) | |
| # min_size is only used for values, we do not sieve values under this size | |
| # (because they are not consolidated) | |
| if end == 0: | |
| # negative 0 would not work as the end index! | |
| self.k = self.k[:,:,:start] | |
| if self.count_usage: | |
| self.use_count = self.use_count[:,:,:start] | |
| self.life_count = self.life_count[:,:,:start] | |
| if self.s is not None: | |
| self.s = self.s[:,:,:start] | |
| if self.e is not None: | |
| self.e = self.e[:,:,:start] | |
| for gi in range(self.num_groups): | |
| if self.v[gi].shape[-1] >= min_size: | |
| self.v[gi] = self.v[gi][:,:,:start] | |
| else: | |
| self.k = torch.cat([self.k[:,:,:start], self.k[:,:,end:]], -1) | |
| if self.count_usage: | |
| self.use_count = torch.cat([self.use_count[:,:,:start], self.use_count[:,:,end:]], -1) | |
| self.life_count = torch.cat([self.life_count[:,:,:start], self.life_count[:,:,end:]], -1) | |
| if self.s is not None: | |
| self.s = torch.cat([self.s[:,:,:start], self.s[:,:,end:]], -1) | |
| if self.e is not None: | |
| self.e = torch.cat([self.e[:,:,:start], self.e[:,:,end:]], -1) | |
| for gi in range(self.num_groups): | |
| if self.v[gi].shape[-1] >= min_size: | |
| self.v[gi] = torch.cat([self.v[gi][:,:,:start], self.v[gi][:,:,end:]], -1) | |
| def remove_obsolete_features(self, max_size: int): | |
| # normalize with life duration | |
| usage = self.get_usage().flatten() | |
| values, _ = torch.topk(usage, k=(self.size-max_size), largest=False, sorted=True) | |
| survived = (usage > values[-1]) | |
| self.k = self.k[:, :, survived] | |
| self.s = self.s[:, :, survived] if self.s is not None else None | |
| # Long-term memory does not store ek so this should not be needed | |
| self.e = self.e[:, :, survived] if self.e is not None else None | |
| if self.num_groups > 1: | |
| raise NotImplementedError("""The current data structure does not support feature removal with | |
| multiple object groups (e.g., some objects start to appear later in the video) | |
| The indices for "survived" is based on keys but not all values are present for every key | |
| Basically we need to remap the indices for keys to values | |
| """) | |
| for gi in range(self.num_groups): | |
| self.v[gi] = self.v[gi][:, :, survived] | |
| self.use_count = self.use_count[:, :, survived] | |
| self.life_count = self.life_count[:, :, survived] | |
| def get_usage(self): | |
| # return normalized usage | |
| if not self.count_usage: | |
| raise RuntimeError('I did not count usage!') | |
| else: | |
| usage = self.use_count / self.life_count | |
| return usage | |
| def get_all_sliced(self, start: int, end: int): | |
| # return k, sk, ek, usage in order, sliced by start and end | |
| if end == 0: | |
| # negative 0 would not work as the end index! | |
| k = self.k[:,:,start:] | |
| sk = self.s[:,:,start:] if self.s is not None else None | |
| ek = self.e[:,:,start:] if self.e is not None else None | |
| usage = self.get_usage()[:,:,start:] | |
| else: | |
| k = self.k[:,:,start:end] | |
| sk = self.s[:,:,start:end] if self.s is not None else None | |
| ek = self.e[:,:,start:end] if self.e is not None else None | |
| usage = self.get_usage()[:,:,start:end] | |
| return k, sk, ek, usage | |
| def get_v_size(self, ni: int): | |
| return self.v[ni].shape[2] | |
| def engaged(self): | |
| return self.k is not None | |
| def size(self): | |
| if self.k is None: | |
| return 0 | |
| else: | |
| return self.k.shape[-1] | |
| def num_groups(self): | |
| return len(self.v) | |
| def key(self): | |
| return self.k | |
| def value(self): | |
| return self.v | |
| def shrinkage(self): | |
| return self.s | |
| def selection(self): | |
| return self.e | |