import time
import torch
import torch.nn.functional as F
import torch.optim as optim
import copy
from tensordict import TensorDict
from modularl.agents.agent import AbstractAgent
from torchrl.data import TensorDictReplayBuffer
from typing import Callable, Optional, Any
from torch.utils.tensorboard import SummaryWriter
[docs]
class SAC(AbstractAgent):
"""
Soft Actor-Critic (SAC) Agent
:param actor: The actor network (policy) to be used.
:type actor: torch.nn.Module
:param qf1: The first Q-function network.
:type qf1: torch.nn.Module
:param qf2: The second Q-function network.
:type qf2: torch.nn.Module
:param actor_optimizer: Optimizer for the actor network.
:type actor_optimizer: torch.optim.Optimizer
:param qf_optimizer: Optimizer for both Q-function networks.
:type qf_optimizer: torch.optim.Optimizer
:param replay_buffer: Replay buffer for storing experiences.
:type replay_buffer: TensorDictReplayBuffer
:param gamma: Discount factor for future rewards. Defaults to 0.99.
:type gamma: float, optional
:param entropy_lr: Learning rate for the entropy temperature. Defaults to 1e-3.
:type entropy_lr: float, optional
:param batch_size: Number of samples per batch for training. Defaults to 32.
:type batch_size: int, optional
:param learning_starts: Number of steps before learning starts. Defaults to 0.
:type learning_starts: int, optional
:param entropy_temperature: Initial entropy temperature. Defaults to 0.2.
:type entropy_temperature: float, optional
:param target_entropy: Target entropy for adaptive temperature adjustment. Defaults to None.
:type target_entropy: float, optional
:param tau: Soft update coefficient for target networks. Defaults to 0.005.
:type tau: float, optional
:param policy_frequency: Frequency of policy updates. Defaults to 1.
:type policy_frequency: int, optional
:param target_network_frequency: Frequency of target network updates. Defaults to 2.
:type target_network_frequency: int, optional
:param device: Device to run the agent on (e.g., "cpu" or "cuda"). Defaults to "cpu".
:type device: str, optional
:param burning_action_func: Function for generating initial exploratory actions. Defaults to None.
:type burning_action_func: Callable, optional
:param writer: Tensorboard writer for logging. Defaults to None.
:type writer: SummaryWriter, optional
""" # noqa: E501
def __init__(
self,
actor: torch.nn.Module,
qf1: torch.nn.Module,
qf2: torch.nn.Module,
actor_optimizer: torch.optim.Optimizer,
qf_optimizer: torch.optim.Optimizer,
replay_buffer: TensorDictReplayBuffer,
gamma: float = 0.99,
entropy_lr: float = 1e-3,
batch_size: int = 32,
learning_starts: int = 0,
entropy_temperature: float = 0.2,
target_entropy: Optional[float] = None,
tau: float = 0.005,
policy_frequency: int = 1,
target_network_frequency: int = 2,
device: str = "cpu",
burning_action_func: Optional[Callable] = None,
writer: Optional[SummaryWriter] = None,
**kwargs: Any
) -> None:
super().__init__(**kwargs)
self.device = device
self.writer = writer
self.rb = replay_buffer
self.batch_size = batch_size
self.tau = tau
self.burning_action_func = burning_action_func
self.learning_starts = learning_starts
self.gamma = gamma
# networks
self.actor = actor.to(self.device)
self.qf1 = qf1.to(self.device)
self.qf2 = qf2.to(self.device)
self.qf1_target = copy.deepcopy(self.qf1).to(self.device)
self.qf2_target = copy.deepcopy(self.qf2).to(self.device)
self.qf1_target.load_state_dict(self.qf1.state_dict())
self.qf2_target.load_state_dict(self.qf2.state_dict())
self.actor_optimizer = actor_optimizer
self.qf_optimizer = qf_optimizer
self.alpha = entropy_temperature
self.entropy_lr = entropy_lr
self.policy_frequency = policy_frequency
self.target_network_frequency = target_network_frequency
self.target_entropy = target_entropy
if self.target_entropy is not None:
self.auto_tune_temp = True
else:
self.auto_tune_temp = False
# entropy
if self.auto_tune_temp:
self.log_alpha = torch.zeros(
1, requires_grad=True, device=self.device
)
self.alpha = self.log_alpha.exp().item()
self.a_optimizer = optim.Adam([self.log_alpha], lr=self.entropy_lr)
else:
self.alpha = self.alpha
[docs]
def init(self) -> None:
self.start_time = time.time()
self.global_step = 0
[docs]
def observe(
self,
batch_obs: torch.Tensor,
batch_actions: torch.Tensor,
batch_rewards: torch.Tensor,
batch_next_obs: torch.Tensor,
batch_dones: torch.Tensor,
) -> None:
self.global_step += 1
batch_transition = TensorDict(
{
"observations": batch_obs.clone(),
"next_observations": batch_next_obs.clone(),
"actions": batch_actions.clone(),
"rewards": batch_rewards.clone(),
"dones": batch_dones.clone(),
},
batch_size=[batch_obs.shape[0]],
)
self.rb.extend(batch_transition)
self.update()
[docs]
def act_train(self, batch_obs: torch.Tensor) -> torch.Tensor:
"""
Generate actions for training based on the current policy.
It uses a burning action function for initial exploration if specified,
then switches to the learned policy.
:param batch_obs: (torch.Tensor) A batch of observations from the environment.
:return: (torch.Tensor) A batch of actions to be taken in the environment.
Notes:
- If the global step is less than `learning_starts` and a burning action
function is provided, it uses that function for exploration.
- Otherwise, it uses the current policy (actor) to generate actions.
""" # noqa: E501
if (
self.global_step < self.learning_starts
and self.burning_action_func is not None
):
return self.burning_action_func(batch_obs).to(self.device)
else:
actions, _, _ = self.actor.get_action(batch_obs)
actions = actions.detach()
return actions
[docs]
def act_eval(self, batch_obs: torch.Tensor) -> torch.Tensor:
self.qf1.eval().requires_grad_(False)
self.qf2.eval().requires_grad_(False)
self.actor.eval().requires_grad_(False)
with torch.no_grad():
actions, _, _ = self.actor.get_action(batch_obs.to(self.device))
self.qf1.train().requires_grad_(True)
self.qf2.train().requires_grad_(True)
self.actor.train().requires_grad_(True)
return actions
[docs]
def update(self) -> None:
if self.global_step > self.learning_starts:
data = self.rb.sample(self.batch_size).to(self.device)
with torch.no_grad():
if self.gamma != 0:
next_state_actions, next_state_log_pi, _ = (
self.actor.get_action(data["next_observations"])
)
qf1_next_target = self.qf1_target(
data["next_observations"], actions=next_state_actions
)
qf2_next_target = self.qf2_target(
data["next_observations"], actions=next_state_actions
)
min_qf_next_target = (
torch.min(qf1_next_target, qf2_next_target)
- self.alpha * next_state_log_pi
)
next_q_value = data["rewards"].flatten() + (
1 - data["dones"].to(torch.float32).flatten()
) * self.gamma * (min_qf_next_target).view(-1)
else:
next_q_value = data["rewards"].flatten()
qf1_a_values = self.qf1(
data["observations"], actions=data["actions"]
).view(-1)
qf2_a_values = self.qf2(
data["observations"], actions=data["actions"]
).view(-1)
qf1_loss = F.mse_loss(qf1_a_values, next_q_value)
qf2_loss = F.mse_loss(qf2_a_values, next_q_value)
qf_loss = qf1_loss + qf2_loss
# optimize the model
self.qf_optimizer.zero_grad()
qf_loss.backward()
self.qf_optimizer.step()
if (
self.global_step % self.policy_frequency == 0
): # TD 3 Delayed update support
for _ in range(
self.policy_frequency
): # compensate for the delay by doing 'actor_update_interval' instead of 1 # noqa: E501
pi, log_pi, _ = self.actor.get_action(data["observations"])
qf1_pi = self.qf1(data["observations"], actions=pi)
qf2_pi = self.qf2(data["observations"], actions=pi)
min_qf_pi = torch.min(qf1_pi, qf2_pi)
actor_loss = ((self.alpha * log_pi) - min_qf_pi).mean()
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
if self.auto_tune_temp:
alpha_loss = (
-self.log_alpha.exp()
* (log_pi + self.target_entropy).detach()
).mean()
self.a_optimizer.zero_grad()
alpha_loss.backward()
self.a_optimizer.step()
self.alpha = self.log_alpha.exp().item()
# update the target networks
if self.gamma != 0:
if self.global_step % self.target_network_frequency == 0:
for param, target_param in zip(
self.qf1.parameters(), self.qf1_target.parameters()
):
target_param.data.copy_(
self.tau * param.data
+ (1 - self.tau) * target_param.data
)
for param, target_param in zip(
self.qf2.parameters(), self.qf2_target.parameters()
):
target_param.data.copy_(
self.tau * param.data
+ (1 - self.tau) * target_param.data
)
if self.global_step % 100 == 0 and self.writer is not None:
self.writer.add_scalar(
"losses/qf1_values",
qf1_a_values.mean().item(),
self.global_step,
)
self.writer.add_scalar(
"losses/qf2_values",
qf2_a_values.mean().item(),
self.global_step,
)
self.writer.add_scalar(
"losses/qf1_loss", qf1_loss.item(), self.global_step
)
self.writer.add_scalar(
"losses/qf2_loss", qf2_loss.item(), self.global_step
)
self.writer.add_scalar(
"losses/qf_loss", qf_loss.item() / 2.0, self.global_step
)
self.writer.add_scalar(
"losses/actor_loss", actor_loss.item(), self.global_step
)
self.writer.add_scalar(
"losses/alpha", self.alpha, self.global_step
)
self.writer.add_scalar(
"charts/SPS",
int(self.global_step / (time.time() - self.start_time)),
self.global_step,
)
if self.auto_tune_temp:
self.writer.add_scalar(
"losses/alpha_loss",
alpha_loss.item(),
self.global_step,
)