Transpiling Models from PyTorch to TensorFlow#

You can install the dependencies required for this notebook by running the cell below ⬇️, or check out the Get Started section of the docs to find out more about installing ivy.

[ ]:
!pip install ivy
!pip install torch
!pip install tensorflow

Here we’ll go through an example of how any model written in PyTorch can be converted, and used in, TensorFlow via ivy.transpile. First, lets import a simple torch model.

[1]:
from example_models import SimpleModel

print("""
This model is defined as follows:

class SimpleModel(torch.nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.conv1 = torch.nn.Conv2d(1, 3, kernel_size=3)
        self.relu = torch.nn.ReLU()
        self.fc = torch.nn.Linear(3 * 26 * 26, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x
"""
)

This model is defined as follows:

class SimpleModel(torch.nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.conv1 = torch.nn.Conv2d(1, 3, kernel_size=3)
        self.relu = torch.nn.ReLU()
        self.fc = torch.nn.Linear(3 * 26 * 26, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

Next, we can convert the model to tensorflow

[2]:
import ivy

TFSimpleModel = ivy.transpile(SimpleModel, source="torch", target="tensorflow")
WARNING:root:   Some binaries seem to be missing in your system. This could be either because we don't have compatible binaries for your system or that newer binaries were available. In the latter case, calling ivy.utils.cleanup_and_fetch_binaries() should fetch the binaries binaries. Feel free to create an issue on https://github.com/ivy-llc/ivy.git in case of the former

WARNING:root:
Following are the supported configurations :
compiler : cp38-cp38-manylinux_2_17_x86_64, cp38-cp38-win_amd64, cp39-cp39-manylinux_2_17_x86_64, cp39-cp39-win_amd64, cp310-cp310-manylinux_2_17_x86_64, cp310-cp310-win_amd64, cp310-cp310-macosx_12_0_arm64, cp311-cp311-manylinux_2_17_x86_64, cp311-cp311-win_amd64, cp311-cp311-macosx_12_0_arm64, cp312-cp312-manylinux_2_17_x86_64, cp312-cp312-win_amd64, cp312-cp312-macosx_12_0_arm64


2024-10-18 14:54:00.304303: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-10-18 14:54:00.315374: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-10-18 14:54:00.329079: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-10-18 14:54:00.332741: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-10-18 14:54:00.346229: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-10-18 14:54:01.007774: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Transpiling SimpleModel from torch to tensorflow. This could take a few minu
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1729245274.391513  365152 cuda_executor.cc:1001] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
2024-10-18 14:54:34.393241: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2343] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
Transpilation of SimpleModel complete.

Now we can use the model with TensorFlow

[3]:
import tensorflow as tf

tf_model = TFSimpleModel()
tf_model(tf.random.normal((1, 1, 28, 28))).shape
[3]:
TensorShape([1, 10])

We can also take advantage of TensorFlow-specific features, such as tf.function:

[4]:
compiled_model = tf.function(tf_model)
compiled_model(tf.random.normal((1, 1, 28, 28))).shape
[4]:
TensorShape([1, 10])