State-Action Q-Function
State-Action Q-Function Class
- class modularl.q_functions.SAQNetwork(observation_shape: int, action_shape: int, network: Module | None = None, use_xavier=True)[source]
Bases:
StateActionQFunctionInitializes a fully-connected (s,a)-input Q-function network.
- Parameters:
observation_shape (int) – The shape of the observation input.
action_shape (int) – The shape of the action input.
network (nn.Module, optional) – Custom neural network to represent the Q-function. If None, a default network is used. Defaults to None.
use_xavier (bool, optional) – Whether to use Xavier initialization for the network weights. Defaults to True.
Note
If no custom network is provided, a default network is created with three linear layers and ReLU activations. The output layer uses a Tanh activation to bound the actions. If a custom network is provided, it should be headless, meaning that this class will add an additional linear layer on top of the provided network for the policy output, with input size equal to the output features of the last layer in the provided network.
Example Usage
Here is an example of how to create and use an instance of the SAQNetwork class:
import torch
import torch.nn as nn
from modularl.q_functions import SAQNetwork
# Define the observation and action shapes
observation_shape = 10
action_shape = 2
# Initialize the Q-network with default settings
q_network = SAQNetwork(observation_shape, action_shape)
# Example observation and action tensors
observation = torch.randn(1, observation_shape)
actions = torch.randn(1, action_shape)
# Compute the Q-value
q_value = q_network(observation, actions)
print(f"Q-value: {q_value.item()}")