Scalable Hyperparameter-Divergent Ensemble Training with Automatic Learning Rate Exploration for Large Models
Hailing Cheng, Tao Huang, Chen Zhu, Antonio Alonso
TLDR
HDET repurposes data-parallel replicas for simultaneous, automatic learning rate exploration, improving large model training without extra cost.
Key contributions
- Repurposes data-parallel replicas for simultaneous learning rate exploration.
- Uses alternating fan-out (independent training) and converge (parameter averaging) phases.
- An auto-LR controller adapts the shared learning rate schedule based on replica performance.
- Framework generalizes to explore other scalar hyperparameters like dropout or weight decay.
Why it matters
This method significantly improves large model training by automatically exploring learning rates, enhancing optimization and generalization. It requires no extra budget or hyperparameter sweeps, acting as a simple drop-in replacement.
Original Abstract
Training large neural networks with data-parallel stochastic gradient descent allocates N GPU replicas to compute effectively identical updates -- a practice that leaves the rich space of learning rate configurations entirely unexplored during training. We propose Hyperparameter-Divergent Ensemble Training (HDET), a method that repurposes these replicas for simultaneous learning rate exploration at negligible communication overhead. HDET operates in alternating phases: a fan-out stage in which replicas train independently under a structured, symmetric spread of learning rates, and a converge stage in which parameters are averaged across all replicas via AllReduce every T steps. Building on this ensemble substrate, we further propose an automatic learning rate (auto-LR) controller that treats the relative training loss across replicas as a performance signal, updating the shared base schedule toward higher-performing configurations via a momentum-based gradient-free meta-update. The combined method produces a self-adapting learning rate schedule that improves both optimization quality and generalization without additional hyperparameter sweeps or training budget. Crucially, the framework generalizes beyond learning rate: any scalar hyperparameter that does not alter model architecture -- such as dropout rate, attention scale temperature, or weight-decay coefficient -- can be explored across replicas using the same fan-out/converge protocol, with inter-replica loss differences serving as zero-order hypergradients that guide the search direction. HDET is implemented as a drop-in replacement for PyTorch's OneCycleLR scheduler, requiring no changes to model architecture, optimizer, or data pipeline.
📬 Weekly AI Paper Digest
Get the top 10 AI/ML arXiv papers from the week — summarized, scored, and delivered to your inbox every Monday.