Skip to content

Commit

Permalink
refactor: Simplify Schelling code (#222)
Browse files Browse the repository at this point in the history
* refactor: Simplify Schelling code

1. Remove unused model attributes
2. Make `similar` calculation more natural language readable

* Remove unused argument doc

* Add type hints to agent class

* refactor: Simplify self.running expression
  • Loading branch information
rht authored Oct 14, 2024
1 parent 5739e84 commit d63ce06
Showing 1 changed file with 8 additions and 16 deletions.
24 changes: 8 additions & 16 deletions examples/schelling/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,21 @@ class SchellingAgent(mesa.Agent):
Schelling segregation agent
"""

def __init__(self, model, agent_type):
def __init__(self, model: mesa.Model, agent_type: int) -> None:
"""
Create a new Schelling agent.
Args:
x, y: Agent initial location.
agent_type: Indicator for the agent's type (minority=1, majority=0)
"""
super().__init__(model)
self.type = agent_type

def step(self):
similar = 0
for neighbor in self.model.grid.iter_neighbors(
def step(self) -> None:
neighbors = self.model.grid.iter_neighbors(
self.pos, moore=True, radius=self.model.radius
):
if neighbor.type == self.type:
similar += 1
)
similar = sum(1 for neighbor in neighbors if neighbor.type == self.type)

# If unhappy, move:
if similar < self.model.homophily:
Expand Down Expand Up @@ -60,10 +57,6 @@ def __init__(
"""

super().__init__(seed=seed)
self.height = height
self.width = width
self.density = density
self.minority_pc = minority_pc
self.homophily = homophily
self.radius = radius

Expand All @@ -79,8 +72,8 @@ def __init__(
# the coordinates of a cell as well as
# its contents. (coord_iter)
for _, pos in self.grid.coord_iter():
if self.random.random() < self.density:
agent_type = 1 if self.random.random() < self.minority_pc else 0
if self.random.random() < density:
agent_type = 1 if self.random.random() < minority_pc else 0
agent = SchellingAgent(self, agent_type)
self.grid.place_agent(agent, pos)

Expand All @@ -95,5 +88,4 @@ def step(self):

self.datacollector.collect(self)

if self.happy == len(self.agents):
self.running = False
self.running = self.happy != len(self.agents)

0 comments on commit d63ce06

Please sign in to comment.