DenseLabelDev / projects /llava_sam2 /datasets /ReVOS_Dataset_eval.py
zhouyik's picture
Upload folder using huggingface_hub
032e687 verified
import logging
import os
import torch
from datasets import Dataset as HFDataset
from datasets import DatasetDict
from mmengine import print_log
import mmengine
from PIL import Image
import numpy as np
from mmengine.dist import master_only
from xtuner.registry import BUILDER
from xtuner.dataset.huggingface import build_origin_dataset
import copy
from vlm.datasets.evaluation.base_eval_dataset import BaseEvalDataset
from .encode_fn import video_lisa_encode_multi_conv_fn
import json
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
SEG_QUESTIONS = [
"Can you segment the {class_name} in this image?",
"Please segment {class_name} in this image.",
"What is {class_name} in this image? Please respond with segmentation mask.",
"What is {class_name} in this image? Please output segmentation mask.",
"Can you segment the {class_name} in this image",
"Please segment {class_name} in this image",
"What is {class_name} in this image? Please respond with segmentation mask",
"What is {class_name} in this image? Please output segmentation mask",
"Could you provide a segmentation mask for the {class_name} in this image?",
"Please identify and segment the {class_name} in this image.",
"Where is the {class_name} in this picture? Please respond with a segmentation mask.",
"Can you highlight the {class_name} in this image with a segmentation mask?",
"Could you provide a segmentation mask for the {class_name} in this image",
"Please identify and segment the {class_name} in this image",
"Where is the {class_name} in this picture? Please respond with a segmentation mask",
"Can you highlight the {class_name} in this image with a segmentation mask",
]
ANSWER_LIST = [
"It is [SEG].",
"Sure, [SEG].",
"Sure, it is [SEG].",
"Sure, the segmentation result is [SEG].",
"[SEG].",
]
def multi_template_fn(conversations, template_map):
for conv in conversations:
for i, single_turn_conversation in enumerate(conv):
input = single_turn_conversation.get('input', '')
if input is None:
input = ''
input_text = template_map.INSTRUCTION.format(input=input, round=i + 1)
system = single_turn_conversation.get('system', '')
if system != '' and system is not None:
system = template_map.SYSTEM.format(system=system)
input_text = system + input_text
single_turn_conversation['input'] = input_text
if template_map.get('SUFFIX', None):
output_text = single_turn_conversation.get('output', '')
output_text += template_map.SUFFIX
single_turn_conversation['output'] = output_text
# SUFFIX_AS_EOS is False ==> need_eos_token is True
single_turn_conversation['need_eos_token'] = \
not template_map.get('SUFFIX_AS_EOS', False)
single_turn_conversation['sep'] = template_map.get('SEP', '')
class VideoReVOSEvalDataset(BaseEvalDataset):
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
IMG_START_TOKEN = '<img>'
IMG_END_TOKEN = '</img>'
FAST_IMG_CONTEXT_TOKEN = '<FAST_IMG_CONTEXT>'
FAST_IMG_START_TOKEN = '<fast_img>'
FAST_IMG_END_TOKEN = '</fast_img>'
METAINFO: dict = dict(name='revos')
def __init__(self,
image_folder,
expression_file,
mask_file,
extra_image_processor=None,
tokenizer=None,
offline_processed_text_folder=None,
template_map_fn=None,
max_length=2048,
lazy=True,
special_tokens=None,
# eval settings
num_frames=5,
# eval name
eval_name=None,
use_fast=False,
fast_pool_size=2,
):
super().__init__()
assert lazy is True
self.tokenizer = BUILDER.build(tokenizer)
assert offline_processed_text_folder or (expression_file and tokenizer)
self.lazy = lazy
self.max_length = max_length
self.template_map = template_map_fn['template']
if offline_processed_text_folder and expression_file:
print_log(
'Both `offline_processed_text_folder` and '
'`data_path` are set, and we load dataset from'
'`offline_processed_text_folder` '
f'({offline_processed_text_folder})',
logger='current',
level=logging.WARNING)
if offline_processed_text_folder is not None:
raise NotImplementedError
else:
vid2metaid, metas, mask_dict = self.json_file_preprocess(expression_file, mask_file)
self.vid2metaid = vid2metaid
self.videos = list(self.vid2metaid.keys())
self.mask_dict = mask_dict
self.json_datas = metas
json_datas = metas
self.text_data = json_datas
# json_data = DatasetDict({'train': HFDataset.from_list(json_datas)})
# if self.lazy:
# self.text_data = build_origin_dataset(json_data, 'train')
# else:
# raise NotImplementedError
self.image_folder = image_folder
if extra_image_processor is not None:
self.extra_image_processor = BUILDER.build(extra_image_processor)
self.down_ratio = 1
self.repeats = 1
self._system = ''
self.downsample_ratio = 0.5
self.image_size = 448
patch_size = 14
self.patch_token = int((self.image_size // patch_size) ** 2 * (self.downsample_ratio ** 2))
self.transformer = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize((self.image_size, self.image_size), interpolation=InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=self.IMAGENET_MEAN, std=self.IMAGENET_STD)
])
if special_tokens is not None:
self.tokenizer.add_tokens(special_tokens, special_tokens=True)
self.num_frames = num_frames
self.use_fast = use_fast
self.fast_pool_size = fast_pool_size
# save to json
if eval_name is None:
eval_name = 'results'
self.eval_name = eval_name
def __len__(self):
return len(self.vid2metaid) * self.repeats
@property
def modality_length(self):
length_list = []
for data_dict in self.vid2metaid:
cur_len = 10000
length_list.append(cur_len)
return length_list
def real_len(self):
return len(self.vid2metaid)
def json_file_preprocess(self, expression_file, mask_file):
# prepare expression annotation files
with open(expression_file, 'r') as f:
expression_datas = json.load(f)['videos']
metas = []
anno_count = 0 # serve as anno_id
vid2metaid = {}
for vid_name in expression_datas:
vid_express_data = expression_datas[vid_name]
vid_frames = sorted(vid_express_data['frames'])
vid_len = len(vid_frames)
exp_id_list = sorted(list(vid_express_data['expressions'].keys()))
for exp_id in exp_id_list:
exp_dict = vid_express_data['expressions'][exp_id]
meta = {}
meta['video'] = vid_name
meta['exp'] = exp_dict['exp'] # str
meta['mask_anno_id'] = exp_dict['anno_id']
if 'obj_id' in exp_dict.keys():
meta['obj_id'] = exp_dict['obj_id']
else:
meta['obj_id'] = [0, ] # Ref-Youtube-VOS only has one object per expression
meta['anno_id'] = [str(anno_count), ]
anno_count += 1
meta['frames'] = vid_frames
meta['exp_id'] = exp_id
meta['length'] = vid_len
metas.append(meta)
if vid_name not in vid2metaid.keys():
vid2metaid[vid_name] = []
vid2metaid[vid_name].append(len(metas) - 1)
# process mask annotation files
with open(mask_file, 'rb') as f:
mask_dict = json.load(f)
return vid2metaid, metas, mask_dict
def create_img_to_refs_mapping(self, refs_train):
img2refs = {}
for ref in refs_train:
img2refs[ref["image_id"]] = img2refs.get(ref["image_id"], []) + [ref, ]
return img2refs
def dataset_map_fn(self, data_dict):
images = []
len_frames = len(data_dict[0]['frames'])
for objet_info in data_dict:
assert len_frames == len(objet_info['frames'])
selected_frame_indexes = range(len_frames)
for selected_frame_index in selected_frame_indexes:
frame_id = data_dict[0]['frames'][selected_frame_index]
images.append(os.path.join(data_dict[0]['video'], frame_id + '.jpg'))
num_frames = len(images) if len(images) < self.num_frames else self.num_frames
num_fast_frames = len(images)
# prepare text
expressions = [object_info['exp'] for object_info in data_dict]
# Modify: To n dialogues
text_dict = self.prepare_text(num_frames, expressions, num_image_tokens=self.patch_token,
num_fast_frames=num_fast_frames)
ret = {'images': images, 'video_masks': None, 'conversation': text_dict['conversation']}
return ret
def prepare_text(self, n_frames, expressions, num_image_tokens=256, num_fast_frames=0):
if self.use_fast:
fast_frame_token_str = f'{self.FAST_IMG_START_TOKEN}' \
f'{self.FAST_IMG_CONTEXT_TOKEN * num_fast_frames * self.fast_pool_size * self.fast_pool_size}' \
f'{self.FAST_IMG_END_TOKEN}' + '\n'
else:
fast_frame_token_str = ''
frame_token_str = f'{self.IMG_START_TOKEN}' \
f'{self.IMG_CONTEXT_TOKEN * num_image_tokens}' \
f'{self.IMG_END_TOKEN}'
questions = []
for i, exp in enumerate(expressions):
# the exp is a question
if '?' in exp:
questions.append(exp)
else:
exp = exp.replace('.', '').strip()
# EVAL: Use the first question all the time.
# question_template = random.choice(SEG_QUESTIONS)
question_template = SEG_QUESTIONS[0]
questions.append(question_template.format(class_name=exp.lower()))
eval_conversation_list = []
for i, question in enumerate(questions):
qa_list = []
frame_tokens = frame_token_str + '\n'
frame_tokens = frame_tokens * n_frames
frame_tokens = frame_tokens.strip()
qa_list.append(
{'from': 'human', 'value': fast_frame_token_str + frame_tokens + question}
)
qa_list.append(
{'from': 'gpt', 'value': ''}
)
assert len(qa_list) == 2
input = ''
conversation = []
for msg in qa_list:
if msg['from'] == 'human':
input += msg['value']
elif msg['from'] == 'gpt':
if msg['value'] == '':
conversation.append({'input': input,})
else:
conversation.append({'input': input, 'output': msg['value']})
input = ''
else:
raise NotImplementedError
# add system information
conversation[0].update({'system': self._system})
eval_conversation_list.append(conversation)
return {'conversation': eval_conversation_list}
def __getitem__(self, index):
index = index % self.real_len()
selected_video_objects = self.vid2metaid[self.videos[index]]
video_objects_infos = [copy.deepcopy(self.text_data[idx]) for idx in selected_video_objects]
selected_objects = video_objects_infos
text_prompts = [copy.deepcopy(item['exp']) for item in selected_objects]
data_dict = self.dataset_map_fn(selected_objects)
multi_template_fn(data_dict['conversation'], self.template_map)
result = video_lisa_encode_multi_conv_fn(data_dict, input_ids_with_output=False, tokenizer=self.tokenizer, max_length=self.max_length)
data_dict.update(result)
assert 'images' in data_dict.keys()
pixel_values = []
if self.use_fast:
fast_pixel_values = []
extra_pixel_values = []
if data_dict.get('images', None) is not None:
frames_files = data_dict['images']
frames_files = [os.path.join(self.image_folder, frame_file) for frame_file in frames_files]
ori_width, ori_height = None, None
for frame_idx, frame_path in enumerate(frames_files):
frame_image = Image.open(frame_path).convert('RGB')
if ori_height is None:
ori_width, ori_height = frame_image.size
else:
assert ori_width == frame_image.size[0]
assert ori_height == frame_image.size[1]
if self.extra_image_processor is not None:
g_image = np.array(frame_image) # for grounding
g_image = self.extra_image_processor.apply_image(g_image)
g_pixel_values = torch.from_numpy(g_image).permute(2, 0, 1).contiguous()
extra_pixel_values.append(g_pixel_values)
if self.use_fast:
frame_image = self.transformer(frame_image)
fast_pixel_values.append(frame_image)
if frame_idx < self.num_frames:
pixel_values.append(frame_image)
else:
if frame_idx < self.num_frames:
frame_image = self.transformer(frame_image)
pixel_values.append(frame_image)
pixel_values = torch.stack(pixel_values, dim=0) # (n_f, 3, h, w)
data_dict['pixel_values'] = pixel_values
if self.use_fast:
fast_pixel_values = torch.stack(fast_pixel_values, dim=0) # (n_f, 3, h, w)
data_dict['fast_pixel_values'] = fast_pixel_values
if self.extra_image_processor is not None:
data_dict['g_pixel_values'] = extra_pixel_values
else:
data_dict['pixel_values'] = torch.zeros(0, 3, self.image_size, self.image_size)
ori_width, ori_height = None, None
data_dict['type'] = 'video'
data_dict['video_id'] = index
data_dict['text_prompts'] = text_prompts
data_dict['image_folder'] = self.image_folder
data_dict['ori_height'] = ori_height
data_dict['ori_width'] = ori_width
data_dict['id'] = index
return data_dict
@master_only
def evaluate(self, results, work_dir):
final_results = {}
for idx, item in enumerate(results):
_id = item['id']
# vid_id = self.videos[idx]
vid_id = self.videos[_id]
selected_video_objects = self.vid2metaid[vid_id]
video_objects_infos = [copy.deepcopy(self.text_data[idx]) for idx in selected_video_objects]
text_prompts = [copy.deepcopy(item['exp']) for item in video_objects_infos]
exp_ids = [copy.deepcopy(item['exp_id']) for item in video_objects_infos]
final_results[vid_id] = {}
assert len(text_prompts) == len(item['prediction_masks']), f"{len(text_prompts)}-----{len(item['prediction_masks'])}"
for idt, text in enumerate(text_prompts):
exp_id = exp_ids[idt]
final_results[vid_id][exp_id] = {
'exp': text,
'prediction_masks': item['prediction_masks'][idt],
}
mmengine.dump(final_results, os.path.join(work_dir, f'{self.eval_name}.json'))
return {"Dummy": 0}
class DAVISEval(VideoReVOSEvalDataset):
def json_file_preprocess(self, expression_file, mask_file):
# prepare expression annotation files
with open(expression_file, 'r') as f:
expression_datas = json.load(f)['videos']
metas = []
anno_count = 0 # serve as anno_id
vid2metaid = {}
for vid_name in expression_datas:
vid_express_data = expression_datas[vid_name]
vid_frames = sorted(vid_express_data['frames'])
vid_len = len(vid_frames)
exp_id_list = sorted(list(vid_express_data['expressions'].keys()))
for exp_id in exp_id_list:
exp_dict = vid_express_data['expressions'][exp_id]
meta = {}
meta['video'] = vid_name
meta['exp'] = exp_dict['exp'] # str
meta['mask_anno_id'] = [str(anno_count), ]
if 'obj_id' in exp_dict.keys():
meta['obj_id'] = exp_dict['obj_id']
else:
meta['obj_id'] = [0, ] # Ref-Youtube-VOS only has one object per expression
meta['anno_id'] = [str(anno_count), ]
anno_count += 1
meta['frames'] = vid_frames
meta['exp_id'] = exp_id
meta['length'] = vid_len
metas.append(meta)
if vid_name not in vid2metaid.keys():
vid2metaid[vid_name] = []
vid2metaid[vid_name].append(len(metas) - 1)
mask_dict = mmengine.load(mask_file)
return vid2metaid, metas, mask_dict