import torch import torch.nn as nn import torch.nn.functional as F from architecture.network import Classifier_1fc class Attention2(nn.Module): def __init__(self, L=512, D=128, K=1): super(Attention2, self).__init__() self.L = L self.D = D self.K = K self.attention = nn.Sequential( nn.Linear(self.L, self.D), nn.Tanh(), nn.Linear(self.D, self.K) ) def forward(self, x, isNorm=True): ## x: N x L A = self.attention(x) ## N x K A = torch.transpose(A, 1, 0) # KxN if isNorm: A = F.softmax(A, dim=1) # softmax over N return A ### K x N class Attention_Gated(nn.Module): def __init__(self, L=512, D=128, K=1): super(Attention_Gated, self).__init__() self.L = L self.D = D self.K = K self.attention_V = nn.Sequential( nn.Linear(self.L, self.D), nn.Tanh() ) self.attention_U = nn.Sequential( nn.Linear(self.L, self.D), nn.Sigmoid() ) self.attention_weights = nn.Linear(self.D, self.K) def forward(self, x, isNorm=True): ## x: N x L A_V = self.attention_V(x) # NxD A_U = self.attention_U(x) # NxD A = self.attention_weights(A_V * A_U) # NxK A = torch.transpose(A, 1, 0) # KxN if isNorm: A = F.softmax(A, dim=1) # softmax over N return A ### K x N class Attention_with_Classifier(nn.Module): def __init__(self, L=512, D=128, K=1, num_cls=2, droprate=0): super(Attention_with_Classifier, self).__init__() self.attention = Attention_Gated(L, D, K) self.classifier = Classifier_1fc(L, num_cls, droprate) def forward(self, x): ## x: N x L AA = self.attention(x) ## K x N afeat = torch.mm(AA, x) ## K x L pred = self.classifier(afeat) ## K x num_cls return pred