Spaces:
Running
on
L4
Running
on
L4
from typing import Dict, List, Optional, Literal | |
from collections import defaultdict | |
import torch | |
def _add_last_dim(dictionary, key, new_value, prepend=False): | |
# append/prepend a new value to the last dimension of a tensor in a dictionary | |
# if the key does not exist, put the new value in | |
# append by default | |
if key in dictionary: | |
dictionary[key] = torch.cat([dictionary[key], new_value], -1) | |
else: | |
dictionary[key] = new_value | |
class KeyValueMemoryStore: | |
""" | |
Works for key/value pairs type storage | |
e.g., working and long-term memory | |
""" | |
def __init__(self, save_selection: bool = False, save_usage: bool = False): | |
""" | |
We store keys and values of objects that first appear in the same frame in a bucket. | |
Each bucket contains a set of object ids. | |
Each bucket is associated with a single key tensor | |
and a dictionary of value tensors indexed by object id. | |
The keys and values are stored as the concatenation of a permanent part and a temporary part. | |
""" | |
self.save_selection = save_selection | |
self.save_usage = save_usage | |
self.global_bucket_id = 0 # does not reduce even if buckets are removed | |
self.buckets: Dict[int, List[int]] = {} # indexed by bucket id | |
self.k: Dict[int, torch.Tensor] = {} # indexed by bucket id | |
self.v: Dict[int, torch.Tensor] = {} # indexed by object id | |
# indexed by bucket id; the end point of permanent memory | |
self.perm_end_pt: Dict[int, int] = defaultdict(int) | |
# shrinkage and selection are just like the keys | |
self.s = {} | |
if self.save_selection: | |
self.e = {} # does not contain the permanent memory part | |
# usage | |
if self.save_usage: | |
self.use_cnt = {} # indexed by bucket id, does not contain the permanent memory part | |
self.life_cnt = {} # indexed by bucket id, does not contain the permanent memory part | |
def add(self, | |
key: torch.Tensor, | |
values: Dict[int, torch.Tensor], | |
shrinkage: torch.Tensor, | |
selection: torch.Tensor, | |
supposed_bucket_id: int = -1, | |
as_permanent: Literal['no', 'first', 'all'] = 'no') -> None: | |
""" | |
key: (1/2)*C*N | |
values: dict of values ((1/2)*C*N), object ids are used as keys | |
shrinkage: (1/2)*1*N | |
selection: (1/2)*C*N | |
supposed_bucket_id: used to sync the bucket id between working and long-term memory | |
if provided, the input should all be in a single bucket indexed by this id | |
as_permanent: whether to store the input as permanent memory | |
'no': don't | |
'first': only store it as permanent memory if the bucket is empty | |
'all': always store it as permanent memory | |
""" | |
bs = key.shape[0] | |
ne = key.shape[-1] | |
assert len(key.shape) == 3 | |
assert len(shrinkage.shape) == 3 | |
assert not self.save_selection or len(selection.shape) == 3 | |
assert as_permanent in ['no', 'first', 'all'] | |
# add the value and create new buckets if necessary | |
if supposed_bucket_id >= 0: | |
enabled_buckets = [supposed_bucket_id] | |
bucket_exist = supposed_bucket_id in self.buckets | |
for obj, value in values.items(): | |
if bucket_exist: | |
assert obj in self.v | |
assert obj in self.buckets[supposed_bucket_id] | |
_add_last_dim(self.v, obj, value, prepend=(as_permanent == 'all')) | |
else: | |
assert obj not in self.v | |
self.v[obj] = value | |
self.buckets[supposed_bucket_id] = list(values.keys()) | |
else: | |
new_bucket_id = None | |
enabled_buckets = set() | |
for obj, value in values.items(): | |
assert len(value.shape) == 3 | |
if obj in self.v: | |
_add_last_dim(self.v, obj, value, prepend=(as_permanent == 'all')) | |
bucket_used = [ | |
bucket_id for bucket_id, object_ids in self.buckets.items() | |
if obj in object_ids | |
] | |
assert len(bucket_used) == 1 # each object should only be in one bucket | |
enabled_buckets.add(bucket_used[0]) | |
else: | |
self.v[obj] = value | |
if new_bucket_id is None: | |
# create new bucket | |
new_bucket_id = self.global_bucket_id | |
self.global_bucket_id += 1 | |
self.buckets[new_bucket_id] = [] | |
# put the new object into the corresponding bucket | |
self.buckets[new_bucket_id].append(obj) | |
enabled_buckets.add(new_bucket_id) | |
# increment the permanent size if necessary | |
add_as_permanent = {} # indexed by bucket id | |
for bucket_id in enabled_buckets: | |
add_as_permanent[bucket_id] = False | |
if as_permanent == 'all': | |
self.perm_end_pt[bucket_id] += ne | |
add_as_permanent[bucket_id] = True | |
elif as_permanent == 'first': | |
if self.perm_end_pt[bucket_id] == 0: | |
self.perm_end_pt[bucket_id] = ne | |
add_as_permanent[bucket_id] = True | |
# create new counters for usage if necessary | |
if self.save_usage and as_permanent != 'all': | |
new_count = torch.zeros((bs, ne), device=key.device, dtype=torch.float32) | |
new_life = torch.zeros((bs, ne), device=key.device, dtype=torch.float32) + 1e-7 | |
# add the key to every bucket | |
for bucket_id in self.buckets: | |
if bucket_id not in enabled_buckets: | |
# if we are not adding new values to a bucket, we should skip it | |
continue | |
_add_last_dim(self.k, bucket_id, key, prepend=add_as_permanent[bucket_id]) | |
_add_last_dim(self.s, bucket_id, shrinkage, prepend=add_as_permanent[bucket_id]) | |
if not add_as_permanent[bucket_id]: | |
if self.save_selection: | |
_add_last_dim(self.e, bucket_id, selection) | |
if self.save_usage: | |
_add_last_dim(self.use_cnt, bucket_id, new_count) | |
_add_last_dim(self.life_cnt, bucket_id, new_life) | |
def update_bucket_usage(self, bucket_id: int, usage: torch.Tensor) -> None: | |
# increase all life count by 1 | |
# increase use of indexed elements | |
if not self.save_usage: | |
return | |
usage = usage[:, self.perm_end_pt[bucket_id]:] | |
if usage.shape[-1] == 0: | |
# if there is no temporary memory, we don't need to update | |
return | |
self.use_cnt[bucket_id] += usage.view_as(self.use_cnt[bucket_id]) | |
self.life_cnt[bucket_id] += 1 | |
def sieve_by_range(self, bucket_id: int, start: int, end: int, min_size: int) -> None: | |
# keep only the temporary elements *outside* of this range (with some boundary conditions) | |
# the permanent elements are ignored in this computation | |
# i.e., concat (a[:start], a[end:]) | |
# bucket with size <= min_size are not modified | |
assert start >= 0 | |
assert end <= 0 | |
object_ids = self.buckets[bucket_id] | |
bucket_num_elements = self.k[bucket_id].shape[-1] - self.perm_end_pt[bucket_id] | |
if bucket_num_elements <= min_size: | |
return | |
if end == 0: | |
# negative 0 would not work as the end index! | |
# effectively make the second part an empty slice | |
end = self.k[bucket_id].shape[-1] + 1 | |
p_size = self.perm_end_pt[bucket_id] | |
start = start + p_size | |
k = self.k[bucket_id] | |
s = self.s[bucket_id] | |
if self.save_selection: | |
e = self.e[bucket_id] | |
if self.save_usage: | |
use_cnt = self.use_cnt[bucket_id] | |
life_cnt = self.life_cnt[bucket_id] | |
self.k[bucket_id] = torch.cat([k[:, :, :start], k[:, :, end:]], -1) | |
self.s[bucket_id] = torch.cat([s[:, :, :start], s[:, :, end:]], -1) | |
if self.save_selection: | |
self.e[bucket_id] = torch.cat([e[:, :, :start - p_size], e[:, :, end:]], -1) | |
if self.save_usage: | |
self.use_cnt[bucket_id] = torch.cat([use_cnt[:, :start - p_size], use_cnt[:, end:]], -1) | |
self.life_cnt[bucket_id] = torch.cat([life_cnt[:, :start - p_size], life_cnt[:, end:]], | |
-1) | |
for obj_id in object_ids: | |
v = self.v[obj_id] | |
self.v[obj_id] = torch.cat([v[:, :, :start], v[:, :, end:]], -1) | |
def remove_old_memory(self, bucket_id: int, max_len: int) -> None: | |
self.sieve_by_range(bucket_id, 0, -max_len, max_len) | |
def remove_obsolete_features(self, bucket_id: int, max_size: int) -> None: | |
# for long-term memory only | |
object_ids = self.buckets[bucket_id] | |
assert self.perm_end_pt[bucket_id] == 0 # permanent memory should be empty in LT memory | |
# normalize with life duration | |
usage = self.get_usage(bucket_id) | |
bs = usage.shape[0] | |
survivals = [] | |
for bi in range(bs): | |
_, survived = torch.topk(usage[bi], k=max_size) | |
survivals.append(survived.flatten()) | |
assert survived.shape[-1] == survivals[0].shape[-1] | |
self.k[bucket_id] = torch.stack( | |
[self.k[bucket_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0) | |
self.s[bucket_id] = torch.stack( | |
[self.s[bucket_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0) | |
if self.save_selection: | |
# Long-term memory does not store selection so this should not be needed | |
self.e[bucket_id] = torch.stack( | |
[self.e[bucket_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0) | |
for obj_id in object_ids: | |
self.v[obj_id] = torch.stack( | |
[self.v[obj_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0) | |
self.use_cnt[bucket_id] = torch.stack( | |
[self.use_cnt[bucket_id][bi, survived] for bi, survived in enumerate(survivals)], 0) | |
self.life_cnt[bucket_id] = torch.stack( | |
[self.life_cnt[bucket_id][bi, survived] for bi, survived in enumerate(survivals)], 0) | |
def get_usage(self, bucket_id: int) -> torch.Tensor: | |
# return normalized usage | |
if not self.save_usage: | |
raise RuntimeError('I did not count usage!') | |
else: | |
usage = self.use_cnt[bucket_id] / self.life_cnt[bucket_id] | |
return usage | |
def get_all_sliced( | |
self, bucket_id: int, start: int, end: int | |
) -> (torch.Tensor, torch.Tensor, torch.Tensor, Dict[int, torch.Tensor], torch.Tensor): | |
# return k, sk, ek, value, normalized usage in order, sliced by start and end | |
# this only queries the temporary memory | |
assert start >= 0 | |
assert end <= 0 | |
p_size = self.perm_end_pt[bucket_id] | |
start = start + p_size | |
if end == 0: | |
# negative 0 would not work as the end index! | |
k = self.k[bucket_id][:, :, start:] | |
sk = self.s[bucket_id][:, :, start:] | |
ek = self.e[bucket_id][:, :, start - p_size:] if self.save_selection else None | |
value = {obj_id: self.v[obj_id][:, :, start:] for obj_id in self.buckets[bucket_id]} | |
usage = self.get_usage(bucket_id)[:, start - p_size:] if self.save_usage else None | |
else: | |
k = self.k[bucket_id][:, :, start:end] | |
sk = self.s[bucket_id][:, :, start:end] | |
ek = self.e[bucket_id][:, :, start - p_size:end] if self.save_selection else None | |
value = {obj_id: self.v[obj_id][:, :, start:end] for obj_id in self.buckets[bucket_id]} | |
usage = self.get_usage(bucket_id)[:, start - p_size:end] if self.save_usage else None | |
return k, sk, ek, value, usage | |
def purge_except(self, obj_keep_idx: List[int]): | |
# purge certain objects from the memory except the one listed | |
obj_keep_idx = set(obj_keep_idx) | |
# remove objects that are not in the keep list from the buckets | |
buckets_to_remove = [] | |
for bucket_id, object_ids in self.buckets.items(): | |
self.buckets[bucket_id] = [obj_id for obj_id in object_ids if obj_id in obj_keep_idx] | |
if len(self.buckets[bucket_id]) == 0: | |
buckets_to_remove.append(bucket_id) | |
# remove object values that are not in the keep list | |
self.v = {k: v for k, v in self.v.items() if k in obj_keep_idx} | |
# remove buckets that are empty | |
for bucket_id in buckets_to_remove: | |
del self.buckets[bucket_id] | |
del self.k[bucket_id] | |
del self.s[bucket_id] | |
if self.save_selection: | |
del self.e[bucket_id] | |
if self.save_usage: | |
del self.use_cnt[bucket_id] | |
del self.life_cnt[bucket_id] | |
def clear_non_permanent_memory(self): | |
# clear all non-permanent memory | |
for bucket_id in self.buckets: | |
self.sieve_by_range(bucket_id, 0, 0, 0) | |
def get_v_size(self, obj_id: int) -> int: | |
return self.v[obj_id].shape[-1] | |
def size(self, bucket_id: int) -> int: | |
if bucket_id not in self.k: | |
return 0 | |
else: | |
return self.k[bucket_id].shape[-1] | |
def perm_size(self, bucket_id: int) -> int: | |
return self.perm_end_pt[bucket_id] | |
def non_perm_size(self, bucket_id: int) -> int: | |
return self.size(bucket_id) - self.perm_size(bucket_id) | |
def engaged(self, bucket_id: Optional[int] = None) -> bool: | |
if bucket_id is None: | |
return len(self.buckets) > 0 | |
else: | |
return bucket_id in self.buckets | |
def num_objects(self) -> int: | |
return len(self.v) | |
def key(self) -> Dict[int, torch.Tensor]: | |
return self.k | |
def value(self) -> Dict[int, torch.Tensor]: | |
return self.v | |
def shrinkage(self) -> Dict[int, torch.Tensor]: | |
return self.s | |
def selection(self) -> Dict[int, torch.Tensor]: | |
return self.e | |
def __contains__(self, key): | |
return key in self.v | |