diff --git a/arbor/backends/multicore/shared_state.cpp b/arbor/backends/multicore/shared_state.cpp index f6c0b20571..1afc67910c 100644 --- a/arbor/backends/multicore/shared_state.cpp +++ b/arbor/backends/multicore/shared_state.cpp @@ -7,6 +7,8 @@ #include #include +#include + #include #include #include @@ -479,6 +481,7 @@ void shared_state::instantiate(arb::mechanism& m, unsigned id, const mechanism_o m.ppack_.vec_di = cv_to_intdom.data(); m.ppack_.vec_dt = dt_cv.data(); m.ppack_.vec_v = voltage.data(); + m.ppack_.vec_v_peer = voltage.data(); m.ppack_.vec_i = current_density.data(); m.ppack_.vec_g = conductivity.data(); m.ppack_.temperature_degC = temperature_degC.data(); @@ -549,7 +552,7 @@ void shared_state::instantiate(arb::mechanism& m, unsigned id, const mechanism_o { // Allocate bulk storage std::size_t index_width_padded = extend_width(m, pos_data.cv.size()); - std::size_t count = mult_in_place + peer_indices + m.mech_.n_ions + 1; + std::size_t count = mult_in_place + peer_indices + m.mech_.n_ions + 2; store.indices_ = iarray(count*index_width_padded, 0, pad); chunk_writer writer(store.indices_.data(), index_width_padded); // Setup node indices @@ -587,6 +590,7 @@ void shared_state::instantiate(arb::mechanism& m, unsigned id, const mechanism_o // Peer CVs are only filled for gap junction mechanisms. They are used // to index the voltage at the other side of a gap-junction connection. if (peer_indices) m.ppack_.peer_index = writer.append(pos_data.peer_cv, pos_data.peer_cv.back()); + if (peer_indices) m.ppack_.peer_cg = writer.append(pos_data.peer_cg, 0); } } diff --git a/arbor/cell_group_factory.cpp b/arbor/cell_group_factory.cpp index 69af0ffdcb..a891261361 100644 --- a/arbor/cell_group_factory.cpp +++ b/arbor/cell_group_factory.cpp @@ -20,14 +20,14 @@ cell_group_ptr make_cell_group(Args&&... args) { } ARB_ARBOR_API cell_group_factory cell_kind_implementation( - cell_kind ck, backend_kind bk, const execution_context& ctx) + cell_kind ck, cell_gid_type cg, backend_kind bk, const execution_context& ctx) { using gid_vector = std::vector; switch (ck) { case cell_kind::cable: - return [bk, ctx](const gid_vector& gids, const recipe& rec, cell_label_range& cg_sources, cell_label_range& cg_targets) { - return make_cell_group(gids, rec, cg_sources, cg_targets, make_fvm_lowered_cell(bk, ctx)); + return [bk, ctx, cg](const gid_vector& gids, const recipe& rec, cell_label_range& cg_sources, cell_label_range& cg_targets) { + return make_cell_group(gids, rec, cg_sources, cg_targets, cg, make_fvm_lowered_cell(bk, ctx, cg)); }; case cell_kind::spike_source: diff --git a/arbor/cell_group_factory.hpp b/arbor/cell_group_factory.hpp index b8a321c423..3cf6cb2966 100644 --- a/arbor/cell_group_factory.hpp +++ b/arbor/cell_group_factory.hpp @@ -19,15 +19,15 @@ namespace arb { using cell_group_factory = std::function< - cell_group_ptr(const std::vector&, const recipe&, cell_label_range& cg_sources, cell_label_range& cg_targets)>; + cell_group_ptr(const std::vector&, const recipe&, cell_label_range& cg_sources, cell_label_range& cg_targets)>; ARB_ARBOR_API cell_group_factory cell_kind_implementation( - cell_kind, backend_kind, const execution_context&); + cell_kind, cell_gid_type, backend_kind, const execution_context&); inline bool cell_kind_supported( cell_kind c, backend_kind b, const execution_context& ctx) { - return static_cast(cell_kind_implementation(c, b, ctx)); + return static_cast(cell_kind_implementation(c, 0, b, ctx)); } } // namespace arb diff --git a/arbor/communication/dry_run_context.cpp b/arbor/communication/dry_run_context.cpp index fcfe4f9f56..1b9019eef5 100644 --- a/arbor/communication/dry_run_context.cpp +++ b/arbor/communication/dry_run_context.cpp @@ -70,6 +70,16 @@ struct dry_run_context_impl { return gathered_vector(std::move(gathered_gids), std::move(partition)); } + std::vector + gather_cg_cv_map(const std::vector& local_map) const { + return {}; + } + + std::vector + gather_trace(const std::vector& trace) const { + return {}; + } + std::vector> gather_gj_connections(const std::vector> & local_connections) const { auto local_size = local_connections.size(); diff --git a/arbor/communication/mpi_context.cpp b/arbor/communication/mpi_context.cpp index 22a2428343..5d67d9ec00 100644 --- a/arbor/communication/mpi_context.cpp +++ b/arbor/communication/mpi_context.cpp @@ -39,6 +39,16 @@ struct mpi_context_impl { return mpi::gather_all_with_partition(local_gids, comm_); } + std::vector + gather_cg_cv_map(const std::vector& cg_cv_map) const { + return mpi::gather_all(cg_cv_map, comm_); + } + + std::vector + gather_trace(const std::vector& trace) const { + return mpi::gather_all(trace, comm_); + } + std::vector> gather_gj_connections(const std::vector>& local_connections) const { return mpi::gather_all(local_connections, comm_); diff --git a/arbor/distributed_context.hpp b/arbor/distributed_context.hpp index 58ca11c318..762257de37 100644 --- a/arbor/distributed_context.hpp +++ b/arbor/distributed_context.hpp @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -66,6 +67,14 @@ class distributed_context { return impl_->gather_gids(local_gids); } + std::vector gather_cg_cv_map(const std::vector& cg_cv_map) const { + return impl_->gather_cg_cv_map(cg_cv_map); + } + + std::vector gather_trace(const std::vector& trace) const { + return impl_->gather_trace(trace); + } + gj_connection_vector gather_gj_connections(const gj_connection_vector& local_connections) const { return impl_->gather_gj_connections(local_connections); } @@ -106,6 +115,10 @@ class distributed_context { gather_spikes(const spike_vector& local_spikes) const = 0; virtual gathered_vector gather_gids(const gid_vector& local_gids) const = 0; + virtual std::vector + gather_cg_cv_map(const std::vector& cg_cv_map) const = 0; + virtual std::vector + gather_trace(const std::vector& trace) const = 0; virtual gj_connection_vector gather_gj_connections(const gj_connection_vector& local_connections) const = 0; virtual cell_label_range @@ -137,10 +150,21 @@ class distributed_context { gather_gids(const gid_vector& local_gids) const override { return wrapped.gather_gids(local_gids); } + std::vector + gather_cg_cv_map(const std::vector& cg_cv_map) const override { + //std::cout << wrapped.id() << " Gather CG CV Map\n"; + return wrapped.gather_cg_cv_map(cg_cv_map); + } + std::vector + gather_trace(const std::vector& trace) const override { + return wrapped.gather_trace(trace); + } std::vector> gather_gj_connections(const gj_connection_vector& local_connections) const override { + //std::cout << wrapped.id() << " Gather GJ Map\n"; return wrapped.gather_gj_connections(local_connections); } + //cell_label_range includes sizes, labels, ranges cell_label_range gather_cell_label_range(const cell_label_range& local_ranges) const override { return wrapped.gather_cell_label_range(local_ranges); @@ -191,6 +215,14 @@ struct local_context { {0u, static_cast(local_gids.size())} ); } + std::vector + gather_cg_cv_map(const std::vector& cg_cv_map) const { + return {}; + } + std::vector + gather_trace(const std::vector& cg_cv_map) const { + return {}; + } std::vector> gather_gj_connections(const std::vector>& local_connections) const { return local_connections; diff --git a/arbor/domain_decomposition.cpp b/arbor/domain_decomposition.cpp index 7a58d804d4..9e99d32e09 100644 --- a/arbor/domain_decomposition.cpp +++ b/arbor/domain_decomposition.cpp @@ -54,7 +54,8 @@ domain_decomposition::domain_decomposition( } for (const auto& gj: rec.gap_junctions_on(gid)) { if (!gid_set.count(gj.peer.gid)) { - throw invalid_gj_cell_group(gid, gj.peer.gid); + //throw invalid_gj_cell_group(gid, gj.peer.gid); + std::cerr << "Warning: Need to use Waveform Relaxation.\n"; } } } diff --git a/arbor/fvm_layout.cpp b/arbor/fvm_layout.cpp index 7ede9f7c8f..218735e84d 100644 --- a/arbor/fvm_layout.cpp +++ b/arbor/fvm_layout.cpp @@ -23,6 +23,8 @@ #include "util/transform.hpp" #include "util/unique.hpp" +#include + namespace arb { using util::assign; @@ -626,6 +628,7 @@ fvm_mechanism_data& append(fvm_mechanism_data& left, const fvm_mechanism_data& r append(L.multiplicity, R.multiplicity); append(L.norm_area, R.norm_area); append(L.local_weight, R.local_weight); + append(L.peer_cg, R.peer_cg); append_offset(L.target, target_offset, R.target); arb_assert(util::equal(L.param_values, R.param_values, @@ -655,6 +658,124 @@ fvm_mechanism_data& append(fvm_mechanism_data& left, const fvm_mechanism_data& r return left; } +// build cg_cv_map = { {gid, lid} -> {cg, cv} } +// gather function: +// 1) take cg_cv_map & split into 4 arrays (gids, lids, cgs, cvs) +// 2) gather separately +// 3) update index map function { {gid, cg, cv} ->index } with gathered gids, cgs, cvs +// 4) use index map as gj_cvs map equivalent + + +// 1) split cg_cv_map into 4 arrays instead of unordered_map +ARB_ARBOR_API std::vector> fvm_build_gap_junction_cv_arr( + const std::vector& cells, + const std::vector& gids, + unsigned cg, + const fvm_cv_discretization& D) +{ + std::vector gid, lids, cgs, cvs; + arb_assert(cells.size() == gids.size()); + std::unordered_map gj_cg_cvs; + for (auto cell_idx: util::make_span(0, cells.size())) { + for (const auto& mech : cells[cell_idx].junctions()) { + for (const auto& gj: mech.second) { + gid.push_back(gids[cell_idx]); + lids.push_back(gj.lid); + cgs.push_back(cg); + cvs.push_back(D.geometry.location_cv(cell_idx, gj.loc, cv_prefer::cv_nonempty)); + } + } + } + return {gid, lids, cgs, cvs}; +} + +//3) index map function +using cell_id = std::tuple; + +ARB_ARBOR_API std::map, int> fvm_cell_to_index_lowered( + const std::vector& cgs, + const std::vector& cvs +) +{ + std::map, int> cell_to_index; + + for (int i = 0; i element{cvs[i], cgs[i]}; + cell_to_index[element] = i; + } + + return cell_to_index; +} + +ARB_ARBOR_API std::map fvm_cell_to_index( + const std::vector& gids, + const std::vector& cgs, + const std::vector& cvs, + const std::vector& lids +) +{ + std::map cell_to_index; + + for (int i = 0; i fvm_index_to_cell( + std::map& cell_to_index +) +{ + std::map index_to_cell; + + for (const auto& [value, index]: cell_to_index) { + index_to_cell[index] = value; + } + + return index_to_cell; +} + +/* +ARB_ARBOR_API std::unordered_map fvm_index_to_cv_map( + const std::vector& gids, + const std::vector& lids, + const std::vector& cgs, + const std::vector& cvs, + const std::map& cell_to_index +) +{ + std::unordered_map gj_cvs_index; + for (int i = 0; i fvm_index_to_cv_map( + const std::vector& gids, + const std::vector& lids, + const std::vector& cgs, + const std::vector& cvs, + const std::map& cell_to_index +) +{ + std::unordered_map gj_cvs_index; + for (int i = 0; i fvm_build_gap_junction_cv_map( const std::vector& cells, const std::vector& gids, @@ -670,8 +791,10 @@ ARB_ARBOR_API std::unordered_map fvm_build_gap_ } } return gj_cvs; -} +}*/ +//make resolution map with gj_data and gids global -> maybe try to gather cell_label_range gj_data +/* ARB_ARBOR_API std::unordered_map> fvm_resolve_gj_connections( const std::vector& gids, const cell_label_range& gj_data, @@ -692,10 +815,44 @@ ARB_ARBOR_API std::unordered_map> f auto peer_cv = gj_cvs.at({conn.peer.gid, peer_idx}); local_conns.push_back({local_idx, local_cv, peer_cv, conn.weight}); + //std::cout <<"local_idx = " << local_idx << " local_cv = " << local_cv << " peer_cv = " << peer_cv << " conn.weight = " << conn.weight << std::endl; + } + // Sort local_conns by local_cv. + util::sort(local_conns); + gj_conns[gid] = std::move(local_conns); + } + return gj_conns; +}*/ + + +ARB_ARBOR_API std::unordered_map> fvm_resolve_gj_connections( + const std::vector& gids, + const cell_label_range& gj_data, + const std::unordered_map& gj_cvs, + const recipe& rec) +{ + // Construct and resolve all gj_connections. + std::unordered_map> gj_conns; + label_resolution_map resolution_map({gj_data, gids}); + auto gj_resolver = resolver(&resolution_map); + for (const auto& gid: gids) { + std::vector local_conns; + for (const auto& conn: rec.gap_junctions_on(gid)) { + auto local_idx = gj_resolver.resolve({gid, conn.local}); + auto peer_idx = gj_resolver.resolve(conn.peer); + + auto local_cv = gj_cvs.at({gid, local_idx}).gid; + auto peer_cv = gj_cvs.at({conn.peer.gid, peer_idx}).gid; + + auto peer_cg = gj_cvs.at({conn.peer.gid, peer_idx}).index; + + local_conns.push_back({local_idx, local_cv, peer_cv, conn.weight, peer_cg}); + //std::cout <<"local_idx = " << local_idx << " local_cv = " << local_cv << " peer_cv = " << peer_cv << " peer_cg = " << peer_cg << " conn.weight = " << conn.weight << std::endl; } // Sort local_conns by local_cv. util::sort(local_conns); gj_conns[gid] = std::move(local_conns); + } return gj_conns; } @@ -1129,10 +1286,13 @@ fvm_mechanism_data fvm_build_mechanism_data( config.cv.push_back(conn.local_cv); config.peer_cv.push_back(conn.peer_cv); config.local_weight.push_back(conn.weight); + config.peer_cg.push_back(conn.peer_cg); + //std::cout << "node cv = " << conn.local_cv <<"peer cv = " << conn.peer_cv <<"peer cg = " << config.peer_cg.back() << std::endl; for (unsigned i = 0; i < local_junction_desc.param_values.size(); ++i) { config.param_values[i].second.push_back(local_junction_desc.param_values[i]); } } + // Add non-empty fvm_mechanism_config to the fvm_mechanism_data for (auto [name, config]: junction_configs) { diff --git a/arbor/fvm_layout.hpp b/arbor/fvm_layout.hpp index b81debcd81..31024cad84 100644 --- a/arbor/fvm_layout.hpp +++ b/arbor/fvm_layout.hpp @@ -198,6 +198,9 @@ struct fvm_mechanism_config { // duplicates for point mechanisms. std::vector cv; + // Cell group for peer index + std::vector peer_cg; + // Coalesced synapse multiplier (point mechanisms only). std::vector multiplicity; @@ -262,6 +265,7 @@ struct fvm_stimulus_config { std::vector> envelope_amplitude; // [A/m²] }; +/* // Maps gj {gid, lid} locations on a cell to their CV indices. ARB_ARBOR_API std::unordered_map fvm_build_gap_junction_cv_map( const std::vector& cells, @@ -274,6 +278,57 @@ ARB_ARBOR_API std::unordered_map> f const cell_label_range& gj_data, const std::unordered_map& gj_cv, const recipe& rec); + */ + +// Maps gj {gid, lid} locations on a cell to their CV indices. +ARB_ARBOR_API std::unordered_map fvm_build_gap_junction_cv_map( + const std::vector& cells, + const std::vector& gids, + const fvm_cv_discretization& D); + +// Resolves gj_connections into {gid, lid} pairs, then to CV indices and a weight. +ARB_ARBOR_API std::unordered_map> fvm_resolve_gj_connections( + const std::vector& gids, + const cell_label_range& gj_data, + const std::unordered_map& gj_cv, + const recipe& rec); + +using cell_id = std::tuple; + +ARB_ARBOR_API std::map fvm_cell_to_index( + const std::vector& gids, + const std::vector& cgs, + const std::vector& cvs, + const std::vector& lids); + +ARB_ARBOR_API std::map fvm_index_to_cell( + std::map& cell_to_index); + +ARB_ARBOR_API std::unordered_map fvm_index_to_cv_map( + const std::vector& gids, + const std::vector& lids, + const std::vector& cgs, + const std::vector& cvs, + const std::map& cell_to_index); + +ARB_ARBOR_API std::map, int> fvm_cell_to_index_lowered( + const std::vector& cgs, + const std::vector& cvs +); +/* +ARB_ARBOR_API std::unordered_map fvm_index_to_cv_map( + const std::vector& gids, + const std::vector& lids, + const std::vector& cgs, + const std::vector& cvs, + const std::map& cell_to_index);*/ + +// 1) split cg cv map into 4 arrays +ARB_ARBOR_API std::vector> fvm_build_gap_junction_cv_arr( + const std::vector& cells, + const std::vector& gids, + unsigned cg, + const fvm_cv_discretization& D); struct fvm_mechanism_data { // Mechanism config, indexed by mechanism name. diff --git a/arbor/fvm_lowered_cell.hpp b/arbor/fvm_lowered_cell.hpp index a3bff57d86..17f5392f7f 100644 --- a/arbor/fvm_lowered_cell.hpp +++ b/arbor/fvm_lowered_cell.hpp @@ -236,6 +236,6 @@ struct fvm_lowered_cell { using fvm_lowered_cell_ptr = std::unique_ptr; -ARB_ARBOR_API fvm_lowered_cell_ptr make_fvm_lowered_cell(backend_kind p, const execution_context& ctx); +ARB_ARBOR_API fvm_lowered_cell_ptr make_fvm_lowered_cell(backend_kind p, const execution_context& ctx, cell_gid_type cg); } // namespace arb diff --git a/arbor/fvm_lowered_cell_impl.cpp b/arbor/fvm_lowered_cell_impl.cpp index 01b0ee0fd7..f1d54969f3 100644 --- a/arbor/fvm_lowered_cell_impl.cpp +++ b/arbor/fvm_lowered_cell_impl.cpp @@ -12,10 +12,10 @@ namespace arb { -fvm_lowered_cell_ptr make_fvm_lowered_cell(backend_kind p, const execution_context& ctx) { +fvm_lowered_cell_ptr make_fvm_lowered_cell(backend_kind p, const execution_context& ctx, cell_gid_type cg) { switch (p) { case backend_kind::multicore: - return fvm_lowered_cell_ptr(new fvm_lowered_cell_impl(ctx)); + return fvm_lowered_cell_ptr(new fvm_lowered_cell_impl(ctx, cg)); case backend_kind::gpu: #ifdef ARB_HAVE_GPU return fvm_lowered_cell_ptr(new fvm_lowered_cell_impl(ctx)); diff --git a/arbor/fvm_lowered_cell_impl.hpp b/arbor/fvm_lowered_cell_impl.hpp index fcd7c4ee0b..501a481f66 100644 --- a/arbor/fvm_lowered_cell_impl.hpp +++ b/arbor/fvm_lowered_cell_impl.hpp @@ -7,6 +7,7 @@ // implementation details may be tested in the unit tests. // It should otherwise only be used in `fvm_lowered_cell.cpp`. +#include #include #include #include @@ -17,6 +18,12 @@ #include #include + +#include +#include +#include +#include + #include #include #include @@ -37,6 +44,8 @@ #include "util/strprintf.hpp" #include "util/transform.hpp" +#include + namespace arb { template @@ -47,7 +56,14 @@ class fvm_lowered_cell_impl: public fvm_lowered_cell { using index_type = fvm_index_type; using size_type = fvm_size_type; - fvm_lowered_cell_impl(execution_context ctx): context_(ctx), threshold_watcher_(ctx) {}; + cell_gid_type cell_group; + std::map, int> cell_to_index; + std::map> index_to_cell; + std::map, int> cell_to_index_lowered; + std::vector cvs_local; + + + fvm_lowered_cell_impl(execution_context ctx, cell_gid_type cg): context_(ctx), threshold_watcher_(ctx) {cell_group = context_.distributed->id();}; void reset() override; @@ -76,7 +92,7 @@ class fvm_lowered_cell_impl: public fvm_lowered_cell { } private: - // Host or GPU-side back-end dependent storage. + // Host or GPU-side back-end dependent storage.x using array = typename backend::array; using shared_state = typename backend::shared_state; using sample_event_stream = typename backend::sample_event_stream; @@ -188,6 +204,7 @@ void fvm_lowered_cell_impl::reset() { threshold_watcher_.reset(state_->voltage); } + template fvm_integration_result fvm_lowered_cell_impl::integrate( value_type tfinal, @@ -198,7 +215,7 @@ fvm_integration_result fvm_lowered_cell_impl::integrate( set_gpu(); // Integration setup - PE(advance:integrate:setup); + PE(advance_integrate_setup); threshold_watcher_.clear_crossings(); auto n_samples = staged_samples.size(); @@ -218,116 +235,779 @@ fvm_integration_result fvm_lowered_cell_impl::integrate( // per-compartment dt probably not a win on GPU), possibly rumbling // complete fvm state into shared state object. - while (remaining_steps) { - // Update any required reversal potentials based on ionic concs. + // Reset state + multicore::iarray cv_to_intdom_r = state_->cv_to_intdom; + multicore::iarray cv_to_cell_r = state_->cv_to_cell; - for (auto& m: revpot_mechanisms_) { - m->update_current(); - } + array time_r = state_->time; + array time_to_r = state_->time_to; + array dt_intdom_r = state_->dt_intdom; + array dt_cv_r = state_->dt_cv; + array voltage_r = state_->voltage; + array current_density_r = state_->current_density; - // Deliver events and accumulate mechanism current contributions. - - PE(advance:integrate:events); - state_->deliverable_events.mark_until_after(state_->time); - PL(); - - PE(advance:integrate:current:zero); - state_->zero_currents(); - PL(); - for (auto& m: mechanisms_) { - auto state = state_->deliverable_events.marked_events(); - arb_deliverable_event_stream events; - events.n_streams = state.n; - events.begin = state.begin_offset; - events.end = state.end_offset; - events.events = (arb_deliverable_event_data*) state.ev_data; // FIXME(TH): This relies on bit-castability - m->deliver_events(events); - m->update_current(); + std::cout << "state_->current_density at reset = "; + for (auto i = 0; i< state_->current_density.size(); ++i) { + std::cout << state_->current_density[i] << " "; + } + std::cout << std::endl; + + array conductivity_r = state_->conductivity; + + array time_since_spike_r = state_->time_since_spike; + multicore::deliverable_event_stream deliverable_events_r = state_->deliverable_events; + multicore::iarray src_to_spike_r = state_->src_to_spike; + + multicore::istim_state stim_data_r = state_->stim_data; + std::unordered_map ion_data_r = state_->ion_data; + + std::unordered_map storage_r; + std::unordered_map> parameters_r; + std::unordered_map> state_vars_r; + + for (auto store: state_->storage) { + storage_r[store.first].data_ = store.second.data_; + storage_r[store.first].indices_ = store.second.indices_; + storage_r[store.first].constraints_ = store.second.constraints_;// + storage_r[store.first].ion_states_ = store.second.ion_states_; + for (auto i = 0; ideliverable_events.drop_marked_events(); + std::ofstream file0; + std::ofstream file1; + std::ofstream file2; + + //trace vectors for recording voltages: voltages for corresponding CVs in cvs_local in same order + // step = 0 step = 1 ... + //trace = [ [v[cvs_local[0]], v[cvs_local[1]], ...], [v[cvs_local[0]], v[cvs_local[1]], ...], ... ] + std::vector> trace, trace_clean, trace_prev; + std::vector trace_t, trace_t_clean; + std::vector trace_wr_it_clean; + + threshold_watcher threshold_watcher_reset = threshold_watcher_; + + // Save starting times + value_type tmin_reset = tmin_; + std::cout << "tmin_reset = " << tmin_reset << std::endl; + std::cout << "state_->time_bounds().first = " << state_->time_bounds().first << std::endl; + std::cout << "len(trace) = " << trace.size() << std::endl; + + value_type dt_max_reset = dt_max; + value_type tfinal_reset = tfinal; + + std::vector peer_ix_group; + std::vector node_ix_group; + std::unordered_map> peer_ix_map; + std::unordered_map peer_ix_reset_map; + std::unordered_map node_ix_reset_map; + + // Global, sorted list of all CVs that are part of a gap junction + auto cvs_global = context_.distributed->gather_cg_cv_map(cvs_local); + + int exp_nr; + std::cin >> exp_nr; + + std::cout << "cvs local: "; + for (int i = 0; i max_remaining_steps) { + std::cout << "ERROR TOO MANY STEPS" << std::endl; + std::cout << "step = " << step << " remaining steps = " << remaining_steps << std::endl; + } + + + // Update any required reversal potentials based on ionic concs. + for (auto& m: revpot_mechanisms_) { + m->update_current(); + } - // Update event list and integration step times. + // Deliver events and accumulate mechanism current contributions. - state_->update_time_to(dt_max, tfinal); - state_->deliverable_events.event_time_if_before(state_->time_to); - state_->set_dt(); - PL(); + PE(advance_integrate_events); + state_->deliverable_events.mark_until_after(state_->time); + PL(); - // Add stimulus current contributions. - // (Note: performed after dt, time_to calculation, in case we - // want to use mean current contributions as opposed to point - // sample.) + PE(advance_integrate_current_zero); + state_->zero_currents(); + PL(); + for (auto& m: mechanisms_) { + auto state = state_->deliverable_events.marked_events(); + arb_deliverable_event_stream events; + events.n_streams = state.n; + events.begin = state.begin_offset; + events.end = state.end_offset; + events.events = (arb_deliverable_event_data*) state.ev_data; // FIXME(TH): This relies on bit-castability + m->deliver_events(events); + + if (m->kind() == arb_mechanism_kind_gap_junction) { + + auto& ppack = m->ppack_; + auto& pidx = ppack.peer_index; + auto& nidx = ppack.node_index; + + auto m_id = m->mechanism_id(); + auto& peer_ix = peer_ix_map[m_id]; + + if (step > 0 && (state_->time_bounds().first - tmin_prev) < 0.1*dt_max) { + // std::cout << "double found in step = " << step << std::endl + // << " tmin_prev= " << tmin_prev + // << " tmin = " << state_->time_bounds().first << std::endl; + doubles_count += 1; + // std::cout << "doubles count = " << doubles_count << std::endl; + } - PE(advance:integrate:stimuli) - state_->add_stimulus_current(); - PL(); + if (wr_it == 0 && step == 0) { + + peer_ix.clear(); + peer_ix.resize(ppack.width); + peer_ix_reset_map[m_id] = ppack.peer_index; + node_ix_reset_map[m_id] = ppack.node_index; + + for (int i = 0; ippack_.peer_index[i]; + auto node = m->ppack_.node_index[i]; + if (cell_group == m->ppack_.peer_cg[i]) { + peer_ix_group.push_back(peer); + } + else { + peer_ix_group.push_back(node); + } + } + for (auto i=0; ippack_.width; ++i) { + std::tuple cell = {m->ppack_.peer_index[i], m->ppack_.peer_cg[i]}; + peer_ix[i] = cell_to_index_lowered.at({m->ppack_.peer_index[i], m->ppack_.peer_cg[i]}); + } + } - // Take samples at cell time if sample time in this step interval. + // in iteration=0 we do not have a trace yet + if (wr_it == 0) { + // Q: Didn't we want to have the identity here, ie i --> i, instead of the peer index? + // A: Yes, indeed peer_ix_group is either the local peer index or the identity, if not local + + // if (step == 0) { + // std::cout << std::endl; + // std::cout << "----------- before -----------" << std::endl; + // std::cout << "cell group = " << cell_group << std::endl; + // std::cout << "node index = "; + // for (auto i = 0; i< m->ppack_.width; ++i) { + // std::cout << m->ppack_.node_index[i] << " "; + // } + // std::cout << std::endl; + + // std::cout << "peer index = "; + // for (auto i = 0; i< m->ppack_.width; ++i) { + // std::cout << m->ppack_.peer_index[i] << " "; + // } + // std::cout << std::endl; + + // std::cout << "peer cg = "; + // for (auto i = 0; i< m->ppack_.width; ++i) { + // std::cout << m->ppack_.peer_cg[i] << " "; + // } + // std::cout << std::endl; + // std::cout << std::endl; + + // } + + pidx = peer_ix_group.data(); + + // if (step == 0) { + // std::cout << std::endl; + // std::cout << "----------- it 0 -----------" << std::endl; + // std::cout << "cell group = " << cell_group << std::endl; + // std::cout << "node index = "; + // for (auto i = 0; i< m->ppack_.width; ++i) { + // std::cout << m->ppack_.node_index[i] << " "; + // } + // std::cout << std::endl; + + // std::cout << "peer index = "; + // for (auto i = 0; i< m->ppack_.width; ++i) { + // std::cout << m->ppack_.peer_index[i] << " "; + // } + // std::cout << std::endl; + + // std::cout << "peer cg = "; + // for (auto i = 0; i< m->ppack_.width; ++i) { + // std::cout << m->ppack_.peer_cg[i] << " "; + // } + // std::cout << std::endl; + // std::cout << std::endl; + // } + } + if (wr_it > 0) { + // Here we expect peer_ix to be the index into the trace structure, regardless of local + // or not. However, the selection of the trace must take locality into account. + // Q: How to select? + + auto trace_prev_step = trace_prev[step - doubles_count]; + if (cell_group == 0 && err_count == 0) { + // std::cout << "state_->voltage at it " << wr_it << " step " << step << " = "; + // for(int i = 0; ivoltage.size(); ++i){ + // std::cout << state_->voltage[i]<< " "; + // } + // std::cout << std::endl; + // std::cout << " index to cell : " << std::endl; + // i -> {gid, cg, cv, lid} + // for (auto item : index_to_cell) { + // std::cout << item.first << " -> {" << std::get<0>(item.second) << ", " << std::get<1>(item.second) << ", "<< std::get<2>(item.second) << ", "<< std::get<3>(item.second) << "}"<< std::endl; + // } + } + + if (step == 0) { + std::cout << "peer ix = "; + for (auto i = 0; i< peer_ix.size(); ++i) { + std::cout << peer_ix[i] << " "; + } + std::cout << std::endl; + std::cout << std::endl; + + for (auto item : index_to_cell) { + std::cout << item.first << " -> {" << std::get<0>(item.second) << ", " << std::get<1>(item.second) << ", "<< std::get<2>(item.second) << ", "<< std::get<3>(item.second) << "}"<< std::endl; + } + } + + for (auto i = 0; ippack_.width; ++i) { + if ( m->ppack_.peer_cg[i] == cell_group ) { + //std::cout << "peer ix[" << i << "] = " << peer_ix[i] << " -> " << std::get<2>(index_to_cell[m->ppack_.peer_index[i]]) << std::endl; + //trace_prev[step][peer_ix[i]] = state_->voltage[std::get<2>(index_to_cell[m->ppack_.peer_index[i]])]; + trace_prev_step[peer_ix[i]] = state_->voltage[std::get<2>(index_to_cell[peer_ix[i]])]; + //trace_prev_step[peer_ix[i]] = state_->voltage[std::get<2>(index_to_cell[m->ppack_.peer_index[i]])]; + + if (cell_group == 0 && step == 0) { + std::cout << "peer_ix = " << peer_ix[i] << " volt index = " << std::get<2>(index_to_cell[peer_ix[i]]) << std::endl; + } + + } + } + + m->ppack_.peer_index = peer_ix.data(); + m->ppack_.vec_v_peer = trace_prev_step.data(); + + if (step < 3) { + std::cout << std::endl; + std::cout << "----------- it "<< wr_it << " -----------" << std::endl; + std::cout << "step = " << step << std::endl; + std::cout << "cell group = " << cell_group << std::endl; + std::cout << "wr_it = " << wr_it << std::endl; + // std::cout << "node index = "; + // for (auto i = 0; i< m->ppack_.width; ++i) { + // std::cout << m->ppack_.node_index[i] << " "; + // } + // std::cout << std::endl; + + // std::cout << "peer index = "; + // for (auto i = 0; i< m->ppack_.width; ++i) { + // std::cout << m->ppack_.peer_index[i] << " "; + // } + // std::cout << std::endl; + + // std::cout << "state_->voltage = "; + // for (auto i = 0; i< state_->voltage.size(); ++i) { + // std::cout << state_->voltage[i] << " "; + // } + // std::cout << std::endl; + + // std::cout << "state_->current_density = "; + // for (auto i = 0; i< state_->current_density.size(); ++i) { + // std::cout << state_->current_density[i] << " "; + // } + // std::cout << std::endl; + + // std::cout << "vec_v_peer[peer_index] = "; + // for (auto i = 0; ippack_.width; ++i) { + // std::cout << m->ppack_.vec_v_peer[m->ppack_.peer_index[i]] << " "; + // } + // std::cout << std::endl; + + // std::cout << "vec_v[node_index] = "; + // for (auto i = 0; ippack_.width; ++i) { + // std::cout << m->ppack_.vec_v[m->ppack_.node_index[i]] << " "; + // } + // std::cout << std::endl; + + // std::cout << "vec_g[i] = "; + // for (auto i = 0; ippack_.width; ++i) { + // std::cout << m->ppack_.parameters[0][i] << " "; + // } + // std::cout << std::endl; + // std::cout << "DELTA V = "; + // for (int i = 0; ippack_.width; ++i) { + // auto dv = m->ppack_.vec_v[m->ppack_.node_index[i]] - m->ppack_.vec_v_peer[m->ppack_.peer_index[i]]; + // std::cout << dv << " "; + // } + // std::cout << std::endl; + + // std::cout << std::endl; + // } + + // if (err_count == 0 && (cell_group == 0 && wr_it > 0)) { + // std::cout << "wr_it = " << wr_it << " step = " << step << " voltage = ["; + // for (int i = 0; i < state_->voltage.size(); ++i) { + // std::cout << state_->voltage[i] << " "; + // } + // std::cout << std::endl; + // } + + + // if (err_count == 0 && (cell_group == 0 && wr_it > 0)) { + // std::cout << std::endl; + // std::cout << " ------- " << std::endl; + // std::cout << "wr_it = " << wr_it << " step = " << step << " trace_prev_step before mod = ["; + // for (int i = 0; i < trace_prev_step.size(); ++i) { + // std::cout << trace_prev[step - doubles_count][i] << " "; + // } + // std::cout << std::endl; + // std::cout << std::endl; + // std::cout << "wr_it = " << wr_it << " step = " << step << " vec_v_peer = ["; + // for (int i = 0; i < trace_prev_step.size(); ++i) { + // std::cout << m->ppack_.vec_v_peer[i] << " "; + // } + // std::cout << std::endl; + // std::cout << " ------- " << std::endl; + // std::cout << std::endl; + + // } + + + // if (wr_it == 1 && (step == 60 && cell_group == 0)) { + // std::cout << "cell group = " << cell_group << std::endl; + // std::cout << "peer index = "; + // for (int i = 0; ippack_.width; ++i){ + // std::cout << m->ppack_.peer_index[i] << ", "; + // } + // std::cout << std::endl; + + // std::cout << "vec_v_peer = "; + // for (int i = 0; ippack_.vec_v_peer[i] << ", "; + // } + // std::cout << std::endl; + } - PE(advance:integrate:samples); - sample_events_.mark_until(state_->time_to); - state_->take_samples(sample_events_.marked_events(), sample_time_, sample_value_); - sample_events_.drop_marked_events(); - PL(); + } + } - // Integrate voltage by matrix solve. + + + + m->update_current(); + + // if (cell_group == 0 && wr_it == 1 && (m->kind() == arb_mechanism_kind_density && density_counter == 0)) { + // std::cout << "density step " << step << " state_->current_density = "; + // for (auto i = 0; i< state_->current_density.size(); ++i) { + // std::cout << state_->current_density[i] << " "; + // } + // std::cout << std::endl; + // } + // if (cell_group == 0 && wr_it == 1 && (m->kind() == arb_mechanism_kind_gap_junction && junction_counter == 0)) { + // std::cout << "junction step " << step << " state_->current_density = "; + // for (auto i = 0; i< state_->current_density.size(); ++i) { + // std::cout << state_->current_density[i] << " "; + // } + // std::cout << std::endl; + // } + + // if ( wr_it > 0 && (step < 3 && cell_group == 0 )){ + // std::cout << "state_->voltage after current update of step" << step << " = "; + // for (auto i = 0; i< state_->voltage.size(); ++i) { + // std::cout << state_->voltage[i] << " "; + // } + // std::cout << std::endl; + + // std::cout << "state_->current_density after current update of step" << step << " = "; + // for (auto i = 0; i< state_->current_density.size(); ++i) { + // std::cout << state_->current_density[i] << " "; + // } + // std::cout << std::endl; + + // std::cout << "weight = "; + // for (auto i = 0; ippack_.width; ++i) { + // std::cout << m->ppack_.weight[i] << " "; + // } + // std::cout << std::endl; + // } + + //update traces + if (m->kind() == arb_mechanism_kind_gap_junction) { + + //if(max_remaining_steps-step > 0){ + std::vector v_step; + + // Only record trace if time step of dt has been made + + for (auto ix = 0; ixvoltage[cv]); + if (err_count == 0 && (cell_group == 0 && (state_->voltage[cv] > 100 or isnan(state_->voltage[cv])))) { + std::cout << std::endl; + std::cout << " ------- " << std::endl; + std::cout << "voltage too large at it " << wr_it << " step " << step << std::endl; + std::cout << "peer voltages: "; + for (auto i = 0; i < m->ppack_.width; ++i) { + auto n_temp = m->ppack_.node_index[i]; + auto p_temp = m->ppack_.peer_index[i]; + auto peer_v_temp = m->ppack_.vec_v_peer[p_temp]; + std::cout << "n = " << n_temp << " p = " << p_temp << " vec_v_peer[p_temp] = " << peer_v_temp << std::endl; + } + std::cout << " ------- " << std::endl; + std::cout << std::endl; + + err_count +=1; + } + } + + trace.push_back(v_step); + trace_t.push_back(state_->time_bounds().first); + + // if (cell_group == 0 && (step > 0 && ((state_->time_bounds().first - tmin_prev) < 0.1*dt_max))){ + // std::cout << " step = " << step + // << " tmin_prev= " << tmin_prev + // << " tmin = " << state_->time_bounds().first + // << " tfinal = " << tfinal + // << " remaining steps = " << remaining_steps + // << " tmin real = " << state_->time_bounds().first + dt_max + // << " new remaining steps = " << dt_steps(state_->time_bounds().first+dt_max, tfinal, dt_max) + // << " dtmax = "<< dt_max + // << " diff = " << state_->time_bounds().first - tmin_prev << std::endl; + // } + //} + + //todo eliminate mechanism id + auto gj = m->mechanism_id(); + m->ppack_.peer_index = peer_ix_reset_map[gj]; + m->ppack_.node_index = node_ix_reset_map[gj]; + } + } + tmin_prev = state_->time_bounds().first; - PE(advance:integrate:matrix:build); - matrix_.assemble(state_->dt_intdom, state_->voltage, state_->current_density, state_->conductivity); - PL(); - PE(advance:integrate:matrix:solve); - matrix_.solve(state_->voltage); - PL(); + PE(advance_integrate_events); + state_->deliverable_events.drop_marked_events(); - // Integrate mechanism state. + // Update event list and integration step times. - for (auto& m: mechanisms_) { - m->update_state(); - } + state_->update_time_to(dt_max, tfinal); + state_->deliverable_events.event_time_if_before(state_->time_to); + state_->set_dt(); + PL(); - // Update ion concentrations. + // Add stimulus current contributions. + // (Note: performed after dt, time_to calculation, in case we + // want to use mean current contributions as opposed to point + // sample.) - PE(advance:integrate:ionupdate); - update_ion_state(); - PL(); + PE(advance_integrate_stimuli) + state_->add_stimulus_current(); + PL(); + + // Take samples at cell time if sample time in this step interval. - // Update time and test for spike threshold crossings. + PE(advance_integrate_samples); + sample_events_.mark_until(state_->time_to); + state_->take_samples(sample_events_.marked_events(), sample_time_, sample_value_); + sample_events_.drop_marked_events(); + PL(); + + // Integrate voltage by matrix solve. + + if (step < 3 && cell_group == 0) { + std::cout << "dt_intdom = "; + for (auto i = 0; idt_intdom.size(); ++i) { + std::cout << state_->dt_intdom[i] << " "; + } + std::cout << std::endl; + + std::cout << "voltage = "; + for (auto i = 0; ivoltage.size(); ++i) { + std::cout << state_->voltage[i] << " "; + } + std::cout << std::endl; - PE(advance:integrate:threshold); - threshold_watcher_.test(&state_->time_since_spike); - PL(); + std::cout << "current_density = "; + for (auto i = 0; icurrent_density.size(); ++i) { + std::cout << state_->current_density[i] << " "; + } + std::cout << std::endl; + + std::cout << "conductivity = "; + for (auto i = 0; iconductivity.size(); ++i) { + std::cout << state_->conductivity[i] << " "; + } + std::cout << std::endl; + + + + } + + PE(advance_integrate_matrix_build); + matrix_.assemble(state_->dt_intdom, state_->voltage, state_->current_density, state_->conductivity); + PL(); + PE(advance_integrate_matrix_solve); + matrix_.solve(state_->voltage); + PL(); + + if (step <3 && cell_group == 0 ){ + std::cout << std::endl; + std::cout << "state_->voltage after matrix solve of step" << step << " = "; + for (auto i = 0; i< state_->voltage.size(); ++i) { + std::cout << state_->voltage[i] << " "; + } + std::cout << std::endl; + } + + // Integrate mechanism state. - PE(advance:integrate:post) - if (post_events_) { for (auto& m: mechanisms_) { - m->post_event(); + m->update_state(); } - } - PL(); - std::swap(state_->time_to, state_->time); - state_->time_ptr = state_->time.data(); + // Update ion concentrations. + + PE(advance_integrate_ionupdate); + update_ion_state(); + PL(); - // Check for non-physical solutions: + // Update time and test for spike threshold crossings. - if (check_voltage_mV_>0) { - PE(advance:integrate:physicalcheck); - assert_voltage_bounded(check_voltage_mV_); + PE(advance_integrate_threshold); + threshold_watcher_.test(&state_->time_since_spike); PL(); + + PE(advance_integrate_post) + if (post_events_) { + for (auto& m: mechanisms_) { + m->post_event(); + } + } + PL(); + + std::swap(state_->time_to, state_->time); + state_->time_ptr = state_->time.data(); + + // Check for non-physical solutions: + + if (check_voltage_mV_>0) { + PE(advance_integrate_physicalcheck); + assert_voltage_bounded(check_voltage_mV_); + PL(); + } + + // Check for end of integration. + PE(advance_integrate_stepsupdate); + if (!--remaining_steps) { + auto tmin_old = tmin_; + tmin_ = state_->time_bounds().first; + remaining_steps = dt_steps(tmin_, tfinal, dt_max); + // if (cell_group == 0) { + // std::cout << " id=" << context_.distributed->id() + // << " tmin_=" << tmin_ + // << " tmin_old=" << tmin_old + // << " remaining steps=" << remaining_steps + // << " step=" << step + // << " tfinal=" << tfinal + // << " dt_max=" << dt_max << std::endl; + // } + } + if (cell_group == 0 && remaining_steps == 0) { + // std::cout << " tmin_ = " << tmin_ << std::endl; + // std::cout << " tmin_prev = " << tmin_prev << std::endl; + // std::cout << " state_->time_bounds() = " << state_->time_bounds().first << std::endl; + if (tmin_prev != tfinal) { + std::cout << "tmin != tfinal -> need another step" << std::endl; + remaining_steps = 1; + } + } + PL(); + step++; + } // end of integration + std::cout << "group " << cell_group << " end of it " << wr_it << std::endl; + + //Gather traces from all groups for next iteration + trace_prev = {}; + auto count_clean = 1; + auto count_total = 1; + std::cout << " len trace = " << trace.size() << std::endl; + std::cout << " len trace t = " << trace_t.size() << std::endl; + + + trace_clean.push_back(trace[0]); + trace_t_clean.push_back(trace_t[0]); + trace_wr_it_clean.push_back(wr_it); + + std::cout << "wr it = " << wr_it << std::endl; + + for (int t=1; t 0.1*dt_max){ + trace_clean.push_back(trace[t]); + trace_t_clean.push_back(trace_t[t]); + trace_wr_it_clean.push_back(wr_it); + count_clean += 1; + } + count_total +=1; + } + if (cell_group == 0){ + std::cout << "group = " << cell_group << " trace.size() = " << trace.size() << std::endl; + std::cout << "group = " << cell_group << " trace_clean.size() = " << trace_clean.size() << std::endl; + std::cout << "group = " << cell_group << " trace_t.size() = " << trace_t.size() << std::endl; + std::cout << "group = " << cell_group << " trace_t_clean.size() = " << trace_t_clean.size() << std::endl; + } + // if (cell_group == 0) { + // std::cout << "trace_t_clean = " << trace_t_clean[0] << ", " << trace_t_clean[1] << ",..." << std::endl; + // } + + for (int i = 1; itime_bounds().first; - remaining_steps = dt_steps(tmin_, tfinal, dt_max); + for (int st = 0; stdt_intdom.begin()); + std::copy(dt_cv_r.begin(), dt_cv_r.end(), state_->dt_cv.begin()); + + std::copy(time_r.begin(), time_r.end(), state_->time.begin()); + std::copy(time_to_r.begin(), time_to_r.end(), state_->time_to.begin()); + state_->dt_intdom = dt_intdom_r; + state_->dt_cv = dt_cv_r; + state_->current_density = current_density_r; + state_->conductivity = conductivity_r; + + std::copy(voltage_r.begin(), voltage_r.end(), state_->voltage.begin()); + + std::copy(cv_to_intdom_r.begin(), cv_to_intdom_r.end(), state_->cv_to_intdom.begin()); + std::copy(cv_to_cell_r.begin(), cv_to_cell_r.end(), state_->cv_to_cell.begin()); + + std::copy(current_density_r.begin(), current_density_r.end(), state_->current_density.begin()); + std::copy(conductivity_r.begin(), conductivity_r.end(), state_->conductivity.begin()); + + std::copy(time_since_spike_r.begin(), time_since_spike_r.end(), state_->time_since_spike.begin()); + std::copy(src_to_spike_r.begin(), src_to_spike_r.end(), state_->src_to_spike.begin()); + + //state_->stim_data = stim_data_r; + //state_->ion_data = ion_data_r;// + state_->deliverable_events = deliverable_events_r; + + for (auto store: state_->storage) { + std::copy(storage_r[store.first].data_.begin(), storage_r[store.first].data_.end(), store.second.data_.begin()); + std::copy(storage_r[store.first].indices_.begin(), storage_r[store.first].indices_.end(), store.second.indices_.begin()); + //store.second.constraints_ = storage_r[store.first].constraints_;// + std::copy(storage_r[store.first].ion_states_.begin(), storage_r[store.first].ion_states_.end(), store.second.ion_states_.begin()); + for (auto i = 0; i::initialize( sample_events_ = sample_event_stream(nintdom); // Discretize and build gap junction info. - - auto gj_cvs = fvm_build_gap_junction_cv_map(cells, gids, D); - auto gj_conns = fvm_resolve_gj_connections(gids, fvm_info.gap_junction_data, gj_cvs, rec); + // 1) split cg cv map into 4 arrays instead of one unordered map + std::vector> cv_cg_arr = fvm_build_gap_junction_cv_arr(cells, gids, cell_group, D); + + // 2) gather separately + auto gids_gathered = context_.distributed->gather_cg_cv_map(cv_cg_arr[0]); + auto lids_gathered = context_.distributed->gather_cg_cv_map(cv_cg_arr[1]); + auto cgs_gathered = context_.distributed->gather_cg_cv_map(cv_cg_arr[2]); + auto cvs_gathered = context_.distributed->gather_cg_cv_map(cv_cg_arr[3]); + + cvs_local = cv_cg_arr[3]; + + //3) cell to index map {gid, cg, cv} -> ix and ix -> {gid, cg, cv} + cell_to_index = fvm_cell_to_index(gids_gathered, cgs_gathered, cvs_gathered, lids_gathered); + index_to_cell = fvm_index_to_cell(cell_to_index); + cell_to_index_lowered = fvm_cell_to_index_lowered(cgs_gathered, cvs_gathered); + + //equivalent to previous gj_cvs = { {gid, lid} : cv } but with a global index + std::unordered_map gj_cvs_index = fvm_index_to_cv_map(gids_gathered, lids_gathered, cgs_gathered, cvs_gathered, cell_to_index); + //std::unordered_map gj_cvs_index = fvm_index_to_cv_map(gids_gathered, lids_gathered, cgs_gathered, cvs_gathered, cell_to_index); + + //resolve gap junctions with global index map + //resolution map = combination of cell_label_range and gids -> gather gids and cell_label_ranges before feeding to fvm_resolve_gap_junction_connection + //resolution map = cell_labels_and_gids + + //gather gids + std::vector resolution_gids_gathered = context_.distributed->gather_gids(gids).values(); + + //gather cell_label_ranges + cell_label_range gj_data_gathered = context_.distributed->gather_cell_label_range(fvm_info.gap_junction_data); + + //create resolution map with gathered label_ranges and gids + std::unordered_map> gj_conns = fvm_resolve_gj_connections(resolution_gids_gathered, gj_data_gathered, gj_cvs_index, rec); // Discretize mechanism data. @@ -531,6 +1240,7 @@ fvm_initialization_data fvm_lowered_cell_impl::initialize( layout.multiplicity = config.multiplicity; layout.peer_cv = config.peer_cv; layout.weight.resize(layout.cv.size()); + layout.peer_cg = config.peer_cg; std::vector multiplicity_divs; auto multiplicity_part = util::make_partition(multiplicity_divs, layout.multiplicity); @@ -1116,4 +1826,4 @@ void resolve_probe(const cable_probe_ion_ext_concentration_cell& p, probe_resolu resolve_ion_conc_common(R.M.ions.at(p.ion).cv, R.state->ion_data.at(p.ion).Xo_.data(), R); } -} // namespace arb +} // namespace arb \ No newline at end of file diff --git a/arbor/include/arbor/fvm_types.hpp b/arbor/include/arbor/fvm_types.hpp index 1b8849656b..c3f983c5b1 100644 --- a/arbor/include/arbor/fvm_types.hpp +++ b/arbor/include/arbor/fvm_types.hpp @@ -14,7 +14,8 @@ struct fvm_gap_junction { fvm_size_type local_cv; // CV index of the local gap junction site. fvm_size_type peer_cv; // CV index of the peer gap junction site. fvm_value_type weight; // unit-less local weight of the connection. + fvm_size_type peer_cg; // Cell group of peer idx }; -ARB_DEFINE_LEXICOGRAPHIC_ORDERING(fvm_gap_junction, (a.local_cv, a.peer_cv, a.local_idx, a.weight), (b.local_cv, b.peer_cv, b.local_idx, b.weight)) +ARB_DEFINE_LEXICOGRAPHIC_ORDERING(fvm_gap_junction, (a.local_cv, a.peer_cv, a.local_idx, a.weight, a.peer_cg), (b.local_cv, b.peer_cv, b.local_idx, b.weight, b.peer_cg)) } // namespace arb diff --git a/arbor/include/arbor/mechanism.hpp b/arbor/include/arbor/mechanism.hpp index d5cc221dcc..f0f92ca972 100644 --- a/arbor/include/arbor/mechanism.hpp +++ b/arbor/include/arbor/mechanism.hpp @@ -96,6 +96,9 @@ struct mechanism_layout { // Maps in-instance index to peer CV index (only for gap-junctions). std::vector peer_cv; + // Cell group of peer index + std::vector peer_cg; + // Maps in-instance index to compartment contribution. std::vector weight; diff --git a/arbor/include/arbor/mechanism_abi.h b/arbor/include/arbor/mechanism_abi.h index 69e11c7c47..88aefe1e9b 100644 --- a/arbor/include/arbor/mechanism_abi.h +++ b/arbor/include/arbor/mechanism_abi.h @@ -86,6 +86,7 @@ typedef struct arb_mechanism_ppack { const arb_value_type* vec_t; arb_value_type* vec_dt; arb_value_type* vec_v; + arb_value_type* vec_v_peer; // Other side of gap junction, default to vec_v. arb_value_type* vec_i; arb_value_type* vec_g; arb_value_type* temperature_degC; @@ -93,6 +94,7 @@ typedef struct arb_mechanism_ppack { arb_value_type* time_since_spike; arb_index_type* node_index; arb_index_type* peer_index; + arb_index_type* peer_cg; // Cell group for peer index arb_index_type* multiplicity; arb_value_type* weight; arb_size_type mechanism_id; diff --git a/arbor/mc_cell_group.cpp b/arbor/mc_cell_group.cpp index 8d140e6edc..90a41989cc 100644 --- a/arbor/mc_cell_group.cpp +++ b/arbor/mc_cell_group.cpp @@ -34,9 +34,11 @@ mc_cell_group::mc_cell_group(const std::vector& gids, const recipe& rec, cell_label_range& cg_sources, cell_label_range& cg_targets, + cell_gid_type cg, fvm_lowered_cell_ptr lowered): gids_(gids), lowered_(std::move(lowered)) { + //std::cout << "MC Cell Group in da House\n"; // Default to no binning of events set_binning_policy(binning_kind::none, 0); diff --git a/arbor/mc_cell_group.hpp b/arbor/mc_cell_group.hpp index 3640c8cf24..d57fc2dc27 100644 --- a/arbor/mc_cell_group.hpp +++ b/arbor/mc_cell_group.hpp @@ -32,6 +32,7 @@ class ARB_ARBOR_API mc_cell_group: public cell_group { const recipe& rec, cell_label_range& cg_sources, cell_label_range& cg_targets, + cell_gid_type cg, fvm_lowered_cell_ptr lowered); cell_kind get_cell_kind() const override { diff --git a/arbor/simulation.cpp b/arbor/simulation.cpp index 940de4ba2b..d80ef17882 100644 --- a/arbor/simulation.cpp +++ b/arbor/simulation.cpp @@ -192,26 +192,33 @@ simulation_state::simulation_state( cell_groups_.resize(decomp.num_groups()); std::vector cg_sources(cell_groups_.size()); std::vector cg_targets(cell_groups_.size()); - foreach_group_index( - [&](cell_group_ptr& group, int i) { - const auto& group_info = decomp.group(i); + + //std::cout << "Rank=" << ctx.distributed->id() << " will build " << cell_groups_.size() << " groups\n"; + + for (int ix = 0; ix < cell_groups_.size(); ++ix) { + auto fn = [&](cell_group_ptr& group, int cg) { + const auto& group_info = decomp.group(cg); cell_label_range sources, targets; - auto factory = cell_kind_implementation(group_info.kind, group_info.backend, ctx); + + //std::cout << "Rank=" << ctx.distributed->id() << " seizing the means of production.\n"; + auto factory = cell_kind_implementation(group_info.kind, cg, group_info.backend, ctx); group = factory(group_info.gids, rec, sources, targets); - cg_sources[i] = cell_labels_and_gids(std::move(sources), group_info.gids); - cg_targets[i] = cell_labels_and_gids(std::move(targets), group_info.gids); - }); - + cg_sources[cg] = cell_labels_and_gids(std::move(sources), group_info.gids); + cg_targets[cg] = cell_labels_and_gids(std::move(targets), group_info.gids); + }; + fn(cell_groups_[ix], ix); + } cell_labels_and_gids local_sources, local_targets; for(const auto& i: util::make_span(cell_groups_.size())) { local_sources.append(cg_sources.at(i)); local_targets.append(cg_targets.at(i)); } auto global_sources = ctx.distributed->gather_cell_labels_and_gids(local_sources); + auto global_targets = ctx.distributed->gather_cell_labels_and_gids(local_targets); auto source_resolution_map = label_resolution_map(std::move(global_sources)); - auto target_resolution_map = label_resolution_map(std::move(local_targets)); + auto target_resolution_map = label_resolution_map(std::move(global_targets)); communicator_ = arb::communicator(rec, decomp, source_resolution_map, target_resolution_map, ctx); @@ -514,6 +521,7 @@ simulation::simulation( const domain_decomposition& decomp, const context& ctx) { + //std::cout << "Rank=" << ctx->distributed->id() << " setting up state.\n"; impl_.reset(new simulation_state(rec, decomp, *ctx)); } diff --git a/modcc/printer/cprinter.cpp b/modcc/printer/cprinter.cpp index 2758e43cfc..ea8055fc4b 100644 --- a/modcc/printer/cprinter.cpp +++ b/modcc/printer/cprinter.cpp @@ -248,6 +248,8 @@ ARB_LIBMODCC_API std::string emit_cpp_source(const Module& module_, const printe "[[maybe_unused]] auto* {0}vec_t = pp->vec_t;\\\n" "[[maybe_unused]] auto* {0}vec_dt = pp->vec_dt;\\\n" "[[maybe_unused]] auto* {0}vec_v = pp->vec_v;\\\n" + "[[maybe_unused]] auto* {0}vec_v_peer = pp->vec_v_peer;\\\n" + "[[maybe_unused]] auto* {0}peer_cg = pp->peer_cg;\\\n" "[[maybe_unused]] auto* {0}vec_i = pp->vec_i;\\\n" "[[maybe_unused]] auto* {0}vec_g = pp->vec_g;\\\n" "[[maybe_unused]] auto* {0}temperature_degC = pp->temperature_degC;\\\n" diff --git a/modcc/printer/printerutil.cpp b/modcc/printer/printerutil.cpp index 72f208e998..9283103a94 100644 --- a/modcc/printer/printerutil.cpp +++ b/modcc/printer/printerutil.cpp @@ -155,7 +155,7 @@ ARB_LIBMODCC_API indexed_variable_info decode_indexed_variable(IndexedVariable* v.readonly = true; break; case sourceKind::peer_voltage: - v.data_var="vec_v"; + v.data_var="vec_v_peer"; v.other_index_var = "peer_index"; v.node_index_var = ""; v.index_var_kind = index_kind::other; diff --git a/python/example/gap_junctions.py b/python/example/gap_junctions.py index 8985b3f330..20f50c4c1f 100644 --- a/python/example/gap_junctions.py +++ b/python/example/gap_junctions.py @@ -3,6 +3,7 @@ import arbor import pandas, seaborn import matplotlib.pyplot as plt +import mpi4py.MPI as mpi # Construct chains of cells linked with gap junctions, # Chains are connected by synapses. @@ -22,11 +23,11 @@ def make_cable_cell(gid): # Build a segment tree tree = arbor.segment_tree() - # Soma with radius 5 μm and length 2 * radius = 10 μm, (tag = 1) - s = tree.append(arbor.mnpos, arbor.mpoint(-10, 0, 0, 5), arbor.mpoint(0, 0, 0, 5), tag=1) + # Soma with radius 5 μm and length 2 * radius = 10 m, (tag = 1) + s = tree.append(arbor.mnpos, arbor.mpoint(-10, 0, 0, 5), arbor.mpoint(10, 0, 0, 5), tag=1) # Single dendrite with radius 2 μm and length 40 μm, (tag = 2) - b = tree.append(s, arbor.mpoint(0, 0, 0, 2), arbor.mpoint(40, 0, 0, 2), tag=2) + b = tree.append(s, arbor.mpoint(0, 0, 0, 2), arbor.mpoint(2000, 0, 0, 2), tag=2) # Label dictionary for cell components labels = arbor.label_dict() @@ -37,7 +38,8 @@ def make_cable_cell(gid): labels['synapse_site'] = '(location 0 0.6)' # Gap junction site at connection point of soma and dendrite - labels['gj_site'] = '(location 0 0.2)' + labels['gj_site_0'] = '(location 0 0.1)' + labels['gj_site_1'] = '(location 0 0.99)' # Label root of the tree labels['root'] = '(root)' @@ -47,9 +49,16 @@ def make_cable_cell(gid): decor.paint('"soma"', arbor.density("hh")) decor.paint('"dend"', arbor.density("pas")) + #split into multiple cvs + policy = arbor.cv_policy_explicit('(location 0 0.35)') + policy = arbor.cv_policy_single() + #decor.discretization(policy) + #decor.discretization("(max-extent 9)") + # Attach one synapse and gap junction each on their labeled sites decor.place('"synapse_site"', arbor.synapse('expsyn'), 'syn') - decor.place('"gj_site"', arbor.junction('gj'), 'gj') + decor.place('"gj_site_0"', arbor.junction('gj'), 'gj_0') + decor.place('"gj_site_1"', arbor.junction('gj'), 'gj_1') # Attach spike detector to cell root decor.place('"root"', arbor.spike_detector(-10), 'detector') @@ -78,13 +87,14 @@ def cell_kind(self, gid): # Create synapse connection between last cell of one chain and first cell of following chain def connections_on(self, gid): - if (gid == 0) or (gid % self.ncells_per_chain > 0): - return [] - else: - src = gid-1 - w = 0.05 - d = 10 + #return [] + if (gid == 0): + src = gid+1 + w = 0.0 # 0.01 μS on expsyn + d = 50 # ms delay return [arbor.connection((src,'detector'), 'syn', w, d)] + else: + return [] # Create gap junction connections between a cell within a chain and its neighbor(s) def gap_junctions_on(self, gid): @@ -96,33 +106,49 @@ def gap_junctions_on(self, gid): next_cell = gid + 1 prev_cell = gid - 1 - if next_cell < chain_end: - conns.append(arbor.gap_junction_connection((gid+1, 'gj'), 'gj', 0.015)) - if prev_cell >= chain_begin: - conns.append(arbor.gap_junction_connection((gid-1, 'gj'), 'gj', 0.015)) + if (gid < self.ncells_per_chain - 1): + conns.append(arbor.gap_junction_connection((gid+1, 'gj_0'), 'gj_1', 0.7)) + if (gid > 0): + conns.append(arbor.gap_junction_connection((gid-1, 'gj_1'), 'gj_0', 0.7)) + if gid == 0: + conns.append(arbor.gap_junction_connection((ncells_per_chain-1, 'gj_1'), 'gj_0', 0.7)) + if gid == ncells_per_chain-1: + conns.append(arbor.gap_junction_connection((0, 'gj_0'), 'gj_1', 0.7)) + + + + # if next_cell < chain_end: + # conns.append(arbor.gap_junction_connection((gid+1, 'gj_0'), 'gj_0', 0.15)) + # if prev_cell >= chain_begin: + # conns.append(arbor.gap_junction_connection((gid-1, 'gj_0'), 'gj_0', 0.15)) + + return conns # Event generator at first cell def event_generators(self, gid): - if gid==0: - sched = arbor.explicit_schedule([1]) - weight = 0.1 + if (gid == 0): + #sched = arbor.explicit_schedule([0, 40, 80, 110, 160, 200]) + sched = arbor.explicit_schedule([0]) + #sched = arbor.regular_schedule(0, 10, 71) + weight = 0.5 return [arbor.event_generator('syn', weight, sched)] return [] # Place a probe at the root of each cell def probes(self, gid): + return [] return [arbor.cable_probe_membrane_voltage('"root"')] def global_properties(self, kind): return self.props # Number of cells per chain -ncells_per_chain = 5 +ncells_per_chain = 400 # Number of chains -nchains = 3 +nchains = 1 # Total number of cells ncells = nchains * ncells_per_chain @@ -130,33 +156,63 @@ def global_properties(self, kind): #Instantiate recipe recipe = chain_recipe(ncells_per_chain, nchains) -# Create a default execution context, domain decomposition and simulation -context = arbor.context() -decomp = arbor.partition_load_balance(recipe, context) +alloc = arbor.proc_allocation(1, None) +comm = mpi.COMM_WORLD +#print(f"rank={comm.rank} size={comm.size}") + +xs = [comm.rank]*(comm.rank + 1) +gxs = comm.allgather(xs) +print(gxs) + +context = arbor.context(alloc, comm) +print(context) + +g = [] +for i in range(ncells_per_chain) : + g.append(i) + +if comm.rank == 0: + # gs = [[0,1]] + gs = [g] +# elif comm.rank == 1: +# gs = [[2,3]] + + +groups = [arbor.group_description(arbor.cell_kind.cable, g, arbor.backend.multicore) for g in gs] +decomp = arbor.partition_by_group(recipe, context, groups) + sim = arbor.simulation(recipe, decomp, context) +dt = 0.025 + +sim.set_binning_policy(arbor.binning.regular, dt) + # Set spike generators to record -sim.record(arbor.spike_recording.all) +#sim.record(arbor.spike_recording.all) # Sampler -handles = [sim.sample((gid, 0), arbor.regular_schedule(0.1)) for gid in range(ncells)] +#handles = [sim.sample((gid, 0), arbor.regular_schedule(0.1)) for gid in range(ncells)] # Run simulation for 100 ms -sim.run(100) -print('Simulation finished') +sim.run(100, dt=dt) +if comm.rank == 0: + print('Simulation finished rank 0') +else: + print('Simulation finished rank 1') + # Print spike times -print('spikes:') -for sp in sim.spikes(): - print(' ', sp) +#print('spikes:') +#for sp in sim.spikes(): +# print(' ', sp) # Plot the results -print("Plotting results ...") -df_list = [] -for gid in range(ncells): - samples, meta = sim.samples(handles[gid])[0] - df_list.append(pandas.DataFrame({'t/ms': samples[:, 0], 'U/mV': samples[:, 1], 'Cell': f"cell {gid}"})) - -df = pandas.concat(df_list,ignore_index=True) -seaborn.relplot(data=df, kind="line", x="t/ms", y="U/mV",hue="Cell",ci=None) -plt.show() +#print("Plotting results ...") +#df_list = [] +#for gid in range(ncells): +# samples, meta = sim.samples(handles[gid])[0] +# df_list.append(pandas.DataFrame({'t/ms': samples[:, 0], 'U/mV': samples[:, 1], 'Cell': f"cell {gid}"})) + +#df = pandas.concat(df_list,ignore_index=True) +#seaborn.relplot(data=df, kind="line", x="t/ms", y="U/mV",hue="Cell",ci=None) +#plt.show() diff --git a/test/unit/test_backend.cpp b/test/unit/test_backend.cpp index b43b755428..3de0a0535d 100644 --- a/test/unit/test_backend.cpp +++ b/test/unit/test_backend.cpp @@ -10,9 +10,10 @@ using namespace arb; TEST(backends, gpu_test) { execution_context context; + cell_gid_type cg; #ifdef ARB_GPU_ENABLED - EXPECT_NO_THROW(make_fvm_lowered_cell(backend_kind::gpu, context)); + EXPECT_NO_THROW(make_fvm_lowered_cell(backend_kind::gpu, context, cg)); #else - EXPECT_ANY_THROW(make_fvm_lowered_cell(backend_kind::gpu, context)); + EXPECT_ANY_THROW(make_fvm_lowered_cell(backend_kind::gpu, context, cg)); #endif } diff --git a/test/unit/test_mc_cell_group.cpp b/test/unit/test_mc_cell_group.cpp index 3744968227..d9bab8d50f 100644 --- a/test/unit/test_mc_cell_group.cpp +++ b/test/unit/test_mc_cell_group.cpp @@ -18,9 +18,10 @@ using namespace arborio::literals; namespace { execution_context context; + cell_gid_type cg; fvm_lowered_cell_ptr lowered_cell() { - return make_fvm_lowered_cell(backend_kind::multicore, context); + return make_fvm_lowered_cell(backend_kind::multicore, context, cg); } cable_cell_description make_cell() { @@ -43,7 +44,8 @@ ACCESS_BIND( TEST(mc_cell_group, get_kind) { cable_cell cell = make_cell(); cell_label_range srcs, tgts; - mc_cell_group group{{0}, cable1d_recipe({cell}), srcs, tgts, lowered_cell()}; + cell_gid_type cg; + mc_cell_group group{{0}, cable1d_recipe({cell}), srcs, tgts, cg, lowered_cell()}; EXPECT_EQ(cell_kind::cable, group.get_cell_kind()); } @@ -56,7 +58,8 @@ TEST(mc_cell_group, test) { rec.nernst_ion("k"); cell_label_range srcs, tgts; - mc_cell_group group{{0}, rec, srcs, tgts, lowered_cell()}; + cell_gid_type cg; + mc_cell_group group{{0}, rec, srcs, tgts, cg, lowered_cell()}; group.advance(epoch(0, 0., 50.), 0.01, {}); // Model is expected to generate 4 spikes as a result of the @@ -86,7 +89,8 @@ TEST(mc_cell_group, sources) { rec.nernst_ion("k"); cell_label_range srcs, tgts; - mc_cell_group group{gids, rec, srcs, tgts, lowered_cell()}; + cell_gid_type cg; + mc_cell_group group{gids, rec, srcs, tgts, cg, lowered_cell()}; // Expect group sources to be lexicographically sorted by source id // with gids in cell group's range and indices starting from zero.