File size: 11,348 Bytes
032e687
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
import os
import json
import argparse
from tqdm import tqdm
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
from pycocotools import mask as maskUtils
from pycocoevalcap.eval import COCOEvalCap
from transformers import AutoTokenizer, AutoModel
from sklearn.metrics.pairwise import cosine_similarity
import torch
import numpy as np

def parse_args():
    parser = argparse.ArgumentParser(description="Training")

    parser.add_argument("--split", required=True, help="Evaluation split, options are 'val', 'test'")
    parser.add_argument("--prediction_dir_path", required=True, help="The path where the inference results are stored.")
    parser.add_argument("--gt_dir_path", required=False, default="./data/glamm_data/annotations/gcg_val_test/",
                        help="The path containing GranD-f evaluation annotations.")

    args = parser.parse_args()

    return args


# Load pre-trained model tokenizer and model for evaluation
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = AutoModel.from_pretrained("bert-base-uncased")


def get_bert_embedding(text):
    inputs = tokenizer(text, return_tensors="pt", max_length=512, truncation=True)
    outputs = model(**inputs)
    # Use the mean of the last hidden states as sentence embedding
    sentence_embedding = torch.mean(outputs.last_hidden_state[0], dim=0).detach().numpy()

    return sentence_embedding

def compute_iou(mask1, mask2):
    intersection = np.logical_and(mask1, mask2)
    union = np.logical_or(mask1, mask2)
    iou = np.sum(intersection) / np.sum(union)

    return iou

def bbox_to_x1y1x2y2(bbox):
    x1, y1, w, h = bbox
    bbox = [x1, y1, x1 + w, y1 + h]

    return bbox

def compute_miou(pred_masks, gt_masks):
    # Computing mIoU between predicted masks and ground truth masks
    iou_matrix = np.zeros((len(pred_masks), len(gt_masks)))
    for i, pred_mask in enumerate(pred_masks):
        for j, gt_mask in enumerate(gt_masks):
            iou_matrix[i, j] = compute_iou(pred_mask, gt_mask)

    # One-to-one pairing and mean IoU calculation
    paired_iou = []
    while iou_matrix.size > 0 and np.max(iou_matrix) > 0:
        max_iou_idx = np.unravel_index(np.argmax(iou_matrix, axis=None), iou_matrix.shape)
        paired_iou.append(iou_matrix[max_iou_idx])
        iou_matrix = np.delete(iou_matrix, max_iou_idx[0], axis=0)
        iou_matrix = np.delete(iou_matrix, max_iou_idx[1], axis=1)

    return np.mean(paired_iou) if paired_iou else 0.0


def evaluate_mask_miou(coco_gt, image_ids, pred_save_path):
    # Load predictions
    coco_dt = coco_gt.loadRes(pred_save_path)

    mious = []
    for image_id in tqdm(image_ids):
        # Getting ground truth masks
        matching_anns = [ann for ann in coco_gt.anns.values() if ann['image_id'] == image_id]
        ann_ids = [ann['id'] for ann in matching_anns]

        gt_anns = coco_gt.loadAnns(ann_ids)
        gt_masks = [maskUtils.decode(ann['segmentation']) for ann in gt_anns if 'segmentation' in ann]

        # Getting predicted masks
        matching_anns = [ann for ann in coco_dt.anns.values() if ann['image_id'] == image_id]
        dt_ann_ids = [ann['id'] for ann in matching_anns]
        pred_anns = coco_dt.loadAnns(dt_ann_ids)
        pred_masks = [maskUtils.decode(ann['segmentation']) for ann in pred_anns if 'segmentation' in ann]

        # Compute and save the mIoU for the current image
        mious.append(compute_miou(pred_masks, gt_masks))

    # Report mean IoU across all images
    mean_miou = np.mean(mious) if mious else 0.0  # If list is empty, return 0.0

    print(f"Mean IoU (mIoU) across all images: {mean_miou:.3f}")


def compute_iou_matrix(pred_masks, gt_masks):
    iou_matrix = np.zeros((len(pred_masks), len(gt_masks)))
    for i, pred_mask in enumerate(pred_masks):
        for j, gt_mask in enumerate(gt_masks):
            iou_matrix[i, j] = compute_iou(pred_mask, gt_mask)

    return iou_matrix


