import torch import torch.nn as nn from torchvision import transforms as T from omegaconf import OmegaConf from typing import List from mmseg import datasets as mmseg_datasets from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD import numpy as np from PIL import Image from detectron2.data import MetadataCatalog from detectron2.utils.visualizer import Visualizer # TCL from models import build_model from models.tcl.pamr import PAMR from datasets.builder import build_text_transform from segmentation.evaluation.builder import build_dataset_class_tokens PALETTE = mmseg_datasets.PascalVOCDataset.PALETTE + mmseg_datasets.COCOStuffDataset.PALETTE PALETTE *= 5 def build_demo_model(ckpt_path="./tcl.pth", size=224): # Load TCL model print(f"Load {ckpt_path} ...") ckpt = torch.load(ckpt_path) cfg = OmegaConf.load("./tcl/configs/tcl.yml") model = build_model(cfg.model) # The (minimal) checkpoint only contains learned parameters; Frozen CLIP params are not contained. model.load_state_dict(ckpt['model'], strict=False) model.eval() # build TCLDemo demo = TCLDemo(model, size) return demo def _convert_image_to_rgb(image): return image.convert("RGB") def _transform(n_px): return T.Compose([ T.Resize(n_px, interpolation=T.InterpolationMode.BICUBIC), _convert_image_to_rgb, T.ToTensor(), T.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD), ]) class TCLDemo(nn.Module): """ Args: model: TCL model size: resize shorter side of image to `size` """ def __init__(self, model, size=224): super().__init__() self.model = model self.size = size self.preprocess = _transform(size) self.tokenizer = build_text_transform() self.pamr = PAMR(10, [1, 2, 4, 8, 12, 24]).eval() @property def device(self): return next(self.model.parameters()).device def build_text_embedding(self, texts: List[str]): text_tokens = build_dataset_class_tokens(self.tokenizer, "custom", texts) text_embeddings = self.model.build_text_embedding(text_tokens) return text_embeddings def forward(self, image, texts: List[str], apply_pamr=True): """ Args: image: PIL.Image texts: List[str] """ with_bg = False if texts[0] in ["bg", "background"]: with_bg = True texts = texts[1:] # preprocess image = self.preprocess(image).unsqueeze(0).to(self.device) text_embs = self.build_text_embedding(texts) # forward mask, simmap = self.model.generate_masks( image, text_embs, ) # refinement if apply_pamr: mask = self.pamr(image, mask) I, T, H, W = mask.shape if with_bg: bg_thresh = 0.4 if apply_pamr else 0.5 bg = torch.full( [I, 1, H, W], bg_thresh, dtype=torch.float, device=mask.device ) mask = torch.cat([bg, mask], dim=1) return mask def visualize(self, image, texts, mask): """ Args: image (PIL.Image) texts (List[str]) mask (Tensor) """ with_bg = texts[0] in ["bg", "background"] N = len(texts) if with_bg: palette = PALETTE else: palette = PALETTE[1:] MetadataCatalog.pop("__unused", None) md = MetadataCatalog.get("__unused") md.set( thing_classes=texts, thing_colors=palette, stuff_classes=texts, stuff_colors=palette, ) seg_res = mask.squeeze(0).argmax(0).cpu() if with_bg: seg_res[seg_res == 0] = N + 10 image = image.resize(mask.shape[2:][::-1]) image = np.asarray(image) visualizer = Visualizer(image, md) r = visualizer.draw_sem_seg(seg_res) res = Image.fromarray(r.get_image()) return res def forward_vis(self, image, texts, apply_pamr=True): mask = self(image, texts, apply_pamr=apply_pamr) res = self.visualize(image, texts, mask) return res