Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

📊 Improve spike delivery #2222

Merged
145 changes: 84 additions & 61 deletions arbor/communication/communicator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ void communicator::update_connections(const connectivity& rec,
PE(init:communicator:update:clear);
// Forget all lingering information
connections_.clear();
ext_connections_.clear();
connection_part_.clear();
index_divisions_.clear();
PL();
Expand Down Expand Up @@ -123,8 +124,8 @@ void communicator::update_connections(const connectivity& rec,
// to do this in place.
// NOTE: The connections are partitioned by the domain of their source gid.
PE(init:communicator:update:connections);
connections_.resize(n_cons);
ext_connections_.resize(n_ext_cons);
std::vector<connection> connections(n_cons);
std::vector<connection> ext_connections(n_ext_cons);
auto offsets = connection_part_; // Copy, as we use this as the list of current target indices to write into
std::size_t ext = 0;
auto src_domain = src_domains.begin();
Expand All @@ -140,15 +141,15 @@ void communicator::update_connections(const connectivity& rec,
auto tgt_lid = target_resolver.resolve(tgt_gid, conn.target);
auto offset = offsets[*src_domain]++;
++src_domain;
connections_[offset] = {{src_gid, src_lid}, tgt_lid, conn.weight, conn.delay, index};
connections[offset] = {{src_gid, src_lid}, tgt_lid, conn.weight, conn.delay, index};
}
for (const auto cidx: util::make_span(part_ext_connections[index], part_ext_connections[index+1])) {
const auto& conn = gid_ext_connections[cidx];
auto src = global_cell_of(conn.source);
auto src_gid = conn.source.rid;
if(is_external(src_gid)) throw arb::source_gid_exceeds_limit(tgt_gid, src_gid);
auto tgt_lid = target_resolver.resolve(tgt_gid, conn.target);
ext_connections_[ext] = {src, tgt_lid, conn.weight, conn.delay, index};
ext_connections[ext] = {src, tgt_lid, conn.weight, conn.delay, index};
++ext;
}
}
Expand All @@ -168,9 +169,14 @@ void communicator::update_connections(const connectivity& rec,
const auto& cp = connection_part_;
threading::parallel_for::apply(0, num_domains_, thread_pool_.get(),
[&](cell_size_type i) {
util::sort(util::subrange_view(connections_, cp[i], cp[i+1]));
util::sort(util::subrange_view(connections, cp[i], cp[i+1]));
});
std::sort(ext_connections_.begin(), ext_connections_.end());
std::sort(ext_connections.begin(), ext_connections.end());
PL();

PE(init:communicator:update:destructure_connections);
connections_.make(connections);
ext_connections_.make(ext_connections);
PL();
}

Expand All @@ -181,12 +187,12 @@ std::pair<cell_size_type, cell_size_type> communicator::group_queue_range(cell_s

time_type communicator::min_delay() {
time_type res = std::numeric_limits<time_type>::max();
res = std::accumulate(connections_.begin(), connections_.end(),
res,
[](auto&& acc, auto&& el) { return std::min(acc, time_type(el.delay)); });
res = std::accumulate(ext_connections_.begin(), ext_connections_.end(),
res,
[](auto&& acc, auto&& el) { return std::min(acc, time_type(el.delay)); });
res = std::accumulate(connections_.delays.begin(), connections_.delays.end(),
res,
[](auto&& acc, time_type del) { return std::min(acc, del); });
res = std::accumulate(ext_connections_.delays.begin(), ext_connections_.delays.end(),
res,
[](auto&& acc, time_type del) { return std::min(acc, del); });
res = distributed_->min(res);
return res;
}
Expand Down Expand Up @@ -228,19 +234,38 @@ void communicator::set_remote_spike_filter(const spike_predicate& p) { remote_sp
void communicator::remote_ctrl_send_continue(const epoch& e) { distributed_->remote_ctrl_send_continue(e); }
void communicator::remote_ctrl_send_done() { distributed_->remote_ctrl_send_done(); }

// Given
// * a set of connections and an index into the set
// * a range of spikes
// * an output queue,
// append events for that sub-range of spikes to the
// queue that has the same source as the connection
// at index.
template<typename It>
void enqueue_from_source(const communicator::connection_list& cons,
const size_t idx,
It& spk,
const It end,
std::vector<pse_vector>& out) {
// const refs to connection.
auto src = cons.srcs[idx];
auto dst = cons.dests[idx];
auto del = cons.delays[idx];
auto wgt = cons.weights[idx];
auto dom = cons.idx_on_domain[idx];
auto& que = out[dom];
for (; spk != end && spk->source == src; ++spk) {
que.emplace_back(dst, spk->time + del, wgt);
}
}

// Internal helper to append to the event queues
template<typename S, typename C>
void append_events_from_domain(C cons,
S spks,
template<typename S>
void append_events_from_domain(const communicator::connection_list& cons, size_t cn, const size_t ce,
const S& spks,
std::vector<pse_vector>& queues) {
// Predicate for partitioning
struct spike_pred {
bool operator()(const spike& spk, const cell_member_type& src) { return spk.source < src; }
bool operator()(const cell_member_type& src, const spike& spk) { return src < spk.source; }
};

auto sp = spks.begin(), se = spks.end();
auto cn = cons.begin(), ce = cons.end();
if (se == sp) return;
// We have a choice of whether to walk spikes or connections:
// i.e., we can iterate over the spikes, and for each spike search
// the for connections that have the same source; or alternatively
Expand All @@ -251,64 +276,62 @@ void append_events_from_domain(C cons,
// complexity of order max(S log(C), C log(S)), where S is the
// number of spikes, and C is the number of connections.
if (cons.size() < spks.size()) {
while (cn != ce && sp != se) {
auto sources = std::equal_range(sp, se, cn->source, spike_pred());
for (auto s: util::make_range(sources)) {
queues[cn->index_on_domain].push_back(make_event(*cn, s));
}
sp = sources.first;
++cn;
for (; sp != se && cn < ce; ++cn) {
// sp is now the beginning of a range of spikes from the same
// source.
sp = std::lower_bound(sp, se,
cons.srcs[cn],
[](const auto& spk, const auto& src) { return spk.source < src; });
// now, sp is at the end of the equal source range.
enqueue_from_source(cons, cn, sp, se, queues);
}
}
else {
while (cn != ce && sp != se) {
auto targets = std::equal_range(cn, ce, sp->source);
for (auto c: util::make_range(targets)) {
queues[c.index_on_domain].push_back(make_event(c, *sp));
while (sp != se) {
auto beg = sp;
auto src = beg->source;
// Here, `cn` is the index of the first connection whose source
// is larger or equal to the spike's source. It may be `ce` if
// all elements compare < to spk.source.
cn = std::lower_bound(cons.srcs.begin() + cn,
cons.srcs.begin() + ce,
src)
- cons.srcs.begin();
for (; cn < ce && cons.srcs[cn] == src; ++cn) {
// Reset the spike iterator as we walk the same sub-range
// for each connection with the same source.
sp = beg;
// If we ever get multiple spikes from the same source, treat
// them all. This is mostly rare.
enqueue_from_source(cons, cn, sp, se, queues);
}
cn = targets.first;
++sp;
while (sp != se && sp->source == src) ++sp;
}
}
}

void communicator::make_event_queues(
const gathered_vector<spike>& global_spikes,
std::vector<pse_vector>& queues,
const std::vector<spike>& external_spikes) {
void communicator::make_event_queues(communicator::spikes& spikes,
std::vector<pse_vector>& queues) {
arb_assert(queues.size()==num_local_cells_);
const auto& sp = global_spikes.partition();
const auto& sp = spikes.from_local.partition();
const auto& cp = connection_part_;
for (auto dom: util::make_span(num_domains_)) {
append_events_from_domain(util::subrange_view(connections_, cp[dom], cp[dom+1]),
util::subrange_view(global_spikes.values(), sp[dom], sp[dom+1]),
append_events_from_domain(connections_, cp[dom], cp[dom+1],
util::subrange_view(spikes.from_local.values(), sp[dom], sp[dom+1]),
queues);
}
num_local_events_ = util::sum_by(queues, [](const auto& q) {return q.size();}, num_local_events_);
// Now that all local spikes have been processed; consume the remote events coming in.
// - turn all gids into externals
auto spikes = external_spikes;
std::for_each(spikes.begin(),
spikes.end(),
std::for_each(spikes.from_remote.begin(), spikes.from_remote.end(),
[](auto& s) { s.source = global_cell_of(s.source); });
append_events_from_domain(ext_connections_, spikes, queues);
append_events_from_domain(ext_connections_, 0, ext_connections_.size(), spikes.from_remote, queues);
}

std::uint64_t communicator::num_spikes() const {
return num_spikes_;
}

void communicator::set_num_spikes(std::uint64_t n) {
num_spikes_ = n;
}

cell_size_type communicator::num_local_cells() const {
return num_local_cells_;
}

const std::vector<connection>& communicator::connections() const {
return connections_;
}
std::uint64_t communicator::num_spikes() const { return num_spikes_; }
void communicator::set_num_spikes(std::uint64_t n) { num_spikes_ = n; }
cell_size_type communicator::num_local_cells() const { return num_local_cells_; }
const communicator::connection_list& communicator::connections() const { return connections_; }

void communicator::reset() {
num_spikes_ = 0;
Expand Down
45 changes: 37 additions & 8 deletions arbor/communication/communicator.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include <vector>
#include <unordered_set>

#include <arbor/export.hpp>
#include <arbor/common_types.hpp>
Expand Down Expand Up @@ -63,19 +64,14 @@ class ARB_ARBOR_API communicator {
/// all events that must be delivered to targets in that cell group as a
/// result of the global spike exchange, plus any events that were already
/// in the list.
void make_event_queues(
const gathered_vector<spike>& global_spikes,
std::vector<pse_vector>& queues,
const std::vector<spike>& external_spikes={});
void make_event_queues(spikes& spks, std::vector<pse_vector>& queues);

/// Returns the total number of global spikes over the duration of the simulation
std::uint64_t num_spikes() const;
void set_num_spikes(std::uint64_t n);

cell_size_type num_local_cells() const;

const std::vector<connection>& connections() const;

void reset();

// used for commmunicate to coupled simulations
Expand All @@ -89,13 +85,46 @@ class ARB_ARBOR_API communicator {

void set_remote_spike_filter(const spike_predicate&);

// TODO: This is public for now.
struct connection_list {
std::vector<cell_size_type> idx_on_domain;
std::vector<cell_member_type> srcs;
std::vector<cell_lid_type> dests;
std::vector<float> weights;
std::vector<float> delays;

void make(const std::vector<connection>& cons) {
clear();
for (const auto& con: cons) {
idx_on_domain.push_back(con.index_on_domain);
srcs.push_back(con.source);
dests.push_back(con.destination);
weights.push_back(con.weight);
delays.push_back(con.delay);
}
}

void clear() {
idx_on_domain.clear();
srcs.clear();
dests.clear();
weights.clear();
delays.clear();
}

size_t size() const { return srcs.size(); }
};

const connection_list& connections() const;

private:

cell_size_type num_total_cells_ = 0;
cell_size_type num_local_cells_ = 0;
cell_size_type num_local_groups_ = 0;
cell_size_type num_domains_ = 0;
// Arbor internal connections
std::vector<connection> connections_;
connection_list connections_;
// partition of connections over the domains of the sources' ids.
std::vector<cell_size_type> connection_part_;
std::vector<cell_size_type> index_divisions_;
Expand All @@ -105,7 +134,7 @@ class ARB_ARBOR_API communicator {

// Connections from external simulators into Arbor.
// Currently we have no partitions/indices/acceleration structures
std::vector<connection> ext_connections_;
connection_list ext_connections_;

distributed_context_handle distributed_;
task_system_handle thread_pool_;
Expand Down
20 changes: 10 additions & 10 deletions arbor/include/arbor/spike_event.hpp
Original file line number Diff line number Diff line change
@@ -1,33 +1,33 @@
#pragma once

#include <arbor/arb_types.hpp>

#include <iosfwd>
#include <tuple>
#include <vector>

#include <arbor/export.hpp>
#include <arbor/serdes.hpp>
#include <arbor/common_types.hpp>
#include <arbor/util/lexcmp_def.hpp>

namespace arb {

// Events delivered to targets on cells with a cell group.

struct spike_event {
cell_lid_type target;
time_type time;
float weight;

friend bool operator==(const spike_event& l, const spike_event& r) {
return l.target==r.target && l.time==r.time && l.weight==r.weight;
}
cell_lid_type target = -1;
float weight = 0;
time_type time = -1;

friend bool operator<(const spike_event& l, const spike_event& r) {
return std::tie(l.time, l.target, l.weight) < std::tie(r.time, r.target, r.weight);
}
spike_event() = default;
constexpr spike_event(cell_lid_type tgt, time_type t, arb_weight_type w) noexcept: target(tgt), weight(w), time(t) {}

ARB_SERDES_ENABLE(spike_event, target, time, weight);
};

ARB_DEFINE_LEXICOGRAPHIC_ORDERING(spike_event,(a.time,a.target,a.weight),(b.time,b.target,b.weight))

using pse_vector = std::vector<spike_event>;

struct cell_spike_events {
Expand Down
Loading