#神经网络 #神经元 #神经网络 #神经科学 #神经传递 #生物物理

spiking_neural_networks

用于设计和模拟具有神经传递的生物神经网络动力学的软件包

25个版本 (14个重大更新)

新增 0.23.2 2024年8月22日
0.22.0 2024年8月11日
0.19.3 2024年7月30日

#44 in 生物学

Download history 361/week @ 2024-06-22 573/week @ 2024-06-29 462/week @ 2024-07-06 315/week @ 2024-07-13 172/week @ 2024-07-20 288/week @ 2024-07-27 264/week @ 2024-08-03 246/week @ 2024-08-10 255/week @ 2024-08-17

1,063 每月下载量

Apache-2.0

465KB
7.5K SLoC

Spiking Neural Networks

spiking_neural_networks 是一个专注于设计具有神经传递的神经元模型并计算神经元之间随时间变化的动态的软件包。神经元动态是通过特质实现的,以便可以通过类型系统进行扩展,为不同的神经递质、受体或神经元模型添加新的动态。目前实现了以下系统: spikes 列车、spike 时间依赖性可塑性、基本吸引子、奖励调节的动态和晶格中连接的神经元的动态。下面是示例和如何添加自定义模型的说明。

快速示例

Morris-Lecar 模型与静态输入

Morris-Lecar with static current input

耦合的 Izhikevich 神经元

Coupled Izhikevich models

具有神经传递的 Hodgkin-Huxley 模型

Hodgkin Huxley model voltage and neurotransmitter over time

随时间变化的 spike 时间依赖性可塑性权重

STDP weights over time

Hopfield 网络模式重建

Discrete Hopfield network pattern reconstruction

晶格

Voltage over time

示例代码

请参阅 examples 文件夹 以获取更多示例。

具有抑制性和兴奋性输入的神经元相互作用的晶格

extern crate spiking_neural_networks;
use rand::Rng;
use spiking_neural_networks::{
    neuron::{
        integrate_and_fire::IzhikevichNeuron, 
        plasticity::STDP,
        Lattice, LatticeNetwork, AverageVoltageHistory
    },
    graph::AdjacencyMatrix,
    error::SpikingNeuralNetworksError, 
};


/// Creates two pools of neurons, one inhibitory and one excitatory, and connects them,
/// writes the average voltage history over time of each pool to .csv files to the
/// current working directory
fn main() -> Result<(), SpikingNeuralNetworksError> {
    let base_neuron = IzhikevichNeuron::default_impl();

    // creates smaller inhbitory lattice to stabilize excitatory feedback
    let mut inh_lattice: Lattice<_, AdjacencyMatrix<_, _>, AverageVoltageHistory, STDP, _> = Lattice::default();
    inh_lattice.populate(&base_neuron, 4, 4);
    inh_lattice.connect(&|x, y| x != y, Some(&|_, _| -1.));
    inh_lattice.apply(|n| {
        let mut rng = rand::thread_rng();
        n.current_voltage = rng.gen_range(n.v_init..=n.v_th);
    });
    inh_lattice.update_grid_history = true;

    // creates larger excitatory lattice
    let mut exc_lattice: Lattice<_, AdjacencyMatrix<_, _>, AverageVoltageHistory, STDP, _> = Lattice::default();
    exc_lattice.set_id(1);
    exc_lattice.populate(&base_neuron, 7, 7);
    exc_lattice.connect(&|x, y| x != y, Some(&|_, _| 1.));
    exc_lattice.apply(|n| {
        let mut rng = rand::thread_rng();
        n.current_voltage = rng.gen_range(n.v_init..=n.v_th);
    });
    exc_lattice.update_grid_history = true;

    // sets up network
    let mut network = LatticeNetwork::default_impl();
    network.parallel = true;
    network.add_lattice(inh_lattice)?;
    network.add_lattice(exc_lattice)?;

    network.connect(0, 1, &|_, _| true, Some(&|_, _| -1.))?;
    network.connect(1, 0, &|_, _| true, None)?;

    network.set_dt(1.);

    network.run_lattices(1_000)?;

    Ok(())
}

