Source code for torchex.nn.modules.lazy.linear

import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from .base import LazyBase
from ...init import feedforward_init

[docs]class Linear(LazyBase): ''' Examples:: import torch import torchex.nn as exnn net = exnn.Linear(10) # You don't need to give the size of input for this module. # This network is equivalent to `nn.Linear(100, 10)`. x = troch.randn(10, 100) y = net(x) ''' def __init__(self, in_features, out_features=None, use_bias=True, xavier_init=True): super(Linear, self).__init__() if out_features is None: self.in_features, self.out_features = None, in_features else: self.in_features = in_features self.out_features = out_features self.use_bias = use_bias self.xavier_init = xavier_init self.weight = nn.Parameter(None) self.bias = nn.Parameter(None) self._register_load_state_dict_pre_hook(self._lazy_load_state_dict_hook) def _lazy_load_state_dict_hook(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): for name, data in state_dict.items(): if prefix in name: if 'weight' in name: self.in_features = data.shape[-1] self.weight.data = data elif 'bias' in name: self.bias.data = data else: raise ValueError(name) def forward(self, x): if len(self.weight.data) == 0: self.in_features = x.shape[-1] self.weight.data = torch.Tensor(self.out_features, self.in_features) stdv = 1. / math.sqrt(self.weight.size(1)) self.weight.data.uniform_(-stdv, stdv) if self.use_bias: self.bias.data = torch.Tensor(self.out_features) self.bias.data.uniform_(-stdv, stdv) self.weight = self._to_device(self.weight) self.bias = self._to_device(self.bias) if self.xavier_init: feedforward_init(self) return F.linear(x, self.weight, self.bias) def extra_repr(self): return 'in_features={}, out_features={}, bias={}'.format( self.in_features, self.out_features, self.bias is not None )
if __name__ == "__main__": x = torch.randn(10, 2).to('cuda') net = Linear(10) net = net.to('cuda') y = net(x) print(y)