Significance
One key ingredient in deep learning is the stochastic gradient descent (SGD) algorithm, which allows neural nets to find generalizable solutions at flat minima of the high-dimensional loss function. However, it is unclear how SGD finds flat minima. Here, by analyzing SGD-based learning dynamics together with the loss function landscape, we discovered a robust inverse relation between weight fluctuation and loss landscape flatness opposite to the fluctuation–dissipation relation in physics. The reason for this inverse relationship is that the SGD noise strength and its correlation time depend inversely on the landscape flatness. Essentially, SGD serves as a landscape-dependent annealing algorithm to search for flat minima. These theoretical insights can lead to more efficient algorithms, e.g., for preventing catastrophic forgetting.
Keywords: statistical physics, machine learning, stochastic gradient descent, loss landscape, generalization
Abstract
Despite tremendous success of the stochastic gradient descent (SGD) algorithm in deep learning, little is known about how SGD finds generalizable solutions at flat minima of the loss function in high-dimensional weight space. Here, we investigate the connection between SGD learning dynamics and the loss function landscape. A principal component analysis (PCA) shows that SGD dynamics follow a low-dimensional drift–diffusion motion in the weight space. Around a solution found by SGD, the loss function landscape can be characterized by its flatness in each PCA direction. Remarkably, our study reveals a robust inverse relation between the weight variance and the landscape flatness in all PCA directions, which is the opposite to the fluctuation–response relation (aka Einstein relation) in equilibrium statistical physics. To understand the inverse variance–flatness relation, we develop a phenomenological theory of SGD based on statistical properties of the ensemble of minibatch loss functions. We find that both the anisotropic SGD noise strength (temperature) and its correlation time depend inversely on the landscape flatness in each PCA direction. Our results suggest that SGD serves as a landscape-dependent annealing algorithm. The effective temperature decreases with the landscape flatness so the system seeks out (prefers) flat minima over sharp ones. Based on these insights, an algorithm with landscape-dependent constraints is developed to mitigate catastrophic forgetting efficiently when learning multiple tasks sequentially. In general, our work provides a theoretical framework to understand learning dynamics, which may eventually lead to better algorithms for different learning tasks.
One key ingredient for the powerful deep neural network (DNN)-based machine-learning paradigm—deep learning (1)—is a relatively simple iterative method called stochastic gradient descent (SGD) (2, 3). However, despite the tremendous successes of deep learning, the reason why SGD is so effective in learning in a high-dimensional nonconvex loss function (energy) landscape remains poorly understood. The random element seems key for SGD, yet makes it harder to understand. Fortunately, many physical systems include such a random element, e.g., Brownian motion, and powerful tools have been developed for understanding collective behaviors in stochastic systems with many degrees of freedom. Here, we use concepts and methods from statistical physics to investigate the SGD dynamics, the loss function landscape, and more importantly their relationship.
We start by introducing the SGD-based learning process as a stochastic dynamical system. A learning system such as a neural network (NN), especially a DNN, has a large number () of weight parameters . For supervised learning, there is a set of training samples each with an input and a correct output for . For each input , the learning system predicts an output , where the output function depends on the architecture of the NN as well as its weights . The goal of learning is to find the weight parameters to minimize the difference between the predicted and correct output characterized by an overall loss function (or energy function)
[1] |
where is a measure of distance between and . In our study, a cross-entropy loss for is used.
One learning strategy is to update the weights by following the gradient of directly. However, this batch learning scheme is computationally prohibitive for large datasets and it also has the obvious shortfall of being trapped by local minima. SGD was first introduced to circumvent the large dataset problem by updating the weights according to a subset (minibatch) of samples randomly chosen at each iteration (2). Specifically, the change of weight for iteration in SGD is given by
[2] |
where is the learning rate and represents the random minibatch used for iteration . The minibatch loss function (MLF) for minibatch of size is defined as
[3] |
where () labels the randomly chosen samples.
Here, we introduce the key concept of a MLF ensemble , i.e., an ensemble of energy landscapes each from a random minibatch. The overall loss function is just the ensemble average of the MLF: . The SGD noise comes from the variation between a MLF and its ensemble average: .
By taking the continuous-time approximation and keeping the first-order time derivative term in Eq. 2, we obtain the following stochastic partial differential equation for SGD,
[4] |
where time and all timescales in this study are measured in the unit of minibatch iteration time . The continuous-time limit amounts to considering time scales that are much larger than ; e.g., one epoch time is . Eq. 4 is analogous to the Langevin equation in statistical physics. The first term is the deterministic gradient descent governed by the overall loss function analogous to the energy function in physics. The second term is the SGD noise term with zero mean and equal time correlation , which depends explicitly on .
Recently, there has been increasing evidence in support of the notion that “good” (generalizable) solutions exist at the flat (shallow) minima of the loss function (4–10); however, there is still little understanding of how SGD-based algorithms can find these flat minima in the high-dimensional weight space. The original gradient descent algorithm searches for loss function minima independent of their flatness and it also has the significant disadvantage of being easily trapped by local minima and saddle points in the high-dimensional weight space. Adding isotropic noise to gradient descent (GD) leads to a Langevin equation analogous to those used to describe stochastic dynamics in equilibrium physical systems. However, although the added noise can help GD escape local traps, it does not seem to improve generalization (11). The intuitive reason is that a “useful noise” should depend on the loss landscape. In particular, a useful noise should be larger in directions where the landscape is rougher to help escape from local traps and smaller in directions where the landscape is flatter to find and stay at good solutions. As first pointed out by Chaudhari and Soatto (12), unlike equilibrium physical systems where the noise has a constant strength given by the thermal temperature, the SGD dynamics are highly nonequilibrium as the SGD noise is anisotropic and varies in the weight space. Our working hypothesis is that SGD may serve as an efficient annealing strategy for varying the noise (or effective temperature) “intelligently” according to the loss function landscape to find the shallow (flat) minima. In this paper, we focus on studying the relation between stochastic learning dynamics and the loss landscape by adopting the nonequilibrium stochastic dynamics framework and using key concepts from statistical physics such as the fluctuation–dissipation relation to test this hypothesis.
Learning via Low-Dimensional Drift–Diffusion Dynamics in SGD
In general, SGD-based DNN learning dynamics can be roughly divided into an initial fast learning phase where both the training error and the training loss decrease rapidly, followed by an “exploration” phase where the training error becomes almost 0 but the loss still decreases albeit much slower. The weight vectors sampled in the exploration phase can all be considered solutions to the problem given their vanishing training error and small testing error; see SI Appendix for details of the two learning phases in SGD.
The original weights are highly coupled in neural networks, which makes them an inconvenient and unnatural basis for studying the high-dimensional learning dynamics. To circumvent this problem, we first use principal component analysis (PCA) to study weight variations in the exploration phase of SGD (see Methods for details of PCA). We then analyze the learning dynamics in the principal component basis because to the leading order the principal components can be considered as independent collective variables of the system around the minima of its loss function.
Within a large time window where is a time in the exploration phase and (=10 epochs used here) is a large time window,* the weight dynamics can be decomposed into its variations in different principal components
[5] |
where is the average weight vector in the time window , and is the th principal component base vector with . The projection of the weight vector along the PCA direction is given by , which is a linear combination of the individual weights, and is the weight vector in the PCA coordinate.
The results reported here are for a simple NN with two hidden layers each with 50 neurons for classification tasks using the Modified National Institute of Standards and Technology (MNIST) database (see Methods for details and other NN architectures used). The PCA was done for the weights between the two hidden layers (results for other NN architectures and databases are included in SI Appendix). In Fig. 1A, we show the PCA spectrum, i.e., the variance versus its rank in descending order . We found that the variance in the first PCA direction () is much larger than variances in other directions because the motion along has a net drift velocity (see discussion below and Fig. 1C). For other PCA directions, after a small number of leading PCA directions (), the variance decays rapidly with its rank: for with a large exponent before an even faster decay for higher . This means that most of the variations (dynamics) of the weights are concentrated in a relatively small number of PCA directions (dimensions). Quantitatively, as shown in Fig. 1B, even excluding , more than of the total variance occurs in the first 35 PCA modes much smaller than the total number of weights , which suggests that the SGD dynamics are embedded in a low-dimensional space (13).
Next, we studied the network dynamics along different PCA directions. We found that along the first PCA direction , there is a net drift velocity with a persistence time much longer than 1 epoch as clearly shown in Fig. 1C where SGD dynamics projected onto the space are shown. For all other PCA directions, the dynamics are random walks with a short velocity correlation time (shorter than 1 epoch) as clearly demonstrated in Fig. 2C where the SGD dynamics projected onto a randomly chosen pair of PCA directions are shown.
The persistent drift in the first PCA direction can be understood by moving a solution found by SGD along by to a new weight vector . We find that is highly aligned with . Therefore, moving along results roughly in an overall amplification of the weights and the difference between the outputs for the right class and the wrong classes, which leads to a change in the cross-entropy loss function with a constant parameter. Even though the training error at is already 0 (or close to 0), this dependence of on leads to the persistent motion along the direction with a low speed proportional to which slowly decreases with time itself (see SI Appendix, section S2 for details). The slow net drift along does not improve the training error (it is already zero), but it may improve the robustness of the solution by increasing the margin around the decision boundary and thus enhance generalization. This result is consistent with a previous study (14) in a simpler setting (using gradient descent to find homogeneous linear predictors on linearly separable datasets). Specifically, it was shown in ref. 14 that the predictor converges to the direction of the max-margin solution, which corresponds to the first PCA direction in our study.
In the rest of this paper, we study the majority of the PCA modes (), which are diffusive. Our focus is to understand the relation between fluctuations in these diffusive modes and the loss function landscape.
The Loss Function Landscape and the Inverse Variance–Flatness Relation
In the exploration phase, the loss function is small and all of the weight vectors along the SGD trajectory can be considered as valid solutions. However, the solutions found by a SGD trajectory represent only a small subset of valid solutions. To gain insights on the full solution space, we study the loss function landscape around a specific solution reached by SGD. Specifically, we compute the loss function profile along the th PCA direction determined from PCA of the SGD dynamics:
[6] |
In Fig. 2A, we show the loss function landscape profiles for several diffusive PCA directions . To characterize the one-dimensional loss function landscape along a given PCA direction near its minimum, we define a flatness parameter as the width of the region within which where is the Euler’s number (other constant; e.g., 2 can be used without affecting the results) and is the minimum loss at . As shown in Fig. 2B, we determine by finding the two closest points and on each side of the minimum that satisfy . The flatness parameter is simply defined as their difference:
[7] |
A larger value of means a flatter landscape in the th PCA direction.
The loss landscape has been studied by computing the Hessian matrix of the loss function (15–17). Even though the flatness parameter defined here is related to the Hessian, they are not the same. is a more robust measure of the landscape flatness as it contains nonlocal information of the landscape in a finite neighborhood of the minimum. In particular, even when the local curvature vanishes or becomes slightly negative, is still well defined (see SI Appendix, section S3 for a detailed comparison). As shown in Fig. 2B, for the MNIST data, can be fitted well by a quadratic function in a finite region near its minimum: . Of course, this simple quadratic fit for (or equivalently an inverse Gaussian fit for ) may not apply to all networks. Nevertheless, the “nonlocal” flatness parameter is well defined regardless of the exact fit and it is used to characterize the loss landscape around its minima in the rest of this paper.
We computed the flatness in each PCA direction and found that the flatness increases with as shown in Fig. 2C. Given that the SGD variance decreases with as shown in Fig. 1A, this immediately suggests an inverse relationship between the loss function landscape flatness and the SGD variance. Indeed, as shown in Fig. 2D for the MNIST data, the inverse variance–flatness relation follows approximately a power law
[8] |
where the exponent for different choices of and . Although the power law dependence may be specific to MNIST, the inverse dependence of on holds generally true in all other NN architectures and datasets we studied (see SI Appendix, section S6 for details).
The inverse variance–flatness relation is the key finding of our study. Previous work has studied either variations of the weights (13) or the landscape of the loss function (7, 17) but not the strong relation between the two that is discovered here. The inverse variance–flatness relation is highly unusual; it goes against physics intuition. In particular, according to equilibrium statistical physics, the fluctuation of a variable around its equilibrium value is proportional to the change of the variable in response to an external perturbation, which is known as the fluctuation–response (or fluctuation–dissipation) relation aka the Einstein relation (18). A generalized fluctuation–response relation also holds true even for nonequilibrium systems linearized near a fixed point (19). However, for SGD-based learning dynamics, the fluctuation–response relation would imply that the variance of a variable (PCA weight) should be larger for a flatter landscape, which is the opposite to the observed inverse relation shown in Fig. 2D. Therefore, the inverse variance–flatness relation in SGD can also be called the “inverse Einstein relation.”
What is the reason for the inverse Einstein relation in SGD? Unlike generic stochastic systems where the noise strength (e.g., temperature) is a constant, the SGD noise comes from the difference between the gradient of a random MLF and that of the overall (mean) loss function. Therefore, the noise is anisotropic and it varies in the weight space and in time. In the next section, we explain the inverse variance–flatness relation based on the dependence of the SGD noise on statistical properties of the MLF ensemble.
The Random Landscape Theory and Origin of the Inverse Variance–Flatness Relation
The most distinctive feature of SGD is that at any given iteration (time) the learning dynamics are driven by a random minibatch out of an ensemble of minibatches each with its own random MLF. To understand the SGD dynamics, we develop a random landscape theory to describe the statistical properties of the MLF ensemble near a solution (we set for convenience).
As shown in Fig. 3A, can be approximated by a quadratic function near its minimum or equivalently can be approximated by an inverse Gaussian function
[9] |
where is the weight parameter vector projected onto the th PCA direction. The MLF for minibatch is characterized by its minimum , its minimum location , and the symmetric Hessian matrix for at the minimum.
Within the quadratic approximation (Eq. 9), statistical properties of the MLF ensemble are determined by the joint distribution of the parameters , , , and . Based on our simulation results and as the first-order approximation, we treat these parameters as independent random variables with normal distributions. By using this mean-field approximation, we can obtain the overall loss function
[10] |
where is the minimum loss with , and and are constants that depend on statistical properties of the MLF ensemble. The overall loss function, Eq. 10, has an inverse Gaussian form that is consistent with the empirical results shown in Fig. 2 A and B, and the flatness parameter can be expressed as . Details of derivation of Eq. 10 (including expressions of and ) and empirical justification of the approximations can be found in SI Appendix, section S4.
From Eq. 9, we can now study the SGD dynamics analytically. By keeping only up to the linear order in , the SGD Langevin equation for becomes
[11] |
where is the velocity for at time . The equation above has an intuitive interpretation: At time , the weight vector is pulled by a random minibatch , whose MLF acts as a spring with a spring tensor and its force center positioned at .
For the diffusive PCA directions (), dynamics of are driven by a random velocity with zero mean . The autocorrelation function of can be written as where is the normalized correlation function, and is the velocity variance,
[12] |
where is a constant and is the variance of the minimum location of MLF projected onto the th PCA direction.
From the velocity–velocity correlation and the variance of the weight variable within the PCA time window, we can define a time scale that characterizes the velocity (or gradient) autocorrelation
[13] |
where is the PCA window size. Note that the minibatch time step is included explicitly in the definition above to make a time scale.
In generic stochastic dynamical systems (19), the noise is independent of the loss landscape. In SGD, however, the noise strength characterized by depends on the landscape flatness. According to Eq. 12, a flatter landscape has a smaller value of , which leads to a smaller . Both and can be determined from the MLF ensemble statistics. As shown in SI Appendix, Fig. S7, and , and therefore with an exponent as shown in Fig. 3B.
The time scale can be determined by Eq. 13 from the normalized velocity correlation function . In the absence of velocity correlation, i.e., , is a constant (isotropic) time scale determined by the PCA time window. However, we find that there are significant velocity correlations in the SGD dynamics, which leads to a much smaller . Remarkably, the correlation time also depends inversely on the flatness (see SI Appendix, section S5 for details of calculating ). As shown in Fig. 3B, this inverse dependence follows approximately a power law with the exponent . The exact reason for the inverse dependence of on is not yet clear. However, since can be interpreted as the number of minibatches needed to estimate the gradient , our results indicate that less minibatch subsampling is needed to infer gradients in flatter directions. This result may explain the reason why the Markov chain Monte Carlo (MCMC) method with a relatively small number of updates can be used to accurately estimate the local gradients in algorithms such as stochastic gradient Langevin dynamics in Bayesian learning (20) and entropy-SGD in deep learning (7).
Put all together, the inverse power-law dependence of and on the landscape flatness leads to the inverse power-law with an exponent , which is in quantitative agreement with the direct simulation result shown in Fig. 2D. Although the power law dependence may not be universal, the inverse dependence of , , and on holds true in general for all NN architectures and datasets we studied (see SI Appendix, section S6 for details).
Preventing Catastrophic Forgetting by Using Landscape-Dependent Constraints
To demonstrate the utility of the theoretical insights gained so far, we tackle a long-lasting challenge in machine learning, i.e., how to prevent catastrophic forgetting (CF) (21, 22). After a DNN learns to perform a particular task, it is trained for another task. Although the DNN can readjust its weights to perform well for the new task, it may forget the previous task and thus fail catastrophically for the previous task. To prevent forgetting, a recent study by Kirkpatrick et al. (23) proposed the elastic weight constraint (EWC) algorithm to train a new task by enforcing constraints on individual weights based on their effects on the performance of the previous task.
Here, by following the same general strategy but using the new insights on the geometry of loss landscape near a solution, we propose a landscape-dependent constraints (LDC) algorithm to train for the new task with constraints applied to the collective PCA coordinates. More specifically, when we find a solution in the weight space for the first task (task 1) by SGD, we can characterize the loss function landscape around in the PCA coordinate system by the flatness parameter along the th PCA direction for task 1. Based on the inverse variance–flatness relationship (Eq. 8), the flatness parameter can be obtained directly (cheaply) from the weight variance instead of computing and diagonalizing the Hessian matrix, which is computationally expensive. When learning the new task, we use a modified loss function for the second task (task 2) by introducing an additional cost term that penalizes the network for going out of the attraction basin of the task-1 solution () for a small number of PCA directions with the lowest values of flatness:
[14] |
where is the original loss function for task 2, is the overall strength of the constraints, and is the number of constrained PCA modes from task 1.
Based on the large attraction basin for a given task as evidenced by the large flatness parameters shown in Fig. 2, we expect that solutions for task 2 exist within the basin of solutions for task 1, so the performance for task 1 should not degrade significantly after learning task 2. We first tested this idea in the simplest case by using the MNIST database with task 1 and task 2 corresponding to classifying two disjoint subsets of digits, e.g., (0,1) for task 1 and (2,3) for task 2 (see Methods for details). To speed up our analysis, we first train the network on all tasks and fix the output layer and input layer; only the hidden layer(s) is initialized and trained for sequential learning with different algorithms. As shown in Fig. 4A in the absence of the constraints (), starting with a task-1 solution , the weights evolve quickly to a solution for task-2 with a small task-2 test error (red line). However, the performance of task 1 deteriorates quickly with a fast increasing task-1 test error (blue line). The reason can be understood in Fig. 4B, where , the projections of the weight displacement vector onto different PCA directions of task 1, are shown. Without constraints, the displacement becomes unbounded for many high-ranking PCA modes (smaller ), which leads to the large task-1 test error after learning task 2, i.e., catastrophic forgetting.
The performance improves significantly when task 2 is learned with the modified loss function given in Eq. 14. As shown in Fig. 4C, with constraints () for the top modes, although the learning process for task 2 is slightly slower, the system is able to learn a solution for task 2 with a comparable error as before (). The significant advantage here is that the performance for task 1 (blue line) remains roughly the same as before; i.e., the system has avoided catastrophic forgetting. As shown in Fig. 4D, now has an upper bound (dashed red line) for all of the top modes () due to the constraints. The upper bound is found to be proportional to the flatness (), which means that the task-2 solution remains within the basin of task-1 solutions.
There is a tradeoff between the two testing errors (, ) when varying . As shown in Fig. 4E, the performance of LDC is better than that of EWC. This is not surprising as LDC uses the full landscape information whereas EWC uses only the diagonal elements of the Fisher information matrix (effectively the Hessian matrix). More interestingly, as shown in Fig. 4F, the overall performance () of LDC reaches its optimal level when a relatively small number () of the top PCA modes are constrained. For EWC, however, all individual weights contribute to the performance, and thus its optimal performance is reached when all individual weights are constrained. The results from the LDC algorithm suggest that memory of the previous task is encoded in the top PCA modes and can be used to estimate the capacity of the network for sequential learning of multiple tasks.
In LDC, the landscape flatness, which is used in the constraints, can be estimated efficiently from the weight variance by using the variance–flatness relationship . To test whether LDC is sensitive to the accuracy of this estimate, we used different values of to estimate . We find that the results do not seem to depend on the exact choice of (see SI Appendix, section S7 for details), which suggests that LDC is robust as long as the constraints are added to the top PCA modes of the previous tasks.
We verified the advantage of the LDC algorithm by considering more complex sequential learning tasks such as more digits (5 instead of 2) in each task from the MNIST dataset and sequential learning of all of the animals and all man-made objects in the Canadian Institute For Advanced Research (CIFAR)10 dataset. The results are consistent with the simple case shown in Fig. 4 and confirm the advantage of using LDC for preventing catastrophic forgetting (see SI Appendix, section S8 for details).
SGD as a Self-Tuned Landscape-Dependent Annealing Strategy for Learning
In the final section of this paper, we go back to evaluate our initial working hypothesis on the learning strategy deployed in SGD. In an equilibrium system with state variables and a free energy function , the statistics of follow the Boltzmann distribution where is the constant temperature that characterizes the strength of the thermal fluctuations (we set the Boltzmann constant here). By expanding the loss function to the second order, around a minimum , it is easy to show that the variance of would be proportional to the squared flatness and temperature : , which is a direct consequence of the fluctuation–response (aka fluctuation–dissipation) relation in equilibrium statistical physics.
Remarkably, for the SGD-based learning dynamics, we found an inverse relation between fluctuations of the variables and the flatness of the loss function landscape, Eq. 8, which is the opposite to the fluctuation–response relation in equilibrium systems. We have tested it with different variants of the SGD algorithms such as adaptive moment estimation (Adam) and momentum-based algorithms, different databases (MNIST and CIFAR10), and different DNN architectures (see SI Appendix, Fig. S3 for details). In all cases we studied, the inverse variance–flatness relation holds, suggesting that it is a universal property of the SGD-based learning algorithms.
Unlike thermal noise in equilibrium systems, which represents a passive random driving force with a constant strength (temperature), the SGD “noise” represents an active learning/searching process that varies in “space” (). The intensity of this learning (searching) activity along the th PCA direction can be characterized by an active local temperature :
[15] |
where is the displacement from along the direction.
As shown in Fig. 5A, the active temperature has a similar spatial profile to that of the loss function with the active temperature higher away from the minimum. In weight space where the overall loss function is high, the active temperature is also high, which drives the system away from regions in the weight space with high losses. The learning intensity is anisotropic, and it differs in different PCA directions as shown in Fig. 5B. For a flatter direction with a larger value of , is lower (Fig. 5 B, Inset) as the basin of solutions is wide and thus no strong active learning is needed. However, for a steeper direction with a smaller value of , the solutions exist only in smaller regions and thus more intensive learning (or higher active temperature) is required. Therefore, the MLF ensemble can sense both the local loss and nonlocal flatness of the landscape in different directions and use this information to drive active learning.
The active temperature also varies with time. As learning progresses, the active temperature profile decreases with time, as shown in Fig. 5C. In Fig. 5D, dynamics of the active temperature and the overall loss function along a SGD trajectory are shown together. It is clear that the active temperature and the overall loss function are highly correlated as shown directly in Fig. 5 D, Inset, which means that the SGD system cools down as it learns. This reminds us of the well-known simulated annealing algorithm for optimization (24), where temperature is decreased from a high value to zero with some prescribed cooling schedule. However, the SGD algorithm seems to deploy a more intelligent landscape-dependent annealing strategy where the active temperature (learning intensity), driven by the MLF ensemble, is self-tuned according to the local and nonlocal properties of the loss landscape that are sensed by the MLF ensemble. This landscape-dependent annealing strategy drives the system toward the flat minima of the loss function landscape and stays at the flat minima by lowering the active temperature once there.
To verify the effects of landscape-dependent noise for generalization, we studied a simple algorithm where landscape-dependent or “flatness-detecting” noise is introduced to the deterministic GD dynamics. In particular, we have added an anisotropic noise term whose strength depends explicitly on the flatness of the landscape. We find that only with this flatness-detecting noise, the system can enter the exploration phase with flat minima and low generalization error (see SI Appendix, section S9 for details). These results (see also ref. 25 for similar results) support the conclusion that the anisotropic landscape-dependent noise in SGD is responsible for finding generalizable solutions.
Discussion
Modern DNNs often contain more parameters than training samples, which allow it to interpolate (memorize) all of the training samples, even if their labels are replaced by pure noise (26). Remarkably, despite their huge capacity, DNNs can achieve small generalization error on real data. This phenomenon has been formalized in the so-called “double-descent” curve (27). As the model capacity (complexity) increases, the test error follows the usual U-shaped curve at the beginning, first decreasing and then peaking near the interpolation threshold when the model achieves vanishing training error. However, it descends again as model capacity exceeds this interpolation threshold with the test error reaching its (global) minimum in the overparameterization regime where the number of parameters is much larger than the number of samples. Rapid progress has been made to understand this double-descent behavior by using simple models. For example, both optimization and generalization guarantees for overparameterized simple two-layer networks are proved with leaky rectified linear activation function (ReLU) on linearly separable data (28). This result has subsequently been extended to two-layer networks with ReLU activation (29) and two- and three-layer networks with smooth activation functions (30). In a different approach by using the neural tangent kernel (31), which connects large (wide) neural nets to kernel methods, it was shown that the generalization error decreases toward to a plateau value in a power-law fashion as with the number of parameters in the overparameterized regime (32). In simple synthetic learning models such as the random features model with ridge regression loss function, the double-descent behavior has been shown analytically (33) and this analytical result has been extended to other synthetic learning models (e.g., the random manifold model) and for more general loss functions by using the replica method (34).
However, despite this recent progress, how to characterize the complexity of the solutions in DNNs and how SGD seeks out simple and more generalizable solutions for more realistic learning tasks remain not well understood. The results in this paper shed some light on both of these questions. We found that in the overparameterized regime, starting with different random initializations SGD reaches different solutions with the same statistical properties. Around each solution, the loss landscape is flat in most PCA directions with only a small number of relevant directions where the loss landscape is sharp. The complexity of the solution can thus be characterized by an effective dimension of the solution, which can be defined by the number of sharp directions () in the loss landscape around this solution. In Fig. 6A, the rank-ordered flatness spectra of the loss landscape are shown for solutions found by networks with different sizes. It is clear that as the network size (width ) increases, the number of sharp directions (small flatness) does not change, and the landscapes along the additional degrees of freedom are flat with large values of flatness. As shown in Fig. 6B, the effective dimension , which is defined with a threshold set by the norm of the solution, is much smaller than the number of parameters . Most importantly, remains roughly a constant as increases. This means that the complexity of the solution found by SGD does not increase with the number of parameters, and the solution remains “simple” with good generalization performance in the overparameterized regime.
Our study also provides evidence on how SGD finds these low-dimensional simple solutions. In particular, we found that the SGD learning activity (temperature) is high only for those directions where the loss landscape is sharp while learning activity along other flat directions becomes quickly frozen during the SGD learning process. An effective learning (searching) dimension can be defined as the number of PCA directions that contain most (e.g., or ) of the total weight variance. Due to the inverse variance–flatness relation, is found to have a similar dependence on the network size to that of . Our results show that SGD searches only in a small subspace for solutions after the initial transient and the dimension of the search space has only weak dependence on the network size in the overparameterized regime (see SI Appendix, Fig. S17 for details).
In summary, a careful study of the SGD dynamics and the loss function landscape in this paper reveals a robust inverse relation between fluctuations in SGD and flatness of the loss landscape, which is critical for deciphering the learning strategy in deep learning and for designing more efficient algorithms. Our study demonstrates that ideas and techniques based on statistical physics provide an additional theoretical framework (alternative to the traditional theorem-proving approach) for studying machine learning. It would be interesting to use this framework to address other fundamental questions in machine learning such as generalization (35, 36), relation between task complexity and network architecture, information flow in DNN (37, 38), transfer learning (39), and continuous learning (40–42).
Methods
Neural Network Architecture and Simulations.
Two types of DNNs are studied: 1) Two fully connected neural networks were used for classifying digits in the MNIST database, one with two hidden layers (, main text) and the other with four hidden layers (; SI Appendix, Fig. S10A). The response of the hidden layer neurons is ReLU, activation of the output neurons is Softmax, and no bias neuron is used for convenience. We also studied the convolution neural network (CNN). 2) Two convolutional neural networks were used in our experiments. One is trained on the MNIST dataset, which has two convolution layers with size and and one fully connected layer with size . (Here we represent the convolution layer using input neural number kernel size kernel size output neural number.) The stride of convolution is 1 and there is a zero padding to keep the data dimension unchanged. After each convolution layer, there is a max pooling layer. Another CNN is trained on the CIFAR10 dataset (SI Appendix, Fig. S10C). It has two convolution layers with size and and three fully connected layers with size , , . The stride of convolution is 1 and the size of max pooling is . We do not use zero padding in this network. All numerical experiments are done on a neural network simulation framework torch.
Principal Component Analysis in Exploration Phase.
For a given time in the exploration phase, we extract the weight matrix between two hidden layers and reshape (flatten) the weight matrix into a weight vector; e.g., a weight matrix is flattened to a 2,500-dimensional vector. Then we stack these row vectors from different times horizontally and do PCA on this matrix. The time step is each minibatch and the total window size is epochs. The PCA is applied using the sklearn package provided by Python 3.7.
Multitask Learning.
We divided the MNIST into five groups. Each group contains only two numbers. Here we call each group task 1, task 2, etc. We use the fully connected neural networks with two hidden layers (). The size of the output layer is 10 so it works for all tasks. In the main text, we choose the group containing (0,1) as task 1 and the group containing (2,3) as task 2; see SI Appendix, section S8 for two more complex cases. The LDC learning algorithm follows the following steps: 1) Train the network on task 1; 2) when the learning dynamics for task 1 reach the exploration phase, do PCA to obtain and from which is determined from Eq. 8; 3) train task 2 by using the modified loss function Eq. 14.
Supplementary Material
Acknowledgments
We thank Mattia Rigotti, Irina Rish, Matthew Riemer, Robert Ajemain, Yunfei Teng, and Alberto Sassi for discussions. We also thank Jerry Tersoff, Tom Theis, and Youssef Mrouef for comments on the manuscript. The work by Y.F. was done when he was an IBM intern.
Footnotes
The authors declare no competing interest.
This article is a PNAS Direct Submission.
See online for related content such as Commentaries.
*Each epoch has iterations which covers all training samples once.
This article contains supporting information online at https://www.pnas.org/lookup/suppl/doi:10.1073/pnas.2015617118/-/DCSupplemental.
Data Availability
There are no data underlying this work.
References
- 1.LeCun Y., Bengio Y., Hinton G., Deep learning. Nature 521, 436–444 (2015). [DOI] [PubMed] [Google Scholar]
- 2.Robbins H., Monro S., A stochastic approximation method. Ann. Math. Stat. 22, 400–407 (1951). [Google Scholar]
- 3.Bottou L. “Large-scale machine learning with stochastic gradient descent” in Proceedings of COMPSTAT’2010, Lechevallier Y., Saporta G., Eds. (Physica-Verlag HD, Heidelberg, Germany, 2010), pp. 177–186. [Google Scholar]
- 4.Hinton G. E., van Camp D., “Keeping the neural networks simple by minimizing the description length of the weights” in Proceedings of the Sixth Annual Conference on Computational Learning Theory, COLT ‘93, L. Pitt, Ed. (ACM, New York, NY, 1993), pp. 5–13. [Google Scholar]
- 5.Hochreiter S., Schmidhuber J., Flat minima. Neural Comput. 9, 1–42 (1997). [DOI] [PubMed] [Google Scholar]
- 6.Baldassi C., et al. , Unreasonable effectiveness of learning neural networks: From accessible states and robust ensembles to basic algorithmic schemes. Proc. Natl. Acad. Sci. U.S.A. 113, E7655–E7662 (2016). [DOI] [PMC free article] [PubMed] [Google Scholar]
- 7.Chaudhari P., et al. , Entropy-SGD: Biasing gradient descent into wide valleys. arXiv 1611.01838 (6 November 2016).
- 8.Zhang Y., Saxe A. M., Advani M. S., Lee A. A., Energy–entropy competition and the effectiveness of stochastic gradient descent in machine learning. Mol. Phys. 116, 3214–3223 (2018). [Google Scholar]
- 9.Mei S., Montanari A., Nguyen P.-M., A mean field view of the landscape of two-layer neural networks. Proc. Natl. Acad. Sci. U.S.A. 115, E7665–E7671 (2018). [DOI] [PMC free article] [PubMed] [Google Scholar]
- 10.Baldassi C., Pittorino F., Zecchina R., Shaping the learning landscape in neural networks around wide flat minima. Proc. Natl. Acad. Sci. U.S.A. 117, 161–170 (2020). [DOI] [PMC free article] [PubMed] [Google Scholar]
- 11.An G., The effects of adding noise during backpropagation training on a generalization performance. Neural Comput. 8, 643–674 (1996). [Google Scholar]
- 12.Chaudhari P., Soatto S., “Stochastic gradient descent performs variational inference, converges to limit cycles for deep networks” in 2018 Information Theory and Applications Workshop (ITA) (IEEE, San Diego, CA, 2018), pp. 1–9. [Google Scholar]
- 13.Gur-Ari G., Roberts D. A., Dyer E., Gradient descent happens in a tiny subspace. arXiv:1812.04754 (12 December 2018).
- 14.Soudry D., Hoffer E., Nacson M. S., Gunasekar S., Srebro N., The implicit bias of gradient descent on separable data. J. Mach. Learn. Res. 19, 2822–2878 (2018). [Google Scholar]
- 15.Sagun L., Evci U., Guney V. U., Dauphin Y., Bottou L., Empirical analysis of the Hessian of over-parametrized neural networks. arXiv: 1706.04454 (14 June 2017).
- 16.Papyan V., Measurements of three-level hierarchical structure in the outliers in the spectrum of deepnet Hessians. arXiv: 1901.08244 (24 january 2019).
- 17.Ghorbani B., Krishnan S., Xiao Y., An investigation into neural net optimization via Hessian eigenvalue density. arXiv: 1901.10159 (29 January 2019).
- 18.Forster D., Hydrodynamic Fluctuations, Broken Symmetry, and Correlation Functions (CRC Press, 2018). [Google Scholar]
- 19.Kwon C., Ao P., Thouless D. J., Structure of stochastic dynamics near fixed points. Proc. Natl. Acad. Sci. U.S.A. 102, 13029–13033 (2005). [DOI] [PMC free article] [PubMed] [Google Scholar]
- 20.Welling M., Teh Y., “Bayesian learning via stochastic gradient Langevin dynamics” in Proceedings of the 28th International Conference on Machine Learning, ICML2011, L. Getoor, T. Scheffer, Eds. (Omipress, Bellevue, WA, 2011), pp. 681–688. [Google Scholar]
- 21.McCloskey M., Cohen N. J., Catastrophic interference in connectionist networks: The sequential learning problem. Psychol. Learn. Motiv. 24, 109–165 (1989). [Google Scholar]
- 22.Robins A., Catastrophic forgetting, rehearsal and pseudorehearsal. Connect. Sci. 7, 123–146 (1995). [Google Scholar]
- 23.Kirkpatrick J., et al. , Overcoming catastrophic forgetting in neural networks. Proc. Natl. Acad. Sci. U.S.A. 114, 3521–3526 (2017). [DOI] [PMC free article] [PubMed] [Google Scholar]
- 24.Kirkpatrick S., Gelatt C. D., Vecchi M. P., Optimization by simulated annealing. Science 220, 671–680 (1983). [DOI] [PubMed] [Google Scholar]
- 25.Zhu Z., Wu J., Yu B., Wu L., Ma J., “The anisotropic noise in stochastic gradient descent: Its behavior of escaping from sharp minima and regularization effects” in Proceedings of the International Conference on Machine Learning, K. Chaudhuri, R. Salakhutdinov, Eds. (Omnipress, Long Beach, CA, 2019), pp. 7654–7663. [Google Scholar]
- 26.Zhang C., Bengio S., Hardt M., Recht B., Vinyals O., Understanding deep learning requires rethinking generalization. arXiv:1611.03530 (10 November 2016).
- 27.Belkin M., Hsu D., Ma S., Mandal S., Reconciling modern machine-learning practice and the classical bias–variance trade-off. Proc. Natl. Acad. Sci. U.S.A. 116, 15849–15854 (2019). [DOI] [PMC free article] [PubMed] [Google Scholar]
- 28.Brutzkus A., Globerson A., Malach E., Shalev-Shwartz S., SGD learns over-parameterized networks that provably generalize on linearly separable data. arXiv:1710.10174 (27 October 2017).
- 29.Li Y., Liang Y., Learning overparameterized neural networks via stochastic gradient descent on structured data. Adv. Neural Inf. Process. Syst. 31, 8157–8166 (2018). [Google Scholar]
- 30.Allen-Zhu Z., Li Y., Song Z., “A convergence theory for deep learning via over-parameterization” in International Conference on Machine Learning, Chaudhuri K., Salakhutdinov R., Eds. (PMLR, 2019), pp. 242–252. [Google Scholar]
- 31.Jacot A., Gabriel F., Hongler C., Neural tangent kernel: Convergence and generalization in neural networks. Adv. Neural Inf. Process. Syst. 31, 8571–8580 (2018). [Google Scholar]
- 32.Geiger M., et al. , Scaling description of generalization with number of parameters in deep learning. J. Stat. Mech. Theor. Exp. 2020, 023401 (2020). [Google Scholar]
- 33.Mei S., Montanari A., The generalization error of random features regression: Precise asymptotics and double descent curve. arXiv:1908.05355 (14 August 2019).
- 34.Gerace F., Loureiro B., Krzakala F., Mézard M., Zdeborová L., Generalisation error in learning with random features and the hidden manifold model. arXiv:2002.09339 (21 February 2020).
- 35.Neyshabur B., Bhojanapalli S., McAllester D.Srebro N. “Exploring generalization in deep learning” in NIPS (Curran Associates Inc., Long Beach, CA, 2017), pp. 5949–5958. [Google Scholar]
- 36.Advani M. S., Saxe A. M., High-dimensional dynamics of generalization error in neural networks. arXiv:1710.03667 (10 October 2017). [DOI] [PMC free article] [PubMed]
- 37.Shwartz-Ziv R., Tishby N., Opening the black box of deep neural networks via information. arXiv:1703.00810 (2 March 2017).
- 38.Tishby N., Zaslavsky N., “Deep learning and the information bottleneck principle” in 2015 IEEE Information Theory Workshop (ITW) (IEEE, Jerusalem, Israel, 2015), pp. 1–5. [Google Scholar]
- 39.Yosinski J., Clune J., Bengio Y., Lipson H., How transferable are features in deep neural networks? arXiv:1411.1792 (6 November 2014).
- 40.Ring M. B., “Continual learning in reinforcement environments,” PhD thesis, University of Texas at Austin, Austin, TX: (1994). [Google Scholar]
- 41.Lopez-Paz D., Ranzato M., “Gradient episodic memory for continuum learning” in NIPS (Curran Associates Inc., Long Beach, CA, 2017), pp. 5967–5976. [Google Scholar]
- 42.Riemer M., et al. , Learning to learn without forgetting by maximizing transfer and minimizing interference. arXiv:1810.11910 (29 October 2018).
Associated Data
This section collects any data citations, data availability statements, or supplementary materials included in this article.
Supplementary Materials
Data Availability Statement
There are no data underlying this work.