Supported packages

GNN-RL support popular deep graph neural network package, such as Torch-Geometric and DGL. In this section we will give examples to modeling DNN's topology to computational graph, and embedding them using graph neural network.

Create computational graph using Torch-geometric

GNN-RL also support Torch-Geometric packadge, and also provid easy-to-use function to create Torch-Geometric graph object. First, get the information of target DNN. The build-in function graph_construction.net_info(net_name) can automatically process the DNN and return the input and output channels of convolutional layers for constructing a graph.

from gnnrl.utils.load_networks import load_model
from gnnrl.graph_env import graph_construction
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = load_model('resnet20', '.')
net = net.to(device)
in_channels, out_channels, _ = graph_construction.net_info('resnet20')

Or you can write your own function to get the network's convolutional layers' input and output channels,

import torch.nn as nn
in_channels = []
out_channels=[]
for name,layer in net.named_modules():
    if isinstance(layer,nn.Conv2d):
        in_channels.append(layer.in_channels)
        out_channels.append(layer.out_channels)

Then, construct the graph by create an computational_graph_pyg object,

from gnnrl.graph_env.graph_construction import computational_graph_pyg
pyg_g = computational_graph_pyg(in_channels,out_channels,feature_size=10)

Plain Computational Graph

Import pyg backend method and convert DNN to simplified computational graph. To create a plain simplified computational graph, call plain_computational_graph(),

graph = pyg_g.plain_computational_graph()

Create computational graph using DGL

GNN-RL provide build in graph construction methods for ResNet-20/32/54, MobileNet-v1/v2 and VGG-16. First, get the information of target DNN. The build-in function graph_construction.net_info(net) can automatically process the DNN and return the input and output channels of DNN for constructing a graph.

from gnnrl.utils.load_networks import load_model
from gnnrl.graph_env import graph_construction
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = load_model('resnet20')
net.to(device)
in_channels, out_channels = graph_construction.net_info(net)

Then, construct the graph by create an computational_graph_dgl object,

from gnnrl.graph_env.graph_construction import computational_graph_dgl
dgl_g = computational_graph_dgl(in_channels,out_channels,feature_size=10)

Plain Computational Graph

Import DGL backend method and convert DNN to simplified computational graph. To create a plain simplified computational graph, call plain_computational_graph(),

graph = dgl_g.plain_computational_graph()