File size: 4,117 Bytes
9a393e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Records previous preprocessing operations and allows them to be repeated.

Used with object_detection.core.preprocessor. Passing a PreprocessorCache
into individual data augmentation functions or the general preprocess() function
will store all randomly generated variables in the PreprocessorCache. When
a preprocessor function is called multiple times with the same
PreprocessorCache object, that function will perform the same augmentation
on all calls.
"""

from collections import defaultdict


class PreprocessorCache(object):
  """Dictionary wrapper storing random variables generated during preprocessing.
  """

  # Constant keys representing different preprocessing functions
  ROTATION90 = 'rotation90'
  HORIZONTAL_FLIP = 'horizontal_flip'
  VERTICAL_FLIP = 'vertical_flip'
  PIXEL_VALUE_SCALE = 'pixel_value_scale'
  IMAGE_SCALE = 'image_scale'
  RGB_TO_GRAY = 'rgb_to_gray'
  ADJUST_BRIGHTNESS = 'adjust_brightness'
  ADJUST_CONTRAST = 'adjust_contrast'
  ADJUST_HUE = 'adjust_hue'
  ADJUST_SATURATION = 'adjust_saturation'
  DISTORT_COLOR = 'distort_color'
  STRICT_CROP_IMAGE = 'strict_crop_image'
  CROP_IMAGE = 'crop_image'
  PAD_IMAGE = 'pad_image'
  CROP_TO_ASPECT_RATIO = 'crop_to_aspect_ratio'
  RESIZE_METHOD = 'resize_method'
  PAD_TO_ASPECT_RATIO = 'pad_to_aspect_ratio'
  BLACK_PATCHES = 'black_patches'
  ADD_BLACK_PATCH = 'add_black_patch'
  SELECTOR = 'selector'
  SELECTOR_TUPLES = 'selector_tuples'
  SELF_CONCAT_IMAGE = 'self_concat_image'
  SSD_CROP_SELECTOR_ID = 'ssd_crop_selector_id'
  SSD_CROP_PAD_SELECTOR_ID = 'ssd_crop_pad_selector_id'

  # 23 permitted function ids
  _VALID_FNS = [ROTATION90, HORIZONTAL_FLIP, VERTICAL_FLIP, PIXEL_VALUE_SCALE,
                IMAGE_SCALE, RGB_TO_GRAY, ADJUST_BRIGHTNESS, ADJUST_CONTRAST,
                ADJUST_HUE, ADJUST_SATURATION, DISTORT_COLOR, STRICT_CROP_IMAGE,
                CROP_IMAGE, PAD_IMAGE, CROP_TO_ASPECT_RATIO, RESIZE_METHOD,
                PAD_TO_ASPECT_RATIO, BLACK_PATCHES, ADD_BLACK_PATCH, SELECTOR,
                SELECTOR_TUPLES, SELF_CONCAT_IMAGE, SSD_CROP_SELECTOR_ID,
                SSD_CROP_PAD_SELECTOR_ID]

  def __init__(self):
    self._history = defaultdict(dict)

  def clear(self):
    """Resets cache."""
    self._history = defaultdict(dict)

  def get(self, function_id, key):
    """Gets stored value given a function id and key.

    Args:
      function_id: identifier for the preprocessing function used.
      key: identifier for the variable stored.
    Returns:
      value: the corresponding value, expected to be a tensor or
             nested structure of tensors.
    Raises:
      ValueError: if function_id is not one of the 23 valid function ids.
    """
    if function_id not in self._VALID_FNS:
      raise ValueError('Function id not recognized: %s.' % str(function_id))
    return self._history[function_id].get(key)

  def update(self, function_id, key, value):
    """Adds a value to the dictionary.

    Args:
      function_id: identifier for the preprocessing function used.
      key: identifier for the variable stored.
      value: the value to store, expected to be a tensor or nested structure
             of tensors.
    Raises:
      ValueError: if function_id is not one of the 23 valid function ids.
    """
    if function_id not in self._VALID_FNS:
      raise ValueError('Function id not recognized: %s.' % str(function_id))
    self._history[function_id][key] = value