Skip to main content

This is a preprint.

It has not yet been peer reviewed by a journal.

The National Library of Medicine is running a pilot to include preprints that result from research funded by NIH in PMC and PubMed.

ArXiv logoLink to ArXiv
[Preprint]. 2025 Jan 16:arXiv:2411.03840v2. [Version 2]

Flexible task abstractions emerge in linear networks with fast and bounded units

Kai Sandbrink 1,*, Jan P Bauer 2,*, Alexandra M Proca 3,*, Andrew M Saxe 4, Christopher Summerfield 5, Ali Hummos 6,*,
PMCID: PMC11774440  PMID: 39876939

Abstract

Animals survive in dynamic environments changing at arbitrary timescales, but such data distribution shifts are a challenge to neural networks. To adapt to change, neural systems may change a large number of parameters, which is a slow process involving forgetting past information. In contrast, animals leverage distribution changes to segment their stream of experience into tasks and associate them with internal task abstractions. Animals can then respond flexibly by selecting the appropriate task abstraction. However, how such flexible task abstractions may arise in neural systems remains unknown. Here, we analyze a linear gated network where the weights and gates are jointly optimized via gradient descent, but with neuron-like constraints on the gates including a faster timescale, nonnegativity, and bounded activity. We observe that the weights self-organize into modules specialized for tasks or sub-tasks encountered, while the gates layer forms unique representations that switch the appropriate weight modules (task abstractions). We analytically reduce the learning dynamics to an effective eigenspace, revealing a virtuous cycle: fast adapting gates drive weight specialization by protecting previous knowledge, while weight specialization in turn increases the update rate of the gating layer. Task switching in the gating layer accelerates as a function of curriculum block size and task training, mirroring key findings in cognitive neuroscience. We show that the discovered task abstractions support generalization through both task and subtask composition, and we extend our findings to a non-linear network switching between two tasks. Overall, our work offers a theory of cognitive flexibility in animals as arising from joint gradient descent on synaptic and neural gating in a neural network architecture.

1. Introduction

Humans and other animals show a remarkable capacity for flexible and adaptive behavior in the face of changes in the environment. Brains leverage change to discover latent factors underlying their sensory experience [Gershman and Niv, 2010, Yu et al., 2021, Castañón et al., 2021]: they segment the computations to be learned into discrete units or ‘tasks’. After learning multiple tasks, low-dimensional task representations emerge that are abstract (represent the task invariant to the current input) [Bernardi et al., 2020] and compositional [Tafazoli et al., 2024]. The discovery of these useful task abstractions relies on the temporal experience of change, and in fact, brains struggle when trained on randomly shuffled interleaved data [Flesch et al., 2018, Beukers et al., 2024].

In contrast, while artificial neural networks have become important models of cognition, they perform well in environments with large, shuffled datasets but struggle with temporally correlated data and distribution shifts. To adapt to changing data distributions (or ‘tasks’), neural networks rely on updating their high-dimensional parameter space, even when revisiting previously learned tasks – leading to catastrophic forgetting [McCloskey and Cohen, 1989, Hadsell et al., 2020]. One way to limit this forgetting is through task abstractions, either provided to the models [Hummos et al., 2024] or discovered from data [Hummos, 2023]. In addition, adapting a model entirely by updating its weights is data-intensive due to high dimensionality. Task abstractions simplify this process by allowing updates to a low-dimensional set of parameters, which can be switched rapidly between known tasks, and recomposed for new ones. However, despite the advantages of task abstractions, simple algorithms for segmenting tasks from a stream of data in neural systems remain an open challenge.

This paper examines a novel setting where task abstractions emerge in a linear gated network model with several neural pathways, each gated by a corresponding gating variable. We jointly optimize the weight layer and gating layer through gradient descent, but impose faster timescale, nonnegativity, and bounded activity on the gating layer units, making them conceptually closer to biological neurons. We find two discrete learning regimes for such networks based on hyperparameters, a flexible learning regime in which knowledge is preserved and task structure is integrated flexibly, and a forgetful learning regime in which knowledge is overwritten in each successive task. In the flexible regime, the gating layer units align to represent tasks and subtasks encountered while the weights separate into modules that align with the computations required. Later on, gradient descent dynamics in the gating layer neurons can retrieve or combine existing representations to switch between previous tasks or solve new ones. Such flexible gating-based adaptation offers a parsimonious mechanism for continual learning and compositional generalization [Butz et al., 2019, Hummos, 2023, Qihong Lu et al., 2024, Schug et al., 2024]. Our key contributions thus are as follows:

  • We describe flexible and forgetful modes of task-switching in neural networks and analytically identify the effective dynamics that induce the flexible regime.

  • The model, to our knowledge, is the first simple neural network model that benefits from data distribution shifts and longer task blocks rather than interleaved training [Flesch et al., 2018, Beukers et al., 2024]. We also provide a direct comparison to human behavior where task switching accelerates with further task practice [Steyvers et al., 2019].

  • We generalize our findings to fully-connected deep linear networks. We find that differential learning rates and regularization on the second layer weights are necessary and sufficient for earlier layers to form task-relevant modules and later layers to implement a gating-based solution that selects the relevant modules for each task.

  • We extend our findings to non-linear networks. As a limited proof of concept, we embed such a layer in a non-linear convolutional network learning two-digit classification tasks.

2. Related work

Cognitive flexibility allows brains to adapt behavior in response to change [Miller and Cohen, 2001, Egner, 2023, Sandbrink and Summerfield, 2024]. Neural network models of cognitive flexibility frequently represent knowledge for different tasks in distinct neural populations, or modules, which then need to be additionally gated or combined [Musslick and Cohen, 2021, Yang et al., 2019]. Several models assumed access to ground truth task identifiers and used them to provide information about the current task demands to the network [Kirkpatrick et al., 2017, Masse et al., 2018, Yang et al., 2019, Wang and Zhang, 2022, Driscoll et al., 2024, Hummos et al., 2024]. Indeed having access to task identifiers facilitates learning, decreases forgetting, and enables compositional generalization [Yang et al., 2019, Masse et al., 2022, Hummos et al., 2024]. Such works sidestep the problem of discovering these task representations from the data stream.

Other models train modular neural structures end-to-end, such as mixture-of-experts [Jacobs et al., 1991, Jordan and Jacobs, 1994, Tsuda et al., 2020], or modular networks [Andreas et al., 2016, Kirsch et al., 2018, Goyal et al., 2019]. A fundamental issue is the ‘identification problem’ where different assignments of experts to tasks do not significantly influence how well the model can fit the data, making identification of useful sub-computations via specialized experts difficult [Geweke, 2007]. Practically, this results in a lack of modularity with tasks learned across many experts [Mittal et al., 2020] or expert collapse, where few experts are utilized [Krishnamurthy et al., 2023]. Recent work used a surprise signal to allow temporal experience to adapt learning [Barry and Gerstner, 2022]. Our model proposes simple dynamics that benefit from the temporal structure to assign sub-tasks to modules.

Our work builds on the theoretical study of linear networks which exhibit complex learning dynamics, but are analytically tractable [Saxe et al., 2013, 2019]. Prior work examined how gating alleviates interference [Saxe et al., 2022], but gating was static and provided as data to the network. We generalize this line of work by showing how appropriate gating emerges dynamically. More recently, Shi et al. [2022] analyzed specialization of a linear network with multiple paths when tasks are presented without blocking and gates, and Lee et al. [2024] studied the effects of a pretraining period. We consider continual learning with a blocked curriculum. Schug et al. [2024] proved that learning a linear number of (connected) task module combinations is sufficient for compositional generalization to an exponential number of module combinations in a modular architecture similar to ours. Instead, we explicitly study the interaction between task learning and gating variable update dynamics.

3. Approach

3.1. Setting

We formulate a dynamic learning problem consisting of M distinct tasks. At each time step t the network is presented with an input and output pair xt,ymxt sampled from the current task m. Tasks are presented in blocks lasting a period of time τB before switching to another task sequentially (Fig. 1A). Models are never given the task identity m or task boundaries.

Figure 1: The open-ended learning setting and the modeling approach.

Figure 1:

A. Example of the blocked curriculum with two tasks. B. Neural Task Abstraction (NTA) model updates Wp through gradient descent, but also the gating variables cp, leading to task abstractions emerging in the gating layer.

Specifically, we consider a multitask teacher-student setup in which each task is defined by a teacher Wm, which generates a ground truth label ym=Wmx with a Gaussian i.i.d. input x at every point in time. We randomly generate the teachers to produce orthogonal responses to the same input. While orthogonal tasks simplify theoretical analysis, we generalize to non-orthogonal tasks in Appendix A.9.

3.2. Model

We study the ability of linear gated neural networks [Saxe et al., 2013, 2022] to adapt to teachers sequentially. We use a one-layer linear gated network with P student weight matrices Wpdout×din, together with P scalar variables cp which gate a cluster of neurons in the hidden layer (Fig. 1B).

The model output ydout reads

y=p=1PcpWpx. (1, NTA)

Since the cp variables will learn to reflect which task is currently active, we refer to their activation patterns as task abstractions. We refer to a student weight matrix together with its corresponding gating variable as a path.

We refer to this architecture as the Neural Task Abstractions (NTA) model when the following two conditions are met during training: first, we update both the weights Wp and the gating variables cp via gradient descent, but on a regularized loss function 𝓛=𝓛task+𝓛reg. Second, we impose a shorter timescale for the gates τc than for the weights τw, i.e. τc<τw (although this condition becomes unnecessary if the task is sufficiently high-dimensional, see Appendix A.2).

The task loss is a mean-squared error 𝓛task=1/2idoutyimyi2 where the average is taken over a batch of samples. The regularization loss contains two components 𝓛reg=λnorm𝓛norm+λnonneg𝓛nonneg weighted by coefficients λnorm, λnonneg. The normalization term bounds gate activity, 𝓛norm=1/2ck12, where we consider k=1,2. The nonnegativity term favors positive gates 𝓛nonneg=p=1Pmax0,cp. Together, these regularizers incentivize the model to function as an approximate mixture model by driving solutions towards any convex combination of students without favoring specialization and reflect constraints of biological neurons (see Appendix B.3 for details).

Assuming small learning rates (gradient flow), this approach implies updates of

τcd/dtcp=cp𝓛,τwd/dtWp=Wp𝓛

where τc and τw are time constants of the model parameters. We initialize Wp as i.i.d. Gaussian with small variance σ2/din, σ=0.01 and cp=12.

Code for model and simulations at: https://github.com/aproca/neural_task_abstraction

4. Task abstractions emerge through joint gradient descent

We train the model with fast and bounded gates on M=2 alternating tasks (Fig. 1A) and use P=2 paths for simplicity (for the PM case, see Fig. 3 and Appendix A.3). As a baseline, we compare to the same model but without gate regularization and timescales difference.

Figure 3: Flexible model generalizes to compositional tasks.

Figure 3:

