import tensorflow as tf
from typing import List
import copy as cp
[docs]
def move_axis(x: tf.Tensor, from_index: int, to_index: int) -> tf.Tensor:
"""Moves an axis from from_index to to_index.
:param x: A tensor of shape [..., k, ...] where k is at from_index.
:type x: :class:`tensorflow.Tensor`
:param from_index: The index of the axis before transposition.
:type from_index: int
:param to_index: The index of the axis after transposition.
:type to_index: int
:return: x_new (:class:`tensorflow.Tensor`): The tensor x transposed such that shape [..., k, ...] now has k at to_index."""
# Input validity
if from_index == -1: from_index = len(x.shape)-1
if to_index == -1: to_index = len(x.shape)-1
# Move axis
new_order = list(range(len(x.shape)))
del new_order[from_index]
new_order.insert(to_index, from_index)
x_new = tf.transpose(a=x, perm=new_order)
# Outputs
return x_new
[docs]
def expand_axes(x: tf.Tensor, axes) -> tf.Tensor:
"""Expands x with singleton axes.
:param x: The tensor to be expanded.
:type x: :class:`tensorflow.Tensor`
:param axes: The axes along which to expand. Their indices are assumed to be valid in the shape of ``x_new``. This means if,
e.g. ``x`` has two axes then ``axes`` may be, e.g. [0,1,3,5,6,7] where axes 2 and 4 are filled in order by ``x`` but
``axes`` must not be, e.g. [0,1,3,5,6,10] because of the gap between 6 and 10 that would be introduced in the shape of
``x_new``.
:type axes: :class:`List[tensorflow.Tensor]`
:return: x_new (:class:`tensorflow.Tensor`) - The reshaped version of x with singletons along ''axes''."""
# Initialize
new_axis_count = len(x.shape) + len(axes)
# Compatibility of new and old axes
old_axes = list(range(new_axis_count))
for axis in axes:
# Input validity
assert axis < new_axis_count, f"""The axis {axis} must be in the interval [0,{new_axis_count})."""
# Exclude new axis from old axes
old_axes.remove(axis)
# Set new shape
o = 0 # Iterates old axes
new_shape = [1] * new_axis_count
for axis in old_axes:
new_shape[axis] = x.shape[o]
o += 1
x_new = tf.keras.ops.reshape(x, new_shape)
# Outputs
return x_new
[docs]
def flatten_along_axes(x: tf.Tensor, axes: List[int]) -> tf.Tensor:
"""Flattens an input ``x`` along axes ``axes``.
:param x: The input to be flattened. Assumed to have at least as many axes as indicated by ``axes``.
:type x: :class:`tensorflow.Tensor`
:param axes: The axes along which the input shall be flattened.
:type axes: :class:`List[int]`
:return: x_new (:class:`tensorflow.Tensor`) - The reshaped tensor ``x`` flattened along ``axes``."""
# Exception handling
if len(axes) == 0: return x
# Reshape
new_shape = list(tf.keras.ops.shape(x))
new_shape[axes[0]] = 1
for a in axes: new_shape[axes[0]] *= tf.keras.ops.shape(x)[a]
axes = cp.copy(axes); axes.reverse()
for a in axes[:-1]: del new_shape[a]
x_new = tf.keras.ops.reshape(x, newshape=new_shape) # Now has original shape except for axes which have been flattened
# Outputs
return x_new
[docs]
def swop_axes(x: tf.Tensor, from_axis: int, to_axis: int) -> tf.Tensor:
"""Swops axes of ``x``.
:param x: The input whose axes shall be swopped. Assumed to have at least as many axes as indicated by
``from_axis`` and ``to_axis``.
:type x: :class:`tensorflow.Tensor`
:param from_axis: The axes to be swopped with ``to_axis``.
:type from_axis: int
:param to_axis: The axes to be swopped with ``from_axis``.
:type to_axis: int
:return: x_new (:class:`tensorflow.Tensor`) - The input which ``from_axis`` and ``to_axis`` swopped."""
# Reshape
axes = list(range(len(x.shape)))
tmp = axes[to_axis]
axes[to_axis] = axes[from_axis]
axes[from_axis] = tmp
x_new = tf.keras.ops.transpose(x, axes=axes)
# Outputs
return x_new