File size: 7,377 Bytes
da6d0ff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from model_.clip import build_model
from .layers import FPN, Projector, TransformerDecoder
class CRIS_PosOnly(nn.Module):
def __init__(self, cfg):
super().__init__()
# Vision & Text Encoder
clip_model = torch.jit.load(cfg.clip_pretrain,
map_location="cpu").eval()
self.backbone = build_model(clip_model.state_dict(), cfg.word_len, cfg.freeze).float()
# Multi-Modal FPN
self.neck = FPN(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
# Decoder
self.decoder = TransformerDecoder(num_layers=cfg.num_layers,
d_model=cfg.vis_dim,
nhead=cfg.num_head,
dim_ffn=cfg.dim_ffn,
dropout=cfg.dropout,
return_intermediate=cfg.intermediate)
# Projector
self.proj = Projector(cfg.word_dim, cfg.vis_dim // 2, 3)
self.metric_learning = False # cfg.metric_learning
self.metric_loss_weight = cfg.metric_loss_weight
self.cfg = cfg
def forward(self, image, text, target=None, verb=None):
'''
image: b, 3, h, w
text: b, words
target: b, 1, h, w
verb: b, words (if applicable, only used in training mode for contrastive learning)
'''
sentences, images, targets, pad_masks = [], [], [], []
if self.training:
verb_masks = []
cl_masks = []
for idx in range(len(text)):
sentences.append(text[idx])
images.append(image[idx])
targets.append(target[idx])
pad_masks.append(torch.zeros_like(text[idx]).masked_fill_(text[idx] == 0, 1).bool())
# If verb exists, process it
if verb[idx].numel() > 0 and verb[idx].sum().item() > 0:
verb_masks.extend([1, 1]) # Both original sentence and verb are marked
cl_masks.extend([1, 0]) # Only original sentence get marked
sentences.append(verb[idx])
images.append(image[idx])
targets.append(target[idx])
pad_masks.append(torch.zeros_like(verb[idx]).masked_fill_(verb[idx] == 0, 1).bool())
else:
verb_masks.append(0)
cl_masks.append(1)
sentences = torch.stack(sentences)
images = torch.stack(images)
targets = torch.stack(targets)
pad_masks = torch.stack(pad_masks)
verb_masks = torch.tensor(verb_masks, dtype=torch.bool)
cl_masks = torch.tensor(cl_masks, dtype=torch.bool)
else:
sentences = text
images = image
targets = target
pad_masks = torch.zeros_like(text).masked_fill_(text == 0, 1).bool()
# Encoding images and text
vis = self.backbone.encode_image(images)
word, state = self.backbone.encode_text(sentences)
# FPN neck and decoder
fq, f5 = self.neck(vis, state)
b, c, h, w = fq.size()
fq = self.decoder(fq, word, pad_masks)
metric_tensor = fq # b, c, h*w
fq = fq.reshape(b, c, h, w)
# Final prediction
pred = self.proj(fq, state)
if self.training:
if pred.shape[-2:] != targets.shape[-2:]:
targets = F.interpolate(targets, pred.shape[-2:], mode='nearest').detach()
loss = F.binary_cross_entropy_with_logits(pred[cl_masks], targets[cl_masks])
if self.metric_learning:
metric_loss = self.compute_metric_loss(metric_tensor, verb_masks, args=self.cfg)
loss = (loss + self.metric_loss_weight * metric_loss) / (1 + self.metric_loss_weight)
return pred[cl_masks].detach(), targets[cl_masks], loss
return pred.detach() # In eval mode, only return the predictions
def compute_metric_loss(self, metric_tensor, positive_verbs, negative_verbs, args) :
if args.loss_option == "ACL_verbonly" :
metric_loss = self.UniAngularContrastLoss(metric_tensor, positive_verbs, negative_verbs, m=args.margin_value, tau=args.temperature, verbonly=True, args=args)
elif args.loss_option == "ACL" :
metric_loss = self.UniAngularContrastLoss(metric_tensor, positive_verbs, negative_verbs, m=args.margin_value, tau=args.temperature, verbonly=False, args=args)
return metric_loss
def return_mask(self, emb_distance, verb_mask=None):
B_, B_ = emb_distance.shape
positive_mask = torch.zeros_like(emb_distance)
positive_mask.fill_diagonal_(1) # Set diagonal elements to 1 for all cases
if B_ < len(verb_mask):
# If B_ equals to 2*K (double the number of verb phrase)
for i in range(B_ // 2):
positive_mask[2 * i, 2 * i + 1] = 1
positive_mask[2 * i + 1, 2 * i] = 1
else:
# Process the case where we have a mix of sentences with and without verbs
i = 0
while i < B_:
if verb_mask[i] == 1:
positive_mask[i, i + 1] = 1
positive_mask[i + 1, i] = 1
i += 2
else:
i += 1
negative_mask = torch.ones_like(emb_distance) - positive_mask
return positive_mask, negative_mask
def UniAngularContrastLoss(self, total_fq, verb_mask, alpha=0.5, verbonly=True, m=0.5, tau=0.05, args=None):
_, C, HW = total_fq.shape
if verbonly :
emb = torch.mean(total_fq[verb_mask], dim=-1)
assert emb.shape[0] % 2 == 0, f"Embedding count {emb.shape[0]} is not divisible by 2."
else :
emb = torch.mean(total_fq, dim=-1)
B_ = emb.shape[0]
# emb = F.normalize(emb, p=2, dim=1)
emb_i = emb.unsqueeze(1).repeat(1, B_, 1) # (B_, B_, C)
emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) # (B_, B_, C)
sim = nn.CosineSimilarity(dim=-1, eps=1e-6)
sim_matrix = sim(emb_i, emb_j).reshape(B_, B_) # (B_, B_)
sim_matrix = torch.clamp(sim_matrix, min=-0.9999, max=0.9999)
positive_mask, negative_mask = self.return_mask(sim_matrix, verb_mask)
# Apply margin to positive pairs
sim_matrix_with_margin = sim_matrix.clone()
sim_matrix_with_margin[positive_mask.bool()] = torch.cos(torch.acos(sim_matrix[positive_mask.bool()]) + m / 57.2958)
# Scale logits with temperature
logits = sim_matrix_with_margin / tau
# Compute the softmax loss for all pairs
exp_logits = torch.exp(logits)
pos_exp_logits = exp_logits[positive_mask.bool()]
total_exp_logits = exp_logits.sum(dim=-1)
# Compute the final loss: L_arc = -log(e^(cos(theta + m)/tau) / sum(e^(cos(theta)/tau)))
positive_loss = -torch.log(pos_exp_logits / total_exp_logits[positive_mask.bool()])
angular_loss = positive_loss.mean()
return angular_loss |