Spaces:
Sleeping
Sleeping
import cv2 | |
import torch | |
import re | |
import numpy as np | |
from typing import List, Union | |
import nltk | |
import inflect | |
from transformers import AutoTokenizer | |
from torchvision import transforms as T | |
from maskrcnn_benchmark.modeling.detector import build_detection_model | |
from maskrcnn_benchmark.utils.checkpoint import DetectronCheckpointer | |
from maskrcnn_benchmark.structures.image_list import to_image_list | |
from maskrcnn_benchmark.structures.boxlist_ops import boxlist_iou | |
from maskrcnn_benchmark.structures.bounding_box import BoxList | |
from maskrcnn_benchmark import layers as L | |
from maskrcnn_benchmark.modeling.roi_heads.mask_head.inference import Masker | |
from maskrcnn_benchmark.utils import cv2_util | |
engine = inflect.engine() | |
nltk.download("punkt") | |
nltk.download("averaged_perceptron_tagger") | |
import timeit | |
class GLIPDemo(object): | |
def __init__( | |
self, | |
cfg, | |
confidence_threshold=0.7, | |
min_image_size=None, | |
show_mask_heatmaps=False, | |
masks_per_dim=5, | |
): | |
self.cfg = cfg.clone() | |
self.model = build_detection_model(cfg) | |
self.model.eval() | |
self.device = torch.device(cfg.MODEL.DEVICE) | |
self.model.to(self.device) | |
self.min_image_size = min_image_size | |
self.show_mask_heatmaps = show_mask_heatmaps | |
self.masks_per_dim = masks_per_dim | |
save_dir = cfg.OUTPUT_DIR | |
checkpointer = DetectronCheckpointer(cfg, self.model, save_dir=save_dir) | |
_ = checkpointer.load(cfg.MODEL.WEIGHT) | |
self.transforms = self.build_transform() | |
# used to make colors for each tokens | |
mask_threshold = -1 if show_mask_heatmaps else 0.5 | |
self.masker = Masker(threshold=mask_threshold, padding=1) | |
self.palette = torch.tensor([2**25 - 1, 2**15 - 1, 2**21 - 1]) | |
self.cpu_device = torch.device("cpu") | |
self.confidence_threshold = confidence_threshold | |
self.tokenizer = self.build_tokenizer() | |
def build_transform(self): | |
""" | |
Creates a basic transformation that was used to train the models | |
""" | |
cfg = self.cfg | |
# we are loading images with OpenCV, so we don't need to convert them | |
# to BGR, they are already! So all we need to do is to normalize | |
# by 255 if we want to convert to BGR255 format, or flip the channels | |
# if we want it to be in RGB in [0-1] range. | |
if cfg.INPUT.TO_BGR255: | |
to_bgr_transform = T.Lambda(lambda x: x * 255) | |
else: | |
to_bgr_transform = T.Lambda(lambda x: x[[2, 1, 0]]) | |
normalize_transform = T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD) | |
transform = T.Compose( | |
[ | |
T.ToPILImage(), | |
T.Resize(self.min_image_size) if self.min_image_size is not None else lambda x: x, | |
T.ToTensor(), | |
to_bgr_transform, | |
normalize_transform, | |
] | |
) | |
return transform | |
def build_tokenizer(self): | |
cfg = self.cfg | |
tokenizer = None | |
if cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE == "bert-base-uncased": | |
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") | |
if cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE == "roberta-base": | |
tokenizer = AutoTokenizer.from_pretrained("roberta-base") | |
elif cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE == "clip": | |
from transformers import CLIPTokenizerFast | |
if cfg.MODEL.DYHEAD.FUSE_CONFIG.MLM_LOSS: | |
tokenizer = CLIPTokenizerFast.from_pretrained( | |
"openai/clip-vit-base-patch32", from_slow=True, mask_token="ðŁĴij</w>" | |
) | |
else: | |
tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch32", from_slow=True) | |
return tokenizer | |
def run_ner(self, caption, refexp_mode=False): | |
noun_phrases = find_noun_phrases(caption) | |
noun_phrases = [remove_punctuation(phrase) for phrase in noun_phrases] | |
noun_phrases = [phrase for phrase in noun_phrases if phrase != ""] | |
relevant_phrases = noun_phrases | |
if refexp_mode: | |
noun_phrases = [caption] | |
relevant_phrases = [caption] | |
labels = noun_phrases | |
self.entities = labels | |
tokens_positive = [] | |
for entity, label in zip(relevant_phrases, labels): | |
try: | |
# search all occurrences and mark them as different entities | |
for m in re.finditer(entity, caption.lower()): | |
tokens_positive.append([[m.start(), m.end()]]) | |
except: | |
print("noun entities:", noun_phrases) | |
print("entity:", entity) | |
print("caption:", caption.lower()) | |
return tokens_positive | |
def inference(self, original_image, original_caption, refexp_mode=False): | |
predictions = self.compute_prediction(original_image, original_caption, refexp_mode) | |
top_predictions = self._post_process_fixed_thresh(predictions) | |
return top_predictions | |
def run_on_web_image(self, original_image, original_caption, thresh=0.5, refexp_mode=False): | |
predictions = self.compute_prediction(original_image, original_caption, refexp_mode) | |
top_predictions = self._post_process(predictions, thresh) | |
result = original_image.copy() | |
if self.show_mask_heatmaps: | |
return self.create_mask_montage(result, top_predictions) | |
result = self.overlay_boxes(result, top_predictions) | |
result = self.overlay_entity_names(result, top_predictions) | |
if self.cfg.MODEL.MASK_ON: | |
result = self.overlay_mask(result, top_predictions) | |
return result, top_predictions | |
def compute_prediction(self, original_image, original_caption, refexp_mode=False): | |
# image | |
image = self.transforms(original_image) | |
image_list = to_image_list(image, self.cfg.DATALOADER.SIZE_DIVISIBILITY) | |
image_list = image_list.to(self.device) | |
# caption | |
tokenized = self.tokenizer([original_caption], return_tensors="pt") | |
tokens_positive = self.run_ner(original_caption, refexp_mode) | |
# process positive map | |
positive_map = create_positive_map(tokenized, tokens_positive) | |
if self.cfg.MODEL.RPN_ARCHITECTURE == "VLDYHEAD": | |
plus = 1 | |
else: | |
plus = 0 | |
positive_map_label_to_token = create_positive_map_label_to_token_from_positive_map(positive_map, plus=plus) | |
self.plus = plus | |
self.positive_map_label_to_token = positive_map_label_to_token | |
tic = timeit.time.perf_counter() | |
# compute predictions | |
with torch.no_grad(): | |
predictions = self.model(image_list, captions=[original_caption], positive_map=positive_map_label_to_token) | |
predictions = [o.to(self.cpu_device) for o in predictions] | |
print("inference time per image: {}".format(timeit.time.perf_counter() - tic)) | |
# always single image is passed at a time | |
prediction = predictions[0] | |
# reshape prediction (a BoxList) into the original image size | |
height, width = original_image.shape[:-1] | |
prediction = prediction.resize((width, height)) | |
if prediction.has_field("mask"): | |
# if we have masks, paste the masks in the right position | |
# in the image, as defined by the bounding boxes | |
masks = prediction.get_field("mask") | |
# always single image is passed at a time | |
masks = self.masker([masks], [prediction])[0] | |
prediction.add_field("mask", masks) | |
return prediction | |
def _post_process_fixed_thresh(self, predictions): | |
scores = predictions.get_field("scores") | |
labels = predictions.get_field("labels").tolist() | |
thresh = scores.clone() | |
for i, lb in enumerate(labels): | |
if isinstance(self.confidence_threshold, float): | |
thresh[i] = self.confidence_threshold | |
elif len(self.confidence_threshold) == 1: | |
thresh[i] = self.confidence_threshold[0] | |
else: | |
thresh[i] = self.confidence_threshold[lb - 1] | |
keep = torch.nonzero(scores > thresh).squeeze(1) | |
predictions = predictions[keep] | |
scores = predictions.get_field("scores") | |
_, idx = scores.sort(0, descending=True) | |
return predictions[idx] | |
def _post_process(self, predictions, threshold=0.5): | |
scores = predictions.get_field("scores") | |
labels = predictions.get_field("labels").tolist() | |
thresh = scores.clone() | |
for i, lb in enumerate(labels): | |
if isinstance(self.confidence_threshold, float): | |
thresh[i] = threshold | |
elif len(self.confidence_threshold) == 1: | |
thresh[i] = threshold | |
else: | |
thresh[i] = self.confidence_threshold[lb - 1] | |
keep = torch.nonzero(scores > thresh).squeeze(1) | |
predictions = predictions[keep] | |
scores = predictions.get_field("scores") | |
_, idx = scores.sort(0, descending=True) | |
return predictions[idx] | |
def compute_colors_for_labels(self, labels): | |
""" | |
Simple function that adds fixed colors depending on the class | |
""" | |
colors = (30 * (labels[:, None] - 1) + 1) * self.palette | |
colors = (colors % 255).numpy().astype("uint8") | |
return colors | |
def overlay_boxes(self, image, predictions): | |
labels = predictions.get_field("labels") | |
boxes = predictions.bbox | |
colors = self.compute_colors_for_labels(labels).tolist() | |
for box, color in zip(boxes, colors): | |
box = box.to(torch.int64) | |
top_left, bottom_right = box[:2].tolist(), box[2:].tolist() | |
image = cv2.rectangle(image, tuple(top_left), tuple(bottom_right), tuple(color), 2) | |
return image | |
def overlay_scores(self, image, predictions): | |
scores = predictions.get_field("scores") | |
boxes = predictions.bbox | |
for box, score in zip(boxes, scores): | |
box = box.to(torch.int64) | |
image = cv2.putText( | |
image, | |
"%.3f" % score, | |
(int(box[0]), int((box[1] + box[3]) / 2)), | |
cv2.FONT_HERSHEY_DUPLEX , | |
0.3, | |
(255, 255, 255), | |
1, | |
) | |
return image | |
def overlay_entity_names(self, image, predictions, names=None): | |
scores = predictions.get_field("scores").tolist() | |
labels = predictions.get_field("labels").tolist() | |
new_labels = [] | |
if self.entities and self.plus: | |
for i in labels: | |
if i <= len(self.entities): | |
new_labels.append(self.entities[i - self.plus]) | |
else: | |
new_labels.append("object") | |
# labels = [self.entities[i - self.plus] for i in labels ] | |
else: | |
new_labels = ["object" for i in labels] | |
boxes = predictions.bbox | |
template = "{}: {:.2f}" | |
for box, score, label in zip(boxes, scores, new_labels): | |
x, y = box[:2] | |
s = template.format(label, score) | |
cv2.putText(image, s, (int(x), int(y + 10)), cv2.FONT_HERSHEY_DUPLEX, 0.5, (255, 255, 255), 1) | |
return image | |
def overlay_mask(self, image, predictions): | |
masks = predictions.get_field("mask").numpy() | |
labels = predictions.get_field("labels") | |
colors = self.compute_colors_for_labels(labels).tolist() | |
# import pdb | |
# pdb.set_trace() | |
# masks = masks > 0.1 | |
for mask, color in zip(masks, colors): | |
thresh = mask[0, :, :, None].astype(np.uint8) | |
contours, hierarchy = cv2_util.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) | |
image = cv2.drawContours(image, contours, -1, color, 2) | |
composite = image | |
return composite | |
def create_mask_montage(self, image, predictions): | |
masks = predictions.get_field("mask") | |
masks_per_dim = self.masks_per_dim | |
masks = L.interpolate(masks.float(), scale_factor=1 / masks_per_dim).byte() | |
height, width = masks.shape[-2:] | |
max_masks = masks_per_dim**2 | |
masks = masks[:max_masks] | |
# handle case where we have less detections than max_masks | |
if len(masks) < max_masks: | |
masks_padded = torch.zeros(max_masks, 1, height, width, dtype=torch.uint8) | |
masks_padded[: len(masks)] = masks | |
masks = masks_padded | |
masks = masks.reshape(masks_per_dim, masks_per_dim, height, width) | |
result = torch.zeros((masks_per_dim * height, masks_per_dim * width), dtype=torch.uint8) | |
for y in range(masks_per_dim): | |
start_y = y * height | |
end_y = (y + 1) * height | |
for x in range(masks_per_dim): | |
start_x = x * width | |
end_x = (x + 1) * width | |
result[start_y:end_y, start_x:end_x] = masks[y, x] | |
return cv2.applyColorMap(result.numpy(), cv2.COLORMAP_JET), None | |
def create_positive_map_label_to_token_from_positive_map(positive_map, plus=0): | |
positive_map_label_to_token = {} | |
for i in range(len(positive_map)): | |
positive_map_label_to_token[i + plus] = torch.nonzero(positive_map[i], as_tuple=True)[0].tolist() | |
return positive_map_label_to_token | |
def create_positive_map(tokenized, tokens_positive): | |
"""construct a map such that positive_map[i,j] = True iff box i is associated to token j""" | |
positive_map = torch.zeros((len(tokens_positive), 256), dtype=torch.float) | |
for j, tok_list in enumerate(tokens_positive): | |
for (beg, end) in tok_list: | |
try: | |
beg_pos = tokenized.char_to_token(beg) | |
end_pos = tokenized.char_to_token(end - 1) | |
except Exception as e: | |
print("beg:", beg, "end:", end) | |
print("token_positive:", tokens_positive) | |
# print("beg_pos:", beg_pos, "end_pos:", end_pos) | |
raise e | |
if beg_pos is None: | |
try: | |
beg_pos = tokenized.char_to_token(beg + 1) | |
if beg_pos is None: | |
beg_pos = tokenized.char_to_token(beg + 2) | |
except: | |
beg_pos = None | |
if end_pos is None: | |
try: | |
end_pos = tokenized.char_to_token(end - 2) | |
if end_pos is None: | |
end_pos = tokenized.char_to_token(end - 3) | |
except: | |
end_pos = None | |
if beg_pos is None or end_pos is None: | |
continue | |
assert beg_pos is not None and end_pos is not None | |
positive_map[j, beg_pos : end_pos + 1].fill_(1) | |
return positive_map / (positive_map.sum(-1)[:, None] + 1e-6) | |
def find_noun_phrases(caption: str) -> List[str]: | |
caption = caption.lower() | |
tokens = nltk.word_tokenize(caption) | |
pos_tags = nltk.pos_tag(tokens) | |
grammar = "NP: {<DT>?<JJ.*>*<NN.*>+}" | |
cp = nltk.RegexpParser(grammar) | |
result = cp.parse(pos_tags) | |
noun_phrases = list() | |
for subtree in result.subtrees(): | |
if subtree.label() == "NP": | |
noun_phrases.append(" ".join(t[0] for t in subtree.leaves())) | |
return noun_phrases | |
def remove_punctuation(text: str) -> str: | |
punct = [ | |
"|", | |
":", | |
";", | |
"@", | |
"(", | |
")", | |
"[", | |
"]", | |
"{", | |
"}", | |
"^", | |
"'", | |
'"', | |
"’", | |
"`", | |
"?", | |
"$", | |
"%", | |
"#", | |
"!", | |
"&", | |
"*", | |
"+", | |
",", | |
".", | |
] | |
for p in punct: | |
text = text.replace(p, "") | |
return text.strip() | |