deanna-emery's picture
updates
5672777
raw
history blame
3.62 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utils used to manipulate tensor shapes."""
import tensorflow as tf, tf_keras
def assert_shape_equal(shape_a, shape_b):
"""Asserts that shape_a and shape_b are equal.
If the shapes are static, raises a ValueError when the shapes
mismatch.
If the shapes are dynamic, raises a tf InvalidArgumentError when the shapes
mismatch.
Args:
shape_a: a list containing shape of the first tensor.
shape_b: a list containing shape of the second tensor.
Returns:
Either a tf.no_op() when shapes are all static and a tf.assert_equal() op
when the shapes are dynamic.
Raises:
ValueError: When shapes are both static and unequal.
"""
if (all(isinstance(dim, int) for dim in shape_a) and
all(isinstance(dim, int) for dim in shape_b)):
if shape_a != shape_b:
raise ValueError('Unequal shapes {}, {}'.format(shape_a, shape_b))
else:
return tf.no_op()
else:
return tf.assert_equal(shape_a, shape_b)
def combined_static_and_dynamic_shape(tensor):
"""Returns a list containing static and dynamic values for the dimensions.
Returns a list of static and dynamic values for shape dimensions. This is
useful to preserve static shapes when available in reshape operation.
Args:
tensor: A tensor of any type.
Returns:
A list of size tensor.shape.ndims containing integers or a scalar tensor.
"""
static_tensor_shape = tensor.shape.as_list()
dynamic_tensor_shape = tf.shape(input=tensor)
combined_shape = []
for index, dim in enumerate(static_tensor_shape):
if dim is not None:
combined_shape.append(dim)
else:
combined_shape.append(dynamic_tensor_shape[index])
return combined_shape
def pad_or_clip_nd(tensor, output_shape):
"""Pad or Clip given tensor to the output shape.
Args:
tensor: Input tensor to pad or clip.
output_shape: A list of integers / scalar tensors (or None for dynamic dim)
representing the size to pad or clip each dimension of the input tensor.
Returns:
Input tensor padded and clipped to the output shape.
"""
tensor_shape = tf.shape(input=tensor)
clip_size = [
tf.where(tensor_shape[i] - shape > 0, shape, -1)
if shape is not None else -1 for i, shape in enumerate(output_shape)
]
clipped_tensor = tf.slice(
tensor, begin=tf.zeros(len(clip_size), dtype=tf.int32), size=clip_size)
# Pad tensor if the shape of clipped tensor is smaller than the expected
# shape.
clipped_tensor_shape = tf.shape(input=clipped_tensor)
trailing_paddings = [
shape - clipped_tensor_shape[i] if shape is not None else 0
for i, shape in enumerate(output_shape)
]
paddings = tf.stack(
[tf.zeros(len(trailing_paddings), dtype=tf.int32), trailing_paddings],
axis=1)
padded_tensor = tf.pad(tensor=clipped_tensor, paddings=paddings)
output_static_shape = [
dim if not isinstance(dim, tf.Tensor) else None for dim in output_shape
]
padded_tensor.set_shape(output_static_shape)
return padded_tensor