import torch from torch import nn from torch.nn import Parameter import torch.nn.functional as F from functools import reduce import operator eps = 1e-8 class LinearVDO(nn.Module): """ Dense layer implementation with weights ARD-prior (arxiv:1701.05369) """ def __init__(self, in_features, out_features, bias=True, thresh=3, ard_init=-8.): super(LinearVDO, self).__init__() self.in_features = in_features self.out_features = out_features self.weight = Parameter(torch.Tensor(out_features, in_features)) self.thresh = thresh if bias: self.bias = Parameter(torch.Tensor(out_features)) else: self.register_parameter('bias', None) self.ard_init = ard_init self.log_alp = Parameter(torch.Tensor(out_features, in_features), requires_grad=True) self.reset_parameters() def forward(self, input): """ Forward with all regularized connections and random activations (Beyesian mode). Typically used for train """ # if self.training == False: return F.linear(input, self.weights_clipped, self.bias) W = self.weight mu = input.matmul(W.t()) eps = 1e-8 log_alp = self.log_alp in2 = input * input exp_ = torch.exp(log_alp) w2 = self.weight * self.weight var = in2.matmul(((exp_ * w2) + eps).t()) si = torch.sqrt(var) activation = mu + torch.normal(torch.zeros_like(mu), torch.ones_like(mu)) * si return activation + self.bias @property def weights_clipped(self): clip_mask = self.get_clip_mask() return torch.where(clip_mask, torch.zeros_like(self.weight), self.weight) def reset_parameters(self): self.weight.data.normal_(std=0.01) if self.bias is not None: self.bias.data.uniform_(0, 0) self.log_alp.data = self.ard_init * torch.ones_like(self.log_alp) @staticmethod def clip(tensor, to=10.): """ Shrink all tensor's values to range [-to,to] """ return torch.clamp(tensor, -to, to) @staticmethod def clip_alp(tensor, lwrb=20.): """ Shrink all tensor's values to range [-to,to] """ return torch.clamp(tensor, -lwrb, -eps) def get_clip_mask(self): log_alp = self.clip_alp(self.log_alp) return torch.ge(log_alp, self.thresh) def train(self, mode): self.training = mode super(LinearVDO, self).train(mode) def get_reg(self, **kwargs): """ Get weights regularization (KL(q(w)||p(w)) approximation) """ # a flexible reparameterization of variance k1 = 0.6134 k2 = 0.2026 k3 = 0.7126 log_alp = self.log_alp element_wise_kl = -.5 * torch.log(1 + 1. / (torch.exp(log_alp))) \ + k1 * torch.exp(-(k2 + k3 * log_alp) ** 2) sum_kl = element_wise_kl.mean(dim=(1,)) return - sum_kl.sum() # return -torch.mean(minus_kl) def extra_repr(self): return 'in_features={}, out_features={}, bias={}'.format( self.in_features, self.out_features, self.bias is not None ) def get_dropped_params_cnt(self): """ Get number of dropped weights (with log alpha greater than "thresh" parameter) :returns (number of dropped weights, number of all weight) """ return self.get_clip_mask().sum().cpu().numpy() @property def log_alpha(self): eps = 1e-8 return self.log_sigma2 - 2 * torch.log(torch.abs(self.weight) + eps) class Conv2dVDO(nn.Conv2d): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, ard_init=-1., thresh=3, weight_prob_fwd=True): bias = False # Goes to nan if bias = True super(Conv2dVDO, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) self.bias = None self.thresh = thresh self.in_channels = in_channels self.out_channels = out_channels self.ard_init = ard_init # self.log_sigma2 = Parameter(ard_init * torch.ones_like(self.weight)) self.log_alp = Parameter(ard_init * torch.ones_like(self.weight), requires_grad=True) self.weight_prob_fwd = weight_prob_fwd # self.log_sigma2 = Parameter(2 * torch.log(torch.abs(self.weight) + eps).clone().detach()+ard_init*torch.ones_like(self.weight)) @staticmethod def clip(tensor, to=8): """ Shrink all tensor's values to range [-to,to] """ return torch.clamp(tensor, -to, to) @staticmethod def clip_alp(tensor, lwrb=10.): """ Shrink all tensor's values to range [-to,to] """ return torch.clamp(tensor, -lwrb, -eps) def set_weight_prob_fwd(self, weight_prob_fwd): assert type(weight_prob_fwd) is bool self.weight_prob_fwd = weight_prob_fwd def forward(self, input): """ Forward with all regularized connections and random activations (Beyesian mode). Typically used for train """ if self.training == False and self.weight_prob_fwd == False: return F.conv2d(input, self.weights_clipped, self.bias, self.stride, self.padding, self.dilation, self.groups) eps = 1e-8 W = self.weight zeros = torch.zeros_like(W) clip_mask = self.get_clip_mask() conved_mu = F.conv2d(input, W, self.bias, self.stride, self.padding, self.dilation, self.groups) log_alpha = self.log_alp # log_alpha = self.log_alpha conved_si = torch.sqrt(eps + F.conv2d(input * input, torch.exp(log_alpha) * W * W, self.bias, self.stride, self.padding, self.dilation, self.groups)) conved = conved_mu + \ conved_si * torch.normal(torch.zeros_like(conved_mu), torch.ones_like(conved_mu)) return conved @property def weights_clipped(self): clip_mask = self.get_clip_mask() return torch.where(clip_mask, torch.zeros_like(self.weight), self.weight) def get_clip_mask(self): log_alp = self.clip_alp(self.log_alp) # log_alp = self.clip_alp(self.log_alpha) return torch.ge(log_alp, self.thresh) def train(self, mode): self.training = mode super(Conv2dVDO, self).train(mode) def get_reg(self, **kwargs): """ Get weights regularization (KL(q(w)||p(w)) approximation) """ # param 1 # k1 = 0.792 # k2 = -0.4826 # k3 = 0.3451 # param 2 k1 = 0.6134 k2 = 0.2026 k3 = 0.7126 log_alp = self.log_alp element_wise_kl = -.5 * torch.log(1 + 1./(torch.exp(log_alp))) \ + k1 * torch.exp(-(k2 + k3 * log_alp) ** 2) sum_kl = element_wise_kl.mean(dim=(1, 2, 3)) return - sum_kl.sum() # log_alp = self.clip_alp(self.log_alp) # # log_alp = self.clip_alp(self.log_alpha) # # mdkl = k1 * torch.sigmoid(k2 + k3 * log_alp2) - 0.5 * torch.log1p(torch.exp(-log_alp2)) + C # minus_kl = .5 * log_alp \ # + 1.16145124 * torch.exp(log_alp) \ # - 1.50204118 * torch.exp(log_alp)**2 \ # + 0.58629921 * torch.exp(log_alp)**3 # # return -torch.sum(minus_kl) def extra_repr(self): return 'in_features={}, out_features={}, bias={}'.format( self.in_channels, self.out_channels, self.bias is not None ) def get_dropped_params_cnt(self): """ Get number of dropped weights (greater than "thresh" parameter) :returns (number of dropped weights, number of all weight) """ return self.get_clip_mask().sum().cpu().numpy() @property def log_alpha(self): eps = 1e-8 return self.log_sigma2 - 2 * torch.log(torch.abs(self.weight) + eps)