Trace code#

Turn your Ivy code into an efficient fully-functional graph, removing wrappers and unused parts of the code.

⚠️ If you are running this notebook in Colab, you will have to install Ivy and some dependencies manually. You can do so by running the cell below ⬇️

If you want to run the notebook locally but don’t have Ivy installed just yet, you can check out the Get Started section of the docs.

[ ]:
!pip install ivy

Let’s begin with an implementation of the normalize function using ivy’s Functional API:

[4]:
import ivy

def normalize(x):
    mean = ivy.mean(x)
    std = ivy.std(x)
    return ivy.divide(ivy.subtract(x, mean), std)

For the purpose of illustration, we will use jax as our backend framework:

[7]:
# set ivy's backend to jax
ivy.set_backend("jax")

# Import jax
import jax

# create random jax arrays for testing
key = jax.random.PRNGKey(42)
x = jax.random.uniform(key, shape=(10,))
normalize(x)
[7]:
ivy.array([ 0.58569533, -0.69083852, -1.20325196,  1.5490098 ,  1.37264228,
       -1.20946217, -0.60102183, -0.96937162,  0.53789282,  0.62870705])

When calling this function, all of ivy’s function wrapping is included in the call stack of normalize, which adds runtime overhead. In general, ivy.trace_graph strips any arbitrary function down to its constituent functions in the functional API of the target framework. The code can be traced like so:

[8]:
ivy.set_backend("jax")
traced = ivy.trace_graph(normalize)  # traces to jax, due to ivy.set_backend

The traced function can be executed in exactly the same manner as the non-traced function:

[9]:
traced(x)
[9]:
Array([ 0.5856953 , -0.6908385 , -1.203252  ,  1.5490098 ,  1.3726423 ,
       -1.2094622 , -0.6010218 , -0.9693716 ,  0.5378928 ,  0.62870705],      dtype=float32)

With all lazy graph tracing calls now performed (which all increase runtime during the very first call of the function), we can now assess the runtime efficiencies of each function:

[10]:
%%timeit
normalize(x)
138 ms ± 3.57 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
[12]:
%%timeit
traced(x)
122 µs ± 2.02 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

As expected, we can see that normalize is slower, as it includes all ivy wrapping overhead. On the other hand, traced has no wrapping overhead and it’s more efficient!

Fun Fact: You can use the graph tracer with pretty much any code written in one of the ML frameworks Ivy supports i.e. PyTorch, TensorFlow, Jax, NumPy etc. and speed it up by removing unnecessary computations that don’t contribute towards the output by extracting an efficient computation graph stitched together in the set backend framework!

Round Up#

That’s it, you can now trace ivy code for more efficient inference! However, there are several other important topics to master before you’re ready to play with ML code like a pro 🥷.