Long Context Pre-Training with Lighthouse Attention
Bowen Peng, Subho Ghosh, Jeffrey Quesnelle
TLDR
Lighthouse Attention enables efficient long-context transformer pre-training by using a subquadratic, gradient-free hierarchical attention that's removed post-training.
Key contributions
- Introduces subquadratic hierarchical pre/post-processing for adaptive sequence compression/decompression.
- Employs symmetrical compression of queries, keys, and values, preserving causality and boosting parallelism.
- Proposes a two-stage training: pre-train with Lighthouse Attention, then recover a full attention model.
Why it matters
Long-context transformers are bottlenecked by quadratic attention costs. This method provides a subquadratic, training-only solution, significantly accelerating pre-training and yielding better final models for extreme sequence lengths.
Original Abstract
Training causal transformers at extreme sequence lengths is bottlenecked by the quadratic time and memory of scaled dot-product attention (SDPA). In this work, we propose Lighthouse Attention, a training-only symmetrical selection-based hierarchical attention algorithm that wraps around ordinary SDPA and can be easily removed towards the end of the training. Our hierarchical selection is also gradient-free, which exempts us from dealing with a complicated and potentially inefficient backward pass kernel. Our contribution is three-fold: (i) A subquadratic hierarchical pre- and post-processing step that does adaptive compression and decompression of the sequence. (ii) A symmetrical compression strategy that pools queries, keys and values at the same time, while preserving left-to-right causality, which greatly improves parallelism. (iii) A two stage training approach which we pre-train for the majority of the time with Lighthouse Attention and recover a full attention model at the end with a short training. We run preliminary small scale LLM pre-training experiments that show the effectiveness of our method compared to full attention training with all other settings matched, where we achieve a faster total training time and lower final loss after the recovery phase. Full code is available at: https://github.com/ighoshsubho/lighthouse-attention
📬 Weekly AI Paper Digest
Get the top 10 AI/ML arXiv papers from the week — summarized, scored, and delivered to your inbox every Monday.