Skip to content

Commit

Permalink
feat(2048): environment performance improvements (#172)
Browse files Browse the repository at this point in the history
Co-authored-by: Clément Bonnet <56230714+clement-bonnet@users.noreply.github.com>
  • Loading branch information
aar65537 and clement-bonnet authored Jun 20, 2023
1 parent 96e8e52 commit 32685cb
Show file tree
Hide file tree
Showing 3 changed files with 258 additions and 224 deletions.
29 changes: 6 additions & 23 deletions jumanji/environments/logic/game_2048/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,7 @@
from jumanji import specs
from jumanji.env import Environment
from jumanji.environments.logic.game_2048.types import Board, Observation, State
from jumanji.environments.logic.game_2048.utils import (
move_down,
move_left,
move_right,
move_up,
)
from jumanji.environments.logic.game_2048.utils import can_move, move
from jumanji.environments.logic.game_2048.viewer import Game2048Viewer
from jumanji.types import TimeStep, restart, termination, transition
from jumanji.viewer import Viewer
Expand Down Expand Up @@ -181,11 +176,7 @@ def step(
timestep: the next timestep.
"""
# Take the action in the environment: Up, Right, Down, Left.
updated_board, additional_reward = jax.lax.switch(
action,
[move_up, move_right, move_down, move_left],
state.board,
)
updated_board, reward = move(state.board, action)

# Generate new key.
random_cell_key, new_state_key = jax.random.split(state.key)
Expand All @@ -209,7 +200,7 @@ def step(
action_mask=action_mask,
step_count=state.step_count + 1,
key=new_state_key,
score=state.score + additional_reward.astype(float),
score=state.score + reward,
)

# Generate the observation from the environment state.
Expand All @@ -227,12 +218,12 @@ def step(
timestep = jax.lax.cond(
done,
lambda: termination(
reward=additional_reward,
reward=reward,
observation=observation,
extras=extras,
),
lambda: transition(
reward=additional_reward,
reward=reward,
observation=observation,
extras=extras,
),
Expand Down Expand Up @@ -303,15 +294,7 @@ def _get_action_mask(self, board: Board) -> chex.Array:
Returns:
action_mask: action mask for the current state of the environment.
"""
action_mask = jnp.array(
[
jnp.any(move_up(board, final_shift=False)[0] != board),
jnp.any(move_right(board, final_shift=False)[0] != board),
jnp.any(move_down(board, final_shift=False)[0] != board),
jnp.any(move_left(board, final_shift=False)[0] != board),
],
)
return action_mask
return jax.vmap(can_move, (None, 0))(board, jnp.arange(4))

def render(self, state: State) -> Optional[NDArray]:
"""Renders the current state of the game board.
Expand Down
Loading

0 comments on commit 32685cb

Please sign in to comment.