Skip to content

Commit

Permalink
fix td3 algo and doc
Browse files Browse the repository at this point in the history
  • Loading branch information
zakaria-narjis committed Aug 9, 2024
1 parent 2d932c3 commit cefb4a2
Show file tree
Hide file tree
Showing 7 changed files with 31 additions and 33 deletions.
7 changes: 4 additions & 3 deletions modularl/agents/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def act_train(self, batch_obs: torch.Tensor) -> torch.Tensor:
):
return self.burning_action_func(batch_obs).to(self.device)
else:
actions = self.actor(batch_obs.to(self.device))
actions = self.actor.get_action(batch_obs.to(self.device))
actions = actions + torch.normal(
0,
self.actor.action_scale * self.exploration_noise,
Expand All @@ -162,7 +162,7 @@ def act_eval(self, batch_obs: torch.Tensor) -> torch.Tensor:

self.actor.eval()
with torch.no_grad():
actions = self.actor(batch_obs.to(self.device))
actions = self.actor.get_action(batch_obs.to(self.device))
self.actor.train()
return actions

Expand Down Expand Up @@ -216,7 +216,8 @@ def update(self) -> None:

if self.global_step % self.policy_frequency == 0:
actor_loss = -self.qf1(
data["observations"], self.actor(data["observations"])
data["observations"],
self.actor.get_action(data["observations"]),
).mean()

self.actor_optimizer.zero_grad()
Expand Down
18 changes: 6 additions & 12 deletions modularl/policies/deterministic_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,20 +78,14 @@ def __init__(
if use_xavier:
self._initialize_weights()

def forward(self, observation):
output = self.network(observation)
def forward(self, batch_observation):
output = self.network(batch_observation)
return output

def get_action(self, observation):
"""
Get action from the policy
Args:
observation (torch.Tensor): Observation from the environment
return:
action (torch.Tensor): Action to be taken
""" # noqa
actions = self(observation) * self.action_scale + self.action_bias
def get_action(self, batch_observation):
actions = (
self(batch_observation) * self.action_scale + self.action_bias
)
return actions

def _initialize_weights(self):
Expand Down
8 changes: 4 additions & 4 deletions modularl/policies/gaussian_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,15 @@ def __init__(
self._initialize_weights()

def forward(
self, observation: torch.Tensor
self, batch_observation: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
x = self.network(observation)
x = self.network(batch_observation)
mean = self.fc_mean(x)
log_std = self.fc_logstd(x)
log_std = torch.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
return mean, log_std

def get_action(self, observation: torch.Tensor):
def get_action(self, batch_observation: torch.Tensor):
"""
Get action from the policy
Expand All @@ -108,7 +108,7 @@ def get_action(self, observation: torch.Tensor):
log_prob (torch.Tensor): Log probability of the action (only if deterministic is False)
mean (torch.Tensor): Mean of the action distribution
""" # noqa
mean, log_std = self(observation)
mean, log_std = self(batch_observation)
std = log_std.exp()
normal = torch.distributions.Normal(mean, std)
x_t = (
Expand Down
14 changes: 6 additions & 8 deletions modularl/policies/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,33 +2,31 @@
import torch.nn as nn
from abc import ABC, abstractmethod

from typing import Any


class AbstractPolicy(nn.Module, ABC):

def __init__(self, **kwargs):
super(AbstractPolicy, self).__init__()

@abstractmethod
def forward(self, observation: torch.Tensor) -> torch.Tensor:
def forward(self, batch_observation: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the policy network
Args:
observation (torch.Tensor): Observation from the environment
batch_observation (torch.Tensor): Batch observation from the environment
""" # noqa

@abstractmethod
def get_action(self, observation: torch.Tensor) -> Any:
def get_action(self, batch_observation: torch.Tensor) -> torch.Tensor:
"""
Get action from the policy
Args:
observation (torch.Tensor): Observation from the environment
batch_observation (torch.Tensor): Batch observation from the environment
return:
action (torch.Tensor): Action to be taken
"""
batch_action (torch.Tensor): Batch action to be taken
""" # noqa

@abstractmethod
def _initialize_weights(self) -> None:
Expand Down
10 changes: 5 additions & 5 deletions modularl/q_functions/q_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ class StateQFunction(nn.Module, ABC):
"""Abstract Q-function with state input."""

@abstractmethod
def forward(self, observation: torch.Tensor) -> torch.Tensor:
def forward(self, batch_observation: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the Q-function network.
:param observation: Batch Observation tensor.
:param batch_observation: Batch Observation tensor.
:return: Q-value tensor.
"""
raise NotImplementedError()
Expand All @@ -21,12 +21,12 @@ class StateActionQFunction(nn.Module, ABC):

@abstractmethod
def forward(
self, observation: torch.Tensor, actions: torch.Tensor
self, batch_observation: torch.Tensor, actions: torch.Tensor
) -> torch.Tensor:
"""
Forward pass of the Q-function network.
:param observation: Batch Observation tensor.
:param actions: Batch Action tensor.
:param batch_observation: Batch Observation tensor.
:param batch_actions: Batch Action tensor.
:return: Q-value tensor.
"""
raise NotImplementedError()
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name="modularl",
version="0.1.3",
version="0.1.4",
author="Zakaria Narjis",
author_email="zakaria.narjis.97@gmail.com",
description="A modular reinforcement learning library",
Expand Down
5 changes: 5 additions & 0 deletions tests/agents/test_td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ def forward(self, x):
x = F.relu(self.fc2(x))
return torch.tanh(self.fc3(x))

def get_action(self, observation):

actions = self(observation) * self.action_scale + self.action_bias
return actions

def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
Expand Down

0 comments on commit cefb4a2

Please sign in to comment.