Deterministic Policy
Deterministic Policy Class
- class modularl.policies.DeterministicPolicy(observation_shape: int, action_shape: int, high_action: float, low_action: float, network: Module | None = None, use_xavier: bool = True, **kwargs: Any)[source]
Bases:
AbstractPolicyDeterministic Policy for continuous action spaces.
- Parameters:
observation_shape (int) – Dimension of the observation space.
action_shape (int) – Dimension of the action space.
high_action (float) – Upper bound of the action space.
low_action (float) – Lower bound of the action space.
network (nn.Module, optional) – Custom neural network to represent the policy. If None, a default network is used. Defaults to None.
use_xavier (bool, optional) – Whether to use Xavier initialization for weights. Defaults to True.
Note
If no custom network is provided, a default network is created with three linear layers and ReLU activations. The output layer uses a Tanh activation to bound the actions. If a custom network is provided, it should be headless, meaning that this class will add an additional linear layer on top of the provided network for the policy output, with input size equal to the output features of the last layer in the provided network.
Example Usage
Here’s an example of how to use the DeterministicPolicy:
import torch
import torch.nn as nn
from modularl.policies.deterministic_policy import DeterministicPolicy
# Define custom network
class CustomNetwork(nn.Module):
def __init__(self, observation_shape):
super(CustomNetwork, self).__init__()
self.network = nn.Sequential(
nn.Linear(observation_shape, 128),
nn.ReLU(),
nn.Linear(128, 128),
nn.ReLU()
)
def forward(self, x):
return self.network(x)
# Observation and action space dimensions
observation_shape = 10
action_shape = 2
high_action = 1.0
low_action = -1.0
# Custom network instance
custom_network = CustomNetwork(observation_shape)
# Create Deterministic Policy with the custom network
policy = DeterministicPolicy(
observation_shape=observation_shape,
action_shape=action_shape,
high_action=high_action,
low_action=low_action,
network=custom_network
)
# Or without a custom network
default_policy = DeterministicPolicy(
observation_shape=observation_shape,
action_shape=action_shape,
high_action=high_action,
low_action=low_action
)
# Example observation
observation = torch.randn((1, observation_shape))
# Get action from the policy
action = policy.get_action(observation)
print("Action:", action)