import math
import numbers
import warnings
import torch
import torch.nn as nn
from torch.nn.utils.rnn import PackedSequence
from .base import LazyBase
[docs]class LazyRNNBase(LazyBase):
def __init__(self, mode, hidden_size,
num_layers=1, bias=True, batch_first=False,
dropout=0, bidirectional=False):
super(LazyRNNBase, self).__init__()
self.mode = mode
self.hidden_size = hidden_size
self.num_layers = num_layers
self.bias = bias
self.batch_first = batch_first
self.dropout = dropout
self.dropout_state = {}
self.bidirectional = bidirectional
self.num_directions = 2 if bidirectional else 1
if not isinstance(dropout, numbers.Number) or not 0 <= dropout <= 1 or \
isinstance(dropout, bool):
raise ValueError("dropout should be a number in range [0, 1] "
"representing the probability of an element being "
"zeroed")
if dropout > 0 and num_layers == 1:
warnings.warn("dropout option adds dropout after all but last "
"recurrent layer, so non-zero dropout expects "
"num_layers greater than 1, but got dropout={} and "
"num_layers={}".format(dropout, num_layers))
if mode == 'LSTM':
self.gate_size = 4 * hidden_size
elif mode == 'GRU':
self.gate_size = 3 * hidden_size
else:
self.gate_size = hidden_size
self._all_weights = []
self._data_ptrs = None
def flatten_parameters(self):
"""Resets parameter data pointer so that they can use faster code paths.
Right now, this works only if the module is on the GPU and cuDNN is enabled.
Otherwise, it's a no-op.
"""
any_param = next(self.parameters()).data
if not any_param.is_cuda or not torch.backends.cudnn.is_acceptable(any_param):
self._data_ptrs = []
return
# If any parameters alias, we fall back to the slower, copying code path. This is
# a sufficient check, because overlapping parameter buffers that don't completely
# alias would break the assumptions of the uniqueness check in
# Module.named_parameters().
unique_data_ptrs = set(p.data_ptr() for l in self.all_weights for p in l)
if len(unique_data_ptrs) != sum(len(l) for l in self.all_weights):
self._data_ptrs = []
return
with torch.cuda.device_of(any_param):
import torch.backends.cudnn.rnn as rnn
weight_arr = list(itertools.chain.from_iterable(self.all_weights))
weight_stride0 = len(self.all_weights[0])
# NB: This is a temporary hack while we still don't have Tensor
# bindings for ATen functions
with torch.no_grad():
# NB: this is an INPLACE function on weight_arr, that's why the
# no_grad() is necessary.
weight_buf = torch._cudnn_rnn_flatten_weight(
weight_arr, weight_stride0,
self.input_size, rnn.get_cudnn_mode(self.mode), self.hidden_size, self.num_layers,
self.batch_first, bool(self.bidirectional))
self._param_buf_size = weight_buf.size(0)
self._data_ptrs = list(p.data.data_ptr() for p in self.parameters())
def _apply(self, fn):
ret = super(RNNBase, self)._apply(fn)
self.flatten_parameters()
return ret
def reset_parameters(self):
stdv = 1.0 / math.sqrt(self.hidden_size)
for name, p in self.named_parameters():
if 'bias' in name:
p.data.fill_(0)
if self.mode == "LSTM":
n = p.nelement()
p.data[n // 4:n // 2].fill_(1) # forget bias
else:
p.data.uniform_(-stdv, stdv)
def _initialize_weight(self, input_size):
for layer in range(self.num_layers):
for direction in range(self.num_directions):
layer_input_size = input_size if layer == 0 else self.hidden_size * self.num_directions
w_ih = torch.Tensor(self.gate_size, layer_input_size)
w_hh = torch.Tensor(self.gate_size, self.hidden_size)
b_ih = torch.Tensor(self.gate_size)
b_hh = torch.Tensor(self.gate_size)
if self.to_args:
w_ih = w_ih.to(*self.to_args, **self.to_kwargs)
w_hh = w_hh.to(*self.to_args, **self.to_kwargs)
b_ih = b_ih.to(*self.to_args, **self.to_kwargs)
b_hh = b_hh.to(*self.to_args, **self.to_kwargs)
w_ih = nn.Parameter(w_ih)
w_hh = nn.Parameter(w_hh)
b_ih = nn.Parameter(b_ih)
b_hh = nn.Parameter(b_hh)
layer_params = (w_ih, w_hh, b_ih, b_hh)
suffix = '_reverse' if direction == 1 else ''
param_names = ['weight_ih_l{}{}', 'weight_hh_l{}{}']
if self.bias:
param_names += ['bias_ih_l{}{}', 'bias_hh_l{}{}']
param_names = [x.format(layer, suffix) for x in param_names]
for name, param in zip(param_names, layer_params):
setattr(self, name, param)
self._all_weights.append(param_names)
self.flatten_parameters()
self.reset_parameters()
def check_forward_args(self, input, hidden, batch_sizes):
is_input_packed = batch_sizes is not None
expected_input_dim = 2 if is_input_packed else 3
if input.dim() != expected_input_dim:
raise RuntimeError(
'input must have {} dimensions, got {}'.format(
expected_input_dim, input.dim()))
if self.input_size != input.size(-1):
raise RuntimeError(
'input.size(-1) must be equal to input_size. Expected {}, got {}'.format(
self.input_size, input.size(-1)))
if is_input_packed:
mini_batch = int(batch_sizes[0])
else:
mini_batch = input.size(0) if self.batch_first else input.size(1)
num_directions = 2 if self.bidirectional else 1
expected_hidden_size = (self.num_layers * num_directions,
mini_batch, self.hidden_size)
def check_hidden_size(hx, expected_hidden_size, msg='Expected hidden size {}, got {}'):
if tuple(hx.size()) != expected_hidden_size:
raise RuntimeError(msg.format(expected_hidden_size, tuple(hx.size())))
if self.mode == 'LSTM':
check_hidden_size(hidden[0], expected_hidden_size,
'Expected hidden[0] size {}, got {}')
check_hidden_size(hidden[1], expected_hidden_size,
'Expected hidden[1] size {}, got {}')
else:
check_hidden_size(hidden, expected_hidden_size)
def forward(self, input, hx=None):
is_packed = isinstance(input, PackedSequence)
if not self._all_weights:
self.input_size = input.size(-1)
self._initialize_weight(self.input_size)
if is_packed:
input, batch_sizes = input
max_batch_size = int(batch_sizes[0])
else:
batch_sizes = None
max_batch_size = input.size(0) if self.batch_first else input.size(1)
if hx is None:
num_directions = 2 if self.bidirectional else 1
hx = input.new_zeros(self.num_layers * num_directions,
max_batch_size, self.hidden_size,
requires_grad=False)
if self.mode == 'LSTM':
hx = (hx, hx)
has_flat_weights = list(p.data.data_ptr() for p in self.parameters()) == self._data_ptrs
if has_flat_weights:
first_data = next(self.parameters()).data
assert first_data.storage().size() == self._param_buf_size
flat_weight = first_data.new().set_(first_data.storage(), 0, torch.Size([self._param_buf_size]))
else:
flat_weight = None
self.check_forward_args(input, hx, batch_sizes)
func = self._backend.RNN(
self.mode,
self.input_size,
self.hidden_size,
num_layers=self.num_layers,
batch_first=self.batch_first,
dropout=self.dropout,
train=self.training,
bidirectional=self.bidirectional,
dropout_state=self.dropout_state,
variable_length=is_packed,
flat_weight=flat_weight
)
output, hidden = func(input, self.all_weights, hx, batch_sizes)
if is_packed:
output = PackedSequence(output, batch_sizes)
return output, hidden
def extra_repr(self):
s = '{input_size}, {hidden_size}'
if self.num_layers != 1:
s += ', num_layers={num_layers}'
if self.bias is not True:
s += ', bias={bias}'
if self.batch_first is not False:
s += ', batch_first={batch_first}'
if self.dropout != 0:
s += ', dropout={dropout}'
if self.bidirectional is not False:
s += ', bidirectional={bidirectional}'
return s.format(**self.__dict__)
def __setstate__(self, d):
super(RNNBase, self).__setstate__(d)
self.__dict__.setdefault('_data_ptrs', [])
if 'all_weights' in d:
self._all_weights = d['all_weights']
if isinstance(self._all_weights[0][0], str):
return
num_layers = self.num_layers
num_directions = 2 if self.bidirectional else 1
self._all_weights = []
for layer in range(num_layers):
for direction in range(num_directions):
suffix = '_reverse' if direction == 1 else ''
weights = ['weight_ih_l{}{}', 'weight_hh_l{}{}', 'bias_ih_l{}{}', 'bias_hh_l{}{}']
weights = [x.format(layer, suffix) for x in weights]
if self.bias:
self._all_weights += [weights]
else:
self._all_weights += [weights[:2]]
@property
def all_weights(self):
return [[getattr(self, weight) for weight in weights] for weights in self._all_weights]