Spaces:
Sleeping
Sleeping
Add model directory
Browse files- models/ACMDM.py +437 -0
- models/ACMDM_ControlNet.py +314 -0
- models/ACMDM_NoisyPrefix_AR.py +556 -0
- models/ACMDM_Prefix_AR.py +434 -0
- models/AE_2D_Causal.py +245 -0
- models/AE_2D_NonCausal.py +228 -0
- models/AE_Mesh.py +601 -0
- models/LengthEstimator.py +40 -0
- models/ROPE.py +91 -0
- models/__pycache__/ACMDM.cpython-310.pyc +0 -0
- models/__pycache__/ACMDM.cpython-313.pyc +0 -0
- models/__pycache__/AE_2D_Causal.cpython-310.pyc +0 -0
- models/__pycache__/AE_2D_Causal.cpython-313.pyc +0 -0
- models/__pycache__/LengthEstimator.cpython-310.pyc +0 -0
- models/__pycache__/ROPE.cpython-310.pyc +0 -0
models/ACMDM.py
ADDED
|
@@ -0,0 +1,437 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import clip
|
| 5 |
+
import math
|
| 6 |
+
from functools import partial
|
| 7 |
+
from timm.models.vision_transformer import Attention
|
| 8 |
+
from models.ROPE import RopeND
|
| 9 |
+
from utils.eval_utils import eval_decorator
|
| 10 |
+
from utils.train_utils import lengths_to_mask
|
| 11 |
+
from diffusions.diffusion import create_diffusion
|
| 12 |
+
from diffusions.transport import create_transport, Sampler
|
| 13 |
+
|
| 14 |
+
#################################################################################
|
| 15 |
+
# ACMDM #
|
| 16 |
+
#################################################################################
|
| 17 |
+
class ACMDM(nn.Module):
|
| 18 |
+
def __init__(self, input_dim, cond_mode, latent_dim=256, ff_size=1024, num_layers=8,
|
| 19 |
+
num_heads=4, dropout=0, clip_dim=512,
|
| 20 |
+
diff_model='Flow', cond_drop_prob=0.1, max_length=49,
|
| 21 |
+
patch_size=(1, 22), stride_size=(1, 22), num_joint=22,
|
| 22 |
+
clip_version='ViT-B/32', **kargs):
|
| 23 |
+
super(ACMDM, self).__init__()
|
| 24 |
+
|
| 25 |
+
self.input_dim = input_dim
|
| 26 |
+
self.latent_dim = latent_dim
|
| 27 |
+
self.clip_dim = clip_dim
|
| 28 |
+
self.dropout = dropout
|
| 29 |
+
|
| 30 |
+
self.cond_mode = cond_mode
|
| 31 |
+
self.cond_drop_prob = cond_drop_prob
|
| 32 |
+
|
| 33 |
+
if self.cond_mode == 'action':
|
| 34 |
+
assert 'num_actions' in kargs
|
| 35 |
+
self.num_actions = kargs.get('num_actions', 1)
|
| 36 |
+
self.encode_action = partial(F.one_hot, num_classes=self.num_actions)
|
| 37 |
+
# --------------------------------------------------------------------------
|
| 38 |
+
# Diffusion
|
| 39 |
+
self.diff_model = diff_model
|
| 40 |
+
if self.diff_model == 'Flow':
|
| 41 |
+
self.train_diffusion = create_transport() # default to linear, velocity prediction
|
| 42 |
+
self.gen_diffusion = Sampler(self.train_diffusion)
|
| 43 |
+
else:
|
| 44 |
+
self.train_diffusion = create_diffusion(timestep_respacing="", noise_schedule="linear")
|
| 45 |
+
self.gen_diffusion = create_diffusion(timestep_respacing="", noise_schedule="linear")
|
| 46 |
+
# --------------------------------------------------------------------------
|
| 47 |
+
# ACMDM
|
| 48 |
+
print('Loading ACMDM...')
|
| 49 |
+
self.t_embedder = TimestepEmbedder(self.latent_dim)
|
| 50 |
+
self.patch_size = patch_size
|
| 51 |
+
self.stride_size = stride_size
|
| 52 |
+
self.patches_per_frame = (num_joint - patch_size[1]) // stride_size[1] + 1
|
| 53 |
+
|
| 54 |
+
# Patchification
|
| 55 |
+
self.x_embedder = nn.Conv2d(self.input_dim, self.latent_dim, kernel_size=self.patch_size, stride=self.stride_size, bias=True)
|
| 56 |
+
|
| 57 |
+
# Positional Encoding
|
| 58 |
+
max_length = max_length * self.patches_per_frame
|
| 59 |
+
self.max_lens = [max_length]
|
| 60 |
+
self.rope = RopeND(nd=1, nd_split=[1], max_lens=self.max_lens)
|
| 61 |
+
self.position_ids_precompute = torch.arange(max_length).unsqueeze(0)
|
| 62 |
+
|
| 63 |
+
self.ACMDMTransformer = nn.ModuleList([
|
| 64 |
+
ACMDMTransBlock(self.latent_dim, num_heads, mlp_size=ff_size, rope=self.rope, qk_norm=True) for _ in range(num_layers)
|
| 65 |
+
])
|
| 66 |
+
|
| 67 |
+
if self.cond_mode == 'text':
|
| 68 |
+
self.y_embedder = nn.Linear(self.clip_dim, self.latent_dim)
|
| 69 |
+
elif self.cond_mode == 'action':
|
| 70 |
+
self.y_embedder = nn.Linear(self.num_actions, self.latent_dim)
|
| 71 |
+
elif self.cond_mode == 'uncond':
|
| 72 |
+
self.y_embedder = nn.Identity()
|
| 73 |
+
else:
|
| 74 |
+
raise KeyError("Unsupported condition mode!!!")
|
| 75 |
+
|
| 76 |
+
self.final_layer = FinalLayer(self.latent_dim, self.input_dim, patch_size=patch_size, stride_size=stride_size, patches=self.patches_per_frame, joint=num_joint)
|
| 77 |
+
|
| 78 |
+
self.initialize_weights()
|
| 79 |
+
|
| 80 |
+
if self.cond_mode == 'text':
|
| 81 |
+
print('Loading CLIP...')
|
| 82 |
+
self.clip_version = clip_version
|
| 83 |
+
self.clip_model = self.load_and_freeze_clip(clip_version)
|
| 84 |
+
|
| 85 |
+
def initialize_weights(self):
|
| 86 |
+
# Initialize transformer layers:
|
| 87 |
+
def _basic_init(module):
|
| 88 |
+
if isinstance(module, nn.Linear):
|
| 89 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
| 90 |
+
if module.bias is not None:
|
| 91 |
+
nn.init.constant_(module.bias, 0)
|
| 92 |
+
|
| 93 |
+
self.apply(_basic_init)
|
| 94 |
+
|
| 95 |
+
# Initialize timestep embedding MLP:
|
| 96 |
+
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
| 97 |
+
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
| 98 |
+
|
| 99 |
+
# Zero-out adaLN modulation layers in ACMDM blocks:
|
| 100 |
+
for block in self.ACMDMTransformer:
|
| 101 |
+
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
| 102 |
+
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
| 103 |
+
|
| 104 |
+
# Zero-out output layers:
|
| 105 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
| 106 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
| 107 |
+
nn.init.constant_(self.final_layer.linear.weight, 0)
|
| 108 |
+
nn.init.constant_(self.final_layer.linear.bias, 0)
|
| 109 |
+
|
| 110 |
+
def load_and_freeze_clip(self, clip_version):
|
| 111 |
+
clip_model, clip_preprocess = clip.load(clip_version, device='cpu', jit=False)
|
| 112 |
+
assert torch.cuda.is_available()
|
| 113 |
+
clip.model.convert_weights(clip_model)
|
| 114 |
+
|
| 115 |
+
clip_model.eval()
|
| 116 |
+
for p in clip_model.parameters():
|
| 117 |
+
p.requires_grad = False
|
| 118 |
+
return clip_model
|
| 119 |
+
|
| 120 |
+
def encode_text(self, raw_text):
|
| 121 |
+
device = next(self.parameters()).device
|
| 122 |
+
text = clip.tokenize(raw_text, truncate=True).to(device)
|
| 123 |
+
feat_clip_text = self.clip_model.encode_text(text).float()
|
| 124 |
+
return feat_clip_text
|
| 125 |
+
|
| 126 |
+
def mask_cond(self, cond, force_mask=False):
|
| 127 |
+
bs, d = cond.shape
|
| 128 |
+
if force_mask:
|
| 129 |
+
return torch.zeros_like(cond)
|
| 130 |
+
elif self.training and self.cond_drop_prob > 0.:
|
| 131 |
+
mask = torch.bernoulli(torch.ones(bs, device=cond.device) * self.cond_drop_prob).view(bs, 1)
|
| 132 |
+
return cond * (1. - mask)
|
| 133 |
+
else:
|
| 134 |
+
return cond
|
| 135 |
+
|
| 136 |
+
def forward(self, x, t, conds, attention_mask, force_mask=False):
|
| 137 |
+
t = self.t_embedder(t, dtype=x.dtype)
|
| 138 |
+
conds = self.mask_cond(conds, force_mask=force_mask)
|
| 139 |
+
x = self.x_embedder(x)
|
| 140 |
+
x = x.flatten(2).transpose(1, 2)
|
| 141 |
+
conds = self.y_embedder(conds)
|
| 142 |
+
y = t.unsqueeze(1) + conds.unsqueeze(1)
|
| 143 |
+
position_ids = self.position_ids_precompute[:, :x.shape[1]]
|
| 144 |
+
for block in self.ACMDMTransformer:
|
| 145 |
+
x = block(x, y, attention_mask, position_ids=position_ids)
|
| 146 |
+
x = self.final_layer(x, y)
|
| 147 |
+
return x
|
| 148 |
+
|
| 149 |
+
def forward_with_CFG(self, x, t, conds, attention_mask, cfg=1.0):
|
| 150 |
+
if not cfg == 1.0:
|
| 151 |
+
half = x[: len(x) // 2]
|
| 152 |
+
x = torch.cat([half, half], dim=0)
|
| 153 |
+
x = self.forward(x, t, conds, attention_mask)
|
| 154 |
+
if not cfg == 1.0:
|
| 155 |
+
cond_eps, uncond_eps = torch.split(x, len(x) // 2, dim=0)
|
| 156 |
+
half_eps = uncond_eps + cfg * (cond_eps - uncond_eps)
|
| 157 |
+
x = torch.cat([half_eps, half_eps], dim=0)
|
| 158 |
+
return x
|
| 159 |
+
|
| 160 |
+
def forward_loss(self, latents, y, m_lens):
|
| 161 |
+
latents = latents.permute(0, 2, 3, 1)
|
| 162 |
+
b, l, j, d = latents.shape
|
| 163 |
+
device = latents.device
|
| 164 |
+
|
| 165 |
+
non_pad_mask = lengths_to_mask(m_lens, l)
|
| 166 |
+
latents = torch.where(non_pad_mask.unsqueeze(-1).unsqueeze(-1), latents, torch.zeros_like(latents))
|
| 167 |
+
|
| 168 |
+
target = latents.clone().permute(0, 3, 1, 2).detach()
|
| 169 |
+
|
| 170 |
+
force_mask = False
|
| 171 |
+
if self.cond_mode == 'text':
|
| 172 |
+
with torch.no_grad():
|
| 173 |
+
cond_vector = self.encode_text(y)
|
| 174 |
+
elif self.cond_mode == 'action':
|
| 175 |
+
cond_vector = self.enc_action(y).to(device).float()
|
| 176 |
+
elif self.cond_mode == 'uncond':
|
| 177 |
+
cond_vector = torch.zeros(b, self.latent_dim).float().to(device)
|
| 178 |
+
force_mask = True
|
| 179 |
+
else:
|
| 180 |
+
raise NotImplementedError("Unsupported condition mode!!!")
|
| 181 |
+
|
| 182 |
+
attention_mask = non_pad_mask.unsqueeze(-1).repeat(1, 1, self.patches_per_frame).flatten(1).unsqueeze(1).unsqueeze(1)
|
| 183 |
+
|
| 184 |
+
model_kwargs = dict(conds=cond_vector, force_mask=force_mask, attention_mask=attention_mask)
|
| 185 |
+
if self.diff_model == "Flow":
|
| 186 |
+
loss_dict = self.train_diffusion.training_losses(self.forward, target, model_kwargs)
|
| 187 |
+
else:
|
| 188 |
+
t = torch.randint(0, self.train_diffusion.num_timesteps, (target.shape[0],), device=target.device)
|
| 189 |
+
loss_dict = self.train_diffusion.training_losses(self.forward, target, t, model_kwargs)
|
| 190 |
+
loss = loss_dict["loss"]
|
| 191 |
+
loss = (loss * non_pad_mask).sum() / non_pad_mask.sum()
|
| 192 |
+
|
| 193 |
+
return loss
|
| 194 |
+
|
| 195 |
+
@torch.no_grad()
|
| 196 |
+
@eval_decorator
|
| 197 |
+
def generate(self,
|
| 198 |
+
conds,
|
| 199 |
+
m_lens,
|
| 200 |
+
cond_scale: int,
|
| 201 |
+
temperature=1,
|
| 202 |
+
j=22,
|
| 203 |
+
):
|
| 204 |
+
device = next(self.parameters()).device
|
| 205 |
+
l = max(m_lens)
|
| 206 |
+
b = len(m_lens)
|
| 207 |
+
|
| 208 |
+
if self.cond_mode == 'text':
|
| 209 |
+
with torch.no_grad():
|
| 210 |
+
cond_vector = self.encode_text(conds)
|
| 211 |
+
elif self.cond_mode == 'action':
|
| 212 |
+
cond_vector = self.enc_action(conds).to(device)
|
| 213 |
+
elif self.cond_mode == 'uncond':
|
| 214 |
+
cond_vector = torch.zeros(b, self.latent_dim).float().to(device)
|
| 215 |
+
else:
|
| 216 |
+
raise NotImplementedError("Unsupported condition mode!!!")
|
| 217 |
+
|
| 218 |
+
padding_mask = ~lengths_to_mask(m_lens, l)
|
| 219 |
+
|
| 220 |
+
noise = torch.randn(b, self.input_dim, l, j).to(device)
|
| 221 |
+
if not cond_scale == 1.0:
|
| 222 |
+
cond_vector = torch.cat([cond_vector, torch.zeros_like(cond_vector)], dim=0)
|
| 223 |
+
noise = torch.cat([noise, noise], dim=0)
|
| 224 |
+
|
| 225 |
+
attention_mask = (~padding_mask).unsqueeze(-1).repeat(1,1,self.patches_per_frame).flatten(1).unsqueeze(1).unsqueeze(1)
|
| 226 |
+
model_kwargs = dict(conds=cond_vector, attention_mask=attention_mask, cfg=cond_scale)
|
| 227 |
+
sample_fn = self.forward_with_CFG
|
| 228 |
+
|
| 229 |
+
if not cond_scale == 1:
|
| 230 |
+
model_kwargs["attention_mask"] = attention_mask.repeat(2, 1, 1, 1)
|
| 231 |
+
|
| 232 |
+
if self.diff_model == "Flow":
|
| 233 |
+
model_fn = self.gen_diffusion.sample_ode() # default to ode sampling
|
| 234 |
+
sampled_token_latent = model_fn(noise, sample_fn, **model_kwargs)[-1]
|
| 235 |
+
else:
|
| 236 |
+
sampled_token_latent = self.gen_diffusion.p_sample_loop(
|
| 237 |
+
sample_fn, noise.shape, noise, clip_denoised=False, model_kwargs=model_kwargs,
|
| 238 |
+
progress=False,
|
| 239 |
+
temperature=temperature
|
| 240 |
+
)
|
| 241 |
+
if not cond_scale == 1:
|
| 242 |
+
sampled_token_latent, _ = sampled_token_latent.chunk(2, dim=0)
|
| 243 |
+
sampled_token_latent = sampled_token_latent.permute(0,2,3,1)
|
| 244 |
+
|
| 245 |
+
latents = torch.where(padding_mask.unsqueeze(-1).unsqueeze(-1), torch.zeros_like(sampled_token_latent), sampled_token_latent)
|
| 246 |
+
return latents.permute(0,3,1,2)
|
| 247 |
+
|
| 248 |
+
#################################################################################
|
| 249 |
+
# ACMDM Zoos #
|
| 250 |
+
#################################################################################
|
| 251 |
+
def acmdm_raw_flow_s_ps22(**kwargs):
|
| 252 |
+
layer = 8
|
| 253 |
+
return ACMDM(latent_dim=layer*64, ff_size=layer*64*4, num_layers=layer, num_heads=layer, dropout=0, clip_dim=512,
|
| 254 |
+
diff_model="Flow", cond_drop_prob=0.1, max_length=196,
|
| 255 |
+
patch_size=(1, 22), stride_size=(1, 22), **kwargs)
|
| 256 |
+
def acmdm_flow_s_ps22(**kwargs):
|
| 257 |
+
layer = 8
|
| 258 |
+
return ACMDM(latent_dim=layer*64, ff_size=layer*64*4, num_layers=layer, num_heads=layer, dropout=0, clip_dim=512,
|
| 259 |
+
diff_model="Flow", cond_drop_prob=0.1, max_length=49,
|
| 260 |
+
patch_size=(1, 22), stride_size=(1, 22), **kwargs)
|
| 261 |
+
def acmdm_flow_xl_ps2(**kwargs):
|
| 262 |
+
layer = 20
|
| 263 |
+
return ACMDM(latent_dim=layer*64, ff_size=layer*64*4, num_layers=layer, num_heads=layer, dropout=0, clip_dim=512,
|
| 264 |
+
diff_model="Flow", cond_drop_prob=0.1, max_length=49,
|
| 265 |
+
patch_size=(1, 2), stride_size=(1, 2), **kwargs)
|
| 266 |
+
def acmdm_mesh_flow_s_ps28(**kwargs):
|
| 267 |
+
layer = 8
|
| 268 |
+
return ACMDM(latent_dim=layer*64, ff_size=layer*64*4, num_layers=layer, num_heads=layer, dropout=0, clip_dim=512,
|
| 269 |
+
diff_model="Flow", cond_drop_prob=0.1, max_length=196, num_joint=28,
|
| 270 |
+
patch_size=(1, 28), stride_size=(1, 28), **kwargs)
|
| 271 |
+
ACMDM_models = {
|
| 272 |
+
'ACMDM-Raw-Flow-S-PatchSize22': acmdm_raw_flow_s_ps22, 'ACMDM-Flow-S-PatchSize22': acmdm_flow_s_ps22,
|
| 273 |
+
'ACMDM-Flow-XL-PatchSize2': acmdm_flow_xl_ps2, 'ACMDM-Mesh-Flow-S-PatchSize28': acmdm_mesh_flow_s_ps28,
|
| 274 |
+
}
|
| 275 |
+
|
| 276 |
+
#################################################################################
|
| 277 |
+
# Inner Architectures #
|
| 278 |
+
#################################################################################
|
| 279 |
+
def modulate(x, shift, scale):
|
| 280 |
+
return x * (1 + scale) + shift
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
class ACMDMAttention(Attention):
|
| 284 |
+
def __init__(
|
| 285 |
+
self,
|
| 286 |
+
dim,
|
| 287 |
+
num_heads=8,
|
| 288 |
+
qkv_bias=True,
|
| 289 |
+
rope=None,
|
| 290 |
+
qk_norm=True,
|
| 291 |
+
**block_kwargs,
|
| 292 |
+
):
|
| 293 |
+
super().__init__(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm, **block_kwargs)
|
| 294 |
+
self.rope = rope
|
| 295 |
+
|
| 296 |
+
def forward(self, x, position_ids=None, attention_mask=None):
|
| 297 |
+
B, N, C = x.shape
|
| 298 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
| 299 |
+
q, k, v = qkv.unbind(0)
|
| 300 |
+
q, k = self.q_norm(q), self.k_norm(k)
|
| 301 |
+
|
| 302 |
+
if self.rope is not None:
|
| 303 |
+
q, k = self.rope(q, k, position_ids)
|
| 304 |
+
|
| 305 |
+
x = torch.nn.functional.scaled_dot_product_attention(
|
| 306 |
+
q, k, v,
|
| 307 |
+
attn_mask=attention_mask,
|
| 308 |
+
dropout_p=self.attn_drop.p
|
| 309 |
+
)
|
| 310 |
+
x = x.transpose(1, 2).reshape(B, N, C)
|
| 311 |
+
x = self.proj(x)
|
| 312 |
+
x = self.proj_drop(x)
|
| 313 |
+
return x
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
class SwiGLUFFN(nn.Module):
|
| 317 |
+
def __init__(
|
| 318 |
+
self,
|
| 319 |
+
in_features: int,
|
| 320 |
+
hidden_features,
|
| 321 |
+
bias: bool = True,
|
| 322 |
+
) -> None:
|
| 323 |
+
super().__init__()
|
| 324 |
+
out_features = in_features
|
| 325 |
+
hidden_features = hidden_features
|
| 326 |
+
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
|
| 327 |
+
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
|
| 328 |
+
|
| 329 |
+
def forward(self, x):
|
| 330 |
+
x12 = self.w12(x)
|
| 331 |
+
x1, x2 = x12.chunk(2, dim=-1)
|
| 332 |
+
hidden = F.silu(x1) * x2
|
| 333 |
+
return self.w3(hidden)
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
class ACMDMTransBlock(nn.Module):
|
| 337 |
+
def __init__(self, hidden_size, num_heads, mlp_size=1024, rope=None, qk_norm=True):
|
| 338 |
+
super().__init__()
|
| 339 |
+
self.norm1 = LlamaRMSNorm(hidden_size, eps=1e-6)
|
| 340 |
+
self.attn = ACMDMAttention(hidden_size, num_heads=num_heads, qkv_bias=True, norm_layer=LlamaRMSNorm,
|
| 341 |
+
qk_norm=qk_norm, rope=rope)
|
| 342 |
+
self.norm2 = LlamaRMSNorm(hidden_size, eps=1e-6)
|
| 343 |
+
self.mlp = SwiGLUFFN(hidden_size, int(2 / 3 * mlp_size))
|
| 344 |
+
self.adaLN_modulation = nn.Sequential(
|
| 345 |
+
nn.SiLU(),
|
| 346 |
+
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
def forward(self, x, c, attention_mask=None, position_ids=None):
|
| 350 |
+
dtype = x.dtype
|
| 351 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
|
| 352 |
+
norm_x1 = self.norm1(x.to(torch.float32)).to(dtype)
|
| 353 |
+
attn_input_x = modulate(norm_x1, shift_msa, scale_msa)
|
| 354 |
+
attn_output_x = self.attn(attn_input_x, attention_mask=attention_mask, position_ids=position_ids)
|
| 355 |
+
x = x + gate_msa * attn_output_x
|
| 356 |
+
|
| 357 |
+
norm_x2 = self.norm2(x.to(torch.float32)).to(dtype)
|
| 358 |
+
gate_input_x = modulate(norm_x2, shift_mlp, scale_mlp)
|
| 359 |
+
gate_output_x = self.mlp(gate_input_x)
|
| 360 |
+
x = x + gate_mlp * gate_output_x
|
| 361 |
+
return x
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
class FinalLayer(nn.Module):
|
| 365 |
+
def __init__(self, hidden_size, output_size, patch_size=(1, 22), stride_size=(1,22), patches=1, joint=22):
|
| 366 |
+
super().__init__()
|
| 367 |
+
self.norm_final = LlamaRMSNorm(hidden_size, eps=1e-6)
|
| 368 |
+
self.patch_size = patch_size
|
| 369 |
+
self.stride_size = stride_size
|
| 370 |
+
self.patches = patches
|
| 371 |
+
self.joint=joint
|
| 372 |
+
self.linear = nn.Linear(hidden_size, output_size*patch_size[0]*patch_size[1], bias=True)
|
| 373 |
+
self.adaLN_modulation = nn.Sequential(
|
| 374 |
+
nn.SiLU(),
|
| 375 |
+
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
def forward(self, x, c):
|
| 379 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
|
| 380 |
+
norm_x = self.norm_final(x.to(torch.float32)).to(x.dtype)
|
| 381 |
+
x = modulate(norm_x, shift, scale)
|
| 382 |
+
x = self.linear(x)
|
| 383 |
+
x = x.reshape(shape=(x.shape[0], x.shape[1]//self.patches, self.patches, self.patch_size[0], self.patch_size[1], x.shape[-1] // self.patch_size[1]))
|
| 384 |
+
x = torch.einsum('nljpqc->nclpjq', x)
|
| 385 |
+
x = x.reshape(shape=(x.shape[0], x.shape[1], -1, self.joint))
|
| 386 |
+
return x
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
class TimestepEmbedder(nn.Module):
|
| 390 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
| 391 |
+
super().__init__()
|
| 392 |
+
self.mlp = nn.Sequential(
|
| 393 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
| 394 |
+
nn.SiLU(),
|
| 395 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
| 396 |
+
)
|
| 397 |
+
self.frequency_embedding_size = frequency_embedding_size
|
| 398 |
+
|
| 399 |
+
@staticmethod
|
| 400 |
+
def timestep_embedding(t, dim, max_period=10000, dtype=torch.float32):
|
| 401 |
+
"""
|
| 402 |
+
Create sinusoidal timestep embeddings.
|
| 403 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
| 404 |
+
These may be fractional.
|
| 405 |
+
:param dim: the dimension of the output.
|
| 406 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
| 407 |
+
:return: an (N, D) Tensor of positional embeddings.
|
| 408 |
+
"""
|
| 409 |
+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
| 410 |
+
half = dim // 2
|
| 411 |
+
freqs = torch.exp(
|
| 412 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=dtype) / half
|
| 413 |
+
).to(device=t.device, dtype=dtype)
|
| 414 |
+
args = t[:, None] * freqs[None]
|
| 415 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 416 |
+
if dim % 2:
|
| 417 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 418 |
+
return embedding
|
| 419 |
+
|
| 420 |
+
def forward(self, t, dtype=torch.bfloat16):
|
| 421 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size, dtype=dtype)
|
| 422 |
+
t_emb = self.mlp(t_freq)
|
| 423 |
+
return t_emb
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
class LlamaRMSNorm(nn.Module):
|
| 427 |
+
def __init__(self, hidden_size, eps=1e-6):
|
| 428 |
+
super().__init__()
|
| 429 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 430 |
+
self.variance_epsilon = eps
|
| 431 |
+
|
| 432 |
+
def forward(self, hidden_states):
|
| 433 |
+
input_dtype = hidden_states.dtype
|
| 434 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 435 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 436 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 437 |
+
return (self.weight * hidden_states).to(input_dtype)
|
models/ACMDM_ControlNet.py
ADDED
|
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from models.ACMDM import ACMDM
|
| 4 |
+
from models.ACMDM import TimestepEmbedder, ACMDMTransBlock, LlamaRMSNorm
|
| 5 |
+
from models.ROPE import RopeND
|
| 6 |
+
from utils.eval_utils import eval_decorator
|
| 7 |
+
from utils.train_utils import lengths_to_mask
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
#################################################################################
|
| 11 |
+
# ACMDM+ControlNet #
|
| 12 |
+
#################################################################################
|
| 13 |
+
class ACMDM_ControlNet(ACMDM):
|
| 14 |
+
def __init__(self, input_dim, cond_mode, base_checkpoint, latent_dim=256, ff_size=1024, num_layers=8,
|
| 15 |
+
num_heads=4, dropout=0.2, clip_dim=512,
|
| 16 |
+
diff_model='Flow', cond_drop_prob=0.1, max_length=49,
|
| 17 |
+
patch_size=(1, 22), stride_size=(1, 22),
|
| 18 |
+
clip_version='ViT-B/32', freeze_base=True, need_base=True, **kargs):
|
| 19 |
+
# --------------------------------------------------------------------------
|
| 20 |
+
# ACMDM
|
| 21 |
+
super().__init__(input_dim, cond_mode, latent_dim=latent_dim, ff_size=ff_size, num_layers=num_layers,
|
| 22 |
+
num_heads=num_heads, dropout=dropout, clip_dim=clip_dim,
|
| 23 |
+
diff_model=diff_model, cond_drop_prob=cond_drop_prob, max_length=max_length,
|
| 24 |
+
patch_size=patch_size, stride_size=stride_size,
|
| 25 |
+
clip_version=clip_version, **kargs)
|
| 26 |
+
|
| 27 |
+
# --------------------------------------------------------------------------
|
| 28 |
+
# ControlNet
|
| 29 |
+
self.c_t_embedder = TimestepEmbedder(self.latent_dim)
|
| 30 |
+
self.c_control_embedder = c_control_embedder(3, self.latent_dim, patch_size=self.patch_size,
|
| 31 |
+
stride_size=self.stride_size)
|
| 32 |
+
self.c_x_embedder = nn.Conv2d(self.input_dim, self.latent_dim, kernel_size=self.patch_size,
|
| 33 |
+
stride=self.stride_size, bias=True)
|
| 34 |
+
self.c_y_embedder = nn.Linear(self.clip_dim, self.latent_dim)
|
| 35 |
+
self.c_rope = RopeND(nd=1, nd_split=[1], max_lens=self.max_lens)
|
| 36 |
+
self.ControlNet = nn.ModuleList([
|
| 37 |
+
ACMDMTransBlock(self.latent_dim, num_heads, mlp_size=ff_size, rope=self.c_rope, qk_norm=True) for _ in
|
| 38 |
+
range(num_layers)
|
| 39 |
+
])
|
| 40 |
+
self.zero_Linear = nn.ModuleList([
|
| 41 |
+
nn.Linear(self.latent_dim, self.latent_dim) for _ in range(num_layers)
|
| 42 |
+
])
|
| 43 |
+
self.initialize_weights_control()
|
| 44 |
+
if need_base:
|
| 45 |
+
for key, value in list(base_checkpoint['ema_acmdm'].items()):
|
| 46 |
+
if key.startswith('ACMDMTransformer.'):
|
| 47 |
+
new_key = key.replace('ACMDMTransformer.', 'ControlNet.')
|
| 48 |
+
base_checkpoint['ema_acmdm'][new_key] = value.clone()
|
| 49 |
+
missing_keys, unexpected_keys = self.load_state_dict(base_checkpoint['ema_acmdm'], strict=False)
|
| 50 |
+
assert len(unexpected_keys) == 0
|
| 51 |
+
|
| 52 |
+
if self.cond_mode == 'text':
|
| 53 |
+
print('ReLoading CLIP...')
|
| 54 |
+
self.clip_version = clip_version
|
| 55 |
+
self.clip_model = self.load_and_freeze_clip(clip_version)
|
| 56 |
+
|
| 57 |
+
if freeze_base:
|
| 58 |
+
for param in self.t_embedder.parameters():
|
| 59 |
+
param.requires_grad = False
|
| 60 |
+
for param in self.x_embedder.parameters():
|
| 61 |
+
param.requires_grad = False
|
| 62 |
+
for param in self.y_embedder.parameters():
|
| 63 |
+
param.requires_grad = False
|
| 64 |
+
for param in self.final_layer.parameters():
|
| 65 |
+
param.requires_grad = False
|
| 66 |
+
for param in self.ACMDMTransformer.parameters():
|
| 67 |
+
param.requires_grad = False
|
| 68 |
+
|
| 69 |
+
def initialize_weights_control(self):
|
| 70 |
+
# Initialize transformer layers:
|
| 71 |
+
def _basic_init(module):
|
| 72 |
+
if isinstance(module, nn.Linear):
|
| 73 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
| 74 |
+
if module.bias is not None:
|
| 75 |
+
nn.init.constant_(module.bias, 0)
|
| 76 |
+
|
| 77 |
+
self.apply(_basic_init)
|
| 78 |
+
|
| 79 |
+
# Initialize timestep embedding MLP:
|
| 80 |
+
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
| 81 |
+
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
| 82 |
+
|
| 83 |
+
# Zero-out adaLN modulation layers in DiT blocks:
|
| 84 |
+
for block in self.ACMDMTransformer:
|
| 85 |
+
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
| 86 |
+
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
| 87 |
+
|
| 88 |
+
# Zero-out output layers:
|
| 89 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
| 90 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
| 91 |
+
nn.init.constant_(self.final_layer.linear.weight, 0)
|
| 92 |
+
nn.init.constant_(self.final_layer.linear.bias, 0)
|
| 93 |
+
|
| 94 |
+
# Initialize timestep embedding MLP:
|
| 95 |
+
nn.init.normal_(self.c_t_embedder.mlp[0].weight, std=0.02)
|
| 96 |
+
nn.init.normal_(self.c_t_embedder.mlp[2].weight, std=0.02)
|
| 97 |
+
|
| 98 |
+
# Zero-out adaLN modulation layers in DiT blocks:
|
| 99 |
+
for block in self.ControlNet:
|
| 100 |
+
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
| 101 |
+
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
| 102 |
+
|
| 103 |
+
nn.init.constant_(self.c_control_embedder.zero_linear.weight, 0)
|
| 104 |
+
nn.init.constant_(self.c_control_embedder.zero_linear.bias, 0)
|
| 105 |
+
|
| 106 |
+
for block in self.zero_Linear:
|
| 107 |
+
nn.init.constant_(block.weight, 0)
|
| 108 |
+
nn.init.constant_(block.bias, 0)
|
| 109 |
+
|
| 110 |
+
def forward_with_control(self, x, t, conds, attention_mask, cfg1=1.0, cfg2=1.0, control=None, index=None,
|
| 111 |
+
force_mask=False):
|
| 112 |
+
if not (cfg1 == 1.0 and cfg2 == 1.0):
|
| 113 |
+
half = x[: len(x) // 3]
|
| 114 |
+
x = torch.cat([half, half, half], dim=0)
|
| 115 |
+
# controlnet
|
| 116 |
+
c_t = self.c_t_embedder(t, dtype=x.dtype)
|
| 117 |
+
conds = self.mask_cond(conds, force_mask=force_mask)
|
| 118 |
+
c_control = self.c_control_embedder(control * index)
|
| 119 |
+
if self.training and self.cond_drop_prob > 0.:
|
| 120 |
+
mask = torch.bernoulli(torch.ones(c_control.shape[0], device=c_control.device) * self.cond_drop_prob).view(c_control.shape[0], 1, 1)
|
| 121 |
+
c_control = c_control * (1. - mask)
|
| 122 |
+
if not (cfg1 == 1.0 and cfg2 == 1.0):
|
| 123 |
+
c_control = torch.cat([c_control, c_control, torch.zeros_like(c_control)], dim=0)
|
| 124 |
+
c_x = self.c_x_embedder(x).flatten(2).transpose(1, 2)
|
| 125 |
+
c_y = self.c_y_embedder(conds)
|
| 126 |
+
c_y = c_t.unsqueeze(1) + c_y.unsqueeze(1)
|
| 127 |
+
c_x = c_x + c_control
|
| 128 |
+
c_position_ids = self.position_ids_precompute[:, :c_x.shape[1]]
|
| 129 |
+
c_out = []
|
| 130 |
+
for c_block, c_linear in zip(self.ControlNet, self.zero_Linear):
|
| 131 |
+
c_x = c_block(c_x, c_y, attention_mask, position_ids=c_position_ids)
|
| 132 |
+
c_out.append(c_linear(c_x))
|
| 133 |
+
# main branch
|
| 134 |
+
tt = self.t_embedder(t, dtype=x.dtype)
|
| 135 |
+
x = self.x_embedder(x)
|
| 136 |
+
x = x.flatten(2).transpose(1, 2)
|
| 137 |
+
conds = self.y_embedder(conds)
|
| 138 |
+
y = tt.unsqueeze(1) + conds.unsqueeze(1)
|
| 139 |
+
position_ids = self.position_ids_precompute[:, :x.shape[1]]
|
| 140 |
+
# merging
|
| 141 |
+
for block, c in zip(self.ACMDMTransformer, c_out):
|
| 142 |
+
x = block(x, y, attention_mask, position_ids=position_ids)
|
| 143 |
+
x = x + c
|
| 144 |
+
x = self.final_layer(x, y)
|
| 145 |
+
if not (cfg1 == 1.0 and cfg2 == 1.0):
|
| 146 |
+
cond_eps, uncond_eps1, uncond_eps2 = torch.split(x, len(x) // 3, dim=0)
|
| 147 |
+
half_eps = cond_eps + (cfg1-1) * (cond_eps - uncond_eps1) + (cfg2-1) * (cond_eps - uncond_eps2)
|
| 148 |
+
x = torch.cat([half_eps, half_eps, half_eps], dim=0)
|
| 149 |
+
return x
|
| 150 |
+
|
| 151 |
+
def forward_control_loss(self, latents, y, m_lens, original, index, ae, mean_std):
|
| 152 |
+
latents = latents.permute(0, 2, 3, 1)
|
| 153 |
+
b, l, j, d = latents.shape
|
| 154 |
+
device = latents.device
|
| 155 |
+
|
| 156 |
+
non_pad_mask = lengths_to_mask(m_lens, l)
|
| 157 |
+
latents = torch.where(non_pad_mask.unsqueeze(-1).unsqueeze(-1), latents, torch.zeros_like(latents))
|
| 158 |
+
|
| 159 |
+
target = latents.clone().permute(0, 3, 1, 2).detach()
|
| 160 |
+
original = original.clone().detach()
|
| 161 |
+
|
| 162 |
+
force_mask = False
|
| 163 |
+
if self.cond_mode == 'text':
|
| 164 |
+
with torch.no_grad():
|
| 165 |
+
cond_vector = self.encode_text(y)
|
| 166 |
+
elif self.cond_mode == 'action':
|
| 167 |
+
cond_vector = self.enc_action(y).to(device).float()
|
| 168 |
+
elif self.cond_mode == 'uncond':
|
| 169 |
+
cond_vector = torch.zeros(b, self.latent_dim).float().to(device)
|
| 170 |
+
force_mask = True
|
| 171 |
+
else:
|
| 172 |
+
raise NotImplementedError("Unsupported condition mode!!!")
|
| 173 |
+
|
| 174 |
+
attention_mask = non_pad_mask.unsqueeze(-1).repeat(1, 1, self.patches_per_frame).flatten(1).unsqueeze(1).unsqueeze(1)
|
| 175 |
+
|
| 176 |
+
random_indices = torch.randint(0, len(index), (b,)).to(device)
|
| 177 |
+
indexx = torch.tensor(index, device=device)[random_indices]
|
| 178 |
+
mask_seq = torch.zeros((b, 3, l*4, j), device=device)
|
| 179 |
+
for i in range(b):
|
| 180 |
+
seq_num = torch.randint(1, m_lens[i]*4, (1,))
|
| 181 |
+
choose_seq = torch.sort(torch.randperm(m_lens[i]*4)[:seq_num.item()]).values
|
| 182 |
+
mask_seq[i, :, choose_seq, indexx[i]] = 1.0
|
| 183 |
+
|
| 184 |
+
model_kwargs = dict(conds=cond_vector, attention_mask=attention_mask, control=original, index=mask_seq,
|
| 185 |
+
force_mask=force_mask, mean_std=mean_std)
|
| 186 |
+
if self.diff_model == "Flow":
|
| 187 |
+
loss_dict = self.train_diffusion.training_losses(self.forward_with_control, target, ae=ae,
|
| 188 |
+
model_kwargs=model_kwargs)
|
| 189 |
+
else:
|
| 190 |
+
t = torch.randint(0, self.train_diffusion.num_timesteps, (target.shape[0],), device=target.device)
|
| 191 |
+
loss_dict = self.train_diffusion.training_losses(self.forward_with_control, target, t, model_kwargs)
|
| 192 |
+
loss = loss_dict["loss"]
|
| 193 |
+
loss = (loss * non_pad_mask).sum() / non_pad_mask.sum()
|
| 194 |
+
|
| 195 |
+
return loss, loss_dict["loss_control"]
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
@torch.no_grad()
|
| 199 |
+
@eval_decorator
|
| 200 |
+
def generate_control(self,
|
| 201 |
+
conds,
|
| 202 |
+
m_lens,
|
| 203 |
+
control,
|
| 204 |
+
index,
|
| 205 |
+
density,
|
| 206 |
+
cond_scale,
|
| 207 |
+
temperature=1,
|
| 208 |
+
j=22
|
| 209 |
+
):
|
| 210 |
+
device = next(self.parameters()).device
|
| 211 |
+
l = control.shape[2]//4
|
| 212 |
+
b = len(m_lens)
|
| 213 |
+
|
| 214 |
+
if self.cond_mode == 'text':
|
| 215 |
+
with torch.no_grad():
|
| 216 |
+
cond_vector = self.encode_text(conds)
|
| 217 |
+
elif self.cond_mode == 'action':
|
| 218 |
+
cond_vector = self.enc_action(conds).to(device)
|
| 219 |
+
elif self.cond_mode == 'uncond':
|
| 220 |
+
cond_vector = torch.zeros(b, self.latent_dim).float().to(device)
|
| 221 |
+
else:
|
| 222 |
+
raise NotImplementedError("Unsupported condition mode!!!")
|
| 223 |
+
|
| 224 |
+
padding_mask = ~lengths_to_mask(m_lens, l)
|
| 225 |
+
|
| 226 |
+
noise = torch.randn(b, self.input_dim, l, j).to(device)
|
| 227 |
+
control = control.clone()
|
| 228 |
+
cfg1 = cond_scale[0]
|
| 229 |
+
cfg2 = cond_scale[1]
|
| 230 |
+
if not (cfg1 == 1.0 and cfg2 == 1.0):
|
| 231 |
+
# (1) with text and with control (2) no text and with control (3) with text and no control
|
| 232 |
+
cond_vector = torch.cat([cond_vector, torch.zeros_like(cond_vector), cond_vector], dim=0)
|
| 233 |
+
|
| 234 |
+
random_indices = torch.tensor(0, device=device).repeat(b) # no random in inference
|
| 235 |
+
indexx = torch.tensor(index, device=device)[random_indices]
|
| 236 |
+
mask_seq = torch.zeros((b, 3, l * 4, j), device=device)
|
| 237 |
+
for i in range(b):
|
| 238 |
+
if density in [1, 2, 5]:
|
| 239 |
+
seq_num = density
|
| 240 |
+
else:
|
| 241 |
+
seq_num = int(m_lens[i] *4* density / 100)
|
| 242 |
+
choose_seq = torch.sort(torch.randperm(m_lens[i] * 4)[:seq_num]).values
|
| 243 |
+
mask_seq[i, :, choose_seq, indexx[i]] = 1.0
|
| 244 |
+
|
| 245 |
+
attention_mask = (~padding_mask).unsqueeze(-1).repeat(1, 1, self.patches_per_frame).flatten(1).unsqueeze(1).unsqueeze(1)
|
| 246 |
+
model_kwargs = dict(conds=cond_vector, attention_mask=attention_mask, cfg1=cfg1, cfg2=cfg2, index=mask_seq,
|
| 247 |
+
control=control)
|
| 248 |
+
sample_fn = self.forward_with_control
|
| 249 |
+
|
| 250 |
+
if not (cfg1 == 1.0 and cfg2 == 1.0):
|
| 251 |
+
model_kwargs["attention_mask"] = attention_mask.repeat(3, 1, 1, 1)
|
| 252 |
+
noise = torch.cat([noise, noise, noise], dim=0)
|
| 253 |
+
|
| 254 |
+
if self.diff_model == "Flow":
|
| 255 |
+
model_fn = self.gen_diffusion.sample_ode() # default to ode sampling
|
| 256 |
+
sampled_token_latent = model_fn(noise, sample_fn, **model_kwargs)[-1]
|
| 257 |
+
else:
|
| 258 |
+
sampled_token_latent = self.gen_diffusion.p_sample_loop(
|
| 259 |
+
sample_fn, noise.shape, noise, clip_denoised=False, model_kwargs=model_kwargs,
|
| 260 |
+
progress=False,
|
| 261 |
+
temperature=temperature
|
| 262 |
+
)
|
| 263 |
+
if not (cfg1 == 1.0 and cfg2 == 1.0):
|
| 264 |
+
sampled_token_latent, _, _ = sampled_token_latent.chunk(3, dim=0)
|
| 265 |
+
sampled_token_latent = sampled_token_latent.permute(0, 2, 3, 1)
|
| 266 |
+
|
| 267 |
+
latents = torch.where(padding_mask.unsqueeze(-1).unsqueeze(-1), torch.zeros_like(sampled_token_latent),
|
| 268 |
+
sampled_token_latent)
|
| 269 |
+
return latents.permute(0, 3, 1, 2), mask_seq
|
| 270 |
+
|
| 271 |
+
#################################################################################
|
| 272 |
+
# ACMDM Zoos #
|
| 273 |
+
#################################################################################
|
| 274 |
+
def acmdm_raw_flow_s_ps22_control(**kwargs):
|
| 275 |
+
layer = 8
|
| 276 |
+
return ACMDM_ControlNet(latent_dim=layer*64, ff_size=layer*64*4, num_layers=layer, num_heads=layer, dropout=0, clip_dim=512,
|
| 277 |
+
diff_model="Flow", cond_drop_prob=0.1, max_length=49,
|
| 278 |
+
patch_size=(1, 22), stride_size=(1, 22), freeze_base=True, **kwargs)
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
ACMDM_ControlNet_Models = {
|
| 282 |
+
'ACMDM-Flow-S-PatchSize22-ControlNet': acmdm_raw_flow_s_ps22_control,
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
#################################################################################
|
| 286 |
+
# Inner Architectures #
|
| 287 |
+
#################################################################################
|
| 288 |
+
def modulate(x, shift, scale):
|
| 289 |
+
return x * (1 + scale) + shift
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
def zero_module(module):
|
| 293 |
+
for p in module.parameters():
|
| 294 |
+
p.detach().zero_()
|
| 295 |
+
return module
|
| 296 |
+
|
| 297 |
+
class c_control_embedder(nn.Module):
|
| 298 |
+
def __init__(
|
| 299 |
+
self,
|
| 300 |
+
in_features: int,
|
| 301 |
+
hidden_features,
|
| 302 |
+
patch_size,
|
| 303 |
+
stride_size,
|
| 304 |
+
) -> None:
|
| 305 |
+
super().__init__()
|
| 306 |
+
self.patch_embed = nn.Conv2d(in_features, hidden_features, kernel_size=(4,patch_size[1]), stride=(4,stride_size[1]), bias=True)
|
| 307 |
+
self.norm = LlamaRMSNorm(hidden_features, eps=1e-6)
|
| 308 |
+
self.zero_linear = nn.Linear(hidden_features, hidden_features)
|
| 309 |
+
|
| 310 |
+
def forward(self, x):
|
| 311 |
+
x = self.patch_embed(x).flatten(2).transpose(1, 2)
|
| 312 |
+
x = self.norm(x)
|
| 313 |
+
x = self.zero_linear(x)
|
| 314 |
+
return x
|
models/ACMDM_NoisyPrefix_AR.py
ADDED
|
@@ -0,0 +1,556 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import clip
|
| 5 |
+
import math
|
| 6 |
+
from functools import partial
|
| 7 |
+
from timm.models.vision_transformer import Attention
|
| 8 |
+
from models.ROPE import RopeND
|
| 9 |
+
from utils.eval_utils import eval_decorator
|
| 10 |
+
from utils.train_utils import lengths_to_mask
|
| 11 |
+
from diffusions.diffusion import create_diffusion
|
| 12 |
+
from diffusions.transport import create_transport, Sampler
|
| 13 |
+
|
| 14 |
+
#################################################################################
|
| 15 |
+
# ACMDM #
|
| 16 |
+
#################################################################################
|
| 17 |
+
class ACMDM(nn.Module):
|
| 18 |
+
def __init__(self, input_dim, cond_mode, latent_dim=256, ff_size=1024, num_layers=8,
|
| 19 |
+
num_heads=4, dropout=0, clip_dim=512,
|
| 20 |
+
diff_model='Flow', cond_drop_prob=0.1, max_length=49,
|
| 21 |
+
patch_size=(1, 22), stride_size=(1, 22), num_joint=22, cluster=5,
|
| 22 |
+
clip_version='ViT-B/32', **kargs):
|
| 23 |
+
super(ACMDM, self).__init__()
|
| 24 |
+
|
| 25 |
+
self.input_dim = input_dim
|
| 26 |
+
self.latent_dim = latent_dim
|
| 27 |
+
self.clip_dim = clip_dim
|
| 28 |
+
self.dropout = dropout
|
| 29 |
+
self.cluster = cluster
|
| 30 |
+
|
| 31 |
+
self.cond_mode = cond_mode
|
| 32 |
+
self.cond_drop_prob = cond_drop_prob
|
| 33 |
+
|
| 34 |
+
if self.cond_mode == 'action':
|
| 35 |
+
assert 'num_actions' in kargs
|
| 36 |
+
self.num_actions = kargs.get('num_actions', 1)
|
| 37 |
+
self.encode_action = partial(F.one_hot, num_classes=self.num_actions)
|
| 38 |
+
# --------------------------------------------------------------------------
|
| 39 |
+
# Diffusion
|
| 40 |
+
self.diff_model = diff_model
|
| 41 |
+
if self.diff_model == 'Flow':
|
| 42 |
+
self.train_diffusion = create_transport() # default to linear, velocity prediction
|
| 43 |
+
self.gen_diffusion = Sampler(self.train_diffusion)
|
| 44 |
+
else:
|
| 45 |
+
self.train_diffusion = create_diffusion(timestep_respacing="", noise_schedule="linear")
|
| 46 |
+
self.gen_diffusion = create_diffusion(timestep_respacing="", noise_schedule="linear")
|
| 47 |
+
# --------------------------------------------------------------------------
|
| 48 |
+
# ACMDM
|
| 49 |
+
print('Loading ACMDM...')
|
| 50 |
+
self.t_embedder = TimestepEmbedder(self.latent_dim)
|
| 51 |
+
self.patch_size = patch_size
|
| 52 |
+
self.stride_size = stride_size
|
| 53 |
+
self.patches_per_frame = (num_joint - patch_size[1]) // stride_size[1] + 1
|
| 54 |
+
|
| 55 |
+
# Patchification
|
| 56 |
+
self.x_embedder = nn.Linear(self.input_dim*self.patch_size[0]*self.patch_size[1], self.latent_dim, bias=True)
|
| 57 |
+
|
| 58 |
+
# Positional Encoding
|
| 59 |
+
max_length = max_length * self.patches_per_frame
|
| 60 |
+
self.max_lens = [max_length]
|
| 61 |
+
self.rope = RopeND(nd=1, nd_split=[1], max_lens=self.max_lens)
|
| 62 |
+
self.position_ids_precompute = torch.arange(max_length).unsqueeze(0)
|
| 63 |
+
self.cluster_patches = max_length // self.cluster
|
| 64 |
+
|
| 65 |
+
self.ACMDMTransformer = nn.ModuleList([
|
| 66 |
+
ACMDMTransBlock(self.latent_dim, num_heads, mlp_size=ff_size, rope=self.rope, qk_norm=True) for _ in range(num_layers)
|
| 67 |
+
])
|
| 68 |
+
|
| 69 |
+
if self.cond_mode == 'text':
|
| 70 |
+
self.y_embedder = nn.Linear(self.clip_dim, self.latent_dim)
|
| 71 |
+
elif self.cond_mode == 'action':
|
| 72 |
+
self.y_embedder = nn.Linear(self.num_actions, self.latent_dim)
|
| 73 |
+
elif self.cond_mode == 'uncond':
|
| 74 |
+
self.y_embedder = nn.Identity()
|
| 75 |
+
else:
|
| 76 |
+
raise KeyError("Unsupported condition mode!!!")
|
| 77 |
+
|
| 78 |
+
self.final_layer = FinalLayer(self.latent_dim, self.input_dim*self.patch_size[0]*self.patch_size[1])
|
| 79 |
+
|
| 80 |
+
self.initialize_weights()
|
| 81 |
+
|
| 82 |
+
if self.cond_mode == 'text':
|
| 83 |
+
print('Loading CLIP...')
|
| 84 |
+
self.clip_version = clip_version
|
| 85 |
+
self.clip_model = self.load_and_freeze_clip(clip_version)
|
| 86 |
+
|
| 87 |
+
attention_mask = []
|
| 88 |
+
start = 0
|
| 89 |
+
total_length = max_length
|
| 90 |
+
for idx in range(max_length):
|
| 91 |
+
if idx in [self.cluster_patches * i for i in range(self.cluster)]:
|
| 92 |
+
start += self.cluster_patches * self.patches_per_frame
|
| 93 |
+
attention_mask.append(torch.cat([torch.ones((1, start)),
|
| 94 |
+
torch.zeros((1, total_length - start))], dim=-1))
|
| 95 |
+
attention_mask = torch.cat(attention_mask, dim=0)
|
| 96 |
+
attention_mask = torch.where(attention_mask == 0, -torch.inf, attention_mask)
|
| 97 |
+
attention_mask = torch.where(attention_mask == 1, 0, attention_mask)
|
| 98 |
+
attention_mask = attention_mask.unsqueeze(0).unsqueeze(0)
|
| 99 |
+
self.register_buffer('attention_mask', attention_mask.contiguous())
|
| 100 |
+
|
| 101 |
+
def initialize_weights(self):
|
| 102 |
+
# Initialize transformer layers:
|
| 103 |
+
def _basic_init(module):
|
| 104 |
+
if isinstance(module, nn.Linear):
|
| 105 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
| 106 |
+
if module.bias is not None:
|
| 107 |
+
nn.init.constant_(module.bias, 0)
|
| 108 |
+
|
| 109 |
+
self.apply(_basic_init)
|
| 110 |
+
|
| 111 |
+
# Initialize timestep embedding MLP:
|
| 112 |
+
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
| 113 |
+
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
| 114 |
+
|
| 115 |
+
# Zero-out adaLN modulation layers in ACMDM blocks:
|
| 116 |
+
for block in self.ACMDMTransformer:
|
| 117 |
+
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
| 118 |
+
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
| 119 |
+
|
| 120 |
+
# Zero-out output layers:
|
| 121 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
| 122 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
| 123 |
+
nn.init.constant_(self.final_layer.linear.weight, 0)
|
| 124 |
+
nn.init.constant_(self.final_layer.linear.bias, 0)
|
| 125 |
+
|
| 126 |
+
def load_and_freeze_clip(self, clip_version):
|
| 127 |
+
clip_model, clip_preprocess = clip.load(clip_version, device='cpu', jit=False)
|
| 128 |
+
assert torch.cuda.is_available()
|
| 129 |
+
clip.model.convert_weights(clip_model)
|
| 130 |
+
|
| 131 |
+
clip_model.eval()
|
| 132 |
+
for p in clip_model.parameters():
|
| 133 |
+
p.requires_grad = False
|
| 134 |
+
return clip_model
|
| 135 |
+
|
| 136 |
+
def encode_text(self, raw_text):
|
| 137 |
+
device = next(self.parameters()).device
|
| 138 |
+
text = clip.tokenize(raw_text, truncate=True).to(device)
|
| 139 |
+
feat_clip_text = self.clip_model.encode_text(text).float()
|
| 140 |
+
return feat_clip_text
|
| 141 |
+
|
| 142 |
+
def mask_cond(self, cond, force_mask=False):
|
| 143 |
+
bs, d = cond.shape
|
| 144 |
+
if force_mask:
|
| 145 |
+
return torch.zeros_like(cond)
|
| 146 |
+
elif self.training and self.cond_drop_prob > 0.:
|
| 147 |
+
mask = torch.bernoulli(torch.ones(bs, device=cond.device) * self.cond_drop_prob).view(bs, 1)
|
| 148 |
+
return cond * (1. - mask)
|
| 149 |
+
else:
|
| 150 |
+
return cond
|
| 151 |
+
|
| 152 |
+
def patchify(self, x):
|
| 153 |
+
b, c, l, j = x.shape
|
| 154 |
+
p = self.patch_size[0]
|
| 155 |
+
q = self.patch_size[1]
|
| 156 |
+
l_, j_ = l // p, j // q
|
| 157 |
+
|
| 158 |
+
x = x.reshape(b, c, l_, p, j_, q)
|
| 159 |
+
x = torch.einsum('nclpjq->nljcpq', x)
|
| 160 |
+
x = x.reshape(b, l_ * j_, c * p *q)
|
| 161 |
+
return x
|
| 162 |
+
|
| 163 |
+
def patchify_mask(self, mask):
|
| 164 |
+
b, l = mask.shape
|
| 165 |
+
p = self.patch_size[0]
|
| 166 |
+
l_ = l//self.patch_size[0]
|
| 167 |
+
q = self.patch_size[1]
|
| 168 |
+
j_ = self.patches_per_frame
|
| 169 |
+
mask = mask.unsqueeze(1).unsqueeze(-1).expand(-1, self.input_dim, -1, j_*q)
|
| 170 |
+
mask = mask.reshape(b, self.input_dim, l_, p, j_, q)
|
| 171 |
+
mask = torch.einsum('nclpjq->nljcpq', mask)
|
| 172 |
+
mask = mask.reshape(b, l_ * j_, self.input_dim*p * q)
|
| 173 |
+
mask = mask.any(dim=-1)
|
| 174 |
+
return mask
|
| 175 |
+
|
| 176 |
+
def unpatchify(self, x):
|
| 177 |
+
b = x.shape[0]
|
| 178 |
+
p = self.patch_size[0]
|
| 179 |
+
q = self.patch_size[1]
|
| 180 |
+
c = self.input_dim
|
| 181 |
+
l_, j_ = x.shape[1]//self.patches_per_frame, self.patches_per_frame
|
| 182 |
+
|
| 183 |
+
x = x.reshape(b, l_, j_, c, p, q)
|
| 184 |
+
x = torch.einsum('nljcpq->nclpjq', x)
|
| 185 |
+
x = x.reshape(b, c, l_ * p, j_ * q)
|
| 186 |
+
return x
|
| 187 |
+
|
| 188 |
+
def forward(self, x, t, conds, attention_mask, force_mask=False, ids=None, block_size=None, cache=False):
|
| 189 |
+
t = self.t_embedder(t, dtype=x.dtype).unsqueeze(1).repeat(1, self.cluster_patches * self.patches_per_frame, 1)
|
| 190 |
+
t = t.chunk(self.cluster, dim=0)
|
| 191 |
+
t = torch.cat(t, dim=1)
|
| 192 |
+
conds = self.mask_cond(conds, force_mask=force_mask)
|
| 193 |
+
x = x.chunk(self.cluster, dim=0)
|
| 194 |
+
x = torch.cat(x, dim=1)
|
| 195 |
+
x = self.x_embedder(x)
|
| 196 |
+
conds = self.y_embedder(conds)
|
| 197 |
+
y = t + conds.unsqueeze(1)
|
| 198 |
+
if ids is not None:
|
| 199 |
+
position_ids = ids
|
| 200 |
+
else:
|
| 201 |
+
position_ids = self.position_ids_precompute[:, :x.shape[1]]
|
| 202 |
+
for block in self.ACMDMTransformer:
|
| 203 |
+
x = block(x, y, attention_mask, position_ids=position_ids, block_size=block_size, cache=cache)
|
| 204 |
+
x = self.final_layer(x, y)
|
| 205 |
+
x = x.chunk(self.cluster, dim=1)
|
| 206 |
+
x = torch.cat(x, dim=0)
|
| 207 |
+
return x
|
| 208 |
+
|
| 209 |
+
def forward_with_CFG(self, x, t, conds, attention_mask, cfg=1.0, context=None, cache=True, block_id=0):
|
| 210 |
+
if cache:
|
| 211 |
+
if self.ACMDMTransformer[0].attn.cached_k is None:
|
| 212 |
+
cache = True
|
| 213 |
+
elif block_id * self.cluster_patches == self.ACMDMTransformer[0].attn.cached_k.shape[2]:
|
| 214 |
+
cache = False
|
| 215 |
+
if not cfg == 1.0:
|
| 216 |
+
half = x[: len(x) // 2]
|
| 217 |
+
x = torch.cat([half, half], dim=0)
|
| 218 |
+
if context is not None and cache:
|
| 219 |
+
ids = self.position_ids_precompute[:, (block_id - 1) * self.cluster_patches * self.patches_per_frame:(block_id + 1) * self.cluster_patches * self.patches_per_frame]
|
| 220 |
+
x = torch.cat([context, x], dim=1)
|
| 221 |
+
t = torch.cat([torch.ones_like(t).unsqueeze(-1).repeat(1, self.patches_per_frame * self.cluster_patches),
|
| 222 |
+
t.unsqueeze(-1).repeat(1, self.patches_per_frame * self.cluster_patches)], dim=1)
|
| 223 |
+
am_idx = block_id if block_id == 0 else block_id - 1
|
| 224 |
+
attention_mask = attention_mask[:, :, am_idx * self.cluster_patches * self.patches_per_frame: (block_id + 1) * self.cluster_patches * self.patches_per_frame,
|
| 225 |
+
:(block_id + 1) * self.cluster_patches * self.patches_per_frame]
|
| 226 |
+
else:
|
| 227 |
+
ids = self.position_ids_precompute[:,
|
| 228 |
+
(block_id) * self.cluster_patches * self.patches_per_frame:(block_id + 1) * self.cluster_patches * self.patches_per_frame]
|
| 229 |
+
t = t.unsqueeze(-1).repeat(1, self.patches_per_frame * self.cluster_patches)
|
| 230 |
+
attention_mask = attention_mask[:, :, :(block_id + 1) * self.cluster_patches * self.patches_per_frame,
|
| 231 |
+
:(block_id + 1) * self.cluster_patches * self.patches_per_frame]
|
| 232 |
+
attention_mask = attention_mask[:, :, -self.patches_per_frame * self.cluster_patches:, :]
|
| 233 |
+
t = t.reshape(-1)
|
| 234 |
+
t = self.t_embedder(t, dtype=x.dtype)
|
| 235 |
+
t = t.reshape(x.shape[0], x.shape[1], -1)
|
| 236 |
+
conds = self.mask_cond(conds)
|
| 237 |
+
x = self.x_embedder(x)
|
| 238 |
+
conds = self.y_embedder(conds)
|
| 239 |
+
y = t + conds.unsqueeze(1)
|
| 240 |
+
position_ids = ids
|
| 241 |
+
for block in self.ACMDMTransformer:
|
| 242 |
+
x = block(x, y, attention_mask, position_ids=position_ids, block_size=self.patches_per_frame * self.cluster_patches,
|
| 243 |
+
cache=cache)
|
| 244 |
+
x = self.final_layer(x, y)
|
| 245 |
+
x = x[:, -self.patches_per_frame * self.cluster_patches:, :]
|
| 246 |
+
if not cfg == 1.0:
|
| 247 |
+
cond_eps, uncond_eps = torch.split(x, len(x) // 2, dim=0)
|
| 248 |
+
half_eps = uncond_eps + cfg * (cond_eps - uncond_eps)
|
| 249 |
+
x = torch.cat([half_eps, half_eps], dim=0)
|
| 250 |
+
return x
|
| 251 |
+
|
| 252 |
+
def forward_loss(self, latents, y, m_lens):
|
| 253 |
+
b, d, l, j = latents.shape
|
| 254 |
+
device = latents.device
|
| 255 |
+
|
| 256 |
+
non_pad_mask = lengths_to_mask(m_lens, l)
|
| 257 |
+
non_pad_mask = self.patchify_mask(non_pad_mask)
|
| 258 |
+
latents = self.patchify(latents)
|
| 259 |
+
b, l, d = latents.shape
|
| 260 |
+
latents = torch.where(non_pad_mask.unsqueeze(-1), latents, torch.zeros_like(latents))
|
| 261 |
+
|
| 262 |
+
target = latents.clone().detach().chunk(self.cluster, dim=1)
|
| 263 |
+
target = torch.cat(target, dim=0)
|
| 264 |
+
|
| 265 |
+
force_mask = False
|
| 266 |
+
if self.cond_mode == 'text':
|
| 267 |
+
with torch.no_grad():
|
| 268 |
+
cond_vector = self.encode_text(y)
|
| 269 |
+
elif self.cond_mode == 'action':
|
| 270 |
+
cond_vector = self.enc_action(y).to(device).float()
|
| 271 |
+
elif self.cond_mode == 'uncond':
|
| 272 |
+
cond_vector = torch.zeros(b, self.latent_dim).float().to(device)
|
| 273 |
+
force_mask = True
|
| 274 |
+
else:
|
| 275 |
+
raise NotImplementedError("Unsupported condition mode!!!")
|
| 276 |
+
|
| 277 |
+
attention_mask = []
|
| 278 |
+
for i in range(b):
|
| 279 |
+
a_mask = self.attention_mask.clone()
|
| 280 |
+
a_mask[:, :, :, m_lens[i] * self.patches_per_frame:] = -torch.inf
|
| 281 |
+
attention_mask.append(a_mask)
|
| 282 |
+
attention_mask = torch.cat(attention_mask)
|
| 283 |
+
|
| 284 |
+
model_kwargs = dict(conds=cond_vector, force_mask=force_mask, attention_mask=attention_mask)
|
| 285 |
+
if self.diff_model == "Flow":
|
| 286 |
+
loss_dict = self.train_diffusion.training_losses(self.forward, target, model_kwargs, dim=(2))
|
| 287 |
+
else:
|
| 288 |
+
t = torch.randint(0, self.train_diffusion.num_timesteps, (target.shape[0],), device=target.device)
|
| 289 |
+
loss_dict = self.train_diffusion.training_losses(self.forward, target, t, model_kwargs)
|
| 290 |
+
loss = loss_dict["loss"]
|
| 291 |
+
loss = loss.chunk(self.cluster, dim=0)
|
| 292 |
+
loss = torch.cat(loss, dim=1)
|
| 293 |
+
loss = (loss * non_pad_mask).sum() / non_pad_mask.sum()
|
| 294 |
+
|
| 295 |
+
return loss
|
| 296 |
+
|
| 297 |
+
@torch.no_grad()
|
| 298 |
+
@eval_decorator
|
| 299 |
+
def generate(self,
|
| 300 |
+
conds,
|
| 301 |
+
m_lens,
|
| 302 |
+
cond_scale: int,
|
| 303 |
+
temperature=1,
|
| 304 |
+
):
|
| 305 |
+
device = next(self.parameters()).device
|
| 306 |
+
l = max(m_lens)
|
| 307 |
+
b = len(m_lens)
|
| 308 |
+
|
| 309 |
+
if self.cond_mode == 'text':
|
| 310 |
+
with torch.no_grad():
|
| 311 |
+
cond_vector = self.encode_text(conds)
|
| 312 |
+
elif self.cond_mode == 'action':
|
| 313 |
+
cond_vector = self.enc_action(conds).to(device)
|
| 314 |
+
elif self.cond_mode == 'uncond':
|
| 315 |
+
cond_vector = torch.zeros(b, self.latent_dim).float().to(device)
|
| 316 |
+
else:
|
| 317 |
+
raise NotImplementedError("Unsupported condition mode!!!")
|
| 318 |
+
|
| 319 |
+
padding_mask = ~lengths_to_mask(m_lens, l)
|
| 320 |
+
if not cond_scale == 1.0:
|
| 321 |
+
cond_vector = torch.cat([cond_vector, torch.zeros_like(cond_vector)], dim=0)
|
| 322 |
+
for block in self.ACMDMTransformer:
|
| 323 |
+
block.set_caching(True)
|
| 324 |
+
|
| 325 |
+
output = []
|
| 326 |
+
attention_mask = []
|
| 327 |
+
for i in range(b):
|
| 328 |
+
a_mask = self.attention_mask.clone()
|
| 329 |
+
a_mask[:, :, :, m_lens[i] * self.patches_per_frame:] = -torch.inf
|
| 330 |
+
attention_mask.append(a_mask)
|
| 331 |
+
attention_mask = torch.cat(attention_mask)
|
| 332 |
+
if not cond_scale == 1.0:
|
| 333 |
+
attention_mask = torch.cat([attention_mask, attention_mask], dim=0)
|
| 334 |
+
for step in range(self.cluster):
|
| 335 |
+
clean_x = output[-1] if len(output) > 0 else None
|
| 336 |
+
cache_flag = step > 0
|
| 337 |
+
noise = torch.randn(b, self.cluster_patches * self.patches_per_frame,
|
| 338 |
+
self.input_dim * self.patch_size[0] * self.patch_size[1]).to(device)
|
| 339 |
+
if not cond_scale == 1.0:
|
| 340 |
+
noise = torch.cat([noise, noise], dim=0)
|
| 341 |
+
if clean_x is not None:
|
| 342 |
+
clean_x = torch.cat([clean_x, clean_x], dim=0)
|
| 343 |
+
# cfg scale
|
| 344 |
+
# cond_scale2 = (cond_scale - 1) * (step+1) / (m_lens//self.cluster_patches + 1) + 1
|
| 345 |
+
model_kwargs = dict(conds=cond_vector, context=clean_x, block_id=step, cache=cache_flag,
|
| 346 |
+
attention_mask=attention_mask, cfg=cond_scale)
|
| 347 |
+
sample_fn = self.forward_with_CFG
|
| 348 |
+
|
| 349 |
+
if self.diff_model == "Flow":
|
| 350 |
+
model_fn = self.gen_diffusion.sample_ode() # default to ode sampling
|
| 351 |
+
sampled_token_latent = model_fn(noise, sample_fn, **model_kwargs)[-1]
|
| 352 |
+
else:
|
| 353 |
+
sampled_token_latent = self.gen_diffusion.p_sample_loop(
|
| 354 |
+
sample_fn, noise.shape, noise, clip_denoised=False, model_kwargs=model_kwargs,
|
| 355 |
+
progress=False,
|
| 356 |
+
temperature=temperature
|
| 357 |
+
)
|
| 358 |
+
if not cond_scale == 1:
|
| 359 |
+
sampled_token_latent, _ = sampled_token_latent.chunk(2, dim=0)
|
| 360 |
+
output.append(sampled_token_latent.detach().clone())
|
| 361 |
+
|
| 362 |
+
latents = torch.cat(output, dim=1)
|
| 363 |
+
latents = self.unpatchify(latents[:, :l * self.patches_per_frame, :])
|
| 364 |
+
latents = torch.where(padding_mask.unsqueeze(1).unsqueeze(-1), torch.zeros_like(latents), latents)
|
| 365 |
+
for block in self.ACMDMTransformer:
|
| 366 |
+
block.set_caching(False)
|
| 367 |
+
return latents
|
| 368 |
+
|
| 369 |
+
#################################################################################
|
| 370 |
+
# ACMDM Zoos #
|
| 371 |
+
#################################################################################
|
| 372 |
+
def acmdm_noisyprefixar_flow_s_ps22(**kwargs):
|
| 373 |
+
layer = 8
|
| 374 |
+
return ACMDM(latent_dim=layer*64, ff_size=layer*64*4, num_layers=layer, num_heads=layer, dropout=0, clip_dim=512,
|
| 375 |
+
diff_model="Flow", cond_drop_prob=0.1, max_length=50,
|
| 376 |
+
patch_size=(1, 22), stride_size=(1, 22), **kwargs)
|
| 377 |
+
ACMDM_models = {
|
| 378 |
+
'ACMDM-NoisyPrefixAR-Flow-S-PatchSize22': acmdm_noisyprefixar_flow_s_ps22,
|
| 379 |
+
}
|
| 380 |
+
|
| 381 |
+
#################################################################################
|
| 382 |
+
# Inner Architectures #
|
| 383 |
+
#################################################################################
|
| 384 |
+
def modulate(x, shift, scale):
|
| 385 |
+
return x * (1 + scale) + shift
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
class ACMDMAttention(Attention):
|
| 389 |
+
def __init__(
|
| 390 |
+
self,
|
| 391 |
+
dim,
|
| 392 |
+
num_heads=8,
|
| 393 |
+
qkv_bias=True,
|
| 394 |
+
rope=None,
|
| 395 |
+
qk_norm=True,
|
| 396 |
+
**block_kwargs,
|
| 397 |
+
):
|
| 398 |
+
super().__init__(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm, **block_kwargs)
|
| 399 |
+
self.caching, self.cached_k, self.cached_v = False, None, None
|
| 400 |
+
self.rope = rope
|
| 401 |
+
|
| 402 |
+
def set_caching(self, flag):
|
| 403 |
+
self.caching, self.cached_k, self.cached_v = flag, None, None
|
| 404 |
+
|
| 405 |
+
def forward(self, x, position_ids=None, attention_mask=None, block_size=None, cache=False):
|
| 406 |
+
B, N, C = x.shape
|
| 407 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
| 408 |
+
q, k, v = qkv.unbind(0)
|
| 409 |
+
q, k = self.q_norm(q), self.k_norm(k)
|
| 410 |
+
|
| 411 |
+
if self.rope is not None:
|
| 412 |
+
q, k = self.rope(q, k, position_ids)
|
| 413 |
+
|
| 414 |
+
if self.caching:
|
| 415 |
+
if cache:
|
| 416 |
+
if self.cached_k is None:
|
| 417 |
+
self.cached_k = k[:, :, :block_size, :]
|
| 418 |
+
self.cached_v = v[:, :, :block_size, :]
|
| 419 |
+
self.cached_x = x
|
| 420 |
+
else:
|
| 421 |
+
self.cached_k = torch.cat((self.cached_k, k[:, :, :block_size, :]), dim=2)
|
| 422 |
+
self.cached_v = torch.cat((self.cached_v, v[:, :, :block_size, :]), dim=2)
|
| 423 |
+
|
| 424 |
+
if self.cached_k is not None:
|
| 425 |
+
k = torch.cat((self.cached_k, k[:, :, -block_size:, :]), dim=2)
|
| 426 |
+
v = torch.cat((self.cached_v, v[:, :, -block_size:, :]), dim=2)
|
| 427 |
+
|
| 428 |
+
x = torch.nn.functional.scaled_dot_product_attention(
|
| 429 |
+
q, k, v,
|
| 430 |
+
attn_mask=attention_mask,
|
| 431 |
+
dropout_p=self.attn_drop.p
|
| 432 |
+
)
|
| 433 |
+
x = x.transpose(1, 2).reshape(B, N, C)
|
| 434 |
+
x = self.proj(x)
|
| 435 |
+
x = self.proj_drop(x)
|
| 436 |
+
return x
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
class SwiGLUFFN(nn.Module):
|
| 440 |
+
def __init__(
|
| 441 |
+
self,
|
| 442 |
+
in_features: int,
|
| 443 |
+
hidden_features,
|
| 444 |
+
bias: bool = True,
|
| 445 |
+
) -> None:
|
| 446 |
+
super().__init__()
|
| 447 |
+
out_features = in_features
|
| 448 |
+
hidden_features = hidden_features
|
| 449 |
+
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
|
| 450 |
+
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
|
| 451 |
+
|
| 452 |
+
def forward(self, x):
|
| 453 |
+
x12 = self.w12(x)
|
| 454 |
+
x1, x2 = x12.chunk(2, dim=-1)
|
| 455 |
+
hidden = F.silu(x1) * x2
|
| 456 |
+
return self.w3(hidden)
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
class ACMDMTransBlock(nn.Module):
|
| 460 |
+
def __init__(self, hidden_size, num_heads, mlp_size=1024, rope=None, qk_norm=True):
|
| 461 |
+
super().__init__()
|
| 462 |
+
self.norm1 = LlamaRMSNorm(hidden_size, eps=1e-6)
|
| 463 |
+
self.attn = ACMDMAttention(hidden_size, num_heads=num_heads, qkv_bias=True, norm_layer=LlamaRMSNorm,
|
| 464 |
+
qk_norm=qk_norm, rope=rope)
|
| 465 |
+
self.norm2 = LlamaRMSNorm(hidden_size, eps=1e-6)
|
| 466 |
+
self.mlp = SwiGLUFFN(hidden_size, int(2 / 3 * mlp_size))
|
| 467 |
+
self.adaLN_modulation = nn.Sequential(
|
| 468 |
+
nn.SiLU(),
|
| 469 |
+
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
| 470 |
+
)
|
| 471 |
+
|
| 472 |
+
def set_caching(self, flag):
|
| 473 |
+
self.attn.set_caching(flag)
|
| 474 |
+
|
| 475 |
+
def forward(self, x, c, attention_mask=None, position_ids=None, block_size=None, cache=False):
|
| 476 |
+
dtype = x.dtype
|
| 477 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
|
| 478 |
+
norm_x1 = self.norm1(x.to(torch.float32)).to(dtype)
|
| 479 |
+
attn_input_x = modulate(norm_x1, shift_msa, scale_msa)
|
| 480 |
+
attn_output_x = self.attn(attn_input_x, attention_mask=attention_mask, position_ids=position_ids, block_size=block_size, cache=cache)
|
| 481 |
+
x = x + gate_msa * attn_output_x
|
| 482 |
+
|
| 483 |
+
norm_x2 = self.norm2(x.to(torch.float32)).to(dtype)
|
| 484 |
+
gate_input_x = modulate(norm_x2, shift_mlp, scale_mlp)
|
| 485 |
+
gate_output_x = self.mlp(gate_input_x)
|
| 486 |
+
x = x + gate_mlp * gate_output_x
|
| 487 |
+
return x
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
class FinalLayer(nn.Module):
|
| 491 |
+
def __init__(self, hidden_size, output_size):
|
| 492 |
+
super().__init__()
|
| 493 |
+
self.norm_final = LlamaRMSNorm(hidden_size, eps=1e-6)
|
| 494 |
+
self.linear = nn.Linear(hidden_size, output_size, bias=True)
|
| 495 |
+
self.adaLN_modulation = nn.Sequential(
|
| 496 |
+
nn.SiLU(),
|
| 497 |
+
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
| 498 |
+
)
|
| 499 |
+
|
| 500 |
+
def forward(self, x, c):
|
| 501 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
|
| 502 |
+
norm_x = self.norm_final(x.to(torch.float32)).to(x.dtype)
|
| 503 |
+
x = modulate(norm_x, shift, scale)
|
| 504 |
+
x = self.linear(x)
|
| 505 |
+
return x
|
| 506 |
+
|
| 507 |
+
|
| 508 |
+
class TimestepEmbedder(nn.Module):
|
| 509 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
| 510 |
+
super().__init__()
|
| 511 |
+
self.mlp = nn.Sequential(
|
| 512 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
| 513 |
+
nn.SiLU(),
|
| 514 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
| 515 |
+
)
|
| 516 |
+
self.frequency_embedding_size = frequency_embedding_size
|
| 517 |
+
|
| 518 |
+
@staticmethod
|
| 519 |
+
def timestep_embedding(t, dim, max_period=10000, dtype=torch.float32):
|
| 520 |
+
"""
|
| 521 |
+
Create sinusoidal timestep embeddings.
|
| 522 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
| 523 |
+
These may be fractional.
|
| 524 |
+
:param dim: the dimension of the output.
|
| 525 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
| 526 |
+
:return: an (N, D) Tensor of positional embeddings.
|
| 527 |
+
"""
|
| 528 |
+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
| 529 |
+
half = dim // 2
|
| 530 |
+
freqs = torch.exp(
|
| 531 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=dtype) / half
|
| 532 |
+
).to(device=t.device, dtype=dtype)
|
| 533 |
+
args = t[:, None] * freqs[None]
|
| 534 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 535 |
+
if dim % 2:
|
| 536 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 537 |
+
return embedding
|
| 538 |
+
|
| 539 |
+
def forward(self, t, dtype=torch.bfloat16):
|
| 540 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size, dtype=dtype)
|
| 541 |
+
t_emb = self.mlp(t_freq)
|
| 542 |
+
return t_emb
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
class LlamaRMSNorm(nn.Module):
|
| 546 |
+
def __init__(self, hidden_size, eps=1e-6):
|
| 547 |
+
super().__init__()
|
| 548 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 549 |
+
self.variance_epsilon = eps
|
| 550 |
+
|
| 551 |
+
def forward(self, hidden_states):
|
| 552 |
+
input_dtype = hidden_states.dtype
|
| 553 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 554 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 555 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 556 |
+
return (self.weight * hidden_states).to(input_dtype)
|
models/ACMDM_Prefix_AR.py
ADDED
|
@@ -0,0 +1,434 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import clip
|
| 5 |
+
import math
|
| 6 |
+
from functools import partial
|
| 7 |
+
from timm.models.vision_transformer import Attention
|
| 8 |
+
from models.ROPE import RopeND
|
| 9 |
+
from utils.eval_utils import eval_decorator
|
| 10 |
+
from utils.train_utils import lengths_to_mask
|
| 11 |
+
from diffusions.diffusion import create_diffusion
|
| 12 |
+
from diffusions.transport import create_transport, Sampler
|
| 13 |
+
|
| 14 |
+
#################################################################################
|
| 15 |
+
# ACMDM #
|
| 16 |
+
#################################################################################
|
| 17 |
+
class ACMDM(nn.Module):
|
| 18 |
+
def __init__(self, input_dim, cond_mode, latent_dim=256, ff_size=1024, num_layers=8,
|
| 19 |
+
num_heads=4, dropout=0, clip_dim=512,
|
| 20 |
+
diff_model='Flow', cond_drop_prob=0.1, max_length=49,
|
| 21 |
+
patch_size=(1, 22), stride_size=(1, 22), num_joint=22,
|
| 22 |
+
clip_version='ViT-B/32', **kargs):
|
| 23 |
+
super(ACMDM, self).__init__()
|
| 24 |
+
|
| 25 |
+
self.input_dim = input_dim
|
| 26 |
+
self.latent_dim = latent_dim
|
| 27 |
+
self.clip_dim = clip_dim
|
| 28 |
+
self.dropout = dropout
|
| 29 |
+
|
| 30 |
+
self.cond_mode = cond_mode
|
| 31 |
+
self.cond_drop_prob = cond_drop_prob
|
| 32 |
+
|
| 33 |
+
if self.cond_mode == 'action':
|
| 34 |
+
assert 'num_actions' in kargs
|
| 35 |
+
self.num_actions = kargs.get('num_actions', 1)
|
| 36 |
+
self.encode_action = partial(F.one_hot, num_classes=self.num_actions)
|
| 37 |
+
# --------------------------------------------------------------------------
|
| 38 |
+
# Diffusion
|
| 39 |
+
self.diff_model = diff_model
|
| 40 |
+
if self.diff_model == 'Flow':
|
| 41 |
+
self.train_diffusion = create_transport() # default to linear, velocity prediction
|
| 42 |
+
self.gen_diffusion = Sampler(self.train_diffusion)
|
| 43 |
+
else:
|
| 44 |
+
self.train_diffusion = create_diffusion(timestep_respacing="", noise_schedule="linear")
|
| 45 |
+
self.gen_diffusion = create_diffusion(timestep_respacing="", noise_schedule="linear")
|
| 46 |
+
# --------------------------------------------------------------------------
|
| 47 |
+
# ACMDM
|
| 48 |
+
print('Loading ACMDM...')
|
| 49 |
+
self.t_embedder = TimestepEmbedder(self.latent_dim)
|
| 50 |
+
self.patch_size = patch_size
|
| 51 |
+
self.stride_size = stride_size
|
| 52 |
+
self.patches_per_frame = (num_joint - patch_size[1]) // stride_size[1] + 1
|
| 53 |
+
|
| 54 |
+
# Patchification
|
| 55 |
+
self.x_embedder = nn.Conv2d(self.input_dim, self.latent_dim, kernel_size=self.patch_size, stride=self.stride_size, bias=True)
|
| 56 |
+
|
| 57 |
+
# Positional Encoding
|
| 58 |
+
max_length = max_length * self.patches_per_frame
|
| 59 |
+
self.max_lens = [max_length]
|
| 60 |
+
self.rope = RopeND(nd=1, nd_split=[1], max_lens=self.max_lens)
|
| 61 |
+
self.position_ids_precompute = torch.arange(max_length).unsqueeze(0)
|
| 62 |
+
|
| 63 |
+
self.ACMDMTransformer = nn.ModuleList([
|
| 64 |
+
ACMDMTransBlock(self.latent_dim, num_heads, mlp_size=ff_size, rope=self.rope, qk_norm=True) for _ in range(num_layers)
|
| 65 |
+
])
|
| 66 |
+
|
| 67 |
+
if self.cond_mode == 'text':
|
| 68 |
+
self.y_embedder = nn.Linear(self.clip_dim, self.latent_dim)
|
| 69 |
+
elif self.cond_mode == 'action':
|
| 70 |
+
self.y_embedder = nn.Linear(self.num_actions, self.latent_dim)
|
| 71 |
+
elif self.cond_mode == 'uncond':
|
| 72 |
+
self.y_embedder = nn.Identity()
|
| 73 |
+
else:
|
| 74 |
+
raise KeyError("Unsupported condition mode!!!")
|
| 75 |
+
|
| 76 |
+
self.final_layer = FinalLayer(self.latent_dim, self.input_dim, patch_size=patch_size, stride_size=stride_size, patches=self.patches_per_frame)
|
| 77 |
+
|
| 78 |
+
self.initialize_weights()
|
| 79 |
+
|
| 80 |
+
if self.cond_mode == 'text':
|
| 81 |
+
print('Loading CLIP...')
|
| 82 |
+
self.clip_version = clip_version
|
| 83 |
+
self.clip_model = self.load_and_freeze_clip(clip_version)
|
| 84 |
+
|
| 85 |
+
def initialize_weights(self):
|
| 86 |
+
# Initialize transformer layers:
|
| 87 |
+
def _basic_init(module):
|
| 88 |
+
if isinstance(module, nn.Linear):
|
| 89 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
| 90 |
+
if module.bias is not None:
|
| 91 |
+
nn.init.constant_(module.bias, 0)
|
| 92 |
+
|
| 93 |
+
self.apply(_basic_init)
|
| 94 |
+
|
| 95 |
+
# Initialize timestep embedding MLP:
|
| 96 |
+
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
| 97 |
+
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
| 98 |
+
|
| 99 |
+
# Zero-out adaLN modulation layers in ACMDM blocks:
|
| 100 |
+
for block in self.ACMDMTransformer:
|
| 101 |
+
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
| 102 |
+
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
| 103 |
+
|
| 104 |
+
# Zero-out output layers:
|
| 105 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
| 106 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
| 107 |
+
nn.init.constant_(self.final_layer.linear.weight, 0)
|
| 108 |
+
nn.init.constant_(self.final_layer.linear.bias, 0)
|
| 109 |
+
|
| 110 |
+
def load_and_freeze_clip(self, clip_version):
|
| 111 |
+
clip_model, clip_preprocess = clip.load(clip_version, device='cpu', jit=False)
|
| 112 |
+
assert torch.cuda.is_available()
|
| 113 |
+
clip.model.convert_weights(clip_model)
|
| 114 |
+
|
| 115 |
+
clip_model.eval()
|
| 116 |
+
for p in clip_model.parameters():
|
| 117 |
+
p.requires_grad = False
|
| 118 |
+
return clip_model
|
| 119 |
+
|
| 120 |
+
def encode_text(self, raw_text):
|
| 121 |
+
device = next(self.parameters()).device
|
| 122 |
+
text = clip.tokenize(raw_text, truncate=True).to(device)
|
| 123 |
+
feat_clip_text = self.clip_model.encode_text(text).float()
|
| 124 |
+
return feat_clip_text
|
| 125 |
+
|
| 126 |
+
def mask_cond(self, cond, force_mask=False):
|
| 127 |
+
bs, d = cond.shape
|
| 128 |
+
if force_mask:
|
| 129 |
+
return torch.zeros_like(cond)
|
| 130 |
+
elif self.training and self.cond_drop_prob > 0.:
|
| 131 |
+
mask = torch.bernoulli(torch.ones(bs, device=cond.device) * self.cond_drop_prob).view(bs, 1)
|
| 132 |
+
return cond * (1. - mask)
|
| 133 |
+
else:
|
| 134 |
+
return cond
|
| 135 |
+
|
| 136 |
+
def forward(self, x, t, conds, attention_mask, context, force_mask=False):
|
| 137 |
+
t = self.t_embedder(t, dtype=x.dtype)
|
| 138 |
+
conds = self.mask_cond(conds, force_mask=force_mask)
|
| 139 |
+
x = torch.cat([context, x], dim=2)
|
| 140 |
+
x = self.x_embedder(x)
|
| 141 |
+
x = x.flatten(2).transpose(1, 2)
|
| 142 |
+
conds = self.y_embedder(conds)
|
| 143 |
+
y = t.unsqueeze(1) + conds.unsqueeze(1)
|
| 144 |
+
position_ids = self.position_ids_precompute[:, :x.shape[1]]
|
| 145 |
+
for block in self.ACMDMTransformer:
|
| 146 |
+
x = block(x, y, attention_mask, position_ids=position_ids)
|
| 147 |
+
x = self.final_layer(x, y)[:, :, 5:, :]
|
| 148 |
+
return x
|
| 149 |
+
|
| 150 |
+
def forward_with_CFG(self, x, t, conds, attention_mask, context, cfg=1.0):
|
| 151 |
+
if not cfg == 1.0:
|
| 152 |
+
half = x[: len(x) // 2]
|
| 153 |
+
x = torch.cat([half, half], dim=0)
|
| 154 |
+
context = torch.cat([context, context], dim=0)
|
| 155 |
+
x = self.forward(x, t, conds, attention_mask, context)
|
| 156 |
+
if not cfg == 1.0:
|
| 157 |
+
cond_eps, uncond_eps = torch.split(x, len(x) // 2, dim=0)
|
| 158 |
+
half_eps = uncond_eps + cfg * (cond_eps - uncond_eps)
|
| 159 |
+
x = torch.cat([half_eps, half_eps], dim=0)
|
| 160 |
+
return x
|
| 161 |
+
|
| 162 |
+
def forward_loss(self, latents, y, m_lens):
|
| 163 |
+
latents = latents.permute(0, 2, 3, 1)
|
| 164 |
+
b, l, j, d = latents.shape
|
| 165 |
+
device = latents.device
|
| 166 |
+
|
| 167 |
+
non_pad_mask = lengths_to_mask(m_lens, l)
|
| 168 |
+
latents = torch.where(non_pad_mask.unsqueeze(-1).unsqueeze(-1), latents, torch.zeros_like(latents))
|
| 169 |
+
|
| 170 |
+
# prefix 20, prediction 40 style
|
| 171 |
+
target = latents.clone().permute(0, 3, 1, 2).detach()[:, :, 5:, :]
|
| 172 |
+
context = latents.clone().permute(0, 3, 1, 2).detach()[:, :, :5, :]
|
| 173 |
+
|
| 174 |
+
force_mask = False
|
| 175 |
+
if self.cond_mode == 'text':
|
| 176 |
+
with torch.no_grad():
|
| 177 |
+
cond_vector = self.encode_text(y)
|
| 178 |
+
elif self.cond_mode == 'action':
|
| 179 |
+
cond_vector = self.enc_action(y).to(device).float()
|
| 180 |
+
elif self.cond_mode == 'uncond':
|
| 181 |
+
cond_vector = torch.zeros(b, self.latent_dim).float().to(device)
|
| 182 |
+
force_mask = True
|
| 183 |
+
else:
|
| 184 |
+
raise NotImplementedError("Unsupported condition mode!!!")
|
| 185 |
+
|
| 186 |
+
attention_mask = non_pad_mask.unsqueeze(-1).repeat(1, 1, self.patches_per_frame).flatten(1).unsqueeze(
|
| 187 |
+
1).unsqueeze(1)
|
| 188 |
+
|
| 189 |
+
model_kwargs = dict(conds=cond_vector, force_mask=force_mask, attention_mask=attention_mask, context=context)
|
| 190 |
+
if self.diff_model == "Flow":
|
| 191 |
+
loss_dict = self.train_diffusion.training_losses(self.forward, target, model_kwargs)
|
| 192 |
+
else:
|
| 193 |
+
t = torch.randint(0, self.train_diffusion.num_timesteps, (target.shape[0],), device=target.device)
|
| 194 |
+
loss_dict = self.train_diffusion.training_losses(self.forward, target, t, model_kwargs)
|
| 195 |
+
loss = loss_dict["loss"]
|
| 196 |
+
non_pad_mask = non_pad_mask[:, 5:]
|
| 197 |
+
loss = (loss * non_pad_mask).sum() / non_pad_mask.sum()
|
| 198 |
+
|
| 199 |
+
return loss
|
| 200 |
+
|
| 201 |
+
@torch.no_grad()
|
| 202 |
+
@eval_decorator
|
| 203 |
+
def generate(self,
|
| 204 |
+
conds,
|
| 205 |
+
m_lens,
|
| 206 |
+
cond_scale: int,
|
| 207 |
+
context,
|
| 208 |
+
temperature=1,
|
| 209 |
+
j=22,
|
| 210 |
+
):
|
| 211 |
+
device = next(self.parameters()).device
|
| 212 |
+
l = max(m_lens)
|
| 213 |
+
b = len(m_lens)
|
| 214 |
+
|
| 215 |
+
if self.cond_mode == 'text':
|
| 216 |
+
with torch.no_grad():
|
| 217 |
+
cond_vector = self.encode_text(conds)
|
| 218 |
+
elif self.cond_mode == 'action':
|
| 219 |
+
cond_vector = self.enc_action(conds).to(device)
|
| 220 |
+
elif self.cond_mode == 'uncond':
|
| 221 |
+
cond_vector = torch.zeros(b, self.latent_dim).float().to(device)
|
| 222 |
+
else:
|
| 223 |
+
raise NotImplementedError("Unsupported condition mode!!!")
|
| 224 |
+
|
| 225 |
+
padding_mask = ~lengths_to_mask(m_lens, l)
|
| 226 |
+
if not cond_scale == 1.0:
|
| 227 |
+
cond_vector = torch.cat([cond_vector, torch.zeros_like(cond_vector)], dim=0)
|
| 228 |
+
|
| 229 |
+
# really naive way to write the PrefixAR inferece loop, to be improved
|
| 230 |
+
iter = [(0,15),(10,25),(20, 35), (30, 45), (40, l.item())]
|
| 231 |
+
out = [context.clone().detach()]
|
| 232 |
+
for i in range(len(iter)):
|
| 233 |
+
noise = torch.randn(b, self.input_dim, iter[i][1]-iter[i][0]-5, j).to(device)
|
| 234 |
+
if not cond_scale == 1.0:
|
| 235 |
+
noise = torch.cat([noise, noise], dim=0)
|
| 236 |
+
|
| 237 |
+
attention_mask = ((~padding_mask)[:, iter[i][0]:iter[i][1]]).unsqueeze(-1).repeat(1,1,self.patches_per_frame).flatten(1).unsqueeze(1).unsqueeze(1)
|
| 238 |
+
model_kwargs = dict(conds=cond_vector, attention_mask=attention_mask, context=context, cfg=cond_scale)
|
| 239 |
+
sample_fn = self.forward_with_CFG
|
| 240 |
+
|
| 241 |
+
if not cond_scale == 1:
|
| 242 |
+
model_kwargs["attention_mask"] = attention_mask.repeat(2, 1, 1, 1)
|
| 243 |
+
|
| 244 |
+
if self.diff_model == "Flow":
|
| 245 |
+
model_fn = self.gen_diffusion.sample_ode(sampling_method="euler") # default to ode sampling, use euler to prevent underflow as current iter can contain paddings
|
| 246 |
+
sampled_token_latent = model_fn(noise, sample_fn, **model_kwargs)[-1]
|
| 247 |
+
else:
|
| 248 |
+
sampled_token_latent = self.gen_diffusion.p_sample_loop(
|
| 249 |
+
sample_fn, noise.shape, noise, clip_denoised=False, model_kwargs=model_kwargs,
|
| 250 |
+
progress=False,
|
| 251 |
+
temperature=temperature
|
| 252 |
+
)
|
| 253 |
+
if not cond_scale == 1:
|
| 254 |
+
sampled_token_latent, _ = sampled_token_latent.chunk(2, dim=0)
|
| 255 |
+
out.append(sampled_token_latent.clone().detach())
|
| 256 |
+
context = sampled_token_latent[:, :, 5:, :].clone().detach()
|
| 257 |
+
sampled_token_latent = torch.cat(out, dim=2).permute(0,2,3,1)
|
| 258 |
+
|
| 259 |
+
latents = torch.where(padding_mask.unsqueeze(-1).unsqueeze(-1), torch.zeros_like(sampled_token_latent), sampled_token_latent)
|
| 260 |
+
return latents.permute(0,3,1,2)
|
| 261 |
+
|
| 262 |
+
#################################################################################
|
| 263 |
+
# ACMDM Zoos #
|
| 264 |
+
#################################################################################
|
| 265 |
+
def acmdm_prefixar_flow_s_ps22(**kwargs):
|
| 266 |
+
layer = 8
|
| 267 |
+
return ACMDM(latent_dim=layer*64, ff_size=layer*64*4, num_layers=layer, num_heads=layer, dropout=0, clip_dim=512,
|
| 268 |
+
diff_model="Flow", cond_drop_prob=0.1, max_length=15,
|
| 269 |
+
patch_size=(1, 22), stride_size=(1, 22), **kwargs)
|
| 270 |
+
ACMDM_models = {
|
| 271 |
+
'ACMDM-PrefixAR-Flow-S-PatchSize22': acmdm_prefixar_flow_s_ps22,
|
| 272 |
+
}
|
| 273 |
+
|
| 274 |
+
#################################################################################
|
| 275 |
+
# Inner Architectures #
|
| 276 |
+
#################################################################################
|
| 277 |
+
def modulate(x, shift, scale):
|
| 278 |
+
return x * (1 + scale) + shift
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
class ACMDMAttention(Attention):
|
| 282 |
+
def __init__(
|
| 283 |
+
self,
|
| 284 |
+
dim,
|
| 285 |
+
num_heads=8,
|
| 286 |
+
qkv_bias=True,
|
| 287 |
+
rope=None,
|
| 288 |
+
qk_norm=True,
|
| 289 |
+
**block_kwargs,
|
| 290 |
+
):
|
| 291 |
+
super().__init__(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm, **block_kwargs)
|
| 292 |
+
self.rope = rope
|
| 293 |
+
|
| 294 |
+
def forward(self, x, position_ids=None, attention_mask=None):
|
| 295 |
+
B, N, C = x.shape
|
| 296 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
| 297 |
+
q, k, v = qkv.unbind(0)
|
| 298 |
+
q, k = self.q_norm(q), self.k_norm(k)
|
| 299 |
+
|
| 300 |
+
if self.rope is not None:
|
| 301 |
+
q, k = self.rope(q, k, position_ids)
|
| 302 |
+
|
| 303 |
+
x = torch.nn.functional.scaled_dot_product_attention(
|
| 304 |
+
q, k, v,
|
| 305 |
+
attn_mask=attention_mask,
|
| 306 |
+
dropout_p=self.attn_drop.p
|
| 307 |
+
)
|
| 308 |
+
x = x.transpose(1, 2).reshape(B, N, C)
|
| 309 |
+
x = self.proj(x)
|
| 310 |
+
x = self.proj_drop(x)
|
| 311 |
+
return x
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
class SwiGLUFFN(nn.Module):
|
| 315 |
+
def __init__(
|
| 316 |
+
self,
|
| 317 |
+
in_features: int,
|
| 318 |
+
hidden_features,
|
| 319 |
+
bias: bool = True,
|
| 320 |
+
) -> None:
|
| 321 |
+
super().__init__()
|
| 322 |
+
out_features = in_features
|
| 323 |
+
hidden_features = hidden_features
|
| 324 |
+
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
|
| 325 |
+
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
|
| 326 |
+
|
| 327 |
+
def forward(self, x):
|
| 328 |
+
x12 = self.w12(x)
|
| 329 |
+
x1, x2 = x12.chunk(2, dim=-1)
|
| 330 |
+
hidden = F.silu(x1) * x2
|
| 331 |
+
return self.w3(hidden)
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
class ACMDMTransBlock(nn.Module):
|
| 335 |
+
def __init__(self, hidden_size, num_heads, mlp_size=1024, rope=None, qk_norm=True):
|
| 336 |
+
super().__init__()
|
| 337 |
+
self.norm1 = LlamaRMSNorm(hidden_size, eps=1e-6)
|
| 338 |
+
self.attn = ACMDMAttention(hidden_size, num_heads=num_heads, qkv_bias=True, norm_layer=LlamaRMSNorm,
|
| 339 |
+
qk_norm=qk_norm, rope=rope)
|
| 340 |
+
self.norm2 = LlamaRMSNorm(hidden_size, eps=1e-6)
|
| 341 |
+
self.mlp = SwiGLUFFN(hidden_size, int(2 / 3 * mlp_size))
|
| 342 |
+
self.adaLN_modulation = nn.Sequential(
|
| 343 |
+
nn.SiLU(),
|
| 344 |
+
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
def forward(self, x, c, attention_mask=None, position_ids=None):
|
| 348 |
+
dtype = x.dtype
|
| 349 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
|
| 350 |
+
norm_x1 = self.norm1(x.to(torch.float32)).to(dtype)
|
| 351 |
+
attn_input_x = modulate(norm_x1, shift_msa, scale_msa)
|
| 352 |
+
attn_output_x = self.attn(attn_input_x, attention_mask=attention_mask, position_ids=position_ids)
|
| 353 |
+
x = x + gate_msa * attn_output_x
|
| 354 |
+
|
| 355 |
+
norm_x2 = self.norm2(x.to(torch.float32)).to(dtype)
|
| 356 |
+
gate_input_x = modulate(norm_x2, shift_mlp, scale_mlp)
|
| 357 |
+
gate_output_x = self.mlp(gate_input_x)
|
| 358 |
+
x = x + gate_mlp * gate_output_x
|
| 359 |
+
return x
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
class FinalLayer(nn.Module):
|
| 363 |
+
def __init__(self, hidden_size, output_size, patch_size=(1, 22), stride_size=(1,22), patches=1):
|
| 364 |
+
super().__init__()
|
| 365 |
+
self.norm_final = LlamaRMSNorm(hidden_size, eps=1e-6)
|
| 366 |
+
self.patch_size = patch_size
|
| 367 |
+
self.stride_size = stride_size
|
| 368 |
+
self.patches = patches
|
| 369 |
+
self.linear = nn.Linear(hidden_size, output_size*patch_size[0]*patch_size[1], bias=True)
|
| 370 |
+
self.adaLN_modulation = nn.Sequential(
|
| 371 |
+
nn.SiLU(),
|
| 372 |
+
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
def forward(self, x, c):
|
| 376 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
|
| 377 |
+
norm_x = self.norm_final(x.to(torch.float32)).to(x.dtype)
|
| 378 |
+
x = modulate(norm_x, shift, scale)
|
| 379 |
+
x = self.linear(x)
|
| 380 |
+
x = x.reshape(shape=(x.shape[0], x.shape[1]//self.patches, self.patches, self.patch_size[0], self.patch_size[1], x.shape[-1] // self.patch_size[1]))
|
| 381 |
+
x = torch.einsum('nljpqc->nclpjq', x)
|
| 382 |
+
x = x.reshape(shape=(x.shape[0], x.shape[1], -1, 22))
|
| 383 |
+
return x
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
class TimestepEmbedder(nn.Module):
|
| 387 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
| 388 |
+
super().__init__()
|
| 389 |
+
self.mlp = nn.Sequential(
|
| 390 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
| 391 |
+
nn.SiLU(),
|
| 392 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
| 393 |
+
)
|
| 394 |
+
self.frequency_embedding_size = frequency_embedding_size
|
| 395 |
+
|
| 396 |
+
@staticmethod
|
| 397 |
+
def timestep_embedding(t, dim, max_period=10000, dtype=torch.float32):
|
| 398 |
+
"""
|
| 399 |
+
Create sinusoidal timestep embeddings.
|
| 400 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
| 401 |
+
These may be fractional.
|
| 402 |
+
:param dim: the dimension of the output.
|
| 403 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
| 404 |
+
:return: an (N, D) Tensor of positional embeddings.
|
| 405 |
+
"""
|
| 406 |
+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
| 407 |
+
half = dim // 2
|
| 408 |
+
freqs = torch.exp(
|
| 409 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=dtype) / half
|
| 410 |
+
).to(device=t.device, dtype=dtype)
|
| 411 |
+
args = t[:, None] * freqs[None]
|
| 412 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 413 |
+
if dim % 2:
|
| 414 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 415 |
+
return embedding
|
| 416 |
+
|
| 417 |
+
def forward(self, t, dtype=torch.bfloat16):
|
| 418 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size, dtype=dtype)
|
| 419 |
+
t_emb = self.mlp(t_freq)
|
| 420 |
+
return t_emb
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
class LlamaRMSNorm(nn.Module):
|
| 424 |
+
def __init__(self, hidden_size, eps=1e-6):
|
| 425 |
+
super().__init__()
|
| 426 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 427 |
+
self.variance_epsilon = eps
|
| 428 |
+
|
| 429 |
+
def forward(self, hidden_states):
|
| 430 |
+
input_dtype = hidden_states.dtype
|
| 431 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 432 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 433 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 434 |
+
return (self.weight * hidden_states).to(input_dtype)
|
models/AE_2D_Causal.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
#################################################################################
|
| 7 |
+
# AE #
|
| 8 |
+
#################################################################################
|
| 9 |
+
class AE(nn.Module):
|
| 10 |
+
def __init__(self, input_width=3, output_emb_width=4, width=512, depth=3, ch_mult=(1,1,1)):
|
| 11 |
+
super().__init__()
|
| 12 |
+
self.output_emb_width = output_emb_width
|
| 13 |
+
self.encoder = Encoder(input_width, output_emb_width, width, depth, in_ch_mult=ch_mult[:-1], ch_mult=ch_mult[1:])
|
| 14 |
+
self.decoder = Decoder(input_width, output_emb_width, width, depth, in_ch_mult=ch_mult[::-1][1:], ch_mult=ch_mult[::-1][:-1])
|
| 15 |
+
|
| 16 |
+
def preprocess(self, x):
|
| 17 |
+
x = x.permute(0, 3, 1, 2).float()
|
| 18 |
+
return x
|
| 19 |
+
|
| 20 |
+
def encode(self, x):
|
| 21 |
+
x_in = self.preprocess(x)
|
| 22 |
+
x_encoder = self.encoder(x_in)
|
| 23 |
+
return x_encoder
|
| 24 |
+
|
| 25 |
+
def forward(self, x):
|
| 26 |
+
x_in = self.preprocess(x)
|
| 27 |
+
x_encoder = self.encoder(x_in)
|
| 28 |
+
x_out = self.decoder(x_encoder)
|
| 29 |
+
return x_out
|
| 30 |
+
|
| 31 |
+
def decode(self, x):
|
| 32 |
+
x_out = self.decoder(x)
|
| 33 |
+
return x_out
|
| 34 |
+
|
| 35 |
+
#################################################################################
|
| 36 |
+
# VAE #
|
| 37 |
+
#################################################################################
|
| 38 |
+
class VAE(nn.Module):
|
| 39 |
+
def __init__(self, input_width=3, output_emb_width=4, width=512, depth=3, ch_mult=(1,1,1)):
|
| 40 |
+
super().__init__()
|
| 41 |
+
self.output_emb_width = output_emb_width
|
| 42 |
+
self.encoder = Encoder(input_width, output_emb_width*2, width, depth, in_ch_mult=ch_mult[:-1], ch_mult=ch_mult[1:])
|
| 43 |
+
self.decoder = Decoder(input_width, output_emb_width, width, depth, in_ch_mult=ch_mult[::-1][1:], ch_mult=ch_mult[::-1][:-1])
|
| 44 |
+
|
| 45 |
+
def preprocess(self, x):
|
| 46 |
+
x = x.permute(0, 3, 1, 2).float()
|
| 47 |
+
return x
|
| 48 |
+
|
| 49 |
+
def encode(self, x):
|
| 50 |
+
x_in = self.preprocess(x)
|
| 51 |
+
x_encoder = self.encoder(x_in)
|
| 52 |
+
x_encoder = DiagonalGaussianDistribution(x_encoder)
|
| 53 |
+
x_encoder = x_encoder.sample()
|
| 54 |
+
return x_encoder
|
| 55 |
+
|
| 56 |
+
def forward(self, x, need_loss=False):
|
| 57 |
+
x_in = self.preprocess(x)
|
| 58 |
+
x_encoder = self.encoder(x_in)
|
| 59 |
+
x_encoder = DiagonalGaussianDistribution(x_encoder)
|
| 60 |
+
kl_loss = x_encoder.kl()
|
| 61 |
+
x_encoder = x_encoder.sample()
|
| 62 |
+
x_out = self.decoder(x_encoder)
|
| 63 |
+
if need_loss:
|
| 64 |
+
# sigma vae for better quality
|
| 65 |
+
log_sigma = ((x - x_out) ** 2).mean([1,2,3], keepdim=True).sqrt().log()
|
| 66 |
+
log_sigma = -6 + F.softplus(log_sigma - (-6))
|
| 67 |
+
rec = 0.5 * torch.pow((x - x_out) / log_sigma.exp(), 2) + log_sigma
|
| 68 |
+
rec = rec.sum(dim=(1,2,3))
|
| 69 |
+
loss = {
|
| 70 |
+
"rec": rec.mean(),
|
| 71 |
+
"kl": kl_loss.mean()}
|
| 72 |
+
return x_out, loss
|
| 73 |
+
else:
|
| 74 |
+
return x_out
|
| 75 |
+
|
| 76 |
+
def decode(self, x):
|
| 77 |
+
x_out = self.decoder(x)
|
| 78 |
+
return x_out
|
| 79 |
+
|
| 80 |
+
#################################################################################
|
| 81 |
+
# AE Zoos #
|
| 82 |
+
#################################################################################
|
| 83 |
+
def ae(**kwargs):
|
| 84 |
+
return AE(output_emb_width=4, width=512, depth=3, ch_mult=(1,1,1), **kwargs)
|
| 85 |
+
def vae(**kwargs):
|
| 86 |
+
return VAE(output_emb_width=4, width=512, depth=3, ch_mult=(1,1,1), **kwargs)
|
| 87 |
+
AE_models = {
|
| 88 |
+
'AE_Model': ae, 'VAE_Model': vae
|
| 89 |
+
}
|
| 90 |
+
#################################################################################
|
| 91 |
+
# Inner Architectures #
|
| 92 |
+
#################################################################################
|
| 93 |
+
class Encoder(nn.Module):
|
| 94 |
+
def __init__(self, input_emb_width=3, output_emb_width=4, width=512, depth=3, in_ch_mult=(1,1), ch_mult=(1,1)):
|
| 95 |
+
super().__init__()
|
| 96 |
+
self.model = nn.ModuleList()
|
| 97 |
+
self.conv_in = nn.Conv2d(input_emb_width, width, (3, 1), (1, 1), (0, 0))
|
| 98 |
+
|
| 99 |
+
block_in = width * in_ch_mult[0]
|
| 100 |
+
for i in range(len(in_ch_mult)):
|
| 101 |
+
block_in = width * in_ch_mult[i]
|
| 102 |
+
block_out = width * ch_mult[i]
|
| 103 |
+
self.model.append(CausalPad2d((0, 0, 2, 0)))
|
| 104 |
+
self.model.append(nn.Conv2d(width, width, (4, 1), (2, 1), (0, 0)))
|
| 105 |
+
for j in range(depth):
|
| 106 |
+
self.model.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dil=2-j))
|
| 107 |
+
block_in = block_out
|
| 108 |
+
|
| 109 |
+
self.conv_out = torch.nn.Conv2d(block_in, output_emb_width, (3, 1), (1, 1), (0, 0))
|
| 110 |
+
def forward(self, x):
|
| 111 |
+
x = F.pad(x, (0, 0, 2, 0))
|
| 112 |
+
x = self.conv_in(x)
|
| 113 |
+
for layer in self.model:
|
| 114 |
+
x = layer(x)
|
| 115 |
+
x = F.pad(x, (0, 0, 2, 0))
|
| 116 |
+
x = self.conv_out(x)
|
| 117 |
+
return x
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class Decoder(nn.Module):
|
| 121 |
+
def __init__(self, input_emb_width=3, output_emb_width=4, width=512, depth=3, in_ch_mult=(1,1), ch_mult=(1,1)):
|
| 122 |
+
super().__init__()
|
| 123 |
+
self.model = nn.ModuleList()
|
| 124 |
+
block_in = width * ch_mult[0]
|
| 125 |
+
self.conv_in = nn.Conv2d(output_emb_width, block_in, (3,1), (1,1), (0,0))
|
| 126 |
+
|
| 127 |
+
for i in range(len(in_ch_mult)):
|
| 128 |
+
block_in = width * ch_mult[i]
|
| 129 |
+
block_out = width * in_ch_mult[i]
|
| 130 |
+
for j in range(depth):
|
| 131 |
+
self.model.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dil=2-j))
|
| 132 |
+
block_in = block_out
|
| 133 |
+
self.model.append(Upsample(block_in))
|
| 134 |
+
|
| 135 |
+
self.conv_out1 = torch.nn.Conv2d(block_in, block_in, (3, 1), (1,1), (0,0))
|
| 136 |
+
self.conv_out2 = torch.nn.Conv2d(block_in, input_emb_width, (3, 1), (1, 1), (0, 0))
|
| 137 |
+
|
| 138 |
+
def forward(self, x):
|
| 139 |
+
x = F.pad(x, (0, 0, 2, 0))
|
| 140 |
+
x = self.conv_in(x)
|
| 141 |
+
for layer in self.model:
|
| 142 |
+
x = layer(x)
|
| 143 |
+
x = F.pad(x, (0, 0, 2, 0))
|
| 144 |
+
x = self.conv_out1(x)
|
| 145 |
+
x = x * torch.sigmoid(x)
|
| 146 |
+
x = F.pad(x, (0, 0, 2, 0))
|
| 147 |
+
x = self.conv_out2(x)
|
| 148 |
+
return x.permute(0,2,3,1)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class Upsample(nn.Module):
|
| 152 |
+
def __init__(self, in_channels):
|
| 153 |
+
super().__init__()
|
| 154 |
+
self.conv = torch.nn.Conv2d(in_channels, in_channels,(3, 1), (1, 1), (0, 0))
|
| 155 |
+
|
| 156 |
+
def forward(self, x):
|
| 157 |
+
x = torch.nn.functional.interpolate(x, scale_factor=(2.0, 1.0), mode="nearest")
|
| 158 |
+
x = F.pad(x, (0, 0, 2, 0))
|
| 159 |
+
x = self.conv(x)
|
| 160 |
+
return x
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class ResnetBlock(nn.Module):
|
| 164 |
+
def __init__(self, *, in_channels, out_channels=None, dil=0, conv_shortcut=False, dropout=0.2):
|
| 165 |
+
super().__init__()
|
| 166 |
+
self.in_channels = in_channels
|
| 167 |
+
out_channels = in_channels if out_channels is None else out_channels
|
| 168 |
+
self.out_channels = out_channels
|
| 169 |
+
self.use_conv_shortcut = conv_shortcut
|
| 170 |
+
self.padd = CausalPad2d((0, 0, 2*(3 ** dil), 0))
|
| 171 |
+
|
| 172 |
+
self.conv1 = torch.nn.Conv2d(in_channels,
|
| 173 |
+
out_channels,
|
| 174 |
+
kernel_size=(3, 1),
|
| 175 |
+
stride=(1, 1),
|
| 176 |
+
padding=(0, 0),
|
| 177 |
+
dilation=(3 ** dil, 1),
|
| 178 |
+
)
|
| 179 |
+
self.dropout = torch.nn.Dropout(dropout)
|
| 180 |
+
self.conv2 = torch.nn.Conv2d(out_channels,
|
| 181 |
+
out_channels,
|
| 182 |
+
kernel_size=(1, 1),
|
| 183 |
+
stride=(1, 1),
|
| 184 |
+
padding=(0, 0),
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
def forward(self, x):
|
| 188 |
+
h = x
|
| 189 |
+
h = h*torch.sigmoid(h)
|
| 190 |
+
h = self.padd(h)
|
| 191 |
+
h = self.conv1(h)
|
| 192 |
+
|
| 193 |
+
h = h*torch.sigmoid(h)
|
| 194 |
+
h = self.conv2(h)
|
| 195 |
+
h = self.dropout(h)
|
| 196 |
+
return x+h
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
class DiagonalGaussianDistribution(object):
|
| 200 |
+
def __init__(self, parameters, deterministic=False):
|
| 201 |
+
self.parameters = parameters
|
| 202 |
+
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
| 203 |
+
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
| 204 |
+
self.deterministic = deterministic
|
| 205 |
+
self.std = torch.exp(0.5 * self.logvar)
|
| 206 |
+
self.var = torch.exp(self.logvar)
|
| 207 |
+
if self.deterministic:
|
| 208 |
+
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
|
| 209 |
+
|
| 210 |
+
def sample(self):
|
| 211 |
+
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
|
| 212 |
+
return x
|
| 213 |
+
|
| 214 |
+
def kl(self, other=None):
|
| 215 |
+
if self.deterministic:
|
| 216 |
+
return torch.Tensor([0.])
|
| 217 |
+
else:
|
| 218 |
+
if other is None:
|
| 219 |
+
return 0.5 * torch.sum(torch.pow(self.mean, 2)
|
| 220 |
+
+ self.var - 1.0 - self.logvar,
|
| 221 |
+
dim=[1, 2, 3])
|
| 222 |
+
else:
|
| 223 |
+
return 0.5 * torch.sum(
|
| 224 |
+
torch.pow(self.mean - other.mean, 2) / other.var
|
| 225 |
+
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
|
| 226 |
+
dim=[1, 2, 3])
|
| 227 |
+
|
| 228 |
+
def nll(self, sample, dims=[1,2,3]):
|
| 229 |
+
if self.deterministic:
|
| 230 |
+
return torch.Tensor([0.])
|
| 231 |
+
logtwopi = np.log(2.0 * np.pi)
|
| 232 |
+
return 0.5 * torch.sum(
|
| 233 |
+
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
| 234 |
+
dim=dims)
|
| 235 |
+
|
| 236 |
+
def mode(self):
|
| 237 |
+
return self.mean
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
class CausalPad2d(nn.Module):
|
| 241 |
+
def __init__(self, pad):
|
| 242 |
+
super().__init__()
|
| 243 |
+
self.pad = pad
|
| 244 |
+
def forward(self, x):
|
| 245 |
+
return F.pad(x, self.pad)
|
models/AE_2D_NonCausal.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
#################################################################################
|
| 7 |
+
# AE #
|
| 8 |
+
#################################################################################
|
| 9 |
+
class AE(nn.Module):
|
| 10 |
+
def __init__(self, input_width=3, output_emb_width=4, width=512, depth=3, ch_mult=(1,1,1)):
|
| 11 |
+
super().__init__()
|
| 12 |
+
self.output_emb_width = output_emb_width
|
| 13 |
+
self.encoder = Encoder(input_width, output_emb_width, width, depth, in_ch_mult=ch_mult[:-1], ch_mult=ch_mult[1:])
|
| 14 |
+
self.decoder = Decoder(input_width, output_emb_width, width, depth, in_ch_mult=ch_mult[::-1][1:], ch_mult=ch_mult[::-1][:-1])
|
| 15 |
+
|
| 16 |
+
def preprocess(self, x):
|
| 17 |
+
x = x.permute(0, 3, 1, 2).float()
|
| 18 |
+
return x
|
| 19 |
+
|
| 20 |
+
def encode(self, x):
|
| 21 |
+
x_in = self.preprocess(x)
|
| 22 |
+
x_encoder = self.encoder(x_in)
|
| 23 |
+
return x_encoder
|
| 24 |
+
|
| 25 |
+
def forward(self, x):
|
| 26 |
+
x_in = self.preprocess(x)
|
| 27 |
+
x_encoder = self.encoder(x_in)
|
| 28 |
+
x_out = self.decoder(x_encoder)
|
| 29 |
+
return x_out
|
| 30 |
+
|
| 31 |
+
def decode(self, x):
|
| 32 |
+
x_out = self.decoder(x)
|
| 33 |
+
return x_out
|
| 34 |
+
|
| 35 |
+
#################################################################################
|
| 36 |
+
# VAE #
|
| 37 |
+
#################################################################################
|
| 38 |
+
class VAE(nn.Module):
|
| 39 |
+
def __init__(self, input_width=3, output_emb_width=4, width=512, depth=3, ch_mult=(1,1,1)):
|
| 40 |
+
super().__init__()
|
| 41 |
+
self.output_emb_width = output_emb_width
|
| 42 |
+
self.encoder = Encoder(input_width, output_emb_width*2, width, depth, in_ch_mult=ch_mult[:-1], ch_mult=ch_mult[1:])
|
| 43 |
+
self.decoder = Decoder(input_width, output_emb_width, width, depth, in_ch_mult=ch_mult[::-1][1:], ch_mult=ch_mult[::-1][:-1])
|
| 44 |
+
|
| 45 |
+
def preprocess(self, x):
|
| 46 |
+
x = x.permute(0, 3, 1, 2).float()
|
| 47 |
+
return x
|
| 48 |
+
|
| 49 |
+
def encode(self, x):
|
| 50 |
+
x_in = self.preprocess(x)
|
| 51 |
+
x_encoder = self.encoder(x_in)
|
| 52 |
+
x_encoder = DiagonalGaussianDistribution(x_encoder)
|
| 53 |
+
x_encoder = x_encoder.sample()
|
| 54 |
+
return x_encoder
|
| 55 |
+
|
| 56 |
+
def forward(self, x, need_loss=False):
|
| 57 |
+
x_in = self.preprocess(x)
|
| 58 |
+
x_encoder = self.encoder(x_in)
|
| 59 |
+
x_encoder = DiagonalGaussianDistribution(x_encoder)
|
| 60 |
+
kl_loss = x_encoder.kl()
|
| 61 |
+
x_encoder = x_encoder.sample()
|
| 62 |
+
x_out = self.decoder(x_encoder)
|
| 63 |
+
if need_loss:
|
| 64 |
+
# sigma vae for better quality
|
| 65 |
+
log_sigma = ((x - x_out) ** 2).mean([1,2,3], keepdim=True).sqrt().log()
|
| 66 |
+
log_sigma = -6 + F.softplus(log_sigma - (-6))
|
| 67 |
+
rec = 0.5 * torch.pow((x - x_out) / log_sigma.exp(), 2) + log_sigma
|
| 68 |
+
rec = rec.sum(dim=(1,2,3))
|
| 69 |
+
loss = {
|
| 70 |
+
"rec": rec.mean(),
|
| 71 |
+
"kl": kl_loss.mean()}
|
| 72 |
+
return x_out, loss
|
| 73 |
+
else:
|
| 74 |
+
return x_out
|
| 75 |
+
|
| 76 |
+
def decode(self, x):
|
| 77 |
+
x_out = self.decoder(x)
|
| 78 |
+
return x_out
|
| 79 |
+
|
| 80 |
+
#################################################################################
|
| 81 |
+
# AE Zoos #
|
| 82 |
+
#################################################################################
|
| 83 |
+
def ae(**kwargs):
|
| 84 |
+
return AE(output_emb_width=4, width=512, depth=3, ch_mult=(1,1,1), **kwargs)
|
| 85 |
+
def vae(**kwargs):
|
| 86 |
+
return VAE(output_emb_width=4, width=512, depth=3, ch_mult=(1,1,1), **kwargs)
|
| 87 |
+
AE_models = {
|
| 88 |
+
'AE_Model': ae, 'VAE_Model': vae
|
| 89 |
+
}
|
| 90 |
+
#################################################################################
|
| 91 |
+
# Inner Architectures #
|
| 92 |
+
#################################################################################
|
| 93 |
+
class Encoder(nn.Module):
|
| 94 |
+
def __init__(self, input_emb_width=3, output_emb_width=4, width=512, depth=3, in_ch_mult=(1,1), ch_mult=(1,1)):
|
| 95 |
+
super().__init__()
|
| 96 |
+
self.model = nn.ModuleList()
|
| 97 |
+
self.conv_in = nn.Conv2d(input_emb_width, width, (3, 1), (1, 1), (1, 1))
|
| 98 |
+
|
| 99 |
+
block_in = width * in_ch_mult[0]
|
| 100 |
+
for i in range(len(in_ch_mult)):
|
| 101 |
+
block_in = width * in_ch_mult[i]
|
| 102 |
+
block_out = width * ch_mult[i]
|
| 103 |
+
self.model.append(nn.Conv2d(width, width, (4, 1), (2, 1), (1, 1)))
|
| 104 |
+
for j in range(depth):
|
| 105 |
+
self.model.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dil=2-j))
|
| 106 |
+
block_in = block_out
|
| 107 |
+
|
| 108 |
+
self.conv_out = torch.nn.Conv2d(block_in, output_emb_width, (3, 1), (1, 1), (1, 1))
|
| 109 |
+
def forward(self, x):
|
| 110 |
+
x = self.conv_in(x)
|
| 111 |
+
for layer in self.model:
|
| 112 |
+
x = layer(x)
|
| 113 |
+
x = self.conv_out(x)
|
| 114 |
+
return x
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class Decoder(nn.Module):
|
| 118 |
+
def __init__(self, input_emb_width=3, output_emb_width=4, width=512, depth=3, in_ch_mult=(1,1), ch_mult=(1,1)):
|
| 119 |
+
super().__init__()
|
| 120 |
+
self.model = nn.ModuleList()
|
| 121 |
+
block_in = width * ch_mult[0]
|
| 122 |
+
self.conv_in = nn.Conv2d(output_emb_width, block_in, (3,1), (1,1), (1,1))
|
| 123 |
+
|
| 124 |
+
for i in range(len(in_ch_mult)):
|
| 125 |
+
block_in = width * ch_mult[i]
|
| 126 |
+
block_out = width * in_ch_mult[i]
|
| 127 |
+
for j in range(depth):
|
| 128 |
+
self.model.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dil=2-j))
|
| 129 |
+
block_in = block_out
|
| 130 |
+
self.model.append(Upsample(block_in))
|
| 131 |
+
|
| 132 |
+
self.conv_out1 = torch.nn.Conv2d(block_in, block_in, (3, 1), (1,1), (1,1))
|
| 133 |
+
self.conv_out2 = torch.nn.Conv2d(block_in, input_emb_width, (3, 1), (1, 1), (1, 1))
|
| 134 |
+
|
| 135 |
+
def forward(self, x):
|
| 136 |
+
x = self.conv_in(x)
|
| 137 |
+
for layer in self.model:
|
| 138 |
+
x = layer(x)
|
| 139 |
+
x = self.conv_out1(x)
|
| 140 |
+
x = x * torch.sigmoid(x)
|
| 141 |
+
x = self.conv_out2(x)
|
| 142 |
+
return x.permute(0,2,3,1)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class Upsample(nn.Module):
|
| 146 |
+
def __init__(self, in_channels):
|
| 147 |
+
super().__init__()
|
| 148 |
+
self.conv = torch.nn.Conv2d(in_channels, in_channels,(3, 1), (1, 1), (1, 1))
|
| 149 |
+
|
| 150 |
+
def forward(self, x):
|
| 151 |
+
x = torch.nn.functional.interpolate(x, scale_factor=(2.0, 1.0), mode="nearest")
|
| 152 |
+
x = self.conv(x)
|
| 153 |
+
return x
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class ResnetBlock(nn.Module):
|
| 157 |
+
def __init__(self, *, in_channels, out_channels=None, dil=0, conv_shortcut=False, dropout=0.2):
|
| 158 |
+
super().__init__()
|
| 159 |
+
self.in_channels = in_channels
|
| 160 |
+
out_channels = in_channels if out_channels is None else out_channels
|
| 161 |
+
self.out_channels = out_channels
|
| 162 |
+
self.use_conv_shortcut = conv_shortcut
|
| 163 |
+
|
| 164 |
+
self.conv1 = torch.nn.Conv2d(in_channels,
|
| 165 |
+
out_channels,
|
| 166 |
+
kernel_size=(3, 1),
|
| 167 |
+
stride=(1, 1),
|
| 168 |
+
padding=(3 ** dil, 0),
|
| 169 |
+
dilation=(3 ** dil, 1),
|
| 170 |
+
)
|
| 171 |
+
self.dropout = torch.nn.Dropout(dropout)
|
| 172 |
+
self.conv2 = torch.nn.Conv2d(out_channels,
|
| 173 |
+
out_channels,
|
| 174 |
+
kernel_size=(1, 1),
|
| 175 |
+
stride=(1, 1),
|
| 176 |
+
padding=(0, 0),
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
def forward(self, x):
|
| 180 |
+
h = x
|
| 181 |
+
h = h*torch.sigmoid(h)
|
| 182 |
+
h = self.conv1(h)
|
| 183 |
+
|
| 184 |
+
h = h*torch.sigmoid(h)
|
| 185 |
+
h = self.conv2(h)
|
| 186 |
+
h = self.dropout(h)
|
| 187 |
+
return x+h
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
class DiagonalGaussianDistribution(object):
|
| 191 |
+
def __init__(self, parameters, deterministic=False):
|
| 192 |
+
self.parameters = parameters
|
| 193 |
+
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
| 194 |
+
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
| 195 |
+
self.deterministic = deterministic
|
| 196 |
+
self.std = torch.exp(0.5 * self.logvar)
|
| 197 |
+
self.var = torch.exp(self.logvar)
|
| 198 |
+
if self.deterministic:
|
| 199 |
+
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
|
| 200 |
+
|
| 201 |
+
def sample(self):
|
| 202 |
+
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
|
| 203 |
+
return x
|
| 204 |
+
|
| 205 |
+
def kl(self, other=None):
|
| 206 |
+
if self.deterministic:
|
| 207 |
+
return torch.Tensor([0.])
|
| 208 |
+
else:
|
| 209 |
+
if other is None:
|
| 210 |
+
return 0.5 * torch.sum(torch.pow(self.mean, 2)
|
| 211 |
+
+ self.var - 1.0 - self.logvar,
|
| 212 |
+
dim=[1, 2, 3])
|
| 213 |
+
else:
|
| 214 |
+
return 0.5 * torch.sum(
|
| 215 |
+
torch.pow(self.mean - other.mean, 2) / other.var
|
| 216 |
+
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
|
| 217 |
+
dim=[1, 2, 3])
|
| 218 |
+
|
| 219 |
+
def nll(self, sample, dims=[1,2,3]):
|
| 220 |
+
if self.deterministic:
|
| 221 |
+
return torch.Tensor([0.])
|
| 222 |
+
logtwopi = np.log(2.0 * np.pi)
|
| 223 |
+
return 0.5 * torch.sum(
|
| 224 |
+
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
| 225 |
+
dim=dims)
|
| 226 |
+
|
| 227 |
+
def mode(self):
|
| 228 |
+
return self.mean
|
models/AE_Mesh.py
ADDED
|
@@ -0,0 +1,601 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# A modified version of "Fully Convolutional Mesh Autoencoder using Efficient Spatially Varying Kernels"
|
| 2 |
+
# https://arxiv.org/abs/2006.04325
|
| 3 |
+
# and thanks to this more modern implementation as well
|
| 4 |
+
# https://github.com/g-fiche/Mesh-VQ-VAE
|
| 5 |
+
# https://arxiv.org/abs/2312.08291
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import numpy as np
|
| 9 |
+
import os
|
| 10 |
+
|
| 11 |
+
#################################################################################
|
| 12 |
+
# AE #
|
| 13 |
+
#################################################################################
|
| 14 |
+
class AE(nn.Module):
|
| 15 |
+
def __init__(self, model, bs=16, num_vertices=6890):
|
| 16 |
+
super().__init__()
|
| 17 |
+
# currently only set up is for SMPL-H
|
| 18 |
+
self.num_vertices = num_vertices
|
| 19 |
+
self.bs=bs
|
| 20 |
+
self.encoder = Encoder(model)
|
| 21 |
+
self.decoder = Decoder(model)
|
| 22 |
+
|
| 23 |
+
def encode(self, x):
|
| 24 |
+
B, L = x.shape[0], x.shape[1]
|
| 25 |
+
x = x.view(B * L, self.num_vertices, 3)
|
| 26 |
+
x_encoder = self.encoder(x)
|
| 27 |
+
return x_encoder
|
| 28 |
+
|
| 29 |
+
def forward(self, x):
|
| 30 |
+
B, L = x.shape[0], x.shape[1]
|
| 31 |
+
x = x.view(B * L, self.num_vertices, 3)
|
| 32 |
+
x_encoder = self.encoder(x)
|
| 33 |
+
x_out = self.decoder(x_encoder)
|
| 34 |
+
x_out = x_out.view(B, L, self.num_vertices, 3)
|
| 35 |
+
return x_out
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def decode(self, x):
|
| 39 |
+
T = x.shape[1]
|
| 40 |
+
if x.shape[1] % self.bs != 0:
|
| 41 |
+
x = torch.cat([x, torch.zeros_like(x[:, :self.bs-x.shape[1] % self.bs])], dim=1)
|
| 42 |
+
outputs = []
|
| 43 |
+
for i in range(x.shape[0]):
|
| 44 |
+
outputss = []
|
| 45 |
+
for j in range(0, x.shape[1], self.bs):
|
| 46 |
+
chunk = x[i, j:j + self.bs]
|
| 47 |
+
out = self.decoder(chunk)
|
| 48 |
+
outputss.append(out)
|
| 49 |
+
outputs.append(torch.cat(outputss, dim=0)[:T])
|
| 50 |
+
x_out = torch.stack(outputs, dim=0)
|
| 51 |
+
|
| 52 |
+
return x_out
|
| 53 |
+
|
| 54 |
+
#################################################################################
|
| 55 |
+
# AE Zoos #
|
| 56 |
+
#################################################################################
|
| 57 |
+
def ae(**kwargs):
|
| 58 |
+
config_model = {"batch": 16,
|
| 59 |
+
"connection_folder": "body_models/ConnectionMatrices/",
|
| 60 |
+
"initial_connection_fn": "body_models/ConnectionMatrices/_pool0.npy",
|
| 61 |
+
"connection_layer_lst": ["pool0", "pool1", "pool2", "pool3", "pool4", "pool5", "pool6", "pool7_28",
|
| 62 |
+
"unpool7_28", "unpool6", "unpool5", "unpool4", "unpool3", "unpool2",
|
| 63 |
+
"unpool1", "unpool0"],
|
| 64 |
+
"channel_lst": [64, 64, 128, 128, 256, 256, 512, 12, 512, 256, 256, 128, 128, 64, 64, 3],
|
| 65 |
+
"weight_num_lst": [9, 0, 9, 0, 9, 0, 9, 0, 0, 9, 0, 9, 0, 9, 0, 9],
|
| 66 |
+
"residual_rate_lst": [0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0],
|
| 67 |
+
}
|
| 68 |
+
return AE(FullyConvAE(config_model, **kwargs), bs=config_model["batch"])
|
| 69 |
+
AE_models = {
|
| 70 |
+
'AE_Model': ae
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class Encoder(nn.Module):
|
| 75 |
+
def __init__(self, model):
|
| 76 |
+
super(Encoder, self).__init__()
|
| 77 |
+
self.model = model
|
| 78 |
+
|
| 79 |
+
def forward(self, x):
|
| 80 |
+
out = self.model.forward_till_layer_n(x, len(self.model.channel_lst) // 2)
|
| 81 |
+
return out
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class Decoder(nn.Module):
|
| 85 |
+
def __init__(self, model):
|
| 86 |
+
super(Decoder, self).__init__()
|
| 87 |
+
self.model = model
|
| 88 |
+
|
| 89 |
+
def forward(self, x):
|
| 90 |
+
out = self.model.forward_from_layer_n(x, len(self.model.channel_lst) // 2)
|
| 91 |
+
return out
|
| 92 |
+
|
| 93 |
+
class FullyConvAE(nn.Module):
|
| 94 |
+
def __init__(
|
| 95 |
+
self, config_model=None, test_mode=False
|
| 96 |
+
): # layer_info_lst= [(point_num, feature_dim)]
|
| 97 |
+
super(FullyConvAE, self).__init__()
|
| 98 |
+
|
| 99 |
+
self.test_mode = test_mode
|
| 100 |
+
|
| 101 |
+
self.channel_lst = config_model["channel_lst"]
|
| 102 |
+
|
| 103 |
+
self.residual_rate_lst = config_model["residual_rate_lst"]
|
| 104 |
+
|
| 105 |
+
self.weight_num_lst = config_model["weight_num_lst"]
|
| 106 |
+
|
| 107 |
+
self.initial_connection_fn = config_model["initial_connection_fn"]
|
| 108 |
+
|
| 109 |
+
data = np.load(self.initial_connection_fn)
|
| 110 |
+
neighbor_id_dist_lstlst = data[:, 1:] # point_num*(1+2*neighbor_num)
|
| 111 |
+
self.point_num = data.shape[0]
|
| 112 |
+
self.neighbor_id_lstlst = neighbor_id_dist_lstlst.reshape(
|
| 113 |
+
(self.point_num, -1, 2)
|
| 114 |
+
)[
|
| 115 |
+
:, :, 0
|
| 116 |
+
] # point_num*neighbor_num
|
| 117 |
+
self.neighbor_num_lst = np.array(data[:, 0]) # point_num
|
| 118 |
+
|
| 119 |
+
self.relu = nn.ELU()
|
| 120 |
+
|
| 121 |
+
self.batch = config_model["batch"]
|
| 122 |
+
|
| 123 |
+
#####For Laplace computation######
|
| 124 |
+
self.initial_neighbor_id_lstlst = torch.LongTensor(
|
| 125 |
+
self.neighbor_id_lstlst
|
| 126 |
+
).cuda() # point_num*max_neighbor_num
|
| 127 |
+
self.initial_neighbor_num_lst = torch.FloatTensor(
|
| 128 |
+
self.neighbor_num_lst
|
| 129 |
+
).cuda() # point_num
|
| 130 |
+
|
| 131 |
+
self.connection_folder = config_model["connection_folder"]
|
| 132 |
+
self.connection_layer_fn_lst = []
|
| 133 |
+
fn_lst = os.listdir(self.connection_folder)
|
| 134 |
+
self.connection_layer_lst = config_model["connection_layer_lst"]
|
| 135 |
+
for layer_name in self.connection_layer_lst:
|
| 136 |
+
layer_name = "_" + layer_name + "."
|
| 137 |
+
|
| 138 |
+
find_fn = False
|
| 139 |
+
for fn in fn_lst:
|
| 140 |
+
if (layer_name in fn) and ((".npy" in fn) or (".npz" in fn)):
|
| 141 |
+
self.connection_layer_fn_lst += [self.connection_folder + fn]
|
| 142 |
+
find_fn = True
|
| 143 |
+
break
|
| 144 |
+
if find_fn == False:
|
| 145 |
+
print("!!!ERROR: cannot find the connection layer fn")
|
| 146 |
+
|
| 147 |
+
self.init_layers(self.batch)
|
| 148 |
+
|
| 149 |
+
self.initial_max_neighbor_num = self.initial_neighbor_id_lstlst.shape[1]
|
| 150 |
+
|
| 151 |
+
def init_layers(self, batch):
|
| 152 |
+
self.layer_lst = (
|
| 153 |
+
[]
|
| 154 |
+
) ##[in_channel, out_channel, in_pn, out_pn, max_neighbor_num, neighbor_num_lst,neighbor_id_lstlst,conv_layer, residual_layer]
|
| 155 |
+
|
| 156 |
+
self.layer_num = len(self.channel_lst)
|
| 157 |
+
|
| 158 |
+
in_point_num = self.point_num
|
| 159 |
+
in_channel = 3
|
| 160 |
+
|
| 161 |
+
for l in range(self.layer_num):
|
| 162 |
+
out_channel = self.channel_lst[l]
|
| 163 |
+
weight_num = self.weight_num_lst[l]
|
| 164 |
+
residual_rate = self.residual_rate_lst[l]
|
| 165 |
+
|
| 166 |
+
connection_info = np.load(self.connection_layer_fn_lst[l])
|
| 167 |
+
out_point_num = connection_info.shape[0]
|
| 168 |
+
neighbor_num_lst = torch.FloatTensor(
|
| 169 |
+
connection_info[:, 0].astype(float)
|
| 170 |
+
).cuda() # out_point_num*1
|
| 171 |
+
neighbor_id_dist_lstlst = connection_info[
|
| 172 |
+
:, 1:
|
| 173 |
+
] # out_point_num*(max_neighbor_num*2)
|
| 174 |
+
print(self.connection_layer_fn_lst[l])
|
| 175 |
+
print()
|
| 176 |
+
neighbor_id_lstlst = neighbor_id_dist_lstlst.reshape(
|
| 177 |
+
(out_point_num, -1, 2)
|
| 178 |
+
)[
|
| 179 |
+
:, :, 0
|
| 180 |
+
] # out_point_num*max_neighbor_num
|
| 181 |
+
neighbor_id_lstlst = torch.LongTensor(neighbor_id_lstlst).cuda()
|
| 182 |
+
max_neighbor_num = neighbor_id_lstlst.shape[1]
|
| 183 |
+
avg_neighbor_num = round(neighbor_num_lst.mean().item())
|
| 184 |
+
effective_w_weights_rate = neighbor_num_lst.sum() / float(
|
| 185 |
+
max_neighbor_num * out_point_num
|
| 186 |
+
)
|
| 187 |
+
effective_w_weights_rate = round(effective_w_weights_rate.item(), 3)
|
| 188 |
+
|
| 189 |
+
pc_mask = torch.ones(in_point_num + 1).cuda()
|
| 190 |
+
pc_mask[in_point_num] = 0
|
| 191 |
+
neighbor_mask_lst = pc_mask[
|
| 192 |
+
neighbor_id_lstlst
|
| 193 |
+
].contiguous() # out_pn*max_neighbor_num neighbor is 1 otherwise 0
|
| 194 |
+
|
| 195 |
+
zeros_batch_outpn_outchannel = torch.zeros(
|
| 196 |
+
(batch, out_point_num, out_channel)
|
| 197 |
+
).cuda()
|
| 198 |
+
|
| 199 |
+
if (residual_rate < 0) or (residual_rate > 1):
|
| 200 |
+
print("Invalid residual rate", residual_rate)
|
| 201 |
+
####parameters for conv###############
|
| 202 |
+
conv_layer = ""
|
| 203 |
+
|
| 204 |
+
if residual_rate < 1:
|
| 205 |
+
weights = torch.randn(weight_num, out_channel * in_channel).cuda()
|
| 206 |
+
|
| 207 |
+
weights = nn.Parameter(weights).cuda()
|
| 208 |
+
|
| 209 |
+
self.register_parameter("weights" + str(l), weights)
|
| 210 |
+
|
| 211 |
+
bias = nn.Parameter(torch.zeros(out_channel).cuda())
|
| 212 |
+
self.register_parameter("bias" + str(l), bias)
|
| 213 |
+
|
| 214 |
+
w_weights = torch.randn(out_point_num, max_neighbor_num, weight_num) / (
|
| 215 |
+
avg_neighbor_num * weight_num
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
w_weights = nn.Parameter(w_weights.cuda())
|
| 219 |
+
self.register_parameter("w_weights" + str(l), w_weights)
|
| 220 |
+
|
| 221 |
+
conv_layer = (weights, bias, w_weights)
|
| 222 |
+
|
| 223 |
+
####parameters for residual###############
|
| 224 |
+
|
| 225 |
+
## a residual layer with out_point_num==in_point_num and residual_rate==1 is a pooling or unpooling layer
|
| 226 |
+
|
| 227 |
+
residual_layer = ""
|
| 228 |
+
|
| 229 |
+
if residual_rate > 0:
|
| 230 |
+
p_neighbors = ""
|
| 231 |
+
weight_res = ""
|
| 232 |
+
|
| 233 |
+
if out_point_num != in_point_num:
|
| 234 |
+
p_neighbors = nn.Parameter(
|
| 235 |
+
(
|
| 236 |
+
torch.randn(out_point_num, max_neighbor_num)
|
| 237 |
+
/ (avg_neighbor_num)
|
| 238 |
+
).cuda()
|
| 239 |
+
)
|
| 240 |
+
self.register_parameter("p_neighbors" + str(l), p_neighbors)
|
| 241 |
+
|
| 242 |
+
if out_channel != in_channel:
|
| 243 |
+
weight_res = torch.randn(out_channel, in_channel)
|
| 244 |
+
# self.normalize_weights(weight_res)
|
| 245 |
+
weight_res = weight_res / out_channel
|
| 246 |
+
weight_res = nn.Parameter(weight_res.cuda())
|
| 247 |
+
self.register_parameter("weight_res" + str(l), weight_res)
|
| 248 |
+
|
| 249 |
+
residual_layer = (weight_res, p_neighbors)
|
| 250 |
+
|
| 251 |
+
#####put everythin together
|
| 252 |
+
|
| 253 |
+
layer = (
|
| 254 |
+
in_channel,
|
| 255 |
+
out_channel,
|
| 256 |
+
in_point_num,
|
| 257 |
+
out_point_num,
|
| 258 |
+
weight_num,
|
| 259 |
+
max_neighbor_num,
|
| 260 |
+
neighbor_num_lst,
|
| 261 |
+
neighbor_id_lstlst,
|
| 262 |
+
conv_layer,
|
| 263 |
+
residual_layer,
|
| 264 |
+
residual_rate,
|
| 265 |
+
neighbor_mask_lst,
|
| 266 |
+
zeros_batch_outpn_outchannel,
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
self.layer_lst += [layer]
|
| 270 |
+
|
| 271 |
+
in_point_num = out_point_num
|
| 272 |
+
in_channel = out_channel
|
| 273 |
+
|
| 274 |
+
# precompute the parameters so as to accelerate forwarding in testing mode
|
| 275 |
+
def init_test_mode(self):
|
| 276 |
+
for l in range(len(self.layer_lst)):
|
| 277 |
+
layer_info = self.layer_lst[l]
|
| 278 |
+
|
| 279 |
+
(
|
| 280 |
+
in_channel,
|
| 281 |
+
out_channel,
|
| 282 |
+
in_pn,
|
| 283 |
+
out_pn,
|
| 284 |
+
weight_num,
|
| 285 |
+
max_neighbor_num,
|
| 286 |
+
neighbor_num_lst,
|
| 287 |
+
neighbor_id_lstlst,
|
| 288 |
+
conv_layer,
|
| 289 |
+
residual_layer,
|
| 290 |
+
residual_rate,
|
| 291 |
+
neighbor_mask_lst,
|
| 292 |
+
zeros_batch_outpn_outchannel,
|
| 293 |
+
) = layer_info
|
| 294 |
+
|
| 295 |
+
if len(conv_layer) != 0:
|
| 296 |
+
(
|
| 297 |
+
weights,
|
| 298 |
+
bias,
|
| 299 |
+
raw_w_weights,
|
| 300 |
+
) = conv_layer # weight_num*(out_channel*in_channel) out_point_num* max_neighbor_num* weight_num
|
| 301 |
+
|
| 302 |
+
w_weights = ""
|
| 303 |
+
|
| 304 |
+
w_weights = raw_w_weights * neighbor_mask_lst.view(
|
| 305 |
+
out_pn, max_neighbor_num, 1
|
| 306 |
+
).repeat(
|
| 307 |
+
1, 1, weight_num
|
| 308 |
+
) # out_pn*max_neighbor_num*weight_num
|
| 309 |
+
|
| 310 |
+
weights = torch.einsum(
|
| 311 |
+
"pmw,wc->pmc", [w_weights, weights]
|
| 312 |
+
) # out_pn*max_neighbor_num*(out_channel*in_channel)
|
| 313 |
+
weights = weights.view(
|
| 314 |
+
out_pn, max_neighbor_num, out_channel, in_channel
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
conv_layer = weights, bias
|
| 318 |
+
|
| 319 |
+
####compute output of residual layer####
|
| 320 |
+
|
| 321 |
+
if len(residual_layer) != 0:
|
| 322 |
+
(
|
| 323 |
+
weight_res,
|
| 324 |
+
p_neighbors_raw,
|
| 325 |
+
) = residual_layer # out_channel*in_channel out_pn*max_neighbor_num
|
| 326 |
+
if in_pn != out_pn:
|
| 327 |
+
p_neighbors = torch.abs(p_neighbors_raw) * neighbor_mask_lst
|
| 328 |
+
p_neighbors_sum = p_neighbors.sum(1) + 1e-8 # out_pn
|
| 329 |
+
p_neighbors = p_neighbors / p_neighbors_sum.view(out_pn, 1).repeat(
|
| 330 |
+
1, max_neighbor_num
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
residual_layer = weight_res, p_neighbors
|
| 334 |
+
|
| 335 |
+
self.layer_lst[l] = (
|
| 336 |
+
in_channel,
|
| 337 |
+
out_channel,
|
| 338 |
+
in_pn,
|
| 339 |
+
out_pn,
|
| 340 |
+
weight_num,
|
| 341 |
+
max_neighbor_num,
|
| 342 |
+
neighbor_num_lst,
|
| 343 |
+
neighbor_id_lstlst,
|
| 344 |
+
conv_layer,
|
| 345 |
+
residual_layer,
|
| 346 |
+
residual_rate,
|
| 347 |
+
neighbor_mask_lst,
|
| 348 |
+
zeros_batch_outpn_outchannel,
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
# a faster mode for testing
|
| 352 |
+
# input_pc batch*in_pn*in_channel
|
| 353 |
+
# out_pc batch*out_pn*out_channel
|
| 354 |
+
def forward_one_conv_layer_batch_during_test(
|
| 355 |
+
self, in_pc, layer_info, is_final_layer=False
|
| 356 |
+
):
|
| 357 |
+
batch = in_pc.shape[0]
|
| 358 |
+
|
| 359 |
+
(
|
| 360 |
+
in_channel,
|
| 361 |
+
out_channel,
|
| 362 |
+
in_pn,
|
| 363 |
+
out_pn,
|
| 364 |
+
weight_num,
|
| 365 |
+
max_neighbor_num,
|
| 366 |
+
neighbor_num_lst,
|
| 367 |
+
neighbor_id_lstlst,
|
| 368 |
+
conv_layer,
|
| 369 |
+
residual_layer,
|
| 370 |
+
residual_rate,
|
| 371 |
+
neighbor_mask_lst,
|
| 372 |
+
zeros_batch_outpn_outchannel,
|
| 373 |
+
) = layer_info
|
| 374 |
+
|
| 375 |
+
device = in_pc.get_device()
|
| 376 |
+
if device < 0:
|
| 377 |
+
device = "cpu"
|
| 378 |
+
|
| 379 |
+
in_pc_pad = torch.cat(
|
| 380 |
+
(in_pc, torch.zeros(batch, 1, in_channel).to(device)), 1
|
| 381 |
+
) # batch*(in_pn+1)*in_channel
|
| 382 |
+
|
| 383 |
+
in_neighbors = in_pc_pad[
|
| 384 |
+
:, neighbor_id_lstlst.to(device)
|
| 385 |
+
] # batch*out_pn*max_neighbor_num*in_channel
|
| 386 |
+
|
| 387 |
+
####compute output of convolution layer####
|
| 388 |
+
out_pc_conv = zeros_batch_outpn_outchannel.clone()
|
| 389 |
+
|
| 390 |
+
if len(conv_layer) != 0:
|
| 391 |
+
(
|
| 392 |
+
weights,
|
| 393 |
+
bias,
|
| 394 |
+
) = conv_layer # weight_num*(out_channel*in_channel) out_point_num* max_neighbor_num* weight_num
|
| 395 |
+
|
| 396 |
+
out_neighbors = torch.einsum(
|
| 397 |
+
"pmoi,bpmi->bpmo", [weights.to(device), in_neighbors]
|
| 398 |
+
) # batch*out_pn*max_neighbor_num*out_channel
|
| 399 |
+
|
| 400 |
+
out_pc_conv = out_neighbors.sum(2)
|
| 401 |
+
|
| 402 |
+
out_pc_conv = out_pc_conv + bias
|
| 403 |
+
|
| 404 |
+
if is_final_layer == False:
|
| 405 |
+
out_pc_conv = self.relu(
|
| 406 |
+
out_pc_conv
|
| 407 |
+
) ##self.relu is defined in the init function
|
| 408 |
+
|
| 409 |
+
# if(self.residual_rate==0):
|
| 410 |
+
# return out_pc
|
| 411 |
+
####compute output of residual layer####
|
| 412 |
+
out_pc_res = zeros_batch_outpn_outchannel.clone()
|
| 413 |
+
|
| 414 |
+
if len(residual_layer) != 0:
|
| 415 |
+
(
|
| 416 |
+
weight_res,
|
| 417 |
+
p_neighbors,
|
| 418 |
+
) = residual_layer # out_channel*in_channel out_pn*max_neighbor_num
|
| 419 |
+
|
| 420 |
+
if in_channel != out_channel:
|
| 421 |
+
in_pc_pad = torch.einsum("oi,bpi->bpo", [weight_res, in_pc_pad])
|
| 422 |
+
|
| 423 |
+
out_pc_res = []
|
| 424 |
+
if in_pn == out_pn:
|
| 425 |
+
out_pc_res = in_pc_pad[:, 0:in_pn].clone()
|
| 426 |
+
else:
|
| 427 |
+
in_neighbors = in_pc_pad[
|
| 428 |
+
:, neighbor_id_lstlst.to(device)
|
| 429 |
+
] # batch*out_pn*max_neighbor_num*out_channel
|
| 430 |
+
out_pc_res = torch.einsum(
|
| 431 |
+
"pm,bpmo->bpo", [p_neighbors.to(device), in_neighbors]
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
out_pc = out_pc_conv.to(device) * np.sqrt(1 - residual_rate) + out_pc_res.to(
|
| 435 |
+
device
|
| 436 |
+
) * np.sqrt(residual_rate)
|
| 437 |
+
|
| 438 |
+
return out_pc
|
| 439 |
+
|
| 440 |
+
# use in train mode. Slower than test mode
|
| 441 |
+
# input_pc batch*in_pn*in_channel
|
| 442 |
+
# out_pc batch*out_pn*out_channel
|
| 443 |
+
def forward_one_conv_layer_batch(self, in_pc, layer_info, is_final_layer=False):
|
| 444 |
+
batch = in_pc.shape[0]
|
| 445 |
+
|
| 446 |
+
(
|
| 447 |
+
in_channel,
|
| 448 |
+
out_channel,
|
| 449 |
+
in_pn,
|
| 450 |
+
out_pn,
|
| 451 |
+
weight_num,
|
| 452 |
+
max_neighbor_num,
|
| 453 |
+
neighbor_num_lst,
|
| 454 |
+
neighbor_id_lstlst,
|
| 455 |
+
conv_layer,
|
| 456 |
+
residual_layer,
|
| 457 |
+
residual_rate,
|
| 458 |
+
neighbor_mask_lst,
|
| 459 |
+
zeros_batch_outpn_outchannel,
|
| 460 |
+
) = layer_info
|
| 461 |
+
|
| 462 |
+
in_pc_pad = torch.cat(
|
| 463 |
+
(in_pc, torch.zeros(batch, 1, in_channel).cuda()), 1
|
| 464 |
+
) # batch*(in_pn+1)*in_channel
|
| 465 |
+
|
| 466 |
+
in_neighbors = in_pc_pad[
|
| 467 |
+
:, neighbor_id_lstlst
|
| 468 |
+
] # batch*out_pn*max_neighbor_num*in_channel
|
| 469 |
+
|
| 470 |
+
####compute output of convolution layer####
|
| 471 |
+
out_pc_conv = zeros_batch_outpn_outchannel.clone()
|
| 472 |
+
|
| 473 |
+
if len(conv_layer) != 0:
|
| 474 |
+
(
|
| 475 |
+
weights,
|
| 476 |
+
bias,
|
| 477 |
+
raw_w_weights,
|
| 478 |
+
) = conv_layer # weight_num*(out_channel*in_channel) out_point_num* max_neighbor_num* weight_num
|
| 479 |
+
|
| 480 |
+
w_weights = raw_w_weights * neighbor_mask_lst.view(
|
| 481 |
+
out_pn, max_neighbor_num, 1
|
| 482 |
+
).repeat(
|
| 483 |
+
1, 1, weight_num
|
| 484 |
+
) # out_pn*max_neighbor_num*weight_num
|
| 485 |
+
|
| 486 |
+
weights = torch.einsum(
|
| 487 |
+
"pmw,wc->pmc", [w_weights, weights]
|
| 488 |
+
) # out_pn*max_neighbor_num*(out_channel*in_channel)
|
| 489 |
+
weights = weights.view(out_pn, max_neighbor_num, out_channel, in_channel)
|
| 490 |
+
|
| 491 |
+
out_neighbors = torch.einsum(
|
| 492 |
+
"pmoi,bpmi->bpmo", [weights, in_neighbors]
|
| 493 |
+
) # batch*out_pn*max_neighbor_num*out_channel
|
| 494 |
+
|
| 495 |
+
out_pc_conv = out_neighbors.sum(2)
|
| 496 |
+
|
| 497 |
+
out_pc_conv = out_pc_conv + bias
|
| 498 |
+
|
| 499 |
+
if is_final_layer == False:
|
| 500 |
+
out_pc_conv = self.relu(
|
| 501 |
+
out_pc_conv
|
| 502 |
+
) ##self.relu is defined in the init function
|
| 503 |
+
|
| 504 |
+
####compute output of residual layer####
|
| 505 |
+
out_pc_res = zeros_batch_outpn_outchannel.clone()
|
| 506 |
+
|
| 507 |
+
if len(residual_layer) != 0:
|
| 508 |
+
(
|
| 509 |
+
weight_res,
|
| 510 |
+
p_neighbors_raw,
|
| 511 |
+
) = residual_layer # out_channel*in_channel out_pn*max_neighbor_num
|
| 512 |
+
|
| 513 |
+
if in_channel != out_channel:
|
| 514 |
+
in_pc_pad = torch.einsum("oi,bpi->bpo", [weight_res, in_pc_pad])
|
| 515 |
+
|
| 516 |
+
out_pc_res = []
|
| 517 |
+
if in_pn == out_pn:
|
| 518 |
+
out_pc_res = in_pc_pad[:, 0:in_pn].clone()
|
| 519 |
+
else:
|
| 520 |
+
in_neighbors = in_pc_pad[
|
| 521 |
+
:, neighbor_id_lstlst
|
| 522 |
+
] # batch*out_pn*max_neighbor_num*out_channel
|
| 523 |
+
|
| 524 |
+
p_neighbors = torch.abs(p_neighbors_raw) * neighbor_mask_lst
|
| 525 |
+
p_neighbors_sum = p_neighbors.sum(1) + 1e-8 # out_pn
|
| 526 |
+
p_neighbors = p_neighbors / p_neighbors_sum.view(out_pn, 1).repeat(
|
| 527 |
+
1, max_neighbor_num
|
| 528 |
+
)
|
| 529 |
+
|
| 530 |
+
out_pc_res = torch.einsum("pm,bpmo->bpo", [p_neighbors, in_neighbors])
|
| 531 |
+
|
| 532 |
+
# print(out_pc_conv.shape, out_pc_res.shape)
|
| 533 |
+
out_pc = out_pc_conv * np.sqrt(1 - residual_rate) + out_pc_res * np.sqrt(
|
| 534 |
+
residual_rate
|
| 535 |
+
)
|
| 536 |
+
|
| 537 |
+
return out_pc
|
| 538 |
+
|
| 539 |
+
def forward_till_layer_n(self, in_pc, layer_n):
|
| 540 |
+
out_pc = in_pc.clone()
|
| 541 |
+
|
| 542 |
+
for i in range(layer_n):
|
| 543 |
+
if self.test_mode == False:
|
| 544 |
+
out_pc = self.forward_one_conv_layer_batch(out_pc, self.layer_lst[i])
|
| 545 |
+
else:
|
| 546 |
+
out_pc = self.forward_one_conv_layer_batch_during_test(
|
| 547 |
+
out_pc, self.layer_lst[i]
|
| 548 |
+
)
|
| 549 |
+
|
| 550 |
+
# out_pc = self.final_linear(out_pc.transpose(1,2)).transpose(1,2) #batch*3*point_num
|
| 551 |
+
|
| 552 |
+
return out_pc
|
| 553 |
+
|
| 554 |
+
def forward_from_layer_n(self, in_pc, layer_n):
|
| 555 |
+
out_pc = in_pc.clone()
|
| 556 |
+
|
| 557 |
+
for i in range(layer_n, self.layer_num):
|
| 558 |
+
if i < (self.layer_num - 1):
|
| 559 |
+
if self.test_mode == False:
|
| 560 |
+
out_pc = self.forward_one_conv_layer_batch(
|
| 561 |
+
out_pc, self.layer_lst[i]
|
| 562 |
+
)
|
| 563 |
+
else:
|
| 564 |
+
out_pc = self.forward_one_conv_layer_batch_during_test(
|
| 565 |
+
out_pc, self.layer_lst[i]
|
| 566 |
+
)
|
| 567 |
+
else:
|
| 568 |
+
if self.test_mode == False:
|
| 569 |
+
out_pc = self.forward_one_conv_layer_batch(
|
| 570 |
+
out_pc, self.layer_lst[i], is_final_layer=True
|
| 571 |
+
)
|
| 572 |
+
else:
|
| 573 |
+
out_pc = self.forward_one_conv_layer_batch_during_test(
|
| 574 |
+
out_pc, self.layer_lst[i], is_final_layer=True
|
| 575 |
+
)
|
| 576 |
+
|
| 577 |
+
return out_pc
|
| 578 |
+
|
| 579 |
+
def forward_layer_n(self, in_pc, layer_n):
|
| 580 |
+
out_pc = in_pc.clone()
|
| 581 |
+
|
| 582 |
+
if layer_n < (self.layer_num - 1):
|
| 583 |
+
if self.test_mode == False:
|
| 584 |
+
out_pc = self.forward_one_conv_layer_batch(
|
| 585 |
+
out_pc, self.layer_lst[layer_n]
|
| 586 |
+
)
|
| 587 |
+
else:
|
| 588 |
+
out_pc = self.forward_one_conv_layer_batch_during_test(
|
| 589 |
+
out_pc, self.layer_lst[layer_n]
|
| 590 |
+
)
|
| 591 |
+
else:
|
| 592 |
+
if self.test_mode == False:
|
| 593 |
+
out_pc = self.forward_one_conv_layer_batch(
|
| 594 |
+
out_pc, self.layer_lst[layer_n], is_final_layer=True
|
| 595 |
+
)
|
| 596 |
+
else:
|
| 597 |
+
out_pc = self.forward_one_conv_layer_batch_during_test(
|
| 598 |
+
out_pc, self.layer_lst[layer_n], is_final_layer=True
|
| 599 |
+
)
|
| 600 |
+
|
| 601 |
+
return out_pc
|
models/LengthEstimator.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
|
| 3 |
+
#################################################################################
|
| 4 |
+
# Length Estimator #
|
| 5 |
+
#################################################################################
|
| 6 |
+
class LengthEstimator(nn.Module):
|
| 7 |
+
def __init__(self, input_size, output_size):
|
| 8 |
+
super(LengthEstimator, self).__init__()
|
| 9 |
+
nd = 512
|
| 10 |
+
self.output = nn.Sequential(
|
| 11 |
+
nn.Linear(input_size, nd),
|
| 12 |
+
nn.LayerNorm(nd),
|
| 13 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 14 |
+
|
| 15 |
+
nn.Dropout(0.2),
|
| 16 |
+
nn.Linear(nd, nd // 2),
|
| 17 |
+
nn.LayerNorm(nd // 2),
|
| 18 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 19 |
+
|
| 20 |
+
nn.Dropout(0.2),
|
| 21 |
+
nn.Linear(nd // 2, nd // 4),
|
| 22 |
+
nn.LayerNorm(nd // 4),
|
| 23 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 24 |
+
|
| 25 |
+
nn.Linear(nd // 4, output_size)
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
self.output.apply(self.__init_weights)
|
| 29 |
+
|
| 30 |
+
def __init_weights(self, module):
|
| 31 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
| 32 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
| 33 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
| 34 |
+
module.bias.data.zero_()
|
| 35 |
+
elif isinstance(module, nn.LayerNorm):
|
| 36 |
+
module.bias.data.zero_()
|
| 37 |
+
module.weight.data.fill_(1.0)
|
| 38 |
+
|
| 39 |
+
def forward(self, text_emb):
|
| 40 |
+
return self.output(text_emb)
|
models/ROPE.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import math
|
| 3 |
+
|
| 4 |
+
class RopeND:
|
| 5 |
+
def __init__(self, head_dim=64, nd=3, max_lens=[1024, 64, 64], nd_split=[2, 1, 1], bases=[1000, 1000, 1000],
|
| 6 |
+
auto_base=True, cache_longer=1):
|
| 7 |
+
self.nd = nd
|
| 8 |
+
self.head_dim = head_dim
|
| 9 |
+
self.max_lens = max_lens
|
| 10 |
+
self.nd_split = nd_split
|
| 11 |
+
self.split_dims = [2 * i * (head_dim // 2 // sum(nd_split)) for i in nd_split]
|
| 12 |
+
assert sum(self.split_dims) == head_dim
|
| 13 |
+
self.auto_base = auto_base
|
| 14 |
+
if auto_base:
|
| 15 |
+
# empirical, make cos(theta) = -1 when length is kL. base = kL/pi
|
| 16 |
+
# And L=1 the difference (1/base)**(1/32) ~ 0.7-0.8 ~ pi/4
|
| 17 |
+
# for traditional L = 4096, 8L/pi = 10.4k, base is set to 10k
|
| 18 |
+
self.bases = [(int(8 * l / math.pi) // 100 + 1) * 100 for l in self.max_lens]
|
| 19 |
+
print(f"Bases for rope: {self.bases}")
|
| 20 |
+
else:
|
| 21 |
+
self.bases = bases
|
| 22 |
+
self.cache_longer = cache_longer
|
| 23 |
+
|
| 24 |
+
def generated_cos_sin_mix2d(self, max_len, dim, device, base=1000):
|
| 25 |
+
inv_freq = 1.0 / (base ** \
|
| 26 |
+
(torch.linspace(start=0, end=self.head_dim, steps=dim // 2,
|
| 27 |
+
device=device).float() / self.head_dim))
|
| 28 |
+
assert inv_freq.size(0) * 2 == dim, f"inv_freq.size(0) = {inv_freq.size(0)}, required dim = {dim}"
|
| 29 |
+
|
| 30 |
+
t = torch.arange(max_len * self.cache_longer, device=device).type_as(inv_freq)
|
| 31 |
+
freqs = torch.einsum("i,j->ij", t, inv_freq)
|
| 32 |
+
freqs = torch.cat([freqs, freqs], dim=1)
|
| 33 |
+
return freqs.cos().to(torch.float), freqs.sin().to(torch.float)
|
| 34 |
+
|
| 35 |
+
def generate_pos_embs_mix2d(self, position_ids, device=None):
|
| 36 |
+
if device is None:
|
| 37 |
+
device = position_ids.device
|
| 38 |
+
|
| 39 |
+
if position_ids.dim() == 1:
|
| 40 |
+
position_ids = position_ids.unsqueeze(0)
|
| 41 |
+
|
| 42 |
+
cos_emb_all, sin_emb_all = [], []
|
| 43 |
+
for i in range(self.nd):
|
| 44 |
+
dim_i = self.split_dims[i]
|
| 45 |
+
base_i = self.bases[i]
|
| 46 |
+
max_len_i = self.max_lens[i]
|
| 47 |
+
if not hasattr(self, f"cos_{i}"):
|
| 48 |
+
_cos, _sin = self.generated_cos_sin_mix2d(max_len=max_len_i, dim=dim_i, device=device, base=base_i)
|
| 49 |
+
setattr(self, f"cos_{i}", _cos)
|
| 50 |
+
setattr(self, f"sin_{i}", _sin)
|
| 51 |
+
cos_emb_all.append(getattr(self, f'cos_{i}')[position_ids[i, :], :])
|
| 52 |
+
sin_emb_all.append(getattr(self, f'sin_{i}')[position_ids[i, :], :])
|
| 53 |
+
cos_emb = torch.cat(cos_emb_all, dim=-1)
|
| 54 |
+
sin_emb = torch.cat(sin_emb_all, dim=-1)
|
| 55 |
+
return cos_emb, sin_emb
|
| 56 |
+
|
| 57 |
+
def __call__(self, q, k, position_ids):
|
| 58 |
+
'''q: N N_head L C
|
| 59 |
+
'''
|
| 60 |
+
cos_emb, sin_emb = self.generate_pos_embs_mix2d(position_ids, device=q.device)
|
| 61 |
+
|
| 62 |
+
def rotate_half(x):
|
| 63 |
+
"""Rotates half the hidden dims of the input."""
|
| 64 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 65 |
+
x2 = x[..., x.shape[-1] // 2:]
|
| 66 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 67 |
+
|
| 68 |
+
def apply_rotary_pos_emb(q, k, cos, sin):
|
| 69 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
q (`torch.Tensor`): The query tensor.
|
| 73 |
+
k (`torch.Tensor`): The key tensor.
|
| 74 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 75 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 76 |
+
Returns:
|
| 77 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 78 |
+
"""
|
| 79 |
+
cos = cos.unsqueeze(0).unsqueeze(0)
|
| 80 |
+
sin = sin.unsqueeze(0).unsqueeze(0)
|
| 81 |
+
dtype = q.dtype
|
| 82 |
+
q = q.to(torch.float)
|
| 83 |
+
k = k.to(torch.float)
|
| 84 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 85 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 86 |
+
q_embed = q_embed.to(dtype)
|
| 87 |
+
k_embed = k_embed.to(dtype)
|
| 88 |
+
return q_embed, k_embed
|
| 89 |
+
|
| 90 |
+
q, k = apply_rotary_pos_emb(q, k, cos_emb, sin_emb)
|
| 91 |
+
return q, k
|
models/__pycache__/ACMDM.cpython-310.pyc
ADDED
|
Binary file (14.9 kB). View file
|
|
|
models/__pycache__/ACMDM.cpython-313.pyc
ADDED
|
Binary file (28.7 kB). View file
|
|
|
models/__pycache__/AE_2D_Causal.cpython-310.pyc
ADDED
|
Binary file (8.63 kB). View file
|
|
|
models/__pycache__/AE_2D_Causal.cpython-313.pyc
ADDED
|
Binary file (15.7 kB). View file
|
|
|
models/__pycache__/LengthEstimator.cpython-310.pyc
ADDED
|
Binary file (1.44 kB). View file
|
|
|
models/__pycache__/ROPE.cpython-310.pyc
ADDED
|
Binary file (3.85 kB). View file
|
|
|