A. Task composition consists of new tasks that sum sets of teachers previously encountered. B. Subtask composition consists of new tasks that concatenate alternating rows of sets of teachers previously encountered. Loss of models trained on generalization to task composition (C.) and subtask composition (D.) for the flexible (black) and forgetful (gray) NTA. ‘New tasks’ indicates the start of the generalization phase when the task curriculum is changed to cycle through the compositional tasks.

Both models reach low loss in early blocks, but only flexible NTA starts to adapt to task switches increasingly fast after several block changes (Fig. 2A,F). Analyzing the model components reveals what underlies this accelerated adaptation (Fig. 2C,D): in early blocks of training, zero loss is reached through re-aligning both students Wp to the active teacher Wm in every block, while the gates cp are mostly drifting (Fig. 2B). Reaching low loss is furthermore only achieved towards the end of a block. Later, the weights stabilize to each align with one of the teachers (Fig. 2C,D), and the appropriate student is selected via gate changes (Fig. 2B), reducing loss quickly. The rate at which gates change correlates with the alignment and magnitude Wp of the learned weights (Fig. 2C,E). Overall, this points towards a transition between two learning regimes: first, learning happens by aligning student weight matrices with the current teacher, which we call forgetful because it overwrites previous weight changes. Later, as the weights specialize, the learned representations Wp can be rapidly selected by the gates according to the task at hand, reflecting adaptation that is flexible. Only the model equipped with fast and bounded gates (flexible NTA) is able to enter this flexible regime (Fig. 2A,F).

Figure 2: Joint gradient descent on gates and weights enables fast adaptation through gradual specialization.

Figure 2:

Learning on the blocked curriculum from Fig. 1 with τc=0.03, τw=1.3, and block length τB=1.0. x-axis indicates time as multiples of τB. (Black) Flexible NTA model Eq. (1, NTA), (gray) forgetful NTA model with τc=τw and λnonneg=λnorm=0. Simulation averaged over 10 random seeds with standard error indicated. A. Loss of both models over time. B. Gate activity of flexible NTA. C. Student-teacher weight alignment WmWp, normalized and averaged over rows (cosine similarity) for each student-teacher pair. D., E. Norm of updates to Wp and c. Dashed: norm of students correlating with update size of c. F. Time to 𝓛task=0.1 for both models over blocks.

Next, we verify that the task abstractions in the gating variables are general, in the sense that they support compositional generalization. We consider two settings that begin by training a model with three paths on three teachers A, B, and C in alternating blocks, and then training on novel conditions. In task composition, the novel conditions are the teachers’ additive compositions A+B, A+C, B+C (Fig. 3A), we see that the flexible NTA model trains on these combinations much faster (Fig. 3C). In subtask composition, the novel conditions are combinations of the rows of different teachers, i.e. we break the teachers A, B, C into rows and select from these rows to compose new tasks. (Fig. 3B). In the subtask composition case, we use a more expressive form of the gates in the model that can control each row of the student matrices Wp individually. We find that, in the flexible regime, the model quickly adapts to the introduction of compositional tasks, while the forgetful model with regularization removed does not (Fig. 3C,D). For more details and extended analysis, see Appendix A.8.

We devote the next section to identifying what factors support the flexible regime of learning.

5. Mechanisms of learning flexible task abstractions

We observed in Fig. 2 that simultaneous gradient descent on weights and gates converges to a flexible regime capable of rapid adaptation to task changes. But what mechanisms facilitate this solution? We here leverage the linearity of the model to identify the effective dynamics in the SVD space of the teachers, in which we describe the emergence and functioning of the flexible regime.

5.1. Reduction to effective 2D model dynamics

For simplicity, we consider the case with only M=P=2 teachers and students. We take a similar approach to Saxe et al. [2013], and project the student weights into the singular value space of the teachers for each mode α individually, yielding a scalar wmαp:=uαmWpvαm. Each pair of components α thus reduces to a 2D state vector y=c1w1+c2w22, where we stack wmp along the index m and omit α in the following for readability (Fig. 4A). A similar projection is possible in terms of the row vectors of both teachers (Appendix A.1.1).

Figure 4: Mechanism of gradual task specialization in effective 2D subspace.

Figure 4:

A. Sketch of the reduced model and dynamic feedback. Out-of-subspace students gradually align to teacher axes. B. Trajectories of student weight matrices (blue, orange) in the teacher subspace during complete adaptation following a context switch from teacher 1 to teacher 2 in the flexible regime. Gray stripes indicate associated gate activation. The student weight matrices move little. C. Like (B), but for the forgetful regime. Student weight matrices entirely remap and gates do not turn off. D. Gradient of the task loss on cp as a function of the weight alignment. E. Trajectories in the specialization subspace as a function of gate timescale for values τc = 0.1, 0.18, 0.32, 0.56, 1.00 comparing (color) simulations and (dashed black) analytical predictions from exact solutions under symmetry in the flexible regime. Simulations begin from initial conditions of complete specialization and separation wmp=δpm, cp=δp1 and follow a complete adaptation from teacher 1 to teacher 2 over the course of a block, reaching 𝓛task<102 for all τc.

The essential learning dynamics of the system can therefore be described as

τwddtwp=cpymy, (1)
τcddtcp=wpymyλcp𝓛reg. (2)

where ym describes the output of the currently-active teacher m. In Appendix A.1, we show analytically and through simulations that this reduction is exact when gradients are calculated over many samples.

5.2. Specialization emerges due to self-reinforcing feedback loops

The flexible regime is characterized by students that are each attuned to a single teacher (Fig. 4B), whereas in the forgetful regime, both students track the active teacher together (Fig. 4C). We can describe this difference by studying the specialization of the students. We define this by considering the difference in how represented the teachers are in the two paths: for teacher 1, w¯1:=wm=1p=1wm=1p=2 and, for teacher 2, w¯2:=wm=2p=2wm=2p=1. Similarly, a hallmark of the flexible regime are separated gates. Together, this defines the specialization subspace

w¯:=12w¯1+w¯2, (3)
c¯:=c1c2 (4)

The system is in the flexible regime when absolute values of w¯ and c¯ are high (approaching 1), and in the forgetful regime when they are low (around 0). In this section, we study the emergence of the flexible regime through self-reinforcing feedback loops, with specialized students and normalizing regularization leading to more separated gates, and separated gates in turn leading to more specialized students. In each subsection, we first describe the effect of the feedback loops on the paths individually, before considering the combined effect on specialization. Without loss of generality, we consider cases where the student p specializes to teacher m=p.

5.2.1. Specialized students and regularization encourage fast and separated gates

We first investigate the influence of wp on ddtcp. From the gate update in Eq. (2), we get

τcddtcp=ε1wpw1+ε2wpw2cp𝓛reg, (5)

where we decomposed the error ε:=ymy into the teacher basis as coefficients εm:=εwm. The feedback between students and gates enters here in two terms, as can be seen by expressing wpwm=wpwmcos((wp,wm)), where denotes the angle between two vectors. As observed in Fig. 2, both the alignment between students and teachers cos((wp,wm)) and the magnitude of the students wp control the gate switching speed.

As the vectors wp are formed from the students’ singular values, they scale proportionally to the bare matrix entries Wijp for random initialization (Marcenko-Pastur distribution). Early in learning, the small initialization will therefore attenuate gate changes by prolonging their effective timescale τc/wp (or equivalently, lower their learning rate).

As we demonstrate in Fig. 4D, these effects apply to both activation and inactivation of the gates, depending on the direction of the current error εwm0.

The regularization in the system introduces a feedback loop between c1 and c2. In practice, the system quickly reaches a regime where both gates cp are positive. In this case, the regularization term using the L1-norm becomes cp𝓛regpcp1, reaching a minimum along the line pcp=1. In order to minimize the regularization loss, the upscaling of one gate cp past 0.5 will result in the downscaling of the other gate cp, and vice versa. We note that this regularization term does not favor specialization by itself since the network can also attain zero loss in the unspecified forgetful solution with, for instance, c1=c2=0.5.

The above dynamics mean that differences in student alignment separate the gates, as described by

τcdc¯dt=w¯1ε1w¯2ε2 (6)

We therefore see that the differences in gate activation are driven by the difference in specialization in the two components w¯1 and w¯2 and corresponding error components ε1 and ε2. Since the error components are of opposite sign following a context switch, dc¯dt is maximized when the students are maximally specialized.

5.2.2. Flexible gates protect learned specialization

We now study the influence of cp on ddtwp. The gates allow for a switching mechanism that does not require a change in the weights. When continuing gradient descent on all parameters, however, Eq. (1) will also entail a finite update to the wrong student.

If we Taylor-expand to second order, this update reads

τwddtwpcpε+12ddtcpε+cpddtεdt. (7)

The first summand of the second term reflects the protection that arises from changes in gating ddtcp=wpε: a task switch to ym=0,1 incurs an error ε1,1. For a specialized, but now incorrect student wp1,0, this term becomes ddtcp=wpε<0 for the incorrect student. Together with the decreasing error in the last term ddtε, this reduces the student update from the leading-order first term cpε. Importantly, this protection effect grows over training as the student’s contribution to the error wpε increases.

Alongside protection, flexible gates also accelerate specialization, as can be seen by considering w in specialization space,

τwdw¯dt=12c¯ε1ε2 (8)

This equation shows that the students specialize through two factors: the difference in error between the two components ε1ε2, and the difference in gate activation c¯=c1c2.

5.3. Exact solutions to the learning dynamics describe protection and adaptation under symmetry in the flexible regime

In this section, we study exact solutions to the learning dynamics in Eq. (6) and Eq. (8) to describe the behavior of the model as it switches between tasks when it is already in the flexible regime. To solve the differential equations, we require the condition of symmetry where w¯=w¯1=w¯2. This condition is approximately true for specialized states in the flexible regime (see Fig. A.10), and its persistence follows as long as ε1=ε2 holds or in the limit of strong L1 regularization.

We use the method presented in Shi et al. [2022] to solve the resulting dynamics of the ratio between the expressions for dc¯dt and dw¯dt

τcτwdc¯dw¯=2w¯ε1ε2c¯ε1ε2 (9)

which is a separable differential equation that can be solved up to an integration constant (see Appendix A.7.2). Plugging in initial conditions that correspond to complete specialization in the flexible regime c¯0=w¯0=1, we obtain the exact dynamics of w¯ as a function of c¯ over the course of a block,

w¯=112τcτw1c¯2. (10)

This analytical solution accurately describes adaptation in the flexible regime (Fig. 4E). The relationship highlights the role of a shorter gate timescale τc in protecting the student’s knowledge. Learning that comes from both students specializing towards the current teacher occurs outside of this specialization space and becomes more important for low τc (see Appendix A.7.3).

6. Quantifying the range of the flexible regime across block length, regularization strength, and gate speed

