
arXiv: 2106.02073
The recently discovered Neural Collapse (NC) phenomenon occurs pervasively in today's deep net training paradigm of driving cross-entropy (CE) loss towards zero. During NC, last-layer features collapse to their class-means, both classifiers and class-means collapse to the same Simplex Equiangular Tight Frame, and classifier behavior collapses to the nearest-class-mean decision rule. Recent works demonstrated that deep nets trained with mean squared error (MSE) loss perform comparably to those trained with CE. As a preliminary, we empirically establish that NC emerges in such MSE-trained deep nets as well through experiments on three canonical networks and five benchmark datasets. We provide, in a Google Colab notebook, PyTorch code for reproducing MSE-NC and CE-NC: at https://colab.research.google.com/github/neuralcollapse/neuralcollapse/blob/main/neuralcollapse.ipynb. The analytically-tractable MSE loss offers more mathematical opportunities than the hard-to-analyze CE loss, inspiring us to leverage MSE loss towards the theoretical investigation of NC. We develop three main contributions: (I) We show a new decomposition of the MSE loss into (A) terms directly interpretable through the lens of NC and which assume the last-layer classifier is exactly the least-squares classifier; and (B) a term capturing the deviation from this least-squares classifier. (II) We exhibit experiments on canonical datasets and networks demonstrating that term-(B) is negligible during training. This motivates us to introduce a new theoretical construct: the central path, where the linear classifier stays MSE-optimal for feature activations throughout the dynamics. (III) By studying renormalized gradient flow along the central path, we derive exact dynamics that predict NC.
ICLR 2022 Outstanding Paper Prize & Oral. Appendix contains [A] empirical experiments, [B-D] proofs of theoretical results, and [E] survey of related works examining Neural Collapse
Mathematics - Differential Geometry, FOS: Computer and information sciences, Computer Science - Machine Learning, Computer Science - Artificial Intelligence, Machine Learning (stat.ML), Machine Learning (cs.LG), Artificial Intelligence (cs.AI), Differential Geometry (math.DG), Statistics - Machine Learning, Optimization and Control (math.OC), FOS: Mathematics, Mathematics - Optimization and Control
Mathematics - Differential Geometry, FOS: Computer and information sciences, Computer Science - Machine Learning, Computer Science - Artificial Intelligence, Machine Learning (stat.ML), Machine Learning (cs.LG), Artificial Intelligence (cs.AI), Differential Geometry (math.DG), Statistics - Machine Learning, Optimization and Control (math.OC), FOS: Mathematics, Mathematics - Optimization and Control
| selected citations These citations are derived from selected sources. This is an alternative to the "Influence" indicator, which also reflects the overall/total impact of an article in the research community at large, based on the underlying citation network (diachronically). | 1 | |
| popularity This indicator reflects the "current" impact/attention (the "hype") of an article in the research community at large, based on the underlying citation network. | Average | |
| influence This indicator reflects the overall/total impact of an article in the research community at large, based on the underlying citation network (diachronically). | Average | |
| impulse This indicator reflects the initial momentum of an article directly after its publication, based on the underlying citation network. | Average |
