Transpile code#
Convert a torch
function to jax
with just one line of code.
You can install the dependencies required for this notebook by running the cell below ⬇️, or check out the Get Started section of the docs to find out more about installing ivy.
[ ]:
!pip install ivy
!pip install torch
!pip install jax jaxlib
Using what we learnt in the previous two notebooks for Unify and Trace, the workflow for converting directly from torch
to jax
would be as follows, first unifying to ivy
code, and then tracing to the jax
backend:
[ ]:
import ivy
import torch
ivy.set_backend("jax")
def normalize(x):
mean = torch.mean(x)
std = torch.std(x)
return torch.div(torch.sub(x, mean), std)
# convert the function to Ivy code
ivy_normalize = ivy.unify(normalize)
# trace the Ivy code into jax functions
jax_normalize = ivy.trace_graph(ivy_normalize)
normalize
is now traced to jax
, ready to be integrated into your wider jax
project.
This workflow is common, and so in order to avoid repeated calls to ivy.unify
followed by ivy.trace_graph
, there is another convenience function ivy.transpile
, which basically acts as a shorthand for this pair of function calls:
[4]:
jax_normalize = ivy.transpile(normalize, source="torch", to="jax")
Again, normalize
is now a jax
function, ready to be integrated into your jax
project.
[6]:
import jax
key = jax.random.PRNGKey(42)
jax.config.update('jax_enable_x64', True)
x = jax.random.uniform(key, shape=(10,))
print(jax_normalize(x))
[-0.93968587 0.26075466 -0.22723222 -1.06276492 -0.47426987 1.72835908
1.71737559 -0.50411096 -0.65419174 0.15576624]
Round Up#
That’s it, you can now transpile code from one framework to another with one line of code! However, there are still other important topics to master before you’re ready to unify ML code like a pro 🥷. In the next notebooks we’ll be learning about the various different ways that ivy.unify
, ivy.trace_graph
and ivy.transpile
can be called, and what implications each of these have!