To assess the roles of block length, regularization, and fast gate timescale (inverse gate learning rate) in establishing the flexible regime, we run two grid searches over the gate learning rate/regularization strength and block length each task is trained on, keeping the total time trained constant (such that models trained on shorter block lengths are trained over more block switches but equal amounts of data). For each set of hyperparameters we compute the total alignment (cosine similarity) between the entire concatenated set of teachers and students as a single overall measure of specialization in the network weights at the end of learning. We identify the boundaries of the flexible regime where specialization emerges in our model, dependent on block length, gate timescale, and regularization strength (Fig. 5). A priori, the block length dependence is surprising, as one might expect additional time spent in a block to be reversed by the equally-long subsequent block. However, we show in Appendix A.6 that gating breaks this time-reversal symmetry, and specialization grows with block length τB for fixed overall learning time t.

Figure 5: Model specialization emerges as a function of block length, gate learning rate, and regularization strength.

Figure 5:

The colorbar indicates total alignment (cosine similarity) between all sets of students and teachers considered collectively.

7. Inducing the flexible regime in a deep fully-connected neural network

Our NTA model uses a low-dimensional gating layer that gates computations from student networks. We sought to understand the necessity and role of this structure by considering a more general form of the model in a deep linear network with no architectural assumptions. Based on the analysis and results so far (Fig. 4D,5), we impose regularization and faster learning rates on the second layer of a 2-layer fully-connected network. Behaviorally, this network also shows the signatures of the flexible regime with adaptation accelerating with each task switch experienced (Fig. A.4A).

To quantify specialization and gating behavior, we compute the cosine similarity between each row of the first hidden layer and the teachers and use this to sort the network into two students that align to the teachers they match best. We also permute the second layer columns to match the sorting of the first layer. We then visualize the specialization of the sorted first hidden layer using the same measures as in the original NTA model. We also take the mean of each sorted student’s second hidden layer to be its corresponding gate. Using this analysis, we find emergent gating in the second layer (Fig. A.4B) and specialization in the first (Fig. A.4C). Adaptation to later task switches takes place primarily in the second layer (Fig. A.4E).

By visualizing the sorted second hidden layer of the fully-connected network at the last timestep of two different task blocks, we indeed observe distinct gating behavior along the diagonal, specialized for each task (Fig. 6 for one seed). We compare this to the same fully-connected network trained without regularization which remains in the forgetful learning regime. We include visualizations of the unsorted second hidden layer for fully-connected networks arriving at both the gating and non-gating solutions (Fig. A.5 for ten seeds), as well as the sorted second hidden layer (Fig. A.6 for ten seeds) as supplement. Appendix A.4.2 discusses the potential for multiplicative gates to emerge in fully-connected architectures.

Figure 6: Task-specialized gating emerges in the second layer of a 2- layer network with faster second-layer learning rate and regularization.

Figure 6:

The sorted second layer weights at the last timestep of two different task blocks (one seed).

8. Flexible remapping of representations in nonlinear networks in two MNIST tasks

We next study whether NTA also works in larger, nonlinear systems. As a proof of concept, we investigate whether NTA can help a neural network switch between two nonlinearly-transformed versions of the MNIST dataset [Deng, 2012]. The first task is the conventional MNIST task. The second is a permuted version of MNIST where the image of a digit is sorted based on its parity according to the function yy/2+5×y%2, where % is the modulo operation (see Fig. 7A). We pre-train a convolutional neural network (CNN) on MNIST to learn useful representations, achieving about 90% accuracy on the test set. We then train an NTA system beginning from the final hidden layer representations that feeds into the same sigmoid nonlinearity (see Fig. 7B). We again induce the flexible regime using regularization and fast timescales, and contrast performance with a forgetful model (see Appendix B.7). We find that the flexible model learns to recover its original accuracy quickly after the first task switches whereas the forgetful one needs to continuously re-learn the task, as evaluated on the MNIST test set (Fig. 7C). The activity in the gating units reflects selective activity (Fig. 7D). To further test the range of NTA, we examine how much these results depend on the orthogonality of the task space by formulating two tasks based on real-world groupings of clothing in fashionMNIST [Xiao et al., 2017] that have different amounts of shared structure. We find that rapid task switching occurs in both settings at a similar speed (Fig. A.15).

Figure 7: Learning flexible neural task abstractions in a nonlinear character recognition setting.

Figure 7:

A. We formulate two tasks, the original and a permuted version of MNIST. B. We embed the NTA system into a larger pretrained convolutional neural network architecture. C. Accuracy reached on the MNIST test set as a function of time for both (black) the NTA network and (gray) the original CNN. The two tasks are presented sequentially in blocks for both (blue shading) MNIST and (orange shading) the permuted version. D. The activation of the two gating units as a function of time. We show mean and standard error with 10 seeds.

9. Relations to multi-task learning in humans

Our model captures several aspects of human learning. Humans update task knowledge in proportion to how well each probable task explains current experience [Castañón et al., 2021]. Analogously in our model, weight updates are gated with the corresponding gating variable, whose activity in turn reflects how well the weights behind it capture the target computation.

Humans show faster task switching with more practice on the tasks involved. In our model, we saw that the gates change faster as weights specialize to the tasks, which facilitated faster adaptation after block switches. NTA shows a qualitative fit (Fig. 8) to humans trained on alternating tasks [Steyvers et al., 2019]. In contrast, a forgetful model shows a deceleration, possibly due to being far from optimal initialization [Dohare et al., 2024] after task switches.

Figure 8: Comparing performance after a task switch in humans and NTA model.

Figure 8:

A. Steyvers et al. [2019] report performance of humans learning two alternating tasks (CC BY-NC-ND 4.0 license). B. After a block switch, loss comparison between the flexible (left) and the forgetful (right) NTA model shows opposite trends with further training on switching speed. Bars are standard error with 10 seeds.

10. Conclusions and future work

This study demonstrates how task abstraction and cognitive flexibility can emerge in neural networks trained through joint gradient descent on weights and gating variables. Simple constraints on the gating variables induce gradient descent dynamics that lead to the flexible regime. In this regime, the weights self-organize into task-specialized modules, with gating variables facilitating rapid task switching. Analytical reductions revealed a virtuous cycle: specialized weights enable faster gating, while flexible gating protects and accelerates weight specialization. This contrasts with the forgetful regime in which knowledge is continually forgotten and re-learned after task switches. The constraints necessary for reaching the flexible regime are appropriate regularization, differential learning rates, and sufficient task block length. These mirror properties of biological neurons and beneficial learning settings identified in cognitive science.

The mechanistic understanding of how task abstraction arises in neural systems might bridge artificial neural networks and biological cognition, offering a foundation for future exploration of adaptive and compositional generalization in dynamic environments. While this study focuses on simple two-layer networks, the framework is applicable to other non-linear architectures such as recurrent networks or Transformer architectures. We see future work providing additional architectures and real-world applications of the framework.

Acknowledgements

We thank Stefano Sarao Mannelli and Pedro A.M. Mediano for thoughtful discussions. We would also like to acknowledge and thank the organizers of the Analytical Connectionism Summer School, at which AP, JB, and KS first met. AH is funded by Collaborative Research in Computational Neuroscience award (R01-MH132172). AP is funded by the Imperial College London President’s PhD Scholarship. KS is funded by a Cusanuswerk Doctoral Fellowship. JB is supported by the Gatsby Charitable Foundation (GAT3850). This work was supported by a Sir Henry Dale Fellowship from the Wellcome Trust and Royal Society (216386/Z/19/Z) to AS, and the Sainsbury Wellcome Centre Core Grant from Wellcome (219627/Z/19/Z) and the Gatsby Charitable Foundation (GAT3755).

Appendix

Overview

We structure the Appendix as follows:

We first provide additional material on our results in Appendix A. Appendix A.1 derives the reduction to the 2D equivalent model. Appendix A.2 shows that even without a differential timescale between gates and weights, high-rank students will learn more slowly compared to gates. Appendix A.3 shows that networks with more paths than teachers will split their representations across paths unless a cost is associated with representation. Appendix A.4 provides simulations and derivations on how gating behavior emerges in an architecture without explicit pathways but with two layers, where one layer emergently takes on the role of gates, and the second layer becomes compartmentalized. Appendix A.4 contains additional results for the fully-connected network. Appendix A.5 shows how the specialized representation incentivized by the virtuous cycle discussed in the main text leads to a faster reduction in loss compared to an unspecialized solution. Appendix A.6 provides an approximate theoretical explanation for the beneficial effect of long blocks towards specialization through symmetry breaking in an effective potential. Appendix A.7 provides approximate closed-form solutions when operating in the flexible regime. Appendix A.8 provides detail on how the model generalizes to new tasks by leveraging existing abstractions. Appendix A.9 shows that the flexible regime largely persists and slowly decays when the orthogonality assumption between teachers is relaxed. Appendix A.10 shows that the model can adapt in a few-shot fashion after a block switch, extending the results from the main text where gradients are calculated on many samples.

We then provide additional technical details in Appendix B. In Appendix B.1, we provide a notation table. Appendix B.2 lists parameters used for simulations. Appendix B.3 discusses the regularization that we use, in particular why it does not incentivize a flexible over a forgetful solution. Appendix B.4 details on how we calculate alignments between teachers and students, both for the per-student and per-neuron gating models. Appendix B.5 discusses how we choose model parameters.

A. Additional details on main text

A.1. Derivation of reduction to 2D equivalent model

We will here show that the dynamics of the model can be reduced to an effective model that acts in a 2D space spanned by the singular teacher vectors across both tasks m.

Recall the task loss as the mean-squared error

𝓛task=12BdoutbBybmyb2=12BdoutbBWmxbpcpWpxb2

for a batch X=xbb=1B of size B.

Following the approach in Saxe et al. [2013], we assume the input data is whitened, such that the batch average 1BXXIdin, and the learning rate τ1 is small (i.e., the gradient flow regime). Then, the batch gradient reads as a differential equation that simplifies as

τwddtWp=1BdoutcpWmXpcpWpXX (11)
cpWmpcpWp. (12)

For each teacher m and singular value decomposition along a mode α (uαm, sαm, vαm), we can project this equation to get

τwddtuαmWpvαmsm,αp=cpuαmWmvαmuαmpcpWpvαm=cpsαmpcpsm,αp. (13)

The student singular vectors uαp, vαp have been shown to align to those of the current teacher uαm, vαm early in learning [Atanasov et al., 2021]. After training on both teachers, the student can therefore be fully described in terms of the coefficients sm,αpm,α in the basis spanned by the α-singular vectors of both teachers, decoupling from the other singular value dimensions. If all singular vectors across two teachers m are pairwise orthogonal, these projections form an orthogonal basis. The components outside of this projection will have finite error in all context and therefore exponentially decay to 0 [Braun et al., 2022].

This reduction allows us to study learning in a simpler and more interpretable model. For the case where M=P=2 which we consider here for simplicity, we can therefore reinterpret each model α-component as vectors w1sm=1,αp=1,sm=2,αp=1, w2sm=1,αp=2,sm=2,αp=2 in 2:

y=c1w1+c2w2=c1s1,α1s2,α1+c2s1,α2s2,α2 (14)

and redefine the context-dependent target vector ym accordingly.

The reduced model follows the update equations

τwddtwp=cpymy (15)
τcddtcp=wpymy. (16)

