Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Utils used to manipulate tensor shapes.""" | |
import tensorflow as tf, tf_keras | |
def assert_shape_equal(shape_a, shape_b): | |
"""Asserts that shape_a and shape_b are equal. | |
If the shapes are static, raises a ValueError when the shapes | |
mismatch. | |
If the shapes are dynamic, raises a tf InvalidArgumentError when the shapes | |
mismatch. | |
Args: | |
shape_a: a list containing shape of the first tensor. | |
shape_b: a list containing shape of the second tensor. | |
Returns: | |
Either a tf.no_op() when shapes are all static and a tf.assert_equal() op | |
when the shapes are dynamic. | |
Raises: | |
ValueError: When shapes are both static and unequal. | |
""" | |
if (all(isinstance(dim, int) for dim in shape_a) and | |
all(isinstance(dim, int) for dim in shape_b)): | |
if shape_a != shape_b: | |
raise ValueError('Unequal shapes {}, {}'.format(shape_a, shape_b)) | |
else: | |
return tf.no_op() | |
else: | |
return tf.assert_equal(shape_a, shape_b) | |
def combined_static_and_dynamic_shape(tensor): | |
"""Returns a list containing static and dynamic values for the dimensions. | |
Returns a list of static and dynamic values for shape dimensions. This is | |
useful to preserve static shapes when available in reshape operation. | |
Args: | |
tensor: A tensor of any type. | |
Returns: | |
A list of size tensor.shape.ndims containing integers or a scalar tensor. | |
""" | |
static_tensor_shape = tensor.shape.as_list() | |
dynamic_tensor_shape = tf.shape(input=tensor) | |
combined_shape = [] | |
for index, dim in enumerate(static_tensor_shape): | |
if dim is not None: | |
combined_shape.append(dim) | |
else: | |
combined_shape.append(dynamic_tensor_shape[index]) | |
return combined_shape | |
def pad_or_clip_nd(tensor, output_shape): | |
"""Pad or Clip given tensor to the output shape. | |
Args: | |
tensor: Input tensor to pad or clip. | |
output_shape: A list of integers / scalar tensors (or None for dynamic dim) | |
representing the size to pad or clip each dimension of the input tensor. | |
Returns: | |
Input tensor padded and clipped to the output shape. | |
""" | |
tensor_shape = tf.shape(input=tensor) | |
clip_size = [ | |
tf.where(tensor_shape[i] - shape > 0, shape, -1) | |
if shape is not None else -1 for i, shape in enumerate(output_shape) | |
] | |
clipped_tensor = tf.slice( | |
tensor, begin=tf.zeros(len(clip_size), dtype=tf.int32), size=clip_size) | |
# Pad tensor if the shape of clipped tensor is smaller than the expected | |
# shape. | |
clipped_tensor_shape = tf.shape(input=clipped_tensor) | |
trailing_paddings = [ | |
shape - clipped_tensor_shape[i] if shape is not None else 0 | |
for i, shape in enumerate(output_shape) | |
] | |
paddings = tf.stack( | |
[tf.zeros(len(trailing_paddings), dtype=tf.int32), trailing_paddings], | |
axis=1) | |
padded_tensor = tf.pad(tensor=clipped_tensor, paddings=paddings) | |
output_static_shape = [ | |
dim if not isinstance(dim, tf.Tensor) else None for dim in output_shape | |
] | |
padded_tensor.set_shape(output_static_shape) | |
return padded_tensor | |