|
| 1 | +#include "adex_cell_group.hpp" |
| 2 | + |
| 3 | +#include <arbor/arbexcept.hpp> |
| 4 | + |
| 5 | +#include "arbor/math.hpp" |
| 6 | +#include "util/rangeutil.hpp" |
| 7 | +#include "util/span.hpp" |
| 8 | +#include "label_resolution.hpp" |
| 9 | +#include "profile/profiler_macro.hpp" |
| 10 | + |
| 11 | +using namespace arb; |
| 12 | + |
| 13 | +// Constructor containing gid of first cell in a group and a container of all cells. |
| 14 | +adex_cell_group::adex_cell_group(const std::vector<cell_gid_type>& gids, |
| 15 | + const recipe& rec, |
| 16 | + cell_label_range& cg_sources, |
| 17 | + cell_label_range& cg_targets): |
| 18 | + gids_(gids) { |
| 19 | + |
| 20 | + for (auto gid: gids_) { |
| 21 | + const auto& cell = util::any_cast<adex_cell>(rec.get_cell_description(gid)); |
| 22 | + // set up cell state |
| 23 | + cells_.push_back(cell); |
| 24 | + // tell our caller about this cell's connections |
| 25 | + cg_sources.add_cell(); |
| 26 | + cg_targets.add_cell(); |
| 27 | + cg_sources.add_label(hash_value(cell.source), {0, 1}); |
| 28 | + cg_targets.add_label(hash_value(cell.target), {0, 1}); |
| 29 | + // insert probes where needed |
| 30 | + auto probes = rec.get_probes(gid); |
| 31 | + for (const auto& probe: probes) { |
| 32 | + if (probe.address.type() == typeid(adex_probe_voltage)) { |
| 33 | + cell_address_type addr{gid, probe.tag}; |
| 34 | + if (probes_.contains(addr)) throw dup_cell_probe(cell_kind::adex, gid, probe.tag); |
| 35 | + probes_.insert_or_assign(addr, adex_probe_info{adex_probe_kind::voltage, {}}); |
| 36 | + } |
| 37 | + else if (probe.address.type() == typeid(adex_probe_adaption)) { |
| 38 | + cell_address_type addr{gid, probe.tag}; |
| 39 | + if (probes_.contains(addr)) throw dup_cell_probe(cell_kind::adex, gid, probe.tag); |
| 40 | + probes_.insert_or_assign(addr, adex_probe_info{adex_probe_kind::adaption, {}}); |
| 41 | + } |
| 42 | + else { |
| 43 | + throw bad_cell_probe{cell_kind::adex, gid}; |
| 44 | + } |
| 45 | + } |
| 46 | + // set up the internal state |
| 47 | + next_update_.push_back(0); |
| 48 | + current_time_.push_back(0); |
| 49 | + } |
| 50 | +} |
| 51 | + |
| 52 | +cell_kind adex_cell_group::get_cell_kind() const { |
| 53 | + return cell_kind::adex; |
| 54 | +} |
| 55 | + |
| 56 | +void adex_cell_group::advance(epoch ep, time_type dt, const event_lane_subrange& event_lanes) { |
| 57 | + PE(adex); |
| 58 | + for (auto lid: util::make_span(gids_.size())) { |
| 59 | + // Advance each cell independently. |
| 60 | + advance_cell(ep.t1, dt, lid, event_lanes); |
| 61 | + } |
| 62 | + PL(adex); |
| 63 | +} |
| 64 | + |
| 65 | +const std::vector<spike>& adex_cell_group::spikes() const { |
| 66 | + return spikes_; |
| 67 | +} |
| 68 | + |
| 69 | +void adex_cell_group::clear_spikes() { |
| 70 | + spikes_.clear(); |
| 71 | +} |
| 72 | + |
| 73 | +void adex_cell_group::add_sampler(sampler_association_handle h, |
| 74 | + cell_member_predicate probeset_ids, |
| 75 | + schedule sched, |
| 76 | + sampler_function fn) { |
| 77 | + std::lock_guard<std::mutex> guard(sampler_mex_); |
| 78 | + std::vector<cell_address_type> probeset; |
| 79 | + for (const auto& [k, v]: probes_) { |
| 80 | + if (probeset_ids(k)) probeset.push_back(k); |
| 81 | + } |
| 82 | + auto assoc = arb::sampler_association{std::move(sched), |
| 83 | + std::move(fn), |
| 84 | + std::move(probeset)}; |
| 85 | + auto result = samplers_.insert({h, std::move(assoc)}); |
| 86 | + arb_assert(result.second); |
| 87 | +} |
| 88 | + |
| 89 | +void adex_cell_group::remove_sampler(sampler_association_handle h) { |
| 90 | + std::lock_guard<std::mutex> guard(sampler_mex_); |
| 91 | + samplers_.erase(h); |
| 92 | +} |
| 93 | +void adex_cell_group::remove_all_samplers() { |
| 94 | + std::lock_guard<std::mutex> guard(sampler_mex_); |
| 95 | + samplers_.clear(); |
| 96 | +} |
| 97 | + |
| 98 | +void adex_cell_group::reset() { |
| 99 | + spikes_.clear(); |
| 100 | +} |
| 101 | + |
| 102 | +// integrate a single cell's state from current time `cur` tos final time `end`. |
| 103 | +// Extra parameters |
| 104 | +// * the cell cannot be updated until time `nxt`, which might be in the past or future. |
| 105 | +// |
| 106 | +// We can be in three states: |
| 107 | +// 1. nxt <= cur: we can simply update the cell without further consideration |
| 108 | +// 2. cur < nxt <= end: we perform two steps: |
| 109 | +// a. cur - nxt: refractory period, just manipulate w |
| 110 | +// b. nxt - end: normal dynamics, add spike |
| 111 | +// 3. nxt > end. Skip everything |
| 112 | +void integrate_until(adex_lowered_cell& cell, const time_type end, const time_type& nxt, time_type& cur) { |
| 113 | + // perform pre-step to skip refractory period. This _might_ put cell state beyond the epoch end. |
| 114 | + if (nxt > cur) cur = std::min(nxt, end); |
| 115 | + // if we still have time left, perform the integration. |
| 116 | + if (nxt > end) return; |
| 117 | + // dT |
| 118 | + auto delta = end - cur; |
| 119 | + // membrane potential deviation from resting value |
| 120 | + auto dE = cell.V_m - cell.E_L; |
| 121 | + // leak current |
| 122 | + auto il = cell.g*dE; |
| 123 | + // spike current |
| 124 | + auto is = cell.g*cell.delta*exp((cell.V_m - cell.V_th)/cell.delta); |
| 125 | + // potential delta |
| 126 | + auto dV = (is - il - cell.w)/cell.C_m; |
| 127 | + cell.V_m += delta*dV; |
| 128 | + |
| 129 | + auto dW = (cell.a*dE - cell.w)/cell.tau; |
| 130 | + cell.w += delta*dW; |
| 131 | + cur = end; |
| 132 | +} |
| 133 | + |
| 134 | +void check_spike(adex_lowered_cell& cell, const time_type time, time_type& nxt, const cell_gid_type gid, std::vector<spike>& spikes) { |
| 135 | + if (time > nxt && cell.V_m >= cell.V_th) { |
| 136 | + spikes.emplace_back(cell_member_type{gid, 0}, time); |
| 137 | + // reset membrane potential |
| 138 | + cell.V_m = cell.E_R; |
| 139 | + // schedule next update |
| 140 | + nxt = time + cell.t_ref; |
| 141 | + cell.w += cell.b; |
| 142 | + } |
| 143 | +} |
| 144 | + |
| 145 | +void adex_cell_group::advance_cell(time_type t_fin, |
| 146 | + time_type dt, |
| 147 | + cell_gid_type lid, |
| 148 | + const event_lane_subrange& event_lanes) { |
| 149 | + auto time = current_time_[lid]; |
| 150 | + auto gid = gids_[lid]; |
| 151 | + // Flattened sampler map |
| 152 | + std::vector<probe_metadata> sample_metadata; |
| 153 | + std::vector<sampler_association_handle> sample_callbacks; |
| 154 | + std::vector<std::vector<sample_record>> sample_records; |
| 155 | + |
| 156 | + struct sample_event { |
| 157 | + time_type time; |
| 158 | + adex_probe_kind kind; |
| 159 | + double* data; |
| 160 | + }; |
| 161 | + |
| 162 | + std::vector<sample_event> sample_events; |
| 163 | + std::vector<double> sample_data; |
| 164 | + |
| 165 | + if (!samplers_.empty()) { |
| 166 | + auto tlast = time; |
| 167 | + std::vector<size_t> sample_sizes; |
| 168 | + std::size_t total_size = 0; |
| 169 | + { |
| 170 | + std::lock_guard<std::mutex> guard(sampler_mex_); |
| 171 | + for (auto& [hdl, assoc]: samplers_) { |
| 172 | + // No need to generate events |
| 173 | + if (assoc.probeset_ids.empty()) continue; |
| 174 | + // Construct sampling times, might give us the last time we sampled, so skip that. |
| 175 | + auto times = util::make_range(assoc.sched.events(tlast, t_fin)); |
| 176 | + // while (!times.empty() && times.front() == tlast) times.left++; |
| 177 | + if (times.empty()) continue; |
| 178 | + for (unsigned idx = 0; idx < assoc.probeset_ids.size(); ++idx) { |
| 179 | + const auto& pid = assoc.probeset_ids[idx]; |
| 180 | + if (pid.gid != gid) continue; |
| 181 | + const auto& probe = probes_.at(pid); |
| 182 | + sample_metadata.push_back({pid, idx, util::any_ptr{&probe.metadata}}); |
| 183 | + sample_callbacks.push_back(hdl); |
| 184 | + sample_records.emplace_back(); |
| 185 | + auto& records = sample_records.back(); |
| 186 | + sample_sizes.push_back(times.size()); |
| 187 | + total_size += times.size(); |
| 188 | + for (auto t: times) { |
| 189 | + records.push_back(sample_record{t, nullptr}); |
| 190 | + sample_events.push_back(sample_event{t, probe.kind, nullptr}); |
| 191 | + } |
| 192 | + } |
| 193 | + } |
| 194 | + } |
| 195 | + // Flat list of things to sample |
| 196 | + // NOTE: Need to allocate in one go, else reallocation will mess up the pointers! |
| 197 | + sample_data.resize(total_size); |
| 198 | + auto rx = 0; |
| 199 | + for (unsigned ix = 0; ix < sample_sizes.size(); ++ix) { |
| 200 | + auto size = sample_sizes[ix]; |
| 201 | + for (size_t kx = 0; kx < size; ++kx) { |
| 202 | + sample_records[ix][kx].data = const_cast<const double*>(sample_data.data() + rx); |
| 203 | + sample_events[rx].data = sample_data.data() + rx; |
| 204 | + ++rx; |
| 205 | + } |
| 206 | + } |
| 207 | + } |
| 208 | + util::sort_by(sample_events, [](const auto& s) { return s.time; }); |
| 209 | + auto n_samples = sample_events.size(); |
| 210 | + |
| 211 | + auto& cell = cells_[lid]; |
| 212 | + auto n_events = static_cast<int>(!event_lanes.empty() ? event_lanes[lid].size() : 0); |
| 213 | + auto evt_idx = 0; |
| 214 | + size_t spl_idx = 0; |
| 215 | + while (time < t_fin) { |
| 216 | + auto t_end = std::min(t_fin, time + dt); |
| 217 | + // forward progress? |
| 218 | + arb_assert(t_end > time); |
| 219 | + auto V_0 = cell.V_m; |
| 220 | + auto W_0 = cell.w; |
| 221 | + // Process events in [time, time + dt) |
| 222 | + // delivering each at the exact time |
| 223 | + for (;; ++evt_idx) { |
| 224 | + if (evt_idx >= n_events) break; |
| 225 | + if (event_lanes[lid][evt_idx].time >= t_end) break; |
| 226 | + |
| 227 | + const auto& evt = event_lanes[lid][evt_idx]; |
| 228 | + integrate_until(cell, evt.time, next_update_[lid], current_time_[lid]); |
| 229 | + // NOTE we _could check here instead or in addition. |
| 230 | + // check_spike(cell, evt.time, next_update_[lid], gid, spikes_); |
| 231 | + if (next_update_[lid] <= evt.time) cell.V_m += evt.weight/cell.C_m; |
| 232 | + check_spike(cell, evt.time, next_update_[lid], gid, spikes_); |
| 233 | + } |
| 234 | + // if there's time left before t_end, integrate until that |
| 235 | + integrate_until(cell, t_end, next_update_[lid], current_time_[lid]); |
| 236 | + check_spike(cell, t_end, next_update_[lid], gid, spikes_); |
| 237 | + |
| 238 | + // now process the sampling events |
| 239 | + for (;; ++spl_idx) { |
| 240 | + if (spl_idx >= n_samples) break; |
| 241 | + const auto& evt = sample_events[spl_idx]; |
| 242 | + if (evt.time > t_end) break; |
| 243 | + // interpolation paramter |
| 244 | + auto t = (evt.time - time)/dt; |
| 245 | + if (evt.kind == adex_probe_kind::voltage) *evt.data = math::lerp(V_0, cell.V_m, t); |
| 246 | + if (evt.kind == adex_probe_kind::adaption) *evt.data = math::lerp(W_0, cell.w, t); |
| 247 | + } |
| 248 | + |
| 249 | + time = t_end; |
| 250 | + } |
| 251 | + |
| 252 | + arb_assert(time == t_fin); |
| 253 | + arb_assert(evt_idx == n_events); |
| 254 | + arb_assert(spl_idx == n_samples); |
| 255 | + |
| 256 | + auto n_samplers = sample_callbacks.size(); |
| 257 | + { |
| 258 | + std::lock_guard<std::mutex> guard{sampler_mex_}; |
| 259 | + for (size_t s_idx = 0; s_idx < n_samplers; ++s_idx) { |
| 260 | + const auto& sd = sample_records[s_idx]; |
| 261 | + auto hdl = sample_callbacks[s_idx]; |
| 262 | + const auto& fun = samplers_[hdl].sampler; |
| 263 | + arb_assert(fun); |
| 264 | + fun(sample_metadata[s_idx], sd.size(), sd.data()); |
| 265 | + } |
| 266 | + } |
| 267 | +} |
| 268 | + |
| 269 | +void adex_cell_group::t_serialize(serializer& ser, const std::string& k) const { |
| 270 | + serialize(ser, k, *this); |
| 271 | +} |
| 272 | + |
| 273 | +void adex_cell_group::t_deserialize(serializer& ser, const std::string& k) { |
| 274 | + deserialize(ser, k, *this); |
| 275 | +} |
| 276 | + |
| 277 | +std::vector<probe_metadata> adex_cell_group::get_probe_metadata(const cell_address_type& key) const { |
| 278 | + // SAFETY: Probe associations are fixed after construction, so we do not |
| 279 | + // need to grab the mutex. |
| 280 | + if (auto it = probes_.find(key); it != probes_.end()) { |
| 281 | + return {probe_metadata{key, 0, &it->second.metadata}}; |
| 282 | + } |
| 283 | + return {}; |
| 284 | +} |
0 commit comments