ArXiv TLDR

AdaSplash-2: Faster Differentiable Sparse Attention

🐦 Tweet
2604.15180

Nuno Gonçalves, Hugo Pitorro, Vlad Niculae, Edoardo Ponti, Lei Li + 2 more

cs.LGcs.CL

TLDR

AdaSplash-2 introduces a novel histogram-based initialization for α-entmax attention, significantly speeding up sparse transformer training.

Key contributions

  • Introduces AdaSplash-2, a faster differentiable sparse attention mechanism based on α-entmax.
  • Uses a novel histogram-based initialization to reduce normalizer (τ) computation to 1-2 iterations.
  • Achieves FlashAttention-2 comparable or better training times with >60% sparsity, crucial for long contexts.
  • Matches softmax baselines at short contexts and shows substantial gains in long-context settings.

Why it matters

This paper significantly improves the efficiency of sparse attention, a key bottleneck in long-context transformers. By speeding up α-entmax, it enables more practical training of models for complex, long-range dependencies, matching or exceeding current state-of-the-art methods.

Original Abstract

Sparse attention has been proposed as a way to alleviate the quadratic cost of transformers, a central bottleneck in long-context training. A promising line of work is $α$-entmax attention, a differentiable sparse alternative to softmax that enables input-dependent sparsity yet has lagged behind softmax due to the computational overhead necessary to compute the normalizer $τ$. In this paper, we introduce AdaSplash-2, which addresses this limitation through a novel histogram-based initialization that reduces the number of iterations needed to compute $τ$ to typically 1--2. The key idea is to compute a coarse histogram of attention scores on the fly and store it in on-chip SRAM, yielding a more accurate initialization that enables fast forward and backward computation. Combined with a sparsity-aware GPU implementation that skips zero blocks with low overhead, AdaSplash-2 matches or improves per-step training time relative to FlashAttention-2 when block sparsity is moderate-to-high (e.g., $>$60\%), which often occurs at long-context lengths. On downstream tasks, models trained with our efficient $α$-entmax attention match softmax baselines at short-context lengths and achieve substantial gains in long-context settings.

📬 Weekly AI Paper Digest

Get the top 10 AI/ML arXiv papers from the week — summarized, scored, and delivered to your inbox every Monday.