Inplace Updates#

Inplace updates enable users to overwrite the contents of existing arrays with new data. This enables much more control over the memory-efficiency of the program, preventing old unused arrays from being kept in memory for any longer than is strictly necessary.

The function ivy.inplace_update() enables explicit inplace updates. ivy.inplace_update() is a primary function, and the backend-specific implementations for each backend are presented below. We also explain the rationale for why each implementation is the way it is, and the important differences.

This is one particular area of the Ivy code where, technically speaking, the function ivy.inplace_update() will result in subtly different behaviour for each backend, unless the ensure_in_backend flag is set to True.

While ivy.Array instances will always be inplace updated consistently, in some cases it is simply not possible to also inplace update the ivy.NativeArray which ivy.Array wraps, due to design choices made by each backend.

NOTE: Native inplace updates do not modify the dtype of the array being updated, as such the keep_input_dtype flag should normally be set to True such that inplace updating behavior is consistent between backends.

JAX:

def inplace_update(
    x: Union[ivy.Array, JaxArray],
    val: Union[ivy.Array, JaxArray],
    /,
    *,
    ensure_in_backend: bool = False,
    keep_input_dtype: bool = False,
) -> ivy.Array:
    if ivy.is_array(x) and ivy.is_array(val):
        if ensure_in_backend:
            raise ivy.utils.exceptions.IvyException(
                "JAX does not natively support inplace updates"
            )
        if keep_input_dtype:
            val = ivy.astype(val, x.dtype)
        (x_native, val_native), _ = ivy.args_to_native(x, val)
        if ivy.is_ivy_array(x):
            x.data = val_native
            # Handle view updates
            if ivy.exists(x._base):
                base = x._base
                base_idx = ivy.arange(base.size).reshape(base.shape)
                for fn, args, kwargs, index in x._manipulation_stack:
                    kwargs["copy"] = True
                    base_idx = ivy.__dict__[fn](base_idx, *args, **kwargs)
                    base_idx = base_idx[index] if ivy.exists(index) else base_idx
                base_flat = base.data.flatten()
                base_flat = base_flat.at[base_idx.data.flatten()].set(
                    val_native.flatten()
                )

                base.data = base_flat.reshape(base.shape)

                for ref in base._view_refs:
                    view = ref()
                    if ivy.exists(view) and view is not x:
                        _update_view(view, base)

            else:
                for ref in x._view_refs:
                    view = ref()
                    if ivy.exists(view):
                        _update_view(view, x)
        else:
            raise ivy.utils.exceptions.IvyException(
                "JAX does not natively support inplace updates"
            )
        return x
    else:
        return val

JAX does not natively support inplace updates, and so there is no way of actually inplace updating the JaxArray instance x_native. Therefore, an inplace update is only performed on ivy.Array instances provided in the input.

JAX functions also never returns views so additional logic is added to functionally support multiple variables referencing the same memory (further explained in a later section).

NumPy:

def inplace_update(
    x: Union[ivy.Array, np.ndarray],
    val: Union[ivy.Array, np.ndarray],
    /,
    *,
    ensure_in_backend: bool = False,
    keep_input_dtype: bool = False,
) -> ivy.Array:
    ivy.utils.assertions.check_inplace_sizes_valid(x, val)
    if ivy.is_array(x) and ivy.is_array(val):
        if keep_input_dtype:
            val = ivy.astype(val, x.dtype)
        (x_native, val_native), _ = ivy.args_to_native(x, val)

        # make both arrays contiguous if not already
        if not x_native.flags.c_contiguous:
            x_native = np.ascontiguousarray(x_native)
        if not val_native.flags.c_contiguous:
            val_native = np.ascontiguousarray(val_native)

        if val_native.shape == x_native.shape:
            if x_native.dtype != val_native.dtype:
                x_native = x_native.astype(val_native.dtype)
            np.copyto(x_native, val_native)
        else:
            x_native = val_native
        if ivy.is_ivy_array(x):
            x.data = x_native
        else:
            x = ivy.Array(x_native)
        return x
    else:
        return val

NumPy does natively support inplace updates, and so x_native is updated inplace with val_native. Following this, an inplace update is then also performed on the ivy.Array instance, if provided in the input.

TensorFlow:

def inplace_update(
    x: Union[ivy.Array, tf.Tensor],
    val: Union[ivy.Array, tf.Tensor],
    /,
    *,
    ensure_in_backend: bool = False,
    keep_input_dtype: bool = False,
) -> ivy.Array:
    if ivy.is_array(x) and ivy.is_array(val):
        if keep_input_dtype:
            val = ivy.astype(val, x.dtype)
        (x_native, val_native), _ = ivy.args_to_native(x, val)
        if _is_variable(x_native):
            x_native.assign(val_native)
            if ivy.is_ivy_array(x):
                x.data = x_native
            else:
                x = ivy.Array(x_native)
        elif ensure_in_backend:
            raise ivy.utils.exceptions.IvyException(
                "TensorFlow does not support inplace updates of the tf.Tensor"
            )
        elif ivy.is_ivy_array(x):
            x.data = val_native
            # Handle view updates
            if ivy.exists(x._base):
                base = x._base
                base_idx = ivy.arange(base.size).reshape(base.shape)
                for fn, args, kwargs, index in x._manipulation_stack:
                    kwargs["copy"] = True
                    base_idx = ivy.__dict__[fn](base_idx, *args, **kwargs)
                    base_idx = base_idx[index] if ivy.exists(index) else base_idx
                base_flat = tf.reshape(base.data, -1)
                base_flat = tf.tensor_scatter_nd_update(
                    base_flat,
                    tf.reshape(base_idx.data, (-1, 1)),
                    tf.reshape(val_native, -1),
                )

                base.data = tf.reshape(base_flat, base.shape)
                for ref in base._view_refs:
                    view = ref()
                    if ivy.exists(view) and view is not x:
                        _update_view(view, base)
            else:
                for ref in x._view_refs:
                    view = ref()
                    if ivy.exists(view):
                        _update_view(view, x)
        else:
            x = ivy.to_ivy(x_native)
        return x
    else:
        return val

TensorFlow does not natively support inplace updates for tf.Tensor instances, and in such cases so there is no way of actually inplace updating the tf.Tensor instance x_native. However, TensorFlow does natively support inplace updates for tf.Variable instances. Therefore, if x_native is a tf.Variable, then x_native is updated inplace with val_native. Irrespective of whether the native array is a tf.Tensor or a tf.Variable, an inplace update is then also performed on the ivy.Array instance, if provided in the input.

TensorFlow functions also never returns views so additional logic is added to functionally support multiple variables referencing the same memory (further explained in a later section).

PyTorch:

def inplace_update(
    x: Union[ivy.Array, torch.Tensor],
    val: Union[ivy.Array, torch.Tensor],
    /,
    *,
    ensure_in_backend: bool = False,
    keep_input_dtype: bool = False,
) -> ivy.Array:
    ivy.utils.assertions.check_inplace_sizes_valid(x, val)
    if ivy.is_array(x) and ivy.is_array(val):
        if keep_input_dtype:
            val = ivy.astype(val, x.dtype)
        (x_native, val_native), _ = ivy.args_to_native(x, val)
        if is_variable(x_native):
            x_native.data = val_native
        else:
            x_native[()] = val_native
        if ivy.is_ivy_array(x):
            x.data = x_native
            _update_torch_views(x)
        else:
            x = ivy.to_ivy(x_native)
        if ensure_in_backend:
            x._data = val_native
        return x
    else:
        return val

PyTorch does natively support inplace updates, and so x_native is updated inplace with val_native. Following this, an inplace update is then also performed on the ivy.Array instance, if provided in the input.

PyTorch also supports views for most manipulation and indexing operations as with NumPy but it lacks that functionality with a few functions such as flip(). Additional logic had to be added to support view functionality for those functions (described in a section below).

The function ivy.inplace_update() is also nestable, meaning it can accept ivy.Container instances in the input. If an ivy.Container instance is provided for the argument x, then along with the arrays at all of the leaves, the container x is also inplace updated, meaning that a new ivy.Container instance is not created for the function return.

out argument#

Most functions in Ivy support inplace updates via the inclusion of a keyword-only out argument. This enables users to specify the array to which they would like the output of a function to be written. This could for example be the input array itself, but can also be any other array of choice.

All Ivy functions which return a single array should support inplace updates via the out argument. The type hint of the out argument is Optional[ivy.Array]. However, as discussed above, if the function is nestable then ivy.Container instances are also supported. ivy.Container is omitted from the type hint in such cases, as explained in the Function Arguments section.

