import torch
import torch.nn as nn
import torch.nn.functional as F
[docs]class Highway(nn.Module):
"""Highway module.
In highway network, two gates are added to the ordinal non-linear
transformation (:math:`H(x) = activate(W_h x + b_h)`).
One gate is the transform gate :math:`T(x) = \\sigma(W_t x + b_t)`, and the
other is the carry gate :math:`C(x)`.
For simplicity, the author defined :math:`C = 1 - T`.
Highway module returns :math:`y` defined as
.. math::
y = activate(W_h x + b_h) \\odot \\sigma(W_t x + b_t) +
x \\odot(1 - \\sigma(W_t x + b_t))
The output array has the same spatial size as the input. In order to
satisfy this, :math:`W_h` and :math:`W_t` must be square matrices.
Args:
in_out_features (int): Dimension of input and output vectors.
bias (bool): If ``True``, then this function does use the bias.
activate: Activation function of plain array. :math:`tanh` is also
available.
See:
`Highway Networks <https://arxiv.org/abs/1505.00387>`_.
"""
def __init__(self, in_out_features, bias=True, activate=F.relu):
super(Highway, self).__init__()
self.in_out_features = in_out_features
self.bias = bias
self.activate = activate
self.plain = nn.Linear(self.in_out_features, self.in_out_features, bias=bias)
self.transform = nn.Linear(self.in_out_features, self.in_out_features, bias=bias)
def forward(self, x):
"""Computes the output of the Highway module.
Args:
x (~torch.Tensor): Input variable.
Returns:
Variable: Output variable. Its array has the same spatial size and
the same minibatch size as the input array.
"""
out_plain = self.activate(self.plain(x))
out_transform = torch.sigmoid(self.transform(x))
y = out_plain * out_transform + x * (1 - out_transform)
return y