Detecting overfitting in Neural Networks during long-horizon grokking using Random Matrix Theory
Hari K. Prakash, Charles H Martin
TLDR
A new Random Matrix Theory method detects overfitting in neural networks, even in large LLMs, by identifying "Correlation Traps" in weight matrices.
Key contributions
- Introduces a novel Random Matrix Theory method to detect neural network overfitting without train/test data.
- Identifies "Correlation Traps" in weight matrices as spectral outliers indicating the onset of overfitting.
- Proposes an empirical approach to distinguish benign from harmful traps using random data and JS divergence.
- Reveals "anti-grokking" as a distinct overfitting phase characterized by the formation of these traps.
Why it matters
Detecting overfitting in NNs is crucial but challenging. This paper offers a novel, data-agnostic method using Random Matrix Theory to identify "Correlation Traps" in weight matrices. This approach provides a new structural understanding of overfitting, particularly in large language models, enabling better generalization.
Original Abstract
Training Neural Networks (NNs) without overfitting is difficult; detecting that overfitting is difficult as well. We present a novel Random Matrix Theory method that detects the onset of overfitting in deep learning models without access to train or test data. For each model layer, we randomize each weight matrix element-wise, $\mathbf{W} \to \mathbf{W}_{\mathrm{rand}}$, fit the randomized empirical spectral distribution with a Marchenko-Pastur distribution, and identify large outliers that violate self-averaging. We call these outliers Correlation Traps. During the onset of overfitting, which we call the "anti-grokking" phase in long-horizon grokking, Correlation Traps form and grow in number and scale as test accuracy decreases while train accuracy remains high. Traps may be benign or may harm generalization; we provide an empirical approach to distinguish between them by passing random data through the trained model and evaluating the JS divergence of output logits. Our findings show that anti-grokking is an additional grokking phase with high train accuracy and decreasing test accuracy, structurally distinct from pre-grokking through its Correlation Traps. More broadly, we find that some foundation-scale LLMs exhibit the same Correlation Traps, indicating potentially harmful overfitting.
📬 Weekly AI Paper Digest
Get the top 10 AI/ML arXiv papers from the week — summarized, scored, and delivered to your inbox every Monday.