Navigating Soft Actor-Critic Reinforcement Learning | by Mohammed AbuSadeh | Dec, 2024

Editor
3 Min Read


The code implemented in this article is taken from the following Github repository (quantumiracle, 2023):

pip install gymnasium torch

SAC relies on environments that use continuous action spaces, so the simulation provided uses the robotic arm ‘Reacher’ environment for the most part and the Pendulum-v1 environment in the gymnasium package.

The Pendulum environment was run on a different repository that implements the same algorithm but with less deprecated libraries given by (MrSyee, 2020):

In terms of the network architectures, as mentioned in the Theory Explanation, there are three main components:

Policy Network: implements a Gaussian Actor network computing the mean and log standard deviation for the action distribution.

class PolicyNetwork(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim):
super(PolicyNetwork, self).__init__()
self.fc1 = nn.Linear(state_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.mean = nn.Linear(hidden_dim, action_dim)
self.log_std = nn.Linear(hidden_dim, action_dim)

def forward(self, state):
x = F.relu(self.fc1(state))
x = F.relu(self.fc2(x))
mean = self.mean(x)
log_std = torch.clamp(self.log_std(x), -20, 2) # Limit log_std to prevent instability
return mean, log_std

Soft Q-Network: estimates the expected future reward given from a state-action pair for a defined optimal policy.

class SoftQNetwork(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim):
super(SoftQNetwork, self).__init__()
self.fc1 = nn.Linear(state_dim + action_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.out = nn.Linear(hidden_dim, 1)

def forward(self, state, action):
x = torch.cat([state, action], dim=-1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return self.out(x)

Value Network: estimates the state value.

class ValueNetwork(nn.Module):
def __init__(self, state_dim, hidden_dim):
super(ValueNetwork, self).__init__()
self.fc1 = nn.Linear(state_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.out = nn.Linear(hidden_dim, 1)

def forward(self, state):
x = F.relu(self.fc1(state))
x = F.relu(self.fc2(x))
return self.out(x)

The following snippet offers the key steps in updating the different variables corresponding to the SAC algorithm. As it starts by sampling a batch from the replay buffer for experience replay. Then, before computing the gradients, they are initialised to zero to ensure that gradients from previous batches are not accumulated. Then performs backpropagation and updates the weights of the network during training. The target and loss values are then updated for the Q-networks. These steps take place for all three methods.

def update(batch_size, reward_scale, gamma=0.99, soft_tau=1e-2):
# Sample a batch
state, action, reward, next_state, done = replay_buffer.sample(batch_size)
state, next_state, action, reward, done = map(lambda x: torch.FloatTensor(x).to(device),
[state, next_state, action, reward, done])

# Update Q-networks
target_value = target_value_net(next_state)
target_q = reward + (1 - done) * gamma * target_value
q1_loss = F.mse_loss(soft_q_net1(state, action), target_q.detach())
q2_loss = F.mse_loss(soft_q_net2(state, action), target_q.detach())

soft_q_optimizer1.zero_grad()
q1_loss.backward()
soft_q_optimizer1.step()

soft_q_optimizer2.zero_grad()
q2_loss.backward()
soft_q_optimizer2.step()

# Update Value Network
predicted_q = torch.min(soft_q_net1(state, action), soft_q_net2(state, action))
value_loss = F.mse_loss(value_net(state), predicted_q - alpha * log_prob)
value_optimizer.zero_grad()
value_loss.backward()
value_optimizer.step()

# Update Policy Network
new_action, log_prob, _, _, _ = policy_net.evaluate(state)
policy_loss = (alpha * log_prob - predicted_q).mean()
policy_optimizer.zero_grad()
policy_loss.backward()
policy_optimizer.step()

# Soft Update Target Network
for target_param, param in zip(target_value_net.parameters(), value_net.parameters()):
target_param.data.copy_(soft_tau * param.data + (1 - soft_tau) * target_param.data)

Finally, to run the code in the sac.py file, just run the following commands:

python sac.py --train
python sac.py --test
Share this Article
Please enter CoinGecko Free Api Key to get this plugin works.