execute_with_gradients#
- ivy.execute_with_gradients(func, xs, /, *, retain_grads=False, xs_grad_idxs=((0,),), ret_grad_idxs=((0,),))[source]#
Call function func with input of xs variables, and return the function result func_ret and the gradients of each output variable w.r.t each input variable,
- Parameters:
func – Function for which we compute the gradients of the output with respect to xs input.
xs (
Union
[Array
,NativeArray
]) – Variables for which to compute the function gradients with respective to. This can be a single array or an arbitrary nest of arrays.retain_grads (
bool
, default:False
) – Whether to retain the gradients of the returned values. (Default value = False)xs_grad_idxs (
Sequence
[Sequence
[Union
[str
,int
]]], default:((0,),)
) – Indices of the input arrays to compute gradients with respect to. If None, gradients are returned with respect to all input arrays. Ifxs
is anivy.Array
orivy.Container
, the default value isNone
, otherwise the default value is[[0]]
.ret_grad_idxs (
Sequence
[Sequence
[Union
[str
,int
]]], default:((0,),)
) – Indices of the returned arrays for which to return computed gradients. If None, gradients are returned for all returned arrays. If the returned object from thefunc
is anivy.Array
orivy.Container
, the default value isNone
otherwise the default value is[[0]]
.
- Return type:
- Returns:
ret – the function result func_ret and a dictionary of gradients of each output variable w.r.t each input variable.
Examples
With
ivy.Array
input:>>> x = ivy.array([[1, 4, 6], [2, 6, 9]]) >>> func = lambda x: ivy.mean(ivy.square(x)) >>> func_ret = ivy.execute_with_gradients(func, x, retain_grads=True) >>> print(func_ret) (ivy.array(29.), ivy.array([[0.33333334, 1.33333337, 2. ], [0.66666669, 2. , 3. ]]))
With
ivy.Container
input:>>> x = ivy.Container(a = ivy.array([1, 4, 6]), ... b = ivy.array([2, 6, 9])) >>> func = lambda x: ivy.mean(ivy.square(x)) >>> func_ret = ivy.execute_with_gradients(func, x, retain_grads=True) >>> print(func_ret) ({ a: ivy.array(17.666666), b: ivy.array(40.333332) }, { a: { a: ivy.array([0.66666669, 2.66666675, 4.]), b: ivy.array([0., 0., 0.]) }, b: { a: ivy.array([0., 0., 0.]), b: ivy.array([1.33333337, 4., 6.]) } })