Decoupled Descent: Exact Test Error Tracking Via Approximate Message Passing
TLDR
Decoupled Descent (DD) is a new training algorithm that uses approximate message passing to make train error track test error, closing the generalization gap.
Key contributions
- DD enforces train error to asymptotically track test error for stylized Gaussian mixture models.
- Uses approximate message passing to cancel data reuse biases, enabling zero-cost validation.
- Governed by a low-dimensional state evolution, making algorithm dynamics transparent.
- Shows superior performance and narrower generalization gap compared to GD on various tasks.
Why it matters
This paper introduces a novel approach to address the generalization gap, where train error becomes an unreliable proxy for test error. Decoupled Descent (DD) ensures train error tracks test error, potentially eliminating the need for validation sets and maximizing data utilization. This could significantly improve model training efficiency.
Original Abstract
In modern parametric model training, full-batch gradient descent (and its variants) suffers due to progressively stronger biasing towards the exact realization of training data; this drives the systematic ``generalization gap'', where the train error becomes an unreliable proxy for test error. Existing approaches either argue this gap is benign through complex analysis or sacrifice data to a validation set. In contrast, we introduce decoupled descent (DD), a novel theory-based training algorithm that satisfies a train-test identity -- enforcing the train error to asymptotically track the test error for stylized Gaussian mixture models. Within this specific regime, leveraging approximate message passing theory, DD iteratively cancels the biases due to data reuse, rigorously demonstrating the feasibility of zero-cost validation and $100\%$ data utilization. Moreover, DD is governed by a low-dimensional state evolution recursion, rendering the dynamics of the algorithm transparent and tractable. We validate DD on XOR classification, yielding superior performance compared to GD; additionally, we implement noisy MNIST and non-linear probing of CIFAR-10, demonstrating that even when our stylized assumptions are relaxed, DD narrows the generalization gap compared to GD.
📬 Weekly AI Paper Digest
Get the top 10 AI/ML arXiv papers from the week — summarized, scored, and delivered to your inbox every Monday.