def text_similarity_bert(str1, str2):
    emb1 = get_bert_embedding(str1)
    emb2 = get_bert_embedding(str2)

    return cosine_similarity([emb1], [emb2])[0, 0]


def find_best_matches(gt_anns, gt_labels, dt_anns, dt_labels, iou_threshold, text_sim_threshold, vectorizer=None):
    best_matches = []

    # Compute pair - wise IoU
    pred_masks = [maskUtils.decode(ann['segmentation']) for ann in dt_anns]
    gt_masks = [maskUtils.decode(ann['segmentation']) for ann in gt_anns]
    ious = compute_iou_matrix(gt_masks, pred_masks)

    text_sims = np.zeros((len(gt_labels), len(dt_labels)))

    for i, gt_label in enumerate(gt_labels):
        for j, dt_label in enumerate(dt_labels):
            text_sims[i, j] = text_similarity_bert(gt_label, dt_label)

    # Find one-to-one matches satisfying both IoU and text similarity thresholds
    while ious.size > 0:
        max_iou_idx = np.unravel_index(np.argmax(ious), ious.shape)
        if ious[max_iou_idx] < iou_threshold or text_sims[max_iou_idx] < text_sim_threshold:
            break  # No admissible pair found

        best_matches.append(max_iou_idx)

        # Remove selected annotations from consideration
        ious[max_iou_idx[0], :] = 0
        ious[:, max_iou_idx[1]] = 0
        text_sims[max_iou_idx[0], :] = 0
        text_sims[:, max_iou_idx[1]] = 0

    return best_matches  # List of index pairs [(gt_idx, dt_idx), ...]


def evaluate_recall_with_mapping(coco_gt, coco_cap_gt, image_ids, pred_save_path, cap_pred_save_path, iou_threshold=0.5,
                                 text_sim_threshold=0.5):
    coco_dt = coco_gt.loadRes(pred_save_path)
    coco_cap_dt = coco_cap_gt.loadRes(cap_pred_save_path)

    true_positives = 0
    actual_positives = 0

    for image_id in tqdm(image_ids):
        try:
            # gt_ann_ids = coco_gt.getAnnIds(imgIds=image_id, iscrowd=None)
            matching_anns = [ann for ann in coco_gt.anns.values() if ann['image_id'] == image_id]
            gt_ann_ids = [ann['id'] for ann in matching_anns]
            gt_anns = coco_gt.loadAnns(gt_ann_ids)

            # dt_ann_ids = coco_dt.getAnnIds(imgIds=image_id, iscrowd=None)
            matching_anns = [ann for ann in coco_dt.anns.values() if ann['image_id'] == image_id]
            dt_ann_ids = [ann['id'] for ann in matching_anns]
            dt_anns = coco_dt.loadAnns(dt_ann_ids)

            # gt_cap_ann_ids = coco_cap_gt.getAnnIds(imgIds=image_id)
            matching_anns = [ann for ann in coco_cap_gt.anns.values() if ann['image_id'] == image_id]
            gt_cap_ann_ids = [ann['id'] for ann in matching_anns]
            gt_cap_ann = coco_cap_gt.loadAnns(gt_cap_ann_ids)[0]

            # dt_cap_ann_ids = coco_cap_dt.getAnnIds(imgIds=image_id)
            matching_anns = [ann for ann in coco_cap_dt.anns.values() if ann['image_id'] == image_id]
            dt_cap_ann_ids = [ann['id'] for ann in matching_anns]
            dt_cap_ann = coco_cap_dt.loadAnns(dt_cap_ann_ids)[0]

            gt_labels = gt_cap_ann['labels']
            dt_labels = dt_cap_ann['labels']

            actual_positives += len(gt_labels)

            # Find best matching pairs
            best_matches = find_best_matches(gt_anns, gt_labels, dt_anns, dt_labels, iou_threshold, text_sim_threshold)

            true_positives += len(best_matches)
        except Exception as e:
            print(e)

    recall = true_positives / actual_positives if actual_positives > 0 else 0

    print(f"Recall: {recall:.3f}")


