Source code for torchex.nn.modules.lazy.conv

import math

import torch
import torch.nn as nn
from torch.nn import init
import torch.nn.functional as F

from .utils import _single, _pair, _triple
from .base import LazyBase
from ...init import feedforward_init

class _ConvNd(LazyBase):
    def __init__(self, in_channels, out_channels, kernel_size, stride,
                 padding, dilation, transposed, output_padding, groups, bias):
        super(_ConvNd, self).__init__()
        if out_channels % groups != 0:
            raise ValueError('out_channels must be divisible by groups')

        self.in_channels = in_channels                    
        self.out_channels = out_channels
        self.kernel_size = kernel_size

        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.transposed = transposed
        self.output_padding = output_padding
        self.groups = groups

        self.weight = nn.Parameter(None)
        self.bias = nn.Parameter(None)
        self.bias_flag = bias
        self._register_load_state_dict_pre_hook(self._lazy_load_state_dict_hook)        

    def _initialize_weight(self, in_channels, xavier_init: bool):

        self.in_channels = in_channels
        if self.in_channels % self.groups != 0:
            raise ValueError('in_channels must be divisible by groups')

        if self.transposed:
            self.weight.data = torch.Tensor(self.in_channels,
                                            self.out_channels // self.groups,
                                            *self.kernel_size)
        else:
            self.weight.data = torch.Tensor(self.out_channels,
                                            self.in_channels // self.groups,
                                            *self.kernel_size)
                
        if self.bias_flag:
            self.bias.data = torch.Tensor(self.out_channels)

        self.weight = self._to_device(self.weight)
        self.bias = self._to_device(self.bias)            
        self._reset_parameters()
        
        if xavier_init:
            feedforward_init(self)        

    def _reset_parameters(self):
        n = self.in_channels
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)

    def _lazy_load_state_dict_hook(self, state_dict, prefix, local_metadata, strict,
                                   missing_keys, unexpected_keys, error_msgs):
        for name, data in state_dict.items():
            if prefix in name:
                if 'weight' in name:
                    self.in_channels = data.shape[0]
                    self.weight.data = data
                elif 'bias' in name:
                    self.bias.data = data
                else:
                    raise ValueError(name)
            
    def extra_repr(self):
        s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
             ', stride={stride}')
        if self.padding != (0,) * len(self.padding):
            s += ', padding={padding}'
        if self.dilation != (1,) * len(self.dilation):
            s += ', dilation={dilation}'
        if self.output_padding != (0,) * len(self.output_padding):
            s += ', output_padding={output_padding}'
        if self.groups != 1:
            s += ', groups={groups}'
        if self.bias is None:
            s += ', bias=False'
        return s.format(**self.__dict__)


[docs]class Conv1d(_ConvNd): ''' :param out_channels: the size of the window to take a max and average over :param kernel_size: the size of the window to take a max and average over :param stride: the size of stride to move kernel :param padding: implicit zero padding to be added on both sides :param dilation: a parameter that controls the stride of elements in the window :param return_indices: if True, will return the max indices along with the outputs. Useful when Unpooling later :param ceil_mode: when True, will use ceil instead of floor to compute the output shape :type kernel_size: int or list :type stride: int or list Examples:: import torch import torchex.nn as exnn net = exnn.Conv1d(10, 2) # You don't need to give the size of input for this module. # This network is equivalent to `nn.Conv1d(3, 10, 2)`. x = troch.randn(10, 3, 28) y = net(x) ''' def __init__(self, in_channels: int, out_channels: int, kernel_size: int or list=None, stride: int or list=1, padding: int=0, dilation: int=1, groups: int=1, bias: bool=True, xavier_init: bool=True): if kernel_size is None: kernel_size = out_channels out_channels = in_channels in_channels = None else: in_channels = in_channels out_channels = out_channels kernel_size = kernel_size kernel_size = _single(kernel_size) stride = _single(stride) padding = _single(padding) dilation = _single(dilation) self.xavier_init = xavier_init super(Conv1d, self).__init__( in_channels, out_channels, kernel_size, stride, padding, dilation, False, _single(0), groups, bias) def forward(self, x): if len(self.weight) == 0: in_channels = x.size(1) self._initialize_weight(in_channels, self.xavier_init) return F.conv1d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
[docs]class Conv2d(_ConvNd): ''' Examples:: import torch import torchex.nn as exnn net = exnn.Conv2d(10, 2) # You don't need to give the size of input for this module. # This network is equivalent to `nn.Conv2d(3, 10, 2)`. x = troch.randn(10, 3, 28, 28) y = net(x) ''' def __init__(self, in_channels: int, out_channels: int, kernel_size: int or list=None, stride: int or list=1, padding: int=0, dilation: int=1, groups: int=1, bias: bool=True, xavier_init: bool=True): if kernel_size is None: kernel_size = out_channels out_channels = in_channels in_channels = None else: in_channels = in_channels out_channels = out_channels kernel_size = kernel_size kernel_size = _pair(kernel_size) stride = _pair(stride) padding = _pair(padding) dilation = _pair(dilation) self.xavier_init = xavier_init super(Conv2d, self).__init__( in_channels, out_channels, kernel_size, stride, padding, dilation, False, _pair(0), groups, bias) def forward(self, x): if len(self.weight) == 0: in_channels = x.size(1) self._initialize_weight(in_channels, self.xavier_init) return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
[docs]class Conv3d(_ConvNd): ''' Examples:: import torch import torchex.nn as exnn net = exnn.Conv3d(10, 2) # You don't need to give the size of input for this module. # This network is equivalent to `nn.Conv3d(3, 10, 2)`. x = troch.randn(10, 3, 100, 28, 28) y = net(x) ''' def __init__(self, in_channels: int, out_channels: int, kernel_size: int or list=None, stride: int or list=1, padding: int=0, dilation: int=1, groups: int=1, bias: bool=True, xavier_init: bool=True): if kernel_size is None: kernel_size = out_channels out_channels = in_channels in_channels = None else: in_channels = in_channels out_channels = out_channels kernel_size = kernel_size kernel_size = _triple(kernel_size) stride = _triple(stride) padding = _triple(padding) dilation = _triple(dilation) self.xavier_init = xavier_init super(Conv3d, self).__init__( in_channels, out_channels, kernel_size, stride, padding, dilation, False, _triple(0), groups, bias) def forward(self, x): if len(self.weight) == 0: in_channels = x.size(1) self._initialize_weight(in_channels, self.xavier_init) return F.conv3d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)