Source code for torchex.nn.modules.padding

import torch
import torch.nn as nn

[docs]class PeriodicPad2d(nn.Module): """ :params torch.Tensor input: Input(B, C, W, H) # https://github.com/ZichaoLong/aTEAM/blob/master/nn/functional/utils.py """ def __init__(self, pad_left: int=0, pad_right: int=0, pad_top: int=0, pad_bottom: int=0): super(PeriodicPad2d, self).__init__() self.__doc__ = 'hello' self.pad_left = pad_left self.pad_right = pad_right self.pad_top = pad_top self.pad_bottom = pad_bottom def forward(self, input): assert input.dim() == 4, 'only supports Input(B, C, H, W)' B, C, H, W = input.size() left_pad = input[:, :, :, W-(self.pad_left):] right_pad = input[:, :, :, :self.pad_right] input = torch.cat([left_pad, input, right_pad], dim=3) top_pad = input[:, :, H-(self.pad_top):, :] bottom_pad = input[:, :, :self.pad_bottom, :] input = torch.cat([top_pad, input, bottom_pad], dim=2) return input
[docs]class PeriodicPad3d(nn.Module): ''' Only support isotropic padding ''' def __init__(self, pad: int=0): super(PeriodicPad3d, self).__init__() self.pad = pad def forward(self, input): ''' :params torch.Tensor input: Input(B, C, D, W, H) ''' assert input.dim() == 5, 'only support Input(B, C, D, W, H)' B, C, D, H, W = input.size() pad_0 = input[:, :, D-(self.pad):, :, :] pad_1 = input[:, :, :self.pad, :, :] input = torch.cat([pad_0, input, pad_1], dim=2) pad_0 = input[:, :, :, H-(self.pad):, :] pad_1 = input[:, :, :, :self.pad, :] input = torch.cat([pad_0, input, pad_1], dim=3) pad_0 = input[:, :, :, :, W-(self.pad):] pad_1 = input[:, :, :, :, :self.pad] input = torch.cat([pad_0, input, pad_1], dim=4) return input
if __name__ == '__main__': x = torch.range(1, 25).view(1, 1, 5, 5) print(x) pad = PeriodicPad2d(2, 2, 2, 1) print(pad(x)) print(pad(x).shape) x = torch.range(1, 27).view(1, 1, 3, 3, 3) pad = PeriodicPad3d(1) print(pad(x))