HubRouter: A Pluggable Sub-Quadratic Routing Primitive for Hybrid Sequence Models
TLDR
HubRouter is a pluggable module that replaces O(n^2) attention with O(nM) hub-mediated routing, offering significant throughput gains.
Key contributions
- Replaces O(n^2) attention with O(nM) hub-mediated routing using M learned hub tokens.
- Implements an encode-decode-score-council pipeline for efficient token routing.
- Hub-Jamba shows 4.2% PPL improvement and up to 90x training throughput (10-15x optimized).
- Graduated replacement of Transformer attention layers improves perplexity (268.0 vs 282.4).
Why it matters
This paper introduces a novel sub-quadratic routing primitive that addresses the computational bottleneck of O(n^2) attention in sequence models. It offers significant throughput improvements and competitive perplexity, paving the way for more efficient and scalable model architectures.
Original Abstract
We introduce HubRouter, a pluggable module that replaces O(n^2) attention layers with O(nM) hub-mediated routing, where M << n is a small number of learned hub tokens. We demonstrate it in two from-scratch architectures: a Jamba-style hybrid and a 12-layer Transformer; retrofit into pretrained models is a tested negative case. HubRouter implements an encode-decode-score-council pipeline: M learned hubs cross-attend to all tokens, tokens project against hubs for routing fingerprints, a score head selects top-k tokens, and a sparse council attends only to the selected subset. We validate HubRouter in three settings. (1) Hub-Jamba yields a nominal 4.2% PPL improvement (200.2 vs 209.0, single seed; possibly within seed noise) and up to ~90x training throughput at sequence length 1024 in matched PyTorch-native baselines; an optimised baseline would narrow this to ~10-15x. (2) Graduated replacement of 25% of Transformer attention layers gives the best perplexity in our matched-budget sweep (268.0 vs 282.4 pure Transformer). (3) Hub-GPT provides strictly causal routing, achieving PPL 211.5 +/- 0.4 over 3 seeds (post council-causal fix); approximately 3 PPL worse than Jamba's 208.5 +/- 0.7, a measurable quality cost for avoiding O(n^2) computation. Post-fix, chunk size C has little effect; the pre-fix chunk-size benefit was an artifact of a bidirectional-council leak we found in adversarial review. A multi-seed hub-count sweep (~105 runs across M=1-32) reveals M=8-14 as the reliably-converging sub-band (4-5/5 seeds); M=6 is rescued to 5/5 by orthogonal regularization, while M>=20 shows increasing seed sensitivity. Companion paper arXiv:2603.20997 (Basu, 2026) defines the routing diagnostic task. Code and scripts will be released.
📬 Weekly AI Paper Digest
Get the top 10 AI/ML arXiv papers from the week — summarized, scored, and delivered to your inbox every Monday.