WinstonHu's picture
Upload folder xtuner to code/xtuner
e5e24c9 verified
raw
history blame
6.28 kB
from math import ceil
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, reduce
# helper functions
def exists(val):
return val is not None
def moore_penrose_iter_pinv(x, iters = 6):
device = x.device
abs_x = torch.abs(x)
col = abs_x.sum(dim = -1)
row = abs_x.sum(dim = -2)
z = rearrange(x, '... i j -> ... j i') / (torch.max(col) * torch.max(row))
I = torch.eye(x.shape[-1], device = device)
I = rearrange(I, 'i j -> () i j')
for _ in range(iters):
xz = x @ z
z = 0.25 * z @ (13 * I - (xz @ (15 * I - (xz @ (7 * I - xz)))))
return z
# main attention class
class NystromAttention(nn.Module):
def __init__(
self,
dim,
dim_head = 64,
heads = 8,
num_landmarks = 256,
pinv_iterations = 6,
residual = True,
residual_conv_kernel = 33,
eps = 1e-8,
dropout = 0.,
n_token = 1
):
super().__init__()
self.eps = eps
inner_dim = heads * dim_head
self.n_token = n_token
self.num_landmarks = num_landmarks
self.pinv_iterations = pinv_iterations
self.heads = heads
self.scale = dim_head ** -0.5
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
self.residual = residual
if residual:
kernel_size = residual_conv_kernel
padding = residual_conv_kernel // 2
self.res_conv = nn.Conv2d(heads, heads, (kernel_size, 1), padding = (padding, 0), groups = heads, bias = False)
def forward(self, x, mask = None, return_attn = False):
b, n, _, h, m, iters, eps = *x.shape, self.heads, self.num_landmarks, self.pinv_iterations, self.eps
# pad so that sequence can be evenly divided into m landmarks
remainder = n % m
if remainder > 0:
padding = m - (n % m)
x = F.pad(x, (0, 0, padding, 0), value = 0)
if exists(mask):
mask = F.pad(mask, (padding, 0), value = False)
# derive query, keys, values
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
# set masked positions to 0 in queries, keys, values
if exists(mask):
mask = rearrange(mask, 'b n -> b () n')
q, k, v = map(lambda t: t * mask[..., None], (q, k, v))
q = q * self.scale
# generate landmarks by sum reduction, and then calculate mean using the mask
l = ceil(n / m)
landmark_einops_eq = '... (n l) d -> ... n d'
q_landmarks = reduce(q, landmark_einops_eq, 'sum', l = l)
k_landmarks = reduce(k, landmark_einops_eq, 'sum', l = l)
# calculate landmark mask, and also get sum of non-masked elements in preparation for masked mean
divisor = l
if exists(mask):
mask_landmarks_sum = reduce(mask, '... (n l) -> ... n', 'sum', l = l)
divisor = mask_landmarks_sum[..., None] + eps
mask_landmarks = mask_landmarks_sum > 0
# masked mean (if mask exists)
q_landmarks /= divisor
k_landmarks /= divisor
# similarities
einops_eq = '... i d, ... j d -> ... i j'
attn1 = einsum(einops_eq, q, k_landmarks)
attn2 = einsum(einops_eq, q_landmarks, k_landmarks)
attn3 = einsum(einops_eq, q_landmarks, k)
# masking
if exists(mask):
mask_value = -torch.finfo(q.dtype).max
sim1.masked_fill_(~(mask[..., None] * mask_landmarks[..., None, :]), mask_value)
sim2.masked_fill_(~(mask_landmarks[..., None] * mask_landmarks[..., None, :]), mask_value)
sim3.masked_fill_(~(mask_landmarks[..., None] * mask[..., None, :]), mask_value)
# eq (15) in the paper and aggregate values
attn1, attn2, attn3 = map(lambda t: t.softmax(dim = -1), (attn1, attn2, attn3))
attn2 = moore_penrose_iter_pinv(attn2, iters)
out = (attn1 @ attn2) @ (attn3 @ v)
# add depth-wise conv residual of values
if self.residual:
out += self.res_conv(v)
# merge and combine heads
out = rearrange(out, 'b h n d -> b n (h d)', h = h)
out = self.to_out(out)
out = out[:, -n:]
if return_attn:
attn1 = attn1[:,:,:self.n_token] @ attn2
attn1 = (attn1 @ attn3)
return out, attn1.mean(1)
return out
# transformer
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
x = self.norm(x)
return self.fn(x, **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, mult = 4, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim * mult),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim * mult, dim)
)
def forward(self, x):
return self.net(x)
class Nystromformer(nn.Module):
def __init__(
self,
*,
dim,
depth,
dim_head = 64,
heads = 8,
num_landmarks = 256,
pinv_iterations = 6,
attn_values_residual = True,
attn_values_residual_conv_kernel = 33,
attn_dropout = 0.,
ff_dropout = 0.
):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, NystromAttention(dim = dim, dim_head = dim_head, heads = heads, num_landmarks = num_landmarks, pinv_iterations = pinv_iterations, residual = attn_values_residual, residual_conv_kernel = attn_values_residual_conv_kernel, dropout = attn_dropout)),
PreNorm(dim, FeedForward(dim = dim, dropout = ff_dropout))
]))
def forward(self, x, mask = None):
for attn, ff in self.layers:
x = attn(x, mask = mask) + x
x = ff(x) + x
return x