import torch import warnings from inference.kv_memory_store import KeyValueMemoryStore from model.memory_util import * class MemoryManager: """ Manages all three memory stores and the transition between working/long-term memory """ def __init__(self, config): self.hidden_dim = config["hidden_dim"] self.top_k = config["top_k"] self.enable_long_term = config["enable_long_term"] self.enable_long_term_usage = config["enable_long_term_count_usage"] if self.enable_long_term: self.max_mt_frames = config["max_mid_term_frames"] self.min_mt_frames = config["min_mid_term_frames"] self.num_prototypes = config["num_prototypes"] self.max_long_elements = config["max_long_term_elements"] # dimensions will be inferred from input later self.CK = self.CV = None self.H = self.W = None # The hidden state will be stored in a single tensor for all objects # B x num_objects x CH x H x W self.hidden = None self.work_mem = KeyValueMemoryStore(count_usage=self.enable_long_term) if self.enable_long_term: self.long_mem = KeyValueMemoryStore(count_usage=self.enable_long_term_usage) self.reset_config = True def update_config(self, config): self.reset_config = True self.hidden_dim = config["hidden_dim"] self.top_k = config["top_k"] assert self.enable_long_term == config["enable_long_term"], "cannot update this" assert ( self.enable_long_term_usage == config["enable_long_term_count_usage"] ), "cannot update this" self.enable_long_term_usage = config["enable_long_term_count_usage"] if self.enable_long_term: self.max_mt_frames = config["max_mid_term_frames"] self.min_mt_frames = config["min_mid_term_frames"] self.num_prototypes = config["num_prototypes"] self.max_long_elements = config["max_long_term_elements"] def _readout(self, affinity, v): # this function is for a single object group return v @ affinity def match_memory(self, query_key, selection): # query_key: B x C^k x H x W # selection: B x C^k x H x W num_groups = self.work_mem.num_groups h, w = query_key.shape[-2:] query_key = query_key.flatten(start_dim=2) selection = selection.flatten(start_dim=2) if selection is not None else None """ Memory readout using keys """ if self.enable_long_term and self.long_mem.engaged(): # Use long-term memory long_mem_size = self.long_mem.size memory_key = torch.cat([self.long_mem.key, self.work_mem.key], -1) shrinkage = torch.cat( [self.long_mem.shrinkage, self.work_mem.shrinkage], -1 ) similarity = get_similarity(memory_key, shrinkage, query_key, selection) work_mem_similarity = similarity[:, long_mem_size:] long_mem_similarity = similarity[:, :long_mem_size] # get the usage with the first group # the first group always have all the keys valid affinity, usage = do_softmax( torch.cat( [ long_mem_similarity[:, -self.long_mem.get_v_size(0) :], work_mem_similarity, ], 1, ), top_k=self.top_k, inplace=True, return_usage=True, ) affinity = [affinity] # compute affinity group by group as later groups only have a subset of keys for gi in range(1, num_groups): if gi < self.long_mem.num_groups: # merge working and lt similarities before softmax affinity_one_group = do_softmax( torch.cat( [ long_mem_similarity[:, -self.long_mem.get_v_size(gi) :], work_mem_similarity[:, -self.work_mem.get_v_size(gi) :], ], 1, ), top_k=self.top_k, inplace=True, ) else: # no long-term memory for this group affinity_one_group = do_softmax( work_mem_similarity[:, -self.work_mem.get_v_size(gi) :], top_k=self.top_k, inplace=(gi == num_groups - 1), ) affinity.append(affinity_one_group) all_memory_value = [] for gi, gv in enumerate(self.work_mem.value): # merge the working and lt values before readout if gi < self.long_mem.num_groups: all_memory_value.append( torch.cat( [self.long_mem.value[gi], self.work_mem.value[gi]], -1 ) ) else: all_memory_value.append(gv) """ Record memory usage for working and long-term memory """ # ignore the index return for long-term memory work_usage = usage[:, long_mem_size:] self.work_mem.update_usage(work_usage.flatten()) if self.enable_long_term_usage: # ignore the index return for working memory long_usage = usage[:, :long_mem_size] self.long_mem.update_usage(long_usage.flatten()) else: # No long-term memory similarity = get_similarity( self.work_mem.key, self.work_mem.shrinkage, query_key, selection ) if self.enable_long_term: affinity, usage = do_softmax( similarity, inplace=(num_groups == 1), top_k=self.top_k, return_usage=True, ) # Record memory usage for working memory self.work_mem.update_usage(usage.flatten()) else: affinity = do_softmax( similarity, inplace=(num_groups == 1), top_k=self.top_k, return_usage=False, ) affinity = [affinity] # compute affinity group by group as later groups only have a subset of keys for gi in range(1, num_groups): affinity_one_group = do_softmax( similarity[:, -self.work_mem.get_v_size(gi) :], top_k=self.top_k, inplace=(gi == num_groups - 1), ) affinity.append(affinity_one_group) all_memory_value = self.work_mem.value # Shared affinity within each group all_readout_mem = torch.cat( [self._readout(affinity[gi], gv) for gi, gv in enumerate(all_memory_value)], 0, ) return all_readout_mem.view(all_readout_mem.shape[0], self.CV, h, w) def add_memory(self, key, shrinkage, value, objects, selection=None): # key: 1*C*H*W # value: 1*num_objects*C*H*W # objects contain a list of object indices if self.H is None or self.reset_config: self.reset_config = False self.H, self.W = key.shape[-2:] self.HW = self.H * self.W if self.enable_long_term: # convert from num. frames to num. nodes self.min_work_elements = self.min_mt_frames * self.HW self.max_work_elements = self.max_mt_frames * self.HW # key: 1*C*N # value: num_objects*C*N key = key.flatten(start_dim=2) shrinkage = shrinkage.flatten(start_dim=2) value = value[0].flatten(start_dim=2) self.CK = key.shape[1] self.CV = value.shape[1] if selection is not None: if not self.enable_long_term: warnings.warn( "the selection factor is only needed in long-term mode", UserWarning ) selection = selection.flatten(start_dim=2) self.work_mem.add(key, value, shrinkage, selection, objects) # long-term memory cleanup if self.enable_long_term: # Do memory compressed if needed if self.work_mem.size >= self.max_work_elements: # print('remove memory') # Remove obsolete features if needed if self.long_mem.size >= (self.max_long_elements - self.num_prototypes): self.long_mem.remove_obsolete_features( self.max_long_elements - self.num_prototypes ) self.compress_features() def create_hidden_state(self, n, sample_key): # n is the TOTAL number of objects h, w = sample_key.shape[-2:] if self.hidden is None: self.hidden = torch.zeros( (1, n, self.hidden_dim, h, w), device=sample_key.device ) elif self.hidden.shape[1] != n: self.hidden = torch.cat( [ self.hidden, torch.zeros( (1, n - self.hidden.shape[1], self.hidden_dim, h, w), device=sample_key.device, ), ], 1, ) assert self.hidden.shape[1] == n def set_hidden(self, hidden): self.hidden = hidden def get_hidden(self): return self.hidden def compress_features(self): HW = self.HW candidate_value = [] total_work_mem_size = self.work_mem.size for gv in self.work_mem.value: # Some object groups might be added later in the video # So not all keys have values associated with all objects # We need to keep track of the key->value validity mem_size_in_this_group = gv.shape[-1] if mem_size_in_this_group == total_work_mem_size: # full LT candidate_value.append(gv[:, :, HW : -self.min_work_elements + HW]) else: # mem_size is smaller than total_work_mem_size, but at least HW assert HW <= mem_size_in_this_group < total_work_mem_size if mem_size_in_this_group > self.min_work_elements + HW: # part of this object group still goes into LT candidate_value.append(gv[:, :, HW : -self.min_work_elements + HW]) else: # this object group cannot go to the LT at all candidate_value.append(None) # perform memory consolidation prototype_key, prototype_value, prototype_shrinkage = self.consolidation( *self.work_mem.get_all_sliced(HW, -self.min_work_elements + HW), candidate_value ) # remove consolidated working memory self.work_mem.sieve_by_range( HW, -self.min_work_elements + HW, min_size=self.min_work_elements + HW ) # add to long-term memory self.long_mem.add( prototype_key, prototype_value, prototype_shrinkage, selection=None, objects=None, ) # print(f'long memory size: {self.long_mem.size}') # print(f'work memory size: {self.work_mem.size}') def consolidation( self, candidate_key, candidate_shrinkage, candidate_selection, usage, candidate_value, ): # keys: 1*C*N # values: num_objects*C*N N = candidate_key.shape[-1] # find the indices with max usage _, max_usage_indices = torch.topk( usage, k=self.num_prototypes, dim=-1, sorted=True ) prototype_indices = max_usage_indices.flatten() # Prototypes are invalid for out-of-bound groups validity = [ prototype_indices >= (N - gv.shape[2]) if gv is not None else None for gv in candidate_value ] prototype_key = candidate_key[:, :, prototype_indices] prototype_selection = ( candidate_selection[:, :, prototype_indices] if candidate_selection is not None else None ) """ Potentiation step """ similarity = get_similarity( candidate_key, candidate_shrinkage, prototype_key, prototype_selection ) # convert similarity to affinity # need to do it group by group since the softmax normalization would be different affinity = [ do_softmax(similarity[:, -gv.shape[2] :, validity[gi]]) if gv is not None else None for gi, gv in enumerate(candidate_value) ] # some values can be have all False validity. Weed them out. affinity = [ aff if aff is None or aff.shape[-1] > 0 else None for aff in affinity ] # readout the values prototype_value = [ self._readout(affinity[gi], gv) if affinity[gi] is not None else None for gi, gv in enumerate(candidate_value) ] # readout the shrinkage term prototype_shrinkage = ( self._readout(affinity[0], candidate_shrinkage) if candidate_shrinkage is not None else None ) return prototype_key, prototype_value, prototype_shrinkage