具有 spike 时间依赖性可塑性的神经元耦合

use std::collections::HashMap;
extern crate spiking_neural_networks;
use spiking_neural_networks::{
    neuron::{
        integrate_and_fire::IzhikevichNeuron,
        iterate_and_spike::{
            IterateAndSpike, GaussianParameters, 
            ApproximateNeurotransmitter, NeurotransmitterType}, 
        spike_train::{DeltaDiracRefractoriness, PresetSpikeTrain}, 
        plasticity::STDP,
        Lattice, LatticeNetwork, SpikeTrainGridHistory, SpikeTrainLattice
    },
    error::SpikingNeuralNetworksError, 
};


/// Tests STDP dynamics over time given a set of input firing rates to a postsynaptic neuron
/// and updates the weights between the spike trains and given postsynaptic neuron, returns
/// the voltage and weight history over time
pub fn test_stdp<N, T>(
    firing_rates: &[f32],
    postsynaptic_neuron: &T,
    iterations: usize,
    stdp_params: &STDP,
    weight_params: &GaussianParameters,
    electrical_synapse: bool,
    chemical_synapse: bool,
) -> Result<(HashMap<String, Vec<f32>>, Vec<Vec<Vec<Option<f32>>>>), SpikingNeuralNetworksError>
where
    N: NeurotransmitterType,
    T: IterateAndSpike<N=N>,
{
    type SpikeTrainType<N> = PresetSpikeTrain<N, ApproximateNeurotransmitter, DeltaDiracRefractoriness>;

    // sets up line of spike trains depending on number of firing rates
    let mut spike_train_lattice: SpikeTrainLattice<
        N, 
        SpikeTrainType<N>, 
        SpikeTrainGridHistory,
    > = SpikeTrainLattice::default();
    let preset_spike_train = PresetSpikeTrain::default();
    spike_train_lattice.populate(&preset_spike_train, firing_rates.len(), 1);
    spike_train_lattice.apply_given_position(
        &(|pos: (usize, usize), spike_train: &mut SpikeTrainType<N>| { 
            spike_train.firing_times = vec![firing_rates[pos.0]]; 
        })
    );
    spike_train_lattice.update_grid_history = true;
    spike_train_lattice.set_id(0);

    // generates postsynaptic neuron
    let mut lattice = Lattice::default_impl();
    lattice.populate(&postsynaptic_neuron.clone(), 1, 1);
    lattice.plasticity = *stdp_params;
    lattice.do_plasticity = true;
    lattice.update_grid_history = true;
    lattice.set_id(1);

    // connects spike trains to neuron and runs
    let lattices = vec![lattice];
    let spike_train_lattices = vec![spike_train_lattice];
    let mut network = LatticeNetwork::generate_network(lattices, spike_train_lattices)?;
    network.connect(
        0, 1, &(|_, _| true), Some(&(|_, _| weight_params.get_random_number()))
    )?;
    network.update_connecting_graph_history = true;
    network.electrical_synapse = electrical_synapse;
    network.chemical_synapse = chemical_synapse;

    network.run_lattices(iterations)?;

    // track postsynaptic voltage over time
    // track spike trains over time
    // track weights over time

    let mut output_hashmap: HashMap<String, Vec<f32>> = HashMap::new();
    output_hashmap.insert(
        String::from("postsynaptic_voltage"),
        network.get_lattice(&1).unwrap().grid_history
            .history
            .iter()
            .map(|i| i[0][0])
            .collect(),
    );
    let spike_train_history = &network.get_spike_train_lattice(&0).unwrap()
        .grid_history.history;
    for i in 0..firing_rates.len() {
        output_hashmap
            .entry(format!("presynaptic_voltage_{}", i))
            .or_insert_with(Vec::new)
            .extend(spike_train_history.iter().map(|step| step[i][0]).collect::<Vec<f32>>());
    }

    Ok((output_hashmap, network.get_connecting_graph().history.clone()))
}