Notably, both updates depend on the full error term ε:=(ymy) with both paths entering into y. The first equation moves the student in the direction of the current total misestimation of the active teacher ε. The second equation changes the gating of the current path according to the alignment of the path wp to the current vectorial error, reflecting the contribution of the path to the mismatch.

In Fig. A.1, we simulate the models side by side and show that the reduced model matches the dynamics of the full model.

A.1.1. Reduction in terms of teacher row vectors

In the main text and the previous section, we consider a reduction that follows from projecting onto the eigenspace of the matrices. However, a similar reduction is possible by considering each row β independently, and considering the row vectors of the two teachers wβmm as a basis for that row. Like in the projection in terms of the eigenspace, the out-of-projection component of the students decays exponentially. We can then consider a single row of the teacher-student system to function as a mode α above, with a row of the student path p becoming wp=w1pw1+w2pw2 so that we can write

y=c1w1+c2w2=c1w1,β1w2,β1+c2w1,β2w2,β2 (17)

where we aggregate over rows β. This formulation only requires pairwise orthogonality between rows wi1·wi2=0 to fully decouple the dynamics of the system, but does not extend as elegantly to considering deeper students or low-rank solutions.

Figure A.1: Simulation of dynamics of full and reduced model.

Figure A.1:

The equivalent reduced model effectively captures the dynamics of the full model in terms of loss (A1., B1.), gates (A2., B2.), and singular value magnitude (A3., B3.).

A.2. High-dimensional students learn slower

Figure A.2: High-dimensional students learn slower.

Figure A.2:

Gate change τcddtc as a function of teacher dimensionality/rank (i.e., non-zero singular values). Weight scaling is chosen such that input and output components take unit scale, yi=𝓞1, xj=𝓞1.

An intuition one might have for the model dynamics is that the weight matrices comprise of more parameters and therefore may respond more slowly under gradient descent. Here, we discuss the formal conditions under which this indeed is the case.

For simplicity, we consider a one-path model y=cWx, with ydout, xdin. We now choose the scaling yi=𝓞1, xj=𝓞1 on input and output, which means that entries do not depend on the respective vector dimensionalities. This is a natural assumption, for example, if yi are label indicators and xj are pixel brightness values of an image. Recall the model loss 𝓛task=12(ymy)2.

Then, the batch-averaged c-gradient reads

τcddtc=c𝓛task=TrymWxxW,=TrWxxWWxxWBTrWWWWTrUSVVSUUSVVSU=TrSSS2=αmindout,dinsαsαsα2.

Here, we used the Gaussian i.i.d. initialization of x to take an expectation for large batch size (xxB=Idin), the SVD of W=USV, orthonormalization of singular vectors UU, VV, and the cyclic property of the trace Tr. We also assumed in the fourth row that the student singular vectors have already undergone Silent Alignment [Atanasov et al., 2021] to match the teachers, as discussed in the main text.

From the last row, we observe that the updates to c tend to scale with the number of nonzero singular values, i.e. the rank of the teachers.

For the fan-in scaling Wij𝓝0,σ2/din that is compatible with xi=𝓞1, yi=𝓞1, we have sα=𝓞σ independent of dimensionality (Marcenko-Pastur distribution). If teacher and students are initialized according to this scaling, students will respond relatively slower compared to gates as their dimension min(din, dout) grows, as the student gradient Eq. (11) or the reduced form Eq. (13) does not involve a sum that scales with dimensionality.

A.3. Representational cost in under-specified model

Figure A.3: Redundant paths become inactive when representation is costly.

Figure A.3:

Gating variables like in Fig. 1B, but with more paths than teacher tasks (P=4>M=2). A. Only under representational cost on the weights, students that are preferably aligned due to the random initialization specialize to the M=2 teachers, whereas other gates decay to 0. B. Without representational cost, the model uses multiple paths for tasks and thus has multiple gates active at the same time for a single teacher.

In our initial Eq. (1, NTA), we have introduced a model in which the number of paths P of the architecture matches the number of available tasks. What happens if this match is not present? If the under-specified case P<M, the model’s expressiveness hinders adaptation. It is however not clear what will happen in the over-specified case P>M. In absence of any regularization on the weights Wp, the model will not devote only P=M<P paths to match the task. Rather, in accordance with the theory by Shi et al. [2022], the model will in general split its paths over the available tasks. This behavior is due to the absence of a “representational cost” of having multiple paths active at the same time. We find that this effect is reduced only when introducing an L2-regularization λW2PdinijpWijp2, λW=0.77. This term additionally penalizes weight magnitude leads to the decay of inactive paths. We show this behavior in Fig. A.3.

A.4. Gating-based solution emerges in a fully-connected network

As described in the main text, we induce the flexible gating regime in a fully-connected network by applying regularization and a faster learning rate to the second layer and compare to a forgetful (unregularized) fully-connected network. Details of the sorting procedure used to identify and visualize this gating-based solution are described in Appendix B.4.1. We find that the flexible fully-connected network exhibits behavior that is qualitatively similar to the flexible NTA (Fig. A.4). By visualizing the second layer at the end of training on different tasks in the flexible regime, we observe that the network upweights single units in each row (Fig. A.5A, A.6A), which act as gates for the first layer rows. Instead, in the forgetful regime, the network has multiple upweighted units in each row and the units do not change behavior across different tasks, exhibiting a lack of task-specificity and gating-like behavior (Fig. A.5B, A.6B).

Figure A.4: Gating-based solution emerges in a fully-connected network with regularized second layer weights and a faster second layer learning rate.

Figure A.4:

A. Loss during learning. B. The dynamics of the sorted gating variables. C. Alignment between the sorted students in the first layer and the teachers. D. Total alignment between the entire set of teachers and students. E. The norm of the gradient of the first (black) and second (red) hidden layer of the fully-connected network.

Figure A.5: Regularized, but not non-regularized, fully-connected network specializes single neurons in each row as ‘gates’ per task and exhibits specificity based on task.

Figure A.5:

Visualization of the unsorted second hidden layer of the flexible (left) and forgetful (right) fully-connected network for a single seed.

Figure A.6: Second hidden layer of regularized, but not non-regularized, fully-connected network exhibits clear task-specific gating across the diagonals of the matrix.

Figure A.6:

Visualization of the sorted second hidden layer of the flexible (left) and forgetful (right) fully-connected network averaged over 10 seeds.

A.4.1. Model specialization as a function of block size, gate timescale, and regularization strength in fully-connected network

We perform two hyperparameter searches to illustrate the joint effects of block length, second layer learning rate, and regularization strength on the fully-connected network, similar to that we perform on the NTA model in the main text. We run the fully-connected network on each set of hyperparameters and report the total alignment of sorted teachers and students at the end of training as an overall measure of specialization, fixing all other hyperparameters (see Appendix B.5 for more details). We observe that the same components of block length, fast second layer learning rate, and regularization are important for specialization to emerge in the fully-connected network (Fig. A.7), just as in the NTA model.

Figure A.7: Model specialization emerges as a function of block length, second hidden layer learning rate, and regularization strength in fully-connected network.

Figure A.7:

The colorbar indicates total alignment (cosine similarity) between all sets of students and teachers considered collectively.

A.4.2. Possibility of emergence of gating in two-layer network

In this work, we have analyzed a linear architecture with an explicit architectural gating structure,

yi=p=1pj=1dincpWijpxj=cWxi, (18)

where we have notationally stacked the students W:=Wpp=1...P into a vector, such that WP×dout×din, cP. here denotes the Hadamard (element-wise) product.

Prior work has considered deep linear networks [Saxe et al., 2019, Atanasov et al., 2022, Shi et al., 2022, Braun et al., 2022], which led us to study such fully-connected network in the main text,

yi=j=1dinh=1dhidWih2Whj1xj=W2W1xi. (19)

The gated network considers gating as a multiplicative effect on each output unit i (or equivalently, input unit j), whereas the deep network invokes an additional all-to-all weighted summation. As such, Eq. (19) does not incorporate any modular structure, yet formally resembles Eq. (18). To further analyze how these settings connect, we decompose the tasks as Wijm=αUiαmsαmVαjm, and write the student layer matrices as SVDs Wih2=αUiα2sα2Vαh2, Whi1=αUhα1sα1Vαj1. The overall model Eq. (19) then reads

yi=jdinαmindout,dhidαmindhid,dinUiα2sα2Vαh2Uhα1sα1Vαj1xj=U2S2V2U1S1V1xi.

If the minimum of weight dimensions min(dout, dhid, din) exceeds the number of task modes mrankWm, it is possible to choose/learn V2 and U1 such that the second layer singular values sα2 effectively take the role of the gates cp, whereas the first layer encodes the student task representations. If we put aside the question of learnability and only ask about expressivity, this argument shows that a gating structure can emerge as subset of a two-layer network. For the fully-connected model we have in the main text, dhid=2dout, giving the network the capacity to learn and remember solutions for both teachers.

A.5. Adaptation speed

In this section, we derive the change in model output that is induced by the change in parameters depending on their configuration, thereby describing the model’s adaptation speed.

Neural tangent kernel

Here, we briefly review the Neural Tangent Kernel (NTK, [Jacot et al., 2020]) which we then use to directly describe the adaptation speed in the output yt. For a vector-valued model ydout parameterized by a flattened parameter vector θk=flattenWijp,cpk, the output evolves as

ddtyi=kdyidθkdθkdt=kdyidθkd𝓛dθk=k,jdyidθkdyjdθkd𝓛dyjNTKij,

where we used the chain rule and that the parameters update according to gradient descent dθkdt=d𝓛dθk and have set the learning rate to 1 for simplicity.

This object can be understood as a matrix operating on the output space NTK=dy/dθdy/dθdout×dout, where the inner product represents the sum across parameters Σk in the expression above. For the reduced model Eq. (2) with θ=flattencp,Wipp=1P, we readily get

dyidcp=wip,dyidwjp=δijcp,

where δij is the Kronecker delta.

We then arrive at

NTK=pcpcp+wpwp, (20)

where we adopt standard matrix notation to imply cpcpcpcpIdout as being proportional to the identity matrix Idoutdout×dout.

Figure A.8: Specialized students and gates accelerate adaptation.

Figure A.8:

Heatmaps of the dot product ddtyε contributions for different terms of the Neural Tangent Kernel (NTK), depending on specialization of weight vectors w1 (blue), w2 (orange), of which three pairs corresponding to different degrees of specialization are shown here (pairs are formed by vectors that are symmetric along the diagonal). c1, c2 are scaled so that the sum lies on the dashed black line (given by L1 regularization). A. shows the total contribution of both terms of Eq. (20) combined, B. isolates the contribution from the wpwp term, and C. displays the contribution from the cpcp term. Dashed lines indicate possible solutions.

Accelerated adaptation through specialized weights and selective gates

To study the accelerated adaptation of the loss 𝓛task=12(ymy)2, we use the Neural Tangent Kernel of the architecture that directly describes the dynamics of the model output. To this end, we study how the model output y changes in response to a block switch ym=1,00,1, entailing an error ε=1,1 that drives a change in model output. To see how this change reduces the loss, we calculate its alignment with the error term as

