Source code for torchex.nn.modules.rnn.basic

from typing import Tuple, Union

import torch.nn as nn
from torch import Tensor
from torch.nn.utils.rnn import PackedSequence

from ...init import rnn_init


[docs]class SequenceLinear(nn.Module): ''' x[T, B, F] -> nn.Linear(x[T, BxF]) -> x[T, B, F] if batch_first is True, x[B, T, F] -> x[T, B, F] -> nn.Linear(x[T, BxF]) -> x[T, B, F] ''' def __init__(self, in_features, out_features, batch_first=True, nonlinearity='sigmoid', init_mean=0, init_std=1, init_xavier: bool=True, init_normal: bool=True, init_gain = None, dropout=0.0, bias=True, ): super(SequenceLinear, self).__init__() self.linear = nn.Linear(in_features, out_features, bias) self.batch_first = batch_first self.out_features = out_features self.nonlinearity = nonlinearity if not init_gain: init_gain = nn.init.calculate_gain(nonlinearity) feedforward_init(self.linear, init_mean, init_std, init_xavier, init_normal, init_gain) def forward(self, x: Tensor): assert len(x.size()) == 3, x.size() if self.batch_first: B, T, F = x.size() x = torch.transpose(x, 0, 1) else: T, B, F = x.size() x = x.contiguous() x = x.view(T*B, F) x = self.linear(x) x = getattr(TF, self.nonlinearity)(x) if self.batch_first: x = x.view(T, B, self.out_features) x = torch.transpose(x, 0, 1) else: x = x.view(T, B, self.out_features) return x
class LSTM(nn.Module): """Basic LSTM model. """ def __init__(self, input_size, hidden_size, num_layers=1, bidirectional=False, dropout=0.0, batch_first=True, init_xavier: bool=True, init_normal: bool=True, init_gain: float=1.0, concat: bool=True, init_mean=0, init_std=0.1, init_lower=0, init_upper=0.04, concat_between_layers=False, ): super(LSTMModel, self).__init__() self.in_size = input_size self.hidden_size = hidden_size self.num_layers = num_layers self.bidirectional = bidirectional self.concat = concat self.concat_between_layers = concat_between_layers if concat_between_layers and bidirectional: _rnn = [] _rnn.append(nn.LSTM(input_size, hidden_size, num_layers=1, bidirectional=bidirectional, dropout=dropout, batch_first=batch_first)) for i in range(num_layers - 1): _rnn.append(nn.LSTM(2 * hidden_size, hidden_size, num_layers=1, bidirectional=bidirectional, dropout=dropout, batch_first=batch_first)) for irnn in _rnn: rnn_init(irnn, init_xavier=init_xavier, init_normal=init_normal, init_gain=init_gain, init_mean=init_mean, init_std=init_std) self.rnn = nn.ModuleList(_rnn) else: self.rnn =\ nn.LSTM(input_size, hidden_size, num_layers, bidirectional=bidirectional, dropout=dropout, batch_first=batch_first, ) rnn_init(self.rnn, init_xavier=init_xavier, init_normal=init_normal, init_gain=init_gain, init_mean=init_mean, init_std=init_std) def forward(self, x: Union[Tensor, PackedSequence], hx: Union[Tensor, Tuple[Tensor, ...]]=None) ->\ Tuple[Tensor, Tensor]: assert isinstance(x, Tensor) or\ isinstance(x, PackedSequence), type(x) if self.concat_between_layers and self.bidirectional: output = x for idx, _rnn in enumerate(self.rnn): _rnn.flatten_parameters() output, hx = _rnn(output) _rnn.flatten_parameters() else: self.rnn.flatten_parameters() output, hx = self.rnn(x, hx) self.rnn.flatten_parameters() if (not self.concat) and self.bidirectional: B, T, F = output.size() output = output[:, :, :F//2] + output[:, :, F//2:] return output, hx