File size: 2,639 Bytes
4d1ebf3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import numpy as np
import torch

def all_to_onehot(masks, labels):
    if len(masks.shape) == 3:
        Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1], masks.shape[2]), dtype=np.uint8)
    else:
        Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1]), dtype=np.uint8)

    for ni, l in enumerate(labels):
        Ms[ni] = (masks == l).astype(np.uint8)
        
    return Ms

class MaskMapper:
    """
    This class is used to convert a indexed-mask to a one-hot representation.
    It also takes care of remapping non-continuous indices
    It has two modes:
        1. Default. Only masks with new indices are supposed to go into the remapper.
        This is also the case for YouTubeVOS.
        i.e., regions with index 0 are not "background", but "don't care".

        2. Exhaustive. Regions with index 0 are considered "background".
        Every single pixel is considered to be "labeled".
    """
    def __init__(self):
        self.labels = []
        self.remappings = {}

        # if coherent, no mapping is required
        self.coherent = True

    def clear_labels(self):
        self.labels = []
        self.remappings = {}
        # if coherent, no mapping is required
        self.coherent = True

    def convert_mask(self, mask, exhaustive=False):
        # mask is in index representation, H*W numpy array
        labels = np.unique(mask).astype(np.uint8)
        labels = labels[labels!=0].tolist()

        new_labels = list(set(labels) - set(self.labels))
        if not exhaustive:
            assert len(new_labels) == len(labels), 'Old labels found in non-exhaustive mode'

        # add new remappings
        for i, l in enumerate(new_labels):
            self.remappings[l] = i+len(self.labels)+1
            if self.coherent and i+len(self.labels)+1 != l:
                self.coherent = False

        if exhaustive:
            new_mapped_labels = range(1, len(self.labels)+len(new_labels)+1)
        else:
            if self.coherent:
                new_mapped_labels = new_labels
            else:
                new_mapped_labels = range(len(self.labels)+1, len(self.labels)+len(new_labels)+1)

        self.labels.extend(new_labels)
        mask = torch.from_numpy(all_to_onehot(mask, self.labels)).float()

        # mask num_objects*H*W
        return mask, new_mapped_labels


    def remap_index_mask(self, mask):
        # mask is in index representation, H*W numpy array
        if self.coherent:
            return mask

        new_mask = np.zeros_like(mask)
        for l, i in self.remappings.items():
            new_mask[mask==i] = l
        return new_mask