import math import os import torch from torch import nn, Tensor import torch.nn.functional as F from architecture.network import Classifier_1fc, DimReduction, DimReduction1 from einops import repeat from .nystrom_attention import NystromAttention from .emb_position import * def pos_enc_1d(D, len_seq): if D % 2 != 0: raise ValueError("Cannot use sin/cos positional encoding with " "odd dim (got dim={:d})".format(D)) pe = torch.zeros(len_seq, D) position = torch.arange(0, len_seq).unsqueeze(1) div_term = torch.exp((torch.arange(0, D, 2, dtype=torch.float) * -(math.log(10000.0) / D))) pe[:, 0::2] = torch.sin(position.float() * div_term) pe[:, 1::2] = torch.cos(position.float() * div_term) return pe class MLP(nn.Module): def __init__(self, input_dim, hidden_dim, output_dim, dropout_rate): super(MLP, self).__init__() self.fc1 = nn.Linear(input_dim, hidden_dim) self.fc2 = nn.Linear(hidden_dim, output_dim) self.dropout = nn.Dropout(dropout_rate) def forward(self, x): x = self.fc1(x) x = torch.relu(x) x = self.dropout(x) x = self.fc2(x) return x class MLP_single_layer(nn.Module): def __init__(self, input_dim, output_dim): super(MLP_single_layer, self).__init__() self.fc = nn.Linear(input_dim, output_dim) def forward(self, x): x = self.fc(x) return x class ACMIL_MHA(nn.Module): def __init__(self, conf, n_token=1, n_masked_patch=0, mask_drop=0): super(ACMIL_MHA, self).__init__() self.dimreduction = DimReduction(conf.D_feat, conf.D_inner) self.sub_attention = nn.ModuleList() for i in range(n_token): self.sub_attention.append(MutiHeadAttention(conf.D_inner, 8, n_masked_patch=n_masked_patch, mask_drop=mask_drop)) self.bag_attention = MutiHeadAttention_modify(conf.D_inner, 8) self.q = nn.Parameter(torch.zeros((1, n_token, conf.D_inner))) nn.init.normal_(self.q, std=1e-6) self.n_class = conf.n_class self.classifier = nn.ModuleList() for i in range(n_token): self.classifier.append(Classifier_1fc(conf.D_inner, conf.n_class, 0.0)) self.n_token = n_token self.Slide_classifier = Classifier_1fc(conf.D_inner, conf.n_class, 0.0) def forward(self, input): input = self.dimreduction(input) q = self.q k = input v = input outputs = [] attns = [] for i in range(self.n_token): feat_i, attn_i = self.sub_attention[i](q[:, i].unsqueeze(0), k, v) outputs.append(self.classifier[i](feat_i)) attns.append(attn_i) attns = torch.cat(attns, 1) feat_bag = self.bag_attention(v, attns.softmax(dim=-1).mean(1, keepdim=True)) return torch.cat(outputs, dim=0), self.Slide_classifier(feat_bag), attns class MHA(nn.Module): def __init__(self, conf): super(MHA, self).__init__() self.dimreduction = DimReduction(conf.D_feat, conf.D_inner) self.attention = MutiHeadAttention(conf.D_inner, 8) self.q = nn.Parameter(torch.zeros((1, 1, conf.D_inner))) nn.init.normal_(self.q, std=1e-6) self.n_class = conf.n_class self.classifier = Classifier_1fc(conf.D_inner, conf.n_class, 0.0) def forward(self, input): input = self.dimreduction(input) q = self.q k = input v = input feat, attn = self.attention(q, k, v) output = self.classifier(feat) return output class MutiHeadAttention(nn.Module): """ An attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and values. """ def __init__( self, embedding_dim: int, num_heads: int, downsample_rate: int = 1, dropout: float = 0.1, n_masked_patch: int = 0, mask_drop: float = 0.0 ) -> None: super().__init__() self.n_masked_patch = n_masked_patch self.mask_drop = mask_drop self.embedding_dim = embedding_dim self.internal_dim = embedding_dim // downsample_rate self.num_heads = num_heads assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." self.q_proj = nn.Linear(embedding_dim, self.internal_dim) self.k_proj = nn.Linear(embedding_dim, self.internal_dim) self.v_proj = nn.Linear(embedding_dim, self.internal_dim) self.out_proj = nn.Linear(self.internal_dim, embedding_dim) self.layer_norm = nn.LayerNorm(embedding_dim, eps=1e-6) self.dropout = nn.Dropout(dropout) def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: b, n, c = x.shape x = x.reshape(b, n, num_heads, c // num_heads) return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head def _recombine_heads(self, x: Tensor) -> Tensor: b, n_heads, n_tokens, c_per_head = x.shape x = x.transpose(1, 2) return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: # Input projections q = self.q_proj(q) k = self.k_proj(k) v = self.v_proj(v) # Separate into heads q = self._separate_heads(q, self.num_heads) k = self._separate_heads(k, self.num_heads) v = self._separate_heads(v, self.num_heads) # Attention _, _, _, c_per_head = q.shape attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens attn = attn / math.sqrt(c_per_head) if self.n_masked_patch > 0 and self.training: # Get the indices of the top-k largest values b, h, q, c = attn.shape n_masked_patch = min(self.n_masked_patch, c) _, indices = torch.topk(attn, n_masked_patch, dim=-1) indices = indices.reshape(b * h * q, -1) rand_selected = torch.argsort(torch.rand(*indices.shape), dim=-1)[:,:int(n_masked_patch * self.mask_drop)] masked_indices = indices[torch.arange(indices.shape[0]).unsqueeze(-1), rand_selected] random_mask = torch.ones(b*h*q, c).to(attn.device) random_mask.scatter_(-1, masked_indices, 0) attn = attn.masked_fill(random_mask.reshape(b, h, q, -1) == 0, -1e9) attn_out = attn attn = torch.softmax(attn, dim=-1) # Get output out1 = attn @ v out1 = self._recombine_heads(out1) out1 = self.out_proj(out1) out1 = self.dropout(out1) out1 = self.layer_norm(out1) return out1[0], attn_out[0] class MutiHeadAttention_modify(nn.Module): """ An attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and values. """ def __init__( self, embedding_dim: int, num_heads: int, downsample_rate: int = 1, dropout: float = 0.1, ) -> None: super().__init__() self.embedding_dim = embedding_dim self.internal_dim = embedding_dim // downsample_rate self.num_heads = num_heads assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." self.v_proj = nn.Linear(embedding_dim, self.internal_dim) self.out_proj = nn.Linear(self.internal_dim, embedding_dim) self.layer_norm = nn.LayerNorm(embedding_dim, eps=1e-6) self.dropout = nn.Dropout(dropout) def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: b, n, c = x.shape x = x.reshape(b, n, num_heads, c // num_heads) return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head def _recombine_heads(self, x: Tensor) -> Tensor: b, n_heads, n_tokens, c_per_head = x.shape x = x.transpose(1, 2) return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C def forward(self, v: Tensor, attn: Tensor) -> Tensor: # Input projections v = self.v_proj(v) # Separate into heads v = self._separate_heads(v, self.num_heads) # Get output out1 = attn @ v out1 = self._recombine_heads(out1) out1 = self.out_proj(out1) out1 = self.dropout(out1) out1 = self.layer_norm(out1) return out1[0] class Attention_Gated(nn.Module): def __init__(self, L=512, D=128, K=1): super(Attention_Gated, self).__init__() self.L = L self.D = D self.K = K self.attention_V = nn.Sequential( nn.Linear(self.L, self.D), nn.Tanh() ) self.attention_U = nn.Sequential( nn.Linear(self.L, self.D), nn.Sigmoid() ) self.attention_weights = nn.Linear(self.D, self.K) def forward(self, x): ## x: N x L A_V = self.attention_V(x) # NxD A_U = self.attention_U(x) # NxD A = self.attention_weights(A_V * A_U) # NxK A = torch.transpose(A, 1, 0) # KxN return A ### K x N class ABMIL(nn.Module): def __init__(self, conf, D=128, droprate=0): super(ABMIL, self).__init__() self.dimreduction = DimReduction(conf.D_feat, conf.D_inner) self.attention = Attention_Gated(conf.D_inner, D, 1) self.classifier = Classifier_1fc(conf.D_inner, conf.n_class, droprate) def forward(self, x): ## x: N x L x = x[0] med_feat = self.dimreduction(x) A = self.attention(med_feat) ## K x N A_out = A A = F.softmax(A, dim=1) # softmax over N afeat = torch.mm(A, med_feat) ## K x L outputs = self.classifier(afeat) return outputs class ACMIL_GA(nn.Module): def __init__(self, conf, D=128, droprate=0, n_token=1, n_masked_patch=0, mask_drop=0): super(ACMIL_GA, self).__init__() self.dimreduction = DimReduction(conf.D_feat, conf.D_inner) self.attention = Attention_Gated(conf.D_inner, D, n_token) self.classifier = nn.ModuleList() for i in range(n_token): self.classifier.append(Classifier_1fc(conf.D_inner, conf.n_class, droprate)) self.n_masked_patch = n_masked_patch self.n_token = conf.n_token self.Slide_classifier = Classifier_1fc(conf.D_inner, conf.n_class, droprate) self.mask_drop = mask_drop def forward(self, x): ## x: N x L x = x[0] x = self.dimreduction(x) A = self.attention(x) ## K x N if self.n_masked_patch > 0 and self.training: # Get the indices of the top-k largest values k, n = A.shape n_masked_patch = min(self.n_masked_patch, n) _, indices = torch.topk(A, n_masked_patch, dim=-1) rand_selected = torch.argsort(torch.rand(*indices.shape), dim=-1)[:,:int(n_masked_patch * self.mask_drop)] masked_indices = indices[torch.arange(indices.shape[0]).unsqueeze(-1), rand_selected] random_mask = torch.ones(k, n).to(A.device) random_mask.scatter_(-1, masked_indices, 0) A = A.masked_fill(random_mask == 0, -1e9) A_out = A A = F.softmax(A, dim=1) # softmax over N afeat = torch.mm(A, x) ## K x L outputs = [] for i, head in enumerate(self.classifier): outputs.append(head(afeat[i])) bag_A = F.softmax(A_out, dim=1).mean(0, keepdim=True) bag_feat = torch.mm(bag_A, x) return torch.stack(outputs, dim=0), self.Slide_classifier(bag_feat), A_out.unsqueeze(0) def forward_feature(self, x, use_attention_mask=False): ## x: N x L x = x[0] x = self.dimreduction(x) A = self.attention(x) ## K x N if self.n_masked_patch > 0 and use_attention_mask: # Get the indices of the top-k largest values k, n = A.shape n_masked_patch = min(self.n_masked_patch, n) _, indices = torch.topk(A, n_masked_patch, dim=-1) rand_selected = torch.argsort(torch.rand(*indices.shape), dim=-1)[:,:int(n_masked_patch * self.mask_drop)] masked_indices = indices[torch.arange(indices.shape[0]).unsqueeze(-1), rand_selected] random_mask = torch.ones(k, n).to(A.device) random_mask.scatter_(-1, masked_indices, 0) A = A.masked_fill(random_mask == 0, -1e9) A_out = A bag_A = F.softmax(A_out, dim=1).mean(0, keepdim=True) bag_feat = torch.mm(bag_A, x) return bag_feat class ACMIL_MHA_NoClassifier(nn.Module): def __init__(self, conf, n_token=1, n_masked_patch=0, mask_drop=0): super(ACMIL_MHA_NoClassifier, self).__init__() self.dimreduction = DimReduction(conf.D_feat, conf.D_inner) self.sub_attention = nn.ModuleList([ MutiHeadAttention(conf.D_inner, 8, n_masked_patch=n_masked_patch, mask_drop=mask_drop) for _ in range(n_token) ]) self.bag_attention = MutiHeadAttention_modify(conf.D_inner, 8) self.q = nn.Parameter(torch.zeros((1, n_token, conf.D_inner))) nn.init.normal_(self.q, std=1e-6) self.n_token = n_token def forward(self, input): x = self.dimreduction(input) q = self.q k = x v = x feats = [] attns = [] for i in range(self.n_token): feat_i, attn_i = self.sub_attention[i](q[:, i].unsqueeze(0), k, v) feats.append(feat_i) attns.append(attn_i) attns_tensor = torch.cat(attns, 1) bag_feat = self.bag_attention(v, attns_tensor.softmax(dim=-1).mean(1, keepdim=True)) # Return the raw instance features, bag feature, and attention map return torch.cat(feats, dim=0), bag_feat, attns_tensor class ACMIL_GA_NoClassifier(nn.Module): def __init__(self, conf, D=128, droprate=0, n_token=1, n_masked_patch=0, mask_drop=0): super(ACMIL_GA_NoClassifier, self).__init__() self.dimreduction = DimReduction(conf.D_feat, conf.D_inner) self.attention = Attention_Gated(conf.D_inner, D, n_token) self.n_token = n_token self.n_masked_patch = n_masked_patch self.mask_drop = mask_drop def forward(self, x): # x: list or tuple with x[0] = tensor N x L x = x[0] x_red = self.dimreduction(x) A = self.attention(x_red) # K x N if self.n_masked_patch > 0 and self.training: k, n = A.shape n_masked = min(self.n_masked_patch, n) _, indices = torch.topk(A, n_masked, dim=-1) rand_sel = torch.argsort(torch.rand(*indices.shape), dim=-1) rand_sel = rand_sel[:, :int(n_masked * self.mask_drop)] masked_inds = indices[torch.arange(k).unsqueeze(-1), rand_sel] mask = torch.ones_like(A) mask.scatter_(-1, masked_inds, 0) A = A.masked_fill(mask == 0, -1e9) A_out = A bag_A = F.softmax(A_out, dim=1).mean(0, keepdim=True) bag_feat = torch.mm(bag_A, x_red) # 1 x L # Return token-level attention, bag feature, and raw attention scores return A_out, bag_feat, None # Remove import of Classifier_1fc and related classifier usages.