KL for a KL: On-Policy Distillation with Control Variate Baseline
Minjae Oh, Sangjun Song, Gyubin Choi, Yunho Choi, Yohan Jo
TLDR
vOPD stabilizes On-Policy Distillation for LLMs by applying an RL control variate baseline, significantly improving reasoning performance efficiently.
Key contributions
- Introduces vOPD, stabilizing On-Policy Distillation (OPD) for LLMs using an RL control variate baseline.
- Defines the OPD value function as a closed-form per-token negative reverse KL divergence, available from forward pass.
- Reduces gradient variance with an unbiased, detached baseline, preserving the lightweight single-sample estimator.
- Consistently outperforms vanilla OPD and matches full-vocabulary baselines on reasoning benchmarks.
Why it matters
On-Policy Distillation (OPD) is vital for LLM reasoning but faces instability due to high gradient variance. This work offers a principled, efficient solution by adapting RL variance reduction. It makes OPD more practical and reliable for large language models without significant computational overhead.
Original Abstract
On-Policy Distillation (OPD) has emerged as a dominant post-training paradigm for large language models, especially for reasoning domains. However, OPD remains unstable in practice due to the high gradient variance of its single-sample Monte Carlo estimator, and recipes for stable training are still immature. We propose vOPD (On-Policy Distillation with a control variate baseline), which casts OPD as policy-gradient RL and stabilizes it by introducing a control variate baseline-canonically a value function -- from the RL literature. We show that the OPD value function admits a closed form as the per-token negative reverse KL divergence between the student and the teacher, available directly from the already-computed forward pass with no additional critic or inference. Existing stabilization methods either compute the full token-level reverse KL over the entire vocabulary, adding significant overhead, or restrict it to a top-k support, biasing the objective. vOPD instead preserves the lightweight single-sample estimator, subtracting the value function as a detached baseline to keep the gradient unbiased while reducing variance. Furthermore, we show that a top-k approximation of the baseline further lowers cost without compromising performance. Across mathematical and scientific reasoning benchmarks, vOPD consistently outperforms vanilla OPD and matches the most expensive full-vocabulary baseline, offering an efficient stabilization of On-Policy Distillation through principled RL variance reduction.
📬 Weekly AI Paper Digest
Get the top 10 AI/ML arXiv papers from the week — summarized, scored, and delivered to your inbox every Monday.