deanna-emery's picture
updates
93528c6
raw
history blame
13.5 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Preprocessing functions for images."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from typing import List, Optional, Text, Tuple
import tensorflow as tf, tf_keras
from official.legacy.image_classification import augment
# Calculated from the ImageNet training set
MEAN_RGB = (0.485 * 255, 0.456 * 255, 0.406 * 255)
STDDEV_RGB = (0.229 * 255, 0.224 * 255, 0.225 * 255)
IMAGE_SIZE = 224
CROP_PADDING = 32
def mean_image_subtraction(
image_bytes: tf.Tensor,
means: Tuple[float, ...],
num_channels: int = 3,
dtype: tf.dtypes.DType = tf.float32,
) -> tf.Tensor:
"""Subtracts the given means from each image channel.
For example:
means = [123.68, 116.779, 103.939]
image_bytes = mean_image_subtraction(image_bytes, means)
Note that the rank of `image` must be known.
Args:
image_bytes: a tensor of size [height, width, C].
means: a C-vector of values to subtract from each channel.
num_channels: number of color channels in the image that will be distorted.
dtype: the dtype to convert the images to. Set to `None` to skip conversion.
Returns:
the centered image.
Raises:
ValueError: If the rank of `image` is unknown, if `image` has a rank other
than three or if the number of channels in `image` doesn't match the
number of values in `means`.
"""
if image_bytes.get_shape().ndims != 3:
raise ValueError('Input must be of size [height, width, C>0]')
if len(means) != num_channels:
raise ValueError('len(means) must match the number of channels')
# We have a 1-D tensor of means; convert to 3-D.
# Note(b/130245863): we explicitly call `broadcast` instead of simply
# expanding dimensions for better performance.
means = tf.broadcast_to(means, tf.shape(image_bytes))
if dtype is not None:
means = tf.cast(means, dtype=dtype)
return image_bytes - means
def standardize_image(
image_bytes: tf.Tensor,
stddev: Tuple[float, ...],
num_channels: int = 3,
dtype: tf.dtypes.DType = tf.float32,
) -> tf.Tensor:
"""Divides the given stddev from each image channel.
For example:
stddev = [123.68, 116.779, 103.939]
image_bytes = standardize_image(image_bytes, stddev)
Note that the rank of `image` must be known.
Args:
image_bytes: a tensor of size [height, width, C].
stddev: a C-vector of values to divide from each channel.
num_channels: number of color channels in the image that will be distorted.
dtype: the dtype to convert the images to. Set to `None` to skip conversion.
Returns:
the centered image.
Raises:
ValueError: If the rank of `image` is unknown, if `image` has a rank other
than three or if the number of channels in `image` doesn't match the
number of values in `stddev`.
"""
if image_bytes.get_shape().ndims != 3:
raise ValueError('Input must be of size [height, width, C>0]')
if len(stddev) != num_channels:
raise ValueError('len(stddev) must match the number of channels')
# We have a 1-D tensor of stddev; convert to 3-D.
# Note(b/130245863): we explicitly call `broadcast` instead of simply
# expanding dimensions for better performance.
stddev = tf.broadcast_to(stddev, tf.shape(image_bytes))
if dtype is not None:
stddev = tf.cast(stddev, dtype=dtype)
return image_bytes / stddev
def normalize_images(features: tf.Tensor,
mean_rgb: Tuple[float, ...] = MEAN_RGB,
stddev_rgb: Tuple[float, ...] = STDDEV_RGB,
num_channels: int = 3,
dtype: tf.dtypes.DType = tf.float32,
data_format: Text = 'channels_last') -> tf.Tensor:
"""Normalizes the input image channels with the given mean and stddev.
Args:
features: `Tensor` representing decoded images in float format.
mean_rgb: the mean of the channels to subtract.
stddev_rgb: the stddev of the channels to divide.
num_channels: the number of channels in the input image tensor.
dtype: the dtype to convert the images to. Set to `None` to skip conversion.
data_format: the format of the input image tensor
['channels_first', 'channels_last'].
Returns:
A normalized image `Tensor`.
"""
# TODO(allencwang) - figure out how to use mean_image_subtraction and
# standardize_image on batches of images and replace the following.
if data_format == 'channels_first':
stats_shape = [num_channels, 1, 1]
else:
stats_shape = [1, 1, num_channels]
if dtype is not None:
features = tf.image.convert_image_dtype(features, dtype=dtype)
if mean_rgb is not None:
mean_rgb = tf.constant(mean_rgb,
shape=stats_shape,
dtype=features.dtype)
mean_rgb = tf.broadcast_to(mean_rgb, tf.shape(features))
features = features - mean_rgb
if stddev_rgb is not None:
stddev_rgb = tf.constant(stddev_rgb,
shape=stats_shape,
dtype=features.dtype)
stddev_rgb = tf.broadcast_to(stddev_rgb, tf.shape(features))
features = features / stddev_rgb
return features
def decode_and_center_crop(image_bytes: tf.Tensor,
image_size: int = IMAGE_SIZE,
crop_padding: int = CROP_PADDING) -> tf.Tensor:
"""Crops to center of image with padding then scales image_size.
Args:
image_bytes: `Tensor` representing an image binary of arbitrary size.
image_size: image height/width dimension.
crop_padding: the padding size to use when centering the crop.
Returns:
A decoded and cropped image `Tensor`.
"""
decoded = image_bytes.dtype != tf.string
shape = (tf.shape(image_bytes) if decoded
else tf.image.extract_jpeg_shape(image_bytes))
image_height = shape[0]
image_width = shape[1]
padded_center_crop_size = tf.cast(
((image_size / (image_size + crop_padding)) *
tf.cast(tf.minimum(image_height, image_width), tf.float32)),
tf.int32)
offset_height = ((image_height - padded_center_crop_size) + 1) // 2
offset_width = ((image_width - padded_center_crop_size) + 1) // 2
crop_window = tf.stack([offset_height, offset_width,
padded_center_crop_size, padded_center_crop_size])
if decoded:
image = tf.image.crop_to_bounding_box(
image_bytes,
offset_height=offset_height,
offset_width=offset_width,
target_height=padded_center_crop_size,
target_width=padded_center_crop_size)
else:
image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3)
image = resize_image(image_bytes=image,
height=image_size,
width=image_size)
return image
def decode_crop_and_flip(image_bytes: tf.Tensor) -> tf.Tensor:
"""Crops an image to a random part of the image, then randomly flips.
Args:
image_bytes: `Tensor` representing an image binary of arbitrary size.
Returns:
A decoded and cropped image `Tensor`.
"""
decoded = image_bytes.dtype != tf.string
bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4])
shape = (tf.shape(image_bytes) if decoded
else tf.image.extract_jpeg_shape(image_bytes))
sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
shape,
bounding_boxes=bbox,
min_object_covered=0.1,
aspect_ratio_range=[0.75, 1.33],
area_range=[0.05, 1.0],
max_attempts=100,
use_image_if_no_bounding_boxes=True)
bbox_begin, bbox_size, _ = sample_distorted_bounding_box
# Reassemble the bounding box in the format the crop op requires.
offset_height, offset_width, _ = tf.unstack(bbox_begin)
target_height, target_width, _ = tf.unstack(bbox_size)
crop_window = tf.stack([offset_height, offset_width,
target_height, target_width])
if decoded:
cropped = tf.image.crop_to_bounding_box(
image_bytes,
offset_height=offset_height,
offset_width=offset_width,
target_height=target_height,
target_width=target_width)
else:
cropped = tf.image.decode_and_crop_jpeg(image_bytes,
crop_window,
channels=3)
# Flip to add a little more random distortion in.
cropped = tf.image.random_flip_left_right(cropped)
return cropped
def resize_image(image_bytes: tf.Tensor,
height: int = IMAGE_SIZE,
width: int = IMAGE_SIZE) -> tf.Tensor:
"""Resizes an image to a given height and width.
Args:
image_bytes: `Tensor` representing an image binary of arbitrary size.
height: image height dimension.
width: image width dimension.
Returns:
A tensor containing the resized image.
"""
print(height, width)
return tf.compat.v1.image.resize(
image_bytes,
tf.convert_to_tensor([height, width]),
method=tf.image.ResizeMethod.BILINEAR,
align_corners=False)
def preprocess_for_eval(
image_bytes: tf.Tensor,
image_size: int = IMAGE_SIZE,
num_channels: int = 3,
mean_subtract: bool = False,
standardize: bool = False,
dtype: tf.dtypes.DType = tf.float32
) -> tf.Tensor:
"""Preprocesses the given image for evaluation.
Args:
image_bytes: `Tensor` representing an image binary of arbitrary size.
image_size: image height/width dimension.
num_channels: number of image input channels.
mean_subtract: whether or not to apply mean subtraction.
standardize: whether or not to apply standardization.
dtype: the dtype to convert the images to. Set to `None` to skip conversion.
Returns:
A preprocessed and normalized image `Tensor`.
"""
images = decode_and_center_crop(image_bytes, image_size)
images = tf.reshape(images, [image_size, image_size, num_channels])
if mean_subtract:
images = mean_image_subtraction(image_bytes=images, means=MEAN_RGB)
if standardize:
images = standardize_image(image_bytes=images, stddev=STDDEV_RGB)
if dtype is not None:
images = tf.image.convert_image_dtype(images, dtype=dtype)
return images
def load_eval_image(filename: Text, image_size: int = IMAGE_SIZE) -> tf.Tensor:
"""Reads an image from the filesystem and applies image preprocessing.
Args:
filename: a filename path of an image.
image_size: image height/width dimension.
Returns:
A preprocessed and normalized image `Tensor`.
"""
image_bytes = tf.io.read_file(filename)
image = preprocess_for_eval(image_bytes, image_size)
return image
def build_eval_dataset(filenames: List[Text],
labels: Optional[List[int]] = None,
image_size: int = IMAGE_SIZE,
batch_size: int = 1) -> tf.Tensor:
"""Builds a tf.data.Dataset from a list of filenames and labels.
Args:
filenames: a list of filename paths of images.
labels: a list of labels corresponding to each image.
image_size: image height/width dimension.
batch_size: the batch size used by the dataset
Returns:
A preprocessed and normalized image `Tensor`.
"""
if labels is None:
labels = [0] * len(filenames)
filenames = tf.constant(filenames)
labels = tf.constant(labels)
dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.map(
lambda filename, label: (load_eval_image(filename, image_size), label))
dataset = dataset.batch(batch_size)
return dataset
def preprocess_for_train(image_bytes: tf.Tensor,
image_size: int = IMAGE_SIZE,
augmenter: Optional[augment.ImageAugment] = None,
mean_subtract: bool = False,
standardize: bool = False,
dtype: tf.dtypes.DType = tf.float32) -> tf.Tensor:
"""Preprocesses the given image for training.
Args:
image_bytes: `Tensor` representing an image binary of
arbitrary size of dtype tf.uint8.
image_size: image height/width dimension.
augmenter: the image augmenter to apply.
mean_subtract: whether or not to apply mean subtraction.
standardize: whether or not to apply standardization.
dtype: the dtype to convert the images to. Set to `None` to skip conversion.
Returns:
A preprocessed and normalized image `Tensor`.
"""
images = decode_crop_and_flip(image_bytes=image_bytes)
images = resize_image(images, height=image_size, width=image_size)
if augmenter is not None:
images = augmenter.distort(images)
if mean_subtract:
images = mean_image_subtraction(image_bytes=images, means=MEAN_RGB)
if standardize:
images = standardize_image(image_bytes=images, stddev=STDDEV_RGB)
if dtype is not None:
images = tf.image.convert_image_dtype(images, dtype)
return images