Skip to content

Commit ef8123b

Browse files
Add AdEx cells. (#2230)
Add Adaptive Exponential Cells to Arbor. Closes #1832 More on request.
1 parent 2366842 commit ef8123b

41 files changed

Lines changed: 3212 additions & 1075 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

arbor/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ set(arbor_sources
2323
iexpr.cpp
2424
label_resolution.cpp
2525
lif_cell_group.cpp
26+
adex_cell_group.cpp
2627
cable_cell_group.cpp
2728
mechcat.cpp
2829
mechinfo.cpp

arbor/adex_cell_group.cpp

Lines changed: 284 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,284 @@
1+
#include "adex_cell_group.hpp"
2+
3+
#include <arbor/arbexcept.hpp>
4+
5+
#include "arbor/math.hpp"
6+
#include "util/rangeutil.hpp"
7+
#include "util/span.hpp"
8+
#include "label_resolution.hpp"
9+
#include "profile/profiler_macro.hpp"
10+
11+
using namespace arb;
12+
13+
// Constructor containing gid of first cell in a group and a container of all cells.
14+
adex_cell_group::adex_cell_group(const std::vector<cell_gid_type>& gids,
15+
const recipe& rec,
16+
cell_label_range& cg_sources,
17+
cell_label_range& cg_targets):
18+
gids_(gids) {
19+
20+
for (auto gid: gids_) {
21+
const auto& cell = util::any_cast<adex_cell>(rec.get_cell_description(gid));
22+
// set up cell state
23+
cells_.push_back(cell);
24+
// tell our caller about this cell's connections
25+
cg_sources.add_cell();
26+
cg_targets.add_cell();
27+
cg_sources.add_label(hash_value(cell.source), {0, 1});
28+
cg_targets.add_label(hash_value(cell.target), {0, 1});
29+
// insert probes where needed
30+
auto probes = rec.get_probes(gid);
31+
for (const auto& probe: probes) {
32+
if (probe.address.type() == typeid(adex_probe_voltage)) {
33+
cell_address_type addr{gid, probe.tag};
34+
if (probes_.contains(addr)) throw dup_cell_probe(cell_kind::adex, gid, probe.tag);
35+
probes_.insert_or_assign(addr, adex_probe_info{adex_probe_kind::voltage, {}});
36+
}
37+
else if (probe.address.type() == typeid(adex_probe_adaption)) {
38+
cell_address_type addr{gid, probe.tag};
39+
if (probes_.contains(addr)) throw dup_cell_probe(cell_kind::adex, gid, probe.tag);
40+
probes_.insert_or_assign(addr, adex_probe_info{adex_probe_kind::adaption, {}});
41+
}
42+
else {
43+
throw bad_cell_probe{cell_kind::adex, gid};
44+
}
45+
}
46+
// set up the internal state
47+
next_update_.push_back(0);
48+
current_time_.push_back(0);
49+
}
50+
}
51+
52+
cell_kind adex_cell_group::get_cell_kind() const {
53+
return cell_kind::adex;
54+
}
55+
56+
void adex_cell_group::advance(epoch ep, time_type dt, const event_lane_subrange& event_lanes) {
57+
PE(adex);
58+
for (auto lid: util::make_span(gids_.size())) {
59+
// Advance each cell independently.
60+
advance_cell(ep.t1, dt, lid, event_lanes);
61+
}
62+
PL(adex);
63+
}
64+
65+
const std::vector<spike>& adex_cell_group::spikes() const {
66+
return spikes_;
67+
}
68+
69+
void adex_cell_group::clear_spikes() {
70+
spikes_.clear();
71+
}
72+
73+
void adex_cell_group::add_sampler(sampler_association_handle h,
74+
cell_member_predicate probeset_ids,
75+
schedule sched,
76+
sampler_function fn) {
77+
std::lock_guard<std::mutex> guard(sampler_mex_);
78+
std::vector<cell_address_type> probeset;
79+
for (const auto& [k, v]: probes_) {
80+
if (probeset_ids(k)) probeset.push_back(k);
81+
}
82+
auto assoc = arb::sampler_association{std::move(sched),
83+
std::move(fn),
84+
std::move(probeset)};
85+
auto result = samplers_.insert({h, std::move(assoc)});
86+
arb_assert(result.second);
87+
}
88+
89+
void adex_cell_group::remove_sampler(sampler_association_handle h) {
90+
std::lock_guard<std::mutex> guard(sampler_mex_);
91+
samplers_.erase(h);
92+
}
93+
void adex_cell_group::remove_all_samplers() {
94+
std::lock_guard<std::mutex> guard(sampler_mex_);
95+
samplers_.clear();
96+
}
97+
98+
void adex_cell_group::reset() {
99+
spikes_.clear();
100+
}
101+
102+
// integrate a single cell's state from current time `cur` tos final time `end`.
103+
// Extra parameters
104+
// * the cell cannot be updated until time `nxt`, which might be in the past or future.
105+
//
106+
// We can be in three states:
107+
// 1. nxt <= cur: we can simply update the cell without further consideration
108+
// 2. cur < nxt <= end: we perform two steps:
109+
// a. cur - nxt: refractory period, just manipulate w
110+
// b. nxt - end: normal dynamics, add spike
111+
// 3. nxt > end. Skip everything
112+
void integrate_until(adex_lowered_cell& cell, const time_type end, const time_type& nxt, time_type& cur) {
113+
// perform pre-step to skip refractory period. This _might_ put cell state beyond the epoch end.
114+
if (nxt > cur) cur = std::min(nxt, end);
115+
// if we still have time left, perform the integration.
116+
if (nxt > end) return;
117+
// dT
118+
auto delta = end - cur;
119+
// membrane potential deviation from resting value
120+
auto dE = cell.V_m - cell.E_L;
121+
// leak current
122+
auto il = cell.g*dE;
123+
// spike current
124+
auto is = cell.g*cell.delta*exp((cell.V_m - cell.V_th)/cell.delta);
125+
// potential delta
126+
auto dV = (is - il - cell.w)/cell.C_m;
127+
cell.V_m += delta*dV;
128+
129+
auto dW = (cell.a*dE - cell.w)/cell.tau;
130+
cell.w += delta*dW;
131+
cur = end;
132+
}
133+
134+
void check_spike(adex_lowered_cell& cell, const time_type time, time_type& nxt, const cell_gid_type gid, std::vector<spike>& spikes) {
135+
if (time > nxt && cell.V_m >= cell.V_th) {
136+
spikes.emplace_back(cell_member_type{gid, 0}, time);
137+
// reset membrane potential
138+
cell.V_m = cell.E_R;
139+
// schedule next update
140+
nxt = time + cell.t_ref;
141+
cell.w += cell.b;
142+
}
143+
}
144+
145+
void adex_cell_group::advance_cell(time_type t_fin,
146+
time_type dt,
147+
cell_gid_type lid,
148+
const event_lane_subrange& event_lanes) {
149+
auto time = current_time_[lid];
150+
auto gid = gids_[lid];
151+
// Flattened sampler map
152+
std::vector<probe_metadata> sample_metadata;
153+
std::vector<sampler_association_handle> sample_callbacks;
154+
std::vector<std::vector<sample_record>> sample_records;
155+
156+
struct sample_event {
157+
time_type time;
158+
adex_probe_kind kind;
159+
double* data;
160+
};
161+
162+
std::vector<sample_event> sample_events;
163+
std::vector<double> sample_data;
164+
165+
if (!samplers_.empty()) {
166+
auto tlast = time;
167+
std::vector<size_t> sample_sizes;
168+
std::size_t total_size = 0;
169+
{
170+
std::lock_guard<std::mutex> guard(sampler_mex_);
171+
for (auto& [hdl, assoc]: samplers_) {
172+
// No need to generate events
173+
if (assoc.probeset_ids.empty()) continue;
174+
// Construct sampling times, might give us the last time we sampled, so skip that.
175+
auto times = util::make_range(assoc.sched.events(tlast, t_fin));
176+
// while (!times.empty() && times.front() == tlast) times.left++;
177+
if (times.empty()) continue;
178+
for (unsigned idx = 0; idx < assoc.probeset_ids.size(); ++idx) {
179+
const auto& pid = assoc.probeset_ids[idx];
180+
if (pid.gid != gid) continue;
181+
const auto& probe = probes_.at(pid);
182+
sample_metadata.push_back({pid, idx, util::any_ptr{&probe.metadata}});
183+
sample_callbacks.push_back(hdl);
184+
sample_records.emplace_back();
185+
auto& records = sample_records.back();
186+
sample_sizes.push_back(times.size());
187+
total_size += times.size();
188+
for (auto t: times) {
189+
records.push_back(sample_record{t, nullptr});
190+
sample_events.push_back(sample_event{t, probe.kind, nullptr});
191+
}
192+
}
193+
}
194+
}
195+
// Flat list of things to sample
196+
// NOTE: Need to allocate in one go, else reallocation will mess up the pointers!
197+
sample_data.resize(total_size);
198+
auto rx = 0;
199+
for (unsigned ix = 0; ix < sample_sizes.size(); ++ix) {
200+
auto size = sample_sizes[ix];
201+
for (size_t kx = 0; kx < size; ++kx) {
202+
sample_records[ix][kx].data = const_cast<const double*>(sample_data.data() + rx);
203+
sample_events[rx].data = sample_data.data() + rx;
204+
++rx;
205+
}
206+
}
207+
}
208+
util::sort_by(sample_events, [](const auto& s) { return s.time; });
209+
auto n_samples = sample_events.size();
210+
211+
auto& cell = cells_[lid];
212+
auto n_events = static_cast<int>(!event_lanes.empty() ? event_lanes[lid].size() : 0);
213+
auto evt_idx = 0;
214+
size_t spl_idx = 0;
215+
while (time < t_fin) {
216+
auto t_end = std::min(t_fin, time + dt);
217+
// forward progress?
218+
arb_assert(t_end > time);
219+
auto V_0 = cell.V_m;
220+
auto W_0 = cell.w;
221+
// Process events in [time, time + dt)
222+
// delivering each at the exact time
223+
for (;; ++evt_idx) {
224+
if (evt_idx >= n_events) break;
225+
if (event_lanes[lid][evt_idx].time >= t_end) break;
226+
227+
const auto& evt = event_lanes[lid][evt_idx];
228+
integrate_until(cell, evt.time, next_update_[lid], current_time_[lid]);
229+
// NOTE we _could check here instead or in addition.
230+
// check_spike(cell, evt.time, next_update_[lid], gid, spikes_);
231+
if (next_update_[lid] <= evt.time) cell.V_m += evt.weight/cell.C_m;
232+
check_spike(cell, evt.time, next_update_[lid], gid, spikes_);
233+
}
234+
// if there's time left before t_end, integrate until that
235+
integrate_until(cell, t_end, next_update_[lid], current_time_[lid]);
236+
check_spike(cell, t_end, next_update_[lid], gid, spikes_);
237+
238+
// now process the sampling events
239+
for (;; ++spl_idx) {
240+
if (spl_idx >= n_samples) break;
241+
const auto& evt = sample_events[spl_idx];
242+
if (evt.time > t_end) break;
243+
// interpolation paramter
244+
auto t = (evt.time - time)/dt;
245+
if (evt.kind == adex_probe_kind::voltage) *evt.data = math::lerp(V_0, cell.V_m, t);
246+
if (evt.kind == adex_probe_kind::adaption) *evt.data = math::lerp(W_0, cell.w, t);
247+
}
248+
249+
time = t_end;
250+
}
251+
252+
arb_assert(time == t_fin);
253+
arb_assert(evt_idx == n_events);
254+
arb_assert(spl_idx == n_samples);
255+
256+
auto n_samplers = sample_callbacks.size();
257+
{
258+
std::lock_guard<std::mutex> guard{sampler_mex_};
259+
for (size_t s_idx = 0; s_idx < n_samplers; ++s_idx) {
260+
const auto& sd = sample_records[s_idx];
261+
auto hdl = sample_callbacks[s_idx];
262+
const auto& fun = samplers_[hdl].sampler;
263+
arb_assert(fun);
264+
fun(sample_metadata[s_idx], sd.size(), sd.data());
265+
}
266+
}
267+
}
268+
269+
void adex_cell_group::t_serialize(serializer& ser, const std::string& k) const {
270+
serialize(ser, k, *this);
271+
}
272+
273+
void adex_cell_group::t_deserialize(serializer& ser, const std::string& k) {
274+
deserialize(ser, k, *this);
275+
}
276+
277+
std::vector<probe_metadata> adex_cell_group::get_probe_metadata(const cell_address_type& key) const {
278+
// SAFETY: Probe associations are fixed after construction, so we do not
279+
// need to grab the mutex.
280+
if (auto it = probes_.find(key); it != probes_.end()) {
281+
return {probe_metadata{key, 0, &it->second.metadata}};
282+
}
283+
return {};
284+
}

0 commit comments

Comments
 (0)