sourxbhh commited on
Commit
0f34fb9
·
1 Parent(s): 1355f5d

Add model directory

Browse files
models/ACMDM.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import clip
5
+ import math
6
+ from functools import partial
7
+ from timm.models.vision_transformer import Attention
8
+ from models.ROPE import RopeND
9
+ from utils.eval_utils import eval_decorator
10
+ from utils.train_utils import lengths_to_mask
11
+ from diffusions.diffusion import create_diffusion
12
+ from diffusions.transport import create_transport, Sampler
13
+
14
+ #################################################################################
15
+ # ACMDM #
16
+ #################################################################################
17
+ class ACMDM(nn.Module):
18
+ def __init__(self, input_dim, cond_mode, latent_dim=256, ff_size=1024, num_layers=8,
19
+ num_heads=4, dropout=0, clip_dim=512,
20
+ diff_model='Flow', cond_drop_prob=0.1, max_length=49,
21
+ patch_size=(1, 22), stride_size=(1, 22), num_joint=22,
22
+ clip_version='ViT-B/32', **kargs):
23
+ super(ACMDM, self).__init__()
24
+
25
+ self.input_dim = input_dim
26
+ self.latent_dim = latent_dim
27
+ self.clip_dim = clip_dim
28
+ self.dropout = dropout
29
+
30
+ self.cond_mode = cond_mode
31
+ self.cond_drop_prob = cond_drop_prob
32
+
33
+ if self.cond_mode == 'action':
34
+ assert 'num_actions' in kargs
35
+ self.num_actions = kargs.get('num_actions', 1)
36
+ self.encode_action = partial(F.one_hot, num_classes=self.num_actions)
37
+ # --------------------------------------------------------------------------
38
+ # Diffusion
39
+ self.diff_model = diff_model
40
+ if self.diff_model == 'Flow':
41
+ self.train_diffusion = create_transport() # default to linear, velocity prediction
42
+ self.gen_diffusion = Sampler(self.train_diffusion)
43
+ else:
44
+ self.train_diffusion = create_diffusion(timestep_respacing="", noise_schedule="linear")
45
+ self.gen_diffusion = create_diffusion(timestep_respacing="", noise_schedule="linear")
46
+ # --------------------------------------------------------------------------
47
+ # ACMDM
48
+ print('Loading ACMDM...')
49
+ self.t_embedder = TimestepEmbedder(self.latent_dim)
50
+ self.patch_size = patch_size
51
+ self.stride_size = stride_size
52
+ self.patches_per_frame = (num_joint - patch_size[1]) // stride_size[1] + 1
53
+
54
+ # Patchification
55
+ self.x_embedder = nn.Conv2d(self.input_dim, self.latent_dim, kernel_size=self.patch_size, stride=self.stride_size, bias=True)
56
+
57
+ # Positional Encoding
58
+ max_length = max_length * self.patches_per_frame
59
+ self.max_lens = [max_length]
60
+ self.rope = RopeND(nd=1, nd_split=[1], max_lens=self.max_lens)
61
+ self.position_ids_precompute = torch.arange(max_length).unsqueeze(0)
62
+
63
+ self.ACMDMTransformer = nn.ModuleList([
64
+ ACMDMTransBlock(self.latent_dim, num_heads, mlp_size=ff_size, rope=self.rope, qk_norm=True) for _ in range(num_layers)
65
+ ])
66
+
67
+ if self.cond_mode == 'text':
68
+ self.y_embedder = nn.Linear(self.clip_dim, self.latent_dim)
69
+ elif self.cond_mode == 'action':
70
+ self.y_embedder = nn.Linear(self.num_actions, self.latent_dim)
71
+ elif self.cond_mode == 'uncond':
72
+ self.y_embedder = nn.Identity()
73
+ else:
74
+ raise KeyError("Unsupported condition mode!!!")
75
+
76
+ self.final_layer = FinalLayer(self.latent_dim, self.input_dim, patch_size=patch_size, stride_size=stride_size, patches=self.patches_per_frame, joint=num_joint)
77
+
78
+ self.initialize_weights()
79
+
80
+ if self.cond_mode == 'text':
81
+ print('Loading CLIP...')
82
+ self.clip_version = clip_version
83
+ self.clip_model = self.load_and_freeze_clip(clip_version)
84
+
85
+ def initialize_weights(self):
86
+ # Initialize transformer layers:
87
+ def _basic_init(module):
88
+ if isinstance(module, nn.Linear):
89
+ torch.nn.init.xavier_uniform_(module.weight)
90
+ if module.bias is not None:
91
+ nn.init.constant_(module.bias, 0)
92
+
93
+ self.apply(_basic_init)
94
+
95
+ # Initialize timestep embedding MLP:
96
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
97
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
98
+
99
+ # Zero-out adaLN modulation layers in ACMDM blocks:
100
+ for block in self.ACMDMTransformer:
101
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
102
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
103
+
104
+ # Zero-out output layers:
105
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
106
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
107
+ nn.init.constant_(self.final_layer.linear.weight, 0)
108
+ nn.init.constant_(self.final_layer.linear.bias, 0)
109
+
110
+ def load_and_freeze_clip(self, clip_version):
111
+ clip_model, clip_preprocess = clip.load(clip_version, device='cpu', jit=False)
112
+ assert torch.cuda.is_available()
113
+ clip.model.convert_weights(clip_model)
114
+
115
+ clip_model.eval()
116
+ for p in clip_model.parameters():
117
+ p.requires_grad = False
118
+ return clip_model
119
+
120
+ def encode_text(self, raw_text):
121
+ device = next(self.parameters()).device
122
+ text = clip.tokenize(raw_text, truncate=True).to(device)
123
+ feat_clip_text = self.clip_model.encode_text(text).float()
124
+ return feat_clip_text
125
+
126
+ def mask_cond(self, cond, force_mask=False):
127
+ bs, d = cond.shape
128
+ if force_mask:
129
+ return torch.zeros_like(cond)
130
+ elif self.training and self.cond_drop_prob > 0.:
131
+ mask = torch.bernoulli(torch.ones(bs, device=cond.device) * self.cond_drop_prob).view(bs, 1)
132
+ return cond * (1. - mask)
133
+ else:
134
+ return cond
135
+
136
+ def forward(self, x, t, conds, attention_mask, force_mask=False):
137
+ t = self.t_embedder(t, dtype=x.dtype)
138
+ conds = self.mask_cond(conds, force_mask=force_mask)
139
+ x = self.x_embedder(x)
140
+ x = x.flatten(2).transpose(1, 2)
141
+ conds = self.y_embedder(conds)
142
+ y = t.unsqueeze(1) + conds.unsqueeze(1)
143
+ position_ids = self.position_ids_precompute[:, :x.shape[1]]
144
+ for block in self.ACMDMTransformer:
145
+ x = block(x, y, attention_mask, position_ids=position_ids)
146
+ x = self.final_layer(x, y)
147
+ return x
148
+
149
+ def forward_with_CFG(self, x, t, conds, attention_mask, cfg=1.0):
150
+ if not cfg == 1.0:
151
+ half = x[: len(x) // 2]
152
+ x = torch.cat([half, half], dim=0)
153
+ x = self.forward(x, t, conds, attention_mask)
154
+ if not cfg == 1.0:
155
+ cond_eps, uncond_eps = torch.split(x, len(x) // 2, dim=0)
156
+ half_eps = uncond_eps + cfg * (cond_eps - uncond_eps)
157
+ x = torch.cat([half_eps, half_eps], dim=0)
158
+ return x
159
+
160
+ def forward_loss(self, latents, y, m_lens):
161
+ latents = latents.permute(0, 2, 3, 1)
162
+ b, l, j, d = latents.shape
163
+ device = latents.device
164
+
165
+ non_pad_mask = lengths_to_mask(m_lens, l)
166
+ latents = torch.where(non_pad_mask.unsqueeze(-1).unsqueeze(-1), latents, torch.zeros_like(latents))
167
+
168
+ target = latents.clone().permute(0, 3, 1, 2).detach()
169
+
170
+ force_mask = False
171
+ if self.cond_mode == 'text':
172
+ with torch.no_grad():
173
+ cond_vector = self.encode_text(y)
174
+ elif self.cond_mode == 'action':
175
+ cond_vector = self.enc_action(y).to(device).float()
176
+ elif self.cond_mode == 'uncond':
177
+ cond_vector = torch.zeros(b, self.latent_dim).float().to(device)
178
+ force_mask = True
179
+ else:
180
+ raise NotImplementedError("Unsupported condition mode!!!")
181
+
182
+ attention_mask = non_pad_mask.unsqueeze(-1).repeat(1, 1, self.patches_per_frame).flatten(1).unsqueeze(1).unsqueeze(1)
183
+
184
+ model_kwargs = dict(conds=cond_vector, force_mask=force_mask, attention_mask=attention_mask)
185
+ if self.diff_model == "Flow":
186
+ loss_dict = self.train_diffusion.training_losses(self.forward, target, model_kwargs)
187
+ else:
188
+ t = torch.randint(0, self.train_diffusion.num_timesteps, (target.shape[0],), device=target.device)
189
+ loss_dict = self.train_diffusion.training_losses(self.forward, target, t, model_kwargs)
190
+ loss = loss_dict["loss"]
191
+ loss = (loss * non_pad_mask).sum() / non_pad_mask.sum()
192
+
193
+ return loss
194
+
195
+ @torch.no_grad()
196
+ @eval_decorator
197
+ def generate(self,
198
+ conds,
199
+ m_lens,
200
+ cond_scale: int,
201
+ temperature=1,
202
+ j=22,
203
+ ):
204
+ device = next(self.parameters()).device
205
+ l = max(m_lens)
206
+ b = len(m_lens)
207
+
208
+ if self.cond_mode == 'text':
209
+ with torch.no_grad():
210
+ cond_vector = self.encode_text(conds)
211
+ elif self.cond_mode == 'action':
212
+ cond_vector = self.enc_action(conds).to(device)
213
+ elif self.cond_mode == 'uncond':
214
+ cond_vector = torch.zeros(b, self.latent_dim).float().to(device)
215
+ else:
216
+ raise NotImplementedError("Unsupported condition mode!!!")
217
+
218
+ padding_mask = ~lengths_to_mask(m_lens, l)
219
+
220
+ noise = torch.randn(b, self.input_dim, l, j).to(device)
221
+ if not cond_scale == 1.0:
222
+ cond_vector = torch.cat([cond_vector, torch.zeros_like(cond_vector)], dim=0)
223
+ noise = torch.cat([noise, noise], dim=0)
224
+
225
+ attention_mask = (~padding_mask).unsqueeze(-1).repeat(1,1,self.patches_per_frame).flatten(1).unsqueeze(1).unsqueeze(1)
226
+ model_kwargs = dict(conds=cond_vector, attention_mask=attention_mask, cfg=cond_scale)
227
+ sample_fn = self.forward_with_CFG
228
+
229
+ if not cond_scale == 1:
230
+ model_kwargs["attention_mask"] = attention_mask.repeat(2, 1, 1, 1)
231
+
232
+ if self.diff_model == "Flow":
233
+ model_fn = self.gen_diffusion.sample_ode() # default to ode sampling
234
+ sampled_token_latent = model_fn(noise, sample_fn, **model_kwargs)[-1]
235
+ else:
236
+ sampled_token_latent = self.gen_diffusion.p_sample_loop(
237
+ sample_fn, noise.shape, noise, clip_denoised=False, model_kwargs=model_kwargs,
238
+ progress=False,
239
+ temperature=temperature
240
+ )
241
+ if not cond_scale == 1:
242
+ sampled_token_latent, _ = sampled_token_latent.chunk(2, dim=0)
243
+ sampled_token_latent = sampled_token_latent.permute(0,2,3,1)
244
+
245
+ latents = torch.where(padding_mask.unsqueeze(-1).unsqueeze(-1), torch.zeros_like(sampled_token_latent), sampled_token_latent)
246
+ return latents.permute(0,3,1,2)
247
+
248
+ #################################################################################
249
+ # ACMDM Zoos #
250
+ #################################################################################
251
+ def acmdm_raw_flow_s_ps22(**kwargs):
252
+ layer = 8
253
+ return ACMDM(latent_dim=layer*64, ff_size=layer*64*4, num_layers=layer, num_heads=layer, dropout=0, clip_dim=512,
254
+ diff_model="Flow", cond_drop_prob=0.1, max_length=196,
255
+ patch_size=(1, 22), stride_size=(1, 22), **kwargs)
256
+ def acmdm_flow_s_ps22(**kwargs):
257
+ layer = 8
258
+ return ACMDM(latent_dim=layer*64, ff_size=layer*64*4, num_layers=layer, num_heads=layer, dropout=0, clip_dim=512,
259
+ diff_model="Flow", cond_drop_prob=0.1, max_length=49,
260
+ patch_size=(1, 22), stride_size=(1, 22), **kwargs)
261
+ def acmdm_flow_xl_ps2(**kwargs):
262
+ layer = 20
263
+ return ACMDM(latent_dim=layer*64, ff_size=layer*64*4, num_layers=layer, num_heads=layer, dropout=0, clip_dim=512,
264
+ diff_model="Flow", cond_drop_prob=0.1, max_length=49,
265
+ patch_size=(1, 2), stride_size=(1, 2), **kwargs)
266
+ def acmdm_mesh_flow_s_ps28(**kwargs):
267
+ layer = 8
268
+ return ACMDM(latent_dim=layer*64, ff_size=layer*64*4, num_layers=layer, num_heads=layer, dropout=0, clip_dim=512,
269
+ diff_model="Flow", cond_drop_prob=0.1, max_length=196, num_joint=28,
270
+ patch_size=(1, 28), stride_size=(1, 28), **kwargs)
271
+ ACMDM_models = {
272
+ 'ACMDM-Raw-Flow-S-PatchSize22': acmdm_raw_flow_s_ps22, 'ACMDM-Flow-S-PatchSize22': acmdm_flow_s_ps22,
273
+ 'ACMDM-Flow-XL-PatchSize2': acmdm_flow_xl_ps2, 'ACMDM-Mesh-Flow-S-PatchSize28': acmdm_mesh_flow_s_ps28,
274
+ }
275
+
276
+ #################################################################################
277
+ # Inner Architectures #
278
+ #################################################################################
279
+ def modulate(x, shift, scale):
280
+ return x * (1 + scale) + shift
281
+
282
+
283
+ class ACMDMAttention(Attention):
284
+ def __init__(
285
+ self,
286
+ dim,
287
+ num_heads=8,
288
+ qkv_bias=True,
289
+ rope=None,
290
+ qk_norm=True,
291
+ **block_kwargs,
292
+ ):
293
+ super().__init__(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm, **block_kwargs)
294
+ self.rope = rope
295
+
296
+ def forward(self, x, position_ids=None, attention_mask=None):
297
+ B, N, C = x.shape
298
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
299
+ q, k, v = qkv.unbind(0)
300
+ q, k = self.q_norm(q), self.k_norm(k)
301
+
302
+ if self.rope is not None:
303
+ q, k = self.rope(q, k, position_ids)
304
+
305
+ x = torch.nn.functional.scaled_dot_product_attention(
306
+ q, k, v,
307
+ attn_mask=attention_mask,
308
+ dropout_p=self.attn_drop.p
309
+ )
310
+ x = x.transpose(1, 2).reshape(B, N, C)
311
+ x = self.proj(x)
312
+ x = self.proj_drop(x)
313
+ return x
314
+
315
+
316
+ class SwiGLUFFN(nn.Module):
317
+ def __init__(
318
+ self,
319
+ in_features: int,
320
+ hidden_features,
321
+ bias: bool = True,
322
+ ) -> None:
323
+ super().__init__()
324
+ out_features = in_features
325
+ hidden_features = hidden_features
326
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
327
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
328
+
329
+ def forward(self, x):
330
+ x12 = self.w12(x)
331
+ x1, x2 = x12.chunk(2, dim=-1)
332
+ hidden = F.silu(x1) * x2
333
+ return self.w3(hidden)
334
+
335
+
336
+ class ACMDMTransBlock(nn.Module):
337
+ def __init__(self, hidden_size, num_heads, mlp_size=1024, rope=None, qk_norm=True):
338
+ super().__init__()
339
+ self.norm1 = LlamaRMSNorm(hidden_size, eps=1e-6)
340
+ self.attn = ACMDMAttention(hidden_size, num_heads=num_heads, qkv_bias=True, norm_layer=LlamaRMSNorm,
341
+ qk_norm=qk_norm, rope=rope)
342
+ self.norm2 = LlamaRMSNorm(hidden_size, eps=1e-6)
343
+ self.mlp = SwiGLUFFN(hidden_size, int(2 / 3 * mlp_size))
344
+ self.adaLN_modulation = nn.Sequential(
345
+ nn.SiLU(),
346
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True)
347
+ )
348
+
349
+ def forward(self, x, c, attention_mask=None, position_ids=None):
350
+ dtype = x.dtype
351
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
352
+ norm_x1 = self.norm1(x.to(torch.float32)).to(dtype)
353
+ attn_input_x = modulate(norm_x1, shift_msa, scale_msa)
354
+ attn_output_x = self.attn(attn_input_x, attention_mask=attention_mask, position_ids=position_ids)
355
+ x = x + gate_msa * attn_output_x
356
+
357
+ norm_x2 = self.norm2(x.to(torch.float32)).to(dtype)
358
+ gate_input_x = modulate(norm_x2, shift_mlp, scale_mlp)
359
+ gate_output_x = self.mlp(gate_input_x)
360
+ x = x + gate_mlp * gate_output_x
361
+ return x
362
+
363
+
364
+ class FinalLayer(nn.Module):
365
+ def __init__(self, hidden_size, output_size, patch_size=(1, 22), stride_size=(1,22), patches=1, joint=22):
366
+ super().__init__()
367
+ self.norm_final = LlamaRMSNorm(hidden_size, eps=1e-6)
368
+ self.patch_size = patch_size
369
+ self.stride_size = stride_size
370
+ self.patches = patches
371
+ self.joint=joint
372
+ self.linear = nn.Linear(hidden_size, output_size*patch_size[0]*patch_size[1], bias=True)
373
+ self.adaLN_modulation = nn.Sequential(
374
+ nn.SiLU(),
375
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True)
376
+ )
377
+
378
+ def forward(self, x, c):
379
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
380
+ norm_x = self.norm_final(x.to(torch.float32)).to(x.dtype)
381
+ x = modulate(norm_x, shift, scale)
382
+ x = self.linear(x)
383
+ x = x.reshape(shape=(x.shape[0], x.shape[1]//self.patches, self.patches, self.patch_size[0], self.patch_size[1], x.shape[-1] // self.patch_size[1]))
384
+ x = torch.einsum('nljpqc->nclpjq', x)
385
+ x = x.reshape(shape=(x.shape[0], x.shape[1], -1, self.joint))
386
+ return x
387
+
388
+
389
+ class TimestepEmbedder(nn.Module):
390
+ def __init__(self, hidden_size, frequency_embedding_size=256):
391
+ super().__init__()
392
+ self.mlp = nn.Sequential(
393
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
394
+ nn.SiLU(),
395
+ nn.Linear(hidden_size, hidden_size, bias=True),
396
+ )
397
+ self.frequency_embedding_size = frequency_embedding_size
398
+
399
+ @staticmethod
400
+ def timestep_embedding(t, dim, max_period=10000, dtype=torch.float32):
401
+ """
402
+ Create sinusoidal timestep embeddings.
403
+ :param t: a 1-D Tensor of N indices, one per batch element.
404
+ These may be fractional.
405
+ :param dim: the dimension of the output.
406
+ :param max_period: controls the minimum frequency of the embeddings.
407
+ :return: an (N, D) Tensor of positional embeddings.
408
+ """
409
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
410
+ half = dim // 2
411
+ freqs = torch.exp(
412
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=dtype) / half
413
+ ).to(device=t.device, dtype=dtype)
414
+ args = t[:, None] * freqs[None]
415
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
416
+ if dim % 2:
417
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
418
+ return embedding
419
+
420
+ def forward(self, t, dtype=torch.bfloat16):
421
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size, dtype=dtype)
422
+ t_emb = self.mlp(t_freq)
423
+ return t_emb
424
+
425
+
426
+ class LlamaRMSNorm(nn.Module):
427
+ def __init__(self, hidden_size, eps=1e-6):
428
+ super().__init__()
429
+ self.weight = nn.Parameter(torch.ones(hidden_size))
430
+ self.variance_epsilon = eps
431
+
432
+ def forward(self, hidden_states):
433
+ input_dtype = hidden_states.dtype
434
+ hidden_states = hidden_states.to(torch.float32)
435
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
436
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
437
+ return (self.weight * hidden_states).to(input_dtype)
models/ACMDM_ControlNet.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from models.ACMDM import ACMDM
4
+ from models.ACMDM import TimestepEmbedder, ACMDMTransBlock, LlamaRMSNorm
5
+ from models.ROPE import RopeND
6
+ from utils.eval_utils import eval_decorator
7
+ from utils.train_utils import lengths_to_mask
8
+
9
+
10
+ #################################################################################
11
+ # ACMDM+ControlNet #
12
+ #################################################################################
13
+ class ACMDM_ControlNet(ACMDM):
14
+ def __init__(self, input_dim, cond_mode, base_checkpoint, latent_dim=256, ff_size=1024, num_layers=8,
15
+ num_heads=4, dropout=0.2, clip_dim=512,
16
+ diff_model='Flow', cond_drop_prob=0.1, max_length=49,
17
+ patch_size=(1, 22), stride_size=(1, 22),
18
+ clip_version='ViT-B/32', freeze_base=True, need_base=True, **kargs):
19
+ # --------------------------------------------------------------------------
20
+ # ACMDM
21
+ super().__init__(input_dim, cond_mode, latent_dim=latent_dim, ff_size=ff_size, num_layers=num_layers,
22
+ num_heads=num_heads, dropout=dropout, clip_dim=clip_dim,
23
+ diff_model=diff_model, cond_drop_prob=cond_drop_prob, max_length=max_length,
24
+ patch_size=patch_size, stride_size=stride_size,
25
+ clip_version=clip_version, **kargs)
26
+
27
+ # --------------------------------------------------------------------------
28
+ # ControlNet
29
+ self.c_t_embedder = TimestepEmbedder(self.latent_dim)
30
+ self.c_control_embedder = c_control_embedder(3, self.latent_dim, patch_size=self.patch_size,
31
+ stride_size=self.stride_size)
32
+ self.c_x_embedder = nn.Conv2d(self.input_dim, self.latent_dim, kernel_size=self.patch_size,
33
+ stride=self.stride_size, bias=True)
34
+ self.c_y_embedder = nn.Linear(self.clip_dim, self.latent_dim)
35
+ self.c_rope = RopeND(nd=1, nd_split=[1], max_lens=self.max_lens)
36
+ self.ControlNet = nn.ModuleList([
37
+ ACMDMTransBlock(self.latent_dim, num_heads, mlp_size=ff_size, rope=self.c_rope, qk_norm=True) for _ in
38
+ range(num_layers)
39
+ ])
40
+ self.zero_Linear = nn.ModuleList([
41
+ nn.Linear(self.latent_dim, self.latent_dim) for _ in range(num_layers)
42
+ ])
43
+ self.initialize_weights_control()
44
+ if need_base:
45
+ for key, value in list(base_checkpoint['ema_acmdm'].items()):
46
+ if key.startswith('ACMDMTransformer.'):
47
+ new_key = key.replace('ACMDMTransformer.', 'ControlNet.')
48
+ base_checkpoint['ema_acmdm'][new_key] = value.clone()
49
+ missing_keys, unexpected_keys = self.load_state_dict(base_checkpoint['ema_acmdm'], strict=False)
50
+ assert len(unexpected_keys) == 0
51
+
52
+ if self.cond_mode == 'text':
53
+ print('ReLoading CLIP...')
54
+ self.clip_version = clip_version
55
+ self.clip_model = self.load_and_freeze_clip(clip_version)
56
+
57
+ if freeze_base:
58
+ for param in self.t_embedder.parameters():
59
+ param.requires_grad = False
60
+ for param in self.x_embedder.parameters():
61
+ param.requires_grad = False
62
+ for param in self.y_embedder.parameters():
63
+ param.requires_grad = False
64
+ for param in self.final_layer.parameters():
65
+ param.requires_grad = False
66
+ for param in self.ACMDMTransformer.parameters():
67
+ param.requires_grad = False
68
+
69
+ def initialize_weights_control(self):
70
+ # Initialize transformer layers:
71
+ def _basic_init(module):
72
+ if isinstance(module, nn.Linear):
73
+ torch.nn.init.xavier_uniform_(module.weight)
74
+ if module.bias is not None:
75
+ nn.init.constant_(module.bias, 0)
76
+
77
+ self.apply(_basic_init)
78
+
79
+ # Initialize timestep embedding MLP:
80
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
81
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
82
+
83
+ # Zero-out adaLN modulation layers in DiT blocks:
84
+ for block in self.ACMDMTransformer:
85
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
86
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
87
+
88
+ # Zero-out output layers:
89
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
90
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
91
+ nn.init.constant_(self.final_layer.linear.weight, 0)
92
+ nn.init.constant_(self.final_layer.linear.bias, 0)
93
+
94
+ # Initialize timestep embedding MLP:
95
+ nn.init.normal_(self.c_t_embedder.mlp[0].weight, std=0.02)
96
+ nn.init.normal_(self.c_t_embedder.mlp[2].weight, std=0.02)
97
+
98
+ # Zero-out adaLN modulation layers in DiT blocks:
99
+ for block in self.ControlNet:
100
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
101
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
102
+
103
+ nn.init.constant_(self.c_control_embedder.zero_linear.weight, 0)
104
+ nn.init.constant_(self.c_control_embedder.zero_linear.bias, 0)
105
+
106
+ for block in self.zero_Linear:
107
+ nn.init.constant_(block.weight, 0)
108
+ nn.init.constant_(block.bias, 0)
109
+
110
+ def forward_with_control(self, x, t, conds, attention_mask, cfg1=1.0, cfg2=1.0, control=None, index=None,
111
+ force_mask=False):
112
+ if not (cfg1 == 1.0 and cfg2 == 1.0):
113
+ half = x[: len(x) // 3]
114
+ x = torch.cat([half, half, half], dim=0)
115
+ # controlnet
116
+ c_t = self.c_t_embedder(t, dtype=x.dtype)
117
+ conds = self.mask_cond(conds, force_mask=force_mask)
118
+ c_control = self.c_control_embedder(control * index)
119
+ if self.training and self.cond_drop_prob > 0.:
120
+ mask = torch.bernoulli(torch.ones(c_control.shape[0], device=c_control.device) * self.cond_drop_prob).view(c_control.shape[0], 1, 1)
121
+ c_control = c_control * (1. - mask)
122
+ if not (cfg1 == 1.0 and cfg2 == 1.0):
123
+ c_control = torch.cat([c_control, c_control, torch.zeros_like(c_control)], dim=0)
124
+ c_x = self.c_x_embedder(x).flatten(2).transpose(1, 2)
125
+ c_y = self.c_y_embedder(conds)
126
+ c_y = c_t.unsqueeze(1) + c_y.unsqueeze(1)
127
+ c_x = c_x + c_control
128
+ c_position_ids = self.position_ids_precompute[:, :c_x.shape[1]]
129
+ c_out = []
130
+ for c_block, c_linear in zip(self.ControlNet, self.zero_Linear):
131
+ c_x = c_block(c_x, c_y, attention_mask, position_ids=c_position_ids)
132
+ c_out.append(c_linear(c_x))
133
+ # main branch
134
+ tt = self.t_embedder(t, dtype=x.dtype)
135
+ x = self.x_embedder(x)
136
+ x = x.flatten(2).transpose(1, 2)
137
+ conds = self.y_embedder(conds)
138
+ y = tt.unsqueeze(1) + conds.unsqueeze(1)
139
+ position_ids = self.position_ids_precompute[:, :x.shape[1]]
140
+ # merging
141
+ for block, c in zip(self.ACMDMTransformer, c_out):
142
+ x = block(x, y, attention_mask, position_ids=position_ids)
143
+ x = x + c
144
+ x = self.final_layer(x, y)
145
+ if not (cfg1 == 1.0 and cfg2 == 1.0):
146
+ cond_eps, uncond_eps1, uncond_eps2 = torch.split(x, len(x) // 3, dim=0)
147
+ half_eps = cond_eps + (cfg1-1) * (cond_eps - uncond_eps1) + (cfg2-1) * (cond_eps - uncond_eps2)
148
+ x = torch.cat([half_eps, half_eps, half_eps], dim=0)
149
+ return x
150
+
151
+ def forward_control_loss(self, latents, y, m_lens, original, index, ae, mean_std):
152
+ latents = latents.permute(0, 2, 3, 1)
153
+ b, l, j, d = latents.shape
154
+ device = latents.device
155
+
156
+ non_pad_mask = lengths_to_mask(m_lens, l)
157
+ latents = torch.where(non_pad_mask.unsqueeze(-1).unsqueeze(-1), latents, torch.zeros_like(latents))
158
+
159
+ target = latents.clone().permute(0, 3, 1, 2).detach()
160
+ original = original.clone().detach()
161
+
162
+ force_mask = False
163
+ if self.cond_mode == 'text':
164
+ with torch.no_grad():
165
+ cond_vector = self.encode_text(y)
166
+ elif self.cond_mode == 'action':
167
+ cond_vector = self.enc_action(y).to(device).float()
168
+ elif self.cond_mode == 'uncond':
169
+ cond_vector = torch.zeros(b, self.latent_dim).float().to(device)
170
+ force_mask = True
171
+ else:
172
+ raise NotImplementedError("Unsupported condition mode!!!")
173
+
174
+ attention_mask = non_pad_mask.unsqueeze(-1).repeat(1, 1, self.patches_per_frame).flatten(1).unsqueeze(1).unsqueeze(1)
175
+
176
+ random_indices = torch.randint(0, len(index), (b,)).to(device)
177
+ indexx = torch.tensor(index, device=device)[random_indices]
178
+ mask_seq = torch.zeros((b, 3, l*4, j), device=device)
179
+ for i in range(b):
180
+ seq_num = torch.randint(1, m_lens[i]*4, (1,))
181
+ choose_seq = torch.sort(torch.randperm(m_lens[i]*4)[:seq_num.item()]).values
182
+ mask_seq[i, :, choose_seq, indexx[i]] = 1.0
183
+
184
+ model_kwargs = dict(conds=cond_vector, attention_mask=attention_mask, control=original, index=mask_seq,
185
+ force_mask=force_mask, mean_std=mean_std)
186
+ if self.diff_model == "Flow":
187
+ loss_dict = self.train_diffusion.training_losses(self.forward_with_control, target, ae=ae,
188
+ model_kwargs=model_kwargs)
189
+ else:
190
+ t = torch.randint(0, self.train_diffusion.num_timesteps, (target.shape[0],), device=target.device)
191
+ loss_dict = self.train_diffusion.training_losses(self.forward_with_control, target, t, model_kwargs)
192
+ loss = loss_dict["loss"]
193
+ loss = (loss * non_pad_mask).sum() / non_pad_mask.sum()
194
+
195
+ return loss, loss_dict["loss_control"]
196
+
197
+
198
+ @torch.no_grad()
199
+ @eval_decorator
200
+ def generate_control(self,
201
+ conds,
202
+ m_lens,
203
+ control,
204
+ index,
205
+ density,
206
+ cond_scale,
207
+ temperature=1,
208
+ j=22
209
+ ):
210
+ device = next(self.parameters()).device
211
+ l = control.shape[2]//4
212
+ b = len(m_lens)
213
+
214
+ if self.cond_mode == 'text':
215
+ with torch.no_grad():
216
+ cond_vector = self.encode_text(conds)
217
+ elif self.cond_mode == 'action':
218
+ cond_vector = self.enc_action(conds).to(device)
219
+ elif self.cond_mode == 'uncond':
220
+ cond_vector = torch.zeros(b, self.latent_dim).float().to(device)
221
+ else:
222
+ raise NotImplementedError("Unsupported condition mode!!!")
223
+
224
+ padding_mask = ~lengths_to_mask(m_lens, l)
225
+
226
+ noise = torch.randn(b, self.input_dim, l, j).to(device)
227
+ control = control.clone()
228
+ cfg1 = cond_scale[0]
229
+ cfg2 = cond_scale[1]
230
+ if not (cfg1 == 1.0 and cfg2 == 1.0):
231
+ # (1) with text and with control (2) no text and with control (3) with text and no control
232
+ cond_vector = torch.cat([cond_vector, torch.zeros_like(cond_vector), cond_vector], dim=0)
233
+
234
+ random_indices = torch.tensor(0, device=device).repeat(b) # no random in inference
235
+ indexx = torch.tensor(index, device=device)[random_indices]
236
+ mask_seq = torch.zeros((b, 3, l * 4, j), device=device)
237
+ for i in range(b):
238
+ if density in [1, 2, 5]:
239
+ seq_num = density
240
+ else:
241
+ seq_num = int(m_lens[i] *4* density / 100)
242
+ choose_seq = torch.sort(torch.randperm(m_lens[i] * 4)[:seq_num]).values
243
+ mask_seq[i, :, choose_seq, indexx[i]] = 1.0
244
+
245
+ attention_mask = (~padding_mask).unsqueeze(-1).repeat(1, 1, self.patches_per_frame).flatten(1).unsqueeze(1).unsqueeze(1)
246
+ model_kwargs = dict(conds=cond_vector, attention_mask=attention_mask, cfg1=cfg1, cfg2=cfg2, index=mask_seq,
247
+ control=control)
248
+ sample_fn = self.forward_with_control
249
+
250
+ if not (cfg1 == 1.0 and cfg2 == 1.0):
251
+ model_kwargs["attention_mask"] = attention_mask.repeat(3, 1, 1, 1)
252
+ noise = torch.cat([noise, noise, noise], dim=0)
253
+
254
+ if self.diff_model == "Flow":
255
+ model_fn = self.gen_diffusion.sample_ode() # default to ode sampling
256
+ sampled_token_latent = model_fn(noise, sample_fn, **model_kwargs)[-1]
257
+ else:
258
+ sampled_token_latent = self.gen_diffusion.p_sample_loop(
259
+ sample_fn, noise.shape, noise, clip_denoised=False, model_kwargs=model_kwargs,
260
+ progress=False,
261
+ temperature=temperature
262
+ )
263
+ if not (cfg1 == 1.0 and cfg2 == 1.0):
264
+ sampled_token_latent, _, _ = sampled_token_latent.chunk(3, dim=0)
265
+ sampled_token_latent = sampled_token_latent.permute(0, 2, 3, 1)
266
+
267
+ latents = torch.where(padding_mask.unsqueeze(-1).unsqueeze(-1), torch.zeros_like(sampled_token_latent),
268
+ sampled_token_latent)
269
+ return latents.permute(0, 3, 1, 2), mask_seq
270
+
271
+ #################################################################################
272
+ # ACMDM Zoos #
273
+ #################################################################################
274
+ def acmdm_raw_flow_s_ps22_control(**kwargs):
275
+ layer = 8
276
+ return ACMDM_ControlNet(latent_dim=layer*64, ff_size=layer*64*4, num_layers=layer, num_heads=layer, dropout=0, clip_dim=512,
277
+ diff_model="Flow", cond_drop_prob=0.1, max_length=49,
278
+ patch_size=(1, 22), stride_size=(1, 22), freeze_base=True, **kwargs)
279
+
280
+
281
+ ACMDM_ControlNet_Models = {
282
+ 'ACMDM-Flow-S-PatchSize22-ControlNet': acmdm_raw_flow_s_ps22_control,
283
+ }
284
+
285
+ #################################################################################
286
+ # Inner Architectures #
287
+ #################################################################################
288
+ def modulate(x, shift, scale):
289
+ return x * (1 + scale) + shift
290
+
291
+
292
+ def zero_module(module):
293
+ for p in module.parameters():
294
+ p.detach().zero_()
295
+ return module
296
+
297
+ class c_control_embedder(nn.Module):
298
+ def __init__(
299
+ self,
300
+ in_features: int,
301
+ hidden_features,
302
+ patch_size,
303
+ stride_size,
304
+ ) -> None:
305
+ super().__init__()
306
+ self.patch_embed = nn.Conv2d(in_features, hidden_features, kernel_size=(4,patch_size[1]), stride=(4,stride_size[1]), bias=True)
307
+ self.norm = LlamaRMSNorm(hidden_features, eps=1e-6)
308
+ self.zero_linear = nn.Linear(hidden_features, hidden_features)
309
+
310
+ def forward(self, x):
311
+ x = self.patch_embed(x).flatten(2).transpose(1, 2)
312
+ x = self.norm(x)
313
+ x = self.zero_linear(x)
314
+ return x
models/ACMDM_NoisyPrefix_AR.py ADDED
@@ -0,0 +1,556 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import clip
5
+ import math
6
+ from functools import partial
7
+ from timm.models.vision_transformer import Attention
8
+ from models.ROPE import RopeND
9
+ from utils.eval_utils import eval_decorator
10
+ from utils.train_utils import lengths_to_mask
11
+ from diffusions.diffusion import create_diffusion
12
+ from diffusions.transport import create_transport, Sampler
13
+
14
+ #################################################################################
15
+ # ACMDM #
16
+ #################################################################################
17
+ class ACMDM(nn.Module):
18
+ def __init__(self, input_dim, cond_mode, latent_dim=256, ff_size=1024, num_layers=8,
19
+ num_heads=4, dropout=0, clip_dim=512,
20
+ diff_model='Flow', cond_drop_prob=0.1, max_length=49,
21
+ patch_size=(1, 22), stride_size=(1, 22), num_joint=22, cluster=5,
22
+ clip_version='ViT-B/32', **kargs):
23
+ super(ACMDM, self).__init__()
24
+
25
+ self.input_dim = input_dim
26
+ self.latent_dim = latent_dim
27
+ self.clip_dim = clip_dim
28
+ self.dropout = dropout
29
+ self.cluster = cluster
30
+
31
+ self.cond_mode = cond_mode
32
+ self.cond_drop_prob = cond_drop_prob
33
+
34
+ if self.cond_mode == 'action':
35
+ assert 'num_actions' in kargs
36
+ self.num_actions = kargs.get('num_actions', 1)
37
+ self.encode_action = partial(F.one_hot, num_classes=self.num_actions)
38
+ # --------------------------------------------------------------------------
39
+ # Diffusion
40
+ self.diff_model = diff_model
41
+ if self.diff_model == 'Flow':
42
+ self.train_diffusion = create_transport() # default to linear, velocity prediction
43
+ self.gen_diffusion = Sampler(self.train_diffusion)
44
+ else:
45
+ self.train_diffusion = create_diffusion(timestep_respacing="", noise_schedule="linear")
46
+ self.gen_diffusion = create_diffusion(timestep_respacing="", noise_schedule="linear")
47
+ # --------------------------------------------------------------------------
48
+ # ACMDM
49
+ print('Loading ACMDM...')
50
+ self.t_embedder = TimestepEmbedder(self.latent_dim)
51
+ self.patch_size = patch_size
52
+ self.stride_size = stride_size
53
+ self.patches_per_frame = (num_joint - patch_size[1]) // stride_size[1] + 1
54
+
55
+ # Patchification
56
+ self.x_embedder = nn.Linear(self.input_dim*self.patch_size[0]*self.patch_size[1], self.latent_dim, bias=True)
57
+
58
+ # Positional Encoding
59
+ max_length = max_length * self.patches_per_frame
60
+ self.max_lens = [max_length]
61
+ self.rope = RopeND(nd=1, nd_split=[1], max_lens=self.max_lens)
62
+ self.position_ids_precompute = torch.arange(max_length).unsqueeze(0)
63
+ self.cluster_patches = max_length // self.cluster
64
+
65
+ self.ACMDMTransformer = nn.ModuleList([
66
+ ACMDMTransBlock(self.latent_dim, num_heads, mlp_size=ff_size, rope=self.rope, qk_norm=True) for _ in range(num_layers)
67
+ ])
68
+
69
+ if self.cond_mode == 'text':
70
+ self.y_embedder = nn.Linear(self.clip_dim, self.latent_dim)
71
+ elif self.cond_mode == 'action':
72
+ self.y_embedder = nn.Linear(self.num_actions, self.latent_dim)
73
+ elif self.cond_mode == 'uncond':
74
+ self.y_embedder = nn.Identity()
75
+ else:
76
+ raise KeyError("Unsupported condition mode!!!")
77
+
78
+ self.final_layer = FinalLayer(self.latent_dim, self.input_dim*self.patch_size[0]*self.patch_size[1])
79
+
80
+ self.initialize_weights()
81
+
82
+ if self.cond_mode == 'text':
83
+ print('Loading CLIP...')
84
+ self.clip_version = clip_version
85
+ self.clip_model = self.load_and_freeze_clip(clip_version)
86
+
87
+ attention_mask = []
88
+ start = 0
89
+ total_length = max_length
90
+ for idx in range(max_length):
91
+ if idx in [self.cluster_patches * i for i in range(self.cluster)]:
92
+ start += self.cluster_patches * self.patches_per_frame
93
+ attention_mask.append(torch.cat([torch.ones((1, start)),
94
+ torch.zeros((1, total_length - start))], dim=-1))
95
+ attention_mask = torch.cat(attention_mask, dim=0)
96
+ attention_mask = torch.where(attention_mask == 0, -torch.inf, attention_mask)
97
+ attention_mask = torch.where(attention_mask == 1, 0, attention_mask)
98
+ attention_mask = attention_mask.unsqueeze(0).unsqueeze(0)
99
+ self.register_buffer('attention_mask', attention_mask.contiguous())
100
+
101
+ def initialize_weights(self):
102
+ # Initialize transformer layers:
103
+ def _basic_init(module):
104
+ if isinstance(module, nn.Linear):
105
+ torch.nn.init.xavier_uniform_(module.weight)
106
+ if module.bias is not None:
107
+ nn.init.constant_(module.bias, 0)
108
+
109
+ self.apply(_basic_init)
110
+
111
+ # Initialize timestep embedding MLP:
112
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
113
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
114
+
115
+ # Zero-out adaLN modulation layers in ACMDM blocks:
116
+ for block in self.ACMDMTransformer:
117
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
118
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
119
+
120
+ # Zero-out output layers:
121
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
122
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
123
+ nn.init.constant_(self.final_layer.linear.weight, 0)
124
+ nn.init.constant_(self.final_layer.linear.bias, 0)
125
+
126
+ def load_and_freeze_clip(self, clip_version):
127
+ clip_model, clip_preprocess = clip.load(clip_version, device='cpu', jit=False)
128
+ assert torch.cuda.is_available()
129
+ clip.model.convert_weights(clip_model)
130
+
131
+ clip_model.eval()
132
+ for p in clip_model.parameters():
133
+ p.requires_grad = False
134
+ return clip_model
135
+
136
+ def encode_text(self, raw_text):
137
+ device = next(self.parameters()).device
138
+ text = clip.tokenize(raw_text, truncate=True).to(device)
139
+ feat_clip_text = self.clip_model.encode_text(text).float()
140
+ return feat_clip_text
141
+
142
+ def mask_cond(self, cond, force_mask=False):
143
+ bs, d = cond.shape
144
+ if force_mask:
145
+ return torch.zeros_like(cond)
146
+ elif self.training and self.cond_drop_prob > 0.:
147
+ mask = torch.bernoulli(torch.ones(bs, device=cond.device) * self.cond_drop_prob).view(bs, 1)
148
+ return cond * (1. - mask)
149
+ else:
150
+ return cond
151
+
152
+ def patchify(self, x):
153
+ b, c, l, j = x.shape
154
+ p = self.patch_size[0]
155
+ q = self.patch_size[1]
156
+ l_, j_ = l // p, j // q
157
+
158
+ x = x.reshape(b, c, l_, p, j_, q)
159
+ x = torch.einsum('nclpjq->nljcpq', x)
160
+ x = x.reshape(b, l_ * j_, c * p *q)
161
+ return x
162
+
163
+ def patchify_mask(self, mask):
164
+ b, l = mask.shape
165
+ p = self.patch_size[0]
166
+ l_ = l//self.patch_size[0]
167
+ q = self.patch_size[1]
168
+ j_ = self.patches_per_frame
169
+ mask = mask.unsqueeze(1).unsqueeze(-1).expand(-1, self.input_dim, -1, j_*q)
170
+ mask = mask.reshape(b, self.input_dim, l_, p, j_, q)
171
+ mask = torch.einsum('nclpjq->nljcpq', mask)
172
+ mask = mask.reshape(b, l_ * j_, self.input_dim*p * q)
173
+ mask = mask.any(dim=-1)
174
+ return mask
175
+
176
+ def unpatchify(self, x):
177
+ b = x.shape[0]
178
+ p = self.patch_size[0]
179
+ q = self.patch_size[1]
180
+ c = self.input_dim
181
+ l_, j_ = x.shape[1]//self.patches_per_frame, self.patches_per_frame
182
+
183
+ x = x.reshape(b, l_, j_, c, p, q)
184
+ x = torch.einsum('nljcpq->nclpjq', x)
185
+ x = x.reshape(b, c, l_ * p, j_ * q)
186
+ return x
187
+
188
+ def forward(self, x, t, conds, attention_mask, force_mask=False, ids=None, block_size=None, cache=False):
189
+ t = self.t_embedder(t, dtype=x.dtype).unsqueeze(1).repeat(1, self.cluster_patches * self.patches_per_frame, 1)
190
+ t = t.chunk(self.cluster, dim=0)
191
+ t = torch.cat(t, dim=1)
192
+ conds = self.mask_cond(conds, force_mask=force_mask)
193
+ x = x.chunk(self.cluster, dim=0)
194
+ x = torch.cat(x, dim=1)
195
+ x = self.x_embedder(x)
196
+ conds = self.y_embedder(conds)
197
+ y = t + conds.unsqueeze(1)
198
+ if ids is not None:
199
+ position_ids = ids
200
+ else:
201
+ position_ids = self.position_ids_precompute[:, :x.shape[1]]
202
+ for block in self.ACMDMTransformer:
203
+ x = block(x, y, attention_mask, position_ids=position_ids, block_size=block_size, cache=cache)
204
+ x = self.final_layer(x, y)
205
+ x = x.chunk(self.cluster, dim=1)
206
+ x = torch.cat(x, dim=0)
207
+ return x
208
+
209
+ def forward_with_CFG(self, x, t, conds, attention_mask, cfg=1.0, context=None, cache=True, block_id=0):
210
+ if cache:
211
+ if self.ACMDMTransformer[0].attn.cached_k is None:
212
+ cache = True
213
+ elif block_id * self.cluster_patches == self.ACMDMTransformer[0].attn.cached_k.shape[2]:
214
+ cache = False
215
+ if not cfg == 1.0:
216
+ half = x[: len(x) // 2]
217
+ x = torch.cat([half, half], dim=0)
218
+ if context is not None and cache:
219
+ ids = self.position_ids_precompute[:, (block_id - 1) * self.cluster_patches * self.patches_per_frame:(block_id + 1) * self.cluster_patches * self.patches_per_frame]
220
+ x = torch.cat([context, x], dim=1)
221
+ t = torch.cat([torch.ones_like(t).unsqueeze(-1).repeat(1, self.patches_per_frame * self.cluster_patches),
222
+ t.unsqueeze(-1).repeat(1, self.patches_per_frame * self.cluster_patches)], dim=1)
223
+ am_idx = block_id if block_id == 0 else block_id - 1
224
+ attention_mask = attention_mask[:, :, am_idx * self.cluster_patches * self.patches_per_frame: (block_id + 1) * self.cluster_patches * self.patches_per_frame,
225
+ :(block_id + 1) * self.cluster_patches * self.patches_per_frame]
226
+ else:
227
+ ids = self.position_ids_precompute[:,
228
+ (block_id) * self.cluster_patches * self.patches_per_frame:(block_id + 1) * self.cluster_patches * self.patches_per_frame]
229
+ t = t.unsqueeze(-1).repeat(1, self.patches_per_frame * self.cluster_patches)
230
+ attention_mask = attention_mask[:, :, :(block_id + 1) * self.cluster_patches * self.patches_per_frame,
231
+ :(block_id + 1) * self.cluster_patches * self.patches_per_frame]
232
+ attention_mask = attention_mask[:, :, -self.patches_per_frame * self.cluster_patches:, :]
233
+ t = t.reshape(-1)
234
+ t = self.t_embedder(t, dtype=x.dtype)
235
+ t = t.reshape(x.shape[0], x.shape[1], -1)
236
+ conds = self.mask_cond(conds)
237
+ x = self.x_embedder(x)
238
+ conds = self.y_embedder(conds)
239
+ y = t + conds.unsqueeze(1)
240
+ position_ids = ids
241
+ for block in self.ACMDMTransformer:
242
+ x = block(x, y, attention_mask, position_ids=position_ids, block_size=self.patches_per_frame * self.cluster_patches,
243
+ cache=cache)
244
+ x = self.final_layer(x, y)
245
+ x = x[:, -self.patches_per_frame * self.cluster_patches:, :]
246
+ if not cfg == 1.0:
247
+ cond_eps, uncond_eps = torch.split(x, len(x) // 2, dim=0)
248
+ half_eps = uncond_eps + cfg * (cond_eps - uncond_eps)
249
+ x = torch.cat([half_eps, half_eps], dim=0)
250
+ return x
251
+
252
+ def forward_loss(self, latents, y, m_lens):
253
+ b, d, l, j = latents.shape
254
+ device = latents.device
255
+
256
+ non_pad_mask = lengths_to_mask(m_lens, l)
257
+ non_pad_mask = self.patchify_mask(non_pad_mask)
258
+ latents = self.patchify(latents)
259
+ b, l, d = latents.shape
260
+ latents = torch.where(non_pad_mask.unsqueeze(-1), latents, torch.zeros_like(latents))
261
+
262
+ target = latents.clone().detach().chunk(self.cluster, dim=1)
263
+ target = torch.cat(target, dim=0)
264
+
265
+ force_mask = False
266
+ if self.cond_mode == 'text':
267
+ with torch.no_grad():
268
+ cond_vector = self.encode_text(y)
269
+ elif self.cond_mode == 'action':
270
+ cond_vector = self.enc_action(y).to(device).float()
271
+ elif self.cond_mode == 'uncond':
272
+ cond_vector = torch.zeros(b, self.latent_dim).float().to(device)
273
+ force_mask = True
274
+ else:
275
+ raise NotImplementedError("Unsupported condition mode!!!")
276
+
277
+ attention_mask = []
278
+ for i in range(b):
279
+ a_mask = self.attention_mask.clone()
280
+ a_mask[:, :, :, m_lens[i] * self.patches_per_frame:] = -torch.inf
281
+ attention_mask.append(a_mask)
282
+ attention_mask = torch.cat(attention_mask)
283
+
284
+ model_kwargs = dict(conds=cond_vector, force_mask=force_mask, attention_mask=attention_mask)
285
+ if self.diff_model == "Flow":
286
+ loss_dict = self.train_diffusion.training_losses(self.forward, target, model_kwargs, dim=(2))
287
+ else:
288
+ t = torch.randint(0, self.train_diffusion.num_timesteps, (target.shape[0],), device=target.device)
289
+ loss_dict = self.train_diffusion.training_losses(self.forward, target, t, model_kwargs)
290
+ loss = loss_dict["loss"]
291
+ loss = loss.chunk(self.cluster, dim=0)
292
+ loss = torch.cat(loss, dim=1)
293
+ loss = (loss * non_pad_mask).sum() / non_pad_mask.sum()
294
+
295
+ return loss
296
+
297
+ @torch.no_grad()
298
+ @eval_decorator
299
+ def generate(self,
300
+ conds,
301
+ m_lens,
302
+ cond_scale: int,
303
+ temperature=1,
304
+ ):
305
+ device = next(self.parameters()).device
306
+ l = max(m_lens)
307
+ b = len(m_lens)
308
+
309
+ if self.cond_mode == 'text':
310
+ with torch.no_grad():
311
+ cond_vector = self.encode_text(conds)
312
+ elif self.cond_mode == 'action':
313
+ cond_vector = self.enc_action(conds).to(device)
314
+ elif self.cond_mode == 'uncond':
315
+ cond_vector = torch.zeros(b, self.latent_dim).float().to(device)
316
+ else:
317
+ raise NotImplementedError("Unsupported condition mode!!!")
318
+
319
+ padding_mask = ~lengths_to_mask(m_lens, l)
320
+ if not cond_scale == 1.0:
321
+ cond_vector = torch.cat([cond_vector, torch.zeros_like(cond_vector)], dim=0)
322
+ for block in self.ACMDMTransformer:
323
+ block.set_caching(True)
324
+
325
+ output = []
326
+ attention_mask = []
327
+ for i in range(b):
328
+ a_mask = self.attention_mask.clone()
329
+ a_mask[:, :, :, m_lens[i] * self.patches_per_frame:] = -torch.inf
330
+ attention_mask.append(a_mask)
331
+ attention_mask = torch.cat(attention_mask)
332
+ if not cond_scale == 1.0:
333
+ attention_mask = torch.cat([attention_mask, attention_mask], dim=0)
334
+ for step in range(self.cluster):
335
+ clean_x = output[-1] if len(output) > 0 else None
336
+ cache_flag = step > 0
337
+ noise = torch.randn(b, self.cluster_patches * self.patches_per_frame,
338
+ self.input_dim * self.patch_size[0] * self.patch_size[1]).to(device)
339
+ if not cond_scale == 1.0:
340
+ noise = torch.cat([noise, noise], dim=0)
341
+ if clean_x is not None:
342
+ clean_x = torch.cat([clean_x, clean_x], dim=0)
343
+ # cfg scale
344
+ # cond_scale2 = (cond_scale - 1) * (step+1) / (m_lens//self.cluster_patches + 1) + 1
345
+ model_kwargs = dict(conds=cond_vector, context=clean_x, block_id=step, cache=cache_flag,
346
+ attention_mask=attention_mask, cfg=cond_scale)
347
+ sample_fn = self.forward_with_CFG
348
+
349
+ if self.diff_model == "Flow":
350
+ model_fn = self.gen_diffusion.sample_ode() # default to ode sampling
351
+ sampled_token_latent = model_fn(noise, sample_fn, **model_kwargs)[-1]
352
+ else:
353
+ sampled_token_latent = self.gen_diffusion.p_sample_loop(
354
+ sample_fn, noise.shape, noise, clip_denoised=False, model_kwargs=model_kwargs,
355
+ progress=False,
356
+ temperature=temperature
357
+ )
358
+ if not cond_scale == 1:
359
+ sampled_token_latent, _ = sampled_token_latent.chunk(2, dim=0)
360
+ output.append(sampled_token_latent.detach().clone())
361
+
362
+ latents = torch.cat(output, dim=1)
363
+ latents = self.unpatchify(latents[:, :l * self.patches_per_frame, :])
364
+ latents = torch.where(padding_mask.unsqueeze(1).unsqueeze(-1), torch.zeros_like(latents), latents)
365
+ for block in self.ACMDMTransformer:
366
+ block.set_caching(False)
367
+ return latents
368
+
369
+ #################################################################################
370
+ # ACMDM Zoos #
371
+ #################################################################################
372
+ def acmdm_noisyprefixar_flow_s_ps22(**kwargs):
373
+ layer = 8
374
+ return ACMDM(latent_dim=layer*64, ff_size=layer*64*4, num_layers=layer, num_heads=layer, dropout=0, clip_dim=512,
375
+ diff_model="Flow", cond_drop_prob=0.1, max_length=50,
376
+ patch_size=(1, 22), stride_size=(1, 22), **kwargs)
377
+ ACMDM_models = {
378
+ 'ACMDM-NoisyPrefixAR-Flow-S-PatchSize22': acmdm_noisyprefixar_flow_s_ps22,
379
+ }
380
+
381
+ #################################################################################
382
+ # Inner Architectures #
383
+ #################################################################################
384
+ def modulate(x, shift, scale):
385
+ return x * (1 + scale) + shift
386
+
387
+
388
+ class ACMDMAttention(Attention):
389
+ def __init__(
390
+ self,
391
+ dim,
392
+ num_heads=8,
393
+ qkv_bias=True,
394
+ rope=None,
395
+ qk_norm=True,
396
+ **block_kwargs,
397
+ ):
398
+ super().__init__(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm, **block_kwargs)
399
+ self.caching, self.cached_k, self.cached_v = False, None, None
400
+ self.rope = rope
401
+
402
+ def set_caching(self, flag):
403
+ self.caching, self.cached_k, self.cached_v = flag, None, None
404
+
405
+ def forward(self, x, position_ids=None, attention_mask=None, block_size=None, cache=False):
406
+ B, N, C = x.shape
407
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
408
+ q, k, v = qkv.unbind(0)
409
+ q, k = self.q_norm(q), self.k_norm(k)
410
+
411
+ if self.rope is not None:
412
+ q, k = self.rope(q, k, position_ids)
413
+
414
+ if self.caching:
415
+ if cache:
416
+ if self.cached_k is None:
417
+ self.cached_k = k[:, :, :block_size, :]
418
+ self.cached_v = v[:, :, :block_size, :]
419
+ self.cached_x = x
420
+ else:
421
+ self.cached_k = torch.cat((self.cached_k, k[:, :, :block_size, :]), dim=2)
422
+ self.cached_v = torch.cat((self.cached_v, v[:, :, :block_size, :]), dim=2)
423
+
424
+ if self.cached_k is not None:
425
+ k = torch.cat((self.cached_k, k[:, :, -block_size:, :]), dim=2)
426
+ v = torch.cat((self.cached_v, v[:, :, -block_size:, :]), dim=2)
427
+
428
+ x = torch.nn.functional.scaled_dot_product_attention(
429
+ q, k, v,
430
+ attn_mask=attention_mask,
431
+ dropout_p=self.attn_drop.p
432
+ )
433
+ x = x.transpose(1, 2).reshape(B, N, C)
434
+ x = self.proj(x)
435
+ x = self.proj_drop(x)
436
+ return x
437
+
438
+
439
+ class SwiGLUFFN(nn.Module):
440
+ def __init__(
441
+ self,
442
+ in_features: int,
443
+ hidden_features,
444
+ bias: bool = True,
445
+ ) -> None:
446
+ super().__init__()
447
+ out_features = in_features
448
+ hidden_features = hidden_features
449
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
450
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
451
+
452
+ def forward(self, x):
453
+ x12 = self.w12(x)
454
+ x1, x2 = x12.chunk(2, dim=-1)
455
+ hidden = F.silu(x1) * x2
456
+ return self.w3(hidden)
457
+
458
+
459
+ class ACMDMTransBlock(nn.Module):
460
+ def __init__(self, hidden_size, num_heads, mlp_size=1024, rope=None, qk_norm=True):
461
+ super().__init__()
462
+ self.norm1 = LlamaRMSNorm(hidden_size, eps=1e-6)
463
+ self.attn = ACMDMAttention(hidden_size, num_heads=num_heads, qkv_bias=True, norm_layer=LlamaRMSNorm,
464
+ qk_norm=qk_norm, rope=rope)
465
+ self.norm2 = LlamaRMSNorm(hidden_size, eps=1e-6)
466
+ self.mlp = SwiGLUFFN(hidden_size, int(2 / 3 * mlp_size))
467
+ self.adaLN_modulation = nn.Sequential(
468
+ nn.SiLU(),
469
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True)
470
+ )
471
+
472
+ def set_caching(self, flag):
473
+ self.attn.set_caching(flag)
474
+
475
+ def forward(self, x, c, attention_mask=None, position_ids=None, block_size=None, cache=False):
476
+ dtype = x.dtype
477
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
478
+ norm_x1 = self.norm1(x.to(torch.float32)).to(dtype)
479
+ attn_input_x = modulate(norm_x1, shift_msa, scale_msa)
480
+ attn_output_x = self.attn(attn_input_x, attention_mask=attention_mask, position_ids=position_ids, block_size=block_size, cache=cache)
481
+ x = x + gate_msa * attn_output_x
482
+
483
+ norm_x2 = self.norm2(x.to(torch.float32)).to(dtype)
484
+ gate_input_x = modulate(norm_x2, shift_mlp, scale_mlp)
485
+ gate_output_x = self.mlp(gate_input_x)
486
+ x = x + gate_mlp * gate_output_x
487
+ return x
488
+
489
+
490
+ class FinalLayer(nn.Module):
491
+ def __init__(self, hidden_size, output_size):
492
+ super().__init__()
493
+ self.norm_final = LlamaRMSNorm(hidden_size, eps=1e-6)
494
+ self.linear = nn.Linear(hidden_size, output_size, bias=True)
495
+ self.adaLN_modulation = nn.Sequential(
496
+ nn.SiLU(),
497
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True)
498
+ )
499
+
500
+ def forward(self, x, c):
501
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
502
+ norm_x = self.norm_final(x.to(torch.float32)).to(x.dtype)
503
+ x = modulate(norm_x, shift, scale)
504
+ x = self.linear(x)
505
+ return x
506
+
507
+
508
+ class TimestepEmbedder(nn.Module):
509
+ def __init__(self, hidden_size, frequency_embedding_size=256):
510
+ super().__init__()
511
+ self.mlp = nn.Sequential(
512
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
513
+ nn.SiLU(),
514
+ nn.Linear(hidden_size, hidden_size, bias=True),
515
+ )
516
+ self.frequency_embedding_size = frequency_embedding_size
517
+
518
+ @staticmethod
519
+ def timestep_embedding(t, dim, max_period=10000, dtype=torch.float32):
520
+ """
521
+ Create sinusoidal timestep embeddings.
522
+ :param t: a 1-D Tensor of N indices, one per batch element.
523
+ These may be fractional.
524
+ :param dim: the dimension of the output.
525
+ :param max_period: controls the minimum frequency of the embeddings.
526
+ :return: an (N, D) Tensor of positional embeddings.
527
+ """
528
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
529
+ half = dim // 2
530
+ freqs = torch.exp(
531
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=dtype) / half
532
+ ).to(device=t.device, dtype=dtype)
533
+ args = t[:, None] * freqs[None]
534
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
535
+ if dim % 2:
536
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
537
+ return embedding
538
+
539
+ def forward(self, t, dtype=torch.bfloat16):
540
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size, dtype=dtype)
541
+ t_emb = self.mlp(t_freq)
542
+ return t_emb
543
+
544
+
545
+ class LlamaRMSNorm(nn.Module):
546
+ def __init__(self, hidden_size, eps=1e-6):
547
+ super().__init__()
548
+ self.weight = nn.Parameter(torch.ones(hidden_size))
549
+ self.variance_epsilon = eps
550
+
551
+ def forward(self, hidden_states):
552
+ input_dtype = hidden_states.dtype
553
+ hidden_states = hidden_states.to(torch.float32)
554
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
555
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
556
+ return (self.weight * hidden_states).to(input_dtype)
models/ACMDM_Prefix_AR.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import clip
5
+ import math
6
+ from functools import partial
7
+ from timm.models.vision_transformer import Attention
8
+ from models.ROPE import RopeND
9
+ from utils.eval_utils import eval_decorator
10
+ from utils.train_utils import lengths_to_mask
11
+ from diffusions.diffusion import create_diffusion
12
+ from diffusions.transport import create_transport, Sampler
13
+
14
+ #################################################################################
15
+ # ACMDM #
16
+ #################################################################################
17
+ class ACMDM(nn.Module):
18
+ def __init__(self, input_dim, cond_mode, latent_dim=256, ff_size=1024, num_layers=8,
19
+ num_heads=4, dropout=0, clip_dim=512,
20
+ diff_model='Flow', cond_drop_prob=0.1, max_length=49,
21
+ patch_size=(1, 22), stride_size=(1, 22), num_joint=22,
22
+ clip_version='ViT-B/32', **kargs):
23
+ super(ACMDM, self).__init__()
24
+
25
+ self.input_dim = input_dim
26
+ self.latent_dim = latent_dim
27
+ self.clip_dim = clip_dim
28
+ self.dropout = dropout
29
+
30
+ self.cond_mode = cond_mode
31
+ self.cond_drop_prob = cond_drop_prob
32
+
33
+ if self.cond_mode == 'action':
34
+ assert 'num_actions' in kargs
35
+ self.num_actions = kargs.get('num_actions', 1)
36
+ self.encode_action = partial(F.one_hot, num_classes=self.num_actions)
37
+ # --------------------------------------------------------------------------
38
+ # Diffusion
39
+ self.diff_model = diff_model
40
+ if self.diff_model == 'Flow':
41
+ self.train_diffusion = create_transport() # default to linear, velocity prediction
42
+ self.gen_diffusion = Sampler(self.train_diffusion)
43
+ else:
44
+ self.train_diffusion = create_diffusion(timestep_respacing="", noise_schedule="linear")
45
+ self.gen_diffusion = create_diffusion(timestep_respacing="", noise_schedule="linear")
46
+ # --------------------------------------------------------------------------
47
+ # ACMDM
48
+ print('Loading ACMDM...')
49
+ self.t_embedder = TimestepEmbedder(self.latent_dim)
50
+ self.patch_size = patch_size
51
+ self.stride_size = stride_size
52
+ self.patches_per_frame = (num_joint - patch_size[1]) // stride_size[1] + 1
53
+
54
+ # Patchification
55
+ self.x_embedder = nn.Conv2d(self.input_dim, self.latent_dim, kernel_size=self.patch_size, stride=self.stride_size, bias=True)
56
+
57
+ # Positional Encoding
58
+ max_length = max_length * self.patches_per_frame
59
+ self.max_lens = [max_length]
60
+ self.rope = RopeND(nd=1, nd_split=[1], max_lens=self.max_lens)
61
+ self.position_ids_precompute = torch.arange(max_length).unsqueeze(0)
62
+
63
+ self.ACMDMTransformer = nn.ModuleList([
64
+ ACMDMTransBlock(self.latent_dim, num_heads, mlp_size=ff_size, rope=self.rope, qk_norm=True) for _ in range(num_layers)
65
+ ])
66
+
67
+ if self.cond_mode == 'text':
68
+ self.y_embedder = nn.Linear(self.clip_dim, self.latent_dim)
69
+ elif self.cond_mode == 'action':
70
+ self.y_embedder = nn.Linear(self.num_actions, self.latent_dim)
71
+ elif self.cond_mode == 'uncond':
72
+ self.y_embedder = nn.Identity()
73
+ else:
74
+ raise KeyError("Unsupported condition mode!!!")
75
+
76
+ self.final_layer = FinalLayer(self.latent_dim, self.input_dim, patch_size=patch_size, stride_size=stride_size, patches=self.patches_per_frame)
77
+
78
+ self.initialize_weights()
79
+
80
+ if self.cond_mode == 'text':
81
+ print('Loading CLIP...')
82
+ self.clip_version = clip_version
83
+ self.clip_model = self.load_and_freeze_clip(clip_version)
84
+
85
+ def initialize_weights(self):
86
+ # Initialize transformer layers:
87
+ def _basic_init(module):
88
+ if isinstance(module, nn.Linear):
89
+ torch.nn.init.xavier_uniform_(module.weight)
90
+ if module.bias is not None:
91
+ nn.init.constant_(module.bias, 0)
92
+
93
+ self.apply(_basic_init)
94
+
95
+ # Initialize timestep embedding MLP:
96
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
97
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
98
+
99
+ # Zero-out adaLN modulation layers in ACMDM blocks:
100
+ for block in self.ACMDMTransformer:
101
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
102
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
103
+
104
+ # Zero-out output layers:
105
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
106
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
107
+ nn.init.constant_(self.final_layer.linear.weight, 0)
108
+ nn.init.constant_(self.final_layer.linear.bias, 0)
109
+
110
+ def load_and_freeze_clip(self, clip_version):
111
+ clip_model, clip_preprocess = clip.load(clip_version, device='cpu', jit=False)
112
+ assert torch.cuda.is_available()
113
+ clip.model.convert_weights(clip_model)
114
+
115
+ clip_model.eval()
116
+ for p in clip_model.parameters():
117
+ p.requires_grad = False
118
+ return clip_model
119
+
120
+ def encode_text(self, raw_text):
121
+ device = next(self.parameters()).device
122
+ text = clip.tokenize(raw_text, truncate=True).to(device)
123
+ feat_clip_text = self.clip_model.encode_text(text).float()
124
+ return feat_clip_text
125
+
126
+ def mask_cond(self, cond, force_mask=False):
127
+ bs, d = cond.shape
128
+ if force_mask:
129
+ return torch.zeros_like(cond)
130
+ elif self.training and self.cond_drop_prob > 0.:
131
+ mask = torch.bernoulli(torch.ones(bs, device=cond.device) * self.cond_drop_prob).view(bs, 1)
132
+ return cond * (1. - mask)
133
+ else:
134
+ return cond
135
+
136
+ def forward(self, x, t, conds, attention_mask, context, force_mask=False):
137
+ t = self.t_embedder(t, dtype=x.dtype)
138
+ conds = self.mask_cond(conds, force_mask=force_mask)
139
+ x = torch.cat([context, x], dim=2)
140
+ x = self.x_embedder(x)
141
+ x = x.flatten(2).transpose(1, 2)
142
+ conds = self.y_embedder(conds)
143
+ y = t.unsqueeze(1) + conds.unsqueeze(1)
144
+ position_ids = self.position_ids_precompute[:, :x.shape[1]]
145
+ for block in self.ACMDMTransformer:
146
+ x = block(x, y, attention_mask, position_ids=position_ids)
147
+ x = self.final_layer(x, y)[:, :, 5:, :]
148
+ return x
149
+
150
+ def forward_with_CFG(self, x, t, conds, attention_mask, context, cfg=1.0):
151
+ if not cfg == 1.0:
152
+ half = x[: len(x) // 2]
153
+ x = torch.cat([half, half], dim=0)
154
+ context = torch.cat([context, context], dim=0)
155
+ x = self.forward(x, t, conds, attention_mask, context)
156
+ if not cfg == 1.0:
157
+ cond_eps, uncond_eps = torch.split(x, len(x) // 2, dim=0)
158
+ half_eps = uncond_eps + cfg * (cond_eps - uncond_eps)
159
+ x = torch.cat([half_eps, half_eps], dim=0)
160
+ return x
161
+
162
+ def forward_loss(self, latents, y, m_lens):
163
+ latents = latents.permute(0, 2, 3, 1)
164
+ b, l, j, d = latents.shape
165
+ device = latents.device
166
+
167
+ non_pad_mask = lengths_to_mask(m_lens, l)
168
+ latents = torch.where(non_pad_mask.unsqueeze(-1).unsqueeze(-1), latents, torch.zeros_like(latents))
169
+
170
+ # prefix 20, prediction 40 style
171
+ target = latents.clone().permute(0, 3, 1, 2).detach()[:, :, 5:, :]
172
+ context = latents.clone().permute(0, 3, 1, 2).detach()[:, :, :5, :]
173
+
174
+ force_mask = False
175
+ if self.cond_mode == 'text':
176
+ with torch.no_grad():
177
+ cond_vector = self.encode_text(y)
178
+ elif self.cond_mode == 'action':
179
+ cond_vector = self.enc_action(y).to(device).float()
180
+ elif self.cond_mode == 'uncond':
181
+ cond_vector = torch.zeros(b, self.latent_dim).float().to(device)
182
+ force_mask = True
183
+ else:
184
+ raise NotImplementedError("Unsupported condition mode!!!")
185
+
186
+ attention_mask = non_pad_mask.unsqueeze(-1).repeat(1, 1, self.patches_per_frame).flatten(1).unsqueeze(
187
+ 1).unsqueeze(1)
188
+
189
+ model_kwargs = dict(conds=cond_vector, force_mask=force_mask, attention_mask=attention_mask, context=context)
190
+ if self.diff_model == "Flow":
191
+ loss_dict = self.train_diffusion.training_losses(self.forward, target, model_kwargs)
192
+ else:
193
+ t = torch.randint(0, self.train_diffusion.num_timesteps, (target.shape[0],), device=target.device)
194
+ loss_dict = self.train_diffusion.training_losses(self.forward, target, t, model_kwargs)
195
+ loss = loss_dict["loss"]
196
+ non_pad_mask = non_pad_mask[:, 5:]
197
+ loss = (loss * non_pad_mask).sum() / non_pad_mask.sum()
198
+
199
+ return loss
200
+
201
+ @torch.no_grad()
202
+ @eval_decorator
203
+ def generate(self,
204
+ conds,
205
+ m_lens,
206
+ cond_scale: int,
207
+ context,
208
+ temperature=1,
209
+ j=22,
210
+ ):
211
+ device = next(self.parameters()).device
212
+ l = max(m_lens)
213
+ b = len(m_lens)
214
+
215
+ if self.cond_mode == 'text':
216
+ with torch.no_grad():
217
+ cond_vector = self.encode_text(conds)
218
+ elif self.cond_mode == 'action':
219
+ cond_vector = self.enc_action(conds).to(device)
220
+ elif self.cond_mode == 'uncond':
221
+ cond_vector = torch.zeros(b, self.latent_dim).float().to(device)
222
+ else:
223
+ raise NotImplementedError("Unsupported condition mode!!!")
224
+
225
+ padding_mask = ~lengths_to_mask(m_lens, l)
226
+ if not cond_scale == 1.0:
227
+ cond_vector = torch.cat([cond_vector, torch.zeros_like(cond_vector)], dim=0)
228
+
229
+ # really naive way to write the PrefixAR inferece loop, to be improved
230
+ iter = [(0,15),(10,25),(20, 35), (30, 45), (40, l.item())]
231
+ out = [context.clone().detach()]
232
+ for i in range(len(iter)):
233
+ noise = torch.randn(b, self.input_dim, iter[i][1]-iter[i][0]-5, j).to(device)
234
+ if not cond_scale == 1.0:
235
+ noise = torch.cat([noise, noise], dim=0)
236
+
237
+ attention_mask = ((~padding_mask)[:, iter[i][0]:iter[i][1]]).unsqueeze(-1).repeat(1,1,self.patches_per_frame).flatten(1).unsqueeze(1).unsqueeze(1)
238
+ model_kwargs = dict(conds=cond_vector, attention_mask=attention_mask, context=context, cfg=cond_scale)
239
+ sample_fn = self.forward_with_CFG
240
+
241
+ if not cond_scale == 1:
242
+ model_kwargs["attention_mask"] = attention_mask.repeat(2, 1, 1, 1)
243
+
244
+ if self.diff_model == "Flow":
245
+ model_fn = self.gen_diffusion.sample_ode(sampling_method="euler") # default to ode sampling, use euler to prevent underflow as current iter can contain paddings
246
+ sampled_token_latent = model_fn(noise, sample_fn, **model_kwargs)[-1]
247
+ else:
248
+ sampled_token_latent = self.gen_diffusion.p_sample_loop(
249
+ sample_fn, noise.shape, noise, clip_denoised=False, model_kwargs=model_kwargs,
250
+ progress=False,
251
+ temperature=temperature
252
+ )
253
+ if not cond_scale == 1:
254
+ sampled_token_latent, _ = sampled_token_latent.chunk(2, dim=0)
255
+ out.append(sampled_token_latent.clone().detach())
256
+ context = sampled_token_latent[:, :, 5:, :].clone().detach()
257
+ sampled_token_latent = torch.cat(out, dim=2).permute(0,2,3,1)
258
+
259
+ latents = torch.where(padding_mask.unsqueeze(-1).unsqueeze(-1), torch.zeros_like(sampled_token_latent), sampled_token_latent)
260
+ return latents.permute(0,3,1,2)
261
+
262
+ #################################################################################
263
+ # ACMDM Zoos #
264
+ #################################################################################
265
+ def acmdm_prefixar_flow_s_ps22(**kwargs):
266
+ layer = 8
267
+ return ACMDM(latent_dim=layer*64, ff_size=layer*64*4, num_layers=layer, num_heads=layer, dropout=0, clip_dim=512,
268
+ diff_model="Flow", cond_drop_prob=0.1, max_length=15,
269
+ patch_size=(1, 22), stride_size=(1, 22), **kwargs)
270
+ ACMDM_models = {
271
+ 'ACMDM-PrefixAR-Flow-S-PatchSize22': acmdm_prefixar_flow_s_ps22,
272
+ }
273
+
274
+ #################################################################################
275
+ # Inner Architectures #
276
+ #################################################################################
277
+ def modulate(x, shift, scale):
278
+ return x * (1 + scale) + shift
279
+
280
+
281
+ class ACMDMAttention(Attention):
282
+ def __init__(
283
+ self,
284
+ dim,
285
+ num_heads=8,
286
+ qkv_bias=True,
287
+ rope=None,
288
+ qk_norm=True,
289
+ **block_kwargs,
290
+ ):
291
+ super().__init__(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm, **block_kwargs)
292
+ self.rope = rope
293
+
294
+ def forward(self, x, position_ids=None, attention_mask=None):
295
+ B, N, C = x.shape
296
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
297
+ q, k, v = qkv.unbind(0)
298
+ q, k = self.q_norm(q), self.k_norm(k)
299
+
300
+ if self.rope is not None:
301
+ q, k = self.rope(q, k, position_ids)
302
+
303
+ x = torch.nn.functional.scaled_dot_product_attention(
304
+ q, k, v,
305
+ attn_mask=attention_mask,
306
+ dropout_p=self.attn_drop.p
307
+ )
308
+ x = x.transpose(1, 2).reshape(B, N, C)
309
+ x = self.proj(x)
310
+ x = self.proj_drop(x)
311
+ return x
312
+
313
+
314
+ class SwiGLUFFN(nn.Module):
315
+ def __init__(
316
+ self,
317
+ in_features: int,
318
+ hidden_features,
319
+ bias: bool = True,
320
+ ) -> None:
321
+ super().__init__()
322
+ out_features = in_features
323
+ hidden_features = hidden_features
324
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
325
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
326
+
327
+ def forward(self, x):
328
+ x12 = self.w12(x)
329
+ x1, x2 = x12.chunk(2, dim=-1)
330
+ hidden = F.silu(x1) * x2
331
+ return self.w3(hidden)
332
+
333
+
334
+ class ACMDMTransBlock(nn.Module):
335
+ def __init__(self, hidden_size, num_heads, mlp_size=1024, rope=None, qk_norm=True):
336
+ super().__init__()
337
+ self.norm1 = LlamaRMSNorm(hidden_size, eps=1e-6)
338
+ self.attn = ACMDMAttention(hidden_size, num_heads=num_heads, qkv_bias=True, norm_layer=LlamaRMSNorm,
339
+ qk_norm=qk_norm, rope=rope)
340
+ self.norm2 = LlamaRMSNorm(hidden_size, eps=1e-6)
341
+ self.mlp = SwiGLUFFN(hidden_size, int(2 / 3 * mlp_size))
342
+ self.adaLN_modulation = nn.Sequential(
343
+ nn.SiLU(),
344
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True)
345
+ )
346
+
347
+ def forward(self, x, c, attention_mask=None, position_ids=None):
348
+ dtype = x.dtype
349
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
350
+ norm_x1 = self.norm1(x.to(torch.float32)).to(dtype)
351
+ attn_input_x = modulate(norm_x1, shift_msa, scale_msa)
352
+ attn_output_x = self.attn(attn_input_x, attention_mask=attention_mask, position_ids=position_ids)
353
+ x = x + gate_msa * attn_output_x
354
+
355
+ norm_x2 = self.norm2(x.to(torch.float32)).to(dtype)
356
+ gate_input_x = modulate(norm_x2, shift_mlp, scale_mlp)
357
+ gate_output_x = self.mlp(gate_input_x)
358
+ x = x + gate_mlp * gate_output_x
359
+ return x
360
+
361
+
362
+ class FinalLayer(nn.Module):
363
+ def __init__(self, hidden_size, output_size, patch_size=(1, 22), stride_size=(1,22), patches=1):
364
+ super().__init__()
365
+ self.norm_final = LlamaRMSNorm(hidden_size, eps=1e-6)
366
+ self.patch_size = patch_size
367
+ self.stride_size = stride_size
368
+ self.patches = patches
369
+ self.linear = nn.Linear(hidden_size, output_size*patch_size[0]*patch_size[1], bias=True)
370
+ self.adaLN_modulation = nn.Sequential(
371
+ nn.SiLU(),
372
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True)
373
+ )
374
+
375
+ def forward(self, x, c):
376
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
377
+ norm_x = self.norm_final(x.to(torch.float32)).to(x.dtype)
378
+ x = modulate(norm_x, shift, scale)
379
+ x = self.linear(x)
380
+ x = x.reshape(shape=(x.shape[0], x.shape[1]//self.patches, self.patches, self.patch_size[0], self.patch_size[1], x.shape[-1] // self.patch_size[1]))
381
+ x = torch.einsum('nljpqc->nclpjq', x)
382
+ x = x.reshape(shape=(x.shape[0], x.shape[1], -1, 22))
383
+ return x
384
+
385
+
386
+ class TimestepEmbedder(nn.Module):
387
+ def __init__(self, hidden_size, frequency_embedding_size=256):
388
+ super().__init__()
389
+ self.mlp = nn.Sequential(
390
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
391
+ nn.SiLU(),
392
+ nn.Linear(hidden_size, hidden_size, bias=True),
393
+ )
394
+ self.frequency_embedding_size = frequency_embedding_size
395
+
396
+ @staticmethod
397
+ def timestep_embedding(t, dim, max_period=10000, dtype=torch.float32):
398
+ """
399
+ Create sinusoidal timestep embeddings.
400
+ :param t: a 1-D Tensor of N indices, one per batch element.
401
+ These may be fractional.
402
+ :param dim: the dimension of the output.
403
+ :param max_period: controls the minimum frequency of the embeddings.
404
+ :return: an (N, D) Tensor of positional embeddings.
405
+ """
406
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
407
+ half = dim // 2
408
+ freqs = torch.exp(
409
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=dtype) / half
410
+ ).to(device=t.device, dtype=dtype)
411
+ args = t[:, None] * freqs[None]
412
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
413
+ if dim % 2:
414
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
415
+ return embedding
416
+
417
+ def forward(self, t, dtype=torch.bfloat16):
418
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size, dtype=dtype)
419
+ t_emb = self.mlp(t_freq)
420
+ return t_emb
421
+
422
+
423
+ class LlamaRMSNorm(nn.Module):
424
+ def __init__(self, hidden_size, eps=1e-6):
425
+ super().__init__()
426
+ self.weight = nn.Parameter(torch.ones(hidden_size))
427
+ self.variance_epsilon = eps
428
+
429
+ def forward(self, hidden_states):
430
+ input_dtype = hidden_states.dtype
431
+ hidden_states = hidden_states.to(torch.float32)
432
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
433
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
434
+ return (self.weight * hidden_states).to(input_dtype)
models/AE_2D_Causal.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ #################################################################################
7
+ # AE #
8
+ #################################################################################
9
+ class AE(nn.Module):
10
+ def __init__(self, input_width=3, output_emb_width=4, width=512, depth=3, ch_mult=(1,1,1)):
11
+ super().__init__()
12
+ self.output_emb_width = output_emb_width
13
+ self.encoder = Encoder(input_width, output_emb_width, width, depth, in_ch_mult=ch_mult[:-1], ch_mult=ch_mult[1:])
14
+ self.decoder = Decoder(input_width, output_emb_width, width, depth, in_ch_mult=ch_mult[::-1][1:], ch_mult=ch_mult[::-1][:-1])
15
+
16
+ def preprocess(self, x):
17
+ x = x.permute(0, 3, 1, 2).float()
18
+ return x
19
+
20
+ def encode(self, x):
21
+ x_in = self.preprocess(x)
22
+ x_encoder = self.encoder(x_in)
23
+ return x_encoder
24
+
25
+ def forward(self, x):
26
+ x_in = self.preprocess(x)
27
+ x_encoder = self.encoder(x_in)
28
+ x_out = self.decoder(x_encoder)
29
+ return x_out
30
+
31
+ def decode(self, x):
32
+ x_out = self.decoder(x)
33
+ return x_out
34
+
35
+ #################################################################################
36
+ # VAE #
37
+ #################################################################################
38
+ class VAE(nn.Module):
39
+ def __init__(self, input_width=3, output_emb_width=4, width=512, depth=3, ch_mult=(1,1,1)):
40
+ super().__init__()
41
+ self.output_emb_width = output_emb_width
42
+ self.encoder = Encoder(input_width, output_emb_width*2, width, depth, in_ch_mult=ch_mult[:-1], ch_mult=ch_mult[1:])
43
+ self.decoder = Decoder(input_width, output_emb_width, width, depth, in_ch_mult=ch_mult[::-1][1:], ch_mult=ch_mult[::-1][:-1])
44
+
45
+ def preprocess(self, x):
46
+ x = x.permute(0, 3, 1, 2).float()
47
+ return x
48
+
49
+ def encode(self, x):
50
+ x_in = self.preprocess(x)
51
+ x_encoder = self.encoder(x_in)
52
+ x_encoder = DiagonalGaussianDistribution(x_encoder)
53
+ x_encoder = x_encoder.sample()
54
+ return x_encoder
55
+
56
+ def forward(self, x, need_loss=False):
57
+ x_in = self.preprocess(x)
58
+ x_encoder = self.encoder(x_in)
59
+ x_encoder = DiagonalGaussianDistribution(x_encoder)
60
+ kl_loss = x_encoder.kl()
61
+ x_encoder = x_encoder.sample()
62
+ x_out = self.decoder(x_encoder)
63
+ if need_loss:
64
+ # sigma vae for better quality
65
+ log_sigma = ((x - x_out) ** 2).mean([1,2,3], keepdim=True).sqrt().log()
66
+ log_sigma = -6 + F.softplus(log_sigma - (-6))
67
+ rec = 0.5 * torch.pow((x - x_out) / log_sigma.exp(), 2) + log_sigma
68
+ rec = rec.sum(dim=(1,2,3))
69
+ loss = {
70
+ "rec": rec.mean(),
71
+ "kl": kl_loss.mean()}
72
+ return x_out, loss
73
+ else:
74
+ return x_out
75
+
76
+ def decode(self, x):
77
+ x_out = self.decoder(x)
78
+ return x_out
79
+
80
+ #################################################################################
81
+ # AE Zoos #
82
+ #################################################################################
83
+ def ae(**kwargs):
84
+ return AE(output_emb_width=4, width=512, depth=3, ch_mult=(1,1,1), **kwargs)
85
+ def vae(**kwargs):
86
+ return VAE(output_emb_width=4, width=512, depth=3, ch_mult=(1,1,1), **kwargs)
87
+ AE_models = {
88
+ 'AE_Model': ae, 'VAE_Model': vae
89
+ }
90
+ #################################################################################
91
+ # Inner Architectures #
92
+ #################################################################################
93
+ class Encoder(nn.Module):
94
+ def __init__(self, input_emb_width=3, output_emb_width=4, width=512, depth=3, in_ch_mult=(1,1), ch_mult=(1,1)):
95
+ super().__init__()
96
+ self.model = nn.ModuleList()
97
+ self.conv_in = nn.Conv2d(input_emb_width, width, (3, 1), (1, 1), (0, 0))
98
+
99
+ block_in = width * in_ch_mult[0]
100
+ for i in range(len(in_ch_mult)):
101
+ block_in = width * in_ch_mult[i]
102
+ block_out = width * ch_mult[i]
103
+ self.model.append(CausalPad2d((0, 0, 2, 0)))
104
+ self.model.append(nn.Conv2d(width, width, (4, 1), (2, 1), (0, 0)))
105
+ for j in range(depth):
106
+ self.model.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dil=2-j))
107
+ block_in = block_out
108
+
109
+ self.conv_out = torch.nn.Conv2d(block_in, output_emb_width, (3, 1), (1, 1), (0, 0))
110
+ def forward(self, x):
111
+ x = F.pad(x, (0, 0, 2, 0))
112
+ x = self.conv_in(x)
113
+ for layer in self.model:
114
+ x = layer(x)
115
+ x = F.pad(x, (0, 0, 2, 0))
116
+ x = self.conv_out(x)
117
+ return x
118
+
119
+
120
+ class Decoder(nn.Module):
121
+ def __init__(self, input_emb_width=3, output_emb_width=4, width=512, depth=3, in_ch_mult=(1,1), ch_mult=(1,1)):
122
+ super().__init__()
123
+ self.model = nn.ModuleList()
124
+ block_in = width * ch_mult[0]
125
+ self.conv_in = nn.Conv2d(output_emb_width, block_in, (3,1), (1,1), (0,0))
126
+
127
+ for i in range(len(in_ch_mult)):
128
+ block_in = width * ch_mult[i]
129
+ block_out = width * in_ch_mult[i]
130
+ for j in range(depth):
131
+ self.model.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dil=2-j))
132
+ block_in = block_out
133
+ self.model.append(Upsample(block_in))
134
+
135
+ self.conv_out1 = torch.nn.Conv2d(block_in, block_in, (3, 1), (1,1), (0,0))
136
+ self.conv_out2 = torch.nn.Conv2d(block_in, input_emb_width, (3, 1), (1, 1), (0, 0))
137
+
138
+ def forward(self, x):
139
+ x = F.pad(x, (0, 0, 2, 0))
140
+ x = self.conv_in(x)
141
+ for layer in self.model:
142
+ x = layer(x)
143
+ x = F.pad(x, (0, 0, 2, 0))
144
+ x = self.conv_out1(x)
145
+ x = x * torch.sigmoid(x)
146
+ x = F.pad(x, (0, 0, 2, 0))
147
+ x = self.conv_out2(x)
148
+ return x.permute(0,2,3,1)
149
+
150
+
151
+ class Upsample(nn.Module):
152
+ def __init__(self, in_channels):
153
+ super().__init__()
154
+ self.conv = torch.nn.Conv2d(in_channels, in_channels,(3, 1), (1, 1), (0, 0))
155
+
156
+ def forward(self, x):
157
+ x = torch.nn.functional.interpolate(x, scale_factor=(2.0, 1.0), mode="nearest")
158
+ x = F.pad(x, (0, 0, 2, 0))
159
+ x = self.conv(x)
160
+ return x
161
+
162
+
163
+ class ResnetBlock(nn.Module):
164
+ def __init__(self, *, in_channels, out_channels=None, dil=0, conv_shortcut=False, dropout=0.2):
165
+ super().__init__()
166
+ self.in_channels = in_channels
167
+ out_channels = in_channels if out_channels is None else out_channels
168
+ self.out_channels = out_channels
169
+ self.use_conv_shortcut = conv_shortcut
170
+ self.padd = CausalPad2d((0, 0, 2*(3 ** dil), 0))
171
+
172
+ self.conv1 = torch.nn.Conv2d(in_channels,
173
+ out_channels,
174
+ kernel_size=(3, 1),
175
+ stride=(1, 1),
176
+ padding=(0, 0),
177
+ dilation=(3 ** dil, 1),
178
+ )
179
+ self.dropout = torch.nn.Dropout(dropout)
180
+ self.conv2 = torch.nn.Conv2d(out_channels,
181
+ out_channels,
182
+ kernel_size=(1, 1),
183
+ stride=(1, 1),
184
+ padding=(0, 0),
185
+ )
186
+
187
+ def forward(self, x):
188
+ h = x
189
+ h = h*torch.sigmoid(h)
190
+ h = self.padd(h)
191
+ h = self.conv1(h)
192
+
193
+ h = h*torch.sigmoid(h)
194
+ h = self.conv2(h)
195
+ h = self.dropout(h)
196
+ return x+h
197
+
198
+
199
+ class DiagonalGaussianDistribution(object):
200
+ def __init__(self, parameters, deterministic=False):
201
+ self.parameters = parameters
202
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
203
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
204
+ self.deterministic = deterministic
205
+ self.std = torch.exp(0.5 * self.logvar)
206
+ self.var = torch.exp(self.logvar)
207
+ if self.deterministic:
208
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
209
+
210
+ def sample(self):
211
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
212
+ return x
213
+
214
+ def kl(self, other=None):
215
+ if self.deterministic:
216
+ return torch.Tensor([0.])
217
+ else:
218
+ if other is None:
219
+ return 0.5 * torch.sum(torch.pow(self.mean, 2)
220
+ + self.var - 1.0 - self.logvar,
221
+ dim=[1, 2, 3])
222
+ else:
223
+ return 0.5 * torch.sum(
224
+ torch.pow(self.mean - other.mean, 2) / other.var
225
+ + self.var / other.var - 1.0 - self.logvar + other.logvar,
226
+ dim=[1, 2, 3])
227
+
228
+ def nll(self, sample, dims=[1,2,3]):
229
+ if self.deterministic:
230
+ return torch.Tensor([0.])
231
+ logtwopi = np.log(2.0 * np.pi)
232
+ return 0.5 * torch.sum(
233
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
234
+ dim=dims)
235
+
236
+ def mode(self):
237
+ return self.mean
238
+
239
+
240
+ class CausalPad2d(nn.Module):
241
+ def __init__(self, pad):
242
+ super().__init__()
243
+ self.pad = pad
244
+ def forward(self, x):
245
+ return F.pad(x, self.pad)
models/AE_2D_NonCausal.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ #################################################################################
7
+ # AE #
8
+ #################################################################################
9
+ class AE(nn.Module):
10
+ def __init__(self, input_width=3, output_emb_width=4, width=512, depth=3, ch_mult=(1,1,1)):
11
+ super().__init__()
12
+ self.output_emb_width = output_emb_width
13
+ self.encoder = Encoder(input_width, output_emb_width, width, depth, in_ch_mult=ch_mult[:-1], ch_mult=ch_mult[1:])
14
+ self.decoder = Decoder(input_width, output_emb_width, width, depth, in_ch_mult=ch_mult[::-1][1:], ch_mult=ch_mult[::-1][:-1])
15
+
16
+ def preprocess(self, x):
17
+ x = x.permute(0, 3, 1, 2).float()
18
+ return x
19
+
20
+ def encode(self, x):
21
+ x_in = self.preprocess(x)
22
+ x_encoder = self.encoder(x_in)
23
+ return x_encoder
24
+
25
+ def forward(self, x):
26
+ x_in = self.preprocess(x)
27
+ x_encoder = self.encoder(x_in)
28
+ x_out = self.decoder(x_encoder)
29
+ return x_out
30
+
31
+ def decode(self, x):
32
+ x_out = self.decoder(x)
33
+ return x_out
34
+
35
+ #################################################################################
36
+ # VAE #
37
+ #################################################################################
38
+ class VAE(nn.Module):
39
+ def __init__(self, input_width=3, output_emb_width=4, width=512, depth=3, ch_mult=(1,1,1)):
40
+ super().__init__()
41
+ self.output_emb_width = output_emb_width
42
+ self.encoder = Encoder(input_width, output_emb_width*2, width, depth, in_ch_mult=ch_mult[:-1], ch_mult=ch_mult[1:])
43
+ self.decoder = Decoder(input_width, output_emb_width, width, depth, in_ch_mult=ch_mult[::-1][1:], ch_mult=ch_mult[::-1][:-1])
44
+
45
+ def preprocess(self, x):
46
+ x = x.permute(0, 3, 1, 2).float()
47
+ return x
48
+
49
+ def encode(self, x):
50
+ x_in = self.preprocess(x)
51
+ x_encoder = self.encoder(x_in)
52
+ x_encoder = DiagonalGaussianDistribution(x_encoder)
53
+ x_encoder = x_encoder.sample()
54
+ return x_encoder
55
+
56
+ def forward(self, x, need_loss=False):
57
+ x_in = self.preprocess(x)
58
+ x_encoder = self.encoder(x_in)
59
+ x_encoder = DiagonalGaussianDistribution(x_encoder)
60
+ kl_loss = x_encoder.kl()
61
+ x_encoder = x_encoder.sample()
62
+ x_out = self.decoder(x_encoder)
63
+ if need_loss:
64
+ # sigma vae for better quality
65
+ log_sigma = ((x - x_out) ** 2).mean([1,2,3], keepdim=True).sqrt().log()
66
+ log_sigma = -6 + F.softplus(log_sigma - (-6))
67
+ rec = 0.5 * torch.pow((x - x_out) / log_sigma.exp(), 2) + log_sigma
68
+ rec = rec.sum(dim=(1,2,3))
69
+ loss = {
70
+ "rec": rec.mean(),
71
+ "kl": kl_loss.mean()}
72
+ return x_out, loss
73
+ else:
74
+ return x_out
75
+
76
+ def decode(self, x):
77
+ x_out = self.decoder(x)
78
+ return x_out
79
+
80
+ #################################################################################
81
+ # AE Zoos #
82
+ #################################################################################
83
+ def ae(**kwargs):
84
+ return AE(output_emb_width=4, width=512, depth=3, ch_mult=(1,1,1), **kwargs)
85
+ def vae(**kwargs):
86
+ return VAE(output_emb_width=4, width=512, depth=3, ch_mult=(1,1,1), **kwargs)
87
+ AE_models = {
88
+ 'AE_Model': ae, 'VAE_Model': vae
89
+ }
90
+ #################################################################################
91
+ # Inner Architectures #
92
+ #################################################################################
93
+ class Encoder(nn.Module):
94
+ def __init__(self, input_emb_width=3, output_emb_width=4, width=512, depth=3, in_ch_mult=(1,1), ch_mult=(1,1)):
95
+ super().__init__()
96
+ self.model = nn.ModuleList()
97
+ self.conv_in = nn.Conv2d(input_emb_width, width, (3, 1), (1, 1), (1, 1))
98
+
99
+ block_in = width * in_ch_mult[0]
100
+ for i in range(len(in_ch_mult)):
101
+ block_in = width * in_ch_mult[i]
102
+ block_out = width * ch_mult[i]
103
+ self.model.append(nn.Conv2d(width, width, (4, 1), (2, 1), (1, 1)))
104
+ for j in range(depth):
105
+ self.model.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dil=2-j))
106
+ block_in = block_out
107
+
108
+ self.conv_out = torch.nn.Conv2d(block_in, output_emb_width, (3, 1), (1, 1), (1, 1))
109
+ def forward(self, x):
110
+ x = self.conv_in(x)
111
+ for layer in self.model:
112
+ x = layer(x)
113
+ x = self.conv_out(x)
114
+ return x
115
+
116
+
117
+ class Decoder(nn.Module):
118
+ def __init__(self, input_emb_width=3, output_emb_width=4, width=512, depth=3, in_ch_mult=(1,1), ch_mult=(1,1)):
119
+ super().__init__()
120
+ self.model = nn.ModuleList()
121
+ block_in = width * ch_mult[0]
122
+ self.conv_in = nn.Conv2d(output_emb_width, block_in, (3,1), (1,1), (1,1))
123
+
124
+ for i in range(len(in_ch_mult)):
125
+ block_in = width * ch_mult[i]
126
+ block_out = width * in_ch_mult[i]
127
+ for j in range(depth):
128
+ self.model.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dil=2-j))
129
+ block_in = block_out
130
+ self.model.append(Upsample(block_in))
131
+
132
+ self.conv_out1 = torch.nn.Conv2d(block_in, block_in, (3, 1), (1,1), (1,1))
133
+ self.conv_out2 = torch.nn.Conv2d(block_in, input_emb_width, (3, 1), (1, 1), (1, 1))
134
+
135
+ def forward(self, x):
136
+ x = self.conv_in(x)
137
+ for layer in self.model:
138
+ x = layer(x)
139
+ x = self.conv_out1(x)
140
+ x = x * torch.sigmoid(x)
141
+ x = self.conv_out2(x)
142
+ return x.permute(0,2,3,1)
143
+
144
+
145
+ class Upsample(nn.Module):
146
+ def __init__(self, in_channels):
147
+ super().__init__()
148
+ self.conv = torch.nn.Conv2d(in_channels, in_channels,(3, 1), (1, 1), (1, 1))
149
+
150
+ def forward(self, x):
151
+ x = torch.nn.functional.interpolate(x, scale_factor=(2.0, 1.0), mode="nearest")
152
+ x = self.conv(x)
153
+ return x
154
+
155
+
156
+ class ResnetBlock(nn.Module):
157
+ def __init__(self, *, in_channels, out_channels=None, dil=0, conv_shortcut=False, dropout=0.2):
158
+ super().__init__()
159
+ self.in_channels = in_channels
160
+ out_channels = in_channels if out_channels is None else out_channels
161
+ self.out_channels = out_channels
162
+ self.use_conv_shortcut = conv_shortcut
163
+
164
+ self.conv1 = torch.nn.Conv2d(in_channels,
165
+ out_channels,
166
+ kernel_size=(3, 1),
167
+ stride=(1, 1),
168
+ padding=(3 ** dil, 0),
169
+ dilation=(3 ** dil, 1),
170
+ )
171
+ self.dropout = torch.nn.Dropout(dropout)
172
+ self.conv2 = torch.nn.Conv2d(out_channels,
173
+ out_channels,
174
+ kernel_size=(1, 1),
175
+ stride=(1, 1),
176
+ padding=(0, 0),
177
+ )
178
+
179
+ def forward(self, x):
180
+ h = x
181
+ h = h*torch.sigmoid(h)
182
+ h = self.conv1(h)
183
+
184
+ h = h*torch.sigmoid(h)
185
+ h = self.conv2(h)
186
+ h = self.dropout(h)
187
+ return x+h
188
+
189
+
190
+ class DiagonalGaussianDistribution(object):
191
+ def __init__(self, parameters, deterministic=False):
192
+ self.parameters = parameters
193
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
194
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
195
+ self.deterministic = deterministic
196
+ self.std = torch.exp(0.5 * self.logvar)
197
+ self.var = torch.exp(self.logvar)
198
+ if self.deterministic:
199
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
200
+
201
+ def sample(self):
202
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
203
+ return x
204
+
205
+ def kl(self, other=None):
206
+ if self.deterministic:
207
+ return torch.Tensor([0.])
208
+ else:
209
+ if other is None:
210
+ return 0.5 * torch.sum(torch.pow(self.mean, 2)
211
+ + self.var - 1.0 - self.logvar,
212
+ dim=[1, 2, 3])
213
+ else:
214
+ return 0.5 * torch.sum(
215
+ torch.pow(self.mean - other.mean, 2) / other.var
216
+ + self.var / other.var - 1.0 - self.logvar + other.logvar,
217
+ dim=[1, 2, 3])
218
+
219
+ def nll(self, sample, dims=[1,2,3]):
220
+ if self.deterministic:
221
+ return torch.Tensor([0.])
222
+ logtwopi = np.log(2.0 * np.pi)
223
+ return 0.5 * torch.sum(
224
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
225
+ dim=dims)
226
+
227
+ def mode(self):
228
+ return self.mean
models/AE_Mesh.py ADDED
@@ -0,0 +1,601 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # A modified version of "Fully Convolutional Mesh Autoencoder using Efficient Spatially Varying Kernels"
2
+ # https://arxiv.org/abs/2006.04325
3
+ # and thanks to this more modern implementation as well
4
+ # https://github.com/g-fiche/Mesh-VQ-VAE
5
+ # https://arxiv.org/abs/2312.08291
6
+ import torch
7
+ import torch.nn as nn
8
+ import numpy as np
9
+ import os
10
+
11
+ #################################################################################
12
+ # AE #
13
+ #################################################################################
14
+ class AE(nn.Module):
15
+ def __init__(self, model, bs=16, num_vertices=6890):
16
+ super().__init__()
17
+ # currently only set up is for SMPL-H
18
+ self.num_vertices = num_vertices
19
+ self.bs=bs
20
+ self.encoder = Encoder(model)
21
+ self.decoder = Decoder(model)
22
+
23
+ def encode(self, x):
24
+ B, L = x.shape[0], x.shape[1]
25
+ x = x.view(B * L, self.num_vertices, 3)
26
+ x_encoder = self.encoder(x)
27
+ return x_encoder
28
+
29
+ def forward(self, x):
30
+ B, L = x.shape[0], x.shape[1]
31
+ x = x.view(B * L, self.num_vertices, 3)
32
+ x_encoder = self.encoder(x)
33
+ x_out = self.decoder(x_encoder)
34
+ x_out = x_out.view(B, L, self.num_vertices, 3)
35
+ return x_out
36
+
37
+
38
+ def decode(self, x):
39
+ T = x.shape[1]
40
+ if x.shape[1] % self.bs != 0:
41
+ x = torch.cat([x, torch.zeros_like(x[:, :self.bs-x.shape[1] % self.bs])], dim=1)
42
+ outputs = []
43
+ for i in range(x.shape[0]):
44
+ outputss = []
45
+ for j in range(0, x.shape[1], self.bs):
46
+ chunk = x[i, j:j + self.bs]
47
+ out = self.decoder(chunk)
48
+ outputss.append(out)
49
+ outputs.append(torch.cat(outputss, dim=0)[:T])
50
+ x_out = torch.stack(outputs, dim=0)
51
+
52
+ return x_out
53
+
54
+ #################################################################################
55
+ # AE Zoos #
56
+ #################################################################################
57
+ def ae(**kwargs):
58
+ config_model = {"batch": 16,
59
+ "connection_folder": "body_models/ConnectionMatrices/",
60
+ "initial_connection_fn": "body_models/ConnectionMatrices/_pool0.npy",
61
+ "connection_layer_lst": ["pool0", "pool1", "pool2", "pool3", "pool4", "pool5", "pool6", "pool7_28",
62
+ "unpool7_28", "unpool6", "unpool5", "unpool4", "unpool3", "unpool2",
63
+ "unpool1", "unpool0"],
64
+ "channel_lst": [64, 64, 128, 128, 256, 256, 512, 12, 512, 256, 256, 128, 128, 64, 64, 3],
65
+ "weight_num_lst": [9, 0, 9, 0, 9, 0, 9, 0, 0, 9, 0, 9, 0, 9, 0, 9],
66
+ "residual_rate_lst": [0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0],
67
+ }
68
+ return AE(FullyConvAE(config_model, **kwargs), bs=config_model["batch"])
69
+ AE_models = {
70
+ 'AE_Model': ae
71
+ }
72
+
73
+
74
+ class Encoder(nn.Module):
75
+ def __init__(self, model):
76
+ super(Encoder, self).__init__()
77
+ self.model = model
78
+
79
+ def forward(self, x):
80
+ out = self.model.forward_till_layer_n(x, len(self.model.channel_lst) // 2)
81
+ return out
82
+
83
+
84
+ class Decoder(nn.Module):
85
+ def __init__(self, model):
86
+ super(Decoder, self).__init__()
87
+ self.model = model
88
+
89
+ def forward(self, x):
90
+ out = self.model.forward_from_layer_n(x, len(self.model.channel_lst) // 2)
91
+ return out
92
+
93
+ class FullyConvAE(nn.Module):
94
+ def __init__(
95
+ self, config_model=None, test_mode=False
96
+ ): # layer_info_lst= [(point_num, feature_dim)]
97
+ super(FullyConvAE, self).__init__()
98
+
99
+ self.test_mode = test_mode
100
+
101
+ self.channel_lst = config_model["channel_lst"]
102
+
103
+ self.residual_rate_lst = config_model["residual_rate_lst"]
104
+
105
+ self.weight_num_lst = config_model["weight_num_lst"]
106
+
107
+ self.initial_connection_fn = config_model["initial_connection_fn"]
108
+
109
+ data = np.load(self.initial_connection_fn)
110
+ neighbor_id_dist_lstlst = data[:, 1:] # point_num*(1+2*neighbor_num)
111
+ self.point_num = data.shape[0]
112
+ self.neighbor_id_lstlst = neighbor_id_dist_lstlst.reshape(
113
+ (self.point_num, -1, 2)
114
+ )[
115
+ :, :, 0
116
+ ] # point_num*neighbor_num
117
+ self.neighbor_num_lst = np.array(data[:, 0]) # point_num
118
+
119
+ self.relu = nn.ELU()
120
+
121
+ self.batch = config_model["batch"]
122
+
123
+ #####For Laplace computation######
124
+ self.initial_neighbor_id_lstlst = torch.LongTensor(
125
+ self.neighbor_id_lstlst
126
+ ).cuda() # point_num*max_neighbor_num
127
+ self.initial_neighbor_num_lst = torch.FloatTensor(
128
+ self.neighbor_num_lst
129
+ ).cuda() # point_num
130
+
131
+ self.connection_folder = config_model["connection_folder"]
132
+ self.connection_layer_fn_lst = []
133
+ fn_lst = os.listdir(self.connection_folder)
134
+ self.connection_layer_lst = config_model["connection_layer_lst"]
135
+ for layer_name in self.connection_layer_lst:
136
+ layer_name = "_" + layer_name + "."
137
+
138
+ find_fn = False
139
+ for fn in fn_lst:
140
+ if (layer_name in fn) and ((".npy" in fn) or (".npz" in fn)):
141
+ self.connection_layer_fn_lst += [self.connection_folder + fn]
142
+ find_fn = True
143
+ break
144
+ if find_fn == False:
145
+ print("!!!ERROR: cannot find the connection layer fn")
146
+
147
+ self.init_layers(self.batch)
148
+
149
+ self.initial_max_neighbor_num = self.initial_neighbor_id_lstlst.shape[1]
150
+
151
+ def init_layers(self, batch):
152
+ self.layer_lst = (
153
+ []
154
+ ) ##[in_channel, out_channel, in_pn, out_pn, max_neighbor_num, neighbor_num_lst,neighbor_id_lstlst,conv_layer, residual_layer]
155
+
156
+ self.layer_num = len(self.channel_lst)
157
+
158
+ in_point_num = self.point_num
159
+ in_channel = 3
160
+
161
+ for l in range(self.layer_num):
162
+ out_channel = self.channel_lst[l]
163
+ weight_num = self.weight_num_lst[l]
164
+ residual_rate = self.residual_rate_lst[l]
165
+
166
+ connection_info = np.load(self.connection_layer_fn_lst[l])
167
+ out_point_num = connection_info.shape[0]
168
+ neighbor_num_lst = torch.FloatTensor(
169
+ connection_info[:, 0].astype(float)
170
+ ).cuda() # out_point_num*1
171
+ neighbor_id_dist_lstlst = connection_info[
172
+ :, 1:
173
+ ] # out_point_num*(max_neighbor_num*2)
174
+ print(self.connection_layer_fn_lst[l])
175
+ print()
176
+ neighbor_id_lstlst = neighbor_id_dist_lstlst.reshape(
177
+ (out_point_num, -1, 2)
178
+ )[
179
+ :, :, 0
180
+ ] # out_point_num*max_neighbor_num
181
+ neighbor_id_lstlst = torch.LongTensor(neighbor_id_lstlst).cuda()
182
+ max_neighbor_num = neighbor_id_lstlst.shape[1]
183
+ avg_neighbor_num = round(neighbor_num_lst.mean().item())
184
+ effective_w_weights_rate = neighbor_num_lst.sum() / float(
185
+ max_neighbor_num * out_point_num
186
+ )
187
+ effective_w_weights_rate = round(effective_w_weights_rate.item(), 3)
188
+
189
+ pc_mask = torch.ones(in_point_num + 1).cuda()
190
+ pc_mask[in_point_num] = 0
191
+ neighbor_mask_lst = pc_mask[
192
+ neighbor_id_lstlst
193
+ ].contiguous() # out_pn*max_neighbor_num neighbor is 1 otherwise 0
194
+
195
+ zeros_batch_outpn_outchannel = torch.zeros(
196
+ (batch, out_point_num, out_channel)
197
+ ).cuda()
198
+
199
+ if (residual_rate < 0) or (residual_rate > 1):
200
+ print("Invalid residual rate", residual_rate)
201
+ ####parameters for conv###############
202
+ conv_layer = ""
203
+
204
+ if residual_rate < 1:
205
+ weights = torch.randn(weight_num, out_channel * in_channel).cuda()
206
+
207
+ weights = nn.Parameter(weights).cuda()
208
+
209
+ self.register_parameter("weights" + str(l), weights)
210
+
211
+ bias = nn.Parameter(torch.zeros(out_channel).cuda())
212
+ self.register_parameter("bias" + str(l), bias)
213
+
214
+ w_weights = torch.randn(out_point_num, max_neighbor_num, weight_num) / (
215
+ avg_neighbor_num * weight_num
216
+ )
217
+
218
+ w_weights = nn.Parameter(w_weights.cuda())
219
+ self.register_parameter("w_weights" + str(l), w_weights)
220
+
221
+ conv_layer = (weights, bias, w_weights)
222
+
223
+ ####parameters for residual###############
224
+
225
+ ## a residual layer with out_point_num==in_point_num and residual_rate==1 is a pooling or unpooling layer
226
+
227
+ residual_layer = ""
228
+
229
+ if residual_rate > 0:
230
+ p_neighbors = ""
231
+ weight_res = ""
232
+
233
+ if out_point_num != in_point_num:
234
+ p_neighbors = nn.Parameter(
235
+ (
236
+ torch.randn(out_point_num, max_neighbor_num)
237
+ / (avg_neighbor_num)
238
+ ).cuda()
239
+ )
240
+ self.register_parameter("p_neighbors" + str(l), p_neighbors)
241
+
242
+ if out_channel != in_channel:
243
+ weight_res = torch.randn(out_channel, in_channel)
244
+ # self.normalize_weights(weight_res)
245
+ weight_res = weight_res / out_channel
246
+ weight_res = nn.Parameter(weight_res.cuda())
247
+ self.register_parameter("weight_res" + str(l), weight_res)
248
+
249
+ residual_layer = (weight_res, p_neighbors)
250
+
251
+ #####put everythin together
252
+
253
+ layer = (
254
+ in_channel,
255
+ out_channel,
256
+ in_point_num,
257
+ out_point_num,
258
+ weight_num,
259
+ max_neighbor_num,
260
+ neighbor_num_lst,
261
+ neighbor_id_lstlst,
262
+ conv_layer,
263
+ residual_layer,
264
+ residual_rate,
265
+ neighbor_mask_lst,
266
+ zeros_batch_outpn_outchannel,
267
+ )
268
+
269
+ self.layer_lst += [layer]
270
+
271
+ in_point_num = out_point_num
272
+ in_channel = out_channel
273
+
274
+ # precompute the parameters so as to accelerate forwarding in testing mode
275
+ def init_test_mode(self):
276
+ for l in range(len(self.layer_lst)):
277
+ layer_info = self.layer_lst[l]
278
+
279
+ (
280
+ in_channel,
281
+ out_channel,
282
+ in_pn,
283
+ out_pn,
284
+ weight_num,
285
+ max_neighbor_num,
286
+ neighbor_num_lst,
287
+ neighbor_id_lstlst,
288
+ conv_layer,
289
+ residual_layer,
290
+ residual_rate,
291
+ neighbor_mask_lst,
292
+ zeros_batch_outpn_outchannel,
293
+ ) = layer_info
294
+
295
+ if len(conv_layer) != 0:
296
+ (
297
+ weights,
298
+ bias,
299
+ raw_w_weights,
300
+ ) = conv_layer # weight_num*(out_channel*in_channel) out_point_num* max_neighbor_num* weight_num
301
+
302
+ w_weights = ""
303
+
304
+ w_weights = raw_w_weights * neighbor_mask_lst.view(
305
+ out_pn, max_neighbor_num, 1
306
+ ).repeat(
307
+ 1, 1, weight_num
308
+ ) # out_pn*max_neighbor_num*weight_num
309
+
310
+ weights = torch.einsum(
311
+ "pmw,wc->pmc", [w_weights, weights]
312
+ ) # out_pn*max_neighbor_num*(out_channel*in_channel)
313
+ weights = weights.view(
314
+ out_pn, max_neighbor_num, out_channel, in_channel
315
+ )
316
+
317
+ conv_layer = weights, bias
318
+
319
+ ####compute output of residual layer####
320
+
321
+ if len(residual_layer) != 0:
322
+ (
323
+ weight_res,
324
+ p_neighbors_raw,
325
+ ) = residual_layer # out_channel*in_channel out_pn*max_neighbor_num
326
+ if in_pn != out_pn:
327
+ p_neighbors = torch.abs(p_neighbors_raw) * neighbor_mask_lst
328
+ p_neighbors_sum = p_neighbors.sum(1) + 1e-8 # out_pn
329
+ p_neighbors = p_neighbors / p_neighbors_sum.view(out_pn, 1).repeat(
330
+ 1, max_neighbor_num
331
+ )
332
+
333
+ residual_layer = weight_res, p_neighbors
334
+
335
+ self.layer_lst[l] = (
336
+ in_channel,
337
+ out_channel,
338
+ in_pn,
339
+ out_pn,
340
+ weight_num,
341
+ max_neighbor_num,
342
+ neighbor_num_lst,
343
+ neighbor_id_lstlst,
344
+ conv_layer,
345
+ residual_layer,
346
+ residual_rate,
347
+ neighbor_mask_lst,
348
+ zeros_batch_outpn_outchannel,
349
+ )
350
+
351
+ # a faster mode for testing
352
+ # input_pc batch*in_pn*in_channel
353
+ # out_pc batch*out_pn*out_channel
354
+ def forward_one_conv_layer_batch_during_test(
355
+ self, in_pc, layer_info, is_final_layer=False
356
+ ):
357
+ batch = in_pc.shape[0]
358
+
359
+ (
360
+ in_channel,
361
+ out_channel,
362
+ in_pn,
363
+ out_pn,
364
+ weight_num,
365
+ max_neighbor_num,
366
+ neighbor_num_lst,
367
+ neighbor_id_lstlst,
368
+ conv_layer,
369
+ residual_layer,
370
+ residual_rate,
371
+ neighbor_mask_lst,
372
+ zeros_batch_outpn_outchannel,
373
+ ) = layer_info
374
+
375
+ device = in_pc.get_device()
376
+ if device < 0:
377
+ device = "cpu"
378
+
379
+ in_pc_pad = torch.cat(
380
+ (in_pc, torch.zeros(batch, 1, in_channel).to(device)), 1
381
+ ) # batch*(in_pn+1)*in_channel
382
+
383
+ in_neighbors = in_pc_pad[
384
+ :, neighbor_id_lstlst.to(device)
385
+ ] # batch*out_pn*max_neighbor_num*in_channel
386
+
387
+ ####compute output of convolution layer####
388
+ out_pc_conv = zeros_batch_outpn_outchannel.clone()
389
+
390
+ if len(conv_layer) != 0:
391
+ (
392
+ weights,
393
+ bias,
394
+ ) = conv_layer # weight_num*(out_channel*in_channel) out_point_num* max_neighbor_num* weight_num
395
+
396
+ out_neighbors = torch.einsum(
397
+ "pmoi,bpmi->bpmo", [weights.to(device), in_neighbors]
398
+ ) # batch*out_pn*max_neighbor_num*out_channel
399
+
400
+ out_pc_conv = out_neighbors.sum(2)
401
+
402
+ out_pc_conv = out_pc_conv + bias
403
+
404
+ if is_final_layer == False:
405
+ out_pc_conv = self.relu(
406
+ out_pc_conv
407
+ ) ##self.relu is defined in the init function
408
+
409
+ # if(self.residual_rate==0):
410
+ # return out_pc
411
+ ####compute output of residual layer####
412
+ out_pc_res = zeros_batch_outpn_outchannel.clone()
413
+
414
+ if len(residual_layer) != 0:
415
+ (
416
+ weight_res,
417
+ p_neighbors,
418
+ ) = residual_layer # out_channel*in_channel out_pn*max_neighbor_num
419
+
420
+ if in_channel != out_channel:
421
+ in_pc_pad = torch.einsum("oi,bpi->bpo", [weight_res, in_pc_pad])
422
+
423
+ out_pc_res = []
424
+ if in_pn == out_pn:
425
+ out_pc_res = in_pc_pad[:, 0:in_pn].clone()
426
+ else:
427
+ in_neighbors = in_pc_pad[
428
+ :, neighbor_id_lstlst.to(device)
429
+ ] # batch*out_pn*max_neighbor_num*out_channel
430
+ out_pc_res = torch.einsum(
431
+ "pm,bpmo->bpo", [p_neighbors.to(device), in_neighbors]
432
+ )
433
+
434
+ out_pc = out_pc_conv.to(device) * np.sqrt(1 - residual_rate) + out_pc_res.to(
435
+ device
436
+ ) * np.sqrt(residual_rate)
437
+
438
+ return out_pc
439
+
440
+ # use in train mode. Slower than test mode
441
+ # input_pc batch*in_pn*in_channel
442
+ # out_pc batch*out_pn*out_channel
443
+ def forward_one_conv_layer_batch(self, in_pc, layer_info, is_final_layer=False):
444
+ batch = in_pc.shape[0]
445
+
446
+ (
447
+ in_channel,
448
+ out_channel,
449
+ in_pn,
450
+ out_pn,
451
+ weight_num,
452
+ max_neighbor_num,
453
+ neighbor_num_lst,
454
+ neighbor_id_lstlst,
455
+ conv_layer,
456
+ residual_layer,
457
+ residual_rate,
458
+ neighbor_mask_lst,
459
+ zeros_batch_outpn_outchannel,
460
+ ) = layer_info
461
+
462
+ in_pc_pad = torch.cat(
463
+ (in_pc, torch.zeros(batch, 1, in_channel).cuda()), 1
464
+ ) # batch*(in_pn+1)*in_channel
465
+
466
+ in_neighbors = in_pc_pad[
467
+ :, neighbor_id_lstlst
468
+ ] # batch*out_pn*max_neighbor_num*in_channel
469
+
470
+ ####compute output of convolution layer####
471
+ out_pc_conv = zeros_batch_outpn_outchannel.clone()
472
+
473
+ if len(conv_layer) != 0:
474
+ (
475
+ weights,
476
+ bias,
477
+ raw_w_weights,
478
+ ) = conv_layer # weight_num*(out_channel*in_channel) out_point_num* max_neighbor_num* weight_num
479
+
480
+ w_weights = raw_w_weights * neighbor_mask_lst.view(
481
+ out_pn, max_neighbor_num, 1
482
+ ).repeat(
483
+ 1, 1, weight_num
484
+ ) # out_pn*max_neighbor_num*weight_num
485
+
486
+ weights = torch.einsum(
487
+ "pmw,wc->pmc", [w_weights, weights]
488
+ ) # out_pn*max_neighbor_num*(out_channel*in_channel)
489
+ weights = weights.view(out_pn, max_neighbor_num, out_channel, in_channel)
490
+
491
+ out_neighbors = torch.einsum(
492
+ "pmoi,bpmi->bpmo", [weights, in_neighbors]
493
+ ) # batch*out_pn*max_neighbor_num*out_channel
494
+
495
+ out_pc_conv = out_neighbors.sum(2)
496
+
497
+ out_pc_conv = out_pc_conv + bias
498
+
499
+ if is_final_layer == False:
500
+ out_pc_conv = self.relu(
501
+ out_pc_conv
502
+ ) ##self.relu is defined in the init function
503
+
504
+ ####compute output of residual layer####
505
+ out_pc_res = zeros_batch_outpn_outchannel.clone()
506
+
507
+ if len(residual_layer) != 0:
508
+ (
509
+ weight_res,
510
+ p_neighbors_raw,
511
+ ) = residual_layer # out_channel*in_channel out_pn*max_neighbor_num
512
+
513
+ if in_channel != out_channel:
514
+ in_pc_pad = torch.einsum("oi,bpi->bpo", [weight_res, in_pc_pad])
515
+
516
+ out_pc_res = []
517
+ if in_pn == out_pn:
518
+ out_pc_res = in_pc_pad[:, 0:in_pn].clone()
519
+ else:
520
+ in_neighbors = in_pc_pad[
521
+ :, neighbor_id_lstlst
522
+ ] # batch*out_pn*max_neighbor_num*out_channel
523
+
524
+ p_neighbors = torch.abs(p_neighbors_raw) * neighbor_mask_lst
525
+ p_neighbors_sum = p_neighbors.sum(1) + 1e-8 # out_pn
526
+ p_neighbors = p_neighbors / p_neighbors_sum.view(out_pn, 1).repeat(
527
+ 1, max_neighbor_num
528
+ )
529
+
530
+ out_pc_res = torch.einsum("pm,bpmo->bpo", [p_neighbors, in_neighbors])
531
+
532
+ # print(out_pc_conv.shape, out_pc_res.shape)
533
+ out_pc = out_pc_conv * np.sqrt(1 - residual_rate) + out_pc_res * np.sqrt(
534
+ residual_rate
535
+ )
536
+
537
+ return out_pc
538
+
539
+ def forward_till_layer_n(self, in_pc, layer_n):
540
+ out_pc = in_pc.clone()
541
+
542
+ for i in range(layer_n):
543
+ if self.test_mode == False:
544
+ out_pc = self.forward_one_conv_layer_batch(out_pc, self.layer_lst[i])
545
+ else:
546
+ out_pc = self.forward_one_conv_layer_batch_during_test(
547
+ out_pc, self.layer_lst[i]
548
+ )
549
+
550
+ # out_pc = self.final_linear(out_pc.transpose(1,2)).transpose(1,2) #batch*3*point_num
551
+
552
+ return out_pc
553
+
554
+ def forward_from_layer_n(self, in_pc, layer_n):
555
+ out_pc = in_pc.clone()
556
+
557
+ for i in range(layer_n, self.layer_num):
558
+ if i < (self.layer_num - 1):
559
+ if self.test_mode == False:
560
+ out_pc = self.forward_one_conv_layer_batch(
561
+ out_pc, self.layer_lst[i]
562
+ )
563
+ else:
564
+ out_pc = self.forward_one_conv_layer_batch_during_test(
565
+ out_pc, self.layer_lst[i]
566
+ )
567
+ else:
568
+ if self.test_mode == False:
569
+ out_pc = self.forward_one_conv_layer_batch(
570
+ out_pc, self.layer_lst[i], is_final_layer=True
571
+ )
572
+ else:
573
+ out_pc = self.forward_one_conv_layer_batch_during_test(
574
+ out_pc, self.layer_lst[i], is_final_layer=True
575
+ )
576
+
577
+ return out_pc
578
+
579
+ def forward_layer_n(self, in_pc, layer_n):
580
+ out_pc = in_pc.clone()
581
+
582
+ if layer_n < (self.layer_num - 1):
583
+ if self.test_mode == False:
584
+ out_pc = self.forward_one_conv_layer_batch(
585
+ out_pc, self.layer_lst[layer_n]
586
+ )
587
+ else:
588
+ out_pc = self.forward_one_conv_layer_batch_during_test(
589
+ out_pc, self.layer_lst[layer_n]
590
+ )
591
+ else:
592
+ if self.test_mode == False:
593
+ out_pc = self.forward_one_conv_layer_batch(
594
+ out_pc, self.layer_lst[layer_n], is_final_layer=True
595
+ )
596
+ else:
597
+ out_pc = self.forward_one_conv_layer_batch_during_test(
598
+ out_pc, self.layer_lst[layer_n], is_final_layer=True
599
+ )
600
+
601
+ return out_pc
models/LengthEstimator.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ #################################################################################
4
+ # Length Estimator #
5
+ #################################################################################
6
+ class LengthEstimator(nn.Module):
7
+ def __init__(self, input_size, output_size):
8
+ super(LengthEstimator, self).__init__()
9
+ nd = 512
10
+ self.output = nn.Sequential(
11
+ nn.Linear(input_size, nd),
12
+ nn.LayerNorm(nd),
13
+ nn.LeakyReLU(0.2, inplace=True),
14
+
15
+ nn.Dropout(0.2),
16
+ nn.Linear(nd, nd // 2),
17
+ nn.LayerNorm(nd // 2),
18
+ nn.LeakyReLU(0.2, inplace=True),
19
+
20
+ nn.Dropout(0.2),
21
+ nn.Linear(nd // 2, nd // 4),
22
+ nn.LayerNorm(nd // 4),
23
+ nn.LeakyReLU(0.2, inplace=True),
24
+
25
+ nn.Linear(nd // 4, output_size)
26
+ )
27
+
28
+ self.output.apply(self.__init_weights)
29
+
30
+ def __init_weights(self, module):
31
+ if isinstance(module, (nn.Linear, nn.Embedding)):
32
+ module.weight.data.normal_(mean=0.0, std=0.02)
33
+ if isinstance(module, nn.Linear) and module.bias is not None:
34
+ module.bias.data.zero_()
35
+ elif isinstance(module, nn.LayerNorm):
36
+ module.bias.data.zero_()
37
+ module.weight.data.fill_(1.0)
38
+
39
+ def forward(self, text_emb):
40
+ return self.output(text_emb)
models/ROPE.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+
4
+ class RopeND:
5
+ def __init__(self, head_dim=64, nd=3, max_lens=[1024, 64, 64], nd_split=[2, 1, 1], bases=[1000, 1000, 1000],
6
+ auto_base=True, cache_longer=1):
7
+ self.nd = nd
8
+ self.head_dim = head_dim
9
+ self.max_lens = max_lens
10
+ self.nd_split = nd_split
11
+ self.split_dims = [2 * i * (head_dim // 2 // sum(nd_split)) for i in nd_split]
12
+ assert sum(self.split_dims) == head_dim
13
+ self.auto_base = auto_base
14
+ if auto_base:
15
+ # empirical, make cos(theta) = -1 when length is kL. base = kL/pi
16
+ # And L=1 the difference (1/base)**(1/32) ~ 0.7-0.8 ~ pi/4
17
+ # for traditional L = 4096, 8L/pi = 10.4k, base is set to 10k
18
+ self.bases = [(int(8 * l / math.pi) // 100 + 1) * 100 for l in self.max_lens]
19
+ print(f"Bases for rope: {self.bases}")
20
+ else:
21
+ self.bases = bases
22
+ self.cache_longer = cache_longer
23
+
24
+ def generated_cos_sin_mix2d(self, max_len, dim, device, base=1000):
25
+ inv_freq = 1.0 / (base ** \
26
+ (torch.linspace(start=0, end=self.head_dim, steps=dim // 2,
27
+ device=device).float() / self.head_dim))
28
+ assert inv_freq.size(0) * 2 == dim, f"inv_freq.size(0) = {inv_freq.size(0)}, required dim = {dim}"
29
+
30
+ t = torch.arange(max_len * self.cache_longer, device=device).type_as(inv_freq)
31
+ freqs = torch.einsum("i,j->ij", t, inv_freq)
32
+ freqs = torch.cat([freqs, freqs], dim=1)
33
+ return freqs.cos().to(torch.float), freqs.sin().to(torch.float)
34
+
35
+ def generate_pos_embs_mix2d(self, position_ids, device=None):
36
+ if device is None:
37
+ device = position_ids.device
38
+
39
+ if position_ids.dim() == 1:
40
+ position_ids = position_ids.unsqueeze(0)
41
+
42
+ cos_emb_all, sin_emb_all = [], []
43
+ for i in range(self.nd):
44
+ dim_i = self.split_dims[i]
45
+ base_i = self.bases[i]
46
+ max_len_i = self.max_lens[i]
47
+ if not hasattr(self, f"cos_{i}"):
48
+ _cos, _sin = self.generated_cos_sin_mix2d(max_len=max_len_i, dim=dim_i, device=device, base=base_i)
49
+ setattr(self, f"cos_{i}", _cos)
50
+ setattr(self, f"sin_{i}", _sin)
51
+ cos_emb_all.append(getattr(self, f'cos_{i}')[position_ids[i, :], :])
52
+ sin_emb_all.append(getattr(self, f'sin_{i}')[position_ids[i, :], :])
53
+ cos_emb = torch.cat(cos_emb_all, dim=-1)
54
+ sin_emb = torch.cat(sin_emb_all, dim=-1)
55
+ return cos_emb, sin_emb
56
+
57
+ def __call__(self, q, k, position_ids):
58
+ '''q: N N_head L C
59
+ '''
60
+ cos_emb, sin_emb = self.generate_pos_embs_mix2d(position_ids, device=q.device)
61
+
62
+ def rotate_half(x):
63
+ """Rotates half the hidden dims of the input."""
64
+ x1 = x[..., : x.shape[-1] // 2]
65
+ x2 = x[..., x.shape[-1] // 2:]
66
+ return torch.cat((-x2, x1), dim=-1)
67
+
68
+ def apply_rotary_pos_emb(q, k, cos, sin):
69
+ """Applies Rotary Position Embedding to the query and key tensors.
70
+
71
+ Args:
72
+ q (`torch.Tensor`): The query tensor.
73
+ k (`torch.Tensor`): The key tensor.
74
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
75
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
76
+ Returns:
77
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
78
+ """
79
+ cos = cos.unsqueeze(0).unsqueeze(0)
80
+ sin = sin.unsqueeze(0).unsqueeze(0)
81
+ dtype = q.dtype
82
+ q = q.to(torch.float)
83
+ k = k.to(torch.float)
84
+ q_embed = (q * cos) + (rotate_half(q) * sin)
85
+ k_embed = (k * cos) + (rotate_half(k) * sin)
86
+ q_embed = q_embed.to(dtype)
87
+ k_embed = k_embed.to(dtype)
88
+ return q_embed, k_embed
89
+
90
+ q, k = apply_rotary_pos_emb(q, k, cos_emb, sin_emb)
91
+ return q, k
models/__pycache__/ACMDM.cpython-310.pyc ADDED
Binary file (14.9 kB). View file
 
models/__pycache__/ACMDM.cpython-313.pyc ADDED
Binary file (28.7 kB). View file
 
models/__pycache__/AE_2D_Causal.cpython-310.pyc ADDED
Binary file (8.63 kB). View file
 
models/__pycache__/AE_2D_Causal.cpython-313.pyc ADDED
Binary file (15.7 kB). View file
 
models/__pycache__/LengthEstimator.cpython-310.pyc ADDED
Binary file (1.44 kB). View file
 
models/__pycache__/ROPE.cpython-310.pyc ADDED
Binary file (3.85 kB). View file