Source code for torchex.nn.modules.highway

import torch
import torch.nn as nn
import torch.nn.functional as F


[docs]class Highway(nn.Module): """Highway module. In highway network, two gates are added to the ordinal non-linear transformation (:math:`H(x) = activate(W_h x + b_h)`). One gate is the transform gate :math:`T(x) = \\sigma(W_t x + b_t)`, and the other is the carry gate :math:`C(x)`. For simplicity, the author defined :math:`C = 1 - T`. Highway module returns :math:`y` defined as .. math:: y = activate(W_h x + b_h) \\odot \\sigma(W_t x + b_t) + x \\odot(1 - \\sigma(W_t x + b_t)) The output array has the same spatial size as the input. In order to satisfy this, :math:`W_h` and :math:`W_t` must be square matrices. Args: in_out_features (int): Dimension of input and output vectors. bias (bool): If ``True``, then this function does use the bias. activate: Activation function of plain array. :math:`tanh` is also available. See: `Highway Networks <https://arxiv.org/abs/1505.00387>`_. """ def __init__(self, in_out_features, bias=True, activate=F.relu): super(Highway, self).__init__() self.in_out_features = in_out_features self.bias = bias self.activate = activate self.plain = nn.Linear(self.in_out_features, self.in_out_features, bias=bias) self.transform = nn.Linear(self.in_out_features, self.in_out_features, bias=bias) def forward(self, x): """Computes the output of the Highway module. Args: x (~torch.Tensor): Input variable. Returns: Variable: Output variable. Its array has the same spatial size and the same minibatch size as the input array. """ out_plain = self.activate(self.plain(x)) out_transform = torch.sigmoid(self.transform(x)) y = out_plain * out_transform + x * (1 - out_transform) return y