ddt𝓛task=εddty=εNTKε=εpcpcp+wpwpε=pε2cp2+wpε2.

The NTK reveals that the change in output y is accelerated through two contributions which we illustrate in Fig. A.8: First, we observed that student-teacher alignment which enters wpε increases towards the flexible regime. We note that this acceleration however does not require unique student-teacher alignment (such that no two students match the same teacher); it is the joint effect of asymmetric gates which further facilitates unique specialization. Second, selective gates accelerate adaptation because a sparse vector with the same norm tends to have a larger sum-of-squares pcp2 that enters the NTK. These factors coincide in the flexible regime.

A.6. Larger blocks enable faster specialization

Figure A.9: Larger blocks enable faster specialization.

Figure A.9:

A. Two trajectories of differing block length (faint: τB, opaque: τB=2τB) of two students (blue, orange) in teacher subspace as in Fig. 4. B. Student dynamics in an approximate loss landscape in early learning. Subpanels 1–4 are time points in a simulation. Background: active context m. The linear first-order loss does not lead to separation, as a block switch will exactly reverse any changes to specialization. In contrast, the curvature from the second order term enables students to accumulate an initial advantage in specialization. C. Like B., but loss is in terms of the specialization variable w¯. The effective loss over blocks depends on block size τB: the longer the τB is, the longer the students will have to fall down the landscape in B.. The first-order term, corresponding to infinitely short blocks, does not prefer specialization.

Here, we calculate how the block length affects the specialization w¯:=w1w2 in students, given equal total learning time t. To do so, we consider periods early in learning where the students have not specialized yet, w¯0, and the gates consequentially are indifferent, c¯=0. We then consider a period of the task, i.e. back-and-forth block switches ab, ba that last a total of T=2τB. We analyze the limit of small block sizes and ask how w¯ changes over T: for short blocks, the model is time-reversal symmetric, i.e. any change ddtw¯τB during ab is exactly undone during ba. We therefore calculate the second-order effects to w¯ to analyze the dependence on τB, where we only make an assumption on the approximate directions of w and ε that apply early in learning, but leave their scale general.

Second-order derivatives for weight and gates

To prepare, we first calculate the second order derivatives of the updates which we will need in what follows. By application of the product rule, we obtain for the weights wp

ddtwp=cpε,d2dt2wp=ddtcpε+cpddtε=ωpεεcpNTKε=ωpεεcppcpcp+wpwpε

where we use ddtε=ddtymy=ddty=ddtNTKε within a block.

For the gates cp, we get

ddtcp=wpε,
d2dt2cp=cpεεwpNTKε.

We then take the difference of the weight derivatives w¯=w1w2, c¯=c1c2 to get by linearity

ddtw¯=c¯ε0
d2dt2w¯=w¯εεc¯pcpcp+wpwpεw¯εε

where we assume that the gates have approximately equal values c¯0 when no specialization has taken place yet.

Effect on specialization

The sum of second derivatives after having seen a switch ab and ba is subsequently:

d2dt2w¯|ab+d2dt2w¯|ba=w¯εbεb+εaεa=w¯εεbεa=2w¯εε,

where we use that w¯ is the component parallel with the error signal, implying w¯εp=w¯εp for the errors εb, εa=±12C1,1 for some constant C depending on weight magnitude, and small w¯(0). We herein assumed that the errors between blocks only differ by a sign, ε:=εb=εa. In particular, we neglect the change in cp for simplicity. Note that this is a good approximation only if the block size τB is much shorter than the timescales of τc. While this limits quantitative predictions of this approximation for the setting we consider, we expect that it identifies the qualitative mechanism.

Introducing the period T, which is double the block length τB=T/2 (spanning two blocks of length τB), we get

w¯T=2τB=w¯0+2122w¯εετB2 (21)

where the factor 12 in the first line is due to being at second order in the Taylor expansion of the update. Setting w¯0=0, this means that the cumulative change of two periods together lasting t=2T (i.e. spanning two pairs of blocks totaling four blocks) is

w¯t=2periodsofT=2×2τB=2×2122w¯εετB2=22w¯εετB2 (22)

In contrast, doubling the block size τB=2τB thus T=2T, but running for the same amount of time t (i.e., two blocks of double the block length τB) gives

w¯t=1periodofT=1×22τB=1×2122w¯εε2τB2=42w¯εετB2,

which is twice as large as the short-block version. This explains why larger blocks can lead to faster specialization.

Mechanical analogy

The importance of the quadratic term for specialization can be understood through a mechanical analogy for the weights as particles: gradient flow corresponds to the dynamics of particles undergoing an overdamped Newtonian motion in a potential. To this end, we consider the proxy loss potential 𝓛n that induces the respective first- and second-order time dynamics described by Eq. (21) when considering gradient flow τwddtw=w𝓛nw. The resulting loss potential 𝓛n is polynomial and grows w (first-order) and w3/2 (second-order), as can be verified by plugging in a solutions wtt, wtt2.

To first-order, block changes will result in exactly opposite gradients, which will revert any changes from the previous block due to the lack of momentum effects in gradient flow (Fig. A.9B, dotted line). In contrast, the quadratic term in the time dynamics can be understood as resulting from an effective non-linear loss potential, breaking this time-reversal symmetry between blocks (Fig. A.9B, dashed line). Note that the sum 𝓛1+𝓛2 may still be monotonic.

When aggregating this effect over many block changes, it gives rise to an effective loss potential 𝓛˜ for the specialization variable w¯ (Fig. A.9C). As the first-order terms cancel out over blocks, the effective loss potential does not have a preferred specialized configuration. When including second-order terms however, the preferred state becomes specialized. Moreover, the speed of this specialization depends on the timescale of the blocks τB: the larger τB, the further the particles increase their advantage down the non-linear loss potential in Fig. A.9B, which manifests as a steeper and thereby faster double-well loss potential in Fig. A.9C.

A.7. Exact solutions under the condition of symmetry

A.7.1. Solving the differential equations under the condition of symmetric specialization

The multiplicatively-coupled dynamics allow for emergent specialization of the students for one of the teachers. Due to the regularization on c as well as the orthogonality of the teachers, these dynamics are competitive, meaning that the students will, after a number of task switches, each specialize for a different teacher. We saw before that the most rapid adaptation occurs when the system is in this state where each teacher is matched by exactly one student. We formalize this state more generally under a condition of symmetric specialization, in which w¯=w¯1=w11w12=w22w21=w¯2. In this section, we present exact solutions to the learning dynamics that occur in this system under this condition. We assume without loss of generality that path p in the NTA student is aligning with teacher m. We will see that these solutions closely match the adaptation dynamics of the full NTA model.

The learning updates of the two specialization components are dw¯1dt=c¯ε1 and dw¯2dt=c¯ε2. Therefore, the symmetric specialization condition is inherently preserved when ε1=ε2. We calculate the accuracy of both relationships in Fig. A.10.

Figure A.10: Accuracy of assumptions used in calculating the exact solutions, only one of which is required.

Figure A.10:

A. Accuracy of the assumption ε¯=ε1=ε2. B. Accuracy of the assumption w¯=w11w12=w22w21.

A.7.2. Learning occurs dominantly within but also outside of the specialization subspace over the course of a single block

The differential equation Eq. (9) then results from dividing the two update equations for ddc and ddw. We can then solve this differential equation to obtain the relationship

τcc¯2=2τww¯2+C (23)

where C is an integration constant, which we determine by plugging in the initial conditions c¯=w¯=1 to represent the theoretical ideal for the flexible regime.

A.7.3. Learning in the orthogonal component

As is highlighted by the fact that the analytical solution is not traversed in full in Fig. 4E, learning also occurs outside of the specialization subspace. This learning can be characterized by co-specialization which is characteristic of the forgetful regime

w¯¯:=w11w21+w12w222. (24)

Fig. A.11 shows learning along both components w¯ and w¯¯, the two error components ε1 and ε2, as well as the separation of the gates c¯ over the course of a single block beginning from the same initial conditions used in the differential equation.

Figure A.11: Learning occurs within and outside of the specialization subspace.

Figure A.11:

A. First component of the error following a task switch from task A to task B for different values of τc. B. Second component of the error across the same timeframe. C. Adaptation of weight matrices in the specialization space. D. Orthogonal component of learning that measures adaptation of both teachers for the current task. E. Gate change in the specialization subspace.

A.8. Gated model generalizes to perform task and subtask composition

Figure A.12: Gated model generalizes to compositional tasks.

Figure A.12:

A. Task composition consists of new tasks that sum sets of teachers previously encountered. B. Subtask composition consists of new tasks that concatenate alternating rows of sets of teachers previously encountered. Loss (C.,D.), gating activity (E,F.), and student-teacher alignment (G.,H.) of models on generalization to task composition (top) and subtask composition (bottom).

As stated in the main text, we consider two settings to evaluate whether the gate layer can recombine previous knowledge for compositional generalization. We first train the NTA model with three paths on three teachers A, B, and C individually, and then change the network on task composition (Fig. A.12A) or subtask composition (Fig. A.12B). Task composition proceeds the same as our standard setup.

In subtask composition, to allow our model the possibility to compose not only tasks (i.e., gate the entire student matrix Wp), but also subtasks (individual rows of Wp), we increase the expressiveness of our gating layer by using an independent gate for each neuron (or row of Wp) in the student hidden layer and allow gradient descent to update these gates individually. We call this the per-neuron gating version of our model.

In principle, the per-neuron gating NTA has Pdout independent paths modulated by gates. Thus, in order to study whether specialization and gating occurs for each teacher, we sort the Pdout paths into P paths of size dout. We do this by computing the cosine similarity between each row in the first layer W and the teachers W. We then sort the rows of the first layer to align with the rows of the rows of the teachers that they best match, such that we identify their respective students to visualize student-teacher alignment (Fig. A.12H). Additionally, we permute the gating layer c to match this sorting. We take the mean of the sorted gates for each student to visualize the task-specific gating (Fig. A.12F).

We find that both the per-student and per-neuron gating NTA models can solve task and subtask generalization tasks, respectively, and maintain their learned specialization after transitioning to compositional settings (Fig. A.12CH). We additionally observe that the gating variables learn to appropriately match the latent structure of the generalization tasks by turning off the non-contributing gate and evenly weighting the two “on” gates. We again observe the rapid adaptation of the flexible NTA to compositional tasks compared to the forgetful regime.

A.9. Non-orthogonal teachers

Figure A.13: Robustness to relaxing orthogonality between teachers.

Figure A.13:

A. Illustration of changing teacher cosine similarity. B. Adaptation speed as measured by the loss after a block switch (black) and student specialization (gray), both as a function of the teacher similarity. 0 represents the orthogonal case studied in the main text.

Figure A.14: Flexible NTA successfully specializes to underlying teachers even when trained on non-orthogonal tasks.

Figure A.14:

