|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
import math
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
Exploring Low-Rank Property in Multiple Instance Learning for Whole Slide Image Classification
|
|
|
Jinxi Xiang et al. ICLR 2023
|
|
|
"""
|
|
|
|
|
|
def initialize_weights(model):
|
|
|
for m in model.modules():
|
|
|
if isinstance(m, nn.Linear):
|
|
|
nn.init.xavier_normal_(m.weight)
|
|
|
|
|
|
|
|
|
elif isinstance(m, nn.BatchNorm1d):
|
|
|
nn.init.constant_(m.weight, 1)
|
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
|
|
|
|
|
class MultiHeadAttention(nn.Module):
|
|
|
"""
|
|
|
multi-head attention block
|
|
|
"""
|
|
|
|
|
|
def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False, gated=False):
|
|
|
super(MultiHeadAttention, self).__init__()
|
|
|
self.dim_V = dim_V
|
|
|
self.num_heads = num_heads
|
|
|
self.multihead_attn = nn.MultiheadAttention(dim_V, num_heads)
|
|
|
self.fc_q = nn.Linear(dim_Q, dim_V)
|
|
|
self.fc_k = nn.Linear(dim_K, dim_V)
|
|
|
self.fc_v = nn.Linear(dim_K, dim_V)
|
|
|
if ln:
|
|
|
self.ln0 = nn.LayerNorm(dim_V)
|
|
|
self.ln1 = nn.LayerNorm(dim_V)
|
|
|
self.fc_o = nn.Linear(dim_V, dim_V)
|
|
|
|
|
|
self.gate = None
|
|
|
if gated:
|
|
|
self.gate = nn.Sequential(nn.Linear(dim_Q, dim_V), nn.SiLU())
|
|
|
|
|
|
def forward(self, Q, K):
|
|
|
|
|
|
Q0 = Q
|
|
|
|
|
|
Q = self.fc_q(Q).transpose(0, 1)
|
|
|
K, V = self.fc_k(K).transpose(0, 1), self.fc_v(K).transpose(0, 1)
|
|
|
|
|
|
A, _ = self.multihead_attn(Q, K, V)
|
|
|
|
|
|
O = (Q + A).transpose(0, 1)
|
|
|
O = O if getattr(self, 'ln0', None) is None else self.ln0(O)
|
|
|
O = O + F.relu(self.fc_o(O))
|
|
|
O = O if getattr(self, 'ln1', None) is None else self.ln1(O)
|
|
|
|
|
|
if self.gate is not None:
|
|
|
O = O.mul(self.gate(Q0))
|
|
|
|
|
|
return O
|
|
|
|
|
|
|
|
|
class GAB(nn.Module):
|
|
|
"""
|
|
|
equation (16) in the paper
|
|
|
"""
|
|
|
|
|
|
def __init__(self, dim_in, dim_out, num_heads, num_inds, ln=False):
|
|
|
super(GAB, self).__init__()
|
|
|
self.latent = nn.Parameter(torch.Tensor(1, num_inds, dim_out))
|
|
|
|
|
|
nn.init.xavier_uniform_(self.latent)
|
|
|
|
|
|
self.project_forward = MultiHeadAttention(dim_out, dim_in, dim_out, num_heads, ln=ln, gated=True)
|
|
|
self.project_backward = MultiHeadAttention(dim_in, dim_out, dim_out, num_heads, ln=ln, gated=True)
|
|
|
|
|
|
def forward(self, X):
|
|
|
"""
|
|
|
This process, which utilizes 'latent_mat' as a proxy, has relatively low computational complexity.
|
|
|
In some respects, it is equivalent to the self-attention function applied to 'X' with itself,
|
|
|
denoted as self-attention(X, X), which has a complexity of O(n^2).
|
|
|
"""
|
|
|
latent_mat = self.latent.repeat(X.size(0), 1, 1)
|
|
|
H = self.project_forward(latent_mat, X)
|
|
|
X_hat = self.project_backward(X, H)
|
|
|
|
|
|
return X_hat
|
|
|
|
|
|
|
|
|
class NLP(nn.Module):
|
|
|
"""
|
|
|
To obtain global features for classification, Non-Local Pooling is a more effective method
|
|
|
than simple average pooling, which may result in degraded performance.
|
|
|
"""
|
|
|
|
|
|
def __init__(self, dim, num_heads, num_seeds, ln=False):
|
|
|
super(NLP, self).__init__()
|
|
|
self.S = nn.Parameter(torch.Tensor(1, num_seeds, dim))
|
|
|
nn.init.xavier_uniform_(self.S)
|
|
|
self.mha = MultiHeadAttention(dim, dim, dim, num_heads, ln=ln)
|
|
|
|
|
|
def forward(self, X):
|
|
|
global_embedding = self.S.repeat(X.size(0), 1, 1)
|
|
|
ret = self.mha(global_embedding, X)
|
|
|
return ret
|
|
|
|
|
|
|
|
|
class ILRA(nn.Module):
|
|
|
def __init__(self, num_layers=2, feat_dim=768, n_classes=2, hidden_feat=256, num_heads=8, topk=1, ln=False):
|
|
|
super().__init__()
|
|
|
|
|
|
gab_blocks = []
|
|
|
for idx in range(num_layers):
|
|
|
block = GAB(dim_in=feat_dim if idx == 0 else hidden_feat,
|
|
|
dim_out=hidden_feat,
|
|
|
num_heads=num_heads,
|
|
|
num_inds=topk,
|
|
|
ln=ln)
|
|
|
gab_blocks.append(block)
|
|
|
|
|
|
self.gab_blocks = nn.ModuleList(gab_blocks)
|
|
|
|
|
|
|
|
|
self.pooling = NLP(dim=hidden_feat, num_heads=num_heads, num_seeds=topk, ln=ln)
|
|
|
|
|
|
|
|
|
self.classifier = nn.Linear(in_features=hidden_feat, out_features=n_classes)
|
|
|
|
|
|
initialize_weights(self)
|
|
|
print(f"ilra2~")
|
|
|
|
|
|
def forward(self, x):
|
|
|
for block in self.gab_blocks:
|
|
|
x = block(x)
|
|
|
|
|
|
feat = self.pooling(x)
|
|
|
logits = self.classifier(feat)
|
|
|
|
|
|
logits = logits.squeeze(1)
|
|
|
|
|
|
|
|
|
|
|
|
return logits
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
model = ILRA(feat_dim=1024, n_classes=2, hidden_feat=256, num_heads=8, topk=1)
|
|
|
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
|
print(f"num of params: {num_params}")
|
|
|
|
|
|
x = torch.randn((1, 1600, 1024))
|
|
|
logits, prob, y_hat = model(x)
|
|
|
print(f"y shape: {logits.shape}") |