./mpic <pyfiles> -l <log level>
"""
See https://docs.dgl.ai/tutorials/blitz/3_message_passing.html
"""
class BasicLayer(mpi.Module):
def __init__(self, input_dim: int, output_dim: int):
super(mpi.Module, self).__init__(input_dim, output_dim)
self.linear = mpi.Linear(input_dim * 2, output_dim)
self.reset_parameters()
def reset_parameters(self):
self.linear.reset_parameters()
def forward(self, graph: mpi.DENSEGraph, h: mpi.Tensor) -> mpi.Tensor:
with graph.local_scope():
graph.ndata["h"] = h
graph.update_all(
message_func=mpi.copy_u("h", "m"), reduce_func=mpi.mean("m", "h_N")
)
h_N = graph.ndata["h_N"]
h_total = mpi.cat(h, h_N, dim=1)
return self.linear(h_total)
basic_layer.h
//
// Autogenerated file!
//
#pragma once
#include "common/datatypes.h"
#include "configuration/options.h"
#include "configuration/config.h"
#include "data/graph.h"
#include "nn/initialization.h"
#include "gnn_layer.h"
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#include "torch/torch.h"
#pragma GCC diagnostic pop
struct BasicLayerOptions : GNNLayerOptions {
};
class BasicLayer : public GNNLayer {
public:
shared_ptr<BasicLayerOptions> options_;
torch::Tensor linear_;
int _mpic_linear_input_dim_;
int _mpic_linear_output_dim_;
BasicLayer(shared_ptr<LayerConfig> layer_config, torch::Device device);
void reset() override;
torch::Tensor forward(torch::Tensor inputs, DENSEGraph dense_graph, bool train = true) override;
};
basic_layer.cpp
//
// Autogenerated file!
//
#include "nn/layers/gnn/basic_layer.h"
#include "nn/layers/gnn/layer_helpers.h"
#include "reporting/logger.h"
#include "nn/initialization.h"
// XXX: GCC warns on always_inline
#pragma GCC diagnostic ignored "-Wattributes"
// XXX: always_inline to mirror original code (unsure if more efficient)
#define FUNCGEN __attribute__((always_inline))
namespace {
struct SumFunc {
static constexpr bool useNumNbrs = false;
FUNCGEN static torch::Tensor segmented_reduce(torch::Tensor const& embeds, Indices const& offsets) {
return segmented_sum_with_offsets(embeds, offsets);
}
};
struct MaxFunc {
static constexpr bool useNumNbrs = false;
FUNCGEN static torch::Tensor segmented_reduce(torch::Tensor const& embeds, Indices const& offsets) {
return segmented_max_with_offsets(embeds, offsets);
}
};
struct MeanFunc {
static constexpr bool useNumNbrs = true;
FUNCGEN static torch::Tensor segmented_reduce(torch::Tensor const& embeds, Indices const& offsets) {
return segmented_sum_with_offsets(embeds, offsets);
}
FUNCGEN static torch::Tensor applyNumNbrs(torch::Tensor const& a_i, torch::Tensor const& num_nbrs) {
torch::Tensor denominator = torch::where(torch::not_equal(
num_nbrs, 0), num_nbrs, 1).to(a_i.dtype()).unsqueeze(-1);
return a_i / denominator;
}
};
template <class ReduceFunc>
FUNCGEN torch::Tensor update_all(DENSEGraph& dense_graph, torch::Tensor const& u) {
constexpr bool useNumNbrs = ReduceFunc::useNumNbrs;
torch::Tensor a_i;
[[maybe_unused]] torch::Tensor total_num_neighbors;
if (dense_graph.out_neighbors_mapping_.defined()) {
Indices outgoing_neighbors = dense_graph.getNeighborIDs(false, false);
Indices outgoing_neighbor_offsets = dense_graph.getNeighborOffsets(false);
torch::Tensor outgoing_num = dense_graph.getNumNeighbors(false);
torch::Tensor outgoing_embeddings = u.index_select(0, outgoing_neighbors);
a_i = ReduceFunc::segmented_reduce(outgoing_embeddings, outgoing_neighbor_offsets);
// often, aggregation functions require the number of neighbors
if constexpr(useNumNbrs) {
total_num_neighbors = outgoing_num;
}
}
if (dense_graph.in_neighbors_mapping_.defined()) {
Indices incoming_neighbors = dense_graph.getNeighborIDs(true, false);
Indices incoming_neighbor_offsets = dense_graph.getNeighborOffsets(true);
torch::Tensor incoming_num = dense_graph.getNumNeighbors(true);
torch::Tensor incoming_embeddings = u.index_select(0, incoming_neighbors);
if (a_i.defined()) {
a_i = a_i + segmented_sum_with_offsets(incoming_embeddings, incoming_neighbor_offsets);
} else {
a_i = segmented_sum_with_offsets(incoming_embeddings, incoming_neighbor_offsets);
}
// often, aggregation functions require the number of neighbors
if constexpr(useNumNbrs) {
if (total_num_neighbors.defined()) {
total_num_neighbors = total_num_neighbors + incoming_num;
} else {
total_num_neighbors = incoming_num;
}
}
}
if constexpr(useNumNbrs) {
return ReduceFunc::applyNumNbrs(a_i, total_num_neighbors);
} else {
return a_i;
}
}
} // anonymous namespace
BasicLayer::BasicLayer(shared_ptr<LayerConfig> layer_config, torch::Device device) {
config_ = layer_config;
options_ = std::dynamic_pointer_cast<BasicLayerOptions>(config_->options);
input_dim_ = config_->input_dim;
output_dim_ = config_->output_dim;
device_ = device;
_mpic_linear_input_dim_ = input_dim_ * 2;
_mpic_linear_output_dim_ = output_dim_;
reset();
}
void BasicLayer::reset() {
[[maybe_unused]] auto tensor_options = torch::TensorOptions().dtype(torch::kFloat32).device(device_);
linear_ = initialize_tensor(config_->init, {_mpic_linear_output_dim_, _mpic_linear_input_dim_}, tensor_options).set_requires_grad(true);;
if (config_->bias) {
init_bias();
}
}
torch::Tensor BasicLayer::forward(torch::Tensor h, DENSEGraph graph, bool train) {
torch::Tensor _mpic_ndata_h;
torch::Tensor _mpic_ndata_h_N;
torch::Tensor h_N;
torch::Tensor h_total;
_mpic_ndata_h = h;
_mpic_ndata_h_N = update_all<MeanFunc>(graph, _mpic_ndata_h);;
h_N = _mpic_ndata_h_N;
h_total = torch::cat({h, h_N}, 1);
return torch::matmul(linear_, h_total.transpose(0, -1)).transpose(0, -1);
}