tcl / predictor.py
khanrc's picture
initial commit
e0ca513
import torch
import torch.nn as nn
from torchvision import transforms as T
from omegaconf import OmegaConf
from typing import List
from mmseg import datasets as mmseg_datasets
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
import numpy as np
from PIL import Image
from detectron2.data import MetadataCatalog
from detectron2.utils.visualizer import Visualizer
# TCL
from models import build_model
from models.tcl.pamr import PAMR
from datasets.builder import build_text_transform
from segmentation.evaluation.builder import build_dataset_class_tokens
PALETTE = mmseg_datasets.PascalVOCDataset.PALETTE + mmseg_datasets.COCOStuffDataset.PALETTE
PALETTE *= 5
def build_demo_model(ckpt_path="./tcl.pth", size=224):
# Load TCL model
print(f"Load {ckpt_path} ...")
ckpt = torch.load(ckpt_path)
cfg = OmegaConf.load("./tcl/configs/tcl.yml")
model = build_model(cfg.model)
# The (minimal) checkpoint only contains learned parameters; Frozen CLIP params are not contained.
model.load_state_dict(ckpt['model'], strict=False)
model.eval()
# build TCLDemo
demo = TCLDemo(model, size)
return demo
def _convert_image_to_rgb(image):
return image.convert("RGB")
def _transform(n_px):
return T.Compose([
T.Resize(n_px, interpolation=T.InterpolationMode.BICUBIC),
_convert_image_to_rgb,
T.ToTensor(),
T.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
])
class TCLDemo(nn.Module):
"""
Args:
model: TCL model
size: resize shorter side of image to `size`
"""
def __init__(self, model, size=224):
super().__init__()
self.model = model
self.size = size
self.preprocess = _transform(size)
self.tokenizer = build_text_transform()
self.pamr = PAMR(10, [1, 2, 4, 8, 12, 24]).eval()
@property
def device(self):
return next(self.model.parameters()).device
def build_text_embedding(self, texts: List[str]):
text_tokens = build_dataset_class_tokens(self.tokenizer, "custom", texts)
text_embeddings = self.model.build_text_embedding(text_tokens)
return text_embeddings
def forward(self, image, texts: List[str], apply_pamr=True):
"""
Args:
image: PIL.Image
texts: List[str]
"""
with_bg = False
if texts[0] in ["bg", "background"]:
with_bg = True
texts = texts[1:]
# preprocess
image = self.preprocess(image).unsqueeze(0).to(self.device)
text_embs = self.build_text_embedding(texts)
# forward
mask, simmap = self.model.generate_masks(
image,
text_embs,
)
# refinement
if apply_pamr:
mask = self.pamr(image, mask)
I, T, H, W = mask.shape
if with_bg:
bg_thresh = 0.4 if apply_pamr else 0.5
bg = torch.full(
[I, 1, H, W],
bg_thresh,
dtype=torch.float,
device=mask.device
)
mask = torch.cat([bg, mask], dim=1)
return mask
def visualize(self, image, texts, mask):
"""
Args:
image (PIL.Image)
texts (List[str])
mask (Tensor)
"""
with_bg = texts[0] in ["bg", "background"]
N = len(texts)
if with_bg:
palette = PALETTE
else:
palette = PALETTE[1:]
MetadataCatalog.pop("__unused", None)
md = MetadataCatalog.get("__unused")
md.set(
thing_classes=texts,
thing_colors=palette,
stuff_classes=texts,
stuff_colors=palette,
)
seg_res = mask.squeeze(0).argmax(0).cpu()
if with_bg:
seg_res[seg_res == 0] = N + 10
image = image.resize(mask.shape[2:][::-1])
image = np.asarray(image)
visualizer = Visualizer(image, md)
r = visualizer.draw_sem_seg(seg_res)
res = Image.fromarray(r.get_image())
return res
def forward_vis(self, image, texts, apply_pamr=True):
mask = self(image, texts, apply_pamr=apply_pamr)
res = self.visualize(image, texts, mask)
return res