deeplab2 / data /build_cityscapes_data_test.py
akhaliq3
spaces demo
506da10
# coding=utf-8
# Copyright 2021 The Deeplab2 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for build_cityscapes_data."""
import os
from absl import flags
import numpy as np
from PIL import Image
import tensorflow as tf
from deeplab2.data import build_cityscapes_data
FLAGS = flags.FLAGS
_TEST_DATA_DIR = 'deeplab2/data/testdata'
_TEST_FILE_PREFIX = 'dummy_000000_000000'
class BuildCityscapesDataTest(tf.test.TestCase):
def test_read_segments(self):
cityscapes_root = os.path.join(_TEST_DATA_DIR)
segments_dict = build_cityscapes_data._read_segments(
cityscapes_root, dataset_split='dummy')
self.assertIn(_TEST_FILE_PREFIX, segments_dict)
_, segments = segments_dict[_TEST_FILE_PREFIX]
self.assertLen(segments, 10)
def test_generate_panoptic_label(self):
FLAGS.treat_crowd_as_ignore = False # Test a more complicated setting
cityscapes_root = os.path.join(_TEST_DATA_DIR)
segments_dict = build_cityscapes_data._read_segments(
cityscapes_root, dataset_split='dummy')
annotation_file_name, segments = segments_dict[_TEST_FILE_PREFIX]
panoptic_annotation_file = build_cityscapes_data._get_panoptic_annotation(
cityscapes_root, dataset_split='dummy',
annotation_file_name=annotation_file_name)
panoptic_label = build_cityscapes_data._generate_panoptic_label(
panoptic_annotation_file, segments)
# Check panoptic label matches golden file.
golden_file_path = os.path.join(_TEST_DATA_DIR,
'dummy_gt_for_vps.png')
with tf.io.gfile.GFile(golden_file_path, 'rb') as f:
golden_label = Image.open(f)
# The PNG file is encoded by:
# color = [segmentId % 256, segmentId // 256, segmentId // 256 // 256]
golden_label = np.dot(np.asarray(golden_label), [1, 256, 256 * 256])
np.testing.assert_array_equal(panoptic_label, golden_label)
if __name__ == '__main__':
tf.test.main()