import torch
from torch.nn import Parameter
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import math
[docs]class IndRNNCell(nn.Module):
r"""An IndRNN cell with tanh or ReLU non-linearity.
* https://arxiv.org/abs/1804.04849
.. math::
h' = \tanh(w_{ih} * x + b_{ih} + w_{hh} (*) h)
With (*) being element-wise vector multiplication.
If nonlinearity='relu', then ReLU is used in place of tanh.
Args:
input_size: The number of expected features in the input x
hidden_size: The number of features in the hidden state h
bias: If ``False``, then the layer does not use bias weights b_ih and b_hh.
Default: ``True``
nonlinearity: The non-linearity to use ['tanh'|'relu']. Default: 'relu'
hidden_min_abs: Minimal absolute inital value for hidden weights. Default: 0
hidden_max_abs: Maximal absolute inital value for hidden weights. Default: None
Inputs: input, hidden
- **input** (batch, input_size): tensor containing input features
- **hidden** (batch, hidden_size): tensor containing the initial hidden
state for each element in the batch.
Outputs: h'
- **h'** (batch, hidden_size): tensor containing the next hidden state
for each element in the batch
Attributes:
weight_ih: the learnable input-hidden weights, of shape
`(input_size x hidden_size)`
weight_hh: the learnable hidden-hidden weights, of shape
`(hidden_size)`
bias_ih: the learnable input-hidden bias, of shape `(hidden_size)`
Examples::
>>> rnn = nn.IndRNNCell(10, 20)
>>> input = Variable(torch.randn(6, 3, 10))
>>> hx = Variable(torch.randn(3, 20))
>>> output = []
>>> for i in range(6):
... hx = rnn(input[i], hx)
... output.append(hx)
"""
def __init__(self, input_size, hidden_size, bias=True, nonlinearity="relu",
hidden_min_abs=0, hidden_max_abs=None):
super(IndRNNCell, self).__init__()
self.hidden_max_abs = hidden_max_abs
self.hidden_min_abs = hidden_min_abs
self.input_size = input_size
self.hidden_size = hidden_size
self.bias = bias
self.nonlinearity = nonlinearity
self.weight_ih = Parameter(torch.Tensor(hidden_size, input_size))
self.weight_hh = Parameter(torch.Tensor(hidden_size))
if bias:
self.bias_ih = Parameter(torch.Tensor(hidden_size))
else:
self.register_parameter('bias_ih', None)
self.reset_parameters()
def reset_parameters(self):
stdv = 1.0 / math.sqrt(self.hidden_size)
for name, weight in self.named_parameters():
if "bias" in name:
weight.data.zero_()
elif "weight_hh" in name:
if self.hidden_max_abs:
stdv_ = self.hidden_max_abs
else:
stdv_ = stdv
weight.data.uniform_(-stdv_, stdv_)
elif "weight_ih" in name:
weight.data.normal_(0, 0.01)
else:
weight.data.normal_(0, 0.01)
# weight.data.uniform_(-stdv, stdv)
self.check_bounds()
def check_bounds(self):
if self.hidden_min_abs:
abs_kernel = torch.abs(self.weight_hh.data)
min_abs_kernel = torch.clamp(abs_kernel, min=self.hidden_min_abs)
self.weight_hh.data.copy_(
torch.mul(torch.sign(self.weight_hh.data), min_abs_kernel))
if self.hidden_max_abs:
self.weight_hh.data.copy_(
torch.clamp(self.weight_hh.data, max=self.hidden_max_abs,
min=-self.hidden_max_abs))
def forward(self, input, hx):
if self.nonlinearity == "tanh":
func = IndRNNTanhCell
elif self.nonlinearity == "relu":
func = IndRNNReLuCell
else:
raise RuntimeError(
"Unknown nonlinearity: {}".format(self.nonlinearity))
return func(input, hx, self.weight_ih, self.weight_hh, self.bias_ih)
[docs]def IndRNNTanhCell(input, hidden, w_ih, w_hh, b_ih=None):
hy = F.tanh(F.linear(input, w_ih, b_ih) + F.mul(w_hh, hidden))
return hy
[docs]def IndRNNReLuCell(input, hidden, w_ih, w_hh, b_ih=None):
hy = F.relu(F.linear(input, w_ih, b_ih) + F.mul(w_hh, hidden))
return hy
[docs]class IndRNN(nn.Module):
r"""Applies a multi-layer IndRNN with `tanh` or `ReLU` non-linearity to an
input sequence.
For each element in the input sequence, each layer computes the following
function:
.. math::
h_t = \tanh(w_{ih} x_t + b_{ih} + w_{hh} (*) h_{(t-1)})
where :math:`h_t` is the hidden state at time `t`, and :math:`x_t` is
the hidden state of the previous layer at time `t` or :math:`input_t`
for the first layer. (*) is element-wise multiplication.
If :attr:`nonlinearity`='relu', then `ReLU` is used instead of `tanh`.
Args:
input_size: The number of expected features in the input `x`
hidden_size: The number of features in the hidden state `h`
num_layers: Number of recurrent layers.
nonlinearity: The non-linearity to use. Can be either 'tanh' or 'relu'. Default: 'tanh'
bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`.
Default: ``True``
batch_first: If ``True``, then the input and output tensors are provided
as `(batch, seq, feature)`
Inputs: input, h_0
- **input** of shape `(seq_len, batch, input_size)`: tensor containing the features
of the input sequence. The input can also be a packed variable length
sequence. See :func:`torch.nn.utils.rnn.pack_padded_sequence`
or :func:`torch.nn.utils.rnn.pack_sequence`
for details.
- **h_0** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor
containing the initial hidden state for each element in the batch.
Defaults to zero if not provided.
Outputs: output, h_n
- **output** of shape `(seq_len, batch, hidden_size * num_directions)`: tensor
containing the output features (`h_k`) from the last layer of the RNN,
for each `k`. If a :class:`torch.nn.utils.rnn.PackedSequence` has
been given as the input, the output will also be a packed sequence.
- **h_n** (num_layers * num_directions, batch, hidden_size): tensor
containing the hidden state for `k = seq_len`.
Attributes:
cells[k]: individual IndRNNCells containing the weights
Examples::
>>> rnn = nn.IndRNN(10, 20, 2)
>>> input = torch.randn(5, 3, 10)
>>> h0 = torch.randn(2, 3, 20)
>>> output = rnn(input, h0)
"""
def __init__(self, input_size, hidden_size, n_layer=1, batch_norm=False,
step_size=None, **kwargs):
super(IndRNN, self).__init__()
self.hidden_size = hidden_size
if batch_norm and step_size is None:
raise Exception("Frame wise batch size needs to know the step size")
self.batch_norm = batch_norm
self.step_size = step_size
self.n_layer = n_layer
cells = []
for i in range(n_layer):
if i == 0:
cells += [IndRNNCell(input_size, hidden_size, **kwargs)]
else:
cells += [IndRNNCell(hidden_size, hidden_size, **kwargs)]
self.cells = nn.ModuleList(cells)
if batch_norm:
bns = []
for i in range(n_layer):
bns += [nn.BatchNorm2d(step_size)]
self.bns = nn.ModuleList(bns)
h0 = torch.zeros(hidden_size)
self.register_buffer('h0', torch.autograd.Variable(h0))
def forward(self, x, hidden=None):
for i, cell in enumerate(self.cells):
cell.check_bounds()
hx = self.h0.unsqueeze(0).expand(x.size(0), self.hidden_size).contiguous()
outputs = []
for t in range(x.size(1)):
x_t = x[:, t]
hx = cell(x_t, hx)
outputs += [hx]
x = torch.stack(outputs, 1)
if self.batch_norm:
x = self.bns[i](x)
return x.squeeze(2)