自定义 IterateAndSpike 实现

use spiking_neural_networks::neuron::iterate_and_spike_traits::IterateAndSpikeBase;
use spiking_neural_networks::neuron::iterate_and_spike::{
    GaussianFactor, GaussianParameters, IsSpiking, Timestep,
    CurrentVoltage, GapConductance, IterateAndSpike, 
    LastFiringTime, NeurotransmitterConcentrations, LigandGatedChannels, 
    ReceptorKinetics, NeurotransmitterKinetics, Neurotransmitters,
    ApproximateNeurotransmitter, ApproximateReceptor,
    IonotropicNeurotransmitterType,
};
use spiking_neural_networks::neuron::ion_channels::{
    BasicGatingVariable, IonChannel, TimestepIndependentIonChannel,
};
 

/// A calcium channel with reduced dimensionality
#[derive(Debug, Clone, Copy)]
pub struct ReducedCalciumChannel {
    /// Conductance of calcium channel (nS)
    pub g_ca: f32,
    /// Reversal potential (mV)
    pub v_ca: f32,
    /// Gating variable steady state
    pub m_ss: f32,
    /// Tuning parameter
    pub v_1: f32,
    /// Tuning parameter
    pub v_2: f32,
    /// Current output
    pub current: f32,
}

impl TimestepIndependentIonChannel for ReducedCalciumChannel {
    fn update_current(&mut self, voltage: f32) {
        self.m_ss = 0.5 * (1. + ((voltage - self.v_1) / self.v_2).tanh());

        self.current = self.g_ca * self.m_ss * (voltage - self.v_ca);
    }

    fn get_current(&self) -> f32 {
        self.current
    }
}

/// A potassium channel based on steady state calculations
#[derive(Debug, Clone, Copy)]
pub struct KSteadyStateChannel {
    /// Conductance of potassium channel (nS)
    pub g_k: f32,
    /// Reversal potential (mV)
    pub v_k: f32,
    /// Gating variable
    pub n: f32,
    /// Gating variable steady state
    pub n_ss: f32,
    /// Gating decay
    pub t_n: f32,
    /// Reference frequency
    pub phi: f32,
    /// Tuning parameter
    pub v_3: f32,
    /// Tuning parameter
    pub v_4: f32,
    /// Current output
    pub current: f32
}

impl KSteadyStateChannel {
    fn update_gating_variables(&mut self, voltage: f32) {
        self.n_ss = 0.5 * (1. + ((voltage - self.v_3) / self.v_4).tanh());
        self.t_n = 1. / (self.phi * ((voltage - self.v_3) / (2. * self.v_4)).cosh());
    }
}

impl IonChannel for KSteadyStateChannel { 
    fn update_current(&mut self, voltage: f32, dt: f32) {
        self.update_gating_variables(voltage);

        let n_change = ((self.n_ss - self.n) / self.t_n) * dt;

        self.n += n_change;

        self.current = self.g_k * self.n * (voltage - self.v_k);
    }

    fn get_current(&self) -> f32 {
        self.current
    }
}

/// An implementation of a leak channel
#[derive(Debug, Clone, Copy)]
pub struct LeakChannel {
    /// Conductance of leak channel (nS)
    pub g_l: f32,
    /// Reversal potential (mV)
    pub v_l: f32,
    /// Current output
    pub current: f32
}

impl TimestepIndependentIonChannel for LeakChannel {
    fn update_current(&mut self, voltage: f32) {
        self.current = self.g_l * (voltage - self.v_l);
    }

    fn get_current(&self) -> f32 {
        self.current
    }
}