def main():
    args = parse_args()

    # Set the correct split
    split = args.split
    assert split == "val" or split == "test"  # GCG Evaluation has only val and test splits
    gt_mask_path = f"{args.gt_dir_path}/{split}_gcg_coco_mask_gt.json"
    gt_cap_path = f"{args.gt_dir_path}/{split}_gcg_coco_caption_gt.json"

    print(f"Starting evalution on {split} split.")

    # Get the image names of the split
    all_images_ids = []
    with open(gt_cap_path, 'r') as f:
        contents = json.load(f)
        for image in contents['images']:
            all_images_ids.append(image['id'])

    # The directory is used to store intermediate files
    tmp_dir_path = f"tmp/{os.path.basename(args.prediction_dir_path)}_{split}"
    os.makedirs(tmp_dir_path, exist_ok=True)  # Create directory if not exists already

    # Create predictions
    pred_save_path = f"{tmp_dir_path}/mask_pred_tmp_save.json"
    cap_pred_save_path = f"{tmp_dir_path}/cap_pred_tmp_save.json"
    coco_pred_file = []
    caption_pred_dict = {}
    for image_id in all_images_ids:
        prediction_path = f"{args.prediction_dir_path}/{image_id}.json"
        with open(prediction_path, 'r') as f:
            pred = json.load(f)
            bu = pred
            key = list(pred.keys())[0]
            pred = pred[key]
            try:
                caption_pred_dict[image_id] = {'caption': pred['caption'], 'labels': pred['phrases']}
            except Exception as e:
                pred = bu
                caption_pred_dict[image_id] = {'caption': pred['caption'], 'labels': pred['phrases']}
            for rle_mask in pred['pred_masks']:
                coco_pred_file.append({"image_id": image_id, "category_id": 1, "segmentation": rle_mask, "score": 1.0})

    # Save gcg_coco_predictions
    with open(pred_save_path, 'w') as f:
        json.dump(coco_pred_file, f)

    # Prepare the CAPTION predictions in COCO format
    cap_image_ids = []
    coco_cap_pred_file = []
    for image_id, values in caption_pred_dict.items():
        cap_image_ids.append(image_id)
        coco_cap_pred_file.append({"image_id": image_id, "caption": values['caption'], "labels": values['labels']})

    # Save gcg_caption_coco_predictions
    with open(cap_pred_save_path, 'w') as f:
        json.dump(coco_cap_pred_file, f)

    # # -------------------------------#
    # 1. Evaluate AP
    # Calculate mask mAP
    # Load the ground truth and predictions in COCO format
    coco_gt = COCO(gt_mask_path)
    coco_dt = coco_gt.loadRes(pred_save_path)  # load predictions
    # Initialize COCOEval and specify the metric you want to use
    coco_eval = COCOeval(coco_gt, coco_dt, "segm")  # "segm" for segmentation
    # Evaluate on a specific category
    coco_eval.params.catIds = [1]  # your category ID
    # Evaluate
    coco_eval.evaluate()
    coco_eval.accumulate()
    coco_eval.summarize()

    # # -------------------------------#
    # # 2. Evaluate Caption Quality
    try:
        coco_cap_gt = COCO(gt_cap_path)
        coco_cap_result = coco_cap_gt.loadRes(cap_pred_save_path)
        # create coco_eval object by taking coco and coco_result
        coco_eval = COCOEvalCap(coco_cap_gt, coco_cap_result)
        coco_eval.params['image_id'] = coco_cap_result.getImgIds()
        coco_eval.evaluate()
        for metric, score in coco_eval.eval.items():
            print(f'{metric}: {score:.3f}')
    except:
        pass

    # # -------------------------------#
    # 3. Evaluate Mask Mean MIoU
    coco_gt = COCO(gt_mask_path)  # Load ground truth annotations
    evaluate_mask_miou(coco_gt, all_images_ids, pred_save_path)

    # # -------------------------------#
    # 4. Evaluate Recall
    evaluate_recall_with_mapping(coco_gt, coco_cap_gt, all_images_ids, pred_save_path, cap_pred_save_path,
                                 iou_threshold=0.5, text_sim_threshold=0.5)


if __name__ == "__main__":
    main()