GCN-based graph encoder

GNN-RL uses GCN-based policy network directly learn topology from graphs. GNN-RL supports Torch-Geometric and DGL graph nueral network backend.

Embedding computational graph use DGL

First, create a neural network and get the network information.

net = load_model(args.model,args.data_root)
in_channels, out_channels = graph_construction.net_info(net)

Then, construct corresponding computational graph by create an computational_graph_dgl object,and create a plain simplified computational graph, call plain_computational_graph(self),

dgl_g = computational_graph_dgl(in_channels,out_channels,feature_size)
graph = dgl_g.plain_computational_graph()

Embedding the computational graph,

mode = graph_encoder_dgl(feature_size,hidden_feature,out_feature)
embedding = mode(graph)

Embedding computational graph use pyg

Construct corresponding computational graph by create an computational_graph_pyg object,and create a plain simplified computational graph, call plain_computational_graph(self),

pyg_g = computational_graph_pyg(in_channels,out_channels,feature_size)
graph = pyg_g.plain_computational_graph()

Embedding the computational graph,

mode = graph_encoder_pyg(feature_size,hidden_feature,out_feature)
embedding = mode(graph)

Build your own GCN encoder

Define the graph encoder using pyg:

import torch
import torch.nn as nn
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool

class graph_encoder_pyg(nn.Module):
    def __init__(self,in_feature, hidden_feature, out_feature):
        super(graph_encoder_pyg, self).__init__()

        self.in_feature = in_feature
        self.hidden_feature = hidden_feature
        self.out_feature = out_feature

        self.conv1 = GCNConv(in_feature, hidden_feature)
        self.linear1 = nn.Linear(hidden_feature,out_feature)
        self.tanh = nn.Tanh()
        self.relu = torch.relu

    def forward(self,  Graph):

        x, edge_index,batch = Graph.x, Graph.edge_index,Graph.batch
        x = self.conv1(x, edge_index)
        x = self.relu(x)
        embedding = global_mean_pool(x,batch)
        embedding = self.linear1(embedding)
        embedding = self.tanh(embedding)

        return embedding