Posterior Augmented Flow Matching
George Stoica, Sayak Paul, Matthew Wallingford, Vivek Ramanujan, Abhay Nori + 4 more
TLDR
PAFM enhances flow matching by using posterior-augmented supervision to reduce training variance, preventing flow collapse and improving generative models.
Key contributions
- Introduces Posterior-Augmented Flow Matching (PAFM) to overcome flow collapse in high-dimensional data.
- Replaces single-target supervision with an expectation over an approximate posterior of valid target completions.
- Reduces gradient variance significantly by aggregating information from multiple plausible trajectories.
- Achieves up to 3.4 FID50K improvement across diverse models and benchmarks with minimal overhead.
Why it matters
Flow Matching (FM) suffers from sparse supervision, causing flow collapse and poor generalization in high-dimensional data. PAFM solves this by reducing training variance, leading to more stable and effective generative models. This significantly improves image generation quality across various benchmarks.
Original Abstract
Flow matching (FM) trains a time-dependent vector field that transports samples from a simple prior to a complex data distribution. However, for high-dimensional images, each training sample supervises only a single trajectory and intermediate point, yielding an extremely sparse and high-variance training signal. This under-constrained supervision can cause flow collapse, where the learned dynamics memorize specific source-target pairings, mapping diverse inputs to overly similar outputs, failing to generalize. We introduce Posterior-Augmented Flow Matching (PAFM), a theoretically grounded generalization of FM that replaces single-target supervision with an expectation over an approximate posterior of valid target completions for a given intermediate state and condition. PAFM factorizes this intractable posterior into (i) the likelihood of the intermediate under a hypothesized endpoint and (ii) the prior probability of that endpoint under the condition, and uses an importance sampling scheme to construct a mixture over multiple candidate targets. We prove that PAFM yields an unbiased estimator of the original FM objective while substantially reducing gradient variance during training by aggregating information from many plausible continuation trajectories per intermediate. Finally, we show that PAFM improves over FM by up to 3.4 FID50K across different model scales (SiT-B/2 and SiT-XL/2), different architectures (SiT and MMDiT), and in both class and text conditioned benchmarks (ImageNet and CC12M), with a negligible increase in the compute overhead. Code: https://github.com/gstoica27/PAFM.git.
📬 Weekly AI Paper Digest
Get the top 10 AI/ML arXiv papers from the week — summarized, scored, and delivered to your inbox every Monday.