Natural Riemannian gradient for learning functional tensor networks
Nikolas Klug, Michael Ulbrich, André Uschmajew, Marius Willner
TLDR
This paper introduces a natural Riemannian gradient descent for functional tensor networks, significantly improving convergence in various machine learning tasks.
Key contributions
- Proposes a natural Riemannian gradient descent for optimizing functional tree tensor networks (TTNs).
- Extends TTN optimization to arbitrary loss functions, beyond just least-squares regression.
- Ensures the gradient search direction is independent of the functional tensor product space basis.
- Introduces efficient approximations for practical updates, showing improved convergence over standard methods.
Why it matters
This work addresses a key limitation in optimizing functional tensor networks, making them applicable to a wider range of machine learning problems beyond least-squares. By improving convergence, it enhances the practical utility and efficiency of TTNs for tasks like classification.
Original Abstract
We consider machine learning tasks with low-rank functional tree tensor networks (TTN) as the learning model. While in the case of least-squares regression, low-rank functional TTNs can be efficiently optimized using alternating optimization, this is not directly possible in other problems, such as multinomial logistic regression. We propose a natural Riemannian gradient descent type approach applicable to arbitrary losses which is based on the natural gradient by Amari. In particular, the search direction obtained by the natural gradient is independent of the choice of basis of the underlying functional tensor product space. Our framework applies to both the factorized and manifold-based approach for representing the functional TTN. For practical application, we propose a hierarchy of efficient approximations to the true natural Riemannian gradient for computing the updates in the parameter space. Numerical experiments confirm our theoretical findings on common classification datasets and show that using natural Riemannian gradient descent for learning considerably improves convergence behavior when compared to standard Riemannian gradient methods.
📬 Weekly AI Paper Digest
Get the top 10 AI/ML arXiv papers from the week — summarized, scored, and delivered to your inbox every Monday.