sourxbhh commited on
Commit
82a6034
·
1 Parent(s): b24a5b7

Initial deployment: ACMDM Motion Generation

Browse files
Files changed (37) hide show
  1. .gitattributes +1 -0
  2. README.md +76 -5
  3. app.py +604 -0
  4. checkpoints/t2m/ACMDM_Flow_S_PatchSize22/model/latest.tar +3 -0
  5. checkpoints/t2m/AE_2D_Causal/AE_2D_Causal_Post_Mean.npy +3 -0
  6. checkpoints/t2m/AE_2D_Causal/AE_2D_Causal_Post_Std.npy +3 -0
  7. checkpoints/t2m/AE_2D_Causal/model/latest.tar +3 -0
  8. checkpoints/t2m/length_estimator/model/finest.tar +3 -0
  9. requirements.txt +30 -0
  10. utils/22x3_mean_std/t2m/22x3_mean.npy +3 -0
  11. utils/22x3_mean_std/t2m/22x3_std.npy +3 -0
  12. utils/__pycache__/back_process.cpython-310.pyc +0 -0
  13. utils/__pycache__/eval_utils.cpython-310.pyc +0 -0
  14. utils/__pycache__/evaluators.cpython-310.pyc +0 -0
  15. utils/__pycache__/glove.cpython-310.pyc +0 -0
  16. utils/__pycache__/motion_process.cpython-310.pyc +0 -0
  17. utils/__pycache__/motion_process.cpython-313.pyc +0 -0
  18. utils/__pycache__/quaternion.cpython-310.pyc +0 -0
  19. utils/__pycache__/skeleton.cpython-310.pyc +0 -0
  20. utils/__pycache__/train_utils.cpython-310.pyc +0 -0
  21. utils/back_process.py +255 -0
  22. utils/cal_ae_post_mean_std.py +64 -0
  23. utils/cal_mean_std.py +35 -0
  24. utils/cal_mesh_ae_post_mean_std.py +107 -0
  25. utils/datasets.py +452 -0
  26. utils/eval_mean_std/t2m/eval_mean.npy +3 -0
  27. utils/eval_mean_std/t2m/eval_std.npy +3 -0
  28. utils/eval_utils.py +928 -0
  29. utils/evaluators.py +392 -0
  30. utils/glove.py +83 -0
  31. utils/mesh_mean_std/t2m/mesh_mean.npy +3 -0
  32. utils/mesh_mean_std/t2m/mesh_std.npy +3 -0
  33. utils/motion_process.py +429 -0
  34. utils/quaternion.py +519 -0
  35. utils/skeleton.py +194 -0
  36. utils/train_utils.py +105 -0
  37. utils/wandb_utils.py +53 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ checkpoints/**/*.tar filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,12 +1,83 @@
1
  ---
2
  title: ACMDM Motion Generation
3
- emoji: 🌖
4
- colorFrom: yellow
5
- colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 5.49.1
8
  app_file: app.py
9
  pinned: false
 
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: ACMDM Motion Generation
3
+ emoji: 🎭
4
+ colorFrom: blue
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 4.0.0
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
11
+ hardware: gpu-t4-small
12
  ---
13
 
