Adversarial Label Invariant Graph Data Augmentations for Out-of-Distribution Generalization
Simon Zhang, Ryan P. DeMilt, Kun Jin, Cathy H. Xia
TLDR
RIA introduces adversarial label-invariant graph data augmentations to improve out-of-distribution generalization under covariate shift.
Key contributions
- Proposes RIA, a new method for out-of-distribution (OoD) generalization under covariate shift.
- Uses adversarial label-invariant data augmentations to explore new training environments.
- Compatible with many existing OoD methods formulated as constrained optimization problems.
- Achieves high accuracy on OoD graph classification tasks with various distribution shifts.
Why it matters
Out-of-distribution generalization is crucial for real-world applications where training and testing data differ. This paper offers a novel adversarial training approach that enhances robustness and performance, making models more reliable in diverse environments.
Original Abstract
Out-of-distribution (OoD) generalization occurs when representation learning encounters a distribution shift. This occurs frequently in practice when training and testing data come from different environments. Covariate shift is a type of distribution shift that occurs only in the input data, while the concept distribution stays invariant. We propose RIA - Regularization for Invariance with Adversarial training, a new method for OoD generalization under convariate shift. Motivated by an analogy to $Q$-learning, it performs an adversarial exploration for training data environments. These new environments are induced by adversarial label invariant data augmentations that prevent a collapse to an in-distribution trained learner. It works with many existing OoD generalization methods for covariate shift that can be formulated as constrained optimization problems. We develop an alternating gradient descent-ascent algorithm to solve the problem, and perform extensive experiments on OoD graph classification for various kinds of synthetic and natural distribution shifts. We demonstrate that our method can achieve high accuracy compared with OoD baselines.
📬 Weekly AI Paper Digest
Get the top 10 AI/ML arXiv papers from the week — summarized, scored, and delivered to your inbox every Monday.