@@ -207,7 +207,7 @@ class Intro extends Component {
},
{
title: "Debugging Policies",
- image: "intro/debugpolicies.gif",
+ image: "intro/debugpolicies_updated.gif",
content: (
diff --git a/python/griddly/GymWrapper.py b/python/griddly/GymWrapper.py
index 8e48911d6..36295ec95 100644
--- a/python/griddly/GymWrapper.py
+++ b/python/griddly/GymWrapper.py
@@ -17,8 +17,8 @@ def __init__(self):
self.player_observation_space = None
self.global_observation_space = None
self.action_space_parts = None
- self.max_action_ids = 0
- self.num_action_ids = {}
+ self.max_action_ids = None
+ self.num_action_ids = None
self.action_space = None
self.object_names = None
self.variable_names = None
@@ -125,19 +125,19 @@ def __init__(
@property
def player_count(self):
- if not self._cache.player_count:
+ if self._cache.player_count is None:
self._cache.player_count = self.gdy.get_player_count()
return self._cache.player_count
@property
def level_count(self):
- if not self._cache.level_count:
+ if self._cache.level_count is None:
self._cache.level_count = self.gdy.get_level_count()
return self._cache.level_count
@property
def avatar_object(self):
- if not self._cache.avatar_object:
+ if self._cache.avatar_object is None:
self._cache.avatar_object = self.gdy.get_avatar_object()
return self._cache.avatar_object
@@ -147,13 +147,13 @@ def has_avatar(self):
@property
def action_input_mappings(self):
- if not self._cache.action_input_mappings:
+ if self._cache.action_input_mappings is None:
self._cache.action_input_mappings = self.gdy.get_action_input_mappings()
return self._cache.action_input_mappings
@property
def action_names(self):
- if not self._cache.action_names:
+ if self._cache.action_names is None:
self._cache.action_names = self.gdy.get_action_names()
return self._cache.action_names
@@ -167,25 +167,25 @@ def default_action_name(self):
@property
def object_names(self):
- if not self._cache.object_names:
+ if self._cache.object_names is None:
self._cache.object_names = self.game.get_object_names()
return self._cache.object_names
@property
def variable_names(self):
- if not self._cache.variable_names:
+ if self._cache.variable_names is None:
self._cache.variable_names = self.game.get_object_variable_names()
return self._cache.variable_names
@property
def _vector2rgb(self):
- if not self._cache.vector2rgb:
+ if self._cache.vector2rgb is None:
self._cache.vector2rgb = Vector2RGB(10, len(self.object_names))
return self._cache.vector2rgb
@property
def global_observation_space(self):
- if not self._cache.global_observation_space:
+ if self._cache.global_observation_space is None:
self._cache.global_observation_space = self._get_obs_space(
self.game.get_global_observation_description(),
self._global_observer_type
@@ -194,7 +194,7 @@ def global_observation_space(self):
@property
def player_observation_space(self):
- if not self._cache.player_observation_space:
+ if self._cache.player_observation_space is None:
if self.player_count == 1:
self._cache.player_observation_space = self._get_obs_space(
self._players[0].get_observation_description(),
@@ -203,7 +203,6 @@ def player_observation_space(self):
else:
observation_spaces = []
for p in range(self.player_count):
- observation_description = self._players[p].get_observation_description()
observation_spaces.append(
self._get_obs_space(
self._players[p].get_observation_description(),
@@ -220,25 +219,25 @@ def observation_space(self):
@property
def max_action_ids(self):
- if not self._cache.max_action_ids:
+ if self._cache.max_action_ids is None:
self._init_action_variable_cache()
return self._cache.max_action_ids
@property
def num_action_ids(self):
- if not self._cache.num_action_ids:
+ if self._cache.num_action_ids is None:
self._init_action_variable_cache()
return self._cache.num_action_ids
@property
def action_space_parts(self):
- if not self._cache.action_space_parts:
+ if self._cache.action_space_parts is None:
self._init_action_variable_cache()
return self._cache.action_space_parts
@property
def action_space(self):
- if not self._cache.action_space:
+ if self._cache.action_space is None:
self._cache.action_space = self._create_action_space()
return self._cache.action_space
@@ -518,14 +517,18 @@ def _init_action_variable_cache(self):
if self.action_count > 1:
self._cache.action_space_parts.append(self.action_count)
+ self._cache.num_action_ids = {}
+ max_action_ids = 0
+
for action_name, mapping in sorted(self.action_input_mappings.items()):
if not mapping["Internal"]:
num_action_ids = len(mapping["InputMappings"]) + 1
self._cache.num_action_ids[action_name] = num_action_ids
- if self._cache.max_action_ids < num_action_ids:
- self._cache.max_action_ids = num_action_ids
+ if max_action_ids < num_action_ids:
+ max_action_ids = num_action_ids
- self._cache.action_space_parts.append(self.max_action_ids)
+ self._cache.max_action_ids = max_action_ids
+ self._cache.action_space_parts.append(max_action_ids)
def clone(self):
"""
diff --git a/python/griddly/util/rllib/environment/core.py b/python/griddly/util/rllib/environment/core.py
index 860466a32..a984bae27 100644
--- a/python/griddly/util/rllib/environment/core.py
+++ b/python/griddly/util/rllib/environment/core.py
@@ -12,6 +12,15 @@
)
+class _RLlibEnvCache:
+ def __init__(self):
+ self.action_space = None
+ self.observation_space = None
+
+ def reset(self):
+ self.__init__()
+
+
class RLlibEnv(GymWrapper):
"""
Wraps a Griddly environment for compatibility with RLLib.
@@ -48,6 +57,7 @@ class RLlibEnv(GymWrapper):
def __init__(self, env_config):
super().__init__(**env_config)
+ self._rllib_cache = _RLlibEnvCache()
self.env_steps = 0
self._env_idx = None
self._worker_idx = None
@@ -128,23 +138,30 @@ def _after_step(self, observation, reward, done, info):
return extra_info
- def set_transform(self):
- """
- Create the transform for rllib based on the observation space
- """
-
- if self.player_count > 1:
- self.observation_space = self.observation_space[0]
- self.action_space = self.action_space[0]
+ @property
+ def action_space(self):
+ if self._rllib_cache.action_space is None:
+ self._rllib_cache.action_space = super().action_space[0] if self.player_count > 1 else super().action_space
+ return self._rllib_cache.action_space
+
+ @property
+ def observation_space(self):
+ if self._rllib_cache.observation_space is None:
+ obs_space = super().observation_space[0] if self.player_count > 1 else super().observation_space
+ self._rllib_cache.observation_space = gym.spaces.Box(
+ obs_space.low.transpose((1, 2, 0)).astype(np.float),
+ obs_space.high.transpose((1, 2, 0)).astype(np.float),
+ dtype=np.float,
+ )
+ return self._rllib_cache.observation_space
- self.observation_space = gym.spaces.Box(
- self.observation_space.low.transpose((1, 2, 0)).astype(np.float),
- self.observation_space.high.transpose((1, 2, 0)).astype(np.float),
- dtype=np.float,
- )
+ @property
+ def width(self):
+ return self.observation_space.shape[0]
- self.height = self.observation_space.shape[1]
- self.width = self.observation_space.shape[0]
+ @property
+ def height(self):
+ return self.observation_space.shape[1]
def _get_valid_action_trees(self):
valid_action_trees = self.game.build_valid_action_trees()
@@ -159,8 +176,8 @@ def reset(self, **kwargs):
elif self._random_level_on_reset:
kwargs["level_id"] = np.random.choice(self.level_count)
+ self._rllib_cache.reset()
observation = super().reset(**kwargs)
- self.set_transform()
if self.generate_valid_action_trees:
self.last_valid_action_trees = self._get_valid_action_trees()
diff --git a/python/requirements.txt b/python/requirements.txt
index 447dc7f60..bd865736e 100644
--- a/python/requirements.txt
+++ b/python/requirements.txt
@@ -7,4 +7,5 @@ matplotlib>=3.3.3
pyglet
pytest>=6.2.1
black
-
+ray[rllib]
+torch
diff --git a/python/setup.py b/python/setup.py
index 719530047..baab413d8 100644
--- a/python/setup.py
+++ b/python/setup.py
@@ -71,7 +71,7 @@ def griddly_package_data(config='Debug'):
setup(
name='griddly',
- version="1.4.3",
+ version="1.5.0",
author_email="chrisbam4d@gmail.com",
description="Griddly Python Libraries",
long_description=long_description,
diff --git a/python/tests/rllib_test.py b/python/tests/rllib_test.py
new file mode 100644
index 000000000..eb203eb12
--- /dev/null
+++ b/python/tests/rllib_test.py
@@ -0,0 +1,60 @@
+import os
+import sys
+import ray
+
+from ray import tune
+from ray.rllib.agents.impala import ImpalaTrainer
+from ray.rllib.models import ModelCatalog
+from ray.tune.registry import register_env
+
+from griddly import gd
+from griddly.util.rllib.torch import GAPAgent
+from griddly.util.rllib.environment.core import RLlibEnv
+
+
+def test_rllib_env():
+ sep = os.pathsep
+ os.environ['PYTHONPATH'] = sep.join(sys.path)
+
+ ray.init(num_gpus=0)
+
+ env_name = "ray-griddly-env"
+
+ register_env(env_name, RLlibEnv)
+ ModelCatalog.register_custom_model("GAP", GAPAgent)
+
+ max_training_steps = 1
+
+ config = {
+ 'framework': 'torch',
+ 'num_workers': 1,
+ 'num_envs_per_worker': 1,
+ 'num_gpus': 0,
+ 'model': {
+ 'custom_model': 'GAP',
+ 'custom_model_config': {}
+ },
+ 'env': env_name,
+ 'env_config': {
+ 'random_level_on_reset': True,
+ 'yaml_file': 'Single-Player/GVGAI/clusters_partially_observable.yaml',
+ 'global_observer_type': gd.ObserverType.VECTOR,
+ 'max_steps': 100,
+ },
+ 'entropy_coeff_schedule': [
+ [0, 0.01],
+ [max_training_steps, 0.0]
+ ],
+ 'lr_schedule': [
+ [0, 0.0005],
+ [max_training_steps, 0.0]
+ ]
+ }
+
+ stop = {
+ "timesteps_total": max_training_steps,
+ }
+
+ result = tune.run(ImpalaTrainer, config=config, stop=stop)
+
+ assert result is not None
diff --git a/resources/griddlybear192x192.png b/resources/griddlybear192x192.png
deleted file mode 100644
index 91ed820f0..000000000
Binary files a/resources/griddlybear192x192.png and /dev/null differ
diff --git a/resources/griddlybear512x512.png b/resources/griddlybear512x512.png
deleted file mode 100644
index 600f9875c..000000000
Binary files a/resources/griddlybear512x512.png and /dev/null differ
diff --git a/resources/griddlybear64x64.png b/resources/griddlybear64x64.png
deleted file mode 100644
index 0178c0ef3..000000000
Binary files a/resources/griddlybear64x64.png and /dev/null differ
diff --git a/resources/logo.png b/resources/logo.png
index 3a6008166..f3df04f21 100644
Binary files a/resources/logo.png and b/resources/logo.png differ
diff --git a/resources/logo1200x630.png b/resources/logo1200x630.png
deleted file mode 100644
index e80720e3e..000000000
Binary files a/resources/logo1200x630.png and /dev/null differ
diff --git a/resources/logo800x418.png b/resources/logo800x418.png
deleted file mode 100644
index ad985d718..000000000
Binary files a/resources/logo800x418.png and /dev/null differ