Spaces:
Runtime error
Runtime error
File size: 8,306 Bytes
4d1ebf3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 |
import torch
from typing import List
class KeyValueMemoryStore:
"""
Works for key/value pairs type storage
e.g., working and long-term memory
"""
"""
An object group is created when new objects enter the video
Objects in the same group share the same temporal extent
i.e., objects initialized in the same frame are in the same group
For DAVIS/interactive, there is only one object group
For YouTubeVOS, there can be multiple object groups
"""
def __init__(self, count_usage: bool):
self.count_usage = count_usage
# keys are stored in a single tensor and are shared between groups/objects
# values are stored as a list indexed by object groups
self.k = None
self.v = []
self.obj_groups = []
# for debugging only
self.all_objects = []
# shrinkage and selection are also single tensors
self.s = self.e = None
# usage
if self.count_usage:
self.use_count = self.life_count = None
def add(self, key, value, shrinkage, selection, objects: List[int]):
new_count = torch.zeros((key.shape[0], 1, key.shape[2]), device=key.device, dtype=torch.float32)
new_life = torch.zeros((key.shape[0], 1, key.shape[2]), device=key.device, dtype=torch.float32) + 1e-7
# add the key
if self.k is None:
self.k = key
self.s = shrinkage
self.e = selection
if self.count_usage:
self.use_count = new_count
self.life_count = new_life
else:
self.k = torch.cat([self.k, key], -1)
if shrinkage is not None:
self.s = torch.cat([self.s, shrinkage], -1)
if selection is not None:
self.e = torch.cat([self.e, selection], -1)
if self.count_usage:
self.use_count = torch.cat([self.use_count, new_count], -1)
self.life_count = torch.cat([self.life_count, new_life], -1)
# add the value
if objects is not None:
# When objects is given, v is a tensor; used in working memory
assert isinstance(value, torch.Tensor)
# First consume objects that are already in the memory bank
# cannot use set here because we need to preserve order
# shift by one as background is not part of value
remaining_objects = [obj-1 for obj in objects]
for gi, group in enumerate(self.obj_groups):
for obj in group:
# should properly raise an error if there are overlaps in obj_groups
remaining_objects.remove(obj)
self.v[gi] = torch.cat([self.v[gi], value[group]], -1)
# If there are remaining objects, add them as a new group
if len(remaining_objects) > 0:
new_group = list(remaining_objects)
self.v.append(value[new_group])
self.obj_groups.append(new_group)
self.all_objects.extend(new_group)
assert sorted(self.all_objects) == self.all_objects, 'Objects MUST be inserted in sorted order '
else:
# When objects is not given, v is a list that already has the object groups sorted
# used in long-term memory
assert isinstance(value, list)
for gi, gv in enumerate(value):
if gv is None:
continue
if gi < self.num_groups:
self.v[gi] = torch.cat([self.v[gi], gv], -1)
else:
self.v.append(gv)
def update_usage(self, usage):
# increase all life count by 1
# increase use of indexed elements
if not self.count_usage:
return
self.use_count += usage.view_as(self.use_count)
self.life_count += 1
def sieve_by_range(self, start: int, end: int, min_size: int):
# keep only the elements *outside* of this range (with some boundary conditions)
# i.e., concat (a[:start], a[end:])
# min_size is only used for values, we do not sieve values under this size
# (because they are not consolidated)
if end == 0:
# negative 0 would not work as the end index!
self.k = self.k[:,:,:start]
if self.count_usage:
self.use_count = self.use_count[:,:,:start]
self.life_count = self.life_count[:,:,:start]
if self.s is not None:
self.s = self.s[:,:,:start]
if self.e is not None:
self.e = self.e[:,:,:start]
for gi in range(self.num_groups):
if self.v[gi].shape[-1] >= min_size:
self.v[gi] = self.v[gi][:,:,:start]
else:
self.k = torch.cat([self.k[:,:,:start], self.k[:,:,end:]], -1)
if self.count_usage:
self.use_count = torch.cat([self.use_count[:,:,:start], self.use_count[:,:,end:]], -1)
self.life_count = torch.cat([self.life_count[:,:,:start], self.life_count[:,:,end:]], -1)
if self.s is not None:
self.s = torch.cat([self.s[:,:,:start], self.s[:,:,end:]], -1)
if self.e is not None:
self.e = torch.cat([self.e[:,:,:start], self.e[:,:,end:]], -1)
for gi in range(self.num_groups):
if self.v[gi].shape[-1] >= min_size:
self.v[gi] = torch.cat([self.v[gi][:,:,:start], self.v[gi][:,:,end:]], -1)
def remove_obsolete_features(self, max_size: int):
# normalize with life duration
usage = self.get_usage().flatten()
values, _ = torch.topk(usage, k=(self.size-max_size), largest=False, sorted=True)
survived = (usage > values[-1])
self.k = self.k[:, :, survived]
self.s = self.s[:, :, survived] if self.s is not None else None
# Long-term memory does not store ek so this should not be needed
self.e = self.e[:, :, survived] if self.e is not None else None
if self.num_groups > 1:
raise NotImplementedError("""The current data structure does not support feature removal with
multiple object groups (e.g., some objects start to appear later in the video)
The indices for "survived" is based on keys but not all values are present for every key
Basically we need to remap the indices for keys to values
""")
for gi in range(self.num_groups):
self.v[gi] = self.v[gi][:, :, survived]
self.use_count = self.use_count[:, :, survived]
self.life_count = self.life_count[:, :, survived]
def get_usage(self):
# return normalized usage
if not self.count_usage:
raise RuntimeError('I did not count usage!')
else:
usage = self.use_count / self.life_count
return usage
def get_all_sliced(self, start: int, end: int):
# return k, sk, ek, usage in order, sliced by start and end
if end == 0:
# negative 0 would not work as the end index!
k = self.k[:,:,start:]
sk = self.s[:,:,start:] if self.s is not None else None
ek = self.e[:,:,start:] if self.e is not None else None
usage = self.get_usage()[:,:,start:]
else:
k = self.k[:,:,start:end]
sk = self.s[:,:,start:end] if self.s is not None else None
ek = self.e[:,:,start:end] if self.e is not None else None
usage = self.get_usage()[:,:,start:end]
return k, sk, ek, usage
def get_v_size(self, ni: int):
return self.v[ni].shape[2]
def engaged(self):
return self.k is not None
@property
def size(self):
if self.k is None:
return 0
else:
return self.k.shape[-1]
@property
def num_groups(self):
return len(self.v)
@property
def key(self):
return self.k
@property
def value(self):
return self.v
@property
def shrinkage(self):
return self.s
@property
def selection(self):
return self.e
|