diff --git a/README.md b/README.md index ec7c9b2..d75f0e0 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,17 @@ -# 🔐 Serialize JAX/Flax models with `safetensors` +# 🔐 Serialize JAX, Flax, or Haiku model params with `safetensors` -`safejax` is a Python package to serialize JAX and Flax models using `safetensors` +`safejax` is a Python package to serialize JAX, Flax, or Haiku model params using `safetensors` as the tensor storage format, instead of relying on `pickle`. For more details on why `safetensors` is safer than `pickle` please check https://github.com/huggingface/safetensors. +Note that `safejax` supports the serialization of `jax`, `flax`, and `dm-haiku` model +parameters and has been tested with all those frameworks. Anyway, `objax` is still pending +as the `VarCollection` that it uses internally to store the tensors in memory is restricted +to another naming convention e.g. `(EfficientNet).stem(ConvBnAct).conv(Conv2d).w` +instead of `params.stem.conv.w` because the first can be more useful when debugging, +even though there's some built-in rename functionality to allow loading weights from +other frameworks, but that's still WIP in `safejax`. + ## 🛠️ Requirements & Installation `safejax` requires Python 3.7 or above @@ -14,12 +22,16 @@ pip install safejax --upgrade ## 💻 Usage +Let's create a `flax` model using the Linen API and once initialized, +we can save the model params with `safejax` (using `safetensors` +storage format). + ```python import jax from flax import linen as nn from jax import numpy as jnp -from safejax.flax import serialize +from safejax import serialize class SingleLayerModel(nn.Module): @@ -36,21 +48,38 @@ model = SingleLayerModel(features=1) rng = jax.random.PRNGKey(0) params = model.init(rng, jnp.ones((1, 1))) -serialized = serialize(params=params) -assert isinstance(serialized, bytes) -assert len(serialized) > 0 +serialized_params = serialize(params=params) +``` + +Those params can be later loaded using `safejax.deserialize` and used +to run the inference over the model using those weights. + +```python +from safejax import deserialize + +params = deserialize(path_or_buf=serialized_params, freeze_dict=True) +``` + +And, finally, running the inference as: + +```python +x = jnp.ones((1, 28, 28, 1)) +y = model.apply(params, x) ``` -More examples can be found at [`examples/`](./examples). +More in-detail examples can be found at [`examples/`](./examples) for both `flax` and `dm-haiku`. ## 🤔 Why `safejax`? `safetensors` defines an easy and fast (zero-copy) format to store tensors, while `pickle` has some known weaknesses and security issues. `safetensors` is also a storage format that is intended to be trivial to the framework -used to load the tensors. More in depth information can be found at +used to load the tensors. More in-depth information can be found at https://github.com/huggingface/safetensors. +Both `jax` and `haiku` use `pytrees` to store the model parameters in memory, so +it's a dictionary-like class containing nested `jnp.DeviceArray` tensors. + `flax` defines a dictionary-like class named `FrozenDict` that is used to store the tensors in memory, it can be dumped either into `bytes` in `MessagePack` format or as a `state_dict`. @@ -60,21 +89,21 @@ there are no plans from HuggingFace to extend `safetensors` to support anything more than tensors e.g. `FrozenDict`s, see their response at https://github.com/huggingface/safetensors/discussions/138. -So `safejax` was created so as to easily provide a way to serialize `FrozenDict`s +So `safejax` was created to easily provide a way to serialize `FrozenDict`s using `safetensors` as the tensor storage format instead of `pickle`. ### 📄 Main differences with `flax.serialization` * `flax.serialization.to_bytes` uses `pickle` as the tensor storage format, while -`safejax.flax.serialize` uses `safetensors` +`safejax.serialize` uses `safetensors` * `flax.serialization.from_bytes` requires the `target` to be instantiated, while -`safejax.flax.deserialize` just needs the encoded bytes +`safejax.deserialize` just needs the encoded bytes ## 🏋🏼 Benchmark Benchmarks are no longer running with [`hyperfine`](https://github.com/sharkdp/hyperfine), as most of the elapsed time is not during the actual serialization but in the imports and -in the model parameter initialization. So we've refactored those so as to run with pure +the model parameter initialization. So we've refactored those to run with pure Python code using `time.perf_counter` to measure the elapsed time in seconds. ```bash diff --git a/benchmarks/hyperfine/resnet50.py b/benchmarks/hyperfine/resnet50.py index 038655d..34d0962 100644 --- a/benchmarks/hyperfine/resnet50.py +++ b/benchmarks/hyperfine/resnet50.py @@ -5,7 +5,7 @@ from flaxmodels.resnet import ResNet50 from jax import numpy as jnp -from safejax.flax import serialize +from safejax import serialize resnet50 = ResNet50() params = resnet50.init(jax.random.PRNGKey(42), jnp.ones((1, 224, 224, 3))) diff --git a/benchmarks/hyperfine/single_layer.py b/benchmarks/hyperfine/single_layer.py index c58f428..f385b9c 100644 --- a/benchmarks/hyperfine/single_layer.py +++ b/benchmarks/hyperfine/single_layer.py @@ -5,7 +5,7 @@ from flax import linen as nn from flax.serialization import to_bytes -from safejax.flax import serialize +from safejax import serialize class SingleLayerModel(nn.Module): diff --git a/benchmarks/resnet50.py b/benchmarks/resnet50.py index 77dd874..2089519 100644 --- a/benchmarks/resnet50.py +++ b/benchmarks/resnet50.py @@ -5,7 +5,7 @@ from flaxmodels.resnet import ResNet50 from jax import numpy as jnp -from safejax.flax import serialize +from safejax import serialize resnet50 = ResNet50() params = resnet50.init(jax.random.PRNGKey(42), jnp.ones((1, 224, 224, 3))) diff --git a/benchmarks/single_layer.py b/benchmarks/single_layer.py index a1f1e0d..124a9b0 100644 --- a/benchmarks/single_layer.py +++ b/benchmarks/single_layer.py @@ -5,7 +5,7 @@ from flax.serialization import to_bytes from jax import numpy as jnp -from safejax.flax import serialize +from safejax import serialize class SingleLayerModel(nn.Module): diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 0000000..d862eb7 --- /dev/null +++ b/examples/README.md @@ -0,0 +1,130 @@ +# 💻 Examples + +Here you will find some detailed examples of how to use `safejax` to serialize +model parameters, in opposition to the default way to store those, which uses +`pickle` as the format to store the tensors instead of `safetensors`. + +## Flax - [`flax_ft_safejax`](./examples/flax_ft_safejax.py) + +To run this Python script you won't need to install anything else than +`safejax`, as both `jax` and `flax` are installed as part of it. + +In this case, a single-layer model will be created, for now, `flax` +doesn't have any pre-defined architecture such as ResNet, but you can use +[`flaxmodels`](https://github.com/matthias-wright/flaxmodels) for that, as +it defines some well-known architectures written in `flax`. + +```python +import jax +from flax import linen as nn + +class SingleLayerModel(nn.Module): + features: int + + @nn.compact + def __call__(self, x): + x = nn.Dense(features=self.features)(x) + return x +``` + +Once the network has been defined, we can instantiate and initialize it, +to retrieve the `params` out of the forward pass performed during +`.init`. + +```python +import jax +from jax import numpy as jnp + +network = SingleLayerModel(features=1) + +rng_key = jax.random.PRNGKey(seed=0) +initial_params = network.init(rng_key, jnp.ones((1, 1))) +``` + +Right after getting the `params` from the `.init` method's output, we can +use `safejax.serialize` to encode those using `safetensors`, that later on +can be loaded back using `safejax.deserialize`. + +```python +from safejax import deserialize, serialize + +encoded_bytes = serialize(params=initial_params) +decoded_params = deserialize(path_or_buf=encoded_bytes, freeze_dict=True) +``` + +As seen in the code above, we're using `freeze_dict=True` since its default +value is False, as we want to freeze the `dict` with the params before actually +returning it during `safejax.deserialize`, this transforms the `Dict` +into a `FrozenDict`. + +Finally, we can use those `decoded_params` to run a forward pass +with the previously defined single-layer network. + +```python +x = jnp.ones((1, 1)) +y = network.apply(decoded_params, x) +``` + + +## Haiku - [`haiku_ft_safejax.py`](./examples/haiku_ft_safejax.py) + +To run this Python script you'll need to have both `safejax` and [`dm-haiku`](https://github.com/deepmind/dm-haiku) +installed. + +A ResNet50 architecture will be used from `haiku.nets.imagenet.resnet` and since +the purpose of the example is to show the integration of both `dm-haiku` and +`safejax`, we won't use pre-trained weights. + +If you're not familiar with `dm-haiku`, please visit [Haiku Basics](https://dm-haiku.readthedocs.io/en/latest/notebooks/basics.html). + +First of all, let's create the network instance for the ResNet50 using `dm-haiku` +with the following code: + +```python +import haiku as hk +from jax import numpy as jnp + +def resnet_fn(x: jnp.DeviceArray, is_training: bool): + resnet = hk.nets.ResNet50(num_classes=10) + return resnet(x, is_training=is_training) + +network = hk.without_apply_rng(hk.transform_with_state(resnet_fn)) +``` + +Some notes on the code above: +* `haiku.nets.ResNet50` requires `num_classes` as a mandatory parameter +* `haiku.nets.ResNet50.__call__` requires `is_training` as a mandatory parameter +* It needs to be initialized with `hk.transform_with_state` as we want to preserve +the state e.g. ExponentialMovingAverage in BatchNorm. More information at https://dm-haiku.readthedocs.io/en/latest/api.html#transform-with-state. +* Using `hk.without_apply_rng` removes the `rng` arg in the `.apply` function. More information at https://dm-haiku.readthedocs.io/en/latest/api.html#without-apply-rng. + +Then we just initialize the network to retrieve both the `params` and the `state`, +which again, are random. + +```python +import jax + +rng_key = jax.random.PRNGKey(seed=0) +initial_params, initial_state = network.init( + rng_key, jnp.ones([1, 224, 224, 3]), is_training=True +) +``` + +Now once we have the `params`, we can import `safejax.serialize` to serialize the +params using `safetensors` as the tensor storage format, that later on can be loaded +back using `safejax.deserialize` and used for the network's inference. + +```python +from safejax import deserialize, serialize + +encoded_bytes = serialize(params=initial_params) +decoded_params = deserialize(path_or_buf=encoded_bytes) +``` + +Finally, let's just use those `decoded_params` to run the inference over the network +using those weights. + +```python +x = jnp.ones([1, 224, 224, 3]) +y, _ = network.apply(decoded_params, initial_state, x, is_training=False) +``` diff --git a/examples/flax_ft_safejax.py b/examples/flax_ft_safejax.py new file mode 100644 index 0000000..5696afe --- /dev/null +++ b/examples/flax_ft_safejax.py @@ -0,0 +1,34 @@ +import jax +from flax import linen as nn +from flax.core.frozen_dict import FrozenDict +from jax import numpy as jnp + +from safejax import deserialize, serialize + + +class SingleLayerModel(nn.Module): + features: int + + @nn.compact + def __call__(self, x): + x = nn.Dense(features=self.features)(x) + return x + + +network = SingleLayerModel(features=1) + +rng_key = jax.random.PRNGKey(seed=0) +initial_params = network.init(rng_key, jnp.ones((1, 1))) + +encoded_bytes = serialize(params=initial_params) +assert isinstance(encoded_bytes, bytes) +assert len(encoded_bytes) > 0 + +decoded_params = deserialize(path_or_buf=encoded_bytes, freeze_dict=True) +assert isinstance(decoded_params, FrozenDict) +assert len(decoded_params) > 0 +assert decoded_params.keys() == initial_params.keys() + +x = jnp.ones((1, 1)) +y = network.apply(decoded_params, x) +assert y.shape == (1, 1) diff --git a/examples/haiku_ft_safejax.py b/examples/haiku_ft_safejax.py new file mode 100644 index 0000000..7f14e36 --- /dev/null +++ b/examples/haiku_ft_safejax.py @@ -0,0 +1,31 @@ +import haiku as hk +import jax +from jax import numpy as jnp + +from safejax import deserialize, serialize + + +def resnet_fn(x: jnp.DeviceArray, is_training: bool): + resnet = hk.nets.ResNet50(num_classes=10) + return resnet(x, is_training=is_training) + + +network = hk.without_apply_rng(hk.transform_with_state(resnet_fn)) + +rng_key = jax.random.PRNGKey(seed=0) +initial_params, initial_state = network.init( + rng_key, jnp.ones([1, 224, 224, 3]), is_training=True +) + +encoded_bytes = serialize(params=initial_params) +assert isinstance(encoded_bytes, bytes) +assert len(encoded_bytes) > 0 + +decoded_params = deserialize(path_or_buf=encoded_bytes) +assert isinstance(decoded_params, dict) +assert len(decoded_params) > 0 +assert decoded_params.keys() == initial_params.keys() + +x = jnp.ones([1, 224, 224, 3]) +y, _ = network.apply(decoded_params, initial_state, x, is_training=False) +assert y.shape == (1, 10) diff --git a/examples/serialization_with_flax.py b/examples/serialization_with_flax.py deleted file mode 100644 index 2d63b3c..0000000 --- a/examples/serialization_with_flax.py +++ /dev/null @@ -1,27 +0,0 @@ -import jax -from flax import linen as nn -from flax.serialization import from_bytes, to_bytes -from jax import numpy as jnp - - -class SingleLayerModel(nn.Module): - features: int - - @nn.compact - def __call__(self, x): - x = nn.Dense(features=self.features)(x) - return x - - -model = SingleLayerModel(features=1) - -rng = jax.random.PRNGKey(0) -params = model.init(rng, jnp.ones((1, 1))) - -serialized = to_bytes(target=params) -assert isinstance(serialized, bytes) -assert len(serialized) > 0 - -deserialized = from_bytes(target=model, encoded_bytes=serialized) -assert isinstance(deserialized, dict) -assert len(deserialized) > 0 diff --git a/examples/serialization_with_safejax.py b/examples/serialization_with_safejax.py deleted file mode 100644 index 1b70b2e..0000000 --- a/examples/serialization_with_safejax.py +++ /dev/null @@ -1,29 +0,0 @@ -import jax -from flax import linen as nn -from flax.core.frozen_dict import FrozenDict -from jax import numpy as jnp - -from safejax.flax import deserialize, serialize - - -class SingleLayerModel(nn.Module): - features: int - - @nn.compact - def __call__(self, x): - x = nn.Dense(features=self.features)(x) - return x - - -model = SingleLayerModel(features=1) - -rng = jax.random.PRNGKey(0) -params = model.init(rng, jnp.ones((1, 1))) - -serialized = serialize(params=params) -assert isinstance(serialized, bytes) -assert len(serialized) > 0 - -deserialized = deserialize(path_or_buf=serialized) -assert isinstance(deserialized, FrozenDict) -assert len(deserialized) > 0 diff --git a/pyproject.toml b/pyproject.toml index 19e6a4d..1708d45 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ build-backend = "hatchling.build" requires = ["hatchling"] [project] -authors = [{name = "Alvaro Bartolome"}] +authors = [{name = "Alvaro Bartolome", email = "alvarobartt@yahoo.com"}] classifiers = [ "Development Status :: 4 - Beta", "Programming Language :: Python", @@ -20,7 +20,7 @@ dependencies = [ "flax~=0.6.2", "safetensors~=0.2.5", ] -description = "Serialize JAX/Flax models with `safetensors`" +description = "Serialize JAX, Flax, or Haiku model params with `safetensors`" dynamic = ["version"] keywords = [] license = "MIT" @@ -50,7 +50,7 @@ tests = [ [tool.hatch.envs.quality] features = [ - "quality" + "quality", ] [tool.hatch.envs.quality.scripts] @@ -68,18 +68,18 @@ format = [ profile = "black" [tool.ruff] -select = [ - "E", # pycodestyle errors - "W", # pycodestyle warnings - "F", # pyflakes - "I", # isort - "C", # flake8-comprehensions - "B", # flake8-bugbear -] ignore = [ - "E501", # line too long, handled by black - "B008", # do not perform function calls in argument defaults - "C901", # too complex + "E501", # line too long, handled by black + "B008", # do not perform function calls in argument defaults + "C901", # too complex +] +select = [ + "E", # pycodestyle errors + "W", # pycodestyle warnings + "F", # pyflakes + "I", # isort + "C", # flake8-comprehensions + "B", # flake8-bugbear ] [tool.ruff.isort] @@ -87,7 +87,7 @@ known-first-party = ["safejax"] [tool.hatch.envs.test] features = [ - "tests" + "tests", ] [tool.hatch.envs.test.scripts] diff --git a/src/safejax/__init__.py b/src/safejax/__init__.py index e1e898a..a35c4c6 100644 --- a/src/safejax/__init__.py +++ b/src/safejax/__init__.py @@ -1,4 +1,7 @@ -"""`safejax `: Serialize JAX/Flax models with `safetensors`""" +"""`safejax `: Serialize JAX, Flax, or Haiku model params with `safetensors`""" __author__ = "Alvaro Bartolome " -__version__ = "0.1.1" +__version__ = "0.2.0" + +from safejax.load import deserialize # noqa: F401 +from safejax.save import serialize # noqa: F401 diff --git a/src/safejax/flax.py b/src/safejax/flax.py deleted file mode 100644 index 2b19cbf..0000000 --- a/src/safejax/flax.py +++ /dev/null @@ -1,111 +0,0 @@ -import os -from pathlib import Path -from typing import Any, Dict, Union - -import numpy as np -from flax.core.frozen_dict import FrozenDict, freeze -from jax import numpy as jnp -from safetensors.flax import load, load_file, save, save_file - -from safejax.typing import PathLike - -__all__ = ["serialize", "deserialize"] - - -def flatten_dict( - params: Union[Dict[str, Any], FrozenDict], - key_prefix: Union[str, None] = None, -) -> Union[Dict[str, jnp.DeviceArray], Dict[str, np.ndarray]]: - """ - Flatten a `FrozenDict` or a `Dict` containing Flax model parameters. - - Note: - This function is recursive to explore all the nested dictionaries, - and the keys are being flattened using the `.` character. So that the - later de-nesting can be done using the `.` character as a separator. - - Reference at https://gist.github.com/Narsil/d5b0d747e5c8c299eb6d82709e480e3d - - Args: - params: A `FrozenDict` or a `Dict` containing the model parameters. - key_prefix: A prefix to prepend to the keys of the flattened dictionary. - - Returns: - A flattened dictionary containing the model parameters. - """ - weights = {} - for key, value in params.items(): - key = f"{key_prefix}.{key}" if key_prefix else key - if isinstance(value, jnp.DeviceArray) or isinstance(value, np.ndarray): - weights[key] = value - continue - if isinstance(value, FrozenDict) or isinstance(value, Dict): - weights.update(flatten_dict(params=value, key_prefix=key)) - return weights - - -def serialize( - params: Union[Dict[str, Any], FrozenDict], - filename: Union[PathLike, None] = None, -) -> Union[bytes, PathLike]: - """ - Serialize a Flax model from either a `FrozenDict` or a `Dict`. - - If `filename` is not provided, the serialized model is returned as a `bytes` object, - otherwise the model is saved to the provided `filename` and the `filename` is returned. - - Args: - params: A `FrozenDict` or a `Dict` containing the model parameters. - filename: The path to the file where the model will be saved. - - Returns: - The serialized model as a `bytes` object or the path to the file where the model was saved. - """ - flattened_dict = flatten_dict(params=params) - if not filename: - return save(tensors=flattened_dict) - else: - save_file(tensors=flattened_dict, filename=filename) - return filename - - -def unflatten_dict(tensors: Dict[str, jnp.DeviceArray]) -> FrozenDict: - """ - Unflatten a `FrozenDict` from a `Dict` of tensors. - - Reference at https://stackoverflow.com/a/63545677. - - Args: - tensors: A `Dict` of tensors containing the model parameters. - - Returns: - A `FrozenDict` containing the model parameters. - """ - weights = {} - for key, value in tensors.items(): - subkeys = key.split(".") - for subkey in subkeys[:-1]: - weights = weights.setdefault(subkey, {}) - weights[subkeys[-1]] = value - return freeze(weights) - - -def deserialize(path_or_buf: Union[PathLike, bytes]) -> FrozenDict: - """ - Deserialize a Flax model from either a `bytes` object or a file path. - - Args: - path_or_buf: A `bytes` object or a file path containing the serialized model. - - Returns: - A `FrozenDict` containing the model parameters. - """ - if isinstance(path_or_buf, bytes): - loaded_dict = load(data=path_or_buf) - if ( - isinstance(path_or_buf, str) - or isinstance(path_or_buf, Path) - or isinstance(path_or_buf, os.PathLike) - ): - loaded_dict = load_file(filename=path_or_buf) - return unflatten_dict(tensors=loaded_dict) diff --git a/src/safejax/load.py b/src/safejax/load.py new file mode 100644 index 0000000..a40e7f0 --- /dev/null +++ b/src/safejax/load.py @@ -0,0 +1,36 @@ +import os +from pathlib import Path +from typing import Union + +from flax.core.frozen_dict import FrozenDict, freeze +from safetensors.flax import load, load_file + +from safejax.typing import PathLike +from safejax.utils import unflatten_dict + + +def deserialize( + path_or_buf: Union[PathLike, bytes], freeze_dict: bool = False +) -> FrozenDict: + """ + Deserialize a JAX, Haiku, or Flax model params from either a `bytes` object or a file path, + stored using `safetensors.flax.save_file` or directly saved using `safejax.save.serialize` with + the `filename` parameter. + + Args: + path_or_buf: A `bytes` object or a file path containing the serialized model params. + freeze_dict: Whether to freeze the output `Dict` to be a `FrozenDict` or not. + + Returns: + An unflattened `Dict` or `FrozenDict` containing the model params. + """ + if isinstance(path_or_buf, bytes): + decoded_params = load(data=path_or_buf) + if ( + isinstance(path_or_buf, str) + or isinstance(path_or_buf, Path) + or isinstance(path_or_buf, os.PathLike) + ): + decoded_params = load_file(filename=path_or_buf) + decoded_params_dict = unflatten_dict(tensors=decoded_params) + return freeze(decoded_params_dict) if freeze_dict else decoded_params_dict diff --git a/src/safejax/save.py b/src/safejax/save.py new file mode 100644 index 0000000..ec1fe6f --- /dev/null +++ b/src/safejax/save.py @@ -0,0 +1,32 @@ +from typing import Any, Dict, Union + +from flax.core.frozen_dict import FrozenDict +from safetensors.flax import save, save_file + +from safejax.typing import PathLike +from safejax.utils import flatten_dict + + +def serialize( + params: Union[Dict[str, Any], FrozenDict], + filename: Union[PathLike, None] = None, +) -> Union[bytes, PathLike]: + """ + Serialize a JAX/Flax/Haiku model params from either a `FrozenDict` or a `Dict`. + + If `filename` is not provided, the serialized model is returned as a `bytes` object, + otherwise the model is saved to the provided `filename` and the `filename` is returned. + + Args: + params: A `FrozenDict` or a `Dict` containing the model params. + filename: The path to the file where the model will be saved. + + Returns: + The serialized model as a `bytes` object or the path to the file where the model was saved. + """ + flattened_dict = flatten_dict(params=params) + if not filename: + return save(tensors=flattened_dict) + else: + save_file(tensors=flattened_dict, filename=filename) + return filename diff --git a/src/safejax/utils.py b/src/safejax/utils.py new file mode 100644 index 0000000..95d6bba --- /dev/null +++ b/src/safejax/utils.py @@ -0,0 +1,65 @@ +from typing import Any, Dict, Union + +import numpy as np +from flax.core.frozen_dict import FrozenDict +from jax import numpy as jnp + + +def flatten_dict( + params: Union[Dict[str, Any], FrozenDict], + key_prefix: Union[str, None] = None, +) -> Dict[str, Any]: + """ + Flatten a `FrozenDict` or a `Dict` containing either `jnp.DeviceArray` or + `np.ndarray` as values. + + Note: + This function is recursive to explore all the nested dictionaries, + and the keys are being flattened using the `.` character. So that the + later de-nesting can be done using the `.` character as a separator. + + Reference at https://gist.github.com/Narsil/d5b0d747e5c8c299eb6d82709e480e3d + + Args: + params: A `FrozenDict` or a `Dict` with the params to flatten. + key_prefix: A prefix to prepend to the keys of the flattened dictionary. + + Returns: + A `Dict` containing the flattened params. + """ + flattened_params = {} + for key, value in params.items(): + key = f"{key_prefix}.{key}" if key_prefix else key + if isinstance(value, (jnp.DeviceArray, np.ndarray)): + flattened_params[key] = value + continue + if isinstance(value, (Dict, FrozenDict)): + flattened_params.update( + flatten_dict( + params=value, + key_prefix=key, + ) + ) + return flattened_params + + +def unflatten_dict(tensors: Dict[str, Any]) -> Dict[str, Any]: + """ + Unflatten a `Dict` of tensors stored as a flattened dictionary. + + Reference at https://stackoverflow.com/a/63545677. + + Args: + tensors: A `Dict` of tensors stored as a flattened dictionary. + + Returns: + An unflattened `Dict` of tensors. + """ + params = {} + for key, value in tensors.items(): + params_tmp = params + subkeys = key.split(".") + for subkey in subkeys[:-1]: + params_tmp = params_tmp.setdefault(subkey, {}) + params_tmp[subkeys[-1]] = value + return params diff --git a/tests/test_flax.py b/tests/test_flax.py deleted file mode 100644 index 67e44c8..0000000 --- a/tests/test_flax.py +++ /dev/null @@ -1,62 +0,0 @@ -from pathlib import Path - -import pytest -from flax.core.frozen_dict import FrozenDict - -from safejax.flax import deserialize, serialize - - -@pytest.mark.parametrize( - "params", - [ - pytest.lazy_fixture("single_layer_params"), - pytest.lazy_fixture("resnet50_params"), - ], -) -def test_serialize(params: FrozenDict) -> None: - serialized = serialize(params=params) - assert isinstance(serialized, bytes) - assert len(serialized) > 0 - - -@pytest.mark.parametrize( - "params", - [ - pytest.lazy_fixture("single_layer_params"), - pytest.lazy_fixture("resnet50_params"), - ], -) -@pytest.mark.usefixtures("safetensors_file") -def test_serialize_to_file(params: FrozenDict, safetensors_file: Path) -> None: - safetensors_file = serialize(params=params, filename=safetensors_file) - assert isinstance(safetensors_file, Path) - assert safetensors_file.exists() - - -@pytest.mark.parametrize( - "params", - [ - pytest.lazy_fixture("single_layer_params"), - pytest.lazy_fixture("resnet50_params"), - ], -) -def test_deserialize(params: FrozenDict) -> None: - serialized = serialize(params=params) - deserialized = deserialize(path_or_buf=serialized) - assert isinstance(deserialized, FrozenDict) - assert len(deserialized) > 0 - - -@pytest.mark.parametrize( - "params", - [ - pytest.lazy_fixture("single_layer_params"), - pytest.lazy_fixture("resnet50_params"), - ], -) -@pytest.mark.usefixtures("safetensors_file") -def test_deserialize_from_file(params: FrozenDict, safetensors_file: Path) -> None: - safetensors_file = serialize(params=params, filename=safetensors_file) - deserialized = deserialize(path_or_buf=safetensors_file) - assert isinstance(deserialized, FrozenDict) - assert len(deserialized) > 0 diff --git a/tests/test_load.py b/tests/test_load.py new file mode 100644 index 0000000..ad7791c --- /dev/null +++ b/tests/test_load.py @@ -0,0 +1,36 @@ +from pathlib import Path + +import pytest +from flax.core.frozen_dict import FrozenDict + +from safejax.load import deserialize +from safejax.save import serialize + + +@pytest.mark.parametrize( + "params", + [ + pytest.lazy_fixture("single_layer_params"), + pytest.lazy_fixture("resnet50_params"), + ], +) +def test_deserialize(params: FrozenDict) -> None: + encoded_params = serialize(params=params) + decoded_params = deserialize(path_or_buf=encoded_params, freeze_dict=True) + assert isinstance(decoded_params, FrozenDict) + assert len(decoded_params) > 0 + + +@pytest.mark.parametrize( + "params", + [ + pytest.lazy_fixture("single_layer_params"), + pytest.lazy_fixture("resnet50_params"), + ], +) +@pytest.mark.usefixtures("safetensors_file") +def test_deserialize_from_file(params: FrozenDict, safetensors_file: Path) -> None: + safetensors_file = serialize(params=params, filename=safetensors_file) + decoded_params = deserialize(path_or_buf=safetensors_file, freeze_dict=True) + assert isinstance(decoded_params, FrozenDict) + assert len(decoded_params) > 0 diff --git a/tests/test_save.py b/tests/test_save.py new file mode 100644 index 0000000..011478e --- /dev/null +++ b/tests/test_save.py @@ -0,0 +1,33 @@ +from pathlib import Path + +import pytest +from flax.core.frozen_dict import FrozenDict + +from safejax.save import serialize + + +@pytest.mark.parametrize( + "params", + [ + pytest.lazy_fixture("single_layer_params"), + pytest.lazy_fixture("resnet50_params"), + ], +) +def test_serialize(params: FrozenDict) -> None: + encoded_params = serialize(params=params) + assert isinstance(encoded_params, bytes) + assert len(encoded_params) > 0 + + +@pytest.mark.parametrize( + "params", + [ + pytest.lazy_fixture("single_layer_params"), + pytest.lazy_fixture("resnet50_params"), + ], +) +@pytest.mark.usefixtures("safetensors_file") +def test_serialize_to_file(params: FrozenDict, safetensors_file: Path) -> None: + safetensors_file = serialize(params=params, filename=safetensors_file) + assert isinstance(safetensors_file, Path) + assert safetensors_file.exists() diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..b94221d --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,102 @@ +from typing import Any, Dict + +import pytest +from jax import numpy as jnp + +from safejax.utils import flatten_dict, unflatten_dict + + +@pytest.mark.parametrize( + "input_dict, expected_output_dict", + [ + ( + {"a": jnp.zeros(1), "b": jnp.zeros(1)}, + {"a": jnp.zeros(1), "b": jnp.zeros(1)}, + ), + ( + {"a.b": jnp.zeros(1), "b": jnp.zeros(1)}, + {"a": {"b": jnp.zeros(1)}, "b": jnp.zeros(1)}, + ), + ( + {"a.b": jnp.zeros(1), "a.c": jnp.zeros(1), "b": jnp.zeros(1)}, + {"a": {"b": jnp.zeros(1), "c": jnp.zeros(1)}, "b": jnp.zeros(1)}, + ), + ( + { + "a.b.c": jnp.zeros(1), + "a.b.d": jnp.zeros(1), + "a.e": jnp.zeros(1), + "b": jnp.zeros(1), + }, + { + "a": {"b": {"c": jnp.zeros(1), "d": jnp.zeros(1)}, "e": jnp.zeros(1)}, + "b": jnp.zeros(1), + }, + ), + ( + { + "a.b.c": jnp.zeros(1), + "a.b.d": jnp.zeros(1), + "a.e": jnp.zeros(1), + "b": jnp.zeros(1), + "c": jnp.zeros(1), + }, + { + "a": {"b": {"c": jnp.zeros(1), "d": jnp.zeros(1)}, "e": jnp.zeros(1)}, + "b": jnp.zeros(1), + "c": jnp.zeros(1), + }, + ), + ], +) +def test_unflatten_dict( + input_dict: Dict[str, Any], expected_output_dict: Dict[str, Any] +) -> None: + unflattened_dict = unflatten_dict(tensors=input_dict) + assert unflattened_dict == expected_output_dict + + +@pytest.mark.parametrize( + "input_dict, expected_output_dict", + [ + ( + {"a": {"b": jnp.zeros(1)}, "b": jnp.zeros(1)}, + {"a.b": jnp.zeros(1), "b": jnp.zeros(1)}, + ), + ( + {"a": {"b": jnp.zeros(1), "c": jnp.zeros(1)}, "b": jnp.zeros(1)}, + {"a.b": jnp.zeros(1), "a.c": jnp.zeros(1), "b": jnp.zeros(1)}, + ), + ( + { + "a": {"b": {"c": jnp.zeros(1), "d": jnp.zeros(1)}, "e": jnp.zeros(1)}, + "b": jnp.zeros(1), + }, + { + "a.b.c": jnp.zeros(1), + "a.b.d": jnp.zeros(1), + "a.e": jnp.zeros(1), + "b": jnp.zeros(1), + }, + ), + ( + { + "a": {"b": {"c": jnp.zeros(1), "d": jnp.zeros(1)}, "e": jnp.zeros(1)}, + "b": jnp.zeros(1), + "c": jnp.zeros(1), + }, + { + "a.b.c": jnp.zeros(1), + "a.b.d": jnp.zeros(1), + "a.e": jnp.zeros(1), + "b": jnp.zeros(1), + "c": jnp.zeros(1), + }, + ), + ], +) +def test_flatten_dict( + input_dict: Dict[str, Any], expected_output_dict: Dict[str, Any] +) -> None: + flattened_dict = flatten_dict(params=input_dict) + assert flattened_dict == expected_output_dict