Utilities#

ivy.stateful.utilities.sync_models(original_model, translated_model)[source]#

Synchronizes the weights and buffers between a native PyTorch model (torch.nn.Module) and it’s translated version in TensorFlow or Flax.

Args:#

original_model (torch.nn.Module): The PyTorch model to synchronize from. translated_model (tf.keras.Model or nnx.Module): The target model to synchronize to,

either a TensorFlow or Flax model.

ivy.stateful.utilities.sync_models_torch_and_jax(model_pt, model_jax)[source]#

Synchronizes the weights and buffers between a PyTorch model (torch.nn.Module) and a Flax model (flax.nnx.Module).

This function ensures both models have identical parameters and buffers by iterating through their submodules and synchronizing them. The Flax model must either be an instance of FlaxModel or have submodules that inherit from the translated FlaxModel, and expose interfaces similar to torch.nn.Module, including named_parameters() and named_buffers().

Args:#

model_pt (torch.nn.Module): The PyTorch model to synchronize from. model_flax (flax.nnx.Module): The Flax model to synchronize to, with submodules

inheriting from the custom FlaxModel class.

Returns:#

:

None

Example:#

```python import torch.nn as nn import jax.numpy as jnp import flax.nnx as nnx

#`CustomFlaxLinear` is a subclass of FlaxModel that exposes a similar # interface to torch.nn.Module (with named_parameters and named_buffers). class CustomFlaxLinear(FlaxModel):

def __init__(self, in_features, out_features):

super(CustomFlaxLinear, self).__init__() self.weight = nnx.Param(jax.random.normal(jax.random.key(0), [out_features,in_features])) self.bias = nnx.Param(jax.random.normal(jax.random.key(0),[out_features]))

def call(self, x):

return x @ self.weight + bias

def named_parameters(self):

return [(“weight”, self.weight), (“bias”, self.bias)]

def named_buffers(self):

return []

def eval(self):

return False

#`NativeFlaxModel` is a subclass of nnx.Module and does NOT exposes a similar # interface to torch.nn.Module (with named_parameters and named_buffers). class NativeFlaxModel(nnx.Module):

def __init__(self):

super(NativeFlaxModel, self).__init__() self.linear = CustomFlaxLinear(10, 5)

def call(self, x):

return self.linear(x)

class PyTorchModel(nn.Module):
def __init__(self):

super(PyTorchModel, self).__init__() self.linear = nn.Linear(10, 5)

def forward(self, x):

return self.linear(x)

# Instantiate both models model_pt = PyTorchModel() # PyTorch model model_flax = NativeFlaxModel() # Native Flax model inheriting from nnx.Module

# Sync all submodules between the PyTorch and Keras models sync_models_torch_and_jax(model_pt, model_flax) ```

ivy.stateful.utilities.sync_models_torch_and_tf(model_pt, model_tf)[source]#

Synchronizes the weights and buffers between a PyTorch model (torch.nn.Module) and a TensorFlow model (keras.Model).

This function ensures that both models have identical parameters and buffers by iterating through their submodules and synchronizing them. The TensorFlow model must either be an instance of KerasModel or have submodules that inherit from the translated KerasModel/KerasLayer, and expose interfaces similar to torch.nn.Module, including named_parameters() and named_buffers().

Args:#

model_pt (torch.nn.Module): The PyTorch model to synchronize from. model_tf (keras.Model): The TensorFlow model to synchronize to, with submodules

inheriting from the custom KerasModel/KerasLayer class.

Returns:#

:

None

Example:#

```python import torch.nn as nn import keras

#`CustomKerasLinear` is a subclass of Layer that exposes a similar # interface to torch.nn.Module (with named_parameters and named_buffers). class CustomKerasLinear(Layer):

def __init__(self, in_features, out_features):

super(CustomKerasLinear, self).__init__() self.weight = tf.Variable(tf.random.normal([out_features, in_features])) self.bias = tf.Variable(tf.random.normal([out_features]))

def call(self, x):

return tf.matmul(x, self.weight) + self.bias

def named_parameters(self):

return [(“weight”, self.weight), (“bias”, self.bias)]

def named_buffers(self):

return []

def eval(self):

return False

#`NativeKerasModel` is a subclass of keras.Model and does NOT exposes a similar # interface to torch.nn.Module (with named_parameters and named_buffers). class NativeKerasModel(keras.Model):

def __init__(self):

super(NativeKerasModel, self).__init__() self.linear = CustomKerasLinear(10, 5)

def call(self, x):

return self.linear(x)

class PyTorchModel(nn.Module):
def __init__(self):

super(PyTorchModel, self).__init__() self.linear = nn.Linear(10, 5)

def forward(self, x):

return self.linear(x)

# Instantiate both models model_pt = PyTorchModel() # PyTorch model model_tf = NativeKerasModel() # Native Keras model inheriting from keras.Model

# Sync all submodules between the PyTorch and Keras models sync_models_torch_and_tf(model_pt, model_tf) ```

ivy.stateful.utilities.transpose_weights_pt_to_tf_jax(layer, params_np, transpose_weights, fw)[source]#

Transpose weights from PyTorch to TensorFlow/JAX format.

Args: - layer: The layer object. - params_np: The weights in NumPy format. - transpose_weights: Flag to enable weight transposition.

Returns: - Transposed weights.

ivy.stateful.utilities.transpose_weights_tf_jax_to_pt(layer, params_np, transpose_weights, fw)[source]#

Transpose weights from TensorFlow/JAX to PyTorch format.

Args: - layer: The layer object. - params_np: The weights in NumPy format. - transpose_weights: Flag to enable weight transposition.

Returns: - Transposed weights.

This should have hopefully given you an overview of the utilities submodule, if you have any questions, please feel free to reach out on our discord!