diff --git a/docs/source/agents/agents.rst b/docs/source/agents/agents.rst index a802fff..7ade085 100644 --- a/docs/source/agents/agents.rst +++ b/docs/source/agents/agents.rst @@ -6,4 +6,5 @@ Agents :caption: Agents prebuilt + curriculum_agents base \ No newline at end of file diff --git a/docs/source/agents/curriculum_agents.rst b/docs/source/agents/curriculum_agents.rst new file mode 100644 index 0000000..cb8ebe3 --- /dev/null +++ b/docs/source/agents/curriculum_agents.rst @@ -0,0 +1,4 @@ +.. automodule:: conformer_rl.agents.curriculum_agents + :members: + :show-inheritance: + :inherited-members: \ No newline at end of file diff --git a/docs/source/conf.py b/docs/source/conf.py index e039134..d012408 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -24,7 +24,7 @@ author = 'Runxuan Jiang' # The full version, including alpha/beta/rc tags -release = '1.0.0' +release = '1.1.0' # -- General configuration --------------------------------------------------- diff --git a/docs/source/environments/curriculum_conformer_env.rst b/docs/source/environments/curriculum_conformer_env.rst new file mode 100644 index 0000000..ebda149 --- /dev/null +++ b/docs/source/environments/curriculum_conformer_env.rst @@ -0,0 +1,3 @@ +.. automodule:: conformer_rl.environments.curriculum_conformer_env + :members: + :show-inheritance: \ No newline at end of file diff --git a/docs/source/environments/environments.rst b/docs/source/environments/environments.rst index c4cfd61..7750927 100644 --- a/docs/source/environments/environments.rst +++ b/docs/source/environments/environments.rst @@ -7,5 +7,6 @@ Environments prebuilt_environments conformer_env + curriculum_conformer_env components/environment_components environment_wrapper \ No newline at end of file diff --git a/docs/source/index.rst b/docs/source/index.rst index c03afeb..7ca2293 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -17,6 +17,7 @@ Introduction tutorial/model_tuning tutorial/customizing_env_1 tutorial/customizing_env_2 + tutorial/curriculum .. toctree:: :caption: API Reference diff --git a/docs/source/tutorial/curriculum.rst b/docs/source/tutorial/curriculum.rst new file mode 100644 index 0000000..fdbf9a5 --- /dev/null +++ b/docs/source/tutorial/curriculum.rst @@ -0,0 +1,72 @@ +Utilizing Curriculum Learning +============================= +This section walks through how to train an agent using curriculum learning. + +What is Curriculum Learning? +---------------------------- +Curriculum learning can be viewed as applying transfer learning iteratively. In order to train an agent on a specific task, the agent +is first on a similar but easier task. Once the agent has learned that task, it will then be trained on a slightly more difficult task. This continues until the agent is trained directly on the original task. + +Previous empirical results have shown that through curriculum learning, an agent can learn difficult tasks that it is not able to learn by training directly on the task itself. Even if it is able to learn a task by training directly on that task, curriculum learning often makes the training process more efficient (it reduces the training time required). + +:mod:`conformer_rl` contains implementations of mixin classes that can make any of the included environments and agents compatible with curriculum learning. + +Curriculum Learning Example Training Script +------------------------------------------- +The full code for this example can be found in `examples/curriculum_example.py `_. + +In this example, we want to train an agent to generate conformers for a branched alkane molecule with 16 carbon atoms. However, instead of training directly on this molecule, we will utilize a curriculum where the agent begins by training on a branched alkane with 8 atoms, and then iteratively moves up to a branched alkane with 15 atoms. + +We first generate the :class:`~conformer_rl.config.mol_config.MolConfig` objects for the training and evaluation environments. For the training environment, we want a list of :class:`~conformer_rl.config.mol_config.MolConfig` objects starting with a branched alkane with 8 carbon atoms, up to a branched alkane with 15 carbon atoms:: + + # Create mol_configs for the curriculum + mol_configs = [config_from_rdkit(generate_branched_alkane(i), num_conformers=200, calc_normalizers=True) for i in range(8, 16)] + +Next, we create a mol_config for the evaluation environment. Note that the evaluation environment will not be a curriculum environment since we are only evaluating the agent on a single conformer:: + + eval_mol_config = config_from_rdkit(generate_branched_alkane(16), num_conformers=200, calc_normalizers=True) + +Next, we will set up the :class:`~conformer_rl.config.agent_config.Config` object for the agent and hyperparameters as we have done in the previous sections:: + + config = Config() + config.tag = 'curriculum_test' + config.network = RTGNRecurrent(6, 128, edge_dim=6, node_dim=5).to(device) + + # Batch Hyperparameters + config.max_steps = 100000 + + # training Hyperparameters + lr = 5e-6 * np.sqrt(10) + config.optimizer_fn = lambda params: torch.optim.Adam(params, lr=lr, eps=1e-5) + + # curriculum Hyperparameters + config.curriculum_agent_buffer_len = 20 + config.curriculum_agent_reward_thresh = 0.7 + config.curriculum_agent_success_rate = 0.7 + config.curriculum_agent_fail_rate = 0.2 + +We will now create the environments for training and evaluation. :mod:`conformer_rl` already has pre-built environments for curriculum learning. We will use the :class:`~conformer_rl.environments.environments.GibbsScorePruningCurriculumEnv` environment which is the same as the :class:`~conformer_rl.environments.environments.GibbsScorePruningEnv` we used previously except it is now compatible with curriculum learning. We will set the evaluation env to :class:`~conformer_rl.environments.environments.GibbsScorePruningEnv`:: + + # Task Settings + config.train_env = Task('GibbsScorePruningCurriculumEnv-v0', concurrency=True, num_envs=10, seed=np.random.randint(0,1e5), mol_configs=mol_configs) + config.eval_env = Task('GibbsScorePruningEnv-v0', seed=np.random.randint(0,7e4), mol_config=eval_mol_config) + config.eval_interval = 20000 + +Next, we need to specify hyperaparameters specific to the curriculum. The specific meaning of each hyperparameter is discussed in :ref:`Curriculum-Supported Agents` and :ref:`Curriculum Conformer_env`:: + + # curriculum Hyperparameters + config.curriculum_agent_buffer_len = 20 + config.curriculum_agent_reward_thresh = 0.7 + config.curriculum_agent_success_rate = 0.7 + config.curriculum_agent_fail_rate = 0.2 + +Finally, we initiate our agent. Each of the pre-built agents in :mod:`conformer_rl` has a curriculum version as well. In this example we will use :class:`~conformer_rl.agents.curriculum_agents.PPORecurrentExternalCurriculumAgent`:: + + agent = PPORecurrentExternalCurriculumAgent(config) + agent.run_steps() + +We can now run the script to train the agent. + +For more information on how the curriculum environments and agents work, see the sections :ref:`Curriculum Conformer_env` and :ref:`Curriculum-Supported Agents`. + + diff --git a/docs/source/tutorial/customizing_env_1.rst b/docs/source/tutorial/customizing_env_1.rst index 37968c5..d548fe3 100644 --- a/docs/source/tutorial/customizing_env_1.rst +++ b/docs/source/tutorial/customizing_env_1.rst @@ -74,7 +74,7 @@ we must initialize the neural network with the correct ``node_dim``. In :ref:`Ge Finally, when setting the ``train_env`` and ``eval_env``, we must specify the name of the environment to be the ``'Test-Env-v0'`` we registered:: # Set the environment to the test env - config.train_env = Task('TestEnv-v0', concurrency=True, num_envs=5, seed=np.random.randint(0,1e5), mol_config=mol_config, max_steps=200) - config.eval_env = Task('TestEnv-v0', seed=np.random.randint(0,7e4), mol_config=mol_config, max_steps=200) + config.train_env = Task('TestEnv-v0', concurrency=True, num_envs=5, seed=np.random.randint(0,1e5), mol_config=mol_config) + config.eval_env = Task('TestEnv-v0', seed=np.random.randint(0,7e4), mol_config=mol_config) \ No newline at end of file diff --git a/docs/source/tutorial/getting_started.rst b/docs/source/tutorial/getting_started.rst index efbe861..c52fb87 100644 --- a/docs/source/tutorial/getting_started.rst +++ b/docs/source/tutorial/getting_started.rst @@ -26,9 +26,9 @@ Suppose we want to generate conformers for a branched alkane molecule with 14 ca Next, we can use the function :func:`~conformer_rl.molecule_generation.generate_molecule_config.config_from_rdkit`:: - mol_config = config_from_rdkit(mol, calc_normalizers=True, save_file='alkane') + mol_config = config_from_rdkit(mol, num_conformers=200, calc_normalizers=True, save_file='alkane') -which will create a :class:`~conformer_rl.config.mol_config.MolConfig` for our branched alkane. By setting ``calc_normalizeres=True``, the function will calculate the normalizing constants which will be later used by the environment for calculating rewards. The ``ep_steps`` parameter specifies the number of conformers we want to generate in each environment episode and is used for calculating the normalizing constants. We also set ``save_file='alkane'``, so that the generated :class:`~conformer_rl.config.mol_config.MolConfig` object is dumped as a binary `Pickle `_ file named ``alkane.pkl``, so that it can be reused later. +which will create a :class:`~conformer_rl.config.mol_config.MolConfig` for our branched alkane. The ``num_conformers`` parameter specifies the number of conformers we want to generate in each environment episode, in this case 200. By setting ``calc_normalizeres=True``, the function will calculate the normalizing constants which will be later used by the environment for calculating rewards. We also set ``save_file='alkane'``, so that the generated :class:`~conformer_rl.config.mol_config.MolConfig` object is dumped as a binary `Pickle `_ file named ``alkane.pkl``, so that it can be reused later. There are two main benefits for saving the generated :class:`~conformer_rl.config.mol_config.MolConfig` object. Firstly, the normalizing constants generated by setting ``calc_normalizers=True`` are not deterministic and relies on rdkit's conformer generation functionality (which uses random initialization). As discussed above, if we wish to compare the performance of two separate models on the same environment, the same set of normalizing constants should be used for both models, and a new set of normalizing constants should not be generated. Secondly, the generation of normalizing constants can be time consuming for large molecules, and it is therefore unnecessary to re-generate these constants for the same molecule for multiple experiments. @@ -41,11 +41,11 @@ Custom Molecules """""""""""""""" If you have prepared your own molecule for conformer generation, and it is not in a rdkit mol format, :mod:`conformer_rl` also has functions to create :class:`~conformer_rl.config.mol_config.MolConfig` for other formats. For example, if your molecule can be expressed as a SMILES string, you can use the :func:`~conformer_rl.molecule_generation.generate_molecule_config.config_from_smiles` function, such as in the following example:: - mol_config = config_from_smiles('CC(CCC)CC', calc_normalizers=True, save_file='alkane') + mol_config = config_from_smiles('CC(CCC)CC', num_conformers=200, calc_normalizers=True, save_file='alkane') The molecule can also be in the form of a MOL file, in which the function :func:`~conformer_rl.molecule_generation.generate_molecule_config.config_from_molFile` can be used:: - mol_config = config_from_molFile('name_of_mol_file.mol', calc_normalizers=True, save_file='alkane') + mol_config = config_from_molFile('name_of_mol_file.mol', num_conformers=200, calc_normalizers=True, save_file='alkane') Configuring the Agent ^^^^^^^^^^^^^^^^^^^^^ @@ -60,12 +60,12 @@ Training Environment """""""""""""""""""" Next, we will set the training environment for the agent:: - config.train_env = Task('GibbsScorePruningEnv-v0', concurrency=True, num_envs=5, seed=np.random.randint(0,1e5), mol_config=mol_config, max_steps=200) + config.train_env = Task('GibbsScorePruningEnv-v0', concurrency=True, num_envs=5, seed=np.random.randint(0,1e5), mol_config=mol_config) :func:`~conformer_rl.environments.environment_wrapper.Task` is a function that generates an environment wrapper compatible with the agent. Its main functionality is to generate multiple environments that the agent can interact with concurrently, which speeds up training if there are multiple CPU cores available. The first parameter, ``'GibbsScorePruningEnv-v0'``, specifies the name of the environment implementation to be used. In this case it represents the class :class:`~conformer_rl.environments.environments.GibbsScorePruningEnv`, which has empirically produced good results for several organic molecules. To learn more about how environments are registered and how to create custom environments, see :ref:`Customizing Environment - Part One` and :ref:`Customizing Environment - Part Two`. -We set ``concurrency=True`` to utilize multithreading across each of the parallel environments during training. The ``num_envs`` parameter specifies the number of environments to be run in parallel. Next we pass in the :class:`~conformer_rl.config.mol_config.MolConfig` object we created earlier by setting ``mol_config=mol_config`` to specify molecule specific parameters when initiating the environments. Finally, we set the ``max_steps`` parameter, which specifies the number of conformers to generate (i.e., the number of environment steps) before the end of an episode in the environment. This parameter should be set to the same number as the ``ep_steps`` parameter when generating the normalizing constants for the :class:`~conformer_rl.config.mol_config.MolConfig` object using :func:`~conformer_rl.molecule_generation.generate_molecule_config.config_from_rdkit`, as described in :ref:`Configuring the Environment`. +We set ``concurrency=True`` to utilize multithreading across each of the parallel environments during training. The ``num_envs`` parameter specifies the number of environments to be run in parallel. Next we pass in the :class:`~conformer_rl.config.mol_config.MolConfig` object we created earlier by setting ``mol_config=mol_config`` to specify molecule specific parameters when initiating the environments. Evaluation Environment """""""""""""""""""""" @@ -73,7 +73,7 @@ Optionally, we can specify an evaluation environment, which is an environment in For simplicity, we will simply use the same molecule config for the evaluation environment in this example. We specify the evaluation environment in a similar way as the training environment, except that we do not require parallel environments so we use the default values for the ``concurrency`` and ``num_envs`` parameters:: - config.eval_env = Task('GibbsScorePruningEnv-v0', seed=np.random.randint(0,7e4), mol_config=mol_config, max_steps=200) + config.eval_env = Task('GibbsScorePruningEnv-v0', seed=np.random.randint(0,7e4), mol_config=mol_config) config.eval_episodes=10000 ``config.eval_episodes`` specifies how often (in number of episodes) the agent should be evaluated on the evaluation environment. If this is set to 0, the agent will not be evaluated on the evaluation environment. diff --git a/docs/source/tutorial/model_tuning.rst b/docs/source/tutorial/model_tuning.rst index d479d5f..506a584 100644 --- a/docs/source/tutorial/model_tuning.rst +++ b/docs/source/tutorial/model_tuning.rst @@ -13,17 +13,18 @@ As in :ref:`Getting Started - Training a Conformer Generation Agent`, we set up # configure molecule mol = generate_lignin(3) - mol_config = config_from_rdkit(mol, calc_normalizers=True, save_file='lignin') + mol_config = config_from_rdkit(mol, num_conformers=200, calc_normalizers=True, save_file='lignin') # create agent config and set environment config = Config() config.tag = 'example2' - config.train_env = Task('GibbsScorePruningEnv-v0', concurrency=True, num_envs=20, mol_config=mol_config, max_steps=200) + config.train_env = Task('GibbsScorePruningEnv-v0', concurrency=True, num_envs=10, mol_config=mol_config) Configuring the Neural Network ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ :mod:`conformer_rl` contains implementations of several graph neural network models, which can be found in :ref:`models`. One neural network architecture that has performed well empirically in conformer generation is :class:`~conformer_rl.models.RTGN.RTGN`, which we will use in this example:: + # Neural Network config.network = RTGN(6, 128, edge_dim=6, node_dim=5).to(device) Notice that the observation from :class:`~conformer_rl.environments.environments.GibbsScorePruningEnv` @@ -34,6 +35,7 @@ Configuring Logging ^^^^^^^^^^^^^^^^^^^ Next, we configure logging options:: + # Logging Parameters config.save_interval = 20000 config.data_dir = 'data' config.use_tensorboard = True @@ -45,9 +47,10 @@ Configuring the Evaluation Environment Next, we can set up evaluation of the agent. In this example, we will have the agent be evaluated every 20000 steps, and we will set the eval environment to be conformer generation for a lignin polymer with four monomers (instead of three). Thus, the evaluation environment will allow us to see whether the agent is able to generalize from three monomer lignin to four monomer lignin. We will also have the agent evaluate for 2 episodes during each evaluation:: + # Set up evaluation eval_mol = generate_lignin(4) - eval_mol_config = config_from_rdkit(mol, calc_normalizers=True, ep_steps=200, save_file='lignin_eval') - config.eval_env = Task('GibbsScorePruningEnv-v0', num_envs=1, mol_config=eval_mol_config, max_steps=200) + eval_mol_config = config_from_rdkit(mol, num_conformers=200, calc_normalizers=True, save_file='lignin_eval') + config.eval_env = Task('GibbsScorePruningEnv-v0', num_envs=1, mol_config=eval_mol_config) config.eval_interval = 20000 config.eval_episodes = 2 @@ -59,7 +62,7 @@ Finally, we can set the other hyperparameters. For more information on what each config.rollout_length = 20 config.recurrence = 5 config.optimization_epochs = 4 - config.max_steps = 200000 + config.max_steps = 80000 config.mini_batch_size = 50 # Training Hyperparameters diff --git a/examples/curriculum_example.py b/examples/curriculum_example.py new file mode 100644 index 0000000..9606268 --- /dev/null +++ b/examples/curriculum_example.py @@ -0,0 +1,50 @@ +import numpy as np +import torch +import pickle + +from conformer_rl import utils +from conformer_rl.config import Config +from conformer_rl.environments import Task +from conformer_rl.models import RTGNRecurrent + +from conformer_rl.molecule_generation.generate_alkanes import generate_branched_alkane +from conformer_rl.molecule_generation.generate_molecule_config import config_from_rdkit +from conformer_rl.agents import PPORecurrentExternalCurriculumAgent + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + +import logging +logging.basicConfig(level=logging.DEBUG) + + +if __name__ == '__main__': + utils.set_one_thread() + + # Create mol_configs for the curriculum + mol_configs = [config_from_rdkit(generate_branched_alkane(i), num_conformers=200, calc_normalizers=True) for i in range(8, 16)] + eval_mol_config = config_from_rdkit(generate_branched_alkane(16), num_conformers=200, calc_normalizers=True) + + config = Config() + config.tag = 'curriculum_test' + config.network = RTGNRecurrent(6, 128, edge_dim=6, node_dim=5).to(device) + + # Batch Hyperparameters + config.max_steps = 100000 + + # training Hyperparameters + lr = 5e-6 * np.sqrt(10) + config.optimizer_fn = lambda params: torch.optim.Adam(params, lr=lr, eps=1e-5) + + # Task Settings + config.train_env = Task('GibbsScorePruningCurriculumEnv-v0', concurrency=True, num_envs=10, seed=np.random.randint(0,1e5), mol_configs=mol_configs) + config.eval_env = Task('GibbsScorePruningEnv-v0', seed=np.random.randint(0,7e4), mol_config=eval_mol_config) + config.eval_interval = 20000 + + # curriculum Hyperparameters + config.curriculum_agent_buffer_len = 20 + config.curriculum_agent_reward_thresh = 0.4 + config.curriculum_agent_success_rate = 0.7 + config.curriculum_agent_fail_rate = 0.2 + + agent = PPORecurrentExternalCurriculumAgent(config) + agent.run_steps() \ No newline at end of file diff --git a/examples/custom_env_basic/run.py b/examples/custom_env_basic/run.py index 4a3c113..b8744cc 100644 --- a/examples/custom_env_basic/run.py +++ b/examples/custom_env_basic/run.py @@ -23,7 +23,7 @@ # Create config object mol = generate_branched_alkane(14) - mol_config = config_from_rdkit(mol, calc_normalizers=True, ep_steps=200, save_file='alkane') + mol_config = config_from_rdkit(mol, num_conformers=200, calc_normalizers=True, save_file='alkane') # Create agent training config object config = Config() @@ -35,8 +35,8 @@ config.network = RTGNRecurrent(6, 128, edge_dim=6, node_dim=2).to(device) # Set the environment to the test env - config.train_env = Task('TestEnv-v0', concurrency=True, num_envs=5, seed=np.random.randint(0,1e5), mol_config=mol_config, max_steps=200) - config.eval_env = Task('TestEnv-v0', seed=np.random.randint(0,7e4), mol_config=mol_config, max_steps=200) + config.train_env = Task('TestEnv-v0', concurrency=True, num_envs=5, seed=np.random.randint(0,1e5), mol_config=mol_config) + config.eval_env = Task('TestEnv-v0', seed=np.random.randint(0,7e4), mol_config=mol_config) config.eval_episodes=10000 agent = PPORecurrentAgent(config) diff --git a/examples/example1.py b/examples/example1.py index 42c42f0..7fcf818 100644 --- a/examples/example1.py +++ b/examples/example1.py @@ -17,15 +17,15 @@ # Create config object mol = generate_branched_alkane(14) - mol_config = config_from_rdkit(mol, calc_normalizers=True, ep_steps=200, save_file='alkane') + mol_config = config_from_rdkit(mol, num_conformers=200, calc_normalizers=True, save_file='alkane') # Create agent training config object config = Config() config.tag = 'example1' # Configure Environment - config.train_env = Task('GibbsScorePruningEnv-v0', concurrency=True, num_envs=5, seed=np.random.randint(0,1e5), mol_config=mol_config, max_steps=200) - config.eval_env = Task('GibbsScorePruningEnv-v0', seed=np.random.randint(0,7e4), mol_config=mol_config, max_steps=200) + config.train_env = Task('GibbsScorePruningEnv-v0', concurrency=True, num_envs=5, seed=np.random.randint(0,1e5), mol_config=mol_config) + config.eval_env = Task('GibbsScorePruningEnv-v0', seed=np.random.randint(0,7e4), mol_config=mol_config) config.eval_episodes=10000 agent = PPORecurrentAgent(config) diff --git a/setup.py b/setup.py index 6427b7f..bed5242 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ setuptools.setup( name="conformer-rl", - version="1.0.0", + version="1.1.0", description="Deep Reinforcement Library for Conformer Generation", long_description=README, long_description_content_type="text/markdown", diff --git a/src/conformer_rl/agents/PPO/PPO_recurrent_agent.py b/src/conformer_rl/agents/PPO/PPO_recurrent_agent.py index e949a0c..a392522 100644 --- a/src/conformer_rl/agents/PPO/PPO_recurrent_agent.py +++ b/src/conformer_rl/agents/PPO/PPO_recurrent_agent.py @@ -23,7 +23,7 @@ class PPORecurrentAgent(BaseACAgentRecurrent): Notes ----- - *Config parameters:* The following parameters are required in the `config` object. See :class:`~conformer_rl.config.agent_config.Config` + *Config parameters:* The following parameters are required in the ``config`` object. See :class:`~conformer_rl.config.agent_config.Config` for more details on the parameters. * tag diff --git a/src/conformer_rl/agents/__init__.py b/src/conformer_rl/agents/__init__.py index 4f45ca4..892be82 100644 --- a/src/conformer_rl/agents/__init__.py +++ b/src/conformer_rl/agents/__init__.py @@ -1,2 +1,5 @@ from .A2C import A2CAgent, A2CRecurrentAgent -from .PPO import PPOAgent, PPORecurrentAgent \ No newline at end of file +from .PPO import PPOAgent, PPORecurrentAgent + +from .curriculum_agents import A2CExternalCurriculumAgent, A2CRecurrentExternalCurriculumAgent +from .curriculum_agents import PPOExternalCurriculumAgent, PPORecurrentExternalCurriculumAgent \ No newline at end of file diff --git a/src/conformer_rl/agents/base_ac_agent.py b/src/conformer_rl/agents/base_ac_agent.py index 27e4edd..769d6a0 100644 --- a/src/conformer_rl/agents/base_ac_agent.py +++ b/src/conformer_rl/agents/base_ac_agent.py @@ -71,6 +71,7 @@ def _sample(self) -> None: storage.append(prediction) storage.append({ 'states': states, + 'terminals': torch.tensor(terminals).unsqueeze(-1).to(device), 'r': torch.tensor(rewards).unsqueeze(-1).to(device), 'm': torch.tensor(1 - terminals).unsqueeze(-1).to(device) }) diff --git a/src/conformer_rl/agents/base_ac_agent_recurrent.py b/src/conformer_rl/agents/base_ac_agent_recurrent.py index 5076353..ce3fc54 100644 --- a/src/conformer_rl/agents/base_ac_agent_recurrent.py +++ b/src/conformer_rl/agents/base_ac_agent_recurrent.py @@ -34,20 +34,6 @@ def __init__(self, config: Config): state.zero_() self.recurrence = self.config.recurrence - - def step(self) -> None: - """Performs one iteration of acquiring samples on the environment - and then trains on the acquired samples. - """ - self.storage.reset() - with torch.no_grad(): - sample_start = time.time() - self._sample() - logging.debug(f'sample time: {time.time() - sample_start} seconds') - train_start = time.time() - self._calculate_advantages() - self._train() - logging.debug(f'train time: {time.time() - train_start} seconds') def _sample(self) -> None: """Collects samples from the training environment. @@ -55,41 +41,44 @@ def _sample(self) -> None: config = self.config states = self.states storage = self.storage + + with torch.no_grad(): ############################################################################################## #Sampling Loop ############################################################################################## - for _ in range(config.rollout_length): - self.total_steps += self.num_workers - - #add recurrent states (lstm hidden and lstm cell states) to storage - storage.append({f'recurrent_states_{i}' : rstate for i, rstate in enumerate(self.recurrent_states)}) - - #run the neural net once to get prediction - prediction, self.recurrent_states = self.network(states, self.recurrent_states) - - #step the environment with the action determined by the prediction - next_states, rewards, terminals, _ = self.task.step(to_np(prediction['a'])) - - self.total_rewards += np.asarray(rewards) - - for idx, done in enumerate(terminals): - if done: - logging.info(f'logging episodic return train... {self.total_steps}') - self.train_logger.add_scalar('episodic_return_train', self.total_rewards[idx], self.total_steps) - self.total_rewards[idx] = 0. - - # zero out lstm states for finished environments - for rstate in self.recurrent_states: - rstate[:, idx].zero_() - - #add everything to storage - storage.append(prediction) - storage.append({ - 'states': states, - 'r': torch.tensor(rewards).unsqueeze(-1).to(device), - 'm': torch.tensor(1 - terminals).unsqueeze(-1).to(device) - }) - states = next_states + for _ in range(config.rollout_length): + self.total_steps += self.num_workers + + #add recurrent states (lstm hidden and lstm cell states) to storage + storage.append({f'recurrent_states_{i}' : rstate for i, rstate in enumerate(self.recurrent_states)}) + + #run the neural net once to get prediction + prediction, self.recurrent_states = self.network(states, self.recurrent_states) + + #step the environment with the action determined by the prediction + next_states, rewards, terminals, _ = self.task.step(to_np(prediction['a'])) + + self.total_rewards += np.asarray(rewards) + + for idx, done in enumerate(terminals): + if done: + logging.info(f'logging episodic return train... {self.total_steps}') + self.train_logger.add_scalar('episodic_return_train', self.total_rewards[idx], self.total_steps) + self.total_rewards[idx] = 0. + + # zero out lstm states for finished environments + for rstate in self.recurrent_states: + rstate[:, idx].zero_() + + #add everything to storage + storage.append(prediction) + storage.append({ + 'states': states, + 'terminals': torch.tensor(terminals).unsqueeze(-1).to(device), + 'r': torch.tensor(rewards).unsqueeze(-1).to(device), + 'm': torch.tensor(1 - terminals).unsqueeze(-1).to(device) + }) + states = next_states self.states = states diff --git a/src/conformer_rl/agents/curriculum_agents.py b/src/conformer_rl/agents/curriculum_agents.py new file mode 100644 index 0000000..8d11715 --- /dev/null +++ b/src/conformer_rl/agents/curriculum_agents.py @@ -0,0 +1,122 @@ +""" +Curriculum-Supported Agents +=========================== +""" + +import logging +import time +from collections import deque + +import numpy as np +import torch + +class ExternalCurriculumAgentMixin(): + """General mixin class to enable curriculum + + Adds functionality to an existing agent for externally interacting with an environment supporting curriculum learning. + + + Parameters + ---------- + config : :class:`~conformer_rl.config.agent_config.Config` + Configuration object for the agent. See notes for a list of config + parameters used by this agent. + + Notes + ----- + In addition to the config parameters required for the base agent class, use of this mixin + requires the following additional parameters in the :class:`~conformer_rl.config.agent_config.Config` object: + + * curriculum_agent_buffer_len + * curriculum_agent_reward_thresh + * curriculum_agent_success_rate + * curriculum_agent_fail_rate + """ + def __init__(self, config): + super().__init__(config) + self.reward_buffer = deque([], maxlen=config.curriculum_agent_buffer_len) + self.curriculum_buffer_len = config.curriculum_agent_buffer_len + self.curriculum_reward_thresh = config.curriculum_agent_reward_thresh + self.curriculum_success_rate = config.curriculum_agent_success_rate + self.curriculum_fail_rate = config.curriculum_agent_fail_rate + + def step(self) -> None: + """Performs one iteration of acquiring samples on the environment + and then trains on the acquired samples. + """ + # sample + self.storage.reset() + sample_start = time.time() + self._sample() + logging.debug(f'sample time: {time.time() - sample_start} seconds') + + # update curriculum + self.update_curriculum() + + # train + train_start = time.time() + self._calculate_advantages() + self._train() + logging.debug(f'train time: {time.time() - train_start} seconds') + + def update_curriculum(self) -> None: + """Evaluates the current performance of the agent and signals the environment to + increase the level (difficulty) or decrease it depending on the agent's performance. + + The agent is evaluated only when the number of episodes elapsed since the last evaluation + has exceeded the parameter ``curriculum_agent_buffer_len`` assigned in the :class:`~conformer_rl.config.agent_config.Config` object. + During the evaluation, the ratio of episodes (out of the last ``curriculum_agent_buffer_len`` episodes) which have a reward exceeding + the ``curriculum_agent_reward_thresh`` parameter defined in the :class:`~conformer_rl.config.agent_config.Config` + is calculated. If this ratio exceeds the ``curriculum_agent_success_rate`` parameter, the environment is signaled + to increase the difficulty of the curriculum. This is done by calling the ``increase_level`` method of the environment. + If the ratio is less than the ``curriculum_agent_fail_rate`` parameter, the environment is told to decrease the difficulty. + """ + current_terminals = torch.cat(self.storage['terminals']).squeeze() + current_rewards = torch.cat(self.storage['r']).squeeze() + self.reward_buffer.extend(current_rewards[current_terminals == True].tolist()) + + if len(self.reward_buffer) >= self.curriculum_buffer_len: + rewbuf = np.array(self.reward_buffer) + pass_rate = (rewbuf >= self.curriculum_reward_thresh).mean() + + if pass_rate > self.curriculum_success_rate: + self.task.env_method('increase_level') + self.reward_buffer.clear() + elif pass_rate < self.curriculum_fail_rate: + self.task.env_method('decrease_level') + self.reward_buffer.clear() + +from conformer_rl.agents import PPOAgent, PPORecurrentAgent +from conformer_rl.agents import A2CAgent, A2CRecurrentAgent + +class PPOExternalCurriculumAgent(ExternalCurriculumAgentMixin, PPOAgent): + """Implementation of :mod:`~conformer_rl.agents.PPO.PPO_agent.PPOAgent` compatible with + environments that use curriculum learning. See :meth:`~conformer_rl.agents.curriculum_agents.ExternalCurriculumAgentMixin.update_curriculum` + for more details. + + """ + pass + +class PPORecurrentExternalCurriculumAgent(ExternalCurriculumAgentMixin, PPORecurrentAgent): + """Implementation of :mod:`~conformer_rl.agents.PPO.PPO_recurrent_agent.PPORecurrentAgent` compatible with + environments that use curriculum learning. See :meth:`~conformer_rl.agents.curriculum_agents.ExternalCurriculumAgentMixin.update_curriculum` + for more details. + + """ + pass + +class A2CExternalCurriculumAgent(ExternalCurriculumAgentMixin, A2CAgent): + """Implementation of :mod:`~conformer_rl.agents.A2C.A2C_agent.A2CAgent` compatible with + environments that use curriculum learning. See :meth:`~conformer_rl.agents.curriculum_agents.ExternalCurriculumAgentMixin.update_curriculum` + for more details. + + """ + pass + +class A2CRecurrentExternalCurriculumAgent(ExternalCurriculumAgentMixin, A2CRecurrentAgent): + """Implementation of :mod:`~conformer_rl.agents.A2C.A2C_recurrent_agent.A2CRecurrentAgent` compatible with + environments that use curriculum learning. See :meth:`~conformer_rl.agents.curriculum_agents.ExternalCurriculumAgentMixin.update_curriculum` + for more details. + + """ + pass \ No newline at end of file diff --git a/src/conformer_rl/config/agent_config.py b/src/conformer_rl/config/agent_config.py index 7968e78..a2e5976 100644 --- a/src/conformer_rl/config/agent_config.py +++ b/src/conformer_rl/config/agent_config.py @@ -64,15 +64,30 @@ class Config: ppo_ratio_clip : float, required by PPO and PPORecurrent agents. Clipping parameter ε for PPO algorithm, see [2]_ for details. + curriculum_agent_buffer_len : int, required by all curriculum agents + The number of most recent completed episodes in which to evaluate the agent on for curriculum learning. + See :meth:`~conformer_rl.agents.curriculum_agents.ExternalCurriculumAgentMixin.update_curriculum` for more details on how + curriculum learning is implemented. + curriculum_agent_reward_thresh : float, required by all curriculum agents + The reward threshold for considering the agent to have "succeeded" in an episode. Used for evaluating the agent + for curriculum learning. + See :meth:`~conformer_rl.agents.curriculum_agents.ExternalCurriculumAgentMixin.update_curriculum` for more details on how + curriculum learning is implemented. + curriculum_agent_success_rate : float, required by all curriculum agents + The minimum success rate for the agent to signal the environment to increase the level/difficulty for the curriculum. + See :meth:`~conformer_rl.agents.curriculum_agents.ExternalCurriculumAgentMixin.update_curriculum` for more details on how + curriculum learning is implemented. + curriculum_agent_fail_rate : float, required by all curriculum agents + The maximum success rate for the agent to signal the environment to decrease the level/difficulty of the curriculum. + See :meth:`~conformer_rl.agents.curriculum_agents.ExternalCurriculumAgentMixin.update_curriculum` for more details on how + curriculum learning is implemented. + data_dir : str, required by all agents Directory path for saving log files. use_tensorboard : bool, required by all agents Whether or not to save agent information to Tensorboard. - - - References ---------- .. [1] `Generalized Advantage Estimation (GAE) paper `_ @@ -92,8 +107,6 @@ def __init__(self): self.network = RTGNGatRecurrent(6, 128, node_dim=5).to(self.device) self.optimizer_fn = lambda params : torch.optim.Adam(params, lr=1e-5, eps=1e-5) - # self.curriculum = None - # batch hyperparameters self.rollout_length = 20 self.max_steps = 50000 @@ -113,6 +126,12 @@ def __init__(self): self.gradient_clip = 0.5 self.ppo_ratio_clip = 0.2 + # curriculum hyperparameters + self.curriculum_agent_buffer_len = 20 + self.curriculum_agent_reward_thresh = 0.7 + self.curriculum_agent_success_rate = 0.7 + self.curriculum_agent_fail_rate = 0.2 + # logging config self.data_dir = 'data' self.use_tensorboard = True diff --git a/src/conformer_rl/config/mol_config.py b/src/conformer_rl/config/mol_config.py index 7b5d950..63096dc 100644 --- a/src/conformer_rl/config/mol_config.py +++ b/src/conformer_rl/config/mol_config.py @@ -13,6 +13,9 @@ class MolConfig: ---------- mol : rdkit Mol, required for all environments The molecule to be used by the environment. + num_conformers : int + The number of conformers to be generated in each episode of the environment for this molecule. In other words, + the number of steps the environment will iterate through in each episode before resetting and entering a new episode. seed: int, required for all environments Seed for generating initial conformers for the molecule. If set to -1, the seed is randomized. @@ -34,6 +37,7 @@ class MolConfig: def __init__(self): self.mol = None + self.num_conformers = 200 self.seed = -1 # Parameters for using Gibbs Score diff --git a/src/conformer_rl/environments/conformer_env.py b/src/conformer_rl/environments/conformer_env.py index 23efe8f..7f99dce 100644 --- a/src/conformer_rl/environments/conformer_env.py +++ b/src/conformer_rl/environments/conformer_env.py @@ -21,8 +21,6 @@ class ConformerEnv(gym.Env): ---------- mol_config : :class:`~conformer_rl.config.mol_config.MolConfig` Configuration object specifying molecule and parameters to be used in the environment. - max_steps : int - The number of steps before the end of an episode. Attributes ---------- @@ -41,11 +39,11 @@ class ConformerEnv(gym.Env): """ metadata = {'render.modes': ['human']} - def __init__(self, mol_config: MolConfig, max_steps = 200): - super(ConformerEnv, self).__init__() + def __init__(self, mol_config: MolConfig): + gym.Env.__init__(self) logging.debug('initializing conformer environment') self.config = copy.deepcopy(mol_config) - self.max_steps = max_steps + self.max_steps = mol_config.num_conformers self.total_reward = 0 self.current_step = 0 @@ -121,6 +119,7 @@ def reset(self) -> object: self.current_step = 0 self.step_info = {} + self.episode_info = {} self.episode_info['mol'] = Chem.Mol(self.mol) self.episode_info['mol'].RemoveAllConformers() diff --git a/src/conformer_rl/environments/curriculum_conformer_env.py b/src/conformer_rl/environments/curriculum_conformer_env.py new file mode 100644 index 0000000..c4f55c6 --- /dev/null +++ b/src/conformer_rl/environments/curriculum_conformer_env.py @@ -0,0 +1,113 @@ +""" +Curriculum Conformer_env +======================== +""" + +import logging +from typing import List +import copy + +import numpy as np +from rdkit.Chem import AllChem as Chem +from rdkit.Chem import TorsionFingerprints +import gym + +from conformer_rl.config import MolConfig +from conformer_rl.environments.conformer_env import ConformerEnv + +class CurriculumConformerEnv(ConformerEnv): + """Base interface for building conformer generation environments with support for curriculum learning. + + Parameters + ---------- + mol_configs : list of :class:`~conformer_rl.config.mol_config.MolConfig` + List of configuration object specifying the molecules and their corresponding parameters to be trained on + as part of the curriculum. The list should be sorted in order of increasing task difficulty. + + Attributes + ---------- + configs : list of :class:`~conformer_rl.config.mol_config.MolConfig` + Configuration objects specifying molecules and corresponding parameters to be used in the environment, + in the order of the designated curriculum (ordered from least to most difficult). + total_reward : float + Keeps track of the total reward for the current episode. + current_step : int + Keeps track of the number of elapsed steps in the current episode. + step_info : dict from str to list + Used for keeping track of data obtained at each step of an episode for logging. + episode_info : dict from str to Any + Used for keeping track of data useful at the end of an episode, such as total_reward, for logging. + curriculum_max_index : int + One plus the maximum index in which a molecule/task from the input list of ``mol_configs`` can be selected to be trained on. + This attribute will be increased as the agent gets better at the current tasks in the curriculum and is ready to move on to + more difficult tasks. + """ + + def __init__(self, mol_configs: List[MolConfig]): + gym.Env.__init__(self) + logging.debug('initializing curriculum conformer environment') + self.configs = copy.deepcopy(mol_configs) + self.curriculum_max_index = 1 + + self.config = self.configs[0] + self.mol = self.config.mol + self.mol.RemoveAllConformers() + if Chem.EmbedMolecule(self.mol, randomSeed=self.config.seed, useRandomCoords=True) == -1: + raise Exception('Unable to embed molecule with conformer using rdkit') + self.conf = self.mol.GetConformer() + nonring, ring = TorsionFingerprints.CalculateTorsionLists(self.mol) + self.nonring = [list(atoms[0]) for atoms, ang in nonring] + + self.reset() + + def reset(self) -> object: + """Resets the environment and returns the observation of the environment. + """ + logging.debug('reset called') + + self.total_reward = 0 + self.current_step = 0 + self.step_info = {} + self.episode_info = {} + + # set index for the next molecule based on curriculum + if self.curriculum_max_index == 1: + index = 0 + else: + p = 0.5 * np.ones(self.curriculum_max_index) / (self.curriculum_max_index - 1) + p[-1] = 0.5 + index = np.random.choice(self.curriculum_max_index, p=p) + + logging.debug(f'Current Curriculum Molecule Index: {index}') + + # set up current molecule + mol_config = self.configs[index] + self.config = mol_config + self.max_steps = mol_config.num_conformers + self.mol = mol_config.mol + self.mol.RemoveAllConformers() + if Chem.EmbedMolecule(self.mol, randomSeed=self.config.seed, useRandomCoords=True) == -1: + raise Exception('Unable to embed molecule with conformer using rdkit') + self.conf = self.mol.GetConformer() + nonring, ring = TorsionFingerprints.CalculateTorsionLists(self.mol) + self.nonring = [list(atoms[0]) for atoms, ang in nonring] + + self.episode_info['mol'] = Chem.Mol(self.mol) + self.episode_info['mol'].RemoveAllConformers() + + obs = self._obs() + return obs + + + def increase_level(self): + """Updates the ``curriculum_max_index`` attribute after obtaining signal from the agent that a favorable + reward threshold has been achieved. + """ + self.curriculum_max_index = min(self.curriculum_max_index * 2, len(self.configs)) + + def decrease_level(self): + """Updates the ``curriculum_max_index`` attribute after obtaining signal that the agent is performing + poorly on the current curriclum range. + """ + if self.curriculum_max_index > 1: + self.curriculum_max_index = self.curriculum_max_index // 2 \ No newline at end of file diff --git a/src/conformer_rl/environments/environment_components/reward_mixins.py b/src/conformer_rl/environments/environment_components/reward_mixins.py index fbb65be..23cf61e 100644 --- a/src/conformer_rl/environments/environment_components/reward_mixins.py +++ b/src/conformer_rl/environments/environment_components/reward_mixins.py @@ -21,10 +21,11 @@ class GibbsRewardMixin: .. [1] `TorsionNet paper `_ """ def reset(self): + obs = super().reset() self.seen = set() self.repeats = 0 self.episode_info['repeats'] = 0 - return super().reset() + return obs def _reward(self) -> float: """ @@ -55,9 +56,10 @@ class GibbsEndPruningRewardMixin: """ def reset(self): + obs = super().reset() self.backup_mol = Chem.Mol(self.mol) self.backup_mol.RemoveAllConformers() - return super().reset() + return obs def _reward(self) -> float: """ @@ -94,10 +96,11 @@ class GibbsPruningRewardMixin: """ def reset(self): + obs = super().reset() self.backup_mol = Chem.Mol(self.mol) self.backup_mol.RemoveAllConformers() self.backup_energys = [] - return super().reset() + return obs def _reward(self) -> float: """ diff --git a/src/conformer_rl/environments/environments.py b/src/conformer_rl/environments/environments.py index e5ef3cf..9c16051 100644 --- a/src/conformer_rl/environments/environments.py +++ b/src/conformer_rl/environments/environments.py @@ -19,12 +19,14 @@ """ from conformer_rl.environments.conformer_env import ConformerEnv +from conformer_rl.environments.curriculum_conformer_env import CurriculumConformerEnv from conformer_rl.environments.environment_components.action_mixins import ContinuousActionMixin, DiscreteActionMixin from conformer_rl.environments.environment_components.reward_mixins import GibbsRewardMixin, GibbsPruningRewardMixin, GibbsEndPruningRewardMixin, GibbsLogPruningRewardMixin from conformer_rl.environments.environment_components.obs_mixins import GraphObsMixin, AtomCoordsTypeGraphObsMixin + class DiscreteActionEnv(DiscreteActionMixin, GraphObsMixin, ConformerEnv): """ * Action Handler: :class:`~conformer_rl.environments.environment_components.action_mixins.DiscreteActionMixin` @@ -67,4 +69,10 @@ class GibbsScoreLogPruningEnv(GibbsLogPruningRewardMixin, DiscreteActionMixin, A * Reward Handler: :class:`~conformer_rl.environments.environment_components.reward_mixins.GibbsLogPruningRewardMixin` * Observation Handler: :class:`~conformer_rl.environments.environment_components.obs_mixins.AtomCoordsTypeGraphObsMixin` """ - pass \ No newline at end of file + pass + +class GibbsScorePruningCurriculumEnv(GibbsPruningRewardMixin, AtomCoordsTypeGraphObsMixin, DiscreteActionMixin, CurriculumConformerEnv): + """Same handlers as the :class:`~conformer_rl.environments.environment.GibbsScorePruningEnv` but with support for curriculum learning.""" + +class GibbsScoreLogPruningCurriculumEnv(GibbsLogPruningRewardMixin, DiscreteActionMixin, AtomCoordsTypeGraphObsMixin, CurriculumConformerEnv): + """Same handlers as the :class:`~conformer_rl.environments.environment.GibbsScoreLogPruningEnv` but with support for curriculum learning.""" \ No newline at end of file diff --git a/src/conformer_rl/molecule_generation/generate_molecule_config.py b/src/conformer_rl/molecule_generation/generate_molecule_config.py index 0edbb1f..874945a 100644 --- a/src/conformer_rl/molecule_generation/generate_molecule_config.py +++ b/src/conformer_rl/molecule_generation/generate_molecule_config.py @@ -11,13 +11,13 @@ import pickle def test_alkane_config() -> MolConfig: - config = config_from_smiles("CC(CCC)CCCC(CCCC)CC", calc_normalizers=False) + config = config_from_smiles("CC(CCC)CCCC(CCCC)CC", num_conformers=200, calc_normalizers=False) config.E0 = 7.668625034772399 config.Z0 = 13.263723987526067 config.tau = 503 return config -def config_from_molFile(file: str, calc_normalizers: bool = False, ep_steps: int = 200, pruning_thresh: float = 0.05, save_file: str = "") -> MolConfig: +def config_from_molFile(file: str, num_conformers: int, calc_normalizers: bool = False, pruning_thresh: float = 0.05, save_file: str = "") -> MolConfig: """Generates a :class:`~conformer_rl.config.mol_config.MolConfig` object for a molecule specified by the location of a `MOL `_ file containing the molecule. @@ -27,14 +27,13 @@ def config_from_molFile(file: str, calc_normalizers: bool = False, ep_steps: int file : str Name of the MOL file containing the molecule to be converted into a :class:`~conformer_rl.config.mol_config.MolConfig` object. + num_conformers : int + Number of conformers to be generated. This parameter is also used for calculating normalizers. + calc_normalizers : bool Whether to calculate normalizing constants used in the Gibbs score reward. See :class:`~conformer_rl.config.mol_config.MolConfig` for more details. - ep_steps : int - Number of conformers to be generated. This parameter is only used for calculating normalizers and is ignored - if ``calc_normalizers`` is set to ``False``. - pruning_thresh : float Torsional fingerprint distance (TFD) threshold for pruning similar conformers when calculating normalizers. This parameter is only used for calculating normalizers and is ignored @@ -52,9 +51,9 @@ def config_from_molFile(file: str, calc_normalizers: bool = False, ep_steps: int """ mol = Chem.MolFromMolFile(file) - return config_from_rdkit(mol, calc_normalizers, ep_steps, pruning_thresh, save_file) + return config_from_rdkit(mol, num_conformers, calc_normalizers, pruning_thresh, save_file) -def config_from_smiles(smiles: str, calc_normalizers: bool = False, ep_steps: int = 200, pruning_thresh: float = 0.05, save_file: str = "") -> MolConfig: +def config_from_smiles(smiles: str, num_conformers: int, calc_normalizers: bool = False, pruning_thresh: float = 0.05, save_file: str = "") -> MolConfig: """Generates a :class:`~conformer_rl.config.mol_config.MolConfig` object for a molecule specified by a `SMILES `_ string. @@ -63,14 +62,13 @@ def config_from_smiles(smiles: str, calc_normalizers: bool = False, ep_steps: in smiles : str A SMILES string representing the molecule. + num_conformers : int + Number of conformers to be generated. This parameter is also used for calculating normalizers. + calc_normalizers : bool Whether to calculate normalizing constants used in the Gibbs score reward. See :class:`~conformer_rl.config.mol_config.MolConfig` for more details. - ep_steps : int - Number of conformers to be generated. This parameter is only used for calculating normalizers and is ignored - if ``calc_normalizers`` is set to ``False``. - pruning_thresh : float Torsional fingerprint distance (TFD) threshold for pruning similar conformers when calculating normalizers. This parameter is only used for calculating normalizers and is ignored @@ -87,9 +85,9 @@ def config_from_smiles(smiles: str, calc_normalizers: bool = False, ep_steps: in normalizing constants if ``calc_normalizers`` is set to ``True``. """ mol = Chem.MolFromSmiles(smiles) - return config_from_rdkit(mol, calc_normalizers, ep_steps, pruning_thresh, save_file) + return config_from_rdkit(mol, num_conformers, calc_normalizers, pruning_thresh, save_file) -def config_from_rdkit(mol: Chem.rdchem.Mol, calc_normalizers: bool = False, ep_steps: int=200, pruning_thresh: float=0.05, save_file: str = "") -> MolConfig: +def config_from_rdkit(mol: Chem.rdchem.Mol, num_conformers: int, calc_normalizers: bool = False, pruning_thresh: float=0.05, save_file: str = "") -> MolConfig: """Generates a :class:`~conformer_rl.config.mol_config.MolConfig` object for a molecule specified by an rdkit molecule object. Parameters @@ -97,14 +95,13 @@ def config_from_rdkit(mol: Chem.rdchem.Mol, calc_normalizers: bool = False, ep_s mol: rdkit.Chem.rdchem.Mol A rdkit molecule object. + num_conformers : int + Number of conformers to be generated. This parameter is also used for calculating normalizers. + calc_normalizers : bool Whether to calculate normalizing constants used in the Gibbs score reward. See :class:`~conformer_rl.config.mol_config.MolConfig` for more details. - ep_steps : int - Number of conformers to be generated. This parameter is only used for calculating normalizers and is ignored - if ``calc_normalizers`` is set to ``False``. - pruning_thresh : float Torsional fingerprint distance (TFD) threshold for pruning similar conformers when calculating normalizers. This parameter is only used for calculating normalizers and is ignored @@ -124,8 +121,9 @@ def config_from_rdkit(mol: Chem.rdchem.Mol, calc_normalizers: bool = False, ep_s config = MolConfig() mol = _preprocess_mol(mol) config.mol = mol + config.num_conformers = num_conformers if calc_normalizers: - config.E0, config.Z0 = calculate_normalizers(mol, ep_steps, pruning_thresh) + config.E0, config.Z0 = calculate_normalizers(mol, num_conformers, pruning_thresh) logging.info('mol_config object constructed for the following molecule:') logging.info(Chem.MolToMolBlock(mol)) diff --git a/tests/model_integration/test_gat.py b/tests/model_integration/test_gat.py index e4d0932..6bacf74 100644 --- a/tests/model_integration/test_gat.py +++ b/tests/model_integration/test_gat.py @@ -18,7 +18,7 @@ def test_gat(mocker): utils.set_one_thread() - mol_config = config_from_rdkit(generate_lignin(2), calc_normalizers=True) + mol_config = config_from_rdkit(generate_lignin(2), num_conformers=8, calc_normalizers=True) config = Config() config.tag = 'example1' @@ -39,8 +39,8 @@ def test_gat(mocker): config.optimizer_fn = lambda params: torch.optim.Adam(params, lr=lr, eps=1e-5) # Task Settings - config.train_env = Task('GibbsScorePruningEnv-v0', concurrency=False, num_envs=config.num_workers, seed=np.random.randint(0,1e5), mol_config=mol_config, max_steps=4) - config.eval_env = Task('GibbsScorePruningEnv-v0', seed=np.random.randint(0,7e4), mol_config=mol_config, max_steps=20) + config.train_env = Task('GibbsScorePruningEnv-v0', concurrency=False, num_envs=config.num_workers, seed=np.random.randint(0,1e5), mol_config=mol_config) + config.eval_env = Task('GibbsScorePruningEnv-v0', seed=np.random.randint(0,7e4), mol_config=mol_config) config.curriculum = None agent = PPOAgent(config) diff --git a/tests/model_integration/test_gat_recurrent.py b/tests/model_integration/test_gat_recurrent.py index d9ecc76..7580a9d 100644 --- a/tests/model_integration/test_gat_recurrent.py +++ b/tests/model_integration/test_gat_recurrent.py @@ -18,7 +18,7 @@ def test_gat_recurrent(mocker): utils.set_one_thread() - mol_config = config_from_rdkit(generate_lignin(2), calc_normalizers=True) + mol_config = config_from_rdkit(generate_lignin(2), num_conformers=8, calc_normalizers=True) config = Config() config.tag = 'example1' @@ -40,8 +40,8 @@ def test_gat_recurrent(mocker): config.optimizer_fn = lambda params: torch.optim.Adam(params, lr=lr, eps=1e-5) # Task Settings - config.train_env = Task('GibbsScorePruningEnv-v0', concurrency=False, num_envs=config.num_workers, seed=np.random.randint(0,1e5), mol_config=mol_config, max_steps=4) - config.eval_env = Task('GibbsScorePruningEnv-v0', seed=np.random.randint(0,7e4), mol_config=mol_config, max_steps=20) + config.train_env = Task('GibbsScorePruningEnv-v0', concurrency=False, num_envs=config.num_workers, seed=np.random.randint(0,1e5), mol_config=mol_config) + config.eval_env = Task('GibbsScorePruningEnv-v0', seed=np.random.randint(0,7e4), mol_config=mol_config) config.curriculum = None agent = PPORecurrentAgent(config) diff --git a/tests/model_integration/test_rtgn.py b/tests/model_integration/test_rtgn.py index 2da92d4..613be67 100644 --- a/tests/model_integration/test_rtgn.py +++ b/tests/model_integration/test_rtgn.py @@ -18,7 +18,7 @@ def test_rtgn(mocker): utils.set_one_thread() - mol_config = config_from_rdkit(generate_lignin(2), calc_normalizers=True) + mol_config = config_from_rdkit(generate_lignin(2), num_conformers=8, calc_normalizers=True) config = Config() config.tag = 'example1' @@ -39,8 +39,8 @@ def test_rtgn(mocker): config.optimizer_fn = lambda params: torch.optim.Adam(params, lr=lr, eps=1e-5) # Task Settings - config.train_env = Task('GibbsScorePruningEnv-v0', concurrency=False, num_envs=config.num_workers, seed=np.random.randint(0,1e5), mol_config=mol_config, max_steps=4) - config.eval_env = Task('GibbsScorePruningEnv-v0', seed=np.random.randint(0,7e4), mol_config=mol_config, max_steps=20) + config.train_env = Task('GibbsScorePruningEnv-v0', concurrency=False, num_envs=config.num_workers, seed=np.random.randint(0,1e5), mol_config=mol_config) + config.eval_env = Task('GibbsScorePruningEnv-v0', seed=np.random.randint(0,7e4), mol_config=mol_config) config.curriculum = None agent = PPOAgent(config) diff --git a/tests/model_integration/test_rtgn_recurrent.py b/tests/model_integration/test_rtgn_recurrent.py index 7a3f2c5..230c4d6 100644 --- a/tests/model_integration/test_rtgn_recurrent.py +++ b/tests/model_integration/test_rtgn_recurrent.py @@ -18,7 +18,7 @@ def test_rtgn_recurrent(mocker): utils.set_one_thread() - mol_config = config_from_rdkit(generate_lignin(2), calc_normalizers=True) + mol_config = config_from_rdkit(generate_lignin(2), num_conformers=8, calc_normalizers=True) config = Config() config.tag = 'example1' @@ -39,8 +39,8 @@ def test_rtgn_recurrent(mocker): config.optimizer_fn = lambda params: torch.optim.Adam(params, lr=lr, eps=1e-5) # Task Settings - config.train_env = Task('GibbsScorePruningEnv-v0', concurrency=False, num_envs=config.num_workers, seed=np.random.randint(0,1e5), mol_config=mol_config, max_steps=4) - config.eval_env = Task('GibbsScorePruningEnv-v0', seed=np.random.randint(0,7e4), mol_config=mol_config, max_steps=20) + config.train_env = Task('GibbsScorePruningEnv-v0', concurrency=False, num_envs=config.num_workers, seed=np.random.randint(0,1e5), mol_config=mol_config) + config.eval_env = Task('GibbsScorePruningEnv-v0', seed=np.random.randint(0,7e4), mol_config=mol_config) config.curriculum = None agent = PPORecurrentAgent(config)