Skip to content

Commit

Permalink
🔀 Merge pull request #6 from alvarobartt/extend-usage
Browse files Browse the repository at this point in the history
⚡️ Overall improvements, bug fixes, more unit tests, and `dm-haiku` compatibility tested
  • Loading branch information
alvarobartt authored Dec 26, 2022
2 parents 1ecc777 + 9a0b658 commit aaee170
Show file tree
Hide file tree
Showing 20 changed files with 564 additions and 262 deletions.
53 changes: 41 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
# 🔐 Serialize JAX/Flax models with `safetensors`
# 🔐 Serialize JAX, Flax, or Haiku model params with `safetensors`

`safejax` is a Python package to serialize JAX and Flax models using `safetensors`
`safejax` is a Python package to serialize JAX, Flax, or Haiku model params using `safetensors`
as the tensor storage format, instead of relying on `pickle`. For more details on why
`safetensors` is safer than `pickle` please check https://github.com/huggingface/safetensors.

Note that `safejax` supports the serialization of `jax`, `flax`, and `dm-haiku` model
parameters and has been tested with all those frameworks. Anyway, `objax` is still pending
as the `VarCollection` that it uses internally to store the tensors in memory is restricted
to another naming convention e.g. `(EfficientNet).stem(ConvBnAct).conv(Conv2d).w`
instead of `params.stem.conv.w` because the first can be more useful when debugging,
even though there's some built-in rename functionality to allow loading weights from
other frameworks, but that's still WIP in `safejax`.

## 🛠️ Requirements & Installation

`safejax` requires Python 3.7 or above
Expand All @@ -14,12 +22,16 @@ pip install safejax --upgrade

## 💻 Usage

Let's create a `flax` model using the Linen API and once initialized,
we can save the model params with `safejax` (using `safetensors`
storage format).

```python
import jax
from flax import linen as nn
from jax import numpy as jnp

from safejax.flax import serialize
from safejax import serialize


class SingleLayerModel(nn.Module):
Expand All @@ -36,21 +48,38 @@ model = SingleLayerModel(features=1)
rng = jax.random.PRNGKey(0)
params = model.init(rng, jnp.ones((1, 1)))

serialized = serialize(params=params)
assert isinstance(serialized, bytes)
assert len(serialized) > 0
serialized_params = serialize(params=params)
```

Those params can be later loaded using `safejax.deserialize` and used
to run the inference over the model using those weights.

```python
from safejax import deserialize

params = deserialize(path_or_buf=serialized_params, freeze_dict=True)
```

And, finally, running the inference as:

```python
x = jnp.ones((1, 28, 28, 1))
y = model.apply(params, x)
```

More examples can be found at [`examples/`](./examples).
More in-detail examples can be found at [`examples/`](./examples) for both `flax` and `dm-haiku`.

## 🤔 Why `safejax`?

`safetensors` defines an easy and fast (zero-copy) format to store tensors,
while `pickle` has some known weaknesses and security issues. `safetensors`
is also a storage format that is intended to be trivial to the framework
used to load the tensors. More in depth information can be found at
used to load the tensors. More in-depth information can be found at
https://github.com/huggingface/safetensors.

Both `jax` and `haiku` use `pytrees` to store the model parameters in memory, so
it's a dictionary-like class containing nested `jnp.DeviceArray` tensors.

`flax` defines a dictionary-like class named `FrozenDict` that is used to
store the tensors in memory, it can be dumped either into `bytes` in `MessagePack`
format or as a `state_dict`.
Expand All @@ -60,21 +89,21 @@ there are no plans from HuggingFace to extend `safetensors` to support anything
more than tensors e.g. `FrozenDict`s, see their response at
https://github.com/huggingface/safetensors/discussions/138.

So `safejax` was created so as to easily provide a way to serialize `FrozenDict`s
So `safejax` was created to easily provide a way to serialize `FrozenDict`s
using `safetensors` as the tensor storage format instead of `pickle`.

### 📄 Main differences with `flax.serialization`

* `flax.serialization.to_bytes` uses `pickle` as the tensor storage format, while
`safejax.flax.serialize` uses `safetensors`
`safejax.serialize` uses `safetensors`
* `flax.serialization.from_bytes` requires the `target` to be instantiated, while
`safejax.flax.deserialize` just needs the encoded bytes
`safejax.deserialize` just needs the encoded bytes