A. Tasks are created by adding different pairs of teachers, such that each task is non-orthogonal to every other task. B. Loss during learning. C. Gating variables learn to appropriately match the latent structure of the tasks. D. Students learn to specialize to teacher components, despite the non-orthogonality of the tasks.

In the main text, we have worked with the assumption that different tasks are approximately orthogonal to permit our theoretical analysis. This assumption holds for randomly-generated teachers when the input dimension is high. In simulations, we implemented this condition by constructing the corresponding rows of teachers to be orthogonal, wi1·wi2=0. Still, it is unclear what happens when the teachers are not even approximately orthogonal. We investigate this question empirically in Fig. A.13 and find that specialization decays gracefully as the orthogonality assumption is relaxed.

We also design a set of three non-orthogonal tasks using three orthogonal teachers, where each task is created by adding different pairs of the teachers (Fig. A.14A). Thus, every task has some similarity (and is non-orthogonal) with every other task. We find that the flexible NTA can successfully solve these tasks and identify their underlying latent teacher structure, learning to specialize and gate all teacher components which comprise the overall set of tasks (Fig. A.14BD).

A.9.1. Experiments on fashionMNIST dataset
Figure A.15:

Figure A.15:

NTA quickly adapts across fashionMNIST for (left) an orthogonal sorting based on upper-to-lower items of clothing and (right) a correlated sorting for warm-to-cold weather clothing. The panels show (top) accuracy on the test set and (bottom) activity of the gates. We show mean and standard error with 10 seeds.

We explicitly compare the network’s performance on two different versions of fashionMNIST based on tasks that might appear in a real-world setting. The original fashionMNIST dataset has items sorted roughly by order of commonality, with the label 0 being assigned to T-shirts, and the label 9 being assigned to ankle boots. We generate two different permutations of these labels representing other real-world sorting of the items that have different amounts of shared structure with the original. The close-to-orthogonal ordering sorts the clothing from upper to lower body, and orders the labels 0, 2, 4, 6, 8, 1, 3, 5, 7, 9. The ordering with more shared structure represents warm to cold weather clothing, and orders the labels to 0, 1, 5, 3, 7, 6, 2, 4, 8, 9. The results show that the stereotypical NTA-like task switching behavior and specialization emerges for both settings at a similar speed despite baseline performance being higher on the task with shared structure (Fig. A.15).

A.10. Few-shot adaptation in the low sample rate regime

Figure A.16: Few-shot adaptation after block switches.

Figure A.16:

Like Fig. 2A in the main text, but with coarsely discretized time to examine the adaptation after a single sample. As this drastically reduces signal-to-noise ratio, we average over 100 samples. Markers indicate a single step of gradient descent on one sample.

In the main text, we have considered the case of large batch size (or equivalently, small time discretization) that allows taking a sample average when going towards the theoretical, equivalent model. This averaging reduces noise in the gradient signal stemming from random samples, so that it is unclear whether learning is still possible when sample rate is low. As theoretical analysis is challenging for this case, we investigate this empirically in Fig. A.16 and find that the qualitative phenomenon is preserved even if only a single sample B=1 is used for every gradient update.

B. Technical details

B.1. Notation

Table 1:

Overview of notation used throughout the paper.

Symbol Description
xbdin, b=1B input sample of a batch
ybdout, b=1B model output
ymdout, m=1M or a,b, target label in task m
ε=yydout prediction error
cp, p=1P gates for each pathway p
Wpdout×din student weights for each pathway p
Wmdout×din, m=1M or a,b, teacher weights for each task m
wpwαp2 2D vector for reduced model (for each teacher singular value α)
Wij=αmindout,dinUiαsαVαj singular value decomposition of weights
τw=ηw1, τc=ηc1 parameter time scale (inverse learning rate)
τB block length
𝓛=𝓛task+𝓛reg loss
w¯1=wm=1p=1wm=1p=2 specialization for teacher 1
w¯2=wm=2p=2wm=2p=1 specialization for teacher 2
w¯=12w¯1+w¯2 overall specialization
c¯=c1c2 separation of gates
w¯¯=wm=1p=1wm=2p=1+wm=1p=2wm=2p=2/2 unspecialized learning

B.2. Hyperparameters

We perform the gradient calculations and the simulation of gradient flow using the JAX framework and make our implementation publicly available at https://github.com/aproca/neural_task_abstraction.

Our hyperparameters for all experimental settings are listed in the table below. We use 𝓛norm-L1 for most experiments, except for generalization to subtask composition and the flexible fully-connected network experiments where we use 𝓛norm-L2 (see Appendix B.3 for a discussion on this choice). For the cases where we induce the forgetful regime as an experimental control, we use the same hyperparameters, set regularization to 0 (λnonneg, λnorm-L1, λnorm-L2 = 0), and may or may not adjust the learning rate of the gating layer. Differences in hyperarameters from the main model and control are denoted in the tables below as ‘main / control.’

Table 2:

Hyperparameters.

Hyperparameter Task specialization (Fig. 2) Task composition (Fig. 3,A.12) Subtask composition (Fig. 3,A.12) Reduced model (Fig. 4) Fully-connected network (Fig. 6,A.4,A.5,A.6) MNIST (Fig. 7)
P 2 3 3 2 2 2
M 2 3 3 2 2 2
din 20 20 20 1 20 64
dhid 20
dout 10 6 6 2 10 10
λnonneg 0.091 / 0 0.5 / 0 0.023 / 0 0.091 / 0 0.2 / 0 0.5 / 0
λnorm-L1 0.456 / 0 1.25 / 0 0 0.455 / 0 0 0.25 / 0
λnorm-L2 0 0 0.011 / 0 0 0.1 / 0 0
τw 1.3 0.2 0.2 5 0.06 10
τc 0.03 / 1.3 0.03 0.005 0.7 0.01 0.005 / 10
batch size 200 200 200 200 200 100
seeds 10 10 10 1 10 10
number of blocks n 20 30 30 17 30 10
τB 1 1 1 1 1 1
dt 0.001 0.001 0.01 0.001 0.01 0.001
Table 3:

Hyperparameters II.

Hyperparameter NTA hyperparameter search (Fig. 5) Fully-connected hyperparameter search (Fig. A.7) Task switching (Fig. 8) Non-orthogonal tasks (Fig. A.14) Non-orthogonal teachers (Fig. A.13)
P 2 2 2 3 2
M 2 2 2 3 2
din 20 20 20 20 20
dhid 20
dout 10 10 10 6 10
λnonneg 0.5 0.23 0.18 / 0 0.33 0
λnorm-L1 1.25 0 0.36 / 0 0.83 0
λnorm-L2 0 0.11 0 0 0.5
τw 0.1 0.04 0.07 0.05 0.016
τc 0.005 0.01 0.01 0.03 0.016
batch size 200 200 200 200 200
seeds 10 10 10 10 1
number of blocks n 7 20 30 50 10
τB 1 1 1 1 1
dt 0.001 0.01 0.01 0.001 0.01
Table 4:

Hyperparameters III.

Hyperparameter Full vs. reduced model (Fig. A.1) Slow high-d students (Fig. A.2) Redundant paths (Fig. A.3) Few-shot adaptation (Fig. A.16) fashionMNIST (Fig. A.15)
P 2 2 4 2 2
M 2 2 2 2 2
din 20 / 1 30 20 20 64
dout 10 / 2 30 10 10 10
λnonneg 0.091 0.091 0.194 / 0.545 0.091 0.5 / 0
λnorm-L1 0.455 0.455 0.968 / 2.727 0 0
λnorm-L2 0 0 0 0.455 0.25 / 0
τw 1.3 0.5 1.3 1 10
τc 0.03 / 0.06 0.1 0.03 0.01 0.005 / 10
batch size 200 200 200 1 100
seeds 1 10 1 10 10
number of blocks n 20 20 20 6 10
τB 1 1 1 1 1
dt 0.001 0.001 0.001 0.02 0.001

B.3. Regularization

We use a combined regularizer that is motivated by biological constraints on our model parameters. The regularizers alleviate the underspecification of the solution space of our linear model and facilitate symmetry breaking to allow the model to specialize different components, while not forcing specialization (Fig. B.1). Here, we detail the effect of these regularizing terms. To this end, recall the definition of the reduced model, Eq. (14)

y=c1w1+c2w2.
Nonnegative neural activity

The gating variables steer the model output multiplicatively. Biologically, such an interaction is mediated by a firing rate, an inherently positive variable. Computationally, this has implications on the solution space: with random initialization, almost all configurations of w1 and w2 form a basis of 2. By definition, this means that there will always be two coefficients c1, c2 that will yield the correct solution. Nonnegativity constrains this set to lie in the positive quadrant of a 2D space. In particular,

𝓛nonneg=p=1Pmax0,cp
Alleviating invariance via competition

Even in the desired specialized configuration, the model is invariant under

cp,Wpacp,Wp/a

for any scalar a. We hence bound the norm of the vector c=c1,c2.

We consider two regularizers based on the L1 and L2 norm

𝓛norm-L1=1/2c112
𝓛norm-L2=1/2c212

𝓛norm-L1 encourages sparsity in the gates which is beneficial when there are few active gating variables as in the NTA model (with a single teacher active in each task). However, if we consider cases with potentially many active gating variables, as in the per-neuron gating NTA (see Appendix A.8) or a deep fully-connected network (or even many teachers active at once), favoring sparsity restricts expressivity of the model. In these cases, we instead use 𝓛norm-L2. In practice, both regularizers facilitate specialization robustly across many settings.

Applying nonnegativity and norm regularization together has the effect of inducing competition between gating variables. While there is a solution where gates are equal in magnitude (as shown in Fig. B.1), deviating from this solution while minimizing regularization loss will cause one gate to increase in magnitude and the other to decrease. Thus, this competitive effect facilitates symmetry-breaking in the gates.

Finally, we note that although these regularizers facilitate symmetry breaking through competition, they simultaneously allow for compositionality such that multiple gates can be active at once. We show this in several experiments, namely our studies of nonorthogonal tasks (Appendix A.9), compositional generalization (Appendix A.8), and fully-connected networks (Appendix A.4).

Figure B.1: The effect of regularization on gating variables.

Figure B.1:

Regularization encourages competition between gates while preventing degeneracy of solutions. Importantly, regularization does not force gating variables to be specialized, as illustrated by the red ×. This holds for two regularizers we consider, A. 𝓛norm-L1 and B. 𝓛norm-L2.

B.4. Description of metrics used across experiments

Student-teacher alignment

We compute a metric of alignment of each student and each teacher to determine whether students are specializing and, if so, to which teachers they specialize. We do this by computing the similarity between each student Wp and teacher Wm. More specifically, we take the mean of the cosine similarity between student and teacher row vectors. We then sort each student and its gate cp to the teacher it has the highest cosine similarity with.

Total alignment

We compute a metric of total alignment of the network students and teachers to evaluate overall specialization. After computing student-teacher alignment and sorting each student to its respective teacher, we concatenate all students and teachers and compute the overall cosine similarity between the set of students and teachers.

B.4.1. Description of sorting performed in per-neuron NTA and deep fully-connected network

