Skip to content

Commit

Permalink
chore: small changes based on review
Browse files Browse the repository at this point in the history
  • Loading branch information
WiemKhlifi committed Oct 25, 2024
1 parent 3f8297d commit 67878f4
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 19 deletions.
8 changes: 2 additions & 6 deletions jumanji/environments/routing/lbf/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,10 @@ def __init__(
assert max_agent_level >= 2, "Maximum agent level must be at least 2."

min_required_cells = num_agents + num_food * 3
assert (
grid_size**2 >= min_required_cells
), "Grid is too small for this many agents and food items."
assert (
grid_size**2
) * 0.6 >= min_required_cells, (
r"Make sure 40% of the grid is empty to allow agents move freely."
)
) * 0.4 <= min_required_cells, r"""Ensure at least 40% of the grid cells remain unoccupied
to facilitate smooth placement and movement of agents and food items."""

self.grid_size = grid_size
self.fov = grid_size if fov is None else fov
Expand Down
24 changes: 17 additions & 7 deletions jumanji/environments/routing/lbf/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,21 +225,31 @@ def make_agents_view(self, agent: Agent, state: State) -> chex.Array:
)

# Assign the foods and agents infos.
agent_view = agent_view.at[jnp.arange(0, 3 * self.num_food, 3)].set(food_xs)
agent_view = agent_view.at[jnp.arange(1, 3 * self.num_food, 3)].set(food_ys)
agent_view = agent_view.at[jnp.arange(2, 3 * self.num_food, 3)].set(food_levels)
agent_view = agent_view.at[jnp.arange(0, 3 * self.num_food, 3)].set(
food_xs, indices_are_sorted=True, unique_indices=True
)
agent_view = agent_view.at[jnp.arange(1, 3 * self.num_food, 3)].set(
food_ys, indices_are_sorted=True, unique_indices=True
)
agent_view = agent_view.at[jnp.arange(2, 3 * self.num_food, 3)].set(
food_levels, indices_are_sorted=True, unique_indices=True
)

# Always place the current agent's info first.
agent_view = agent_view.at[
jnp.arange(3 * self.num_food, 3 * self.num_food + 3)
].set(agent_i_infos)
].set(agent_i_infos, indices_are_sorted=True, unique_indices=True)

start_idx = 3 * self.num_food + 3
end_idx = start_idx + 3 * (self.num_agents - 1)
agent_view = agent_view.at[jnp.arange(start_idx, end_idx, 3)].set(agent_xs)
agent_view = agent_view.at[jnp.arange(start_idx + 1, end_idx, 3)].set(agent_ys)
agent_view = agent_view.at[jnp.arange(start_idx, end_idx, 3)].set(
agent_xs, indices_are_sorted=True, unique_indices=True
)
agent_view = agent_view.at[jnp.arange(start_idx + 1, end_idx, 3)].set(
agent_ys, indices_are_sorted=True, unique_indices=True
)
agent_view = agent_view.at[jnp.arange(start_idx + 2, end_idx, 3)].set(
agent_levels
agent_levels, indices_are_sorted=True, unique_indices=True
)

return agent_view
Expand Down
7 changes: 1 addition & 6 deletions jumanji/environments/routing/lbf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,7 @@ def simulate_agent_movement(
)

# Return the agent with the updated position
return Agent(
id=agent.id,
position=new_agent_position,
level=agent.level,
loading=jnp.asarray(False),
)
return agent.replace(position=new_agent_position) # type: ignore


def update_agent_positions(
Expand Down

0 comments on commit 67878f4

Please sign in to comment.