File size: 3,173 Bytes
47af768
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import json
import os
from .burst_helpers.burst_ow_base import BURST_OW_Base
from .burst_helpers.format_converter import GroundTruthBURSTFormatToTAOFormatConverter, PredictionBURSTFormatToTAOFormatConverter
from .. import utils


class BURST_OW(BURST_OW_Base):
    """Dataset class for TAO tracking"""

    @staticmethod
    def get_default_dataset_config():
        tao_config = BURST_OW_Base.get_default_dataset_config()
        code_path = utils.get_code_path()
        tao_config['GT_FOLDER'] = os.path.join(
            code_path, 'data/gt/burst/all_classes/val/')  # Location of GT data
        tao_config['TRACKERS_FOLDER'] = os.path.join(
            code_path, 'data/trackers/burst/open-world/val/')  # Trackers location
        return tao_config

    def _iou_type(self):
        return 'mask'

    def _box_or_mask_from_det(self, det):
        if "segmentation" in det:
            return det["segmentation"]
        else:
            return det["mask"]

    def _calculate_area_for_ann(self, ann):
        import pycocotools.mask as cocomask
        seg = self._box_or_mask_from_det(ann)
        return cocomask.area(seg)

    def _calculate_similarities(self, gt_dets_t, tracker_dets_t):
        similarity_scores = self._calculate_mask_ious(gt_dets_t, tracker_dets_t, is_encoded=True, do_ioa=False)
        return similarity_scores

    def _postproc_ground_truth_data(self, data):
        return GroundTruthBURSTFormatToTAOFormatConverter(data).convert()

    def _postproc_prediction_data(self, data):
        # if it's a list, it's already in TAO format and not in Ali format
        # however the image ids do not match and need to be remapped
        if isinstance(data, list):
            _remap_image_ids(data, self.gt_data)
            return data

        return PredictionBURSTFormatToTAOFormatConverter(
            self.gt_data, data,
            exemplar_guided=False).convert()


def _remap_image_ids(pred_data, ali_gt_data):
    code_path = utils.get_code_path()
    if 'split' in ali_gt_data:
        split = ali_gt_data['split']
    else:
        split = 'val'

    if split in ('val', 'validation'):
        tao_gt_path = os.path.join(
            code_path, 'data/gt/tao/tao_validation/gt.json')
    else:
        tao_gt_path = os.path.join(
            code_path, 'data/gt/tao/tao_test/test_without_annotations.json')

    with open(tao_gt_path) as f:
        tao_gt = json.load(f)

    tao_img_by_id = {}
    for img in tao_gt['images']:
        img_id = img['id']
        tao_img_by_id[img_id] = img

    ali_img_id_by_filename = {}
    for ali_img in ali_gt_data['images']:
        ali_img_id = ali_img['id']
        file_name = ali_img['file_name'].replace("validation", "val")
        ali_img_id_by_filename[file_name] = ali_img_id

    ali_img_id_by_tao_img_id = {}
    for tao_img_id, tao_img in tao_img_by_id.items():
        file_name = tao_img['file_name']
        ali_img_id = ali_img_id_by_filename[file_name]
        ali_img_id_by_tao_img_id[tao_img_id] = ali_img_id

    for det in pred_data:
        tao_img_id = det['image_id']
        ali_img_id = ali_img_id_by_tao_img_id[tao_img_id]
        det['image_id'] = ali_img_id