File size: 1,812 Bytes
82d5d16 bcc8459 6d6f3c6 82d5d16 bcc8459 82d5d16 6d6f3c6 82d5d16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
from typing import Dict, List, Tuple, Union
import numpy as np
import torch
from networks import deeplabv3plus_resnet50
from networks import convert_to_separable_conv, set_bn_momentum
def get_network() -> torch.nn.Module:
network = deeplabv3plus_resnet50(num_classes=21, pretrained_backbone=False)
convert_to_separable_conv(network.classifier)
set_bn_momentum(network.backbone, momentum=0.01)
state_dict = torch.hub.load_state_dict_from_url(
"https://www.robots.ox.ac.uk/~vgg/research/namedmask/shared_files/voc2012/namedmask_voc2012.pt",
map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu")
)
network.load_state_dict(state_dict, strict=True)
return network
def colourise_mask(
mask: np.ndarray,
):
assert len(mask.shape) == 2, ValueError(mask.shape)
h, w = mask.shape
grid = np.zeros((h, w, 3), dtype=np.uint8)
unique_labels = set(mask.flatten())
voc2012_palette = {
0: [0, 0, 0],
1: [128, 0, 0],
2: [0, 128, 0],
3: [128, 128, 0],
4: [0, 0, 128],
5: [128, 0, 128],
6: [0, 128, 128],
7: [128, 128, 128],
8: [64, 0, 0],
9: [192, 0, 0],
10: [64, 128, 0],
11: [192, 128, 0],
12: [64, 0, 128],
13: [192, 0, 128],
14: [64, 128, 128],
15: [192, 128, 128],
16: [0, 64, 0],
17: [128, 64, 0],
18: [0, 192, 0],
19: [128, 192, 0],
20: [0, 64, 128],
255: [255, 255, 255]
}
for l in unique_labels:
grid[mask == l] = np.array(voc2012_palette[l])
try:
grid[mask == l] = np.array(voc2012_palette[l])
except IndexError:
raise IndexError(f"No colour is found for a label id: {l}")
return grid |