# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Contains definitions of box sampler.""" # Import libraries import tensorflow as tf, tf_keras from official.vision.ops import sampling_ops @tf_keras.utils.register_keras_serializable(package='Vision') class BoxSampler(tf_keras.layers.Layer): """Creates a BoxSampler to sample positive and negative boxes.""" def __init__(self, num_samples: int = 512, foreground_fraction: float = 0.25, **kwargs): """Initializes a box sampler. Args: num_samples: An `int` of the number of sampled boxes per image. foreground_fraction: A `float` in [0, 1], what percentage of boxes should be sampled from the positive examples. **kwargs: Additional keyword arguments passed to Layer. """ self._config_dict = { 'num_samples': num_samples, 'foreground_fraction': foreground_fraction, } super(BoxSampler, self).__init__(**kwargs) def call(self, positive_matches: tf.Tensor, negative_matches: tf.Tensor, ignored_matches: tf.Tensor): """Samples and selects positive and negative instances. Args: positive_matches: A `bool` tensor of shape of [batch, N] where N is the number of instances. For each element, `True` means the instance corresponds to a positive example. negative_matches: A `bool` tensor of shape of [batch, N] where N is the number of instances. For each element, `True` means the instance corresponds to a negative example. ignored_matches: A `bool` tensor of shape of [batch, N] where N is the number of instances. For each element, `True` means the instance should be ignored. Returns: A `tf.tensor` of shape of [batch_size, K], storing the indices of the sampled examples, where K is `num_samples`. """ sample_candidates = tf.logical_and( tf.logical_or(positive_matches, negative_matches), tf.logical_not(ignored_matches)) sampler = sampling_ops.BalancedPositiveNegativeSampler( positive_fraction=self._config_dict['foreground_fraction'], is_static=True) batch_size = sample_candidates.shape[0] sampled_indicators = [] for i in range(batch_size): sampled_indicator = sampler.subsample( sample_candidates[i], self._config_dict['num_samples'], positive_matches[i]) sampled_indicators.append(sampled_indicator) sampled_indicators = tf.stack(sampled_indicators) _, selected_indices = tf.nn.top_k( tf.cast(sampled_indicators, dtype=tf.int32), k=self._config_dict['num_samples'], sorted=True) return selected_indices def get_config(self): return self._config_dict @classmethod def from_config(cls, config): return cls(**config)