## 🏋🏼 Benchmark

Benchmarks are no longer running with [`hyperfine`](https://github.com/sharkdp/hyperfine),
as most of the elapsed time is not during the actual serialization but in the imports and
in the model parameter initialization. So we've refactored those so as to run with pure
the model parameter initialization. So we've refactored those to run with pure
Python code using `time.perf_counter` to measure the elapsed time in seconds.

```bash
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/hyperfine/resnet50.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from flaxmodels.resnet import ResNet50
from jax import numpy as jnp

from safejax.flax import serialize
from safejax import serialize

resnet50 = ResNet50()
params = resnet50.init(jax.random.PRNGKey(42), jnp.ones((1, 224, 224, 3)))
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/hyperfine/single_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from flax import linen as nn
from flax.serialization import to_bytes

from safejax.flax import serialize
from safejax import serialize


class SingleLayerModel(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/resnet50.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from flaxmodels.resnet import ResNet50
from jax import numpy as jnp

from safejax.flax import serialize
from safejax import serialize

resnet50 = ResNet50()
params = resnet50.init(jax.random.PRNGKey(42), jnp.ones((1, 224, 224, 3)))
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/single_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from flax.serialization import to_bytes
from jax import numpy as jnp

from safejax.flax import serialize
from safejax import serialize


class SingleLayerModel(nn.Module):
Expand Down
130 changes: 130 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# 💻 Examples

Here you will find some detailed examples of how to use `safejax` to serialize
model parameters, in opposition to the default way to store those, which uses
`pickle` as the format to store the tensors instead of `safetensors`.

## Flax - [`flax_ft_safejax`](./examples/flax_ft_safejax.py)

To run this Python script you won't need to install anything else than
`safejax`, as both `jax` and `flax` are installed as part of it.

In this case, a single-layer model will be created, for now, `flax`
doesn't have any pre-defined architecture such as ResNet, but you can use
[`flaxmodels`](https://github.com/matthias-wright/flaxmodels) for that, as
it defines some well-known architectures written in `flax`.

```python
import jax
from flax import linen as nn

class SingleLayerModel(nn.Module):
features: int

@nn.compact
def __call__(self, x):
x = nn.Dense(features=self.features)(x)
return x
```

Once the network has been defined, we can instantiate and initialize it,
to retrieve the `params` out of the forward pass performed during
`.init`.

```python
import jax
from jax import numpy as jnp

network = SingleLayerModel(features=1)

rng_key = jax.random.PRNGKey(seed=0)
initial_params = network.init(rng_key, jnp.ones((1, 1)))
```

Right after getting the `params` from the `.init` method's output, we can
use `safejax.serialize` to encode those using `safetensors`, that later on
can be loaded back using `safejax.deserialize`.

```python
from safejax import deserialize, serialize

encoded_bytes = serialize(params=initial_params)
decoded_params = deserialize(path_or_buf=encoded_bytes, freeze_dict=True)
```

As seen in the code above, we're using `freeze_dict=True` since its default
value is False, as we want to freeze the `dict` with the params before actually
returning it during `safejax.deserialize`, this transforms the `Dict`
into a `FrozenDict`.

Finally, we can use those `decoded_params` to run a forward pass
with the previously defined single-layer network.

```python
x = jnp.ones((1, 1))
y = network.apply(decoded_params, x)
```


## Haiku - [`haiku_ft_safejax.py`](./examples/haiku_ft_safejax.py)

To run this Python script you'll need to have both `safejax` and [`dm-haiku`](https://github.com/deepmind/dm-haiku)
installed.

A ResNet50 architecture will be used from `haiku.nets.imagenet.resnet` and since
the purpose of the example is to show the integration of both `dm-haiku` and
`safejax`, we won't use pre-trained weights.

If you're not familiar with `dm-haiku`, please visit [Haiku Basics](https://dm-haiku.readthedocs.io/en/latest/notebooks/basics.html).

First of all, let's create the network instance for the ResNet50 using `dm-haiku`
with the following code:

```python
import haiku as hk
from jax import numpy as jnp

def resnet_fn(x: jnp.DeviceArray, is_training: bool):
resnet = hk.nets.ResNet50(num_classes=10)
return resnet(x, is_training=is_training)

network = hk.without_apply_rng(hk.transform_with_state(resnet_fn))
```

Some notes on the code above:
* `haiku.nets.ResNet50` requires `num_classes` as a mandatory parameter
* `haiku.nets.ResNet50.__call__` requires `is_training` as a mandatory parameter
* It needs to be initialized with `hk.transform_with_state` as we want to preserve
the state e.g. ExponentialMovingAverage in BatchNorm. More information at https://dm-haiku.readthedocs.io/en/latest/api.html#transform-with-state.
* Using `hk.without_apply_rng` removes the `rng` arg in the `.apply` function. More information at https://dm-haiku.readthedocs.io/en/latest/api.html#without-apply-rng.

Then we just initialize the network to retrieve both the `params` and the `state`,
which again, are random.

```python
import jax

rng_key = jax.random.PRNGKey(seed=0)
initial_params, initial_state = network.init(
rng_key, jnp.ones([1, 224, 224, 3]), is_training=True
)
```

Now once we have the `params`, we can import `safejax.serialize` to serialize the
params using `safetensors` as the tensor storage format, that later on can be loaded
back using `safejax.deserialize` and used for the network's inference.

```python
from safejax import deserialize, serialize

encoded_bytes = serialize(params=initial_params)
decoded_params = deserialize(path_or_buf=encoded_bytes)
```

Finally, let's just use those `decoded_params` to run the inference over the network
using those weights.

```python
x = jnp.ones([1, 224, 224, 3])
y, _ = network.apply(decoded_params, initial_state, x, is_training=False)
```
34 changes: 34 additions & 0 deletions examples/flax_ft_safejax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import jax
from flax import linen as nn
from flax.core.frozen_dict import FrozenDict
from jax import numpy as jnp

from safejax import deserialize, serialize


class SingleLayerModel(nn.Module):
features: int

@nn.compact
def __call__(self, x):
x = nn.Dense(features=self.features)(x)
return x


network = SingleLayerModel(features=1)

rng_key = jax.random.PRNGKey(seed=0)
initial_params = network.init(rng_key, jnp.ones((1, 1)))

encoded_bytes = serialize(params=initial_params)
assert isinstance(encoded_bytes, bytes)
assert len(encoded_bytes) > 0

decoded_params = deserialize(path_or_buf=encoded_bytes, freeze_dict=True)
assert isinstance(decoded_params, FrozenDict)
assert len(decoded_params) > 0
assert decoded_params.keys() == initial_params.keys()

x = jnp.ones((1, 1))
y = network.apply(decoded_params, x)
assert y.shape == (1, 1)
31 changes: 31 additions & 0 deletions examples/haiku_ft_safejax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import haiku as hk
import jax
from jax import numpy as jnp

from safejax import deserialize, serialize


def resnet_fn(x: jnp.DeviceArray, is_training: bool):
resnet = hk.nets.ResNet50(num_classes=10)
return resnet(x, is_training=is_training)


network = hk.without_apply_rng(hk.transform_with_state(resnet_fn))

rng_key = jax.random.PRNGKey(seed=0)
initial_params, initial_state = network.init(
rng_key, jnp.ones([1, 224, 224, 3]), is_training=True
)

encoded_bytes = serialize(params=initial_params)
assert isinstance(encoded_bytes, bytes)
assert len(encoded_bytes) > 0

decoded_params = deserialize(path_or_buf=encoded_bytes)
assert isinstance(decoded_params, dict)
assert len(decoded_params) > 0
assert decoded_params.keys() == initial_params.keys()

x = jnp.ones([1, 224, 224, 3])
y, _ = network.apply(decoded_params, initial_state, x, is_training=False)
assert y.shape == (1, 10)
27 changes: 0 additions & 27 deletions examples/serialization_with_flax.py

This file was deleted.

29 changes: 0 additions & 29 deletions examples/serialization_with_safejax.py

This file was deleted.

Loading

0 comments on commit aaee170

Please sign in to comment.