Source code for torchex.nn.modules.lazy.rnn

import math
import numbers
import warnings

import torch
import torch.nn as nn
from torch.nn.utils.rnn import PackedSequence

from .base import LazyBase


[docs]class LazyRNNBase(LazyBase): def __init__(self, mode, hidden_size, num_layers=1, bias=True, batch_first=False, dropout=0, bidirectional=False): super(LazyRNNBase, self).__init__() self.mode = mode self.hidden_size = hidden_size self.num_layers = num_layers self.bias = bias self.batch_first = batch_first self.dropout = dropout self.dropout_state = {} self.bidirectional = bidirectional self.num_directions = 2 if bidirectional else 1 if not isinstance(dropout, numbers.Number) or not 0 <= dropout <= 1 or \ isinstance(dropout, bool): raise ValueError("dropout should be a number in range [0, 1] " "representing the probability of an element being " "zeroed") if dropout > 0 and num_layers == 1: warnings.warn("dropout option adds dropout after all but last " "recurrent layer, so non-zero dropout expects " "num_layers greater than 1, but got dropout={} and " "num_layers={}".format(dropout, num_layers)) if mode == 'LSTM': self.gate_size = 4 * hidden_size elif mode == 'GRU': self.gate_size = 3 * hidden_size else: self.gate_size = hidden_size self._all_weights = [] self._data_ptrs = None def flatten_parameters(self): """Resets parameter data pointer so that they can use faster code paths. Right now, this works only if the module is on the GPU and cuDNN is enabled. Otherwise, it's a no-op. """ any_param = next(self.parameters()).data if not any_param.is_cuda or not torch.backends.cudnn.is_acceptable(any_param): self._data_ptrs = [] return # If any parameters alias, we fall back to the slower, copying code path. This is # a sufficient check, because overlapping parameter buffers that don't completely # alias would break the assumptions of the uniqueness check in # Module.named_parameters(). unique_data_ptrs = set(p.data_ptr() for l in self.all_weights for p in l) if len(unique_data_ptrs) != sum(len(l) for l in self.all_weights): self._data_ptrs = [] return with torch.cuda.device_of(any_param): import torch.backends.cudnn.rnn as rnn weight_arr = list(itertools.chain.from_iterable(self.all_weights)) weight_stride0 = len(self.all_weights[0]) # NB: This is a temporary hack while we still don't have Tensor # bindings for ATen functions with torch.no_grad(): # NB: this is an INPLACE function on weight_arr, that's why the # no_grad() is necessary. weight_buf = torch._cudnn_rnn_flatten_weight( weight_arr, weight_stride0, self.input_size, rnn.get_cudnn_mode(self.mode), self.hidden_size, self.num_layers, self.batch_first, bool(self.bidirectional)) self._param_buf_size = weight_buf.size(0) self._data_ptrs = list(p.data.data_ptr() for p in self.parameters()) def _apply(self, fn): ret = super(RNNBase, self)._apply(fn) self.flatten_parameters() return ret def reset_parameters(self): stdv = 1.0 / math.sqrt(self.hidden_size) for name, p in self.named_parameters(): if 'bias' in name: p.data.fill_(0) if self.mode == "LSTM": n = p.nelement() p.data[n // 4:n // 2].fill_(1) # forget bias else: p.data.uniform_(-stdv, stdv) def _initialize_weight(self, input_size): for layer in range(self.num_layers): for direction in range(self.num_directions): layer_input_size = input_size if layer == 0 else self.hidden_size * self.num_directions w_ih = torch.Tensor(self.gate_size, layer_input_size) w_hh = torch.Tensor(self.gate_size, self.hidden_size) b_ih = torch.Tensor(self.gate_size) b_hh = torch.Tensor(self.gate_size) if self.to_args: w_ih = w_ih.to(*self.to_args, **self.to_kwargs) w_hh = w_hh.to(*self.to_args, **self.to_kwargs) b_ih = b_ih.to(*self.to_args, **self.to_kwargs) b_hh = b_hh.to(*self.to_args, **self.to_kwargs) w_ih = nn.Parameter(w_ih) w_hh = nn.Parameter(w_hh) b_ih = nn.Parameter(b_ih) b_hh = nn.Parameter(b_hh) layer_params = (w_ih, w_hh, b_ih, b_hh) suffix = '_reverse' if direction == 1 else '' param_names = ['weight_ih_l{}{}', 'weight_hh_l{}{}'] if self.bias: param_names += ['bias_ih_l{}{}', 'bias_hh_l{}{}'] param_names = [x.format(layer, suffix) for x in param_names] for name, param in zip(param_names, layer_params): setattr(self, name, param) self._all_weights.append(param_names) self.flatten_parameters() self.reset_parameters() def check_forward_args(self, input, hidden, batch_sizes): is_input_packed = batch_sizes is not None expected_input_dim = 2 if is_input_packed else 3 if input.dim() != expected_input_dim: raise RuntimeError( 'input must have {} dimensions, got {}'.format( expected_input_dim, input.dim())) if self.input_size != input.size(-1): raise RuntimeError( 'input.size(-1) must be equal to input_size. Expected {}, got {}'.format( self.input_size, input.size(-1))) if is_input_packed: mini_batch = int(batch_sizes[0]) else: mini_batch = input.size(0) if self.batch_first else input.size(1) num_directions = 2 if self.bidirectional else 1 expected_hidden_size = (self.num_layers * num_directions, mini_batch, self.hidden_size) def check_hidden_size(hx, expected_hidden_size, msg='Expected hidden size {}, got {}'): if tuple(hx.size()) != expected_hidden_size: raise RuntimeError(msg.format(expected_hidden_size, tuple(hx.size()))) if self.mode == 'LSTM': check_hidden_size(hidden[0], expected_hidden_size, 'Expected hidden[0] size {}, got {}') check_hidden_size(hidden[1], expected_hidden_size, 'Expected hidden[1] size {}, got {}') else: check_hidden_size(hidden, expected_hidden_size) def forward(self, input, hx=None): is_packed = isinstance(input, PackedSequence) if not self._all_weights: self.input_size = input.size(-1) self._initialize_weight(self.input_size) if is_packed: input, batch_sizes = input max_batch_size = int(batch_sizes[0]) else: batch_sizes = None max_batch_size = input.size(0) if self.batch_first else input.size(1) if hx is None: num_directions = 2 if self.bidirectional else 1 hx = input.new_zeros(self.num_layers * num_directions, max_batch_size, self.hidden_size, requires_grad=False) if self.mode == 'LSTM': hx = (hx, hx) has_flat_weights = list(p.data.data_ptr() for p in self.parameters()) == self._data_ptrs if has_flat_weights: first_data = next(self.parameters()).data assert first_data.storage().size() == self._param_buf_size flat_weight = first_data.new().set_(first_data.storage(), 0, torch.Size([self._param_buf_size])) else: flat_weight = None self.check_forward_args(input, hx, batch_sizes) func = self._backend.RNN( self.mode, self.input_size, self.hidden_size, num_layers=self.num_layers, batch_first=self.batch_first, dropout=self.dropout, train=self.training, bidirectional=self.bidirectional, dropout_state=self.dropout_state, variable_length=is_packed, flat_weight=flat_weight ) output, hidden = func(input, self.all_weights, hx, batch_sizes) if is_packed: output = PackedSequence(output, batch_sizes) return output, hidden def extra_repr(self): s = '{input_size}, {hidden_size}' if self.num_layers != 1: s += ', num_layers={num_layers}' if self.bias is not True: s += ', bias={bias}' if self.batch_first is not False: s += ', batch_first={batch_first}' if self.dropout != 0: s += ', dropout={dropout}' if self.bidirectional is not False: s += ', bidirectional={bidirectional}' return s.format(**self.__dict__) def __setstate__(self, d): super(RNNBase, self).__setstate__(d) self.__dict__.setdefault('_data_ptrs', []) if 'all_weights' in d: self._all_weights = d['all_weights'] if isinstance(self._all_weights[0][0], str): return num_layers = self.num_layers num_directions = 2 if self.bidirectional else 1 self._all_weights = [] for layer in range(num_layers): for direction in range(num_directions): suffix = '_reverse' if direction == 1 else '' weights = ['weight_ih_l{}{}', 'weight_hh_l{}{}', 'bias_ih_l{}{}', 'bias_hh_l{}{}'] weights = [x.format(layer, suffix) for x in weights] if self.bias: self._all_weights += [weights] else: self._all_weights += [weights[:2]] @property def all_weights(self): return [[getattr(self, weight) for weight in weights] for weights in self._all_weights]