Generalized Advantage Estimation (GAE)
def compute_gae(rewards, values, dones, next_value=0.0, gamma=0.99, lambda_=0.95):
advantages = []
advantage = 0
for t in reversed(range(len(rewards))):
if t == len(rewards) - 1:
next_value = 0 if dones[t] else next_value
else:
next_value = values[t + 1]
delta = rewards[t] + gamma * next_value * (1 - dones[t]) - values[t]
advantage = delta + gamma * lambda_ * (1 - dones[t]) * advantage
advantages.insert(0, advantage)
# Calculate returns as advantages + values
returns = [adv + val for adv, val in zip(advantages, values)]
return advantages, returns