TensorFlow

API

 tf / ensure_shape


Updates the shape of a tensor and checks at runtime that the shape holds.

For example:

@tf.function(input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)])
def f(tensor):
  return tf.ensure_shape(tensor, [3, 3])

f(tf.zeros([3, 3])) # Passes
<tf.Tensor: shape=(3, 3), dtype=float32, numpy=
array([[0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.]], dtype=float32)>
f([1, 2, 3]) # fails
Traceback (most recent call last):

InvalidArgumentError:  Shape of tensor x [3] is not compatible with expected shape [3,3].

The above example raises tf.errors.InvalidArgumentError, because the shape (3,) is not compatible with the shape (None, 3, 3)

With eager execution this is a shape assertion, that returns the input:

x = tf.constant([1,2,3])
print(x.shape)
(3,)
x = tf.ensure_shape(x, [3])
x = tf.ensure_shape(x, [5])
Traceback (most recent call last):

tf.errors.InvalidArgumentError: Shape of tensor dummy_input [3] is not
  compatible with expected shape [5]. [Op:EnsureShape]

Inside a tf.function or v1.Graph context it checks both the buildtime and runtime shapes. This is stricter than tf.Tensor.set_shape which only checks the buildtime shape.

For example, of loading images of a known size:

@tf.function
def decode_image(png):
  image = tf.image.decode_png(png, channels=3)
  # the `print` executes during tracing.
  print("Initial shape: ", image.shape)
  image = tf.ensure_shape(image,[28, 28, 3])
  print("Final shape: ", image.shape)
  return image

When tracing a function, no ops are being executed, shapes may be unknown. See the Concrete Functions Guide for details.

concrete_decode = decode_image.get_concrete_function(
    tf.TensorSpec([], dtype=tf.string))
Initial shape:  (None, None, 3)
Final shape:  (28, 28, 3)
image = tf.random.uniform(maxval=255, shape=[28, 28, 3], dtype=tf.int32)
image = tf.cast(image,tf.uint8)
png = tf.image.encode_png(image)
image2 = concrete_decode(png)
print(image2.shape)
(28, 28, 3)
image = tf.concat([image,image], axis=0)
print(image.shape)
(56, 28, 3)
png = tf.image.encode_png(image)
image2 = concrete_decode(png)
Traceback (most recent call last):

tf.errors.InvalidArgumentError:  Shape of tensor DecodePng [56,28,3] is not
  compatible with expected shape [28,28,3].
@tf.function
def bad_decode_image(png):
  image = tf.image.decode_png(png, channels=3)
  # the `print` executes during tracing.
  print("Initial shape: ", image.shape)
  # BAD: forgot to use the returned tensor.
  tf.ensure_shape(image,[28, 28, 3])
  print("Final shape: ", image.shape)
  return image
image = bad_decode_image(png)
Initial shape:  (None, None, 3)
Final shape:  (None, None, 3)
print(image.shape)
(56, 28, 3)

x A Tensor.
shape A TensorShape representing the shape of this tensor, a TensorShapeProto, a list, a tuple, or None.
name A name for this operation (optional). Defaults to "EnsureShape".

A Tensor. Has the same type and contents as x.

tf.errors.InvalidArgumentError If shape is incompatible with the shape of x.

此页内容是否对您有帮助