#[derive(Debug, Clone, IterateAndSpikeBase)]
pub struct MorrisLecarNeuron<T: NeurotransmitterKinetics, R: ReceptorKinetics> {
    /// Membrane potential (mV)
    pub current_voltage: f32, 
    /// Voltage threshold (mV)
    pub v_th: f32,
    /// Initial voltage value (mV)
    pub v_init: f32,
    /// Controls conductance of input gap junctions
    pub gap_conductance: f32,
    /// Calcium channel
    pub ca_channel: ReducedCalciumChannel,
    /// Potassium channel
    pub k_channel: KSteadyStateChannel,
    /// Leak channel
    pub leak_channel: LeakChannel,
    /// Membrane capacitance (nF)
    pub c_m: f32,
    /// Timestep in (ms)
    pub dt: f32,
    /// Whether the neuron is spiking
    pub is_spiking: bool,
    /// Whether the voltage was increasing in the last step
    pub was_increasing: bool,
    /// Last timestep the neuron has spiked
    pub last_firing_time: Option<usize>,
    /// Parameters used in generating noise
    pub gaussian_params: GaussianParameters,
    /// Postsynaptic neurotransmitters in cleft
    pub synaptic_neurotransmitters: Neurotransmitters<IonotropicNeurotransmitterType, T>,
    /// Ionotropic receptor ligand gated channels
    pub ligand_gates: LigandGatedChannels<R>,
}

impl<T: NeurotransmitterKinetics, R: ReceptorKinetics> MorrisLecarNeuron<T, R> {
    /// Updates channel states based on current voltage
    pub fn update_channels(&mut self) {
        self.ca_channel.update_current(self.current_voltage);
        self.k_channel.update_current(self.current_voltage, self.dt);
        self.leak_channel.update_current(self.current_voltage);
    }
    
    /// Calculates change in voltage given an input current
    pub fn get_dv_change(&self, i: f32) -> f32 {
        (i - self.leak_channel.current - self.ca_channel.current - self.k_channel.current)
        * (self.dt / self.c_m)
    }

    // checks if neuron is currently spiking but seeing if the neuron is increasing in
    // reference to the last inputted voltage and if it is above a certain
    // voltage threshold, if it is then the neuron is considered spiking
    // and `true` is returned, otherwise `false` is returned
    fn handle_spiking(&mut self, last_voltage: f32) -> bool {
        let increasing_right_now = last_voltage < self.current_voltage;
        let threshold_crossed = self.current_voltage > self.v_th;
        let is_spiking = threshold_crossed && self.was_increasing && !increasing_right_now;

        self.is_spiking = is_spiking;
        self.was_increasing = increasing_right_now;

        is_spiking
    }
}

impl<T: NeurotransmitterKinetics, R: ReceptorKinetics> IterateAndSpike for MorrisLecarNeuron<T, R> {
    type N = IonotropicNeurotransmitterType;

    fn get_neurotransmitter_concentrations(&self) -> NeurotransmitterConcentrations<Self::N> {
        self.synaptic_neurotransmitters.get_concentrations()
    }

    // updates voltage and adaptive values as well as the 
    // neurotransmitters, receptor current is not factored in,
    // and spiking is handled and returns whether it is currently spiking
    fn iterate_and_spike(&mut self, input_current: f32) -> bool {
        self.update_channels();

        let last_voltage = self.current_voltage;
        self.current_voltage += self.get_dv_change(input_current);

        self.synaptic_neurotransmitters.apply_t_changes(self.current_voltage, self.dt);

        self.handle_spiking(last_voltage)
    }

    // updates voltage and adaptive values as well as the 
    // neurotransmitters, receptor current is factored in and receptor gating
    // is updated spiking is handled at the end of the method and 
    // returns whether it is currently spiking
    fn iterate_with_neurotransmitter_and_spike(
        &mut self, 
        input_current: f32, 
        t_total: &NeurotransmitterConcentrations<Self::N>,
    ) -> bool {
        self.ligand_gates.update_receptor_kinetics(t_total, self.dt);
        self.ligand_gates.set_receptor_currents(self.current_voltage, self.dt);
        
        self.update_channels();

        let last_voltage = self.current_voltage;
        let receptor_current = self.ligand_gates.get_receptor_currents(self.dt, self.c_m);
        self.current_voltage += self.get_dv_change(input_current) + receptor_current;

        self.synaptic_neurotransmitters.apply_t_changes(self.current_voltage, self.dt);

        self.handle_spiking(last_voltage)
    }
}

