|
|
import math
|
|
|
import numbers
|
|
|
import time
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
from torch import Tensor
|
|
|
from utils.utils import initialize_weights
|
|
|
from architecture.linear_vdo import LinearVDO, Conv2dVDO
|
|
|
import numpy as np
|
|
|
from torch.distributions import kl
|
|
|
|
|
|
EPS_1 = 1e-16
|
|
|
|
|
|
|
|
|
"""
|
|
|
Attention Network without Gating (2 fc layers)
|
|
|
args:
|
|
|
L: input feature dimension
|
|
|
D: hidden layer dimension
|
|
|
dropout: whether to use dropout (p = 0.25)
|
|
|
n_classes: number of classes
|
|
|
"""
|
|
|
|
|
|
|
|
|
class Attn_Net(nn.Module):
|
|
|
|
|
|
def __init__(self, L=1024, D=256, dropout=False, n_classes=1):
|
|
|
super(Attn_Net, self).__init__()
|
|
|
self.module = [
|
|
|
nn.Linear(L, D),
|
|
|
nn.Tanh()]
|
|
|
|
|
|
if dropout:
|
|
|
self.module.append(nn.Dropout(0.25))
|
|
|
|
|
|
self.module.append(nn.Linear(D, n_classes))
|
|
|
|
|
|
self.module = nn.Sequential(*self.module)
|
|
|
|
|
|
def forward(self, x):
|
|
|
return self.module(x), x
|
|
|
|
|
|
|
|
|
"""
|
|
|
Attention Network with Sigmoid Gating (3 fc layers)
|
|
|
args:
|
|
|
L: input feature dimension
|
|
|
D: hidden layer dimension
|
|
|
dropout: whether to use dropout (p = 0.25)
|
|
|
n_classes: number of classes
|
|
|
"""
|
|
|
|
|
|
|
|
|
class Attn_Net_Gated(nn.Module):
|
|
|
def __init__(self, L=1024, D=256, dropout=False, n_classes=1):
|
|
|
super(Attn_Net_Gated, self).__init__()
|
|
|
ard_init = -1.
|
|
|
self.attention_a = [
|
|
|
LinearVDO(L, D, ard_init=ard_init),
|
|
|
nn.Tanh()]
|
|
|
|
|
|
self.attention_b = [LinearVDO(L, D, ard_init=ard_init),
|
|
|
nn.Sigmoid()]
|
|
|
if dropout:
|
|
|
self.attention_a.append(nn.Dropout(0.25))
|
|
|
self.attention_b.append(nn.Dropout(0.25))
|
|
|
|
|
|
self.attention_a = nn.Sequential(*self.attention_a)
|
|
|
self.attention_b = nn.Sequential(*self.attention_b)
|
|
|
|
|
|
self.attention_c = LinearVDO(D, n_classes, ard_init=ard_init)
|
|
|
|
|
|
def forward(self, x):
|
|
|
a = self.attention_a(x)
|
|
|
b = self.attention_b(x)
|
|
|
A = a.mul(b)
|
|
|
A = self.attention_c(A)
|
|
|
return A, x
|
|
|
|
|
|
|
|
|
class DAttn_Net_Gated(nn.Module):
|
|
|
def __init__(self, L=1024, D=256, dropout=False, n_classes=1):
|
|
|
super(DAttn_Net_Gated, self).__init__()
|
|
|
self.attention_a = [
|
|
|
nn.Linear(L, D),
|
|
|
nn.Tanh()]
|
|
|
|
|
|
self.attention_b = [nn.Linear(L, D),
|
|
|
nn.Sigmoid()]
|
|
|
if dropout:
|
|
|
self.attention_a.append(nn.Dropout(0.25))
|
|
|
self.attention_b.append(nn.Dropout(0.25))
|
|
|
|
|
|
self.attention_a = nn.Sequential(*self.attention_a)
|
|
|
self.attention_b = nn.Sequential(*self.attention_b)
|
|
|
|
|
|
self.attention_c = nn.Linear(D, n_classes)
|
|
|
|
|
|
def forward(self, x):
|
|
|
a = self.attention_a(x)
|
|
|
b = self.attention_b(x)
|
|
|
A = a.mul(b)
|
|
|
A = self.attention_c(A)
|
|
|
|
|
|
return A, x
|
|
|
|
|
|
|
|
|
class GaussianSmoothing(nn.Module):
|
|
|
"""
|
|
|
Apply gaussian smoothing on a
|
|
|
1d, 2d or 3d tensor. Filtering is performed seperately for each channel
|
|
|
in the input using a depthwise convolution.
|
|
|
Arguments:
|
|
|
channels (int, sequence): Number of channels of the input tensors. Output will
|
|
|
have this number of channels as well.
|
|
|
kernel_size (int, sequence): Size of the gaussian kernel.
|
|
|
sigma (float, sequence): Standard deviation of the gaussian kernel.
|
|
|
dim (int, optional): The number of dimensions of the data.
|
|
|
Default value is 2 (spatial).
|
|
|
"""
|
|
|
|
|
|
def __init__(self, channels, kernel_size, sigma, dim=2):
|
|
|
super(GaussianSmoothing, self).__init__()
|
|
|
if isinstance(kernel_size, numbers.Number):
|
|
|
kernel_size = [kernel_size] * dim
|
|
|
if isinstance(sigma, numbers.Number):
|
|
|
sigma = [sigma] * dim
|
|
|
|
|
|
|
|
|
|
|
|
kernel = 1
|
|
|
meshgrids = torch.meshgrid(
|
|
|
[
|
|
|
torch.arange(size, dtype=torch.float32)
|
|
|
for size in kernel_size
|
|
|
]
|
|
|
)
|
|
|
|
|
|
for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
|
|
|
mean = (size - 1) / 2
|
|
|
kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \
|
|
|
torch.exp(-((mgrid - mean) / std) ** 2 / 2)
|
|
|
|
|
|
|
|
|
kernel = kernel / torch.sum(kernel)
|
|
|
|
|
|
|
|
|
kernel = kernel.view(1, 1, *kernel.size())
|
|
|
kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
|
|
|
|
|
|
self.register_buffer('weight', kernel)
|
|
|
self.groups = channels
|
|
|
|
|
|
if dim == 1:
|
|
|
self.conv = F.conv1d
|
|
|
elif dim == 2:
|
|
|
self.conv = F.conv2d
|
|
|
elif dim == 3:
|
|
|
self.conv = F.conv3d
|
|
|
else:
|
|
|
raise RuntimeError(
|
|
|
'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim)
|
|
|
)
|
|
|
|
|
|
def forward(self, input):
|
|
|
"""
|
|
|
Apply gaussian filter to input.
|
|
|
Arguments:
|
|
|
input (torch.Tensor): Input to apply gaussian filter on.
|
|
|
Returns:
|
|
|
filtered (torch.Tensor): Filtered output.
|
|
|
"""
|
|
|
|
|
|
return self.conv(input, weight=self.weight, groups=self.groups)
|
|
|
|
|
|
|
|
|
class probabilistic_MIL_Bayes_vis(nn.Module):
|
|
|
def __init__(self, gate=True, size_arg="small", dropout=False, n_classes=2, top_k=1):
|
|
|
super(probabilistic_MIL_Bayes_vis, self).__init__()
|
|
|
self.size_dict = {"small": [1024, 512, 256], "big": [1024, 512, 384]}
|
|
|
size = self.size_dict[size_arg]
|
|
|
fc = [nn.Linear(size[0], size[1]), nn.ReLU()]
|
|
|
if dropout:
|
|
|
fc.append(nn.Dropout(0.25))
|
|
|
if gate:
|
|
|
attention_net = Attn_Net_Gated(L=size[1], D=size[2], dropout=dropout, n_classes=2)
|
|
|
else:
|
|
|
attention_net = Attn_Net(L=size[1], D=size[2], dropout=dropout, n_classes=1)
|
|
|
fc.append(attention_net)
|
|
|
self.attention_net = nn.Sequential(*fc)
|
|
|
self.classifiers = LinearVDO(size[1], n_classes, ard_init=-3.)
|
|
|
self.n_classes = n_classes
|
|
|
self.print_sample_trigger = False
|
|
|
self.num_samples = 16
|
|
|
self.temperature = torch.tensor([1.0])
|
|
|
self.fixed_b = torch.tensor([5.], requires_grad=False)
|
|
|
|
|
|
initialize_weights(self)
|
|
|
self.top_k = top_k
|
|
|
|
|
|
def reparameterize(self, mu, logvar):
|
|
|
std = torch.exp(0.5 * logvar)
|
|
|
eps = torch.randn_like(std)
|
|
|
return mu + eps * std
|
|
|
|
|
|
def relocate(self):
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
self.attention_net = self.attention_net.to(device)
|
|
|
self.classifiers = self.classifiers.to(device)
|
|
|
self.temperature = self.temperature.to(device)
|
|
|
|
|
|
def forward(self, h, validation=False):
|
|
|
device = h.device
|
|
|
|
|
|
|
|
|
A, h = self.attention_net(h)
|
|
|
|
|
|
mu = A[:, 0]
|
|
|
logvar = A[:, 1]
|
|
|
gaus_samples = self.reparameterize(mu, logvar)
|
|
|
beta_samples = F.sigmoid(gaus_samples)
|
|
|
A = beta_samples.unsqueeze(0)
|
|
|
|
|
|
|
|
|
|
|
|
M = torch.mm(A, h) / A.sum()
|
|
|
logits = self.classifiers(M)
|
|
|
y_probs = F.softmax(logits, dim=1)
|
|
|
top_instance_idx = torch.topk(y_probs[:, 1], self.top_k, dim=0)[1].view(1, )
|
|
|
top_instance = torch.index_select(logits, dim=0, index=top_instance_idx)
|
|
|
Y_hat = torch.topk(top_instance, 1, dim=1)[1]
|
|
|
Y_prob = F.softmax(top_instance, dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return top_instance, Y_prob, Y_hat, y_probs, A
|
|
|
|
|
|
|
|
|
class probabilistic_MIL_Bayes_enc(nn.Module):
|
|
|
def __init__(self, gate=True, size_arg="small", dropout=False, n_classes=2, top_k=1):
|
|
|
super(probabilistic_MIL_Bayes_enc, self).__init__()
|
|
|
self.size_dict = {"small": [1024, 512, 256], "big": [1024, 512, 384]}
|
|
|
size = self.size_dict[size_arg]
|
|
|
first_transform = nn.Linear(size[0], size[1])
|
|
|
fc1 = [first_transform, nn.ReLU()]
|
|
|
|
|
|
if dropout:
|
|
|
fc1.append(nn.Dropout(0.25))
|
|
|
|
|
|
if gate:
|
|
|
postr_net = Attn_Net_Gated(L=size[1], D=size[2], dropout=dropout, n_classes=2)
|
|
|
else:
|
|
|
postr_net = Attn_Net(L=size[1], D=size[2], dropout=dropout, n_classes=1)
|
|
|
|
|
|
fc1.append(postr_net)
|
|
|
|
|
|
self.postr_net = nn.Sequential(*fc1)
|
|
|
self.classifiers = LinearVDO(size[1], n_classes, ard_init=-3.)
|
|
|
|
|
|
self.n_classes = n_classes
|
|
|
self.print_sample_trigger = False
|
|
|
self.num_samples = 16
|
|
|
self.temperature = torch.tensor([1.0])
|
|
|
self.prior_mu = torch.tensor([-5., 0.])
|
|
|
self.prior_logvar = torch.tensor([-1., 3.])
|
|
|
|
|
|
initialize_weights(self)
|
|
|
self.top_k = top_k
|
|
|
|
|
|
def reparameterize(self, mu, logvar):
|
|
|
std = torch.exp(0.5 * logvar)
|
|
|
eps = torch.randn_like(std)
|
|
|
return mu + eps * std
|
|
|
|
|
|
def relocate(self):
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
self.postr_net = self.postr_net.to(device)
|
|
|
|
|
|
self.classifiers = self.classifiers.to(device)
|
|
|
self.temperature = self.temperature.to(device)
|
|
|
self.prior_mu = self.prior_mu.to(device)
|
|
|
self.prior_logvar = self.prior_logvar.to(device)
|
|
|
|
|
|
def kl_logistic_normal(self, mu_pr, mu_pos, logvar_pr, logvar_pos):
|
|
|
return (logvar_pr - logvar_pos) / 2. + (logvar_pos ** 2 + (mu_pr - mu_pos) ** 2) / (2. * logvar_pr ** 2) - 0.5
|
|
|
|
|
|
def forward(self, h, return_features=False, slide_label=None, validation=False):
|
|
|
device = h.device
|
|
|
|
|
|
|
|
|
param, h = self.postr_net(h)
|
|
|
|
|
|
mu = param[:, 0]
|
|
|
logvar = param[:, 1]
|
|
|
gaus_samples = self.reparameterize(mu, logvar)
|
|
|
beta_samples = F.sigmoid(gaus_samples)
|
|
|
A = beta_samples.unsqueeze(0)
|
|
|
|
|
|
if not validation:
|
|
|
mu_pr = self.prior_mu[slide_label.item()].expand(h.shape[0])
|
|
|
logvar_pr = self.prior_logvar[slide_label.item()]
|
|
|
kl_div = self.kl_logistic_normal(mu_pr, mu, logvar_pr, logvar)
|
|
|
else:
|
|
|
kl_div = None
|
|
|
|
|
|
M = torch.mm(A, h) / A.sum()
|
|
|
|
|
|
logits = self.classifiers(M)
|
|
|
|
|
|
y_probs = F.softmax(logits, dim=1)
|
|
|
top_instance_idx = torch.topk(y_probs[:, 1], self.top_k, dim=0)[1].view(1, )
|
|
|
top_instance = torch.index_select(logits, dim=0, index=top_instance_idx)
|
|
|
Y_hat = torch.topk(top_instance, 1, dim=1)[1]
|
|
|
Y_prob = F.softmax(top_instance, dim=1)
|
|
|
results_dict = {}
|
|
|
|
|
|
if return_features:
|
|
|
top_features = torch.index_select(h, dim=0, index=top_instance_idx)
|
|
|
results_dict.update({'features': top_features})
|
|
|
if not validation:
|
|
|
return top_instance, Y_prob, Y_hat, kl_div, y_probs, A
|
|
|
else:
|
|
|
return top_instance, Y_prob, Y_hat, y_probs, A
|
|
|
|
|
|
|
|
|
|
|
|
class probabilistic_MIL_Bayes_spvis(nn.Module):
|
|
|
def __init__(self, conf, size_arg="small", top_k=1):
|
|
|
super(probabilistic_MIL_Bayes_spvis, self).__init__()
|
|
|
|
|
|
|
|
|
self.size_dict = {"small": [conf.feat_d, 512, 256], "big": [conf.feat_d, 512, 384]}
|
|
|
size = self.size_dict[size_arg]
|
|
|
|
|
|
ard_init = -4.
|
|
|
self.linear1 = nn.Linear(size[0], size[1])
|
|
|
self.linear2a = LinearVDO(size[1], size[2], ard_init=ard_init)
|
|
|
self.linear2b = LinearVDO(size[1], size[2], ard_init=ard_init)
|
|
|
self.linear3 = LinearVDO(size[2], 2, ard_init=ard_init)
|
|
|
|
|
|
self.gaus_smoothing = GaussianSmoothing(1, 3, 0.5)
|
|
|
|
|
|
self.classifiers = LinearVDO(size[1], conf.n_class, ard_init=-3.)
|
|
|
|
|
|
self.dp_0 = nn.Dropout(0.25)
|
|
|
self.dp_a = nn.Dropout(0.25)
|
|
|
self.dp_b = nn.Dropout(0.25)
|
|
|
|
|
|
self.prior_mu = torch.tensor([-5., 0.])
|
|
|
self.prior_logvar = torch.tensor([-1., 3.])
|
|
|
|
|
|
initialize_weights(self)
|
|
|
self.top_k = top_k
|
|
|
self.patch_size = conf.patch_size
|
|
|
|
|
|
def reparameterize(self, mu, logvar):
|
|
|
std = torch.exp(0.5 * logvar)
|
|
|
eps = torch.randn_like(std)
|
|
|
return mu + eps * std
|
|
|
|
|
|
def kl_logistic_normal(self, mu_pr, mu_pos, logvar_pr, logvar_pos):
|
|
|
return (logvar_pr - logvar_pos) / 2. + (logvar_pos ** 2 + (mu_pr - mu_pos) ** 2) / (2. * logvar_pr ** 2) - 0.5
|
|
|
|
|
|
def relocate(self):
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
self.linear1 = self.linear1.to(device)
|
|
|
self.linear2a = self.linear2a.to(device)
|
|
|
self.linear2b = self.linear2b.to(device)
|
|
|
self.linear3 = self.linear3.to(device)
|
|
|
|
|
|
self.dp_0 = self.dp_0.to(device)
|
|
|
self.dp_a = self.dp_a.to(device)
|
|
|
self.dp_b = self.dp_b.to(device)
|
|
|
self.gaus_smoothing = self.gaus_smoothing.to(device)
|
|
|
|
|
|
self.prior_mu = self.prior_mu.to(device)
|
|
|
self.prior_logvar = self.prior_logvar.to(device)
|
|
|
|
|
|
self.classifiers = self.classifiers.to(device)
|
|
|
|
|
|
def forward(self, h, coords, height, width, slide_label=None, validation=False):
|
|
|
h = h[0]
|
|
|
device = h.device
|
|
|
h = F.relu(self.dp_0(self.linear1(h)))
|
|
|
|
|
|
feat_a = self.dp_a(torch.sigmoid(self.linear2a(h)))
|
|
|
feat_b = self.dp_b(torch.tanh(self.linear2b(h)))
|
|
|
feat = feat_a.mul(feat_b)
|
|
|
params = self.linear3(feat)
|
|
|
|
|
|
coords = coords // self.patch_size
|
|
|
asign = lambda coord: coord[:, 0] + coord[:, 1] * (width // self.patch_size)
|
|
|
coords = asign(coords)
|
|
|
coords = torch.from_numpy(coords).to(device)
|
|
|
|
|
|
mu = torch.zeros([1, (height // self.patch_size + 1) * (width // self.patch_size + 1)]).to(device)
|
|
|
logvar = torch.zeros([1, (height // self.patch_size + 1) * (width // self.patch_size + 1)]).to(device)
|
|
|
|
|
|
mu[:, coords.long()] = params[:, 0]
|
|
|
logvar[:, coords.long()] = params[:, 1]
|
|
|
|
|
|
mu = mu.view(1, height // self.patch_size + 1, width // self.patch_size + 1)
|
|
|
logvar = logvar.view(1, height // self.patch_size + 1, width // self.patch_size + 1)
|
|
|
|
|
|
if not validation:
|
|
|
mu_pr = self.prior_mu[slide_label.item()].expand_as(mu)
|
|
|
logvar_pr = self.prior_logvar[slide_label.item()]
|
|
|
kl_div = self.kl_logistic_normal(mu_pr, mu, logvar_pr, logvar)
|
|
|
else:
|
|
|
kl_div = None
|
|
|
|
|
|
|
|
|
mu = F.pad(mu, (1, 1, 1, 1), mode='constant', value=0)
|
|
|
mu = torch.unsqueeze(mu, dim=0)
|
|
|
mu = self.gaus_smoothing(mu)
|
|
|
|
|
|
gaus_samples = self.reparameterize(mu, logvar)
|
|
|
gaus_samples = torch.squeeze(gaus_samples, dim=0)
|
|
|
|
|
|
A = F.sigmoid(gaus_samples)
|
|
|
A = A.view(1, -1)
|
|
|
|
|
|
patch_A = torch.index_select(A, dim=1, index=coords)
|
|
|
M = torch.mm(patch_A, h) / patch_A.sum()
|
|
|
|
|
|
logits = self.classifiers(M)
|
|
|
|
|
|
y_probs = F.softmax(logits, dim=1)
|
|
|
top_instance_idx = torch.topk(y_probs[:, 1], self.top_k, dim=0)[1].view(1, )
|
|
|
top_instance = torch.index_select(logits, dim=0, index=top_instance_idx)
|
|
|
Y_hat = torch.topk(top_instance, 1, dim=1)[1]
|
|
|
Y_prob = F.softmax(top_instance, dim=1)
|
|
|
|
|
|
if not validation:
|
|
|
return top_instance, Y_prob, Y_hat, kl_div, y_probs, patch_A.view((1, -1))
|
|
|
else:
|
|
|
return top_instance, Y_prob, Y_hat, y_probs, patch_A.view((1, -1))
|
|
|
|
|
|
|
|
|
def get_ard_reg_vdo(module, reg=0):
|
|
|
"""
|
|
|
:param module: model to evaluate ard regularization for
|
|
|
:param reg: auxilary cumulative variable for recursion
|
|
|
:return: total regularization for module
|
|
|
"""
|
|
|
if isinstance(module, LinearVDO) or isinstance(module, Conv2dVDO): return reg + module.get_reg()
|
|
|
if hasattr(module, 'children'): return reg + sum([get_ard_reg_vdo(submodule) for submodule in module.children()])
|
|
|
return reg
|
|
|
|
|
|
|
|
|
bMIL_model_dict = {
|
|
|
'vis': probabilistic_MIL_Bayes_vis,
|
|
|
'enc': probabilistic_MIL_Bayes_enc,
|
|
|
'spvis': probabilistic_MIL_Bayes_spvis,
|
|
|
} |