In the cases of more expressive models, such as the per-neuron NTA and fully-connected network, there are Pdout and dout×Pdout respective independent paths modulated by gates. Thus, in order to study whether specialization and gating occurs for each teacher, we sort these into P paths. To do this, we compute the cosine similarity between each row in the first layer W and each row in the teachers W. We then sort the rows of the first layer to align with the rows of the teachers that they best match.

Additionally, we permute the second layer to match this sorting. In the per-neuron NTA, this corresponds to scalar gates that are multiplied to each row. In the fully-connected network, this corresponds to the columns of the second layer. Finally, we take the mean of the sorted gates (set of dout columns for the fully-connected network) for each student to visualize teacher-specific gating.

B.5. Hyperparameter search

NTA

For the two hyperparameter searches we perform, we run the NTA model on each set of hyperparameters and report the total alignment of concatenated teachers and students at the end of training as an overall measure of specialization. We fix all other hyperparameters.

When varying gate learning rate and block length, we fix the regularization strength to λnonneg=0.5, λnorm-L1=1.25. When varying regularization strength, we fix gate learning rate τw/τc=20. The regularization strength λ is multipled separately for each type of regularizer such that λnonneg=5λ/3 and λnorm-L1=25λ/6.

Fully-connected network

We also perform two hyperparameter searches on the fully-connected network. We run the fully-connected network on each set of hyperparameters and report the total alignment of sorted teachers and students at the end of training as an overall measure of specialization. We fix all other hyperparameters.

When varying second layer learning rate and block length, we fix the regularization strength to λnonneg=0.23, λnorm-L2=0.11. When varying regularization strength, we fix gate learning rate τW2/τW1=4. The regularization strength λ is multiplied separately for each type of regularizer such that λnonneg=10λ/11 and λnorm-L2=5λ/11.

B.6. Flexible fully-connected network

We randomly generate two orthogonal teachers Wmdout×din. We initialize our fully-connected networks to have two weight layers, W12dout×din and W2dout×2dout. We use a faster learning rate and regularize the second layer during training. We treat each weight Wij2 as a gate that enters into the regularization terms described in Appendix B.3. Student-teacher alignment, total alignment, and gate sorting is then performed as described above.

B.7. MNIST

The original convolutional network features a single convolutional layer with three feature maps and kernel size four, followed by a MaxPool layer of kernel size 2, a ReLU nonlinearity, a and fully-connected sigmoid, ReLU and log softmax layers of size 512, 64, and 10 respectively. The networks are trained using cross-entropy loss with a one-hot encoding for labels. The NTA portion is then trained beginning from the hidden layer representations of the final hidden layer with 64 units. hyperparameters for the NTA portion are given in Table 2 for MNIST and Table 4 for fashionMNIST.

Footnotes

38th Conference on Neural Information Processing Systems (NeurIPS 2024).

Contributor Information

Kai Sandbrink, Exp. Psychology, Oxford Brain Mind Institute, EPFL.

Jan P. Bauer, ELSC, HebrewU Gatsby Unit, UCL.

Alexandra M. Proca, Department of Computing Imperial College London.

Andrew M. Saxe, Gatsby Unit, UCL

Christopher Summerfield, Exp. Psychology, Oxford.

Ali Hummos, Brain and Cognitive Sciences MIT.

