MDN: Parallelizing Stepwise Momentum for Delta Linear Attention
Yulong Huang, Xiang Liu, Hongxiang Huang, Xiaopeng Lin, Zunchang Liu + 3 more
TLDR
MDN introduces a parallel stepwise momentum algorithm for Linear Attention, improving LLM performance and stability for long sequences.
Key contributions
- Introduces MDN, a chunkwise parallel algorithm for Linear Attention with a stepwise momentum rule.
- Analyzes momentum recurrence as a second-order system, guiding stable gating constraints for LA.
- Achieves comparable training throughput to Mamba2 and KDA using optimized Triton kernels.
- Outperforms Transformers, Mamba2, and GDN on 400M/1.3B LLMs across diverse benchmarks.
Why it matters
Current Linear Attention models struggle with information decay and convergence. MDN addresses this by integrating momentum, making LA models more stable and effective for long sequences. This advancement is crucial for scaling LLMs to handle much longer contexts efficiently.
Original Abstract
Linear Attention (LA) offers a promising paradigm for scaling large language models (LLMs) to long sequences by avoiding the quadratic complexity of self-attention. Recent LA models such as Mamba2 and GDN interpret linear recurrences as closed-form online stochastic gradient descent (SGD), but naive SGD updates suffer from rapid information decay and suboptimal convergence in optimization. While momentum-based optimizers provide a natural remedy, they pose challenges in simultaneously achieving training efficiency and effectiveness. To address this, we develop a chunkwise parallel algorithm for LA with a stepwise momentum rule by geometrically reordering the update coefficients. Further, from a dynamical systems perspective, we analyze the momentum-based recurrence as a second-order system that introduces complex conjugate eigenvalues. This analysis guides the design of stable gating constraints. The resulting model, Momentum DeltaNet (MDN), leverages Triton kernels to achieve comparable training throughput with competitive linear models such as Mamba2 and KDA. Extensive experiments on the 400M and 1.3B parameter models demonstrate consistent performance improvements over strong baselines, including Transformers, Mamba2 and GDN, across diverse downstream evaluation benchmarks. Code: https://github.com/HuuYuLong/MomentumDeltaNet .
📬 Weekly AI Paper Digest
Get the top 10 AI/ML arXiv papers from the week — summarized, scored, and delivered to your inbox every Monday.