14
+ # ACMDM Motion Generation
15
+
16
+ Generate human motion animations from text descriptions using the ACMDM (Absolute Coordinates Make Motion Generation Easy) model.
17
+
18
+ ## 🎯 Features
19
+
20
+ - **Text-to-Motion Generation**: Create realistic human motion from natural language descriptions
21
+ - **Batch Processing**: Generate multiple motions at once
22
+ - **Auto-Length Estimation**: AI automatically determines optimal motion length
23
+ - **Flexible Parameters**: Adjust CFG scale, motion length, and more
24
+ - **Real-time Preview**: See your generated motions instantly
25
+
26
+ ## 🚀 Usage
27
+
28
+ 1. **Enter a text description** of the motion you want (e.g., "A person is running on a treadmill.")
29
+ 2. **Adjust parameters** (optional):
30
+ - Motion length (40-196 frames)
31
+ - CFG scale (controls text alignment)
32
+ - Auto-length estimation
33
+ 3. **Click "Generate Motion"**
34
+ 4. **View and download** your generated motion video
35
+
36
+ ## 📝 Example Prompts
37
+
38
+ - "A person is running on a treadmill."
39
+ - "Someone is doing jumping jacks."
40
+ - "A person walks forward and then turns around."
41
+ - "A person is dancing energetically."
42
+
43
+ ## ⚙️ Parameters
44
+
45
+ - **Motion Length**: Number of frames (40-196). Automatically rounded to multiples of 4.
46
+ - **CFG Scale**: Classifier-free guidance scale (1.0-10.0). Higher = more text-aligned, Lower = more diverse.
47
+ - **Auto-length**: Let AI estimate the optimal motion length based on your text.
48
+
49
+ ## 🔧 Technical Details
50
+
51
+ This space uses pre-trained ACMDM models:
52
+ - **Autoencoder**: AE_2D_Causal
53
+ - **Diffusion Model**: ACMDM_Flow_S_PatchSize22
54
+ - **Dataset**: HumanML3D (t2m)
55
+
56
+ ## 📚 Paper
57
+
58
+ [Absolute Coordinates Make Motion Generation Easy](https://arxiv.org/abs/2505.19377)
59
+
60
+ ## 🤝 Citation
61
+
62
+ ```bibtex
63
+ @article{meng2025absolute,
64
+ title={Absolute Coordinates Make Motion Generation Easy},
65
+ author={Meng, Zichong and Han, Zeyu and Peng, Xiaogang and Xie, Yiming and Jiang, Huaizu},
66
+ journal={arXiv preprint arXiv:2505.19377},
67
+ year={2025}
68
+ }
69
+ ```
70
+
71
+ ## ⚠️ Notes
72
+
73
+ - First generation may take 30-60 seconds (model loading)
74
+ - Subsequent generations are faster (5-15 seconds)
75
+ - GPU recommended for best performance
76
+ - Works on CPU but slower
77
+
78
+ ## 🔗 Links
79
+
80
+ - [GitHub Repository](https://github.com/neu-vi/ACMDM)
81
+ - [Project Page](https://neu-vi.github.io/ACMDM/)
82
+ - [Paper](https://arxiv.org/abs/2505.19377)
83
+
app.py ADDED
@@ -0,0 +1,604 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio Web Interface for ACMDM Motion Generation
3
+ Milestone 3: User-friendly web interface for text-to-motion generation
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ from os.path import join as pjoin
9
+ import torch
10
+ import numpy as np
11
+ import gradio as gr
12
+ from typing import Optional, Tuple, List
13
+ import tempfile
14
+ import random
15
+
16
+ # Import from sample.py
17
+ from models.AE_2D_Causal import AE_models
18
+ from models.ACMDM import ACMDM_models
19
+ from models.LengthEstimator import LengthEstimator
20
+ from utils.back_process import back_process
21
+ from utils.motion_process import plot_3d_motion, t2m_kinematic_chain
22
+
23
+ # Global variables for model caching
24
+ _models_cache = {
25
+ 'ae': None,
26
+ 'acmdm': None,
27
+ 'length_estimator': None,
28
+ 'stats': None,
29
+ 'device': None,
30
+ 'loaded': False
31
+ }
32
+
33
+
34
+ def set_seed(seed):
35
+ """Set random seed for reproducibility"""
36
+ random.seed(seed)
37
+ np.random.seed(seed)
38
+ torch.manual_seed(seed)
39
+ torch.backends.cudnn.benchmark = False
40
+
41
+
42
+ def load_models_cached(
43
+ gpu_id: int = 0,
44
+ model_name: str = 'ACMDM_Flow_S_PatchSize22',
45
+ ae_name: str = 'AE_2D_Causal',
46
+ ae_model: str = 'AE_Model',
47
+ model_type: str = 'ACMDM-Flow-S-PatchSize22',
48
+ dataset_name: str = 't2m',
49
+ checkpoints_dir: str = './checkpoints',
50
+ use_length_estimator: bool = False
51
+ ) -> Tuple[dict, str]:
52
+ """
53
+ Load models with caching to avoid reloading on every request.
54
+ Returns (models_dict, status_message)
55
+ """
56
+ global _models_cache
57
+
58
+ # Check if models are already loaded
59
+ if _models_cache['loaded']:
60
+ return _models_cache, "Models already loaded (using cache)"
61
+
62
+ try:
63
+ # Determine device
64
+ if gpu_id >= 0 and torch.cuda.is_available():
65
+ device = torch.device(f"cuda:{gpu_id}")
66
+ else:
67
+ device = torch.device("cpu")
68
+
69
+ _models_cache['device'] = device
70
+ status_messages = [f"Using device: {device}"]
71
+
72
+ # Load AE
73
+ status_messages.append("Loading AE model...")
74
+ ae = AE_models[ae_model](input_width=3)
75
+ ae_ckpt_path = pjoin(checkpoints_dir, dataset_name, ae_name, 'model', 'latest.tar')
76
+
77
+ if not os.path.exists(ae_ckpt_path):
78
+ return None, f"Error: AE checkpoint not found at {ae_ckpt_path}"
79
+
80
+ ae_ckpt = torch.load(ae_ckpt_path, map_location='cpu')
81
+ ae.load_state_dict(ae_ckpt['ae'])
82
+ ae.eval()
83
+ ae.to(device)
84
+ _models_cache['ae'] = ae
85
+ status_messages.append(f"✓ Loaded AE from {ae_ckpt_path}")
86
+
87
+ # Load ACMDM
88
+ status_messages.append("Loading ACMDM model...")
89
+ acmdm = ACMDM_models[model_type](input_dim=ae.output_emb_width, cond_mode='text')
90
+ acmdm_ckpt_path = pjoin(checkpoints_dir, dataset_name, model_name, 'model', 'latest.tar')
91
+
92
+ if not os.path.exists(acmdm_ckpt_path):
93
+ return None, f"Error: ACMDM checkpoint not found at {acmdm_ckpt_path}"
94
+
95
+ acmdm_ckpt = torch.load(acmdm_ckpt_path, map_location='cpu')
96
+ missing_keys, unexpected_keys = acmdm.load_state_dict(acmdm_ckpt['ema_acmdm'], strict=False)
97
+ assert len(unexpected_keys) == 0
98
+ assert all([k.startswith('clip_model.') for k in missing_keys])
99
+ acmdm.eval()
100
+ acmdm.to(device)
101
+ _models_cache['acmdm'] = acmdm
102
+ status_messages.append(f"✓ Loaded ACMDM from {acmdm_ckpt_path}")
103
+
104
+ # Load LengthEstimator if needed
105
+ length_estimator = None
106
+ if use_length_estimator:
107
+ status_messages.append("Loading LengthEstimator...")
108
+ length_estimator = LengthEstimator(input_size=512, output_size=1)
109
+ length_estimator_path = pjoin(checkpoints_dir, dataset_name, 'length_estimator', 'model', 'latest.tar')
110
+ if os.path.exists(length_estimator_path):
111
+ length_ckpt = torch.load(length_estimator_path, map_location='cpu')
112
+ length_estimator.load_state_dict(length_ckpt['model'])
113
+ length_estimator.eval()
114
+ length_estimator.to(device)
115
+ _models_cache['length_estimator'] = length_estimator
116
+ status_messages.append(f"✓ Loaded LengthEstimator")
117
+ else:
118
+ status_messages.append(f"⚠ LengthEstimator not found, will use default length")
119
+
120
+ # Load normalization stats
121
+ status_messages.append("Loading normalization statistics...")
122
+ after_mean = np.load(pjoin(checkpoints_dir, dataset_name, ae_name, 'AE_2D_Causal_Post_Mean.npy'))
123
+ after_std = np.load(pjoin(checkpoints_dir, dataset_name, ae_name, 'AE_2D_Causal_Post_Std.npy'))
124
+ joint_mean = np.load(f'utils/22x3_mean_std/{dataset_name}/22x3_mean.npy')
125
+ joint_std = np.load(f'utils/22x3_mean_std/{dataset_name}/22x3_std.npy')
126
+ eval_mean = np.load(f'utils/eval_mean_std/{dataset_name}/eval_mean.npy')
127
+ eval_std = np.load(f'utils/eval_mean_std/{dataset_name}/eval_std.npy')
128
+
129
+ _models_cache['stats'] = {
130
+ 'after_mean': after_mean,
131
+ 'after_std': after_std,
132
+ 'joint_mean': joint_mean,
133
+ 'joint_std': joint_std,
134
+ 'eval_mean': eval_mean,
135
+ 'eval_std': eval_std
136
+ }
137
+
138
+ _models_cache['loaded'] = True
139
+ status_message = "\n".join(status_messages)
140
+ return _models_cache, status_message
141
+
142
+ except Exception as e:
143
+ error_msg = f"Error loading models: {str(e)}"
144
+ import traceback
145
+ error_msg += f"\n\nTraceback:\n{traceback.format_exc()}"
146
+ return None, error_msg
147
+
148
+
149
+ def estimate_motion_length(text: str, models_cache: dict) -> int:
150
+ """Estimate motion length from text using LengthEstimator"""
151
+ if models_cache['length_estimator'] is None:
152
+ return None
153
+
154
+ device = models_cache['device']
155
+ acmdm = models_cache['acmdm']
156
+ length_estimator = models_cache['length_estimator']
157
+
158
+ with torch.no_grad():
159
+ text_emb = acmdm.encode_text([text])
160
+ pred_length = length_estimator(text_emb)
161
+ pred_length = int(pred_length.item() * 4)
162
+ pred_length = ((pred_length + 2) // 4) * 4
163
+ pred_length = max(40, min(196, pred_length))
164
+ return pred_length
165
+
166
+
167
+ def generate_motion_single(
168
+ text: str,
169
+ motion_length: Optional[int],
170
+ cfg_scale: float,
171
+ use_auto_length: bool,
172
+ gpu_id: int,
173
+ seed: int
174
+ ) -> Tuple[Optional[str], str]:
175
+ """
176
+ Generate a single motion from text.
177
+ Returns (video_path, status_message)
178
+ """
179
+ global _models_cache
180
+
181
+ try:
182
+ set_seed(seed)
183
+
184
+ # Load models if not cached
185
+ if not _models_cache['loaded']:
186
+ models_cache, load_msg = load_models_cached(
187
+ gpu_id=gpu_id,
188
+ use_length_estimator=use_auto_length
189
+ )
190
+ if models_cache is None:
191
+ return None, f"Failed to load models:\n{load_msg}"
192
+ else:
193
+ models_cache = _models_cache
194
+
195
+ device = models_cache['device']
196
+ ae = models_cache['ae']
197
+ acmdm = models_cache['acmdm']
198
+ stats = models_cache['stats']
199
+
200
+ # Estimate length if needed
201
+ if motion_length is None or (motion_length == 0 and use_auto_length):
202
+ if use_auto_length and models_cache['length_estimator'] is not None:
203
+ motion_length = estimate_motion_length(text, models_cache)
204
+ status_msg = f"Estimated motion length: {motion_length} frames\n"
205
+ else:
206
+ motion_length = 120 # Default
207
+ status_msg = f"Using default motion length: {motion_length} frames\n"
208
+ else:
209
+ # Round to multiple of 4
210
+ motion_length = ((motion_length + 2) // 4) * 4
211
+ status_msg = f"Using specified motion length: {motion_length} frames\n"
212
+
213
+ status_msg += f"Generating motion for: '{text}'...\n"
214
+
215
+ # Generate motion
216
+ with torch.no_grad():
217
+ latent_length = motion_length // 4
218
+ m_lens = torch.tensor([latent_length], device=device)
219
+ pred_latents = acmdm.generate([text], m_lens, cfg_scale)
220
+
221
+ # Denormalize latents
222
+ pred_latents_np = pred_latents.permute(0, 2, 3, 1).detach().cpu().numpy()
223
+ pred_latents_np = pred_latents_np * stats['after_std'] + stats['after_mean']
224
+ pred_latents_tensor = torch.from_numpy(pred_latents_np).to(device)
225
+
226
+ # Decode through AE
227
+ pred_motions = ae.decode(pred_latents_tensor.permute(0, 3, 1, 2))
228
+
229
+ # Denormalize motions
230
+ pred_motions_np = pred_motions.permute(0, 2, 3, 1).detach().cpu().numpy()
231
+ if stats['joint_mean'].ndim == 1:
232
+ pred_motions_np = pred_motions_np * stats['joint_std'][np.newaxis, np.newaxis, :, np.newaxis] + stats['joint_mean'][np.newaxis, np.newaxis, :, np.newaxis]
233
+ else:
234
+ pred_motions_np = pred_motions_np * stats['joint_std'][np.newaxis, ..., np.newaxis] + stats['joint_mean'][np.newaxis, ..., np.newaxis]
235
+
236
+ # Back process to get RIC format, then recover joint positions
237
+ from utils.motion_process import recover_from_ric
238
+
239
+ motion = pred_motions_np[0] # (22, 3, seq_len)
240
+ motion = motion[:, :, :motion_length].transpose(2, 0, 1) # (seq_len, 22, 3)
241
+ ric_data = back_process(motion, is_mesh=False)
242
+ ric_tensor = torch.from_numpy(ric_data).float()
243
+ joints = recover_from_ric(ric_tensor, joints_num=22).numpy() # (seq_len, 22, 3)
244
+
245
+ # Create temporary file for video
246
+ temp_dir = tempfile.mkdtemp()
247
+ video_path = pjoin(temp_dir, 'motion.mp4')
248
+
249
+ # Generate video
250
+ plot_3d_motion(
251
+ video_path,
252
+ t2m_kinematic_chain,
253
+ joints,
254
+ title=text,
255
+ fps=20,
256
+ radius=4
257
+ )
258
+
259
+ status_msg += "✓ Motion generated successfully!"
260
+ return video_path, status_msg
261
+
262
+ except Exception as e:
263
+ error_msg = f"Error generating motion: {str(e)}"
264
+ import traceback
265
+ error_msg += f"\n\nTraceback:\n{traceback.format_exc()}"
266
+ return None, error_msg
267
+
268
+
269
+ def generate_motion_batch(
270
+ text_file_content: str,
271
+ cfg_scale: float,
272
+ use_auto_length: bool,
273
+ gpu_id: int,
274
+ seed: int
275
+ ) -> Tuple[List[Optional[str]], str]:
276
+ """
277
+ Generate motions from a batch of text prompts.
278
+ Returns (list_of_video_paths, status_message)
279
+ """
280
+ # Parse text file
281
+ text_prompts = []
282
+ for line in text_file_content.strip().split('\n'):
283
+ line = line.strip()
284
+ if not line:
285
+ continue
286
+ if '#' in line:
287
+ parts = line.split('#')
288
+ text = parts[0].strip()
289
+ length_str = parts[1].strip() if len(parts) > 1 else 'NA'
290
+ else:
291
+ text = line
292
+ length_str = 'NA'
293
+
294
+ if length_str.upper() == 'NA':
295
+ motion_length = None if use_auto_length else 120
296
+ else:
297
+ try:
298
+ motion_length = int(length_str)
299
+ motion_length = ((motion_length + 2) // 4) * 4
300
+ except:
301
+ motion_length = None if use_auto_length else 120
302
+
303
+ text_prompts.append((text, motion_length))
304
+
305
+ if not text_prompts:
306
+ return [], "No valid prompts found in file"
307
+
308
+ status_msg = f"Processing {len(text_prompts)} prompts...\n"
309
+ video_paths = []
310
+
311
+ for idx, (text, motion_length) in enumerate(text_prompts):
312
+ video_path, gen_msg = generate_motion_single(
313
+ text, motion_length, cfg_scale, use_auto_length, gpu_id, seed + idx
314
+ )
315
+ video_paths.append(video_path)
316
+ status_msg += f"\n[{idx+1}/{len(text_prompts)}] {gen_msg}\n"
317
+
318
+ return video_paths, status_msg
319
+
320
+
321
+ # Gradio Interface
322
+ def create_interface():
323
+ """Create and configure the Gradio interface"""
324
+
325
+ with gr.Blocks(title="ACMDM Motion Generation", theme=gr.themes.Soft()) as app:
326
+ gr.Markdown("""
327
+ # 🎭 ACMDM Motion Generation
328
+
329
+ Generate human motion from text descriptions using the ACMDM (Absolute Coordinates Make Motion Generation Easy) model.
330
+
331
+ **How to use:**
332
+ 1. Enter a text description of the motion you want to generate
333
+ 2. Adjust motion length (or use auto-estimate)
334
+ 3. Click "Generate Motion" to create the animation
335
+ 4. View and download the generated video
336
+
337
+ **Example prompts:**
338
+ - "A person is running on a treadmill."
339
+ - "Someone is doing jumping jacks."
340
+ - "A person walks forward and then turns around."
341
+ """)
342
+
343
+ with gr.Tabs():
344
+ # Single Generation Tab
345
+ with gr.Tab("Single Motion Generation"):
346
+ with gr.Row():
347
+ with gr.Column(scale=1):
348
+ text_input = gr.Textbox(
349
+ label="Motion Description",
350
+ placeholder="A person is running on a treadmill.",
351
+ lines=3,
352
+ value="A person is running on a treadmill."
353
+ )
354
+
355
+ with gr.Row():
356
+ motion_length = gr.Slider(
357
+ label="Motion Length (frames)",
358
+ minimum=40,
359
+ maximum=196,
360
+ value=120,
361
+ step=4,
362
+ info="Will be rounded to nearest multiple of 4"
363
+ )
364
+ use_auto_length = gr.Checkbox(
365
+ label="Auto-estimate length",
366
+ value=False,
367
+ info="Use length estimator (ignores manual length if checked)"
368
+ )
369
+
370
+ cfg_scale = gr.Slider(
371
+ label="CFG Scale",
372
+ minimum=1.0,
373
+ maximum=10.0,
374
+ value=3.0,
375
+ step=0.5,
376
+ info="Classifier-free guidance scale (higher = more aligned with text)"
377
+ )
378
+
379
+ with gr.Row():
380
+ gpu_id = gr.Number(
381
+ label="GPU ID",
382
+ value=0,
383
+ precision=0,
384
+ info="Use -1 for CPU"
385
+ )
386
+ seed = gr.Number(
387
+ label="Random Seed",
388
+ value=3407,
389
+ precision=0
390
+ )
391
+
392
+ generate_btn = gr.Button("Generate Motion", variant="primary", size="lg")
393
+
394
+ with gr.Column(scale=1):
395
+ video_output = gr.Video(
396
+ label="Generated Motion",
397
+ format="mp4"
398
+ )
399
+ status_output = gr.Textbox(
400
+ label="Status",
401
+ lines=10,
402
+ interactive=False
403
+ )
404
+
405
+ # Update model status after generation
406
+ def generate_and_update_status(text, motion_length, cfg_scale, use_auto_length, gpu_id, seed):
407
+ video_path, status_msg = generate_motion_single(
408
+ text, motion_length, cfg_scale, use_auto_length, gpu_id, seed
409
+ )
410
+ # Return video and status, plus trigger status update
411
+ return video_path, status_msg
412
+
413
+ generate_btn.click(
414
+ fn=generate_and_update_status,
415
+ inputs=[text_input, motion_length, cfg_scale, use_auto_length, gpu_id, seed],
416
+ outputs=[video_output, status_output]
417
+ )
418
+
419
+ # Batch Generation Tab
420
+ with gr.Tab("Batch Generation"):
421
+ with gr.Row():
422
+ with gr.Column(scale=1):
423
+ batch_text_input = gr.Textbox(
424
+ label="Text Prompts (one per line, format: text#length or text#NA)",
425
+ placeholder="A person is running on a treadmill.#120\nSomeone is doing jumping jacks.#NA",
426
+ lines=10,
427
+ info="Each line: 'text#length' or 'text#NA' for auto-estimate"
428
+ )
429
+
430
+ batch_cfg_scale = gr.Slider(
431
+ label="CFG Scale",
432
+ minimum=1.0,
433
+ maximum=10.0,
434
+ value=3.0,
435
+ step=0.5
436
+ )
437
+
438
+ batch_use_auto_length = gr.Checkbox(
439
+ label="Auto-estimate length for NA",
440
+ value=True
441
+ )
442
+
443
+ batch_gpu_id = gr.Number(
444
+ label="GPU ID",
445
+ value=0,
446
+ precision=0
447
+ )
448
+
449
+ batch_seed = gr.Number(
450
+ label="Random Seed",
451
+ value=3407,
452
+ precision=0
453
+ )
454
+
455
+ batch_generate_btn = gr.Button("Generate Batch", variant="primary", size="lg")
456
+
457
+ with gr.Column(scale=1):
458
+ batch_status_output = gr.Textbox(
459
+ label="Batch Status",
460
+ lines=15,
461
+ interactive=False
462
+ )
463
+ batch_video_gallery = gr.Gallery(
464
+ label="Generated Motions",
465
+ show_label=True,
466
+ elem_id="gallery",
467
+ columns=2,
468
+ rows=2,
469
+ height="auto"
470
+ )
471
+
472
+ batch_generate_btn.click(
473
+ fn=generate_motion_batch,
474
+ inputs=[batch_text_input, batch_cfg_scale, batch_use_auto_length, batch_gpu_id, batch_seed],
475
+ outputs=[batch_video_gallery, batch_status_output]
476
+ )
477
+
478
+ # Model Management Tab
479
+ with gr.Tab("Model Management"):
480
+ with gr.Row():
481
+ with gr.Column():
482
+ model_status = gr.Textbox(
483
+ label="Model Status",
484
+ lines=15,
485
+ interactive=False,
486
+ value="⏳ Models not loaded yet. They will be loaded automatically on first generation."
487
+ )
488
+ with gr.Row():
489
+ refresh_status_btn = gr.Button("🔄 Refresh Status", variant="primary")
490
+ reload_models_btn = gr.Button("🗑️ Clear Cache", variant="secondary")
491
+
492
+ gr.Markdown("""
493
+ **Model Configuration:**
494
+ - Model Name: `ACMDM_Flow_S_PatchSize22`
495
+ - AE Name: `AE_2D_Causal`
496
+ - Dataset: `t2m`
497
+ - Checkpoints Directory: `./checkpoints`
498
+ """)
499
+
500
+ def check_model_status():
501
+ """Check and display current model status"""
502
+ global _models_cache
503
+ if _models_cache['loaded']:
504
+ device = _models_cache['device']
505
+ status_lines = [
506
+ "✓ MODELS LOADED AND READY",
507
+ "=" * 50,
508
+ f"📱 Device: {device}",
509
+ f"💾 CUDA Available: {torch.cuda.is_available()}",
510
+ ]
511
+
512
+ if device.type == 'cuda':
513
+ status_lines.append(f"🎮 GPU: {torch.cuda.get_device_name(device.index if device.index is not None else 0)}")
514
+ status_lines.append(f"💾 GPU Memory: {torch.cuda.get_device_properties(device.index if device.index is not None else 0).total_memory / 1e9:.2f} GB")
515
+
516
+ status_lines.extend([
517
+ "",
518
+ "📦 Loaded Models:",
519
+ " ✓ Autoencoder (AE_2D_Causal)",
520
+ " ✓ ACMDM Diffusion Model",
521
+ ])
522
+
523
+ if _models_cache['length_estimator'] is not None:
524
+ status_lines.append(" ✓ Length Estimator")
525
+ else:
526
+ status_lines.append(" ⚠ Length Estimator (not loaded)")
527
+
528
+ status_lines.extend([
529
+ "",
530
+ "📊 Statistics Loaded:",
531
+ " ✓ Post-AE Mean/Std",
532
+ " ✓ Joint Mean/Std",
533
+ " ✓ Eval Mean/Std",
534
+ "",
535
+ "✨ Models are cached and ready for generation!",
536
+ "💡 Tip: Use 'Clear Cache' to force reload on next generation."
537
+ ])
538
+
539
+ return "\n".join(status_lines)
540
+ else:
541
+ return (
542
+ "⏳ Models not loaded yet. They will be loaded automatically on first generation.\n"
543
+ "📝 To load models now, go to 'Single Motion Generation' tab and click 'Generate Motion'.\n"
544
+ "⏱️ First generation will take 30-60 seconds (model loading time).\n"
545
+ "⚡ Subsequent generations will be much faster (5-15 seconds)."
546
+ )
547
+
548
+ def reload_models():
549
+ """Clear model cache"""
550
+ global _models_cache
551
+ _models_cache['loaded'] = False
552
+ _models_cache['ae'] = None
553
+ _models_cache['acmdm'] = None
554
+ _models_cache['length_estimator'] = None
555
+ _models_cache['stats'] = None
556
+ _models_cache['device'] = None
557
+ return (
558
+ "🗑️ Model cache cleared. Models will be automatically reloaded on your next generation request.\n"
559
+ "💡 Click 'Refresh Status' after generating to see updated status."
560
+ )
561
+
562
+ # Button callbacks
563
+ refresh_status_btn.click(
564
+ fn=check_model_status,
565
+ outputs=model_status
566
+ )
567
+
568
+ reload_models_btn.click(
569
+ fn=reload_models,
570
+ outputs=model_status
571
+ )
572
+
573
+ # Update status on tab load
574
+ app.load(
575
+ fn=check_model_status,
576
+ outputs=model_status
577
+ )
578
+
579
+ gr.Markdown("""
580
+ ---
581
+ **Note:** First generation may take longer as models need to be loaded. Subsequent generations will be faster.
582
+
583
+ **Tips:**
584
+ - Use descriptive text prompts for better results
585
+ - Adjust CFG scale: higher values (3-5) for more text alignment, lower values (1-2) for more diversity
586
+ - Motion length should be a multiple of 4 (automatically rounded)
587
+ - For batch processing, use the format: `text description#length` or `text description#NA`
588
+ """)
589
+
590
+ return app
591
+
592
+
593
+ if __name__ == "__main__":
594
+ # Create and launch the interface
595
+ app = create_interface()
596
+
597
+ # Launch with sharing option
598
+ app.launch(
599
+ server_name="0.0.0.0", # Allow external connections
600
+ server_port=7860, # Default Gradio port
601
+ share=False, # Set to True for public link (requires ngrok)
602
+ show_error=True
603
+ )
604
+
checkpoints/t2m/ACMDM_Flow_S_PatchSize22/model/latest.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0b24050576ffaa92657e2cc146d4b083579a248be6757bd5f7eeaf98a1c9dd3e
3
+ size 312829822
checkpoints/t2m/AE_2D_Causal/AE_2D_Causal_Post_Mean.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:49354ec5e7cfc75121d3d4ddf323364debd86e8815e9a153176c9d538b1fbf17
3
+ size 144
checkpoints/t2m/AE_2D_Causal/AE_2D_Causal_Post_Std.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b2177734e42141c5c27f6b262b0e379df12bd664ec4a3986fd35d1933ccce4b8
3
+ size 144
checkpoints/t2m/AE_2D_Causal/model/latest.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7c9601098c69561fcc12a7e91762c3caa03e5ab03e2c0488f44b2c4edbd08ee3
3
+ size 204997782
checkpoints/t2m/length_estimator/model/finest.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:90512e0e893186e3b09b24b006bfbf51f1ad71ac4c6626c9b1309373db675d12
3
+ size 1745581
requirements.txt ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Requirements for HuggingFace Spaces Deployment
2
+ # This file is used by HuggingFace Spaces to install dependencies
3
+
4
+ # Core ML libraries
5
+ torch>=2.0.0
6
+ torchvision>=0.15.0
7
+ numpy>=1.21.0
8
+ scipy>=1.7.0
9
+
10
+ # Gradio for web interface
11
+ gradio>=4.0.0
12
+
13
+ # CLIP for text encoding
14
+ git+https://github.com/openai/CLIP.git
15
+
16
+ # Additional dependencies
17
+ einops>=0.6.0
18
+ timm>=0.6.0
19
+ tqdm>=4.64.0
20
+ matplotlib>=3.5.0
21
+ opencv-python>=4.6.0
22
+ Pillow>=9.0.0
23
+
24
+ # For model loading and processing
25
+ h5py>=3.7.0
26
+ scikit-learn>=1.0.0
27
+
28
+ # Utilities
29
+ gdown>=4.6.0
30
+
utils/22x3_mean_std/t2m/22x3_mean.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ebfe87efeb6828b33c69f9df88d7a0373b9af7ac546496a4c7701112a748f12f
3
+ size 140
utils/22x3_mean_std/t2m/22x3_std.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4733059903169bc95993999fde322d8e84fb95f3e6ff1b277c283a0eb877d0c5
3
+ size 140
utils/__pycache__/back_process.cpython-310.pyc ADDED
Binary file (5.56 kB). View file
 
utils/__pycache__/eval_utils.cpython-310.pyc ADDED
Binary file (21.2 kB). View file
 
utils/__pycache__/evaluators.cpython-310.pyc ADDED
Binary file (12.5 kB). View file
 
utils/__pycache__/glove.cpython-310.pyc ADDED
Binary file (2.62 kB). View file
 
utils/__pycache__/motion_process.cpython-310.pyc ADDED
Binary file (11.2 kB). View file
 
utils/__pycache__/motion_process.cpython-313.pyc ADDED
Binary file (16.6 kB). View file
 
utils/__pycache__/quaternion.cpython-310.pyc ADDED
Binary file (13.2 kB). View file
 
utils/__pycache__/skeleton.cpython-310.pyc ADDED
Binary file (6.09 kB). View file
 
utils/__pycache__/train_utils.cpython-310.pyc ADDED
Binary file (3.73 kB). View file
 
utils/back_process.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils.skeleton import Skeleton
2
+ from utils.quaternion import *
3
+ from utils.motion_process import t2m_kinematic_chain, t2m_raw_offsets
4
+
5
+ import torch
6
+ import os
7
+
8
+ n_raw_offsets = torch.from_numpy(t2m_raw_offsets)
9
+ kinematic_chain = t2m_kinematic_chain
10
+ l_idx1, l_idx2 = 5, 8
11
+ face_joint_indx = [2, 1, 17, 16]
12
+
13
+ # Lazy loading of tgt_offsets - only needed when is_mesh=True
14
+ _tgt_offsets_cache = None
15
+
16
+ def _get_tgt_offsets():
17
+ """Lazily load target offsets for mesh processing. Only called when is_mesh=True."""
18
+ global _tgt_offsets_cache
19
+ if _tgt_offsets_cache is None:
20
+ example_data_path = os.path.join("datasets/HumanML3D/new_joints", "000021" + '.npy')
21
+ if not os.path.exists(example_data_path):
22
+ raise FileNotFoundError(
23
+ f"Example data file not found: {example_data_path}\n"
24
+ "This file is only needed for mesh-level motion generation (is_mesh=True).\n"
25
+ "For regular text-to-motion generation, use is_mesh=False."
26
+ )
27
+ example_data = np.load(example_data_path)
28
+ example_data = example_data.reshape(len(example_data), -1, 3)
29
+ example_data = torch.from_numpy(example_data)
30
+ tgt_skel = Skeleton(n_raw_offsets, kinematic_chain, 'cpu')
31
+ _tgt_offsets_cache = tgt_skel.get_offsets_joints(example_data[0])
32
+ return _tgt_offsets_cache
33
+
34
+ l_idx1, l_idx2 = 5, 8
35
+ fid_r, fid_l = [8, 11], [7, 10]
36
+ r_hip, l_hip = 2, 1
37
+ joints_num = 22
38
+
39
+ def uniform_skeleton(positions, target_offset):
40
+ src_skel = Skeleton(n_raw_offsets, kinematic_chain, 'cpu')
41
+ src_offset = src_skel.get_offsets_joints(torch.from_numpy(positions[0]))
42
+ src_offset = src_offset.numpy()
43
+ tgt_offset = target_offset.numpy()
44
+ # print(src_offset)
45
+ # print(tgt_offset)
46
+ '''Calculate Scale Ratio as the ratio of legs'''
47
+ src_leg_len = np.abs(src_offset[l_idx1]).max() + np.abs(src_offset[l_idx2]).max()
48
+ tgt_leg_len = np.abs(tgt_offset[l_idx1]).max() + np.abs(tgt_offset[l_idx2]).max()
49
+
50
+ scale_rt = tgt_leg_len / src_leg_len
51
+ # print(scale_rt)
52
+ src_root_pos = positions[:, 0]
53
+ tgt_root_pos = src_root_pos * scale_rt
54
+
55
+ '''Inverse Kinematics'''
56
+ quat_params = src_skel.inverse_kinematics_np(positions, face_joint_indx)
57
+ # print(quat_params.shape)
58
+
59
+ '''Forward Kinematics'''
60
+ src_skel.set_offset(target_offset)
61
+ new_joints = src_skel.forward_kinematics_np(quat_params, tgt_root_pos)
62
+ return new_joints
63
+
64
+
65
+ def process_file(positions, feet_thre, is_mesh=False):
66
+ # (seq_len, joints_num, 3)
67
+ # '''Down Sample'''
68
+ # positions = positions[::ds_num]
69
+
70
+ if is_mesh:
71
+ '''Uniform Skeleton'''
72
+ tgt_offsets = _get_tgt_offsets() # Lazy load only when needed
73
+ positions = uniform_skeleton(positions, tgt_offsets)
74
+
75
+ '''Put on Floor'''
76
+ floor_height = positions.min(axis=0).min(axis=0)[1]
77
+ positions[:, :, 1] -= floor_height
78
+ # print(floor_height)
79
+
80
+ # plot_3d_motion("./positions_1.mp4", kinematic_chain, positions, 'title', fps=20)
81
+
82
+ '''XZ at origin'''
83
+ root_pos_init = positions[0]
84
+ root_pose_init_xz = root_pos_init[0] * np.array([1, 0, 1])
85
+ positions = positions - root_pose_init_xz
86
+
87
+ # '''Move the first pose to origin '''
88
+ # root_pos_init = positions[0]
89
+ # positions = positions - root_pos_init[0]
90
+
91
+ '''All initially face Z+'''
92
+ r_hip, l_hip, sdr_r, sdr_l = face_joint_indx
93
+ across1 = root_pos_init[r_hip] - root_pos_init[l_hip]
94
+ across2 = root_pos_init[sdr_r] - root_pos_init[sdr_l]
95
+ across = across1 + across2
96
+ across = across / np.sqrt((across ** 2).sum(axis=-1))[..., np.newaxis]
97
+
98
+ # forward (3,), rotate around y-axis
99
+ forward_init = np.cross(np.array([[0, 1, 0]]), across, axis=-1)
100
+ # forward (3,)
101
+ forward_init = forward_init / np.sqrt((forward_init ** 2).sum(axis=-1))[..., np.newaxis]
102
+
103
+ # print(forward_init)
104
+
105
+ target = np.array([[0, 0, 1]])
106
+ root_quat_init = qbetween_np(forward_init, target)
107
+ root_quat_init = np.ones(positions.shape[:-1] + (4,)) * root_quat_init
108
+
109
+ positions_b = positions.copy()
110
+
111
+ positions = qrot_np(root_quat_init, positions)
112
+
113
+ # plot_3d_motion("./positions_2.mp4", kinematic_chain, positions, 'title', fps=20)
114
+
115
+ '''New ground truth positions'''
116
+ global_positions = positions.copy()
117
+
118
+ # plt.plot(positions_b[:, 0, 0], positions_b[:, 0, 2], marker='*')
119
+ # plt.plot(positions[:, 0, 0], positions[:, 0, 2], marker='o', color='r')
120
+ # plt.xlabel('x')
121
+ # plt.ylabel('z')
122
+ # plt.axis('equal')
123
+ # plt.show()
124
+
125
+ """ Get Foot Contacts """
126
+
127
+ def foot_detect(positions, thres):
128
+ velfactor, heightfactor = np.array([thres, thres]), np.array([3.0, 2.0])
129
+
130
+ feet_l_x = (positions[1:, fid_l, 0] - positions[:-1, fid_l, 0]) ** 2
131
+ feet_l_y = (positions[1:, fid_l, 1] - positions[:-1, fid_l, 1]) ** 2
132
+ feet_l_z = (positions[1:, fid_l, 2] - positions[:-1, fid_l, 2]) ** 2
133
+ # feet_l_h = positions[:-1,fid_l,1]
134
+ # feet_l = (((feet_l_x + feet_l_y + feet_l_z) < velfactor) & (feet_l_h < heightfactor)).astype(np.float)
135
+ feet_l = ((feet_l_x + feet_l_y + feet_l_z) < velfactor).astype(np.float32)
136
+
137
+ feet_r_x = (positions[1:, fid_r, 0] - positions[:-1, fid_r, 0]) ** 2
138
+ feet_r_y = (positions[1:, fid_r, 1] - positions[:-1, fid_r, 1]) ** 2
139
+ feet_r_z = (positions[1:, fid_r, 2] - positions[:-1, fid_r, 2]) ** 2
140
+ # feet_r_h = positions[:-1,fid_r,1]
141
+ # feet_r = (((feet_r_x + feet_r_y + feet_r_z) < velfactor) & (feet_r_h < heightfactor)).astype(np.float)
142
+ feet_r = (((feet_r_x + feet_r_y + feet_r_z) < velfactor)).astype(np.float32)
143
+ return feet_l, feet_r
144
+
145
+ #
146
+ feet_l, feet_r = foot_detect(positions, feet_thre)
147
+ # feet_l, feet_r = foot_detect(positions, 0.002)
148
+
149
+ '''Quaternion and Cartesian representation'''
150
+ r_rot = None
151
+
152
+ def get_rifke(positions):
153
+ '''Local pose'''
154
+ positions[..., 0] -= positions[:, 0:1, 0]
155
+ positions[..., 2] -= positions[:, 0:1, 2]
156
+ '''All pose face Z+'''
157
+ positions = qrot_np(np.repeat(r_rot[:, None], positions.shape[1], axis=1), positions)
158
+ return positions
159
+
160
+ def get_quaternion(positions):
161
+ skel = Skeleton(n_raw_offsets, kinematic_chain, "cpu")
162
+ # (seq_len, joints_num, 4)
163
+ quat_params = skel.inverse_kinematics_np(positions, face_joint_indx, smooth_forward=False)
164
+
165
+ '''Fix Quaternion Discontinuity'''
166
+ quat_params = qfix(quat_params)
167
+ # (seq_len, 4)
168
+ r_rot = quat_params[:, 0].copy()
169
+ # print(r_rot[0])
170
+ '''Root Linear Velocity'''
171
+ # (seq_len - 1, 3)
172
+ velocity = (positions[1:, 0] - positions[:-1, 0]).copy()
173
+ # print(r_rot.shape, velocity.shape)
174
+ velocity = qrot_np(r_rot[1:], velocity)
175
+ '''Root Angular Velocity'''
176
+ # (seq_len - 1, 4)
177
+ r_velocity = qmul_np(r_rot[1:], qinv_np(r_rot[:-1]))
178
+ quat_params[1:, 0] = r_velocity
179
+ # (seq_len, joints_num, 4)
180
+ return quat_params, r_velocity, velocity, r_rot
181
+
182
+ def get_cont6d_params(positions):
183
+ skel = Skeleton(n_raw_offsets, kinematic_chain, "cpu")
184
+ # (seq_len, joints_num, 4)
185
+ quat_params = skel.inverse_kinematics_np(positions, face_joint_indx, smooth_forward=True)
186
+
187
+ '''Quaternion to continuous 6D'''
188
+ cont_6d_params = quaternion_to_cont6d_np(quat_params)
189
+ # (seq_len, 4)
190
+ r_rot = quat_params[:, 0].copy()
191
+ # print(r_rot[0])
192
+ '''Root Linear Velocity'''
193
+ # (seq_len - 1, 3)
194
+ velocity = (positions[1:, 0] - positions[:-1, 0]).copy()
195
+ # print(r_rot.shape, velocity.shape)
196
+ velocity = qrot_np(r_rot[1:], velocity)
197
+ '''Root Angular Velocity'''
198
+ # (seq_len - 1, 4)
199
+ r_velocity = qmul_np(r_rot[1:], qinv_np(r_rot[:-1]))
200
+ # (seq_len, joints_num, 4)
201
+ return cont_6d_params, r_velocity, velocity, r_rot
202
+
203
+ cont_6d_params, r_velocity, velocity, r_rot = get_cont6d_params(positions)
204
+ positions = get_rifke(positions)
205
+
206
+ # trejec = np.cumsum(np.concatenate([np.array([[0, 0, 0]]), velocity], axis=0), axis=0)
207
+ # r_rotations, r_pos = recover_ric_glo_np(r_velocity, velocity[:, [0, 2]])
208
+
209
+ # plt.plot(positions_b[:, 0, 0], positions_b[:, 0, 2], marker='*')
210
+ # plt.plot(ground_positions[:, 0, 0], ground_positions[:, 0, 2], marker='o', color='r')
211
+ # plt.plot(trejec[:, 0], trejec[:, 2], marker='^', color='g')
212
+ # plt.plot(r_pos[:, 0], r_pos[:, 2], marker='s', color='y')
213
+ # plt.xlabel('x')
214
+ # plt.ylabel('z')
215
+ # plt.axis('equal')
216
+ # plt.show()
217
+
218
+ '''Root height'''
219
+ root_y = positions[:, 0, 1:2]
220
+
221
+ '''Root rotation and linear velocity'''
222
+ # (seq_len-1, 1) rotation velocity along y-axis
223
+ # (seq_len-1, 2) linear velovity on xz plane
224
+ r_velocity = np.arcsin(r_velocity[:, 2:3])
225
+ l_velocity = velocity[:, [0, 2]]
226
+ # print(r_velocity.shape, l_velocity.shape, root_y.shape)
227
+ root_data = np.concatenate([r_velocity, l_velocity, root_y[:-1]], axis=-1)
228
+
229
+ '''Get Joint Rotation Representation'''
230
+ # (seq_len, (joints_num-1) *6) quaternion for skeleton joints
231
+ rot_data = cont_6d_params[:, 1:].reshape(len(cont_6d_params), -1)
232
+
233
+ '''Get Joint Rotation Invariant Position Represention'''
234
+ # (seq_len, (joints_num-1)*3) local joint position
235
+ ric_data = positions[:, 1:].reshape(len(positions), -1)
236
+
237
+ '''Get Joint Velocity Representation'''
238
+ # (seq_len-1, joints_num*3)
239
+ local_vel = qrot_np(np.repeat(r_rot[:-1, None], global_positions.shape[1], axis=1),
240
+ global_positions[1:] - global_positions[:-1])
241
+ local_vel = local_vel.reshape(len(local_vel), -1)
242
+
243
+ data = root_data
244
+ data = np.concatenate([data, ric_data[:-1]], axis=-1)
245
+ data = np.concatenate([data, rot_data[:-1]], axis=-1)
246
+ # print(data.shape, local_vel.shape)
247
+ data = np.concatenate([data, local_vel], axis=-1)
248
+ data = np.concatenate([data, feet_l, feet_r], axis=-1)
249
+
250
+ return data, global_positions, positions, l_velocity
251
+
252
+
253
+ def back_process(data, is_mesh=False):
254
+ data, ground_positions, positions, l_velocity = process_file(data, 0.002, is_mesh=is_mesh)
255
+ return data[:, :67]
utils/cal_ae_post_mean_std.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+ from os.path import join as pjoin
4
+ from tqdm import tqdm
5
+ import torch
6
+ import argparse
7
+ from models.AE_2D_Causal import AE_models
8
+
9
+ #################################################################################
10
+ # Calculate Post AE/VAE Mean Std #
11
+ #################################################################################
12
+
13
+ def mean_variance(data_dir, save_dir, ae, tp='AE'):
14
+ file_list = os.listdir(data_dir)
15
+ data_list = []
16
+ mean = np.load(f'utils/22x3_mean_std/t2m/22x3_mean.npy')
17
+ std = np.load(f'utils/22x3_mean_std/t2m/22x3_std.npy')
18
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+ ae = ae.to(device)
20
+
21
+ for file in tqdm(file_list):
22
+ data = np.load(pjoin(data_dir, file))
23
+ if len(data.shape) == 2:
24
+ data = np.expand_dims(data, axis=0)
25
+ if np.isnan(data).any():
26
+ print(file)
27
+ continue
28
+ data = data[:(data.shape[0]//4)*4, :, :]
29
+ if data.shape[0] == 0:
30
+ continue
31
+ data = (data - mean) / std
32
+ data = torch.from_numpy(data).to(device)
33
+ with torch.no_grad():
34
+ data_list.append(ae.encode(data.unsqueeze(0)).squeeze(0).cpu().numpy())
35
+
36
+ data = np.concatenate(data_list, axis=1)
37
+ data = data.reshape(4, -1)
38
+ print('Data Range:')
39
+ print(data.min(),data.max())
40
+ Mean = data.mean(axis=1)
41
+ Std = data.std(axis=1)
42
+ print('Mean/Std:')
43
+ print(Mean, Std)
44
+
45
+ np.save(pjoin(save_dir, f'{tp}_2D_Causal_Post_Mean.npy'), Mean)
46
+ np.save(pjoin(save_dir, f'{tp}_2D_Causal_Post_Std.npy'), Std)
47
+
48
+ if __name__ == '__main__':
49
+ parser = argparse.ArgumentParser()
50
+ data_dir = 'datasets/HumanML3D/new_joints/'
51
+
52
+ parser.add_argument('--is_ae', action="store_true")
53
+ parser.add_argument('--ae_name', type=str, default="AE_2D_Causal")
54
+ args = parser.parse_args()
55
+
56
+ if args.is_ae:
57
+ ae = AE_models["AE_Model"](input_width=3)
58
+ else:
59
+ ae = AE_models["VAE_Model"](input_width=3)
60
+ ckpt = torch.load(f'checkpoints/t2m/{args.ae_name}/model/latest.tar', map_location='cpu')
61
+ ae.load_state_dict(ckpt['ae'])
62
+ ae = ae.eval()
63
+ save_dir = f'checkpoints/t2m/{args.ae_name}'
64
+ mean_variance(data_dir, save_dir, ae, tp='AE' if args.is_ae else 'VAE')
utils/cal_mean_std.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+ from os.path import join as pjoin
4
+ from tqdm import tqdm
5
+
6
+ #################################################################################
7
+ # Calculate Absolute Coordinate Mean Std #
8
+ #################################################################################
9
+ def mean_variance(data_dir, save_dir):
10
+ file_list = os.listdir(data_dir)
11
+ data_list = []
12
+
13
+ for file in tqdm(file_list):
14
+ data = np.load(pjoin(data_dir, file))
15
+ if len(data.shape) == 2:
16
+ data = np.expand_dims(data, axis=0)
17
+ if np.isnan(data).any():
18
+ print(file)
19
+ continue
20
+ data_list.append(data.reshape(-1, 3))
21
+
22
+ data = np.concatenate(data_list, axis=0)
23
+ print(data.shape)
24
+ Mean = data.mean(axis=0)
25
+ Std = data.std(axis=0)
26
+
27
+ np.save(pjoin(save_dir, 'Mean_22x3.npy'), Mean)
28
+ np.save(pjoin(save_dir, 'Std_22x3.npy'), Std)
29
+
30
+ return Mean, Std
31
+
32
+ if __name__ == '__main__':
33
+ data_dir1 = 'datasets/HumanML3D/new_joints/'
34
+ save_dir1 = 'datasets/HumanML3D/'
35
+ mean, std = mean_variance(data_dir1, save_dir1)
utils/cal_mesh_ae_post_mean_std.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # very hard coded will clean and refine later
2
+ import numpy as np
3
+ import os
4
+ from os.path import join as pjoin
5
+ from tqdm import tqdm
6
+ import torch
7
+ from models.AE_Mesh import AE_models
8
+ from human_body_prior.body_model.body_model import BodyModel
9
+
10
+ def downsample(data_dir, save_dir):
11
+ ae = AE_models["AE_Model"](test_mode=True)
12
+ ckpt = torch.load('checkpoints/t2m/AE_Mesh/model/latest.tar', map_location='cpu')
13
+ model_key = 'ae'
14
+ ae.load_state_dict(ckpt[model_key])
15
+ ae = ae.eval()
16
+ ae = ae.cuda()
17
+
18
+ mean = np.load(f'utils/mesh_mean_std/t2m/mesh_mean.npy') # or yours
19
+ std = np.load(f'utils/mesh_mean_std/t2m/mesh_std.npy') # or yours
20
+
21
+ bm_path = './body_models/smplh/neutral/model.npz'
22
+ dmpl_path = './body_models/dmpls/neutral/model.npz'
23
+ num_betas = 10
24
+ num_dmpls = 8
25
+ bm = BodyModel(bm_fname=bm_path, num_betas=num_betas, num_dmpls=num_dmpls, dmpl_fname=dmpl_path).cuda()
26
+ comp_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
+
28
+ file_list = os.listdir(data_dir)
29
+ file_list.sort()
30
+ for file in tqdm(file_list):
31
+ data = np.load(pjoin(data_dir, file))
32
+
33
+ if np.isnan(data).any():
34
+ print(f"Skipping {file} due to NaN")
35
+ continue
36
+
37
+ body_parms = {
38
+ 'root_orient': torch.from_numpy(data[:, :3]).to(comp_device).float(),
39
+ 'pose_body': torch.from_numpy(data[:, 3:66]).to(comp_device).float(),
40
+ 'pose_hand': torch.from_numpy(data[:, 66:52*3]).to(comp_device).float(),
41
+ 'trans': torch.from_numpy(data[:, 52*3:53*3]).to(comp_device).float(),
42
+ 'betas': torch.from_numpy(data[:, 53*3:53*3+10]).to(comp_device).float(),
43
+ 'dmpls': torch.from_numpy(data[:, 53*3+10:]).to(comp_device).float()
44
+ }
45
+
46
+ with torch.no_grad():
47
+ verts = bm(**body_parms).v
48
+ verts[:, :, 1] -= verts[:, :, 1].min()
49
+ verts = verts.detach().cpu().numpy()
50
+ "Z Normalization"
51
+ vertss = (verts - mean) / std
52
+ T = vertss.shape[0]
53
+ data = torch.from_numpy(vertss).float().to(comp_device)
54
+ if T % 16 != 0:
55
+ pad_len = 16 - (T % 16)
56
+ pad_data = torch.zeros((pad_len, 6890, 3), dtype=data.dtype, device=comp_device)
57
+ data = torch.cat([data, pad_data], dim=0)
58
+
59
+ outputs = []
60
+ with torch.no_grad():
61
+ for i in range(0, data.shape[0], 16):
62
+ chunk = data[i:i+16].unsqueeze(0)
63
+ out = ae.encode(chunk).squeeze(0).cpu().numpy()
64
+ outputs.append(out)
65
+ downsampled = np.concatenate(outputs, axis=0)[:T]
66
+
67
+ np.save(pjoin(save_dir, f"{file}"), downsampled)
68
+
69
+ #################################################################################
70
+ # Calculate Mean Std #
71
+ #################################################################################
72
+ def mean_variance(data_dir, save_dir):
73
+ file_list = os.listdir(data_dir)
74
+ data_list = []
75
+
76
+ for file in tqdm(file_list):
77
+ data = np.load(pjoin(data_dir, file))
78
+ if np.isnan(data).any():
79
+ print(file)
80
+ continue
81
+ data_list.append(data.reshape(-1, 12))
82
+
83
+ data = np.concatenate(data_list, axis=0)
84
+ print(data.shape)
85
+ Mean = data.mean(axis=0)
86
+ Std = data.std(axis=0)
87
+
88
+ np.save(pjoin(save_dir, 'AE_Mesh_Post_Mean.npy'), Mean)
89
+ np.save(pjoin(save_dir, 'AE_Mesh_Post_Std.npy'), Std)
90
+ print(data.min(),data.max())
91
+ Mean2 = data.mean()
92
+ Std2 = data.std()
93
+ print(Mean, Std)
94
+ print(Mean2, Std2)
95
+
96
+ return Mean, Std
97
+
98
+ if __name__ == '__main__':
99
+ data_dir = 'datasets/HumanML3D/meshes/'
100
+ save_dir = 'datasets/HumanML3D/meshes_after_ae/'
101
+
102
+ os.makedirs(save_dir, exist_ok=True)
103
+ downsample(data_dir, save_dir)
104
+
105
+ data_dir1 = 'datasets/HumanML3D/meshes_after_ae/'
106
+ save_dir1 = 'checkpoints/t2m/AE_Mesh/'
107
+ mean, std = mean_variance(data_dir1, save_dir1)
utils/datasets.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os.path import join as pjoin
2
+ from torch.utils import data
3
+ import numpy as np
4
+ import torch
5
+ from tqdm import tqdm
6
+ from torch.utils.data._utils.collate import default_collate
7
+ import random
8
+ import codecs as cs
9
+ from utils.glove import GloVe
10
+ from human_body_prior.body_model.body_model import BodyModel
11
+
12
+ #################################################################################
13
+ # Collate Function #
14
+ #################################################################################
15
+ def collate_fn(batch):
16
+ batch.sort(key=lambda x: x[3], reverse=True)
17
+ return default_collate(batch)
18
+
19
+ #################################################################################
20
+ # Datasets #
21
+ #################################################################################
22
+ class AEDataset(data.Dataset):
23
+ def __init__(self, mean, std, motion_dir, window_size, split_file):
24
+ self.data = []
25
+ self.lengths = []
26
+ id_list = []
27
+ with open(split_file, 'r') as f:
28
+ for line in f.readlines():
29
+ id_list.append(line.strip())
30
+
31
+ for name in tqdm(id_list):
32
+ try:
33
+ motion = np.load(pjoin(motion_dir, name + '.npy'))
34
+ if len(motion.shape) == 2: # B,L,J,3
35
+ motion = np.expand_dims(motion, axis=0)
36
+ if motion.shape[0] < window_size:
37
+ continue
38
+ self.lengths.append(motion.shape[0] - window_size)
39
+ self.data.append(motion)
40
+ except Exception as e:
41
+ pass
42
+ self.cumsum = np.cumsum([0] + self.lengths)
43
+ self.window_size = window_size
44
+
45
+ self.mean = mean
46
+ self.std = std
47
+ print("Total number of motions {}, snippets {}".format(len(self.data), self.cumsum[-1]))
48
+
49
+ def __len__(self):
50
+ return self.cumsum[-1]
51
+
52
+ def __getitem__(self, item):
53
+ if item != 0:
54
+ motion_id = np.searchsorted(self.cumsum, item) - 1
55
+ idx = item - self.cumsum[motion_id] - 1
56
+ else:
57
+ motion_id = 0
58
+ idx = 0
59
+ motion = self.data[motion_id][idx:idx + self.window_size]
60
+ "Z Normalization"
61
+ motion = (motion - self.mean) / self.std
62
+
63
+ return motion
64
+
65
+
66
+ class AEMeshDataset(data.Dataset):
67
+ def __init__(self, mean, std, motion_dir, window_size, split_file):
68
+ self.data = []
69
+ self.lengths = []
70
+ id_list = []
71
+ with open(split_file, 'r') as f:
72
+ for line in f.readlines():
73
+ id_list.append(line.strip())
74
+
75
+ for name in id_list:
76
+ try:
77
+ motion = np.load(pjoin(motion_dir, name + '.npy'))
78
+ if motion.shape[0] < window_size:
79
+ continue
80
+ self.lengths.append(motion.shape[0] - motion.shape[0]+1)
81
+ self.data.append(motion)
82
+ except Exception as e:
83
+ pass
84
+ self.cumsum = np.cumsum([0] + self.lengths)
85
+ self.window_size = window_size
86
+
87
+ self.mean = mean
88
+ self.std = std
89
+ num_betas = 10
90
+ num_dmpls = 8
91
+ self.bm = BodyModel(bm_fname='./body_models/smplh/neutral/model.npz', num_betas=num_betas, num_dmpls=num_dmpls, dmpl_fname='./body_models/dmpls/neutral/model.npz')
92
+ print("Total number of motions {}, snippets {}".format(len(self.data), self.cumsum[-1]))
93
+
94
+ def __len__(self):
95
+ return self.cumsum[-1]
96
+
97
+ def __getitem__(self, item):
98
+ motion = self.data[item]
99
+ body_parms = {
100
+ 'root_orient': torch.from_numpy(motion[:, :3]).float(),
101
+ 'pose_body': torch.from_numpy(motion[:, 3:66]).float(),
102
+ 'pose_hand': torch.from_numpy(motion[:, 66:52*3]).float(),
103
+ 'trans': torch.from_numpy(motion[:, 52*3:53*3]).float(),
104
+ 'betas': torch.from_numpy(motion[:, 53*3:53*3+10]).float(),
105
+ 'dmpls': torch.from_numpy(motion[:, 53*3+10:]).float()
106
+ }
107
+ body_parms['betas']= torch.zeros_like(torch.from_numpy(motion[:, 53*3:53*3+10]).float())
108
+ with torch.no_grad():
109
+ verts = self.bm(**body_parms).v
110
+ verts[:, :, 1] -= verts[:, :, 1].min()
111
+ idx = random.randint(0, verts.shape[0] - 1)
112
+ verts = verts[idx:idx + self.window_size].detach().cpu().numpy()
113
+ "Z Normalization"
114
+ verts = (verts - self.mean) / self.std
115
+
116
+ return verts
117
+
118
+
119
+
120
+ class Text2MotionDataset(data.Dataset):
121
+ def __init__(self, mean, std, split_file, dataset_name, motion_dir, text_dir, unit_length, max_motion_length,
122
+ max_text_length, evaluation=False, is_mesh=False):
123
+ self.evaluation = evaluation
124
+ self.max_length = 20
125
+ self.pointer = 0
126
+ self.max_motion_length = max_motion_length
127
+ self.max_text_len = max_text_length
128
+ self.unit_length = unit_length
129
+ min_motion_len = 40 if dataset_name =='t2m' else 24
130
+
131
+ data_dict = {}
132
+ id_list = []
133
+ with cs.open(split_file, 'r') as f:
134
+ for line in f.readlines():
135
+ id_list.append(line.strip())
136
+
137
+ new_name_list = []
138
+ length_list = []
139
+ for name in tqdm(id_list):
140
+ try:
141
+ motion = np.load(pjoin(motion_dir, name + '.npy'))
142
+ if len(motion.shape) == 2:
143
+ motion = np.expand_dims(motion, axis=0)
144
+ if is_mesh:
145
+ if (len(motion)) < min_motion_len:
146
+ continue
147
+ else:
148
+ if (len(motion)) < min_motion_len or (len(motion) >= 200):
149
+ continue
150
+ text_data = []
151
+ flag = False
152
+ with cs.open(pjoin(text_dir, name + '.txt')) as f:
153
+ for line in f.readlines():
154
+ text_dict = {}
155
+ line_split = line.strip().split('#')
156
+ caption = line_split[0]
157
+ tokens = line_split[1].split(' ')
158
+ f_tag = float(line_split[2])
159
+ to_tag = float(line_split[3])
160
+ f_tag = 0.0 if np.isnan(f_tag) else f_tag
161
+ to_tag = 0.0 if np.isnan(to_tag) else to_tag
162
+
163
+ text_dict['caption'] = caption
164
+ text_dict['tokens'] = tokens
165
+ if f_tag == 0.0 and to_tag == 0.0:
166
+ flag = True
167
+ text_data.append(text_dict)
168
+ else:
169
+ try:
170
+ n_motion = motion[int(f_tag*20) : int(to_tag*20)]
171
+ if (len(n_motion)) < min_motion_len or (len(n_motion) >= 200):
172
+ continue
173
+ new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name
174
+ while new_name in data_dict:
175
+ new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name
176
+ data_dict[new_name] = {'motion': n_motion,
177
+ 'length': len(n_motion),
178
+ 'text':[text_dict]}
179
+ new_name_list.append(new_name)
180
+ length_list.append(len(n_motion))
181
+ except:
182
+ print(line_split)
183
+ print(line_split[2], line_split[3], f_tag, to_tag, name)
184
+
185
+ if flag:
186
+ data_dict[name] = {'motion': motion,
187
+ 'length': len(motion),
188
+ 'text': text_data}
189
+ new_name_list.append(name)
190
+ length_list.append(len(motion))
191
+ except:
192
+ pass
193
+ if self.evaluation:
194
+ self.w_vectorizer = GloVe('./glove', 'our_vab')
195
+ name_list, length_list = zip(*sorted(zip(new_name_list, length_list), key=lambda x: x[1]))
196
+ else:
197
+ name_list, length_list = new_name_list, length_list
198
+ self.mean = mean
199
+ self.std = std
200
+ self.length_arr = np.array(length_list)
201
+ self.data_dict = data_dict
202
+ self.name_list = name_list
203
+ if self.evaluation:
204
+ self.reset_max_len(self.max_length)
205
+
206
+ def reset_max_len(self, length):
207
+ assert length <= self.max_motion_length
208
+ self.pointer = np.searchsorted(self.length_arr, length)
209
+ print("Pointer Pointing at %d"%self.pointer)
210
+ self.max_length = length
211
+
212
+ def transform(self, data, mean=None, std=None):
213
+ if mean is None and std is None:
214
+ return (data - self.mean) / self.std
215
+ else:
216
+ return (data - mean) / std
217
+
218
+ def inv_transform(self, data, mean=None, std=None):
219
+ if mean is None and std is None:
220
+ return data * self.std + self.mean
221
+ else:
222
+ return data * std + mean
223
+
224
+ def __len__(self):
225
+ return len(self.data_dict) - self.pointer
226
+
227
+ def __getitem__(self, item):
228
+ idx = self.pointer + item
229
+ data = self.data_dict[self.name_list[idx]]
230
+ motion, m_length, text_list = data['motion'], data['length'], data['text']
231
+ # Randomly select a caption
232
+ text_data = random.choice(text_list)
233
+ caption, tokens = text_data['caption'], text_data['tokens']
234
+
235
+ if self.evaluation:
236
+ if len(tokens) < self.max_text_len:
237
+ # pad with "unk"
238
+ tokens = ['sos/OTHER'] + tokens + ['eos/OTHER']
239
+ sent_len = len(tokens)
240
+ tokens = tokens + ['unk/OTHER'] * (self.max_text_len + 2 - sent_len)
241
+ else:
242
+ # crop
243
+ tokens = tokens[:self.max_text_len]
244
+ tokens = ['sos/OTHER'] + tokens + ['eos/OTHER']
245
+ sent_len = len(tokens)
246
+ pos_one_hots = []
247
+ word_embeddings = []
248
+ for token in tokens:
249
+ word_emb, pos_oh = self.w_vectorizer[token]
250
+ pos_one_hots.append(pos_oh[None, :])
251
+ word_embeddings.append(word_emb[None, :])
252
+ pos_one_hots = np.concatenate(pos_one_hots, axis=0)
253
+ word_embeddings = np.concatenate(word_embeddings, axis=0)
254
+
255
+ if self.unit_length < 10:
256
+ coin2 = np.random.choice(['single', 'single', 'double'])
257
+ else:
258
+ coin2 = 'single'
259
+
260
+ if coin2 == 'double':
261
+ m_length = (m_length // self.unit_length - 1) * self.unit_length
262
+ elif coin2 == 'single':
263
+ m_length = (m_length // self.unit_length) * self.unit_length
264
+ idx = random.randint(0, len(motion) - m_length)
265
+ motion = motion[idx:idx+m_length]
266
+
267
+ "Z Normalization"
268
+ motion = (motion - self.mean) / self.std
269
+
270
+ if m_length < self.max_motion_length:
271
+ motion = np.concatenate([motion,
272
+ np.zeros((self.max_motion_length - m_length, motion.shape[1], motion.shape[2]))
273
+ ], axis=0)
274
+ elif m_length > self.max_motion_length:
275
+ idx = random.randint(0, m_length - self.max_motion_length)
276
+ motion = motion[idx:idx + self.max_motion_length]
277
+ if self.evaluation:
278
+ return word_embeddings, pos_one_hots, caption, sent_len, motion, m_length, '_'.join(tokens)
279
+ else:
280
+ return caption, motion, m_length
281
+
282
+
283
+ class Text2MotionDataset_Another_V(data.Dataset):
284
+ def __init__(self, mean, std, split_file, dataset_name, motion_dir, text_dir, unit_length, max_motion_length,
285
+ max_text_length, evaluation=False, is_mesh=False):
286
+ self.evaluation = evaluation
287
+ self.max_length = 20
288
+ self.pointer = 0
289
+ self.max_motion_length = max_motion_length
290
+ self.max_text_len = max_text_length
291
+ self.unit_length = unit_length
292
+ min_motion_len = 40 if dataset_name =='t2m' else 24
293
+
294
+ data_dict = {}
295
+ id_list = []
296
+ with cs.open(split_file, 'r') as f:
297
+ for line in f.readlines():
298
+ id_list.append(line.strip())
299
+
300
+ new_name_list = []
301
+ length_list = []
302
+ for name in tqdm(id_list):
303
+ try:
304
+ motion = np.load(pjoin(motion_dir, name + '.npy'))
305
+ if not self.evaluation:
306
+ if len(motion.shape) == 2:
307
+ motion = np.expand_dims(motion, axis=0)
308
+ if is_mesh:
309
+ if (len(motion)) < min_motion_len:
310
+ continue
311
+ else:
312
+ if (len(motion)) < min_motion_len or (len(motion) >= 200):
313
+ continue
314
+ text_data = []
315
+ flag = False
316
+ with cs.open(pjoin(text_dir, name + '.txt')) as f:
317
+ for line in f.readlines():
318
+ text_dict = {}
319
+ line_split = line.strip().split('#')
320
+ caption = line_split[0]
321
+ tokens = line_split[1].split(' ')
322
+ f_tag = float(line_split[2])
323
+ to_tag = float(line_split[3])
324
+ f_tag = 0.0 if np.isnan(f_tag) else f_tag
325
+ to_tag = 0.0 if np.isnan(to_tag) else to_tag
326
+
327
+ text_dict['caption'] = caption
328
+ text_dict['tokens'] = tokens
329
+ if f_tag == 0.0 and to_tag == 0.0:
330
+ flag = True
331
+ text_data.append(text_dict)
332
+ else:
333
+ try:
334
+ n_motion = motion[int(f_tag*20) : int(to_tag*20)]
335
+ if (len(n_motion)) < min_motion_len or (len(n_motion) >= 200):
336
+ continue
337
+ new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name
338
+ while new_name in data_dict:
339
+ new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name
340
+ data_dict[new_name] = {'motion': n_motion,
341
+ 'length': len(n_motion),
342
+ 'text':[text_dict]}
343
+ new_name_list.append(new_name)
344
+ length_list.append(len(n_motion))
345
+ except:
346
+ print(line_split)
347
+ print(line_split[2], line_split[3], f_tag, to_tag, name)
348
+
349
+ if flag:
350
+ data_dict[name] = {'motion': motion,
351
+ 'length': len(motion),
352
+ 'text': text_data}
353
+ new_name_list.append(name)
354
+ length_list.append(len(motion))
355
+ except:
356
+ pass
357
+ if self.evaluation:
358
+ self.w_vectorizer = GloVe('./glove', 'our_vab')
359
+ name_list, length_list = zip(*sorted(zip(new_name_list, length_list), key=lambda x: x[1]))
360
+ else:
361
+ name_list, length_list = new_name_list, length_list
362
+ self.mean = mean
363
+ self.std = std
364
+ self.length_arr = np.array(length_list)
365
+ self.data_dict = data_dict
366
+ self.name_list = name_list
367
+ if self.evaluation:
368
+ self.reset_max_len(self.max_length)
369
+
370
+ def reset_max_len(self, length):
371
+ assert length <= self.max_motion_length
372
+ self.pointer = np.searchsorted(self.length_arr, length)
373
+ print("Pointer Pointing at %d"%self.pointer)
374
+ self.max_length = length
375
+
376
+ def transform(self, data, mean=None, std=None):
377
+ if mean is None and std is None:
378
+ return (data - self.mean) / self.std
379
+ else:
380
+ return (data - mean) / std
381
+
382
+ def inv_transform(self, data, mean=None, std=None):
383
+ if mean is None and std is None:
384
+ return data * self.std + self.mean
385
+ else:
386
+ return data * std + mean
387
+
388
+ def __len__(self):
389
+ return len(self.data_dict) - self.pointer
390
+
391
+ def __getitem__(self, item):
392
+ idx = self.pointer + item
393
+ data = self.data_dict[self.name_list[idx]]
394
+ motion, m_length, text_list = data['motion'], data['length'], data['text']
395
+ # Randomly select a caption
396
+ text_data = random.choice(text_list)
397
+ caption, tokens = text_data['caption'], text_data['tokens']
398
+
399
+ if self.evaluation:
400
+ if len(tokens) < self.max_text_len:
401
+ # pad with "unk"
402
+ tokens = ['sos/OTHER'] + tokens + ['eos/OTHER']
403
+ sent_len = len(tokens)
404
+ tokens = tokens + ['unk/OTHER'] * (self.max_text_len + 2 - sent_len)
405
+ else:
406
+ # crop
407
+ tokens = tokens[:self.max_text_len]
408
+ tokens = ['sos/OTHER'] + tokens + ['eos/OTHER']
409
+ sent_len = len(tokens)
410
+ pos_one_hots = []
411
+ word_embeddings = []
412
+ for token in tokens:
413
+ word_emb, pos_oh = self.w_vectorizer[token]
414
+ pos_one_hots.append(pos_oh[None, :])
415
+ word_embeddings.append(word_emb[None, :])
416
+ pos_one_hots = np.concatenate(pos_one_hots, axis=0)
417
+ word_embeddings = np.concatenate(word_embeddings, axis=0)
418
+
419
+ if self.unit_length < 10:
420
+ coin2 = np.random.choice(['single', 'single', 'double'])
421
+ else:
422
+ coin2 = 'single'
423
+
424
+ if coin2 == 'double':
425
+ m_length = (m_length // self.unit_length - 1) * self.unit_length
426
+ elif coin2 == 'single':
427
+ m_length = (m_length // self.unit_length) * self.unit_length
428
+ idx = random.randint(0, len(motion) - m_length)
429
+ motion = motion[idx:idx+m_length]
430
+
431
+ "Z Normalization"
432
+ if self.evaluation:
433
+ motion = motion[:, :self.mean.shape[0]]
434
+ motion = (motion - self.mean) / self.std
435
+
436
+ if m_length < self.max_motion_length:
437
+ if self.evaluation:
438
+ motion = np.concatenate([motion,
439
+ np.zeros((self.max_motion_length - m_length, motion.shape[1]))
440
+ ], axis=0)
441
+ else:
442
+ motion = np.concatenate([motion,
443
+ np.zeros((self.max_motion_length - m_length, motion.shape[1], motion.shape[2]))
444
+ ], axis=0)
445
+ elif m_length > self.max_motion_length:
446
+ if not self.evaluation:
447
+ idx = random.randint(0, m_length - self.max_motion_length)
448
+ motion = motion[idx:idx + self.max_motion_length]
449
+ if self.evaluation:
450
+ return word_embeddings, pos_one_hots, caption, sent_len, motion, m_length, '_'.join(tokens)
451
+ else:
452
+ return caption, motion, m_length
utils/eval_mean_std/t2m/eval_mean.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1f866545af58e1d603b649fbf47680ffdcbf2856f8a3771c7e8aadb1098d560f
3
+ size 396
utils/eval_mean_std/t2m/eval_std.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7898ab84360e67ef06046df9e39001b02182d3aeae2ae4e3a3f46277bd748045
3
+ size 396
utils/eval_utils.py ADDED
@@ -0,0 +1,928 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ from scipy import linalg
4
+ from scipy.ndimage import uniform_filter1d
5
+ import torch
6
+ from utils.motion_process import recover_from_ric
7
+ from utils.back_process import back_process
8
+ from tqdm import tqdm
9
+
10
+ #################################################################################
11
+ # Eval Function Loops #
12
+ #################################################################################
13
+ @torch.no_grad()
14
+ def evaluation_ae(out_dir, val_loader, net, writer, ep, eval_wrapper, device, best_fid=1000, best_div=0,
15
+ best_top1=0, best_top2=0, best_top3=0, best_matching=100,
16
+ eval_mean=None, eval_std=None, save=True, draw=True):
17
+ net.eval()
18
+
19
+ motion_annotation_list = []
20
+ motion_pred_list = []
21
+
22
+ R_precision_real = 0
23
+ R_precision = 0
24
+
25
+ nb_sample = 0
26
+ matching_score_real = 0
27
+ matching_score_pred = 0
28
+ mpjpe = 0
29
+ num_poses = 0
30
+
31
+ for batch in tqdm(val_loader):
32
+ word_embeddings, pos_one_hots, caption, sent_len, motion, m_length, token = batch
33
+
34
+ bs, seq = motion.shape[0], motion.shape[1]
35
+ gt = val_loader.dataset.inv_transform(motion.detach().cpu().numpy())
36
+ bgt = []
37
+ for j in range(bs):
38
+ bgt.append(back_process(gt[j], is_mesh=False))
39
+ bgt = np.stack(bgt, axis=0)
40
+ bgt = val_loader.dataset.transform(bgt, eval_mean, eval_std)
41
+
42
+ bgt = torch.from_numpy(bgt).to(device)
43
+ (et, em), (_, _) = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, caption, bgt, m_length-1)
44
+
45
+ motion = motion.float().to(device)
46
+ with torch.no_grad():
47
+ pred_pose_eval = net.forward(motion)
48
+ pred = val_loader.dataset.inv_transform(pred_pose_eval.detach().cpu().numpy())
49
+ bpred = []
50
+ for j in range(bs):
51
+ bpred.append(back_process(pred[j], is_mesh=False))
52
+ bpred = np.stack(bpred, axis=0)
53
+ bpred = val_loader.dataset.transform(bpred, eval_mean, eval_std)
54
+
55
+ (et_pred, em_pred), (_, _) = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, caption,
56
+ torch.from_numpy(bpred).to(device), m_length-1)
57
+ for i in range(bs):
58
+ gtt = torch.from_numpy(gt[i, :m_length[i]]).float().reshape(-1, 22, 3)
59
+ predd = torch.from_numpy(pred[i, :m_length[i]]).float().reshape(-1, 22, 3)
60
+ mpjpe += torch.sum(calculate_mpjpe(gtt, predd))
61
+ num_poses += gt.shape[0]
62
+
63
+ motion_pred_list.append(em_pred)
64
+ motion_annotation_list.append(em)
65
+
66
+ temp_R = calculate_R_precision(et.cpu().numpy(), em.cpu().numpy(), top_k=3, sum_all=True)
67
+ temp_match = euclidean_distance_matrix(et.cpu().numpy(), em.cpu().numpy()).trace()
68
+ R_precision_real += temp_R
69
+ matching_score_real += temp_match
70
+ temp_R = calculate_R_precision(et_pred.cpu().numpy(), em_pred.cpu().numpy(), top_k=3, sum_all=True)
71
+ temp_match = euclidean_distance_matrix(et_pred.cpu().numpy(), em_pred.cpu().numpy()).trace()
72
+ R_precision += temp_R
73
+ matching_score_pred += temp_match
74
+
75
+ nb_sample += bs
76
+
77
+ motion_annotation_np = torch.cat(motion_annotation_list, dim=0).cpu().numpy()
78
+ motion_pred_np = torch.cat(motion_pred_list, dim=0).cpu().numpy()
79
+ gt_mu, gt_cov = calculate_activation_statistics(motion_annotation_np)
80
+ mu, cov = calculate_activation_statistics(motion_pred_np)
81
+
82
+ diversity_real = calculate_diversity(motion_annotation_np, 300 if nb_sample > 300 else 100)
83
+ diversity = calculate_diversity(motion_pred_np, 300 if nb_sample > 300 else 100)
84
+
85
+ R_precision_real = R_precision_real / nb_sample
86
+ R_precision = R_precision / nb_sample
87
+
88
+ matching_score_real = matching_score_real / nb_sample
89
+ matching_score_pred = matching_score_pred / nb_sample
90
+ mpjpe = mpjpe / num_poses
91
+
92
+ fid = calculate_frechet_distance(gt_mu, gt_cov, mu, cov)
93
+
94
+ msg = "--> \t Eva. Re %d:, FID. %.4f, Diversity Real. %.4f, Diversity. %.4f, R_precision_real. (%.4f, %.4f, %.4f), R_precision. (%.4f, %.4f, %.4f), matching_real. %.4f, matching_pred. %.4f, MPJPE. %.4f" % \
95
+ (ep, fid, diversity_real, diversity, R_precision_real[0], R_precision_real[1], R_precision_real[2],
96
+ R_precision[0], R_precision[1], R_precision[2], matching_score_real, matching_score_pred, mpjpe)
97
+ print(msg)
98
+ if draw:
99
+ writer.add_scalar('./Test/FID', fid, ep)
100
+ writer.add_scalar('./Test/Diversity', diversity, ep)
101
+ writer.add_scalar('./Test/top1', R_precision[0], ep)
102
+ writer.add_scalar('./Test/top2', R_precision[1], ep)
103
+ writer.add_scalar('./Test/top3', R_precision[2], ep)
104
+ writer.add_scalar('./Test/matching_score', matching_score_pred, ep)
105
+
106
+ if fid < best_fid:
107
+ msg = "--> --> \t FID Improved from %.5f to %.5f !!!" % (best_fid, fid)
108
+ if draw: print(msg)
109
+ best_fid = fid
110
+ if save:
111
+ torch.save({'ae': net.state_dict(), 'ep': ep}, os.path.join(out_dir, 'net_best_fid.tar'))
112
+
113
+ if abs(diversity_real - diversity) < abs(diversity_real - best_div):
114
+ msg = "--> --> \t Diversity Improved from %.5f to %.5f !!!"%(best_div, diversity)
115
+ if draw: print(msg)
116
+ best_div = diversity
117
+
118
+ if R_precision[0] > best_top1:
119
+ msg = "--> --> \t Top1 Improved from %.5f to %.5f !!!" % (best_top1, R_precision[0])
120
+ if draw: print(msg)
121
+ best_top1 = R_precision[0]
122
+
123
+ if R_precision[1] > best_top2:
124
+ msg = "--> --> \t Top2 Improved from %.5f to %.5f!!!" % (best_top2, R_precision[1])
125
+ if draw: print(msg)
126
+ best_top2 = R_precision[1]
127
+
128
+ if R_precision[2] > best_top3:
129
+ msg = "--> --> \t Top3 Improved from %.5f to %.5f !!!" % (best_top3, R_precision[2])
130
+ if draw: print(msg)
131
+ best_top3 = R_precision[2]
132
+
133
+ if matching_score_pred < best_matching:
134
+ msg = f"--> --> \t matching_score Improved from %.5f to %.5f !!!" % (best_matching, matching_score_pred)
135
+ if draw: print(msg)
136
+ best_matching = matching_score_pred
137
+
138
+ net.train()
139
+ return best_fid, best_div, best_top1, best_top2, best_top3, best_matching, mpjpe, writer
140
+
141
+ @torch.no_grad()
142
+ def evaluation_acmdm(out_dir, val_loader, ema_acmdm, ae, writer, ep, best_fid, best_div,
143
+ best_top1, best_top2, best_top3, best_matching, eval_wrapper, device, clip_score_old,
144
+ cond_scale=None, cal_mm=False, eval_mean=None, eval_std=None, after_mean=None, after_std=None, mesh_mean=None, mesh_std=None,
145
+ draw=True,
146
+ is_raw=False,
147
+ is_prefix=False,
148
+ is_control=False, index=[0], intensity=100,
149
+ is_mesh=False):
150
+
151
+ ema_acmdm.eval()
152
+ if not is_raw:
153
+ ae.eval()
154
+
155
+ save=False
156
+
157
+ motion_annotation_list = []
158
+ motion_pred_list = []
159
+ motion_multimodality = []
160
+ R_precision_real = 0
161
+ R_precision = 0
162
+ matching_score_real = 0
163
+ matching_score_pred = 0
164
+ multimodality = 0
165
+ if cond_scale is None:
166
+ if "kit" in out_dir:
167
+ cond_scale = 2.5
168
+ else:
169
+ cond_scale = 2.5
170
+ clip_score_real = 0
171
+ clip_score_gt = 0
172
+ skate_ratio_sum = 0
173
+ dist_sum = 0
174
+ traj_err = []
175
+
176
+ nb_sample = 0
177
+ if cal_mm:
178
+ num_mm_batch = 3
179
+ else:
180
+ num_mm_batch = 0
181
+
182
+ for i, batch in enumerate(tqdm(val_loader)):
183
+ word_embeddings, pos_one_hots, clip_text, sent_len, pose, m_length, token = batch
184
+ m_length = m_length.to(device)
185
+
186
+ bs, seq = pose.shape[:2]
187
+ if i < num_mm_batch:
188
+ motion_multimodality_batch = []
189
+ batch_clip_score_pred = 0
190
+ for _ in tqdm(range(30)):
191
+ pred_latents = ema_acmdm.generate(clip_text, m_length//4 if not is_raw else m_length, cond_scale)
192
+
193
+ if not is_raw:
194
+ pred_latents = val_loader.dataset.inv_transform(pred_latents.permute(0, 2, 3, 1).detach().cpu().numpy(),
195
+ after_mean, after_std)
196
+ pred_latents = torch.from_numpy(pred_latents).to(device)
197
+ with torch.no_grad():
198
+ pred_motions = ae.decode(pred_latents.permute(0,3,1,2))
199
+ else:
200
+ pred_motions = pred_latents.permute(0, 2, 3, 1)
201
+ pred_motions = val_loader.dataset.inv_transform(pred_motions.detach().cpu().numpy())
202
+ pred_motionss = []
203
+ for j in range(bs):
204
+ pred_motionss.append(back_process(pred_motions[j], is_mesh=is_mesh))
205
+ pred_motionss = np.stack(pred_motionss, axis=0)
206
+ pred_motions = val_loader.dataset.transform(pred_motionss, eval_mean, eval_std)
207
+ (et_pred, em_pred), (et_pred_clip, em_pred_clip) = eval_wrapper.get_co_embeddings(word_embeddings,
208
+ pos_one_hots,
209
+ sent_len,
210
+ clip_text,
211
+ torch.from_numpy(
212
+ pred_motions).to(
213
+ device),
214
+ m_length - 1)
215
+ motion_multimodality_batch.append(em_pred.unsqueeze(1))
216
+ motion_multimodality_batch = torch.cat(motion_multimodality_batch, dim=1) #(bs, 30, d)
217
+ motion_multimodality.append(motion_multimodality_batch)
218
+ for j in range(bs):
219
+ single_em = em_pred_clip[j]
220
+ single_et = et_pred_clip[j]
221
+ clip_score = (single_em @ single_et.T).item()
222
+ batch_clip_score_pred += clip_score
223
+ clip_score_real += batch_clip_score_pred
224
+ else:
225
+ if is_control:
226
+ pred_latents, mask_hint = ema_acmdm.generate_control(clip_text, m_length//4, pose.clone().float().to(device).permute(0,3,1,2), index, intensity,
227
+ cond_scale)
228
+ mask_hint = mask_hint.permute(0, 2, 3, 1).cpu().numpy()
229
+ elif is_prefix:
230
+ motion = pose.clone().float().to(device)
231
+ with torch.no_grad():
232
+ motion = ae.encode(motion)
233
+ amean = torch.from_numpy(after_mean).to(device)
234
+ astd = torch.from_numpy(after_std).to(device)
235
+ motion = ((motion.permute(0,2,3,1)-amean)/astd).permute(0,3,1,2)
236
+ pred_latents = ema_acmdm.generate(clip_text, m_length // 4, cond_scale, motion[:, :, :5, :]) # 20+40 style
237
+ else:
238
+ pred_latents = ema_acmdm.generate(clip_text, m_length//4 if not (is_raw or is_mesh) else m_length, cond_scale, j=22 if not is_mesh else 28)
239
+ if not is_raw:
240
+ pred_latents = val_loader.dataset.inv_transform(pred_latents.permute(0,2,3,1).detach().cpu().numpy(), after_mean, after_std)
241
+ pred_latents = torch.from_numpy(pred_latents).to(device)
242
+ with torch.no_grad():
243
+ if not is_mesh:
244
+ pred_latents = pred_latents.permute(0, 3, 1, 2)
245
+ pred_motions = ae.decode(pred_latents)
246
+ else:
247
+ pred_motions = pred_latents.permute(0, 2, 3, 1)
248
+ if not is_mesh:
249
+ pred_motions = val_loader.dataset.inv_transform(pred_motions.detach().cpu().numpy())
250
+ else:
251
+ pred_motions = val_loader.dataset.inv_transform(pred_motions.detach().cpu().numpy(), mesh_mean, mesh_std)
252
+ J_regressor = np.load('body_models/J_regressor.npy')
253
+ pred_motions = np.einsum('jk,btkc->btjc', J_regressor, pred_motions)[:, :, :22]
254
+ if is_control:
255
+ # foot skate
256
+ skate_ratio, skate_vel = calculate_skating_ratio(torch.from_numpy(pred_motions.transpose(0,2,3,1))) # [batch_size]
257
+ skate_ratio_sum += skate_ratio.sum()
258
+ # control errors
259
+ hint = val_loader.dataset.inv_transform(pose.clone().detach().cpu().numpy())
260
+ hint = hint * mask_hint
261
+ for i, (mot, h, mask) in enumerate(zip(pred_motions, hint, mask_hint)):
262
+ control_error = control_l2(np.expand_dims(mot, axis=0), np.expand_dims(h, axis=0),
263
+ np.expand_dims(mask, axis=0))
264
+ mean_error = control_error.sum() / mask.sum()
265
+ dist_sum += mean_error
266
+ control_error = control_error.reshape(-1)
267
+ mask = mask.reshape(-1)
268
+ err_np = calculate_trajectory_error(control_error, mean_error, mask)
269
+ traj_err.append(err_np)
270
+
271
+ pred_motionss = []
272
+ for j in range(bs):
273
+ pred_motionss.append(back_process(pred_motions[j], is_mesh=is_mesh))
274
+ pred_motionss = np.stack(pred_motionss, axis=0)
275
+ pred_motions = val_loader.dataset.transform(pred_motionss, eval_mean, eval_std)
276
+ (et_pred, em_pred), (et_pred_clip, em_pred_clip) = eval_wrapper.get_co_embeddings(word_embeddings,
277
+ pos_one_hots, sent_len,
278
+ clip_text,
279
+ torch.from_numpy(pred_motions).to(device),
280
+ m_length-1)
281
+ batch_clip_score_pred = 0
282
+ for j in range(bs):
283
+ single_em = em_pred_clip[j]
284
+ single_et = et_pred_clip[j]
285
+ clip_score = (single_em @ single_et.T).item()
286
+ batch_clip_score_pred += clip_score
287
+ clip_score_real += batch_clip_score_pred
288
+
289
+ pose = val_loader.dataset.inv_transform(pose.detach().cpu().numpy())
290
+ poses = []
291
+ for j in range(bs):
292
+ poses.append(back_process(pose[j], is_mesh=False))
293
+ poses = np.stack(poses, axis=0)
294
+ pose = val_loader.dataset.transform(poses, eval_mean, eval_std)
295
+ pose = torch.from_numpy(pose).cuda().float()
296
+ pose = pose.cuda().float()
297
+ (et, em), (et_clip, em_clip) = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, clip_text,
298
+ pose.clone(), m_length-1)
299
+ batch_clip_score = 0
300
+ for j in range(bs):
301
+ single_em = em_clip[j]
302
+ single_et = et_clip[j]
303
+ clip_score = (single_em @ single_et.T).item()
304
+ batch_clip_score += clip_score
305
+ clip_score_gt += batch_clip_score
306
+ motion_annotation_list.append(em)
307
+ motion_pred_list.append(em_pred)
308
+
309
+ temp_R = calculate_R_precision(et.cpu().numpy(), em.cpu().numpy(), top_k=3, sum_all=True)
310
+ temp_match = euclidean_distance_matrix(et.cpu().numpy(), em.cpu().numpy()).trace()
311
+ R_precision_real += temp_R
312
+ matching_score_real += temp_match
313
+ temp_R = calculate_R_precision(et_pred.cpu().numpy(), em_pred.cpu().numpy(), top_k=3, sum_all=True)
314
+ temp_match = euclidean_distance_matrix(et_pred.cpu().numpy(), em_pred.cpu().numpy()).trace()
315
+ R_precision += temp_R
316
+ matching_score_pred += temp_match
317
+
318
+ nb_sample += bs
319
+
320
+ motion_annotation_np = torch.cat(motion_annotation_list, dim=0).cpu().numpy()
321
+ motion_pred_np = torch.cat(motion_pred_list, dim=0).cpu().numpy()
322
+ gt_mu, gt_cov = calculate_activation_statistics(motion_annotation_np)
323
+ mu, cov = calculate_activation_statistics(motion_pred_np)
324
+
325
+ diversity_real = calculate_diversity(motion_annotation_np, 300 if nb_sample > 300 else 100)
326
+ diversity = calculate_diversity(motion_pred_np, 300 if nb_sample > 300 else 100)
327
+
328
+ R_precision_real = R_precision_real / nb_sample
329
+ R_precision = R_precision / nb_sample
330
+
331
+ clip_score_real = clip_score_real / nb_sample
332
+ clip_score_gt = clip_score_gt / nb_sample
333
+
334
+ matching_score_real = matching_score_real / nb_sample
335
+ matching_score_pred = matching_score_pred / nb_sample
336
+ if is_control:
337
+ # l2 dist
338
+ dist_mean = dist_sum / nb_sample
339
+
340
+ # Skating evaluation
341
+ skating_score = skate_ratio_sum / nb_sample
342
+
343
+ ### For trajecotry evaluation from GMD ###
344
+ traj_err = np.stack(traj_err).mean(0)
345
+
346
+ if cal_mm:
347
+ motion_multimodality = torch.cat(motion_multimodality, dim=0).cpu().numpy()
348
+ multimodality = calculate_multimodality(motion_multimodality, 10)
349
+
350
+ fid = calculate_frechet_distance(gt_mu, gt_cov, mu, cov)
351
+
352
+ msg = (f"--> \t Eva. Ep/Re {ep} :, FID. {fid:.4f}, Diversity Real. {diversity_real:.4f}, Diversity."
353
+ f" {diversity:.4f}, R_precision_real. {R_precision_real}, R_precision. {R_precision},"
354
+ f" matching_score_real. {matching_score_real}, matching_score_pred. {matching_score_pred}"
355
+ f" multimodality. {multimodality:.4f}, clip score. {clip_score_real}"
356
+ + (f" foot skating. {skating_score:.4f}, traj error. {traj_err[1].item():.4f}, pos error. {traj_err[3].item():.4f}, avg error. {traj_err[4].item():.4f}"
357
+ if is_control else ""))
358
+ print(msg)
359
+
360
+ if draw:
361
+ writer.add_scalar('./Test/FID', fid, ep)
362
+ writer.add_scalar('./Test/Diversity', diversity, ep)
363
+ writer.add_scalar('./Test/top1', R_precision[0], ep)
364
+ writer.add_scalar('./Test/top2', R_precision[1], ep)
365
+ writer.add_scalar('./Test/top3', R_precision[2], ep)
366
+ writer.add_scalar('./Test/matching_score', matching_score_pred, ep)
367
+ writer.add_scalar('./Test/clip_score', clip_score_real, ep)
368
+
369
+
370
+ if fid < best_fid:
371
+ msg = f"--> --> \t FID Improved from {best_fid:.5f} to {fid:.5f} !!!"
372
+ if draw: print(msg)
373
+ best_fid, best_ep = fid, ep
374
+ save=True
375
+
376
+
377
+ if matching_score_pred < best_matching:
378
+ msg = f"--> --> \t matching_score Improved from {best_matching:.5f} to {matching_score_pred:.5f} !!!"
379
+ if draw: print(msg)
380
+ best_matching = matching_score_pred
381
+
382
+ if abs(diversity_real - diversity) < abs(diversity_real - best_div):
383
+ msg = f"--> --> \t Diversity Improved from {best_div:.5f} to {diversity:.5f} !!!"
384
+ if draw: print(msg)
385
+ best_div = diversity
386
+
387
+ if R_precision[0] > best_top1:
388
+ msg = f"--> --> \t Top1 Improved from {best_top1:.4f} to {R_precision[0]:.4f} !!!"
389
+ if draw: print(msg)
390
+ best_top1 = R_precision[0]
391
+
392
+ if R_precision[1] > best_top2:
393
+ msg = f"--> --> \t Top2 Improved from {best_top2:.4f} to {R_precision[1]:.4f} !!!"
394
+ if draw: print(msg)
395
+ best_top2 = R_precision[1]
396
+
397
+ if R_precision[2] > best_top3:
398
+ msg = f"--> --> \t Top3 Improved from {best_top3:.4f} to {R_precision[2]:.4f} !!!"
399
+ if draw: print(msg)
400
+ best_top3 = R_precision[2]
401
+
402
+ if clip_score_real > clip_score_old:
403
+ msg = f"--> --> \t CLIP-score Improved from {clip_score_old:.4f} to {clip_score_real:.4f} !!!"
404
+ if draw: print(msg)
405
+ clip_score_old = clip_score_real
406
+
407
+ if cal_mm:
408
+ return best_fid, best_div, best_top1, best_top2, best_top3, best_matching, multimodality, clip_score_old, writer, save
409
+ else:
410
+ if is_control:
411
+ return best_fid, best_div, best_top1, best_top2, best_top3, best_matching, 0, clip_score_old, writer, save, dist_mean, skating_score, traj_err
412
+ else:
413
+ return best_fid, best_div, best_top1, best_top2, best_top3, best_matching, 0, clip_score_old, writer, save, None, None, None
414
+
415
+
416
+ @torch.no_grad()
417
+ def evaluation_acmdm_another_v(out_dir, val_loader, ema_acmdm, ae, writer, ep, best_fid, best_div,
418
+ best_top1, best_top2, best_top3, best_matching, eval_wrapper, device, clip_score_old,
419
+ cond_scale=None, cal_mm=False, train_mean=None, train_std=None, after_mean=None, after_std=None,
420
+ draw=True,
421
+ is_raw=False,
422
+ is_prefix=False,
423
+ is_control=False, index=[0], intensity=100,
424
+ is_mesh=False):
425
+
426
+ ema_acmdm.eval()
427
+ if not is_raw:
428
+ ae.eval()
429
+
430
+ save=False
431
+
432
+ motion_annotation_list = []
433
+ motion_pred_list = []
434
+ motion_multimodality = []
435
+ R_precision_real = 0
436
+ R_precision = 0
437
+ matching_score_real = 0
438
+ matching_score_pred = 0
439
+ multimodality = 0
440
+ if cond_scale is None:
441
+ if "kit" in out_dir:
442
+ cond_scale = 2.5
443
+ else:
444
+ cond_scale = 2.5
445
+ clip_score_real = 0
446
+ clip_score_gt = 0
447
+ skate_ratio_sum = 0
448
+ dist_sum = 0
449
+ traj_err = []
450
+
451
+ nb_sample = 0
452
+ if cal_mm:
453
+ num_mm_batch = 3
454
+ else:
455
+ num_mm_batch = 0
456
+
457
+ for i, batch in enumerate(tqdm(val_loader)):
458
+ word_embeddings, pos_one_hots, clip_text, sent_len, pose, m_length, token = batch
459
+ m_length = m_length.to(device)
460
+
461
+ bs, seq = pose.shape[:2]
462
+ if i < num_mm_batch:
463
+ motion_multimodality_batch = []
464
+ batch_clip_score_pred = 0
465
+ for _ in tqdm(range(30)):
466
+ pred_latents = ema_acmdm.generate(clip_text, m_length//4 if not is_raw else m_length, cond_scale)
467
+
468
+ if not is_raw:
469
+ pred_latents = val_loader.dataset.inv_transform(pred_latents.permute(0, 2, 3, 1).detach().cpu().numpy(),
470
+ after_mean, after_std)
471
+ pred_latents = torch.from_numpy(pred_latents).to(device)
472
+ with torch.no_grad():
473
+ pred_motions = ae.decode(pred_latents.permute(0,3,1,2))
474
+ else:
475
+ pred_motions = pred_latents.permute(0, 2, 3, 1)
476
+ pred_motions = val_loader.dataset.inv_transform(pred_motions.detach().cpu().numpy(), train_mean, train_std)
477
+ pred_motionss = []
478
+ for j in range(bs):
479
+ pred_motionss.append(back_process(pred_motions[j], is_mesh=is_mesh))
480
+ pred_motionss = np.stack(pred_motionss, axis=0)
481
+ pred_motions = val_loader.dataset.transform(pred_motionss)
482
+ (et_pred, em_pred), (et_pred_clip, em_pred_clip) = eval_wrapper.get_co_embeddings(word_embeddings,
483
+ pos_one_hots,
484
+ sent_len,
485
+ clip_text,
486
+ torch.from_numpy(
487
+ pred_motions).to(
488
+ device),
489
+ m_length - 1)
490
+ motion_multimodality_batch.append(em_pred.unsqueeze(1))
491
+ motion_multimodality_batch = torch.cat(motion_multimodality_batch, dim=1) #(bs, 30, d)
492
+ motion_multimodality.append(motion_multimodality_batch)
493
+ for j in range(bs):
494
+ single_em = em_pred_clip[j]
495
+ single_et = et_pred_clip[j]
496
+ clip_score = (single_em @ single_et.T).item()
497
+ batch_clip_score_pred += clip_score
498
+ clip_score_real += batch_clip_score_pred
499
+ else:
500
+ if is_control:
501
+ bgt = val_loader.dataset.inv_transform(pose.clone())
502
+ motion_gt = []
503
+ for j in range(bs):
504
+ motion_gt.append(recover_from_ric(bgt[j].float(), 22).numpy())
505
+ motion_gt = np.stack(motion_gt, axis=0)
506
+ motion = val_loader.dataset.transform(motion_gt, train_mean, train_std)
507
+ motion = torch.from_numpy(motion).float().to(device)
508
+ pred_latents, mask_hint = ema_acmdm.generate_control(clip_text, m_length//4, motion.clone().permute(0,3,1,2), index, intensity,
509
+ cond_scale)
510
+ mask_hint = mask_hint.permute(0, 2, 3, 1).cpu().numpy()
511
+ elif is_prefix:
512
+ bgt = val_loader.dataset.inv_transform(pose.clone())
513
+ motion_gt = []
514
+ for j in range(bs):
515
+ motion_gt.append(recover_from_ric(bgt[j].float(), 22).numpy())
516
+ motion_gt = np.stack(motion_gt, axis=0)
517
+ motion = val_loader.dataset.transform(motion_gt, train_mean, train_std)
518
+ motion = torch.from_numpy(motion).float().to(device)
519
+ with torch.no_grad():
520
+ motion = ae.encode(motion)
521
+ amean = torch.from_numpy(after_mean).to(device)
522
+ astd = torch.from_numpy(after_std).to(device)
523
+ motion = ((motion.permute(0,2,3,1)-amean)/astd).permute(0,3,1,2)
524
+ pred_latents = ema_acmdm.generate(clip_text, m_length // 4, cond_scale, motion[:, :, :5, :]) # 20+40 style
525
+ else:
526
+ pred_latents = ema_acmdm.generate(clip_text, m_length//4 if not is_raw else m_length, cond_scale)
527
+ if not is_raw:
528
+ pred_latents = val_loader.dataset.inv_transform(pred_latents.permute(0,2,3,1).detach().cpu().numpy(), after_mean, after_std)
529
+ pred_latents = torch.from_numpy(pred_latents).to(device)
530
+ with torch.no_grad():
531
+ pred_motions = ae.decode(pred_latents.permute(0,3,1,2))
532
+ else:
533
+ pred_motions = pred_latents.permute(0, 2, 3, 1)
534
+ pred_motions = val_loader.dataset.inv_transform(pred_motions.detach().cpu().numpy(), train_mean, train_std)
535
+ if is_control:
536
+ # foot skate
537
+ skate_ratio, skate_vel = calculate_skating_ratio(torch.from_numpy(pred_motions.transpose(0,2,3,1))) # [batch_size]
538
+ skate_ratio_sum += skate_ratio.sum()
539
+ # control errors
540
+ hint = motion_gt * mask_hint
541
+ for i, (mot, h, mask) in enumerate(zip(pred_motions, hint, mask_hint)):
542
+ control_error = control_l2(np.expand_dims(mot, axis=0), np.expand_dims(h, axis=0),
543
+ np.expand_dims(mask, axis=0))
544
+ mean_error = control_error.sum() / mask.sum()
545
+ dist_sum += mean_error
546
+ control_error = control_error.reshape(-1)
547
+ mask = mask.reshape(-1)
548
+ err_np = calculate_trajectory_error(control_error, mean_error, mask)
549
+ traj_err.append(err_np)
550
+
551
+ pred_motionss = []
552
+ for j in range(bs):
553
+ pred_motionss.append(back_process(pred_motions[j], is_mesh=is_mesh))
554
+ pred_motionss = np.stack(pred_motionss, axis=0)
555
+ pred_motions = val_loader.dataset.transform(pred_motionss)
556
+ (et_pred, em_pred), (et_pred_clip, em_pred_clip) = eval_wrapper.get_co_embeddings(word_embeddings,
557
+ pos_one_hots, sent_len,
558
+ clip_text,
559
+ torch.from_numpy(pred_motions).to(device),
560
+ m_length-1)
561
+ batch_clip_score_pred = 0
562
+ for j in range(bs):
563
+ single_em = em_pred_clip[j]
564
+ single_et = et_pred_clip[j]
565
+ clip_score = (single_em @ single_et.T).item()
566
+ batch_clip_score_pred += clip_score
567
+ clip_score_real += batch_clip_score_pred
568
+
569
+ pose = pose.cuda().float()
570
+ (et, em), (et_clip, em_clip) = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, clip_text,
571
+ pose.clone(), m_length-1)
572
+ batch_clip_score = 0
573
+ for j in range(bs):
574
+ single_em = em_clip[j]
575
+ single_et = et_clip[j]
576
+ clip_score = (single_em @ single_et.T).item()
577
+ batch_clip_score += clip_score
578
+ clip_score_gt += batch_clip_score
579
+ motion_annotation_list.append(em)
580
+ motion_pred_list.append(em_pred)
581
+
582
+ temp_R = calculate_R_precision(et.cpu().numpy(), em.cpu().numpy(), top_k=3, sum_all=True)
583
+ temp_match = euclidean_distance_matrix(et.cpu().numpy(), em.cpu().numpy()).trace()
584
+ R_precision_real += temp_R
585
+ matching_score_real += temp_match
586
+ temp_R = calculate_R_precision(et_pred.cpu().numpy(), em_pred.cpu().numpy(), top_k=3, sum_all=True)
587
+ temp_match = euclidean_distance_matrix(et_pred.cpu().numpy(), em_pred.cpu().numpy()).trace()
588
+ R_precision += temp_R
589
+ matching_score_pred += temp_match
590
+
591
+ nb_sample += bs
592
+
593
+ motion_annotation_np = torch.cat(motion_annotation_list, dim=0).cpu().numpy()
594
+ motion_pred_np = torch.cat(motion_pred_list, dim=0).cpu().numpy()
595
+ gt_mu, gt_cov = calculate_activation_statistics(motion_annotation_np)
596
+ mu, cov = calculate_activation_statistics(motion_pred_np)
597
+
598
+ diversity_real = calculate_diversity(motion_annotation_np, 300 if nb_sample > 300 else 100)
599
+ diversity = calculate_diversity(motion_pred_np, 300 if nb_sample > 300 else 100)
600
+
601
+ R_precision_real = R_precision_real / nb_sample
602
+ R_precision = R_precision / nb_sample
603
+
604
+ clip_score_real = clip_score_real / nb_sample
605
+ clip_score_gt = clip_score_gt / nb_sample
606
+
607
+ matching_score_real = matching_score_real / nb_sample
608
+ matching_score_pred = matching_score_pred / nb_sample
609
+ if is_control:
610
+ # l2 dist
611
+ dist_mean = dist_sum / nb_sample
612
+
613
+ # Skating evaluation
614
+ skating_score = skate_ratio_sum / nb_sample
615
+
616
+ ### For trajecotry evaluation from GMD ###
617
+ traj_err = np.stack(traj_err).mean(0)
618
+
619
+ if cal_mm:
620
+ motion_multimodality = torch.cat(motion_multimodality, dim=0).cpu().numpy()
621
+ multimodality = calculate_multimodality(motion_multimodality, 10)
622
+
623
+ fid = calculate_frechet_distance(gt_mu, gt_cov, mu, cov)
624
+
625
+ msg = (f"--> \t Eva. Ep/Re {ep} :, FID. {fid:.4f}, Diversity Real. {diversity_real:.4f}, Diversity."
626
+ f" {diversity:.4f}, R_precision_real. {R_precision_real}, R_precision. {R_precision},"
627
+ f" matching_score_real. {matching_score_real}, matching_score_pred. {matching_score_pred}"
628
+ f" multimodality. {multimodality:.4f}, clip score. {clip_score_real}"
629
+ + (f" foot skating. {skating_score:.4f}, traj error. {traj_err[1].item():.4f}, loc error. {traj_err[3].item():.4f}, avg error. {traj_err[4].item():.4f}"
630
+ if is_control else ""))
631
+ print(msg)
632
+
633
+ if draw:
634
+ writer.add_scalar('./Test/FID', fid, ep)
635
+ writer.add_scalar('./Test/Diversity', diversity, ep)
636
+ writer.add_scalar('./Test/top1', R_precision[0], ep)
637
+ writer.add_scalar('./Test/top2', R_precision[1], ep)
638
+ writer.add_scalar('./Test/top3', R_precision[2], ep)
639
+ writer.add_scalar('./Test/matching_score', matching_score_pred, ep)
640
+ writer.add_scalar('./Test/clip_score', clip_score_real, ep)
641
+
642
+
643
+ if fid < best_fid:
644
+ msg = f"--> --> \t FID Improved from {best_fid:.5f} to {fid:.5f} !!!"
645
+ if draw: print(msg)
646
+ best_fid, best_ep = fid, ep
647
+ save=True
648
+
649
+
650
+ if matching_score_pred < best_matching:
651
+ msg = f"--> --> \t matching_score Improved from {best_matching:.5f} to {matching_score_pred:.5f} !!!"
652
+ if draw: print(msg)
653
+ best_matching = matching_score_pred
654
+
655
+ if abs(diversity_real - diversity) < abs(diversity_real - best_div):
656
+ msg = f"--> --> \t Diversity Improved from {best_div:.5f} to {diversity:.5f} !!!"
657
+ if draw: print(msg)
658
+ best_div = diversity
659
+
660
+ if R_precision[0] > best_top1:
661
+ msg = f"--> --> \t Top1 Improved from {best_top1:.4f} to {R_precision[0]:.4f} !!!"
662
+ if draw: print(msg)
663
+ best_top1 = R_precision[0]
664
+
665
+ if R_precision[1] > best_top2:
666
+ msg = f"--> --> \t Top2 Improved from {best_top2:.4f} to {R_precision[1]:.4f} !!!"
667
+ if draw: print(msg)
668
+ best_top2 = R_precision[1]
669
+
670
+ if R_precision[2] > best_top3:
671
+ msg = f"--> --> \t Top3 Improved from {best_top3:.4f} to {R_precision[2]:.4f} !!!"
672
+ if draw: print(msg)
673
+ best_top3 = R_precision[2]
674
+
675
+ if clip_score_real > clip_score_old:
676
+ msg = f"--> --> \t CLIP-score Improved from {clip_score_old:.4f} to {clip_score_real:.4f} !!!"
677
+ if draw: print(msg)
678
+ clip_score_old = clip_score_real
679
+
680
+ if cal_mm:
681
+ return best_fid, best_div, best_top1, best_top2, best_top3, best_matching, multimodality, clip_score_old, writer, save
682
+ else:
683
+ if is_control:
684
+ return best_fid, best_div, best_top1, best_top2, best_top3, best_matching, 0, clip_score_old, writer, save, dist_mean, skating_score, traj_err
685
+ else:
686
+ return best_fid, best_div, best_top1, best_top2, best_top3, best_matching, 0, clip_score_old, writer, save, None, None, None
687
+
688
+ #################################################################################
689
+ # Util Functions #
690
+ #################################################################################
691
+ def eval_decorator(fn):
692
+ def inner(model, *args, **kwargs):
693
+ was_training = model.training
694
+ model.eval()
695
+ out = fn(model, *args, **kwargs)
696
+ model.train(was_training)
697
+ return out
698
+ return inner
699
+
700
+ #################################################################################
701
+ # Metrics #
702
+ #################################################################################
703
+ def calculate_mpjpe(gt_joints, pred_joints):
704
+ """
705
+ gt_joints: num_poses x num_joints(22) x 3
706
+ pred_joints: num_poses x num_joints(22) x 3
707
+ (obtained from recover_from_ric())
708
+ """
709
+ assert gt_joints.shape == pred_joints.shape, f"GT shape: {gt_joints.shape}, pred shape: {pred_joints.shape}"
710
+
711
+ # Align by root (pelvis)
712
+ pelvis = gt_joints[:, [0]].mean(1)
713
+ gt_joints = gt_joints - torch.unsqueeze(pelvis, dim=1)
714
+ pelvis = pred_joints[:, [0]].mean(1)
715
+ pred_joints = pred_joints - torch.unsqueeze(pelvis, dim=1)
716
+
717
+ # Compute MPJPE
718
+ mpjpe = torch.linalg.norm(pred_joints - gt_joints, dim=-1) # num_poses x num_joints=22
719
+ mpjpe_seq = mpjpe.mean(-1) # num_poses
720
+
721
+ return mpjpe_seq
722
+
723
+ # (X - X_train)*(X - X_train) = -2X*X_train + X*X + X_train*X_train
724
+ def euclidean_distance_matrix(matrix1, matrix2):
725
+ """
726
+ Params:
727
+ -- matrix1: N1 x D
728
+ -- matrix2: N2 x D
729
+ Returns:
730
+ -- dist: N1 x N2
731
+ dist[i, j] == distance(matrix1[i], matrix2[j])
732
+ """
733
+ assert matrix1.shape[1] == matrix2.shape[1]
734
+ d1 = -2 * np.dot(matrix1, matrix2.T) # shape (num_test, num_train)
735
+ d2 = np.sum(np.square(matrix1), axis=1, keepdims=True) # shape (num_test, 1)
736
+ d3 = np.sum(np.square(matrix2), axis=1) # shape (num_train, )
737
+ dists = np.sqrt(d1 + d2 + d3) # broadcasting
738
+ return dists
739
+
740
+ def calculate_top_k(mat, top_k):
741
+ size = mat.shape[0]
742
+ gt_mat = np.expand_dims(np.arange(size), 1).repeat(size, 1)
743
+ bool_mat = (mat == gt_mat)
744
+ correct_vec = False
745
+ top_k_list = []
746
+ for i in range(top_k):
747
+ # print(correct_vec, bool_mat[:, i])
748
+ correct_vec = (correct_vec | bool_mat[:, i])
749
+ # print(correct_vec)
750
+ top_k_list.append(correct_vec[:, None])
751
+ top_k_mat = np.concatenate(top_k_list, axis=1)
752
+ return top_k_mat
753
+
754
+
755
+ def calculate_R_precision(embedding1, embedding2, top_k, sum_all=False):
756
+ dist_mat = euclidean_distance_matrix(embedding1, embedding2)
757
+ argmax = np.argsort(dist_mat, axis=1)
758
+ top_k_mat = calculate_top_k(argmax, top_k)
759
+ if sum_all:
760
+ return top_k_mat.sum(axis=0)
761
+ else:
762
+ return top_k_mat
763
+
764
+
765
+ def calculate_matching_score(embedding1, embedding2, sum_all=False):
766
+ assert len(embedding1.shape) == 2
767
+ assert embedding1.shape[0] == embedding2.shape[0]
768
+ assert embedding1.shape[1] == embedding2.shape[1]
769
+
770
+ dist = linalg.norm(embedding1 - embedding2, axis=1)
771
+ if sum_all:
772
+ return dist.sum(axis=0)
773
+ else:
774
+ return dist
775
+
776
+
777
+
778
+ def calculate_activation_statistics(activations):
779
+ """
780
+ Params:
781
+ -- activation: num_samples x dim_feat
782
+ Returns:
783
+ -- mu: dim_feat
784
+ -- sigma: dim_feat x dim_feat
785
+ """
786
+ mu = np.mean(activations, axis=0)
787
+ cov = np.cov(activations, rowvar=False)
788
+ return mu, cov
789
+
790
+
791
+ def calculate_diversity(activation, diversity_times):
792
+ assert len(activation.shape) == 2
793
+ assert activation.shape[0] > diversity_times
794
+ num_samples = activation.shape[0]
795
+
796
+ first_indices = np.random.choice(num_samples, diversity_times, replace=False)
797
+ second_indices = np.random.choice(num_samples, diversity_times, replace=False)
798
+ dist = linalg.norm(activation[first_indices] - activation[second_indices], axis=1)
799
+ return dist.mean()
800
+
801
+
802
+ def calculate_multimodality(activation, multimodality_times):
803
+ assert len(activation.shape) == 3
804
+ assert activation.shape[1] > multimodality_times
805
+ num_per_sent = activation.shape[1]
806
+
807
+ first_dices = np.random.choice(num_per_sent, multimodality_times, replace=False)
808
+ second_dices = np.random.choice(num_per_sent, multimodality_times, replace=False)
809
+ dist = linalg.norm(activation[:, first_dices] - activation[:, second_dices], axis=2)
810
+ return dist.mean()
811
+
812
+
813
+ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
814
+ """Numpy implementation of the Frechet Distance.
815
+ The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
816
+ and X_2 ~ N(mu_2, C_2) is
817
+ d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
818
+ Stable version by Dougal J. Sutherland.
819
+ Params:
820
+ -- mu1 : Numpy array containing the activations of a layer of the
821
+ inception net (like returned by the function 'get_predictions')
822
+ for generated samples.
823
+ -- mu2 : The sample mean over activations, precalculated on an
824
+ representative data set.
825
+ -- sigma1: The covariance matrix over activations for generated samples.
826
+ -- sigma2: The covariance matrix over activations, precalculated on an
827
+ representative data set.
828
+ Returns:
829
+ -- : The Frechet Distance.
830
+ """
831
+
832
+ mu1 = np.atleast_1d(mu1)
833
+ mu2 = np.atleast_1d(mu2)
834
+
835
+ sigma1 = np.atleast_2d(sigma1)
836
+ sigma2 = np.atleast_2d(sigma2)
837
+
838
+ assert mu1.shape == mu2.shape, \
839
+ 'Training and test mean vectors have different lengths'
840
+ assert sigma1.shape == sigma2.shape, \
841
+ 'Training and test covariances have different dimensions'
842
+
843
+ diff = mu1 - mu2
844
+
845
+ # Product might be almost singular
846
+ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
847
+ if not np.isfinite(covmean).all():
848
+ msg = ('fid calculation produces singular product; '
849
+ 'adding %s to diagonal of cov estimates') % eps
850
+ print(msg)
851
+ offset = np.eye(sigma1.shape[0]) * eps
852
+ covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
853
+
854
+ # Numerical error might give slight imaginary component
855
+ if np.iscomplexobj(covmean):
856
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
857
+ m = np.max(np.abs(covmean.imag))
858
+ raise ValueError('Imaginary component {}'.format(m))
859
+ covmean = covmean.real
860
+
861
+ tr_covmean = np.trace(covmean)
862
+
863
+ return (diff.dot(diff) + np.trace(sigma1) +
864
+ np.trace(sigma2) - 2 * tr_covmean)
865
+
866
+
867
+ # directly from omnicontrol
868
+ def calculate_skating_ratio(motions):
869
+ thresh_height = 0.05 # 10
870
+ fps = 20.0
871
+ thresh_vel = 0.50 # 20 cm /s
872
+ avg_window = 5 # frames
873
+
874
+ batch_size = motions.shape[0]
875
+ # 10 left, 11 right foot. XZ plane, y up
876
+ # motions [bs, 22, 3, max_len]
877
+ verts_feet = motions[:, [10, 11], :, :].detach().cpu().numpy() # [bs, 2, 3, max_len]
878
+ verts_feet_plane_vel = np.linalg.norm(verts_feet[:, :, [0, 2], 1:] - verts_feet[:, :, [0, 2], :-1],
879
+ axis=2) * fps # [bs, 2, max_len-1]
880
+ # [bs, 2, max_len-1]
881
+ vel_avg = uniform_filter1d(verts_feet_plane_vel, axis=-1, size=avg_window, mode='constant', origin=0)
882
+
883
+ verts_feet_height = verts_feet[:, :, 1, :] # [bs, 2, max_len]
884
+ # If feet touch ground in agjecent frames
885
+ feet_contact = np.logical_and((verts_feet_height[:, :, :-1] < thresh_height),
886
+ (verts_feet_height[:, :, 1:] < thresh_height)) # [bs, 2, max_len - 1]
887
+ # skate velocity
888
+ skate_vel = feet_contact * vel_avg
889
+
890
+ # it must both skating in the current frame
891
+ skating = np.logical_and(feet_contact, (verts_feet_plane_vel > thresh_vel))
892
+ # and also skate in the windows of frames
893
+ skating = np.logical_and(skating, (vel_avg > thresh_vel))
894
+
895
+ # Both feet slide
896
+ skating = np.logical_or(skating[:, 0, :], skating[:, 1, :]) # [bs, max_len -1]
897
+ skating_ratio = np.sum(skating, axis=1) / skating.shape[1]
898
+
899
+ return skating_ratio, skate_vel
900
+
901
+ # directly from omnicontrol
902
+ def control_l2(motion, hint, hint_mask):
903
+ # motion: b, seq, 22, 3
904
+ # hint: b, seq, 22, 1
905
+ loss = np.linalg.norm((motion - hint) * hint_mask, axis=-1)
906
+ return loss
907
+
908
+ # directly from omnicontrol
909
+ def calculate_trajectory_error(dist_error, mean_err_traj, mask, strict=True):
910
+ ''' dist_error shape [5]: error for each kps in metre
911
+ Two threshold: 20 cm and 50 cm.
912
+ If mean error in sequence is more then the threshold, fails
913
+ return: traj_fail(0.2), traj_fail(0.5), all_kps_fail(0.2), all_kps_fail(0.5), all_mean_err.
914
+ Every metrics are already averaged.
915
+ '''
916
+ # mean_err_traj = dist_error.mean(1)
917
+ if strict:
918
+ # Traj fails if any of the key frame fails
919
+ traj_fail_02 = 1.0 - (dist_error <= 0.2).all()
920
+ traj_fail_05 = 1.0 - (dist_error <= 0.5).all()
921
+ else:
922
+ # Traj fails if the mean error of all keyframes more than the threshold
923
+ traj_fail_02 = (mean_err_traj > 0.2)
924
+ traj_fail_05 = (mean_err_traj > 0.5)
925
+ all_fail_02 = (dist_error > 0.2).sum() / mask.sum()
926
+ all_fail_05 = (dist_error > 0.5).sum() / mask.sum()
927
+
928
+ return np.array([traj_fail_02, traj_fail_05, all_fail_02, all_fail_05, dist_error.sum() / mask.sum()])
utils/evaluators.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+ import math
6
+ from torch.nn.utils.rnn import pack_padded_sequence
7
+ import torch.nn.functional as F
8
+ from utils.glove import POS_enumerator
9
+ import clip
10
+
11
+ #################################################################################
12
+ # Evaluators #
13
+ #################################################################################
14
+ def build_evaluators(dim_pose, dataset_name, dim_movement_enc_hidden, dim_movement_latent, dim_word, dim_pos_ohot, dim_text_hidden,
15
+ dim_coemb_hidden, dim_motion_hidden, checkpoints_dir, device):
16
+ movement_enc = MovementConvEncoder(dim_pose, dim_movement_enc_hidden, dim_movement_latent)
17
+ text_enc = TextEncoderBiGRUCo(word_size=dim_word,
18
+ pos_size=dim_pos_ohot,
19
+ hidden_size=dim_text_hidden,
20
+ output_size=dim_coemb_hidden,
21
+ device=device)
22
+
23
+ motion_enc = MotionEncoderBiGRUCo(input_size=dim_movement_latent,
24
+ hidden_size=dim_motion_hidden,
25
+ output_size=dim_coemb_hidden,
26
+ device=device)
27
+ contrast_model = MotionCLIP(dim_pose)
28
+
29
+ checkpoint = torch.load(os.path.join(checkpoints_dir, dataset_name, 'text_mot_match', 'model', 'finest.tar'),
30
+ map_location=device)
31
+ checkpoint_clip = torch.load(os.path.join(checkpoints_dir, dataset_name, 'text_mot_match_clip', 'model', 'finest.tar'),
32
+ map_location=device)
33
+ movement_enc.load_state_dict(checkpoint['movement_encoder'])
34
+ text_enc.load_state_dict(checkpoint['text_encoder'])
35
+ motion_enc.load_state_dict(checkpoint['motion_encoder'])
36
+ contrast_model.load_state_dict(checkpoint_clip['contrast_model'])
37
+ print('Loading Evaluators')
38
+ return text_enc, motion_enc, movement_enc, contrast_model
39
+
40
+ class Evaluators(object):
41
+
42
+ def __init__(self, dataset_name, device):
43
+ if dataset_name == 't2m':
44
+ dim_pose = 67
45
+ elif dataset_name == 'kit':
46
+ dim_pose = 64
47
+ else:
48
+ raise KeyError('Dataset not Recognized!!!')
49
+
50
+ dim_word = 300
51
+ dim_pos_ohot = len(POS_enumerator)
52
+ dim_motion_hidden = 1024
53
+ dim_movement_enc_hidden = 512
54
+ dim_movement_latent = 512
55
+ dim_text_hidden = 512
56
+ dim_coemb_hidden = 512
57
+ checkpoints_dir = 'checkpoints'
58
+ self.unit_length=4
59
+
60
+ self.text_encoder, self.motion_encoder, self.movement_encoder, self.contrast_model \
61
+ = build_evaluators(dim_pose, dataset_name, dim_movement_enc_hidden, dim_movement_latent, dim_word,
62
+ dim_pos_ohot, dim_text_hidden, dim_coemb_hidden, dim_motion_hidden, checkpoints_dir, device)
63
+ self.device = device
64
+
65
+ self.text_encoder.to(device)
66
+ self.motion_encoder.to(device)
67
+ self.movement_encoder.to(device)
68
+ self.contrast_model.to(device)
69
+
70
+ self.text_encoder.eval()
71
+ self.motion_encoder.eval()
72
+ self.movement_encoder.eval()
73
+ self.contrast_model.eval()
74
+
75
+ def get_co_embeddings(self, word_embs, pos_ohot, cap_lens, captions, motions, m_lens):
76
+ with torch.no_grad():
77
+ word_embs = word_embs.detach().to(self.device).float()
78
+ pos_ohot = pos_ohot.detach().to(self.device).float()
79
+ motions = motions.detach().to(self.device).float()
80
+
81
+ '''clip based'''
82
+ clip_em = self.contrast_model.encode_motion(motions.clone(), m_lens)
83
+ clip_et = self.contrast_model.encode_text(captions)
84
+ clip_em = clip_em / clip_em.norm(dim=1, keepdim=True)
85
+ clip_et = clip_et / clip_et.norm(dim=1, keepdim=True)
86
+
87
+ '''original architecture'''
88
+ align_idx = np.argsort(m_lens.data.tolist())[::-1].copy()
89
+ motions = motions[align_idx]
90
+ m_lens = m_lens[align_idx]
91
+
92
+ movements = self.movement_encoder(motions).detach()
93
+ m_lens = m_lens // self.unit_length
94
+ motion_embedding = self.motion_encoder(movements, m_lens)
95
+
96
+ text_embedding = self.text_encoder(word_embs, pos_ohot, cap_lens)
97
+ text_embedding = text_embedding[align_idx]
98
+ return (text_embedding, motion_embedding), (clip_et, clip_em)
99
+
100
+ def get_motion_embeddings(self, motions, m_lens):
101
+ with torch.no_grad():
102
+ motions = motions.detach().to(self.device).float()
103
+ '''clip based'''
104
+ clip_em = self.contrast_model.encode_motion(motions.clone(), m_lens)
105
+ clip_em = clip_em / clip_em.norm(dim=1, keepdim=True)
106
+
107
+ '''original architecture'''
108
+ align_idx = np.argsort(m_lens.data.tolist())[::-1].copy()
109
+ motions = motions[align_idx]
110
+ m_lens = m_lens[align_idx]
111
+
112
+ movements = self.movement_encoder(motions).detach()
113
+ m_lens = m_lens // self.unit_length
114
+ motion_embedding = self.motion_encoder(movements, m_lens)
115
+ return motion_embedding, clip_em
116
+
117
+ #################################################################################
118
+ # Inner Architectures #
119
+ #################################################################################
120
+ def init_weight(m):
121
+ if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose1d):
122
+ nn.init.xavier_normal_(m.weight)
123
+ if m.bias is not None:
124
+ nn.init.constant_(m.bias, 0)
125
+
126
+
127
+ class PositionalEncoding(nn.Module):
128
+ def __init__(self, d_model, max_len=300):
129
+ super(PositionalEncoding, self).__init__()
130
+ pe = torch.zeros(max_len, d_model)
131
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
132
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
133
+ pe[:, 0::2] = torch.sin(position * div_term)
134
+ pe[:, 1::2] = torch.cos(position * div_term)
135
+ self.register_buffer('pe', pe)
136
+
137
+ def forward(self, pos):
138
+ return self.pe[pos]
139
+
140
+
141
+ class PositionalEncodingCLIP(nn.Module):
142
+ def __init__(self, d_model, dropout=0.0, max_len=5000):
143
+ super(PositionalEncodingCLIP, self).__init__()
144
+ self.dropout = nn.Dropout(p=dropout)
145
+ pe = torch.zeros(max_len, d_model)
146
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
147
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
148
+ pe[:, 0::2] = torch.sin(position * div_term)
149
+ pe[:, 1::2] = torch.cos(position * div_term)
150
+ self.register_buffer('pe', pe)
151
+
152
+ def forward(self, x):
153
+ x = x + self.pe[:x.shape[1], :].unsqueeze(0)
154
+ return self.dropout(x)
155
+
156
+
157
+ def no_grad(nets):
158
+ if not isinstance(nets, list):
159
+ nets = [nets]
160
+ for net in nets:
161
+ if net is not None:
162
+ for param in net.parameters():
163
+ param.requires_grad = False
164
+
165
+ def lengths_to_mask(lengths, max_len):
166
+ mask = torch.arange(max_len, device=lengths.device).expand(len(lengths), max_len) < lengths.unsqueeze(1)
167
+ return mask
168
+
169
+
170
+ class MovementConvEncoder(nn.Module):
171
+ def __init__(self, input_size, hidden_size, output_size):
172
+ super(MovementConvEncoder, self).__init__()
173
+ self.main = nn.Sequential(
174
+ nn.Conv1d(input_size, hidden_size, 4, 2, 1),
175
+ nn.Dropout(0.2, inplace=True),
176
+ nn.LeakyReLU(0.2, inplace=True),
177
+ nn.Conv1d(hidden_size, output_size, 4, 2, 1),
178
+ nn.Dropout(0.2, inplace=True),
179
+ nn.LeakyReLU(0.2, inplace=True),
180
+ )
181
+ self.out_net = nn.Linear(output_size, output_size)
182
+ self.main.apply(init_weight)
183
+ self.out_net.apply(init_weight)
184
+
185
+ def forward(self, inputs):
186
+ inputs = inputs.permute(0, 2, 1)
187
+ outputs = self.main(inputs).permute(0, 2, 1)
188
+ return self.out_net(outputs)
189
+
190
+
191
+ class MovementConvDecoder(nn.Module):
192
+ def __init__(self, input_size, hidden_size, output_size):
193
+ super(MovementConvDecoder, self).__init__()
194
+ self.main = nn.Sequential(
195
+ nn.ConvTranspose1d(input_size, hidden_size, 4, 2, 1),
196
+ nn.LeakyReLU(0.2, inplace=True),
197
+ nn.ConvTranspose1d(hidden_size, output_size, 4, 2, 1),
198
+ nn.LeakyReLU(0.2, inplace=True),
199
+ )
200
+ self.out_net = nn.Linear(output_size, output_size)
201
+
202
+ self.main.apply(init_weight)
203
+ self.out_net.apply(init_weight)
204
+
205
+ def forward(self, inputs):
206
+ inputs = inputs.permute(0, 2, 1)
207
+ outputs = self.main(inputs).permute(0, 2, 1)
208
+ return self.out_net(outputs)
209
+
210
+ class TextEncoderBiGRUCo(nn.Module):
211
+ def __init__(self, word_size, pos_size, hidden_size, output_size, device):
212
+ super(TextEncoderBiGRUCo, self).__init__()
213
+ self.device = device
214
+
215
+ self.pos_emb = nn.Linear(pos_size, word_size)
216
+ self.input_emb = nn.Linear(word_size, hidden_size)
217
+ self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True)
218
+ self.output_net = nn.Sequential(
219
+ nn.Linear(hidden_size * 2, hidden_size),
220
+ nn.LayerNorm(hidden_size),
221
+ nn.LeakyReLU(0.2, inplace=True),
222
+ nn.Linear(hidden_size, output_size)
223
+ )
224
+
225
+ self.input_emb.apply(init_weight)
226
+ self.pos_emb.apply(init_weight)
227
+ self.output_net.apply(init_weight)
228
+ self.hidden_size = hidden_size
229
+ self.hidden = nn.Parameter(torch.randn((2, 1, self.hidden_size), requires_grad=True))
230
+
231
+ def forward(self, word_embs, pos_onehot, cap_lens):
232
+ num_samples = word_embs.shape[0]
233
+
234
+ pos_embs = self.pos_emb(pos_onehot)
235
+ inputs = word_embs + pos_embs
236
+ input_embs = self.input_emb(inputs)
237
+ hidden = self.hidden.repeat(1, num_samples, 1)
238
+
239
+ cap_lens = cap_lens.data.tolist()
240
+ emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True)
241
+
242
+ gru_seq, gru_last = self.gru(emb, hidden)
243
+
244
+ gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1)
245
+
246
+ return self.output_net(gru_last)
247
+
248
+
249
+ class MotionEncoderBiGRUCo(nn.Module):
250
+ def __init__(self, input_size, hidden_size, output_size, device):
251
+ super(MotionEncoderBiGRUCo, self).__init__()
252
+ self.device = device
253
+
254
+ self.input_emb = nn.Linear(input_size, hidden_size)
255
+ self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True)
256
+ self.output_net = nn.Sequential(
257
+ nn.Linear(hidden_size*2, hidden_size),
258
+ nn.LayerNorm(hidden_size),
259
+ nn.LeakyReLU(0.2, inplace=True),
260
+ nn.Linear(hidden_size, output_size)
261
+ )
262
+
263
+ self.input_emb.apply(init_weight)
264
+ self.output_net.apply(init_weight)
265
+ self.hidden_size = hidden_size
266
+ self.hidden = nn.Parameter(torch.randn((2, 1, self.hidden_size), requires_grad=True))
267
+
268
+ def forward(self, inputs, m_lens):
269
+ num_samples = inputs.shape[0]
270
+
271
+ input_embs = self.input_emb(inputs)
272
+ hidden = self.hidden.repeat(1, num_samples, 1)
273
+
274
+ cap_lens = m_lens.data.tolist()
275
+ emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True)
276
+
277
+ gru_seq, gru_last = self.gru(emb, hidden)
278
+
279
+ gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1)
280
+
281
+ return self.output_net(gru_last)
282
+
283
+
284
+ class MotionEncoder(nn.Module):
285
+ def __init__(self, in_dim, latent_dim, ff_size, num_layers, num_heads, dropout, activation):
286
+ super().__init__()
287
+ self.input_feats = in_dim
288
+ self.latent_dim = latent_dim
289
+ self.ff_size = ff_size
290
+ self.num_layers = num_layers
291
+ self.num_heads = num_heads
292
+ self.dropout = dropout
293
+ self.activation = activation
294
+
295
+ self.query_token = nn.Parameter(torch.randn(1, self.latent_dim))
296
+
297
+ self.embed_motion = nn.Linear(self.input_feats, self.latent_dim)
298
+ self.sequence_pos_encoder = PositionalEncodingCLIP(self.latent_dim, self.dropout, max_len=2000)
299
+
300
+ seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim,
301
+ nhead=self.num_heads,
302
+ dim_feedforward=self.ff_size,
303
+ dropout=self.dropout,
304
+ activation=self.activation,)
305
+ self.transformer = nn.TransformerEncoder(seqTransEncoderLayer, num_layers=self.num_layers)
306
+ self.out_ln = nn.LayerNorm(self.latent_dim)
307
+ self.out = nn.Linear(self.latent_dim, 512)
308
+
309
+
310
+ def forward(self, motion, padding_mask):
311
+ B, T, D = motion.shape
312
+
313
+ x_emb = self.embed_motion(motion)
314
+
315
+ emb = torch.cat([self.query_token[torch.zeros(B, dtype=torch.long, device=motion.device)][:,None], x_emb], dim=1)
316
+
317
+ padding_mask = torch.cat([torch.zeros_like(padding_mask[:, 0:1]), padding_mask], dim=1)
318
+
319
+ h = self.sequence_pos_encoder(emb)
320
+ h = h.permute(1, 0, 2)
321
+ h = self.transformer(h, src_key_padding_mask=padding_mask)
322
+ h = h.permute(1, 0, 2)
323
+ h = self.out_ln(h)
324
+ motion_emb = self.out(h[:,0])
325
+
326
+ return motion_emb
327
+
328
+
329
+ class MotionCLIP(nn.Module):
330
+ def __init__(self, in_dim):
331
+ super().__init__()
332
+ self.motion_encoder = MotionEncoder(in_dim, 512, 1024, 8, 8, 0.2, 'gelu')
333
+ clip_model, _ = clip.load("ViT-B/16", device="cpu", jit=False)
334
+ self.token_embedding = clip_model.token_embedding
335
+ self.positional_embedding = clip_model.positional_embedding
336
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
337
+ no_grad(self.token_embedding)
338
+
339
+ textTransEncoderLayer = nn.TransformerEncoderLayer(
340
+ d_model=512,
341
+ nhead=8,
342
+ dim_feedforward=1024,
343
+ dropout=0.2,
344
+ activation="gelu",)
345
+ self.textTransEncoder = nn.TransformerEncoder(
346
+ textTransEncoderLayer,
347
+ num_layers=8)
348
+ self.text_ln = nn.LayerNorm(512)
349
+ self.out = nn.Linear(512, 512)
350
+
351
+ def encode_motion(self, motion, m_lens):
352
+ seq_len = motion.shape[1]
353
+ padding_mask = ~lengths_to_mask(m_lens, seq_len)
354
+ motion_embedding = self.motion_encoder(motion, padding_mask.to(motion.device))
355
+ return motion_embedding
356
+
357
+ def encode_text(self, text):
358
+ device = next(self.parameters()).device
359
+
360
+ with torch.no_grad():
361
+ text = clip.tokenize(text, truncate=True).to(device)
362
+ x = self.token_embedding(text).float()
363
+ pe_tokens = x + self.positional_embedding.float()
364
+ pe_tokens = pe_tokens.permute(1,0,2)
365
+ out = self.textTransEncoder(pe_tokens)
366
+ out = out.permute(1, 0, 2)
367
+ out = self.text_ln(out)
368
+
369
+ out = out[torch.arange(x.shape[0]), text.argmax(dim=-1)]
370
+ out = self.out(out)
371
+ return out
372
+
373
+ def forward(self, motion, m_lens, text):
374
+ motion_features = self.encode_motion(motion, m_lens)
375
+ text_features = self.encode_text(text)
376
+
377
+ motion_features = motion_features / motion_features .norm(dim=1, keepdim=True)
378
+ text_features = text_features / text_features.norm(dim=1, keepdim=True)
379
+
380
+ logit_scale = self.logit_scale.exp()
381
+ logits_per_motion = logit_scale * motion_features @ text_features.t()
382
+ logits_per_text = logits_per_motion.t()
383
+ return logits_per_motion, logits_per_text
384
+
385
+ def forward_loss(self, motion, m_lens, text):
386
+ logits_per_motion, logits_per_text = self.forward(motion, m_lens, text)
387
+ labels = torch.arange(len(logits_per_motion)).to(logits_per_motion.device)
388
+
389
+ image_loss = F.cross_entropy(logits_per_motion, labels)
390
+ text_loss = F.cross_entropy(logits_per_text, labels)
391
+ loss = (image_loss + text_loss) / 2
392
+ return loss
utils/glove.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pickle
3
+ from os.path import join as pjoin
4
+
5
+ #################################################################################
6
+ # GloVe #
7
+ #################################################################################
8
+ POS_enumerator = {
9
+ 'VERB': 0,
10
+ 'NOUN': 1,
11
+ 'DET': 2,
12
+ 'ADP': 3,
13
+ 'NUM': 4,
14
+ 'AUX': 5,
15
+ 'PRON': 6,
16
+ 'ADJ': 7,
17
+ 'ADV': 8,
18
+ 'Loc_VIP': 9,
19
+ 'Body_VIP': 10,
20
+ 'Obj_VIP': 11,
21
+ 'Act_VIP': 12,
22
+ 'Desc_VIP': 13,
23
+ 'OTHER': 14,
24
+ }
25
+
26
+ Loc_list = ('left', 'right', 'clockwise', 'counterclockwise', 'anticlockwise', 'forward', 'back', 'backward',
27
+ 'up', 'down', 'straight', 'curve')
28
+
29
+ Body_list = ('arm', 'chin', 'foot', 'feet', 'face', 'hand', 'mouth', 'leg', 'waist', 'eye', 'knee', 'shoulder', 'thigh')
30
+
31
+ Obj_List = ('stair', 'dumbbell', 'chair', 'window', 'floor', 'car', 'ball', 'handrail', 'baseball', 'basketball')
32
+
33
+ Act_list = ('walk', 'run', 'swing', 'pick', 'bring', 'kick', 'put', 'squat', 'throw', 'hop', 'dance', 'jump', 'turn',
34
+ 'stumble', 'dance', 'stop', 'sit', 'lift', 'lower', 'raise', 'wash', 'stand', 'kneel', 'stroll',
35
+ 'rub', 'bend', 'balance', 'flap', 'jog', 'shuffle', 'lean', 'rotate', 'spin', 'spread', 'climb')
36
+
37
+ Desc_list = ('slowly', 'carefully', 'fast', 'careful', 'slow', 'quickly', 'happy', 'angry', 'sad', 'happily',
38
+ 'angrily', 'sadly')
39
+
40
+ VIP_dict = {
41
+ 'Loc_VIP': Loc_list,
42
+ 'Body_VIP': Body_list,
43
+ 'Obj_VIP': Obj_List,
44
+ 'Act_VIP': Act_list,
45
+ 'Desc_VIP': Desc_list,
46
+ }
47
+
48
+
49
+ class GloVe(object):
50
+ def __init__(self, meta_root, prefix):
51
+ vectors = np.load(pjoin(meta_root, '%s_data.npy'%prefix))
52
+ words = pickle.load(open(pjoin(meta_root, '%s_words.pkl'%prefix), 'rb'))
53
+ self.word2idx = pickle.load(open(pjoin(meta_root, '%s_idx.pkl'%prefix), 'rb'))
54
+ self.word2vec = {w: vectors[self.word2idx[w]] for w in words}
55
+
56
+ def _get_pos_ohot(self, pos):
57
+ pos_vec = np.zeros(len(POS_enumerator))
58
+ if pos in POS_enumerator:
59
+ pos_vec[POS_enumerator[pos]] = 1
60
+ else:
61
+ pos_vec[POS_enumerator['OTHER']] = 1
62
+ return pos_vec
63
+
64
+ def __len__(self):
65
+ return len(self.word2vec)
66
+
67
+ def __getitem__(self, item):
68
+ word, pos = item.split('/')
69
+ if word in self.word2vec:
70
+ word_vec = self.word2vec[word]
71
+ vip_pos = None
72
+ for key, values in VIP_dict.items():
73
+ if word in values:
74
+ vip_pos = key
75
+ break
76
+ if vip_pos is not None:
77
+ pos_vec = self._get_pos_ohot(vip_pos)
78
+ else:
79
+ pos_vec = self._get_pos_ohot(pos)
80
+ else:
81
+ word_vec = self.word2vec['unk']
82
+ pos_vec = self._get_pos_ohot('OTHER')
83
+ return word_vec, pos_vec
utils/mesh_mean_std/t2m/mesh_mean.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:be690561cf95997a0a69fc8b00440d0e1af2eb5b0ac1c4f3c5a45a5b0666f3bf
3
+ size 140
utils/mesh_mean_std/t2m/mesh_std.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:19f56c7aaa752639b39eb76742b5ee1b79fb71931b8412dd84ca7312398293b4
3
+ size 140
utils/motion_process.py ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import matplotlib
4
+ import matplotlib.pyplot as plt
5
+ from matplotlib.animation import FuncAnimation, writers
6
+ from mpl_toolkits.mplot3d.art3d import Poly3DCollection
7
+ import mpl_toolkits.mplot3d.axes3d as p3
8
+ import os
9
+ import io
10
+ try:
11
+ from PIL import Image
12
+ except ImportError:
13
+ Image = None
14
+ try:
15
+ import imageio
16
+ except ImportError:
17
+ imageio = None
18
+
19
+ #################################################################################
20
+ # Data Params #
21
+ #################################################################################
22
+ kit_kinematic_chain = [[0, 11, 12, 13, 14, 15], [0, 16, 17, 18, 19, 20], [0, 1, 2, 3, 4], [3, 5, 6, 7], [3, 8, 9, 10]]
23
+ t2m_kinematic_chain = [[0, 2, 5, 8, 11], [0, 1, 4, 7, 10], [0, 3, 6, 9, 12, 15], [9, 14, 17, 19, 21], [9, 13, 16, 18, 20]]
24
+ t2m_left_hand_chain = [[20, 22, 23, 24], [20, 34, 35, 36], [20, 25, 26, 27], [20, 31, 32, 33], [20, 28, 29, 30]]
25
+ t2m_right_hand_chain = [[21, 43, 44, 45], [21, 46, 47, 48], [21, 40, 41, 42], [21, 37, 38, 39], [21, 49, 50, 51]]
26
+
27
+ kit_raw_offsets = np.array(
28
+ [[0, 0, 0], [0, 1, 0], [0, 1, 0], [0, 1, 0], [0, 1, 0],
29
+ [1, 0, 0], [0, -1, 0], [0, -1, 0], [-1, 0, 0], [0, -1, 0],
30
+ [0, -1, 0], [1, 0, 0], [0, -1, 0], [0, -1, 0], [0, 0, 1],
31
+ [0, 0, 1], [-1, 0, 0], [0, -1, 0], [0, -1, 0], [0, 0, 1],
32
+ [0, 0, 1]])
33
+ t2m_raw_offsets = np.array([[0,0,0], [1,0,0], [-1,0,0], [0,1,0], [0,-1,0],
34
+ [0,-1,0], [0,1,0], [0,-1,0], [0,-1,0], [0,1,0],
35
+ [0,0,1], [0,0,1], [0,1,0], [1,0,0], [-1,0,0],
36
+ [0,0,1], [0,-1,0], [0,-1,0], [0,-1,0], [0,-1,0],
37
+ [0,-1,0], [0,-1,0]])
38
+
39
+ #################################################################################
40
+ # Joints Revert #
41
+ #################################################################################
42
+ def qinv(q):
43
+ assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
44
+ mask = torch.ones_like(q)
45
+ mask[..., 1:] = -mask[..., 1:]
46
+ return q * mask
47
+
48
+
49
+ def qrot(q, v):
50
+ """
51
+ Rotate vector(s) v about the rotation described by quaternion(s) q.
52
+ Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v,
53
+ where * denotes any number of dimensions.
54
+ Returns a tensor of shape (*, 3).
55
+ """
56
+ assert q.shape[-1] == 4
57
+ assert v.shape[-1] == 3
58
+ assert q.shape[:-1] == v.shape[:-1]
59
+
60
+ original_shape = list(v.shape)
61
+ # print(q.shape)
62
+ q = q.contiguous().view(-1, 4)
63
+ v = v.contiguous().view(-1, 3)
64
+
65
+ qvec = q[:, 1:]
66
+ uv = torch.cross(qvec, v, dim=1)
67
+ uuv = torch.cross(qvec, uv, dim=1)
68
+ return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape)
69
+
70
+
71
+ def recover_root_rot_pos(data):
72
+ rot_vel = data[..., 0]
73
+ r_rot_ang = torch.zeros_like(rot_vel).to(data.device)
74
+ '''Get Y-axis rotation from rotation velocity'''
75
+ r_rot_ang[..., 1:] = rot_vel[..., :-1]
76
+ r_rot_ang = torch.cumsum(r_rot_ang, dim=-1)
77
+
78
+ r_rot_quat = torch.zeros(data.shape[:-1] + (4,)).to(data.device)
79
+ r_rot_quat[..., 0] = torch.cos(r_rot_ang)
80
+ r_rot_quat[..., 2] = torch.sin(r_rot_ang)
81
+
82
+ r_pos = torch.zeros(data.shape[:-1] + (3,)).to(data.device)
83
+ r_pos[..., 1:, [0, 2]] = data[..., :-1, 1:3]
84
+ '''Add Y-axis rotation to root position'''
85
+ r_pos = qrot(qinv(r_rot_quat), r_pos)
86
+
87
+ r_pos = torch.cumsum(r_pos, dim=-2)
88
+
89
+ r_pos[..., 1] = data[..., 3]
90
+ return r_rot_quat, r_pos
91
+
92
+
93
+ def recover_from_ric(data, joints_num):
94
+ r_rot_quat, r_pos = recover_root_rot_pos(data)
95
+ positions = data[..., 4:(joints_num - 1) * 3 + 4]
96
+ positions = positions.view(positions.shape[:-1] + (-1, 3))
97
+
98
+ '''Add Y-axis rotation to local joints'''
99
+ positions = qrot(qinv(r_rot_quat[..., None, :]).expand(positions.shape[:-1] + (4,)), positions)
100
+
101
+ '''Add root XZ to joints'''
102
+ positions[..., 0] += r_pos[..., 0:1]
103
+ positions[..., 2] += r_pos[..., 2:3]
104
+
105
+ '''Concate root and joints'''
106
+ positions = torch.cat([r_pos.unsqueeze(-2), positions], dim=-2)
107
+
108
+ return positions
109
+
110
+ #################################################################################
111
+ # Motion Plotting #
112
+ #################################################################################
113
+ def plot_3d_motion(save_path, kinematic_tree, joints, title, figsize=(10, 10), fps=120, radius=4, save_frames_dir=None):
114
+ # Ensure Agg backend is used (already set at module level, but ensure it's active)
115
+ if matplotlib.get_backend() != 'Agg':
116
+ matplotlib.use('Agg')
117
+
118
+ title_sp = title.split(' ')
119
+ if len(title_sp) > 20:
120
+ title = '\n'.join([' '.join(title_sp[:10]), ' '.join(title_sp[10:20]), ' '.join(title_sp[20:])])
121
+ elif len(title_sp) > 10:
122
+ title = '\n'.join([' '.join(title_sp[:10]), ' '.join(title_sp[10:])])
123
+
124
+ def init():
125
+ ax.set_xlim3d([-radius / 2, radius / 2])
126
+ ax.set_ylim3d([0, radius])
127
+ ax.set_zlim3d([0, radius])
128
+ fig.suptitle(title, fontsize=20)
129
+ ax.grid(b=False)
130
+
131
+ def plot_xzPlane(minx, maxx, miny, minz, maxz):
132
+ verts = [
133
+ [minx, miny, minz],
134
+ [minx, miny, maxz],
135
+ [maxx, miny, maxz],
136
+ [maxx, miny, minz]
137
+ ]
138
+ xz_plane = Poly3DCollection([verts])
139
+ xz_plane.set_facecolor((0.5, 0.5, 0.5, 0.5))
140
+ ax.add_collection3d(xz_plane)
141
+
142
+ # Ensure joints is in the correct format: (seq_len, num_joints, 3)
143
+ joints = np.array(joints)
144
+ if joints.ndim == 3:
145
+ # Already in (seq_len, num_joints, 3) format
146
+ data = joints.copy()
147
+ elif joints.ndim == 2:
148
+ # If 2D, reshape to (seq_len, num_joints, 3)
149
+ # Assume it's (seq_len * num_joints, 3) or (seq_len, num_joints * 3)
150
+ if joints.shape[1] == 3:
151
+ # (seq_len * num_joints, 3) - need to infer num_joints
152
+ # For t2m, we expect 22 joints
153
+ num_joints = 22
154
+ seq_len = joints.shape[0] // num_joints
155
+ data = joints.reshape(seq_len, num_joints, 3)
156
+ else:
157
+ # (seq_len, num_joints * 3)
158
+ num_joints = joints.shape[1] // 3
159
+ data = joints.reshape(len(joints), num_joints, 3)
160
+ else:
161
+ raise ValueError(f"Invalid joints shape: {joints.shape}, expected (seq_len, num_joints, 3)")
162
+
163
+ # Check if data is valid
164
+ if data.size == 0:
165
+ raise ValueError("Invalid motion data: data is empty")
166
+
167
+ fig = plt.figure(figsize=figsize)
168
+ ax = p3.Axes3D(fig)
169
+ init()
170
+ MINS = data.min(axis=0).min(axis=0)
171
+ MAXS = data.max(axis=0).max(axis=0)
172
+ colors = ['red', 'blue', 'black', 'red', 'blue',
173
+ 'darkblue', 'darkblue', 'darkblue', 'darkblue', 'darkblue',
174
+ 'darkred', 'darkred', 'darkred', 'darkred', 'darkred']
175
+ frame_number = data.shape[0]
176
+
177
+ height_offset = MINS[1]
178
+ data[:, :, 1] -= height_offset
179
+ trajec = data[:, 0, [0, 2]]
180
+
181
+ data[..., 0] -= data[:, 0:1, 0]
182
+ data[..., 2] -= data[:, 0:1, 2]
183
+
184
+ # Recompute bounds after centering
185
+ MINS = data.min(axis=0).min(axis=0)
186
+ MAXS = data.max(axis=0).max(axis=0)
187
+ # Add some padding
188
+ center = (MINS + MAXS) / 2
189
+ ranges = MAXS - MINS
190
+ # Ensure we have a minimum range to avoid issues with very small or zero ranges
191
+ min_range = 0.1 # Minimum range for each axis
192
+ ranges = np.maximum(ranges, min_range)
193
+ max_range = max(ranges) * 1.2 # 20% padding
194
+ plot_mins = center - max_range / 2
195
+ plot_maxs = center + max_range / 2
196
+
197
+
198
+ def update(index):
199
+ # Clear axes properly
200
+ ax.cla()
201
+ # Reapply title
202
+ fig.suptitle(title, fontsize=20)
203
+ # Reapply view settings and limits based on actual data bounds
204
+ ax.set_xlim3d([plot_mins[0], plot_maxs[0]])
205
+ ax.set_ylim3d([plot_mins[1], plot_maxs[1]])
206
+ ax.set_zlim3d([plot_mins[2], plot_maxs[2]])
207
+ ax.view_init(elev=120, azim=-90)
208
+ ax.dist = 7.5
209
+ ax.grid(False)
210
+ plot_xzPlane(plot_mins[0] - trajec[index, 0], plot_maxs[0] - trajec[index, 0], 0,
211
+ plot_mins[2] - trajec[index, 1], plot_maxs[2] - trajec[index, 1])
212
+
213
+ if index > 1:
214
+ ax.plot3D(trajec[:index, 0] - trajec[index, 0], np.zeros_like(trajec[:index, 0]),
215
+ trajec[:index, 1] - trajec[index, 1], linewidth=1.0,
216
+ color='blue')
217
+
218
+ for i, (chain, color) in enumerate(zip(kinematic_tree, colors)):
219
+ if i < len(colors):
220
+ if i < 5:
221
+ linewidth = 4.0
222
+ else:
223
+ linewidth = 2.0
224
+ # Ensure chain indices are valid
225
+ valid_chain = [idx for idx in chain if idx < data.shape[1]]
226
+ if len(valid_chain) > 1:
227
+ ax.plot3D(data[index, valid_chain, 0], data[index, valid_chain, 1], data[index, valid_chain, 2],
228
+ linewidth=linewidth, color=color)
229
+
230
+ plt.axis('off')
231
+ ax.set_xticklabels([])
232
+ ax.set_yticklabels([])
233
+ ax.set_zticklabels([])
234
+
235
+ def render_frame(frame_idx):
236
+ """Render a single frame and return as PIL Image"""
237
+ # Create a fresh figure for each frame to avoid 3D projection issues
238
+ frame_fig = plt.figure(figsize=figsize, facecolor='white')
239
+ frame_ax = frame_fig.add_subplot(111, projection='3d')
240
+ frame_fig.suptitle(title, fontsize=20)
241
+
242
+ # Set limits and view
243
+ frame_ax.set_xlim3d([plot_mins[0], plot_maxs[0]])
244
+ frame_ax.set_ylim3d([plot_mins[1], plot_maxs[1]])
245
+ frame_ax.set_zlim3d([plot_mins[2], plot_maxs[2]])
246
+ frame_ax.view_init(elev=120, azim=-90)
247
+ frame_ax.dist = 7.5
248
+ frame_ax.grid(False)
249
+
250
+ # Plot ground plane
251
+ minx = plot_mins[0] - trajec[frame_idx, 0]
252
+ maxx = plot_maxs[0] - trajec[frame_idx, 0]
253
+ minz = plot_mins[2] - trajec[frame_idx, 1]
254
+ maxz = plot_maxs[2] - trajec[frame_idx, 1]
255
+ verts = [[minx, 0, minz], [minx, 0, maxz], [maxx, 0, maxz], [maxx, 0, minz]]
256
+ xz_plane = Poly3DCollection([verts])
257
+ xz_plane.set_facecolor((0.5, 0.5, 0.5, 0.5))
258
+ frame_ax.add_collection3d(xz_plane)
259
+
260
+ # Plot trajectory
261
+ if frame_idx > 1:
262
+ frame_ax.plot3D(trajec[:frame_idx, 0] - trajec[frame_idx, 0], np.zeros_like(trajec[:frame_idx, 0]),
263
+ trajec[:frame_idx, 1] - trajec[frame_idx, 1], linewidth=1.0,
264
+ color='blue')
265
+
266
+ # Plot skeleton
267
+ for i, (chain, color) in enumerate(zip(kinematic_tree, colors)):
268
+ if i < len(colors):
269
+ if i < 5:
270
+ linewidth = 4.0
271
+ else:
272
+ linewidth = 2.0
273
+ valid_chain = [idx for idx in chain if idx < data.shape[1]]
274
+ if len(valid_chain) > 1:
275
+ frame_ax.plot3D(data[frame_idx, valid_chain, 0], data[frame_idx, valid_chain, 1],
276
+ data[frame_idx, valid_chain, 2], linewidth=linewidth, color=color)
277
+
278
+ plt.axis('off')
279
+ frame_ax.set_xticklabels([])
280
+ frame_ax.set_yticklabels([])
281
+ frame_ax.set_zticklabels([])
282
+
283
+ # Convert to image - copy data so it's independent of the buffer
284
+ buf = io.BytesIO()
285
+ frame_fig.savefig(buf, format='png', dpi=100, bbox_inches='tight', facecolor='white', edgecolor='none')
286
+ buf.seek(0)
287
+ # Create image and convert to RGB for GIF compatibility
288
+ img = Image.open(buf)
289
+ if img.mode != 'RGB':
290
+ img = img.convert('RGB')
291
+ img_copy = img.copy() # Create a copy that's independent of the buffer
292
+ buf.close()
293
+ plt.close(frame_fig)
294
+ return img_copy
295
+
296
+ # Always use frame-by-frame rendering for reliability (works with both GIF and MP4)
297
+ # This ensures 3D plots render correctly
298
+ actual_path = save_path
299
+
300
+ # Note: We'll use frame-by-frame rendering for both MP4 and GIF
301
+ # imageio-ffmpeg will be used for MP4 if available
302
+
303
+ # Use frame-by-frame rendering (works reliably for GIF)
304
+ if Image is None:
305
+ raise RuntimeError("PIL/Pillow is required for GIF generation. Please install: pip install Pillow")
306
+
307
+ # Use provided frames directory or create one for debugging
308
+ frames_dir = save_frames_dir
309
+ if frames_dir is not None:
310
+ os.makedirs(frames_dir, exist_ok=True)
311
+ print(f"Saving individual frames to: {frames_dir}")
312
+
313
+ frames = []
314
+ print(f"Rendering {frame_number} frames...")
315
+ for i in range(frame_number):
316
+ if (i + 1) % 20 == 0:
317
+ print(f" Frame {i+1}/{frame_number}")
318
+ frame_img = render_frame(i)
319
+ frames.append(frame_img)
320
+
321
+ # Save individual frame as PNG for debugging
322
+ if frames_dir is not None:
323
+ frame_path = os.path.join(frames_dir, f"frame_{i:04d}.png")
324
+ frame_img.save(frame_path)
325
+ if i == 0 or i == frame_number - 1:
326
+ print(f" Saved frame {i} to {frame_path}")
327
+
328
+ # Save video - prefer MP4 if imageio-ffmpeg is available, otherwise GIF
329
+ if len(frames) > 0:
330
+ # Ensure all frames are in the same mode and size
331
+ frames_rgb = []
332
+ first_frame = frames[0]
333
+ if first_frame.mode != 'RGB':
334
+ first_frame = first_frame.convert('RGB')
335
+
336
+ # Get size from first frame
337
+ target_size = first_frame.size
338
+ frames_rgb.append(first_frame)
339
+
340
+ # Convert and resize all other frames to match
341
+ for frame in frames[1:]:
342
+ if frame.mode != 'RGB':
343
+ frame = frame.convert('RGB')
344
+ # Ensure all frames are the same size
345
+ if frame.size != target_size:
346
+ frame = frame.resize(target_size, Image.Resampling.LANCZOS)
347
+ frames_rgb.append(frame)
348
+
349
+ # Convert PIL Images to numpy arrays for imageio
350
+ frame_arrays = [np.array(frame) for frame in frames_rgb]
351
+
352
+ # Try to save as MP4 first (better quality and compatibility)
353
+ if actual_path.endswith('.mp4'):
354
+ if imageio is not None:
355
+ try:
356
+ # Use imageio-ffmpeg for MP4 (automatically uses ffmpeg if available)
357
+ imageio.mimsave(actual_path, frame_arrays, fps=fps, codec='libx264', quality=8)
358
+ print(f"Saved {len(frames_rgb)} frames to MP4 using imageio-ffmpeg: {actual_path}")
359
+ except Exception as e:
360
+ print(f"Error saving MP4 with imageio-ffmpeg: {e}")
361
+ print("Falling back to GIF format...")
362
+ # Fall back to GIF
363
+ base_path = os.path.splitext(actual_path)[0]
364
+ actual_path = base_path + '.gif'
365
+ if imageio is not None:
366
+ try:
367
+ imageio.mimsave(actual_path, frame_arrays, duration=1.0/fps, loop=0)
368
+ print(f"Saved {len(frames_rgb)} frames to GIF using imageio: {actual_path}")
369
+ except Exception as e2:
370
+ print(f"imageio GIF failed, using PIL: {e2}")
371
+ frames_rgb[0].save(
372
+ actual_path,
373
+ save_all=True,
374
+ append_images=frames_rgb[1:],
375
+ duration=int(1000 / fps),
376
+ loop=0,
377
+ optimize=False
378
+ )
379
+ else:
380
+ frames_rgb[0].save(
381
+ actual_path,
382
+ save_all=True,
383
+ append_images=frames_rgb[1:],
384
+ duration=int(1000 / fps),
385
+ loop=0,
386
+ optimize=False
387
+ )
388
+ else:
389
+ print("imageio not available, cannot save MP4. Falling back to GIF...")
390
+ base_path = os.path.splitext(actual_path)[0]
391
+ actual_path = base_path + '.gif'
392
+ frames_rgb[0].save(
393
+ actual_path,
394
+ save_all=True,
395
+ append_images=frames_rgb[1:],
396
+ duration=int(1000 / fps),
397
+ loop=0,
398
+ optimize=False
399
+ )
400
+ elif actual_path.endswith('.gif'):
401
+ # Save as GIF
402
+ if imageio is not None:
403
+ try:
404
+ imageio.mimsave(actual_path, frame_arrays, duration=1.0/fps, loop=0)
405
+ print(f"Saved {len(frames_rgb)} frames to GIF using imageio: {actual_path}")
406
+ except Exception as e:
407
+ print(f"imageio failed, using PIL: {e}")
408
+ frames_rgb[0].save(
409
+ actual_path,
410
+ save_all=True,
411
+ append_images=frames_rgb[1:],
412
+ duration=int(1000 / fps),
413
+ loop=0,
414
+ optimize=False
415
+ )
416
+ else:
417
+ frames_rgb[0].save(
418
+ actual_path,
419
+ save_all=True,
420
+ append_images=frames_rgb[1:],
421
+ duration=int(1000 / fps),
422
+ loop=0,
423
+ optimize=False
424
+ )
425
+
426
+ if frames_dir is not None:
427
+ print(f"Saved {len(frames)} individual frames to {frames_dir}")
428
+
429
+ return actual_path
utils/quaternion.py ADDED
@@ -0,0 +1,519 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2018-present, Facebook, Inc.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+
8
+ import torch
9
+ import numpy as np
10
+
11
+ _EPS4 = np.finfo(float).eps * 4.0
12
+
13
+ _FLOAT_EPS = np.finfo(np.float64).eps
14
+
15
+ # PyTorch-backed implementations
16
+ def qinv(q):
17
+ assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
18
+ mask = torch.ones_like(q)
19
+ mask[..., 1:] = -mask[..., 1:]
20
+ return q * mask
21
+
22
+
23
+ def qinv_np(q):
24
+ assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
25
+ return qinv(torch.from_numpy(q).float()).numpy()
26
+
27
+
28
+ def qnormalize(q):
29
+ assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
30
+ return q / torch.norm(q, dim=-1, keepdim=True)
31
+
32
+
33
+ def qmul(q, r):
34
+ """
35
+ Multiply quaternion(s) q with quaternion(s) r.
36
+ Expects two equally-sized tensors of shape (*, 4), where * denotes any number of dimensions.
37
+ Returns q*r as a tensor of shape (*, 4).
38
+ """
39
+ assert q.shape[-1] == 4
40
+ assert r.shape[-1] == 4
41
+
42
+ original_shape = q.shape
43
+
44
+ # Compute outer product
45
+ terms = torch.bmm(r.view(-1, 4, 1), q.view(-1, 1, 4))
46
+
47
+ w = terms[:, 0, 0] - terms[:, 1, 1] - terms[:, 2, 2] - terms[:, 3, 3]
48
+ x = terms[:, 0, 1] + terms[:, 1, 0] - terms[:, 2, 3] + terms[:, 3, 2]
49
+ y = terms[:, 0, 2] + terms[:, 1, 3] + terms[:, 2, 0] - terms[:, 3, 1]
50
+ z = terms[:, 0, 3] - terms[:, 1, 2] + terms[:, 2, 1] + terms[:, 3, 0]
51
+ return torch.stack((w, x, y, z), dim=1).view(original_shape)
52
+
53
+
54
+ def qrot(q, v):
55
+ """
56
+ Rotate vector(s) v about the rotation described by quaternion(s) q.
57
+ Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v,
58
+ where * denotes any number of dimensions.
59
+ Returns a tensor of shape (*, 3).
60
+ """
61
+ assert q.shape[-1] == 4
62
+ assert v.shape[-1] == 3
63
+ assert q.shape[:-1] == v.shape[:-1]
64
+
65
+ original_shape = list(v.shape)
66
+ # print(q.shape)
67
+ q = q.contiguous().view(-1, 4)
68
+ v = v.contiguous().view(-1, 3)
69
+
70
+ qvec = q[:, 1:]
71
+ uv = torch.cross(qvec, v, dim=1)
72
+ uuv = torch.cross(qvec, uv, dim=1)
73
+ return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape)
74
+
75
+
76
+ def qeuler(q, order, epsilon=0, deg=True, follow_order=True):
77
+ """
78
+ Convert quaternion(s) q to Euler angles.
79
+ Expects a tensor of shape (*, 4), where * denotes any number of dimensions.
80
+ Returns a tensor of shape (*, 3).
81
+ """
82
+ assert q.shape[-1] == 4
83
+
84
+ original_shape = list(q.shape)
85
+ original_shape[-1] = 3
86
+ q = q.view(-1, 4)
87
+
88
+ q0 = q[:, 0]
89
+ q1 = q[:, 1]
90
+ q2 = q[:, 2]
91
+ q3 = q[:, 3]
92
+
93
+ if order == 'xyz':
94
+ x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
95
+ y = torch.asin(torch.clamp(2 * (q1 * q3 + q0 * q2), -1 + epsilon, 1 - epsilon))
96
+ z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3))
97
+ elif order == 'yzx':
98
+ x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
99
+ y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3))
100
+ z = torch.asin(torch.clamp(2 * (q1 * q2 + q0 * q3), -1 + epsilon, 1 - epsilon))
101
+ elif order == 'zxy':
102
+ x = torch.asin(torch.clamp(2 * (q0 * q1 + q2 * q3), -1 + epsilon, 1 - epsilon))
103
+ y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
104
+ z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q1 * q1 + q3 * q3))
105
+ elif order == 'xzy':
106
+ x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
107
+ y = torch.atan2(2 * (q0 * q2 + q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3))
108
+ z = torch.asin(torch.clamp(2 * (q0 * q3 - q1 * q2), -1 + epsilon, 1 - epsilon))
109
+ elif order == 'yxz':
110
+ x = torch.asin(torch.clamp(2 * (q0 * q1 - q2 * q3), -1 + epsilon, 1 - epsilon))
111
+ y = torch.atan2(2 * (q1 * q3 + q0 * q2), 1 - 2 * (q1 * q1 + q2 * q2))
112
+ z = torch.atan2(2 * (q1 * q2 + q0 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
113
+ elif order == 'zyx':
114
+ x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
115
+ y = torch.asin(torch.clamp(2 * (q0 * q2 - q1 * q3), -1 + epsilon, 1 - epsilon))
116
+ z = torch.atan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3))
117
+ else:
118
+ raise
119
+ resdict = {"x":x, "y":y, "z":z}
120
+
121
+ # print(order)
122
+ reslist = [resdict[order[i]] for i in range(len(order))] if follow_order else [x, y, z]
123
+ # print(reslist)
124
+ if deg:
125
+ return torch.stack(reslist, dim=1).view(original_shape) * 180 / np.pi
126
+ else:
127
+ return torch.stack(reslist, dim=1).view(original_shape)
128
+
129
+
130
+ # Numpy-backed implementations
131
+
132
+ def qmul_np(q, r):
133
+ q = torch.from_numpy(q).contiguous().float()
134
+ r = torch.from_numpy(r).contiguous().float()
135
+ return qmul(q, r).numpy()
136
+
137
+
138
+ def qrot_np(q, v):
139
+ q = torch.from_numpy(q).contiguous().float()
140
+ v = torch.from_numpy(v).contiguous().float()
141
+ return qrot(q, v).numpy()
142
+
143
+
144
+ def qeuler_np(q, order, epsilon=0, use_gpu=False):
145
+ if use_gpu:
146
+ q = torch.from_numpy(q).cuda().float()
147
+ return qeuler(q, order, epsilon).cpu().numpy()
148
+ else:
149
+ q = torch.from_numpy(q).contiguous().float()
150
+ return qeuler(q, order, epsilon).numpy()
151
+
152
+
153
+ def qfix(q):
154
+ """
155
+ Enforce quaternion continuity across the time dimension by selecting
156
+ the representation (q or -q) with minimal distance (or, equivalently, maximal dot product)
157
+ between two consecutive frames.
158
+
159
+ Expects a tensor of shape (L, J, 4), where L is the sequence length and J is the number of joints.
160
+ Returns a tensor of the same shape.
161
+ """
162
+ assert len(q.shape) == 3
163
+ assert q.shape[-1] == 4
164
+
165
+ result = q.copy()
166
+ dot_products = np.sum(q[1:] * q[:-1], axis=2)
167
+ mask = dot_products < 0
168
+ mask = (np.cumsum(mask, axis=0) % 2).astype(bool)
169
+ result[1:][mask] *= -1
170
+ return result
171
+
172
+
173
+ def euler2quat(e, order, deg=True):
174
+ """
175
+ Convert Euler angles to quaternions.
176
+ """
177
+ assert e.shape[-1] == 3
178
+
179
+ original_shape = list(e.shape)
180
+ original_shape[-1] = 4
181
+
182
+ e = e.view(-1, 3)
183
+
184
+ ## if euler angles in degrees
185
+ if deg:
186
+ e = e * np.pi / 180.
187
+
188
+ x = e[:, 0]
189
+ y = e[:, 1]
190
+ z = e[:, 2]
191
+
192
+ rx = torch.stack((torch.cos(x / 2), torch.sin(x / 2), torch.zeros_like(x), torch.zeros_like(x)), dim=1)
193
+ ry = torch.stack((torch.cos(y / 2), torch.zeros_like(y), torch.sin(y / 2), torch.zeros_like(y)), dim=1)
194
+ rz = torch.stack((torch.cos(z / 2), torch.zeros_like(z), torch.zeros_like(z), torch.sin(z / 2)), dim=1)
195
+
196
+ result = None
197
+ for coord in order:
198
+ if coord == 'x':
199
+ r = rx
200
+ elif coord == 'y':
201
+ r = ry
202
+ elif coord == 'z':
203
+ r = rz
204
+ else:
205
+ raise
206
+ if result is None:
207
+ result = r
208
+ else:
209
+ result = qmul(result, r)
210
+
211
+ # Reverse antipodal representation to have a non-negative "w"
212
+ if order in ['xyz', 'yzx', 'zxy']:
213
+ result *= -1
214
+
215
+ return result.view(original_shape)
216
+
217
+
218
+ def expmap_to_quaternion(e):
219
+ """
220
+ Convert axis-angle rotations (aka exponential maps) to quaternions.
221
+ Stable formula from "Practical Parameterization of Rotations Using the Exponential Map".
222
+ Expects a tensor of shape (*, 3), where * denotes any number of dimensions.
223
+ Returns a tensor of shape (*, 4).
224
+ """
225
+ assert e.shape[-1] == 3
226
+
227
+ original_shape = list(e.shape)
228
+ original_shape[-1] = 4
229
+ e = e.reshape(-1, 3)
230
+
231
+ theta = np.linalg.norm(e, axis=1).reshape(-1, 1)
232
+ w = np.cos(0.5 * theta).reshape(-1, 1)
233
+ xyz = 0.5 * np.sinc(0.5 * theta / np.pi) * e
234
+ return np.concatenate((w, xyz), axis=1).reshape(original_shape)
235
+
236
+
237
+ def euler_to_quaternion(e, order):
238
+ """
239
+ Convert Euler angles to quaternions.
240
+ """
241
+ assert e.shape[-1] == 3
242
+
243
+ original_shape = list(e.shape)
244
+ original_shape[-1] = 4
245
+
246
+ e = e.reshape(-1, 3)
247
+
248
+ x = e[:, 0]
249
+ y = e[:, 1]
250
+ z = e[:, 2]
251
+
252
+ rx = np.stack((np.cos(x / 2), np.sin(x / 2), np.zeros_like(x), np.zeros_like(x)), axis=1)
253
+ ry = np.stack((np.cos(y / 2), np.zeros_like(y), np.sin(y / 2), np.zeros_like(y)), axis=1)
254
+ rz = np.stack((np.cos(z / 2), np.zeros_like(z), np.zeros_like(z), np.sin(z / 2)), axis=1)
255
+
256
+ result = None
257
+ for coord in order:
258
+ if coord == 'x':
259
+ r = rx
260
+ elif coord == 'y':
261
+ r = ry
262
+ elif coord == 'z':
263
+ r = rz
264
+ else:
265
+ raise
266
+ if result is None:
267
+ result = r
268
+ else:
269
+ result = qmul_np(result, r)
270
+
271
+ # Reverse antipodal representation to have a non-negative "w"
272
+ if order in ['xyz', 'yzx', 'zxy']:
273
+ result *= -1
274
+
275
+ return result.reshape(original_shape)
276
+
277
+
278
+ def quaternion_to_matrix(quaternions):
279
+ """
280
+ Convert rotations given as quaternions to rotation matrices.
281
+ Args:
282
+ quaternions: quaternions with real part first,
283
+ as tensor of shape (..., 4).
284
+ Returns:
285
+ Rotation matrices as tensor of shape (..., 3, 3).
286
+ """
287
+ r, i, j, k = torch.unbind(quaternions, -1)
288
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
289
+
290
+ o = torch.stack(
291
+ (
292
+ 1 - two_s * (j * j + k * k),
293
+ two_s * (i * j - k * r),
294
+ two_s * (i * k + j * r),
295
+ two_s * (i * j + k * r),
296
+ 1 - two_s * (i * i + k * k),
297
+ two_s * (j * k - i * r),
298
+ two_s * (i * k - j * r),
299
+ two_s * (j * k + i * r),
300
+ 1 - two_s * (i * i + j * j),
301
+ ),
302
+ -1,
303
+ )
304
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
305
+
306
+
307
+ def quaternion_to_matrix_np(quaternions):
308
+ q = torch.from_numpy(quaternions).contiguous().float()
309
+ return quaternion_to_matrix(q).numpy()
310
+
311
+
312
+ def quaternion_to_cont6d_np(quaternions):
313
+ rotation_mat = quaternion_to_matrix_np(quaternions)
314
+ cont_6d = np.concatenate([rotation_mat[..., 0], rotation_mat[..., 1]], axis=-1)
315
+ return cont_6d
316
+
317
+
318
+ def quaternion_to_cont6d(quaternions):
319
+ rotation_mat = quaternion_to_matrix(quaternions)
320
+ cont_6d = torch.cat([rotation_mat[..., 0], rotation_mat[..., 1]], dim=-1)
321
+ return cont_6d
322
+
323
+
324
+ def cont6d_to_matrix(cont6d):
325
+ assert cont6d.shape[-1] == 6, "The last dimension must be 6"
326
+ x_raw = cont6d[..., 0:3]
327
+ y_raw = cont6d[..., 3:6]
328
+
329
+ x = x_raw / torch.norm(x_raw, dim=-1, keepdim=True)
330
+ z = torch.cross(x, y_raw, dim=-1)
331
+ z = z / torch.norm(z, dim=-1, keepdim=True)
332
+
333
+ y = torch.cross(z, x, dim=-1)
334
+
335
+ x = x[..., None]
336
+ y = y[..., None]
337
+ z = z[..., None]
338
+
339
+ mat = torch.cat([x, y, z], dim=-1)
340
+ return mat
341
+
342
+
343
+ def cont6d_to_matrix_np(cont6d):
344
+ q = torch.from_numpy(cont6d).contiguous().float()
345
+ return cont6d_to_matrix(q).numpy()
346
+
347
+
348
+ def qpow(q0, t, dtype=torch.float):
349
+ ''' q0 : tensor of quaternions
350
+ t: tensor of powers
351
+ '''
352
+ q0 = qnormalize(q0)
353
+ theta0 = torch.acos(q0[..., 0])
354
+
355
+ ## if theta0 is close to zero, add epsilon to avoid NaNs
356
+ mask = (theta0 <= 10e-10) * (theta0 >= -10e-10)
357
+ theta0 = (1 - mask) * theta0 + mask * 10e-10
358
+ v0 = q0[..., 1:] / torch.sin(theta0).view(-1, 1)
359
+
360
+ if isinstance(t, torch.Tensor):
361
+ q = torch.zeros(t.shape + q0.shape)
362
+ theta = t.view(-1, 1) * theta0.view(1, -1)
363
+ else: ## if t is a number
364
+ q = torch.zeros(q0.shape)
365
+ theta = t * theta0
366
+
367
+ q[..., 0] = torch.cos(theta)
368
+ q[..., 1:] = v0 * torch.sin(theta).unsqueeze(-1)
369
+
370
+ return q.to(dtype)
371
+
372
+
373
+ def qslerp(q0, q1, t):
374
+ '''
375
+ q0: starting quaternion
376
+ q1: ending quaternion
377
+ t: array of points along the way
378
+
379
+ Returns:
380
+ Tensor of Slerps: t.shape + q0.shape
381
+ '''
382
+
383
+ q0 = qnormalize(q0)
384
+ q1 = qnormalize(q1)
385
+ q_ = qpow(qmul(q1, qinv(q0)), t)
386
+
387
+ return qmul(q_,
388
+ q0.contiguous().view(torch.Size([1] * len(t.shape)) + q0.shape).expand(t.shape + q0.shape).contiguous())
389
+
390
+
391
+ def qbetween(v0, v1):
392
+ '''
393
+ find the quaternion used to rotate v0 to v1
394
+ '''
395
+ assert v0.shape[-1] == 3, 'v0 must be of the shape (*, 3)'
396
+ assert v1.shape[-1] == 3, 'v1 must be of the shape (*, 3)'
397
+
398
+ v = torch.cross(v0, v1)
399
+ w = torch.sqrt((v0 ** 2).sum(dim=-1, keepdim=True) * (v1 ** 2).sum(dim=-1, keepdim=True)) + (v0 * v1).sum(dim=-1,
400
+ keepdim=True)
401
+ return qnormalize(torch.cat([w, v], dim=-1))
402
+
403
+
404
+ def qbetween_np(v0, v1):
405
+ '''
406
+ find the quaternion used to rotate v0 to v1
407
+ '''
408
+ assert v0.shape[-1] == 3, 'v0 must be of the shape (*, 3)'
409
+ assert v1.shape[-1] == 3, 'v1 must be of the shape (*, 3)'
410
+
411
+ v0 = torch.from_numpy(v0).float()
412
+ v1 = torch.from_numpy(v1).float()
413
+ return qbetween(v0, v1).numpy()
414
+
415
+
416
+ def lerp(p0, p1, t):
417
+ if not isinstance(t, torch.Tensor):
418
+ t = torch.Tensor([t])
419
+
420
+ new_shape = t.shape + p0.shape
421
+ new_view_t = t.shape + torch.Size([1] * len(p0.shape))
422
+ new_view_p = torch.Size([1] * len(t.shape)) + p0.shape
423
+ p0 = p0.view(new_view_p).expand(new_shape)
424
+ p1 = p1.view(new_view_p).expand(new_shape)
425
+ t = t.view(new_view_t).expand(new_shape)
426
+
427
+ return p0 + t * (p1 - p0)
428
+
429
+ def matrix_to_quat(R) -> torch.Tensor:
430
+ '''
431
+ https://github.com/duolu/pyrotation/blob/master/pyrotation/pyrotation.py
432
+ Convert a rotation matrix to a unit quaternion.
433
+ This uses the Shepperd’s method for numerical stability.
434
+ '''
435
+
436
+ # The rotation matrix must be orthonormal
437
+
438
+ w2 = (1 + R[..., 0, 0] + R[..., 1, 1] + R[..., 2, 2])
439
+ x2 = (1 + R[..., 0, 0] - R[..., 1, 1] - R[..., 2, 2])
440
+ y2 = (1 - R[..., 0, 0] + R[..., 1, 1] - R[..., 2, 2])
441
+ z2 = (1 - R[..., 0, 0] - R[..., 1, 1] + R[..., 2, 2])
442
+
443
+ yz = (R[..., 1, 2] + R[..., 2, 1])
444
+ xz = (R[..., 2, 0] + R[..., 0, 2])
445
+ xy = (R[..., 0, 1] + R[..., 1, 0])
446
+
447
+ wx = (R[..., 2, 1] - R[..., 1, 2])
448
+ wy = (R[..., 0, 2] - R[..., 2, 0])
449
+ wz = (R[..., 1, 0] - R[..., 0, 1])
450
+
451
+ w = torch.empty_like(x2)
452
+ x = torch.empty_like(x2)
453
+ y = torch.empty_like(x2)
454
+ z = torch.empty_like(x2)
455
+
456
+ flagA = (R[..., 2, 2] < 0) * (R[..., 0, 0] > R[..., 1, 1])
457
+ flagB = (R[..., 2, 2] < 0) * (R[..., 0, 0] <= R[..., 1, 1])
458
+ flagC = (R[..., 2, 2] >= 0) * (R[..., 0, 0] < -R[..., 1, 1])
459
+ flagD = (R[..., 2, 2] >= 0) * (R[..., 0, 0] >= -R[..., 1, 1])
460
+
461
+ x[flagA] = torch.sqrt(x2[flagA])
462
+ w[flagA] = wx[flagA] / x[flagA]
463
+ y[flagA] = xy[flagA] / x[flagA]
464
+ z[flagA] = xz[flagA] / x[flagA]
465
+
466
+ y[flagB] = torch.sqrt(y2[flagB])
467
+ w[flagB] = wy[flagB] / y[flagB]
468
+ x[flagB] = xy[flagB] / y[flagB]
469
+ z[flagB] = yz[flagB] / y[flagB]
470
+
471
+ z[flagC] = torch.sqrt(z2[flagC])
472
+ w[flagC] = wz[flagC] / z[flagC]
473
+ x[flagC] = xz[flagC] / z[flagC]
474
+ y[flagC] = yz[flagC] / z[flagC]
475
+
476
+ w[flagD] = torch.sqrt(w2[flagD])
477
+ x[flagD] = wx[flagD] / w[flagD]
478
+ y[flagD] = wy[flagD] / w[flagD]
479
+ z[flagD] = wz[flagD] / w[flagD]
480
+
481
+ # if R[..., 2, 2] < 0:
482
+ #
483
+ # if R[..., 0, 0] > R[..., 1, 1]:
484
+ #
485
+ # x = torch.sqrt(x2)
486
+ # w = wx / x
487
+ # y = xy / x
488
+ # z = xz / x
489
+ #
490
+ # else:
491
+ #
492
+ # y = torch.sqrt(y2)
493
+ # w = wy / y
494
+ # x = xy / y
495
+ # z = yz / y
496
+ #
497
+ # else:
498
+ #
499
+ # if R[..., 0, 0] < -R[..., 1, 1]:
500
+ #
501
+ # z = torch.sqrt(z2)
502
+ # w = wz / z
503
+ # x = xz / z
504
+ # y = yz / z
505
+ #
506
+ # else:
507
+ #
508
+ # w = torch.sqrt(w2)
509
+ # x = wx / w
510
+ # y = wy / w
511
+ # z = wz / w
512
+
513
+ res = [w, x, y, z]
514
+ res = [z.unsqueeze(-1) for z in res]
515
+
516
+ return torch.cat(res, dim=-1) / 2
517
+
518
+ def cont6d_to_quat(cont6d):
519
+ return matrix_to_quat(cont6d_to_matrix(cont6d))
utils/skeleton.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils.quaternion import *
2
+ import scipy.ndimage.filters as filters
3
+
4
+ class Skeleton(object):
5
+ def __init__(self, offset, kinematic_tree, device):
6
+ self.device = device
7
+ self._raw_offset_np = offset.numpy()
8
+ self._raw_offset = offset.clone().detach().to(device).float()
9
+ self._kinematic_tree = kinematic_tree
10
+ self._offset = None
11
+ self._parents = [0] * len(self._raw_offset)
12
+ self._parents[0] = -1
13
+ for chain in self._kinematic_tree:
14
+ for j in range(1, len(chain)):
15
+ self._parents[chain[j]] = chain[j-1]
16
+
17
+ def njoints(self):
18
+ return len(self._raw_offset)
19
+
20
+ def offset(self):
21
+ return self._offset
22
+
23
+ def set_offset(self, offsets):
24
+ self._offset = offsets.clone().detach().to(self.device).float()
25
+
26
+ def kinematic_tree(self):
27
+ return self._kinematic_tree
28
+
29
+ def parents(self):
30
+ return self._parents
31
+
32
+ # joints (batch_size, joints_num, 3)
33
+ def get_offsets_joints_batch(self, joints):
34
+ assert len(joints.shape) == 3
35
+ _offsets = self._raw_offset.expand(joints.shape[0], -1, -1).clone()
36
+ for i in range(1, self._raw_offset.shape[0]):
37
+ _offsets[:, i] = torch.norm(joints[:, i] - joints[:, self._parents[i]], p=2, dim=1)[:, None] * _offsets[:, i]
38
+
39
+ self._offset = _offsets.detach()
40
+ return _offsets
41
+
42
+ # joints (joints_num, 3)
43
+ def get_offsets_joints(self, joints):
44
+ assert len(joints.shape) == 2
45
+ _offsets = self._raw_offset.clone()
46
+ for i in range(1, self._raw_offset.shape[0]):
47
+ # print(joints.shape)
48
+ _offsets[i] = torch.norm(joints[i] - joints[self._parents[i]], p=2, dim=0) * _offsets[i]
49
+
50
+ self._offset = _offsets.detach()
51
+ return _offsets
52
+
53
+ # face_joint_idx should follow the order of right hip, left hip, right shoulder, left shoulder
54
+ # joints (batch_size, joints_num, 3)
55
+ def inverse_kinematics_np(self, joints, face_joint_idx, smooth_forward=False):
56
+ assert len(face_joint_idx) == 4
57
+ '''Get Forward Direction'''
58
+ l_hip, r_hip, sdr_r, sdr_l = face_joint_idx
59
+ across1 = joints[:, r_hip] - joints[:, l_hip]
60
+ across2 = joints[:, sdr_r] - joints[:, sdr_l]
61
+ across = across1 + across2
62
+ across = across / (np.sqrt((across**2).sum(axis=-1, keepdims=True)) + 1e-8)
63
+ # print(across1.shape, across2.shape)
64
+
65
+ # forward (batch_size, 3)
66
+ forward = np.cross(np.array([[0, 1, 0]]), across, axis=-1)
67
+ if smooth_forward:
68
+ forward = filters.gaussian_filter1d(forward, 20, axis=0, mode='nearest')
69
+ # forward (batch_size, 3)
70
+ forward = forward / (np.sqrt((forward**2).sum(axis=-1))[..., np.newaxis] + 1e-8)
71
+
72
+ '''Get Root Rotation'''
73
+ target = np.array([[0,0,1]]).repeat(len(forward), axis=0)
74
+ root_quat = qbetween_np(forward, target)
75
+
76
+ '''Inverse Kinematics'''
77
+ # quat_params (batch_size, joints_num, 4)
78
+ # print(joints.shape[:-1])
79
+ quat_params = np.zeros(joints.shape[:-1] + (4,))
80
+ # print(quat_params.shape)
81
+ root_quat[0] = np.array([[1.0, 0.0, 0.0, 0.0]])
82
+ quat_params[:, 0] = root_quat
83
+ # quat_params[0, 0] = np.array([[1.0, 0.0, 0.0, 0.0]])
84
+ for chain in self._kinematic_tree:
85
+ R = root_quat
86
+ for j in range(len(chain) - 1):
87
+ # (batch, 3)
88
+ u = self._raw_offset_np[chain[j+1]][np.newaxis,...].repeat(len(joints), axis=0)
89
+ # print(u.shape)
90
+ # (batch, 3)
91
+ v = joints[:, chain[j+1]] - joints[:, chain[j]]
92
+ v = v / (np.sqrt((v**2).sum(axis=-1, keepdims=True)) + 1e-8)
93
+ # print(u.shape, v.shape)
94
+ rot_u_v = qbetween_np(u, v)
95
+
96
+ R_loc = qmul_np(qinv_np(R), rot_u_v)
97
+
98
+ quat_params[:,chain[j + 1], :] = R_loc
99
+ R = qmul_np(R, R_loc)
100
+
101
+ return quat_params
102
+
103
+ # Be sure root joint is at the beginning of kinematic chains
104
+ def forward_kinematics(self, quat_params, root_pos, skel_joints=None, do_root_R=True):
105
+ # quat_params (batch_size, joints_num, 4)
106
+ # joints (batch_size, joints_num, 3)
107
+ # root_pos (batch_size, 3)
108
+ if skel_joints is not None:
109
+ offsets = self.get_offsets_joints_batch(skel_joints)
110
+ if len(self._offset.shape) == 2:
111
+ offsets = self._offset.expand(quat_params.shape[0], -1, -1)
112
+ joints = torch.zeros(quat_params.shape[:-1] + (3,)).to(self.device)
113
+ joints[:, 0] = root_pos
114
+ for chain in self._kinematic_tree:
115
+ if do_root_R:
116
+ R = quat_params[:, 0]
117
+ else:
118
+ R = torch.tensor([[1.0, 0.0, 0.0, 0.0]]).expand(len(quat_params), -1).detach().to(self.device)
119
+ for i in range(1, len(chain)):
120
+ R = qmul(R, quat_params[:, chain[i]])
121
+ offset_vec = offsets[:, chain[i]]
122
+ joints[:, chain[i]] = qrot(R, offset_vec) + joints[:, chain[i-1]]
123
+ return joints
124
+
125
+ # Be sure root joint is at the beginning of kinematic chains
126
+ def forward_kinematics_np(self, quat_params, root_pos, skel_joints=None, do_root_R=True):
127
+ # quat_params (batch_size, joints_num, 4)
128
+ # joints (batch_size, joints_num, 3)
129
+ # root_pos (batch_size, 3)
130
+ if skel_joints is not None:
131
+ skel_joints = torch.from_numpy(skel_joints)
132
+ offsets = self.get_offsets_joints_batch(skel_joints)
133
+ if len(self._offset.shape) == 2:
134
+ offsets = self._offset.expand(quat_params.shape[0], -1, -1)
135
+ offsets = offsets.numpy()
136
+ joints = np.zeros(quat_params.shape[:-1] + (3,))
137
+ joints[:, 0] = root_pos
138
+ for chain in self._kinematic_tree:
139
+ if do_root_R:
140
+ R = quat_params[:, 0]
141
+ else:
142
+ R = np.array([[1.0, 0.0, 0.0, 0.0]]).repeat(len(quat_params), axis=0)
143
+ for i in range(1, len(chain)):
144
+ R = qmul_np(R, quat_params[:, chain[i]])
145
+ offset_vec = offsets[:, chain[i]]
146
+ joints[:, chain[i]] = qrot_np(R, offset_vec) + joints[:, chain[i - 1]]
147
+ return joints
148
+
149
+ def forward_kinematics_cont6d_np(self, cont6d_params, root_pos, skel_joints=None, do_root_R=True):
150
+ # cont6d_params (batch_size, joints_num, 6)
151
+ # joints (batch_size, joints_num, 3)
152
+ # root_pos (batch_size, 3)
153
+ if skel_joints is not None:
154
+ skel_joints = torch.from_numpy(skel_joints)
155
+ offsets = self.get_offsets_joints_batch(skel_joints)
156
+ if len(self._offset.shape) == 2:
157
+ offsets = self._offset.expand(cont6d_params.shape[0], -1, -1)
158
+ offsets = offsets.numpy()
159
+ joints = np.zeros(cont6d_params.shape[:-1] + (3,))
160
+ joints[:, 0] = root_pos
161
+ for chain in self._kinematic_tree:
162
+ if do_root_R:
163
+ matR = cont6d_to_matrix_np(cont6d_params[:, 0])
164
+ else:
165
+ matR = np.eye(3)[np.newaxis, :].repeat(len(cont6d_params), axis=0)
166
+ for i in range(1, len(chain)):
167
+ matR = np.matmul(matR, cont6d_to_matrix_np(cont6d_params[:, chain[i]]))
168
+ offset_vec = offsets[:, chain[i]][..., np.newaxis]
169
+ # print(matR.shape, offset_vec.shape)
170
+ joints[:, chain[i]] = np.matmul(matR, offset_vec).squeeze(-1) + joints[:, chain[i-1]]
171
+ return joints
172
+
173
+ def forward_kinematics_cont6d(self, cont6d_params, root_pos, skel_joints=None, do_root_R=True):
174
+ # cont6d_params (batch_size, joints_num, 6)
175
+ # joints (batch_size, joints_num, 3)
176
+ # root_pos (batch_size, 3)
177
+ if skel_joints is not None:
178
+ # skel_joints = torch.from_numpy(skel_joints)
179
+ offsets = self.get_offsets_joints_batch(skel_joints)
180
+ if len(self._offset.shape) == 2:
181
+ offsets = self._offset.expand(cont6d_params.shape[0], -1, -1)
182
+ joints = torch.zeros(cont6d_params.shape[:-1] + (3,)).to(cont6d_params.device)
183
+ joints[..., 0, :] = root_pos
184
+ for chain in self._kinematic_tree:
185
+ if do_root_R:
186
+ matR = cont6d_to_matrix(cont6d_params[:, 0])
187
+ else:
188
+ matR = torch.eye(3).expand((len(cont6d_params), -1, -1)).detach().to(cont6d_params.device)
189
+ for i in range(1, len(chain)):
190
+ matR = torch.matmul(matR, cont6d_to_matrix(cont6d_params[:, chain[i]]))
191
+ offset_vec = offsets[:, chain[i]].unsqueeze(-1)
192
+ # print(matR.shape, offset_vec.shape)
193
+ joints[:, chain[i]] = torch.matmul(matR, offset_vec).squeeze(-1) + joints[:, chain[i-1]]
194
+ return joints
utils/train_utils.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ import time
4
+ import torch.distributed as dist
5
+ import logging
6
+ #################################################################################
7
+ # DDP Functions #
8
+ #################################################################################
9
+ def cleanup():
10
+ dist.destroy_process_group()
11
+
12
+ #################################################################################
13
+ # Util Functions #
14
+ #################################################################################
15
+ def lengths_to_mask(lengths, max_len):
16
+ # max_len = max(lengths)
17
+ mask = torch.arange(max_len, device=lengths.device).expand(len(lengths), max_len) < lengths.unsqueeze(1)
18
+ return mask #(b, len)
19
+
20
+
21
+ def get_mask_subset_prob(mask, prob):
22
+ subset_mask = torch.bernoulli(mask, p=prob) & mask
23
+ return subset_mask
24
+
25
+
26
+ def uniform(shape, device=None):
27
+ return torch.zeros(shape, device=device).float().uniform_(0, 1)
28
+
29
+
30
+ def cosine_schedule(t):
31
+ return torch.cos(t * math.pi * 0.5)
32
+
33
+
34
+ def update_ema(model, ema_model, ema_decay):
35
+ with torch.no_grad():
36
+ for ema_param, model_param in zip(ema_model.parameters(), model.parameters()):
37
+ ema_param.data.mul_(ema_decay).add_(model_param.data, alpha=(1 - ema_decay))
38
+
39
+ #################################################################################
40
+ # Logging Functions #
41
+ #################################################################################
42
+ def def_value():
43
+ return 0.0
44
+
45
+
46
+ def create_logger(logging_dir):
47
+ if dist.get_rank() == 0: # real logger
48
+ logging.basicConfig(
49
+ level=logging.INFO,
50
+ format='[\033[34m%(asctime)s\033[0m] %(message)s',
51
+ datefmt='%Y-%m-%d %H:%M:%S',
52
+ handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")]
53
+ )
54
+ logger = logging.getLogger(__name__)
55
+ else: # dummy logger (does nothing)
56
+ logger = logging.getLogger(__name__)
57
+ logger.addHandler(logging.NullHandler())
58
+ return logger
59
+
60
+
61
+ def update_lr_warm_up(nb_iter, warm_up_iter, optimizer, lr):
62
+ current_lr = lr * (nb_iter + 1) / (warm_up_iter + 1)
63
+ for param_group in optimizer.param_groups:
64
+ param_group["lr"] = current_lr
65
+ return current_lr
66
+
67
+
68
+ def save(file_name, ep, model, optimizer, scheduler, total_it, name, ema=None):
69
+ state = {
70
+ name: model.state_dict(),
71
+ f"opt_{name}": optimizer.state_dict(),
72
+ "scheduler": scheduler.state_dict(),
73
+ 'ep': ep,
74
+ 'total_it': total_it,
75
+ }
76
+ if ema is not None:
77
+ mardm_state_dict = model.state_dict()
78
+ ema_mardm_state_dict = ema.state_dict()
79
+ clip_weights = [e for e in mardm_state_dict.keys() if e.startswith('clip_model.')]
80
+ for e in clip_weights:
81
+ del mardm_state_dict[e]
82
+ del ema_mardm_state_dict[e]
83
+ state[name] = mardm_state_dict
84
+ state[f"ema_{name}"] = ema_mardm_state_dict
85
+ torch.save(state, file_name)
86
+
87
+
88
+ def print_current_loss(start_time, niter_state, total_niters, losses, epoch=None, sub_epoch=None,
89
+ inner_iter=None, tf_ratio=None, sl_steps=None):
90
+ def as_minutes(s):
91
+ m = math.floor(s / 60)
92
+ s -= m * 60
93
+ return '%dm %ds' % (m, s)
94
+ def time_since(since, percent):
95
+ now = time.time()
96
+ s = now - since
97
+ es = s / percent
98
+ rs = es - s
99
+ return '%s (- %s)' % (as_minutes(s), as_minutes(rs))
100
+ if epoch is not None:
101
+ print('ep/it:%2d-%4d niter:%6d' % (epoch, inner_iter, niter_state), end=" ")
102
+ message = ' %s completed:%3d%%)' % (time_since(start_time, niter_state / total_niters), niter_state / total_niters * 100)
103
+ for k, v in losses.items():
104
+ message += ' %s: %.4f ' % (k, v)
105
+ print(message)
utils/wandb_utils.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import wandb
2
+ import torch
3
+ from torchvision.utils import make_grid
4
+ import torch.distributed as dist
5
+ import argparse
6
+ import hashlib
7
+ import math
8
+
9
+
10
+ def is_main_process():
11
+ return dist.get_rank() == 0
12
+
13
+ def namespace_to_dict(namespace):
14
+ return {
15
+ k: namespace_to_dict(v) if isinstance(v, argparse.Namespace) else v
16
+ for k, v in vars(namespace).items()
17
+ }
18
+
19
+
20
+ def generate_run_id(exp_name):
21
+ # https://stackoverflow.com/questions/16008670/how-to-hash-a-string-into-8-digits
22
+ return str(int(hashlib.sha256(exp_name.encode('utf-8')).hexdigest(), 16) % 10 ** 8)
23
+
24
+
25
+ def initialize(args, entity, exp_name, project_name):
26
+ config_dict = namespace_to_dict(args)
27
+ # wandb.login(key=os.environ["WANDB_KEY"])
28
+ wandb.init(
29
+ entity=entity,
30
+ project=project_name,
31
+ name=exp_name,
32
+ config=config_dict,
33
+ id=generate_run_id(exp_name),
34
+ resume="allow",
35
+ )
36
+
37
+
38
+ def log(stats, step=None):
39
+ if is_main_process():
40
+ wandb.log({k: v for k, v in stats.items()}, step=step)
41
+
42
+
43
+ def log_image(sample, step=None):
44
+ if is_main_process():
45
+ sample = array2grid(sample)
46
+ wandb.log({f"samples": wandb.Image(sample), "train_step": step})
47
+
48
+
49
+ def array2grid(x):
50
+ nrow = round(math.sqrt(x.size(0)))
51
+ x = make_grid(x, nrow=nrow, normalize=True, value_range=(-1,1))
52
+ x = x.mul(255).add_(0.5).clamp_(0,255).permute(1,2,0).to('cpu', torch.uint8).numpy()
53
+ return x