SivaResearch's picture
demo
b6d5990
raw
history blame
5.71 kB
import torch
import torch.nn as nn
from models import image
import torch.nn.functional as F
# loss function
def KL(alpha, c):
if torch.cuda.is_available():
beta = torch.ones((1, c)).cuda()
else:
beta = torch.ones((1, c))
S_alpha = torch.sum(alpha, dim=1, keepdim=True)
S_beta = torch.sum(beta, dim=1, keepdim=True)
lnB = torch.lgamma(S_alpha) - torch.sum(torch.lgamma(alpha), dim=1, keepdim=True)
lnB_uni = torch.sum(torch.lgamma(beta), dim=1, keepdim=True) - torch.lgamma(S_beta)
dg0 = torch.digamma(S_alpha)
dg1 = torch.digamma(alpha)
kl = torch.sum((alpha - beta) * (dg1 - dg0), dim=1, keepdim=True) + lnB + lnB_uni
return kl
def ce_loss(p, alpha, c, global_step, annealing_step):
S = torch.sum(alpha, dim=1, keepdim=True)
E = alpha - 1
label = p
A = torch.sum(label * (torch.digamma(S) - torch.digamma(alpha)), dim=1, keepdim=True)
annealing_coef = min(1, global_step / annealing_step)
alp = E * (1 - label) + 1
B = annealing_coef * KL(alp, c)
return torch.mean((A + B))
class TMC(nn.Module):
def __init__(self, args):
super(TMC, self).__init__()
self.args = args
self.rgbenc = image.ImageEncoder(args)
self.specenc = image.RawNet(args)
spec_last_size = args.img_hidden_sz * 1
rgb_last_size = args.img_hidden_sz * args.num_image_embeds
self.spec_depth = nn.ModuleList()
self.clf_rgb = nn.ModuleList()
for hidden in args.hidden:
self.spec_depth.append(nn.Linear(spec_last_size, hidden))
self.spec_depth.append(nn.ReLU())
self.spec_depth.append(nn.Dropout(args.dropout))
spec_last_size = hidden
self.spec_depth.append(nn.Linear(spec_last_size, args.n_classes))
for hidden in args.hidden:
self.clf_rgb.append(nn.Linear(rgb_last_size, hidden))
self.clf_rgb.append(nn.ReLU())
self.clf_rgb.append(nn.Dropout(args.dropout))
rgb_last_size = hidden
self.clf_rgb.append(nn.Linear(rgb_last_size, args.n_classes))
def DS_Combin_two(self, alpha1, alpha2):
# Calculate the merger of two DS evidences
alpha = dict()
alpha[0], alpha[1] = alpha1, alpha2
b, S, E, u = dict(), dict(), dict(), dict()
for v in range(2):
S[v] = torch.sum(alpha[v], dim=1, keepdim=True)
E[v] = alpha[v] - 1
b[v] = E[v] / (S[v].expand(E[v].shape))
u[v] = self.args.n_classes / S[v]
# b^0 @ b^(0+1)
bb = torch.bmm(b[0].view(-1, self.args.n_classes, 1), b[1].view(-1, 1, self.args.n_classes))
# b^0 * u^1
uv1_expand = u[1].expand(b[0].shape)
bu = torch.mul(b[0], uv1_expand)
# b^1 * u^0
uv_expand = u[0].expand(b[0].shape)
ub = torch.mul(b[1], uv_expand)
# calculate K
bb_sum = torch.sum(bb, dim=(1, 2), out=None)
bb_diag = torch.diagonal(bb, dim1=-2, dim2=-1).sum(-1)
# bb_diag1 = torch.diag(torch.mm(b[v], torch.transpose(b[v+1], 0, 1)))
K = bb_sum - bb_diag
# calculate b^a
b_a = (torch.mul(b[0], b[1]) + bu + ub) / ((1 - K).view(-1, 1).expand(b[0].shape))
# calculate u^a
u_a = torch.mul(u[0], u[1]) / ((1 - K).view(-1, 1).expand(u[0].shape))
# test = torch.sum(b_a, dim = 1, keepdim = True) + u_a #Verify programming errors
# calculate new S
S_a = self.args.n_classes / u_a
# calculate new e_k
e_a = torch.mul(b_a, S_a.expand(b_a.shape))
alpha_a = e_a + 1
return alpha_a
def forward(self, rgb, spec):
spec = self.specenc(spec)
spec = torch.flatten(spec, start_dim=1)
rgb = self.rgbenc(rgb)
rgb = torch.flatten(rgb, start_dim=1)
spec_out = spec
for layer in self.spec_depth:
spec_out = layer(spec_out)
rgb_out = rgb
for layer in self.clf_rgb:
rgb_out = layer(rgb_out)
spec_evidence, rgb_evidence = F.softplus(spec_out), F.softplus(rgb_out)
spec_alpha, rgb_alpha = spec_evidence+1, rgb_evidence+1
spec_rgb_alpha = self.DS_Combin_two(spec_alpha, rgb_alpha)
return spec_alpha, rgb_alpha, spec_rgb_alpha
class ETMC(TMC):
def __init__(self, args):
super(ETMC, self).__init__(args)
last_size = args.img_hidden_sz * args.num_image_embeds + args.img_hidden_sz * args.num_image_embeds
self.clf = nn.ModuleList()
for hidden in args.hidden:
self.clf.append(nn.Linear(last_size, hidden))
self.clf.append(nn.ReLU())
self.clf.append(nn.Dropout(args.dropout))
last_size = hidden
self.clf.append(nn.Linear(last_size, args.n_classes))
def forward(self, rgb, spec):
spec = self.specenc(spec)
spec = torch.flatten(spec, start_dim=1)
rgb = self.rgbenc(rgb)
rgb = torch.flatten(rgb, start_dim=1)
spec_out = spec
for layer in self.spec_depth:
spec_out = layer(spec_out)
rgb_out = rgb
for layer in self.clf_rgb:
rgb_out = layer(rgb_out)
pseudo_out = torch.cat([rgb, spec], -1)
for layer in self.clf:
pseudo_out = layer(pseudo_out)
depth_evidence, rgb_evidence, pseudo_evidence = F.softplus(spec_out), F.softplus(rgb_out), F.softplus(pseudo_out)
depth_alpha, rgb_alpha, pseudo_alpha = depth_evidence+1, rgb_evidence+1, pseudo_evidence+1
depth_rgb_alpha = self.DS_Combin_two(self.DS_Combin_two(depth_alpha, rgb_alpha), pseudo_alpha)
return depth_alpha, rgb_alpha, pseudo_alpha, depth_rgb_alpha