import math import numpy as np import torch from typing import Optional def get_similarity(mk, ms, qk, qe): # used for training/inference and memory reading/memory potentiation # mk: B x CK x [N] - Memory keys # ms: B x 1 x [N] - Memory shrinkage # qk: B x CK x [HW/P] - Query keys # qe: B x CK x [HW/P] - Query selection # Dimensions in [] are flattened CK = mk.shape[1] mk = mk.flatten(start_dim=2) ms = ms.flatten(start_dim=1).unsqueeze(2) if ms is not None else None qk = qk.flatten(start_dim=2) qe = qe.flatten(start_dim=2) if qe is not None else None if qe is not None: # See appendix for derivation # or you can just trust me ヽ(ー_ー )ノ mk = mk.transpose(1, 2) a_sq = mk.pow(2) @ qe two_ab = 2 * (mk @ (qk * qe)) b_sq = (qe * qk.pow(2)).sum(1, keepdim=True) similarity = -a_sq + two_ab - b_sq else: # similar to STCN if we don't have the selection term a_sq = mk.pow(2).sum(1).unsqueeze(2) two_ab = 2 * (mk.transpose(1, 2) @ qk) similarity = -a_sq + two_ab if ms is not None: similarity = similarity * ms / math.sqrt(CK) # B*N*HW else: similarity = similarity / math.sqrt(CK) # B*N*HW return similarity def do_softmax( similarity, top_k: Optional[int] = None, inplace=False, return_usage=False ): # normalize similarity with top-k softmax # similarity: B x N x [HW/P] # use inplace with care if top_k is not None: values, indices = torch.topk(similarity, k=top_k, dim=1) x_exp = values.exp_() x_exp /= torch.sum(x_exp, dim=1, keepdim=True) if inplace: similarity.zero_().scatter_(1, indices, x_exp) # B*N*HW affinity = similarity else: affinity = torch.zeros_like(similarity).scatter_( 1, indices, x_exp ) # B*N*HW else: maxes = torch.max(similarity, dim=1, keepdim=True)[0] x_exp = torch.exp(similarity - maxes) x_exp_sum = torch.sum(x_exp, dim=1, keepdim=True) affinity = x_exp / x_exp_sum indices = None if return_usage: return affinity, affinity.sum(dim=2) return affinity def get_affinity(mk, ms, qk, qe): # shorthand used in training with no top-k similarity = get_similarity(mk, ms, qk, qe) affinity = do_softmax(similarity) return affinity def readout(affinity, mv): B, CV, T, H, W = mv.shape mo = mv.view(B, CV, T * H * W) mem = torch.bmm(mo, affinity) mem = mem.view(B, CV, H, W) return mem