Source code for torchex.nn.modules.util

import collections.abc
from itertools import repeat

import torch.nn as nn

def _ntuple(n):
    def parse(x):
        if isinstance(x, collections.abc.Iterable):
            return x
        return tuple(repeat(x, n))
    return parse

_single = _ntuple(1)
_pair = _ntuple(2)
_triple = _ntuple(3)
_quadruple = _ntuple(4)


[docs]class Flatten(nn.Module): def forward(self, input): return input.contiguous().view(input.size(0), -1)
class Pass(nn.Module): ''' Nothint to do. ''' def forward(self, input): return input