Source code for torchex.nn.modules.graph.batchnorm

import torch

[docs]class GraphBatchNorm(torch.nn.Module): """Graph Batch Normalization layer. .. seealso:: :class:`torch.autograd.BatchNorm1d` """ def __init__(self, num_features): super(GraphBatchNorm, self).__init__() self.bn = torch.nn.BatchNorm1d(num_features) def __call__(self, x): """Forward propagation. Args: x (:class:`torch.autograd.Variable`, or :class:`numpy.ndarray` Input array that should be a float array whose ``ndim`` is 3. It represents a minibatch of atoms, each of which consists of a sequence of molecules. Each molecule is represented by integer IDs. The first axis is an index of atoms (i.e. minibatch dimension) and the second one an index of molecules. Returns: :class:`touch.autograd.Variable`: A 3-dimeisional array. """ # (minibatch, atom, ch) # The implemenataion of batch normalization for graph convolution below # is rather naive. To be precise, it is necessary to consider the # difference in the number of atoms for each graph. However, the # implementation below does not take it into account, and assumes # that all graphs have the same number of atoms, hence extra numbers # of zero are included when average is computed. In other word, the # results of batch normalization below is biased. s0, s1, s2 = x.size() x = x.view(s0 * s1, s2) x = self.bn(x) x = x.view(s0, s1, s2) return x