File size: 4,963 Bytes
505e401 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import torch
import utils
def get_loss_function(params):
if params['loss'] == 'an_full':
return an_full
elif params['loss'] == 'an_slds':
return an_slds
elif params['loss'] == 'an_ssdl':
return an_ssdl
elif params['loss'] == 'an_full_me':
return an_full_me
elif params['loss'] == 'an_slds_me':
return an_slds_me
elif params['loss'] == 'an_ssdl_me':
return an_ssdl_me
def neg_log(x):
return -torch.log(x + 1e-5)
def bernoulli_entropy(p):
entropy = p * neg_log(p) + (1-p) * neg_log(1-p)
return entropy
def an_ssdl(batch, model, params, loc_to_feats, neg_type='hard'):
inds = torch.arange(params['batch_size'])
loc_feat, _, class_id = batch
loc_feat = loc_feat.to(params['device'])
class_id = class_id.to(params['device'])
assert model.inc_bias == False
batch_size = loc_feat.shape[0]
# create random background samples and extract features
rand_loc = utils.rand_samples(batch_size, params['device'], rand_type='spherical')
rand_feat = loc_to_feats(rand_loc, normalize=False)
# get location embeddings
loc_cat = torch.cat((loc_feat, rand_feat), 0) # stack vertically
loc_emb_cat = model(loc_cat, return_feats=True)
loc_emb = loc_emb_cat[:batch_size, :]
loc_emb_rand = loc_emb_cat[batch_size:, :]
loc_pred = torch.sigmoid(model.class_emb(loc_emb))
loc_pred_rand = torch.sigmoid(model.class_emb(loc_emb_rand))
# data loss
loss_pos = neg_log(loc_pred[inds[:batch_size], class_id])
if neg_type == 'hard':
loss_bg = neg_log(1.0 - loc_pred_rand[inds[:batch_size], class_id]) # assume negative
elif neg_type == 'entropy':
loss_bg = -1 * bernoulli_entropy(1.0 - loc_pred_rand[inds[:batch_size], class_id]) # entropy
else:
raise NotImplementedError
# total loss
loss = loss_pos.mean() + loss_bg.mean()
return loss
def an_slds(batch, model, params, loc_to_feats, neg_type='hard'):
inds = torch.arange(params['batch_size'])
loc_feat, _, class_id = batch
loc_feat = loc_feat.to(params['device'])
class_id = class_id.to(params['device'])
assert model.inc_bias == False
batch_size = loc_feat.shape[0]
loc_emb = model(loc_feat, return_feats=True)
loc_pred = torch.sigmoid(model.class_emb(loc_emb))
num_classes = loc_pred.shape[1]
bg_class = torch.randint(low=0, high=num_classes-1, size=(batch_size,), device=params['device'])
bg_class[bg_class >= class_id[:batch_size]] += 1
# data loss
loss_pos = neg_log(loc_pred[inds[:batch_size], class_id])
if neg_type == 'hard':
loss_bg = neg_log(1.0 - loc_pred[inds[:batch_size], bg_class]) # assume negative
elif neg_type == 'entropy':
loss_bg = -1 * bernoulli_entropy(1.0 - loc_pred[inds[:batch_size], bg_class]) # entropy
else:
raise NotImplementedError
# total loss
loss = loss_pos.mean() + loss_bg.mean()
return loss
def an_full(batch, model, params, loc_to_feats, neg_type='hard'):
inds = torch.arange(params['batch_size'])
loc_feat, _, class_id = batch
loc_feat = loc_feat.to(params['device'])
class_id = class_id.to(params['device'])
assert model.inc_bias == False
batch_size = loc_feat.shape[0]
# create random background samples and extract features
rand_loc = utils.rand_samples(batch_size, params['device'], rand_type='spherical')
rand_feat = loc_to_feats(rand_loc, normalize=False)
# get location embeddings
loc_cat = torch.cat((loc_feat, rand_feat), 0) # stack vertically
loc_emb_cat = model(loc_cat, return_feats=True)
loc_emb = loc_emb_cat[:batch_size, :]
loc_emb_rand = loc_emb_cat[batch_size:, :]
# get predictions for locations and background locations
loc_pred = torch.sigmoid(model.class_emb(loc_emb))
loc_pred_rand = torch.sigmoid(model.class_emb(loc_emb_rand))
# data loss
if neg_type == 'hard':
loss_pos = neg_log(1.0 - loc_pred) # assume negative
loss_bg = neg_log(1.0 - loc_pred_rand) # assume negative
elif neg_type == 'entropy':
loss_pos = -1 * bernoulli_entropy(1.0 - loc_pred) # entropy
loss_bg = -1 * bernoulli_entropy(1.0 - loc_pred_rand) # entropy
else:
raise NotImplementedError
loss_pos[inds[:batch_size], class_id] = params['pos_weight'] * neg_log(loc_pred[inds[:batch_size], class_id])
# total loss
loss = loss_pos.mean() + loss_bg.mean()
return loss
def an_full_me(batch, model, params, loc_to_feats):
return an_full(batch, model, params, loc_to_feats, neg_type='entropy')
def an_ssdl_me(batch, model, params, loc_to_feats):
return an_ssdl(batch, model, params, loc_to_feats, neg_type='entropy')
def an_slds_me(batch, model, params, loc_to_feats):
return an_slds(batch, model, params, loc_to_feats, neg_type='entropy') |