Spaces:
Runtime error
Runtime error
import math | |
import numpy as np | |
import torch | |
from typing import Optional | |
def get_similarity(mk, ms, qk, qe): | |
# used for training/inference and memory reading/memory potentiation | |
# mk: B x CK x [N] - Memory keys | |
# ms: B x 1 x [N] - Memory shrinkage | |
# qk: B x CK x [HW/P] - Query keys | |
# qe: B x CK x [HW/P] - Query selection | |
# Dimensions in [] are flattened | |
CK = mk.shape[1] | |
mk = mk.flatten(start_dim=2) | |
ms = ms.flatten(start_dim=1).unsqueeze(2) if ms is not None else None | |
qk = qk.flatten(start_dim=2) | |
qe = qe.flatten(start_dim=2) if qe is not None else None | |
if qe is not None: | |
# See appendix for derivation | |
# or you can just trust me ヽ(ー_ー )ノ | |
mk = mk.transpose(1, 2) | |
a_sq = (mk.pow(2) @ qe) | |
two_ab = 2 * (mk @ (qk * qe)) | |
b_sq = (qe * qk.pow(2)).sum(1, keepdim=True) | |
similarity = (-a_sq+two_ab-b_sq) | |
else: | |
# similar to STCN if we don't have the selection term | |
a_sq = mk.pow(2).sum(1).unsqueeze(2) | |
two_ab = 2 * (mk.transpose(1, 2) @ qk) | |
similarity = (-a_sq+two_ab) | |
if ms is not None: | |
similarity = similarity * ms / math.sqrt(CK) # B*N*HW | |
else: | |
similarity = similarity / math.sqrt(CK) # B*N*HW | |
return similarity | |
def do_softmax(similarity, top_k: Optional[int]=None, inplace=False, return_usage=False): | |
# normalize similarity with top-k softmax | |
# similarity: B x N x [HW/P] | |
# use inplace with care | |
if top_k is not None: | |
values, indices = torch.topk(similarity, k=top_k, dim=1) | |
x_exp = values.exp_() | |
x_exp /= torch.sum(x_exp, dim=1, keepdim=True) | |
if inplace: | |
similarity.zero_().scatter_(1, indices, x_exp) # B*N*HW | |
affinity = similarity | |
else: | |
affinity = torch.zeros_like(similarity).scatter_(1, indices, x_exp) # B*N*HW | |
else: | |
maxes = torch.max(similarity, dim=1, keepdim=True)[0] | |
x_exp = torch.exp(similarity - maxes) | |
x_exp_sum = torch.sum(x_exp, dim=1, keepdim=True) | |
affinity = x_exp / x_exp_sum | |
indices = None | |
if return_usage: | |
return affinity, affinity.sum(dim=2) | |
return affinity | |
def get_affinity(mk, ms, qk, qe): | |
# shorthand used in training with no top-k | |
similarity = get_similarity(mk, ms, qk, qe) | |
affinity = do_softmax(similarity) | |
return affinity | |
def readout(affinity, mv): | |
B, CV, T, H, W = mv.shape | |
mo = mv.view(B, CV, T*H*W) | |
mem = torch.bmm(mo, affinity) | |
mem = mem.view(B, CV, H, W) | |
return mem | |