References

  1. Gershman Samuel J and Niv Yael. Learning latent structure: carving nature at its joints. Current Opinion in Neurobiology, 20(2):251–256, April 2010. ISSN 0959–4388. doi: 10.1016/j.conb.2010.02.008. URL https://www.sciencedirect.com/science/article/pii/S0959438810000309. [DOI] [PMC free article] [PubMed] [Google Scholar]
  2. Yu Linda Q., Wilson Robert C., and Nassar Matthew R.. Adaptive learning is structure learning in time. Neuroscience & Biobehavioral Reviews, 128:270–281, September 2021. ISSN 0149–7634. doi: 10.1016/j.neubiorev.2021.06.024. URL https://www.sciencedirect.com/science/article/pii/S0149763421002657. [DOI] [PMC free article] [PubMed] [Google Scholar]
  3. Castañón Santiago Herce, Cardoso-Leite Pedro, Altarelli Irene, Shawn Green C., Schrater Paul, and Bavelier Daphne. A mixture of generative models strategy helps humans generalize across tasks. bioRxiv, page 2021.02.16.431506, January 2021. doi: 10.1101/2021.02.16.431506. URL http://biorxiv.org/content/early/2021/02/16/2021.02.16.431506.abstract. [DOI] [Google Scholar]
  4. Bernardi Silvia, Benna Marcus K., Rigotti Mattia, Munuera Jérôme, Fusi Stefano, and Daniel Salzman C.. The geometry of abstraction in hippocampus and prefrontal cortex. Cell, 183(4): 954–967.e21, November 2020. ISSN 0092–8674. doi: 10.1016/j.cell.2020.09.031. URL https://www.ncbi.nlm.nih.gov/pmc/articles/PMC8451959/. [DOI] [PMC free article] [PubMed] [Google Scholar]
  5. Tafazoli Sina, Bouchacourt Flora M., Ardalan Adel, Markov Nikola T., Uchimura Motoaki, Mattar Marcelo G., Daw Nathaniel D., and Buschman Timothy J.. Building compositional tasks with shared neural subspaces. bioRxiv, 2024. doi: 10.1101/2024.01.31.578263. URL https://www.biorxiv.org/content/early/2024/03/22/2024.01.31.578263. Publisher: Cold Spring Harbor Laboratory _eprint: https://www.biorxiv.org/content/early/2024/03/22/2024.01.31.578263.full.pdf. [DOI] [PubMed] [Google Scholar]
  6. Flesch Timo, Balaguer Jan, Dekker Ronald, Nili Hamed, and Summerfield Christopher. Comparing continual task learning in minds and machines. Proceedings of the National Academy of Sciences, 115(44):E10313–E10322, October 2018. doi: 10.1073/pnas.1800755115. URL https://www.pnas.org/doi/full/10.1073/pnas.1800755115. Publisher: Proceedings of the National Academy of Sciences. [DOI] [PMC free article] [PubMed] [Google Scholar]
  7. Beukers Andre O., Collin Silvy H. P., Kempner Ross P., Franklin Nicholas T., Gershman Samuel J., and Norman Kenneth A.. Blocked training facilitates learning of multiple schemas. Communications Psychology, 2(1):1–17, April 2024. ISSN 2731–9121. doi: 10.1038/s44271-024-00079-4. URL https://www.nature.com/articles/s44271-024-00079-4. Publisher: Nature Publishing Group. [DOI] [PMC free article] [PubMed] [Google Scholar]
  8. McCloskey Michael and Cohen Neal J.. Catastrophic Interference in Connectionist Networks: The Sequential Learning Problem. In Bower Gordon H., editor, Psychology of Learning and Motivation, volume 24, pages 109–165. Academic Press, January 1989. doi: 10.1016/S0079-7421(08)60536-8. URL https://www.sciencedirect.com/science/article/pii/S0079742108605368. [DOI] [Google Scholar]
  9. Hadsell Raia, Rao Dushyant, Rusu Andrei A., and Pascanu Razvan. Embracing Change: Continual Learning in Deep Neural Networks. Trends in Cognitive Sciences, 24(12):1028–1040, December 2020. ISSN 1364–6613. doi: 10.1016/j.tics.2020.09.004. URL https://www.sciencedirect.com/science/article/pii/S1364661320302199. [DOI] [PubMed] [Google Scholar]
  10. Hummos Ali, del Río Felipe, Wang Brabeeba Mien, Hurtado Julio, Calderon Cristian B., and Yang Guangyu Robert. Gradient-based inference of abstract task representations for generalization in neural networks, 2024. URL https://arxiv.org/abs/2407.17356. _eprint: 2407.17356. [Google Scholar]
  11. Hummos Ali. Thalamus: a brain-inspired algorithm for biologically-plausible continual learning and disentangled representations. In The Eleventh International Conference on Learning Representations, 2023. URL https://openreview.net/forum?id=6orC5MvgPBK. [Google Scholar]
  12. Butz Martin V., Bilkey David, Humaidan Dania, Knott Alistair, and Otte Sebastian. Learning, planning, and control in a monolithic neural event inference architecture. Neural Networks, 117:135–144, September 2019. ISSN 0893–6080. doi: 10.1016/j.neunet.2019.05.001. URL https://www.sciencedirect.com/science/article/pii/S0893608019301339. [DOI] [PubMed] [Google Scholar]
  13. Lu Qihong, Hummos Ali, and Norman Kenneth A . Episodic memory supports the acquisition of structured task representations. bioRxiv, page 2024.05.06.592749, January 2024. doi: 10.1101/2024.05.06.592749. URL http://biorxiv.org/content/early/2024/05/07/2024.05.06.592749.abstract. [DOI] [Google Scholar]
  14. Schug Simon, Kobayashi Seijin, Akram Yassir, Wolczyk Maciej, Proca Alexandra Maria, Oswald Johannes Von, Pascanu Razvan, Sacramento Joao, and Steger Angelika. Discovering modular solutions that generalize compositionally. In The Twelfth International Conference on Learning Representations, 2024. URL https://openreview.net/forum?id=H98CVcX1eh. [Google Scholar]
  15. Steyvers Mark, Hawkins Guy E., Karayanidis Frini, and Brown Scott D.. A large-scale analysis of task switching practice effects across the lifespan. Proceedings of the National Academy of Sciences, 116(36):17735–17740, September 2019. doi: 10.1073/pnas.1906788116. URL https://www.pnas.org/doi/full/10.1073/pnas.1906788116. Publisher: Proceedings of the National Academy of Sciences. [DOI] [PMC free article] [PubMed] [Google Scholar]
  16. Miller E. K. and Cohen J. D.. An integrative theory of prefrontal cortex function. Annual Review of Neuroscience, 24:167–202, 2001. ISSN 0147–006X. doi: 10.1146/annurev.neuro.24.1.167. [DOI] [PubMed] [Google Scholar]
  17. Egner Tobias. Principles of cognitive control over task focus and task switching. Nature Reviews Psychology, 2(11):702–714, November 2023. ISSN 2731–0574. doi: 10.1038/s44159-023-00234-4. URL https://www.nature.com/articles/s44159-023-00234-4. Publisher: Nature Publishing Group. [DOI] [PMC free article] [PubMed] [Google Scholar]
  18. Sandbrink Kai and Summerfield Christopher. Modelling cognitive flexibility with deep neural networks. Current Opinion in Behavioral Sciences, 57:101361, June 2024. ISSN 2352–1546. doi: 10.1016/j.cobeha.2024.101361. URL https://www.sciencedirect.com/science/article/pii/S2352154624000123. [DOI] [PMC free article] [PubMed] [Google Scholar]
  19. Musslick Sebastian and Cohen Jonathan D.. Rationalizing constraints on the capacity for cognitive control. Trends in Cognitive Sciences, 25(9):757–775, September 2021. ISSN 1879–307X. doi: 10.1016/j.tics.2021.06.001. [DOI] [PubMed] [Google Scholar]
  20. Yang Guangyu Robert, Joglekar Madhura R., Francis Song H., Newsome William T., and Wang Xiao-Jing. Task representations in neural networks trained to perform many cognitive tasks. Nature Neuroscience, 22(2):297–306, February 2019. ISSN 1546–1726. doi: 10.1038/s41593-018-0310-2. URL https://www.nature.com/articles/s41593-018-0310-2. [DOI] [PMC free article] [PubMed] [Google Scholar]
  21. Kirkpatrick James, Pascanu Razvan, Rabinowitz Neil, Veness Joel, Desjardins Guillaume, Rusu Andrei A., Milan Kieran, Quan John, Ramalho Tiago, Grabska-Barwinska Agnieszka, Hassabis Demis, Clopath Claudia, Kumaran Dharshan, and Hadsell Raia. Overcoming catastrophic forgetting in neural networks. Proceedings of the National Academy of Sciences, 114(13):3521–3526, March 2017. doi: 10.1073/pnas.1611835114. URL https://www.pnas.org/doi/10.1073/pnas.1611835114. Publisher: Proceedings of the National Academy of Sciences. [DOI] [PMC free article] [PubMed] [Google Scholar]
  22. Masse Nicolas Y., Grant Gregory D., and Freedman David J.. Alleviating catastrophic forgetting using context-dependent gating and synaptic stabilization. Proceedings of the National Academy of Sciences, 115(44):E10467–E10475, October 2018. doi: 10.1073/pnas.1803839115. URL https://www.pnas.org/doi/10.1073/pnas.1803839115. Publisher: Proceedings of the National Academy of Sciences. [DOI] [PMC free article] [PubMed] [Google Scholar]
  23. Wang Dongkai and Zhang Shiliang. Contextual Instance Decoupling for Robust Multi-Person Pose Estimation. pages 11060–11068, 2022. URL https://openaccess.thecvf.com/content/CVPR2022/html/Wang_Contextual_Instance_Decoupling_for_Robust_Multi-Person_Pose_Estimation_CVPR_2022_paper.html. [Google Scholar]
  24. Driscoll Laura N., Shenoy Krishna, and Sussillo David. Flexible multitask computation in recurrent networks utilizes shared dynamical motifs. Nature Neuroscience, 27(7):1349–1363, July 2024. ISSN 1546–1726. doi: 10.1038/s41593-024-01668-6. URL https://www.nature.com/articles/s41593-024-01668-6. Publisher: Nature Publishing Group. [DOI] [PMC free article] [PubMed] [Google Scholar]
  25. Masse Nicolas Y., Rosen Matthew C., Tsao Doris Y., and Freedman David J.. Flexible cognition in rigid reservoir networks modulated by behavioral context, May 2022. URL https://www.biorxiv.org/content/10.1101/2022.05.09.491102v2. Pages: 2022.05.09.491102 Section: New Results. [Google Scholar]
  26. Jacobs Robert A., Jordan Michael I., and Barto Andrew G.. Task Decomposition Through Competition in a Modular Connectionist Architecture: The What and Where Vision Tasks. Cognitive Science, 15(2):219–250, 1991. doi: 10.1207/s15516709cog1502_2. URL https://onlinelibrary.wiley.com/doi/abs/10.1207/s15516709cog1502_2. _eprint: https://onlinelibrary.wiley.com/doi/pdf/10.1207/s15516709cog1502_2. [DOI] [Google Scholar]
  27. Jordan Michael I. and Jacobs Robert A.. Hierarchical Mixtures of Experts and the EM Algorithm. Neural Computation, 6(2):181–214, March 1994. ISSN 0899–7667. doi: 10.1162/neco.1994.6.2.181. URL https://doi.org/10.1162/neco.1994.6.2.181. _eprint: https://direct.mit.edu/neco/article-pdf/6/2/181/812708/neco.1994.6.2.181.pdf. [DOI] [Google Scholar]
  28. Tsuda Ben, Tye Kay M., Siegelmann Hava T., and Sejnowski Terrence J.. A modeling framework for adaptive lifelong learning with transfer and savings through gating in the prefrontal cortex. Proceedings of the National Academy of Sciences, 117(47):29872–29882, November 2020. doi: 10.1073/pnas.2009591117. URL https://www.pnas.org/doi/abs/10.1073/pnas.2009591117. Publisher: Proceedings of the National Academy of Sciences. [DOI] [PMC free article] [PubMed] [Google Scholar]
  29. Andreas Jacob, Rohrbach Marcus, Darrell Trevor, and Klein Dan. Neural module networks. In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 39–48, 2016. [Google Scholar]
  30. Kirsch Louis, Kunze Julius, and Barber D.. Modular Networks: Learning to Decompose Neural Computation. November 2018. URL https://www.semanticscholar.org/paper/Modular-Networks%3A-Learning-to-Decompose-Neural-Kirsch-Kunze/0b50b9e103e19d87c2d30ed0d157d8379320ce6f. [Google Scholar]
  31. Goyal Anirudh, Lamb Alex, Hoffmann Jordan, Sodhani Shagun, Levine Sergey, Bengio Yoshua, and Schölkopf B.. Recurrent Independent Mechanisms. ArXiv, 2019. [Google Scholar]
  32. Geweke John. Interpretation and inference in mixture models: Simple MCMC works. Computational Statistics & Data Analysis, 51(7):3529–3550, April 2007. ISSN 0167–9473. doi: 10.1016/j.csda.2006.11.026. URL https://www.sciencedirect.com/science/article/pii/S0167947306004506. [DOI] [Google Scholar]
  33. Mittal Sarthak, Lamb Alex, Goyal Anirudh, Voleti Vikram, Shanahan Murray, Lajoie Guillaume, Mozer Michael, and Bengio Yoshua. Learning to Combine Top-Down and Bottom-Up Signals in Recurrent Neural Networks with Attention over Modules. arXiv:2006.16981 [cs, stat], November 2020. URL http://arxiv.org/abs/2006.16981. arXiv: 2006.16981. [Google Scholar]
  34. Krishnamurthy Yamuna, Watkins Chris, and Gaertner Thomas. Improving Expert Specialization in Mixture of Experts. arXiv preprint arXiv:2302.14703, 2023. [Google Scholar]
  35. Barry Martin and Gerstner Wulfram. Fast Adaptation to Rule Switching using Neuronal Surprise, September 2022. URL https://www.biorxiv.org/content/10.1101/2022.09. 13.507727v1. Pages: 2022.09.13.507727 Section: New Results. [DOI] [PMC free article] [PubMed] [Google Scholar]
  36. Saxe Andrew M, McClelland James L, and Ganguli Surya. Exact solutions to the nonlinear dynamics of learning in deep linear neural networks. arXiv preprint arXiv:1312.6120, 2013. tex.creationdate: 2022-07-07T18:40:31 tex.modificationdate: 2022-07-07T18:40:39. [Google Scholar]
  37. Saxe Andrew M., McClelland James L., and Ganguli Surya. A mathematical theory of semantic development in deep neural networks. Proceedings of the National Academy of Sciences, 116(23): 11537–11546, June 2019. doi: 10.1073/pnas.1820226116. URL https://www.pnas.org/doi/10.1073/pnas.1820226116. Publisher: Proceedings of the National Academy of Sciences. [DOI] [PMC free article] [PubMed] [Google Scholar]
  38. Saxe Andrew, Sodhani Shagun, and Lewallen Sam Jay. The Neural Race Reduction: Dynamics of Abstraction in Gated Networks. In Chaudhuri Kamalika, Jegelka Stefanie, Song Le, Szepesvari Csaba, Niu Gang, and Sabato Sivan, editors, Proceedings of the 39th International Conference on Machine Learning, volume 162 of Proceedings of Machine Learning Research, pages 19287–19309. PMLR, July 2022. URL https://proceedings.mlr.press/v162/saxe22a.html. [Google Scholar]
  39. Shi Jianghong, Shea-Brown Eric, and Buice Michael A.. Learning dynamics of deep linear networks with multiple pathways. Advances in Neural Information Processing Systems, 35:34064–34076, December 2022. ISSN 1049–5258. [PMC free article] [PubMed] [Google Scholar]
  40. Lee Jin Hwa, Mannelli Stefano Sarao, and Saxe Andrew. Why Do Animals Need Shaping? A Theory of Task Composition and Curriculum Learning. arXiv preprint arXiv:2402.18361, 2024. [Google Scholar]
  41. Deng Li. The mnist database of handwritten digit images for machine learning research. IEEE Signal Processing Magazine, 29(6):141–142, 2012. Publisher: IEEE. [Google Scholar]
  42. Xiao Han, Rasul Kashif, and Vollgraf Roland. Fashion-MNIST: a novel image dataset for benchmarking machine learning algorithms, August 2017. arXiv: cs.LG/1708.07747 [cs.LG]. [Google Scholar]
  43. Dohare Shibhansh, Fernando Hernandez-Garcia J., Lan Qingfeng, Rahman Parash, Rupam Mahmood A., and Sutton Richard S.. Loss of plasticity in deep continual learning. Nature, 632 (8026):768–774, August 2024. ISSN 1476–4687. doi: 10.1038/s41586-024-07711-7. URL https://www.nature.com/articles/s41586-024-07711-7. Publisher: Nature Publishing Group. [DOI] [PMC free article] [PubMed] [Google Scholar]
  44. Atanasov Alexander, Bordelon Blake, and Pehlevan C.. Neural Networks as Kernel Learners: The Silent Alignment Effect. ArXiv, October 2021. URL https://www.semanticscholar.org/paper/Neural-Networks-as-Kernel-Learners%3A-The-Silent-Atanasov-Bordelon/ccd3631a4509aac2d71c320a6ac677f311d94b05. [Google Scholar]
  45. Braun Lukas, Dominé Clémentine, Fitzgerald James, and Saxe Andrew. Exact learning dynamics of deep linear networks with prior knowledge. Advances in Neural Information Processing Systems, 35:6615–6629, 2022. URL https://proceedings.neurips.cc/paper_files/paper/2022/hash/2b3bb2c95195130977a51b3bb251c40a-Abstract-Conference.html. [PMC free article] [PubMed] [Google Scholar]
  46. Atanasov Alexander, Bordelon Blake, Sainathan Sabarish, and Pehlevan Cengiz. The onset of variance-limited behavior for networks in the lazy and rich regimes. December 2022. arXiv: 2212.12147 [stat.ML] tex.creationdate: 2022-12-28T08:24:33 tex.modificationdate: 2022-12-28T08:25:21. [Google Scholar]
  47. Jacot Arthur, Gabriel Franck, and Hongler Clément. Neural Tangent Kernel: Convergence and Generalization in Neural Networks. arXiv:1806.07572 [cs, math, stat], February 2020. URL http://arxiv.org/abs/1806.07572. arXiv: 1806.07572. [Google Scholar]

Articles from ArXiv are provided here courtesy of arXiv

RESOURCES