image_path = './image001.png' sentence = 'spoon on the dish' weights = 'checkpoints/gradio.pth' device = 'cpu' # pre-process the input image from PIL import Image import torchvision.transforms as T import numpy as np import datetime import os import time import torch import torch.utils.data from torch import nn from bert.multimodal_bert import MultiModalBert import torchvision from lib import multimodal_segmentation_ppm #import transforms as T import utils import numpy as np from PIL import Image import torch.nn.functional as F from modeling.MaskFormerModel import MaskFormerHead from addict import Dict #from bert.modeling_bert import BertLMPredictionHead, BertEncoder import cv2 import textwrap class WrapperModel(nn.Module): def __init__(self, image_model, language_model, classifier) : super(WrapperModel, self).__init__() self.image_model = image_model self.language_model = language_model self.classifier = classifier config = Dict({ "architectures": [ "BertForMaskedLM" ], "attention_probs_dropout_prob": 0.1, "gradient_checkpointing": False, "hidden_act": "gelu", "hidden_dropout_prob": 0.1, "hidden_size": 512, "initializer_range": 0.02, "intermediate_size": 3072, "layer_norm_eps": 1e-12, #"max_position_embeddings": 16+20, "model_type": "bert", "num_attention_heads": 8, "num_hidden_layers": 8, "pad_token_id": 0, "position_embedding_type": "absolute", "transformers_version": "4.6.0.dev0", "type_vocab_size": 2, "use_cache": True, "vocab_size": 30522 }) def _get_binary_mask(self, target): # 返回每类的binary mask y, x = target.size() target_onehot = torch.zeros(self.num_classes + 1, y, x) target_onehot = target_onehot.scatter(dim=0, index=target.unsqueeze(0), value=1) return target_onehot[1:] def semantic_inference(self, mask_cls, mask_pred): mask_cls = F.softmax(mask_cls, dim=1)[...,1:] mask_pred = mask_pred.sigmoid() semseg = torch.einsum("bqc,bqhw->bchw", mask_cls, mask_pred) return semseg def forward(self, image, sentences, attentions): print(image.sum(), sentences.sum(), attentions.sum()) input_shape = image.shape[-2:] l_mask = attentions.unsqueeze(dim=-1) i0, Wh, Ww = self.image_model.forward_stem(image) l0, extended_attention_mask = self.language_model.forward_stem(sentences, attentions) i1 = self.image_model.forward_stage1(i0, Wh, Ww) l1 = self.language_model.forward_stage1(l0, extended_attention_mask) i1_residual, H, W, i1_temp, Wh, Ww = self.image_model.forward_pwam1(i1, Wh, Ww, l1, l_mask) l1_residual, l1 = self.language_model.forward_pwam1(i1, l1, extended_attention_mask) i1 = i1_temp i2 = self.image_model.forward_stage2(i1, Wh, Ww) l2 = self.language_model.forward_stage2(l1, extended_attention_mask) i2_residual, H, W, i2_temp, Wh, Ww = self.image_model.forward_pwam2(i2, Wh, Ww, l2, l_mask) l2_residual, l2 = self.language_model.forward_pwam2(i2, l2, extended_attention_mask) i2 = i2_temp i3 = self.image_model.forward_stage3(i2, Wh, Ww) l3 = self.language_model.forward_stage3(l2, extended_attention_mask) i3_residual, H, W, i3_temp, Wh, Ww = self.image_model.forward_pwam3(i3, Wh, Ww, l3, l_mask) l3_residual, l3 = self.language_model.forward_pwam3(i3, l3, extended_attention_mask) i3 = i3_temp i4 = self.image_model.forward_stage4(i3, Wh, Ww) l4 = self.language_model.forward_stage4(l3, extended_attention_mask) i4_residual, H, W, i4_temp, Wh, Ww = self.image_model.forward_pwam4(i4, Wh, Ww, l4, l_mask) l4_residual, l4 = self.language_model.forward_pwam4(i4, l4, extended_attention_mask) i4 = i4_temp #i1_residual, i2_residual, i3_residual, i4_residual = features #x = self.classifier(i4_residual, i3_residual, i2_residual, i1_residual) #x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=True) outputs = {} outputs['s1'] = i1_residual outputs['s2'] = i2_residual outputs['s3'] = i3_residual outputs['s4'] = i4_residual predictions = self.classifier(outputs) return predictions #img = Image.open(image_path).convert("RGB") img = Image.open(image_path).convert("RGB") img_ndarray = np.array(img) # (orig_h, orig_w, 3); for visualization original_w, original_h = img.size # PIL .size returns width first and height second image_transforms = T.Compose( [ T.Resize((480, 480)), T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ] ) img = image_transforms(img).unsqueeze(0) # (1, 3, 480, 480) img = img.to(device) # for inference (input) # pre-process the raw sentence from bert.tokenization_bert import BertTokenizer import torch tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') sentence_tokenized = tokenizer.encode(text=sentence, add_special_tokens=True) sentence_tokenized = sentence_tokenized[:20] # if the sentence is longer than 20, then this truncates it to 20 words # pad the tokenized sentence padded_sent_toks = [0] * 20 padded_sent_toks[:len(sentence_tokenized)] = sentence_tokenized # create a sentence token mask: 1 for real words; 0 for padded tokens attention_mask = [0] * 20 attention_mask[:len(sentence_tokenized)] = [1]*len(sentence_tokenized) # convert lists to tensors padded_sent_toks = torch.tensor(padded_sent_toks).unsqueeze(0) # (1, 20) attention_mask = torch.tensor(attention_mask).unsqueeze(0) # (1, 20) padded_sent_toks = padded_sent_toks.to(device) # for inference (input) attention_mask = attention_mask.to(device) # for inference (input) # initialize model and load weights #from bert.modeling_bert import BertModel #from lib import segmentation # construct a mini args class; like from a config file class args: swin_type = 'base' window12 = True mha = '' fusion_drop = 0.0 #single_model = segmentation.__dict__['lavt'](pretrained='', args=args) single_model = multimodal_segmentation_ppm.__dict__['lavt'](pretrained='',args=args) single_model.to(device) model_class = MultiModalBert single_bert_model = model_class.from_pretrained('bert-base-uncased', embed_dim=single_model.backbone.embed_dim) single_bert_model.pooler = None input_shape = dict() input_shape['s1'] = Dict({'channel': 128, 'stride': 4}) input_shape['s2'] = Dict({'channel': 256, 'stride': 8}) input_shape['s3'] = Dict({'channel': 512, 'stride': 16}) input_shape['s4'] = Dict({'channel': 1024, 'stride': 32}) cfg = Dict() cfg.MODEL.SEM_SEG_HEAD.COMMON_STRIDE = 4 cfg.MODEL.MASK_FORMER.DROPOUT = 0.0 cfg.MODEL.MASK_FORMER.NHEADS = 8 cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS = 4 cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM = 256 cfg.MODEL.SEM_SEG_HEAD.MASK_DIM = 256 cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES = ["s1", "s2", "s3", "s4"] cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES = 1 cfg.MODEL.MASK_FORMER.HIDDEN_DIM = 256 cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES = 1 cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD = 2048 cfg.MODEL.MASK_FORMER.DEC_LAYERS = 10 cfg.MODEL.MASK_FORMER.PRE_NORM = False maskformer_head = MaskFormerHead(cfg, input_shape) model = WrapperModel(single_model.backbone, single_bert_model, maskformer_head) checkpoint = torch.load(weights, map_location='cpu') model.load_state_dict(checkpoint['model'], strict=False) model.to(device) model.eval() #single_bert_model.load_state_dict(checkpoint['bert_model']) #single_model.load_state_dict(checkpoint['model']) #model = single_model.to(device) #bert_model = single_bert_model.to(device) # inference #import torch.nn.functional as F #last_hidden_states = bert_model(padded_sent_toks, attention_mask=attention_mask)[0] #embedding = last_hidden_states.permute(0, 2, 1) #output = model(img, embedding, l_mask=attention_mask.unsqueeze(-1)) #output = output.argmax(1, keepdim=True) # (1, 1, 480, 480) #output = F.interpolate(output.float(), (original_h, original_w)) # 'nearest'; resize to the original image size #output = output.squeeze() # (orig_h, orig_w) #output = output.cpu().data.numpy() # (orig_h, orig_w) output = model(img, padded_sent_toks, attention_mask)[0] #print(output[0].keys()) #print(output[1].shape) mask_cls_results = output["pred_logits"] mask_pred_results = output["pred_masks"] target_shape = img_ndarray.shape[:2] #print(target_shape, mask_pred_results.shape) mask_pred_results = F.interpolate(mask_pred_results, size=(480,480), mode='bilinear', align_corners=True) pred_masks = model.semantic_inference(mask_cls_results, mask_pred_results) #output = pred_masks[0] #output = output.cpu() #print(output.shape) #output_mask = output.argmax(1).data.numpy() #output = (output > 0.5).data.cpu().numpy() output = torch.nn.functional.interpolate(pred_masks, target_shape) output = (output > 0.5).data.cpu().numpy() # show/save results def overlay_davis(image, mask, colors=[[0, 0, 0], [255, 0, 0]], cscale=1, alpha=0.4): from scipy.ndimage.morphology import binary_dilation colors = np.reshape(colors, (-1, 3)) colors = np.atleast_2d(colors) * cscale im_overlay = image.copy() object_ids = np.unique(mask) for object_id in object_ids[1:]: # Overlay color on binary mask foreground = image*alpha + np.ones(image.shape)*(1-alpha) * np.array(colors[object_id]) binary_mask = mask == object_id # Compose image im_overlay[binary_mask] = foreground[binary_mask] # countours = skimage.morphology.binary.binary_dilation(binary_mask) - binary_mask countours = binary_dilation(binary_mask) ^ binary_mask # countours = cv2.dilate(binary_mask, cv2.getStructuringElement(cv2.MORPH_CROSS,(3,3))) - binary_mask im_overlay[countours, :] = 0 return im_overlay.astype(image.dtype) output = output.astype(np.uint8) # (orig_h, orig_w), np.uint8 # Overlay the mask on the image print(img_ndarray.shape, output.shape) visualization = overlay_davis(img_ndarray, output[0][0]) # red visualization = Image.fromarray(visualization) # show the visualization #visualization.show() # Save the visualization visualization.save('./demo/spoon_on_the_dish.jpg')