Training Non-Differentiable Networks via Optimal Transport
TLDR
PolyStep is a gradient-free optimizer using optimal transport to train non-differentiable neural networks, outperforming existing methods significantly.
Key contributions
- Introduces PolyStep, a gradient-free optimizer for genuinely non-differentiable neural networks.
- Updates parameters via optimal transport on structured polytope vertices, requiring only forward passes.
- Achieves 93.4% accuracy on hard-LIF spiking networks, vastly outperforming gradient-free baselines.
- Demonstrates superior performance across int8 quantization, argmax attention, hard MoE routing, and RL policy search.
Why it matters
This paper addresses a critical limitation in training neural networks with non-differentiable components, which are increasingly common. PolyStep provides a robust, effective gradient-free solution where existing methods fail, enabling broader application of complex, modern architectures. It closes a significant gap in training non-differentiable elements without surrogate gradient biases.
Original Abstract
Neural networks increasingly embed non-differentiable components (spiking neurons, quantized layers, discrete routing, blackbox simulators, etc.) where backpropagation is inapplicable and surrogate gradients introduce bias. We present PolyStep, a gradient-free optimizer that updates parameters using only forward passes. Each step evaluates the loss at structured polytope vertices in a compressed subspace, computes softmax-weighted assignments over the resulting cost matrix, and displaces particles toward low-cost vertices via barycentric projection. This update corresponds to the one-sided limit of a regularized optimal-transport problem, inheriting its geometric structure without Sinkhorn iterations. PolyStep trains genuinely non-differentiable models where existing gradient-free methods collapse to near-random accuracy. On hard-LIF spiking networks we reach 93.4% test accuracy, outperforming all gradient-free baselines by over 60~pp and closing to within 4.4~pp of a surrogate-gradient Adam ceiling. Across four additional non-differentiable architectures (int8 quantization, argmax attention, staircase activations, hard MoE routing) we lead every gradient-free competitor. On MAX-SAT scaling from 100 to 1M variables, we sustain above 92% clause satisfaction while evolution strategies drop 8--12~pp. On RL policy search, we match OpenAI-ES on classical control and retain performance under integer and binary quantization that collapses gradient-based methods. We prove convergence to conservative-stationary points at rate $O(\log T/\sqrt{T})$ on piecewise-smooth losses, upgraded to Clarke-stationary on the headline architectures and extended to the piecewise-constant regime via a hitting-time bound. These rates match the known zeroth-order query-complexity lower bounds that all forward-only methods inherit. Code is available at https://github.com/anindex/polystep.
📬 Weekly AI Paper Digest
Get the top 10 AI/ML arXiv papers from the week — summarized, scored, and delivered to your inbox every Monday.