WinstonHu's picture
Upload folder xtuner to code/xtuner
e5e24c9 verified
raw
history blame
10.8 kB
import torch
import numpy as np
from torch import nn
import torch.nn.functional as F
class DAttention(nn.Module):
def __init__(self, n_classes, dropout, act):
super(DAttention, self).__init__()
self.L = 512 # 512
self.D = 128 # 128
self.K = 1
self.feature = [nn.Linear(1024, 512)]
if act.lower() == 'gelu':
self.feature += [nn.GELU()]
else:
self.feature += [nn.ReLU()]
if dropout:
self.feature += [nn.Dropout(0.25)]
self.feature = nn.Sequential(*self.feature)
self.attention = nn.Sequential(
nn.Linear(self.L, self.D),
nn.Tanh(),
nn.Linear(self.D, self.K)
)
self.classifier = nn.Sequential(
nn.Linear(self.L * self.K, n_classes),
)
self.apply(initialize_weights)
def forward(self, x, return_attn=False, no_norm=False):
feature = self.feature(x)
# feature = group_shuffle(feature)
feature = feature.squeeze(0)
A = self.attention(feature)
A_ori = A.clone()
A = torch.transpose(A, -1, -2) # KxN
A = F.softmax(A, dim=-1) # softmax over N
M = torch.mm(A, feature) # KxL
Y_prob = self.classifier(M)
if return_attn:
if no_norm:
return Y_prob, A_ori
else:
return Y_prob, A
else:
return Y_prob
def initialize_weights(module):
for m in module.modules():
if isinstance(m,nn.Linear):
nn.init.xavier_normal_(m.weight)
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m,nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
class SoftTargetCrossEntropy_v2(nn.Module):
def __init__(self,temp_t=1.,temp_s=1.):
super(SoftTargetCrossEntropy_v2, self).__init__()
self.temp_t = temp_t
self.temp_s = temp_s
def forward(self, x: torch.Tensor, target: torch.Tensor, mean: bool= True) -> torch.Tensor:
loss = torch.sum(-F.softmax(target / self.temp_t,dim=-1) * F.log_softmax(x / self.temp_s, dim=-1), dim=-1)
if mean:
return loss.mean()
else:
return loss
class MHIM(nn.Module):
def __init__(self, mlp_dim=512,mask_ratio=0,n_classes=2,temp_t=1.,temp_s=1.,dropout=0.25,act='relu',select_mask=True,select_inv=False,msa_fusion='vote',mask_ratio_h=0.,mrh_sche=None,mask_ratio_hr=0.,mask_ratio_l=0.,da_act='gelu',baseline='selfattn',head=8,attn_layer=0):
super(MHIM, self).__init__()
self.mask_ratio = mask_ratio
self.mask_ratio_h = mask_ratio_h
self.mask_ratio_hr = mask_ratio_hr
self.mask_ratio_l = mask_ratio_l
self.select_mask = select_mask
self.select_inv = select_inv
self.msa_fusion = msa_fusion
self.mrh_sche = mrh_sche
self.attn_layer = attn_layer
self.patch_to_emb = [nn.Linear(1024, 512)]
if act.lower() == 'relu':
self.patch_to_emb += [nn.ReLU()]
elif act.lower() == 'gelu':
self.patch_to_emb += [nn.GELU()]
self.dp = nn.Dropout(dropout) if dropout > 0. else nn.Identity()
self.patch_to_emb = nn.Sequential(*self.patch_to_emb)
self.online_encoder = DAttention(mlp_dim,da_act)
self.predictor = nn.Linear(mlp_dim,n_classes)
self.temp_t = temp_t
self.temp_s = temp_s
self.cl_loss = SoftTargetCrossEntropy_v2(self.temp_t,self.temp_s)
self.predictor_cl = nn.Identity()
self.target_predictor = nn.Identity()
self.apply(initialize_weights)
def select_mask_fn(self,ps,attn,largest,mask_ratio,mask_ids_other=None,len_keep_other=None,cls_attn_topk_idx_other=None,random_ratio=1.,select_inv=False):
ps_tmp = ps
mask_ratio_ori = mask_ratio
mask_ratio = mask_ratio / random_ratio
if mask_ratio > 1:
random_ratio = mask_ratio_ori
mask_ratio = 1.
# print(attn.size())
if mask_ids_other is not None:
if cls_attn_topk_idx_other is None:
cls_attn_topk_idx_other = mask_ids_other[:,len_keep_other:].squeeze()
ps_tmp = ps - cls_attn_topk_idx_other.size(0)
if len(attn.size()) > 2:
if self.msa_fusion == 'mean':
_,cls_attn_topk_idx = torch.topk(attn,int(np.ceil((ps_tmp*mask_ratio)) // attn.size(1)),largest=largest)
cls_attn_topk_idx = torch.unique(cls_attn_topk_idx.flatten(-3,-1))
elif self.msa_fusion == 'vote':
vote = attn.clone()
vote[:] = 0
_,idx = torch.topk(attn,k=int(np.ceil((ps_tmp*mask_ratio))),sorted=False,largest=largest)
mask = vote.clone()
mask = mask.scatter_(2,idx,1) == 1
vote[mask] = 1
vote = vote.sum(dim=1)
_,cls_attn_topk_idx = torch.topk(vote,k=int(np.ceil((ps_tmp*mask_ratio))),sorted=False)
# print(cls_attn_topk_idx.size())
cls_attn_topk_idx = cls_attn_topk_idx[0]
else:
k = int(np.ceil((ps_tmp*mask_ratio)))
_,cls_attn_topk_idx = torch.topk(attn,k,largest=largest)
cls_attn_topk_idx = cls_attn_topk_idx.squeeze(0)
# randomly
if random_ratio < 1.:
random_idx = torch.randperm(cls_attn_topk_idx.size(0),device=cls_attn_topk_idx.device)
cls_attn_topk_idx = torch.gather(cls_attn_topk_idx,dim=0,index=random_idx[:int(np.ceil((cls_attn_topk_idx.size(0)*random_ratio)))])
# concat other masking idx
if mask_ids_other is not None:
cls_attn_topk_idx = torch.cat([cls_attn_topk_idx,cls_attn_topk_idx_other]).unique()
# if cls_attn_topk_idx is not None:
len_keep = ps - cls_attn_topk_idx.size(0)
a = set(cls_attn_topk_idx.tolist())
b = set(list(range(ps)))
mask_ids = torch.tensor(list(b.difference(a)),device=attn.device)
if select_inv:
mask_ids = torch.cat([cls_attn_topk_idx,mask_ids]).unsqueeze(0)
len_keep = ps - len_keep
else:
mask_ids = torch.cat([mask_ids,cls_attn_topk_idx]).unsqueeze(0)
return len_keep,mask_ids
def get_mask(self,ps,i,attn,mrh=None):
if attn is not None and isinstance(attn,(list,tuple)):
if self.attn_layer == -1:
attn = attn[1]
else:
attn = attn[self.attn_layer]
else:
attn = attn
# random mask
if attn is not None and self.mask_ratio > 0.:
len_keep,mask_ids = self.select_mask_fn(ps,attn,False,self.mask_ratio,select_inv=self.select_inv,random_ratio=0.001)
else:
len_keep,mask_ids = ps,None
# low attention mask
if attn is not None and self.mask_ratio_l > 0.:
if mask_ids is None:
len_keep,mask_ids = self.select_mask_fn(ps,attn,False,self.mask_ratio_l,select_inv=self.select_inv)
else:
cls_attn_topk_idx_other = mask_ids[:,:len_keep].squeeze() if self.select_inv else mask_ids[:,len_keep:].squeeze()
len_keep,mask_ids = self.select_mask_fn(ps,attn,False,self.mask_ratio_l,select_inv=self.select_inv,mask_ids_other=mask_ids,len_keep_other=ps,cls_attn_topk_idx_other = cls_attn_topk_idx_other)
# high attention mask
mask_ratio_h = self.mask_ratio_h
if self.mrh_sche is not None:
mask_ratio_h = self.mrh_sche[i]
if mrh is not None:
mask_ratio_h = mrh
if mask_ratio_h > 0. :
# mask high conf patch
if mask_ids is None:
len_keep,mask_ids = self.select_mask_fn(ps,attn,largest=True,mask_ratio=mask_ratio_h,len_keep_other=ps,random_ratio=self.mask_ratio_hr,select_inv=self.select_inv)
else:
cls_attn_topk_idx_other = mask_ids[:,:len_keep].squeeze() if self.select_inv else mask_ids[:,len_keep:].squeeze()
len_keep,mask_ids = self.select_mask_fn(ps,attn,largest=True,mask_ratio=mask_ratio_h,mask_ids_other=mask_ids,len_keep_other=ps,cls_attn_topk_idx_other = cls_attn_topk_idx_other,random_ratio=self.mask_ratio_hr,select_inv=self.select_inv)
return len_keep,mask_ids
@torch.no_grad()
def forward_teacher(self,x,return_attn=False):
x = self.patch_to_emb(x)
x = self.dp(x)
if return_attn:
x,attn = self.online_encoder(x,return_attn=True)
else:
x = self.online_encoder(x)
attn = None
return x,attn
@torch.no_grad()
def forward_test(self,x,return_attn=False,no_norm=False):
x = self.patch_to_emb(x)
x = self.dp(x)
if return_attn:
x,a = self.online_encoder(x,return_attn=True,no_norm=no_norm)
else:
x = self.online_encoder(x)
x = self.predictor(x)
if return_attn:
return x,a
else:
return x
def pure(self,x,return_attn=False):
x = self.patch_to_emb(x)
x = self.dp(x)
ps = x.size(1)
if return_attn:
x,attn = self.online_encoder(x,return_attn=True)
else:
x = self.online_encoder(x)
x = self.predictor(x)
if self.training:
if return_attn:
return x, 0, ps,ps,attn
else:
return x, 0, ps,ps
else:
if return_attn:
return x,attn
else:
return x
def forward_loss(self, student_cls_feat, teacher_cls_feat):
if teacher_cls_feat is not None:
cls_loss = self.cl_loss(student_cls_feat,teacher_cls_feat.detach())
else:
cls_loss = 0.
return cls_loss
def forward(self, x,attn=None,teacher_cls_feat=None,i=None):
x = self.patch_to_emb(x)
x = self.dp(x)
ps = x.size(1)
# get mask
if self.select_mask:
len_keep,mask_ids = self.get_mask(ps,i,attn)
else:
len_keep,mask_ids = ps,None
# forward online network
student_cls_feat= self.online_encoder(x,len_keep=len_keep,mask_ids=mask_ids,mask_enable=True)
# prediction
student_logit = self.predictor(student_cls_feat)
# cl loss
cls_loss= self.forward_loss(student_cls_feat=student_cls_feat,teacher_cls_feat=teacher_cls_feat)
return student_logit, cls_loss,ps,len_keep