WinstonHu's picture
Upload folder xtuner to code/xtuner
e5e24c9 verified
raw
history blame
5.24 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
"""
Exploring Low-Rank Property in Multiple Instance Learning for Whole Slide Image Classification
Jinxi Xiang et al. ICLR 2023
"""
def initialize_weights(model):
for m in model.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight)
# m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
class MultiHeadAttention(nn.Module):
"""
multi-head attention block
"""
def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False, gated=False):
super(MultiHeadAttention, self).__init__()
self.dim_V = dim_V
self.num_heads = num_heads
self.multihead_attn = nn.MultiheadAttention(dim_V, num_heads)
self.fc_q = nn.Linear(dim_Q, dim_V)
self.fc_k = nn.Linear(dim_K, dim_V)
self.fc_v = nn.Linear(dim_K, dim_V)
if ln:
self.ln0 = nn.LayerNorm(dim_V)
self.ln1 = nn.LayerNorm(dim_V)
self.fc_o = nn.Linear(dim_V, dim_V)
self.gate = None
if gated:
self.gate = nn.Sequential(nn.Linear(dim_Q, dim_V), nn.SiLU())
def forward(self, Q, K):
Q0 = Q
Q = self.fc_q(Q).transpose(0, 1)
K, V = self.fc_k(K).transpose(0, 1), self.fc_v(K).transpose(0, 1)
A, _ = self.multihead_attn(Q, K, V)
O = (Q + A).transpose(0, 1)
O = O if getattr(self, 'ln0', None) is None else self.ln0(O)
O = O + F.relu(self.fc_o(O))
O = O if getattr(self, 'ln1', None) is None else self.ln1(O)
if self.gate is not None:
O = O.mul(self.gate(Q0))
return O
class GAB(nn.Module):
"""
equation (16) in the paper
"""
def __init__(self, dim_in, dim_out, num_heads, num_inds, ln=False):
super(GAB, self).__init__()
self.latent = nn.Parameter(torch.Tensor(1, num_inds, dim_out)) # low-rank matrix L
nn.init.xavier_uniform_(self.latent)
self.project_forward = MultiHeadAttention(dim_out, dim_in, dim_out, num_heads, ln=ln, gated=True)
self.project_backward = MultiHeadAttention(dim_in, dim_out, dim_out, num_heads, ln=ln, gated=True)
def forward(self, X):
"""
This process, which utilizes 'latent_mat' as a proxy, has relatively low computational complexity.
In some respects, it is equivalent to the self-attention function applied to 'X' with itself,
denoted as self-attention(X, X), which has a complexity of O(n^2).
"""
latent_mat = self.latent.repeat(X.size(0), 1, 1)
H = self.project_forward(latent_mat, X) # project the high-dimensional X into low-dimensional H
X_hat = self.project_backward(X, H) # recover to high-dimensional space X_hat
return X_hat
class NLP(nn.Module):
"""
To obtain global features for classification, Non-Local Pooling is a more effective method
than simple average pooling, which may result in degraded performance.
"""
def __init__(self, dim, num_heads, num_seeds, ln=False):
super(NLP, self).__init__()
self.S = nn.Parameter(torch.Tensor(1, num_seeds, dim))
nn.init.xavier_uniform_(self.S)
self.mha = MultiHeadAttention(dim, dim, dim, num_heads, ln=ln)
def forward(self, X):
global_embedding = self.S.repeat(X.size(0), 1, 1)
ret = self.mha(global_embedding, X)
return ret
class ILRA(nn.Module):
def __init__(self, num_layers=2, feat_dim=768, n_classes=2, hidden_feat=256, num_heads=8, topk=1, ln=False):
super().__init__()
# stack multiple GAB block
gab_blocks = []
for idx in range(num_layers):
block = GAB(dim_in=feat_dim if idx == 0 else hidden_feat,
dim_out=hidden_feat,
num_heads=num_heads,
num_inds=topk,
ln=ln)
gab_blocks.append(block)
self.gab_blocks = nn.ModuleList(gab_blocks)
# non-local pooling for classification
self.pooling = NLP(dim=hidden_feat, num_heads=num_heads, num_seeds=topk, ln=ln)
# classifier
self.classifier = nn.Linear(in_features=hidden_feat, out_features=n_classes)
initialize_weights(self)
print(f"ilra2~")
def forward(self, x):
for block in self.gab_blocks:
x = block(x)
feat = self.pooling(x)
logits = self.classifier(feat)
logits = logits.squeeze(1)
# Y_hat = torch.topk(logits, 1, dim=1)[1]
# Y_prob = F.softmax(logits, dim=1)
return logits
if __name__ == "__main__":
model = ILRA(feat_dim=1024, n_classes=2, hidden_feat=256, num_heads=8, topk=1)
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"num of params: {num_params}")
x = torch.randn((1, 1600, 1024))
logits, prob, y_hat = model(x)
print(f"y shape: {logits.shape}")