Spaces:
Runtime error
Runtime error
File size: 2,586 Bytes
4d1ebf3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
import math
import numpy as np
import torch
from typing import Optional
def get_similarity(mk, ms, qk, qe):
# used for training/inference and memory reading/memory potentiation
# mk: B x CK x [N] - Memory keys
# ms: B x 1 x [N] - Memory shrinkage
# qk: B x CK x [HW/P] - Query keys
# qe: B x CK x [HW/P] - Query selection
# Dimensions in [] are flattened
CK = mk.shape[1]
mk = mk.flatten(start_dim=2)
ms = ms.flatten(start_dim=1).unsqueeze(2) if ms is not None else None
qk = qk.flatten(start_dim=2)
qe = qe.flatten(start_dim=2) if qe is not None else None
if qe is not None:
# See appendix for derivation
# or you can just trust me ヽ(ー_ー )ノ
mk = mk.transpose(1, 2)
a_sq = (mk.pow(2) @ qe)
two_ab = 2 * (mk @ (qk * qe))
b_sq = (qe * qk.pow(2)).sum(1, keepdim=True)
similarity = (-a_sq+two_ab-b_sq)
else:
# similar to STCN if we don't have the selection term
a_sq = mk.pow(2).sum(1).unsqueeze(2)
two_ab = 2 * (mk.transpose(1, 2) @ qk)
similarity = (-a_sq+two_ab)
if ms is not None:
similarity = similarity * ms / math.sqrt(CK) # B*N*HW
else:
similarity = similarity / math.sqrt(CK) # B*N*HW
return similarity
def do_softmax(similarity, top_k: Optional[int]=None, inplace=False, return_usage=False):
# normalize similarity with top-k softmax
# similarity: B x N x [HW/P]
# use inplace with care
if top_k is not None:
values, indices = torch.topk(similarity, k=top_k, dim=1)
x_exp = values.exp_()
x_exp /= torch.sum(x_exp, dim=1, keepdim=True)
if inplace:
similarity.zero_().scatter_(1, indices, x_exp) # B*N*HW
affinity = similarity
else:
affinity = torch.zeros_like(similarity).scatter_(1, indices, x_exp) # B*N*HW
else:
maxes = torch.max(similarity, dim=1, keepdim=True)[0]
x_exp = torch.exp(similarity - maxes)
x_exp_sum = torch.sum(x_exp, dim=1, keepdim=True)
affinity = x_exp / x_exp_sum
indices = None
if return_usage:
return affinity, affinity.sum(dim=2)
return affinity
def get_affinity(mk, ms, qk, qe):
# shorthand used in training with no top-k
similarity = get_similarity(mk, ms, qk, qe)
affinity = do_softmax(similarity)
return affinity
def readout(affinity, mv):
B, CV, T, H, W = mv.shape
mo = mv.view(B, CV, T*H*W)
mem = torch.bmm(mo, affinity)
mem = mem.view(B, CV, H, W)
return mem
|