Source code for torchex.nn.modules.graph.conv

import torch
from torch.nn.parameter import Parameter

[docs]class SparseMM(torch.autograd.Function): """ Sparse x dense matrix multiplication with autograd support. Implementation by Soumith Chintala: https://discuss.pytorch.org/t/ does-pytorch-support-autograd-on-sparse-matrix/6156/7 """ def __init__(self, sparse): super(SparseMM, self).__init__() self.sparse = sparse def forward(self, dense): return torch.mm(self.sparse, dense) def backward(self, grad_output): grad_input = None if self.needs_input_grad[0]: grad_input = torch.mm(self.sparse.t(), grad_output) return grad_input
[docs]class GraphConv(torch.nn.Module): """ Simple GCN layer, similar to https://arxiv.org/abs/1609.02907 """ def __init__(self, in_features, out_features, bias=True): super(GraphConv, self).__init__() self.in_features = in_features self.out_features = out_features self.weight = Parameter(torch.Tensor(in_features, out_features)) if bias: self.bias = Parameter(torch.Tensor(out_features)) else: self.register_parameter('bias', None) def forward(self, input, adj): support = torch.mm(input, self.weight) output = SparseMM(adj)(support) if self.bias is not None: return output + self.bias else: return output def __repr__(self): return self.__class__.__name__ + ' (' \ + str(self.in_features) + ' -> ' \ + str(self.out_features) + ')'