WinstonHu's picture
Upload folder xtuner to code/xtuner
e5e24c9 verified
raw
history blame
17 kB
import math
import numbers
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from utils.utils import initialize_weights
from architecture.linear_vdo import LinearVDO, Conv2dVDO
import numpy as np
from torch.distributions import kl
EPS_1 = 1e-16
# EPS_2 = 1e-28
"""
Attention Network without Gating (2 fc layers)
args:
L: input feature dimension
D: hidden layer dimension
dropout: whether to use dropout (p = 0.25)
n_classes: number of classes
"""
class Attn_Net(nn.Module):
def __init__(self, L=1024, D=256, dropout=False, n_classes=1):
super(Attn_Net, self).__init__()
self.module = [
nn.Linear(L, D),
nn.Tanh()]
if dropout:
self.module.append(nn.Dropout(0.25))
self.module.append(nn.Linear(D, n_classes))
self.module = nn.Sequential(*self.module)
def forward(self, x):
return self.module(x), x # N x n_classes
"""
Attention Network with Sigmoid Gating (3 fc layers)
args:
L: input feature dimension
D: hidden layer dimension
dropout: whether to use dropout (p = 0.25)
n_classes: number of classes
"""
class Attn_Net_Gated(nn.Module):
def __init__(self, L=1024, D=256, dropout=False, n_classes=1):
super(Attn_Net_Gated, self).__init__()
ard_init = -1.
self.attention_a = [
LinearVDO(L, D, ard_init=ard_init),
nn.Tanh()]
self.attention_b = [LinearVDO(L, D, ard_init=ard_init),
nn.Sigmoid()]
if dropout:
self.attention_a.append(nn.Dropout(0.25))
self.attention_b.append(nn.Dropout(0.25))
self.attention_a = nn.Sequential(*self.attention_a)
self.attention_b = nn.Sequential(*self.attention_b)
self.attention_c = LinearVDO(D, n_classes, ard_init=ard_init)
def forward(self, x):
a = self.attention_a(x)
b = self.attention_b(x)
A = a.mul(b)
A = self.attention_c(A) # N x n_classes
return A, x
class DAttn_Net_Gated(nn.Module):
def __init__(self, L=1024, D=256, dropout=False, n_classes=1):
super(DAttn_Net_Gated, self).__init__()
self.attention_a = [
nn.Linear(L, D),
nn.Tanh()]
self.attention_b = [nn.Linear(L, D),
nn.Sigmoid()]
if dropout:
self.attention_a.append(nn.Dropout(0.25))
self.attention_b.append(nn.Dropout(0.25))
self.attention_a = nn.Sequential(*self.attention_a)
self.attention_b = nn.Sequential(*self.attention_b)
self.attention_c = nn.Linear(D, n_classes)
def forward(self, x):
a = self.attention_a(x)
b = self.attention_b(x)
A = a.mul(b)
A = self.attention_c(A) # N x n_classes
# print(x.shape)
return A, x
class GaussianSmoothing(nn.Module):
"""
Apply gaussian smoothing on a
1d, 2d or 3d tensor. Filtering is performed seperately for each channel
in the input using a depthwise convolution.
Arguments:
channels (int, sequence): Number of channels of the input tensors. Output will
have this number of channels as well.
kernel_size (int, sequence): Size of the gaussian kernel.
sigma (float, sequence): Standard deviation of the gaussian kernel.
dim (int, optional): The number of dimensions of the data.
Default value is 2 (spatial).
"""
def __init__(self, channels, kernel_size, sigma, dim=2):
super(GaussianSmoothing, self).__init__()
if isinstance(kernel_size, numbers.Number):
kernel_size = [kernel_size] * dim
if isinstance(sigma, numbers.Number):
sigma = [sigma] * dim
# The gaussian kernel is the product of the
# gaussian function of each dimension.
kernel = 1
meshgrids = torch.meshgrid(
[
torch.arange(size, dtype=torch.float32)
for size in kernel_size
]
)
for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
mean = (size - 1) / 2
kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \
torch.exp(-((mgrid - mean) / std) ** 2 / 2)
# Make sure sum of values in gaussian kernel equals 1.
kernel = kernel / torch.sum(kernel)
# Reshape to depthwise convolutional weight
kernel = kernel.view(1, 1, *kernel.size())
kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
self.register_buffer('weight', kernel)
self.groups = channels
if dim == 1:
self.conv = F.conv1d
elif dim == 2:
self.conv = F.conv2d
elif dim == 3:
self.conv = F.conv3d
else:
raise RuntimeError(
'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim)
)
def forward(self, input):
"""
Apply gaussian filter to input.
Arguments:
input (torch.Tensor): Input to apply gaussian filter on.
Returns:
filtered (torch.Tensor): Filtered output.
"""
# return self.conv(input, weight=self.weight, groups=self.groups, dilation=2)
return self.conv(input, weight=self.weight, groups=self.groups)
class probabilistic_MIL_Bayes_vis(nn.Module):
def __init__(self, gate=True, size_arg="small", dropout=False, n_classes=2, top_k=1):
super(probabilistic_MIL_Bayes_vis, self).__init__()
self.size_dict = {"small": [1024, 512, 256], "big": [1024, 512, 384]}
size = self.size_dict[size_arg]
fc = [nn.Linear(size[0], size[1]), nn.ReLU()]
if dropout:
fc.append(nn.Dropout(0.25))
if gate:
attention_net = Attn_Net_Gated(L=size[1], D=size[2], dropout=dropout, n_classes=2)
else:
attention_net = Attn_Net(L=size[1], D=size[2], dropout=dropout, n_classes=1)
fc.append(attention_net)
self.attention_net = nn.Sequential(*fc)
self.classifiers = LinearVDO(size[1], n_classes, ard_init=-3.)
self.n_classes = n_classes
self.print_sample_trigger = False
self.num_samples = 16
self.temperature = torch.tensor([1.0])
self.fixed_b = torch.tensor([5.], requires_grad=False)
initialize_weights(self)
self.top_k = top_k
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def relocate(self):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.attention_net = self.attention_net.to(device)
self.classifiers = self.classifiers.to(device)
self.temperature = self.temperature.to(device)
def forward(self, h, validation=False):
device = h.device
# *-*# A, h = self.attention_net(h) # NxK
A, h = self.attention_net(h)
mu = A[:, 0]
logvar = A[:, 1]
gaus_samples = self.reparameterize(mu, logvar)
beta_samples = F.sigmoid(gaus_samples)
A = beta_samples.unsqueeze(0)
# print('gaus max: {0:.4f}, gaus min: {1:.4f}.'.format(torch.max(gaus_samples), torch.min(gaus_samples)))
# print('sample max: {0:.4f}, sample min: {1:.4f}.'.format(torch.max(A), torch.min(A)))
M = torch.mm(A, h) / A.sum()
logits = self.classifiers(M)
y_probs = F.softmax(logits, dim=1)
top_instance_idx = torch.topk(y_probs[:, 1], self.top_k, dim=0)[1].view(1, )
top_instance = torch.index_select(logits, dim=0, index=top_instance_idx)
Y_hat = torch.topk(top_instance, 1, dim=1)[1]
Y_prob = F.softmax(top_instance, dim=1)
# results_dict = {}
# if return_features:
# top_features = torch.index_select(h, dim=0, index=top_instance_idx)
# results_dict.update({'features': top_features})
return top_instance, Y_prob, Y_hat, y_probs, A
class probabilistic_MIL_Bayes_enc(nn.Module):
def __init__(self, gate=True, size_arg="small", dropout=False, n_classes=2, top_k=1):
super(probabilistic_MIL_Bayes_enc, self).__init__()
self.size_dict = {"small": [1024, 512, 256], "big": [1024, 512, 384]}
size = self.size_dict[size_arg]
first_transform = nn.Linear(size[0], size[1])
fc1 = [first_transform, nn.ReLU()]
if dropout:
fc1.append(nn.Dropout(0.25))
if gate:
postr_net = Attn_Net_Gated(L=size[1], D=size[2], dropout=dropout, n_classes=2)
else:
postr_net = Attn_Net(L=size[1], D=size[2], dropout=dropout, n_classes=1)
fc1.append(postr_net)
self.postr_net = nn.Sequential(*fc1)
self.classifiers = LinearVDO(size[1], n_classes, ard_init=-3.)
self.n_classes = n_classes
self.print_sample_trigger = False
self.num_samples = 16
self.temperature = torch.tensor([1.0])
self.prior_mu = torch.tensor([-5., 0.])
self.prior_logvar = torch.tensor([-1., 3.])
initialize_weights(self)
self.top_k = top_k
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def relocate(self):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# self.attention_net = self.attention_net.to(device)
self.postr_net = self.postr_net.to(device)
# self.prior_net = self.prior_net.to(device)
self.classifiers = self.classifiers.to(device)
self.temperature = self.temperature.to(device)
self.prior_mu = self.prior_mu.to(device)
self.prior_logvar = self.prior_logvar.to(device)
def kl_logistic_normal(self, mu_pr, mu_pos, logvar_pr, logvar_pos):
return (logvar_pr - logvar_pos) / 2. + (logvar_pos ** 2 + (mu_pr - mu_pos) ** 2) / (2. * logvar_pr ** 2) - 0.5
def forward(self, h, return_features=False, slide_label=None, validation=False):
device = h.device
# *-*# A, h = self.attention_net(h) # NxK
param, h = self.postr_net(h)
mu = param[:, 0]
logvar = param[:, 1]
gaus_samples = self.reparameterize(mu, logvar)
beta_samples = F.sigmoid(gaus_samples)
A = beta_samples.unsqueeze(0)
if not validation:
mu_pr = self.prior_mu[slide_label.item()].expand(h.shape[0])
logvar_pr = self.prior_logvar[slide_label.item()]
kl_div = self.kl_logistic_normal(mu_pr, mu, logvar_pr, logvar)
else:
kl_div = None
M = torch.mm(A, h) / A.sum()
logits = self.classifiers(M)
y_probs = F.softmax(logits, dim=1)
top_instance_idx = torch.topk(y_probs[:, 1], self.top_k, dim=0)[1].view(1, )
top_instance = torch.index_select(logits, dim=0, index=top_instance_idx)
Y_hat = torch.topk(top_instance, 1, dim=1)[1]
Y_prob = F.softmax(top_instance, dim=1)
results_dict = {}
if return_features:
top_features = torch.index_select(h, dim=0, index=top_instance_idx)
results_dict.update({'features': top_features})
if not validation:
return top_instance, Y_prob, Y_hat, kl_div, y_probs, A
else:
return top_instance, Y_prob, Y_hat, y_probs, A
class probabilistic_MIL_Bayes_spvis(nn.Module):
def __init__(self, conf, size_arg="small", top_k=1):
super(probabilistic_MIL_Bayes_spvis, self).__init__()
# self.size_dict = {"small": [1024, 512, 256], "big": [1024, 512, 384]}
self.size_dict = {"small": [conf.feat_d, 512, 256], "big": [conf.feat_d, 512, 384]}
size = self.size_dict[size_arg]
ard_init = -4.
self.linear1 = nn.Linear(size[0], size[1])
self.linear2a = LinearVDO(size[1], size[2], ard_init=ard_init)
self.linear2b = LinearVDO(size[1], size[2], ard_init=ard_init)
self.linear3 = LinearVDO(size[2], 2, ard_init=ard_init)
self.gaus_smoothing = GaussianSmoothing(1, 3, 0.5)
self.classifiers = LinearVDO(size[1], conf.n_class, ard_init=-3.)
self.dp_0 = nn.Dropout(0.25)
self.dp_a = nn.Dropout(0.25)
self.dp_b = nn.Dropout(0.25)
self.prior_mu = torch.tensor([-5., 0.])
self.prior_logvar = torch.tensor([-1., 3.])
initialize_weights(self)
self.top_k = top_k
self.patch_size = conf.patch_size
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def kl_logistic_normal(self, mu_pr, mu_pos, logvar_pr, logvar_pos):
return (logvar_pr - logvar_pos) / 2. + (logvar_pos ** 2 + (mu_pr - mu_pos) ** 2) / (2. * logvar_pr ** 2) - 0.5
def relocate(self):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.linear1 = self.linear1.to(device)
self.linear2a = self.linear2a.to(device)
self.linear2b = self.linear2b.to(device)
self.linear3 = self.linear3.to(device)
self.dp_0 = self.dp_0.to(device)
self.dp_a = self.dp_a.to(device)
self.dp_b = self.dp_b.to(device)
self.gaus_smoothing = self.gaus_smoothing.to(device)
self.prior_mu = self.prior_mu.to(device)
self.prior_logvar = self.prior_logvar.to(device)
self.classifiers = self.classifiers.to(device)
def forward(self, h, coords, height, width, slide_label=None, validation=False):
h = h[0]
device = h.device
h = F.relu(self.dp_0(self.linear1(h)))
feat_a = self.dp_a(torch.sigmoid(self.linear2a(h)))
feat_b = self.dp_b(torch.tanh(self.linear2b(h)))
feat = feat_a.mul(feat_b)
params = self.linear3(feat)
coords = coords // self.patch_size
asign = lambda coord: coord[:, 0] + coord[:, 1] * (width // self.patch_size)
coords = asign(coords)
coords = torch.from_numpy(coords).to(device)
mu = torch.zeros([1, (height // self.patch_size + 1) * (width // self.patch_size + 1)]).to(device)
logvar = torch.zeros([1, (height // self.patch_size + 1) * (width // self.patch_size + 1)]).to(device)
mu[:, coords.long()] = params[:, 0]
logvar[:, coords.long()] = params[:, 1]
mu = mu.view(1, height // self.patch_size + 1, width // self.patch_size + 1)
logvar = logvar.view(1, height // self.patch_size + 1, width // self.patch_size + 1)
if not validation:
mu_pr = self.prior_mu[slide_label.item()].expand_as(mu)
logvar_pr = self.prior_logvar[slide_label.item()]
kl_div = self.kl_logistic_normal(mu_pr, mu, logvar_pr, logvar)
else:
kl_div = None
# # no branch
mu = F.pad(mu, (1, 1, 1, 1), mode='constant', value=0)
mu = torch.unsqueeze(mu, dim=0)
mu = self.gaus_smoothing(mu)
gaus_samples = self.reparameterize(mu, logvar)
gaus_samples = torch.squeeze(gaus_samples, dim=0)
A = F.sigmoid(gaus_samples)
A = A.view(1, -1)
patch_A = torch.index_select(A, dim=1, index=coords)
M = torch.mm(patch_A, h) / patch_A.sum()
logits = self.classifiers(M)
y_probs = F.softmax(logits, dim=1)
top_instance_idx = torch.topk(y_probs[:, 1], self.top_k, dim=0)[1].view(1, )
top_instance = torch.index_select(logits, dim=0, index=top_instance_idx)
Y_hat = torch.topk(top_instance, 1, dim=1)[1]
Y_prob = F.softmax(top_instance, dim=1)
if not validation:
return top_instance, Y_prob, Y_hat, kl_div, y_probs, patch_A.view((1, -1))
else:
return top_instance, Y_prob, Y_hat, y_probs, patch_A.view((1, -1))
def get_ard_reg_vdo(module, reg=0):
"""
:param module: model to evaluate ard regularization for
:param reg: auxilary cumulative variable for recursion
:return: total regularization for module
"""
if isinstance(module, LinearVDO) or isinstance(module, Conv2dVDO): return reg + module.get_reg()
if hasattr(module, 'children'): return reg + sum([get_ard_reg_vdo(submodule) for submodule in module.children()])
return reg
bMIL_model_dict = {
'vis': probabilistic_MIL_Bayes_vis,
'enc': probabilistic_MIL_Bayes_enc,
'spvis': probabilistic_MIL_Bayes_spvis,
}