When the out argument is unspecified, then the return is simply provided in a newly created ivy.Array (or ivy.Container if nestable). However, when out is specified, then the return is provided as an inplace update of the out argument provided. This can for example be the same as the input to the function, resulting in a simple inplace update of the input.

In the case of ivy.Array return types, the out argument is predominantly handled in handle_out_argument. As explained in the Function Wrapping section, this wrapping is applied to every function with the @handle_out_argument decorator dynamically during backend setting.

Primary Functions

In the case of primary functions, handle_out_argument does not handle the backend-specific inplace updates in cases where the backend function being wrapped supports them directly, such as torch.tan and numpy.tan, which both support the out argument directly. When implementing backend-specific functions, the attribute support_native_out should be added to all functions which wrap a function in the backend supporting inplace updates directly. tf.math.tan and jax.numpy.tan for example do not support inplace updates, and so the support_native_out attribute should not be added to the tan implementations.

The implementations of ivy.tan() for each backend are as follows.

JAX (no support_native_out attribute):

def tan(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray:
    return jnp.tan(x)

NumPy (includes support_native_out attribute):

@_scalar_output_to_0d_array
def tan(x: np.ndarray, /, *, out: Optional[np.ndarray] = None) -> np.ndarray:
    return np.tan(x, out=out)


tan.support_native_out = True

TensorFlow (no support_native_out attribute):

def tan(
    x: Union[tf.Tensor, tf.Variable],
    /,
    *,
    out: Optional[Union[tf.Tensor, tf.Variable]] = None,
) -> Union[tf.Tensor, tf.Variable]:
    return tf.tan(x)

PyTorch (includes support_native_out attribute):

def tan(x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None) -> torch.Tensor:
    x = _cast_for_unary_op(x)
    return torch.tan(x, out=out)

tan.support_native_out = True

It’s very important to ensure the support_native_out attribute is not added to backend implementations that do not handle the out argument, as the presence of this attribute dictates whether the argument should be handled by the backend function or by the wrapper.

This distinction only concerns how the inplace update is applied to the native array, which is operated upon directly by the backend. If out is specified in an Ivy function, then an inplace update is always also performed on the ivy.Array instance itself, which is how out is provided to the function originally. The inplace update of this ivy.Array is always handled by the wrapper.

Alternatively, if out is an ivy.Container, then the inplace update is always handled by _wrap_fn in the container wrapping module.

Special Case

Take a function which has multiple possible “paths” through the code:

def cholesky(
    x: torch.Tensor, /, *, upper: bool = False, out: Optional[torch.Tensor] = None
) -> torch.Tensor:
    if not upper:
        return torch.linalg.cholesky(x, out=out)
    else:
        ret = torch.transpose(
            torch.linalg.cholesky(
                torch.transpose(x, dim0=len(x.shape) - 1, dim1=len(x.shape) - 2)
            ),
            dim0=len(x.shape) - 1,
            dim1=len(x.shape) - 2,
        )
        if ivy.exists(out):
            return ivy.inplace_update(out, ret)
        return ret


cholesky.support_native_out = True

Here we still have the support_native_out attribute since we want to take advantage of the native inplace update enabled by torch.linalg.cholesky() in the first condition. However, in the else statement, the last operation is torch.transpose() which does not support the out argument, and so the native inplace update can’t be performed by torch here. This is why we need to call ivy.inplace_update() explicitly here, to ensure the native inplace update is performed, as well as the ivy.Array inplace update.

Another case where we need to use ivy.inplace_update() with a function that has support_native_out is for the example of the torch backend implementation of the ivy.remainder() function

def remainder(
    x1: Union[float, torch.Tensor],
    x2: Union[float, torch.Tensor],
    /,
    *,
    modulus: bool = True,
    out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    x1, x2 = ivy.promote_types_of_inputs(x1, x2)
    if not modulus:
        res = x1 / x2
        res_floored = torch.where(res >= 0, torch.floor(res), torch.ceil(res))
        diff = res - res_floored
        diff, x2 = ivy.promote_types_of_inputs(diff, x2)
        if ivy.exists(out):
            if out.dtype != x2.dtype:
                return ivy.inplace_update(
                    out, torch.round(torch.mul(diff, x2)).to(out.dtype)
                )
        return torch.round(torch.mul(diff, x2), out=out).to(x1.dtype)
    return torch.remainder(x1, x2, out=out).to(x1.dtype)


remainder.support_native_out = True

Here, even though the torch.round() function natively supports the out argument, in case the dtype of the out argument is different from the dtype of the result of the function, we need to use ivy.inplace_update(), while still trying to utilize the native out argument whenever the dtype is the same for maximum possible extent of the native inplace update.

Compositional Functions

For compositional functions, the out argument should always be handled in the compositional implementation, with no wrapping applied at all. This is for a few reasons:

  1. we need to show the out argument in the compositional function signature, as this is the only function implementation in the codebase. Adding an argument unused in the implementation could cause some confusion.

  2. generally, inplace updates are performed because memory management is an area of concern for the user. By handling the out argument in the compositional implementation itself. We can maximize the memory efficiency of the function, using inplace updates in as many of the inner Ivy functions as possible.

  3. this enables us to make use of backend-specific out argument handling.

The second and third points are the most important points.

We’ll use ivy.cross_entropy() as an example:

def cross_entropy(
    true: Union[ivy.Array, ivy.NativeArray],
    pred: Union[ivy.Array, ivy.NativeArray],
    /,
    *,
    axis: int = -1,
    epsilon: float = 1e-7,
    reduction: str = "mean",
    out: Optional[ivy.Array] = None,
) -> ivy.Array:
    ivy.utils.assertions.check_elem_in_list(reduction, ["none", "sum", "mean"])
    pred = ivy.clip(pred, epsilon, 1 - epsilon)
    log_pred = ivy.log(pred)
    return _reduce_loss(reduction, log_pred * true, axis, out=out)

By handling the out argument in the function, we are able to get the benefits outlined above. Firstly, the return of ivy.sum() is the same shape and type as the return of the entire function, and so we can also write this output to the out argument inplace. We can then subsequently overwrite the contents of out again with the return of the ivy.negative() function. This minimizes the number of arrays created during the execution of the function, which is generally the intention when specifying the out argument. Additionally, with a PyTorch backend, the ivy.negative() function defers to the out argument of torch.negative() function directly, which is the most efficient inplace update possible, making use of backend-specific optimizations.

If we had instead simply used the wrapper handle_out_argument, then we would not leverage any of these benefits, and instead simply call ivy.inplace_update() at the very end of the function call.

For some compositional functions, the internal function which generates the final return value does not itself support the out argument. For example, ivy.multi_head_attention includes support for arbitrary functions passed in the input, including to_out_fn which, if specified, is applied to the outputs before returning. For such functions, the inplace update should just be performed using ivy.inplace_update() at the end of the function, like so.

Technically, this could be handled using the handle_out_argument wrapping, but we opt to implement this in the compositional function itself, due to point 1 mentioned above.

Mixed Functions

As explained in the Function Types section, mixed functions can effectively behave as either compositional or primary functions, depending on the backend that is selected. We must add the handle_out_argument to the add_wrappers`key of the :code:`mixed_backend_wrappers attribute so that the decorator gets added to the primary implementation when the backend is set. Here’s an example from the linear function.

copy argument#

As well as the out argument, many also support the copy argument. The functions with support for the copy argument are either in the Array API Standard, and the standard mandates the inclusion of copy in each case. Or they are expected to return views with specific backends (hence being decorated with the @handle_view wrapper) and the copy is added to allow a way to prevent views from being created.

The copy argument dictates whether a new copy should be created, or whether the input array should be updated inplace. When copy is not specified explicitly, then an inplace update is performed with the same behaviour as copy=False. Setting copy=False is equivalent to passing out=input_array. If only one of copy or out is specified, then this specified argument is given priority. If both are specified, then priority is given to the more general out argument. As with the out argument, the copy argument is also handled by the wrapper.

Views#

Many functions in NumPy and PyTorch return views instead of copies, these functions are mostly manipulation routines or indexing routines. Views are arrays which access the same data buffer as another array but view it with different metadata like stride. More information about these arrays can be found in NumPy’s documentation. This essentially means that any inplace update on the original array or any of its views will cause all the other views to be updated as well since they reference the same memory buffer.

We want to keep supporting NumPy and PyTorch inplace updates whenever we can and superset backend behaviour, however it is not trivial to replicate this in JAX and TensorFlow. The main reason is because these frameworks do not natively support inplace updates so even if multiple native arrays are referencing the same memory buffer, you would never be able to update it once for all of them. Therefore views and their updates must be tracked through Ivy and extra logic has been added to update views in the case an inplace update happens to any array which is meant to be referencing the same memory. We call views tracked and updated by Ivy functional views as they work with a functional paradigm.

What functions return views is mostly dictated by NumPy since it has the most expansive support for them, any function which returns views in NumPy or PyTorch should be decorated with the @handle_view wrapper, except get_item() which has it’s own @handle_view_indexing wrapper. Every function with this wrapper should also have a copy argument such that Ivy maintains a way to prevent views from being created if necessary. What that wrapper does is update a few ivy.Array attributes which help keep track of views, how they were created, and which arrays should be updated together. These attributes are then used in the ivy.inplace_update() to update all the arrays which are meant to be referencing the same memory, at least to NumPy’s standard. Of course, these are normally only used with a JAX and TensorFlow backend since NumPy and PyTorch natively update their views and Ivy does not need to do any extra handling except for a few functions where PyTorch fails to return views when NumPy does. The functions currently implemented in the Ivy API where PyTorch fails to return views at the time of writing are ivy.flip(), ivy.rot90(), ivy.flipud(), ivy.fliplr(). In the case one of those functions is used with a Pytorch backend, additional logic has been added to make the returns of those functions behave as views of the original that made them.

Here’s a brief description of the additional attributes added to ivy.Array and their usage:

  1. Base (._base): the original array being referenced (array all views stem from)

  2. Manipulation stack (._manipulation_stack): store of operations that were done on the original to get to the current shape (manipulation or indexing)

  3. Reference stack ._view_refs: Weak references to the arrays that reference the original as view, only populated for base arrays.

  4. PyTorch Base (._torch_base): Keeps track of functional view (array created from the listed functions above) that made it, otherwise stores original array

  5. PyTorch reference stack (._torch_view_refs): Functional views referencing this array in its PyTorch base, only populated for original arrays or functional views.

  6. PyTorch manipulation cache (._torch_manipulation): Tuple storing array or view and function which made the functional view, only populated for functional views

Note

Parts of an arrays metadata like stride are attributed to the low-level memory layout of arrays while views in ivy operate at a higher level of abstraction. As a result, ivy.strides() isn’t guaranteed to produce an output reflective of the underlying memory layout if the ivy.Array passed in is a view (or in other words has a _base).

Here’s a brief description of how the @handle_view wrapper populates these attributes:

  1. When an array is made using a function decorated by this wrapper its base becomes the array that made it, or if the array that made it is also a view, its base.

  2. The view is then added to the reference stack of the base array (weakly), the operation that created the array is also added to the manipulation stack of the array.

  3. The way the PyTorch specific attributes are updated should be adequately explained above.

Here’s a brief description of what happens during an inplace operation with a JAX and TensorFlow backend:

  1. If the base is inplace updated, then it goes through all the arrays in the reference stack, and through their manipulation, then inplace updates every array respectively.

  2. If a view gets inplace updated, an index array is created of the shape of the base array, which then is passed through the manipulation stack of the updated array.

  3. The updated array and the index array are then flattened and they then update the original array by performing a scatter update on a flattened version of the original array, which then gets reshaped into the correct shape.

  4. Then the all views stemming from the original are updated as described in the first point.

Here’s a brief description of what happens during an inplace operation with a PyTorch backend:

  1. The array being updated checks if it has a populated reference stack, if it does it inplace updates each functional view in the stack with the output of the stored function called with the array that made it. It then checks if the functional view has a reference stack and continues recursively until it reaches a point where it exhausts all reference stacks.

  2. If the reference stack is empty or exhausted it checks if it has a manipulation stack. If populated it performs the reverse functional operation with itself as the input and inplace updates the view that made it (reverses the operation that made it). If the manipulation stack is empty or already exhausted it goes to the array’s PyTorch base and repeats the recursively until everything is exhausted and the base is None.

  3. All other views are expected to be updated automatically through PyTorch’s native view handling.

Round Up

This should have hopefully given you a good feel for inplace updates, and how these are handled in Ivy.

If you have any questions, please feel free to reach out on discord in the inplace updates thread!