Generalized Advantage Estimation (GAE)

def compute_gae(rewards, values, dones, next_value=0.0, gamma=0.99, lambda_=0.95):
    advantages = []
    advantage = 0
    
    for t in reversed(range(len(rewards))):
        if t == len(rewards) - 1:
            next_value = 0 if dones[t] else next_value
        else:
            next_value = values[t + 1]
        
        delta = rewards[t] + gamma * next_value * (1 - dones[t]) - values[t]
        advantage = delta + gamma * lambda_ * (1 - dones[t]) * advantage
        advantages.insert(0, advantage)
    
    # Calculate returns as advantages + values
    returns = [adv + val for adv, val in zip(advantages, values)]
    
    return advantages, returns