|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import os |
|
import sys |
|
from ldm.util import instantiate_from_config |
|
from transformers.models.clip.modeling_clip import CLIPTextModel |
|
from omegaconf import OmegaConf |
|
from lib.mask_predictor import SimpleDecoding |
|
from transformers import PreTrainedModel |
|
from .models import UNetWrapper, TextAdapterRefer |
|
from .evpconfig import EVPConfig |
|
from transformers import CLIPTokenizer |
|
import torchvision.transforms as transforms |
|
|
|
|
|
def icnr(x, scale=2, init=nn.init.kaiming_normal_): |
|
""" |
|
Checkerboard artifact free sub-pixel convolution |
|
https://arxiv.org/abs/1707.02937 |
|
""" |
|
ni,nf,h,w = x.shape |
|
ni2 = int(ni/(scale**2)) |
|
k = init(torch.zeros([ni2,nf,h,w])).transpose(0, 1) |
|
k = k.contiguous().view(ni2, nf, -1) |
|
k = k.repeat(1, 1, scale**2) |
|
k = k.contiguous().view([nf,ni,h,w]).transpose(0, 1) |
|
x.data.copy_(k) |
|
|
|
|
|
class PixelShuffle(nn.Module): |
|
""" |
|
Real-Time Single Image and Video Super-Resolution |
|
https://arxiv.org/abs/1609.05158 |
|
""" |
|
def __init__(self, n_channels, scale): |
|
super(PixelShuffle, self).__init__() |
|
self.conv = nn.Conv2d(n_channels, n_channels*(scale**2), kernel_size=1) |
|
icnr(self.conv.weight) |
|
self.shuf = nn.PixelShuffle(scale) |
|
self.relu = nn.ReLU() |
|
|
|
def forward(self,x): |
|
x = self.shuf(self.relu(self.conv(x))) |
|
return x |
|
|
|
|
|
class AttentionModule(nn.Module): |
|
def __init__(self, in_channels, out_channels): |
|
super(AttentionModule, self).__init__() |
|
|
|
|
|
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) |
|
|
|
|
|
self.group_norm = nn.GroupNorm(20, out_channels) |
|
|
|
|
|
self.relu = nn.ReLU() |
|
|
|
|
|
self.spatial_attention = nn.Sequential( |
|
nn.Conv2d(in_channels, 1, kernel_size=1), |
|
nn.Sigmoid() |
|
) |
|
|
|
def forward(self, x): |
|
|
|
spatial_attention = self.spatial_attention(x) |
|
x = x * spatial_attention |
|
|
|
|
|
x = self.conv1(x) |
|
x = self.group_norm(x) |
|
x = self.relu(x) |
|
|
|
return x |
|
|
|
|
|
class AttentionDownsamplingModule(nn.Module): |
|
def __init__(self, in_channels, out_channels, scale_factor=2): |
|
super(AttentionDownsamplingModule, self).__init__() |
|
|
|
|
|
self.spatial_attention = nn.Sequential( |
|
nn.Conv2d(in_channels, 1, kernel_size=1), |
|
nn.Sigmoid() |
|
) |
|
|
|
|
|
self.channel_attention = nn.Sequential( |
|
nn.AdaptiveAvgPool2d(1), |
|
nn.Conv2d(in_channels, in_channels // 8, kernel_size=1), |
|
nn.ReLU(inplace=True), |
|
nn.Conv2d(in_channels // 8, in_channels, kernel_size=1), |
|
nn.Sigmoid() |
|
) |
|
|
|
|
|
if scale_factor == 2: |
|
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) |
|
elif scale_factor == 4: |
|
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1) |
|
|
|
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1) |
|
|
|
|
|
self.group_norm = nn.GroupNorm(20, out_channels) |
|
|
|
|
|
self.relu = nn.ReLU(inplace=True) |
|
|
|
def forward(self, x): |
|
|
|
spatial_attention = self.spatial_attention(x) |
|
x = x * spatial_attention |
|
|
|
|
|
channel_attention = self.channel_attention(x) |
|
x = x * channel_attention |
|
|
|
|
|
x = self.conv1(x) |
|
x = self.group_norm(x) |
|
x = self.relu(x) |
|
x = self.conv2(x) |
|
x = self.group_norm(x) |
|
x = self.relu(x) |
|
|
|
return x |
|
|
|
|
|
class AttentionUpsamplingModule(nn.Module): |
|
def __init__(self, in_channels, out_channels): |
|
super(AttentionUpsamplingModule, self).__init__() |
|
|
|
|
|
self.spatial_attention = nn.Sequential( |
|
nn.Conv2d(in_channels, 1, kernel_size=1), |
|
nn.Sigmoid() |
|
) |
|
|
|
|
|
self.channel_attention = nn.Sequential( |
|
nn.AdaptiveAvgPool2d(1), |
|
nn.Conv2d(in_channels, in_channels // 8, kernel_size=1), |
|
nn.ReLU(), |
|
nn.Conv2d(in_channels // 8, in_channels, kernel_size=1), |
|
nn.Sigmoid() |
|
) |
|
|
|
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) |
|
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) |
|
|
|
|
|
self.group_norm = nn.GroupNorm(20, out_channels) |
|
|
|
|
|
self.relu = nn.ReLU() |
|
self.upscale = PixelShuffle(in_channels, 2) |
|
|
|
def forward(self, x): |
|
|
|
spatial_attention = self.spatial_attention(x) |
|
x = x * spatial_attention |
|
|
|
|
|
channel_attention = self.channel_attention(x) |
|
x = x * channel_attention |
|
|
|
|
|
x = self.conv1(x) |
|
x = self.group_norm(x) |
|
x = self.relu(x) |
|
x = self.conv2(x) |
|
x = self.group_norm(x) |
|
x = self.relu(x) |
|
|
|
|
|
x = self.upscale(x) |
|
|
|
return x |
|
|
|
|
|
class ConvLayer(nn.Module): |
|
def __init__(self, in_channels, out_channels): |
|
super(ConvLayer, self).__init__() |
|
|
|
self.conv1 = nn.Sequential( |
|
nn.Conv2d(in_channels, out_channels, 1), |
|
nn.GroupNorm(20, out_channels), |
|
nn.ReLU(), |
|
) |
|
|
|
def forward(self, x): |
|
x = self.conv1(x) |
|
|
|
return x |
|
|
|
|
|
class InverseMultiAttentiveFeatureRefinement(nn.Module): |
|
def __init__(self, in_channels_list): |
|
super(InverseMultiAttentiveFeatureRefinement, self).__init__() |
|
|
|
self.layer1 = AttentionModule(in_channels_list[0], in_channels_list[0]) |
|
self.layer2 = AttentionDownsamplingModule(in_channels_list[0], in_channels_list[0]//2, scale_factor = 2) |
|
self.layer3 = ConvLayer(in_channels_list[0]//2 + in_channels_list[1], in_channels_list[1]) |
|
self.layer4 = AttentionDownsamplingModule(in_channels_list[1], in_channels_list[1]//2, scale_factor = 2) |
|
self.layer5 = ConvLayer(in_channels_list[1]//2 + in_channels_list[2], in_channels_list[2]) |
|
self.layer6 = AttentionDownsamplingModule(in_channels_list[2], in_channels_list[2]//2, scale_factor = 2) |
|
self.layer7 = ConvLayer(in_channels_list[2]//2 + in_channels_list[3], in_channels_list[3]) |
|
|
|
''' |
|
self.layer8 = AttentionUpsamplingModule(in_channels_list[3], in_channels_list[3]) |
|
self.layer9 = ConvLayer(in_channels_list[2] + in_channels_list[3], in_channels_list[2]) |
|
self.layer10 = AttentionUpsamplingModule(in_channels_list[2], in_channels_list[2]) |
|
self.layer11 = ConvLayer(in_channels_list[1] + in_channels_list[2], in_channels_list[1]) |
|
self.layer12 = AttentionUpsamplingModule(in_channels_list[1], in_channels_list[1]) |
|
self.layer13 = ConvLayer(in_channels_list[0] + in_channels_list[1], in_channels_list[0]) |
|
''' |
|
def forward(self, inputs): |
|
x_c4, x_c3, x_c2, x_c1 = inputs |
|
x_c4 = self.layer1(x_c4) |
|
x_c4_3 = self.layer2(x_c4) |
|
x_c3 = torch.cat([x_c4_3, x_c3], dim=1) |
|
x_c3 = self.layer3(x_c3) |
|
x_c3_2 = self.layer4(x_c3) |
|
x_c2 = torch.cat([x_c3_2, x_c2], dim=1) |
|
x_c2 = self.layer5(x_c2) |
|
x_c2_1 = self.layer6(x_c2) |
|
x_c1 = torch.cat([x_c2_1, x_c1], dim=1) |
|
x_c1 = self.layer7(x_c1) |
|
''' |
|
x_c1_2 = self.layer8(x_c1) |
|
x_c2 = torch.cat([x_c1_2, x_c2], dim=1) |
|
x_c2 = self.layer9(x_c2) |
|
x_c2_3 = self.layer10(x_c2) |
|
x_c3 = torch.cat([x_c2_3, x_c3], dim=1) |
|
x_c3 = self.layer11(x_c3) |
|
x_c3_4 = self.layer12(x_c3) |
|
x_c4 = torch.cat([x_c3_4, x_c4], dim=1) |
|
x_c4 = self.layer13(x_c4) |
|
''' |
|
return [x_c4, x_c3, x_c2, x_c1] |
|
|
|
|
|
class EVPRefer_warp(PreTrainedModel): |
|
config_class = EVPConfig |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.model = EVPRefer() |
|
def forward(self, img, sentences): |
|
return self.model(img, sentences) |
|
|
|
|
|
class EVPRefer(nn.Module): |
|
"""Encoder Decoder segmentors. |
|
|
|
EncoderDecoder typically consists of backbone, decode_head, auxiliary_head. |
|
Note that auxiliary_head is only used for deep supervision during training, |
|
which could be dumped during inference. |
|
""" |
|
|
|
def __init__(self, |
|
sd_path=None, |
|
base_size=512, |
|
token_embed_dim=768, |
|
neck_dim=[320,680,1320,1280], |
|
**args): |
|
super().__init__() |
|
config = OmegaConf.load('./v1-inference.yaml') |
|
if os.path.exists(f'{sd_path}'): |
|
config.model.params.ckpt_path = f'{sd_path}' |
|
else: |
|
config.model.params.ckpt_path = None |
|
|
|
sd_model = instantiate_from_config(config.model) |
|
self.encoder_vq = sd_model.first_stage_model |
|
self.unet = UNetWrapper(sd_model.model, base_size=base_size) |
|
del sd_model.cond_stage_model |
|
del self.encoder_vq.decoder |
|
for param in self.encoder_vq.parameters(): |
|
param.requires_grad = True |
|
|
|
self.text_adapter = TextAdapterRefer(text_dim=token_embed_dim) |
|
|
|
self.classifier = SimpleDecoding(dims=neck_dim) |
|
|
|
self.alpha = nn.Parameter(torch.ones(token_embed_dim) * 1e-4) |
|
|
|
self.aggregation = InverseMultiAttentiveFeatureRefinement([320,680,1320,1280]) |
|
self.clip_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") |
|
self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") |
|
|
|
for param in self.clip_model.parameters(): |
|
param.requires_grad = True |
|
|
|
|
|
def forward(self, img, sentences): |
|
image_t = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])(img) |
|
shape = image_t.shape |
|
img = torch.nn.functional.interpolate(image_t, (512,512), mode='bilinear', align_corners=True) |
|
|
|
input_ids = self.tokenizer(text=sentences, truncation=True, max_length=40, return_length=True, |
|
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")['input_ids'].to(image_t.device) |
|
|
|
|
|
input_shape = img.shape[-2:] |
|
|
|
latents = self.encoder_vq.encode(img).mode() |
|
latents = latents / 4.7164 |
|
|
|
l_feats = self.clip_model(input_ids=input_ids).last_hidden_state |
|
c_crossattn = self.text_adapter(latents, l_feats, self.alpha) |
|
t = torch.ones((img.shape[0],), device=img.device).long() |
|
outs = self.unet(latents, t, c_crossattn=[c_crossattn]) |
|
|
|
outs = self.aggregation(outs) |
|
|
|
x_c1, x_c2, x_c3, x_c4 = outs |
|
x = self.classifier(x_c4, x_c3, x_c2, x_c1) |
|
x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=True) |
|
pred = torch.nn.functional.interpolate(x, shape[2:], mode='bilinear', align_corners=True) |
|
output_mask = pred.detach().cpu().argmax(1).data.numpy().squeeze() |
|
return output_mask |
|
|
|
|
|
def get_latent(self, x): |
|
return self.encoder_vq.encode(x).mode() |
|
|