|
import torch |
|
import torch.nn as nn |
|
from torchvision import transforms as T |
|
from omegaconf import OmegaConf |
|
from typing import List |
|
from mmseg import datasets as mmseg_datasets |
|
|
|
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD |
|
|
|
import numpy as np |
|
from PIL import Image |
|
from detectron2.data import MetadataCatalog |
|
from detectron2.utils.visualizer import Visualizer |
|
|
|
|
|
from models import build_model |
|
from models.tcl.pamr import PAMR |
|
from datasets.builder import build_text_transform |
|
from segmentation.evaluation.builder import build_dataset_class_tokens |
|
|
|
PALETTE = mmseg_datasets.PascalVOCDataset.PALETTE + mmseg_datasets.COCOStuffDataset.PALETTE |
|
PALETTE *= 5 |
|
|
|
|
|
def build_demo_model(ckpt_path="./tcl.pth", size=224): |
|
|
|
print(f"Load {ckpt_path} ...") |
|
ckpt = torch.load(ckpt_path) |
|
cfg = OmegaConf.load("./tcl/configs/tcl.yml") |
|
model = build_model(cfg.model) |
|
|
|
|
|
model.load_state_dict(ckpt['model'], strict=False) |
|
model.eval() |
|
|
|
|
|
demo = TCLDemo(model, size) |
|
|
|
return demo |
|
|
|
|
|
def _convert_image_to_rgb(image): |
|
return image.convert("RGB") |
|
|
|
|
|
def _transform(n_px): |
|
return T.Compose([ |
|
T.Resize(n_px, interpolation=T.InterpolationMode.BICUBIC), |
|
_convert_image_to_rgb, |
|
T.ToTensor(), |
|
T.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD), |
|
]) |
|
|
|
|
|
class TCLDemo(nn.Module): |
|
""" |
|
Args: |
|
model: TCL model |
|
size: resize shorter side of image to `size` |
|
""" |
|
def __init__(self, model, size=224): |
|
super().__init__() |
|
self.model = model |
|
self.size = size |
|
|
|
self.preprocess = _transform(size) |
|
self.tokenizer = build_text_transform() |
|
self.pamr = PAMR(10, [1, 2, 4, 8, 12, 24]).eval() |
|
|
|
@property |
|
def device(self): |
|
return next(self.model.parameters()).device |
|
|
|
def build_text_embedding(self, texts: List[str]): |
|
text_tokens = build_dataset_class_tokens(self.tokenizer, "custom", texts) |
|
text_embeddings = self.model.build_text_embedding(text_tokens) |
|
return text_embeddings |
|
|
|
def forward(self, image, texts: List[str], apply_pamr=True): |
|
""" |
|
Args: |
|
image: PIL.Image |
|
texts: List[str] |
|
""" |
|
with_bg = False |
|
if texts[0] in ["bg", "background"]: |
|
with_bg = True |
|
texts = texts[1:] |
|
|
|
|
|
image = self.preprocess(image).unsqueeze(0).to(self.device) |
|
text_embs = self.build_text_embedding(texts) |
|
|
|
|
|
mask, simmap = self.model.generate_masks( |
|
image, |
|
text_embs, |
|
) |
|
|
|
|
|
if apply_pamr: |
|
mask = self.pamr(image, mask) |
|
|
|
I, T, H, W = mask.shape |
|
if with_bg: |
|
bg_thresh = 0.4 if apply_pamr else 0.5 |
|
bg = torch.full( |
|
[I, 1, H, W], |
|
bg_thresh, |
|
dtype=torch.float, |
|
device=mask.device |
|
) |
|
mask = torch.cat([bg, mask], dim=1) |
|
|
|
return mask |
|
|
|
def visualize(self, image, texts, mask): |
|
""" |
|
Args: |
|
image (PIL.Image) |
|
texts (List[str]) |
|
mask (Tensor) |
|
""" |
|
with_bg = texts[0] in ["bg", "background"] |
|
|
|
N = len(texts) |
|
if with_bg: |
|
palette = PALETTE |
|
else: |
|
palette = PALETTE[1:] |
|
|
|
MetadataCatalog.pop("__unused", None) |
|
md = MetadataCatalog.get("__unused") |
|
md.set( |
|
thing_classes=texts, |
|
thing_colors=palette, |
|
stuff_classes=texts, |
|
stuff_colors=palette, |
|
) |
|
|
|
seg_res = mask.squeeze(0).argmax(0).cpu() |
|
if with_bg: |
|
seg_res[seg_res == 0] = N + 10 |
|
|
|
image = image.resize(mask.shape[2:][::-1]) |
|
image = np.asarray(image) |
|
|
|
visualizer = Visualizer(image, md) |
|
r = visualizer.draw_sem_seg(seg_res) |
|
|
|
res = Image.fromarray(r.get_image()) |
|
|
|
return res |
|
|
|
def forward_vis(self, image, texts, apply_pamr=True): |
|
mask = self(image, texts, apply_pamr=apply_pamr) |
|
res = self.visualize(image, texts, mask) |
|
|
|
return res |
|
|