# Copyright (c) Facebook, Inc. and its affiliates. from typing import Tuple import torch from torch import nn from torch.nn import functional as F from detectron2.config import configurable from detectron2.data import MetadataCatalog from detectron2.modeling import META_ARCH_REGISTRY, build_backbone, build_sem_seg_head from detectron2.modeling.backbone import Backbone from detectron2.modeling.postprocessing import sem_seg_postprocess from detectron2.structures import ImageList from detectron2.utils.memory import _ignore_torch_cuda_oom import numpy as np from einops import rearrange from segment_anything import SamPredictor, sam_model_registry, SamAutomaticMaskGenerator @META_ARCH_REGISTRY.register() class CATSeg(nn.Module): @configurable def __init__( self, *, backbone: Backbone, sem_seg_head: nn.Module, size_divisibility: int, pixel_mean: Tuple[float], pixel_std: Tuple[float], clip_pixel_mean: Tuple[float], clip_pixel_std: Tuple[float], train_class_json: str, test_class_json: str, sliding_window: bool, clip_finetune: str, backbone_multiplier: float, clip_pretrained: str, ): """ Args: backbone: a backbone module, must follow detectron2's backbone interface sem_seg_head: a module that predicts semantic segmentation from backbone features """ super().__init__() self.backbone = backbone self.sem_seg_head = sem_seg_head if size_divisibility < 0: size_divisibility = self.backbone.size_divisibility self.size_divisibility = size_divisibility self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) self.register_buffer("clip_pixel_mean", torch.Tensor(clip_pixel_mean).view(-1, 1, 1), False) self.register_buffer("clip_pixel_std", torch.Tensor(clip_pixel_std).view(-1, 1, 1), False) self.train_class_json = train_class_json self.test_class_json = test_class_json self.clip_finetune = clip_finetune for name, params in self.sem_seg_head.predictor.clip_model.named_parameters(): if "visual" in name: if clip_finetune == "prompt": params.requires_grad = True if "prompt" in name else False elif clip_finetune == "attention": params.requires_grad = True if "attn" in name or "position" in name else False elif clip_finetune == "full": params.requires_grad = True else: params.requires_grad = False else: params.requires_grad = False finetune_backbone = backbone_multiplier > 0. for name, params in self.backbone.named_parameters(): if "norm0" in name: params.requires_grad = False else: params.requires_grad = finetune_backbone self.sliding_window = sliding_window self.clip_resolution = (384, 384) if clip_pretrained == "ViT-B/16" else (336, 336) self.sequential = False self.use_sam = False self.sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth").to(self.device) amg_kwargs = { "points_per_side": 32, "points_per_batch": None, #"pred_iou_thresh": 0.0, #"stability_score_thresh": 0.0, "stability_score_offset": None, "box_nms_thresh": None, "crop_n_layers": None, "crop_nms_thresh": None, "crop_overlap_ratio": None, "crop_n_points_downscale_factor": None, "min_mask_region_area": None, } amg_kwargs = {k: v for k, v in amg_kwargs.items() if v is not None} self.mask = SamAutomaticMaskGenerator(self.sam, output_mode="binary_mask", **amg_kwargs) self.overlap_threshold = 0.8 self.panoptic_on = False @classmethod def from_config(cls, cfg): backbone = build_backbone(cfg) sem_seg_head = build_sem_seg_head(cfg, backbone.output_shape()) return { "backbone": backbone, "sem_seg_head": sem_seg_head, "size_divisibility": cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY, "pixel_mean": cfg.MODEL.PIXEL_MEAN, "pixel_std": cfg.MODEL.PIXEL_STD, "clip_pixel_mean": cfg.MODEL.CLIP_PIXEL_MEAN, "clip_pixel_std": cfg.MODEL.CLIP_PIXEL_STD, "train_class_json": cfg.MODEL.SEM_SEG_HEAD.TRAIN_CLASS_JSON, "test_class_json": cfg.MODEL.SEM_SEG_HEAD.TEST_CLASS_JSON, "sliding_window": cfg.TEST.SLIDING_WINDOW, "clip_finetune": cfg.MODEL.SEM_SEG_HEAD.CLIP_FINETUNE, "backbone_multiplier": cfg.SOLVER.BACKBONE_MULTIPLIER, "clip_pretrained": cfg.MODEL.SEM_SEG_HEAD.CLIP_PRETRAINED, } @property def device(self): return self.pixel_mean.device def forward(self, batched_inputs): """ Args: batched_inputs: a list, batched outputs of :class:`DatasetMapper`. Each item in the list contains the inputs for one image. For now, each item in the list is a dict that contains: * "image": Tensor, image in (C, H, W) format. * "instances": per-region ground truth * Other information that's included in the original dicts, such as: "height", "width" (int): the output resolution of the model (may be different from input resolution), used in inference. Returns: list[dict]: each dict has the results for one image. The dict contains the following keys: * "sem_seg": A Tensor that represents the per-pixel segmentation prediced by the head. The prediction has shape KxHxW that represents the logits of each class for each pixel. """ images = [x["image"].to(self.device) for x in batched_inputs] sam_images = images if not self.training and self.sliding_window: if not self.sequential: with _ignore_torch_cuda_oom(): return self.inference_sliding_window(batched_inputs) self.sequential = True return self.inference_sliding_window(batched_inputs) clip_images = [(x - self.clip_pixel_mean) / self.clip_pixel_std for x in images] clip_images = ImageList.from_tensors(clip_images, self.size_divisibility) images = [(x - self.pixel_mean) / self.pixel_std for x in images] images = ImageList.from_tensors(images, self.size_divisibility) clip_images = F.interpolate(clip_images.tensor, size=self.clip_resolution, mode='bilinear', align_corners=False, ) clip_features = self.sem_seg_head.predictor.clip_model.encode_image(clip_images, dense=True) images_resized = F.interpolate(images.tensor, size=(384, 384), mode='bilinear', align_corners=False,) features = self.backbone(images_resized) outputs = self.sem_seg_head(clip_features, features) if self.training: targets = torch.stack([x["sem_seg"].to(self.device) for x in batched_inputs], dim=0) outputs = F.interpolate(outputs, size=(targets.shape[-2], targets.shape[-1]), mode="bilinear", align_corners=False) num_classes = outputs.shape[1] mask = targets != self.sem_seg_head.ignore_value outputs = outputs.permute(0,2,3,1) _targets = torch.zeros(outputs.shape, device=self.device) _onehot = F.one_hot(targets[mask], num_classes=num_classes).float() _targets[mask] = _onehot loss = F.binary_cross_entropy_with_logits(outputs, _targets) losses = {"loss_sem_seg" : loss} return losses else: #outputs = outputs.sigmoid() image_size = images.image_sizes[0] if self.use_sam: masks = self.mask.generate(np.uint8(sam_images[0].permute(1, 2, 0).cpu().numpy())) outputs, sam_cls = self.discrete_semantic_inference(outputs, masks, image_size) #outputs, sam_cls = self.continuous_semantic_inference(outputs, masks, image_size) #outputs, sam_cls = self.continuous_semantic_inference2(outputs, masks, image_size, img=img, text=text) height = batched_inputs[0].get("height", image_size[0]) width = batched_inputs[0].get("width", image_size[1]) output = sem_seg_postprocess(outputs[0], image_size, height, width) processed_results = [{'sem_seg': output}] return processed_results @torch.no_grad() def inference_sliding_window(self, batched_inputs, kernel=384, overlap=0.333, out_res=[640, 640]): images = [x["image"].to(self.device, dtype=torch.float32) for x in batched_inputs] stride = int(kernel * (1 - overlap)) unfold = nn.Unfold(kernel_size=kernel, stride=stride) fold = nn.Fold(out_res, kernel_size=kernel, stride=stride) image = F.interpolate(images[0].unsqueeze(0), size=out_res, mode='bilinear', align_corners=False).squeeze() sam_images = [image] image = rearrange(unfold(image), "(C H W) L-> L C H W", C=3, H=kernel) global_image = F.interpolate(images[0].unsqueeze(0), size=(kernel, kernel), mode='bilinear', align_corners=False) image = torch.cat((image, global_image), dim=0) images = (image - self.pixel_mean) / self.pixel_std clip_images = (image - self.clip_pixel_mean) / self.clip_pixel_std clip_images = F.interpolate(clip_images, size=self.clip_resolution, mode='bilinear', align_corners=False, ) clip_features = self.sem_seg_head.predictor.clip_model.encode_image(clip_images, dense=True) if self.sequential: outputs = [] for clip_feat, image in zip(clip_features, images): feature = self.backbone(image.unsqueeze(0)) output = self.sem_seg_head(clip_feat.unsqueeze(0), feature) outputs.append(output[0]) outputs = torch.stack(outputs, dim=0) else: features = self.backbone(images) outputs = self.sem_seg_head(clip_features, features) outputs = F.interpolate(outputs, size=kernel, mode="bilinear", align_corners=False) outputs = outputs.sigmoid() global_output = outputs[-1:] global_output = F.interpolate(global_output, size=out_res, mode='bilinear', align_corners=False,) outputs = outputs[:-1] outputs = fold(outputs.flatten(1).T) / fold(unfold(torch.ones([1] + out_res, device=self.device))) outputs = (outputs + global_output) / 2. height = batched_inputs[0].get("height", out_res[0]) width = batched_inputs[0].get("width", out_res[1]) catseg_outputs = sem_seg_postprocess(outputs[0], out_res, height, width) #catseg_outputs = catseg_outputs.argmax(dim=1)[0].cpu() masks = self.mask.generate(np.uint8(sam_images[0].permute(1, 2, 0).cpu().numpy())) if self.use_sam: outputs, sam_cls = self.discrete_semantic_inference(outputs, masks, out_res) #outputs, sam_cls = self.continuous_semantic_inference(outputs, masks, out_res) output = sem_seg_postprocess(outputs[0], out_res, height, width) ret = [{'sem_seg': output}] if self.panoptic_on: panoptic_r = self.panoptic_inference(catseg_outputs, masks, sam_cls, size=output.shape[-2:]) ret[0]['panoptic_seg'] = panoptic_r return ret def discrete_semantic_inference(self, outputs, masks, image_size): catseg_outputs = F.interpolate(outputs, size=image_size, mode="bilinear", align_corners=True) #.argmax(dim=1)[0].cpu() sam_outputs = torch.zeros_like(catseg_outputs).cpu() catseg_outputs = catseg_outputs.argmax(dim=1)[0].cpu() sam_classes = torch.zeros(len(masks)) for i in range(len(masks)): m = masks[i]['segmentation'] s = masks[i]['stability_score'] idx = catseg_outputs[m].bincount().argmax() sam_outputs[0, idx][m] = s sam_classes[i] = idx return sam_outputs, sam_classes def continuous_semantic_inference(self, outputs, masks, image_size, scale=100/7.): #import pdb; pdb.set_trace() catseg_outputs = F.interpolate(outputs, size=image_size, mode="bilinear", align_corners=True)[0].cpu() sam_outputs = torch.zeros_like(catseg_outputs) #catseg_outputs = catseg_outputs.argmax(dim=1)[0].cpu() sam_classes = torch.zeros(len(masks)) #import pdb; pdb.set_trace() mask_pred = torch.tensor(np.asarray([x['segmentation'] for x in masks]), dtype=torch.float32) # N H W mask_score = torch.tensor(np.asarray([x['predicted_iou'] for x in masks]), dtype=torch.float32) # N mask_cls = torch.einsum("nhw, chw -> nc", mask_pred, catseg_outputs) mask_norm = mask_pred.sum(-1).sum(-1) mask_cls = mask_cls / mask_norm[:, None] mask_cls = mask_cls / mask_cls.norm(p=1, dim=1)[:, None] mask_logits = mask_pred * mask_score[:, None, None] output = torch.einsum("nhw, nc -> chw", mask_logits, mask_cls) return output.unsqueeze(0), mask_cls def continuous_semantic_inference2(self, outputs, masks, image_size, scale=100/7., img=None, text=None): assert img is not None and text is not None import pdb; pdb.set_trace() #catseg_outputs = F.interpolate(outputs, size=image_size, mode="bilinear", align_corners=True)[0].cpu() img = F.interpolate(img, size=image_size, mode="bilinear", align_corners=True)[0].cpu() img = img.permute(1, 2, 0) #sam_outputs = torch.zeros_like(catseg_outputs) #catseg_outputs = catseg_outputs.argmax(dim=1)[0].cpu() sam_classes = torch.zeros(len(masks)) #import pdb; pdb.set_trace() mask_pred = torch.tensor(np.asarray([x['segmentation'] for x in masks]), dtype=torch.float32) # N H W mask_score = torch.tensor(np.asarray([x['predicted_iou'] for x in masks]), dtype=torch.float32) # N mask_pool = torch.einsum("nhw, hwd -> nd ", mask_pred, img) mask_pool = mask_pool / mask_pool.norm(dim=1, keepdim=True) mask_cls = torch.einsum("nd, cd -> nc", 100 * mask_pool, text.cpu()) mask_cls = mask_cls.softmax(dim=1) #mask_cls = torch.einsum("nhw, chw -> nc", mask_pred, catseg_outputs) mask_norm = mask_pred.sum(-1).sum(-1) mask_cls = mask_cls / mask_norm[:, None] mask_cls = mask_cls / mask_cls.norm(p=1, dim=1)[:, None] mask_logits = mask_pred * mask_score[:, None, None] output = torch.einsum("nhw, nc -> chw", mask_logits, mask_cls) return output.unsqueeze(0), sam_classes def panoptic_inference(self, outputs, masks, sam_classes, size=None): #import pdb; pdb.set_trace() scores = np.asarray([x['predicted_iou'] for x in masks]) mask_pred = np.asarray([x['segmentation'] for x in masks]) #keep = labels.ne(self.sem_seg_head.num_classes) & (scores > self.object_mask_threshold) cur_scores = torch.tensor(scores) cur_masks = torch.tensor(mask_pred) cur_masks = F.interpolate(cur_masks.unsqueeze(0).float(), size=outputs.shape[-2:], mode="nearest")[0] cur_classes = sam_classes.argmax(dim=-1) #cur_mask_cls = mask_cls#[keep] #cur_mask_cls = cur_mask_cls[:, :-1] #import pdb; pdb.set_trace() cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks h, w = cur_masks.shape[-2:] panoptic_seg = torch.zeros((h, w), dtype=torch.int32, device=cur_masks.device) segments_info = [] current_segment_id = 0 if cur_masks.shape[0] == 0: # We didn't detect any mask :( return panoptic_seg, segments_info else: # take argmax cur_mask_ids = cur_prob_masks.argmax(0) stuff_memory_list = {} for k in range(cur_classes.shape[0]): pred_class = cur_classes[k].item() #isthing = pred_class in self.metadata.thing_dataset_id_to_contiguous_id.values() isthing = pred_class in [3, 6] #[i for i in range(10)]#self.metadata.thing_dataset_id_to_contiguous_id.values() mask = cur_mask_ids == k mask_area = mask.sum().item() original_area = (cur_masks[k] >= 0.5).sum().item() if mask_area > 0 and original_area > 0: if mask_area / original_area < self.overlap_threshold: continue # merge stuff regions if not isthing: if int(pred_class) in stuff_memory_list.keys(): panoptic_seg[mask] = stuff_memory_list[int(pred_class)] continue else: stuff_memory_list[int(pred_class)] = current_segment_id + 1 current_segment_id += 1 panoptic_seg[mask] = current_segment_id segments_info.append( { "id": current_segment_id, "isthing": bool(isthing), "category_id": int(pred_class), } ) return panoptic_seg, segments_info