Spaces:
Running
Running
# Usage as a simple differentiable segmenter base class | |
import os, torch, numpy, json, glob | |
import skimage.morphology | |
from collections import OrderedDict | |
from netdissect import upsegmodel | |
from netdissect import segmodel as segmodel_module | |
from netdissect.easydict import EasyDict | |
from urllib.request import urlretrieve | |
class BaseSegmenter: | |
def get_label_and_category_names(self): | |
''' | |
Returns two lists: first, a list of tuples [(label, category), ...] | |
where the label and category are human-readable strings indicating | |
the meaning of a segmentation class. The 0th segmentation class | |
should be reserved for a label ('-') that means "no prediction." | |
The second list should just be a list of [category,...] listing | |
all categories in a canonical order. | |
''' | |
raise NotImplemented() | |
def segment_batch(self, tensor_images, downsample=1): | |
''' | |
Returns a multilabel segmentation for the given batch of (RGB [-1...1]) | |
images. Each pixel of the result is a torch.long indicating a | |
predicted class number. Multiple classes can be predicted for | |
the same pixel: output shape is (n, multipred, y, x), where | |
multipred is 3, 5, or 6, for how many different predicted labels can | |
be given for each pixel (depending on whether subdivision is being | |
used). If downsample is specified, then the output y and x dimensions | |
are downsampled from the original image. | |
''' | |
raise NotImplemented() | |
def predict_single_class(self, tensor_images, classnum, downsample=1): | |
''' | |
Given a batch of images (RGB, normalized to [-1...1]) and | |
a specific segmentation class number, returns a tuple with | |
(1) a differentiable ([0..1]) prediction score for the class | |
at every pixel of the input image. | |
(2) a binary mask showing where in the input image the | |
specified class is the best-predicted label for the pixel. | |
Does not work on subdivided labels. | |
''' | |
raise NotImplemented() | |
class UnifiedParsingSegmenter(BaseSegmenter): | |
''' | |
This is a wrapper for a more complicated multi-class segmenter, | |
as described in https://arxiv.org/pdf/1807.10221.pdf, and as | |
released in https://github.com/CSAILVision/unifiedparsing. | |
For our purposes and to simplify processing, we do not use | |
whole-scene predictions, and we only consume part segmentations | |
for the three largest object classes (sky, building, person). | |
''' | |
def __init__(self, segsizes=None, segdiv=None): | |
# Create a segmentation model | |
if segsizes is None: | |
segsizes = [256] | |
if segdiv == None: | |
segdiv = 'undivided' | |
segvocab = 'upp' | |
segarch = ('resnet50', 'upernet') | |
epoch = 40 | |
segmodel = load_unified_parsing_segmentation_model( | |
segarch, segvocab, epoch) | |
segmodel.cuda() | |
self.segmodel = segmodel | |
self.segsizes = segsizes | |
self.segdiv = segdiv | |
mult = 1 | |
if self.segdiv == 'quad': | |
mult = 5 | |
self.divmult = mult | |
# Assign class numbers for parts. | |
first_partnumber = ( | |
(len(segmodel.labeldata['object']) - 1) * mult + 1 + | |
(len(segmodel.labeldata['material']) - 1)) | |
# We only use parts for these three types of objects, for efficiency. | |
partobjects = ['sky', 'building', 'person'] | |
partnumbers = {} | |
partnames = [] | |
objectnumbers = {k: v | |
for v, k in enumerate(segmodel.labeldata['object'])} | |
part_index_translation = [] | |
# We merge some classes. For example "door" is both an object | |
# and a part of a building. To avoid confusion, we just count | |
# such classes as objects, and add part scores to the same index. | |
for owner in partobjects: | |
part_list = segmodel.labeldata['object_part'][owner] | |
numeric_part_list = [] | |
for part in part_list: | |
if part in objectnumbers: | |
numeric_part_list.append(objectnumbers[part]) | |
elif part in partnumbers: | |
numeric_part_list.append(partnumbers[part]) | |
else: | |
partnumbers[part] = len(partnames) + first_partnumber | |
partnames.append(part) | |
numeric_part_list.append(partnumbers[part]) | |
part_index_translation.append(torch.tensor(numeric_part_list)) | |
self.objects_with_parts = [objectnumbers[obj] for obj in partobjects] | |
self.part_index = part_index_translation | |
self.part_names = partnames | |
# For now we'll just do object and material labels. | |
self.num_classes = 1 + ( | |
len(segmodel.labeldata['object']) - 1) * mult + ( | |
len(segmodel.labeldata['material']) - 1) + len(partnames) | |
self.num_object_classes = len(self.segmodel.labeldata['object']) - 1 | |
def get_label_and_category_names(self, dataset=None): | |
''' | |
Lists label and category names. | |
''' | |
# Labels are ordered as follows: | |
# 0, [object labels] [divided object labels] [materials] [parts] | |
# The zero label is reserved to mean 'no prediction'. | |
if self.segdiv == 'quad': | |
suffixes = ['t', 'l', 'b', 'r'] | |
else: | |
suffixes = [] | |
divided_labels = [] | |
for suffix in suffixes: | |
divided_labels.extend([('%s-%s' % (label, suffix), 'part') | |
for label in self.segmodel.labeldata['object'][1:]]) | |
# Create the whole list of labels | |
labelcats = ( | |
[(label, 'object') | |
for label in self.segmodel.labeldata['object']] + | |
divided_labels + | |
[(label, 'material') | |
for label in self.segmodel.labeldata['material'][1:]] + | |
[(label, 'part') for label in self.part_names]) | |
return labelcats, ['object', 'part', 'material'] | |
def raw_seg_prediction(self, tensor_images, downsample=1): | |
''' | |
Generates a segmentation by applying multiresolution voting on | |
the segmentation model, using (rounded to 32 pixels) a set of | |
resolutions in the example benchmark code. | |
''' | |
y, x = tensor_images.shape[2:] | |
b = len(tensor_images) | |
tensor_images = (tensor_images + 1) / 2 * 255 | |
tensor_images = torch.flip(tensor_images, (1,)) # BGR!!!? | |
tensor_images -= torch.tensor([102.9801, 115.9465, 122.7717]).to( | |
dtype=tensor_images.dtype, device=tensor_images.device | |
)[None,:,None,None] | |
seg_shape = (y // downsample, x // downsample) | |
# We want these to be multiples of 32 for the model. | |
sizes = [(s, s) for s in self.segsizes] | |
pred = {category: torch.zeros( | |
len(tensor_images), len(self.segmodel.labeldata[category]), | |
seg_shape[0], seg_shape[1]).cuda() | |
for category in ['object', 'material']} | |
part_pred = {partobj_index: torch.zeros( | |
len(tensor_images), len(partindex), | |
seg_shape[0], seg_shape[1]).cuda() | |
for partobj_index, partindex in enumerate(self.part_index)} | |
for size in sizes: | |
if size == tensor_images.shape[2:]: | |
resized = tensor_images | |
else: | |
resized = torch.nn.AdaptiveAvgPool2d(size)(tensor_images) | |
r_pred = self.segmodel( | |
dict(img=resized), seg_size=seg_shape) | |
for k in pred: | |
pred[k] += r_pred[k] | |
for k in part_pred: | |
part_pred[k] += r_pred['part'][k] | |
return pred, part_pred | |
def segment_batch(self, tensor_images, downsample=1): | |
''' | |
Returns a multilabel segmentation for the given batch of (RGB [-1...1]) | |
images. Each pixel of the result is a torch.long indicating a | |
predicted class number. Multiple classes can be predicted for | |
the same pixel: output shape is (n, multipred, y, x), where | |
multipred is 3, 5, or 6, for how many different predicted labels can | |
be given for each pixel (depending on whether subdivision is being | |
used). If downsample is specified, then the output y and x dimensions | |
are downsampled from the original image. | |
''' | |
pred, part_pred = self.raw_seg_prediction(tensor_images, | |
downsample=downsample) | |
piece_channels = 2 if self.segdiv == 'quad' else 0 | |
y, x = tensor_images.shape[2:] | |
seg_shape = (y // downsample, x // downsample) | |
segs = torch.zeros(len(tensor_images), 3 + piece_channels, | |
seg_shape[0], seg_shape[1], | |
dtype=torch.long, device=tensor_images.device) | |
_, segs[:,0] = torch.max(pred['object'], dim=1) | |
# Get materials and translate to shared numbering scheme | |
_, segs[:,1] = torch.max(pred['material'], dim=1) | |
maskout = (segs[:,1] == 0) | |
segs[:,1] += (len(self.segmodel.labeldata['object']) - 1) * self.divmult | |
segs[:,1][maskout] = 0 | |
# Now deal with subparts of sky, buildings, people | |
for i, object_index in enumerate(self.objects_with_parts): | |
trans = self.part_index[i].to(segs.device) | |
# Get the argmax, and then translate to shared numbering scheme | |
seg = trans[torch.max(part_pred[i], dim=1)[1]] | |
# Only trust the parts where the prediction also predicts the | |
# owning object. | |
mask = (segs[:,0] == object_index) | |
segs[:,2][mask] = seg[mask] | |
if self.segdiv == 'quad': | |
segs = self.expand_segment_quad(segs, self.segdiv) | |
return segs | |
def predict_single_class(self, tensor_images, classnum, downsample=1): | |
''' | |
Given a batch of images (RGB, normalized to [-1...1]) and | |
a specific segmentation class number, returns a tuple with | |
(1) a differentiable ([0..1]) prediction score for the class | |
at every pixel of the input image. | |
(2) a binary mask showing where in the input image the | |
specified class is the best-predicted label for the pixel. | |
Does not work on subdivided labels. | |
''' | |
result = 0 | |
pred, part_pred = self.raw_seg_prediction(tensor_images, | |
downsample=downsample) | |
material_offset = (len(self.segmodel.labeldata['object']) - 1 | |
) * self.divmult | |
if material_offset < classnum < material_offset + len( | |
self.segmodel.labeldata['material']): | |
return ( | |
pred['material'][:, classnum - material_offset], | |
pred['material'].max(dim=1)[1] == classnum - material_offset) | |
mask = None | |
if classnum < len(self.segmodel.labeldata['object']): | |
result = pred['object'][:, classnum] | |
mask = (pred['object'].max(dim=1)[1] == classnum) | |
# Some objects, like 'door', are also a part of other objects, | |
# so add the part prediction also. | |
for i, object_index in enumerate(self.objects_with_parts): | |
local_index = (self.part_index[i] == classnum).nonzero() | |
if len(local_index) == 0: | |
continue | |
local_index = local_index.item() | |
# Ignore part predictions outside the mask. (We could pay | |
# atttention to and penalize such predictions.) | |
mask2 = (pred['object'].max(dim=1)[1] == object_index) * ( | |
part_pred[i].max(dim=1)[1] == local_index) | |
if mask is None: | |
mask = mask2 | |
else: | |
mask = torch.max(mask, mask2) | |
result = result + (part_pred[i][:, local_index]) | |
assert result is not 0, 'unrecognized class %d' % classnum | |
return result, mask | |
def expand_segment_quad(self, segs, segdiv='quad'): | |
shape = segs.shape | |
segs[:,3:] = segs[:,0:1] # start by copying the object channel | |
num_seg_labels = self.num_object_classes | |
# For every connected component present (using generator) | |
for i, mask in component_masks(segs[:,0:1]): | |
# Figure the bounding box of the label | |
top, bottom = mask.any(dim=1).nonzero()[[0, -1], 0] | |
left, right = mask.any(dim=0).nonzero()[[0, -1], 0] | |
# Chop the bounding box into four parts | |
vmid = (top + bottom + 1) // 2 | |
hmid = (left + right + 1) // 2 | |
# Construct top, bottom, right, left masks | |
quad_mask = mask[None,:,:].repeat(4, 1, 1) | |
quad_mask[0, vmid:, :] = 0 # top | |
quad_mask[1, :, hmid:] = 0 # right | |
quad_mask[2, :vmid, :] = 0 # bottom | |
quad_mask[3, :, :hmid] = 0 # left | |
quad_mask = quad_mask.long() | |
# Modify extra segmentation labels by offsetting | |
segs[i,3,:,:] += quad_mask[0] * num_seg_labels | |
segs[i,4,:,:] += quad_mask[1] * (2 * num_seg_labels) | |
segs[i,3,:,:] += quad_mask[2] * (3 * num_seg_labels) | |
segs[i,4,:,:] += quad_mask[3] * (4 * num_seg_labels) | |
# remove any components that were too small to subdivide | |
mask = segs[:,3:] <= self.num_object_classes | |
segs[:,3:][mask] = 0 | |
return segs | |
class SemanticSegmenter(BaseSegmenter): | |
def __init__(self, modeldir=None, segarch=None, segvocab=None, | |
segsizes=None, segdiv=None, epoch=None): | |
# Create a segmentation model | |
if modeldir == None: | |
modeldir = 'dataset/segmodel' | |
if segvocab == None: | |
segvocab = 'baseline' | |
if segarch == None: | |
segarch = ('resnet50_dilated8', 'ppm_bilinear_deepsup') | |
if segdiv == None: | |
segdiv = 'undivided' | |
elif isinstance(segarch, str): | |
segarch = segarch.split(',') | |
segmodel = load_segmentation_model(modeldir, segarch, segvocab, epoch) | |
if segsizes is None: | |
segsizes = getattr(segmodel.meta, 'segsizes', [256]) | |
self.segsizes = segsizes | |
# Verify segmentation model to has every out_channel labeled. | |
assert len(segmodel.meta.labels) == list(c for c in segmodel.modules() | |
if isinstance(c, torch.nn.Conv2d))[-1].out_channels | |
segmodel.cuda() | |
self.segmodel = segmodel | |
self.segdiv = segdiv | |
# Image normalization | |
self.bgr = (segmodel.meta.imageformat.byteorder == 'BGR') | |
self.imagemean = torch.tensor(segmodel.meta.imageformat.mean) | |
self.imagestd = torch.tensor(segmodel.meta.imageformat.stdev) | |
# Map from labels to external indexes, and labels to channel sets. | |
self.labelmap = {'-': 0} | |
self.channelmap = {'-': []} | |
self.labels = [('-', '-')] | |
num_labels = 1 | |
self.num_underlying_classes = len(segmodel.meta.labels) | |
# labelmap maps names to external indexes. | |
for i, label in enumerate(segmodel.meta.labels): | |
if label.name not in self.channelmap: | |
self.channelmap[label.name] = [] | |
self.channelmap[label.name].append(i) | |
if getattr(label, 'internal', None) or label.name in self.labelmap: | |
continue | |
self.labelmap[label.name] = num_labels | |
num_labels += 1 | |
self.labels.append((label.name, label.category)) | |
# Each category gets its own independent softmax. | |
self.category_indexes = { category.name: | |
[i for i, label in enumerate(segmodel.meta.labels) | |
if label.category == category.name] | |
for category in segmodel.meta.categories } | |
# catindexmap maps names to category internal indexes | |
self.catindexmap = {} | |
for catname, indexlist in self.category_indexes.items(): | |
for index, i in enumerate(indexlist): | |
self.catindexmap[segmodel.meta.labels[i].name] = ( | |
(catname, index)) | |
# After the softmax, each category is mapped to external indexes. | |
self.category_map = { catname: | |
torch.tensor([ | |
self.labelmap.get(segmodel.meta.labels[ind].name, 0) | |
for ind in catindex]) | |
for catname, catindex in self.category_indexes.items()} | |
self.category_rules = segmodel.meta.categories | |
# Finally, naive subdivision can be applied. | |
mult = 1 | |
if self.segdiv == 'quad': | |
mult = 5 | |
suffixes = ['t', 'l', 'b', 'r'] | |
divided_labels = [] | |
for suffix in suffixes: | |
divided_labels.extend([('%s-%s' % (label, suffix), cat) | |
for label, cat in self.labels[1:]]) | |
self.channelmap.update({ | |
'%s-%s' % (label, suffix): self.channelmap[label] | |
for label, cat in self.labels[1:] }) | |
self.labels.extend(divided_labels) | |
# For examining a single class | |
self.channellist = [self.channelmap[name] for name, _ in self.labels] | |
def get_label_and_category_names(self, dataset=None): | |
return self.labels, self.segmodel.categories | |
def segment_batch(self, tensor_images, downsample=1): | |
return self.raw_segment_batch(tensor_images, downsample)[0] | |
def raw_segment_batch(self, tensor_images, downsample=1): | |
pred = self.raw_seg_prediction(tensor_images, downsample) | |
catsegs = {} | |
for catkey, catindex in self.category_indexes.items(): | |
_, segs = torch.max(pred[:, catindex], dim=1) | |
catsegs[catkey] = segs | |
masks = {} | |
segs = torch.zeros(len(tensor_images), len(self.category_rules), | |
pred.shape[2], pred.shape[2], device=pred.device, | |
dtype=torch.long) | |
for i, cat in enumerate(self.category_rules): | |
catmap = self.category_map[cat.name].to(pred.device) | |
translated = catmap[catsegs[cat.name]] | |
if getattr(cat, 'mask', None) is not None: | |
if cat.mask not in masks: | |
maskcat, maskind = self.catindexmap[cat.mask] | |
masks[cat.mask] = (catsegs[maskcat] == maskind) | |
translated *= masks[cat.mask].long() | |
segs[:,i] = translated | |
if self.segdiv == 'quad': | |
segs = self.expand_segment_quad(segs, | |
self.num_underlying_classes, self.segdiv) | |
return segs, pred | |
def raw_seg_prediction(self, tensor_images, downsample=1): | |
''' | |
Generates a segmentation by applying multiresolution voting on | |
the segmentation model, using (rounded to 32 pixels) a set of | |
resolutions in the example benchmark code. | |
''' | |
y, x = tensor_images.shape[2:] | |
b = len(tensor_images) | |
# Flip the RGB order if specified. | |
if self.bgr: | |
tensor_images = torch.flip(tensor_images, (1,)) | |
# Transform from our [-1..1] range to torch standard [0..1] range | |
# and then apply normalization. | |
tensor_images = ((tensor_images + 1) / 2 | |
).sub_(self.imagemean[None,:,None,None].to(tensor_images.device) | |
).div_(self.imagestd[None,:,None,None].to(tensor_images.device)) | |
# Output shape can be downsampled. | |
seg_shape = (y // downsample, x // downsample) | |
# We want these to be multiples of 32 for the model. | |
sizes = [(s, s) for s in self.segsizes] | |
pred = torch.zeros( | |
len(tensor_images), (self.num_underlying_classes), | |
seg_shape[0], seg_shape[1]).cuda() | |
for size in sizes: | |
if size == tensor_images.shape[2:]: | |
resized = tensor_images | |
else: | |
resized = torch.nn.AdaptiveAvgPool2d(size)(tensor_images) | |
raw_pred = self.segmodel( | |
dict(img_data=resized), segSize=seg_shape) | |
softmax_pred = torch.empty_like(raw_pred) | |
for catindex in self.category_indexes.values(): | |
softmax_pred[:, catindex] = torch.nn.functional.softmax( | |
raw_pred[:, catindex], dim=1) | |
pred += softmax_pred | |
return pred | |
def expand_segment_quad(self, segs, num_seg_labels, segdiv='quad'): | |
shape = segs.shape | |
output = segs.repeat(1, 3, 1, 1) | |
# For every connected component present (using generator) | |
for i, mask in component_masks(segs): | |
# Figure the bounding box of the label | |
top, bottom = mask.any(dim=1).nonzero()[[0, -1], 0] | |
left, right = mask.any(dim=0).nonzero()[[0, -1], 0] | |
# Chop the bounding box into four parts | |
vmid = (top + bottom + 1) // 2 | |
hmid = (left + right + 1) // 2 | |
# Construct top, bottom, right, left masks | |
quad_mask = mask[None,:,:].repeat(4, 1, 1) | |
quad_mask[0, vmid:, :] = 0 # top | |
quad_mask[1, :, hmid:] = 0 # right | |
quad_mask[2, :vmid, :] = 0 # bottom | |
quad_mask[3, :, :hmid] = 0 # left | |
quad_mask = quad_mask.long() | |
# Modify extra segmentation labels by offsetting | |
output[i,1,:,:] += quad_mask[0] * num_seg_labels | |
output[i,2,:,:] += quad_mask[1] * (2 * num_seg_labels) | |
output[i,1,:,:] += quad_mask[2] * (3 * num_seg_labels) | |
output[i,2,:,:] += quad_mask[3] * (4 * num_seg_labels) | |
return output | |
def predict_single_class(self, tensor_images, classnum, downsample=1): | |
''' | |
Given a batch of images (RGB, normalized to [-1...1]) and | |
a specific segmentation class number, returns a tuple with | |
(1) a differentiable ([0..1]) prediction score for the class | |
at every pixel of the input image. | |
(2) a binary mask showing where in the input image the | |
specified class is the best-predicted label for the pixel. | |
Does not work on subdivided labels. | |
''' | |
seg, pred = self.raw_segment_batch(tensor_images, | |
downsample=downsample) | |
result = pred[:,self.channellist[classnum]].sum(dim=1) | |
mask = (seg == classnum).max(1)[0] | |
return result, mask | |
def component_masks(segmentation_batch): | |
''' | |
Splits connected components into regions (slower, requires cpu). | |
''' | |
npbatch = segmentation_batch.cpu().numpy() | |
for i in range(segmentation_batch.shape[0]): | |
labeled, num = skimage.morphology.label(npbatch[i][0], return_num=True) | |
labeled = torch.from_numpy(labeled).to(segmentation_batch.device) | |
for label in range(1, num): | |
yield i, (labeled == label) | |
def load_unified_parsing_segmentation_model(segmodel_arch, segvocab, epoch): | |
segmodel_dir = 'dataset/segmodel/%s-%s-%s' % ((segvocab,) + segmodel_arch) | |
# Load json of class names and part/object structure | |
with open(os.path.join(segmodel_dir, 'labels.json')) as f: | |
labeldata = json.load(f) | |
nr_classes={k: len(labeldata[k]) | |
for k in ['object', 'scene', 'material']} | |
nr_classes['part'] = sum(len(p) for p in labeldata['object_part'].values()) | |
# Create a segmentation model | |
segbuilder = upsegmodel.ModelBuilder() | |
# example segmodel_arch = ('resnet101', 'upernet') | |
seg_encoder = segbuilder.build_encoder( | |
arch=segmodel_arch[0], | |
fc_dim=2048, | |
weights=os.path.join(segmodel_dir, 'encoder_epoch_%d.pth' % epoch)) | |
seg_decoder = segbuilder.build_decoder( | |
arch=segmodel_arch[1], | |
fc_dim=2048, use_softmax=True, | |
nr_classes=nr_classes, | |
weights=os.path.join(segmodel_dir, 'decoder_epoch_%d.pth' % epoch)) | |
segmodel = upsegmodel.SegmentationModule( | |
seg_encoder, seg_decoder, labeldata) | |
segmodel.categories = ['object', 'part', 'material'] | |
segmodel.eval() | |
return segmodel | |
def load_segmentation_model(modeldir, segmodel_arch, segvocab, epoch=None): | |
# Load csv of class names | |
segmodel_dir = 'dataset/segmodel/%s-%s-%s' % ((segvocab,) + segmodel_arch) | |
with open(os.path.join(segmodel_dir, 'labels.json')) as f: | |
labeldata = EasyDict(json.load(f)) | |
# Automatically pick the last epoch available. | |
if epoch is None: | |
choices = [os.path.basename(n)[14:-4] for n in | |
glob.glob(os.path.join(segmodel_dir, 'encoder_epoch_*.pth'))] | |
epoch = max([int(c) for c in choices if c.isdigit()]) | |
# Create a segmentation model | |
segbuilder = segmodel_module.ModelBuilder() | |
# example segmodel_arch = ('resnet101', 'upernet') | |
seg_encoder = segbuilder.build_encoder( | |
arch=segmodel_arch[0], | |
fc_dim=2048, | |
weights=os.path.join(segmodel_dir, 'encoder_epoch_%d.pth' % epoch)) | |
seg_decoder = segbuilder.build_decoder( | |
arch=segmodel_arch[1], | |
fc_dim=2048, inference=True, num_class=len(labeldata.labels), | |
weights=os.path.join(segmodel_dir, 'decoder_epoch_%d.pth' % epoch)) | |
segmodel = segmodel_module.SegmentationModule(seg_encoder, seg_decoder, | |
torch.nn.NLLLoss(ignore_index=-1)) | |
segmodel.categories = [cat.name for cat in labeldata.categories] | |
segmodel.labels = [label.name for label in labeldata.labels] | |
categories = OrderedDict() | |
label_category = numpy.zeros(len(segmodel.labels), dtype=int) | |
for i, label in enumerate(labeldata.labels): | |
label_category[i] = segmodel.categories.index(label.category) | |
segmodel.meta = labeldata | |
segmodel.eval() | |
return segmodel | |
def ensure_upp_segmenter_downloaded(directory): | |
baseurl = 'http://netdissect.csail.mit.edu/data/segmodel' | |
dirname = 'upp-resnet50-upernet' | |
files = ['decoder_epoch_40.pth', 'encoder_epoch_40.pth', 'labels.json'] | |
download_dir = os.path.join(directory, dirname) | |
os.makedirs(download_dir, exist_ok=True) | |
for fn in files: | |
if os.path.isfile(os.path.join(download_dir, fn)): | |
continue # Skip files already downloaded | |
url = '%s/%s/%s' % (baseurl, dirname, fn) | |
print('Downloading %s' % url) | |
urlretrieve(url, os.path.join(download_dir, fn)) | |
assert os.path.isfile(os.path.join(directory, dirname, 'labels.json')) | |
def test_main(): | |
''' | |
Test the unified segmenter. | |
''' | |
from PIL import Image | |
testim = Image.open('script/testdata/test_church_242.jpg') | |
tensor_im = (torch.from_numpy(numpy.asarray(testim)).permute(2, 0, 1) | |
.float() / 255 * 2 - 1)[None, :, :, :].cuda() | |
segmenter = UnifiedParsingSegmenter() | |
seg = segmenter.segment_batch(tensor_im) | |
bc = torch.bincount(seg.view(-1)) | |
labels, cats = segmenter.get_label_and_category_names() | |
for label in bc.nonzero()[:,0]: | |
if label.item(): | |
# What is the prediction for this class? | |
pred, mask = segmenter.predict_single_class(tensor_im, label.item()) | |
assert mask.sum().item() == bc[label].item() | |
assert len(((seg == label).max(1)[0] - mask).nonzero()) == 0 | |
inside_pred = pred[mask].mean().item() | |
outside_pred = pred[~mask].mean().item() | |
print('%s (%s, #%d): %d pixels, pred %.2g inside %.2g outside' % | |
(labels[label.item()] + (label.item(), bc[label].item(), | |
inside_pred, outside_pred))) | |
if __name__ == '__main__': | |
test_main() | |