Source code for modularl.agents.td3

import time
import torch
import torch.nn.functional as F
import copy
from tensordict import TensorDict
from modularl.agents.agent import AbstractAgent
from torchrl.data import TensorDictReplayBuffer
from typing import Callable, Optional, Any
from torch.utils.tensorboard import SummaryWriter


[docs] class TD3(AbstractAgent): """ Twin Delayed Deep Deterministic Policy Gradient (TD3) Agent :param actor: The actor network. :type actor: torch.nn.Module :param qf1: The first Q-function network. :type qf1: torch.nn.Module :param qf2: The second Q-function network. :type qf2: torch.nn.Module :param actor_optimizer: Optimizer for the actor network. :type actor_optimizer: torch.optim.Optimizer :param qf_optimizer: Optimizer for the Q-function networks. :type qf_optimizer: torch.optim.Optimizer :param replay_buffer: Replay buffer for storing experiences. :type replay_buffer: TensorDictReplayBuffer :param gamma: Discount factor for future rewards. Defaults to 0.99. :type gamma: float, optional :param batch_size: Number of samples per batch for training. Defaults to 32. :type batch_size: int, optional :param learning_starts: Number of steps before learning starts. Defaults to 0. :type learning_starts: int, optional :param tau: Soft update coefficient for target networks. Defaults to 0.005. :type tau: float, optional :param exploration_noise: Noise added to the actor policy during training. Defaults to 0.1. :type exploration_noise: float, optional :param policy_noise: Noise added to the target policy during critic updates. Defaults to 0.2. :type policy_noise: float, optional :param noise_clip: Range to clip the target policy noise. Defaults to 0.5. :type noise_clip: float, optional :param policy_frequency: Frequency of delayed policy updates. Defaults to 2. :type policy_frequency: int, optional :param device: Device to run the agent on (e.g., "cpu" or "cuda"). Defaults to "cpu". :type device: str, optional :param burning_action_func: Function for generating initial exploratory actions. Defaults to None. :type burning_action_func: Callable, optional :param writer: Tensorboard writer for logging. Defaults to None. :type writer: SummaryWriter, optional """ # noqa: E501 def __init__( self, actor: torch.nn.Module, qf1: torch.nn.Module, qf2: torch.nn.Module, actor_optimizer: torch.optim.Optimizer, qf_optimizer: torch.optim.Optimizer, replay_buffer: TensorDictReplayBuffer, gamma: float = 0.99, batch_size: int = 32, learning_starts: int = 0, tau: float = 0.005, exploration_noise: float = 0.1, policy_noise: float = 0.2, noise_clip: float = 0.5, policy_frequency: int = 2, device: str = "cpu", burning_action_func: Optional[Callable] = None, writer: Optional[SummaryWriter] = None, **kwargs: Any ) -> None: super().__init__(**kwargs) self.device = device self.writer = writer self.rb = replay_buffer self.batch_size = batch_size self.tau = tau self.burning_action_func = burning_action_func self.learning_starts = learning_starts self.gamma = gamma self.policy_noise = policy_noise self.noise_clip = noise_clip self.exploration_noise = exploration_noise self.policy_frequency = policy_frequency # Networks self.actor = actor.to(self.device) self.qf1 = qf1.to(self.device) self.qf2 = qf2.to(self.device) self.target_actor = copy.deepcopy(self.actor).to(self.device) self.qf1_target = copy.deepcopy(self.qf1).to(self.device) self.qf2_target = copy.deepcopy(self.qf2).to(self.device) self.actor_optimizer = actor_optimizer self.qf_optimizer = qf_optimizer
[docs] def init(self) -> None: self.start_time = time.time() self.global_step = 0
[docs] def observe( self, batch_obs: torch.Tensor, batch_actions: torch.Tensor, batch_rewards: torch.Tensor, batch_next_obs: torch.Tensor, batch_dones: torch.Tensor, ) -> None: self.global_step += 1 batch_transition = TensorDict( { "observations": batch_obs.clone(), "next_observations": batch_next_obs.clone(), "actions": batch_actions.clone(), "rewards": batch_rewards.clone(), "dones": batch_dones.clone(), }, batch_size=[batch_obs.shape[0]], ) self.rb.extend(batch_transition) self.update()
[docs] def act_train(self, batch_obs: torch.Tensor) -> torch.Tensor: with torch.no_grad(): if ( self.global_step < self.learning_starts and self.burning_action_func is not None ): return self.burning_action_func(batch_obs).to(self.device) else: actions = self.actor.get_action(batch_obs.to(self.device)) actions = actions + torch.normal( 0, self.actor.action_scale * self.exploration_noise, size=actions.shape, device=self.device, ) actions = actions.clamp( self.actor.low_action, self.actor.high_action ) return actions
[docs] def act_eval(self, batch_obs: torch.Tensor) -> torch.Tensor: self.actor.eval() with torch.no_grad(): actions = self.actor.get_action(batch_obs.to(self.device)) self.actor.train() return actions
[docs] def update(self) -> None: if self.global_step > self.learning_starts: data = self.rb.sample(self.batch_size).to(self.device) with torch.no_grad(): if self.gamma != 0: clipped_noise = ( torch.randn_like(data["actions"], device=self.device) * self.policy_noise ).clamp( -self.noise_clip, self.noise_clip ) * self.target_actor.action_scale next_state_actions = ( self.target_actor(data["next_observations"]) + clipped_noise ).clamp(self.actor.low_action, self.actor.high_action) qf1_next_target = self.qf1_target( data["next_observations"], next_state_actions ) qf2_next_target = self.qf2_target( data["next_observations"], next_state_actions ) min_qf_next_target = torch.min( qf1_next_target, qf2_next_target ) next_q_value = data["rewards"].flatten() + ( 1 - data["dones"].flatten() ) * self.gamma * min_qf_next_target.view(-1) else: next_q_value = data["rewards"].flatten() qf1_a_values = self.qf1( data["observations"], data["actions"] ).view(-1) qf2_a_values = self.qf2( data["observations"], data["actions"] ).view(-1) qf1_loss = F.mse_loss(qf1_a_values, next_q_value) qf2_loss = F.mse_loss(qf2_a_values, next_q_value) qf_loss = qf1_loss + qf2_loss self.qf_optimizer.zero_grad() qf_loss.backward() self.qf_optimizer.step() if self.global_step % self.policy_frequency == 0: actor_loss = -self.qf1( data["observations"], self.actor.get_action(data["observations"]), ).mean() self.actor_optimizer.zero_grad() actor_loss.backward() self.actor_optimizer.step() if self.gamma != 0: # Update target networks for param, target_param in zip( self.actor.parameters(), self.target_actor.parameters() ): target_param.data.copy_( self.tau * param.data + (1 - self.tau) * target_param.data ) for param, target_param in zip( self.qf1.parameters(), self.qf1_target.parameters() ): target_param.data.copy_( self.tau * param.data + (1 - self.tau) * target_param.data ) for param, target_param in zip( self.qf2.parameters(), self.qf2_target.parameters() ): target_param.data.copy_( self.tau * param.data + (1 - self.tau) * target_param.data ) if self.global_step % 100 == 0 and self.writer is not None: self.writer.add_scalar( "losses/qf1_values", qf1_a_values.mean().item(), self.global_step, ) self.writer.add_scalar( "losses/qf2_values", qf2_a_values.mean().item(), self.global_step, ) self.writer.add_scalar( "losses/qf1_loss", qf1_loss.item(), self.global_step ) self.writer.add_scalar( "losses/qf2_loss", qf2_loss.item(), self.global_step ) self.writer.add_scalar( "losses/qf_loss", qf_loss.item() / 2.0, self.global_step ) self.writer.add_scalar( "losses/actor_loss", actor_loss.item(), self.global_step ) self.writer.add_scalar( "charts/SPS", int(self.global_step / (time.time() - self.start_time)), self.global_step, )