| | import torch |
| | import os |
| | from enum import Enum |
| | from tqdm import tqdm |
| | import numpy as np |
| | |
| | |
| | |
| | |
| | |
| | |
| | import cv2 |
| | |
| | |
| | from psalm.train.train_datasets_eval import COCO_interactive_dataset_extrametric |
| | |
| | |
| | from typing import Dict, Optional, Sequence, List |
| | from dataclasses import dataclass, field |
| | import torch.distributed as dist |
| | import transformers |
| | from pathlib import Path |
| | from psalm.eval.segmentation_evaluation import openseg_classes |
| | from natsort import natsorted |
| | COLOR_MAP = openseg_classes.ADE20K_150_CATEGORIES |
| | import re |
| | from psalm.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, SEG_TOKEN_INDEX, CLS_TOKEN_INDEX, REGION_TOKEN_INDEX, REFER_TOKEN_INDEX |
| | import json |
| | from pycocotools import mask as mask_utils |
| |
|
| | |
| | @dataclass |
| | class DataArguments: |
| | data_path: str = field(default=None, |
| | metadata={"help": "Path to the training data."}) |
| | lazy_preprocess: bool = False |
| | is_multimodal: bool = False |
| | image_folder: Optional[str] = field(default='/path/to/val2017') |
| | model_path: Optional[str] = field(default="/path/to/model") |
| | mask_config: Optional[str] = field(default="./psalm/mask_config/maskformer2_swin_base_384_bs16_50ep.yaml") |
| | image_aspect_ratio: str = 'square' |
| | image_grid_pinpoints: Optional[str] = field(default=None) |
| | json_path: str = '/path/to/coco' |
| | model_map_name: str = 'psalm_video' |
| | version: str = 'llava_phi' |
| | segmentation: bool = True |
| | eval_batch_size: int = 1 |
| | dataloader_num_workers: int = 8 |
| | seg_task: Optional[str] = field(default="region") |
| | region_mask_type: Optional[str] = field(default=None) |
| | with_memory: bool = False |
| | resume: bool = False |
| | using_autocast: bool = False |
| | resume_path: Optional[str] = field(default=None) |
| | save_format: Optional[str] = field(default=None) |
| |
|
| |
|
| | |
| | @dataclass |
| | class DataCollatorForCOCODatasetV2(object): |
| | """Collate examples for supervised fine-tuning.""" |
| |
|
| | tokenizer: transformers.PreTrainedTokenizer |
| |
|
| | |
| | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: |
| | input_ids, labels = tuple([instance[key] for instance in instances] |
| | for key in ("input_ids", "labels")) |
| | input_ids = torch.nn.utils.rnn.pad_sequence( |
| | input_ids, |
| | batch_first=True, |
| | padding_value=self.tokenizer.pad_token_id) |
| | labels = torch.nn.utils.rnn.pad_sequence(labels, |
| | batch_first=True, |
| | padding_value=IGNORE_INDEX) |
| | input_ids = input_ids[:, :self.tokenizer.model_max_length] |
| | labels = labels[:, :self.tokenizer.model_max_length] |
| | batch = dict( |
| | input_ids=input_ids, |
| | labels=labels, |
| | attention_mask=input_ids.ne(self.tokenizer.pad_token_id), |
| | ) |
| | if 'image' in instances[0]: |
| | images = [instance['image'] for instance in instances] |
| | if all(x is not None and x.shape == images[0].shape for x in images): |
| | batch['images'] = torch.stack(images) |
| | else: |
| | batch['images'] = images |
| | if 'vp_image' in instances[0]: |
| | vp_images = [instance['vp_image'] for instance in instances] |
| | if all(x is not None and x.shape == vp_images[0].shape for x in vp_images): |
| | batch['vp_images'] = torch.stack(vp_images) |
| | else: |
| | batch['vp_images'] = vp_images |
| | for instance in instances: |
| | for key in ['input_ids', 'labels', 'image']: |
| | del instance[key] |
| | batch['seg_info'] = [instance for instance in instances] |
| |
|
| | if 'dataset_type' in instances[0]: |
| | batch['dataset_type'] = [instance['dataset_type'] for instance in instances] |
| |
|
| | if 'class_name_ids' in instances[0]: |
| | class_name_ids = [instance['class_name_ids'] for instance in instances] |
| | if any(x.shape != class_name_ids[0].shape for x in class_name_ids): |
| | batch['class_name_ids'] = torch.nn.utils.rnn.pad_sequence( |
| | class_name_ids, |
| | batch_first=True, |
| | padding_value=-1, |
| | ) |
| | else: |
| | batch['class_name_ids'] = torch.stack(class_name_ids, dim=0) |
| | if 'token_refer_id' in instances[0]: |
| | token_refer_id = [instance['token_refer_id'] for instance in instances] |
| | batch['token_refer_id'] = token_refer_id |
| | if 'cls_indices' in instances[0]: |
| | cls_indices = [instance['cls_indices'] for instance in instances] |
| | if any(x.shape != cls_indices[0].shape for x in cls_indices): |
| | batch['cls_indices'] = torch.nn.utils.rnn.pad_sequence( |
| | cls_indices, |
| | batch_first=True, |
| | padding_value=-1, |
| | ) |
| | else: |
| | batch['cls_indices'] = torch.stack(cls_indices, dim=0) |
| | if 'random_idx' in instances[0]: |
| | random_idxs = [instance['random_idx'] for instance in instances] |
| | batch['random_idx'] = torch.stack(random_idxs, dim=0) |
| | if 'class_name_embedding_indices' in instances[0]: |
| | class_name_embedding_indices = [instance['class_name_embedding_indices'] for instance in instances] |
| | class_name_embedding_indices = torch.nn.utils.rnn.pad_sequence( |
| | class_name_embedding_indices, |
| | batch_first=True, |
| | padding_value=0) |
| | batch['class_name_embedding_indices'] = class_name_embedding_indices |
| | if 'refer_embedding_indices' in instances[0]: |
| | refer_embedding_indices = [instance['refer_embedding_indices'] for instance in instances] |
| | refer_embedding_indices = torch.nn.utils.rnn.pad_sequence( |
| | refer_embedding_indices, |
| | batch_first=True, |
| | padding_value=0) |
| | batch['refer_embedding_indices'] = refer_embedding_indices |
| |
|
| | return batch |
| | |
| | |
| | def parse_outputs(outputs,gt_mask): |
| | res_list = [] |
| | for output in outputs: |
| | |
| |
|
| | pred_mask = output['instances'].pred_masks |
| | pred_mask = pred_mask.cpu().numpy() |
| | scores = output['instances'].scores.transpose(1,0).cpu().numpy() |
| | gt_mask = output['gt'].cpu().numpy().astype(np.uint8) |
| | try: |
| | pred_cls = output['instances'].pred_classes.cpu().numpy() |
| | except: |
| | pred_cls = None |
| | assert scores.shape[0] == gt_mask.shape[0] |
| | for i in range(gt_mask.shape[0]): |
| | res = { |
| | 'pred':pred_mask, |
| | 'gt': gt_mask[i], |
| | 'scores':scores[i], |
| | 'pred_cls':pred_cls |
| | } |
| | res_list.append(res) |
| | return res_list |
| |
|
| |
|
| | |
| | class DAVIS_Dataset(COCO_interactive_dataset_extrametric): |
| |
|
| | |
| | def __getitem__(self, idx): |
| | data = self.data[idx] |
| |
|
| | |
| | image_file = data['image'] |
| | |
| | image_folder = self.data_args.image_folder |
| |
|
| |
|
| | data_dict = {} |
| | |
| | data_dict['file_name'] = os.path.join(image_folder, image_file) |
| | data_dict['height'] = data['image_info']['height'] |
| | data_dict['width'] = data['image_info']['width'] |
| | |
| | data_dict['image_id'] = data['new_img_id'] |
| | |
| | data_dict['annotations'] = data['anns'] |
| | |
| | data_dict['vp_annotations'] = data['first_frame_anns'] |
| | |
| | data_dict['vp_image'] = os.path.join(image_folder,data['first_frame_image']) |
| | for annotation in data_dict['annotations']: |
| | annotation['bbox_mode'] = BoxMode.XYXY_ABS |
| | |
| | annotation['bbox'] = [0,0,0,0] |
| | annotation['image_id'] = data['new_img_id'] |
| | for annotation in data_dict['vp_annotations']: |
| | annotation['bbox_mode'] = BoxMode.XYXY_ABS |
| | annotation['bbox'] = [0,0,0,0] |
| | annotation['image_id'] = data['new_img_id'] |
| |
|
| | |
| | processor = self.data_args.image_processor['null_mask'] |
| | |
| | |
| | |
| | region_mask_type = getattr(self.data_args,'region_mask_type',None) |
| | if region_mask_type is not None: |
| | region_mask_type = region_mask_type.split('||') |
| | |
| | |
| | data_dict = processor.preprocess(data_dict,region_mask_type=region_mask_type,mask_format='bitmask') |
| |
|
| |
|
| | |
| | |
| | num_target = len(data_dict['instances']) |
| | |
| | prefix_inst = 'This is an image <image>, Please segment by given regions' |
| | |
| | regions_inst = ' <region>,' * (num_target - 1) + ' <region>.' |
| | sources_value = f'\nThis is all regions: {regions_inst}\n' |
| |
|
| | |
| | sources = [ |
| | [{'from': 'human', 'value': prefix_inst + sources_value}, |
| | {'from': 'gpt', 'value': '\n[SEG]<seg>'}]] |
| |
|
| | text_dict = self.preprocess_llama2(sources, self.tokenizer) |
| | |
| | input_ids = text_dict['input_ids'][0] |
| | |
| | labels = text_dict['labels'][0] |
| | data_dict['input_ids'] = input_ids |
| | data_dict['labels'] = labels |
| | data_dict['dataset_type'] = 'region_coco' |
| |
|
| | return data_dict |
| |
|
| |
|
| | def evaluation(): |
| | |
| | parser = transformers.HfArgumentParser(DataArguments) |
| | data_args = parser.parse_args_into_dataclasses()[0] |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| | |
| | |
| | |
| | |
| | |
| | |
| | save_path_json = "/scratch/yuqian_fu/competition_test_20250518_hardcode_v2_exoego_new.json" |
| | data_path = "/home/yuqian_fu/Projects/PSALM/egoexo_test_framelevel.json" |
| | with open(data_path, "r") as fp: |
| | datas = json.load(fp) |
| | splits_path = "/home/yuqian_fu/Projects/ego-exo4d-relation/correspondence/split.json" |
| | with open(splits_path, "r") as fp: |
| | splits = json.load(fp) |
| | takes_all = splits["test"] |
| | |
| | NUM = len(takes_all) // 8 |
| | takes_all = takes_all[:NUM] |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | |
| | if data_args.resume: |
| | with open(data_args.resume_path, "r") as fp: |
| | result = json.load(fp) |
| | |
| | |
| | |
| | processed_takes = set(result.keys()) |
| | takes_all = [take_id for take_id in takes_all if take_id not in processed_takes] |
| | else: |
| | result = {} |
| | |
| | |
| | if data_args.using_autocast: |
| | scaler = torch.cuda.amp.autocast(enabled=True) |
| | |
| | anno_miss_num = 0 |
| |
|
| | with torch.no_grad(): |
| | for take_id in tqdm(takes_all): |
| | print("current take_id:", take_id) |
| | |
| | with open(f'{data_args.image_folder}/{take_id}/annotation.json', 'r') as fp: |
| | annotations = json.load(fp) |
| |
|
| | |
| | objs = natsorted(list(annotations["masks"].keys())) |
| | coco_id_to_cont_id = {cont_id + 1: coco_id for cont_id, coco_id in enumerate(objs)} |
| | id_range = list(coco_id_to_cont_id.keys()) |
| | |
| |
|
| | |
| | datas_list = [] |
| | for data in datas: |
| | if data['video_name'] == take_id: |
| | datas_list.append(data) |
| | |
| | |
| | |
| | |
| |
|
| | |
| | pred_json = {'masks': {}, 'subsample_idx': annotations['subsample_idx']} |
| | objs_after = [] |
| |
|
| | for idx, inputs in enumerate(datas_list): |
| | |
| | |
| | |
| | data_idx = datas_list[idx] |
| | |
| | |
| | |
| | target_cam = data_idx['image'].split('/')[-2] |
| | query_cam = data_idx['first_frame_image'].split('/')[-2] |
| | pair_key = f'{query_cam}_{target_cam}' |
| | |
| | |
| | |
| | |
| | |
| | id = data_idx['image'].split('/')[-1].split('.')[0] |
| | |
| | |
| | |
| | |
| | |
| | h = data_idx['image_info']['height'] |
| | w = data_idx['image_info']['width'] |
| |
|
| | |
| |
|
| | for i in range(len(data_idx['first_frame_anns'])): |
| | cur_fill_number = data_idx['first_frame_anns'][i]['category_id'] |
| | |
| | if cur_fill_number not in id_range: |
| | print(f"cur_fill_number {cur_fill_number} not in id_range, skipping...") |
| | raise ValueError(f"cur_fill_number {cur_fill_number} not in id_range, skipping...") |
| | |
| |
|
| | |
| | obj_name = coco_id_to_cont_id[cur_fill_number] |
| | objs_after.append(obj_name) |
| | |
| | if target_cam not in annotations['masks'][obj_name].keys(): |
| | print(f"target_cam {target_cam} not in {obj_name}, skipping...") |
| | raise ValueError(f"target_cam {target_cam} not in {obj_name}, skipping...") |
| | if id not in annotations["masks"][obj_name][target_cam].keys(): |
| | anno_miss_num += 1 |
| | |
| | |
| | |
| | cur_pred = np.random.randint(0, 2, (h, w), dtype=np.uint8) |
| | if data_args.save_format == 'rle': |
| | cur_pred = mask_utils.encode(np.asfortranarray(cur_pred.astype(np.uint8))) |
| | cur_pred['counts'] = cur_pred['counts'].decode('ascii') |
| | elif data_args.save_format == 'png': |
| | save_path = f'/scratch/yuqian_fu/results_v2/{take_id}/{target_cam}/{obj_name}/{id}.png' |
| | os.makedirs(os.path.dirname(save_path), exist_ok=True) |
| | cv2.imwrite(save_path, cur_pred.astype(np.uint8)) |
| | else: |
| | raise ValueError(f"Unsupported save format: {data_args.save_format}") |
| | |
| | |
| | if obj_name not in pred_json['masks']: |
| | pred_json['masks'][obj_name] = {} |
| |
|
| | |
| | if pair_key not in pred_json['masks'][obj_name]: |
| | pred_json['masks'][obj_name][pair_key] = {} |
| | |
| | if data_args.save_format == 'rle': |
| | pred_json['masks'][obj_name][f'{query_cam}_{target_cam}'][id] = {'pred_mask': cur_pred, 'confidence': 1.0} |
| | elif data_args.save_format == 'png': |
| | pred_json['masks'][obj_name][f'{query_cam}_{target_cam}'][id] = {'pred_mask': save_path, 'confidence': 1.0} |
| |
|
| | |
| | if len(pred_json['masks']) == 0: |
| | print(f"pred_json['masks'] is empty for take_id {take_id}, skipping...") |
| | |
| |
|
| | |
| | |
| | check_obj = set(objs) - set(objs_after) |
| | if len(check_obj) > 0: |
| | for obj in check_obj: |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | print(f"{take_id}缺失物体{obj}") |
| | pred_json['masks'][obj] = {} |
| | |
| | result[take_id] = pred_json |
| | |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | print(f"Total number of missing annotations: {anno_miss_num}") |
| | |
| | if __name__ == '__main__': |
| | evaluation() |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|