Gradients#

Overview#

Gradients are a crucial aspect of all modern deep learning workflows. Different frameworks provide different APIs for gradient computation and there were a few considerations to be made while building a unified gradients API in Ivy. There are a number of functions added in ivy to allow gradient computation, but we’ll mainly focus on the most commonly used and the most general function ivy.execute_with_gradients(). This is because the other gradient functions such as ivy.value_and_grad() and ivy.grad() can be considered as providing a subset of the functionality that ivy.execute_with_gradients() provides.

Example Usage of the Gradient API#

The ivy.execute_with_gradients() function signature#

Following is the pseudo function signature for the ivy.execute_with_gradients() function,

def execute_with_gradients (
    func : Callable,
    xs : Any arbitrary nest,
    xs_grad_idxs : Input indices,
    ret_grad_idxs : Output indices,
) :
    return func_ret, grads

The func in the input can be any user-defined function that returns a single scalar or any arbitrary nest of scalars. By scalars, we are referring to zero-dimensional arrays.

So for example, the following are some valid outputs by the func,

ivy.array(12.)

# OR

ivy.Container(
    a=ivy.array(12.),
    b=ivy.Container(
        c=ivy.array(15.),
        d=ivy.array(32.)
    )
)

# OR

[ivy.array(25.), {'x': (ivy.array(21.), ivy.array(11.))}, (ivy.array(9.),)]

xs can be any arbitrary nest of arrays and refers to the inputs passed to the func, so we suggest designing your func based on what inputs you pass in xs. The arrays in xs can contain any arbitrary number of dimensions, the only constraint is on the output of the func as explained above.

The xs_grad_idxs and ret_grad_idxs are intended to provide more control over the arrays gradients are computed with. xs_grad_idxs accepts the indices of the input arrays to compute gradients for, and ret_grad_idxs accepts the indices of the output arrays to compute gradients with respect to.

An example using ivy.execute_with_gradients()#

def func(xs) :
    return ivy.mean(xs[0] + xs[1].b)

x = ivy.array([1., 2., 3.])
x = ivy.Container(a=x, b=x)
y = ivy.array([4., 5., 6.])
y = ivy.Container(b=y, c=x)
xs = [x, y]

ret, grads = ivy.execute_with_gradients(
    func,
    xs,
    xs_grad_idxs=[[0]],
    ret_grad_idxs=[["a"]]
)

Custom Gradient Functions#

There are various scenarios where users may want to define custom gradient computation rules for their functions. Some of these are numerical stability, smoothing, and clipping of the computed gradients. Ivy provides the ivy.bind_custom_gradient_function() function to allow users to bind custom gradient computation logic to their functions.

Following is an example of usage of ivy.bind_custom_gradient_function(),

import ivy

ivy.set_backend("torch")
x = ivy.array(50.0)
inter_func = lambda x: ivy.log1p(ivy.exp(x))

# args –> ((xs, ret), upstream)
def custom_grad_fn(*args):
    args1 = (1 - 10 / (1 + args[0][0]))
    return (args[1] * args)

inter_func = ivy.bind_custom_gradient_function(
inter_func, custom_grad_fn
)
func = lambda x: ivy.sum(inter_func(x) ** 2)

ret, grad = ivy.execute_with_gradients(func, x)

The custom_grad_fn here accepts *args which has the structure ((xs, ret), upstream) where,

  • xs is the input similar to the one accepted in ivy.execute_with_gradients()

  • ret is the output of the forward pass of the inter_func()

  • upstream refers to the previously computed gradients while back-propagating

Design of the Gradient API#

Our policy on gradients#

  • The gradient API is fully-functional in ivy.

  • There is no explicit variable class or any public-facing function for adding gradient support to an ivy.Array.

  • The gradient functions in ivy implicitly convert all arrays to support gradient computation before computing gradients and detach all arrays after computing gradients.

  • We don’t retain any previously tracked computations in arrays by frameworks like torch for e.g.

  • This makes our gradient API disambiguous, flexible, and easy to debug.

  • Any framework-specific tracking of computations or variable classes should be handled in the corresponding frontends.

Gradient APIs of frameworks#

General Structure of Backend-specific implementations#

Here’s a high-level description of the steps followed backend-specific implementation of ivy.execute_with_gradients():

  1. Get Duplicate Index Chains : indices of arrays that share the same id

  2. Convert integer arrays to floats : only for ease of use. it’s not recommended to pass integer arrays to gradient functions

  3. Get relevant inputs : based on the xs_grad_idxs, we collect the relevant inputs for gradient computation

  4. Enable gradient support : we implicitly make use of framework-specific APIs to enable gradients in arrays. Ivy doesn’t need to have an explicit variable class as the gradient API is fully functional

  5. Compute Results : we do the forward pass by passing the input as it is to the function

  6. Get relevant outputs : based on the ret_grad_idxs, we collect the relevant outputs for gradient computation

  7. Compute gradients : we make use of the framework-specific APIs to compute the gradients for the relevant outputs with respect to the relevant inputs

  8. Handle duplicates : we explicitly handle duplicate instances using the index chains captured above as different frameworks treat duplicates differently

  9. Post process and detach : finally, all computed gradients are updated to deal with NaN and inf and the input arrays are detached (i.e. gradient propagation is stopped)

Framework-specific Considerations#

  • JAX treats duplicate arrays as distinct while computing gradients, so we need additional logic to replicate gradients computed w.r.t one array over all its duplicates.

  • Gradients computed for functions with undefined results are inconsistent across backends (NaN, Inf, 0). We handle all these inconsistencies by returning 0 for all backends. So if you’re debugging gradients and find a 0, there’s a possibility that it was NaN or an Inf before computing.

Round Up

This should have hopefully given you a good feel for how the gradient API is implemented in Ivy.

If you have any questions, please feel free to reach out on discord in the gradients thread!