Source code for torchex.nn.modules.pooling

import torch.nn as nn
import torch.nn.functional as F

[docs]class GlobalAvgPool1d(nn.Module): def __init__(self): super(GlobalAvgPool1d, self).__init__() def forward(self, x): assert len(x.size()) == 3, x.size() B, C, L = x.size() return F.avg_pool1d(x, L)
[docs]class GlobalAvgPool2d(nn.Module): def __init__(self): super(GlobalAvgPool2d, self).__init__() def forward(self, x): assert len(x.size()) == 4, x.size() B, C, W, H = x.size() return F.avg_pool2d(x, (W, H))
[docs]class GlobalMaxPool1d(nn.Module): def __init__(self): super(GlobalMaxPool1d, self).__init__() def forward(self, x): assert len(x.size()) == 3, x.size() B, C, L = x.size() return F.avg_pool1d(x, L)
[docs]class GlobalMaxPool2d(nn.Module): def __init__(self): super(GlobalMaxPool2d, self).__init__() def forward(self, x): assert len(x.size()) == 4, x.size() B, C, W, H = x.size() return F.max_pool2d(x, (W, H))
[docs]class MaxAvgPool2d(nn.Module): ''' :param kernel_size: the size of the window to take a max and average over :param stride: the size of stride to move kernel :param padding: implicit zero padding to be added on both sides :param dilation: a parameter that controls the stride of elements in the window :param return_indices: if True, will return the max indices along with the outputs. Useful when Unpooling later :param ceil_mode: when True, will use ceil instead of floor to compute the output shape :type kernel_size: int or list :type stride: int or list ''' def __init__(self, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False, count_include_pad=True): super(MaxAvgPool2d, self).__init__() kwargs = dict( kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, return_indices=return_indices, ceil_mode=ceil_mode) self.max_pool = nn.MaxPool2d(**kwargs) del kwargs['dilation'] del kwargs['return_indices'] kwargs['count_include_pad'] = count_include_pad self.avg_pool = nn.AvgPool2d(**kwargs) def forward(self, x): mx = self.max_pool(x) ax = self.avg_pool(x) return mx + ax