vstar / VisualSearch /utils /dataset.py
Penghao Wu
init
3672502
raw
history blame
No virus
14.3 kB
import glob
import os
import random
from PIL import Image
import cv2
cv2.setNumThreads(1)
import numpy as np
import torch
import torch.nn.functional as F
from pycocotools import mask
from transformers import CLIPImageProcessor
from transformers import OwlViTProcessor
from VisualSearch.model.llava import conversation as conversation_lib
from VisualSearch.model.llava.constants import (DEFAULT_IMAGE_TOKEN, IGNORE_INDEX,
IMAGE_TOKEN_INDEX)
from VisualSearch.model.llava.mm_utils import tokenizer_image_token
from VisualSearch.utils.data_processing import get_mask_from_json
from VisualSearch.utils.refer import REFER
from VisualSearch.utils.refer_seg_dataset import ReferSegDataset
from VisualSearch.utils.general_segdet_dataset import SegDetDataset
from VisualSearch.utils.mixed_grounding_dataset import MixedGroundingDataset
from VisualSearch.utils.vqa_dataset import VQADataset
from VisualSearch.utils.utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
DEFAULT_IMAGE_TOKEN)
from VisualSearch.utils.utils import box_xyxy_to_cxcywh, expand2square
def collate_fn(
batch, tokenizer=None, conv_type="llava_v1", use_mm_start_end=True, local_rank=-1
):
image_path_list = []
images_list = []
images_clip_list = []
conversation_list = []
masks_list = []
label_list = []
bboxes_labels_list = []
bboxes_valid_list = []
masks_valid_list = []
resize_list = []
questions_list = []
sampled_classes_list = []
offset_list = [0]
cnt = 0
inferences = []
for (
image_path,
images,
images_clip,
conversations,
masks,
label,
bboxes_labels,
bboxes_valid,
masks_valid,
resize,
questions,
sampled_classes,
inference,
) in batch:
image_path_list.append(image_path)
images_list.append(images)
images_clip_list.append(images_clip)
conversation_list.extend(conversations)
label_list.append(label)
masks_list.append(masks.float())
bboxes_labels_list.extend(bboxes_labels)
bboxes_valid_list.extend(bboxes_valid)
masks_valid_list.append(torch.tensor(masks_valid))
resize_list.append(resize)
questions_list.append(questions)
sampled_classes_list.append(sampled_classes)
cnt += len(conversations)
offset_list.append(cnt)
inferences.append(inference)
if use_mm_start_end:
# replace <image> token
for i in range(len(conversation_list)):
replace_token = DEFAULT_IMAGE_TOKEN
replace_token = (
DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
)
conversation_list[i] = conversation_list[i].replace(
DEFAULT_IMAGE_TOKEN, replace_token
)
input_ids = [
tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
for prompt in conversation_list
]
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids, batch_first=True, padding_value=tokenizer.pad_token_id
)
attention_masks = input_ids.ne(tokenizer.pad_token_id)
for i in range(len(bboxes_valid_list)):
bboxes_valid = bboxes_valid_list[i]
attention_mask = attention_masks[i]
if not bboxes_valid:
attention_mask = attention_mask & input_ids[i].ne(tokenizer("[LOC]", add_special_tokens=False).input_ids[0])
attention_masks[i] = attention_mask
conv = conversation_lib.default_conversation.copy()
targets = input_ids.clone()
if conv_type == "llava_v1":
sep = conv.sep + conv.roles[1] + ": "
else:
sep = "[/INST] "
for conversation, target in zip(conversation_list, targets):
total_len = int(target.ne(tokenizer.pad_token_id).sum())
rounds = conversation.split(conv.sep2)
cur_len = 1
target[:cur_len] = IGNORE_INDEX
for i, rou in enumerate(rounds):
if rou == "":
break
parts = rou.split(sep)
# if len(parts) != 2:
# break
assert len(parts) == 2, (len(parts), rou)
parts[0] += sep
if DEFAULT_IMAGE_TOKEN in conversation:
round_len = len(tokenizer_image_token(rou, tokenizer))
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
else:
round_len = len(tokenizer(rou).input_ids)
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
cur_len += round_len
target[cur_len:] = IGNORE_INDEX
if False:
z = target.clone()
z = torch.where(z == IGNORE_INDEX, tokenizer.unk_token_id, z)
if local_rank == 0:
print(
"conversation: ",
conversation,
"tokenizer.decode(z): ",
tokenizer.decode(z),
)
if cur_len < tokenizer.model_max_length:
assert cur_len == total_len
if inferences[0] == False:
truncate_len = tokenizer.model_max_length - 255
if input_ids.shape[1] > truncate_len:
input_ids = input_ids[:, :truncate_len]
targets = targets[:, :truncate_len]
attention_masks = attention_masks[:, :truncate_len]
return {
"image_paths": image_path_list,
"images": torch.stack(images_list, dim=0),
"images_clip": torch.stack(images_clip_list, dim=0),
"input_ids": input_ids,
"labels": targets,
"bboxes_labels_list": bboxes_labels_list,
"bboxes_valid_list": torch.tensor(bboxes_valid_list),
"masks_valid_list": masks_valid_list,
"attention_masks": attention_masks,
"masks_list": masks_list,
"label_list": label_list,
"resize_list": resize_list,
"offset": torch.LongTensor(offset_list),
"questions_list": questions_list,
"sampled_classes_list": sampled_classes_list,
"inference": inferences[0],
"conversation_list": conversation_list,
}
class HybridDataset(torch.utils.data.Dataset):
pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1)
img_size = 1024
ignore_label = 255
def __init__(
self,
base_dir,
tokenizer,
vision_tower,
samples_per_epoch=500 * 8 * 2 * 10,
precision: str = "fp32",
num_classes_per_sample: int = 3,
exclude_val=False,
dataset="general_segdet||refer_seg||vqa||reason_seg",
sample_rate=[9, 3, 3, 1],
general_segdet_data="objects365||cocostuff||paco_lvis",
general_segdet_sample_rate=[2,1,1],
refer_seg_data="refclef||refcoco||refcoco+||refcocog",
vqa_data="possible_locations_conv_86k||llava_instruct_80k",
vqa_sample_rate=[2,1],
):
self.exclude_val = exclude_val
self.dataset = dataset
self.samples_per_epoch = samples_per_epoch
self.num_classes_per_sample = num_classes_per_sample
sample_rate = np.array(sample_rate)
self.sample_rate = sample_rate / sample_rate.sum()
self.base_dir = base_dir
self.tokenizer = tokenizer
self.precision = precision
self.datasets = dataset.split("||")
self.all_datasets = []
for dataset in self.datasets:
if dataset == "general_segdet":
self.all_datasets.append(
SegDetDataset(
base_dir,
tokenizer,
vision_tower,
samples_per_epoch,
precision,
num_classes_per_sample,
exclude_val,
general_segdet_data,
general_segdet_sample_rate,
)
)
elif dataset == "refer_seg":
self.all_datasets.append(
ReferSegDataset(
base_dir,
tokenizer,
vision_tower,
samples_per_epoch,
precision,
num_classes_per_sample,
exclude_val,
refer_seg_data,
)
)
elif dataset == "vqa":
self.all_datasets.append(
VQADataset(
base_dir,
tokenizer,
vision_tower,
samples_per_epoch,
precision,
num_classes_per_sample,
exclude_val,
vqa_data,
vqa_sample_rate,
)
)
elif dataset == "mixed_grounding":
self.all_datasets.append(
MixedGroundingDataset(
base_dir,
tokenizer,
vision_tower,
samples_per_epoch,
precision,
num_classes_per_sample,
exclude_val,
)
)
def __len__(self):
return self.samples_per_epoch
def __getitem__(self, idx):
ind = np.random.choice(list(range(len(self.datasets))), p=self.sample_rate)
data = self.all_datasets[ind]
inference = False
return *data[0], inference
class ValDataset(torch.utils.data.Dataset):
pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1)
img_size = 1024
ignore_label = 255
def __init__(
self,
base_dir,
tokenizer,
vision_tower,
val_dataset,
):
self.base_dir = base_dir
splits = val_dataset.split("|")
if len(splits) == 2:
ds, split = splits
images = glob.glob(
os.path.join(self.base_dir, "reason_seg", ds, split, "*.jpg")
)
self.images = images
self.data_type = "reason_seg"
elif len(splits) == 3:
self.base_dir = os.path.join(self.base_dir, 'refer_seg')
ds, splitBy, split = splits
refer_api = REFER(self.base_dir, ds, splitBy)
ref_ids_val = refer_api.getRefIds(split=split)
images_ids_val = refer_api.getImgIds(ref_ids=ref_ids_val)
refs_val = refer_api.loadRefs(ref_ids=ref_ids_val)
refer_seg_ds = {}
refer_seg_ds["images"] = []
loaded_images = refer_api.loadImgs(image_ids=images_ids_val)
for item in loaded_images:
item = item.copy()
if ds == "refclef":
item["file_name"] = os.path.join(
self.base_dir, "images/saiapr_tc-12", item["file_name"]
)
elif ds in ["refcoco", "refcoco+", "refcocog", "grefcoco"]:
item["file_name"] = os.path.join(
self.base_dir,
"images/mscoco/images/train2014",
item["file_name"],
)
refer_seg_ds["images"].append(item)
refer_seg_ds["annotations"] = refer_api.Anns # anns_val
img2refs = {}
for ref in refs_val:
image_id = ref["image_id"]
img2refs[image_id] = img2refs.get(image_id, []) + [
ref,
]
refer_seg_ds["img2refs"] = img2refs
self.refer_seg_ds = refer_seg_ds
self.data_type = "refer_seg"
self.ds = ds
self.tokenizer = tokenizer
self.transform = OwlViTProcessor.from_pretrained("google/owlvit-base-patch16")
self.clip_image_processor = CLIPImageProcessor.from_pretrained(vision_tower)
def __len__(self):
if self.data_type == "refer_seg":
return len(self.refer_seg_ds["images"])
else:
return len(self.images)
def preprocess(self, x: torch.Tensor) -> torch.Tensor:
"""Normalize pixel values and pad to a square input."""
# Normalize colors
x = (x - self.pixel_mean) / self.pixel_std
# Pad
h, w = x.shape[-2:]
padh = self.img_size - h
padw = self.img_size - w
x = F.pad(x, (0, padw, 0, padh))
return x
def __getitem__(self, idx):
if self.data_type == "refer_seg":
refer_seg_ds = self.refer_seg_ds
images = refer_seg_ds["images"]
annotations = refer_seg_ds["annotations"]
img2refs = refer_seg_ds["img2refs"]
image_info = images[idx]
image_path = image_info["file_name"]
image_id = image_info["id"]
refs = img2refs[image_id]
if len(refs) == 0:
raise ValueError("image {} has no refs".format(image_id))
sents = []
ann_ids = []
for ref in refs:
for sent in ref["sentences"]:
sents.append(sent["sent"].strip().lower())
ann_ids.append(ref["ann_id"])
sampled_sents = sents
sampled_ann_ids = ann_ids
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
is_sentence = False
else:
image_path = self.images[idx]
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
json_path = image_path.replace(".jpg", ".json")
mask_json, sampled_sents, is_sentence = get_mask_from_json(json_path, image)
sampled_sents = [sampled_sents[0]]
conversations = []
conv = conversation_lib.default_conversation.copy()
i = 0
while i < len(sampled_sents):
conv.messages = []
text = sampled_sents[i].strip()
if is_sentence:
conv.append_message(
conv.roles[0],
DEFAULT_IMAGE_TOKEN
+ "\n {} Please output segmentation mask.".format(text),
)
conv.append_message(conv.roles[1], "[LOC].")
else:
conv.append_message(
conv.roles[0],
DEFAULT_IMAGE_TOKEN
+ "\n Please locate the {} in this image.".format(
text
),
)
conv.append_message(conv.roles[1], "Sure, [LOC].")
conversations.append(conv.get_prompt())
i += 1
# preprocess image for clip
image_clip = self.clip_image_processor.preprocess(
expand2square(Image.open(image_path).convert('RGB'), tuple(int(x*255) for x in self.clip_image_processor.image_mean)), return_tensors="pt")["pixel_values"][0]
original_size = image.shape[:2]
image = self.transform(images=image, return_tensors="pt")['pixel_values'][0]
resize = image.shape[:2]
if self.data_type == "refer_seg":
masks = []
bboxes_labels = []
for i, ann_id in enumerate(sampled_ann_ids):
ann = annotations[ann_id]
cur_bboxes = [ann['bbox']]
cur_bboxes = torch.tensor(cur_bboxes).view(-1, 4)
# xywh to x1y1x2y2
cur_bboxes[:, 2:] += cur_bboxes[:, :2]
cur_bboxes[:, 0::2].clamp_(min=0, max=original_size[1])
cur_bboxes[:, 1::2].clamp_(min=0, max=original_size[0])
keep = (cur_bboxes[:, 3] > cur_bboxes[:, 1]) & (cur_bboxes[:, 2] > cur_bboxes[:, 0])
cur_bboxes = cur_bboxes[keep]
cur_bboxes = box_xyxy_to_cxcywh(cur_bboxes)
cur_bboxes = cur_bboxes / torch.tensor([original_size[1], original_size[0], original_size[1], original_size[0]], dtype=torch.float32)
if len(cur_bboxes) == 0:
return self.__getitem__(0)
bboxes_labels.append(cur_bboxes)
if len(ann["segmentation"]) == 0 and sampled_sents[i] != "":
m = np.zeros((image_info["height"], image_info["width"], 1))
else:
if type(ann["segmentation"][0]) == list: # polygon
rle = mask.frPyObjects(
ann["segmentation"],
image_info["height"],
image_info["width"],
)
else:
rle = ann["segmentation"]
for i in range(len(rle)):
if not isinstance(rle[i]["counts"], bytes):
rle[i]["counts"] = rle[i]["counts"].encode()
m = mask.decode(rle)
m = np.sum(
m, axis=2
) # sometimes there are multiple binary map (corresponding to multiple segs)
m = m.astype(np.uint8) # convert to np.uint8
masks.append(m)
else:
masks = [mask_json]
bboxes_valid = [1]*len(bboxes_labels)
masks_valid = [1]*len(bboxes_labels)
masks = np.stack(masks, axis=0)
masks = torch.from_numpy(masks)
labels = torch.ones(masks.shape[1], masks.shape[2]) * self.ignore_label
inference = True
return (
image_path,
image,
image_clip,
conversations,
masks,
labels,
bboxes_labels,
bboxes_valid,
masks_valid,
resize,
None,
None,
inference,
)