Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import torch | |
| def all_to_onehot(masks, labels): | |
| if len(masks.shape) == 3: | |
| Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1], masks.shape[2]), dtype=np.uint8) | |
| else: | |
| Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1]), dtype=np.uint8) | |
| for ni, l in enumerate(labels): | |
| Ms[ni] = (masks == l).astype(np.uint8) | |
| return Ms | |
| class MaskMapper: | |
| """ | |
| This class is used to convert a indexed-mask to a one-hot representation. | |
| It also takes care of remapping non-continuous indices | |
| It has two modes: | |
| 1. Default. Only masks with new indices are supposed to go into the remapper. | |
| This is also the case for YouTubeVOS. | |
| i.e., regions with index 0 are not "background", but "don't care". | |
| 2. Exhaustive. Regions with index 0 are considered "background". | |
| Every single pixel is considered to be "labeled". | |
| """ | |
| def __init__(self): | |
| self.labels = [] | |
| self.remappings = {} | |
| # if coherent, no mapping is required | |
| self.coherent = True | |
| def clear_labels(self): | |
| self.labels = [] | |
| self.remappings = {} | |
| # if coherent, no mapping is required | |
| self.coherent = True | |
| def convert_mask(self, mask, exhaustive=False): | |
| # mask is in index representation, H*W numpy array | |
| labels = np.unique(mask).astype(np.uint8) | |
| labels = labels[labels!=0].tolist() | |
| new_labels = list(set(labels) - set(self.labels)) | |
| if not exhaustive: | |
| assert len(new_labels) == len(labels), 'Old labels found in non-exhaustive mode' | |
| # add new remappings | |
| for i, l in enumerate(new_labels): | |
| self.remappings[l] = i+len(self.labels)+1 | |
| if self.coherent and i+len(self.labels)+1 != l: | |
| self.coherent = False | |
| if exhaustive: | |
| new_mapped_labels = range(1, len(self.labels)+len(new_labels)+1) | |
| else: | |
| if self.coherent: | |
| new_mapped_labels = new_labels | |
| else: | |
| new_mapped_labels = range(len(self.labels)+1, len(self.labels)+len(new_labels)+1) | |
| self.labels.extend(new_labels) | |
| mask = torch.from_numpy(all_to_onehot(mask, self.labels)).float() | |
| # mask num_objects*H*W | |
| return mask, new_mapped_labels | |
| def remap_index_mask(self, mask): | |
| # mask is in index representation, H*W numpy array | |
| if self.coherent: | |
| return mask | |
| new_mask = np.zeros_like(mask) | |
| for l, i in self.remappings.items(): | |
| new_mask[mask==i] = l | |
| return new_mask |