Multi-stage graph neural network

DNN's computational graph contains multiple motif (reused graph pattern), GNN-RL can model DNN as a hierarchical computational graph and uses efficient multi-stage graph embedding.

Define the neural network, and get its information.

net = load_model(args.model,args.data_root)
net.to(device)
in_channels, out_channels = graph_construction.net_info(net)

Multi-stage graph embedding

DGL backend

Create a hierarchical computational graph use DGL backend, call hierarchical_computational_graph(self),

graph = dgl_g.hierarchical_computational_graph()

Define the encoder by creating the multi_stage_graph_encoder_dgl object,

encoder = multi_stage_graph_encoder_dgl(in_feature, hidden_feature, out_feature)
hierarchical_embeddings = encoder(graph)

Torch-Geometric backend

Create a hierarchical computational graph use DGL backend, call hierarchical_computational_graph(self),

graph = pyg_g.hierarchical_computational_graph()

Define the encoder by creating the multi_stage_graph_encoder_pyg object,

encoder = multi_stage_graph_encoder_pyg(in_feature, hidden_feature, out_feature)
hierarchical_embeddings = encoder(graph)