|
"""Modified from https://github.com/CSAILVision/semantic-segmentation-pytorch""" |
|
|
|
import os |
|
import sys |
|
|
|
import numpy as np |
|
import torch |
|
|
|
try: |
|
from urllib import urlretrieve |
|
except ImportError: |
|
from urllib.request import urlretrieve |
|
|
|
|
|
def load_url(url, model_dir='./pretrained', map_location=None): |
|
if not os.path.exists(model_dir): |
|
os.makedirs(model_dir) |
|
filename = url.split('/')[-1] |
|
cached_file = os.path.join(model_dir, filename) |
|
if not os.path.exists(cached_file): |
|
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) |
|
urlretrieve(url, cached_file) |
|
return torch.load(cached_file, map_location=map_location) |
|
|
|
|
|
def color_encode(labelmap, colors, mode='RGB'): |
|
labelmap = labelmap.astype('int') |
|
labelmap_rgb = np.zeros((labelmap.shape[0], labelmap.shape[1], 3), |
|
dtype=np.uint8) |
|
for label in np.unique(labelmap): |
|
if label < 0: |
|
continue |
|
labelmap_rgb += (labelmap == label)[:, :, np.newaxis] * \ |
|
np.tile(colors[label], |
|
(labelmap.shape[0], labelmap.shape[1], 1)) |
|
|
|
if mode == 'BGR': |
|
return labelmap_rgb[:, :, ::-1] |
|
else: |
|
return labelmap_rgb |
|
|