Skip to content

Commit

Permalink
Merge branch 'main' into feat/lbf-truncate
Browse files Browse the repository at this point in the history
  • Loading branch information
WiemKhlifi authored Jul 11, 2024
2 parents 2945c3b + fd511b4 commit 6c51988
Show file tree
Hide file tree
Showing 18 changed files with 174 additions and 39 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ jobs:
python-version: "3.x"
- name: Install dependencies
run: |
pip install --upgrade pip setuptools twine
pip install --upgrade pip hatch twine
- name: Build and publish
env:
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
TWINE_USERNAME: __token__
TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }}
run: |
python setup.py sdist
hatch build
twine upload dist/*
46 changes: 23 additions & 23 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,39 +26,44 @@
<img src="docs/env_anim/cleaner.gif" alt="Cleaner" width="16%">
<img src="docs/env_anim/connector.gif" alt="Connector" width="16%">
<img src="docs/env_anim/cvrp.gif" alt="CVRP" width="16%">
<img src="docs/env_anim/flat_pack.gif" alt="FlatPack" width="16%">
<img src="docs/env_anim/game_2048.gif" alt="Game2048" width="16%">
<img src="docs/env_anim/graph_coloring.gif" alt="GraphColoring" width="16%">
</div>
<div class="row" align="center">
<img src="docs/env_anim/graph_coloring.gif" alt="GraphColoring" width="16%">
<img src="docs/env_anim/job_shop.gif" alt="JobShop" width="16%">
<img src="docs/env_anim/knapsack.gif" alt="Knapsack" width="16%">
<img src="docs/env_anim/maze.gif" alt="Maze" width="16%">
<img src="docs/env_anim/minesweeper.gif" alt="Minesweeper" width="16%">
<img src="docs/env_anim/mmst.gif" alt="MMST" width="16%">
<img src="docs/env_anim/multi_cvrp.gif" alt="MultiCVRP" width="16%">
</div>
<div class="row" align="center">
<img src="docs/env_anim/multi_cvrp.gif" alt="MultiCVRP" width="16%">
<img src="docs/env_anim/pac_man.gif" alt="PacMan" width="16%">
<img src="docs/env_anim/robot_warehouse.gif" alt="RobotWarehouse" width="16%">
<img src="docs/env_anim/rubiks_cube.gif" alt="RubiksCube" width="16%">
<img src="docs/env_anim/sliding_tile_puzzle.gif" alt="SlidingTilePuzzle" width="16%">
<img src="docs/env_anim/snake.gif" alt="Snake" width="16%">
<img src="docs/env_anim/sudoku.gif" alt="Sudoku" width="16%">
<img src="docs/env_anim/tetris.gif" alt="Tetris" width="16%">
<img src="docs/env_anim/tsp.gif" alt="Tetris" width="16%">
</div>
<div class="row" align="center">
<img src="docs/env_anim/pac_man.gif" alt="RobotWarehouse" width="16%">
<img src="docs/env_anim/sokoban.gif" alt="RobotWarehouse" width="16%">
<img src="docs/env_anim/sudoku.gif" alt="Sudoku" width="16%">
<img src="docs/env_anim/tetris.gif" alt="Tetris" width="16%">
<img src="docs/env_anim/tsp.gif" alt="Tetris" width="16%">
</div>
</div>

## Jumanji @ ICLR 2024

Jumanji has been accepted at [ICLR 2024](https://iclr.cc/), check out our [research paper](https://arxiv.org/abs/2306.09884).

## Welcome to the Jungle! 🌴

Jumanji is a diverse suite of scalable reinforcement learning environments written in JAX.
Jumanji is a diverse suite of scalable reinforcement learning environments written in JAX. It now features 22 environments!

Jumanji is helping pioneer a new wave of hardware-accelerated research and development in the
field of RL. Jumanji's high-speed environments enable faster iteration and large-scale
experimentation while simultaneously reducing complexity. Originating in the Research Team at
experimentation while simultaneously reducing complexity. Originating in the research team at
[InstaDeep](https://www.instadeep.com/), Jumanji is now developed jointly with the open-source
community. To join us in these efforts, reach out, raise issues and read our
[contribution guidelines](https://github.com/instadeepai/jumanji/blob/main/CONTRIBUTING.md) or just
Expand Down Expand Up @@ -98,8 +103,10 @@ problems.
| 🎨 GraphColoring | Logic | `GraphColoring-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/logic/graph_coloring/) | [doc](https://instadeepai.github.io/jumanji/environments/graph_coloring/) |
| 💣 Minesweeper | Logic | `Minesweeper-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/logic/minesweeper/) | [doc](https://instadeepai.github.io/jumanji/environments/minesweeper/) |
| 🎲 RubiksCube | Logic | `RubiksCube-v0`<br/>`RubiksCube-partly-scrambled-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/logic/rubiks_cube/) | [doc](https://instadeepai.github.io/jumanji/environments/rubiks_cube/) |
| ✏️ Sudoku | Logic | `Sudoku-v0` <br/>`Sudoku-very-easy-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/logic/sudoku/) | [doc](https://instadeepai.github.io/jumanji/environments/sudoku/) |
| 📦 BinPack (3D BinPacking Problem) | Packing | `BinPack-v2` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/packing/bin_pack/) | [doc](https://instadeepai.github.io/jumanji/environments/bin_pack/) |
| 🔀 SlidingTilePuzzle | Logic | `SlidingTilePuzzle-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/logic/sliding_tile_puzzle/) | [doc](https://instadeepai.github.io/jumanji/environments/sliding_tile_puzzle/) |
| ✏️ Sudoku | Logic | `Sudoku-v0` <br/>`Sudoku-very-easy-v0`| [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/logic/sudoku/) | [doc](https://instadeepai.github.io/jumanji/environments/sudoku/) |
| 📦 BinPack (3D BinPacking Problem) | Packing | `BinPack-v1` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/packing/bin_pack/) | [doc](https://instadeepai.github.io/jumanji/environments/bin_pack/) |
| 🧩 FlatPack (2D Grid Filling Problem) | Packing | `FlatPack-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/packing/flat_pack/) | [doc](https://instadeepai.github.io/jumanji/environments/flat_pack/) |
| 🏭 JobShop (Job Shop Scheduling Problem) | Packing | `JobShop-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/packing/job_shop/) | [doc](https://instadeepai.github.io/jumanji/environments/job_shop/) |
| 🎒 Knapsack | Packing | `Knapsack-v1` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/packing/knapsack/) | [doc](https://instadeepai.github.io/jumanji/environments/knapsack/) |
| ▒ Tetris | Packing | `Tetris-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/packing/tetris/) | [doc](https://instadeepai.github.io/jumanji/environments/tetris/) |
Expand All @@ -112,15 +119,15 @@ problems.
| 🐍 Snake | Routing | `Snake-v1` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/routing/snake/) | [doc](https://instadeepai.github.io/jumanji/environments/snake/) |
| 📬 TSP (Travelling Salesman Problem) | Routing | `TSP-v1` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/routing/tsp/) | [doc](https://instadeepai.github.io/jumanji/environments/tsp/) |
| Multi Minimum Spanning Tree Problem | Routing | `MMST-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/routing/mmst) | [doc](https://instadeepai.github.io/jumanji/environments/mmst/) |
| ᗧ•••ᗣ•• PacMan | Routing | `PacMan-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/routing/pacman/) | [doc](https://instadeepai.github.io/jumanji/environments/pacman/)
| ᗧ•••ᗣ•• PacMan | Routing | `PacMan-v1` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/routing/pac_man/) | [doc](https://instadeepai.github.io/jumanji/environments/pac_man/)
| 👾 Sokoban | Routing | `Sokoban-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/routing/sokoban/) | [doc](https://instadeepai.github.io/jumanji/environments/sokoban/) |

<h2 name="install" id="install">Installation 🎬</h2>

You can install the latest release of Jumanji from PyPI:

```bash
pip install jumanji
pip install -U jumanji
```

Alternatively, you can install the latest development version directly from GitHub:
Expand Down Expand Up @@ -165,7 +172,7 @@ state, timestep = jax.jit(env.reset)(key)
env.render(state)

# Interact with the (jit-able) environment
action = env.action_spec().generate_value() # Action selection (dummy value here)
action = env.action_spec.generate_value() # Action selection (dummy value here)
state, timestep = jax.jit(env.step)(state, action) # Take a step and observe the next state and time step
```

Expand Down Expand Up @@ -228,17 +235,10 @@ details on how to submit pull requests, our Contributor License Agreement, and c
If you use Jumanji in your work, please cite the library using:

```
@misc{bonnet2023jumanji,
@misc{bonnet2024jumanji,
title={Jumanji: a Diverse Suite of Scalable Reinforcement Learning Environments in JAX},
author={
Clément Bonnet and Daniel Luo and Donal Byrne and Shikha Surana and Vincent Coyette and
Paul Duckworth and Laurence I. Midgley and Tristan Kalloniatis and Sasha Abramowitz and
Cemlyn N. Waters and Andries P. Smit and Nathan Grinsztajn and Ulrich A. Mbou Sob and
Omayma Mahjoub and Elshadai Tegegn and Mohamed A. Mimouni and Raphael Boige and
Ruan de Kock and Daniel Furelos-Blanco and Victor Le and Arnu Pretorius and
Alexandre Laterre
},
year={2023},
author={Clément Bonnet and Daniel Luo and Donal Byrne and Shikha Surana and Sasha Abramowitz and Paul Duckworth and Vincent Coyette and Laurence I. Midgley and Elshadai Tegegn and Tristan Kalloniatis and Omayma Mahjoub and Matthew Macfarlane and Andries P. Smit and Nathan Grinsztajn and Raphael Boige and Cemlyn N. Waters and Mohamed A. Mimouni and Ulrich A. Mbou Sob and Ruan de Kock and Siddarth Singh and Daniel Furelos-Blanco and Victor Le and Arnu Pretorius and Alexandre Laterre},
year={2024},
eprint={2306.09884},
url={https://arxiv.org/abs/2306.09884},
archivePrefix={arXiv},
Expand Down
8 changes: 8 additions & 0 deletions docs/api/environments/flat_pack.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
::: jumanji.environments.packing.flat_pack.env.FlatPack
selection:
members:
- __init__
- reset
- step
- observation_spec
- action_spec
8 changes: 8 additions & 0 deletions docs/api/environments/sliding_tile_puzzle.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
::: jumanji.environments.logic.sliding_tile_puzzle.env.SlidingTilePuzzle
selection:
members:
- __init__
- reset
- step
- observation_spec
- action_spec
Binary file added docs/env_anim/flat_pack.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/env_anim/sliding_tile_puzzle.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/env_img/flat_pack.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/env_img/sliding_tile_puzzle.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
57 changes: 57 additions & 0 deletions docs/environments/flat_pack.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# FlatPack Environment

<p align="center">
<img src="../env_anim/flat_pack.gif" width="500"/>
</p>

We provide here a Jax JIT-able implementation of a packing environment named _flat pack_. The goal of
the agent is to place all the available blocks on an empty 2D grid.
Each time an episode resets a new set of blocks is created and the grid is emptied. Blocks are randomly
shuffled and rotated and all have shape (3, 3).

## Observation
The observation given to the agent gives a view of the current state of the grid as well as
all blocks that can be placed.

- `current_grid`: jax array (float32) of shape `(num_rows, num_cols)` with values in the range
`[0, num_blocks]` (corresponding to the number of each block). This grid will have zeros
where no blocks have been placed and numbers corresponding to each block where that particular
block has been placed.

- `blocks`: jax array (float32) of shape `(num_blocks, 3, 3)` of all possible blocks in
that can fit in the current grid. These blocks are shuffled, rotated and will always have shape `(3, 3)`.

- `action_mask`: jax array (bool) of shape `(num_blocks, 4, num_rows-2, num_cols-2)`, representing
which actions are possible given the current state of the grid. The first index indicates the
number of blocks associated with a given grid. The second index indicates the number of times a block may be rotated.
The third and fourth indices indicate the row and column coordinate of where a blocks top left-most corner may be placed
respectively. Blocks are placed by an agent by specifying the row and column coordinate on the grid where the top left corner
of the selected block should be placed. These values will always be `num_rows-2` and `num_cols-2`
respectively to make it impossible for an agent to place a block outside the current grid.


## Action
The action space is a `MultiDiscreteArray`, specifically a tuple of an index between 0 and `num_blocks - 1`,
an index between 0 and 4 (since there are 4 possible rotations), an index between 0 and `num_rows-2`
(the possible row coordinates for placing a block) and an index between 0 and `num_cols-2`
(the possible column coordinates for placing a block). An action thus consists of four pieces of
information:

- Block to place,

- Number of 90 degree rotations to make to a chosen block ({0, 90, 180, 270} degrees),

- Row coordinate for placing the rotated block's top left corner,

- Column coordinate for placing the rotated block's top left corner.


## Reward
The reward function is configurable, but by default is a fully dense reward giving the sum of the number of non-zero
cells in a placed block normalised by the total number of cells in the grid at each timestep. The episode
terminates if either the grid is filled or `num_blocks` steps have been taken by an agent.


## Registered Versions 📖
- `FlatPack-v0`, a flat pack environment grid with 11 rows and 11 columns containing 5 row blocks and 5 column blocks
for a total of 25 blocks that can be placed on the grid. This version has a dense reward.
2 changes: 1 addition & 1 deletion docs/environments/pac_man.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,4 @@ Eating a ghost when scatter mode is enabled also awards +200 points but, points


## Registered Versions 📖
- `PacMan-v0`, PacMan in a 31x28 map with simple grid observations.
- `PacMan-v1`, PacMan in a 31x28 map with simple grid observations.
52 changes: 52 additions & 0 deletions docs/environments/sliding_tile_puzzle.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Sliding Tile Puzzle Environment

<p align="center">
<img src="../env_anim/sliding_tile_puzzle.gif" width="500"/>
</p>

This is a Jax JIT-able implementation of the classic [Sliding Tile Puzzle game](https://en.wikipedia.org/wiki/Sliding_puzzle).

The Sliding Tile Puzzle game is a classic puzzle that challenges a player to slide (typically flat) pieces along certain routes (usually on a board) to establish a certain end-configuration. The pieces to be moved may consist of simple shapes, or they may be imprinted with colors, patterns, sections of a larger picture (like a jigsaw puzzle), numbers, or letters.

The puzzle is often 3×3, 4×4 or 5×5 in size and made up of square tiles that are slid into a square base, larger than the tiles by one tile space, in a specific large configuration. Tiles are moved/arranged by sliding an adjacent tile into a position occupied by the missing tile, which creates a new space. The sliding puzzle is mechanical and requires the use of no other equipment or tools.

## Observation

The observation in the Sliding Tile Puzzle game includes information about the puzzle, the position of the empty tile, and the action mask.

- `puzzle`: jax array (int32) of shape `(grid_size, grid_size)`, representing the current game state. Each element in the array corresponds to a puzzle tile. The tile represented by 0 is the empty tile.

- Here is an example of a random observation of the game board:

```
[[ 1 2 3 4]
[ 5 6 7 8]
[ 9 10 0 12]
[ 13 14 15 11]]
```
- In this array, the tile represented by 0 is the empty tile that can be moved.

- `empty_tile_position`: a tuple (int32) of shape `(2,)` representing the position of the empty tile in the grid. For example, (2, 2) would represent the third row and the third column in a zero-indexed grid.

- `action_mask`: jax array (bool) of shape `(4,)`, indicating which actions are valid in the current state of the environment. The actions include moving the empty tile up, right, down, or left. For example, an action mask `[True, False, True, False]` means that the valid actions are to move the empty tile upward or downward.

- `step_count`: jax array (int32) of shape `()`, current number of steps in the episode.

## Action

The action space is a `DiscreteArray` of integer values in `[0, 1, 2, 3]`. Specifically, these four actions correspond to moving the empty tile: up (0), right (1), down (2), or left (3).

## Reward

The reward could be either:

- **DenseRewardFn**: This reward function provides a dense reward based on the difference of correctly placed tiles between the current state and the next state. The reward is positive for each newly correctly placed tile and negative for each newly incorrectly placed tile.

- **SparseRewardFn**: This reward function provides a sparse reward, only rewarding when the puzzle is solved.
The reward is 1 if the puzzle is solved, and 0 otherwise.

The goal in all cases is to solve the puzzle in a way that maximizes the reward.

## Registered Versions 📖

- `SlidingTilePuzzle-v0`, the Sliding Tile Puzzle with a grid size of 5x5.
2 changes: 1 addition & 1 deletion docs/guides/advanced_usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ env = AutoResetWrapper(env) # Automatically reset the environment when an ep

batch_size = 7
rollout_length = 5
num_actions = env.action_spec().num_values
num_actions = env.action_spec.num_values

random_key = jax.random.PRNGKey(0)
key1, key2 = jax.random.split(random_key)
Expand Down
4 changes: 2 additions & 2 deletions docs/guides/wrappers.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ env = jumanji.make("Snake-6x6-v0")
dm_env = jumanji.wrappers.JumanjiToDMEnvWrapper(env)

timestep = dm_env.reset()
action = dm_env.action_spec().generate_value()
action = dm_env.action_spec.generate_value()
next_timestep = dm_env.step(action)
...
```
Expand Down Expand Up @@ -52,7 +52,7 @@ key = jax.random.PRNGKey(0)
state, timestep = env.reset(key)
print("New episode")
for i in range(100):
action = env.action_spec().generate_value() # Returns jnp.array(0) when using Snake.
action = env.action_spec.generate_value() # Returns jnp.array(0) when using Snake.
state, timestep = env.step(state, action)
if timestep.first():
print("New episode")
Expand Down
8 changes: 7 additions & 1 deletion examples/load_checkpoints.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,11 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"metadata": {
"collapsed": false
},
"source": [
"## Load configs"
]
Expand Down Expand Up @@ -194,6 +197,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand Down Expand Up @@ -243,6 +247,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand Down Expand Up @@ -279,6 +284,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand Down
Loading

0 comments on commit 6c51988

Please sign in to comment.