Skip to content

Commit

Permalink
feat(flatpak): change flatpak specs to cached properties
Browse files Browse the repository at this point in the history
  • Loading branch information
aar65537 committed Mar 13, 2024
1 parent 1442e6d commit 063b108
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 7 deletions.
7 changes: 5 additions & 2 deletions jumanji/environments/packing/flat_pack/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import cached_property
from typing import Optional, Sequence, Tuple

import chex
Expand All @@ -34,7 +35,7 @@
from jumanji.viewer import Viewer


class FlatPack(Environment[State]):
class FlatPack(Environment[State, specs.MultiDiscreteArray, Observation]):

"""The FlatPack environment with a configurable number of row and column blocks.
Here the goal of an agent is to completely fill an empty grid by placing all
Expand Down Expand Up @@ -129,6 +130,7 @@ def __init__(
self.viewer = viewer or FlatPackViewer(
"FlatPack", self.num_blocks, render_mode="human"
)
super().__init__()

def __repr__(self) -> str:
return (
Expand All @@ -141,7 +143,6 @@ def reset(
self,
key: chex.PRNGKey,
) -> Tuple[State, TimeStep[Observation]]:

"""Resets the environment.
Args:
Expand Down Expand Up @@ -259,6 +260,7 @@ def close(self) -> None:

self.viewer.close()

@cached_property
def observation_spec(self) -> specs.Spec[Observation]:
"""Returns the observation spec of the environment.
Expand Down Expand Up @@ -307,6 +309,7 @@ def observation_spec(self) -> specs.Spec[Observation]:
action_mask=action_mask,
)

@cached_property
def action_spec(self) -> specs.MultiDiscreteArray:
"""Specifications of the action expected by the `FlatPack` environment.
Expand Down
10 changes: 9 additions & 1 deletion jumanji/environments/packing/flat_pack/env_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@
CellDenseReward,
)
from jumanji.environments.packing.flat_pack.types import State
from jumanji.testing.env_not_smoke import check_env_does_not_smoke
from jumanji.testing.env_not_smoke import (
check_env_does_not_smoke,
check_env_specs_does_not_smoke,
)
from jumanji.testing.pytrees import assert_is_jax_array_tree
from jumanji.types import StepType, TimeStep

Expand Down Expand Up @@ -182,6 +185,11 @@ def test_flat_pack__does_not_smoke(flat_pack: FlatPack) -> None:
check_env_does_not_smoke(flat_pack)


def test_flat_pack__specs_does_not_smoke(flat_pack: FlatPack) -> None:
"""Test that we can access specs without any errors."""
check_env_specs_does_not_smoke(flat_pack)


def test_flat_pack__is_done(flat_pack: FlatPack, key: chex.PRNGKey) -> None:
"""Test that the is_done method works as expected."""

Expand Down
10 changes: 9 additions & 1 deletion jumanji/environments/routing/sokoban/env_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@
SimpleSolveGenerator,
)
from jumanji.environments.routing.sokoban.types import State
from jumanji.testing.env_not_smoke import check_env_does_not_smoke
from jumanji.testing.env_not_smoke import (
check_env_does_not_smoke,
check_env_specs_does_not_smoke,
)
from jumanji.types import TimeStep


Expand Down Expand Up @@ -215,3 +218,8 @@ def test_sokoban__reward_function_solved(sokoban_simple: Sokoban) -> None:
def test_sokoban__does_not_smoke(sokoban: Sokoban) -> None:
"""Test that we can run an episode without any errors."""
check_env_does_not_smoke(sokoban)


def test_sokoban__specs_does_not_smoke(sokoban: Sokoban) -> None:
"""Test that we can access specs without any errors."""
check_env_specs_does_not_smoke(sokoban)
3 changes: 1 addition & 2 deletions jumanji/training/networks/flat_pack/actor_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def make_actor_critic_networks_flat_pack(
hidden_size: int,
) -> ActorCriticNetworks:
"""Make actor-critic networks for the `FlatPack` environment."""
num_values = np.asarray(flat_pack.action_spec().num_values)
num_values = np.asarray(flat_pack.action_spec.num_values)
parametric_action_distribution = FactorisedActionSpaceParametricDistribution(
action_spec_num_values=num_values
)
Expand Down Expand Up @@ -171,7 +171,6 @@ def __call__(self, observation: Observation) -> Tuple[chex.Array, chex.Array]:
) # (B, model_size), (B, num_rows-2, num_cols-2, hidden_size)

for block_id in range(self.num_transformer_layers):

(
self_attention_mask, # (B, 1, num_blocks, num_blocks)
cross_attention_mask, # (B, 1, num_blocks, 1)
Expand Down
2 changes: 1 addition & 1 deletion jumanji/training/networks/flat_pack/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

def make_random_policy_flat_pack(flat_pack: FlatPack) -> RandomPolicy:
"""Make random policy for FlatPack."""
action_spec_num_values = flat_pack.action_spec().num_values
action_spec_num_values = flat_pack.action_spec.num_values

return make_masked_categorical_random_ndim(
action_spec_num_values=action_spec_num_values
Expand Down

0 comments on commit 063b108

Please sign in to comment.