Inplace Updates#
Inplace updates enable users to overwrite the contents of existing arrays with new data. This enables much more control over the memory-efficiency of the program, preventing old unused arrays from being kept in memory for any longer than is strictly necessary.
The function ivy.inplace_update()
enables explicit inplace updates.
ivy.inplace_update()
is a primary function, and the backend-specific implementations for each backend are presented below.
We also explain the rationale for why each implementation is the way it is, and the important differences.
This is one particular area of the Ivy code where, technically speaking, the function ivy.inplace_update()
will result in subtly different behaviour for each backend, unless the ensure_in_backend
flag is set to True
.
While ivy.Array
instances will always be inplace updated consistently, in some cases it is simply not possible to also inplace update the ivy.NativeArray
which ivy.Array
wraps, due to design choices made by each backend.
NOTE: Native inplace updates do not modify the dtype of the array being updated, as such the keep_input_dtype
flag should normally be set to True
such that inplace updating behavior is consistent between backends.
JAX:
def inplace_update(
x: Union[ivy.Array, JaxArray],
val: Union[ivy.Array, JaxArray],
/,
*,
ensure_in_backend: bool = False,
keep_input_dtype: bool = False,
) -> ivy.Array:
if ivy.is_array(x) and ivy.is_array(val):
if ensure_in_backend:
raise ivy.utils.exceptions.IvyException(
"JAX does not natively support inplace updates"
)
if keep_input_dtype:
val = ivy.astype(val, x.dtype)
(x_native, val_native), _ = ivy.args_to_native(x, val)
if ivy.is_ivy_array(x):
x.data = val_native
# Handle view updates
if ivy.exists(x._base):
base = x._base
base_idx = ivy.arange(base.size).reshape(base.shape)
for fn, args, kwargs, index in x._manipulation_stack:
kwargs["copy"] = True
base_idx = ivy.__dict__[fn](base_idx, *args, **kwargs)
base_idx = base_idx[index] if ivy.exists(index) else base_idx
base_flat = base.data.flatten()
base_flat = base_flat.at[base_idx.data.flatten()].set(
val_native.flatten()
)
base.data = base_flat.reshape(base.shape)
for ref in base._view_refs:
view = ref()
if ivy.exists(view) and view is not x:
_update_view(view, base)
else:
for ref in x._view_refs:
view = ref()
if ivy.exists(view):
_update_view(view, x)
else:
raise ivy.utils.exceptions.IvyException(
"JAX does not natively support inplace updates"
)
return x
else:
return val
JAX does not natively support inplace updates, and so there is no way of actually inplace updating the JaxArray
instance x_native
.
Therefore, an inplace update is only performed on ivy.Array
instances provided in the input.
JAX functions also never returns views so additional logic is added to functionally support multiple variables referencing the same memory (further explained in a later section).
NumPy:
def inplace_update(
x: Union[ivy.Array, np.ndarray],
val: Union[ivy.Array, np.ndarray],
/,
*,
ensure_in_backend: bool = False,
keep_input_dtype: bool = False,
) -> ivy.Array:
ivy.utils.assertions.check_inplace_sizes_valid(x, val)
if ivy.is_array(x) and ivy.is_array(val):
if keep_input_dtype:
val = ivy.astype(val, x.dtype)
(x_native, val_native), _ = ivy.args_to_native(x, val)
# make both arrays contiguous if not already
if not x_native.flags.c_contiguous:
x_native = np.ascontiguousarray(x_native)
if not val_native.flags.c_contiguous:
val_native = np.ascontiguousarray(val_native)
if val_native.shape == x_native.shape:
if x_native.dtype != val_native.dtype:
x_native = x_native.astype(val_native.dtype)
np.copyto(x_native, val_native)
else:
x_native = val_native
if ivy.is_ivy_array(x):
x.data = x_native
else:
x = ivy.Array(x_native)
return x
else:
return val
NumPy does natively support inplace updates, and so x_native
is updated inplace with val_native
.
Following this, an inplace update is then also performed on the ivy.Array
instance, if provided in the input.
TensorFlow:
def inplace_update(
x: Union[ivy.Array, tf.Tensor],
val: Union[ivy.Array, tf.Tensor],
/,
*,
ensure_in_backend: bool = False,
keep_input_dtype: bool = False,
) -> ivy.Array:
if ivy.is_array(x) and ivy.is_array(val):
if keep_input_dtype:
val = ivy.astype(val, x.dtype)
(x_native, val_native), _ = ivy.args_to_native(x, val)
if _is_variable(x_native):
x_native.assign(val_native)
if ivy.is_ivy_array(x):
x.data = x_native
else:
x = ivy.Array(x_native)
elif ensure_in_backend:
raise ivy.utils.exceptions.IvyException(
"TensorFlow does not support inplace updates of the tf.Tensor"
)
elif ivy.is_ivy_array(x):
x.data = val_native
# Handle view updates
if ivy.exists(x._base):
base = x._base
base_idx = ivy.arange(base.size).reshape(base.shape)
for fn, args, kwargs, index in x._manipulation_stack:
kwargs["copy"] = True
base_idx = ivy.__dict__[fn](base_idx, *args, **kwargs)
base_idx = base_idx[index] if ivy.exists(index) else base_idx
base_flat = tf.reshape(base.data, -1)
base_flat = tf.tensor_scatter_nd_update(
base_flat,
tf.reshape(base_idx.data, (-1, 1)),
tf.reshape(val_native, -1),
)
base.data = tf.reshape(base_flat, base.shape)
for ref in base._view_refs:
view = ref()
if ivy.exists(view) and view is not x:
_update_view(view, base)
else:
for ref in x._view_refs:
view = ref()
if ivy.exists(view):
_update_view(view, x)
else:
x = ivy.to_ivy(x_native)
return x
else:
return val
TensorFlow does not natively support inplace updates for tf.Tensor
instances, and in such cases so there is no way of actually inplace updating the tf.Tensor
instance x_native
.
However, TensorFlow does natively support inplace updates for tf.Variable
instances.
Therefore, if x_native
is a tf.Variable
, then x_native
is updated inplace with val_native
.
Irrespective of whether the native array is a tf.Tensor
or a tf.Variable
, an inplace update is then also performed on the ivy.Array
instance, if provided in the input.
TensorFlow functions also never returns views so additional logic is added to functionally support multiple variables referencing the same memory (further explained in a later section).
PyTorch:
def inplace_update(
x: Union[ivy.Array, torch.Tensor],
val: Union[ivy.Array, torch.Tensor],
/,
*,
ensure_in_backend: bool = False,
keep_input_dtype: bool = False,
) -> ivy.Array:
ivy.utils.assertions.check_inplace_sizes_valid(x, val)
if ivy.is_array(x) and ivy.is_array(val):
if keep_input_dtype:
val = ivy.astype(val, x.dtype)
(x_native, val_native), _ = ivy.args_to_native(x, val)
if is_variable(x_native):
x_native.data = val_native
else:
x_native[()] = val_native
if ivy.is_ivy_array(x):
x.data = x_native
_update_torch_views(x)
else:
x = ivy.to_ivy(x_native)
if ensure_in_backend:
x._data = val_native
return x
else:
return val
PyTorch does natively support inplace updates, and so x_native
is updated inplace with val_native
.
Following this, an inplace update is then also performed on the ivy.Array
instance, if provided in the input.
PyTorch also supports views for most manipulation and indexing operations as with NumPy but it lacks that functionality with a few functions such as flip()
.
Additional logic had to be added to support view functionality for those functions (described in a section below).
The function ivy.inplace_update()
is also nestable, meaning it can accept ivy.Container
instances in the input.
If an ivy.Container
instance is provided for the argument x
, then along with the arrays at all of the leaves, the container x
is also inplace updated, meaning that a new ivy.Container
instance is not created for the function return.
out argument#
Most functions in Ivy support inplace updates via the inclusion of a keyword-only out
argument.
This enables users to specify the array to which they would like the output of a function to be written.
This could for example be the input array itself, but can also be any other array of choice.
All Ivy functions which return a single array should support inplace updates via the out
argument.
The type hint of the out
argument is Optional[ivy.Array]
.
However, as discussed above, if the function is nestable then ivy.Container
instances are also supported.
ivy.Container
is omitted from the type hint in such cases, as explained in the Function Arguments section.
When the out
argument is unspecified, then the return is simply provided in a newly created ivy.Array
(or ivy.Container
if nestable).
However, when out
is specified, then the return is provided as an inplace update of the out
argument provided.
This can for example be the same as the input to the function, resulting in a simple inplace update of the input.
In the case of ivy.Array
return types, the out
argument is predominantly handled in handle_out_argument.
As explained in the Function Wrapping section, this wrapping is applied to every function with the @handle_out_argument
decorator dynamically during backend setting.
Primary Functions
In the case of primary functions, handle_out_argument does not handle the backend-specific inplace updates in cases where the backend function being wrapped supports them directly, such as torch.tan and numpy.tan, which both support the out
argument directly.
When implementing backend-specific functions, the attribute support_native_out
should be added to all functions which wrap a function in the backend supporting inplace updates directly.
tf.math.tan and jax.numpy.tan for example do not support inplace updates, and so the support_native_out
attribute should not be added to the tan
implementations.
The implementations of ivy.tan()
for each backend are as follows.
JAX (no support_native_out
attribute):
def tan(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray:
return jnp.tan(x)
NumPy (includes support_native_out
attribute):
@_scalar_output_to_0d_array
def tan(x: np.ndarray, /, *, out: Optional[np.ndarray] = None) -> np.ndarray:
return np.tan(x, out=out)
tan.support_native_out = True
TensorFlow (no support_native_out
attribute):
def tan(
x: Union[tf.Tensor, tf.Variable],
/,
*,
out: Optional[Union[tf.Tensor, tf.Variable]] = None,
) -> Union[tf.Tensor, tf.Variable]:
return tf.tan(x)
PyTorch (includes support_native_out
attribute):
def tan(x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None) -> torch.Tensor:
x = _cast_for_unary_op(x)
return torch.tan(x, out=out)
tan.support_native_out = True
It’s very important to ensure the support_native_out
attribute is not added to backend implementations that do not handle the out
argument, as the presence of this attribute dictates whether the argument should be handled by the backend function or by the wrapper.
This distinction only concerns how the inplace update is applied to the native array, which is operated upon directly by the backend.
If out
is specified in an Ivy function, then an inplace update is always also performed on the ivy.Array
instance itself, which is how out
is provided to the function originally.
The inplace update of this ivy.Array
is always handled by the wrapper.
Alternatively, if out
is an ivy.Container
, then the inplace update is always handled by _wrap_fn in the container wrapping module.
Special Case
Take a function which has multiple possible “paths” through the code:
def cholesky(
x: torch.Tensor, /, *, upper: bool = False, out: Optional[torch.Tensor] = None
) -> torch.Tensor:
if not upper:
return torch.linalg.cholesky(x, out=out)
else:
ret = torch.transpose(
torch.linalg.cholesky(
torch.transpose(x, dim0=len(x.shape) - 1, dim1=len(x.shape) - 2)
),
dim0=len(x.shape) - 1,
dim1=len(x.shape) - 2,
)
if ivy.exists(out):
return ivy.inplace_update(out, ret)
return ret
cholesky.support_native_out = True
Here we still have the support_native_out
attribute since we want to take advantage of the native inplace update enabled by torch.linalg.cholesky()
in the first condition.
However, in the else
statement, the last operation is torch.transpose()
which does not support the out
argument, and so the native inplace update can’t be performed by torch here.
This is why we need to call ivy.inplace_update()
explicitly here, to ensure the native inplace update is performed, as well as the ivy.Array
inplace update.
Another case where we need to use ivy.inplace_update()
with a function that has support_native_out
is for the example of the torch
backend implementation of the ivy.remainder()
function
def remainder(
x1: Union[float, torch.Tensor],
x2: Union[float, torch.Tensor],
/,
*,
modulus: bool = True,
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
x1, x2 = ivy.promote_types_of_inputs(x1, x2)
if not modulus:
res = x1 / x2
res_floored = torch.where(res >= 0, torch.floor(res), torch.ceil(res))
diff = res - res_floored
diff, x2 = ivy.promote_types_of_inputs(diff, x2)
if ivy.exists(out):
if out.dtype != x2.dtype:
return ivy.inplace_update(
out, torch.round(torch.mul(diff, x2)).to(out.dtype)
)
return torch.round(torch.mul(diff, x2), out=out).to(x1.dtype)
return torch.remainder(x1, x2, out=out).to(x1.dtype)
remainder.support_native_out = True
Here, even though the torch.round()
function natively supports the out
argument, in case the dtype
of the out
argument is different
from the dtype
of the result of the function, we need to use ivy.inplace_update()
, while still trying to utilize the native out
argument whenever
the dtype
is the same for maximum possible extent of the native inplace update.
Compositional Functions
For compositional functions, the out
argument should always be handled in the compositional implementation, with no wrapping applied at all.
This is for a few reasons:
we need to show the
out
argument in the compositional function signature, as this is the only function implementation in the codebase. Adding an argument unused in the implementation could cause some confusion.generally, inplace updates are performed because memory management is an area of concern for the user. By handling the
out
argument in the compositional implementation itself. We can maximize the memory efficiency of the function, using inplace updates in as many of the inner Ivy functions as possible.this enables us to make use of backend-specific
out
argument handling.
The second and third points are the most important points.
We’ll use ivy.cross_entropy()
as an example:
def cross_entropy(
true: Union[ivy.Array, ivy.NativeArray],
pred: Union[ivy.Array, ivy.NativeArray],
/,
*,
axis: int = -1,
epsilon: float = 1e-7,
reduction: str = "mean",
out: Optional[ivy.Array] = None,
) -> ivy.Array:
ivy.utils.assertions.check_elem_in_list(reduction, ["none", "sum", "mean"])
pred = ivy.clip(pred, epsilon, 1 - epsilon)
log_pred = ivy.log(pred)
return _reduce_loss(reduction, log_pred * true, axis, out=out)
By handling the out
argument in the function, we are able to get the benefits outlined above.
Firstly, the return of ivy.sum()
is the same shape and type as the return of the entire function, and so we can also write this output to the out
argument inplace.
We can then subsequently overwrite the contents of out
again with the return of the ivy.negative()
function.
This minimizes the number of arrays created during the execution of the function, which is generally the intention when specifying the out
argument.
Additionally, with a PyTorch backend, the ivy.negative()
function defers to the out
argument of torch.negative()
function directly, which is the most efficient inplace update possible, making use of backend-specific optimizations.
If we had instead simply used the wrapper handle_out_argument, then we would not leverage any of these benefits, and instead simply call ivy.inplace_update()
at the very end of the function call.
For some compositional functions, the internal function which generates the final return value does not itself support the out
argument.
For example, ivy.multi_head_attention includes support for arbitrary functions passed in the input, including to_out_fn
which, if specified, is applied to the outputs before returning.
For such functions, the inplace update should just be performed using ivy.inplace_update()
at the end of the function, like so.
Technically, this could be handled using the handle_out_argument wrapping, but we opt to implement this in the compositional function itself, due to point 1 mentioned above.
Mixed Functions
As explained in the Function Types section, mixed functions can effectively behave as either compositional or primary functions, depending on the backend that is selected. We must add the handle_out_argument
to the add_wrappers`key of
the :code:`mixed_backend_wrappers
attribute so that the decorator gets added to the primary implementation when the backend is set. Here’s an example from the linear function.
copy argument#
As well as the out
argument, many also support the copy
argument.
The functions with support for the copy
argument are either in the Array API Standard, and the standard mandates the inclusion of copy
in each case.
Or they are expected to return views with specific backends (hence being decorated with the @handle_view
wrapper) and the copy
is added to allow a way to prevent views from being created.
The copy
argument dictates whether a new copy should be created, or whether the input array should be updated inplace.
When copy
is not specified explicitly, then an inplace update is performed with the same behaviour as copy=False
.
Setting copy=False
is equivalent to passing out=input_array
.
If only one of copy
or out
is specified, then this specified argument is given priority.
If both are specified, then priority is given to the more general out
argument.
As with the out
argument, the copy
argument is also handled by the wrapper.
Views#
Many functions in NumPy and PyTorch return views instead of copies, these functions are mostly manipulation routines or indexing routines.
Views are arrays which access the same data buffer as another array but view it with different metadata like stride
.
More information about these arrays can be found in NumPy’s documentation.
This essentially means that any inplace update on the original array or any of its views will cause all the other views to be updated as well since they reference the same memory buffer.
We want to keep supporting NumPy and PyTorch inplace updates whenever we can and superset backend behaviour, however it is not trivial to replicate this in JAX and TensorFlow. The main reason is because these frameworks do not natively support inplace updates so even if multiple native arrays are referencing the same memory buffer, you would never be able to update it once for all of them. Therefore views and their updates must be tracked through Ivy and extra logic has been added to update views in the case an inplace update happens to any array which is meant to be referencing the same memory. We call views tracked and updated by Ivy functional views as they work with a functional paradigm.
What functions return views is mostly dictated by NumPy since it has the most expansive support for them, any function which returns views in NumPy or PyTorch should be decorated with the @handle_view
wrapper, except get_item()
which has it’s own @handle_view_indexing
wrapper.
Every function with this wrapper should also have a copy
argument such that Ivy maintains a way to prevent views from being created if necessary.
What that wrapper does is update a few ivy.Array
attributes which help keep track of views, how they were created, and which arrays should be updated together.
These attributes are then used in the ivy.inplace_update()
to update all the arrays which are meant to be referencing the same memory, at least to NumPy’s standard.
Of course, these are normally only used with a JAX and TensorFlow backend since NumPy and PyTorch natively update their views and Ivy does not need to do any extra handling except for a few functions where PyTorch fails to return views when NumPy does.
The functions currently implemented in the Ivy API where PyTorch fails to return views at the time of writing are ivy.flip()
, ivy.rot90()
, ivy.flipud()
, ivy.fliplr()
.
In the case one of those functions is used with a Pytorch backend, additional logic has been added to make the returns of those functions behave as views of the original that made them.
Here’s a brief description of the additional attributes added to ivy.Array
and their usage:
Base (
._base
): the original array being referenced (array all views stem from)Manipulation stack (
._manipulation_stack
): store of operations that were done on the original to get to the current shape (manipulation or indexing)Reference stack
._view_refs
: Weak references to the arrays that reference the original as view, only populated for base arrays.PyTorch Base (
._torch_base
): Keeps track of functional view (array created from the listed functions above) that made it, otherwise stores original arrayPyTorch reference stack (
._torch_view_refs
): Functional views referencing this array in its PyTorch base, only populated for original arrays or functional views.PyTorch manipulation cache (
._torch_manipulation
): Tuple storing array or view and function which made the functional view, only populated for functional views
Note
Parts of an arrays metadata like stride
are attributed to the low-level memory layout of arrays while views in ivy
operate at a higher level of abstraction.
As a result, ivy.strides()
isn’t guaranteed to produce an output reflective of the underlying memory layout if the ivy.Array
passed in is a view (or in other words has a _base
).
Here’s a brief description of how the @handle_view
wrapper populates these attributes:
When an array is made using a function decorated by this wrapper its base becomes the array that made it, or if the array that made it is also a view, its base.
The view is then added to the reference stack of the base array (weakly), the operation that created the array is also added to the manipulation stack of the array.
The way the PyTorch specific attributes are updated should be adequately explained above.
Here’s a brief description of what happens during an inplace operation with a JAX and TensorFlow backend:
If the base is inplace updated, then it goes through all the arrays in the reference stack, and through their manipulation, then inplace updates every array respectively.
If a view gets inplace updated, an index array is created of the shape of the base array, which then is passed through the manipulation stack of the updated array.
The updated array and the index array are then flattened and they then update the original array by performing a scatter update on a flattened version of the original array, which then gets reshaped into the correct shape.
Then the all views stemming from the original are updated as described in the first point.
Here’s a brief description of what happens during an inplace operation with a PyTorch backend:
The array being updated checks if it has a populated reference stack, if it does it inplace updates each functional view in the stack with the output of the stored function called with the array that made it. It then checks if the functional view has a reference stack and continues recursively until it reaches a point where it exhausts all reference stacks.
If the reference stack is empty or exhausted it checks if it has a manipulation stack. If populated it performs the reverse functional operation with itself as the input and inplace updates the view that made it (reverses the operation that made it). If the manipulation stack is empty or already exhausted it goes to the array’s PyTorch base and repeats the recursively until everything is exhausted and the base is None.
All other views are expected to be updated automatically through PyTorch’s native view handling.
Round Up
This should have hopefully given you a good feel for inplace updates, and how these are handled in Ivy.
If you have any questions, please feel free to reach out on discord in the inplace updates thread!