Skip to content

Generalized Advantage Estimation (GAE)

How do we estimate the advantage A^t\hat{A}_t? There are many options:

Monte Carlo (high variance, no bias):

A^tMC=GtV(st)=k=0Ttγkrt+kV(st)\hat{A}_t^{MC} = G_t - V(s_t) = \sum_{k=0}^{T-t} \gamma^k r_{t+k} - V(s_t)

TD(0) (low variance, high bias):

A^tTD=rt+γV(st+1)V(st)=δt\hat{A}_t^{TD} = r_t + \gamma V(s_{t+1}) - V(s_t) = \delta_t

where δt=rt+γV(st+1)V(st)\delta_t = r_t + \gamma V(s_{t+1}) - V(s_t) is the TD residual.

Monte Carlo uses the actual return (accurate but noisy). TD uses the value estimate (smooth but biased by VV errors).

Generalized Advantage Estimation (Schulman et al., 2016) interpolates between these extremes with a parameter λ[0,1]\lambda \in [0, 1]:

A^tGAE(γ,λ)=k=0Tt(γλ)kδt+k\hat{A}_t^{GAE(\gamma, \lambda)} = \sum_{k=0}^{T-t} (\gamma \lambda)^k \delta_{t+k}

This can be written recursively:

A^tGAE=δt+γλA^t+1GAE\hat{A}_t^{GAE} = \delta_t + \gamma \lambda \hat{A}_{t+1}^{GAE}
λ\lambdaEquivalent toBiasVariance
λ=0\lambda = 0TD(0): A^t=δt\hat{A}_t = \delta_tHighLow
λ=1\lambda = 1Monte Carlo: A^t=GtV(st)\hat{A}_t = G_t - V(s_t)NoneHigh
0<λ<10 < \lambda < 1Weighted mix of n-step returnsMediumMedium

The full PPO training loop:

  1. Collect rollout data with current policy πθold\pi_{\theta_\text{old}}
  2. Compute TD residuals: δt=rt+γVϕ(st+1)Vϕ(st)\delta_t = r_t + \gamma V_\phi(s_{t+1}) - V_\phi(s_t)
  3. Compute GAE advantages: A^t=k(γλ)kδt+k\hat{A}_t = \sum_k (\gamma\lambda)^k \delta_{t+k}
  4. Compute targets for the value function: Gt=A^t+Vϕ(st)G_t = \hat{A}_t + V_\phi(s_t)
  5. Run multiple epochs of mini-batch updates on LCLIPL^{CLIP} and the value loss

In LLM training (RLHF), episodes can be long (hundreds of tokens). Monte Carlo returns have very high variance because each token’s reward signal is buried under the noise of all future tokens.

GAE with λ<1\lambda < 1 exponentially downweights distant TD residuals, giving a much cleaner signal for credit assignment — which token actually contributed to the reward?