自定义 NeurotransmitterKinetics 实现

use spiking_neural_networks::neuron::iterate_and_spike::NeurotransmitterKinetics;

/// An approximation of neurotransmitter kinetics that sets the concentration to the 
/// maximal value when a spike is detected (input `voltage` is greater than `v_th`) and
/// slowly through exponential decay that scales based on the `decay_constant` and `dt`
#[derive(Debug, Clone, Copy)]
pub struct ExponentialDecayNeurotransmitter {
    /// Maximal neurotransmitter concentration (mM)
    pub t_max: f32,
    /// Current neurotransmitter concentration (mM)
    pub t: f32,
    /// Voltage threshold for detecting spikes (mV)
    pub v_th: f32,
    /// Amount to decay neurotransmitter concentration by
    pub decay_constant: f32,
}

// used to determine when voltage spike occurs
fn heaviside(x: f32) -> f32 {
    if x > 0. {
        1.
    } else {
        0.
    }
}

// calculate change in concentration
fn exp_decay(x: f32, l: f32, dt: f32) -> f32 {
    -x * (dt / -l).exp()
}

impl NeurotransmitterKinetics for ExponentialDecayNeurotransmitter {
    fn apply_t_change(&mut self, voltage: f32, dt: f32) {
        let t_change = exp_decay(self.t, self.decay_constant, dt);
        // add change and account for spike
        self.t += t_change + (heaviside(voltage - self.v_th) * self.t_max);
        self.t = self.t_max.min(self.t.max(0.)); // clamp values
    }

    fn get_t(&self) -> f32 {
        self.t
    }

    fn set_t(&mut self, t: f32) {
        self.t = t;
    }
}

自定义 ReceptorKinetics 实现

use spiking_neural_networks::neuron::iterate_and_spike::{
    ReceptorKinetics, AMPADefault, GABAaDefault, GABAbDefault, NMDADefault,
};

/// Receptor dynamics approximation that sets the receptor
/// gating value to the inputted neurotransmitter concentration and
/// then exponentially decays the receptor over time
#[derive(Debug, Clone, Copy)]
pub struct ExponentialDecayReceptor {
    /// Maximal receptor gating value
    pub r_max: f32,
    /// Receptor gating value
    pub r: f32,
    /// Amount to decay neurotransmitter concentration by
    pub decay_constant: f32,
}

// calculate change in receptor gating variable over time
fn exp_decay(x: f32, l: f32, dt: f32) -> f32 {
    -x * (dt / -l).exp()
}

impl ReceptorKinetics for ExponentialDecayReceptor {
    fn apply_r_change(&mut self, t: f32, dt: f32) {
        // calculate and apply change
        self.r += exp_decay(self.r, self.decay_constant, dt) + t;
        self.r = self.r_max.min(self.r.max(0.)); // clamp values
    }
    fn get_r(&self) -> f32 {
        self.r
    }
    fn set_r(&mut self, r: f32) {
        self.r = r;
    }
}
// automatically generate defaults so `LigandGatedChannels`
// can use default receptor settings in construction
macro_rules! impl_exp_decay_receptor_default {
    ($trait:ident, $method:ident) => {
        impl $trait for ExponentialDecayReceptor {
            fn $method() -> Self {
                ExponentialDecayReceptor { 
                    r_max: 1.0,
                    r: 0.,
                    decay_constant: 2.,
                }
            }
        }
    };
}

impl_exp_decay_receptor_default!(Default, default);
impl_exp_decay_receptor_default!(AMPADefault, ampa_default);
impl_exp_decay_receptor_default!(GABAaDefault, gabaa_default);
impl_exp_decay_receptor_default!(GABAbDefault, gabab_default);
impl_exp_decay_receptor_default!(NMDADefault, nmda_default);

依赖关系

~8MB
~156K SLoC