import torch from torch import nn import numpy as np class PPEG(nn.Module): def __init__(self, dim=512,k=7,conv_1d=False,bias=True): super(PPEG, self).__init__() self.proj = nn.Conv2d(dim, dim, k, 1, k//2, groups=dim,bias=bias) if not conv_1d else nn.Conv2d(dim, dim, (k,1), 1, (k//2,0), groups=dim,bias=bias) self.proj1 = nn.Conv2d(dim, dim, 5, 1, 5//2, groups=dim,bias=bias) if not conv_1d else nn.Conv2d(dim, dim, (5,1), 1, (5//2,0), groups=dim,bias=bias) self.proj2 = nn.Conv2d(dim, dim, 3, 1, 3//2, groups=dim,bias=bias) if not conv_1d else nn.Conv2d(dim, dim, (3,1), 1, (3//2,0), groups=dim,bias=bias) def forward(self, x): B, N, C = x.shape # padding H, W = int(np.ceil(np.sqrt(N))), int(np.ceil(np.sqrt(N))) add_length = H * W - N # if add_length >0: x = torch.cat([x, x[:,:add_length,:]],dim = 1) if H < 7: H,W = 7,7 zero_pad = H * W - (N+add_length) x = torch.cat([x, torch.zeros((B,zero_pad,C),device=x.device)],dim = 1) add_length += zero_pad # H, W = int(N**0.5),int(N**0.5) # cls_token, feat_token = x[:, 0], x[:, 1:] # feat_token = x cnn_feat = x.transpose(1, 2).view(B, C, H, W) x = self.proj(cnn_feat)+cnn_feat+self.proj1(cnn_feat)+self.proj2(cnn_feat) x = x.flatten(2).transpose(1, 2) # print(add_length) if add_length >0: x = x[:,:-add_length] # x = torch.cat((cls_token.unsqueeze(1), x), dim=1) return x class PEG(nn.Module): def __init__(self, dim=512,k=7,bias=True,conv_1d=False): super(PEG, self).__init__() self.proj = nn.Conv2d(dim, dim, k, 1, k//2, groups=dim,bias=bias) if not conv_1d else nn.Conv2d(dim, dim, (k,1), 1, (k//2,0), groups=dim,bias=bias) def forward(self, x): B, N, C = x.shape # padding H, W = int(np.ceil(np.sqrt(N))), int(np.ceil(np.sqrt(N))) add_length = H * W - N x = torch.cat([x, x[:,:add_length,:]],dim = 1) feat_token = x cnn_feat = feat_token.transpose(1, 2).view(B, C, H, W) x = self.proj(cnn_feat)+cnn_feat x = x.flatten(2).transpose(1, 2) if add_length >0: x = x[:,:-add_length] # x = torch.cat((cls_token.unsqueeze(1), x), dim=1) return x class SINCOS(nn.Module): def __init__(self,embed_dim=512): super(SINCOS, self).__init__() self.embed_dim = embed_dim self.pos_embed = self.get_2d_sincos_pos_embed(embed_dim, 8) def get_1d_sincos_pos_embed_from_grid(self,embed_dim, pos): """ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) """ assert embed_dim % 2 == 0 omega = np.arange(embed_dim // 2, dtype=np.float) omega /= embed_dim / 2. omega = 1. / 10000**omega # (D/2,) pos = pos.reshape(-1) # (M,) out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product emb_sin = np.sin(out) # (M, D/2) emb_cos = np.cos(out) # (M, D/2) emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) return emb def get_2d_sincos_pos_embed_from_grid(self,embed_dim, grid): assert embed_dim % 2 == 0 # use half of dimensions to encode grid_h emb_h = self.get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) emb_w = self.get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) return emb def get_2d_sincos_pos_embed(self,embed_dim, grid_size, cls_token=False): """ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) """ grid_h = np.arange(grid_size, dtype=np.float32) grid_w = np.arange(grid_size, dtype=np.float32) grid = np.meshgrid(grid_w, grid_h) # here w goes first grid = np.stack(grid, axis=0) grid = grid.reshape([2, 1, grid_size, grid_size]) pos_embed = self.get_2d_sincos_pos_embed_from_grid(embed_dim, grid) if cls_token: pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) return pos_embed def forward(self, x): #B, N, C = x.shape B,H,W,C = x.shape # # padding # H, W = int(np.ceil(np.sqrt(N))), int(np.ceil(np.sqrt(N))) # add_length = H * W - N # x = torch.cat([x, x[:,:add_length,:]],dim = 1) # pos_embed = torch.zeros(1, H * W + 1, self.embed_dim) # pos_embed = self.get_2d_sincos_pos_embed(pos_embed.shape[-1], int(H), cls_token=True) #pos_embed = torch.from_numpy(self.pos_embed).float().unsqueeze(0).to(x.device) pos_embed = torch.from_numpy(self.pos_embed).float().to(x.device) # print(pos_embed.size()) # print(x.size()) x = x + pos_embed.unsqueeze(1).unsqueeze(1).repeat(1,H,W,1) #x = x + pos_embed[:, 1:, :] # if add_length >0: # x = x[:,:-add_length] return x