import random import torch import torch.nn as nn import torch.nn.functional as F from model.clip import build_model from .layers import FPN, Projector, TransformerDecoder class CRIS_VerbOnly(nn.Module): def __init__(self, cfg): super().__init__() # Vision & Text Encoder clip_model = torch.jit.load(cfg.clip_pretrain, map_location="cpu").eval() self.backbone = build_model(clip_model.state_dict(), cfg.word_len).float() # Multi-Modal FPN self.neck = FPN(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out) # Decoder self.decoder = TransformerDecoder(num_layers=cfg.num_layers, d_model=cfg.vis_dim, nhead=cfg.num_head, dim_ffn=cfg.dim_ffn, dropout=cfg.dropout, return_intermediate=cfg.intermediate) # Projector self.proj = Projector(cfg.word_dim, cfg.vis_dim // 2, 3) self.metric_learning = False # cfg.metric_learning self.positive_strength = cfg.positive_strength self.metric_loss_weight = cfg.metric_loss_weight self.eps = cfg.ptb_rate self.cfg = cfg def forward(self, image, text, target=None, verb=None): ''' image: b, 3, h, w text: b, words target: b, 1, h, w verb: b, words (if applicable, only used in training mode for contrastive learning) ''' sentences, images, targets, pad_masks = [], [], [], [] if self.training: verb_masks = [] cl_masks = [] for idx in range(len(text)): sentences.append(text[idx]) images.append(image[idx]) targets.append(target[idx]) pad_masks.append(torch.zeros_like(text[idx]).masked_fill_(text[idx] == 0, 1).bool()) # If verb exists, process it if verb[idx].numel() > 0 and verb[idx].sum().item() > 0: verb_masks.extend([1, 1]) # Both original sentence and verb are marked cl_masks.extend([0, 1]) # Only verb gets marked for exclusion from CE loss sentences.append(verb[idx]) images.append(image[idx]) targets.append(target[idx]) pad_masks.append(torch.zeros_like(verb[idx]).masked_fill_(verb[idx] == 0, 1).bool()) else: verb_masks.append(0) cl_masks.append(0) sentences = torch.stack(sentences) images = torch.stack(images) targets = torch.stack(targets) pad_masks = torch.stack(pad_masks) verb_masks = torch.tensor(verb_masks, dtype=torch.bool) cl_masks = torch.tensor(cl_masks, dtype=torch.bool) else: sentences = text images = image targets = target pad_masks = torch.zeros_like(text).masked_fill_(text == 0, 1).bool() # Encoding images and text vis = self.backbone.encode_image(images) word, state = self.backbone.encode_text(sentences) # FPN neck and decoder fq, f5 = self.neck(vis, state) b, c, h, w = fq.size() fq = self.decoder(fq, word, pad_masks) metric_tensor = fq # b, c, h*w fq = fq.reshape(b, c, h, w) # Final prediction pred = self.proj(fq, state) if self.training: if pred.shape[-2:] != targets.shape[-2:]: targets = F.interpolate(targets, pred.shape[-2:], mode='nearest').detach() loss = F.binary_cross_entropy_with_logits(pred[~cl_masks], targets[~cl_masks]) if self.metric_learning: metric_loss = self.compute_metric_loss(metric_tensor, verb_masks, args=self.cfg) loss = (loss + self.metric_loss_weight * metric_loss) / (1 + self.metric_loss_weight) return pred.detach(), targets, loss return pred.detach() # In eval mode, only return the predictions def return_mask_hponly(self, emb_distance, verb_mask=None): B_, B_ = emb_distance.shape positive_mask = torch.zeros_like(emb_distance) positive_mask.fill_diagonal_(1) # Set diagonal elements to 1 for all cases if B_ < len(verb_mask): # If B_ equals to 2*K (double the number of verb phrase) for i in range(B_ // 2): positive_mask[2 * i, 2 * i + 1] = 1 positive_mask[2 * i + 1, 2 * i] = 1 else: # Process the case where we have a mix of sentences with and without verbs i = 0 while i < B_: if verb_mask[i] == 1: positive_mask[i, i + 1] = 1 positive_mask[i + 1, i] = 1 i += 2 else: i += 1 negative_mask = torch.ones_like(emb_distance) - positive_mask return positive_mask, negative_mask def return_mask_hphn(self, emb_distance, positive_verbs, negative_verbs, verb_mask): B_, B_ = emb_distance.shape positive_mask = torch.zeros_like(emb_distance) negative_mask = torch.ones_like(emb_distance) positive_mask.fill_diagonal_(1) if B_ < len(verb_mask): # Considering only verbs that pass the verb_mask filter positive_verbs = torch.tensor(positive_verbs)[verb_mask] negative_verbs = torch.tensor(negative_verbs)[verb_mask] # Exclude hard negatives from both masks (diagonal) for i in range(B_): if negative_verbs[i] == 1: positive_mask[i, i] = 0 negative_mask[i, i] = 0 i = 0 while i < B_: if positive_verbs[i] == 1: if i + 1 < B_ and positive_verbs[i + 1] == 1: positive_mask[i, i + 1] = 1 positive_mask[i + 1, i] = 1 i += 2 else: i += 1 else: # Exclude hard negatives from both masks (diagonal) for i in range(B_): if negative_verbs[i] == 1: positive_mask[i, i] = 0 negative_mask[i, i] = 0 # Apply the positive pairs logic similarly as above i = 0 while i < B_: if positive_verbs[i] == 1 and i + 1 < B_ and positive_verbs[i + 1] == 1: positive_mask[i, i + 1] = 1 positive_mask[i + 1, i] = 1 i += 2 else: i += 1 negative_mask = negative_mask - positive_mask return positive_mask, negative_mask def compute_contrastive_loss(self, fq, state, verb_masks, temperature=0.05): """ Compute contrastive loss (NCE) only for the samples with verb phrases. fq: shape (b, c, h*w) -> Encoded image features state: shape (b, d) -> Encoded text features (word representations) verb_masks: boolean mask indicating samples containing verb phrases temperature: scaling factor for contrastive loss """ # Extract only the samples that contain verbs using verb_masks fq_verb_samples = fq[verb_masks] # (num_verbs, c, h*w) state_verb_samples = state[verb_masks] # (num_verbs, d) fq_verb_samples = F.normalize(fq_verb_samples, p=2, dim=1) state_verb_samples = F.normalize(state_verb_samples, p=2, dim=1) # Compute the inner product between language conditioned feature output and encoded text (verb phrases) fq_verb_flat = fq_verb_samples.view(fq_verb_samples.size(0), -1) logits = torch.matmul(fq_verb_flat, state_verb_samples.t()) logits = logits / temperature # Create labels for the contrastive loss (positive pairs are diagonals) labels = torch.arange(logits.size(0), device=logits.device) contrastive_loss = F.cross_entropy(logits, labels) return contrastive_loss # cosine sim only on metric_tensor def AngularContrastiveLoss_1(self, embeddings, verb_mask, alpha=0.5, m=0.5, tau=0.05, args=None): """ Angular Margin Contrastive Loss function. - \( \theta_{i, i^*} \) represents the cosine similarity between the anchor \( h_i \) and the positive sample \( h_{i^*} \). - An angular margin \( m \) is added to increase the distance between the positive and negative pairs. - \( \tau \) is a temperature scaling factor to control the sharpness of the probability distribution. https://aclanthology.org/2022.acl-long.336.pdf \[ \mathcal{L}_{arc} = -\log \frac{\exp\left(\cos(\theta_{i,i^*} + m)/\tau\right)}{\exp\left(\cos(\theta_{i,i^*} + m)/\tau\right) + \sum_{j \neq i} \exp\left(\cos(\theta_{i,j})/\tau\right)} \] Args: embeddings: Encoded embeddings with shape (B, C, H*W) for image-text fused features. verb_mask: A mask indicating the samples with verb phrases. alpha: Weight for balancing positive and negative loss components. m: Angular margin to add to the cosine similarity of positive pairs. tau: Temperature scaling factor for softmax. args: Optional arguments for additional control. Returns: geometric_loss: Calculated Angular Metric Loss. """ # Get batch size and feature dimensions B_, C, HW = embeddings.shape # Mean pooling across the spatial dimension (H*W) and normalize embeddings emb = torch.mean(embeddings[verb_mask], dim=-1) # (B_, C) emb = F.normalize(emb, p=2, dim=1) # Normalize the embeddings # Create cosine similarity matrix sim = nn.CosineSimilarity(dim=-1, eps=1e-6) # Pairwise cosine similarities emb_i = emb.unsqueeze(1).repeat(1, B_, 1) # Expand emb_i to pair with all other embeddings emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) # Expand emb_j to pair with all other embeddings sim_matrix = sim(emb_i, emb_j).reshape(B_, B_) # (B_, B_) # Clamp values to avoid numerical instability sim_matrix = torch.clamp(sim_matrix, min=-0.9999, max=0.9999) # Apply angular margin for positive pairs positive_mask = torch.eye(B_, device=embeddings.device).bool() # Diagonal is the positive pairs sim_matrix_with_margin = sim_matrix.clone() # Apply the angular margin `m` only to positive pairs (diagonal) sim_matrix_with_margin[positive_mask] = torch.cos(torch.acos(sim_matrix[positive_mask]) + m) # Scale logits with temperature logits = sim_matrix_with_margin / tau # Compute the softmax loss for all pairs exp_logits = torch.exp(logits) pos_exp_logits = exp_logits[positive_mask] total_exp_logits = exp_logits.sum(dim=-1) # Compute the final loss: L_arc = -log(e^(cos(theta + m)/tau) / sum(e^(cos(theta)/tau))) positive_loss = -torch.log(pos_exp_logits / total_exp_logits) # Average the loss over the batch size angular_loss = positive_loss.mean() return angular_loss # cosine similarity on metric_tensor (image-text) and state (text eos) def AngularContrastiveLoss_2(self, fq, state, verb_masks, alpha=1.0, margin=0.5, temperature=0.05): """ Angular Margin Contrastive Loss function. - \( \theta_{i, i^*} \) represents the cosine similarity between the anchor \( h_i \) and the positive sample \( h_{i^*} \). - An angular margin \( m \) is added to increase the distance between the positive and negative pairs. - \( \tau \) is a temperature scaling factor to control the sharpness of the probability distribution. https://aclanthology.org/2022.acl-long.336.pdf \[ \mathcal{L}_{arc} = -\log \frac{\exp\left(\cos(\theta_{i,i^*} + m)/\tau\right)}{\exp\left(\cos(\theta_{i,i^*} + m)/\tau\right) + \sum_{j \neq i} \exp\left(\cos(\theta_{i,j})/\tau\right)} \] fq: (b, c, h*w) -> Encoded language-fused multimodal feature (metric_tensor) state: (b, d) -> Encoded text features (word representations) verb_masks: boolean mask indicating samples containing verb phrases alpha: weight for positive samples margin: the angular margin to enforce between positive pairs temperature: scaling factor for contrastive loss """ # Select only the verb-containing samples # Assume c equals to d (CLIP model backbone) fq_verb_samples = torch.mean(fq[verb_masks], dim=-1) # (num_verbs, d) state_verb_samples = state[verb_masks] # (num_verbs, d) fq_verb_samples = F.normalize(fq_verb_samples, p=2, dim=1) # (num_verbs, d) state_verb_samples = F.normalize(state_verb_samples, p=2, dim=1) # (num_verbs, d) # Compute cosine similarity (logits) between image and text features logits = torch.matmul(fq_verb_samples, state_verb_samples.t()) # (num_verbs, num_verbs) # Apply the angular margin to positive pairs (diagonal entries) diagonal_indices = torch.arange(logits.size(0), device=logits.device) positive_logits = logits[diagonal_indices, diagonal_indices] positive_logits_with_margin = positive_logits + margin # Replace the diagonal (positive) entries with the margin-added values logits[diagonal_indices, diagonal_indices] = positive_logits_with_margin logits = logits / temperature # Create positive mask (diagonal) and negative mask (non-diagonal) positive_mask = torch.eye(logits.size(0), device=logits.device).bool() # Diagonal for positive pairs negative_mask = ~positive_mask # Non-diagonal for negative pairs exp_logits = torch.exp(logits) # Exponentials of logits # Positive and negative softmax components pos_exp_logits = exp_logits[positive_mask].view(-1) # Positive pairs (diagonal entries) neg_exp_logits = exp_logits[negative_mask].view(logits.size(0), -1).sum(dim=1) # Sum of negative pairs # Final loss: -log(e^(cos(theta + m)/tau) / (e^(cos(theta + m)/tau) + sum(e^(cos(theta)/tau))) positive_loss = -torch.log(pos_exp_logits / (pos_exp_logits + neg_exp_logits)) loss = positive_loss.mean() return loss # def AngularMetricLoss_Seunghoon(self, embeddings, n_pos , verb_mask, alpha = 0.5, args = None): # # embeddings: ((2*B), C, (H*W)) # # n_pos : chunk size of positive pairs # # args: args # # returns: loss # geometric_loss = 0 # # flatten embeddings # B_, C, HW = embeddings.shape # emb = torch.mean(embeddings[verb_mask], dim=-1) # emb_i = emb.unsqueeze(1).repeat(1, B_, 1) # emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) # sim = nn.CosineSimilarity(dim=-1, eps=1e-6) # sim_matrix = sim(emb_i, emb_j).reshape(B_, B_) # sim_matrix = torch.clamp(sim_matrix, min=-0.9999, max=0.9999) # phi = torch.acos(sim_matrix) # # positive pairs and loss # positive_mask = torch.zeros_like(sim_matrix) # positive_mask.fill_diagonal_(1) # positive_loss = torch.sum((phi**2) * positive_mask) / B_ # # negative pairs and loss # negative_mask = torch.ones_like(sim_matrix) - positive_mask # phi_mask = phi < args.phi_threshold # negative_loss = (args.phi_threshold - phi)**2 # #print(negative_mask * phi_mask) # negative_loss = torch.sum(negative_loss * negative_mask * phi_mask) / (B_**2 - 2*B_) # #print("pos loss, neg loss : ", positive_loss, negative_loss) # geometric_loss = alpha * positive_loss + (1-alpha) * negative_loss # return geometric_loss