|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import numpy as np
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
|
from .clip import build_model, build_promptlearner, build_modified_model, PromptLearner, build_lclip_model
|
|
from torch.cuda.amp import autocast as autocast
|
|
from timm.models.layers import trunc_normal_ as __call_trunc_normal_
|
|
from timm.models.layers import variance_scaling_
|
|
from einops import rearrange, repeat
|
|
from loguru import logger
|
|
from transformers import AlignProcessor, AlignModel
|
|
from sklearn.metrics import classification_report
|
|
from huggingface_hub import PyTorchModelHubMixin
|
|
from .layers import FPN, TransformerDecoder, ViTFPN, AdaptiveSpatialFeatureFusion, Text_Projector, Image_Projector, Adapter, GAP
|
|
from cisen.model.clip import CLIP
|
|
def lecun_normal_(tensor):
|
|
variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
|
|
|
|
def trunc_normal_(tensor, mean=0.0, std=1.0):
|
|
__call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std)
|
|
|
|
class CISEN_vit(nn.Module, PyTorchModelHubMixin):
|
|
def __init__(self, cfg):
|
|
super().__init__()
|
|
|
|
clip_model = torch.jit.load(cfg.clip_pretrain,
|
|
map_location="cpu").eval()
|
|
|
|
backbone, image_resolution, vision_heads, embed_dim, vision_width, patch_size = build_model(clip_model.state_dict(), cfg.word_len)
|
|
self.backbone = backbone.float()
|
|
self.patch_emb = image_resolution // patch_size
|
|
cfg.image_resolution = image_resolution
|
|
cfg.input_size = image_resolution
|
|
cfg.heads = vision_heads // 32
|
|
cfg.emb_dim = vision_width
|
|
cfg.output_dim = embed_dim
|
|
|
|
|
|
|
|
self.FPN = ViTFPN(image_resolution, in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.ADP = Adapter(cfg.output_dim, 4)
|
|
|
|
self.ratio = cfg.ratio
|
|
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
|
self.share_temperature = True
|
|
self.ce = nn.CrossEntropyLoss()
|
|
self.ms_adaptor = nn.ModuleList(
|
|
[
|
|
nn.Sequential(
|
|
nn.ConvTranspose2d(cfg.emb_dim, cfg.emb_dim, 2, 2),
|
|
nn.GroupNorm(32, cfg.emb_dim),
|
|
nn.GELU(),
|
|
nn.ConvTranspose2d(cfg.emb_dim, cfg.emb_dim, 2, 2),
|
|
),
|
|
nn.Sequential(
|
|
nn.ConvTranspose2d(cfg.emb_dim, cfg.emb_dim, 2, 2),
|
|
),
|
|
nn.Sequential(
|
|
nn.Identity(),
|
|
),
|
|
nn.Sequential(
|
|
nn.MaxPool2d(2),
|
|
),
|
|
|
|
]
|
|
)
|
|
|
|
self.ms_adaptor.apply(self.init_adaptor)
|
|
def init_adaptor(self, m):
|
|
if isinstance(m, nn.Conv2d):
|
|
lecun_normal_(m.weight)
|
|
if m.bias is not None:
|
|
nn.init.constant_(m.bias, 0)
|
|
elif isinstance(m, nn.GroupNorm):
|
|
nn.init.constant_(m.bias, 0)
|
|
nn.init.constant_(m.weight, 1.0)
|
|
elif isinstance(m, nn.ConvTranspose2d):
|
|
lecun_normal_(m.weight)
|
|
if m.bias is not None:
|
|
nn.init.zeros_(m.bias)
|
|
|
|
|
|
|
|
def IT_loss(self, image_features, text_features):
|
|
|
|
batch = image_features.shape[0]
|
|
|
|
image_features = image_features / image_features.norm(dim=-1,
|
|
keepdim=True)
|
|
text_features = text_features / text_features.norm(dim=-1,
|
|
keepdim=True)
|
|
|
|
|
|
logit_scale = self.logit_scale.exp()
|
|
logits_per_image = logit_scale * image_features @ text_features.t()
|
|
logits_per_text = logits_per_image.t()
|
|
|
|
|
|
contrastive_labels = torch.arange(batch).to(logits_per_image.device)
|
|
contrastive_loss = (self.ce(logits_per_image, contrastive_labels) + self.ce(logits_per_text, contrastive_labels)) * 0.5
|
|
|
|
|
|
return contrastive_loss
|
|
|
|
def forward(self, img, txt, stage):
|
|
|
|
if stage == '1st':
|
|
'''
|
|
img: b, 3, h, w
|
|
word: b, words
|
|
word_mask: b, words
|
|
mask: b, 1, h, w
|
|
stage: 1st or 2nd stage
|
|
'''
|
|
|
|
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
|
|
|
|
|
|
|
|
|
|
|
|
vis, image = self.backbone.encode_image(img)
|
|
|
|
word, text = self.backbone.encode_text(txt)
|
|
|
|
x = self.ADP(image)
|
|
|
|
x = self.ratio * x + (1-self.ratio) * image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
loss1 = self.IT_loss(x, text)
|
|
|
|
loss = loss1
|
|
|
|
ft = text
|
|
fi = x
|
|
fv = None
|
|
elif stage == '2nd':
|
|
'''
|
|
img: b, 3, h, w
|
|
word: b, words
|
|
word_mask: b, words
|
|
mask: b, 1, h, w
|
|
stage: 1st or 2nd stage
|
|
'''
|
|
|
|
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
|
|
|
|
|
|
|
|
|
|
|
|
vis, image = self.backbone.encode_image(img)
|
|
|
|
word, text = self.backbone.encode_text(txt)
|
|
|
|
x = self.ADP(image)
|
|
|
|
x = self.ratio * x + (1 - self.ratio) * image
|
|
|
|
vis_trans = []
|
|
for i in range(len(self.ms_adaptor)):
|
|
x_ = rearrange(
|
|
vis[i],
|
|
"b (h w) c -> b c h w",
|
|
h=self.patch_emb,
|
|
w=self.patch_emb,
|
|
).contiguous()
|
|
|
|
feats = self.ms_adaptor[i](x_)
|
|
|
|
vis_trans.append(feats)
|
|
|
|
|
|
fv_t = self.FPN(vis_trans[1:], x, False)
|
|
|
|
|
|
|
|
|
|
loss2 = self.IT_loss(fv_t, text)
|
|
|
|
loss = (loss2)
|
|
fv = fv_t
|
|
ft = text
|
|
fi = x
|
|
|
|
|
|
return loss, fv, fi, ft
|
|
|
|
def visualize(self, img, txt):
|
|
vis, image = self.backbone.encode_image(img)
|
|
word, text = self.backbone.encode_text(txt)
|
|
|
|
x = self.ADP(image)
|
|
|
|
x = self.ratio * x + (1 - self.ratio) * image
|
|
|
|
vis_trans = []
|
|
for i in range(len(self.ms_adaptor)):
|
|
x_ = rearrange(
|
|
vis[i],
|
|
"b (h w) c -> b c h w",
|
|
h=self.patch_emb,
|
|
w=self.patch_emb,
|
|
).contiguous()
|
|
|
|
feats = self.ms_adaptor[i](x_)
|
|
|
|
vis_trans.append(feats)
|
|
|
|
|
|
fv_t = self.FPN(vis_trans[1:], x, True)
|
|
ft_t = self.FPN(vis_trans[1:], text, True)
|
|
return vis, fv_t, ft_t
|
|
|
|
class CISEN_rsvit(nn.Module, PyTorchModelHubMixin):
|
|
def __init__(self, cfg):
|
|
super().__init__()
|
|
|
|
clip_model = torch.load(cfg.clip_pretrain,
|
|
map_location="cpu")
|
|
|
|
backbone, image_resolution, vision_heads, embed_dim, vision_width, patch_size = build_model(clip_model, cfg.word_len)
|
|
self.backbone = backbone.float()
|
|
self.patch_emb = image_resolution // patch_size
|
|
|
|
cfg.image_resolution = image_resolution
|
|
cfg.input_size = image_resolution
|
|
cfg.heads = vision_heads // 32
|
|
cfg.emb_dim = vision_width
|
|
cfg.output_dim = embed_dim
|
|
|
|
|
|
|
|
self.FPN = ViTFPN(image_resolution, in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.ADP = Adapter(cfg.output_dim, 4)
|
|
|
|
self.ratio = cfg.ratio
|
|
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
|
self.share_temperature = True
|
|
self.ce = nn.CrossEntropyLoss()
|
|
self.ms_adaptor = nn.ModuleList(
|
|
[
|
|
nn.Sequential(
|
|
nn.ConvTranspose2d(cfg.emb_dim, cfg.emb_dim, 2, 2),
|
|
nn.GroupNorm(32, cfg.emb_dim),
|
|
nn.GELU(),
|
|
nn.ConvTranspose2d(cfg.emb_dim, cfg.emb_dim, 2, 2),
|
|
),
|
|
nn.Sequential(
|
|
nn.ConvTranspose2d(cfg.emb_dim, cfg.emb_dim, 2, 2),
|
|
),
|
|
nn.Sequential(
|
|
nn.Identity(),
|
|
),
|
|
nn.Sequential(
|
|
nn.MaxPool2d(2),
|
|
),
|
|
|
|
]
|
|
)
|
|
|
|
self.ms_adaptor.apply(self.init_adaptor)
|
|
def init_adaptor(self, m):
|
|
if isinstance(m, nn.Conv2d):
|
|
lecun_normal_(m.weight)
|
|
if m.bias is not None:
|
|
nn.init.constant_(m.bias, 0)
|
|
elif isinstance(m, nn.GroupNorm):
|
|
nn.init.constant_(m.bias, 0)
|
|
nn.init.constant_(m.weight, 1.0)
|
|
elif isinstance(m, nn.ConvTranspose2d):
|
|
lecun_normal_(m.weight)
|
|
if m.bias is not None:
|
|
nn.init.zeros_(m.bias)
|
|
|
|
|
|
|
|
def IT_loss(self, image_features, text_features):
|
|
|
|
batch = image_features.shape[0]
|
|
|
|
image_features = image_features / image_features.norm(dim=-1,
|
|
keepdim=True)
|
|
text_features = text_features / text_features.norm(dim=-1,
|
|
keepdim=True)
|
|
|
|
|
|
logit_scale = self.logit_scale.exp()
|
|
logits_per_image = logit_scale * image_features @ text_features.t()
|
|
logits_per_text = logits_per_image.t()
|
|
|
|
|
|
contrastive_labels = torch.arange(batch).to(logits_per_image.device)
|
|
contrastive_loss = (self.ce(logits_per_image, contrastive_labels) + self.ce(logits_per_text, contrastive_labels)) * 0.5
|
|
|
|
|
|
return contrastive_loss
|
|
def image_encode(self, img):
|
|
vis, image = self.backbone.encode_image(img)
|
|
|
|
x = self.ADP(image)
|
|
|
|
x = self.ratio * x + (1 - self.ratio) * image
|
|
return x
|
|
|
|
def text_encode(self, txt):
|
|
|
|
word, text = self.backbone.encode_text(txt)
|
|
|
|
return text
|
|
|
|
def forward(self, img, txt, stage):
|
|
|
|
if stage == '1st':
|
|
'''
|
|
img: b, 3, h, w
|
|
word: b, words
|
|
word_mask: b, words
|
|
mask: b, 1, h, w
|
|
stage: 1st or 2nd stage
|
|
'''
|
|
|
|
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
|
|
|
|
|
|
|
|
|
|
|
|
vis, image = self.backbone.encode_image(img)
|
|
|
|
word, text = self.backbone.encode_text(txt)
|
|
|
|
x = self.ADP(image)
|
|
|
|
x = self.ratio * x + (1-self.ratio) * image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
loss1 = self.IT_loss(x, text)
|
|
|
|
loss = loss1
|
|
|
|
ft = text
|
|
fi = x
|
|
fv = None
|
|
elif stage == '2nd':
|
|
'''
|
|
img: b, 3, h, w
|
|
word: b, words
|
|
word_mask: b, words
|
|
mask: b, 1, h, w
|
|
stage: 1st or 2nd stage
|
|
'''
|
|
|
|
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
|
|
|
|
|
|
|
|
|
|
|
|
vis, image = self.backbone.encode_image(img)
|
|
|
|
word, text = self.backbone.encode_text(txt)
|
|
|
|
x = self.ADP(image)
|
|
|
|
x = self.ratio * x + (1 - self.ratio) * image
|
|
|
|
vis_trans = []
|
|
for i in range(len(self.ms_adaptor)):
|
|
x_ = rearrange(
|
|
vis[i],
|
|
"b (h w) c -> b c h w",
|
|
h=self.patch_emb,
|
|
w=self.patch_emb,
|
|
).contiguous()
|
|
|
|
feats = self.ms_adaptor[i](x_)
|
|
|
|
vis_trans.append(feats)
|
|
|
|
|
|
fv_t = self.FPN(vis_trans[1:], x, False)
|
|
|
|
|
|
|
|
|
|
loss2 = self.IT_loss(fv_t, text)
|
|
|
|
loss = (loss2)
|
|
fv = fv_t
|
|
ft = text
|
|
fi = x
|
|
|
|
|
|
return loss, fv, fi, ft
|
|
|
|
def visualize(self, img):
|
|
vis, image = self.backbone.encode_image(img)
|
|
|
|
|
|
x = self.ADP(image)
|
|
|
|
x = self.ratio * x + (1 - self.ratio) * image
|
|
|
|
vis_trans = []
|
|
for i in range(len(self.ms_adaptor)):
|
|
x_ = rearrange(
|
|
vis[i],
|
|
"b (h w) c -> b c h w",
|
|
h=self.patch_emb,
|
|
w=self.patch_emb,
|
|
).contiguous()
|
|
|
|
feats = self.ms_adaptor[i](x_)
|
|
|
|
vis_trans.append(feats)
|
|
|
|
|
|
fv_t = self.FPN(vis_trans[1:], x, True)
|
|
return vis, fv_t
|
|
|
|
class CISEN_vit(nn.Module, PyTorchModelHubMixin):
|
|
def __init__(self, cfg):
|
|
super().__init__()
|
|
|
|
clip_model = torch.jit.load(cfg.clip_pretrain,
|
|
map_location="cpu").eval()
|
|
|
|
backbone, image_resolution, vision_heads, embed_dim, vision_width, patch_size = build_model(clip_model.state_dict(), cfg.word_len)
|
|
self.backbone = backbone.float()
|
|
self.patch_emb = image_resolution // patch_size
|
|
cfg.image_resolution = image_resolution
|
|
cfg.input_size = image_resolution
|
|
cfg.heads = vision_heads // 32
|
|
cfg.emb_dim = vision_width
|
|
cfg.output_dim = embed_dim
|
|
|
|
|
|
|
|
self.FPN = ViTFPN(cfg, in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.ADP = Adapter(cfg.output_dim, 4)
|
|
|
|
self.ratio = cfg.ratio
|
|
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
|
self.share_temperature = True
|
|
self.ce = nn.CrossEntropyLoss()
|
|
self.ms_adaptor = nn.ModuleList(
|
|
[
|
|
nn.Sequential(
|
|
nn.ConvTranspose2d(cfg.emb_dim, cfg.emb_dim, 2, 2),
|
|
nn.GroupNorm(32, cfg.emb_dim),
|
|
nn.GELU(),
|
|
nn.ConvTranspose2d(cfg.emb_dim, cfg.emb_dim, 2, 2),
|
|
),
|
|
nn.Sequential(
|
|
nn.ConvTranspose2d(cfg.emb_dim, cfg.emb_dim, 2, 2),
|
|
),
|
|
nn.Sequential(
|
|
nn.Identity(),
|
|
),
|
|
nn.Sequential(
|
|
nn.MaxPool2d(2),
|
|
),
|
|
|
|
]
|
|
)
|
|
|
|
self.ms_adaptor.apply(self.init_adaptor)
|
|
def init_adaptor(self, m):
|
|
if isinstance(m, nn.Conv2d):
|
|
lecun_normal_(m.weight)
|
|
if m.bias is not None:
|
|
nn.init.constant_(m.bias, 0)
|
|
elif isinstance(m, nn.GroupNorm):
|
|
nn.init.constant_(m.bias, 0)
|
|
nn.init.constant_(m.weight, 1.0)
|
|
elif isinstance(m, nn.ConvTranspose2d):
|
|
lecun_normal_(m.weight)
|
|
if m.bias is not None:
|
|
nn.init.zeros_(m.bias)
|
|
|
|
|
|
|
|
def IT_loss(self, image_features, text_features):
|
|
|
|
batch = image_features.shape[0]
|
|
|
|
image_features = image_features / image_features.norm(dim=-1,
|
|
keepdim=True)
|
|
text_features = text_features / text_features.norm(dim=-1,
|
|
keepdim=True)
|
|
|
|
|
|
logit_scale = self.logit_scale.exp()
|
|
logits_per_image = logit_scale * image_features @ text_features.t()
|
|
logits_per_text = logits_per_image.t()
|
|
|
|
|
|
contrastive_labels = torch.arange(batch).to(logits_per_image.device)
|
|
contrastive_loss = (self.ce(logits_per_image, contrastive_labels) + self.ce(logits_per_text, contrastive_labels)) * 0.5
|
|
|
|
|
|
return contrastive_loss
|
|
|
|
def forward(self, img, txt, stage):
|
|
|
|
if stage == '1st':
|
|
'''
|
|
img: b, 3, h, w
|
|
word: b, words
|
|
word_mask: b, words
|
|
mask: b, 1, h, w
|
|
stage: 1st or 2nd stage
|
|
'''
|
|
|
|
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
|
|
|
|
|
|
|
|
|
|
|
|
vis, image = self.backbone.encode_image(img)
|
|
|
|
word, text = self.backbone.encode_text(txt)
|
|
|
|
x = self.ADP(image)
|
|
|
|
x = self.ratio * x + (1-self.ratio) * image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
loss1 = self.IT_loss(x, text)
|
|
|
|
loss = loss1
|
|
|
|
ft = text
|
|
fi = x
|
|
fv = None
|
|
elif stage == '2nd':
|
|
'''
|
|
img: b, 3, h, w
|
|
word: b, words
|
|
word_mask: b, words
|
|
mask: b, 1, h, w
|
|
stage: 1st or 2nd stage
|
|
'''
|
|
|
|
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
|
|
|
|
|
|
|
|
|
|
|
|
vis, image = self.backbone.encode_image(img)
|
|
|
|
word, text = self.backbone.encode_text(txt)
|
|
|
|
x = self.ADP(image)
|
|
|
|
x = self.ratio * x + (1 - self.ratio) * image
|
|
|
|
vis_trans = []
|
|
for i in range(len(self.ms_adaptor)):
|
|
x_ = rearrange(
|
|
vis[i],
|
|
"b (h w) c -> b c h w",
|
|
h=self.patch_emb,
|
|
w=self.patch_emb,
|
|
).contiguous()
|
|
|
|
feats = self.ms_adaptor[i](x_)
|
|
|
|
vis_trans.append(feats)
|
|
|
|
|
|
fv_t = self.FPN(vis_trans[1:], x, False)
|
|
|
|
|
|
|
|
|
|
loss2 = self.IT_loss(fv_t, text)
|
|
|
|
loss = (loss2)
|
|
fv = fv_t
|
|
ft = text
|
|
fi = x
|
|
|
|
|
|
return loss, fv, fi, ft
|
|
|
|
def visualize(self, img, txt):
|
|
vis, image = self.backbone.encode_image(img)
|
|
word, text = self.backbone.encode_text(txt)
|
|
|
|
x = self.ADP(image)
|
|
|
|
x = self.ratio * x + (1 - self.ratio) * image
|
|
|
|
vis_trans = []
|
|
for i in range(len(self.ms_adaptor)):
|
|
x_ = rearrange(
|
|
vis[i],
|
|
"b (h w) c -> b c h w",
|
|
h=self.patch_emb,
|
|
w=self.patch_emb,
|
|
).contiguous()
|
|
|
|
feats = self.ms_adaptor[i](x_)
|
|
|
|
vis_trans.append(feats)
|
|
|
|
|
|
fv_t = self.FPN(vis_trans[1:], x, True)
|
|
ft_t = self.FPN(vis_trans[1:], text, True)
|
|
return vis, fv_t, ft_t
|
|
|
|
class CISEN_rsvit_classification(nn.Module):
|
|
def __init__(self, cfg):
|
|
super().__init__()
|
|
|
|
clip_model = torch.load(cfg.clip_pretrain,
|
|
map_location="cpu")
|
|
|
|
backbone, image_resolution, vision_heads, embed_dim, vision_width, patch_size = build_model(clip_model, cfg.word_len)
|
|
self.backbone = backbone.float()
|
|
self.patch_emb = image_resolution // patch_size
|
|
num_classes_fc = 512
|
|
num_classes_output = 10
|
|
self.num_classes_fc = num_classes_fc
|
|
self.num_classes_output = num_classes_output
|
|
|
|
|
|
self.fc = nn.Linear(in_features=cfg.vis_dim, out_features=num_classes_fc)
|
|
|
|
|
|
self.output_layer = nn.Linear(in_features=num_classes_fc, out_features=num_classes_output)
|
|
self.criterion = nn.BCEWithLogitsLoss()
|
|
cfg.image_resolution = image_resolution
|
|
cfg.input_size = image_resolution
|
|
cfg.heads = vision_heads // 32
|
|
cfg.emb_dim = vision_width
|
|
cfg.output_dim = embed_dim
|
|
|
|
|
|
def IT_loss(self, labels, labels_pre):
|
|
|
|
labels = labels.squeeze(1)
|
|
|
|
loss = self.criterion(labels_pre, labels)
|
|
return loss
|
|
|
|
def forward(self, img, labels):
|
|
_, image_features = self.backbone.encode_image(img)
|
|
|
|
fc_output = self.fc(image_features)
|
|
|
|
fc_output = F.relu(fc_output)
|
|
|
|
|
|
labels_pre = self.output_layer(fc_output)
|
|
|
|
loss2 = self.IT_loss(labels, labels_pre)
|
|
|
|
return labels_pre, loss2
|
|
|
|
|
|
class CISEN_new(nn.Module):
|
|
def __init__(self, cfg):
|
|
super().__init__()
|
|
|
|
clip_model = torch.jit.load(cfg.clip_pretrain,
|
|
map_location="cpu").eval()
|
|
|
|
backbone, image_resolution, vision_heads, embed_dim, vision_width, _ = build_model(clip_model.state_dict(), cfg.word_len)
|
|
self.backbone = backbone.float()
|
|
cfg.input_size = image_resolution
|
|
cfg.heads = vision_heads
|
|
cfg.emb_dim = vision_width * 32
|
|
cfg.output_dim = embed_dim
|
|
|
|
self.FPN = FPN(cfg, in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.ADP = Adapter(cfg.output_dim, 4)
|
|
self.gap = GAP((1,1))
|
|
|
|
self.ratio = cfg.ratio
|
|
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
|
self.share_temperature = True
|
|
self.margin = 1
|
|
self.eps = 1e-3
|
|
self.ce = nn.CrossEntropyLoss()
|
|
|
|
self.lamda1 = cfg.lamda1
|
|
self.lamda2 = cfg.lamda2
|
|
self.avg = nn.AdaptiveAvgPool2d((1,1))
|
|
|
|
|
|
|
|
def IT_loss(self, image_features, text_features):
|
|
|
|
batch = image_features.shape[0]
|
|
|
|
image_features = image_features / image_features.norm(dim=-1,
|
|
keepdim=True)
|
|
text_features = text_features / text_features.norm(dim=-1,
|
|
keepdim=True)
|
|
|
|
|
|
logit_scale = self.logit_scale.exp()
|
|
logits_per_image = logit_scale * image_features @ text_features.t()
|
|
logits_per_text = logits_per_image.t()
|
|
|
|
|
|
contrastive_labels = torch.arange(batch).to(logits_per_image.device)
|
|
contrastive_loss = (self.ce(logits_per_image, contrastive_labels) + self.ce(logits_per_text, contrastive_labels)) * 0.5
|
|
|
|
|
|
return contrastive_loss
|
|
|
|
def forward(self, img, txt, stage):
|
|
|
|
if stage == '1st':
|
|
'''
|
|
img: b, 3, h, w
|
|
word: b, words
|
|
word_mask: b, words
|
|
mask: b, 1, h, w
|
|
stage: 1st or 2nd stage
|
|
'''
|
|
|
|
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
|
|
|
|
|
|
|
|
|
|
|
|
vis, image = self.backbone.encode_image(img)
|
|
|
|
word, text = self.backbone.encode_text(txt)
|
|
|
|
x = self.ADP(image)
|
|
|
|
x = self.ratio * x + (1-self.ratio) * image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
loss1 = self.IT_loss(x, text)
|
|
|
|
loss = loss1
|
|
|
|
ft = text
|
|
fi = x
|
|
fv = None
|
|
elif stage == '2nd':
|
|
'''
|
|
img: b, 3, h, w
|
|
word: b, words
|
|
word_mask: b, words
|
|
mask: b, 1, h, w
|
|
stage: 1st or 2nd stage
|
|
'''
|
|
|
|
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
|
|
|
|
|
|
|
|
|
|
|
|
vis, image = self.backbone.encode_image(img)
|
|
|
|
word, text = self.backbone.encode_text(txt)
|
|
|
|
x = self.ADP(image)
|
|
|
|
x = self.ratio * x + (1 - self.ratio) * image
|
|
|
|
|
|
|
|
fq_t = self.FPN(vis, x)
|
|
|
|
fv_t = self.gap(fq_t)
|
|
|
|
|
|
|
|
loss2 = self.IT_loss(fv_t, text)
|
|
|
|
loss = (loss2)
|
|
fv = fv_t
|
|
ft = text
|
|
fi = x
|
|
elif stage == '3rd':
|
|
'''
|
|
img: b, 3, h, w
|
|
word: b, words
|
|
word_mask: b, words
|
|
mask: b, 1, h, w
|
|
stage: 1st or 2nd stage
|
|
'''
|
|
|
|
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
|
|
|
|
|
|
|
|
|
|
|
|
vis, image = self.backbone.encode_image(img)
|
|
|
|
word, text = self.backbone.encode_text(txt)
|
|
|
|
x = self.ADP(text)
|
|
ratio = 0.2
|
|
x = ratio * x + (1 - ratio) * text
|
|
|
|
|
|
|
|
|
|
|
|
loss1 = self.IT_loss(image, x)
|
|
|
|
|
|
loss = loss1
|
|
fv = None
|
|
ft = x
|
|
fi = image
|
|
elif stage == '4th':
|
|
'''
|
|
img: b, 3, h, w
|
|
word: b, words
|
|
word_mask: b, words
|
|
mask: b, 1, h, w
|
|
stage: 1st or 2nd stage
|
|
'''
|
|
|
|
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
|
|
|
|
|
|
|
|
|
|
|
|
vis, image = self.backbone.encode_image(img)
|
|
word, text = self.backbone.encode_text(txt)
|
|
|
|
|
|
|
|
fq_t = self.FPN(vis, image)
|
|
|
|
fv_t = self.gap(fq_t)
|
|
ratio_1 = 0.2
|
|
|
|
loss2 = self.IT_loss(fv_t, text)
|
|
|
|
loss = loss2
|
|
fv = fv_t
|
|
fi = None
|
|
ft = text
|
|
elif stage == '5th':
|
|
'''
|
|
img: b, 3, h, w
|
|
word: b, words
|
|
word_mask: b, words
|
|
mask: b, 1, h, w
|
|
stage: 1st or 2nd stage
|
|
'''
|
|
|
|
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
|
|
|
|
|
|
|
|
|
|
|
|
vis, image = self.backbone.encode_image(img)
|
|
word, text = self.backbone.encode_text(txt)
|
|
x = self.ADP(image)
|
|
ratio = 0.2
|
|
x = ratio * x + (1 - ratio) * image
|
|
|
|
y = self.ADP_t(text)
|
|
ratio_1 = 0.2
|
|
y = ratio * y + (1 - ratio_1) * text
|
|
|
|
fq_t = self.FPN(vis, image)
|
|
|
|
fv_t = self.gap(fq_t)
|
|
|
|
|
|
|
|
|
|
loss2 = self.IT_loss(fv_t, y)
|
|
|
|
loss = loss2
|
|
fv = fv_t
|
|
fi = x
|
|
ft = y
|
|
|
|
return loss, fv, fi, ft
|
|
|
|
class CISEN_lclip(nn.Module):
|
|
def __init__(self, cfg):
|
|
super().__init__()
|
|
|
|
clip_model = torch.load(cfg.clip_pretrain,
|
|
map_location="cpu")
|
|
|
|
backbone, image_resolution, vision_heads, embed_dim, vision_width, _ = build_lclip_model(clip_model, load_from_clip=True)
|
|
self.backbone = backbone.float()
|
|
cfg.input_size = image_resolution
|
|
cfg.heads = vision_heads // 32
|
|
cfg.emb_dim = vision_width
|
|
cfg.output_dim = embed_dim
|
|
|
|
self.FPN = FPN(cfg, in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.ADP = Adapter(cfg.output_dim, 4)
|
|
self.gap = GAP((1,1))
|
|
|
|
self.ratio = cfg.ratio
|
|
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
|
self.share_temperature = True
|
|
self.margin = 1
|
|
self.eps = 1e-3
|
|
self.ce = nn.CrossEntropyLoss()
|
|
|
|
self.lamda1 = cfg.lamda1
|
|
self.lamda2 = cfg.lamda2
|
|
self.avg = nn.AdaptiveAvgPool2d((1,1))
|
|
|
|
|
|
|
|
def IT_loss(self, image_features, text_features):
|
|
|
|
batch = image_features.shape[0]
|
|
|
|
image_features = image_features / image_features.norm(dim=-1,
|
|
keepdim=True)
|
|
text_features = text_features / text_features.norm(dim=-1,
|
|
keepdim=True)
|
|
|
|
|
|
logit_scale = self.logit_scale.exp()
|
|
logits_per_image = logit_scale * image_features @ text_features.t()
|
|
logits_per_text = logits_per_image.t()
|
|
|
|
|
|
contrastive_labels = torch.arange(batch).to(logits_per_image.device)
|
|
contrastive_loss = (self.ce(logits_per_image, contrastive_labels) + self.ce(logits_per_text, contrastive_labels)) * 0.5
|
|
|
|
|
|
return contrastive_loss
|
|
|
|
def forward(self, img, txt, stage):
|
|
|
|
if stage == '1st':
|
|
'''
|
|
img: b, 3, h, w
|
|
word: b, words
|
|
word_mask: b, words
|
|
mask: b, 1, h, w
|
|
stage: 1st or 2nd stage
|
|
'''
|
|
|
|
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
|
|
|
|
|
|
|
|
|
|
|
|
vis, image = self.backbone.encode_image(img)
|
|
|
|
text = self.backbone.encode_text(txt)
|
|
|
|
x = self.ADP(image)
|
|
|
|
x = self.ratio * x + (1-self.ratio) * image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
loss1 = self.IT_loss(x, text)
|
|
|
|
loss = loss1
|
|
|
|
ft = text
|
|
fi = x
|
|
fv = None
|
|
elif stage == '2nd':
|
|
'''
|
|
img: b, 3, h, w
|
|
word: b, words
|
|
word_mask: b, words
|
|
mask: b, 1, h, w
|
|
stage: 1st or 2nd stage
|
|
'''
|
|
|
|
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
|
|
|
|
|
|
|
|
|
|
|
|
vis, image = self.backbone.encode_image(img)
|
|
|
|
word, text = self.backbone.encode_text(txt)
|
|
|
|
x = self.ADP(image)
|
|
|
|
x = self.ratio * x + (1 - self.ratio) * image
|
|
|
|
|
|
|
|
fq_t = self.FPN(vis, x)
|
|
|
|
fv_t = self.gap(fq_t)
|
|
|
|
|
|
|
|
loss2 = self.IT_loss(fv_t, text)
|
|
|
|
loss = (loss2)
|
|
fv = fv_t
|
|
ft = text
|
|
fi = x
|
|
elif stage == '3rd':
|
|
'''
|
|
img: b, 3, h, w
|
|
word: b, words
|
|
word_mask: b, words
|
|
mask: b, 1, h, w
|
|
stage: 1st or 2nd stage
|
|
'''
|
|
|
|
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
|
|
|
|
|
|
|
|
|
|
|
|
vis, image = self.backbone.encode_image(img)
|
|
|
|
text = self.backbone.encode_text(txt)
|
|
|
|
x = self.ADP(text)
|
|
ratio = 0.2
|
|
x = ratio * x + (1 - ratio) * text
|
|
|
|
|
|
|
|
|
|
|
|
loss1 = self.IT_loss(image, x)
|
|
|
|
|
|
loss = loss1
|
|
fv = None
|
|
ft = x
|
|
fi = image
|
|
elif stage == '4th':
|
|
'''
|
|
img: b, 3, h, w
|
|
word: b, words
|
|
word_mask: b, words
|
|
mask: b, 1, h, w
|
|
stage: 1st or 2nd stage
|
|
'''
|
|
|
|
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
|
|
|
|
|
|
|
|
|
|
|
|
vis, image = self.backbone.encode_image(img)
|
|
word, text = self.backbone.encode_text(txt)
|
|
|
|
|
|
|
|
fq_t = self.FPN(vis, image)
|
|
|
|
fv_t = self.gap(fq_t)
|
|
ratio_1 = 0.2
|
|
|
|
loss2 = self.IT_loss(fv_t, text)
|
|
|
|
loss = loss2
|
|
fv = fv_t
|
|
fi = None
|
|
ft = text
|
|
elif stage == '5th':
|
|
'''
|
|
img: b, 3, h, w
|
|
word: b, words
|
|
word_mask: b, words
|
|
mask: b, 1, h, w
|
|
stage: 1st or 2nd stage
|
|
'''
|
|
|
|
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
|
|
|
|
|
|
|
|
|
|
|
|
vis, image = self.backbone.encode_image(img)
|
|
word, text = self.backbone.encode_text(txt)
|
|
x = self.ADP(image)
|
|
ratio = 0.2
|
|
x = ratio * x + (1 - ratio) * image
|
|
|
|
y = self.ADP_t(text)
|
|
ratio_1 = 0.2
|
|
y = ratio * y + (1 - ratio_1) * text
|
|
|
|
fq_t = self.FPN(vis, image)
|
|
|
|
fv_t = self.gap(fq_t)
|
|
|
|
|
|
|
|
|
|
loss2 = self.IT_loss(fv_t, y)
|
|
|
|
loss = loss2
|
|
fv = fv_t
|
|
fi = x
|
|
ft = y
|
|
|
|
return loss, fv, fi, ft
|
|
|
|
class GeoRSCLIP(nn.Module):
|
|
def __init__(self, cfg):
|
|
super().__init__()
|
|
|
|
clip_model = torch.load(cfg.clip_pretrain,
|
|
map_location="cpu")
|
|
|
|
backbone, image_resolution, vision_heads, embed_dim, vision_width, patch_size = build_model(clip_model, cfg.word_len)
|
|
self.backbone = backbone.float()
|
|
|
|
def forward(self, img, txt, stage):
|
|
|
|
|
|
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
|
|
|
|
|
|
|
|
|
|
|
|
vis, image = self.backbone.encode_image(img)
|
|
|
|
word, text = self.backbone.encode_text(txt)
|
|
|
|
loss = None
|
|
|
|
ft = text
|
|
fi = image
|
|
fv = None
|
|
return loss, fv, fi, ft
|
|
|
|
class CISEN(nn.Module):
|
|
def __init__(self, cfg):
|
|
super().__init__()
|
|
|
|
clip_model = torch.jit.load(cfg.clip_pretrain,
|
|
map_location="cpu").eval()
|
|
|
|
self.backbone = build_model(clip_model.state_dict(), cfg.word_len).float()
|
|
|
|
self.FPN = FPN(cfg, in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
|
|
|
|
self.FGFusion = TransformerDecoder(num_layers=cfg.num_layers,
|
|
d_model=cfg.vis_dim,
|
|
nhead=cfg.num_head,
|
|
dim_ffn=cfg.dim_ffn,
|
|
dropout=cfg.dropout,
|
|
return_intermediate=cfg.intermediate)
|
|
|
|
self.ASFF = AdaptiveSpatialFeatureFusion(cfg, in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
|
|
|
|
self.projT = Text_Projector(cfg, in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
|
|
|
|
|
|
|
|
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
|
self.multi_label_logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
|
self.share_temperature = True
|
|
self.margin = 1
|
|
self.eps = 1e-3
|
|
self.ce = nn.CrossEntropyLoss()
|
|
|
|
self.lamda1 = cfg.lamda1
|
|
self.lamda2 = cfg.lamda2
|
|
self.beta1 = cfg.beta1
|
|
self.beta2 = cfg.beta2
|
|
self.avg = nn.AdaptiveAvgPool2d((1,1))
|
|
|
|
|
|
self.pos_samples = cfg.pos_samples
|
|
self.neg_samples = cfg.neg_samples
|
|
|
|
def IT_loss(self, image_features, text_features):
|
|
|
|
batch = image_features.shape[0]
|
|
|
|
image_features = image_features / image_features.norm(dim=-1,
|
|
keepdim=True)
|
|
text_features = text_features / text_features.norm(dim=-1,
|
|
keepdim=True)
|
|
|
|
|
|
logit_scale = self.logit_scale.exp()
|
|
logits_per_image = logit_scale * image_features @ text_features.t()
|
|
logits_per_text = logits_per_image.t()
|
|
|
|
|
|
contrastive_labels = torch.arange(batch).to(logits_per_image.device)
|
|
contrastive_loss = (self.ce(logits_per_image, contrastive_labels) + self.ce(logits_per_text, contrastive_labels)) * 0.5
|
|
|
|
|
|
return contrastive_loss
|
|
|
|
def IET_loss(self, image_features, text_features, pos_samples, beta):
|
|
|
|
|
|
image_features = [image_feature / image_feature.norm(dim=-1,
|
|
keepdim=True) for image_feature in image_features]
|
|
text_features = text_features / text_features.norm(dim=-1,
|
|
keepdim=True)
|
|
|
|
|
|
logit_scale = self.logit_scale.exp()
|
|
|
|
|
|
logits_per_image = [logit_scale * torch.sum(torch.mul(image_feature, text_features),1) for image_feature in image_features]
|
|
logits_per_image = torch.stack(logits_per_image).t()
|
|
b = logits_per_image.shape[0]
|
|
loss1 = torch.norm(text_features - image_features[0])
|
|
positive_tagsT = torch.zeros(b,len(image_features)).to(text_features.device)
|
|
negative_tagsT = torch.zeros(b,len(image_features)).to(text_features.device)
|
|
positive_tagsT[:, 0 : pos_samples + 1] = 1
|
|
negative_tagsT[:, pos_samples + 1 : -1] = 1
|
|
|
|
maskT = positive_tagsT.unsqueeze(1) * negative_tagsT.unsqueeze(-1)
|
|
pos_score_matT = logits_per_image * positive_tagsT
|
|
neg_score_matT = logits_per_image * negative_tagsT
|
|
IW_pos3T = pos_score_matT.unsqueeze(1)
|
|
IW_neg3T = neg_score_matT.unsqueeze(-1)
|
|
OT = 1 + IW_neg3T - IW_pos3T
|
|
O_maskT = maskT * OT
|
|
diffT = torch.clamp(O_maskT, 0)
|
|
violationT = torch.sign(diffT).sum(1).sum(1)
|
|
diffT = diffT.sum(1).sum(1)
|
|
lossT = torch.mean(diffT / (violationT + self.eps))
|
|
loss = beta * loss1 + lossT
|
|
|
|
return loss
|
|
|
|
def test_IET_loss(self, image_features, text_features, pos_samples, beta1, beta2):
|
|
|
|
|
|
|
|
image_features = image_features / image_features.norm(dim=-1,
|
|
keepdim=True)
|
|
text_features = text_features / text_features.norm(dim=-1,
|
|
keepdim=True)
|
|
image_features = image_features.unsqueeze(1)
|
|
|
|
logit_scale = self.logit_scale.exp()
|
|
|
|
logits_per_image = logit_scale * torch.matmul(image_features, text_features.transpose(1, 2))
|
|
logits_per_image = logits_per_image.squeeze(1)
|
|
|
|
|
|
|
|
b = logits_per_image.shape[0]
|
|
|
|
|
|
|
|
positive_tagsT = torch.zeros(b, text_features.shape[1]).to(text_features.device)
|
|
negative_tagsT = torch.zeros(b, text_features.shape[1]).to(text_features.device)
|
|
positive_tagsT[:, 0 : pos_samples + 1] = 1
|
|
negative_tagsT[:, pos_samples + 1 : -1] = 1
|
|
|
|
maskT = positive_tagsT.unsqueeze(1) * negative_tagsT.unsqueeze(-1)
|
|
pos_score_matT = logits_per_image * positive_tagsT
|
|
neg_score_matT = logits_per_image * negative_tagsT
|
|
IW_pos3T = pos_score_matT.unsqueeze(1)
|
|
IW_neg3T = neg_score_matT.unsqueeze(-1)
|
|
OT = 1 + IW_neg3T - IW_pos3T
|
|
O_maskT = maskT * OT
|
|
diffT = torch.clamp(O_maskT, 0)
|
|
violationT = torch.sign(diffT).sum(1).sum(1)
|
|
diffT = diffT.sum(1).sum(1)
|
|
lossT = torch.mean(diffT / (violationT + self.eps))
|
|
|
|
loss = lossT
|
|
return loss
|
|
|
|
def test_IT_loss(self, image_features, text_features):
|
|
|
|
batch = image_features.shape[0]
|
|
|
|
image_features = image_features / image_features.norm(dim=-1,
|
|
keepdim=True)
|
|
text_features = text_features / text_features.norm(dim=-1,
|
|
keepdim=True)
|
|
image_features = image_features.unsqueeze(1)
|
|
|
|
logit_scale = self.logit_scale.exp()
|
|
logits_per_image = logit_scale * torch.matmul(image_features, text_features.transpose(1, 2))
|
|
logits_per_image = logits_per_image.squeeze(1)
|
|
|
|
|
|
contrastive_labels = torch.arange(batch).to(logits_per_image.device)
|
|
contrastive_loss = self.ce(logits_per_image, contrastive_labels)
|
|
|
|
|
|
return contrastive_loss
|
|
|
|
def test_forward(self, img, txt):
|
|
'''
|
|
img: b, 3, h, w
|
|
word: b, words
|
|
word_mask: b, words
|
|
mask: b, 1, h, w
|
|
stage: 1st or 2nd stage
|
|
'''
|
|
|
|
|
|
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
|
|
|
|
|
|
|
|
|
|
|
|
vis, image = self.backbone.encode_image(img)
|
|
|
|
word, text = self.backbone.encode_text(txt)
|
|
|
|
fq = self.FPN(vis, text)
|
|
|
|
b, c, h, w = fq.size()
|
|
|
|
ff = self.FGFusion(fq, word, pad_mask)
|
|
ff = ff.reshape(b, c, h, w)
|
|
|
|
f2 = self.avg(ff)
|
|
fi = image.unsqueeze(-1).unsqueeze(-1)
|
|
fv = self.ASFF(fi, f2)
|
|
fi = fi.squeeze(-1).squeeze(-1)
|
|
|
|
ft = self.projT(text)
|
|
loss1 = self.IT_loss(fi, ft)
|
|
loss2 = self.IT_loss(fv, ft)
|
|
loss = self.lamda1 * loss1 + self.lamda2 * loss2
|
|
|
|
return loss, fv, ft, fi
|
|
|
|
def forward(self, img, txt, stage):
|
|
|
|
if stage == '1st':
|
|
'''
|
|
img: b, 3, h, w
|
|
word: b, words
|
|
word_mask: b, words
|
|
mask: b, 1, h, w
|
|
stage: 1st or 2nd stage
|
|
'''
|
|
|
|
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
|
|
|
|
|
|
|
|
|
|
|
|
vis, image = self.backbone.encode_image(img)
|
|
|
|
word, text = self.backbone.encode_text(txt)
|
|
|
|
fq = self.FPN(vis, text)
|
|
|
|
b, c, h, w = fq.size()
|
|
|
|
ff = self.FGFusion(fq, word, pad_mask)
|
|
ff = ff.reshape(b, c, h, w)
|
|
|
|
f2 = self.avg(ff)
|
|
fi = image.unsqueeze(-1).unsqueeze(-1)
|
|
fv = self.ASFF(fi, f2)
|
|
fi = fi.squeeze(-1).squeeze(-1)
|
|
|
|
ft = self.projT(text)
|
|
loss1 = self.IT_loss(fi, ft)
|
|
loss2 = self.IT_loss(fv, ft)
|
|
loss = self.lamda1 * loss1 + self.lamda2 * loss2
|
|
|
|
elif stage == '2nd':
|
|
"""
|
|
txt: b, num, words
|
|
img: b, 3, h, w
|
|
"""
|
|
|
|
|
|
b, num, l = txt.shape[0], txt.shape[1], txt.shape[2]
|
|
txt = txt.view(-1, txt.size(-1))
|
|
|
|
|
|
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
|
|
|
|
b = img.shape[0]
|
|
vis, image = self.backbone.encode_image(img)
|
|
word, text = self.backbone.encode_text(txt)
|
|
|
|
fq = self.FPN(vis, text)
|
|
|
|
|
|
b, c, h, w = fq.size()
|
|
|
|
ff = self.FGFusion(fq, word, pad_mask)
|
|
ff = ff.reshape(b, c, h, w)
|
|
|
|
f2 = self.avg(ff)
|
|
fi = image.unsqueeze(-1).unsqueeze(-1)
|
|
fi_ = fi.repeat(int(f2.shape[0] / fi.shape[0]), 1, 1, 1)
|
|
|
|
fv = self.ASFF(fi_, f2)
|
|
fi = fi.squeeze(-1).squeeze(-1)
|
|
|
|
|
|
ft = text.view(img.shape[0], int(text.shape[0] / img.shape[0]), -1)[:, 0, :]
|
|
fv = fv.view(ft.shape[0], int(text.shape[0] / ft.shape[0]), fv.shape[1])
|
|
loss = self.test_IET_loss(fi, fv, self.pos_samples, self.beta1, self.beta2)
|
|
|
|
|
|
elif stage == 'test':
|
|
"""
|
|
txt: b, num, words
|
|
img: b, 3, h, w
|
|
"""
|
|
txt = txt.permute(1, 0, 2)
|
|
|
|
|
|
|
|
|
|
|
|
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
|
|
|
|
|
|
|
|
|
|
|
|
b = img.shape[0]
|
|
words = []
|
|
texts = []
|
|
vis, image = self.backbone.encode_image(img)
|
|
for i in range(txt.shape[0]):
|
|
word, text = self.backbone.encode_text(txt[i])
|
|
words.append(word)
|
|
texts.append(text)
|
|
|
|
fvn = []
|
|
|
|
for i in range(txt.shape[0]):
|
|
fq = self.FPN(vis, texts[i])
|
|
|
|
b, c, h, w = fq.size()
|
|
|
|
ff = self.FGFusion(fq, words[i], pad_mask[i, :, :])
|
|
ff = ff.reshape(b, c, h, w)
|
|
|
|
f2 = self.avg(ff)
|
|
fi = image.unsqueeze(-1).unsqueeze(-1)
|
|
fv = self.ASFF(fi, f2)
|
|
fi = fi.squeeze(-1).squeeze(-1)
|
|
fvn.append(fv)
|
|
|
|
|
|
ft = self.projT(texts[0])
|
|
loss = self.IET_loss(fvn, ft, self.pos_samples, self.beta)
|
|
fv = fvn
|
|
|
|
|
|
else:
|
|
print('stage should be either 1st or 2nd or test')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return loss, fv, fi, ft
|
|
|
|
|
|
|
|
class CRIS(nn.Module):
|
|
def __init__(self, cfg):
|
|
super().__init__()
|
|
|
|
clip_model = torch.jit.load(cfg.clip_pretrain,
|
|
map_location="cpu").eval()
|
|
|
|
self.backbone, _, _, _, _ = build_model(clip_model.state_dict(), cfg.word_len)
|
|
self.backbone = self.backbone.float()
|
|
self.Label_encoder = build_promptlearner(clip_model.state_dict()).float()
|
|
self.Label_encoder.init_label_emb(cfg.label_path)
|
|
|
|
|
|
self.FPN = FPN(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
|
|
|
|
self.FGFusion = TransformerDecoder(num_layers=cfg.num_layers,
|
|
d_model=cfg.vis_dim,
|
|
nhead=cfg.num_head,
|
|
dim_ffn=cfg.dim_ffn,
|
|
dropout=cfg.dropout,
|
|
return_intermediate=cfg.intermediate)
|
|
|
|
self.ASFF = AdaptiveSpatialFeatureFusion(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
|
|
|
|
self.projT = Text_Projector(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
|
|
|
|
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
|
self.multi_label_logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
|
self.share_temperature = True
|
|
self.margin = 1
|
|
self.eps = 1e-3
|
|
self.ce = nn.CrossEntropyLoss()
|
|
self.avg = nn.AdaptiveAvgPool2d((1,1))
|
|
self.fc = nn.Linear(512, cfg.num_classes)
|
|
|
|
|
|
|
|
def IT_loss(self, image_features, text_features):
|
|
|
|
batch = image_features.shape[0]
|
|
|
|
image_features = image_features / image_features.norm(dim=-1,
|
|
keepdim=True)
|
|
text_features = text_features / text_features.norm(dim=-1,
|
|
keepdim=True)
|
|
|
|
|
|
logit_scale = self.logit_scale.exp()
|
|
logits_per_image = logit_scale * image_features @ text_features.t()
|
|
logits_per_text = logits_per_image.t()
|
|
|
|
|
|
contrastive_labels = torch.arange(batch).to(logits_per_image.device)
|
|
contrastive_loss = (self.ce(logits_per_image, contrastive_labels) + self.ce(logits_per_text, contrastive_labels)) * 0.5
|
|
|
|
|
|
return contrastive_loss
|
|
|
|
def IL_loss(self, image_features, label_features, labels):
|
|
|
|
|
|
positive_tagsT = torch.clamp(labels,0.,1.)
|
|
negative_tagsT = torch.clamp(-labels,0.,1.)
|
|
maskT = positive_tagsT.unsqueeze(1) * negative_tagsT.unsqueeze(-1)
|
|
|
|
|
|
|
|
image_features = image_features / image_features.norm(dim=-1,
|
|
keepdim=True)
|
|
label_features = label_features / label_features.norm(dim=-1,
|
|
keepdim=True)
|
|
|
|
logit_scale = self.multi_label_logit_scale.exp()
|
|
logits_per_image = logit_scale * image_features @ label_features.t()
|
|
|
|
pos_score_matT = logits_per_image * positive_tagsT
|
|
neg_score_matT = logits_per_image * negative_tagsT
|
|
IW_pos3T = pos_score_matT.unsqueeze(1)
|
|
IW_neg3T = neg_score_matT.unsqueeze(-1)
|
|
OT = self.margin + IW_neg3T - IW_pos3T
|
|
O_maskT = maskT * OT
|
|
diffT = torch.clamp(O_maskT, 0)
|
|
violationT = torch.sign(diffT).sum(1).sum(1)
|
|
diffT = diffT.sum(1).sum(1)
|
|
lossT = torch.mean(diffT / (violationT + self.eps))
|
|
|
|
|
|
|
|
|
|
return lossT
|
|
|
|
def margin_loss(self, image_features, label_features, labels):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
image_features = image_features / image_features.norm(dim=-1,
|
|
keepdim=True)
|
|
label_features = label_features / label_features.norm(dim=-1,
|
|
keepdim=True)
|
|
|
|
logit_scale = self.multi_label_logit_scale.exp()
|
|
logits_per_image = logit_scale * image_features @ label_features.t()
|
|
|
|
|
|
image_label_positive_pairs = logits_per_image * labels
|
|
image_label_mean_positive = image_label_positive_pairs.sum() / labels.sum()
|
|
image_label_negative_pairs = logits_per_image * (1 - labels)
|
|
image_label_mean_negative = image_label_negative_pairs.sum() / (logits_per_image.numel() - labels.sum() + self.eps)
|
|
|
|
contrastive_loss = torch.relu(self.margin - image_label_mean_positive + image_label_mean_negative)
|
|
|
|
return contrastive_loss
|
|
|
|
def forward(self, img, txt, target=None):
|
|
'''
|
|
img: b, 3, h, w
|
|
word: b, words
|
|
word_mask: b, words
|
|
mask: b, 1, h, w
|
|
'''
|
|
|
|
|
|
|
|
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
|
|
|
|
|
|
|
|
|
|
|
|
vis, image = self.backbone.encode_image(img)
|
|
word, text = self.backbone.encode_text(txt)
|
|
|
|
|
|
fl = self.Label_encoder(image.device)
|
|
|
|
fq = self.FPN(vis, text)
|
|
b, c, h, w = fq.size()
|
|
|
|
ff = self.FGFusion(fq, word, pad_mask)
|
|
|
|
ff = ff.reshape(b, c, h, w)
|
|
f2 = self.avg(ff)
|
|
|
|
f1 = image.unsqueeze(-1).unsqueeze(-1)
|
|
fv = self.ASFF(f1, f2)
|
|
|
|
|
|
ft = self.projT(text)
|
|
|
|
|
|
|
|
|
|
loss1 = self.IT_loss(fv, ft)
|
|
loss2 = self.IL_loss(fv, fl, target)
|
|
loss = loss1 + loss2
|
|
|
|
|
|
|
|
|
|
return loss, fv, ft, fl
|
|
|
|
class zh_clip(nn.Module):
|
|
def __init__(self, cfg):
|
|
super().__init__()
|
|
|
|
clip_model = torch.jit.load(cfg.clip_pretrain,
|
|
map_location="cpu").eval()
|
|
self.backbone = build_modified_model(clip_model.state_dict(), cfg.word_len).float()
|
|
|
|
self.text_encoder = AutoModelForSequenceClassification.from_pretrained(cfg.chinese)
|
|
self.text_lin = nn.Linear(512, 1024)
|
|
|
|
|
|
|
|
self.neck = ViTFPN(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
|
|
|
|
|
|
self.avg = nn.AdaptiveAvgPool2d((1,1))
|
|
self.fc = nn.Linear(512, cfg.num_classes)
|
|
def forward(self, img, word):
|
|
'''
|
|
img: b, 3, h, w
|
|
word: b, words
|
|
word_mask: b, words
|
|
mask: b, 1, h, w
|
|
'''
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vis, feat, cls = self.backbone.encode_image(img)
|
|
state = self.text_encoder(word.squeeze(1)).logits
|
|
state = self.text_lin(state)
|
|
|
|
fq = self.neck(feat, state)
|
|
|
|
out = self.avg(fq)
|
|
out = out.squeeze(-1).squeeze(-1)
|
|
out = self.fc(out)
|
|
|
|
return out
|
|
|
|
class poi_clip(nn.Module):
|
|
def __init__(self, cfg):
|
|
super().__init__()
|
|
|
|
clip_model = torch.jit.load(cfg.clip_pretrain,
|
|
map_location="cpu").eval()
|
|
self.backbone = build_modified_model(clip_model.state_dict(), cfg.word_len).float()
|
|
|
|
self.text_encoder = AutoModelForSequenceClassification.from_pretrained(cfg.chinese)
|
|
self.text_lin = nn.Linear(512, 1024)
|
|
|
|
|
|
|
|
self.neck = ViTFPN(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
|
|
|
|
|
|
self.avg = nn.AdaptiveAvgPool2d((1,1))
|
|
self.fc = nn.Linear(512, cfg.num_classes)
|
|
def forward(self, img, word):
|
|
'''
|
|
img: b, 3, h, w
|
|
word: b, words
|
|
word_mask: b, words
|
|
mask: b, 1, h, w
|
|
'''
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vis, feat, cls = self.backbone.encode_image(img)
|
|
state = self.text_encoder(word.squeeze(1)).logits
|
|
state = self.text_lin(state)
|
|
|
|
fq = self.neck(feat, state)
|
|
|
|
out = self.avg(fq)
|
|
out = out.squeeze(-1).squeeze(-1)
|
|
out = self.fc(out)
|
|
|
|
return out
|
|
|
|
class Clip_hash_model(nn.Module):
|
|
def __init__(self, cfg):
|
|
super().__init__()
|
|
|
|
|
|
clip_model = torch.jit.load(cfg.clip_pretrain,
|
|
map_location="cpu").eval()
|
|
self.backbone = build_model(clip_model.state_dict(), cfg.word_len).float()
|
|
|
|
self.neck = FPN(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
|
|
|
|
|
|
self.avg = nn.AdaptiveAvgPool2d((1, 1))
|
|
|
|
self.classifier = nn.Sequential(
|
|
nn.Linear(cfg.fpn_out[1], cfg.hash_dim, bias=True),
|
|
nn.Tanh(),
|
|
)
|
|
|
|
self.classifier2 = nn.Sequential(
|
|
nn.Linear(cfg.hash_dim, cfg.num_classes)
|
|
)
|
|
|
|
|
|
self.image_module = nn.Sequential(
|
|
nn.Linear(cfg.img_dim, cfg.hidden_dim, bias=True),
|
|
nn.BatchNorm1d(cfg.hidden_dim),
|
|
nn.ReLU(True),
|
|
nn.Linear(cfg.hidden_dim, cfg.hash_dim, bias=True),
|
|
nn.Tanh()
|
|
)
|
|
|
|
self.text_module = nn.Sequential(
|
|
nn.Linear(cfg.txt_dim, cfg.hidden_dim, bias=True),
|
|
nn.BatchNorm1d(cfg.hidden_dim),
|
|
nn.ReLU(True),
|
|
nn.Linear(cfg.hidden_dim, cfg.hash_dim, bias=True),
|
|
nn.Tanh()
|
|
)
|
|
def forward(self, img, word, mask=None):
|
|
'''
|
|
img: b, 3, h, w
|
|
word: b, words
|
|
word_mask: b, words
|
|
'''
|
|
pad_mask = torch.zeros_like(word).masked_fill_(word == 0, 1).bool()
|
|
|
|
|
|
|
|
vis, image = self.backbone.encode_image(img)
|
|
word, state = self.backbone.encode_text(word)
|
|
|
|
|
|
fq = self.neck(vis, state)
|
|
|
|
|
|
|
|
out = self.avg(fq)
|
|
out = out.squeeze(-1).squeeze(-1)
|
|
out_hash = self.classifier(out)
|
|
res = self.classifier2(out_hash)
|
|
|
|
|
|
|
|
img_hash = self.image_module(image)
|
|
txt_hash = self.text_module(state)
|
|
|
|
|
|
|
|
return img_hash, txt_hash, out_hash, res
|
|
|
|
class Clip_model(nn.Module):
|
|
def __init__(self, cfg):
|
|
super().__init__()
|
|
|
|
|
|
clip_model = torch.jit.load(cfg.clip_pretrain,
|
|
map_location="cpu").eval()
|
|
|
|
self.neck = FPN(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
|
|
self.avg = nn.AdaptiveAvgPool2d((1, 1))
|
|
self.backbone = build_model(clip_model.state_dict(), cfg.word_len).float()
|
|
|
|
def forward(self, img, word, mask=None):
|
|
'''
|
|
img: b, 3, h, w
|
|
word: b, words
|
|
word_mask: b, words
|
|
'''
|
|
|
|
|
|
|
|
pad_mask = torch.zeros_like(word).masked_fill_(word == 0, 1).bool()
|
|
vis, image = self.backbone.encode_image(img)
|
|
word, state = self.backbone.encode_text(word)
|
|
f = self.neck(vis, state)
|
|
out = self.avg(f)
|
|
out = out.squeeze(-1).squeeze(-1)
|
|
image_features = image / image.norm(dim=-1, keepdim=True)
|
|
text_features = state / state.norm(dim=-1, keepdim=True)
|
|
|
|
|
|
logit_scale = self.backbone.logit_scale.exp()
|
|
logits_per_image = logit_scale * image_features @ text_features.t()
|
|
logits_per_text = logits_per_image.t()
|
|
|
|
|
|
return logits_per_image, logits_per_text
|
|
|
|
|
|
class CISEN_rsvit_hug(nn.Module, PyTorchModelHubMixin):
|
|
def __init__(self, embed_dim, image_resolution, vision_layers, vision_width,
|
|
vision_patch_size, context_length, txt_length, vocab_size,
|
|
transformer_width, transformer_heads, transformer_layers, patch_size,
|
|
output_dim, ratio, emb_dim, fpn_in, fpn_out):
|
|
super().__init__()
|
|
|
|
vision_heads = vision_width * 32 // 64
|
|
|
|
backbone = CLIP(embed_dim, image_resolution, vision_layers, vision_width,
|
|
vision_patch_size, context_length, txt_length, vocab_size,
|
|
transformer_width, transformer_heads, transformer_layers)
|
|
self.backbone = backbone.float()
|
|
self.patch_emb = image_resolution // patch_size
|
|
|
|
self.FPN = ViTFPN(image_resolution, in_channels=fpn_in, out_channels=fpn_out)
|
|
|
|
self.ADP = Adapter(output_dim, 4)
|
|
|
|
self.ratio = ratio
|
|
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
|
self.share_temperature = True
|
|
self.ce = nn.CrossEntropyLoss()
|
|
self.ms_adaptor = nn.ModuleList(
|
|
[
|
|
nn.Sequential(
|
|
nn.ConvTranspose2d(emb_dim, emb_dim, 2, 2),
|
|
nn.GroupNorm(32, emb_dim),
|
|
nn.GELU(),
|
|
nn.ConvTranspose2d(emb_dim, emb_dim, 2, 2),
|
|
),
|
|
nn.Sequential(
|
|
nn.ConvTranspose2d(emb_dim, emb_dim, 2, 2),
|
|
),
|
|
nn.Sequential(
|
|
nn.Identity(),
|
|
),
|
|
nn.Sequential(
|
|
nn.MaxPool2d(2),
|
|
),
|
|
|
|
]
|
|
)
|
|
|
|
self.ms_adaptor.apply(self.init_adaptor)
|
|
def init_adaptor(self, m):
|
|
if isinstance(m, nn.Conv2d):
|
|
lecun_normal_(m.weight)
|
|
if m.bias is not None:
|
|
nn.init.constant_(m.bias, 0)
|
|
elif isinstance(m, nn.GroupNorm):
|
|
nn.init.constant_(m.bias, 0)
|
|
nn.init.constant_(m.weight, 1.0)
|
|
elif isinstance(m, nn.ConvTranspose2d):
|
|
lecun_normal_(m.weight)
|
|
if m.bias is not None:
|
|
nn.init.zeros_(m.bias)
|
|
|
|
|
|
def image_encode(self, img):
|
|
vis, image = self.backbone.encode_image(img)
|
|
|
|
x = self.ADP(image)
|
|
|
|
x = self.ratio * x + (1 - self.ratio) * image
|
|
return x
|
|
|
|
def text_encode(self, txt):
|
|
|
|
word, text = self.backbone.encode_text(txt)
|
|
|
|
return text
|
|
|
|
def forward(self, img, txt):
|
|
'''
|
|
img: b, 3, h, w
|
|
word: b, words
|
|
word_mask: b, words
|
|
mask: b, 1, h, w
|
|
stage: 1st or 2nd stage
|
|
'''
|
|
|
|
pad_mask = torch.zeros_like(txt).masked_fill_(txt == 0, 1).bool()
|
|
|
|
|
|
|
|
|
|
|
|
vis, image = self.backbone.encode_image(img)
|
|
|
|
word, text = self.backbone.encode_text(txt)
|
|
|
|
x = self.ADP(image)
|
|
|
|
x = self.ratio * x + (1 - self.ratio) * image
|
|
|
|
vis_trans = []
|
|
for i in range(len(self.ms_adaptor)):
|
|
x_ = rearrange(
|
|
vis[i],
|
|
"b (h w) c -> b c h w",
|
|
h=self.patch_emb,
|
|
w=self.patch_emb,
|
|
).contiguous()
|
|
|
|
feats = self.ms_adaptor[i](x_)
|
|
|
|
vis_trans.append(feats)
|
|
|
|
|
|
fv_t = self.FPN(vis_trans[1:], x, False)
|
|
|
|
|
|
|
|
fv = fv_t
|
|
ft = text
|
|
fi = x
|
|
|
|
return fv, fi, ft |