File size: 7,769 Bytes
749745d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import os
import torch
from tqdm import tqdm
from collections import defaultdict
import collections
import numpy as np
import cv2, json, base64
import pdb
from copy import deepcopy
from pprint import pprint
import os.path as op

from maskrcnn_benchmark.structures.bounding_box import BoxList
from maskrcnn_benchmark.data.datasets.tsv import load_from_yaml_file
from maskrcnn_benchmark.structures.boxlist_ops import cat_boxlist
from maskrcnn_benchmark.data.datasets.od_to_grounding import clean_name

def ensure_file(file_name):
    # if the directory does not exist, create it
    if not os.path.exists(os.path.dirname(file_name)):
        os.makedirs(os.path.dirname(file_name))
        ensure_file(os.path.dirname(file_name))

class TSVResultWriter(object):
    def __init__(self, tokenizer = None, max_visualize_num=-1, dataset_length=-1, threshold = -1.0, in_order = True, write_freq = 100, file_name = None):
        self.tokenizer = tokenizer
        self.max_visualize_num = max_visualize_num
        self.dataset_length = dataset_length
        self.threshold = threshold
        self.in_order = in_order
        self.file_name = file_name
        self.write_freq = write_freq
        self.predictions = []
        if not self.in_order:
            assert(0)

    @staticmethod
    def imagelist_to_b64(imgs):
        imgs = imgs.tensors.permute(0, 2, 3, 1).cpu().numpy()
        # the last dimension is BGR, convert to RGB
        imgs = ((imgs * [0.225, 0.224, 0.229] + [0.406, 0.456, 0.485]) * 255).astype(np.uint8)
        # imgs = [cv2.cvtColor(img, cv2.COLOR_RGB2BGR) for img in imgs]
        imgs = [base64.b64encode(cv2.imencode('.jpg', image)[1]) for image in imgs]
        return imgs

    def update(self, imgs, results):
        if self.max_visualize_num > 0 and len(self.predictions) >= self.max_visualize_num:
            return

        imgs = self.imagelist_to_b64(imgs)

        for img_encoded_str, result in zip(imgs, results):
            # result: (img_id, {"scores": scores, "labels": labels, "boxes": boxes})
            annotations = result[1]
            # img_encoded_str = image #base64.b64encode(cv2.imencode('.jpg', image)[1])
            # convert boxes
            boxes = annotations["raw_boxes"] #box_cxcywh_to_xyxy(annotations["boxes"])
            pred = {}
            pred["objects"] = []

            # pred["caption"] = ""
            for s, rect, l in zip(annotations["scores"], boxes.tolist(), annotations["labels_text"]):
                pred["num_boxes"] = len(rect)
                pred["objects"].append({"rect": rect,
                                        "class": l,
                                        "conf": float(s)
                                        #"caption": captions[0]
                                        })
            if "caption" in annotations:
                pred['objects'][0]["caption"] = annotations["caption"] # record the caption in the first object; a workaround for the tsvviewer

            pred["predicates"] = []
            pred["relations"] = []
            pred = [str(result[0]), json.dumps(pred, sort_keys=False), img_encoded_str]
            self.predictions.append(pred)

        if len(self.predictions) % self.write_freq == 0 or len(self.predictions) >= self.max_visualize_num:
            self.tsv_writer(self.predictions, self.file_name)


    def update_train_data(self, imgs, targets):
        if self.max_visualize_num > 0 and len(self.predictions) >= self.max_visualize_num:
            return

        imgs = self.imagelist_to_b64(imgs)
        for img_encoded_str, target in zip(imgs, targets):
            boxes = target.bbox
            pred = {}
            pred["objects"] = []
            pred["caption"] = [target.extra_fields["caption"]]
            caption_tokenized = self.tokenizer.tokenize(target.extra_fields["caption"])
            for rect, positive_map in zip(boxes.tolist(), target.extra_fields["positive_map"]):
                pred["num_boxes"] = len(rect)
                non_zero_indexes = positive_map.nonzero().squeeze(1).tolist()
                label = [caption_tokenized[i-1] for i in non_zero_indexes]
                label = " ".join(label).replace(" ##", "")
                pred["objects"].append({"rect": rect,
                                        "class": label,
                                        "conf": 1.0,
                                        #"caption": target.extra_fields["caption"]
                                        })
            try:
                pred['objects'][0]["caption"] = target.extra_fields["caption"] # record the caption in the first object; a workaround for the tsvviewer
            except:
                pass
            pred["predicates"] = []
            pred["relations"] = []
            pred = [str(0), json.dumps(pred, sort_keys=False), img_encoded_str]
            self.predictions.append(pred)
        if len(self.predictions) % self.write_freq == 0 or len(self.predictions) >= self.max_visualize_num:
            ensure_file(self.file_name)
            self.tsv_writer(self.predictions, self.file_name)

    def update_gold_od_data(self, imgs, targets, categories):
        if self.max_visualize_num > 0 and len(self.predictions) >= self.max_visualize_num:
            return

        imgs = self.imagelist_to_b64(imgs)
        for img_encoded_str, target in zip(imgs, targets):
            boxes = target["boxes"]
            pred = {}
            pred["objects"] = []

            for rect, label in zip(boxes.tolist(), target["labels"].tolist()):
                pred["num_boxes"] = len(rect)
                cat = categories[label]
                label_text = "{}_{}".format(cat["name"], cat["frequency"])
                pred["objects"].append({"rect": rect,
                                        "class": label_text,
                                        "conf": 1.0,
                                        #"caption": target.extra_fields["caption"]
                                        })
            pred["predicates"] = []
            pred["relations"] = []
            pred = [str(0), json.dumps(pred, sort_keys=False), img_encoded_str]
            self.predictions.append(pred)

        if len(self.predictions) % self.write_freq == 0 or len(self.predictions) >= self.max_visualize_num:
            ensure_file(self.file_name)
            print("Writing to {}".format(self.file_name))
            self.tsv_writer(self.predictions, self.file_name)

    @staticmethod
    def tsv_writer(values, tsv_file, sep='\t'):
        try:
            os.makedirs(op.dirname(tsv_file))
        except:
            pass
        lineidx_file = op.splitext(tsv_file)[0] + '.lineidx'
        idx = 0
        tsv_file_tmp = tsv_file + '.tmp'
        lineidx_file_tmp = lineidx_file + '.tmp'
        with open(tsv_file_tmp, 'w') as fp, open(lineidx_file_tmp, 'w') as fpidx:
            assert values is not None
            for value in values:
                assert value is not None
                # this step makes sure python2 and python3 encoded img string are the same.
                # for python2 encoded image string, it is a str class starts with "/".
                # for python3 encoded image string, it is a bytes class starts with "b'/".
                # v.decode('utf-8') converts bytes to str so the content is the same.
                # v.decode('utf-8') should only be applied to bytes class type.
                value = [v if type(v)!=bytes else v.decode('utf-8') for v in value]
                v = '{0}\n'.format(sep.join(map(str, value)))
                fp.write(v)
                fpidx.write(str(idx) + '\n')
                idx = idx + len(v)
        os.rename(tsv_file_tmp, tsv_file)
        os.rename(lineidx_file_tmp, lineidx_file)