Training PyTorch ResNet in your TensorFlow Projects#
Framework Incompatibility#
Practitioners with large codebases written in other frameworks, such as PyTorch, are unable to take advantage of TensorFlow’s rich ecosystem of state-of-the-art (SOTA) deployment toolings, as this requires converting their code manually and inaccurately.
Ivy’s transpiler allows ML practitioners to dynamically connect libraries, layers and models from different frameworks together. For PyTorch users, the transpiler provides a seamless and accurate way to introduce code written in PyTorch to TensorFlow pipelines.
In this blog post, we’ll go through an example of how the transpiler can be used to convert a model from PyTorch to TensorFlow and train the converted model in TensorFlow.
Transpiling a PyTorch model to TensorFlow#
About the transpiled model#
To illustrate a typical transpilation workflow, we’ll be converting a pre-trained ResNet model from PyTorch to TensorFlow, and using the transpiled model to run inference.
ResNet owes its name to its residual blocks with skip connections that enable the model to be extremely deep. Even though including skip connections is a common idea in the community now, it was a revolutionary architectural choice and allowed ResNet to reach up to 152 layers with no vanishing or exploding gradient problems during training.
Architecturally, a ResNet block is similar to a ConvNext block but differs in terms of the specific convolutional layers used, grouped convolution, normalization, activation function, and downsampling. Going through the details of the models is outside the scope of this demo, interested readers might want to go through the paper.
Since we want the packages to be available after installing, after running the first cell, the notebook will automatically restart.
You can then do Runtime -> Run all after the notebook has restarted, to run all of the cells.
Make sure you run this demo with GPU enabled!
!pip install -q ivy
!python3 -m pip install torchvision
!python3 -m pip install astor
Setting-up the source model#
We import the necessary libraries. We’ll mostly use the PyTorch’s Torchvision API to load the model, Ivy to transpile it from PyTorch to TensorFlow, and TensorFlow functions to fine-tune the transpiled model.
import warnings
import logging
import tensorflow as tf
tf.config.list_physical_devices("GPU")[0], True
# Filter TensorFlow info and warning messages
import os
import ivy
import torch
import torchvision
from torchvision import datasets, models, transforms
Load the Data#
We will use torchvision and packages for loading the data.
The problem we’re going to solve today is to train a model to classify ants and bees. We have about 120 training images each for ants and bees. There are 75 validation images for each class. Usually, this is a very small dataset to generalize upon, if trained from scratch. Since we are using transfer learning, we should be able to generalize reasonably well.
This dataset is a very small subset of imagenet.
Note: Download the data from here and extract it to the current directory by running the following cell
import requests
import os
import zipfile
# URL of the zip file you want to download
url = '' # Replace with your URL
# Send a GET request to the URL
response = requests.get(url)
# Check if the request was successful (status code 200)
if response.status_code == 200:
# Get the file name from the URL
filename = os.path.basename(url)
# Specify where you want to save the zip file (current working directory in Colab)
zip_save_path = os.path.join(os.getcwd(), filename)
# Write the content to the zip file
with open(zip_save_path, 'wb') as f:
print(f"Zip file downloaded successfully as '{filename}' in the current working directory.")
# Extract the contents of the zip file
with zipfile.ZipFile(zip_save_path, 'r') as zip_ref:
print("Zip file contents extracted successfully.")
# Optionally, you can remove the zip file after extraction
print(f"Zip file '{filename}' deleted.")
print(f"Failed to download zip file from '{url}'. Status code: {response.status_code}")
Zip file downloaded successfully as '' in the current working directory.
Zip file contents extracted successfully.
Zip file '' deleted.
# Data augmentation and normalization for training
# Just normalization for validation
data_transforms = {
'train': transforms.Compose([
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
'val': transforms.Compose([
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
data_dir = 'hymenoptera_data'
image_datasets = {
x: datasets.ImageFolder(
os.path.join(data_dir, x), data_transforms[x]
) for x in ['train', 'val']
dataloaders = {
image_datasets[x], batch_size=4,
) for x in ['train', 'val']
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
Visualize a few images#
We also load an input tensor to be passed as the input for transpilation
import numpy as np
from matplotlib import pyplot as plt
def imshow(inp, title=None):
"""Display image for Tensor."""
inp = inp.numpy().transpose((1, 2, 0))
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
inp = std * inp + mean
inp = np.clip(inp, 0, 1)
if title is not None:
plt.pause(0.001) # pause a bit so that plots are updated
# Get a batch of training data
inputs, classes = next(iter(dataloaders['train']))
# Make a grid from batch
out = torchvision.utils.make_grid(inputs)
imshow(out, title=[class_names[x] for x in classes])

Load the pre-trained model#
We then initialise our ML model through the torchvision API, specifically we’ll be using ResNet18. Note that while we are using a model from the torchvision models API for this demonstration, it would still work with any arbitrary PyTorch model regardless of how it is being loaded. You can load models hosted on different platforms including local models.
model = torchvision.models.resnet18(weights='IMAGENET1K_V1')
Converting the model from PyTorch to TensorFlow#
As we explain in our Quickstart Guide, transpiling a model (an instance of a trainable class) involves manually providing the type/class of the model rather than the model itself. This is because Ivy’s transpiler will eagerly convert all source code of the model from PyTorch
to TensorFlow
and make it readily available for you in the output directory (controlled by passing the output_dir
kwarg to ivy.transpile
, defaulting to
The weights of the transpiled model and the original model (if it was pretrained or loaded from a checkpoint) can be easily synced by using ivy.sync_models
as we demonstrate in this demo
TensorflowResNet = ivy.transpile(type(model), source="torch", target="tensorflow")
Initializing and building the TensorFlow model#
from ivy_transpiled_outputs.tensorflow_outputs.torchvision.models.resnet import (
tensorflow_BasicBlock as TensorflowBasicBlock,
kwargs = {"num_classes": 1000}
layers = [2, 2, 2, 2]
block = TensorflowBasicBlock
Instantiate the model in TensorFlow
tensorflow_model = TensorflowResNet(block, layers, **kwargs)
Call the forward pass once in order to build all the layers
import tensorflow as tf
<tf.Tensor: shape=(4, 1000), dtype=float32, numpy=
array([[ 0.03397159, -0.25558692, -0.24101321, ..., 0.30248338,
0.4400444 , 0.19245608],
[ 0.04251331, -0.18369699, -0.19522265, ..., 0.30382818,
0.3872773 , 0.18959698],
[ 0.01237407, -0.1483828 , -0.22826895, ..., 0.19425122,
0.33760768, 0.10684194],
[ 0.0467218 , -0.19289383, -0.22670889, ..., 0.277786 ,
0.39353883, 0.17944181]], dtype=float32)>
TensorFlow (Keras) Model specific attributes and methods are all still accessible on the transpiled model:
Let’s sync the weights of both the source and the transpiled model to do a 1-to-1 comparison of the two models and validate that they are functionally equal:
import ivy
ivy.sync_models(model, tensorflow_model)
All parameters and buffers are now synced!
Comparing the results#
Let’s now try predicting the logits of the same input with the transpiled model
To compare the logits produced by the original and transpiled models at a more granular level, let’s try an allclose
logits = model(inputs)
logits_np = logits.detach().numpy()
logits_transpiled = tensorflow_model(tf.convert_to_tensor(inputs.numpy()), training=False)
logits_transpiled_np = logits_transpiled.numpy()
np.allclose(logits_np, logits_transpiled_np, atol=1e-4)
The logits produced by the transpiled model at inference time are close to the ones produced by the original model, the logits are indeed consistent!
Fine-tuning the transpiled model#
One of the key benefits of using ivy’s transpiler is that the transpiled model is also trainable. As a result, we can also further train the transpiled model if required. Here’s an example of fine-tuning the transpiled model with a few images sampled from ImageNet using TensorFlow.
Let’s start by writing a general function to train a model.
import time
import tensorflow as tf
def train_model(model, epochs, warmup_epochs, train_dataset, val_dataset, optimizer, loss_fn):
# Prepare the metrics.
train_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
val_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
def train_step(model, x_batch_train, y_batch_train, optimizer, loss_fn):
with tf.GradientTape() as tape:
logits = model(x_batch_train, training=True)
loss_value = loss_fn(y_batch_train, logits)
grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))
# Update training metric.
train_acc_metric.update_state(y_batch_train, logits)
return loss_value
def train_loop(model, train_dataset, val_dataset, epochs, warmup_epochs, optimizer, loss_fn):
for epoch in range(epochs + warmup_epochs):
print("\nStart of epoch %d" % (epoch,))
start_time = time.time()
# Iterate over the batches of the dataset.
for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
x_batch_train = tf.convert_to_tensor(x_batch_train.detach().numpy())
y_batch_train = tf.convert_to_tensor(y_batch_train.detach().numpy())
loss_value = train_step(model, x_batch_train, y_batch_train, optimizer, loss_fn)
# Log every 20 batches.
if epoch >= warmup_epochs:
if step % 20 == 0:
"Training loss (for one batch) at step %d: %.4f"
% (step, float(loss_value))
print("Seen so far: %d samples" % ((step + 1) * 4))
# Display metrics at the end of each epoch.
if epoch >= warmup_epochs:
train_acc = train_acc_metric.result()
print("Training acc over epoch: %.4f" % (float(train_acc),))
# Reset training metrics at the end of each epoch
# Run a validation loop at the end of each epoch.
for x_batch_val, y_batch_val in val_dataset:
x_batch_val = tf.convert_to_tensor(x_batch_val.detach().numpy())
y_batch_val = tf.convert_to_tensor(y_batch_val.detach().numpy())
val_logits = model(x_batch_val, training=False)
# Update val metrics
val_acc_metric.update_state(y_batch_val, val_logits)
val_acc = val_acc_metric.result()
print("Validation acc: %.4f" % (float(val_acc),))
print("Time taken: %.2fs" % (time.time() - start_time))
# Call the training loop
train_loop(model, train_dataset, val_dataset, epochs, warmup_epochs, optimizer, loss_fn)
return model
# Instantiate an optimizer to train the model.
optimizer = tf.keras.optimizers.SGD(learning_rate=1e-3)
# Instantiate a loss function.
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
# Prepare the datasets
train_dataset = dataloaders["train"]
val_dataset = dataloaders["val"]
[ ]:
# Train the model
transpiled_model = train_model(
And that’s it. we’ve successfully been able to train the transpiled model, we can now plug into any TensorFlow workflow!
Let’s now visualize the inference of the trained model on some sample images from the validation step
def visualize_model(model, num_images=6):
was_training = tf.keras.backend.learning_phase() == 1
images_so_far = 0
fig = plt.figure()
for i, (inputs, labels) in enumerate(dataloaders['val']):
inputs = tf.convert_to_tensor(inputs.detach().numpy())
labels = tf.convert_to_tensor(labels.detach().numpy())
outputs = model(inputs, training=False)
preds = tf.argmax(outputs, 1)
for j in range(inputs.shape[0]):
images_so_far += 1
ax = plt.subplot(num_images//2, 2, images_so_far)
ax.set_title(f'predicted: {class_names[preds[j]]}')
if images_so_far == num_images:
model(inputs, training=was_training)
model(inputs, training=was_training)
We’ve just seen how the transpiler can be used to convert a model from PyTorch to TensorFlow and train the converted model in TensorFlow.
Head over to the Examples and Demos section in our documentation if you’d like to explore other demos like this. You can also run demos locally on your own machine by getting started or signing up for an API key to get a transpiler API key with juiced up usage quotas for local development.
If you have any questions or suggestions for other interesting demos you’d like to see, feel free to ask on our Discord community server, we look forward to seeing you there!