Accelerating PyTorch models with JAX#

Accelerate your Pytorch models by converting them to JAX for faster inference.

⚠️ If you are running this notebook in Colab, you will have to install Ivy and some dependencies manually. You can do so by running the cell below ⬇️

If you want to run the notebook locally but don’t have Ivy installed just yet, you can check out the Get Started section of the docs.

Make sure you run this demo with GPU enabled!

[1]:
!pip install -q ivy
!pip install -q transformers
!pip install -q flax
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv

Let’s now import Ivy and the libraries we’ll use in this example:

[1]:
import jax
jax.devices()
import ivy
ivy.set_default_device("gpu:0")
import torch
import requests
import numpy as np
from PIL import Image
from transformers import AutoModel, AutoFeatureExtractor
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
WARNING:root:   Some binaries seem to be missing in your system. This could be either because we don't have compatible binaries for your system or that newer binaries were available. In the latter case, calling ivy.utils.cleanup_and_fetch_binaries() should fetch the binaries binaries. Feel free to create an issue on https://github.com/ivy-llc/ivy.git in case of the former

WARNING:root:
Following are the supported configurations :
compiler : cp38-cp38-manylinux_2_17_x86_64, cp38-cp38-win_amd64, cp39-cp39-manylinux_2_17_x86_64, cp39-cp39-win_amd64, cp310-cp310-manylinux_2_17_x86_64, cp310-cp310-win_amd64, cp310-cp310-macosx_12_0_arm64, cp311-cp311-manylinux_2_17_x86_64, cp311-cp311-win_amd64, cp311-cp311-macosx_12_0_arm64, cp312-cp312-manylinux_2_17_x86_64, cp312-cp312-win_amd64, cp312-cp312-macosx_12_0_arm64


Now we can load a ResNet model and its corresponding feature extractor from Hugging Face transformers library

[2]:
jax.config.update("jax_enable_x64", True)

arch_name = "ResNet"
checkpoint_name = "microsoft/resnet-50"

feature_extractor = AutoFeatureExtractor.from_pretrained(checkpoint_name)
model = AutoModel.from_pretrained(checkpoint_name).to('cuda')
2024-10-21 20:32:14.781022: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-10-21 20:32:14.799336: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-10-21 20:32:14.805088: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-10-21 20:32:15.897178: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT

Let’s use the feature extractor to get the corresponding inputs.

[3]:
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
inputs = feature_extractor(
    images=image, return_tensors="pt"
).to('cuda')

And finally, let’s transpile the model to jax!

[4]:
from transformers.models.resnet.modeling_resnet import ResNetModel


JaxResNetModel = ivy.transpile(ResNetModel, source="torch", target="jax")

Instantiate the Jax/Flax model

[19]:
from ivy_transpiled_outputs.jax_outputs.transformers.models.resnet.modeling_resnet import (
    jax_ResNetConfig as JaxResNetConfig,
)


jax_config = JaxResNetConfig.from_dict(model.config.to_dict())
jax_model = JaxResNetModel(config=jax_config)

Let’s sync the weights of both the source and the transpiled model to do a 1-to-1 comparison of the two models and validate that they are functionally equal:

[20]:
ivy.sync_models(model, jax_model)
All parameters and buffers are now synced!
[21]:
torch_inputs = feature_extractor(image, return_tensors="pt").to("cuda")
with torch.no_grad():
    torch_outputs = model(**torch_inputs)
logits_np = torch_outputs.last_hidden_state.detach().cpu().numpy()
[23]:
jax_inputs = feature_extractor(image, return_tensors="jax")
logits_transpiled = jax_model(**jax_inputs)
logits_transpiled_np = logits_transpiled.last_hidden_state

np.allclose(logits_np, logits_transpiled_np, atol=1e-1)
[23]:
True

After transpiling our model, we can see what’s the improvement in runtime efficiency like. For this let’s compile the original PyTorch model using torch.compile

[54]:
# ref : https://github.com/pytorch/pytorch/issues/107960
!export LC_ALL="en_US.UTF-8"
!export LD_LIBRARY_PATH="/usr/lib64-nvidia"
!export LIBRARY_PATH="/usr/local/cuda/lib64/stubs"
!ldconfig /usr/lib64-nvidia
/sbin/ldconfig.real: /usr/lib/wsl/lib/libcuda.so.1 is not a symbolic link

/sbin/ldconfig.real: Can't create temporary cache file /etc/ld.so.cache~: Permission denied
[25]:
WARMUPS = 25

inputs = feature_extractor(
    images=image, return_tensors="pt"
).to("cuda")

comp_model = torch.compile(model)

for _ in range(WARMUPS):
  _ = comp_model(**inputs)

Let’s now do the equivalent transformation in our new haiku model by using JAX just in time compilation:

[27]:
from flax import nnx

inputs_jax = feature_extractor(
    images=image, return_tensors="jax"
)

def _forward(**kwargs):
  return jax_model(**kwargs).last_hidden_state

comp_model_jax = nnx.jit(_forward)

for _ in range(WARMUPS):
  _ = comp_model_jax(**inputs_jax)

Now that we have both models optimized, let’s see how their runtime speeds compare to each other!

[ ]:
%%timeit
_ = comp_model(**inputs)
6.63 ms ± 122 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
[ ]:
%%timeit
_ = comp_model_jax(**inputs_jax)
1.18 ms ± 134 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

As expected, we have made the model significantly faster with just one line of code, getting a ~2x increase in its execution speed! 🚀

Finally, as a sanity check, let’s load a different image and make sure that the results are the same in both models

[33]:
url = "http://images.cocodataset.org/train2017/000000283921.jpg"
image = Image.open(requests.get(url, stream=True).raw)
inputs = feature_extractor(
    images=image, return_tensors="pt"
).to("cuda")
inputs_jax = feature_extractor(
    images=image, return_tensors="jax"
)
out_torch = comp_model(**inputs)
out_jax = comp_model_jax(**inputs_jax)

np.allclose(out_torch.last_hidden_state.detach().cpu().numpy(), out_jax, atol=1e-1)
[33]:
True

That’s pretty much it! The results from both models are the same, but we have achieved a solid speed up by using Ivy’s transpiler to convert the model to JAX!