vstar / VisualSearch /utils /refer_seg_dataset.py
Penghao Wu
init
3672502
raw
history blame contribute delete
No virus
11.1 kB
import os
import random
from PIL import Image
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from pycocotools import mask
from transformers import CLIPImageProcessor
from transformers import OwlViTProcessor
from VisualSearch.model.llava import conversation as conversation_lib
from VisualSearch.utils.grefer import G_REFER
from VisualSearch.utils.refer import REFER
from VisualSearch.utils.utils import box_xyxy_to_cxcywh, expand2square
from VisualSearch.utils.utils import ANSWER_LIST, SHORT_QUESTION_LIST
class ReferSegDataset(torch.utils.data.Dataset):
pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1)
img_size = 1024
ignore_label = 255
def __init__(
self,
base_dir,
tokenizer,
vision_tower,
samples_per_epoch=500 * 8 * 2 * 10,
precision: str = "fp32",
num_classes_per_sample: int = 3,
exclude_val=False,
refer_seg_data="refclef||refcoco||refcoco+||refcocog",
):
self.exclude_val = exclude_val
self.samples_per_epoch = samples_per_epoch
self.num_classes_per_sample = num_classes_per_sample
self.base_dir = base_dir
self.tokenizer = tokenizer
self.precision = precision
self.transform = OwlViTProcessor.from_pretrained("google/owlvit-base-patch16")
self.clip_image_processor = CLIPImageProcessor.from_pretrained(vision_tower)
self.short_question_list = SHORT_QUESTION_LIST
self.answer_list = ANSWER_LIST
DATA_DIR = os.path.join(base_dir, "refer_seg")
self.refer_seg_ds_list = refer_seg_data.split(
"||"
) # ['refclef', 'refcoco', 'refcoco+', 'refcocog']
self.refer_seg_data = {}
for ds in self.refer_seg_ds_list:
if ds == "refcocog":
splitBy = "umd"
else:
splitBy = "unc"
if ds == "grefcoco":
refer_api = G_REFER(DATA_DIR, ds, splitBy)
else:
refer_api = REFER(DATA_DIR, ds, splitBy)
ref_ids_train = refer_api.getRefIds(split="train")
images_ids_train = refer_api.getImgIds(ref_ids=ref_ids_train)
refs_train = refer_api.loadRefs(ref_ids=ref_ids_train)
refer_seg_ds = {}
refer_seg_ds["images"] = []
loaded_images = refer_api.loadImgs(image_ids=images_ids_train)
for item in loaded_images:
item = item.copy()
if ds == "refclef":
item["file_name"] = os.path.join(
DATA_DIR, "images/saiapr_tc-12", item["file_name"]
)
else:
item["file_name"] = os.path.join(
DATA_DIR, "images/mscoco/images/train2014", item["file_name"]
)
refer_seg_ds["images"].append(item)
refer_seg_ds["annotations"] = refer_api.Anns # anns_train
print(
"dataset {} (refs {}) (train split) has {} images and {} annotations.".format(
ds,
splitBy,
len(refer_seg_ds["images"]),
len(refer_seg_ds["annotations"]),
)
)
img2refs = {}
for ref in refs_train:
image_id = ref["image_id"]
img2refs[image_id] = img2refs.get(image_id, []) + [
ref,
]
refer_seg_ds["img2refs"] = img2refs
self.refer_seg_data[ds] = refer_seg_ds
def __len__(self):
return self.samples_per_epoch
def preprocess(self, x: torch.Tensor) -> torch.Tensor:
"""Normalize pixel values and pad to a square input."""
# Normalize colors
x = (x - self.pixel_mean) / self.pixel_std
# Pad
h, w = x.shape[-2:]
padh = self.img_size - h
padw = self.img_size - w
x = F.pad(x, (0, padw, 0, padh))
return x
def __getitem__(self, idx):
ds = random.randint(0, len(self.refer_seg_ds_list) - 1)
ds = self.refer_seg_ds_list[ds]
refer_seg_ds = self.refer_seg_data[ds]
images = refer_seg_ds["images"]
annotations = refer_seg_ds["annotations"]
img2refs = refer_seg_ds["img2refs"]
idx = random.randint(0, len(images) - 1)
image_info = images[idx]
image_path = image_info["file_name"]
image_id = image_info["id"]
refs = img2refs[image_id]
if len(refs) == 0:
return self.__getitem__(0)
sents = []
ann_ids = []
for ref in refs:
for sent in ref["sentences"]:
text = sent["sent"]
sents.append(text)
ann_ids.append(ref["ann_id"])
if len(sents) >= self.num_classes_per_sample:
sampled_inds = np.random.choice(
list(range(len(sents))), size=self.num_classes_per_sample, replace=False
)
else:
sampled_inds = list(range(len(sents)))
sampled_sents = np.vectorize(sents.__getitem__)(sampled_inds).tolist()
# sampled_ann_ids = np.vectorize(ann_ids.__getitem__)(sampled_inds).tolist()
sampled_ann_ids = [ann_ids[ind] for ind in sampled_inds]
sampled_classes = sampled_sents
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# preprocess image for clip
image_clip = self.clip_image_processor.preprocess(
expand2square(Image.open(image_path).convert('RGB'), tuple(int(x*255) for x in self.clip_image_processor.image_mean)), return_tensors="pt")["pixel_values"][0]
original_size = image.shape[:2]
image = self.transform(images=image, return_tensors="pt")['pixel_values'][0]
resize = image.shape[:2]
questions = []
answers = []
for text in sampled_classes:
text = text.strip()
assert len(text.split("||")) == 1
question_template = random.choice(self.short_question_list)
questions.append(question_template.format(class_name=text.lower()))
answers.append(random.choice(self.answer_list))
conversations = []
conv = conversation_lib.default_conversation.copy()
i = 0
while i < len(questions):
conv.messages = []
conv.append_message(conv.roles[0], questions[i])
conv.append_message(conv.roles[1], answers[i])
conversations.append(conv.get_prompt())
i += 1
flag = False
masks = []
bboxes_labels = []
for ann_id in sampled_ann_ids:
if isinstance(ann_id, list):
assert False
flag = True
if -1 in ann_id:
assert len(ann_id) == 1
m = np.zeros((image_info["height"], image_info["width"])).astype(
np.uint8
)
else:
m_final = np.zeros(
(image_info["height"], image_info["width"])
).astype(np.uint8)
for ann_id_i in ann_id:
ann = annotations[ann_id_i]
if len(ann["segmentation"]) == 0:
m = np.zeros(
(image_info["height"], image_info["width"])
).astype(np.uint8)
else:
if type(ann["segmentation"][0]) == list: # polygon
rle = mask.frPyObjects(
ann["segmentation"],
image_info["height"],
image_info["width"],
)
else:
rle = ann["segmentation"]
for i in range(len(rle)):
if not isinstance(rle[i]["counts"], bytes):
rle[i]["counts"] = rle[i]["counts"].encode()
m = mask.decode(rle)
m = np.sum(
m, axis=2
) # sometimes there are multiple binary map (corresponding to multiple segs)
m = m.astype(np.uint8) # convert to np.uint8
m_final = m_final | m
m = m_final
masks.append(m)
continue
ann = annotations[ann_id]
cur_bboxes = [ann['bbox']]
cur_bboxes = torch.tensor(cur_bboxes).view(-1, 4)
# xywh to x1y1x2y2
cur_bboxes[:, 2:] += cur_bboxes[:, :2]
cur_bboxes[:, 0::2].clamp_(min=0, max=original_size[1])
cur_bboxes[:, 1::2].clamp_(min=0, max=original_size[0])
keep = (cur_bboxes[:, 3] > cur_bboxes[:, 1]) & (cur_bboxes[:, 2] > cur_bboxes[:, 0])
cur_bboxes = cur_bboxes[keep]
cur_bboxes = box_xyxy_to_cxcywh(cur_bboxes)
cur_bboxes = cur_bboxes / torch.tensor([original_size[1], original_size[0], original_size[1], original_size[0]], dtype=torch.float32)
if len(cur_bboxes) == 0:
return self.__getitem__(0)
bboxes_labels.append(cur_bboxes)
if len(ann["segmentation"]) == 0:
m = np.zeros((image_info["height"], image_info["width"])).astype(
np.uint8
)
masks.append(m)
continue
if type(ann["segmentation"][0]) == list: # polygon
rle = mask.frPyObjects(
ann["segmentation"], image_info["height"], image_info["width"]
)
else:
rle = ann["segmentation"]
for i in range(len(rle)):
if not isinstance(rle[i]["counts"], bytes):
rle[i]["counts"] = rle[i]["counts"].encode()
m = mask.decode(rle)
m = np.sum(
m, axis=2
) # sometimes there are multiple binary map (corresponding to multiple segs)
m = m.astype(np.uint8) # convert to np.uint8
masks.append(m)
bboxes_valid = [1]*len(bboxes_labels)
masks_valid = [1]*len(bboxes_labels)
masks = np.stack(masks, axis=0)
masks = torch.from_numpy(masks)
label = torch.ones(masks.shape[1], masks.shape[2]) * self.ignore_label
return (
image_path,
image,
image_clip,
conversations,
masks,
label,
bboxes_labels,
bboxes_valid,
masks_valid,
resize,
questions,
sampled_classes,
)