import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from transformers import AutoTokenizer, AutoModelForSequenceClassification from .clip import build_model, build_promptlearner, build_modified_model, PromptLearner, build_lclip_model from torch.cuda.amp import autocast as autocast from timm.models.layers import trunc_normal_ as __call_trunc_normal_ from timm.models.layers import variance_scaling_ from einops import rearrange, repeat from loguru import logger from transformers import AlignProcessor, AlignModel from sklearn.metrics import classification_report from huggingface_hub import PyTorchModelHubMixin from .layers import FPN, TransformerDecoder, ViTFPN, AdaptiveSpatialFeatureFusion, Text_Projector, Image_Projector, Adapter, GAP from cisen.model.clip import CLIP def lecun_normal_(tensor): variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") def trunc_normal_(tensor, mean=0.0, std=1.0): __call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std) class CISEN_vit(nn.Module, PyTorchModelHubMixin): def __init__(self, cfg): super().__init__() # Vision & Text Encoder & Label Encoder clip_model = torch.jit.load(cfg.clip_pretrain, map_location="cpu").eval() backbone, image_resolution, vision_heads, embed_dim, vision_width, patch_size = build_model(clip_model.state_dict(), cfg.word_len) self.backbone = backbone.float() self.patch_emb = image_resolution // patch_size cfg.image_resolution = image_resolution cfg.input_size = image_resolution cfg.heads = vision_heads // 32 cfg.emb_dim = vision_width cfg.output_dim = embed_dim # multi-scale adapter # Multi-Modal FPN self.FPN = ViTFPN(image_resolution, in_channels=cfg.fpn_in, out_channels=cfg.fpn_out) # Fined-grained Fusion # self.FGFusion = TransformerDecoder(num_layers=cfg.num_layers, # d_model=cfg.vis_dim, # nhead=cfg.num_head, # dim_ffn=cfg.dim_ffn, # dropout=cfg.dropout, # return_intermediate=cfg.intermediate) # image-text transformer # self.trans = nn.Linear(1024, 1024) self.ADP = Adapter(cfg.output_dim, 4) # parameter self.ratio = cfg.ratio self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) self.share_temperature = True self.ce = nn.CrossEntropyLoss() self.ms_adaptor = nn.ModuleList( [ nn.Sequential( nn.ConvTranspose2d(cfg.emb_dim, cfg.emb_dim, 2, 2), nn.GroupNorm(32, cfg.emb_dim), nn.GELU(), nn.ConvTranspose2d(cfg.emb_dim, cfg.emb_dim, 2, 2), ), nn.Sequential( nn.ConvTranspose2d(cfg.emb_dim, cfg.emb_dim, 2, 2), ), nn.Sequential( nn.Identity(), ), nn.Sequential( nn.MaxPool2d(2), ), ] ) self.ms_adaptor.apply(self.init_adaptor) def init_adaptor(self, m): if isinstance(m, nn.Conv2d): lecun_normal_(m.weight) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.GroupNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) elif isinstance(m, nn.ConvTranspose2d): lecun_normal_(m.weight) if m.bias is not None: nn.init.zeros_(m.bias) # self.fc = nn.Linear(512, cfg.num_classes) def IT_loss(self, image_features, text_features): # b, 1024 / b, 1024 batch = image_features.shape[0] # # normalized features image_features = image_features / image_features.norm(dim=-1, keepdim=True) text_features = text_features / text_features.norm(dim=-1, keepdim=True) # cosine similarity as logits logit_scale = self.logit_scale.exp() logits_per_image = logit_scale * image_features @ text_features.t() logits_per_text = logits_per_image.t() # shape = [global_batch_size, global_batch_size] contrastive_labels = torch.arange(batch).to(logits_per_image.device) contrastive_loss = (self.ce(logits_per_image, contrastive_labels) + self.ce(logits_per_text, contrastive_labels)) * 0.5 return contrastive_loss def forward(self, img, txt, stage): if stage == '1st': ''' img: b, 3, h, w word: b, words word_mask: b, words mask: b, 1, h, w stage: 1st or 2nd stage ''' # padding mask used in decoder pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool() # vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7 # word: b, length, 512 # text: b, 1024 # image: b, 1024 vis, image = self.backbone.encode_image(img) word, text = self.backbone.encode_text(txt) x = self.ADP(image) x = self.ratio * x + (1-self.ratio) * image # b, 1024 # fq_t = self.FPN(vis, x) # # fv_t = self.gap(fq_t) loss1 = self.IT_loss(x, text) loss = loss1 ft = text fi = x fv = None elif stage == '2nd': ''' img: b, 3, h, w word: b, words word_mask: b, words mask: b, 1, h, w stage: 1st or 2nd stage ''' # padding mask used in decoder pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool() # vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7 # word: b, length, 512 # text: b, 1024 # image: b, 1024 vis, image = self.backbone.encode_image(img) word, text = self.backbone.encode_text(txt) x = self.ADP(image) x = self.ratio * x + (1 - self.ratio) * image # Construct multi-scale feats vis_trans = [] for i in range(len(self.ms_adaptor)): x_ = rearrange( vis[i], "b (h w) c -> b c h w", h=self.patch_emb, w=self.patch_emb, ).contiguous() feats = self.ms_adaptor[i](x_) vis_trans.append(feats) # fq = self.FPN(vis, x_t) fv_t = self.FPN(vis_trans[1:], x, False) # fv_t = self.gap(fq_t) # b, 1024 loss2 = self.IT_loss(fv_t, text) loss = (loss2) fv = fv_t ft = text fi = x return loss, fv, fi, ft def visualize(self, img, txt): vis, image = self.backbone.encode_image(img) word, text = self.backbone.encode_text(txt) x = self.ADP(image) x = self.ratio * x + (1 - self.ratio) * image # Construct multi-scale feats vis_trans = [] for i in range(len(self.ms_adaptor)): x_ = rearrange( vis[i], "b (h w) c -> b c h w", h=self.patch_emb, w=self.patch_emb, ).contiguous() feats = self.ms_adaptor[i](x_) vis_trans.append(feats) # fq = self.FPN(vis, x_t) fv_t = self.FPN(vis_trans[1:], x, True) ft_t = self.FPN(vis_trans[1:], text, True) return vis, fv_t, ft_t class CISEN_rsvit(nn.Module, PyTorchModelHubMixin): def __init__(self, cfg): super().__init__() # Vision & Text Encoder & Label Encoder clip_model = torch.load(cfg.clip_pretrain, map_location="cpu") backbone, image_resolution, vision_heads, embed_dim, vision_width, patch_size = build_model(clip_model, cfg.word_len) self.backbone = backbone.float() self.patch_emb = image_resolution // patch_size cfg.image_resolution = image_resolution cfg.input_size = image_resolution cfg.heads = vision_heads // 32 cfg.emb_dim = vision_width cfg.output_dim = embed_dim # multi-scale adapter # Multi-Modal FPN self.FPN = ViTFPN(image_resolution, in_channels=cfg.fpn_in, out_channels=cfg.fpn_out) # Fined-grained Fusion # self.FGFusion = TransformerDecoder(num_layers=cfg.num_layers, # d_model=cfg.vis_dim, # nhead=cfg.num_head, # dim_ffn=cfg.dim_ffn, # dropout=cfg.dropout, # return_intermediate=cfg.intermediate) # image-text transformer # self.trans = nn.Linear(1024, 1024) self.ADP = Adapter(cfg.output_dim, 4) # parameter self.ratio = cfg.ratio self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) self.share_temperature = True self.ce = nn.CrossEntropyLoss() self.ms_adaptor = nn.ModuleList( [ nn.Sequential( nn.ConvTranspose2d(cfg.emb_dim, cfg.emb_dim, 2, 2), nn.GroupNorm(32, cfg.emb_dim), nn.GELU(), nn.ConvTranspose2d(cfg.emb_dim, cfg.emb_dim, 2, 2), ), nn.Sequential( nn.ConvTranspose2d(cfg.emb_dim, cfg.emb_dim, 2, 2), ), nn.Sequential( nn.Identity(), ), nn.Sequential( nn.MaxPool2d(2), ), ] ) self.ms_adaptor.apply(self.init_adaptor) def init_adaptor(self, m): if isinstance(m, nn.Conv2d): lecun_normal_(m.weight) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.GroupNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) elif isinstance(m, nn.ConvTranspose2d): lecun_normal_(m.weight) if m.bias is not None: nn.init.zeros_(m.bias) # self.fc = nn.Linear(512, cfg.num_classes) def IT_loss(self, image_features, text_features): # b, 1024 / b, 1024 batch = image_features.shape[0] # # normalized features image_features = image_features / image_features.norm(dim=-1, keepdim=True) text_features = text_features / text_features.norm(dim=-1, keepdim=True) # cosine similarity as logits logit_scale = self.logit_scale.exp() logits_per_image = logit_scale * image_features @ text_features.t() logits_per_text = logits_per_image.t() # shape = [global_batch_size, global_batch_size] contrastive_labels = torch.arange(batch).to(logits_per_image.device) contrastive_loss = (self.ce(logits_per_image, contrastive_labels) + self.ce(logits_per_text, contrastive_labels)) * 0.5 return contrastive_loss def image_encode(self, img): vis, image = self.backbone.encode_image(img) x = self.ADP(image) x = self.ratio * x + (1 - self.ratio) * image return x def text_encode(self, txt): word, text = self.backbone.encode_text(txt) return text def forward(self, img, txt, stage): if stage == '1st': ''' img: b, 3, h, w word: b, words word_mask: b, words mask: b, 1, h, w stage: 1st or 2nd stage ''' # padding mask used in decoder pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool() # vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7 # word: b, length, 512 # text: b, 1024 # image: b, 1024 vis, image = self.backbone.encode_image(img) word, text = self.backbone.encode_text(txt) x = self.ADP(image) x = self.ratio * x + (1-self.ratio) * image # b, 1024 # fq_t = self.FPN(vis, x) # # fv_t = self.gap(fq_t) loss1 = self.IT_loss(x, text) loss = loss1 ft = text fi = x fv = None elif stage == '2nd': ''' img: b, 3, h, w word: b, words word_mask: b, words mask: b, 1, h, w stage: 1st or 2nd stage ''' # padding mask used in decoder pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool() # vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7 # word: b, length, 512 # text: b, 1024 # image: b, 1024 vis, image = self.backbone.encode_image(img) word, text = self.backbone.encode_text(txt) x = self.ADP(image) x = self.ratio * x + (1 - self.ratio) * image # Construct multi-scale feats vis_trans = [] for i in range(len(self.ms_adaptor)): x_ = rearrange( vis[i], "b (h w) c -> b c h w", h=self.patch_emb, w=self.patch_emb, ).contiguous() feats = self.ms_adaptor[i](x_) vis_trans.append(feats) # fq = self.FPN(vis, x_t) fv_t = self.FPN(vis_trans[1:], x, False) # fv_t = self.gap(fq_t) # b, 1024 loss2 = self.IT_loss(fv_t, text) loss = (loss2) fv = fv_t ft = text fi = x return loss, fv, fi, ft def visualize(self, img): vis, image = self.backbone.encode_image(img) x = self.ADP(image) x = self.ratio * x + (1 - self.ratio) * image # Construct multi-scale feats vis_trans = [] for i in range(len(self.ms_adaptor)): x_ = rearrange( vis[i], "b (h w) c -> b c h w", h=self.patch_emb, w=self.patch_emb, ).contiguous() feats = self.ms_adaptor[i](x_) vis_trans.append(feats) fv_t = self.FPN(vis_trans[1:], x, True) return vis, fv_t class CISEN_vit(nn.Module, PyTorchModelHubMixin): def __init__(self, cfg): super().__init__() # Vision & Text Encoder & Label Encoder clip_model = torch.jit.load(cfg.clip_pretrain, map_location="cpu").eval() backbone, image_resolution, vision_heads, embed_dim, vision_width, patch_size = build_model(clip_model.state_dict(), cfg.word_len) self.backbone = backbone.float() self.patch_emb = image_resolution // patch_size cfg.image_resolution = image_resolution cfg.input_size = image_resolution cfg.heads = vision_heads // 32 cfg.emb_dim = vision_width cfg.output_dim = embed_dim # multi-scale adapter # Multi-Modal FPN self.FPN = ViTFPN(cfg, in_channels=cfg.fpn_in, out_channels=cfg.fpn_out) # Fined-grained Fusion # self.FGFusion = TransformerDecoder(num_layers=cfg.num_layers, # d_model=cfg.vis_dim, # nhead=cfg.num_head, # dim_ffn=cfg.dim_ffn, # dropout=cfg.dropout, # return_intermediate=cfg.intermediate) # image-text transformer # self.trans = nn.Linear(1024, 1024) self.ADP = Adapter(cfg.output_dim, 4) # parameter self.ratio = cfg.ratio self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) self.share_temperature = True self.ce = nn.CrossEntropyLoss() self.ms_adaptor = nn.ModuleList( [ nn.Sequential( nn.ConvTranspose2d(cfg.emb_dim, cfg.emb_dim, 2, 2), nn.GroupNorm(32, cfg.emb_dim), nn.GELU(), nn.ConvTranspose2d(cfg.emb_dim, cfg.emb_dim, 2, 2), ), nn.Sequential( nn.ConvTranspose2d(cfg.emb_dim, cfg.emb_dim, 2, 2), ), nn.Sequential( nn.Identity(), ), nn.Sequential( nn.MaxPool2d(2), ), ] ) self.ms_adaptor.apply(self.init_adaptor) def init_adaptor(self, m): if isinstance(m, nn.Conv2d): lecun_normal_(m.weight) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.GroupNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) elif isinstance(m, nn.ConvTranspose2d): lecun_normal_(m.weight) if m.bias is not None: nn.init.zeros_(m.bias) # self.fc = nn.Linear(512, cfg.num_classes) def IT_loss(self, image_features, text_features): # b, 1024 / b, 1024 batch = image_features.shape[0] # # normalized features image_features = image_features / image_features.norm(dim=-1, keepdim=True) text_features = text_features / text_features.norm(dim=-1, keepdim=True) # cosine similarity as logits logit_scale = self.logit_scale.exp() logits_per_image = logit_scale * image_features @ text_features.t() logits_per_text = logits_per_image.t() # shape = [global_batch_size, global_batch_size] contrastive_labels = torch.arange(batch).to(logits_per_image.device) contrastive_loss = (self.ce(logits_per_image, contrastive_labels) + self.ce(logits_per_text, contrastive_labels)) * 0.5 return contrastive_loss def forward(self, img, txt, stage): if stage == '1st': ''' img: b, 3, h, w word: b, words word_mask: b, words mask: b, 1, h, w stage: 1st or 2nd stage ''' # padding mask used in decoder pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool() # vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7 # word: b, length, 512 # text: b, 1024 # image: b, 1024 vis, image = self.backbone.encode_image(img) word, text = self.backbone.encode_text(txt) x = self.ADP(image) x = self.ratio * x + (1-self.ratio) * image # b, 1024 # fq_t = self.FPN(vis, x) # # fv_t = self.gap(fq_t) loss1 = self.IT_loss(x, text) loss = loss1 ft = text fi = x fv = None elif stage == '2nd': ''' img: b, 3, h, w word: b, words word_mask: b, words mask: b, 1, h, w stage: 1st or 2nd stage ''' # padding mask used in decoder pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool() # vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7 # word: b, length, 512 # text: b, 1024 # image: b, 1024 vis, image = self.backbone.encode_image(img) word, text = self.backbone.encode_text(txt) x = self.ADP(image) x = self.ratio * x + (1 - self.ratio) * image # Construct multi-scale feats vis_trans = [] for i in range(len(self.ms_adaptor)): x_ = rearrange( vis[i], "b (h w) c -> b c h w", h=self.patch_emb, w=self.patch_emb, ).contiguous() feats = self.ms_adaptor[i](x_) vis_trans.append(feats) # fq = self.FPN(vis, x_t) fv_t = self.FPN(vis_trans[1:], x, False) # fv_t = self.gap(fq_t) # b, 1024 loss2 = self.IT_loss(fv_t, text) loss = (loss2) fv = fv_t ft = text fi = x return loss, fv, fi, ft def visualize(self, img, txt): vis, image = self.backbone.encode_image(img) word, text = self.backbone.encode_text(txt) x = self.ADP(image) x = self.ratio * x + (1 - self.ratio) * image # Construct multi-scale feats vis_trans = [] for i in range(len(self.ms_adaptor)): x_ = rearrange( vis[i], "b (h w) c -> b c h w", h=self.patch_emb, w=self.patch_emb, ).contiguous() feats = self.ms_adaptor[i](x_) vis_trans.append(feats) # fq = self.FPN(vis, x_t) fv_t = self.FPN(vis_trans[1:], x, True) ft_t = self.FPN(vis_trans[1:], text, True) return vis, fv_t, ft_t class CISEN_rsvit_classification(nn.Module): def __init__(self, cfg): super().__init__() # Vision & Text Encoder & Label Encoder clip_model = torch.load(cfg.clip_pretrain, map_location="cpu") backbone, image_resolution, vision_heads, embed_dim, vision_width, patch_size = build_model(clip_model, cfg.word_len) self.backbone = backbone.float() self.patch_emb = image_resolution // patch_size num_classes_fc = 512 num_classes_output = 10 self.num_classes_fc = num_classes_fc # Number of classes for fully connected layer self.num_classes_output = num_classes_output # Number of classes for output layer # Add a fully connected layer self.fc = nn.Linear(in_features=cfg.vis_dim, out_features=num_classes_fc) # Add an output layer for multi-label classification self.output_layer = nn.Linear(in_features=num_classes_fc, out_features=num_classes_output) self.criterion = nn.BCEWithLogitsLoss() cfg.image_resolution = image_resolution cfg.input_size = image_resolution cfg.heads = vision_heads // 32 cfg.emb_dim = vision_width cfg.output_dim = embed_dim def IT_loss(self, labels, labels_pre): labels = labels.squeeze(1) loss = self.criterion(labels_pre, labels) return loss def forward(self, img, labels): _, image_features = self.backbone.encode_image(img) # Fully connected layer fc_output = self.fc(image_features) # Apply ReLU activation function fc_output = F.relu(fc_output) # Output layer for multi-label classification labels_pre = self.output_layer(fc_output) loss2 = self.IT_loss(labels, labels_pre) return labels_pre, loss2 class CISEN_new(nn.Module): def __init__(self, cfg): super().__init__() # Vision & Text Encoder & Label Encoder clip_model = torch.jit.load(cfg.clip_pretrain, map_location="cpu").eval() backbone, image_resolution, vision_heads, embed_dim, vision_width, _ = build_model(clip_model.state_dict(), cfg.word_len) self.backbone = backbone.float() cfg.input_size = image_resolution cfg.heads = vision_heads cfg.emb_dim = vision_width * 32 cfg.output_dim = embed_dim # Multi-Modal FPN self.FPN = FPN(cfg, in_channels=cfg.fpn_in, out_channels=cfg.fpn_out) # Fined-grained Fusion # self.FGFusion = TransformerDecoder(num_layers=cfg.num_layers, # d_model=cfg.vis_dim, # nhead=cfg.num_head, # dim_ffn=cfg.dim_ffn, # dropout=cfg.dropout, # return_intermediate=cfg.intermediate) # image-text transformer # self.trans = nn.Linear(1024, 1024) self.ADP = Adapter(cfg.output_dim, 4) self.gap = GAP((1,1)) # parameter self.ratio = cfg.ratio self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) self.share_temperature = True self.margin = 1 self.eps = 1e-3 self.ce = nn.CrossEntropyLoss() #1st stage self.lamda1 = cfg.lamda1 self.lamda2 = cfg.lamda2 self.avg = nn.AdaptiveAvgPool2d((1,1)) # self.fc = nn.Linear(512, cfg.num_classes) def IT_loss(self, image_features, text_features): # b, 1024 / b, 1024 batch = image_features.shape[0] # # normalized features image_features = image_features / image_features.norm(dim=-1, keepdim=True) text_features = text_features / text_features.norm(dim=-1, keepdim=True) # cosine similarity as logits logit_scale = self.logit_scale.exp() logits_per_image = logit_scale * image_features @ text_features.t() logits_per_text = logits_per_image.t() # shape = [global_batch_size, global_batch_size] contrastive_labels = torch.arange(batch).to(logits_per_image.device) contrastive_loss = (self.ce(logits_per_image, contrastive_labels) + self.ce(logits_per_text, contrastive_labels)) * 0.5 return contrastive_loss def forward(self, img, txt, stage): if stage == '1st': ''' img: b, 3, h, w word: b, words word_mask: b, words mask: b, 1, h, w stage: 1st or 2nd stage ''' # padding mask used in decoder pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool() # vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7 # word: b, length, 512 # text: b, 1024 # image: b, 1024 vis, image = self.backbone.encode_image(img) word, text = self.backbone.encode_text(txt) x = self.ADP(image) x = self.ratio * x + (1-self.ratio) * image # b, 1024 # fq_t = self.FPN(vis, x) # # fv_t = self.gap(fq_t) loss1 = self.IT_loss(x, text) loss = loss1 ft = text fi = x fv = None elif stage == '2nd': ''' img: b, 3, h, w word: b, words word_mask: b, words mask: b, 1, h, w stage: 1st or 2nd stage ''' # padding mask used in decoder pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool() # vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7 # word: b, length, 512 # text: b, 1024 # image: b, 1024 vis, image = self.backbone.encode_image(img) word, text = self.backbone.encode_text(txt) x = self.ADP(image) x = self.ratio * x + (1 - self.ratio) * image # x_t = self.trans(x) # fq = self.FPN(vis, x_t) fq_t = self.FPN(vis, x) fv_t = self.gap(fq_t) # b, 1024 loss2 = self.IT_loss(fv_t, text) loss = (loss2) fv = fv_t ft = text fi = x elif stage == '3rd': ''' img: b, 3, h, w word: b, words word_mask: b, words mask: b, 1, h, w stage: 1st or 2nd stage ''' # padding mask used in decoder pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool() # vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7 # word: b, length, 512 # text: b, 1024 # image: b, 1024 vis, image = self.backbone.encode_image(img) word, text = self.backbone.encode_text(txt) x = self.ADP(text) ratio = 0.2 x = ratio * x + (1 - ratio) * text # x_t = self.trans(x) # fq = self.FPN(vis, x_t) # b, 1024 loss1 = self.IT_loss(image, x) loss = loss1 fv = None ft = x fi = image elif stage == '4th': ''' img: b, 3, h, w word: b, words word_mask: b, words mask: b, 1, h, w stage: 1st or 2nd stage ''' # padding mask used in decoder pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool() # vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7 # word: b, length, 512 # text: b, 1024 # image: b, 1024 vis, image = self.backbone.encode_image(img) word, text = self.backbone.encode_text(txt) # x = self.ADP(image) # ratio = 0.2 # x = ratio * x + (1 - ratio) * text fq_t = self.FPN(vis, image) fv_t = self.gap(fq_t) ratio_1 = 0.2 # b, 1024 loss2 = self.IT_loss(fv_t, text) loss = loss2 fv = fv_t fi = None ft = text elif stage == '5th': ''' img: b, 3, h, w word: b, words word_mask: b, words mask: b, 1, h, w stage: 1st or 2nd stage ''' # padding mask used in decoder pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool() # vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7 # word: b, length, 512 # text: b, 1024 # image: b, 1024 vis, image = self.backbone.encode_image(img) word, text = self.backbone.encode_text(txt) x = self.ADP(image) ratio = 0.2 x = ratio * x + (1 - ratio) * image y = self.ADP_t(text) ratio_1 = 0.2 y = ratio * y + (1 - ratio_1) * text fq_t = self.FPN(vis, image) fv_t = self.gap(fq_t) # b, 1024 loss2 = self.IT_loss(fv_t, y) loss = loss2 fv = fv_t fi = x ft = y return loss, fv, fi, ft class CISEN_lclip(nn.Module): def __init__(self, cfg): super().__init__() # Vision & Text Encoder & Label Encoder clip_model = torch.load(cfg.clip_pretrain, map_location="cpu") # print(type(clip_model)) backbone, image_resolution, vision_heads, embed_dim, vision_width, _ = build_lclip_model(clip_model, load_from_clip=True) self.backbone = backbone.float() cfg.input_size = image_resolution cfg.heads = vision_heads // 32 cfg.emb_dim = vision_width cfg.output_dim = embed_dim # Multi-Modal FPN self.FPN = FPN(cfg, in_channels=cfg.fpn_in, out_channels=cfg.fpn_out) # Fined-grained Fusion # self.FGFusion = TransformerDecoder(num_layers=cfg.num_layers, # d_model=cfg.vis_dim, # nhead=cfg.num_head, # dim_ffn=cfg.dim_ffn, # dropout=cfg.dropout, # return_intermediate=cfg.intermediate) # image-text transformer # self.trans = nn.Linear(1024, 1024) self.ADP = Adapter(cfg.output_dim, 4) self.gap = GAP((1,1)) # parameter self.ratio = cfg.ratio self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) self.share_temperature = True self.margin = 1 self.eps = 1e-3 self.ce = nn.CrossEntropyLoss() #1st stage self.lamda1 = cfg.lamda1 self.lamda2 = cfg.lamda2 self.avg = nn.AdaptiveAvgPool2d((1,1)) # self.fc = nn.Linear(512, cfg.num_classes) def IT_loss(self, image_features, text_features): # b, 1024 / b, 1024 batch = image_features.shape[0] # # normalized features image_features = image_features / image_features.norm(dim=-1, keepdim=True) text_features = text_features / text_features.norm(dim=-1, keepdim=True) # cosine similarity as logits logit_scale = self.logit_scale.exp() logits_per_image = logit_scale * image_features @ text_features.t() logits_per_text = logits_per_image.t() # shape = [global_batch_size, global_batch_size] contrastive_labels = torch.arange(batch).to(logits_per_image.device) contrastive_loss = (self.ce(logits_per_image, contrastive_labels) + self.ce(logits_per_text, contrastive_labels)) * 0.5 return contrastive_loss def forward(self, img, txt, stage): if stage == '1st': ''' img: b, 3, h, w word: b, words word_mask: b, words mask: b, 1, h, w stage: 1st or 2nd stage ''' # padding mask used in decoder pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool() # vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7 # word: b, length, 512 # text: b, 1024 # image: b, 1024 vis, image = self.backbone.encode_image(img) text = self.backbone.encode_text(txt) x = self.ADP(image) x = self.ratio * x + (1-self.ratio) * image # b, 1024 # fq_t = self.FPN(vis, x) # # fv_t = self.gap(fq_t) loss1 = self.IT_loss(x, text) loss = loss1 ft = text fi = x fv = None elif stage == '2nd': ''' img: b, 3, h, w word: b, words word_mask: b, words mask: b, 1, h, w stage: 1st or 2nd stage ''' # padding mask used in decoder pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool() # vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7 # word: b, length, 512 # text: b, 1024 # image: b, 1024 vis, image = self.backbone.encode_image(img) word, text = self.backbone.encode_text(txt) x = self.ADP(image) x = self.ratio * x + (1 - self.ratio) * image # x_t = self.trans(x) # fq = self.FPN(vis, x_t) fq_t = self.FPN(vis, x) fv_t = self.gap(fq_t) # b, 1024 loss2 = self.IT_loss(fv_t, text) loss = (loss2) fv = fv_t ft = text fi = x elif stage == '3rd': ''' img: b, 3, h, w word: b, words word_mask: b, words mask: b, 1, h, w stage: 1st or 2nd stage ''' # padding mask used in decoder pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool() # vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7 # word: b, length, 512 # text: b, 1024 # image: b, 1024 vis, image = self.backbone.encode_image(img) text = self.backbone.encode_text(txt) x = self.ADP(text) ratio = 0.2 x = ratio * x + (1 - ratio) * text # x_t = self.trans(x) # fq = self.FPN(vis, x_t) # b, 1024 loss1 = self.IT_loss(image, x) loss = loss1 fv = None ft = x fi = image elif stage == '4th': ''' img: b, 3, h, w word: b, words word_mask: b, words mask: b, 1, h, w stage: 1st or 2nd stage ''' # padding mask used in decoder pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool() # vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7 # word: b, length, 512 # text: b, 1024 # image: b, 1024 vis, image = self.backbone.encode_image(img) word, text = self.backbone.encode_text(txt) # x = self.ADP(image) # ratio = 0.2 # x = ratio * x + (1 - ratio) * text fq_t = self.FPN(vis, image) fv_t = self.gap(fq_t) ratio_1 = 0.2 # b, 1024 loss2 = self.IT_loss(fv_t, text) loss = loss2 fv = fv_t fi = None ft = text elif stage == '5th': ''' img: b, 3, h, w word: b, words word_mask: b, words mask: b, 1, h, w stage: 1st or 2nd stage ''' # padding mask used in decoder pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool() # vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7 # word: b, length, 512 # text: b, 1024 # image: b, 1024 vis, image = self.backbone.encode_image(img) word, text = self.backbone.encode_text(txt) x = self.ADP(image) ratio = 0.2 x = ratio * x + (1 - ratio) * image y = self.ADP_t(text) ratio_1 = 0.2 y = ratio * y + (1 - ratio_1) * text fq_t = self.FPN(vis, image) fv_t = self.gap(fq_t) # b, 1024 loss2 = self.IT_loss(fv_t, y) loss = loss2 fv = fv_t fi = x ft = y return loss, fv, fi, ft class GeoRSCLIP(nn.Module): def __init__(self, cfg): super().__init__() # Vision & Text Encoder & Label Encoder clip_model = torch.load(cfg.clip_pretrain, map_location="cpu") backbone, image_resolution, vision_heads, embed_dim, vision_width, patch_size = build_model(clip_model, cfg.word_len) self.backbone = backbone.float() def forward(self, img, txt, stage): pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool() # vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7 # word: b, length, 512 # text: b, 1024 # image: b, 1024 vis, image = self.backbone.encode_image(img) word, text = self.backbone.encode_text(txt) loss = None ft = text fi = image fv = None return loss, fv, fi, ft class CISEN(nn.Module): def __init__(self, cfg): super().__init__() # Vision & Text Encoder & Label Encoder clip_model = torch.jit.load(cfg.clip_pretrain, map_location="cpu").eval() self.backbone = build_model(clip_model.state_dict(), cfg.word_len).float() # Multi-Modal FPN self.FPN = FPN(cfg, in_channels=cfg.fpn_in, out_channels=cfg.fpn_out) # Fined-grained Fusion self.FGFusion = TransformerDecoder(num_layers=cfg.num_layers, d_model=cfg.vis_dim, nhead=cfg.num_head, dim_ffn=cfg.dim_ffn, dropout=cfg.dropout, return_intermediate=cfg.intermediate) # adaptively aggretation self.ASFF = AdaptiveSpatialFeatureFusion(cfg, in_channels=cfg.fpn_in, out_channels=cfg.fpn_out) # text projector self.projT = Text_Projector(cfg, in_channels=cfg.fpn_in, out_channels=cfg.fpn_out) # image projector # self.projI = Image_Projector(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out) # parameter self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) self.multi_label_logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) self.share_temperature = True self.margin = 1 self.eps = 1e-3 self.ce = nn.CrossEntropyLoss() #1st stage self.lamda1 = cfg.lamda1 self.lamda2 = cfg.lamda2 self.beta1 = cfg.beta1 self.beta2 = cfg.beta2 self.avg = nn.AdaptiveAvgPool2d((1,1)) # self.fc = nn.Linear(512, cfg.num_classes) #2nd stage self.pos_samples = cfg.pos_samples self.neg_samples = cfg.neg_samples def IT_loss(self, image_features, text_features): # b, 1024 / b, 1024 batch = image_features.shape[0] # # normalized features image_features = image_features / image_features.norm(dim=-1, keepdim=True) text_features = text_features / text_features.norm(dim=-1, keepdim=True) # cosine similarity as logits logit_scale = self.logit_scale.exp() logits_per_image = logit_scale * image_features @ text_features.t() logits_per_text = logits_per_image.t() # shape = [global_batch_size, global_batch_size] contrastive_labels = torch.arange(batch).to(logits_per_image.device) contrastive_loss = (self.ce(logits_per_image, contrastive_labels) + self.ce(logits_per_text, contrastive_labels)) * 0.5 return contrastive_loss def IET_loss(self, image_features, text_features, pos_samples, beta): # b, 1024 / b, 1024 # # normalized features image_features = [image_feature / image_feature.norm(dim=-1, keepdim=True) for image_feature in image_features] text_features = text_features / text_features.norm(dim=-1, keepdim=True) # cosine similarity as logits logit_scale = self.logit_scale.exp() # logits_per_image = [logit_scale * image_feature @ text_features.t() for image_feature in image_features] logits_per_image = [logit_scale * torch.sum(torch.mul(image_feature, text_features),1) for image_feature in image_features] logits_per_image = torch.stack(logits_per_image).t() b = logits_per_image.shape[0] loss1 = torch.norm(text_features - image_features[0]) positive_tagsT = torch.zeros(b,len(image_features)).to(text_features.device) negative_tagsT = torch.zeros(b,len(image_features)).to(text_features.device) positive_tagsT[:, 0 : pos_samples + 1] = 1 negative_tagsT[:, pos_samples + 1 : -1] = 1 maskT = positive_tagsT.unsqueeze(1) * negative_tagsT.unsqueeze(-1) pos_score_matT = logits_per_image * positive_tagsT neg_score_matT = logits_per_image * negative_tagsT IW_pos3T = pos_score_matT.unsqueeze(1) IW_neg3T = neg_score_matT.unsqueeze(-1) OT = 1 + IW_neg3T - IW_pos3T O_maskT = maskT * OT diffT = torch.clamp(O_maskT, 0) violationT = torch.sign(diffT).sum(1).sum(1) diffT = diffT.sum(1).sum(1) lossT = torch.mean(diffT / (violationT + self.eps)) loss = beta * loss1 + lossT return loss def test_IET_loss(self, image_features, text_features, pos_samples, beta1, beta2): # text_features: enhanced_features # b, 1024 / b, 1024 # # normalized features image_features = image_features / image_features.norm(dim=-1, keepdim=True) text_features = text_features / text_features.norm(dim=-1, keepdim=True) image_features = image_features.unsqueeze(1) # cosine similarity as logits logit_scale = self.logit_scale.exp() # image_features = image_features.expand(-1, text_features.shape[1], -1) logits_per_image = logit_scale * torch.matmul(image_features, text_features.transpose(1, 2)) logits_per_image = logits_per_image.squeeze(1) # logits_per_image = logit_scale * image_features @ text_features.t() # logits_per_image = [logit_scale * image_feature @ text_features.t() for image_feature in image_features] b = logits_per_image.shape[0] # loss1 = torch.norm(text_features[:, 0, :] - image_features.squeeze(1)) positive_tagsT = torch.zeros(b, text_features.shape[1]).to(text_features.device) negative_tagsT = torch.zeros(b, text_features.shape[1]).to(text_features.device) positive_tagsT[:, 0 : pos_samples + 1] = 1 negative_tagsT[:, pos_samples + 1 : -1] = 1 maskT = positive_tagsT.unsqueeze(1) * negative_tagsT.unsqueeze(-1) pos_score_matT = logits_per_image * positive_tagsT neg_score_matT = logits_per_image * negative_tagsT IW_pos3T = pos_score_matT.unsqueeze(1) IW_neg3T = neg_score_matT.unsqueeze(-1) OT = 1 + IW_neg3T - IW_pos3T O_maskT = maskT * OT diffT = torch.clamp(O_maskT, 0) violationT = torch.sign(diffT).sum(1).sum(1) diffT = diffT.sum(1).sum(1) lossT = torch.mean(diffT / (violationT + self.eps)) # loss = beta1 * loss1 + beta2 * lossT loss = lossT return loss def test_IT_loss(self, image_features, text_features): # b, 1024 / b, 1024 batch = image_features.shape[0] # # normalized features image_features = image_features / image_features.norm(dim=-1, keepdim=True) text_features = text_features / text_features.norm(dim=-1, keepdim=True) image_features = image_features.unsqueeze(1) # cosine similarity as logits logit_scale = self.logit_scale.exp() logits_per_image = logit_scale * torch.matmul(image_features, text_features.transpose(1, 2)) logits_per_image = logits_per_image.squeeze(1) # shape = [global_batch_size, global_batch_size] contrastive_labels = torch.arange(batch).to(logits_per_image.device) contrastive_loss = self.ce(logits_per_image, contrastive_labels) return contrastive_loss def test_forward(self, img, txt): ''' img: b, 3, h, w word: b, words word_mask: b, words mask: b, 1, h, w stage: 1st or 2nd stage ''' # padding mask used in decoder pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool() # vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7 # word: b, length, 512 # state: b, 1024 # image: b, 512 vis, image = self.backbone.encode_image(img) word, text = self.backbone.encode_text(txt) fq = self.FPN(vis, text) b, c, h, w = fq.size() # b, 512, 14, 14 ff = self.FGFusion(fq, word, pad_mask) ff = ff.reshape(b, c, h, w) f2 = self.avg(ff) fi = image.unsqueeze(-1).unsqueeze(-1) fv = self.ASFF(fi, f2) fi = fi.squeeze(-1).squeeze(-1) # b, 1024 ft = self.projT(text) loss1 = self.IT_loss(fi, ft) loss2 = self.IT_loss(fv, ft) loss = self.lamda1 * loss1 + self.lamda2 * loss2 return loss, fv, ft, fi def forward(self, img, txt, stage): if stage == '1st': ''' img: b, 3, h, w word: b, words word_mask: b, words mask: b, 1, h, w stage: 1st or 2nd stage ''' # padding mask used in decoder pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool() # vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7 # word: b, length, 512 # state: b, 1024 # image: b, 512 vis, image = self.backbone.encode_image(img) word, text = self.backbone.encode_text(txt) fq = self.FPN(vis, text) b, c, h, w = fq.size() # b, 512, 14, 14 ff = self.FGFusion(fq, word, pad_mask) ff = ff.reshape(b, c, h, w) f2 = self.avg(ff) fi = image.unsqueeze(-1).unsqueeze(-1) fv = self.ASFF(fi, f2) fi = fi.squeeze(-1).squeeze(-1) # b, 1024 ft = self.projT(text) loss1 = self.IT_loss(fi, ft) loss2 = self.IT_loss(fv, ft) loss = self.lamda1 * loss1 + self.lamda2 * loss2 elif stage == '2nd': """ txt: b, num, words img: b, 3, h, w """ # txt = b * num, word b, num, l = txt.shape[0], txt.shape[1], txt.shape[2] txt = txt.view(-1, txt.size(-1)) # padding mask used in decoder pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool() b = img.shape[0] vis, image = self.backbone.encode_image(img) word, text = self.backbone.encode_text(txt) fq = self.FPN(vis, text) # b, 512, 14, 14 (C4) b, c, h, w = fq.size() # b, 512, 14, 14 ff = self.FGFusion(fq, word, pad_mask) ff = ff.reshape(b, c, h, w) f2 = self.avg(ff) fi = image.unsqueeze(-1).unsqueeze(-1) fi_ = fi.repeat(int(f2.shape[0] / fi.shape[0]), 1, 1, 1) fv = self.ASFF(fi_, f2) fi = fi.squeeze(-1).squeeze(-1) # fi_ = fi_.squeeze(-1).squeeze(-1) # b, 1024 ft = text.view(img.shape[0], int(text.shape[0] / img.shape[0]), -1)[:, 0, :] fv = fv.view(ft.shape[0], int(text.shape[0] / ft.shape[0]), fv.shape[1]) loss = self.test_IET_loss(fi, fv, self.pos_samples, self.beta1, self.beta2) elif stage == 'test': """ txt: b, num, words img: b, 3, h, w """ txt = txt.permute(1, 0, 2) # txt = b * num, word # txt = txt.view(-1, txt.size(-1)) # padding mask used in decoder pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool() # vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7 # word: b, length, 512 # state: b, 1024 # image: b, 512 b = img.shape[0] words = [] texts = [] vis, image = self.backbone.encode_image(img) for i in range(txt.shape[0]): word, text = self.backbone.encode_text(txt[i]) words.append(word) texts.append(text) fvn = [] # b, 512, 14, 14 (C4) for i in range(txt.shape[0]): fq = self.FPN(vis, texts[i]) b, c, h, w = fq.size() # b, 512, 14, 14 ff = self.FGFusion(fq, words[i], pad_mask[i, :, :]) ff = ff.reshape(b, c, h, w) f2 = self.avg(ff) fi = image.unsqueeze(-1).unsqueeze(-1) fv = self.ASFF(fi, f2) fi = fi.squeeze(-1).squeeze(-1) fvn.append(fv) # b, 1024 ft = self.projT(texts[0]) loss = self.IET_loss(fvn, ft, self.pos_samples, self.beta) fv = fvn else: print('stage should be either 1st or 2nd or test') # labels = torch.ones(image.shape[0], image.shape[0]).to(image.device) # labels[:,-1] = 0 # labels[3, :] = 0 # out = self.avg(fq) # out = out.squeeze(-1).squeeze(-1) # out = self.fc(out) return loss, fv, fi, ft class CRIS(nn.Module): def __init__(self, cfg): super().__init__() # Vision & Text Encoder & Label Encoder clip_model = torch.jit.load(cfg.clip_pretrain, map_location="cpu").eval() self.backbone, _, _, _, _ = build_model(clip_model.state_dict(), cfg.word_len) self.backbone = self.backbone.float() self.Label_encoder = build_promptlearner(clip_model.state_dict()).float() self.Label_encoder.init_label_emb(cfg.label_path) # Multi-Modal FPN self.FPN = FPN(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out) # Fined-grained Fusion self.FGFusion = TransformerDecoder(num_layers=cfg.num_layers, d_model=cfg.vis_dim, nhead=cfg.num_head, dim_ffn=cfg.dim_ffn, dropout=cfg.dropout, return_intermediate=cfg.intermediate) # adaptively aggretation self.ASFF = AdaptiveSpatialFeatureFusion(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out) # text projector self.projT = Text_Projector(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out) # parameter self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) self.multi_label_logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) self.share_temperature = True self.margin = 1 self.eps = 1e-3 self.ce = nn.CrossEntropyLoss() self.avg = nn.AdaptiveAvgPool2d((1,1)) self.fc = nn.Linear(512, cfg.num_classes) def IT_loss(self, image_features, text_features): # b, 1024 / b, 1024 batch = image_features.shape[0] # # normalized features image_features = image_features / image_features.norm(dim=-1, keepdim=True) text_features = text_features / text_features.norm(dim=-1, keepdim=True) # cosine similarity as logits logit_scale = self.logit_scale.exp() logits_per_image = logit_scale * image_features @ text_features.t() logits_per_text = logits_per_image.t() # shape = [global_batch_size, global_batch_size] contrastive_labels = torch.arange(batch).to(logits_per_image.device) contrastive_loss = (self.ce(logits_per_image, contrastive_labels) + self.ce(logits_per_text, contrastive_labels)) * 0.5 return contrastive_loss def IL_loss(self, image_features, label_features, labels): # b, 1024 / K, 1024/ b, K positive_tagsT = torch.clamp(labels,0.,1.) negative_tagsT = torch.clamp(-labels,0.,1.) maskT = positive_tagsT.unsqueeze(1) * negative_tagsT.unsqueeze(-1) # normalized features image_features = image_features / image_features.norm(dim=-1, keepdim=True) label_features = label_features / label_features.norm(dim=-1, keepdim=True) # cosine similarity as logits logit_scale = self.multi_label_logit_scale.exp() logits_per_image = logit_scale * image_features @ label_features.t() # logits_per_label = logit_scale * label_features @ image_features.t() pos_score_matT = logits_per_image * positive_tagsT neg_score_matT = logits_per_image * negative_tagsT IW_pos3T = pos_score_matT.unsqueeze(1) IW_neg3T = neg_score_matT.unsqueeze(-1) OT = self.margin + IW_neg3T - IW_pos3T O_maskT = maskT * OT diffT = torch.clamp(O_maskT, 0) violationT = torch.sign(diffT).sum(1).sum(1) diffT = diffT.sum(1).sum(1) lossT = torch.mean(diffT / (violationT + self.eps)) return lossT def margin_loss(self, image_features, label_features, labels): # b, 1024 / K, 1024/ b, K # normalized features image_features = image_features / image_features.norm(dim=-1, keepdim=True) label_features = label_features / label_features.norm(dim=-1, keepdim=True) # cosine similarity as logits logit_scale = self.multi_label_logit_scale.exp() logits_per_image = logit_scale * image_features @ label_features.t() # logits_per_label = logit_scale * label_features @ image_features.t() image_label_positive_pairs = logits_per_image * labels image_label_mean_positive = image_label_positive_pairs.sum() / labels.sum() image_label_negative_pairs = logits_per_image * (1 - labels) image_label_mean_negative = image_label_negative_pairs.sum() / (logits_per_image.numel() - labels.sum() + self.eps) contrastive_loss = torch.relu(self.margin - image_label_mean_positive + image_label_mean_negative) return contrastive_loss def forward(self, img, txt, target=None): ''' img: b, 3, h, w word: b, words word_mask: b, words mask: b, 1, h, w ''' # padding mask used in decoder pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool() # vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7 # word: b, length, 512 # state: b, 1024 # image: b, 512 vis, image = self.backbone.encode_image(img) word, text = self.backbone.encode_text(txt) fl = self.Label_encoder(image.device) # b, 512, 14, 14 (C4) fq = self.FPN(vis, text) b, c, h, w = fq.size() # b, 512, 14, 14 ff = self.FGFusion(fq, word, pad_mask) # b, 512, 196 ff = ff.reshape(b, c, h, w) f2 = self.avg(ff) # b, 1024 f1 = image.unsqueeze(-1).unsqueeze(-1) fv = self.ASFF(f1, f2) # b, 1024 ft = self.projT(text) # labels = torch.ones(image.shape[0], image.shape[0]).to(image.device) # labels[:,-1] = 0 # labels[3, :] = 0 loss1 = self.IT_loss(fv, ft) loss2 = self.IL_loss(fv, fl, target) loss = loss1 + loss2 # out = self.avg(fq) # out = out.squeeze(-1).squeeze(-1) # out = self.fc(out) return loss, fv, ft, fl class zh_clip(nn.Module): def __init__(self, cfg): super().__init__() # Vision & Text Encoder clip_model = torch.jit.load(cfg.clip_pretrain, map_location="cpu").eval() self.backbone = build_modified_model(clip_model.state_dict(), cfg.word_len).float() self.text_encoder = AutoModelForSequenceClassification.from_pretrained(cfg.chinese) self.text_lin = nn.Linear(512, 1024) # Multi-Modal FPN self.neck = ViTFPN(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out) # Decoder self.avg = nn.AdaptiveAvgPool2d((1,1)) self.fc = nn.Linear(512, cfg.num_classes) def forward(self, img, word): ''' img: b, 3, h, w word: b, words word_mask: b, words mask: b, 1, h, w ''' # padding mask used in decoder # vis: v1 / v2 / b, 49, 1024/ b, 196, 512 # state: b, 1024 # feat: f1 / f2 / b, 1024, 7, 7/ b, 1024, 7, 7 # cls: c1 / c2 / b, 1024/ b, 512 vis, feat, cls = self.backbone.encode_image(img) state = self.text_encoder(word.squeeze(1)).logits state = self.text_lin(state) # b, 1024, 7, 7 (C5) fq = self.neck(feat, state) out = self.avg(fq) out = out.squeeze(-1).squeeze(-1) out = self.fc(out) return out class poi_clip(nn.Module): def __init__(self, cfg): super().__init__() # Vision & Text Encoder clip_model = torch.jit.load(cfg.clip_pretrain, map_location="cpu").eval() self.backbone = build_modified_model(clip_model.state_dict(), cfg.word_len).float() self.text_encoder = AutoModelForSequenceClassification.from_pretrained(cfg.chinese) self.text_lin = nn.Linear(512, 1024) # Multi-Modal FPN self.neck = ViTFPN(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out) # Decoder self.avg = nn.AdaptiveAvgPool2d((1,1)) self.fc = nn.Linear(512, cfg.num_classes) def forward(self, img, word): ''' img: b, 3, h, w word: b, words word_mask: b, words mask: b, 1, h, w ''' # padding mask used in decoder # vis: v1 / v2 / b, 49, 1024/ b, 196, 512 # state: b, 1024 # feat: f1 / f2 / b, 1024, 7, 7/ b, 1024, 7, 7 # cls: c1 / c2 / b, 1024/ b, 512 vis, feat, cls = self.backbone.encode_image(img) state = self.text_encoder(word.squeeze(1)).logits state = self.text_lin(state) # b, 1024, 7, 7 (C5) fq = self.neck(feat, state) out = self.avg(fq) out = out.squeeze(-1).squeeze(-1) out = self.fc(out) return out class Clip_hash_model(nn.Module): def __init__(self, cfg): super().__init__() # Vision & Text Encoder clip_model = torch.jit.load(cfg.clip_pretrain, map_location="cpu").eval() self.backbone = build_model(clip_model.state_dict(), cfg.word_len).float() # Multi-Modal FPN self.neck = FPN(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out) # Decoder self.avg = nn.AdaptiveAvgPool2d((1, 1)) self.classifier = nn.Sequential( nn.Linear(cfg.fpn_out[1], cfg.hash_dim, bias=True), nn.Tanh(), ) self.classifier2 = nn.Sequential( nn.Linear(cfg.hash_dim, cfg.num_classes) ) # Hash Module self.image_module = nn.Sequential( nn.Linear(cfg.img_dim, cfg.hidden_dim, bias=True), nn.BatchNorm1d(cfg.hidden_dim), nn.ReLU(True), nn.Linear(cfg.hidden_dim, cfg.hash_dim, bias=True), nn.Tanh() ) self.text_module = nn.Sequential( nn.Linear(cfg.txt_dim, cfg.hidden_dim, bias=True), nn.BatchNorm1d(cfg.hidden_dim), nn.ReLU(True), nn.Linear(cfg.hidden_dim, cfg.hash_dim, bias=True), nn.Tanh() ) def forward(self, img, word, mask=None): ''' img: b, 3, h, w word: b, words word_mask: b, words ''' pad_mask = torch.zeros_like(word).masked_fill_(word == 0, 1).bool() # vis: C3 / C4 / C5 # word: b, length, 512 # state: b, 1024 vis, image = self.backbone.encode_image(img) word, state = self.backbone.encode_text(word) # b, 512, 26, 26 (C4) fq = self.neck(vis, state) # out_hash: b, code_length # res: b, classes out = self.avg(fq) out = out.squeeze(-1).squeeze(-1) out_hash = self.classifier(out) res = self.classifier2(out_hash) # img_hash: b, code_length # txt_hash: b, code_length img_hash = self.image_module(image) txt_hash = self.text_module(state) return img_hash, txt_hash, out_hash, res class Clip_model(nn.Module): def __init__(self, cfg): super().__init__() # Vision & Text Encoder clip_model = torch.jit.load(cfg.clip_pretrain, map_location="cpu").eval() self.neck = FPN(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out) self.avg = nn.AdaptiveAvgPool2d((1, 1)) self.backbone = build_model(clip_model.state_dict(), cfg.word_len).float() def forward(self, img, word, mask=None): ''' img: b, 3, h, w word: b, words word_mask: b, words ''' # vis: C3 / C4 / C5 # word: b, length, 512 # state: b, 1024 pad_mask = torch.zeros_like(word).masked_fill_(word == 0, 1).bool() vis, image = self.backbone.encode_image(img) word, state = self.backbone.encode_text(word) f = self.neck(vis, state) out = self.avg(f) out = out.squeeze(-1).squeeze(-1) image_features = image / image.norm(dim=-1, keepdim=True) text_features = state / state.norm(dim=-1, keepdim=True) # cosine similarity as logits logit_scale = self.backbone.logit_scale.exp() logits_per_image = logit_scale * image_features @ text_features.t() logits_per_text = logits_per_image.t() # shape = [global_batch_size, global_batch_size] return logits_per_image, logits_per_text class CISEN_rsvit_hug(nn.Module, PyTorchModelHubMixin): def __init__(self, embed_dim, image_resolution, vision_layers, vision_width, vision_patch_size, context_length, txt_length, vocab_size, transformer_width, transformer_heads, transformer_layers, patch_size, output_dim, ratio, emb_dim, fpn_in, fpn_out): super().__init__() # Vision & Text Encoder & Label Encoder vision_heads = vision_width * 32 // 64 backbone = CLIP(embed_dim, image_resolution, vision_layers, vision_width, vision_patch_size, context_length, txt_length, vocab_size, transformer_width, transformer_heads, transformer_layers) self.backbone = backbone.float() self.patch_emb = image_resolution // patch_size self.FPN = ViTFPN(image_resolution, in_channels=fpn_in, out_channels=fpn_out) self.ADP = Adapter(output_dim, 4) # parameter self.ratio = ratio self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) self.share_temperature = True self.ce = nn.CrossEntropyLoss() self.ms_adaptor = nn.ModuleList( [ nn.Sequential( nn.ConvTranspose2d(emb_dim, emb_dim, 2, 2), nn.GroupNorm(32, emb_dim), nn.GELU(), nn.ConvTranspose2d(emb_dim, emb_dim, 2, 2), ), nn.Sequential( nn.ConvTranspose2d(emb_dim, emb_dim, 2, 2), ), nn.Sequential( nn.Identity(), ), nn.Sequential( nn.MaxPool2d(2), ), ] ) self.ms_adaptor.apply(self.init_adaptor) def init_adaptor(self, m): if isinstance(m, nn.Conv2d): lecun_normal_(m.weight) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.GroupNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) elif isinstance(m, nn.ConvTranspose2d): lecun_normal_(m.weight) if m.bias is not None: nn.init.zeros_(m.bias) # self.fc = nn.Linear(512, cfg.num_classes) def image_encode(self, img): vis, image = self.backbone.encode_image(img) x = self.ADP(image) x = self.ratio * x + (1 - self.ratio) * image return x def text_encode(self, txt): word, text = self.backbone.encode_text(txt) return text def forward(self, img, txt): ''' img: b, 3, h, w word: b, words word_mask: b, words mask: b, 1, h, w stage: 1st or 2nd stage ''' # padding mask used in decoder pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool() # vis: C3 / C4 / C5 / b, 512, 28, 28/ b, 1024, 14, 14/ b, 1024, 7, 7 # word: b, length, 512 # text: b, 1024 # image: b, 1024 vis, image = self.backbone.encode_image(img) word, text = self.backbone.encode_text(txt) x = self.ADP(image) x = self.ratio * x + (1 - self.ratio) * image # Construct multi-scale feats vis_trans = [] for i in range(len(self.ms_adaptor)): x_ = rearrange( vis[i], "b (h w) c -> b c h w", h=self.patch_emb, w=self.patch_emb, ).contiguous() feats = self.ms_adaptor[i](x_) vis_trans.append(feats) # fq = self.FPN(vis, x_t) fv_t = self.FPN(vis_trans[1:], x, False) # fv_t = self.gap(fq_t) # b, 1024 fv = fv_t ft = text fi = x return fv, fi, ft