File size: 3,448 Bytes
0b8359d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Operations for image patches."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow.compat.v1 as tf


def get_patch_mask(y, x, patch_size, image_shape):
  """Creates a 2D mask array for a square patch of a given size and location.

  The mask is created with its center at the y and x coordinates, which must be
  within the image. While the mask center must be within the image, the mask
  itself can be partially outside of it. If patch_size is an even number, then
  the mask is created with lower-valued coordinates first (top and left).

  Args:
    y: An integer or scalar int32 tensor. The vertical coordinate of the
      patch mask center. Must be within the range [0, image_height).
    x: An integer or scalar int32 tensor. The horizontal coordinate of the
      patch mask center. Must be within the range [0, image_width).
    patch_size: An integer or scalar int32 tensor. The square size of the
      patch mask. Must be at least 1.
    image_shape: A list or 1D int32 tensor representing the shape of the image
      to which the mask will correspond, with the first two values being image
      height and width. For example, [image_height, image_width] or
      [image_height, image_width, image_channels].

  Returns:
    Boolean mask tensor of shape [image_height, image_width] with True values
    for the patch.

  Raises:
    tf.errors.InvalidArgumentError: if x is not in the range [0, image_width), y
      is not in the range [0, image_height), or patch_size is not at least 1.
  """
  image_hw = image_shape[:2]
  mask_center_yx = tf.stack([y, x])
  with tf.control_dependencies([
      tf.debugging.assert_greater_equal(
          patch_size, 1,
          message='Patch size must be >= 1'),
      tf.debugging.assert_greater_equal(
          mask_center_yx, 0,
          message='Patch center (y, x) must be >= (0, 0)'),
      tf.debugging.assert_less(
          mask_center_yx, image_hw,
          message='Patch center (y, x) must be < image (h, w)')
  ]):
    mask_center_yx = tf.identity(mask_center_yx)

  half_patch_size = tf.cast(patch_size, dtype=tf.float32) / 2
  start_yx = mask_center_yx - tf.cast(tf.floor(half_patch_size), dtype=tf.int32)
  end_yx = mask_center_yx + tf.cast(tf.ceil(half_patch_size), dtype=tf.int32)

  start_yx = tf.maximum(start_yx, 0)
  end_yx = tf.minimum(end_yx, image_hw)

  start_y = start_yx[0]
  start_x = start_yx[1]
  end_y = end_yx[0]
  end_x = end_yx[1]

  lower_pad = image_hw[0] - end_y
  upper_pad = start_y
  left_pad = start_x
  right_pad = image_hw[1] - end_x
  mask = tf.ones([end_y - start_y, end_x - start_x], dtype=tf.bool)
  return tf.pad(mask, [[upper_pad, lower_pad], [left_pad, right_pad]])