Skip to content

Commit

Permalink
feat: switch to ruff and upgrade pre-commit hooks (#260)
Browse files Browse the repository at this point in the history
  • Loading branch information
sash-a authored Nov 5, 2024
1 parent 66dfc93 commit 5ab7166
Show file tree
Hide file tree
Showing 199 changed files with 906 additions and 2,290 deletions.
43 changes: 14 additions & 29 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,22 +1,23 @@
default_stages: [ "commit", "commit-msg", "push" ]
default_stages: [ "pre-commit", "commit-msg", "pre-push" ]
default_language_version:
python: python3


repos:
- repo: https://github.com/timothycrosley/isort
rev: 5.11.5
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.7.2
hooks:
- id: isort

- repo: https://github.com/psf/black
rev: 22.3.0
hooks:
- id: black
name: "Code formatter"
# Run the linter.
- id: ruff
types_or: [ python ]
args: [ --fix ]
# Run the formatter.
- id: ruff-format
types_or: [ python, pyi, jupyter ]

- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.1.0
rev: v5.0.0
hooks:
- id: end-of-file-fixer
name: "End of file fixer"
Expand All @@ -32,22 +33,6 @@ repos:
- id: trailing-whitespace
name: "Trailing whitespace fixer"

- repo: https://github.com/PyCQA/flake8
rev: 7.1.1
hooks:
- id: flake8
name: "Linter"
args:
- --config=setup.cfg
additional_dependencies:
- pep8-naming
- flake8-builtins
- flake8-comprehensions
- flake8-bugbear
- flake8-pytest-style
- flake8-cognitive-complexity
- importlib-metadata<5.0

- repo: local
hooks:
- id: mypy
Expand All @@ -58,15 +43,15 @@ repos:
pass_filenames: false

- repo: https://github.com/alessandrojcm/commitlint-pre-commit-hook
rev: v4.1.0
rev: v9.18.0
hooks:
- id: commitlint
name: "Commit linter"
stages: [ commit-msg ]
additional_dependencies: [ '@commitlint/config-conventional' ]

- repo: https://github.com/Lucas-C/pre-commit-hooks
rev: v1.3.0
rev: v1.5.5
hooks:
- id: insert-license
name: "License inserter"
Expand Down
8 changes: 3 additions & 5 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,14 @@ Before sending your pull request for review, make sure your changes are consiste
### Coding Style
In general, we follow the [Google Style Guide](https://google.github.io/styleguide/pyguide.html).
We use [conventional commit messages](https://www.conventionalcommits.org/en/v1.0.0/) for commit messages.
In addition, to guarantee the quality and uniformity of the code, we use various linters:
In addition, to guarantee the quality and uniformity of the code, we use two tools:

- [Black](https://black.readthedocs.io/en/stable/#) is a deterministic code formatter that is compliant with PEP8 standards.
- [Isort](https://pycqa.github.io/isort/) sorts imports alphabetically and separates them into sections.
- [Flake8](https://flake8.pycqa.org/en/latest/) is a library that wraps PyFlakes and PyCodeStyle. It is a great toolkit for checking your codebase against coding style (PEP8), programming, and syntax errors. Flake8 also benefits from an ecosystem of plugins developed by the community that extend its capabilities. You can read more about Flake8 plugins on the documentation and find a curated list of plugins here.
- [Ruff](https://docs.astral.sh/ruff/) is an extremely fast Python linter and code formatter.
- [MyPy](https://mypy.readthedocs.io/en/stable/#) is a static type checker that can help you detect inconsistent typing of variables.


#### Pre-Commit
To help in automating the quality of the code, we use [pre-commit](https://pre-commit.com/), a framework that manages the installation and execution of git hooks that will be run before every commit. These hooks help to automatically point out issues in code such as formatting mistakes, unused variables, trailing whitespace, debug statements, etc. By pointing these issues out before code review, it allows a code reviewer to focus on the architecture of a change while not wasting time with trivial style nitpicks. Each commit should be preceded by a call to pre-commit to ensure code quality and formatting. The configuration is in .pre-commit-config.yaml and includes Black, Flake8, MyPy and checks for the yaml formatting, trimming trailing whitespace, etc.
To help in automating the quality of the code, we use [pre-commit](https://pre-commit.com/), a framework that manages the installation and execution of git hooks that will be run before every commit. These hooks help to automatically point out issues in code such as formatting mistakes, unused variables, trailing whitespace, debug statements, etc. By pointing these issues out before code review, it allows a code reviewer to focus on the architecture of a change while not wasting time with trivial style nitpicks. Each commit should be preceded by a call to pre-commit to ensure code quality and formatting. The configuration is in .pre-commit-config.yaml and includes Ruff, MyPy and checks for the yaml formatting, trimming trailing whitespace, etc.
Try running: `pre-commit run --all-files`. All linters must pass before committing your change.

### Code of Conduct
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
[![Python Versions](https://img.shields.io/pypi/pyversions/jumanji.svg?style=flat-square)](https://www.python.org/doc/versions/)
[![PyPI Version](https://badge.fury.io/py/jumanji.svg)](https://badge.fury.io/py/jumanji)
[![Tests](https://github.com/instadeepai/jumanji/actions/workflows/tests_linters.yml/badge.svg)](https://github.com/instadeepai/jumanji/actions/workflows/tests_linters.yml)
[![Code Style](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
[![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff)
[![MyPy](http://www.mypy-lang.org/static/mypy_badge.svg)](http://mypy-lang.org/)
[![License](https://img.shields.io/badge/License-Apache%202.0-orange.svg)](https://opensource.org/licenses/Apache-2.0)
[![Hugging Face](https://img.shields.io/badge/%F0%9F%A4%97-Hugging%20Face-F8D521)](https://huggingface.co/InstaDeepAI)
Expand Down
120 changes: 61 additions & 59 deletions examples/load_checkpoints.ipynb

Large diffs are not rendered by default.

15 changes: 12 additions & 3 deletions examples/training.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
"\n",
"# Based on https://stackoverflow.com/questions/67504079/how-to-check-if-an-nvidia-gpu-is-available-on-my-system\n",
"try:\n",
" subprocess.check_output('nvidia-smi')\n",
" subprocess.check_output(\"nvidia-smi\")\n",
" print(\"a GPU is connected.\")\n",
"except Exception:\n",
" # TPU or CPU\n",
Expand Down Expand Up @@ -82,6 +82,7 @@
"outputs": [],
"source": [
"import warnings\n",
"\n",
"warnings.filterwarnings(\"ignore\")\n",
"\n",
"from jumanji.training.train import train\n",
Expand Down Expand Up @@ -117,7 +118,7 @@
},
"outputs": [],
"source": [
"#@title Download Jumanji Configs (run me) { display-mode: \"form\" }\n",
"# @title Download Jumanji Configs (run me) { display-mode: \"form\" }\n",
"\n",
"import os\n",
"import requests\n",
Expand Down Expand Up @@ -407,7 +408,15 @@
],
"source": [
"with initialize(version_base=None, config_path=\"configs\"):\n",
" cfg = compose(config_name=\"config.yaml\", overrides=[f\"env={env}\", f\"agent={agent}\", \"logger.type=terminal\", \"logger.save_checkpoint=true\"])\n",
" cfg = compose(\n",
" config_name=\"config.yaml\",\n",
" overrides=[\n",
" f\"env={env}\",\n",
" f\"agent={agent}\",\n",
" \"logger.type=terminal\",\n",
" \"logger.save_checkpoint=true\",\n",
" ],\n",
" )\n",
"\n",
"train(cfg)"
]
Expand Down
8 changes: 2 additions & 6 deletions jumanji/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,8 @@
register(id="PacMan-v1", entry_point="jumanji.environments:PacMan")

# SlidingTilePuzzle - A sliding tile puzzle environment with the default grid size of 5x5.
register(
id="SlidingTilePuzzle-v0", entry_point="jumanji.environments:SlidingTilePuzzle"
)
register(id="SlidingTilePuzzle-v0", entry_point="jumanji.environments:SlidingTilePuzzle")

# LevelBasedForaging with a random generator with 8 grid size,
# 2 agents and 2 food items and the maximum agent's level is 2.
register(
id="LevelBasedForaging-v0", entry_point="jumanji.environments:LevelBasedForaging"
)
register(id="LevelBasedForaging-v0", entry_point="jumanji.environments:LevelBasedForaging")
16 changes: 6 additions & 10 deletions jumanji/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@ def __repr__(self) -> str:

def __init__(self) -> None:
"""Initialize environment."""
self.observation_spec
self.action_spec
self.reward_spec
self.discount_spec
self.observation_spec # noqa: B018
self.action_spec # noqa: B018
self.reward_spec # noqa: B018
self.discount_spec # noqa: B018

@abc.abstractmethod
def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]:
Expand All @@ -67,9 +67,7 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]:
"""

@abc.abstractmethod
def step(
self, state: State, action: chex.Array
) -> Tuple[State, TimeStep[Observation]]:
def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]:
"""Run one timestep of the environment's dynamics.
Args:
Expand Down Expand Up @@ -115,9 +113,7 @@ def discount_spec(self) -> specs.BoundedArray:
Returns:
discount_spec: a `specs.BoundedArray` spec.
"""
return specs.BoundedArray(
shape=(), dtype=float, minimum=0.0, maximum=1.0, name="discount"
)
return specs.BoundedArray(shape=(), dtype=float, minimum=0.0, maximum=1.0, name="discount")

@property
def unwrapped(self) -> Environment[State, ActionSpec, Observation]:
Expand Down
9 changes: 3 additions & 6 deletions jumanji/environments/commons/maze_utils/maze_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
nodes) through a vertical wall must be at an even y coordinate while a passage through a horizontal
wall must be at an even x coordinate.
"""

from typing import NamedTuple, Tuple

import chex
Expand Down Expand Up @@ -123,9 +124,7 @@ def create_chamber(chambers: Stack, x: int, y: int, width: int, height: int) ->
return new_stack


def split_vertically(
state: MazeGenerationState, chamber: chex.Array
) -> MazeGenerationState:
def split_vertically(state: MazeGenerationState, chamber: chex.Array) -> MazeGenerationState:
"""Split the chamber vertically.
Randomly draw a horizontal wall to split the chamber vertically. Randomly open a passage
Expand Down Expand Up @@ -215,8 +214,6 @@ def generate_maze(width: int, height: int, key: chex.PRNGKey) -> chex.Array:

initial_state = MazeGenerationState(maze, chambers, key)

final_state = jax.lax.while_loop(
chambers_remaining, split_next_chamber, initial_state
)
final_state = jax.lax.while_loop(chambers_remaining, split_next_chamber, initial_state)

return final_state.maze
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,7 @@ def test_random_odd(self, key: chex.PRNGKey) -> None:
assert i % 2 == 1
assert 0 <= i < max_val

def test_split_vertically(
self, maze: chex.Array, chambers: Stack, key: chex.PRNGKey
) -> None:
def test_split_vertically(self, maze: chex.Array, chambers: Stack, key: chex.PRNGKey) -> None:
"""Test that a horizontal wall is drawn and that subchambers are added to stack."""
chambers, chamber = stack_pop(chambers)
state = MazeGenerationState(maze, chambers, key)
Expand All @@ -124,9 +122,7 @@ def test_split_vertically(

assert chambers.insertion_index >= 1

def test_split_horizontally(
self, maze: chex.Array, chambers: Stack, key: chex.PRNGKey
) -> None:
def test_split_horizontally(self, maze: chex.Array, chambers: Stack, key: chex.PRNGKey) -> None:
"""Test that a vertical wall is drawn and that subchambers are added to stack."""
chambers, chamber = stack_pop(chambers)
state = MazeGenerationState(maze, chambers, key)
Expand Down
4 changes: 2 additions & 2 deletions jumanji/environments/commons/maze_utils/maze_rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Optional, Sequence, Tuple
from typing import Callable, ClassVar, Dict, List, Optional, Sequence, Tuple

import chex
import matplotlib.animation
Expand All @@ -32,7 +32,7 @@ class MazeViewer(Viewer):
FONT_STYLE = "monospace"
FIGURE_SIZE = (10.0, 10.0)
# EMPTY is white, WALL is black
COLORS = {EMPTY: [1, 1, 1], WALL: [0, 0, 0]}
COLORS: ClassVar[Dict[int, List[int]]] = {EMPTY: [1, 1, 1], WALL: [0, 0, 0]}

def __init__(self, name: str, render_mode: str = "human") -> None:
"""Viewer for a maze environment.
Expand Down
1 change: 1 addition & 0 deletions jumanji/environments/commons/maze_utils/stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
[. . . .]]
"""

from typing import NamedTuple, Tuple

import chex
Expand Down
16 changes: 4 additions & 12 deletions jumanji/environments/logic/game_2048/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,7 @@ class Game2048(Environment[State, specs.DiscreteArray, Observation]):
```
"""

def __init__(
self, board_size: int = 4, viewer: Optional[Viewer[State]] = None
) -> None:
def __init__(self, board_size: int = 4, viewer: Optional[Viewer[State]] = None) -> None:
"""Initialize the 2048 game.
Args:
Expand Down Expand Up @@ -166,9 +164,7 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]:

return state, timestep

def step(
self, state: State, action: chex.Array
) -> Tuple[State, TimeStep[Observation]]:
def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]:
"""Updates the environment state after the agent takes an action.
Args:
Expand Down Expand Up @@ -279,9 +275,7 @@ def _add_random_cell(self, board: Board, key: chex.PRNGKey) -> Board:
position = jnp.divmod(tile_idx, self.board_size)

# Choose the value of the new cell: 1 with probability 90% or 2 with probability of 10%
cell_value = jax.random.choice(
subkey, jnp.array([1, 2]), p=jnp.array([0.9, 0.1])
)
cell_value = jax.random.choice(subkey, jnp.array([1, 2]), p=jnp.array([0.9, 0.1]))
board = board.at[position].set(cell_value)

return board
Expand Down Expand Up @@ -325,9 +319,7 @@ def animate(
Returns:
animation.FuncAnimation: the animation object that was created.
"""
return self._viewer.animate(
states=states, interval=interval, save_path=save_path
)
return self._viewer.animate(states=states, interval=interval, save_path=save_path)

def close(self) -> None:
"""Perform any necessary cleanup.
Expand Down
12 changes: 3 additions & 9 deletions jumanji/environments/logic/game_2048/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,7 @@ def can_move_left_row_cond(carry: CanMoveCarry) -> chex.Numeric:
def can_move_left_row_body(carry: CanMoveCarry) -> CanMoveCarry:
"""Check if the current tiles can move and increment the indices."""
# Check if tiles can move
can_move = (carry.origin != 0) & (
(carry.target == 0) | (carry.target == carry.origin)
)
can_move = (carry.origin != 0) & ((carry.target == 0) | (carry.target == carry.origin))

# Increment indices as if performed a no op
# If not performing no op, loop will be terminated anyways
Expand All @@ -75,17 +73,13 @@ def can_move_left_row_body(carry: CanMoveCarry) -> CanMoveCarry:
)

# Return updated carry
return carry._replace(
can_move=can_move, target_idx=target_idx, origin_idx=origin_idx
)
return carry._replace(can_move=can_move, target_idx=target_idx, origin_idx=origin_idx)


def can_move_left_row(row: chex.Array) -> bool:
"""Check if row can move left."""
carry = CanMoveCarry(can_move=False, row=row, target_idx=0, origin_idx=1)
can_move: bool = jax.lax.while_loop(
can_move_left_row_cond, can_move_left_row_body, carry
)[0]
can_move: bool = jax.lax.while_loop(can_move_left_row_cond, can_move_left_row_body, carry)[0]
return can_move


Expand Down
12 changes: 4 additions & 8 deletions jumanji/environments/logic/game_2048/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Sequence, Tuple
from typing import ClassVar, Dict, Optional, Sequence, Tuple

import jax.numpy as jnp
import matplotlib.animation
Expand All @@ -24,7 +24,7 @@


class Game2048Viewer(Viewer):
COLORS = {
COLORS: ClassVar[Dict[int | str, str]] = {
1: "#ccc0b3",
2: "#eee4da",
4: "#ede0c8",
Expand Down Expand Up @@ -158,13 +158,9 @@ def render_tile(self, tile_value: int, ax: plt.Axes, row: int, col: int) -> None
"""
# Set the background color of the tile based on its value.
if tile_value <= 16384:
rect = plt.Rectangle(
[col - 0.5, row - 0.5], 1, 1, color=self.COLORS[int(tile_value)]
)
rect = plt.Rectangle([col - 0.5, row - 0.5], 1, 1, color=self.COLORS[int(tile_value)])
else:
rect = plt.Rectangle(
[col - 0.5, row - 0.5], 1, 1, color=self.COLORS["other"]
)
rect = plt.Rectangle([col - 0.5, row - 0.5], 1, 1, color=self.COLORS["other"])
ax.add_patch(rect)

if tile_value in [2, 4]:
Expand Down
Loading

0 comments on commit 5ab7166

Please sign in to comment.