File size: 5,240 Bytes
e5e24c9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
"""
Exploring Low-Rank Property in Multiple Instance Learning for Whole Slide Image Classification
Jinxi Xiang et al. ICLR 2023
"""
def initialize_weights(model):
for m in model.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight)
# m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
class MultiHeadAttention(nn.Module):
"""
multi-head attention block
"""
def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False, gated=False):
super(MultiHeadAttention, self).__init__()
self.dim_V = dim_V
self.num_heads = num_heads
self.multihead_attn = nn.MultiheadAttention(dim_V, num_heads)
self.fc_q = nn.Linear(dim_Q, dim_V)
self.fc_k = nn.Linear(dim_K, dim_V)
self.fc_v = nn.Linear(dim_K, dim_V)
if ln:
self.ln0 = nn.LayerNorm(dim_V)
self.ln1 = nn.LayerNorm(dim_V)
self.fc_o = nn.Linear(dim_V, dim_V)
self.gate = None
if gated:
self.gate = nn.Sequential(nn.Linear(dim_Q, dim_V), nn.SiLU())
def forward(self, Q, K):
Q0 = Q
Q = self.fc_q(Q).transpose(0, 1)
K, V = self.fc_k(K).transpose(0, 1), self.fc_v(K).transpose(0, 1)
A, _ = self.multihead_attn(Q, K, V)
O = (Q + A).transpose(0, 1)
O = O if getattr(self, 'ln0', None) is None else self.ln0(O)
O = O + F.relu(self.fc_o(O))
O = O if getattr(self, 'ln1', None) is None else self.ln1(O)
if self.gate is not None:
O = O.mul(self.gate(Q0))
return O
class GAB(nn.Module):
"""
equation (16) in the paper
"""
def __init__(self, dim_in, dim_out, num_heads, num_inds, ln=False):
super(GAB, self).__init__()
self.latent = nn.Parameter(torch.Tensor(1, num_inds, dim_out)) # low-rank matrix L
nn.init.xavier_uniform_(self.latent)
self.project_forward = MultiHeadAttention(dim_out, dim_in, dim_out, num_heads, ln=ln, gated=True)
self.project_backward = MultiHeadAttention(dim_in, dim_out, dim_out, num_heads, ln=ln, gated=True)
def forward(self, X):
"""
This process, which utilizes 'latent_mat' as a proxy, has relatively low computational complexity.
In some respects, it is equivalent to the self-attention function applied to 'X' with itself,
denoted as self-attention(X, X), which has a complexity of O(n^2).
"""
latent_mat = self.latent.repeat(X.size(0), 1, 1)
H = self.project_forward(latent_mat, X) # project the high-dimensional X into low-dimensional H
X_hat = self.project_backward(X, H) # recover to high-dimensional space X_hat
return X_hat
class NLP(nn.Module):
"""
To obtain global features for classification, Non-Local Pooling is a more effective method
than simple average pooling, which may result in degraded performance.
"""
def __init__(self, dim, num_heads, num_seeds, ln=False):
super(NLP, self).__init__()
self.S = nn.Parameter(torch.Tensor(1, num_seeds, dim))
nn.init.xavier_uniform_(self.S)
self.mha = MultiHeadAttention(dim, dim, dim, num_heads, ln=ln)
def forward(self, X):
global_embedding = self.S.repeat(X.size(0), 1, 1)
ret = self.mha(global_embedding, X)
return ret
class ILRA(nn.Module):
def __init__(self, num_layers=2, feat_dim=768, n_classes=2, hidden_feat=256, num_heads=8, topk=1, ln=False):
super().__init__()
# stack multiple GAB block
gab_blocks = []
for idx in range(num_layers):
block = GAB(dim_in=feat_dim if idx == 0 else hidden_feat,
dim_out=hidden_feat,
num_heads=num_heads,
num_inds=topk,
ln=ln)
gab_blocks.append(block)
self.gab_blocks = nn.ModuleList(gab_blocks)
# non-local pooling for classification
self.pooling = NLP(dim=hidden_feat, num_heads=num_heads, num_seeds=topk, ln=ln)
# classifier
self.classifier = nn.Linear(in_features=hidden_feat, out_features=n_classes)
initialize_weights(self)
print(f"ilra2~")
def forward(self, x):
for block in self.gab_blocks:
x = block(x)
feat = self.pooling(x)
logits = self.classifier(feat)
logits = logits.squeeze(1)
# Y_hat = torch.topk(logits, 1, dim=1)[1]
# Y_prob = F.softmax(logits, dim=1)
return logits
if __name__ == "__main__":
model = ILRA(feat_dim=1024, n_classes=2, hidden_feat=256, num_heads=8, topk=1)
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"num of params: {num_params}")
x = torch.randn((1, 1600, 1024))
logits, prob, y_hat = model(x)
print(f"y shape: {logits.shape}") |