Gradients#
Overview#
Gradients are a crucial aspect of all modern deep learning workflows.
Different frameworks provide different APIs for gradient computation and there were a few considerations to be made while building a unified gradients API in Ivy.
There are a number of functions added in ivy to allow gradient computation, but we’ll mainly focus on the most commonly used and the most general function ivy.execute_with_gradients()
.
This is because the other gradient functions such as ivy.value_and_grad()
and ivy.grad()
can be considered as providing a subset of the functionality that ivy.execute_with_gradients()
provides.
Example Usage of the Gradient API#
The ivy.execute_with_gradients()
function signature#
Following is the pseudo function signature for the ivy.execute_with_gradients()
function,
def execute_with_gradients (
func : Callable,
xs : Any arbitrary nest,
xs_grad_idxs : Input indices,
ret_grad_idxs : Output indices,
) :
return func_ret, grads
The func
in the input can be any user-defined function that returns a single scalar or any arbitrary nest of scalars.
By scalars, we are referring to zero-dimensional arrays.
So for example, the following are some valid outputs by the func
,
ivy.array(12.)
# OR
ivy.Container(
a=ivy.array(12.),
b=ivy.Container(
c=ivy.array(15.),
d=ivy.array(32.)
)
)
# OR
[ivy.array(25.), {'x': (ivy.array(21.), ivy.array(11.))}, (ivy.array(9.),)]
xs
can be any arbitrary nest of arrays and refers to the inputs passed to the func
, so we suggest designing your func
based on what inputs you pass in xs
.
The arrays in xs
can contain any arbitrary number of dimensions, the only constraint is on the output of the func
as explained above.
The xs_grad_idxs
and ret_grad_idxs
are intended to provide more control over the arrays gradients are computed with.
xs_grad_idxs
accepts the indices of the input arrays to compute gradients for, and ret_grad_idxs
accepts the indices of the output arrays to compute gradients with respect to.
An example using ivy.execute_with_gradients()
#
def func(xs) :
return ivy.mean(xs[0] + xs[1].b)
x = ivy.array([1., 2., 3.])
x = ivy.Container(a=x, b=x)
y = ivy.array([4., 5., 6.])
y = ivy.Container(b=y, c=x)
xs = [x, y]
ret, grads = ivy.execute_with_gradients(
func,
xs,
xs_grad_idxs=[[0]],
ret_grad_idxs=[["a"]]
)
Custom Gradient Functions#
There are various scenarios where users may want to define custom gradient computation rules for their functions.
Some of these are numerical stability, smoothing, and clipping of the computed gradients.
Ivy provides the ivy.bind_custom_gradient_function()
function to allow users to bind custom gradient computation logic to their functions.
Following is an example of usage of ivy.bind_custom_gradient_function()
,
import ivy
ivy.set_backend("torch")
x = ivy.array(50.0)
inter_func = lambda x: ivy.log1p(ivy.exp(x))
# args –> ((xs, ret), upstream)
def custom_grad_fn(*args):
args1 = (1 - 10 / (1 + args[0][0]))
return (args[1] * args)
inter_func = ivy.bind_custom_gradient_function(
inter_func, custom_grad_fn
)
func = lambda x: ivy.sum(inter_func(x) ** 2)
ret, grad = ivy.execute_with_gradients(func, x)
The custom_grad_fn
here accepts *args
which has the structure ((xs, ret), upstream)
where,
xs
is the input similar to the one accepted inivy.execute_with_gradients()
ret
is the output of the forward pass of theinter_func()
upstream
refers to the previously computed gradients while back-propagating
Design of the Gradient API#
Our policy on gradients#
The gradient API is fully-functional in ivy.
There is no explicit variable class or any public-facing function for adding gradient support to an ivy.Array.
The gradient functions in ivy implicitly convert all arrays to support gradient computation before computing gradients and detach all arrays after computing gradients.
We don’t retain any previously tracked computations in arrays by frameworks like torch for e.g.
This makes our gradient API disambiguous, flexible, and easy to debug.
Any framework-specific tracking of computations or variable classes should be handled in the corresponding frontends.
Gradient APIs of frameworks#
General Structure of Backend-specific implementations#
Here’s a high-level description of the steps followed backend-specific implementation of ivy.execute_with_gradients()
:
Get Duplicate Index Chains : indices of arrays that share the same
id
Convert integer arrays to floats : only for ease of use. it’s not recommended to pass integer arrays to gradient functions
Get relevant inputs : based on the
xs_grad_idxs
, we collect the relevant inputs for gradient computationEnable gradient support : we implicitly make use of framework-specific APIs to enable gradients in arrays. Ivy doesn’t need to have an explicit variable class as the gradient API is fully functional
Compute Results : we do the forward pass by passing the input as it is to the function
Get relevant outputs : based on the
ret_grad_idxs
, we collect the relevant outputs for gradient computationCompute gradients : we make use of the framework-specific APIs to compute the gradients for the relevant outputs with respect to the relevant inputs
Handle duplicates : we explicitly handle duplicate instances using the index chains captured above as different frameworks treat duplicates differently
Post process and detach : finally, all computed gradients are updated to deal with
NaN
andinf
and the input arrays are detached (i.e. gradient propagation is stopped)
Framework-specific Considerations#
JAX treats duplicate arrays as distinct while computing gradients, so we need additional logic to replicate gradients computed w.r.t one array over all its duplicates.
Gradients computed for functions with undefined results are inconsistent across backends (NaN, Inf, 0). We handle all these inconsistencies by returning 0 for all backends. So if you’re debugging gradients and find a 0, there’s a possibility that it was NaN or an Inf before computing.
Round Up
This should have hopefully given you a good feel for how the gradient API is implemented in Ivy.
If you have any questions, please feel free to reach out on discord in the gradients thread!