JohanDL's picture
initial commit
history blame
7.59 kB
Author: Siyuan Li
Licensed: Apache-2.0 License
from typing import List, Union
import numpy as np
import pycocotools.mask as mask_util
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmdet.registry import MODELS
from mmdet.structures import SampleList
from mmengine.logging import print_log
from torch import Tensor
from projects.Detic_new.detic import Detic
def encode_mask_results(mask_results):
"""Encode bitmap mask to RLE code.
mask_results (list): bitmap mask results.
list | tuple: RLE encoded mask.
encoded_mask_results = []
for mask in mask_results:
np.array(mask[:, :, np.newaxis], order="F", dtype="uint8")
) # encoded with RLE
return encoded_mask_results
class CLIPTextEncoder(nn.Module):
def __init__(self, model_name="ViT-B/32"):
import clip
from clip.simple_tokenizer import SimpleTokenizer
self.tokenizer = SimpleTokenizer()
pretrained_model, _ = clip.load(model_name, device="cpu")
self.clip = pretrained_model
def device(self):
return self.clip.device
def dtype(self):
return self.clip.dtype
def tokenize(
self, texts: Union[str, List[str]], context_length: int = 77
) -> torch.LongTensor:
if isinstance(texts, str):
texts = [texts]
sot_token = self.tokenizer.encoder["<|startoftext|>"]
eot_token = self.tokenizer.encoder["<|endoftext|>"]
all_tokens = [
[sot_token] + self.tokenizer.encode(text) + [eot_token] for text in texts
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
for i, tokens in enumerate(all_tokens):
if len(tokens) > context_length:
st = torch.randint(len(tokens) - context_length + 1, (1,))[0].item()
tokens = tokens[st : st + context_length]
result[i, : len(tokens)] = torch.tensor(tokens)
return result
def forward(self, text):
text = self.tokenize(text)
text_features = self.clip.encode_text(text)
return text_features
def get_class_weight(original_caption, prompt_prefix="a "):
if isinstance(original_caption, str):
if original_caption == "coco":
from mmdet.datasets import CocoDataset
class_names = CocoDataset.METAINFO["classes"]
elif original_caption == "cityscapes":
from mmdet.datasets import CityscapesDataset
class_names = CityscapesDataset.METAINFO["classes"]
elif original_caption == "voc":
from mmdet.datasets import VOCDataset
class_names = VOCDataset.METAINFO["classes"]
elif original_caption == "openimages":
from mmdet.datasets import OpenImagesDataset
class_names = OpenImagesDataset.METAINFO["classes"]
elif original_caption == "lvis":
from mmdet.datasets import LVISV1Dataset
class_names = LVISV1Dataset.METAINFO["classes"]
if not original_caption.endswith("."):
original_caption = original_caption + " . "
original_caption = original_caption.split(" . ")
class_names = list(filter(lambda x: len(x) > 0, original_caption))
# for
class_names = list(original_caption)
text_encoder = CLIPTextEncoder()
texts = [prompt_prefix + x for x in class_names]
print_log(f"Computing text embeddings for {len(class_names)} classes.")
embeddings = text_encoder(texts).detach().permute(1, 0).contiguous().cpu()
return class_names, embeddings
def reset_cls_layer_weight(roi_head, weight):
if type(weight) == str:
print_log(f"Resetting cls_layer_weight from file: {weight}")
zs_weight = (
torch.tensor(np.load(weight), dtype=torch.float32)
.permute(1, 0)
) # D x C
zs_weight = weight
zs_weight =
[zs_weight, zs_weight.new_zeros((zs_weight.shape[0], 1))], dim=1
) # D x (C + 1)
zs_weight = F.normalize(zs_weight, p=2, dim=0)
zs_weight ="cuda")
num_classes = zs_weight.shape[-1]
for bbox_head in roi_head.bbox_head:
bbox_head.num_classes = num_classes
del bbox_head.fc_cls.zs_weight
bbox_head.fc_cls.zs_weight = zs_weight
class DeticMasa(Detic):
def predict(
batch_inputs: Tensor,
detection_features: Tensor,
batch_data_samples: SampleList,
rescale: bool = True,
) -> SampleList:
"""Predict results from a batch of inputs and data samples with post-
batch_inputs (Tensor): Inputs with shape (N, C, H, W).
batch_data_samples (List[:obj:`DetDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
rescale (bool): Whether to rescale the results.
Defaults to True.
list[:obj:`DetDataSample`]: Return the detection results of the
input images. The returns value is DetDataSample,
which usually contain 'pred_instances'. And the
``pred_instances`` usually contains following keys.
- scores (Tensor): Classification scores, has a shape
(num_instance, )
- labels (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes (Tensor): Has a shape (num_instances, 4),
the last dimension 4 arrange as (x1, y1, x2, y2).
- masks (Tensor): Has a shape (num_instances, H, W).
# For single image inference
if "custom_entities" in batch_data_samples[0]:
text_prompts = batch_data_samples[0].text
if text_prompts != self._text_prompts:
self._text_prompts = text_prompts
class_names, zs_weight = get_class_weight(text_prompts)
self._entities = class_names
reset_cls_layer_weight(self.roi_head, zs_weight)
assert self.with_bbox, "Bbox head must be implemented."
# x = self.extract_feat(batch_inputs)
x = detection_features
# If there are no pre-defined proposals, use RPN to get proposals
if batch_data_samples[0].get("proposals", None) is None:
rpn_results_list = self.rpn_head.predict(
x, batch_data_samples, rescale=False
rpn_results_list = [
data_sample.proposals for data_sample in batch_data_samples
results_list = self.roi_head.predict(
x, rpn_results_list, batch_data_samples, rescale=rescale
for data_sample, pred_instances in zip(batch_data_samples, results_list):
if len(pred_instances) > 0:
label_names = []
for labels in pred_instances.labels:
# for visualization
pred_instances.label_names = label_names
data_sample.pred_instances = pred_instances
return batch_data_samples