From 67878f4fe2574ba167038358ffc073c098223ab0 Mon Sep 17 00:00:00 2001 From: WiemKhlifi Date: Fri, 25 Oct 2024 18:34:16 +0100 Subject: [PATCH] chore: small changes based on review --- jumanji/environments/routing/lbf/generator.py | 8 ++----- jumanji/environments/routing/lbf/observer.py | 24 +++++++++++++------ jumanji/environments/routing/lbf/utils.py | 7 +----- 3 files changed, 20 insertions(+), 19 deletions(-) diff --git a/jumanji/environments/routing/lbf/generator.py b/jumanji/environments/routing/lbf/generator.py index 5158c77e8..55f89da90 100644 --- a/jumanji/environments/routing/lbf/generator.py +++ b/jumanji/environments/routing/lbf/generator.py @@ -56,14 +56,10 @@ def __init__( assert max_agent_level >= 2, "Maximum agent level must be at least 2." min_required_cells = num_agents + num_food * 3 - assert ( - grid_size**2 >= min_required_cells - ), "Grid is too small for this many agents and food items." assert ( grid_size**2 - ) * 0.6 >= min_required_cells, ( - r"Make sure 40% of the grid is empty to allow agents move freely." - ) + ) * 0.4 <= min_required_cells, r"""Ensure at least 40% of the grid cells remain unoccupied + to facilitate smooth placement and movement of agents and food items.""" self.grid_size = grid_size self.fov = grid_size if fov is None else fov diff --git a/jumanji/environments/routing/lbf/observer.py b/jumanji/environments/routing/lbf/observer.py index c1f9f1836..7565ee9ce 100644 --- a/jumanji/environments/routing/lbf/observer.py +++ b/jumanji/environments/routing/lbf/observer.py @@ -225,21 +225,31 @@ def make_agents_view(self, agent: Agent, state: State) -> chex.Array: ) # Assign the foods and agents infos. - agent_view = agent_view.at[jnp.arange(0, 3 * self.num_food, 3)].set(food_xs) - agent_view = agent_view.at[jnp.arange(1, 3 * self.num_food, 3)].set(food_ys) - agent_view = agent_view.at[jnp.arange(2, 3 * self.num_food, 3)].set(food_levels) + agent_view = agent_view.at[jnp.arange(0, 3 * self.num_food, 3)].set( + food_xs, indices_are_sorted=True, unique_indices=True + ) + agent_view = agent_view.at[jnp.arange(1, 3 * self.num_food, 3)].set( + food_ys, indices_are_sorted=True, unique_indices=True + ) + agent_view = agent_view.at[jnp.arange(2, 3 * self.num_food, 3)].set( + food_levels, indices_are_sorted=True, unique_indices=True + ) # Always place the current agent's info first. agent_view = agent_view.at[ jnp.arange(3 * self.num_food, 3 * self.num_food + 3) - ].set(agent_i_infos) + ].set(agent_i_infos, indices_are_sorted=True, unique_indices=True) start_idx = 3 * self.num_food + 3 end_idx = start_idx + 3 * (self.num_agents - 1) - agent_view = agent_view.at[jnp.arange(start_idx, end_idx, 3)].set(agent_xs) - agent_view = agent_view.at[jnp.arange(start_idx + 1, end_idx, 3)].set(agent_ys) + agent_view = agent_view.at[jnp.arange(start_idx, end_idx, 3)].set( + agent_xs, indices_are_sorted=True, unique_indices=True + ) + agent_view = agent_view.at[jnp.arange(start_idx + 1, end_idx, 3)].set( + agent_ys, indices_are_sorted=True, unique_indices=True + ) agent_view = agent_view.at[jnp.arange(start_idx + 2, end_idx, 3)].set( - agent_levels + agent_levels, indices_are_sorted=True, unique_indices=True ) return agent_view diff --git a/jumanji/environments/routing/lbf/utils.py b/jumanji/environments/routing/lbf/utils.py index c7956bd99..1e45ea180 100644 --- a/jumanji/environments/routing/lbf/utils.py +++ b/jumanji/environments/routing/lbf/utils.py @@ -90,12 +90,7 @@ def simulate_agent_movement( ) # Return the agent with the updated position - return Agent( - id=agent.id, - position=new_agent_position, - level=agent.level, - loading=jnp.asarray(False), - ) + return agent.replace(position=new_agent_position) # type: ignore def update_agent_positions(