(C) PLOS One
This story was originally published by PLOS One and is unaltered.
. . . . . . . . . .
Neural spiking for causal inference and learning [1]
['Benjamin James Lansdell', 'Department Of Bioengineering', 'University Of Pennsylvania', 'Philadelphia', 'Pennsylvania', 'United States Of America', 'Konrad Paul Kording', 'Department Of Neuroscience']
Date: 2023-04
When a neuron is driven beyond its threshold, it spikes. The fact that it does not communicate its continuous membrane potential is usually seen as a computational liability. Here we show that this spiking mechanism allows neurons to produce an unbiased estimate of their causal influence, and a way of approximating gradient descent-based learning. Importantly, neither activity of upstream neurons, which act as confounders, nor downstream non-linearities bias the results. We show how spiking enables neurons to solve causal estimation problems and that local plasticity can approximate gradient descent using spike discontinuity learning.
Despite significant research, models of spiking neural networks still lag behind artificial neural networks in terms of performance in machine learning and modeling cognitive tasks. Given this, we may wonder, why do neurons spike? A key problem that must be solved in any neural network is the credit assignment problem. That is, how does a neuron know its effect on downstream computation and rewards, and thus how it should change its synaptic weights to improve? Artificial neural networks solve this problem with the back-propagation algorithm. We are still seeking to understand how biological neural networks effectively solve this problem. In this work we show that the discontinuous, all-or-none spiking response of a neuron can in fact be used to estimate a neuron’s causal effect on downstream processes. Inspired by methods from econometrics, we show that the thresholded response of a neuron can be used to get at that neuron’s unique contribution to a reward signal, separating it from other neurons whose activity it may be correlated with. This proposal provides insights into a novel function of spiking that we explore in simple networks and learning tasks.
Here we propose the spiking discontinuity is used by a neuron to efficiently estimate its causal effect. Once a neuron can estimate its causal effect, it can use this knowledge to calculate gradients and adjust its synaptic strengths. We show that this idea suggests learning rules that allows a network of neurons to learn to maximize reward, particularly in the presence of confounded inputs. We demonstrate the rule in simple models. The discontinuity-based method provides a novel and plausible account of how neurons learn their causal effect.
However, the key insight in this paper is that the story is different when comparing the average reward in times when the neuron barely spikes versus when it almost spikes. The difference in the state of the network in the barely spikes versus almost spikes case is negligible, the only difference is the fact that in one case the neuron spiked and in the other case the neuron did not. Any difference in observed reward can therefore only be attributed to the neuron’s activity. In this way the spiking discontinuity may allow neurons to estimate their causal effect.
How else could neurons estimate their causal effect? Over a short time window, a neuron either does or does not spike. Comparing the average reward when the neuron spikes versus does not spike gives a confounded estimate of the neuron’s effect. Because neurons are correlated, a given neuron spiking is associated with a different network state than that neuron not-spiking. And it is this difference in network state that may account for an observed difference in reward, not specifically the neuron’s activity. Simple correlations will give wrong causal estimates.
The gold-standard approach to causal inference is randomized perturbation. If a neuron occasionally adds an extra spike (or removes one), it could readily estimate its causal effect by correlating the extra spikes with performance. Such perturbations come at a cost, since the noise can degrade performance. This class of learning methods has been extensively explored [ 16 – 22 ]. However, despite important special cases [ 17 , 19 , 23 ], in general it is not clear how a neuron may know its own noise level. Thus we may wonder if neurons estimate their causal effect without random perturbations.
A key computational problem in both biological and artificial settings is the credit assignment problem [ 10 ]. When performance is sub-optimal, the brain needs to decide which activities or weights should be different. Credit assignment is fundamentally a causal estimation problem—which neurons are responsible for the bad performance, and not just correlated with bad performance? Solving such problems is difficult because of confounding: if a neuron of interest was active during bad performance it could be that it was responsible, or it could be that another neuron whose activity is correlated with the neuron of interest was responsible. In general, confounding happens if a variable affects both another variable of interest and the performance. Even when a fixed stimulus is presented repeatedly, neurons exhibit complicated correlation structures [ 11 – 15 ] which confounds a neuron’s estimate of its causal effect. This prompts us to ask how neurons can solve causal estimation problems.
There are, of course, pragmatic reasons for spiking: spiking may be more energy efficient [ 6 , 7 ], spiking allows for reliable transmission over long distances [ 8 ], and spike timing codes may allow for more transmission bandwidth [ 9 ]. Yet, despite these ideas, we may still wonder if there are computational benefits of spikes that balance the apparent disparity in the learning abilities of spiking and artificial networks.
Most nervous systems communicate and process information utilizing spiking. Yet machine learning mostly uses artificial neural networks with continuous activities. Computationally, despite a lot of recent progress [ 1 – 5 ], it remains challenging to create spiking neural networks that perform comparably to continuous artificial networks. Instead, spiking is generally seen as a disadvantage—it is difficult to propagate gradients through a discontinuity, and thus to train spiking networks. This disparity between biological neurons that spike and artificial neurons that are continuous raises the question, what are the computational benefits of spiking?
We also want to know whether spiking discontinuity can estimate causal effects in deep neural networks. It effectively estimates the causal effect in a spiking neural network with two hidden layers ( Fig 5B ). We compare a network simulated with correlated inputs, and one with uncorrelated inputs. The estimates of causal effect in the uncorrelated case, obtained using the observed dependence estimator, provide an unbiased estimator the true causal effect (blue dashed line). The causal effect in the correlated inputs case is indeed close to this unbiased value. In contrast, using the observed-dependence estimator on the confounded inputs significantly deviates from the true causal effect. These results show spiking discontinuity can estimate causal effects in both wide and deep neural networks.
Furthermore, when considering a networks’ estimates as a whole, we can compare the vector of estimated causal effects to the true causal effects β ( Fig 5A , bottom panels). The angle between these two vectors gives an idea of how a learning algorithm will perform when using these estimates of the causal effect. If considered as a gradient then any angle well below ninety represents a descent direction in the reward landscape, and thus shifting parameters in this direction will lead to improvements. In this case there is a more striking difference between the spiking discontinuity and observed dependence estimators. Only for extremely high correlation values or networks with a layer of one thousand neurons does it fail to produce estimates that are not very well aligned with the true causal effects. In contrast, the observed dependence estimator is only well-aligned with the true gradient for small networks, and with a small correlation coefficient. This suggests the SDE provides a more scale-able and robust estimator of causal effect for the purposes of learning in the presence of confounders.
(A) Mean square error (MSE) as a function of network size and noise correlation coefficient, c. MSE is computed as squared difference from the true causal effect, where the true causal effect is estimated using the observed dependence estimator with c = 0 (unconfounded). Alignment between the true causal effects β and the estimated effects is the angle between these two vectors. (B) A two hidden layer neural network, with hidden layers of width 10. Plots show the causal effect of each of the first hidden layer neurons on the reward signal. Dashed lines show the observed-dependence estimator, solid lines show the spiking discontinuity estimator, for correlated and uncorrelated (unconfounded) inputs, over a range of window sizes p. The observed dependence estimator is significantly biased with confounded inputs.
Having validated spiking discontinuity-based causal inference in a small network, we investigate how well we can estimate causal effects in wider and deeper networks. First we investigate the effects of network width on performance. We simulate a single hidden layer neural network of varying width ( Fig 5A ; refer to the Methods for implementation details). The mean squared error in estimating causal effects shows an approximately linear dependence on the number of neurons in the layer, for both the observed-dependence estimator and the spiking discontinuity. For low correlation coefficients, representing low confounding, the observed dependence estimator has a lower error. However, once confounding is introduced, the error increases dramatically, varying over three orders of magnitude as a function of correlation coefficient. In contrast, the spiking discontinuity error is more or less constant as a function of correlation coefficient, except for the most extreme case of c = 0.99. This shows that over a range of network sizes and confounding levels, a spiking discontinuity estimator is robust to confounding.
(A) Parameters for causal effect model, u , are updated based on whether neuron is driven marginally below or above threshold. (B) Applying rule to estimate for two sample neurons shows convergence within 10s (red curves). Error bars represent standard error of the mean over 50 simulations. (C) Convergence of observed dependence (left) and spike discontinuity (right) learning rule to unconfounded network (c = 0.01). Observed dependence converges more directly to bottom of valley, while spiking discontinuity learning trajectories meander more, as the initial estimate of causal effect takes more inputs to update. (D,E) Convergence of observed dependence (D) and spiking discontinuity (E) learning rule to confounded network (c = 0.5). Right panels: error as a function of time for individual traces (blue curves) and mean (black curve). With confounding learning based on observed dependence converges slowly or not at all, whereas spike discontinuity learning succeeds.
When applied to the toy LIF network, the online learning rule ( Fig 4A ) estimates β over the course of seconds ( Fig 4B ). The estimated β is then used to update weights to maximize expected reward in an unconfounded network (uncorrelated noise—c = 0.01). In these simulations updates to β are made when the neuron is close to threshold, while updates to w i are made for all time periods of length T. Learning exhibits trajectories that initially meander while the estimate of β settles down ( Fig 4C ). When a confounded network (correlated noise—c = 0.5) is used the spike discontinuity learning exhibits similar performance, while learning based on the observed dependence sometimes fails to converge due to the bias in gradient estimate. In this case convergence is faster than learning based on observed dependence ( Fig 4D and 4E ). Thus the spike discontinuity learning rule allows a network to be trained even in the presence of confounded inputs.
We just showed that the spiking discontinuity allows neurons to estimate their causal effect. We can implement this as a simple learning rule that illustrates how knowing the causal effect impacts learning. We derive an online learning rule that estimates β, and the linear model parameters, if needed. That is, let u i be the vector of parameters required to estimate β i for neuron i. For the piece-wise constant reward model that is u i = [γ i , β i ] and for the piece-wise linear model it is u i = [γ i , β i , α ri , α li ]. Then the learning rule takes the form: (6) where η is a learning rate, and a i are drive-dependent terms (see Methods for details and the derivation). Using this learning rule to update u i , along with the relation (7) allows us to update the weights according to a stochastic gradient-like update rule: in order to maximize reward. This allows us to use the causal effect to estimate ( Fig 3E and 3F ), and thus gives a local learning rule that approximates gradient-descent.
The same simple model can be used to estimate the dependence of the quality of spike discontinuity estimates on network parameters. To investigate the robustness of this estimator, we systematically vary the weights, w i , of the network. This allows for an exploration SDE’s performance in a range of network states. SDE works better when activity is fluctuation-driven and at a lower firing rate ( Fig 3C ). Thus spiking discontinuity is most applicable in irregular but synchronous activity regimes [ 26 ]. Over this range of network weights, spiking discontinuity is less biased than the observed dependence ( Fig 3D ). Overall, corrected estimates based on spiking considerably improve on the naive implementation.
(A) Estimates of causal effect (black line) using a constant spiking discontinuity model (difference in mean reward when neuron is within a window p of threshold) reveals confounding for high p values and highly correlated activity. p = 1 represents the observed dependence, revealing the extent of confounding (dashed lines). Curves show mean plus/minus standard deviation over 50 simulations. (B) The linear model is unbiased over larger window sizes and more highly correlated activity (high c). (C) Relative error in estimates of causal effect over a range of weights (1 ≤ w i ≤ 20) show lower error with higher coefficient of variability (CV; top panel), and lower error with lower firing rate (bottom panel). (D) Over this range of weights, spiking discontinuity estimates are less biased than just the naive observed dependence. (E,F) Approximation to the reward gradient overlaid on the expected reward landscape. The white vector field corresponds to the true gradient field, the black field correspond to the spiking discontinuity estimate (E) and observed dependence (F) estimates. The observed dependence is biased by correlations between neuron 1 and 2—changes in reward caused by neuron 1 are also attributed to neuron 2.
Simulating this simple two-neuron network shows how a neuron can estimate its causal effect using the SDE ( Fig 3A and 3B ). To show how it removes confounding, we implement both the piece-wise constant and piece-wise linear models for a range of window sizes p. When p is large, the piece-wise constant model corresponds to the biased observed-dependence estimator, , while small p values approximate the SDE estimator and result in an unbiased estimate ( Fig 3A ). The window size p determines the variance of the estimator, as expected from theory [ 33 ]. The piece-wise linear model, Eq (5) , is more robust to confounding ( Fig 3B ), allowing larger p values to be used. Thus the linear correction that is the basis of many RDD implementations [ 28 ] allows neurons to more readily estimate their causal effect. This linear dependence on the maximal voltage of the neuron may be approximated by plasticity that depends on calcium concentration; such implementation issues are considered in the discussion.
To summarize the idea: for a neuron to apply spiking discontinuity estimation, it simply must track if it was close to spiking, whether it spiked or not, and observe the reward signal. Then the comparison in reward between time periods when a neuron almost reaches its firing threshold to moments when it just reaches its threshold allows for an unbiased estimate of its own causal effect ( Fig 2D and 2E ). Below we gain intuition about how the estimator works, and how it differs from the naive estimate.
This approach relies on some assumptions. First, a neuron assumes its effect on the expected reward can be written as a function of Z i which has a discontinuity at Z i = θ, such that, in the neighborhood of Z i = θ, the function can be approximated by either its 0-degree (piecewise constant version) or 1-degree Taylor expansion (piecewise linear). This approach also assumes that the input variable Z i is itself a continuous variable. This is the case in simulations explored here.
A one-order-higher model of the reward adds a linear correction, resulting in the piece-wise linear model of the reward function: (5) where γ i , α li and α ri are the linear regression parameters. This higher-order model can allow for larger window sizes p, and thus a lower variance estimator. Call the causal effect estimate using the piecewise constant model and the causal effect estimate using the piecewise-linear model . Both such models are explored in the simulations below.
Instead, we can estimate β i only for inputs that placed the neuron close to its threshold. That is, let Z i be the maximum integrated neural drive to the neuron over the trial period. If θ is the neuron’s spiking threshold, then a maximum drive above θ results in a spike, and below θ results in no spike. The neural drive used here is the leaky, integrated input to the neuron, that obeys the same dynamics as the membrane potential except without a reset mechanism. By tracking the maximum drive attained over the trial period, we can track when inputs placed the neuron close to its threshold, and marginally super-threshold inputs can be distinguished from well-above-threshold inputs, as required for SDE ( Fig 2C ). Let p be a window size within which we are going to call the integrated inputs Z i ‘close’ to threshold, then the SDE estimator of β i is: (4)
(A) Graphical model describing neural network. Neuron H i receives input X , which contributes to drive Z i . If drive is above the spiking threshold, then H i is active. The activity contributes to reward R. Though not shown, this relationship may be mediated through downstream layers of a neural network, and complicated interactions with the environment. From neuron H i ’s perspective, the activity of the other neurons H which also contribute to R is unobserved. For clarity, other neuron’s Z variables have been omitted from this graph. Unobserved dependencies of X → R and Z i → R are shown here, even though not part of the underlying dynamical model, such dependencies in the graphical model may still exist, as discussed in the text. (B) The reward may be tightly correlated with other neurons’ activity, which act as confounders. However any discontinuity in reward at the neuron’s spiking threshold can only be attributed to that neuron. The discontinuity at the threshold is thus a meaningful estimate of the causal effect (left). The effect of a spike on a reward function can be determined by considering data when the neuron is driven to be just above or just below threshold (right). (C) This is judged by looking at the neural drive to the neuron over a short time period. Marginal sub- and super-threshold cases can be distinguished by considering the maximum drive throughout this period. (D) Schematic showing how spiking discontinuity operates in network of neurons. Each neuron contributes to output, and observes a resulting reward signal. Learning takes place at end of windows of length T. Only neurons whose input drive brought it close to, or just above, threshold (gray bar in voltage traces; compare neuron 1 to 2) update their estimate of β. (E) Model notation.
To give intuition into how this confounding problem manifests in a neural learning setting, consider how a neuron can estimate its causal effect. A basic model of a neuron’s effect on reward is that it can be estimated from the following piece-wise constant model of the reward function: (2) That is, there is some baseline expected reward, γ i , and a neuron-specific contribution β i , where H i represents the spiking indicator function for neuron i over the trial of period T. Then denote by β i the causal effect of neuron i on the resulting reward R.
To demonstrate the idea that a neuron can use its spiking non-linearity to estimate causal effects, here we analyze a simple two neuron network obeying leaky integrate-and-fire (LIF) dynamics. The neurons receive a shared scalar input signal x(t), with added separate noise inputs η i (t), that are correlated with coefficient c. Each neuron weighs the noisy input by w i . The correlation in input noise induces a correlation in the output spike trains of the two neurons [ 31 ], thereby introducing confounding. At the end of a trial period T, the neural output determines a reward signal R. Most aspects of causal inference can be investigated in a simple, few-variable model such as this [ 32 ], thus demonstrating that a neuron can estimate a causal effect in this simple case is an important first step to understanding how it can do so in a larger network.
As outlined in the introduction, the idea is that inputs that place a neuron close to its spiking threshold can be used in an unbiased causal effect estimator. For inputs that place the neuron just below or just above its spiking threshold, the difference in the state of the rest of the network becomes negligible, the only difference is the fact that in one case the neuron spiked and in the other case the neuron did not. Any difference in observed reward can therefore only be attributed to the neuron’s activity. Statistically, within a small interval around the threshold spiking becomes as good as random [ 28 – 30 ]. In this way the spiking discontinuity may allow neurons to estimate their causal effect.
Having understood how the causal effect of a neuron can be defined, and how it is relevant to learning, we now consider how to estimate it. The key observation of this paper is to note that a discontinuity can be used to estimate causal effects, without randomization, but while retaining the benefits of randomization. We will refer to this approach as the Spiking Discontinuity Estimator (SDE). This estimator is equivalent to the regression discontinuity design (RDD) approach to causal inference which is popular in economics [ 28 , 29 ].
The causal effect β i is an important quantity for learning: if we know how a neuron contributes to the reward, the neuron can change its behavior to increase it. More specifically, in a spiking neural network, the causal effect can be seen as a type of finite difference approximation of the partial derivative (reward with a spike vs reward without a spike). That is, since we assume that R is function of some filtered neural output of the network, then it can be shown that the causal effect, β i , under certain assumptions, does indeed approximate the reward gradient . That is, there is a sense in which: (1) More specifically: where D i R is a random variable that represents the finite difference operator of R with respect to neuron i’s firing, and Δ s is a constant that depends on the spike kernel κ and acts here like a kind of finite step size. Refer to the methods section for the derivation. This result establishes a connection between causal inference and gradient-based learning. It suggests that methods from causal inference may provide efficient algorithms to estimate reward gradients, and thus can be used to optimize reward. In this way the causal effect is a relevant quantity for learning.
Given this causal network, we can then define a neuron’s causal effect. In causal inference, the causal effect can be understood as the expected difference in an outcome R when a ‘treatment’ H i is exogenously assigned. The causal effect of neuron i on reward R is defined as: where do represents the do-operator, notation for an intervention [ 27 ]. Because of the fact that the aggregated variables maintain the same ordering as the underlying dynamic variables, there is a well-defined sense in which R is indeed an effect of H i , not the other way around, and therefore that the causal effect β i is a sensible, and not necessarily zero, quantity ( Fig 1C and 1D ). That is to say, it makes sense to associate with each neuron, for each stimulus, what its causal effect is on the output and thus reward.
The second criterion is that the graph can be used to describe what happens when interventions are made. Fully spelling this out is beyond the scope of this study, so here we assume the interventional distributions on nodes of factor the distribution ρ as expected in the definition above. This is reasonable since, for instance, intervening on the underlying variable h i (t) (to enforce a spike at a given time), would sever the relation between Z i and H i as dictated by the graph topology. Taken together, this means the graph describes a causal Bayesian network over the distribution ρ.
(A) The dynamic spiking network model. The state at time bin t depends on both the previous state and a hierarchical dependence between inputs x t , neuron activities h t , and the reward signal r. Omitted for clarity are the extra variables that determine the network state (v(t) and s(t)). (B) To formulate the supervised learning problem, these variables are aggregated in time to produce summary variables of the state of the network during the simulated window. We consider these variables being drawn IID from a distribution ( X , Z , H , S , R) ∼ ρ. Intervening on the underlying dynamic variables changes the distribution ρ accordingly. E.g. severing the connection from x t to h t for all t renders H independent of X . Thus the graphical model over ( X , Z , H , R) has the same hierarchy (ordering) as the underlying dynamical model. However, the aggregate variables do not fully summarize the state of the network throughout the simulation. Therefore, unlike the structure in the underlying dynamics, H may not fully separate X from R—we must allow for the possibility of a direct connection from X to R. (C) and (D) are simple examples illustrating the difference between observed dependence and causal effect. We have omitted the dependence on X for simplicity. Violin plots show reward when H 1 is active or inactive, without (left subplot) and with (right) intervening on H 1 . (C) If H 1 and H 2 are independent, the observed dependence matches the causal effect. (D) If H 2 causes H 1 then H 2 is an unobserved confounder, and the observed dependence and causal effects differ.
To use this theory, first, we describe a graph such that ρ is compatible with the conditional independence requirement of the above definition. Consider the ordering of the variables Φ that matches the feedforward structure of the underlying dynamic feedforward network ( Fig 1A ). From this ordering we construct the graph over the variables Φ ( Fig 1B ). This graph respects the order of variables implied in Fig 1A , but it is over-complete, in the sense that it also contains a direct link between X and R. This direct link between X and R, though absent in the underlying dynamical model, cannot be ruled out in a distribution over the aggregate variables, so must be included. The graph is directed, acyclic and fully-connected. Being fully connected in this way guarantees that we can factor the distribution ρ with the graph and it will obey the conditional independence criterion described above.
Definition 1 . In a causal Bayesian network, the probability distribution ρ is factored according to a graph, [ 27 ]. The edges in the graph represent causal relationships between the nodes ; the graph is both directed and acyclic (a DAG). The standard definition of a causal Bayesian model imposes two constraints on the distribution ρ, relating to:
The functionals are required to only depend on one underlying dynamical variable. The spiking discontinuity approach requires that H is an indicator functional, simply indicating the occurrence of a spike or not within window T; it could instead be defined directly in terms of Z. The random variable Z is required to have the form defined above, a maximum of the integrated drive. The choice of functionals is required to be such that, if there is a dependence between two underlying dynamical variables (e.g. h and x), then there is also some statistical dependence between these variables in the aggregated variables. I.e. a trivial functional f R (r) = 0 would destroy any dependence between X and R. Given these considerations, for the subsequent analysis, the following choices are used: These choices were made since they showed better empirical performance than, e.g. taking the mean over T for S and R. Along with the parameters of the underlying dynamical neural network, these choices determine the form of the distribution ρ. The learning problem can now be framed as: how can parameters Θ be adjusted such that is maximized? Below we show how the causal effect of a neuron on reward can be defined and used to maximize this reward.
To compute what the network’s output is for the given input over T seconds, we define a set of random variables, X , Z , H , S , R that aggregate the underlying dynamical (and spiking) variables, x (t), z (t), h (t), s (t) and r(t), respectively. Specifically, we define aggregating functionals f [⋅] to be summaries of the underlying dynamical variable activities. When the network dynamics are irregular [ 26 ], these aggregate variables will be approximately independent across subsequent windows of sufficient duration T. The dynamics given by the noisy LIF network generate an ergodic Markov process with a stationary distribution. That is, from a dynamical network, we have a set of random variables Φ that summarize the state of the network and can be considered I.I.D. draws from some distribution, which depends on the network’s weights and other parameters, Θ (e.g. noise magnitude and correlation): Φ ≔ ( X , Z , H , S , R) ∼ ρ(⋅; Θ).
As described in the introduction, to apply the spiking discontinuity method to estimate causal effects, we have to track how close a neuron is to spiking. Over the time period T, to distinguish between just-above-threshold inputs from well-above-threshold inputs, we also consider the input drive to the neuron: u i (t), which is the integrated input to the neuron, except without the reset mechanism. By tracking integrated inputs with a reset mechanism, then the value Z i = max 0≤t≤T u i (t) tells us if neuron i received inputs that placed it well above threshold, or just above threshold.
To accommodate these differences, we consider the following learning problem. Consider a population of N neurons. Let v i (t) denote the membrane potential of neuron i at time t, having leaky integrate-and-fire dynamics: for leak term g L , reset voltage v r and threshold θ. The network is assumed to have a feedforward structure. The neurons receive inputs from an input layer x (t), along with a noise process η j (t), weighted by synaptic weights w ij . The network is presented with this input stimulus for a fixed period of T seconds. Let the variable h i (t) denote the neuron’s spiking indicator function: h i (t) = ∑δ(t − t s ) if neuron i spikes at times t s . Post-synaptic current, s i (t), is generated according to the dynamics . On the basis of the output of the neurons, a reward signal r is generated, assumed to be a function of the filtered currents s (t): r(t) = r( s (t)). Here we assume that T is sufficiently long for the network to have received an input, produced an output, and for feedback to have been distributed to the system (e.g. in the form of dopamine signaling a reward prediction error [ 25 ]). Given this network, then, the learning problem is for each neuron to adjust its weights to maximize reward, using an estimate of its causal effect on reward.
We develop this idea in the context of a supervised learning setting. To formalize causal effects in this setting, we thus first have to think about how supervised learning might be performed by a spiking, dynamically integrating network of neurons (see, for example, the solution by Guergiuev et al 2016 [ 24 ]). That is, for simplicity, to avoid problems to do with temporal credit assignment, we consider a neural network that receives immediate feedback/reward on the quality of the computation, potentially provided by an internal critic (similar to the setup of [ 24 ]). Spiking neural networks generally operate dynamically where activities unfold over time, yet supervised learning in an artificial neural network typically has no explicit dynamics—the state of a neuron is only a function of its current inputs, not its previous inputs.
Causal effects are formally defined in the context of a certain type of probabilistic graphical model—the causal Bayesian network—while a spiking neural network is a dynamical, stochastic process. Thus before we can understand how a neuron can use its spiking discontinuity to do causal inference we must first understand how a causal effect can be defined for a neural network. We present two results:
Though not previously recognized as such, the credit assignment problem is a causal inference problem: how can a neuron know its causal effect on an output and subsequent reward? This section shows how this idea can be made more precise.
Discussion
Here we focused on the relation between gradient-based learning and causal inference. Our approach is inspired by the regression discontinuity design commonly used in econometrics [34]. We cast neural learning explicitly as a causal inference problem, and have shown that neurons can estimate their causal effect using their spiking mechanism. In this way we found that spiking can be an advantage, allowing neurons to quantify their causal effect in an unbiased way.
It is important to note that other neural learning rules also perform causal inference. Thus the spiking discontinuity learning rule can be placed in the context of other neural learning mechanisms. First, as many authors have noted, any reinforcement learning algorithm relies on estimating the effect of an agent’s/neuron’s activity on a reward signal. Learning by operant conditioning relies on learning a causal relationship (compared to classical conditioning, which only relies on learning a correlation) [27, 35–37]. Causal inference is, at least implicitly, the basis of reinforcement learning.
There is a large literature on how reinforcement learning algorithms can be implemented in the brain. It is well known there are many neuromodulators which may represent reward or expected reward, including dopaminergic neurons from the substantia nigra to the ventral striatum representing a reward prediction error [25, 38]. Many of these methods use something like the REINFORCE algorithm [39], a policy gradient method in which locally added noise is correlated with reward and this correlation is used to update weights. This gives an unbiased estimate of the causal effect because the noise is assumed to be independent, private to each neuron. These ideas have extensively been used to model learning in brains [16–22].
Learning in birdsong is a particularly well developed example of this form of learning [17]. In birdsong learning in zebra finches, neurons from area LMAN synapse onto neurons in area RA. These synapses are referred to as ‘empiric’ synapses, and are treated by the neurons as an ‘experimenter’, producing random perturbations which can be used to estimate causal effects. This is a compelling account of learning in birdsong, however it relies on the specific structural form of the learning circuit. It is unknown more broadly how a neuron may estimate what is perturbative noise without these structural specifics, and thus if it can provide an account of learning in general.
There are two factors that cast doubt on the use of reinforcement learning-type algorithms broadly in neural circuits. First, even for a fixed stimulus, noise is correlated across neurons [11–15]. Thus if the noise a neuron uses for learning is correlated with other neurons then it can not know which neuron’s changes in output is responsible for changes in reward. In such a case, the synchronizing presynaptic activity acts as a confounder. Thus, as discussed, such algorithms require biophysical mechanisms to distinguish independent perturbative noise from correlated input signals in presynaptic activity, and in general it is unclear how a neuron can do this. Though well characterized in sensory coding, noise correlation role in learning has been less studied. This work suggests that understanding learning as a causal inference problem can provide insight into the role of noise correlations in learning.
Learning with perturbations scales poorly with network size [17, 22, 40]. Thus neurons may use alternatives to these reinforcement-learning algorithms. A number of authors have looked to learning in artificial neural networks for inspiration. In artificial neural networks, the credit assignment problem is efficiently solved using the backpropagation algorithm, which allows efficiently calculating gradients. Backpropagation requires differentiable systems, which spiking neurons are not. Indeed, cortical networks often have low firing rates in which the stochastic and discontinuous nature of spiking output cannot be neglected [41]. It also requires full knowledge of the system, which is often not the case if parts of the system relate to the outside world. No known structures exist in the brain that could exactly implement backpropagation. Yet, backpropagation is significantly more efficient than perturbation-based methods—it is the only known algorithm able to solve large-scale problems at a human-level [42]. The success of backpropagation suggests that efficient methods for computing gradients are needed for solving large-scale learning problems.
An important disclaimer is that the performance of local update rules like SDE-based learning are likely to share similar scaling to that observed by REINFORCE-based methods, e.g. significantly slower than backpropagation. SDE-based learning, on its own, is not a learning rule that is significantly more efficient than REINFORCE, instead it is a rule that is more robust to the structure of noise that REINFORCE-based methods utilize. Thus it is an approach that can be used in more neural circuits than just those with special circuitry for independent noise perturbations. Such an approach can be built into neural architectures alongside backpropagation-like learning mechanisms, to solve the credit assignment problem. Indeed, this exact approach is taken by [43]. Thus SDE-based learning has relevance to both spiking neural networks in the presence of noise correlations, and as part of a biologically plausible solution to the credit assignment problem.
Further, the insights made here are relevant to models that attempt to mimic backpropagation through time. A lot of recurrent neural networks, when applied to spiking neural networks, have to deal with propagating gradients through the discontinuous spiking function [5, 44–48]. A common strategy is to replace the true derivative of the spiking response function (either zero or undefined), with a pseudo-derivative. A number of replacements have been explored: [44] uses a boxcar function, [47] uses the negative slope of the sigmoid function, [45] explores the so-called straight-through estimator—pretending the spike response function is the identity function, with a gradient of 1. Thus a number of choices are possible. As a supplementary analysis (S1 Text and S3 Fig), we demonstrated that the width of the non-zero component of this pseudo-derivative can be adjusted to account for correlated inputs. We found that in highly correlated cases, learning is more efficient when the window is smaller. Thus the exact same considerations raised by framing learning as a causal inference problem provides insight into other biologically-plausible, spiking learning models.
Finally, it must be noted how the learning rule derived here relates to the dominant spike-based learning paradigm—spike timing dependent plasticity (STDP [49]). STDP performs unsupervised learning, so is not directly related to the type of optimization considered here. Reward-modulated STDP (R-STDP) can be shown to approximate the reinforcement learning policy gradient type algorithms described above [50, 51]. Thus R-STDP can be cast as performing a type of causal inference on a reward signal, and shares the same features and caveats as outlined above. Thus we see that learning rules that aim at maximizing some reward either implicitly or explicitly involve a neuron estimating its causal effect on that reward signal. Explicitly recognizing this can lead to new methods and understanding.
Compatibility with known physiology There are a number of ways that the simulations presented here are simplifications of true learning circuits. For instance, the LIF neural network implemented in Figs 2–4 has a fixed threshold. Yet cortical neurons do not have a fixed threshold, but rather one that can adapt to recent inputs [52]. We note that our exploration of learning in this more complicated case—the delayed XOR model (S1 Text)—consists of populations of LIF and adaptive LIF neurons. The adaptive LIF neurons do have a threshold that adapts, based on recent spiking activity. The results in this model exhibit the same behavior as that observed in previous sections—for sufficiently highly correlated activity, performance is better for a narrow spiking discontinuity parameter p (cf. Fig 2A). This suggests populations of adaptive spiking threshold neurons show the same behavior as non-adaptive ones. This makes sense since what is important for the spike discontinuity method is that the learning rule use a signal that is unique to an individual neuron. The stochastic, all-or-none spiking response provides this, regardless of the exact value of the threshold. The neuron just needs to know if it was close to its threshold or not. Of course, given our simulations are based on a simplified model, it makes sense to ask what neuro-physiological features may allow spiking discontinuity learning in more realistic learning circuits. If neurons perform something like spiking discontinuity learning we should expect that they exhibit certain physiological properties. In this section we discuss the concrete demands of such learning and how they relate to past experiments. First, the spiking regime is required to be irregular [26] in order to produce spike trains with randomness in their response for repeated inputs. Irregular spiking regimes are common in cortical networks (e.g. [53]). Further, for spiking discontinuity learning, plasticity should be confined to cases where a neuron’s membrane potential is close to threshold, regardless of spiking. This means inputs that place a neuron close to threshold, but do not elicit a spike, still result in plasticity. This type of sub-threshold dependent plasticity is known to occur [54–56]. This also means that plasticity will not occur for inputs that place a neuron too far below threshold. In fact, in past models and experiments testing voltage-dependent plasticity, changes do not occur when postsynaptic voltages are too low [57, 58]. Spiking discontinuity predicts that plasticity does not occur when postsynaptic voltages are too high. However, in many voltage-dependent plasticity models, potentiation does occur for inputs well-above the spiking threshold. To test if this is an important difference between what is statistically correct and what is more readily implementable in neurophysiology, we experimented with a modification of the learning rule, which does not distinguish between barely above threshold inputs and well above threshold inputs. For the piecewise-linear model, this modification did not adversely affect the neurons’ ability to estimate causal effects (S2 Fig). Thus threshold-adjacent plasticity as required for spike discontinuity learning appears to be compatible with neuronal physiology. The learning rules presented here also need knowledge of outcomes and would thus likely be dependent on neuromodulation. Neuromodulated-STDP is well studied in models [51, 59, 60]. However, this learning requires reward-dependent plasticity that differs depending on if the neuron spiked or not. This may be communicated by neuromodulation. For instance, there is some evidence that the relative balance between adrenergic and M1 muscarinic agonists alters both the sign and magnitude of STDP in layer II/III visual cortical neurons [59]. To the best of our knowledge, how such behavior interacts with postsynaptic voltage dependence as required by spike discontinuity is unknown. Overall, the structure of the learning rule is that a global reward signal (potentially transmitted through neuromodulators) drives learning of a variable, β, inside single neurons. This internal variable is combined with a term to update synaptic weights. The update rule for the weights depends only on pre- and post-synaptic terms, with the post- term getting updated over time, independently of the weight updates. Such an interpretation is interestingly in line with recently proposed ideas on inter-neuron learning, e.g., Gershman 2023 [61], who proposes an interaction of intra-cellular variables and synaptic learning rules can provide a substrate for memory. Thus, taken together, these factors show that SDE-based learning may well be compatible with known neuronal physiology.
[END]
---
[1] Url:
https://journals.plos.org/ploscompbiol/article?id=10.1371/journal.pcbi.1011005
Published and (C) by PLOS One
Content appears here under this condition or license: Creative Commons - Attribution BY 4.0.
via Magical.Fish Gopher News Feeds:
gopher://magical.fish/1/feeds/news/plosone/