|
|
|
import os.path |
|
|
|
import cv2 |
|
import mmengine |
|
from mmengine.runner import ValLoop as MMENGINE_ValLoop |
|
from mmengine.dist import broadcast_object_list, is_main_process, get_world_size, get_rank, barrier, collect_results |
|
import math |
|
import torch |
|
from mmengine.model import is_model_wrapper |
|
from types import MethodType |
|
from xtuner.utils import PROMPT_TEMPLATE |
|
from xtuner.tools.utils import get_stop_criteria |
|
from transformers import GenerationConfig |
|
from pycocotools import mask as _mask |
|
from mmengine.visualization.visualizer import Visualizer |
|
|
|
from vlm.utils import VideoReader |
|
|
|
TORCH_DTYPE_MAP = dict(fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32, auto='auto') |
|
|
|
VID_INTERVAL = 4 |
|
|
|
def visualize(data_batch, prediction, visualize_path='work_dirs/visualize'): |
|
if 'video_path' in data_batch: |
|
vid_frames = VideoReader(data_batch['video_path'])[::VID_INTERVAL] |
|
vid_id = os.path.basename(data_batch['video_path']).split('.')[0] |
|
text_prompts = data_batch['text_prompts'] |
|
mmengine.mkdir_or_exist(os.path.join(visualize_path, vid_id)) |
|
visualizer = Visualizer() |
|
|
|
mmengine.mkdir_or_exist(os.path.join(visualize_path, vid_id, "vid")) |
|
for id_frame, img in enumerate(vid_frames): |
|
out_path = os.path.join(visualize_path, vid_id, "vid", "{:06d}.jpg".format(id_frame)) |
|
cv2.imwrite(out_path, img) |
|
|
|
for id_text, text in enumerate(text_prompts): |
|
mmengine.mkdir_or_exist(os.path.join(visualize_path, vid_id, "sample_{:06d}".format(id_text))) |
|
mmengine.put_text(text, os.path.join(visualize_path, vid_id, "sample_{:06d}".format(id_text), 'text.txt')) |
|
for id_frame, img in enumerate(vid_frames): |
|
visualizer.set_image(img) |
|
mask = prediction['prediction_masks'][id_text][id_frame] |
|
mask = _mask.decode(mask).astype(bool) |
|
visualizer.draw_binary_masks(mask, colors='g') |
|
visual_result = visualizer.get_image() |
|
out_path = os.path.join(visualize_path, vid_id, "sample_{:06d}".format(id_text), |
|
"{:06d}.jpg".format(id_frame)) |
|
cv2.imwrite(out_path, visual_result) |
|
else: |
|
images_files = data_batch['images'] |
|
vid_id = data_batch['video_id'] |
|
text_prompts = data_batch['text_prompts'] |
|
image_folder = data_batch['image_folder'] |
|
mmengine.mkdir_or_exist(os.path.join(visualize_path, "{:06d}".format(vid_id))) |
|
visualizer = Visualizer() |
|
|
|
mmengine.mkdir_or_exist(os.path.join(visualize_path, "{:06d}".format(vid_id), "vid")) |
|
for id_frame, img_file in enumerate(images_files): |
|
img = cv2.imread(os.path.join(image_folder, img_file)) |
|
out_path = os.path.join(visualize_path, "{:06d}".format(vid_id), "vid", os.path.basename(img_file)) |
|
cv2.imwrite(out_path, img) |
|
|
|
for id_text, text in enumerate(text_prompts): |
|
mmengine.mkdir_or_exist(os.path.join(visualize_path, "{:06d}".format(vid_id), "sample_{:06d}".format(id_text))) |
|
mmengine.put_text(text, os.path.join(visualize_path, "{:06d}".format(vid_id), "sample_{:06d}".format(id_text), |
|
'text.txt')) |
|
for id_frame, img_file in enumerate(images_files): |
|
img = cv2.imread(os.path.join(image_folder, img_file)) |
|
visualizer.set_image(img) |
|
mask = prediction['prediction_masks'][id_text][id_frame] |
|
mask = _mask.decode(mask).astype(bool) |
|
visualizer.draw_binary_masks(mask, colors='g') |
|
visual_result = visualizer.get_image() |
|
|
|
out_path = os.path.join(visualize_path, "{:06d}".format(vid_id), "sample_{:06d}".format(id_text), |
|
os.path.basename(img_file)) |
|
cv2.imwrite(out_path, visual_result) |
|
|
|
|
|
|
|
class VideoTestLoop(MMENGINE_ValLoop): |
|
def __init__(self, runner, dataloader, torch_dtype='fp16', select_metric='first', visualize=None, evaluator=None) -> None: |
|
|
|
super(MMENGINE_ValLoop, self).__init__(runner, dataloader) |
|
self._runner = runner |
|
self.torch_dtype = torch_dtype |
|
if torch_dtype is not None: |
|
self.torch_dtype = TORCH_DTYPE_MAP[torch_dtype] |
|
self.select_metric = select_metric |
|
|
|
self.visualize = visualize |
|
self.evaluator = evaluator |
|
|
|
def run(self) -> dict: |
|
"""Launch Test.""" |
|
self.runner.logger.info('==================== Start test loop ===================') |
|
self.runner.call_hook('before_test') |
|
self.runner.call_hook('before_test_epoch') |
|
|
|
if is_model_wrapper(self.runner.model): |
|
model = self.runner.model.module |
|
else: |
|
model = self.runner.model |
|
|
|
model.gradient_checkpointing_disable() |
|
model.eval() |
|
model.cuda() |
|
|
|
rank = get_rank() |
|
metrics = [] |
|
|
|
current_run_total_ids = 0 |
|
for _, dataset in enumerate(self.dataloader.dataset.datasets): |
|
if not hasattr(model, 'preparing_for_generation'): |
|
model.preparing_for_generation = MethodType(default_preparing_for_generation, model) |
|
print("Warning, the model do not have the preparing_for_generation() function, using the default!!!") |
|
model.preparing_for_generation(dataset.metainfo) |
|
|
|
|
|
results = [] |
|
n_samples = len(dataset) |
|
per_rank_samples = math.ceil(n_samples / get_world_size()) |
|
running_tot = per_rank_samples * get_world_size() |
|
assert running_tot >= n_samples |
|
per_rank_ids = range(per_rank_samples * rank, per_rank_samples * (rank + 1)) |
|
for idx in per_rank_ids: |
|
if n_samples <= idx: |
|
data_batch = dataset[n_samples - 1] |
|
else: |
|
data_batch = dataset[idx] |
|
self.run_iter(current_run_total_ids, data_batch, results, model) |
|
current_run_total_ids += 1 |
|
|
|
barrier() |
|
self.runner.logger.info('==================== Start collect results ===================') |
|
results = collect_results(results, n_samples) |
|
self.runner.logger.info('========= Starting the evaluation of a data ===========') |
|
if is_main_process(): |
|
metric = dataset.evaluate(results, self.runner.work_dir) |
|
objects = [metric] |
|
else: |
|
objects = [None] |
|
broadcast_object_list(objects) |
|
metric = objects[0] |
|
metrics.append(metric) |
|
|
|
|
|
if self.select_metric == 'first': |
|
metrics = metrics[0] |
|
else: |
|
raise NotImplementedError |
|
|
|
self.runner.logger.info('================ Ending test loop ================') |
|
self.runner.call_hook('after_test_epoch', metrics=metrics) |
|
self.runner.call_hook('after_test') |
|
return metrics |
|
|
|
@torch.no_grad() |
|
def run_iter(self, idx, data_batch, results, model): |
|
prediction = {'video_id': data_batch['video_id']} |
|
|
|
self.runner.call_hook( |
|
'before_test_iter', batch_idx=idx, data_batch=data_batch) |
|
|
|
outputs = model.predict_forward(**data_batch) |
|
prediction.update(outputs) |
|
results.append(prediction) |
|
|
|
if self.visualize: |
|
|
|
|
|
visualize(data_batch=data_batch, prediction=prediction, visualize_path=self.visualize) |
|
|
|
self.runner.call_hook( |
|
'after_test_iter', |
|
batch_idx=idx, |
|
data_batch=data_batch, |
|
outputs=outputs) |
|
|
|
def default_preparing_for_generation(self, metainfo): |
|
|
|
|
|
assert hasattr(self, 'tokenizer'), "The Model does not have the tokenizer!!!" |
|
|
|
self.bot_name = 'BOT' |
|
template = PROMPT_TEMPLATE['internlm2_chat'] |
|
self.template = template |
|
stop_words = [] |
|
stop_words += template.get('STOP_WORDS', []) |
|
stop_criteria = get_stop_criteria( |
|
tokenizer=self.tokenizer, stop_words=stop_words) |
|
self.stop_criteria = stop_criteria |
|
|
|
default_generation_kwargs = dict( |
|
max_new_tokens=2048, |
|
do_sample=False, |
|
eos_token_id=self.tokenizer.eos_token_id, |
|
pad_token_id=( |
|
self.tokenizer.pad_token_id |
|
if self.tokenizer.pad_token_id is not None |
|
else self.tokenizer.eos_token_id |
|
), |
|
) |
|
default_generation_kwargs.update(metainfo.get('generation_kwargs', {})) |
|
self.gen_config = GenerationConfig(**default_generation_kwargs) |
|
return |
|
|