import torch
from torch.nn.parameter import Parameter
[docs]class SparseMM(torch.autograd.Function):
"""
Sparse x dense matrix multiplication with autograd support.
Implementation by Soumith Chintala:
https://discuss.pytorch.org/t/
does-pytorch-support-autograd-on-sparse-matrix/6156/7
"""
def __init__(self, sparse):
super(SparseMM, self).__init__()
self.sparse = sparse
def forward(self, dense):
return torch.mm(self.sparse, dense)
def backward(self, grad_output):
grad_input = None
if self.needs_input_grad[0]:
grad_input = torch.mm(self.sparse.t(), grad_output)
return grad_input
[docs]class GraphConv(torch.nn.Module):
"""
Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
"""
def __init__(self, in_features, out_features, bias=True):
super(GraphConv, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = Parameter(torch.Tensor(in_features, out_features))
if bias:
self.bias = Parameter(torch.Tensor(out_features))
else:
self.register_parameter('bias', None)
def forward(self, input, adj):
support = torch.mm(input, self.weight)
output = SparseMM(adj)(support)
if self.bias is not None:
return output + self.bias
else:
return output
def __repr__(self):
return self.__class__.__name__ + ' (' \
+ str(self.in_features) + ' -> ' \
+ str(self.out_features) + ')'