import torch
import torch.nn as nn
import torch.nn.init as init
from modularl.policies.policy import AbstractPolicy
from typing import Any, Optional, Tuple
LOG_STD_MAX = 2
LOG_STD_MIN = -20
[docs]
class GaussianPolicy(AbstractPolicy):
"""
Gaussian Policy for continuous action spaces.
:param observation_shape: Dimension of the observation space.
:type observation_shape: int
:param action_shape: Dimension of the action space.
:type action_shape: int
:param high_action: Upper bound of the action space.
:type high_action: float
:param low_action: Lower bound of the action space.
:type low_action: float
:param network: Custom neural network to represent the policy. If None, a default network is used. Defaults to None.
:type network: nn.Module, optional
:param use_xavier: Whether to use Xavier initialization for weights. Defaults to True.
:type use_xavier: bool, optional
Note:
If a custom network is provided, it should be headless, meaning that this class will add additional
linear layers on top of the provided network. Specifically, the class appends two `nn.Linear` layers for
mean and log_std, with input size equal to the output features of the last layer in the provided network.
The head of the network consists of two nn.Linear layers for mean and log_std, with input size equal to
the output features of the last layer in the provided network.
""" # noqa
def __init__(
self,
observation_shape: int,
action_shape: int,
high_action: float,
low_action: float,
network: Optional[nn.Module] = None,
use_xavier: bool = True,
**kwargs: Any,
):
super().__init__(**kwargs)
self.high_action = high_action
self.low_action = low_action
self.action_shape = action_shape
self.observation_shape = observation_shape
if network is None:
self.network = nn.Sequential(
nn.Linear(observation_shape, 16 * observation_shape),
nn.ReLU(),
nn.Linear(16 * observation_shape, 16 * observation_shape),
nn.ReLU(),
)
else:
self.network = network
self.fc_mean = nn.Linear(
self.network[-2].out_features, self.action_shape
)
self.fc_logstd = nn.Linear(
self.network[-2].out_features, self.action_shape
)
# action rescaling
self.register_buffer(
"action_scale",
torch.tensor(
(self.high_action - self.low_action) / 2.0, dtype=torch.float32
),
)
self.register_buffer(
"action_bias",
torch.tensor(
(self.high_action + self.low_action) / 2.0, dtype=torch.float32
),
)
if use_xavier:
self._initialize_weights()
[docs]
def forward(
self, batch_observation: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
x = self.network(batch_observation)
mean = self.fc_mean(x)
log_std = self.fc_logstd(x)
log_std = torch.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
return mean, log_std
[docs]
def get_action(self, batch_observation: torch.Tensor):
"""
Get action from the policy
Args:
observation (torch.Tensor): Observation from the environment
Returns:
action (torch.Tensor): Sampled action from the policy distribution (only if deterministic is False)
log_prob (torch.Tensor): Log probability of the action (only if deterministic is False)
mean (torch.Tensor): Mean of the action distribution
""" # noqa
mean, log_std = self(batch_observation)
std = log_std.exp()
normal = torch.distributions.Normal(mean, std)
x_t = (
normal.rsample()
) # for reparameterization trick (mean + std * N(0,1))
y_t = torch.tanh(x_t)
action = y_t * self.action_scale + self.action_bias
log_prob = normal.log_prob(x_t)
# Enforcing Action Bound
log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + 1e-6)
log_prob = log_prob.sum(1, keepdim=True)
mean = torch.tanh(mean) * self.action_scale + self.action_bias
return action, log_prob, mean
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
init.xavier_uniform_(m.weight)
if m.bias is not None:
init.zeros_(m.bias)