Skip to content

bvskp-projects/cs744-mpi

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

11 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Marius Script

Usage

./mpic <pyfiles> -l <log level>

Example

Input

"""
See https://docs.dgl.ai/tutorials/blitz/3_message_passing.html
"""


class BasicLayer(mpi.Module):
    def __init__(self, input_dim: int, output_dim: int):
        super(mpi.Module, self).__init__(input_dim, output_dim)
        self.linear = mpi.Linear(input_dim * 2, output_dim)
        self.reset_parameters()

    def reset_parameters(self):
        self.linear.reset_parameters()

    def forward(self, graph: mpi.DENSEGraph, h: mpi.Tensor) -> mpi.Tensor:
        with graph.local_scope():
            graph.ndata["h"] = h
            graph.update_all(
                message_func=mpi.copy_u("h", "m"), reduce_func=mpi.mean("m", "h_N")
            )
            h_N = graph.ndata["h_N"]
            h_total = mpi.cat(h, h_N, dim=1)
            return self.linear(h_total)

Outputs

basic_layer.h

//
// Autogenerated file!
//

#pragma once

#include "common/datatypes.h"
#include "configuration/options.h"
#include "configuration/config.h"
#include "data/graph.h"
#include "nn/initialization.h"
#include "gnn_layer.h"
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#include "torch/torch.h"
#pragma GCC diagnostic pop

struct BasicLayerOptions : GNNLayerOptions {
};

class BasicLayer : public GNNLayer {
   public:
    shared_ptr<BasicLayerOptions> options_;
    torch::Tensor linear_;
    int _mpic_linear_input_dim_;
    int _mpic_linear_output_dim_;

    BasicLayer(shared_ptr<LayerConfig> layer_config, torch::Device device);

    void reset() override;

    torch::Tensor forward(torch::Tensor inputs, DENSEGraph dense_graph, bool train = true) override;
};

basic_layer.cpp

//
// Autogenerated file!
//

#include "nn/layers/gnn/basic_layer.h"

#include "nn/layers/gnn/layer_helpers.h"
#include "reporting/logger.h"
#include "nn/initialization.h"

// XXX: GCC warns on always_inline
#pragma GCC diagnostic ignored "-Wattributes"

// XXX: always_inline to mirror original code (unsure if more efficient)
#define FUNCGEN __attribute__((always_inline))

namespace {

struct SumFunc {
    static constexpr bool useNumNbrs = false;

    FUNCGEN static torch::Tensor segmented_reduce(torch::Tensor const& embeds, Indices const& offsets) {
        return segmented_sum_with_offsets(embeds, offsets);
    }
};

struct MaxFunc {
    static constexpr bool useNumNbrs = false;

    FUNCGEN static torch::Tensor segmented_reduce(torch::Tensor const& embeds, Indices const& offsets) {
        return segmented_max_with_offsets(embeds, offsets);
    }
};

struct MeanFunc {
    static constexpr bool useNumNbrs = true;

    FUNCGEN static torch::Tensor segmented_reduce(torch::Tensor const& embeds, Indices const& offsets) {
        return segmented_sum_with_offsets(embeds, offsets);
    }

    FUNCGEN static torch::Tensor applyNumNbrs(torch::Tensor const& a_i, torch::Tensor const& num_nbrs) {
        torch::Tensor denominator = torch::where(torch::not_equal(
                    num_nbrs, 0), num_nbrs, 1).to(a_i.dtype()).unsqueeze(-1);
        return a_i / denominator;
    }
};

template <class ReduceFunc>
FUNCGEN torch::Tensor update_all(DENSEGraph& dense_graph, torch::Tensor const& u) {
    constexpr bool useNumNbrs = ReduceFunc::useNumNbrs;
    torch::Tensor a_i;
    [[maybe_unused]] torch::Tensor total_num_neighbors;

    if (dense_graph.out_neighbors_mapping_.defined()) {
        Indices outgoing_neighbors = dense_graph.getNeighborIDs(false, false);
        Indices outgoing_neighbor_offsets = dense_graph.getNeighborOffsets(false);
        torch::Tensor outgoing_num = dense_graph.getNumNeighbors(false);

        torch::Tensor outgoing_embeddings = u.index_select(0, outgoing_neighbors);
        a_i = ReduceFunc::segmented_reduce(outgoing_embeddings, outgoing_neighbor_offsets);

        // often, aggregation functions require the number of neighbors
        if constexpr(useNumNbrs) {
            total_num_neighbors = outgoing_num;
        }
    }

    if (dense_graph.in_neighbors_mapping_.defined()) {
        Indices incoming_neighbors = dense_graph.getNeighborIDs(true, false);
        Indices incoming_neighbor_offsets = dense_graph.getNeighborOffsets(true);
        torch::Tensor incoming_num = dense_graph.getNumNeighbors(true);

        torch::Tensor incoming_embeddings = u.index_select(0, incoming_neighbors);

        if (a_i.defined()) {
            a_i = a_i + segmented_sum_with_offsets(incoming_embeddings, incoming_neighbor_offsets);
        } else {
            a_i = segmented_sum_with_offsets(incoming_embeddings, incoming_neighbor_offsets);
        }

        // often, aggregation functions require the number of neighbors
        if constexpr(useNumNbrs) {
            if (total_num_neighbors.defined()) {
                total_num_neighbors = total_num_neighbors + incoming_num;
            } else {
                total_num_neighbors = incoming_num;
            }
        }
    }

    if constexpr(useNumNbrs) {
        return ReduceFunc::applyNumNbrs(a_i, total_num_neighbors);
    } else {
        return a_i;
    }
}

} // anonymous namespace

BasicLayer::BasicLayer(shared_ptr<LayerConfig> layer_config, torch::Device device) {
    config_ = layer_config;
    options_ = std::dynamic_pointer_cast<BasicLayerOptions>(config_->options);
    input_dim_ = config_->input_dim;
    output_dim_ = config_->output_dim;
    device_ = device;

    _mpic_linear_input_dim_ = input_dim_ * 2;
    _mpic_linear_output_dim_ = output_dim_;
    reset();
}

void BasicLayer::reset() {
    [[maybe_unused]] auto tensor_options = torch::TensorOptions().dtype(torch::kFloat32).device(device_);

    linear_ = initialize_tensor(config_->init, {_mpic_linear_output_dim_, _mpic_linear_input_dim_}, tensor_options).set_requires_grad(true);;

    if (config_->bias) {
        init_bias();
    }
}

torch::Tensor BasicLayer::forward(torch::Tensor h, DENSEGraph graph, bool train) {
    torch::Tensor _mpic_ndata_h;
    torch::Tensor _mpic_ndata_h_N;
    torch::Tensor h_N;
    torch::Tensor h_total;

    _mpic_ndata_h = h;
    _mpic_ndata_h_N = update_all<MeanFunc>(graph, _mpic_ndata_h);;
    h_N = _mpic_ndata_h_N;
    h_total = torch::cat({h, h_N}, 1);
    return torch::matmul(linear_, h_total.transpose(0, -1)).transpose(0, -1);
}

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published