Spaces:
Sleeping
Sleeping
Initial deployment: ACMDM Motion Generation
Browse files- .gitattributes +1 -0
- README.md +76 -5
- app.py +604 -0
- checkpoints/t2m/ACMDM_Flow_S_PatchSize22/model/latest.tar +3 -0
- checkpoints/t2m/AE_2D_Causal/AE_2D_Causal_Post_Mean.npy +3 -0
- checkpoints/t2m/AE_2D_Causal/AE_2D_Causal_Post_Std.npy +3 -0
- checkpoints/t2m/AE_2D_Causal/model/latest.tar +3 -0
- checkpoints/t2m/length_estimator/model/finest.tar +3 -0
- requirements.txt +30 -0
- utils/22x3_mean_std/t2m/22x3_mean.npy +3 -0
- utils/22x3_mean_std/t2m/22x3_std.npy +3 -0
- utils/__pycache__/back_process.cpython-310.pyc +0 -0
- utils/__pycache__/eval_utils.cpython-310.pyc +0 -0
- utils/__pycache__/evaluators.cpython-310.pyc +0 -0
- utils/__pycache__/glove.cpython-310.pyc +0 -0
- utils/__pycache__/motion_process.cpython-310.pyc +0 -0
- utils/__pycache__/motion_process.cpython-313.pyc +0 -0
- utils/__pycache__/quaternion.cpython-310.pyc +0 -0
- utils/__pycache__/skeleton.cpython-310.pyc +0 -0
- utils/__pycache__/train_utils.cpython-310.pyc +0 -0
- utils/back_process.py +255 -0
- utils/cal_ae_post_mean_std.py +64 -0
- utils/cal_mean_std.py +35 -0
- utils/cal_mesh_ae_post_mean_std.py +107 -0
- utils/datasets.py +452 -0
- utils/eval_mean_std/t2m/eval_mean.npy +3 -0
- utils/eval_mean_std/t2m/eval_std.npy +3 -0
- utils/eval_utils.py +928 -0
- utils/evaluators.py +392 -0
- utils/glove.py +83 -0
- utils/mesh_mean_std/t2m/mesh_mean.npy +3 -0
- utils/mesh_mean_std/t2m/mesh_std.npy +3 -0
- utils/motion_process.py +429 -0
- utils/quaternion.py +519 -0
- utils/skeleton.py +194 -0
- utils/train_utils.py +105 -0
- utils/wandb_utils.py +53 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
checkpoints/**/*.tar filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -1,12 +1,83 @@
|
|
| 1 |
---
|
| 2 |
title: ACMDM Motion Generation
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
|
|
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
title: ACMDM Motion Generation
|
| 3 |
+
emoji: 🎭
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 4.0.0
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
+
hardware: gpu-t4-small
|
| 12 |
---
|
| 13 |
|
| 14 |
+
# ACMDM Motion Generation
|
| 15 |
+
|
| 16 |
+
Generate human motion animations from text descriptions using the ACMDM (Absolute Coordinates Make Motion Generation Easy) model.
|
| 17 |
+
|
| 18 |
+
## 🎯 Features
|
| 19 |
+
|
| 20 |
+
- **Text-to-Motion Generation**: Create realistic human motion from natural language descriptions
|
| 21 |
+
- **Batch Processing**: Generate multiple motions at once
|
| 22 |
+
- **Auto-Length Estimation**: AI automatically determines optimal motion length
|
| 23 |
+
- **Flexible Parameters**: Adjust CFG scale, motion length, and more
|
| 24 |
+
- **Real-time Preview**: See your generated motions instantly
|
| 25 |
+
|
| 26 |
+
## 🚀 Usage
|
| 27 |
+
|
| 28 |
+
1. **Enter a text description** of the motion you want (e.g., "A person is running on a treadmill.")
|
| 29 |
+
2. **Adjust parameters** (optional):
|
| 30 |
+
- Motion length (40-196 frames)
|
| 31 |
+
- CFG scale (controls text alignment)
|
| 32 |
+
- Auto-length estimation
|
| 33 |
+
3. **Click "Generate Motion"**
|
| 34 |
+
4. **View and download** your generated motion video
|
| 35 |
+
|
| 36 |
+
## 📝 Example Prompts
|
| 37 |
+
|
| 38 |
+
- "A person is running on a treadmill."
|
| 39 |
+
- "Someone is doing jumping jacks."
|
| 40 |
+
- "A person walks forward and then turns around."
|
| 41 |
+
- "A person is dancing energetically."
|
| 42 |
+
|
| 43 |
+
## ⚙️ Parameters
|
| 44 |
+
|
| 45 |
+
- **Motion Length**: Number of frames (40-196). Automatically rounded to multiples of 4.
|
| 46 |
+
- **CFG Scale**: Classifier-free guidance scale (1.0-10.0). Higher = more text-aligned, Lower = more diverse.
|
| 47 |
+
- **Auto-length**: Let AI estimate the optimal motion length based on your text.
|
| 48 |
+
|
| 49 |
+
## 🔧 Technical Details
|
| 50 |
+
|
| 51 |
+
This space uses pre-trained ACMDM models:
|
| 52 |
+
- **Autoencoder**: AE_2D_Causal
|
| 53 |
+
- **Diffusion Model**: ACMDM_Flow_S_PatchSize22
|
| 54 |
+
- **Dataset**: HumanML3D (t2m)
|
| 55 |
+
|
| 56 |
+
## 📚 Paper
|
| 57 |
+
|
| 58 |
+
[Absolute Coordinates Make Motion Generation Easy](https://arxiv.org/abs/2505.19377)
|
| 59 |
+
|
| 60 |
+
## 🤝 Citation
|
| 61 |
+
|
| 62 |
+
```bibtex
|
| 63 |
+
@article{meng2025absolute,
|
| 64 |
+
title={Absolute Coordinates Make Motion Generation Easy},
|
| 65 |
+
author={Meng, Zichong and Han, Zeyu and Peng, Xiaogang and Xie, Yiming and Jiang, Huaizu},
|
| 66 |
+
journal={arXiv preprint arXiv:2505.19377},
|
| 67 |
+
year={2025}
|
| 68 |
+
}
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
## ⚠️ Notes
|
| 72 |
+
|
| 73 |
+
- First generation may take 30-60 seconds (model loading)
|
| 74 |
+
- Subsequent generations are faster (5-15 seconds)
|
| 75 |
+
- GPU recommended for best performance
|
| 76 |
+
- Works on CPU but slower
|
| 77 |
+
|
| 78 |
+
## 🔗 Links
|
| 79 |
+
|
| 80 |
+
- [GitHub Repository](https://github.com/neu-vi/ACMDM)
|
| 81 |
+
- [Project Page](https://neu-vi.github.io/ACMDM/)
|
| 82 |
+
- [Paper](https://arxiv.org/abs/2505.19377)
|
| 83 |
+
|
app.py
ADDED
|
@@ -0,0 +1,604 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gradio Web Interface for ACMDM Motion Generation
|
| 3 |
+
Milestone 3: User-friendly web interface for text-to-motion generation
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import sys
|
| 8 |
+
from os.path import join as pjoin
|
| 9 |
+
import torch
|
| 10 |
+
import numpy as np
|
| 11 |
+
import gradio as gr
|
| 12 |
+
from typing import Optional, Tuple, List
|
| 13 |
+
import tempfile
|
| 14 |
+
import random
|
| 15 |
+
|
| 16 |
+
# Import from sample.py
|
| 17 |
+
from models.AE_2D_Causal import AE_models
|
| 18 |
+
from models.ACMDM import ACMDM_models
|
| 19 |
+
from models.LengthEstimator import LengthEstimator
|
| 20 |
+
from utils.back_process import back_process
|
| 21 |
+
from utils.motion_process import plot_3d_motion, t2m_kinematic_chain
|
| 22 |
+
|
| 23 |
+
# Global variables for model caching
|
| 24 |
+
_models_cache = {
|
| 25 |
+
'ae': None,
|
| 26 |
+
'acmdm': None,
|
| 27 |
+
'length_estimator': None,
|
| 28 |
+
'stats': None,
|
| 29 |
+
'device': None,
|
| 30 |
+
'loaded': False
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def set_seed(seed):
|
| 35 |
+
"""Set random seed for reproducibility"""
|
| 36 |
+
random.seed(seed)
|
| 37 |
+
np.random.seed(seed)
|
| 38 |
+
torch.manual_seed(seed)
|
| 39 |
+
torch.backends.cudnn.benchmark = False
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def load_models_cached(
|
| 43 |
+
gpu_id: int = 0,
|
| 44 |
+
model_name: str = 'ACMDM_Flow_S_PatchSize22',
|
| 45 |
+
ae_name: str = 'AE_2D_Causal',
|
| 46 |
+
ae_model: str = 'AE_Model',
|
| 47 |
+
model_type: str = 'ACMDM-Flow-S-PatchSize22',
|
| 48 |
+
dataset_name: str = 't2m',
|
| 49 |
+
checkpoints_dir: str = './checkpoints',
|
| 50 |
+
use_length_estimator: bool = False
|
| 51 |
+
) -> Tuple[dict, str]:
|
| 52 |
+
"""
|
| 53 |
+
Load models with caching to avoid reloading on every request.
|
| 54 |
+
Returns (models_dict, status_message)
|
| 55 |
+
"""
|
| 56 |
+
global _models_cache
|
| 57 |
+
|
| 58 |
+
# Check if models are already loaded
|
| 59 |
+
if _models_cache['loaded']:
|
| 60 |
+
return _models_cache, "Models already loaded (using cache)"
|
| 61 |
+
|
| 62 |
+
try:
|
| 63 |
+
# Determine device
|
| 64 |
+
if gpu_id >= 0 and torch.cuda.is_available():
|
| 65 |
+
device = torch.device(f"cuda:{gpu_id}")
|
| 66 |
+
else:
|
| 67 |
+
device = torch.device("cpu")
|
| 68 |
+
|
| 69 |
+
_models_cache['device'] = device
|
| 70 |
+
status_messages = [f"Using device: {device}"]
|
| 71 |
+
|
| 72 |
+
# Load AE
|
| 73 |
+
status_messages.append("Loading AE model...")
|
| 74 |
+
ae = AE_models[ae_model](input_width=3)
|
| 75 |
+
ae_ckpt_path = pjoin(checkpoints_dir, dataset_name, ae_name, 'model', 'latest.tar')
|
| 76 |
+
|
| 77 |
+
if not os.path.exists(ae_ckpt_path):
|
| 78 |
+
return None, f"Error: AE checkpoint not found at {ae_ckpt_path}"
|
| 79 |
+
|
| 80 |
+
ae_ckpt = torch.load(ae_ckpt_path, map_location='cpu')
|
| 81 |
+
ae.load_state_dict(ae_ckpt['ae'])
|
| 82 |
+
ae.eval()
|
| 83 |
+
ae.to(device)
|
| 84 |
+
_models_cache['ae'] = ae
|
| 85 |
+
status_messages.append(f"✓ Loaded AE from {ae_ckpt_path}")
|
| 86 |
+
|
| 87 |
+
# Load ACMDM
|
| 88 |
+
status_messages.append("Loading ACMDM model...")
|
| 89 |
+
acmdm = ACMDM_models[model_type](input_dim=ae.output_emb_width, cond_mode='text')
|
| 90 |
+
acmdm_ckpt_path = pjoin(checkpoints_dir, dataset_name, model_name, 'model', 'latest.tar')
|
| 91 |
+
|
| 92 |
+
if not os.path.exists(acmdm_ckpt_path):
|
| 93 |
+
return None, f"Error: ACMDM checkpoint not found at {acmdm_ckpt_path}"
|
| 94 |
+
|
| 95 |
+
acmdm_ckpt = torch.load(acmdm_ckpt_path, map_location='cpu')
|
| 96 |
+
missing_keys, unexpected_keys = acmdm.load_state_dict(acmdm_ckpt['ema_acmdm'], strict=False)
|
| 97 |
+
assert len(unexpected_keys) == 0
|
| 98 |
+
assert all([k.startswith('clip_model.') for k in missing_keys])
|
| 99 |
+
acmdm.eval()
|
| 100 |
+
acmdm.to(device)
|
| 101 |
+
_models_cache['acmdm'] = acmdm
|
| 102 |
+
status_messages.append(f"✓ Loaded ACMDM from {acmdm_ckpt_path}")
|
| 103 |
+
|
| 104 |
+
# Load LengthEstimator if needed
|
| 105 |
+
length_estimator = None
|
| 106 |
+
if use_length_estimator:
|
| 107 |
+
status_messages.append("Loading LengthEstimator...")
|
| 108 |
+
length_estimator = LengthEstimator(input_size=512, output_size=1)
|
| 109 |
+
length_estimator_path = pjoin(checkpoints_dir, dataset_name, 'length_estimator', 'model', 'latest.tar')
|
| 110 |
+
if os.path.exists(length_estimator_path):
|
| 111 |
+
length_ckpt = torch.load(length_estimator_path, map_location='cpu')
|
| 112 |
+
length_estimator.load_state_dict(length_ckpt['model'])
|
| 113 |
+
length_estimator.eval()
|
| 114 |
+
length_estimator.to(device)
|
| 115 |
+
_models_cache['length_estimator'] = length_estimator
|
| 116 |
+
status_messages.append(f"✓ Loaded LengthEstimator")
|
| 117 |
+
else:
|
| 118 |
+
status_messages.append(f"⚠ LengthEstimator not found, will use default length")
|
| 119 |
+
|
| 120 |
+
# Load normalization stats
|
| 121 |
+
status_messages.append("Loading normalization statistics...")
|
| 122 |
+
after_mean = np.load(pjoin(checkpoints_dir, dataset_name, ae_name, 'AE_2D_Causal_Post_Mean.npy'))
|
| 123 |
+
after_std = np.load(pjoin(checkpoints_dir, dataset_name, ae_name, 'AE_2D_Causal_Post_Std.npy'))
|
| 124 |
+
joint_mean = np.load(f'utils/22x3_mean_std/{dataset_name}/22x3_mean.npy')
|
| 125 |
+
joint_std = np.load(f'utils/22x3_mean_std/{dataset_name}/22x3_std.npy')
|
| 126 |
+
eval_mean = np.load(f'utils/eval_mean_std/{dataset_name}/eval_mean.npy')
|
| 127 |
+
eval_std = np.load(f'utils/eval_mean_std/{dataset_name}/eval_std.npy')
|
| 128 |
+
|
| 129 |
+
_models_cache['stats'] = {
|
| 130 |
+
'after_mean': after_mean,
|
| 131 |
+
'after_std': after_std,
|
| 132 |
+
'joint_mean': joint_mean,
|
| 133 |
+
'joint_std': joint_std,
|
| 134 |
+
'eval_mean': eval_mean,
|
| 135 |
+
'eval_std': eval_std
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
_models_cache['loaded'] = True
|
| 139 |
+
status_message = "\n".join(status_messages)
|
| 140 |
+
return _models_cache, status_message
|
| 141 |
+
|
| 142 |
+
except Exception as e:
|
| 143 |
+
error_msg = f"Error loading models: {str(e)}"
|
| 144 |
+
import traceback
|
| 145 |
+
error_msg += f"\n\nTraceback:\n{traceback.format_exc()}"
|
| 146 |
+
return None, error_msg
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def estimate_motion_length(text: str, models_cache: dict) -> int:
|
| 150 |
+
"""Estimate motion length from text using LengthEstimator"""
|
| 151 |
+
if models_cache['length_estimator'] is None:
|
| 152 |
+
return None
|
| 153 |
+
|
| 154 |
+
device = models_cache['device']
|
| 155 |
+
acmdm = models_cache['acmdm']
|
| 156 |
+
length_estimator = models_cache['length_estimator']
|
| 157 |
+
|
| 158 |
+
with torch.no_grad():
|
| 159 |
+
text_emb = acmdm.encode_text([text])
|
| 160 |
+
pred_length = length_estimator(text_emb)
|
| 161 |
+
pred_length = int(pred_length.item() * 4)
|
| 162 |
+
pred_length = ((pred_length + 2) // 4) * 4
|
| 163 |
+
pred_length = max(40, min(196, pred_length))
|
| 164 |
+
return pred_length
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def generate_motion_single(
|
| 168 |
+
text: str,
|
| 169 |
+
motion_length: Optional[int],
|
| 170 |
+
cfg_scale: float,
|
| 171 |
+
use_auto_length: bool,
|
| 172 |
+
gpu_id: int,
|
| 173 |
+
seed: int
|
| 174 |
+
) -> Tuple[Optional[str], str]:
|
| 175 |
+
"""
|
| 176 |
+
Generate a single motion from text.
|
| 177 |
+
Returns (video_path, status_message)
|
| 178 |
+
"""
|
| 179 |
+
global _models_cache
|
| 180 |
+
|
| 181 |
+
try:
|
| 182 |
+
set_seed(seed)
|
| 183 |
+
|
| 184 |
+
# Load models if not cached
|
| 185 |
+
if not _models_cache['loaded']:
|
| 186 |
+
models_cache, load_msg = load_models_cached(
|
| 187 |
+
gpu_id=gpu_id,
|
| 188 |
+
use_length_estimator=use_auto_length
|
| 189 |
+
)
|
| 190 |
+
if models_cache is None:
|
| 191 |
+
return None, f"Failed to load models:\n{load_msg}"
|
| 192 |
+
else:
|
| 193 |
+
models_cache = _models_cache
|
| 194 |
+
|
| 195 |
+
device = models_cache['device']
|
| 196 |
+
ae = models_cache['ae']
|
| 197 |
+
acmdm = models_cache['acmdm']
|
| 198 |
+
stats = models_cache['stats']
|
| 199 |
+
|
| 200 |
+
# Estimate length if needed
|
| 201 |
+
if motion_length is None or (motion_length == 0 and use_auto_length):
|
| 202 |
+
if use_auto_length and models_cache['length_estimator'] is not None:
|
| 203 |
+
motion_length = estimate_motion_length(text, models_cache)
|
| 204 |
+
status_msg = f"Estimated motion length: {motion_length} frames\n"
|
| 205 |
+
else:
|
| 206 |
+
motion_length = 120 # Default
|
| 207 |
+
status_msg = f"Using default motion length: {motion_length} frames\n"
|
| 208 |
+
else:
|
| 209 |
+
# Round to multiple of 4
|
| 210 |
+
motion_length = ((motion_length + 2) // 4) * 4
|
| 211 |
+
status_msg = f"Using specified motion length: {motion_length} frames\n"
|
| 212 |
+
|
| 213 |
+
status_msg += f"Generating motion for: '{text}'...\n"
|
| 214 |
+
|
| 215 |
+
# Generate motion
|
| 216 |
+
with torch.no_grad():
|
| 217 |
+
latent_length = motion_length // 4
|
| 218 |
+
m_lens = torch.tensor([latent_length], device=device)
|
| 219 |
+
pred_latents = acmdm.generate([text], m_lens, cfg_scale)
|
| 220 |
+
|
| 221 |
+
# Denormalize latents
|
| 222 |
+
pred_latents_np = pred_latents.permute(0, 2, 3, 1).detach().cpu().numpy()
|
| 223 |
+
pred_latents_np = pred_latents_np * stats['after_std'] + stats['after_mean']
|
| 224 |
+
pred_latents_tensor = torch.from_numpy(pred_latents_np).to(device)
|
| 225 |
+
|
| 226 |
+
# Decode through AE
|
| 227 |
+
pred_motions = ae.decode(pred_latents_tensor.permute(0, 3, 1, 2))
|
| 228 |
+
|
| 229 |
+
# Denormalize motions
|
| 230 |
+
pred_motions_np = pred_motions.permute(0, 2, 3, 1).detach().cpu().numpy()
|
| 231 |
+
if stats['joint_mean'].ndim == 1:
|
| 232 |
+
pred_motions_np = pred_motions_np * stats['joint_std'][np.newaxis, np.newaxis, :, np.newaxis] + stats['joint_mean'][np.newaxis, np.newaxis, :, np.newaxis]
|
| 233 |
+
else:
|
| 234 |
+
pred_motions_np = pred_motions_np * stats['joint_std'][np.newaxis, ..., np.newaxis] + stats['joint_mean'][np.newaxis, ..., np.newaxis]
|
| 235 |
+
|
| 236 |
+
# Back process to get RIC format, then recover joint positions
|
| 237 |
+
from utils.motion_process import recover_from_ric
|
| 238 |
+
|
| 239 |
+
motion = pred_motions_np[0] # (22, 3, seq_len)
|
| 240 |
+
motion = motion[:, :, :motion_length].transpose(2, 0, 1) # (seq_len, 22, 3)
|
| 241 |
+
ric_data = back_process(motion, is_mesh=False)
|
| 242 |
+
ric_tensor = torch.from_numpy(ric_data).float()
|
| 243 |
+
joints = recover_from_ric(ric_tensor, joints_num=22).numpy() # (seq_len, 22, 3)
|
| 244 |
+
|
| 245 |
+
# Create temporary file for video
|
| 246 |
+
temp_dir = tempfile.mkdtemp()
|
| 247 |
+
video_path = pjoin(temp_dir, 'motion.mp4')
|
| 248 |
+
|
| 249 |
+
# Generate video
|
| 250 |
+
plot_3d_motion(
|
| 251 |
+
video_path,
|
| 252 |
+
t2m_kinematic_chain,
|
| 253 |
+
joints,
|
| 254 |
+
title=text,
|
| 255 |
+
fps=20,
|
| 256 |
+
radius=4
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
status_msg += "✓ Motion generated successfully!"
|
| 260 |
+
return video_path, status_msg
|
| 261 |
+
|
| 262 |
+
except Exception as e:
|
| 263 |
+
error_msg = f"Error generating motion: {str(e)}"
|
| 264 |
+
import traceback
|
| 265 |
+
error_msg += f"\n\nTraceback:\n{traceback.format_exc()}"
|
| 266 |
+
return None, error_msg
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def generate_motion_batch(
|
| 270 |
+
text_file_content: str,
|
| 271 |
+
cfg_scale: float,
|
| 272 |
+
use_auto_length: bool,
|
| 273 |
+
gpu_id: int,
|
| 274 |
+
seed: int
|
| 275 |
+
) -> Tuple[List[Optional[str]], str]:
|
| 276 |
+
"""
|
| 277 |
+
Generate motions from a batch of text prompts.
|
| 278 |
+
Returns (list_of_video_paths, status_message)
|
| 279 |
+
"""
|
| 280 |
+
# Parse text file
|
| 281 |
+
text_prompts = []
|
| 282 |
+
for line in text_file_content.strip().split('\n'):
|
| 283 |
+
line = line.strip()
|
| 284 |
+
if not line:
|
| 285 |
+
continue
|
| 286 |
+
if '#' in line:
|
| 287 |
+
parts = line.split('#')
|
| 288 |
+
text = parts[0].strip()
|
| 289 |
+
length_str = parts[1].strip() if len(parts) > 1 else 'NA'
|
| 290 |
+
else:
|
| 291 |
+
text = line
|
| 292 |
+
length_str = 'NA'
|
| 293 |
+
|
| 294 |
+
if length_str.upper() == 'NA':
|
| 295 |
+
motion_length = None if use_auto_length else 120
|
| 296 |
+
else:
|
| 297 |
+
try:
|
| 298 |
+
motion_length = int(length_str)
|
| 299 |
+
motion_length = ((motion_length + 2) // 4) * 4
|
| 300 |
+
except:
|
| 301 |
+
motion_length = None if use_auto_length else 120
|
| 302 |
+
|
| 303 |
+
text_prompts.append((text, motion_length))
|
| 304 |
+
|
| 305 |
+
if not text_prompts:
|
| 306 |
+
return [], "No valid prompts found in file"
|
| 307 |
+
|
| 308 |
+
status_msg = f"Processing {len(text_prompts)} prompts...\n"
|
| 309 |
+
video_paths = []
|
| 310 |
+
|
| 311 |
+
for idx, (text, motion_length) in enumerate(text_prompts):
|
| 312 |
+
video_path, gen_msg = generate_motion_single(
|
| 313 |
+
text, motion_length, cfg_scale, use_auto_length, gpu_id, seed + idx
|
| 314 |
+
)
|
| 315 |
+
video_paths.append(video_path)
|
| 316 |
+
status_msg += f"\n[{idx+1}/{len(text_prompts)}] {gen_msg}\n"
|
| 317 |
+
|
| 318 |
+
return video_paths, status_msg
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
# Gradio Interface
|
| 322 |
+
def create_interface():
|
| 323 |
+
"""Create and configure the Gradio interface"""
|
| 324 |
+
|
| 325 |
+
with gr.Blocks(title="ACMDM Motion Generation", theme=gr.themes.Soft()) as app:
|
| 326 |
+
gr.Markdown("""
|
| 327 |
+
# 🎭 ACMDM Motion Generation
|
| 328 |
+
|
| 329 |
+
Generate human motion from text descriptions using the ACMDM (Absolute Coordinates Make Motion Generation Easy) model.
|
| 330 |
+
|
| 331 |
+
**How to use:**
|
| 332 |
+
1. Enter a text description of the motion you want to generate
|
| 333 |
+
2. Adjust motion length (or use auto-estimate)
|
| 334 |
+
3. Click "Generate Motion" to create the animation
|
| 335 |
+
4. View and download the generated video
|
| 336 |
+
|
| 337 |
+
**Example prompts:**
|
| 338 |
+
- "A person is running on a treadmill."
|
| 339 |
+
- "Someone is doing jumping jacks."
|
| 340 |
+
- "A person walks forward and then turns around."
|
| 341 |
+
""")
|
| 342 |
+
|
| 343 |
+
with gr.Tabs():
|
| 344 |
+
# Single Generation Tab
|
| 345 |
+
with gr.Tab("Single Motion Generation"):
|
| 346 |
+
with gr.Row():
|
| 347 |
+
with gr.Column(scale=1):
|
| 348 |
+
text_input = gr.Textbox(
|
| 349 |
+
label="Motion Description",
|
| 350 |
+
placeholder="A person is running on a treadmill.",
|
| 351 |
+
lines=3,
|
| 352 |
+
value="A person is running on a treadmill."
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
with gr.Row():
|
| 356 |
+
motion_length = gr.Slider(
|
| 357 |
+
label="Motion Length (frames)",
|
| 358 |
+
minimum=40,
|
| 359 |
+
maximum=196,
|
| 360 |
+
value=120,
|
| 361 |
+
step=4,
|
| 362 |
+
info="Will be rounded to nearest multiple of 4"
|
| 363 |
+
)
|
| 364 |
+
use_auto_length = gr.Checkbox(
|
| 365 |
+
label="Auto-estimate length",
|
| 366 |
+
value=False,
|
| 367 |
+
info="Use length estimator (ignores manual length if checked)"
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
cfg_scale = gr.Slider(
|
| 371 |
+
label="CFG Scale",
|
| 372 |
+
minimum=1.0,
|
| 373 |
+
maximum=10.0,
|
| 374 |
+
value=3.0,
|
| 375 |
+
step=0.5,
|
| 376 |
+
info="Classifier-free guidance scale (higher = more aligned with text)"
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
with gr.Row():
|
| 380 |
+
gpu_id = gr.Number(
|
| 381 |
+
label="GPU ID",
|
| 382 |
+
value=0,
|
| 383 |
+
precision=0,
|
| 384 |
+
info="Use -1 for CPU"
|
| 385 |
+
)
|
| 386 |
+
seed = gr.Number(
|
| 387 |
+
label="Random Seed",
|
| 388 |
+
value=3407,
|
| 389 |
+
precision=0
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
generate_btn = gr.Button("Generate Motion", variant="primary", size="lg")
|
| 393 |
+
|
| 394 |
+
with gr.Column(scale=1):
|
| 395 |
+
video_output = gr.Video(
|
| 396 |
+
label="Generated Motion",
|
| 397 |
+
format="mp4"
|
| 398 |
+
)
|
| 399 |
+
status_output = gr.Textbox(
|
| 400 |
+
label="Status",
|
| 401 |
+
lines=10,
|
| 402 |
+
interactive=False
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
# Update model status after generation
|
| 406 |
+
def generate_and_update_status(text, motion_length, cfg_scale, use_auto_length, gpu_id, seed):
|
| 407 |
+
video_path, status_msg = generate_motion_single(
|
| 408 |
+
text, motion_length, cfg_scale, use_auto_length, gpu_id, seed
|
| 409 |
+
)
|
| 410 |
+
# Return video and status, plus trigger status update
|
| 411 |
+
return video_path, status_msg
|
| 412 |
+
|
| 413 |
+
generate_btn.click(
|
| 414 |
+
fn=generate_and_update_status,
|
| 415 |
+
inputs=[text_input, motion_length, cfg_scale, use_auto_length, gpu_id, seed],
|
| 416 |
+
outputs=[video_output, status_output]
|
| 417 |
+
)
|
| 418 |
+
|
| 419 |
+
# Batch Generation Tab
|
| 420 |
+
with gr.Tab("Batch Generation"):
|
| 421 |
+
with gr.Row():
|
| 422 |
+
with gr.Column(scale=1):
|
| 423 |
+
batch_text_input = gr.Textbox(
|
| 424 |
+
label="Text Prompts (one per line, format: text#length or text#NA)",
|
| 425 |
+
placeholder="A person is running on a treadmill.#120\nSomeone is doing jumping jacks.#NA",
|
| 426 |
+
lines=10,
|
| 427 |
+
info="Each line: 'text#length' or 'text#NA' for auto-estimate"
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
batch_cfg_scale = gr.Slider(
|
| 431 |
+
label="CFG Scale",
|
| 432 |
+
minimum=1.0,
|
| 433 |
+
maximum=10.0,
|
| 434 |
+
value=3.0,
|
| 435 |
+
step=0.5
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
batch_use_auto_length = gr.Checkbox(
|
| 439 |
+
label="Auto-estimate length for NA",
|
| 440 |
+
value=True
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
batch_gpu_id = gr.Number(
|
| 444 |
+
label="GPU ID",
|
| 445 |
+
value=0,
|
| 446 |
+
precision=0
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
batch_seed = gr.Number(
|
| 450 |
+
label="Random Seed",
|
| 451 |
+
value=3407,
|
| 452 |
+
precision=0
|
| 453 |
+
)
|
| 454 |
+
|
| 455 |
+
batch_generate_btn = gr.Button("Generate Batch", variant="primary", size="lg")
|
| 456 |
+
|
| 457 |
+
with gr.Column(scale=1):
|
| 458 |
+
batch_status_output = gr.Textbox(
|
| 459 |
+
label="Batch Status",
|
| 460 |
+
lines=15,
|
| 461 |
+
interactive=False
|
| 462 |
+
)
|
| 463 |
+
batch_video_gallery = gr.Gallery(
|
| 464 |
+
label="Generated Motions",
|
| 465 |
+
show_label=True,
|
| 466 |
+
elem_id="gallery",
|
| 467 |
+
columns=2,
|
| 468 |
+
rows=2,
|
| 469 |
+
height="auto"
|
| 470 |
+
)
|
| 471 |
+
|
| 472 |
+
batch_generate_btn.click(
|
| 473 |
+
fn=generate_motion_batch,
|
| 474 |
+
inputs=[batch_text_input, batch_cfg_scale, batch_use_auto_length, batch_gpu_id, batch_seed],
|
| 475 |
+
outputs=[batch_video_gallery, batch_status_output]
|
| 476 |
+
)
|
| 477 |
+
|
| 478 |
+
# Model Management Tab
|
| 479 |
+
with gr.Tab("Model Management"):
|
| 480 |
+
with gr.Row():
|
| 481 |
+
with gr.Column():
|
| 482 |
+
model_status = gr.Textbox(
|
| 483 |
+
label="Model Status",
|
| 484 |
+
lines=15,
|
| 485 |
+
interactive=False,
|
| 486 |
+
value="⏳ Models not loaded yet. They will be loaded automatically on first generation."
|
| 487 |
+
)
|
| 488 |
+
with gr.Row():
|
| 489 |
+
refresh_status_btn = gr.Button("🔄 Refresh Status", variant="primary")
|
| 490 |
+
reload_models_btn = gr.Button("🗑️ Clear Cache", variant="secondary")
|
| 491 |
+
|
| 492 |
+
gr.Markdown("""
|
| 493 |
+
**Model Configuration:**
|
| 494 |
+
- Model Name: `ACMDM_Flow_S_PatchSize22`
|
| 495 |
+
- AE Name: `AE_2D_Causal`
|
| 496 |
+
- Dataset: `t2m`
|
| 497 |
+
- Checkpoints Directory: `./checkpoints`
|
| 498 |
+
""")
|
| 499 |
+
|
| 500 |
+
def check_model_status():
|
| 501 |
+
"""Check and display current model status"""
|
| 502 |
+
global _models_cache
|
| 503 |
+
if _models_cache['loaded']:
|
| 504 |
+
device = _models_cache['device']
|
| 505 |
+
status_lines = [
|
| 506 |
+
"✓ MODELS LOADED AND READY",
|
| 507 |
+
"=" * 50,
|
| 508 |
+
f"📱 Device: {device}",
|
| 509 |
+
f"💾 CUDA Available: {torch.cuda.is_available()}",
|
| 510 |
+
]
|
| 511 |
+
|
| 512 |
+
if device.type == 'cuda':
|
| 513 |
+
status_lines.append(f"🎮 GPU: {torch.cuda.get_device_name(device.index if device.index is not None else 0)}")
|
| 514 |
+
status_lines.append(f"💾 GPU Memory: {torch.cuda.get_device_properties(device.index if device.index is not None else 0).total_memory / 1e9:.2f} GB")
|
| 515 |
+
|
| 516 |
+
status_lines.extend([
|
| 517 |
+
"",
|
| 518 |
+
"📦 Loaded Models:",
|
| 519 |
+
" ✓ Autoencoder (AE_2D_Causal)",
|
| 520 |
+
" ✓ ACMDM Diffusion Model",
|
| 521 |
+
])
|
| 522 |
+
|
| 523 |
+
if _models_cache['length_estimator'] is not None:
|
| 524 |
+
status_lines.append(" ✓ Length Estimator")
|
| 525 |
+
else:
|
| 526 |
+
status_lines.append(" ⚠ Length Estimator (not loaded)")
|
| 527 |
+
|
| 528 |
+
status_lines.extend([
|
| 529 |
+
"",
|
| 530 |
+
"📊 Statistics Loaded:",
|
| 531 |
+
" ✓ Post-AE Mean/Std",
|
| 532 |
+
" ✓ Joint Mean/Std",
|
| 533 |
+
" ✓ Eval Mean/Std",
|
| 534 |
+
"",
|
| 535 |
+
"✨ Models are cached and ready for generation!",
|
| 536 |
+
"💡 Tip: Use 'Clear Cache' to force reload on next generation."
|
| 537 |
+
])
|
| 538 |
+
|
| 539 |
+
return "\n".join(status_lines)
|
| 540 |
+
else:
|
| 541 |
+
return (
|
| 542 |
+
"⏳ Models not loaded yet. They will be loaded automatically on first generation.\n"
|
| 543 |
+
"📝 To load models now, go to 'Single Motion Generation' tab and click 'Generate Motion'.\n"
|
| 544 |
+
"⏱️ First generation will take 30-60 seconds (model loading time).\n"
|
| 545 |
+
"⚡ Subsequent generations will be much faster (5-15 seconds)."
|
| 546 |
+
)
|
| 547 |
+
|
| 548 |
+
def reload_models():
|
| 549 |
+
"""Clear model cache"""
|
| 550 |
+
global _models_cache
|
| 551 |
+
_models_cache['loaded'] = False
|
| 552 |
+
_models_cache['ae'] = None
|
| 553 |
+
_models_cache['acmdm'] = None
|
| 554 |
+
_models_cache['length_estimator'] = None
|
| 555 |
+
_models_cache['stats'] = None
|
| 556 |
+
_models_cache['device'] = None
|
| 557 |
+
return (
|
| 558 |
+
"🗑️ Model cache cleared. Models will be automatically reloaded on your next generation request.\n"
|
| 559 |
+
"💡 Click 'Refresh Status' after generating to see updated status."
|
| 560 |
+
)
|
| 561 |
+
|
| 562 |
+
# Button callbacks
|
| 563 |
+
refresh_status_btn.click(
|
| 564 |
+
fn=check_model_status,
|
| 565 |
+
outputs=model_status
|
| 566 |
+
)
|
| 567 |
+
|
| 568 |
+
reload_models_btn.click(
|
| 569 |
+
fn=reload_models,
|
| 570 |
+
outputs=model_status
|
| 571 |
+
)
|
| 572 |
+
|
| 573 |
+
# Update status on tab load
|
| 574 |
+
app.load(
|
| 575 |
+
fn=check_model_status,
|
| 576 |
+
outputs=model_status
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
+
gr.Markdown("""
|
| 580 |
+
---
|
| 581 |
+
**Note:** First generation may take longer as models need to be loaded. Subsequent generations will be faster.
|
| 582 |
+
|
| 583 |
+
**Tips:**
|
| 584 |
+
- Use descriptive text prompts for better results
|
| 585 |
+
- Adjust CFG scale: higher values (3-5) for more text alignment, lower values (1-2) for more diversity
|
| 586 |
+
- Motion length should be a multiple of 4 (automatically rounded)
|
| 587 |
+
- For batch processing, use the format: `text description#length` or `text description#NA`
|
| 588 |
+
""")
|
| 589 |
+
|
| 590 |
+
return app
|
| 591 |
+
|
| 592 |
+
|
| 593 |
+
if __name__ == "__main__":
|
| 594 |
+
# Create and launch the interface
|
| 595 |
+
app = create_interface()
|
| 596 |
+
|
| 597 |
+
# Launch with sharing option
|
| 598 |
+
app.launch(
|
| 599 |
+
server_name="0.0.0.0", # Allow external connections
|
| 600 |
+
server_port=7860, # Default Gradio port
|
| 601 |
+
share=False, # Set to True for public link (requires ngrok)
|
| 602 |
+
show_error=True
|
| 603 |
+
)
|
| 604 |
+
|
checkpoints/t2m/ACMDM_Flow_S_PatchSize22/model/latest.tar
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0b24050576ffaa92657e2cc146d4b083579a248be6757bd5f7eeaf98a1c9dd3e
|
| 3 |
+
size 312829822
|
checkpoints/t2m/AE_2D_Causal/AE_2D_Causal_Post_Mean.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:49354ec5e7cfc75121d3d4ddf323364debd86e8815e9a153176c9d538b1fbf17
|
| 3 |
+
size 144
|
checkpoints/t2m/AE_2D_Causal/AE_2D_Causal_Post_Std.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b2177734e42141c5c27f6b262b0e379df12bd664ec4a3986fd35d1933ccce4b8
|
| 3 |
+
size 144
|
checkpoints/t2m/AE_2D_Causal/model/latest.tar
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7c9601098c69561fcc12a7e91762c3caa03e5ab03e2c0488f44b2c4edbd08ee3
|
| 3 |
+
size 204997782
|
checkpoints/t2m/length_estimator/model/finest.tar
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:90512e0e893186e3b09b24b006bfbf51f1ad71ac4c6626c9b1309373db675d12
|
| 3 |
+
size 1745581
|
requirements.txt
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Requirements for HuggingFace Spaces Deployment
|
| 2 |
+
# This file is used by HuggingFace Spaces to install dependencies
|
| 3 |
+
|
| 4 |
+
# Core ML libraries
|
| 5 |
+
torch>=2.0.0
|
| 6 |
+
torchvision>=0.15.0
|
| 7 |
+
numpy>=1.21.0
|
| 8 |
+
scipy>=1.7.0
|
| 9 |
+
|
| 10 |
+
# Gradio for web interface
|
| 11 |
+
gradio>=4.0.0
|
| 12 |
+
|
| 13 |
+
# CLIP for text encoding
|
| 14 |
+
git+https://github.com/openai/CLIP.git
|
| 15 |
+
|
| 16 |
+
# Additional dependencies
|
| 17 |
+
einops>=0.6.0
|
| 18 |
+
timm>=0.6.0
|
| 19 |
+
tqdm>=4.64.0
|
| 20 |
+
matplotlib>=3.5.0
|
| 21 |
+
opencv-python>=4.6.0
|
| 22 |
+
Pillow>=9.0.0
|
| 23 |
+
|
| 24 |
+
# For model loading and processing
|
| 25 |
+
h5py>=3.7.0
|
| 26 |
+
scikit-learn>=1.0.0
|
| 27 |
+
|
| 28 |
+
# Utilities
|
| 29 |
+
gdown>=4.6.0
|
| 30 |
+
|
utils/22x3_mean_std/t2m/22x3_mean.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ebfe87efeb6828b33c69f9df88d7a0373b9af7ac546496a4c7701112a748f12f
|
| 3 |
+
size 140
|
utils/22x3_mean_std/t2m/22x3_std.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4733059903169bc95993999fde322d8e84fb95f3e6ff1b277c283a0eb877d0c5
|
| 3 |
+
size 140
|
utils/__pycache__/back_process.cpython-310.pyc
ADDED
|
Binary file (5.56 kB). View file
|
|
|
utils/__pycache__/eval_utils.cpython-310.pyc
ADDED
|
Binary file (21.2 kB). View file
|
|
|
utils/__pycache__/evaluators.cpython-310.pyc
ADDED
|
Binary file (12.5 kB). View file
|
|
|
utils/__pycache__/glove.cpython-310.pyc
ADDED
|
Binary file (2.62 kB). View file
|
|
|
utils/__pycache__/motion_process.cpython-310.pyc
ADDED
|
Binary file (11.2 kB). View file
|
|
|
utils/__pycache__/motion_process.cpython-313.pyc
ADDED
|
Binary file (16.6 kB). View file
|
|
|
utils/__pycache__/quaternion.cpython-310.pyc
ADDED
|
Binary file (13.2 kB). View file
|
|
|
utils/__pycache__/skeleton.cpython-310.pyc
ADDED
|
Binary file (6.09 kB). View file
|
|
|
utils/__pycache__/train_utils.cpython-310.pyc
ADDED
|
Binary file (3.73 kB). View file
|
|
|
utils/back_process.py
ADDED
|
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from utils.skeleton import Skeleton
|
| 2 |
+
from utils.quaternion import *
|
| 3 |
+
from utils.motion_process import t2m_kinematic_chain, t2m_raw_offsets
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
n_raw_offsets = torch.from_numpy(t2m_raw_offsets)
|
| 9 |
+
kinematic_chain = t2m_kinematic_chain
|
| 10 |
+
l_idx1, l_idx2 = 5, 8
|
| 11 |
+
face_joint_indx = [2, 1, 17, 16]
|
| 12 |
+
|
| 13 |
+
# Lazy loading of tgt_offsets - only needed when is_mesh=True
|
| 14 |
+
_tgt_offsets_cache = None
|
| 15 |
+
|
| 16 |
+
def _get_tgt_offsets():
|
| 17 |
+
"""Lazily load target offsets for mesh processing. Only called when is_mesh=True."""
|
| 18 |
+
global _tgt_offsets_cache
|
| 19 |
+
if _tgt_offsets_cache is None:
|
| 20 |
+
example_data_path = os.path.join("datasets/HumanML3D/new_joints", "000021" + '.npy')
|
| 21 |
+
if not os.path.exists(example_data_path):
|
| 22 |
+
raise FileNotFoundError(
|
| 23 |
+
f"Example data file not found: {example_data_path}\n"
|
| 24 |
+
"This file is only needed for mesh-level motion generation (is_mesh=True).\n"
|
| 25 |
+
"For regular text-to-motion generation, use is_mesh=False."
|
| 26 |
+
)
|
| 27 |
+
example_data = np.load(example_data_path)
|
| 28 |
+
example_data = example_data.reshape(len(example_data), -1, 3)
|
| 29 |
+
example_data = torch.from_numpy(example_data)
|
| 30 |
+
tgt_skel = Skeleton(n_raw_offsets, kinematic_chain, 'cpu')
|
| 31 |
+
_tgt_offsets_cache = tgt_skel.get_offsets_joints(example_data[0])
|
| 32 |
+
return _tgt_offsets_cache
|
| 33 |
+
|
| 34 |
+
l_idx1, l_idx2 = 5, 8
|
| 35 |
+
fid_r, fid_l = [8, 11], [7, 10]
|
| 36 |
+
r_hip, l_hip = 2, 1
|
| 37 |
+
joints_num = 22
|
| 38 |
+
|
| 39 |
+
def uniform_skeleton(positions, target_offset):
|
| 40 |
+
src_skel = Skeleton(n_raw_offsets, kinematic_chain, 'cpu')
|
| 41 |
+
src_offset = src_skel.get_offsets_joints(torch.from_numpy(positions[0]))
|
| 42 |
+
src_offset = src_offset.numpy()
|
| 43 |
+
tgt_offset = target_offset.numpy()
|
| 44 |
+
# print(src_offset)
|
| 45 |
+
# print(tgt_offset)
|
| 46 |
+
'''Calculate Scale Ratio as the ratio of legs'''
|
| 47 |
+
src_leg_len = np.abs(src_offset[l_idx1]).max() + np.abs(src_offset[l_idx2]).max()
|
| 48 |
+
tgt_leg_len = np.abs(tgt_offset[l_idx1]).max() + np.abs(tgt_offset[l_idx2]).max()
|
| 49 |
+
|
| 50 |
+
scale_rt = tgt_leg_len / src_leg_len
|
| 51 |
+
# print(scale_rt)
|
| 52 |
+
src_root_pos = positions[:, 0]
|
| 53 |
+
tgt_root_pos = src_root_pos * scale_rt
|
| 54 |
+
|
| 55 |
+
'''Inverse Kinematics'''
|
| 56 |
+
quat_params = src_skel.inverse_kinematics_np(positions, face_joint_indx)
|
| 57 |
+
# print(quat_params.shape)
|
| 58 |
+
|
| 59 |
+
'''Forward Kinematics'''
|
| 60 |
+
src_skel.set_offset(target_offset)
|
| 61 |
+
new_joints = src_skel.forward_kinematics_np(quat_params, tgt_root_pos)
|
| 62 |
+
return new_joints
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def process_file(positions, feet_thre, is_mesh=False):
|
| 66 |
+
# (seq_len, joints_num, 3)
|
| 67 |
+
# '''Down Sample'''
|
| 68 |
+
# positions = positions[::ds_num]
|
| 69 |
+
|
| 70 |
+
if is_mesh:
|
| 71 |
+
'''Uniform Skeleton'''
|
| 72 |
+
tgt_offsets = _get_tgt_offsets() # Lazy load only when needed
|
| 73 |
+
positions = uniform_skeleton(positions, tgt_offsets)
|
| 74 |
+
|
| 75 |
+
'''Put on Floor'''
|
| 76 |
+
floor_height = positions.min(axis=0).min(axis=0)[1]
|
| 77 |
+
positions[:, :, 1] -= floor_height
|
| 78 |
+
# print(floor_height)
|
| 79 |
+
|
| 80 |
+
# plot_3d_motion("./positions_1.mp4", kinematic_chain, positions, 'title', fps=20)
|
| 81 |
+
|
| 82 |
+
'''XZ at origin'''
|
| 83 |
+
root_pos_init = positions[0]
|
| 84 |
+
root_pose_init_xz = root_pos_init[0] * np.array([1, 0, 1])
|
| 85 |
+
positions = positions - root_pose_init_xz
|
| 86 |
+
|
| 87 |
+
# '''Move the first pose to origin '''
|
| 88 |
+
# root_pos_init = positions[0]
|
| 89 |
+
# positions = positions - root_pos_init[0]
|
| 90 |
+
|
| 91 |
+
'''All initially face Z+'''
|
| 92 |
+
r_hip, l_hip, sdr_r, sdr_l = face_joint_indx
|
| 93 |
+
across1 = root_pos_init[r_hip] - root_pos_init[l_hip]
|
| 94 |
+
across2 = root_pos_init[sdr_r] - root_pos_init[sdr_l]
|
| 95 |
+
across = across1 + across2
|
| 96 |
+
across = across / np.sqrt((across ** 2).sum(axis=-1))[..., np.newaxis]
|
| 97 |
+
|
| 98 |
+
# forward (3,), rotate around y-axis
|
| 99 |
+
forward_init = np.cross(np.array([[0, 1, 0]]), across, axis=-1)
|
| 100 |
+
# forward (3,)
|
| 101 |
+
forward_init = forward_init / np.sqrt((forward_init ** 2).sum(axis=-1))[..., np.newaxis]
|
| 102 |
+
|
| 103 |
+
# print(forward_init)
|
| 104 |
+
|
| 105 |
+
target = np.array([[0, 0, 1]])
|
| 106 |
+
root_quat_init = qbetween_np(forward_init, target)
|
| 107 |
+
root_quat_init = np.ones(positions.shape[:-1] + (4,)) * root_quat_init
|
| 108 |
+
|
| 109 |
+
positions_b = positions.copy()
|
| 110 |
+
|
| 111 |
+
positions = qrot_np(root_quat_init, positions)
|
| 112 |
+
|
| 113 |
+
# plot_3d_motion("./positions_2.mp4", kinematic_chain, positions, 'title', fps=20)
|
| 114 |
+
|
| 115 |
+
'''New ground truth positions'''
|
| 116 |
+
global_positions = positions.copy()
|
| 117 |
+
|
| 118 |
+
# plt.plot(positions_b[:, 0, 0], positions_b[:, 0, 2], marker='*')
|
| 119 |
+
# plt.plot(positions[:, 0, 0], positions[:, 0, 2], marker='o', color='r')
|
| 120 |
+
# plt.xlabel('x')
|
| 121 |
+
# plt.ylabel('z')
|
| 122 |
+
# plt.axis('equal')
|
| 123 |
+
# plt.show()
|
| 124 |
+
|
| 125 |
+
""" Get Foot Contacts """
|
| 126 |
+
|
| 127 |
+
def foot_detect(positions, thres):
|
| 128 |
+
velfactor, heightfactor = np.array([thres, thres]), np.array([3.0, 2.0])
|
| 129 |
+
|
| 130 |
+
feet_l_x = (positions[1:, fid_l, 0] - positions[:-1, fid_l, 0]) ** 2
|
| 131 |
+
feet_l_y = (positions[1:, fid_l, 1] - positions[:-1, fid_l, 1]) ** 2
|
| 132 |
+
feet_l_z = (positions[1:, fid_l, 2] - positions[:-1, fid_l, 2]) ** 2
|
| 133 |
+
# feet_l_h = positions[:-1,fid_l,1]
|
| 134 |
+
# feet_l = (((feet_l_x + feet_l_y + feet_l_z) < velfactor) & (feet_l_h < heightfactor)).astype(np.float)
|
| 135 |
+
feet_l = ((feet_l_x + feet_l_y + feet_l_z) < velfactor).astype(np.float32)
|
| 136 |
+
|
| 137 |
+
feet_r_x = (positions[1:, fid_r, 0] - positions[:-1, fid_r, 0]) ** 2
|
| 138 |
+
feet_r_y = (positions[1:, fid_r, 1] - positions[:-1, fid_r, 1]) ** 2
|
| 139 |
+
feet_r_z = (positions[1:, fid_r, 2] - positions[:-1, fid_r, 2]) ** 2
|
| 140 |
+
# feet_r_h = positions[:-1,fid_r,1]
|
| 141 |
+
# feet_r = (((feet_r_x + feet_r_y + feet_r_z) < velfactor) & (feet_r_h < heightfactor)).astype(np.float)
|
| 142 |
+
feet_r = (((feet_r_x + feet_r_y + feet_r_z) < velfactor)).astype(np.float32)
|
| 143 |
+
return feet_l, feet_r
|
| 144 |
+
|
| 145 |
+
#
|
| 146 |
+
feet_l, feet_r = foot_detect(positions, feet_thre)
|
| 147 |
+
# feet_l, feet_r = foot_detect(positions, 0.002)
|
| 148 |
+
|
| 149 |
+
'''Quaternion and Cartesian representation'''
|
| 150 |
+
r_rot = None
|
| 151 |
+
|
| 152 |
+
def get_rifke(positions):
|
| 153 |
+
'''Local pose'''
|
| 154 |
+
positions[..., 0] -= positions[:, 0:1, 0]
|
| 155 |
+
positions[..., 2] -= positions[:, 0:1, 2]
|
| 156 |
+
'''All pose face Z+'''
|
| 157 |
+
positions = qrot_np(np.repeat(r_rot[:, None], positions.shape[1], axis=1), positions)
|
| 158 |
+
return positions
|
| 159 |
+
|
| 160 |
+
def get_quaternion(positions):
|
| 161 |
+
skel = Skeleton(n_raw_offsets, kinematic_chain, "cpu")
|
| 162 |
+
# (seq_len, joints_num, 4)
|
| 163 |
+
quat_params = skel.inverse_kinematics_np(positions, face_joint_indx, smooth_forward=False)
|
| 164 |
+
|
| 165 |
+
'''Fix Quaternion Discontinuity'''
|
| 166 |
+
quat_params = qfix(quat_params)
|
| 167 |
+
# (seq_len, 4)
|
| 168 |
+
r_rot = quat_params[:, 0].copy()
|
| 169 |
+
# print(r_rot[0])
|
| 170 |
+
'''Root Linear Velocity'''
|
| 171 |
+
# (seq_len - 1, 3)
|
| 172 |
+
velocity = (positions[1:, 0] - positions[:-1, 0]).copy()
|
| 173 |
+
# print(r_rot.shape, velocity.shape)
|
| 174 |
+
velocity = qrot_np(r_rot[1:], velocity)
|
| 175 |
+
'''Root Angular Velocity'''
|
| 176 |
+
# (seq_len - 1, 4)
|
| 177 |
+
r_velocity = qmul_np(r_rot[1:], qinv_np(r_rot[:-1]))
|
| 178 |
+
quat_params[1:, 0] = r_velocity
|
| 179 |
+
# (seq_len, joints_num, 4)
|
| 180 |
+
return quat_params, r_velocity, velocity, r_rot
|
| 181 |
+
|
| 182 |
+
def get_cont6d_params(positions):
|
| 183 |
+
skel = Skeleton(n_raw_offsets, kinematic_chain, "cpu")
|
| 184 |
+
# (seq_len, joints_num, 4)
|
| 185 |
+
quat_params = skel.inverse_kinematics_np(positions, face_joint_indx, smooth_forward=True)
|
| 186 |
+
|
| 187 |
+
'''Quaternion to continuous 6D'''
|
| 188 |
+
cont_6d_params = quaternion_to_cont6d_np(quat_params)
|
| 189 |
+
# (seq_len, 4)
|
| 190 |
+
r_rot = quat_params[:, 0].copy()
|
| 191 |
+
# print(r_rot[0])
|
| 192 |
+
'''Root Linear Velocity'''
|
| 193 |
+
# (seq_len - 1, 3)
|
| 194 |
+
velocity = (positions[1:, 0] - positions[:-1, 0]).copy()
|
| 195 |
+
# print(r_rot.shape, velocity.shape)
|
| 196 |
+
velocity = qrot_np(r_rot[1:], velocity)
|
| 197 |
+
'''Root Angular Velocity'''
|
| 198 |
+
# (seq_len - 1, 4)
|
| 199 |
+
r_velocity = qmul_np(r_rot[1:], qinv_np(r_rot[:-1]))
|
| 200 |
+
# (seq_len, joints_num, 4)
|
| 201 |
+
return cont_6d_params, r_velocity, velocity, r_rot
|
| 202 |
+
|
| 203 |
+
cont_6d_params, r_velocity, velocity, r_rot = get_cont6d_params(positions)
|
| 204 |
+
positions = get_rifke(positions)
|
| 205 |
+
|
| 206 |
+
# trejec = np.cumsum(np.concatenate([np.array([[0, 0, 0]]), velocity], axis=0), axis=0)
|
| 207 |
+
# r_rotations, r_pos = recover_ric_glo_np(r_velocity, velocity[:, [0, 2]])
|
| 208 |
+
|
| 209 |
+
# plt.plot(positions_b[:, 0, 0], positions_b[:, 0, 2], marker='*')
|
| 210 |
+
# plt.plot(ground_positions[:, 0, 0], ground_positions[:, 0, 2], marker='o', color='r')
|
| 211 |
+
# plt.plot(trejec[:, 0], trejec[:, 2], marker='^', color='g')
|
| 212 |
+
# plt.plot(r_pos[:, 0], r_pos[:, 2], marker='s', color='y')
|
| 213 |
+
# plt.xlabel('x')
|
| 214 |
+
# plt.ylabel('z')
|
| 215 |
+
# plt.axis('equal')
|
| 216 |
+
# plt.show()
|
| 217 |
+
|
| 218 |
+
'''Root height'''
|
| 219 |
+
root_y = positions[:, 0, 1:2]
|
| 220 |
+
|
| 221 |
+
'''Root rotation and linear velocity'''
|
| 222 |
+
# (seq_len-1, 1) rotation velocity along y-axis
|
| 223 |
+
# (seq_len-1, 2) linear velovity on xz plane
|
| 224 |
+
r_velocity = np.arcsin(r_velocity[:, 2:3])
|
| 225 |
+
l_velocity = velocity[:, [0, 2]]
|
| 226 |
+
# print(r_velocity.shape, l_velocity.shape, root_y.shape)
|
| 227 |
+
root_data = np.concatenate([r_velocity, l_velocity, root_y[:-1]], axis=-1)
|
| 228 |
+
|
| 229 |
+
'''Get Joint Rotation Representation'''
|
| 230 |
+
# (seq_len, (joints_num-1) *6) quaternion for skeleton joints
|
| 231 |
+
rot_data = cont_6d_params[:, 1:].reshape(len(cont_6d_params), -1)
|
| 232 |
+
|
| 233 |
+
'''Get Joint Rotation Invariant Position Represention'''
|
| 234 |
+
# (seq_len, (joints_num-1)*3) local joint position
|
| 235 |
+
ric_data = positions[:, 1:].reshape(len(positions), -1)
|
| 236 |
+
|
| 237 |
+
'''Get Joint Velocity Representation'''
|
| 238 |
+
# (seq_len-1, joints_num*3)
|
| 239 |
+
local_vel = qrot_np(np.repeat(r_rot[:-1, None], global_positions.shape[1], axis=1),
|
| 240 |
+
global_positions[1:] - global_positions[:-1])
|
| 241 |
+
local_vel = local_vel.reshape(len(local_vel), -1)
|
| 242 |
+
|
| 243 |
+
data = root_data
|
| 244 |
+
data = np.concatenate([data, ric_data[:-1]], axis=-1)
|
| 245 |
+
data = np.concatenate([data, rot_data[:-1]], axis=-1)
|
| 246 |
+
# print(data.shape, local_vel.shape)
|
| 247 |
+
data = np.concatenate([data, local_vel], axis=-1)
|
| 248 |
+
data = np.concatenate([data, feet_l, feet_r], axis=-1)
|
| 249 |
+
|
| 250 |
+
return data, global_positions, positions, l_velocity
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def back_process(data, is_mesh=False):
|
| 254 |
+
data, ground_positions, positions, l_velocity = process_file(data, 0.002, is_mesh=is_mesh)
|
| 255 |
+
return data[:, :67]
|
utils/cal_ae_post_mean_std.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import os
|
| 3 |
+
from os.path import join as pjoin
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
import torch
|
| 6 |
+
import argparse
|
| 7 |
+
from models.AE_2D_Causal import AE_models
|
| 8 |
+
|
| 9 |
+
#################################################################################
|
| 10 |
+
# Calculate Post AE/VAE Mean Std #
|
| 11 |
+
#################################################################################
|
| 12 |
+
|
| 13 |
+
def mean_variance(data_dir, save_dir, ae, tp='AE'):
|
| 14 |
+
file_list = os.listdir(data_dir)
|
| 15 |
+
data_list = []
|
| 16 |
+
mean = np.load(f'utils/22x3_mean_std/t2m/22x3_mean.npy')
|
| 17 |
+
std = np.load(f'utils/22x3_mean_std/t2m/22x3_std.npy')
|
| 18 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 19 |
+
ae = ae.to(device)
|
| 20 |
+
|
| 21 |
+
for file in tqdm(file_list):
|
| 22 |
+
data = np.load(pjoin(data_dir, file))
|
| 23 |
+
if len(data.shape) == 2:
|
| 24 |
+
data = np.expand_dims(data, axis=0)
|
| 25 |
+
if np.isnan(data).any():
|
| 26 |
+
print(file)
|
| 27 |
+
continue
|
| 28 |
+
data = data[:(data.shape[0]//4)*4, :, :]
|
| 29 |
+
if data.shape[0] == 0:
|
| 30 |
+
continue
|
| 31 |
+
data = (data - mean) / std
|
| 32 |
+
data = torch.from_numpy(data).to(device)
|
| 33 |
+
with torch.no_grad():
|
| 34 |
+
data_list.append(ae.encode(data.unsqueeze(0)).squeeze(0).cpu().numpy())
|
| 35 |
+
|
| 36 |
+
data = np.concatenate(data_list, axis=1)
|
| 37 |
+
data = data.reshape(4, -1)
|
| 38 |
+
print('Data Range:')
|
| 39 |
+
print(data.min(),data.max())
|
| 40 |
+
Mean = data.mean(axis=1)
|
| 41 |
+
Std = data.std(axis=1)
|
| 42 |
+
print('Mean/Std:')
|
| 43 |
+
print(Mean, Std)
|
| 44 |
+
|
| 45 |
+
np.save(pjoin(save_dir, f'{tp}_2D_Causal_Post_Mean.npy'), Mean)
|
| 46 |
+
np.save(pjoin(save_dir, f'{tp}_2D_Causal_Post_Std.npy'), Std)
|
| 47 |
+
|
| 48 |
+
if __name__ == '__main__':
|
| 49 |
+
parser = argparse.ArgumentParser()
|
| 50 |
+
data_dir = 'datasets/HumanML3D/new_joints/'
|
| 51 |
+
|
| 52 |
+
parser.add_argument('--is_ae', action="store_true")
|
| 53 |
+
parser.add_argument('--ae_name', type=str, default="AE_2D_Causal")
|
| 54 |
+
args = parser.parse_args()
|
| 55 |
+
|
| 56 |
+
if args.is_ae:
|
| 57 |
+
ae = AE_models["AE_Model"](input_width=3)
|
| 58 |
+
else:
|
| 59 |
+
ae = AE_models["VAE_Model"](input_width=3)
|
| 60 |
+
ckpt = torch.load(f'checkpoints/t2m/{args.ae_name}/model/latest.tar', map_location='cpu')
|
| 61 |
+
ae.load_state_dict(ckpt['ae'])
|
| 62 |
+
ae = ae.eval()
|
| 63 |
+
save_dir = f'checkpoints/t2m/{args.ae_name}'
|
| 64 |
+
mean_variance(data_dir, save_dir, ae, tp='AE' if args.is_ae else 'VAE')
|
utils/cal_mean_std.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import os
|
| 3 |
+
from os.path import join as pjoin
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
|
| 6 |
+
#################################################################################
|
| 7 |
+
# Calculate Absolute Coordinate Mean Std #
|
| 8 |
+
#################################################################################
|
| 9 |
+
def mean_variance(data_dir, save_dir):
|
| 10 |
+
file_list = os.listdir(data_dir)
|
| 11 |
+
data_list = []
|
| 12 |
+
|
| 13 |
+
for file in tqdm(file_list):
|
| 14 |
+
data = np.load(pjoin(data_dir, file))
|
| 15 |
+
if len(data.shape) == 2:
|
| 16 |
+
data = np.expand_dims(data, axis=0)
|
| 17 |
+
if np.isnan(data).any():
|
| 18 |
+
print(file)
|
| 19 |
+
continue
|
| 20 |
+
data_list.append(data.reshape(-1, 3))
|
| 21 |
+
|
| 22 |
+
data = np.concatenate(data_list, axis=0)
|
| 23 |
+
print(data.shape)
|
| 24 |
+
Mean = data.mean(axis=0)
|
| 25 |
+
Std = data.std(axis=0)
|
| 26 |
+
|
| 27 |
+
np.save(pjoin(save_dir, 'Mean_22x3.npy'), Mean)
|
| 28 |
+
np.save(pjoin(save_dir, 'Std_22x3.npy'), Std)
|
| 29 |
+
|
| 30 |
+
return Mean, Std
|
| 31 |
+
|
| 32 |
+
if __name__ == '__main__':
|
| 33 |
+
data_dir1 = 'datasets/HumanML3D/new_joints/'
|
| 34 |
+
save_dir1 = 'datasets/HumanML3D/'
|
| 35 |
+
mean, std = mean_variance(data_dir1, save_dir1)
|
utils/cal_mesh_ae_post_mean_std.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# very hard coded will clean and refine later
|
| 2 |
+
import numpy as np
|
| 3 |
+
import os
|
| 4 |
+
from os.path import join as pjoin
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
import torch
|
| 7 |
+
from models.AE_Mesh import AE_models
|
| 8 |
+
from human_body_prior.body_model.body_model import BodyModel
|
| 9 |
+
|
| 10 |
+
def downsample(data_dir, save_dir):
|
| 11 |
+
ae = AE_models["AE_Model"](test_mode=True)
|
| 12 |
+
ckpt = torch.load('checkpoints/t2m/AE_Mesh/model/latest.tar', map_location='cpu')
|
| 13 |
+
model_key = 'ae'
|
| 14 |
+
ae.load_state_dict(ckpt[model_key])
|
| 15 |
+
ae = ae.eval()
|
| 16 |
+
ae = ae.cuda()
|
| 17 |
+
|
| 18 |
+
mean = np.load(f'utils/mesh_mean_std/t2m/mesh_mean.npy') # or yours
|
| 19 |
+
std = np.load(f'utils/mesh_mean_std/t2m/mesh_std.npy') # or yours
|
| 20 |
+
|
| 21 |
+
bm_path = './body_models/smplh/neutral/model.npz'
|
| 22 |
+
dmpl_path = './body_models/dmpls/neutral/model.npz'
|
| 23 |
+
num_betas = 10
|
| 24 |
+
num_dmpls = 8
|
| 25 |
+
bm = BodyModel(bm_fname=bm_path, num_betas=num_betas, num_dmpls=num_dmpls, dmpl_fname=dmpl_path).cuda()
|
| 26 |
+
comp_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 27 |
+
|
| 28 |
+
file_list = os.listdir(data_dir)
|
| 29 |
+
file_list.sort()
|
| 30 |
+
for file in tqdm(file_list):
|
| 31 |
+
data = np.load(pjoin(data_dir, file))
|
| 32 |
+
|
| 33 |
+
if np.isnan(data).any():
|
| 34 |
+
print(f"Skipping {file} due to NaN")
|
| 35 |
+
continue
|
| 36 |
+
|
| 37 |
+
body_parms = {
|
| 38 |
+
'root_orient': torch.from_numpy(data[:, :3]).to(comp_device).float(),
|
| 39 |
+
'pose_body': torch.from_numpy(data[:, 3:66]).to(comp_device).float(),
|
| 40 |
+
'pose_hand': torch.from_numpy(data[:, 66:52*3]).to(comp_device).float(),
|
| 41 |
+
'trans': torch.from_numpy(data[:, 52*3:53*3]).to(comp_device).float(),
|
| 42 |
+
'betas': torch.from_numpy(data[:, 53*3:53*3+10]).to(comp_device).float(),
|
| 43 |
+
'dmpls': torch.from_numpy(data[:, 53*3+10:]).to(comp_device).float()
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
with torch.no_grad():
|
| 47 |
+
verts = bm(**body_parms).v
|
| 48 |
+
verts[:, :, 1] -= verts[:, :, 1].min()
|
| 49 |
+
verts = verts.detach().cpu().numpy()
|
| 50 |
+
"Z Normalization"
|
| 51 |
+
vertss = (verts - mean) / std
|
| 52 |
+
T = vertss.shape[0]
|
| 53 |
+
data = torch.from_numpy(vertss).float().to(comp_device)
|
| 54 |
+
if T % 16 != 0:
|
| 55 |
+
pad_len = 16 - (T % 16)
|
| 56 |
+
pad_data = torch.zeros((pad_len, 6890, 3), dtype=data.dtype, device=comp_device)
|
| 57 |
+
data = torch.cat([data, pad_data], dim=0)
|
| 58 |
+
|
| 59 |
+
outputs = []
|
| 60 |
+
with torch.no_grad():
|
| 61 |
+
for i in range(0, data.shape[0], 16):
|
| 62 |
+
chunk = data[i:i+16].unsqueeze(0)
|
| 63 |
+
out = ae.encode(chunk).squeeze(0).cpu().numpy()
|
| 64 |
+
outputs.append(out)
|
| 65 |
+
downsampled = np.concatenate(outputs, axis=0)[:T]
|
| 66 |
+
|
| 67 |
+
np.save(pjoin(save_dir, f"{file}"), downsampled)
|
| 68 |
+
|
| 69 |
+
#################################################################################
|
| 70 |
+
# Calculate Mean Std #
|
| 71 |
+
#################################################################################
|
| 72 |
+
def mean_variance(data_dir, save_dir):
|
| 73 |
+
file_list = os.listdir(data_dir)
|
| 74 |
+
data_list = []
|
| 75 |
+
|
| 76 |
+
for file in tqdm(file_list):
|
| 77 |
+
data = np.load(pjoin(data_dir, file))
|
| 78 |
+
if np.isnan(data).any():
|
| 79 |
+
print(file)
|
| 80 |
+
continue
|
| 81 |
+
data_list.append(data.reshape(-1, 12))
|
| 82 |
+
|
| 83 |
+
data = np.concatenate(data_list, axis=0)
|
| 84 |
+
print(data.shape)
|
| 85 |
+
Mean = data.mean(axis=0)
|
| 86 |
+
Std = data.std(axis=0)
|
| 87 |
+
|
| 88 |
+
np.save(pjoin(save_dir, 'AE_Mesh_Post_Mean.npy'), Mean)
|
| 89 |
+
np.save(pjoin(save_dir, 'AE_Mesh_Post_Std.npy'), Std)
|
| 90 |
+
print(data.min(),data.max())
|
| 91 |
+
Mean2 = data.mean()
|
| 92 |
+
Std2 = data.std()
|
| 93 |
+
print(Mean, Std)
|
| 94 |
+
print(Mean2, Std2)
|
| 95 |
+
|
| 96 |
+
return Mean, Std
|
| 97 |
+
|
| 98 |
+
if __name__ == '__main__':
|
| 99 |
+
data_dir = 'datasets/HumanML3D/meshes/'
|
| 100 |
+
save_dir = 'datasets/HumanML3D/meshes_after_ae/'
|
| 101 |
+
|
| 102 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 103 |
+
downsample(data_dir, save_dir)
|
| 104 |
+
|
| 105 |
+
data_dir1 = 'datasets/HumanML3D/meshes_after_ae/'
|
| 106 |
+
save_dir1 = 'checkpoints/t2m/AE_Mesh/'
|
| 107 |
+
mean, std = mean_variance(data_dir1, save_dir1)
|
utils/datasets.py
ADDED
|
@@ -0,0 +1,452 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from os.path import join as pjoin
|
| 2 |
+
from torch.utils import data
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
from torch.utils.data._utils.collate import default_collate
|
| 7 |
+
import random
|
| 8 |
+
import codecs as cs
|
| 9 |
+
from utils.glove import GloVe
|
| 10 |
+
from human_body_prior.body_model.body_model import BodyModel
|
| 11 |
+
|
| 12 |
+
#################################################################################
|
| 13 |
+
# Collate Function #
|
| 14 |
+
#################################################################################
|
| 15 |
+
def collate_fn(batch):
|
| 16 |
+
batch.sort(key=lambda x: x[3], reverse=True)
|
| 17 |
+
return default_collate(batch)
|
| 18 |
+
|
| 19 |
+
#################################################################################
|
| 20 |
+
# Datasets #
|
| 21 |
+
#################################################################################
|
| 22 |
+
class AEDataset(data.Dataset):
|
| 23 |
+
def __init__(self, mean, std, motion_dir, window_size, split_file):
|
| 24 |
+
self.data = []
|
| 25 |
+
self.lengths = []
|
| 26 |
+
id_list = []
|
| 27 |
+
with open(split_file, 'r') as f:
|
| 28 |
+
for line in f.readlines():
|
| 29 |
+
id_list.append(line.strip())
|
| 30 |
+
|
| 31 |
+
for name in tqdm(id_list):
|
| 32 |
+
try:
|
| 33 |
+
motion = np.load(pjoin(motion_dir, name + '.npy'))
|
| 34 |
+
if len(motion.shape) == 2: # B,L,J,3
|
| 35 |
+
motion = np.expand_dims(motion, axis=0)
|
| 36 |
+
if motion.shape[0] < window_size:
|
| 37 |
+
continue
|
| 38 |
+
self.lengths.append(motion.shape[0] - window_size)
|
| 39 |
+
self.data.append(motion)
|
| 40 |
+
except Exception as e:
|
| 41 |
+
pass
|
| 42 |
+
self.cumsum = np.cumsum([0] + self.lengths)
|
| 43 |
+
self.window_size = window_size
|
| 44 |
+
|
| 45 |
+
self.mean = mean
|
| 46 |
+
self.std = std
|
| 47 |
+
print("Total number of motions {}, snippets {}".format(len(self.data), self.cumsum[-1]))
|
| 48 |
+
|
| 49 |
+
def __len__(self):
|
| 50 |
+
return self.cumsum[-1]
|
| 51 |
+
|
| 52 |
+
def __getitem__(self, item):
|
| 53 |
+
if item != 0:
|
| 54 |
+
motion_id = np.searchsorted(self.cumsum, item) - 1
|
| 55 |
+
idx = item - self.cumsum[motion_id] - 1
|
| 56 |
+
else:
|
| 57 |
+
motion_id = 0
|
| 58 |
+
idx = 0
|
| 59 |
+
motion = self.data[motion_id][idx:idx + self.window_size]
|
| 60 |
+
"Z Normalization"
|
| 61 |
+
motion = (motion - self.mean) / self.std
|
| 62 |
+
|
| 63 |
+
return motion
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class AEMeshDataset(data.Dataset):
|
| 67 |
+
def __init__(self, mean, std, motion_dir, window_size, split_file):
|
| 68 |
+
self.data = []
|
| 69 |
+
self.lengths = []
|
| 70 |
+
id_list = []
|
| 71 |
+
with open(split_file, 'r') as f:
|
| 72 |
+
for line in f.readlines():
|
| 73 |
+
id_list.append(line.strip())
|
| 74 |
+
|
| 75 |
+
for name in id_list:
|
| 76 |
+
try:
|
| 77 |
+
motion = np.load(pjoin(motion_dir, name + '.npy'))
|
| 78 |
+
if motion.shape[0] < window_size:
|
| 79 |
+
continue
|
| 80 |
+
self.lengths.append(motion.shape[0] - motion.shape[0]+1)
|
| 81 |
+
self.data.append(motion)
|
| 82 |
+
except Exception as e:
|
| 83 |
+
pass
|
| 84 |
+
self.cumsum = np.cumsum([0] + self.lengths)
|
| 85 |
+
self.window_size = window_size
|
| 86 |
+
|
| 87 |
+
self.mean = mean
|
| 88 |
+
self.std = std
|
| 89 |
+
num_betas = 10
|
| 90 |
+
num_dmpls = 8
|
| 91 |
+
self.bm = BodyModel(bm_fname='./body_models/smplh/neutral/model.npz', num_betas=num_betas, num_dmpls=num_dmpls, dmpl_fname='./body_models/dmpls/neutral/model.npz')
|
| 92 |
+
print("Total number of motions {}, snippets {}".format(len(self.data), self.cumsum[-1]))
|
| 93 |
+
|
| 94 |
+
def __len__(self):
|
| 95 |
+
return self.cumsum[-1]
|
| 96 |
+
|
| 97 |
+
def __getitem__(self, item):
|
| 98 |
+
motion = self.data[item]
|
| 99 |
+
body_parms = {
|
| 100 |
+
'root_orient': torch.from_numpy(motion[:, :3]).float(),
|
| 101 |
+
'pose_body': torch.from_numpy(motion[:, 3:66]).float(),
|
| 102 |
+
'pose_hand': torch.from_numpy(motion[:, 66:52*3]).float(),
|
| 103 |
+
'trans': torch.from_numpy(motion[:, 52*3:53*3]).float(),
|
| 104 |
+
'betas': torch.from_numpy(motion[:, 53*3:53*3+10]).float(),
|
| 105 |
+
'dmpls': torch.from_numpy(motion[:, 53*3+10:]).float()
|
| 106 |
+
}
|
| 107 |
+
body_parms['betas']= torch.zeros_like(torch.from_numpy(motion[:, 53*3:53*3+10]).float())
|
| 108 |
+
with torch.no_grad():
|
| 109 |
+
verts = self.bm(**body_parms).v
|
| 110 |
+
verts[:, :, 1] -= verts[:, :, 1].min()
|
| 111 |
+
idx = random.randint(0, verts.shape[0] - 1)
|
| 112 |
+
verts = verts[idx:idx + self.window_size].detach().cpu().numpy()
|
| 113 |
+
"Z Normalization"
|
| 114 |
+
verts = (verts - self.mean) / self.std
|
| 115 |
+
|
| 116 |
+
return verts
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class Text2MotionDataset(data.Dataset):
|
| 121 |
+
def __init__(self, mean, std, split_file, dataset_name, motion_dir, text_dir, unit_length, max_motion_length,
|
| 122 |
+
max_text_length, evaluation=False, is_mesh=False):
|
| 123 |
+
self.evaluation = evaluation
|
| 124 |
+
self.max_length = 20
|
| 125 |
+
self.pointer = 0
|
| 126 |
+
self.max_motion_length = max_motion_length
|
| 127 |
+
self.max_text_len = max_text_length
|
| 128 |
+
self.unit_length = unit_length
|
| 129 |
+
min_motion_len = 40 if dataset_name =='t2m' else 24
|
| 130 |
+
|
| 131 |
+
data_dict = {}
|
| 132 |
+
id_list = []
|
| 133 |
+
with cs.open(split_file, 'r') as f:
|
| 134 |
+
for line in f.readlines():
|
| 135 |
+
id_list.append(line.strip())
|
| 136 |
+
|
| 137 |
+
new_name_list = []
|
| 138 |
+
length_list = []
|
| 139 |
+
for name in tqdm(id_list):
|
| 140 |
+
try:
|
| 141 |
+
motion = np.load(pjoin(motion_dir, name + '.npy'))
|
| 142 |
+
if len(motion.shape) == 2:
|
| 143 |
+
motion = np.expand_dims(motion, axis=0)
|
| 144 |
+
if is_mesh:
|
| 145 |
+
if (len(motion)) < min_motion_len:
|
| 146 |
+
continue
|
| 147 |
+
else:
|
| 148 |
+
if (len(motion)) < min_motion_len or (len(motion) >= 200):
|
| 149 |
+
continue
|
| 150 |
+
text_data = []
|
| 151 |
+
flag = False
|
| 152 |
+
with cs.open(pjoin(text_dir, name + '.txt')) as f:
|
| 153 |
+
for line in f.readlines():
|
| 154 |
+
text_dict = {}
|
| 155 |
+
line_split = line.strip().split('#')
|
| 156 |
+
caption = line_split[0]
|
| 157 |
+
tokens = line_split[1].split(' ')
|
| 158 |
+
f_tag = float(line_split[2])
|
| 159 |
+
to_tag = float(line_split[3])
|
| 160 |
+
f_tag = 0.0 if np.isnan(f_tag) else f_tag
|
| 161 |
+
to_tag = 0.0 if np.isnan(to_tag) else to_tag
|
| 162 |
+
|
| 163 |
+
text_dict['caption'] = caption
|
| 164 |
+
text_dict['tokens'] = tokens
|
| 165 |
+
if f_tag == 0.0 and to_tag == 0.0:
|
| 166 |
+
flag = True
|
| 167 |
+
text_data.append(text_dict)
|
| 168 |
+
else:
|
| 169 |
+
try:
|
| 170 |
+
n_motion = motion[int(f_tag*20) : int(to_tag*20)]
|
| 171 |
+
if (len(n_motion)) < min_motion_len or (len(n_motion) >= 200):
|
| 172 |
+
continue
|
| 173 |
+
new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name
|
| 174 |
+
while new_name in data_dict:
|
| 175 |
+
new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name
|
| 176 |
+
data_dict[new_name] = {'motion': n_motion,
|
| 177 |
+
'length': len(n_motion),
|
| 178 |
+
'text':[text_dict]}
|
| 179 |
+
new_name_list.append(new_name)
|
| 180 |
+
length_list.append(len(n_motion))
|
| 181 |
+
except:
|
| 182 |
+
print(line_split)
|
| 183 |
+
print(line_split[2], line_split[3], f_tag, to_tag, name)
|
| 184 |
+
|
| 185 |
+
if flag:
|
| 186 |
+
data_dict[name] = {'motion': motion,
|
| 187 |
+
'length': len(motion),
|
| 188 |
+
'text': text_data}
|
| 189 |
+
new_name_list.append(name)
|
| 190 |
+
length_list.append(len(motion))
|
| 191 |
+
except:
|
| 192 |
+
pass
|
| 193 |
+
if self.evaluation:
|
| 194 |
+
self.w_vectorizer = GloVe('./glove', 'our_vab')
|
| 195 |
+
name_list, length_list = zip(*sorted(zip(new_name_list, length_list), key=lambda x: x[1]))
|
| 196 |
+
else:
|
| 197 |
+
name_list, length_list = new_name_list, length_list
|
| 198 |
+
self.mean = mean
|
| 199 |
+
self.std = std
|
| 200 |
+
self.length_arr = np.array(length_list)
|
| 201 |
+
self.data_dict = data_dict
|
| 202 |
+
self.name_list = name_list
|
| 203 |
+
if self.evaluation:
|
| 204 |
+
self.reset_max_len(self.max_length)
|
| 205 |
+
|
| 206 |
+
def reset_max_len(self, length):
|
| 207 |
+
assert length <= self.max_motion_length
|
| 208 |
+
self.pointer = np.searchsorted(self.length_arr, length)
|
| 209 |
+
print("Pointer Pointing at %d"%self.pointer)
|
| 210 |
+
self.max_length = length
|
| 211 |
+
|
| 212 |
+
def transform(self, data, mean=None, std=None):
|
| 213 |
+
if mean is None and std is None:
|
| 214 |
+
return (data - self.mean) / self.std
|
| 215 |
+
else:
|
| 216 |
+
return (data - mean) / std
|
| 217 |
+
|
| 218 |
+
def inv_transform(self, data, mean=None, std=None):
|
| 219 |
+
if mean is None and std is None:
|
| 220 |
+
return data * self.std + self.mean
|
| 221 |
+
else:
|
| 222 |
+
return data * std + mean
|
| 223 |
+
|
| 224 |
+
def __len__(self):
|
| 225 |
+
return len(self.data_dict) - self.pointer
|
| 226 |
+
|
| 227 |
+
def __getitem__(self, item):
|
| 228 |
+
idx = self.pointer + item
|
| 229 |
+
data = self.data_dict[self.name_list[idx]]
|
| 230 |
+
motion, m_length, text_list = data['motion'], data['length'], data['text']
|
| 231 |
+
# Randomly select a caption
|
| 232 |
+
text_data = random.choice(text_list)
|
| 233 |
+
caption, tokens = text_data['caption'], text_data['tokens']
|
| 234 |
+
|
| 235 |
+
if self.evaluation:
|
| 236 |
+
if len(tokens) < self.max_text_len:
|
| 237 |
+
# pad with "unk"
|
| 238 |
+
tokens = ['sos/OTHER'] + tokens + ['eos/OTHER']
|
| 239 |
+
sent_len = len(tokens)
|
| 240 |
+
tokens = tokens + ['unk/OTHER'] * (self.max_text_len + 2 - sent_len)
|
| 241 |
+
else:
|
| 242 |
+
# crop
|
| 243 |
+
tokens = tokens[:self.max_text_len]
|
| 244 |
+
tokens = ['sos/OTHER'] + tokens + ['eos/OTHER']
|
| 245 |
+
sent_len = len(tokens)
|
| 246 |
+
pos_one_hots = []
|
| 247 |
+
word_embeddings = []
|
| 248 |
+
for token in tokens:
|
| 249 |
+
word_emb, pos_oh = self.w_vectorizer[token]
|
| 250 |
+
pos_one_hots.append(pos_oh[None, :])
|
| 251 |
+
word_embeddings.append(word_emb[None, :])
|
| 252 |
+
pos_one_hots = np.concatenate(pos_one_hots, axis=0)
|
| 253 |
+
word_embeddings = np.concatenate(word_embeddings, axis=0)
|
| 254 |
+
|
| 255 |
+
if self.unit_length < 10:
|
| 256 |
+
coin2 = np.random.choice(['single', 'single', 'double'])
|
| 257 |
+
else:
|
| 258 |
+
coin2 = 'single'
|
| 259 |
+
|
| 260 |
+
if coin2 == 'double':
|
| 261 |
+
m_length = (m_length // self.unit_length - 1) * self.unit_length
|
| 262 |
+
elif coin2 == 'single':
|
| 263 |
+
m_length = (m_length // self.unit_length) * self.unit_length
|
| 264 |
+
idx = random.randint(0, len(motion) - m_length)
|
| 265 |
+
motion = motion[idx:idx+m_length]
|
| 266 |
+
|
| 267 |
+
"Z Normalization"
|
| 268 |
+
motion = (motion - self.mean) / self.std
|
| 269 |
+
|
| 270 |
+
if m_length < self.max_motion_length:
|
| 271 |
+
motion = np.concatenate([motion,
|
| 272 |
+
np.zeros((self.max_motion_length - m_length, motion.shape[1], motion.shape[2]))
|
| 273 |
+
], axis=0)
|
| 274 |
+
elif m_length > self.max_motion_length:
|
| 275 |
+
idx = random.randint(0, m_length - self.max_motion_length)
|
| 276 |
+
motion = motion[idx:idx + self.max_motion_length]
|
| 277 |
+
if self.evaluation:
|
| 278 |
+
return word_embeddings, pos_one_hots, caption, sent_len, motion, m_length, '_'.join(tokens)
|
| 279 |
+
else:
|
| 280 |
+
return caption, motion, m_length
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
class Text2MotionDataset_Another_V(data.Dataset):
|
| 284 |
+
def __init__(self, mean, std, split_file, dataset_name, motion_dir, text_dir, unit_length, max_motion_length,
|
| 285 |
+
max_text_length, evaluation=False, is_mesh=False):
|
| 286 |
+
self.evaluation = evaluation
|
| 287 |
+
self.max_length = 20
|
| 288 |
+
self.pointer = 0
|
| 289 |
+
self.max_motion_length = max_motion_length
|
| 290 |
+
self.max_text_len = max_text_length
|
| 291 |
+
self.unit_length = unit_length
|
| 292 |
+
min_motion_len = 40 if dataset_name =='t2m' else 24
|
| 293 |
+
|
| 294 |
+
data_dict = {}
|
| 295 |
+
id_list = []
|
| 296 |
+
with cs.open(split_file, 'r') as f:
|
| 297 |
+
for line in f.readlines():
|
| 298 |
+
id_list.append(line.strip())
|
| 299 |
+
|
| 300 |
+
new_name_list = []
|
| 301 |
+
length_list = []
|
| 302 |
+
for name in tqdm(id_list):
|
| 303 |
+
try:
|
| 304 |
+
motion = np.load(pjoin(motion_dir, name + '.npy'))
|
| 305 |
+
if not self.evaluation:
|
| 306 |
+
if len(motion.shape) == 2:
|
| 307 |
+
motion = np.expand_dims(motion, axis=0)
|
| 308 |
+
if is_mesh:
|
| 309 |
+
if (len(motion)) < min_motion_len:
|
| 310 |
+
continue
|
| 311 |
+
else:
|
| 312 |
+
if (len(motion)) < min_motion_len or (len(motion) >= 200):
|
| 313 |
+
continue
|
| 314 |
+
text_data = []
|
| 315 |
+
flag = False
|
| 316 |
+
with cs.open(pjoin(text_dir, name + '.txt')) as f:
|
| 317 |
+
for line in f.readlines():
|
| 318 |
+
text_dict = {}
|
| 319 |
+
line_split = line.strip().split('#')
|
| 320 |
+
caption = line_split[0]
|
| 321 |
+
tokens = line_split[1].split(' ')
|
| 322 |
+
f_tag = float(line_split[2])
|
| 323 |
+
to_tag = float(line_split[3])
|
| 324 |
+
f_tag = 0.0 if np.isnan(f_tag) else f_tag
|
| 325 |
+
to_tag = 0.0 if np.isnan(to_tag) else to_tag
|
| 326 |
+
|
| 327 |
+
text_dict['caption'] = caption
|
| 328 |
+
text_dict['tokens'] = tokens
|
| 329 |
+
if f_tag == 0.0 and to_tag == 0.0:
|
| 330 |
+
flag = True
|
| 331 |
+
text_data.append(text_dict)
|
| 332 |
+
else:
|
| 333 |
+
try:
|
| 334 |
+
n_motion = motion[int(f_tag*20) : int(to_tag*20)]
|
| 335 |
+
if (len(n_motion)) < min_motion_len or (len(n_motion) >= 200):
|
| 336 |
+
continue
|
| 337 |
+
new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name
|
| 338 |
+
while new_name in data_dict:
|
| 339 |
+
new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name
|
| 340 |
+
data_dict[new_name] = {'motion': n_motion,
|
| 341 |
+
'length': len(n_motion),
|
| 342 |
+
'text':[text_dict]}
|
| 343 |
+
new_name_list.append(new_name)
|
| 344 |
+
length_list.append(len(n_motion))
|
| 345 |
+
except:
|
| 346 |
+
print(line_split)
|
| 347 |
+
print(line_split[2], line_split[3], f_tag, to_tag, name)
|
| 348 |
+
|
| 349 |
+
if flag:
|
| 350 |
+
data_dict[name] = {'motion': motion,
|
| 351 |
+
'length': len(motion),
|
| 352 |
+
'text': text_data}
|
| 353 |
+
new_name_list.append(name)
|
| 354 |
+
length_list.append(len(motion))
|
| 355 |
+
except:
|
| 356 |
+
pass
|
| 357 |
+
if self.evaluation:
|
| 358 |
+
self.w_vectorizer = GloVe('./glove', 'our_vab')
|
| 359 |
+
name_list, length_list = zip(*sorted(zip(new_name_list, length_list), key=lambda x: x[1]))
|
| 360 |
+
else:
|
| 361 |
+
name_list, length_list = new_name_list, length_list
|
| 362 |
+
self.mean = mean
|
| 363 |
+
self.std = std
|
| 364 |
+
self.length_arr = np.array(length_list)
|
| 365 |
+
self.data_dict = data_dict
|
| 366 |
+
self.name_list = name_list
|
| 367 |
+
if self.evaluation:
|
| 368 |
+
self.reset_max_len(self.max_length)
|
| 369 |
+
|
| 370 |
+
def reset_max_len(self, length):
|
| 371 |
+
assert length <= self.max_motion_length
|
| 372 |
+
self.pointer = np.searchsorted(self.length_arr, length)
|
| 373 |
+
print("Pointer Pointing at %d"%self.pointer)
|
| 374 |
+
self.max_length = length
|
| 375 |
+
|
| 376 |
+
def transform(self, data, mean=None, std=None):
|
| 377 |
+
if mean is None and std is None:
|
| 378 |
+
return (data - self.mean) / self.std
|
| 379 |
+
else:
|
| 380 |
+
return (data - mean) / std
|
| 381 |
+
|
| 382 |
+
def inv_transform(self, data, mean=None, std=None):
|
| 383 |
+
if mean is None and std is None:
|
| 384 |
+
return data * self.std + self.mean
|
| 385 |
+
else:
|
| 386 |
+
return data * std + mean
|
| 387 |
+
|
| 388 |
+
def __len__(self):
|
| 389 |
+
return len(self.data_dict) - self.pointer
|
| 390 |
+
|
| 391 |
+
def __getitem__(self, item):
|
| 392 |
+
idx = self.pointer + item
|
| 393 |
+
data = self.data_dict[self.name_list[idx]]
|
| 394 |
+
motion, m_length, text_list = data['motion'], data['length'], data['text']
|
| 395 |
+
# Randomly select a caption
|
| 396 |
+
text_data = random.choice(text_list)
|
| 397 |
+
caption, tokens = text_data['caption'], text_data['tokens']
|
| 398 |
+
|
| 399 |
+
if self.evaluation:
|
| 400 |
+
if len(tokens) < self.max_text_len:
|
| 401 |
+
# pad with "unk"
|
| 402 |
+
tokens = ['sos/OTHER'] + tokens + ['eos/OTHER']
|
| 403 |
+
sent_len = len(tokens)
|
| 404 |
+
tokens = tokens + ['unk/OTHER'] * (self.max_text_len + 2 - sent_len)
|
| 405 |
+
else:
|
| 406 |
+
# crop
|
| 407 |
+
tokens = tokens[:self.max_text_len]
|
| 408 |
+
tokens = ['sos/OTHER'] + tokens + ['eos/OTHER']
|
| 409 |
+
sent_len = len(tokens)
|
| 410 |
+
pos_one_hots = []
|
| 411 |
+
word_embeddings = []
|
| 412 |
+
for token in tokens:
|
| 413 |
+
word_emb, pos_oh = self.w_vectorizer[token]
|
| 414 |
+
pos_one_hots.append(pos_oh[None, :])
|
| 415 |
+
word_embeddings.append(word_emb[None, :])
|
| 416 |
+
pos_one_hots = np.concatenate(pos_one_hots, axis=0)
|
| 417 |
+
word_embeddings = np.concatenate(word_embeddings, axis=0)
|
| 418 |
+
|
| 419 |
+
if self.unit_length < 10:
|
| 420 |
+
coin2 = np.random.choice(['single', 'single', 'double'])
|
| 421 |
+
else:
|
| 422 |
+
coin2 = 'single'
|
| 423 |
+
|
| 424 |
+
if coin2 == 'double':
|
| 425 |
+
m_length = (m_length // self.unit_length - 1) * self.unit_length
|
| 426 |
+
elif coin2 == 'single':
|
| 427 |
+
m_length = (m_length // self.unit_length) * self.unit_length
|
| 428 |
+
idx = random.randint(0, len(motion) - m_length)
|
| 429 |
+
motion = motion[idx:idx+m_length]
|
| 430 |
+
|
| 431 |
+
"Z Normalization"
|
| 432 |
+
if self.evaluation:
|
| 433 |
+
motion = motion[:, :self.mean.shape[0]]
|
| 434 |
+
motion = (motion - self.mean) / self.std
|
| 435 |
+
|
| 436 |
+
if m_length < self.max_motion_length:
|
| 437 |
+
if self.evaluation:
|
| 438 |
+
motion = np.concatenate([motion,
|
| 439 |
+
np.zeros((self.max_motion_length - m_length, motion.shape[1]))
|
| 440 |
+
], axis=0)
|
| 441 |
+
else:
|
| 442 |
+
motion = np.concatenate([motion,
|
| 443 |
+
np.zeros((self.max_motion_length - m_length, motion.shape[1], motion.shape[2]))
|
| 444 |
+
], axis=0)
|
| 445 |
+
elif m_length > self.max_motion_length:
|
| 446 |
+
if not self.evaluation:
|
| 447 |
+
idx = random.randint(0, m_length - self.max_motion_length)
|
| 448 |
+
motion = motion[idx:idx + self.max_motion_length]
|
| 449 |
+
if self.evaluation:
|
| 450 |
+
return word_embeddings, pos_one_hots, caption, sent_len, motion, m_length, '_'.join(tokens)
|
| 451 |
+
else:
|
| 452 |
+
return caption, motion, m_length
|
utils/eval_mean_std/t2m/eval_mean.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1f866545af58e1d603b649fbf47680ffdcbf2856f8a3771c7e8aadb1098d560f
|
| 3 |
+
size 396
|
utils/eval_mean_std/t2m/eval_std.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7898ab84360e67ef06046df9e39001b02182d3aeae2ae4e3a3f46277bd748045
|
| 3 |
+
size 396
|
utils/eval_utils.py
ADDED
|
@@ -0,0 +1,928 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
from scipy import linalg
|
| 4 |
+
from scipy.ndimage import uniform_filter1d
|
| 5 |
+
import torch
|
| 6 |
+
from utils.motion_process import recover_from_ric
|
| 7 |
+
from utils.back_process import back_process
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
|
| 10 |
+
#################################################################################
|
| 11 |
+
# Eval Function Loops #
|
| 12 |
+
#################################################################################
|
| 13 |
+
@torch.no_grad()
|
| 14 |
+
def evaluation_ae(out_dir, val_loader, net, writer, ep, eval_wrapper, device, best_fid=1000, best_div=0,
|
| 15 |
+
best_top1=0, best_top2=0, best_top3=0, best_matching=100,
|
| 16 |
+
eval_mean=None, eval_std=None, save=True, draw=True):
|
| 17 |
+
net.eval()
|
| 18 |
+
|
| 19 |
+
motion_annotation_list = []
|
| 20 |
+
motion_pred_list = []
|
| 21 |
+
|
| 22 |
+
R_precision_real = 0
|
| 23 |
+
R_precision = 0
|
| 24 |
+
|
| 25 |
+
nb_sample = 0
|
| 26 |
+
matching_score_real = 0
|
| 27 |
+
matching_score_pred = 0
|
| 28 |
+
mpjpe = 0
|
| 29 |
+
num_poses = 0
|
| 30 |
+
|
| 31 |
+
for batch in tqdm(val_loader):
|
| 32 |
+
word_embeddings, pos_one_hots, caption, sent_len, motion, m_length, token = batch
|
| 33 |
+
|
| 34 |
+
bs, seq = motion.shape[0], motion.shape[1]
|
| 35 |
+
gt = val_loader.dataset.inv_transform(motion.detach().cpu().numpy())
|
| 36 |
+
bgt = []
|
| 37 |
+
for j in range(bs):
|
| 38 |
+
bgt.append(back_process(gt[j], is_mesh=False))
|
| 39 |
+
bgt = np.stack(bgt, axis=0)
|
| 40 |
+
bgt = val_loader.dataset.transform(bgt, eval_mean, eval_std)
|
| 41 |
+
|
| 42 |
+
bgt = torch.from_numpy(bgt).to(device)
|
| 43 |
+
(et, em), (_, _) = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, caption, bgt, m_length-1)
|
| 44 |
+
|
| 45 |
+
motion = motion.float().to(device)
|
| 46 |
+
with torch.no_grad():
|
| 47 |
+
pred_pose_eval = net.forward(motion)
|
| 48 |
+
pred = val_loader.dataset.inv_transform(pred_pose_eval.detach().cpu().numpy())
|
| 49 |
+
bpred = []
|
| 50 |
+
for j in range(bs):
|
| 51 |
+
bpred.append(back_process(pred[j], is_mesh=False))
|
| 52 |
+
bpred = np.stack(bpred, axis=0)
|
| 53 |
+
bpred = val_loader.dataset.transform(bpred, eval_mean, eval_std)
|
| 54 |
+
|
| 55 |
+
(et_pred, em_pred), (_, _) = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, caption,
|
| 56 |
+
torch.from_numpy(bpred).to(device), m_length-1)
|
| 57 |
+
for i in range(bs):
|
| 58 |
+
gtt = torch.from_numpy(gt[i, :m_length[i]]).float().reshape(-1, 22, 3)
|
| 59 |
+
predd = torch.from_numpy(pred[i, :m_length[i]]).float().reshape(-1, 22, 3)
|
| 60 |
+
mpjpe += torch.sum(calculate_mpjpe(gtt, predd))
|
| 61 |
+
num_poses += gt.shape[0]
|
| 62 |
+
|
| 63 |
+
motion_pred_list.append(em_pred)
|
| 64 |
+
motion_annotation_list.append(em)
|
| 65 |
+
|
| 66 |
+
temp_R = calculate_R_precision(et.cpu().numpy(), em.cpu().numpy(), top_k=3, sum_all=True)
|
| 67 |
+
temp_match = euclidean_distance_matrix(et.cpu().numpy(), em.cpu().numpy()).trace()
|
| 68 |
+
R_precision_real += temp_R
|
| 69 |
+
matching_score_real += temp_match
|
| 70 |
+
temp_R = calculate_R_precision(et_pred.cpu().numpy(), em_pred.cpu().numpy(), top_k=3, sum_all=True)
|
| 71 |
+
temp_match = euclidean_distance_matrix(et_pred.cpu().numpy(), em_pred.cpu().numpy()).trace()
|
| 72 |
+
R_precision += temp_R
|
| 73 |
+
matching_score_pred += temp_match
|
| 74 |
+
|
| 75 |
+
nb_sample += bs
|
| 76 |
+
|
| 77 |
+
motion_annotation_np = torch.cat(motion_annotation_list, dim=0).cpu().numpy()
|
| 78 |
+
motion_pred_np = torch.cat(motion_pred_list, dim=0).cpu().numpy()
|
| 79 |
+
gt_mu, gt_cov = calculate_activation_statistics(motion_annotation_np)
|
| 80 |
+
mu, cov = calculate_activation_statistics(motion_pred_np)
|
| 81 |
+
|
| 82 |
+
diversity_real = calculate_diversity(motion_annotation_np, 300 if nb_sample > 300 else 100)
|
| 83 |
+
diversity = calculate_diversity(motion_pred_np, 300 if nb_sample > 300 else 100)
|
| 84 |
+
|
| 85 |
+
R_precision_real = R_precision_real / nb_sample
|
| 86 |
+
R_precision = R_precision / nb_sample
|
| 87 |
+
|
| 88 |
+
matching_score_real = matching_score_real / nb_sample
|
| 89 |
+
matching_score_pred = matching_score_pred / nb_sample
|
| 90 |
+
mpjpe = mpjpe / num_poses
|
| 91 |
+
|
| 92 |
+
fid = calculate_frechet_distance(gt_mu, gt_cov, mu, cov)
|
| 93 |
+
|
| 94 |
+
msg = "--> \t Eva. Re %d:, FID. %.4f, Diversity Real. %.4f, Diversity. %.4f, R_precision_real. (%.4f, %.4f, %.4f), R_precision. (%.4f, %.4f, %.4f), matching_real. %.4f, matching_pred. %.4f, MPJPE. %.4f" % \
|
| 95 |
+
(ep, fid, diversity_real, diversity, R_precision_real[0], R_precision_real[1], R_precision_real[2],
|
| 96 |
+
R_precision[0], R_precision[1], R_precision[2], matching_score_real, matching_score_pred, mpjpe)
|
| 97 |
+
print(msg)
|
| 98 |
+
if draw:
|
| 99 |
+
writer.add_scalar('./Test/FID', fid, ep)
|
| 100 |
+
writer.add_scalar('./Test/Diversity', diversity, ep)
|
| 101 |
+
writer.add_scalar('./Test/top1', R_precision[0], ep)
|
| 102 |
+
writer.add_scalar('./Test/top2', R_precision[1], ep)
|
| 103 |
+
writer.add_scalar('./Test/top3', R_precision[2], ep)
|
| 104 |
+
writer.add_scalar('./Test/matching_score', matching_score_pred, ep)
|
| 105 |
+
|
| 106 |
+
if fid < best_fid:
|
| 107 |
+
msg = "--> --> \t FID Improved from %.5f to %.5f !!!" % (best_fid, fid)
|
| 108 |
+
if draw: print(msg)
|
| 109 |
+
best_fid = fid
|
| 110 |
+
if save:
|
| 111 |
+
torch.save({'ae': net.state_dict(), 'ep': ep}, os.path.join(out_dir, 'net_best_fid.tar'))
|
| 112 |
+
|
| 113 |
+
if abs(diversity_real - diversity) < abs(diversity_real - best_div):
|
| 114 |
+
msg = "--> --> \t Diversity Improved from %.5f to %.5f !!!"%(best_div, diversity)
|
| 115 |
+
if draw: print(msg)
|
| 116 |
+
best_div = diversity
|
| 117 |
+
|
| 118 |
+
if R_precision[0] > best_top1:
|
| 119 |
+
msg = "--> --> \t Top1 Improved from %.5f to %.5f !!!" % (best_top1, R_precision[0])
|
| 120 |
+
if draw: print(msg)
|
| 121 |
+
best_top1 = R_precision[0]
|
| 122 |
+
|
| 123 |
+
if R_precision[1] > best_top2:
|
| 124 |
+
msg = "--> --> \t Top2 Improved from %.5f to %.5f!!!" % (best_top2, R_precision[1])
|
| 125 |
+
if draw: print(msg)
|
| 126 |
+
best_top2 = R_precision[1]
|
| 127 |
+
|
| 128 |
+
if R_precision[2] > best_top3:
|
| 129 |
+
msg = "--> --> \t Top3 Improved from %.5f to %.5f !!!" % (best_top3, R_precision[2])
|
| 130 |
+
if draw: print(msg)
|
| 131 |
+
best_top3 = R_precision[2]
|
| 132 |
+
|
| 133 |
+
if matching_score_pred < best_matching:
|
| 134 |
+
msg = f"--> --> \t matching_score Improved from %.5f to %.5f !!!" % (best_matching, matching_score_pred)
|
| 135 |
+
if draw: print(msg)
|
| 136 |
+
best_matching = matching_score_pred
|
| 137 |
+
|
| 138 |
+
net.train()
|
| 139 |
+
return best_fid, best_div, best_top1, best_top2, best_top3, best_matching, mpjpe, writer
|
| 140 |
+
|
| 141 |
+
@torch.no_grad()
|
| 142 |
+
def evaluation_acmdm(out_dir, val_loader, ema_acmdm, ae, writer, ep, best_fid, best_div,
|
| 143 |
+
best_top1, best_top2, best_top3, best_matching, eval_wrapper, device, clip_score_old,
|
| 144 |
+
cond_scale=None, cal_mm=False, eval_mean=None, eval_std=None, after_mean=None, after_std=None, mesh_mean=None, mesh_std=None,
|
| 145 |
+
draw=True,
|
| 146 |
+
is_raw=False,
|
| 147 |
+
is_prefix=False,
|
| 148 |
+
is_control=False, index=[0], intensity=100,
|
| 149 |
+
is_mesh=False):
|
| 150 |
+
|
| 151 |
+
ema_acmdm.eval()
|
| 152 |
+
if not is_raw:
|
| 153 |
+
ae.eval()
|
| 154 |
+
|
| 155 |
+
save=False
|
| 156 |
+
|
| 157 |
+
motion_annotation_list = []
|
| 158 |
+
motion_pred_list = []
|
| 159 |
+
motion_multimodality = []
|
| 160 |
+
R_precision_real = 0
|
| 161 |
+
R_precision = 0
|
| 162 |
+
matching_score_real = 0
|
| 163 |
+
matching_score_pred = 0
|
| 164 |
+
multimodality = 0
|
| 165 |
+
if cond_scale is None:
|
| 166 |
+
if "kit" in out_dir:
|
| 167 |
+
cond_scale = 2.5
|
| 168 |
+
else:
|
| 169 |
+
cond_scale = 2.5
|
| 170 |
+
clip_score_real = 0
|
| 171 |
+
clip_score_gt = 0
|
| 172 |
+
skate_ratio_sum = 0
|
| 173 |
+
dist_sum = 0
|
| 174 |
+
traj_err = []
|
| 175 |
+
|
| 176 |
+
nb_sample = 0
|
| 177 |
+
if cal_mm:
|
| 178 |
+
num_mm_batch = 3
|
| 179 |
+
else:
|
| 180 |
+
num_mm_batch = 0
|
| 181 |
+
|
| 182 |
+
for i, batch in enumerate(tqdm(val_loader)):
|
| 183 |
+
word_embeddings, pos_one_hots, clip_text, sent_len, pose, m_length, token = batch
|
| 184 |
+
m_length = m_length.to(device)
|
| 185 |
+
|
| 186 |
+
bs, seq = pose.shape[:2]
|
| 187 |
+
if i < num_mm_batch:
|
| 188 |
+
motion_multimodality_batch = []
|
| 189 |
+
batch_clip_score_pred = 0
|
| 190 |
+
for _ in tqdm(range(30)):
|
| 191 |
+
pred_latents = ema_acmdm.generate(clip_text, m_length//4 if not is_raw else m_length, cond_scale)
|
| 192 |
+
|
| 193 |
+
if not is_raw:
|
| 194 |
+
pred_latents = val_loader.dataset.inv_transform(pred_latents.permute(0, 2, 3, 1).detach().cpu().numpy(),
|
| 195 |
+
after_mean, after_std)
|
| 196 |
+
pred_latents = torch.from_numpy(pred_latents).to(device)
|
| 197 |
+
with torch.no_grad():
|
| 198 |
+
pred_motions = ae.decode(pred_latents.permute(0,3,1,2))
|
| 199 |
+
else:
|
| 200 |
+
pred_motions = pred_latents.permute(0, 2, 3, 1)
|
| 201 |
+
pred_motions = val_loader.dataset.inv_transform(pred_motions.detach().cpu().numpy())
|
| 202 |
+
pred_motionss = []
|
| 203 |
+
for j in range(bs):
|
| 204 |
+
pred_motionss.append(back_process(pred_motions[j], is_mesh=is_mesh))
|
| 205 |
+
pred_motionss = np.stack(pred_motionss, axis=0)
|
| 206 |
+
pred_motions = val_loader.dataset.transform(pred_motionss, eval_mean, eval_std)
|
| 207 |
+
(et_pred, em_pred), (et_pred_clip, em_pred_clip) = eval_wrapper.get_co_embeddings(word_embeddings,
|
| 208 |
+
pos_one_hots,
|
| 209 |
+
sent_len,
|
| 210 |
+
clip_text,
|
| 211 |
+
torch.from_numpy(
|
| 212 |
+
pred_motions).to(
|
| 213 |
+
device),
|
| 214 |
+
m_length - 1)
|
| 215 |
+
motion_multimodality_batch.append(em_pred.unsqueeze(1))
|
| 216 |
+
motion_multimodality_batch = torch.cat(motion_multimodality_batch, dim=1) #(bs, 30, d)
|
| 217 |
+
motion_multimodality.append(motion_multimodality_batch)
|
| 218 |
+
for j in range(bs):
|
| 219 |
+
single_em = em_pred_clip[j]
|
| 220 |
+
single_et = et_pred_clip[j]
|
| 221 |
+
clip_score = (single_em @ single_et.T).item()
|
| 222 |
+
batch_clip_score_pred += clip_score
|
| 223 |
+
clip_score_real += batch_clip_score_pred
|
| 224 |
+
else:
|
| 225 |
+
if is_control:
|
| 226 |
+
pred_latents, mask_hint = ema_acmdm.generate_control(clip_text, m_length//4, pose.clone().float().to(device).permute(0,3,1,2), index, intensity,
|
| 227 |
+
cond_scale)
|
| 228 |
+
mask_hint = mask_hint.permute(0, 2, 3, 1).cpu().numpy()
|
| 229 |
+
elif is_prefix:
|
| 230 |
+
motion = pose.clone().float().to(device)
|
| 231 |
+
with torch.no_grad():
|
| 232 |
+
motion = ae.encode(motion)
|
| 233 |
+
amean = torch.from_numpy(after_mean).to(device)
|
| 234 |
+
astd = torch.from_numpy(after_std).to(device)
|
| 235 |
+
motion = ((motion.permute(0,2,3,1)-amean)/astd).permute(0,3,1,2)
|
| 236 |
+
pred_latents = ema_acmdm.generate(clip_text, m_length // 4, cond_scale, motion[:, :, :5, :]) # 20+40 style
|
| 237 |
+
else:
|
| 238 |
+
pred_latents = ema_acmdm.generate(clip_text, m_length//4 if not (is_raw or is_mesh) else m_length, cond_scale, j=22 if not is_mesh else 28)
|
| 239 |
+
if not is_raw:
|
| 240 |
+
pred_latents = val_loader.dataset.inv_transform(pred_latents.permute(0,2,3,1).detach().cpu().numpy(), after_mean, after_std)
|
| 241 |
+
pred_latents = torch.from_numpy(pred_latents).to(device)
|
| 242 |
+
with torch.no_grad():
|
| 243 |
+
if not is_mesh:
|
| 244 |
+
pred_latents = pred_latents.permute(0, 3, 1, 2)
|
| 245 |
+
pred_motions = ae.decode(pred_latents)
|
| 246 |
+
else:
|
| 247 |
+
pred_motions = pred_latents.permute(0, 2, 3, 1)
|
| 248 |
+
if not is_mesh:
|
| 249 |
+
pred_motions = val_loader.dataset.inv_transform(pred_motions.detach().cpu().numpy())
|
| 250 |
+
else:
|
| 251 |
+
pred_motions = val_loader.dataset.inv_transform(pred_motions.detach().cpu().numpy(), mesh_mean, mesh_std)
|
| 252 |
+
J_regressor = np.load('body_models/J_regressor.npy')
|
| 253 |
+
pred_motions = np.einsum('jk,btkc->btjc', J_regressor, pred_motions)[:, :, :22]
|
| 254 |
+
if is_control:
|
| 255 |
+
# foot skate
|
| 256 |
+
skate_ratio, skate_vel = calculate_skating_ratio(torch.from_numpy(pred_motions.transpose(0,2,3,1))) # [batch_size]
|
| 257 |
+
skate_ratio_sum += skate_ratio.sum()
|
| 258 |
+
# control errors
|
| 259 |
+
hint = val_loader.dataset.inv_transform(pose.clone().detach().cpu().numpy())
|
| 260 |
+
hint = hint * mask_hint
|
| 261 |
+
for i, (mot, h, mask) in enumerate(zip(pred_motions, hint, mask_hint)):
|
| 262 |
+
control_error = control_l2(np.expand_dims(mot, axis=0), np.expand_dims(h, axis=0),
|
| 263 |
+
np.expand_dims(mask, axis=0))
|
| 264 |
+
mean_error = control_error.sum() / mask.sum()
|
| 265 |
+
dist_sum += mean_error
|
| 266 |
+
control_error = control_error.reshape(-1)
|
| 267 |
+
mask = mask.reshape(-1)
|
| 268 |
+
err_np = calculate_trajectory_error(control_error, mean_error, mask)
|
| 269 |
+
traj_err.append(err_np)
|
| 270 |
+
|
| 271 |
+
pred_motionss = []
|
| 272 |
+
for j in range(bs):
|
| 273 |
+
pred_motionss.append(back_process(pred_motions[j], is_mesh=is_mesh))
|
| 274 |
+
pred_motionss = np.stack(pred_motionss, axis=0)
|
| 275 |
+
pred_motions = val_loader.dataset.transform(pred_motionss, eval_mean, eval_std)
|
| 276 |
+
(et_pred, em_pred), (et_pred_clip, em_pred_clip) = eval_wrapper.get_co_embeddings(word_embeddings,
|
| 277 |
+
pos_one_hots, sent_len,
|
| 278 |
+
clip_text,
|
| 279 |
+
torch.from_numpy(pred_motions).to(device),
|
| 280 |
+
m_length-1)
|
| 281 |
+
batch_clip_score_pred = 0
|
| 282 |
+
for j in range(bs):
|
| 283 |
+
single_em = em_pred_clip[j]
|
| 284 |
+
single_et = et_pred_clip[j]
|
| 285 |
+
clip_score = (single_em @ single_et.T).item()
|
| 286 |
+
batch_clip_score_pred += clip_score
|
| 287 |
+
clip_score_real += batch_clip_score_pred
|
| 288 |
+
|
| 289 |
+
pose = val_loader.dataset.inv_transform(pose.detach().cpu().numpy())
|
| 290 |
+
poses = []
|
| 291 |
+
for j in range(bs):
|
| 292 |
+
poses.append(back_process(pose[j], is_mesh=False))
|
| 293 |
+
poses = np.stack(poses, axis=0)
|
| 294 |
+
pose = val_loader.dataset.transform(poses, eval_mean, eval_std)
|
| 295 |
+
pose = torch.from_numpy(pose).cuda().float()
|
| 296 |
+
pose = pose.cuda().float()
|
| 297 |
+
(et, em), (et_clip, em_clip) = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, clip_text,
|
| 298 |
+
pose.clone(), m_length-1)
|
| 299 |
+
batch_clip_score = 0
|
| 300 |
+
for j in range(bs):
|
| 301 |
+
single_em = em_clip[j]
|
| 302 |
+
single_et = et_clip[j]
|
| 303 |
+
clip_score = (single_em @ single_et.T).item()
|
| 304 |
+
batch_clip_score += clip_score
|
| 305 |
+
clip_score_gt += batch_clip_score
|
| 306 |
+
motion_annotation_list.append(em)
|
| 307 |
+
motion_pred_list.append(em_pred)
|
| 308 |
+
|
| 309 |
+
temp_R = calculate_R_precision(et.cpu().numpy(), em.cpu().numpy(), top_k=3, sum_all=True)
|
| 310 |
+
temp_match = euclidean_distance_matrix(et.cpu().numpy(), em.cpu().numpy()).trace()
|
| 311 |
+
R_precision_real += temp_R
|
| 312 |
+
matching_score_real += temp_match
|
| 313 |
+
temp_R = calculate_R_precision(et_pred.cpu().numpy(), em_pred.cpu().numpy(), top_k=3, sum_all=True)
|
| 314 |
+
temp_match = euclidean_distance_matrix(et_pred.cpu().numpy(), em_pred.cpu().numpy()).trace()
|
| 315 |
+
R_precision += temp_R
|
| 316 |
+
matching_score_pred += temp_match
|
| 317 |
+
|
| 318 |
+
nb_sample += bs
|
| 319 |
+
|
| 320 |
+
motion_annotation_np = torch.cat(motion_annotation_list, dim=0).cpu().numpy()
|
| 321 |
+
motion_pred_np = torch.cat(motion_pred_list, dim=0).cpu().numpy()
|
| 322 |
+
gt_mu, gt_cov = calculate_activation_statistics(motion_annotation_np)
|
| 323 |
+
mu, cov = calculate_activation_statistics(motion_pred_np)
|
| 324 |
+
|
| 325 |
+
diversity_real = calculate_diversity(motion_annotation_np, 300 if nb_sample > 300 else 100)
|
| 326 |
+
diversity = calculate_diversity(motion_pred_np, 300 if nb_sample > 300 else 100)
|
| 327 |
+
|
| 328 |
+
R_precision_real = R_precision_real / nb_sample
|
| 329 |
+
R_precision = R_precision / nb_sample
|
| 330 |
+
|
| 331 |
+
clip_score_real = clip_score_real / nb_sample
|
| 332 |
+
clip_score_gt = clip_score_gt / nb_sample
|
| 333 |
+
|
| 334 |
+
matching_score_real = matching_score_real / nb_sample
|
| 335 |
+
matching_score_pred = matching_score_pred / nb_sample
|
| 336 |
+
if is_control:
|
| 337 |
+
# l2 dist
|
| 338 |
+
dist_mean = dist_sum / nb_sample
|
| 339 |
+
|
| 340 |
+
# Skating evaluation
|
| 341 |
+
skating_score = skate_ratio_sum / nb_sample
|
| 342 |
+
|
| 343 |
+
### For trajecotry evaluation from GMD ###
|
| 344 |
+
traj_err = np.stack(traj_err).mean(0)
|
| 345 |
+
|
| 346 |
+
if cal_mm:
|
| 347 |
+
motion_multimodality = torch.cat(motion_multimodality, dim=0).cpu().numpy()
|
| 348 |
+
multimodality = calculate_multimodality(motion_multimodality, 10)
|
| 349 |
+
|
| 350 |
+
fid = calculate_frechet_distance(gt_mu, gt_cov, mu, cov)
|
| 351 |
+
|
| 352 |
+
msg = (f"--> \t Eva. Ep/Re {ep} :, FID. {fid:.4f}, Diversity Real. {diversity_real:.4f}, Diversity."
|
| 353 |
+
f" {diversity:.4f}, R_precision_real. {R_precision_real}, R_precision. {R_precision},"
|
| 354 |
+
f" matching_score_real. {matching_score_real}, matching_score_pred. {matching_score_pred}"
|
| 355 |
+
f" multimodality. {multimodality:.4f}, clip score. {clip_score_real}"
|
| 356 |
+
+ (f" foot skating. {skating_score:.4f}, traj error. {traj_err[1].item():.4f}, pos error. {traj_err[3].item():.4f}, avg error. {traj_err[4].item():.4f}"
|
| 357 |
+
if is_control else ""))
|
| 358 |
+
print(msg)
|
| 359 |
+
|
| 360 |
+
if draw:
|
| 361 |
+
writer.add_scalar('./Test/FID', fid, ep)
|
| 362 |
+
writer.add_scalar('./Test/Diversity', diversity, ep)
|
| 363 |
+
writer.add_scalar('./Test/top1', R_precision[0], ep)
|
| 364 |
+
writer.add_scalar('./Test/top2', R_precision[1], ep)
|
| 365 |
+
writer.add_scalar('./Test/top3', R_precision[2], ep)
|
| 366 |
+
writer.add_scalar('./Test/matching_score', matching_score_pred, ep)
|
| 367 |
+
writer.add_scalar('./Test/clip_score', clip_score_real, ep)
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
if fid < best_fid:
|
| 371 |
+
msg = f"--> --> \t FID Improved from {best_fid:.5f} to {fid:.5f} !!!"
|
| 372 |
+
if draw: print(msg)
|
| 373 |
+
best_fid, best_ep = fid, ep
|
| 374 |
+
save=True
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
if matching_score_pred < best_matching:
|
| 378 |
+
msg = f"--> --> \t matching_score Improved from {best_matching:.5f} to {matching_score_pred:.5f} !!!"
|
| 379 |
+
if draw: print(msg)
|
| 380 |
+
best_matching = matching_score_pred
|
| 381 |
+
|
| 382 |
+
if abs(diversity_real - diversity) < abs(diversity_real - best_div):
|
| 383 |
+
msg = f"--> --> \t Diversity Improved from {best_div:.5f} to {diversity:.5f} !!!"
|
| 384 |
+
if draw: print(msg)
|
| 385 |
+
best_div = diversity
|
| 386 |
+
|
| 387 |
+
if R_precision[0] > best_top1:
|
| 388 |
+
msg = f"--> --> \t Top1 Improved from {best_top1:.4f} to {R_precision[0]:.4f} !!!"
|
| 389 |
+
if draw: print(msg)
|
| 390 |
+
best_top1 = R_precision[0]
|
| 391 |
+
|
| 392 |
+
if R_precision[1] > best_top2:
|
| 393 |
+
msg = f"--> --> \t Top2 Improved from {best_top2:.4f} to {R_precision[1]:.4f} !!!"
|
| 394 |
+
if draw: print(msg)
|
| 395 |
+
best_top2 = R_precision[1]
|
| 396 |
+
|
| 397 |
+
if R_precision[2] > best_top3:
|
| 398 |
+
msg = f"--> --> \t Top3 Improved from {best_top3:.4f} to {R_precision[2]:.4f} !!!"
|
| 399 |
+
if draw: print(msg)
|
| 400 |
+
best_top3 = R_precision[2]
|
| 401 |
+
|
| 402 |
+
if clip_score_real > clip_score_old:
|
| 403 |
+
msg = f"--> --> \t CLIP-score Improved from {clip_score_old:.4f} to {clip_score_real:.4f} !!!"
|
| 404 |
+
if draw: print(msg)
|
| 405 |
+
clip_score_old = clip_score_real
|
| 406 |
+
|
| 407 |
+
if cal_mm:
|
| 408 |
+
return best_fid, best_div, best_top1, best_top2, best_top3, best_matching, multimodality, clip_score_old, writer, save
|
| 409 |
+
else:
|
| 410 |
+
if is_control:
|
| 411 |
+
return best_fid, best_div, best_top1, best_top2, best_top3, best_matching, 0, clip_score_old, writer, save, dist_mean, skating_score, traj_err
|
| 412 |
+
else:
|
| 413 |
+
return best_fid, best_div, best_top1, best_top2, best_top3, best_matching, 0, clip_score_old, writer, save, None, None, None
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
@torch.no_grad()
|
| 417 |
+
def evaluation_acmdm_another_v(out_dir, val_loader, ema_acmdm, ae, writer, ep, best_fid, best_div,
|
| 418 |
+
best_top1, best_top2, best_top3, best_matching, eval_wrapper, device, clip_score_old,
|
| 419 |
+
cond_scale=None, cal_mm=False, train_mean=None, train_std=None, after_mean=None, after_std=None,
|
| 420 |
+
draw=True,
|
| 421 |
+
is_raw=False,
|
| 422 |
+
is_prefix=False,
|
| 423 |
+
is_control=False, index=[0], intensity=100,
|
| 424 |
+
is_mesh=False):
|
| 425 |
+
|
| 426 |
+
ema_acmdm.eval()
|
| 427 |
+
if not is_raw:
|
| 428 |
+
ae.eval()
|
| 429 |
+
|
| 430 |
+
save=False
|
| 431 |
+
|
| 432 |
+
motion_annotation_list = []
|
| 433 |
+
motion_pred_list = []
|
| 434 |
+
motion_multimodality = []
|
| 435 |
+
R_precision_real = 0
|
| 436 |
+
R_precision = 0
|
| 437 |
+
matching_score_real = 0
|
| 438 |
+
matching_score_pred = 0
|
| 439 |
+
multimodality = 0
|
| 440 |
+
if cond_scale is None:
|
| 441 |
+
if "kit" in out_dir:
|
| 442 |
+
cond_scale = 2.5
|
| 443 |
+
else:
|
| 444 |
+
cond_scale = 2.5
|
| 445 |
+
clip_score_real = 0
|
| 446 |
+
clip_score_gt = 0
|
| 447 |
+
skate_ratio_sum = 0
|
| 448 |
+
dist_sum = 0
|
| 449 |
+
traj_err = []
|
| 450 |
+
|
| 451 |
+
nb_sample = 0
|
| 452 |
+
if cal_mm:
|
| 453 |
+
num_mm_batch = 3
|
| 454 |
+
else:
|
| 455 |
+
num_mm_batch = 0
|
| 456 |
+
|
| 457 |
+
for i, batch in enumerate(tqdm(val_loader)):
|
| 458 |
+
word_embeddings, pos_one_hots, clip_text, sent_len, pose, m_length, token = batch
|
| 459 |
+
m_length = m_length.to(device)
|
| 460 |
+
|
| 461 |
+
bs, seq = pose.shape[:2]
|
| 462 |
+
if i < num_mm_batch:
|
| 463 |
+
motion_multimodality_batch = []
|
| 464 |
+
batch_clip_score_pred = 0
|
| 465 |
+
for _ in tqdm(range(30)):
|
| 466 |
+
pred_latents = ema_acmdm.generate(clip_text, m_length//4 if not is_raw else m_length, cond_scale)
|
| 467 |
+
|
| 468 |
+
if not is_raw:
|
| 469 |
+
pred_latents = val_loader.dataset.inv_transform(pred_latents.permute(0, 2, 3, 1).detach().cpu().numpy(),
|
| 470 |
+
after_mean, after_std)
|
| 471 |
+
pred_latents = torch.from_numpy(pred_latents).to(device)
|
| 472 |
+
with torch.no_grad():
|
| 473 |
+
pred_motions = ae.decode(pred_latents.permute(0,3,1,2))
|
| 474 |
+
else:
|
| 475 |
+
pred_motions = pred_latents.permute(0, 2, 3, 1)
|
| 476 |
+
pred_motions = val_loader.dataset.inv_transform(pred_motions.detach().cpu().numpy(), train_mean, train_std)
|
| 477 |
+
pred_motionss = []
|
| 478 |
+
for j in range(bs):
|
| 479 |
+
pred_motionss.append(back_process(pred_motions[j], is_mesh=is_mesh))
|
| 480 |
+
pred_motionss = np.stack(pred_motionss, axis=0)
|
| 481 |
+
pred_motions = val_loader.dataset.transform(pred_motionss)
|
| 482 |
+
(et_pred, em_pred), (et_pred_clip, em_pred_clip) = eval_wrapper.get_co_embeddings(word_embeddings,
|
| 483 |
+
pos_one_hots,
|
| 484 |
+
sent_len,
|
| 485 |
+
clip_text,
|
| 486 |
+
torch.from_numpy(
|
| 487 |
+
pred_motions).to(
|
| 488 |
+
device),
|
| 489 |
+
m_length - 1)
|
| 490 |
+
motion_multimodality_batch.append(em_pred.unsqueeze(1))
|
| 491 |
+
motion_multimodality_batch = torch.cat(motion_multimodality_batch, dim=1) #(bs, 30, d)
|
| 492 |
+
motion_multimodality.append(motion_multimodality_batch)
|
| 493 |
+
for j in range(bs):
|
| 494 |
+
single_em = em_pred_clip[j]
|
| 495 |
+
single_et = et_pred_clip[j]
|
| 496 |
+
clip_score = (single_em @ single_et.T).item()
|
| 497 |
+
batch_clip_score_pred += clip_score
|
| 498 |
+
clip_score_real += batch_clip_score_pred
|
| 499 |
+
else:
|
| 500 |
+
if is_control:
|
| 501 |
+
bgt = val_loader.dataset.inv_transform(pose.clone())
|
| 502 |
+
motion_gt = []
|
| 503 |
+
for j in range(bs):
|
| 504 |
+
motion_gt.append(recover_from_ric(bgt[j].float(), 22).numpy())
|
| 505 |
+
motion_gt = np.stack(motion_gt, axis=0)
|
| 506 |
+
motion = val_loader.dataset.transform(motion_gt, train_mean, train_std)
|
| 507 |
+
motion = torch.from_numpy(motion).float().to(device)
|
| 508 |
+
pred_latents, mask_hint = ema_acmdm.generate_control(clip_text, m_length//4, motion.clone().permute(0,3,1,2), index, intensity,
|
| 509 |
+
cond_scale)
|
| 510 |
+
mask_hint = mask_hint.permute(0, 2, 3, 1).cpu().numpy()
|
| 511 |
+
elif is_prefix:
|
| 512 |
+
bgt = val_loader.dataset.inv_transform(pose.clone())
|
| 513 |
+
motion_gt = []
|
| 514 |
+
for j in range(bs):
|
| 515 |
+
motion_gt.append(recover_from_ric(bgt[j].float(), 22).numpy())
|
| 516 |
+
motion_gt = np.stack(motion_gt, axis=0)
|
| 517 |
+
motion = val_loader.dataset.transform(motion_gt, train_mean, train_std)
|
| 518 |
+
motion = torch.from_numpy(motion).float().to(device)
|
| 519 |
+
with torch.no_grad():
|
| 520 |
+
motion = ae.encode(motion)
|
| 521 |
+
amean = torch.from_numpy(after_mean).to(device)
|
| 522 |
+
astd = torch.from_numpy(after_std).to(device)
|
| 523 |
+
motion = ((motion.permute(0,2,3,1)-amean)/astd).permute(0,3,1,2)
|
| 524 |
+
pred_latents = ema_acmdm.generate(clip_text, m_length // 4, cond_scale, motion[:, :, :5, :]) # 20+40 style
|
| 525 |
+
else:
|
| 526 |
+
pred_latents = ema_acmdm.generate(clip_text, m_length//4 if not is_raw else m_length, cond_scale)
|
| 527 |
+
if not is_raw:
|
| 528 |
+
pred_latents = val_loader.dataset.inv_transform(pred_latents.permute(0,2,3,1).detach().cpu().numpy(), after_mean, after_std)
|
| 529 |
+
pred_latents = torch.from_numpy(pred_latents).to(device)
|
| 530 |
+
with torch.no_grad():
|
| 531 |
+
pred_motions = ae.decode(pred_latents.permute(0,3,1,2))
|
| 532 |
+
else:
|
| 533 |
+
pred_motions = pred_latents.permute(0, 2, 3, 1)
|
| 534 |
+
pred_motions = val_loader.dataset.inv_transform(pred_motions.detach().cpu().numpy(), train_mean, train_std)
|
| 535 |
+
if is_control:
|
| 536 |
+
# foot skate
|
| 537 |
+
skate_ratio, skate_vel = calculate_skating_ratio(torch.from_numpy(pred_motions.transpose(0,2,3,1))) # [batch_size]
|
| 538 |
+
skate_ratio_sum += skate_ratio.sum()
|
| 539 |
+
# control errors
|
| 540 |
+
hint = motion_gt * mask_hint
|
| 541 |
+
for i, (mot, h, mask) in enumerate(zip(pred_motions, hint, mask_hint)):
|
| 542 |
+
control_error = control_l2(np.expand_dims(mot, axis=0), np.expand_dims(h, axis=0),
|
| 543 |
+
np.expand_dims(mask, axis=0))
|
| 544 |
+
mean_error = control_error.sum() / mask.sum()
|
| 545 |
+
dist_sum += mean_error
|
| 546 |
+
control_error = control_error.reshape(-1)
|
| 547 |
+
mask = mask.reshape(-1)
|
| 548 |
+
err_np = calculate_trajectory_error(control_error, mean_error, mask)
|
| 549 |
+
traj_err.append(err_np)
|
| 550 |
+
|
| 551 |
+
pred_motionss = []
|
| 552 |
+
for j in range(bs):
|
| 553 |
+
pred_motionss.append(back_process(pred_motions[j], is_mesh=is_mesh))
|
| 554 |
+
pred_motionss = np.stack(pred_motionss, axis=0)
|
| 555 |
+
pred_motions = val_loader.dataset.transform(pred_motionss)
|
| 556 |
+
(et_pred, em_pred), (et_pred_clip, em_pred_clip) = eval_wrapper.get_co_embeddings(word_embeddings,
|
| 557 |
+
pos_one_hots, sent_len,
|
| 558 |
+
clip_text,
|
| 559 |
+
torch.from_numpy(pred_motions).to(device),
|
| 560 |
+
m_length-1)
|
| 561 |
+
batch_clip_score_pred = 0
|
| 562 |
+
for j in range(bs):
|
| 563 |
+
single_em = em_pred_clip[j]
|
| 564 |
+
single_et = et_pred_clip[j]
|
| 565 |
+
clip_score = (single_em @ single_et.T).item()
|
| 566 |
+
batch_clip_score_pred += clip_score
|
| 567 |
+
clip_score_real += batch_clip_score_pred
|
| 568 |
+
|
| 569 |
+
pose = pose.cuda().float()
|
| 570 |
+
(et, em), (et_clip, em_clip) = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, clip_text,
|
| 571 |
+
pose.clone(), m_length-1)
|
| 572 |
+
batch_clip_score = 0
|
| 573 |
+
for j in range(bs):
|
| 574 |
+
single_em = em_clip[j]
|
| 575 |
+
single_et = et_clip[j]
|
| 576 |
+
clip_score = (single_em @ single_et.T).item()
|
| 577 |
+
batch_clip_score += clip_score
|
| 578 |
+
clip_score_gt += batch_clip_score
|
| 579 |
+
motion_annotation_list.append(em)
|
| 580 |
+
motion_pred_list.append(em_pred)
|
| 581 |
+
|
| 582 |
+
temp_R = calculate_R_precision(et.cpu().numpy(), em.cpu().numpy(), top_k=3, sum_all=True)
|
| 583 |
+
temp_match = euclidean_distance_matrix(et.cpu().numpy(), em.cpu().numpy()).trace()
|
| 584 |
+
R_precision_real += temp_R
|
| 585 |
+
matching_score_real += temp_match
|
| 586 |
+
temp_R = calculate_R_precision(et_pred.cpu().numpy(), em_pred.cpu().numpy(), top_k=3, sum_all=True)
|
| 587 |
+
temp_match = euclidean_distance_matrix(et_pred.cpu().numpy(), em_pred.cpu().numpy()).trace()
|
| 588 |
+
R_precision += temp_R
|
| 589 |
+
matching_score_pred += temp_match
|
| 590 |
+
|
| 591 |
+
nb_sample += bs
|
| 592 |
+
|
| 593 |
+
motion_annotation_np = torch.cat(motion_annotation_list, dim=0).cpu().numpy()
|
| 594 |
+
motion_pred_np = torch.cat(motion_pred_list, dim=0).cpu().numpy()
|
| 595 |
+
gt_mu, gt_cov = calculate_activation_statistics(motion_annotation_np)
|
| 596 |
+
mu, cov = calculate_activation_statistics(motion_pred_np)
|
| 597 |
+
|
| 598 |
+
diversity_real = calculate_diversity(motion_annotation_np, 300 if nb_sample > 300 else 100)
|
| 599 |
+
diversity = calculate_diversity(motion_pred_np, 300 if nb_sample > 300 else 100)
|
| 600 |
+
|
| 601 |
+
R_precision_real = R_precision_real / nb_sample
|
| 602 |
+
R_precision = R_precision / nb_sample
|
| 603 |
+
|
| 604 |
+
clip_score_real = clip_score_real / nb_sample
|
| 605 |
+
clip_score_gt = clip_score_gt / nb_sample
|
| 606 |
+
|
| 607 |
+
matching_score_real = matching_score_real / nb_sample
|
| 608 |
+
matching_score_pred = matching_score_pred / nb_sample
|
| 609 |
+
if is_control:
|
| 610 |
+
# l2 dist
|
| 611 |
+
dist_mean = dist_sum / nb_sample
|
| 612 |
+
|
| 613 |
+
# Skating evaluation
|
| 614 |
+
skating_score = skate_ratio_sum / nb_sample
|
| 615 |
+
|
| 616 |
+
### For trajecotry evaluation from GMD ###
|
| 617 |
+
traj_err = np.stack(traj_err).mean(0)
|
| 618 |
+
|
| 619 |
+
if cal_mm:
|
| 620 |
+
motion_multimodality = torch.cat(motion_multimodality, dim=0).cpu().numpy()
|
| 621 |
+
multimodality = calculate_multimodality(motion_multimodality, 10)
|
| 622 |
+
|
| 623 |
+
fid = calculate_frechet_distance(gt_mu, gt_cov, mu, cov)
|
| 624 |
+
|
| 625 |
+
msg = (f"--> \t Eva. Ep/Re {ep} :, FID. {fid:.4f}, Diversity Real. {diversity_real:.4f}, Diversity."
|
| 626 |
+
f" {diversity:.4f}, R_precision_real. {R_precision_real}, R_precision. {R_precision},"
|
| 627 |
+
f" matching_score_real. {matching_score_real}, matching_score_pred. {matching_score_pred}"
|
| 628 |
+
f" multimodality. {multimodality:.4f}, clip score. {clip_score_real}"
|
| 629 |
+
+ (f" foot skating. {skating_score:.4f}, traj error. {traj_err[1].item():.4f}, loc error. {traj_err[3].item():.4f}, avg error. {traj_err[4].item():.4f}"
|
| 630 |
+
if is_control else ""))
|
| 631 |
+
print(msg)
|
| 632 |
+
|
| 633 |
+
if draw:
|
| 634 |
+
writer.add_scalar('./Test/FID', fid, ep)
|
| 635 |
+
writer.add_scalar('./Test/Diversity', diversity, ep)
|
| 636 |
+
writer.add_scalar('./Test/top1', R_precision[0], ep)
|
| 637 |
+
writer.add_scalar('./Test/top2', R_precision[1], ep)
|
| 638 |
+
writer.add_scalar('./Test/top3', R_precision[2], ep)
|
| 639 |
+
writer.add_scalar('./Test/matching_score', matching_score_pred, ep)
|
| 640 |
+
writer.add_scalar('./Test/clip_score', clip_score_real, ep)
|
| 641 |
+
|
| 642 |
+
|
| 643 |
+
if fid < best_fid:
|
| 644 |
+
msg = f"--> --> \t FID Improved from {best_fid:.5f} to {fid:.5f} !!!"
|
| 645 |
+
if draw: print(msg)
|
| 646 |
+
best_fid, best_ep = fid, ep
|
| 647 |
+
save=True
|
| 648 |
+
|
| 649 |
+
|
| 650 |
+
if matching_score_pred < best_matching:
|
| 651 |
+
msg = f"--> --> \t matching_score Improved from {best_matching:.5f} to {matching_score_pred:.5f} !!!"
|
| 652 |
+
if draw: print(msg)
|
| 653 |
+
best_matching = matching_score_pred
|
| 654 |
+
|
| 655 |
+
if abs(diversity_real - diversity) < abs(diversity_real - best_div):
|
| 656 |
+
msg = f"--> --> \t Diversity Improved from {best_div:.5f} to {diversity:.5f} !!!"
|
| 657 |
+
if draw: print(msg)
|
| 658 |
+
best_div = diversity
|
| 659 |
+
|
| 660 |
+
if R_precision[0] > best_top1:
|
| 661 |
+
msg = f"--> --> \t Top1 Improved from {best_top1:.4f} to {R_precision[0]:.4f} !!!"
|
| 662 |
+
if draw: print(msg)
|
| 663 |
+
best_top1 = R_precision[0]
|
| 664 |
+
|
| 665 |
+
if R_precision[1] > best_top2:
|
| 666 |
+
msg = f"--> --> \t Top2 Improved from {best_top2:.4f} to {R_precision[1]:.4f} !!!"
|
| 667 |
+
if draw: print(msg)
|
| 668 |
+
best_top2 = R_precision[1]
|
| 669 |
+
|
| 670 |
+
if R_precision[2] > best_top3:
|
| 671 |
+
msg = f"--> --> \t Top3 Improved from {best_top3:.4f} to {R_precision[2]:.4f} !!!"
|
| 672 |
+
if draw: print(msg)
|
| 673 |
+
best_top3 = R_precision[2]
|
| 674 |
+
|
| 675 |
+
if clip_score_real > clip_score_old:
|
| 676 |
+
msg = f"--> --> \t CLIP-score Improved from {clip_score_old:.4f} to {clip_score_real:.4f} !!!"
|
| 677 |
+
if draw: print(msg)
|
| 678 |
+
clip_score_old = clip_score_real
|
| 679 |
+
|
| 680 |
+
if cal_mm:
|
| 681 |
+
return best_fid, best_div, best_top1, best_top2, best_top3, best_matching, multimodality, clip_score_old, writer, save
|
| 682 |
+
else:
|
| 683 |
+
if is_control:
|
| 684 |
+
return best_fid, best_div, best_top1, best_top2, best_top3, best_matching, 0, clip_score_old, writer, save, dist_mean, skating_score, traj_err
|
| 685 |
+
else:
|
| 686 |
+
return best_fid, best_div, best_top1, best_top2, best_top3, best_matching, 0, clip_score_old, writer, save, None, None, None
|
| 687 |
+
|
| 688 |
+
#################################################################################
|
| 689 |
+
# Util Functions #
|
| 690 |
+
#################################################################################
|
| 691 |
+
def eval_decorator(fn):
|
| 692 |
+
def inner(model, *args, **kwargs):
|
| 693 |
+
was_training = model.training
|
| 694 |
+
model.eval()
|
| 695 |
+
out = fn(model, *args, **kwargs)
|
| 696 |
+
model.train(was_training)
|
| 697 |
+
return out
|
| 698 |
+
return inner
|
| 699 |
+
|
| 700 |
+
#################################################################################
|
| 701 |
+
# Metrics #
|
| 702 |
+
#################################################################################
|
| 703 |
+
def calculate_mpjpe(gt_joints, pred_joints):
|
| 704 |
+
"""
|
| 705 |
+
gt_joints: num_poses x num_joints(22) x 3
|
| 706 |
+
pred_joints: num_poses x num_joints(22) x 3
|
| 707 |
+
(obtained from recover_from_ric())
|
| 708 |
+
"""
|
| 709 |
+
assert gt_joints.shape == pred_joints.shape, f"GT shape: {gt_joints.shape}, pred shape: {pred_joints.shape}"
|
| 710 |
+
|
| 711 |
+
# Align by root (pelvis)
|
| 712 |
+
pelvis = gt_joints[:, [0]].mean(1)
|
| 713 |
+
gt_joints = gt_joints - torch.unsqueeze(pelvis, dim=1)
|
| 714 |
+
pelvis = pred_joints[:, [0]].mean(1)
|
| 715 |
+
pred_joints = pred_joints - torch.unsqueeze(pelvis, dim=1)
|
| 716 |
+
|
| 717 |
+
# Compute MPJPE
|
| 718 |
+
mpjpe = torch.linalg.norm(pred_joints - gt_joints, dim=-1) # num_poses x num_joints=22
|
| 719 |
+
mpjpe_seq = mpjpe.mean(-1) # num_poses
|
| 720 |
+
|
| 721 |
+
return mpjpe_seq
|
| 722 |
+
|
| 723 |
+
# (X - X_train)*(X - X_train) = -2X*X_train + X*X + X_train*X_train
|
| 724 |
+
def euclidean_distance_matrix(matrix1, matrix2):
|
| 725 |
+
"""
|
| 726 |
+
Params:
|
| 727 |
+
-- matrix1: N1 x D
|
| 728 |
+
-- matrix2: N2 x D
|
| 729 |
+
Returns:
|
| 730 |
+
-- dist: N1 x N2
|
| 731 |
+
dist[i, j] == distance(matrix1[i], matrix2[j])
|
| 732 |
+
"""
|
| 733 |
+
assert matrix1.shape[1] == matrix2.shape[1]
|
| 734 |
+
d1 = -2 * np.dot(matrix1, matrix2.T) # shape (num_test, num_train)
|
| 735 |
+
d2 = np.sum(np.square(matrix1), axis=1, keepdims=True) # shape (num_test, 1)
|
| 736 |
+
d3 = np.sum(np.square(matrix2), axis=1) # shape (num_train, )
|
| 737 |
+
dists = np.sqrt(d1 + d2 + d3) # broadcasting
|
| 738 |
+
return dists
|
| 739 |
+
|
| 740 |
+
def calculate_top_k(mat, top_k):
|
| 741 |
+
size = mat.shape[0]
|
| 742 |
+
gt_mat = np.expand_dims(np.arange(size), 1).repeat(size, 1)
|
| 743 |
+
bool_mat = (mat == gt_mat)
|
| 744 |
+
correct_vec = False
|
| 745 |
+
top_k_list = []
|
| 746 |
+
for i in range(top_k):
|
| 747 |
+
# print(correct_vec, bool_mat[:, i])
|
| 748 |
+
correct_vec = (correct_vec | bool_mat[:, i])
|
| 749 |
+
# print(correct_vec)
|
| 750 |
+
top_k_list.append(correct_vec[:, None])
|
| 751 |
+
top_k_mat = np.concatenate(top_k_list, axis=1)
|
| 752 |
+
return top_k_mat
|
| 753 |
+
|
| 754 |
+
|
| 755 |
+
def calculate_R_precision(embedding1, embedding2, top_k, sum_all=False):
|
| 756 |
+
dist_mat = euclidean_distance_matrix(embedding1, embedding2)
|
| 757 |
+
argmax = np.argsort(dist_mat, axis=1)
|
| 758 |
+
top_k_mat = calculate_top_k(argmax, top_k)
|
| 759 |
+
if sum_all:
|
| 760 |
+
return top_k_mat.sum(axis=0)
|
| 761 |
+
else:
|
| 762 |
+
return top_k_mat
|
| 763 |
+
|
| 764 |
+
|
| 765 |
+
def calculate_matching_score(embedding1, embedding2, sum_all=False):
|
| 766 |
+
assert len(embedding1.shape) == 2
|
| 767 |
+
assert embedding1.shape[0] == embedding2.shape[0]
|
| 768 |
+
assert embedding1.shape[1] == embedding2.shape[1]
|
| 769 |
+
|
| 770 |
+
dist = linalg.norm(embedding1 - embedding2, axis=1)
|
| 771 |
+
if sum_all:
|
| 772 |
+
return dist.sum(axis=0)
|
| 773 |
+
else:
|
| 774 |
+
return dist
|
| 775 |
+
|
| 776 |
+
|
| 777 |
+
|
| 778 |
+
def calculate_activation_statistics(activations):
|
| 779 |
+
"""
|
| 780 |
+
Params:
|
| 781 |
+
-- activation: num_samples x dim_feat
|
| 782 |
+
Returns:
|
| 783 |
+
-- mu: dim_feat
|
| 784 |
+
-- sigma: dim_feat x dim_feat
|
| 785 |
+
"""
|
| 786 |
+
mu = np.mean(activations, axis=0)
|
| 787 |
+
cov = np.cov(activations, rowvar=False)
|
| 788 |
+
return mu, cov
|
| 789 |
+
|
| 790 |
+
|
| 791 |
+
def calculate_diversity(activation, diversity_times):
|
| 792 |
+
assert len(activation.shape) == 2
|
| 793 |
+
assert activation.shape[0] > diversity_times
|
| 794 |
+
num_samples = activation.shape[0]
|
| 795 |
+
|
| 796 |
+
first_indices = np.random.choice(num_samples, diversity_times, replace=False)
|
| 797 |
+
second_indices = np.random.choice(num_samples, diversity_times, replace=False)
|
| 798 |
+
dist = linalg.norm(activation[first_indices] - activation[second_indices], axis=1)
|
| 799 |
+
return dist.mean()
|
| 800 |
+
|
| 801 |
+
|
| 802 |
+
def calculate_multimodality(activation, multimodality_times):
|
| 803 |
+
assert len(activation.shape) == 3
|
| 804 |
+
assert activation.shape[1] > multimodality_times
|
| 805 |
+
num_per_sent = activation.shape[1]
|
| 806 |
+
|
| 807 |
+
first_dices = np.random.choice(num_per_sent, multimodality_times, replace=False)
|
| 808 |
+
second_dices = np.random.choice(num_per_sent, multimodality_times, replace=False)
|
| 809 |
+
dist = linalg.norm(activation[:, first_dices] - activation[:, second_dices], axis=2)
|
| 810 |
+
return dist.mean()
|
| 811 |
+
|
| 812 |
+
|
| 813 |
+
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
|
| 814 |
+
"""Numpy implementation of the Frechet Distance.
|
| 815 |
+
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
|
| 816 |
+
and X_2 ~ N(mu_2, C_2) is
|
| 817 |
+
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
|
| 818 |
+
Stable version by Dougal J. Sutherland.
|
| 819 |
+
Params:
|
| 820 |
+
-- mu1 : Numpy array containing the activations of a layer of the
|
| 821 |
+
inception net (like returned by the function 'get_predictions')
|
| 822 |
+
for generated samples.
|
| 823 |
+
-- mu2 : The sample mean over activations, precalculated on an
|
| 824 |
+
representative data set.
|
| 825 |
+
-- sigma1: The covariance matrix over activations for generated samples.
|
| 826 |
+
-- sigma2: The covariance matrix over activations, precalculated on an
|
| 827 |
+
representative data set.
|
| 828 |
+
Returns:
|
| 829 |
+
-- : The Frechet Distance.
|
| 830 |
+
"""
|
| 831 |
+
|
| 832 |
+
mu1 = np.atleast_1d(mu1)
|
| 833 |
+
mu2 = np.atleast_1d(mu2)
|
| 834 |
+
|
| 835 |
+
sigma1 = np.atleast_2d(sigma1)
|
| 836 |
+
sigma2 = np.atleast_2d(sigma2)
|
| 837 |
+
|
| 838 |
+
assert mu1.shape == mu2.shape, \
|
| 839 |
+
'Training and test mean vectors have different lengths'
|
| 840 |
+
assert sigma1.shape == sigma2.shape, \
|
| 841 |
+
'Training and test covariances have different dimensions'
|
| 842 |
+
|
| 843 |
+
diff = mu1 - mu2
|
| 844 |
+
|
| 845 |
+
# Product might be almost singular
|
| 846 |
+
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
|
| 847 |
+
if not np.isfinite(covmean).all():
|
| 848 |
+
msg = ('fid calculation produces singular product; '
|
| 849 |
+
'adding %s to diagonal of cov estimates') % eps
|
| 850 |
+
print(msg)
|
| 851 |
+
offset = np.eye(sigma1.shape[0]) * eps
|
| 852 |
+
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
|
| 853 |
+
|
| 854 |
+
# Numerical error might give slight imaginary component
|
| 855 |
+
if np.iscomplexobj(covmean):
|
| 856 |
+
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
|
| 857 |
+
m = np.max(np.abs(covmean.imag))
|
| 858 |
+
raise ValueError('Imaginary component {}'.format(m))
|
| 859 |
+
covmean = covmean.real
|
| 860 |
+
|
| 861 |
+
tr_covmean = np.trace(covmean)
|
| 862 |
+
|
| 863 |
+
return (diff.dot(diff) + np.trace(sigma1) +
|
| 864 |
+
np.trace(sigma2) - 2 * tr_covmean)
|
| 865 |
+
|
| 866 |
+
|
| 867 |
+
# directly from omnicontrol
|
| 868 |
+
def calculate_skating_ratio(motions):
|
| 869 |
+
thresh_height = 0.05 # 10
|
| 870 |
+
fps = 20.0
|
| 871 |
+
thresh_vel = 0.50 # 20 cm /s
|
| 872 |
+
avg_window = 5 # frames
|
| 873 |
+
|
| 874 |
+
batch_size = motions.shape[0]
|
| 875 |
+
# 10 left, 11 right foot. XZ plane, y up
|
| 876 |
+
# motions [bs, 22, 3, max_len]
|
| 877 |
+
verts_feet = motions[:, [10, 11], :, :].detach().cpu().numpy() # [bs, 2, 3, max_len]
|
| 878 |
+
verts_feet_plane_vel = np.linalg.norm(verts_feet[:, :, [0, 2], 1:] - verts_feet[:, :, [0, 2], :-1],
|
| 879 |
+
axis=2) * fps # [bs, 2, max_len-1]
|
| 880 |
+
# [bs, 2, max_len-1]
|
| 881 |
+
vel_avg = uniform_filter1d(verts_feet_plane_vel, axis=-1, size=avg_window, mode='constant', origin=0)
|
| 882 |
+
|
| 883 |
+
verts_feet_height = verts_feet[:, :, 1, :] # [bs, 2, max_len]
|
| 884 |
+
# If feet touch ground in agjecent frames
|
| 885 |
+
feet_contact = np.logical_and((verts_feet_height[:, :, :-1] < thresh_height),
|
| 886 |
+
(verts_feet_height[:, :, 1:] < thresh_height)) # [bs, 2, max_len - 1]
|
| 887 |
+
# skate velocity
|
| 888 |
+
skate_vel = feet_contact * vel_avg
|
| 889 |
+
|
| 890 |
+
# it must both skating in the current frame
|
| 891 |
+
skating = np.logical_and(feet_contact, (verts_feet_plane_vel > thresh_vel))
|
| 892 |
+
# and also skate in the windows of frames
|
| 893 |
+
skating = np.logical_and(skating, (vel_avg > thresh_vel))
|
| 894 |
+
|
| 895 |
+
# Both feet slide
|
| 896 |
+
skating = np.logical_or(skating[:, 0, :], skating[:, 1, :]) # [bs, max_len -1]
|
| 897 |
+
skating_ratio = np.sum(skating, axis=1) / skating.shape[1]
|
| 898 |
+
|
| 899 |
+
return skating_ratio, skate_vel
|
| 900 |
+
|
| 901 |
+
# directly from omnicontrol
|
| 902 |
+
def control_l2(motion, hint, hint_mask):
|
| 903 |
+
# motion: b, seq, 22, 3
|
| 904 |
+
# hint: b, seq, 22, 1
|
| 905 |
+
loss = np.linalg.norm((motion - hint) * hint_mask, axis=-1)
|
| 906 |
+
return loss
|
| 907 |
+
|
| 908 |
+
# directly from omnicontrol
|
| 909 |
+
def calculate_trajectory_error(dist_error, mean_err_traj, mask, strict=True):
|
| 910 |
+
''' dist_error shape [5]: error for each kps in metre
|
| 911 |
+
Two threshold: 20 cm and 50 cm.
|
| 912 |
+
If mean error in sequence is more then the threshold, fails
|
| 913 |
+
return: traj_fail(0.2), traj_fail(0.5), all_kps_fail(0.2), all_kps_fail(0.5), all_mean_err.
|
| 914 |
+
Every metrics are already averaged.
|
| 915 |
+
'''
|
| 916 |
+
# mean_err_traj = dist_error.mean(1)
|
| 917 |
+
if strict:
|
| 918 |
+
# Traj fails if any of the key frame fails
|
| 919 |
+
traj_fail_02 = 1.0 - (dist_error <= 0.2).all()
|
| 920 |
+
traj_fail_05 = 1.0 - (dist_error <= 0.5).all()
|
| 921 |
+
else:
|
| 922 |
+
# Traj fails if the mean error of all keyframes more than the threshold
|
| 923 |
+
traj_fail_02 = (mean_err_traj > 0.2)
|
| 924 |
+
traj_fail_05 = (mean_err_traj > 0.5)
|
| 925 |
+
all_fail_02 = (dist_error > 0.2).sum() / mask.sum()
|
| 926 |
+
all_fail_05 = (dist_error > 0.5).sum() / mask.sum()
|
| 927 |
+
|
| 928 |
+
return np.array([traj_fail_02, traj_fail_05, all_fail_02, all_fail_05, dist_error.sum() / mask.sum()])
|
utils/evaluators.py
ADDED
|
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import numpy as np
|
| 5 |
+
import math
|
| 6 |
+
from torch.nn.utils.rnn import pack_padded_sequence
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from utils.glove import POS_enumerator
|
| 9 |
+
import clip
|
| 10 |
+
|
| 11 |
+
#################################################################################
|
| 12 |
+
# Evaluators #
|
| 13 |
+
#################################################################################
|
| 14 |
+
def build_evaluators(dim_pose, dataset_name, dim_movement_enc_hidden, dim_movement_latent, dim_word, dim_pos_ohot, dim_text_hidden,
|
| 15 |
+
dim_coemb_hidden, dim_motion_hidden, checkpoints_dir, device):
|
| 16 |
+
movement_enc = MovementConvEncoder(dim_pose, dim_movement_enc_hidden, dim_movement_latent)
|
| 17 |
+
text_enc = TextEncoderBiGRUCo(word_size=dim_word,
|
| 18 |
+
pos_size=dim_pos_ohot,
|
| 19 |
+
hidden_size=dim_text_hidden,
|
| 20 |
+
output_size=dim_coemb_hidden,
|
| 21 |
+
device=device)
|
| 22 |
+
|
| 23 |
+
motion_enc = MotionEncoderBiGRUCo(input_size=dim_movement_latent,
|
| 24 |
+
hidden_size=dim_motion_hidden,
|
| 25 |
+
output_size=dim_coemb_hidden,
|
| 26 |
+
device=device)
|
| 27 |
+
contrast_model = MotionCLIP(dim_pose)
|
| 28 |
+
|
| 29 |
+
checkpoint = torch.load(os.path.join(checkpoints_dir, dataset_name, 'text_mot_match', 'model', 'finest.tar'),
|
| 30 |
+
map_location=device)
|
| 31 |
+
checkpoint_clip = torch.load(os.path.join(checkpoints_dir, dataset_name, 'text_mot_match_clip', 'model', 'finest.tar'),
|
| 32 |
+
map_location=device)
|
| 33 |
+
movement_enc.load_state_dict(checkpoint['movement_encoder'])
|
| 34 |
+
text_enc.load_state_dict(checkpoint['text_encoder'])
|
| 35 |
+
motion_enc.load_state_dict(checkpoint['motion_encoder'])
|
| 36 |
+
contrast_model.load_state_dict(checkpoint_clip['contrast_model'])
|
| 37 |
+
print('Loading Evaluators')
|
| 38 |
+
return text_enc, motion_enc, movement_enc, contrast_model
|
| 39 |
+
|
| 40 |
+
class Evaluators(object):
|
| 41 |
+
|
| 42 |
+
def __init__(self, dataset_name, device):
|
| 43 |
+
if dataset_name == 't2m':
|
| 44 |
+
dim_pose = 67
|
| 45 |
+
elif dataset_name == 'kit':
|
| 46 |
+
dim_pose = 64
|
| 47 |
+
else:
|
| 48 |
+
raise KeyError('Dataset not Recognized!!!')
|
| 49 |
+
|
| 50 |
+
dim_word = 300
|
| 51 |
+
dim_pos_ohot = len(POS_enumerator)
|
| 52 |
+
dim_motion_hidden = 1024
|
| 53 |
+
dim_movement_enc_hidden = 512
|
| 54 |
+
dim_movement_latent = 512
|
| 55 |
+
dim_text_hidden = 512
|
| 56 |
+
dim_coemb_hidden = 512
|
| 57 |
+
checkpoints_dir = 'checkpoints'
|
| 58 |
+
self.unit_length=4
|
| 59 |
+
|
| 60 |
+
self.text_encoder, self.motion_encoder, self.movement_encoder, self.contrast_model \
|
| 61 |
+
= build_evaluators(dim_pose, dataset_name, dim_movement_enc_hidden, dim_movement_latent, dim_word,
|
| 62 |
+
dim_pos_ohot, dim_text_hidden, dim_coemb_hidden, dim_motion_hidden, checkpoints_dir, device)
|
| 63 |
+
self.device = device
|
| 64 |
+
|
| 65 |
+
self.text_encoder.to(device)
|
| 66 |
+
self.motion_encoder.to(device)
|
| 67 |
+
self.movement_encoder.to(device)
|
| 68 |
+
self.contrast_model.to(device)
|
| 69 |
+
|
| 70 |
+
self.text_encoder.eval()
|
| 71 |
+
self.motion_encoder.eval()
|
| 72 |
+
self.movement_encoder.eval()
|
| 73 |
+
self.contrast_model.eval()
|
| 74 |
+
|
| 75 |
+
def get_co_embeddings(self, word_embs, pos_ohot, cap_lens, captions, motions, m_lens):
|
| 76 |
+
with torch.no_grad():
|
| 77 |
+
word_embs = word_embs.detach().to(self.device).float()
|
| 78 |
+
pos_ohot = pos_ohot.detach().to(self.device).float()
|
| 79 |
+
motions = motions.detach().to(self.device).float()
|
| 80 |
+
|
| 81 |
+
'''clip based'''
|
| 82 |
+
clip_em = self.contrast_model.encode_motion(motions.clone(), m_lens)
|
| 83 |
+
clip_et = self.contrast_model.encode_text(captions)
|
| 84 |
+
clip_em = clip_em / clip_em.norm(dim=1, keepdim=True)
|
| 85 |
+
clip_et = clip_et / clip_et.norm(dim=1, keepdim=True)
|
| 86 |
+
|
| 87 |
+
'''original architecture'''
|
| 88 |
+
align_idx = np.argsort(m_lens.data.tolist())[::-1].copy()
|
| 89 |
+
motions = motions[align_idx]
|
| 90 |
+
m_lens = m_lens[align_idx]
|
| 91 |
+
|
| 92 |
+
movements = self.movement_encoder(motions).detach()
|
| 93 |
+
m_lens = m_lens // self.unit_length
|
| 94 |
+
motion_embedding = self.motion_encoder(movements, m_lens)
|
| 95 |
+
|
| 96 |
+
text_embedding = self.text_encoder(word_embs, pos_ohot, cap_lens)
|
| 97 |
+
text_embedding = text_embedding[align_idx]
|
| 98 |
+
return (text_embedding, motion_embedding), (clip_et, clip_em)
|
| 99 |
+
|
| 100 |
+
def get_motion_embeddings(self, motions, m_lens):
|
| 101 |
+
with torch.no_grad():
|
| 102 |
+
motions = motions.detach().to(self.device).float()
|
| 103 |
+
'''clip based'''
|
| 104 |
+
clip_em = self.contrast_model.encode_motion(motions.clone(), m_lens)
|
| 105 |
+
clip_em = clip_em / clip_em.norm(dim=1, keepdim=True)
|
| 106 |
+
|
| 107 |
+
'''original architecture'''
|
| 108 |
+
align_idx = np.argsort(m_lens.data.tolist())[::-1].copy()
|
| 109 |
+
motions = motions[align_idx]
|
| 110 |
+
m_lens = m_lens[align_idx]
|
| 111 |
+
|
| 112 |
+
movements = self.movement_encoder(motions).detach()
|
| 113 |
+
m_lens = m_lens // self.unit_length
|
| 114 |
+
motion_embedding = self.motion_encoder(movements, m_lens)
|
| 115 |
+
return motion_embedding, clip_em
|
| 116 |
+
|
| 117 |
+
#################################################################################
|
| 118 |
+
# Inner Architectures #
|
| 119 |
+
#################################################################################
|
| 120 |
+
def init_weight(m):
|
| 121 |
+
if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose1d):
|
| 122 |
+
nn.init.xavier_normal_(m.weight)
|
| 123 |
+
if m.bias is not None:
|
| 124 |
+
nn.init.constant_(m.bias, 0)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class PositionalEncoding(nn.Module):
|
| 128 |
+
def __init__(self, d_model, max_len=300):
|
| 129 |
+
super(PositionalEncoding, self).__init__()
|
| 130 |
+
pe = torch.zeros(max_len, d_model)
|
| 131 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
| 132 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
| 133 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 134 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 135 |
+
self.register_buffer('pe', pe)
|
| 136 |
+
|
| 137 |
+
def forward(self, pos):
|
| 138 |
+
return self.pe[pos]
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class PositionalEncodingCLIP(nn.Module):
|
| 142 |
+
def __init__(self, d_model, dropout=0.0, max_len=5000):
|
| 143 |
+
super(PositionalEncodingCLIP, self).__init__()
|
| 144 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 145 |
+
pe = torch.zeros(max_len, d_model)
|
| 146 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
| 147 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
|
| 148 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 149 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 150 |
+
self.register_buffer('pe', pe)
|
| 151 |
+
|
| 152 |
+
def forward(self, x):
|
| 153 |
+
x = x + self.pe[:x.shape[1], :].unsqueeze(0)
|
| 154 |
+
return self.dropout(x)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def no_grad(nets):
|
| 158 |
+
if not isinstance(nets, list):
|
| 159 |
+
nets = [nets]
|
| 160 |
+
for net in nets:
|
| 161 |
+
if net is not None:
|
| 162 |
+
for param in net.parameters():
|
| 163 |
+
param.requires_grad = False
|
| 164 |
+
|
| 165 |
+
def lengths_to_mask(lengths, max_len):
|
| 166 |
+
mask = torch.arange(max_len, device=lengths.device).expand(len(lengths), max_len) < lengths.unsqueeze(1)
|
| 167 |
+
return mask
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
class MovementConvEncoder(nn.Module):
|
| 171 |
+
def __init__(self, input_size, hidden_size, output_size):
|
| 172 |
+
super(MovementConvEncoder, self).__init__()
|
| 173 |
+
self.main = nn.Sequential(
|
| 174 |
+
nn.Conv1d(input_size, hidden_size, 4, 2, 1),
|
| 175 |
+
nn.Dropout(0.2, inplace=True),
|
| 176 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 177 |
+
nn.Conv1d(hidden_size, output_size, 4, 2, 1),
|
| 178 |
+
nn.Dropout(0.2, inplace=True),
|
| 179 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 180 |
+
)
|
| 181 |
+
self.out_net = nn.Linear(output_size, output_size)
|
| 182 |
+
self.main.apply(init_weight)
|
| 183 |
+
self.out_net.apply(init_weight)
|
| 184 |
+
|
| 185 |
+
def forward(self, inputs):
|
| 186 |
+
inputs = inputs.permute(0, 2, 1)
|
| 187 |
+
outputs = self.main(inputs).permute(0, 2, 1)
|
| 188 |
+
return self.out_net(outputs)
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
class MovementConvDecoder(nn.Module):
|
| 192 |
+
def __init__(self, input_size, hidden_size, output_size):
|
| 193 |
+
super(MovementConvDecoder, self).__init__()
|
| 194 |
+
self.main = nn.Sequential(
|
| 195 |
+
nn.ConvTranspose1d(input_size, hidden_size, 4, 2, 1),
|
| 196 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 197 |
+
nn.ConvTranspose1d(hidden_size, output_size, 4, 2, 1),
|
| 198 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 199 |
+
)
|
| 200 |
+
self.out_net = nn.Linear(output_size, output_size)
|
| 201 |
+
|
| 202 |
+
self.main.apply(init_weight)
|
| 203 |
+
self.out_net.apply(init_weight)
|
| 204 |
+
|
| 205 |
+
def forward(self, inputs):
|
| 206 |
+
inputs = inputs.permute(0, 2, 1)
|
| 207 |
+
outputs = self.main(inputs).permute(0, 2, 1)
|
| 208 |
+
return self.out_net(outputs)
|
| 209 |
+
|
| 210 |
+
class TextEncoderBiGRUCo(nn.Module):
|
| 211 |
+
def __init__(self, word_size, pos_size, hidden_size, output_size, device):
|
| 212 |
+
super(TextEncoderBiGRUCo, self).__init__()
|
| 213 |
+
self.device = device
|
| 214 |
+
|
| 215 |
+
self.pos_emb = nn.Linear(pos_size, word_size)
|
| 216 |
+
self.input_emb = nn.Linear(word_size, hidden_size)
|
| 217 |
+
self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True)
|
| 218 |
+
self.output_net = nn.Sequential(
|
| 219 |
+
nn.Linear(hidden_size * 2, hidden_size),
|
| 220 |
+
nn.LayerNorm(hidden_size),
|
| 221 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 222 |
+
nn.Linear(hidden_size, output_size)
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
self.input_emb.apply(init_weight)
|
| 226 |
+
self.pos_emb.apply(init_weight)
|
| 227 |
+
self.output_net.apply(init_weight)
|
| 228 |
+
self.hidden_size = hidden_size
|
| 229 |
+
self.hidden = nn.Parameter(torch.randn((2, 1, self.hidden_size), requires_grad=True))
|
| 230 |
+
|
| 231 |
+
def forward(self, word_embs, pos_onehot, cap_lens):
|
| 232 |
+
num_samples = word_embs.shape[0]
|
| 233 |
+
|
| 234 |
+
pos_embs = self.pos_emb(pos_onehot)
|
| 235 |
+
inputs = word_embs + pos_embs
|
| 236 |
+
input_embs = self.input_emb(inputs)
|
| 237 |
+
hidden = self.hidden.repeat(1, num_samples, 1)
|
| 238 |
+
|
| 239 |
+
cap_lens = cap_lens.data.tolist()
|
| 240 |
+
emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True)
|
| 241 |
+
|
| 242 |
+
gru_seq, gru_last = self.gru(emb, hidden)
|
| 243 |
+
|
| 244 |
+
gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1)
|
| 245 |
+
|
| 246 |
+
return self.output_net(gru_last)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
class MotionEncoderBiGRUCo(nn.Module):
|
| 250 |
+
def __init__(self, input_size, hidden_size, output_size, device):
|
| 251 |
+
super(MotionEncoderBiGRUCo, self).__init__()
|
| 252 |
+
self.device = device
|
| 253 |
+
|
| 254 |
+
self.input_emb = nn.Linear(input_size, hidden_size)
|
| 255 |
+
self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True)
|
| 256 |
+
self.output_net = nn.Sequential(
|
| 257 |
+
nn.Linear(hidden_size*2, hidden_size),
|
| 258 |
+
nn.LayerNorm(hidden_size),
|
| 259 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 260 |
+
nn.Linear(hidden_size, output_size)
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
self.input_emb.apply(init_weight)
|
| 264 |
+
self.output_net.apply(init_weight)
|
| 265 |
+
self.hidden_size = hidden_size
|
| 266 |
+
self.hidden = nn.Parameter(torch.randn((2, 1, self.hidden_size), requires_grad=True))
|
| 267 |
+
|
| 268 |
+
def forward(self, inputs, m_lens):
|
| 269 |
+
num_samples = inputs.shape[0]
|
| 270 |
+
|
| 271 |
+
input_embs = self.input_emb(inputs)
|
| 272 |
+
hidden = self.hidden.repeat(1, num_samples, 1)
|
| 273 |
+
|
| 274 |
+
cap_lens = m_lens.data.tolist()
|
| 275 |
+
emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True)
|
| 276 |
+
|
| 277 |
+
gru_seq, gru_last = self.gru(emb, hidden)
|
| 278 |
+
|
| 279 |
+
gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1)
|
| 280 |
+
|
| 281 |
+
return self.output_net(gru_last)
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
class MotionEncoder(nn.Module):
|
| 285 |
+
def __init__(self, in_dim, latent_dim, ff_size, num_layers, num_heads, dropout, activation):
|
| 286 |
+
super().__init__()
|
| 287 |
+
self.input_feats = in_dim
|
| 288 |
+
self.latent_dim = latent_dim
|
| 289 |
+
self.ff_size = ff_size
|
| 290 |
+
self.num_layers = num_layers
|
| 291 |
+
self.num_heads = num_heads
|
| 292 |
+
self.dropout = dropout
|
| 293 |
+
self.activation = activation
|
| 294 |
+
|
| 295 |
+
self.query_token = nn.Parameter(torch.randn(1, self.latent_dim))
|
| 296 |
+
|
| 297 |
+
self.embed_motion = nn.Linear(self.input_feats, self.latent_dim)
|
| 298 |
+
self.sequence_pos_encoder = PositionalEncodingCLIP(self.latent_dim, self.dropout, max_len=2000)
|
| 299 |
+
|
| 300 |
+
seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim,
|
| 301 |
+
nhead=self.num_heads,
|
| 302 |
+
dim_feedforward=self.ff_size,
|
| 303 |
+
dropout=self.dropout,
|
| 304 |
+
activation=self.activation,)
|
| 305 |
+
self.transformer = nn.TransformerEncoder(seqTransEncoderLayer, num_layers=self.num_layers)
|
| 306 |
+
self.out_ln = nn.LayerNorm(self.latent_dim)
|
| 307 |
+
self.out = nn.Linear(self.latent_dim, 512)
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def forward(self, motion, padding_mask):
|
| 311 |
+
B, T, D = motion.shape
|
| 312 |
+
|
| 313 |
+
x_emb = self.embed_motion(motion)
|
| 314 |
+
|
| 315 |
+
emb = torch.cat([self.query_token[torch.zeros(B, dtype=torch.long, device=motion.device)][:,None], x_emb], dim=1)
|
| 316 |
+
|
| 317 |
+
padding_mask = torch.cat([torch.zeros_like(padding_mask[:, 0:1]), padding_mask], dim=1)
|
| 318 |
+
|
| 319 |
+
h = self.sequence_pos_encoder(emb)
|
| 320 |
+
h = h.permute(1, 0, 2)
|
| 321 |
+
h = self.transformer(h, src_key_padding_mask=padding_mask)
|
| 322 |
+
h = h.permute(1, 0, 2)
|
| 323 |
+
h = self.out_ln(h)
|
| 324 |
+
motion_emb = self.out(h[:,0])
|
| 325 |
+
|
| 326 |
+
return motion_emb
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
class MotionCLIP(nn.Module):
|
| 330 |
+
def __init__(self, in_dim):
|
| 331 |
+
super().__init__()
|
| 332 |
+
self.motion_encoder = MotionEncoder(in_dim, 512, 1024, 8, 8, 0.2, 'gelu')
|
| 333 |
+
clip_model, _ = clip.load("ViT-B/16", device="cpu", jit=False)
|
| 334 |
+
self.token_embedding = clip_model.token_embedding
|
| 335 |
+
self.positional_embedding = clip_model.positional_embedding
|
| 336 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
| 337 |
+
no_grad(self.token_embedding)
|
| 338 |
+
|
| 339 |
+
textTransEncoderLayer = nn.TransformerEncoderLayer(
|
| 340 |
+
d_model=512,
|
| 341 |
+
nhead=8,
|
| 342 |
+
dim_feedforward=1024,
|
| 343 |
+
dropout=0.2,
|
| 344 |
+
activation="gelu",)
|
| 345 |
+
self.textTransEncoder = nn.TransformerEncoder(
|
| 346 |
+
textTransEncoderLayer,
|
| 347 |
+
num_layers=8)
|
| 348 |
+
self.text_ln = nn.LayerNorm(512)
|
| 349 |
+
self.out = nn.Linear(512, 512)
|
| 350 |
+
|
| 351 |
+
def encode_motion(self, motion, m_lens):
|
| 352 |
+
seq_len = motion.shape[1]
|
| 353 |
+
padding_mask = ~lengths_to_mask(m_lens, seq_len)
|
| 354 |
+
motion_embedding = self.motion_encoder(motion, padding_mask.to(motion.device))
|
| 355 |
+
return motion_embedding
|
| 356 |
+
|
| 357 |
+
def encode_text(self, text):
|
| 358 |
+
device = next(self.parameters()).device
|
| 359 |
+
|
| 360 |
+
with torch.no_grad():
|
| 361 |
+
text = clip.tokenize(text, truncate=True).to(device)
|
| 362 |
+
x = self.token_embedding(text).float()
|
| 363 |
+
pe_tokens = x + self.positional_embedding.float()
|
| 364 |
+
pe_tokens = pe_tokens.permute(1,0,2)
|
| 365 |
+
out = self.textTransEncoder(pe_tokens)
|
| 366 |
+
out = out.permute(1, 0, 2)
|
| 367 |
+
out = self.text_ln(out)
|
| 368 |
+
|
| 369 |
+
out = out[torch.arange(x.shape[0]), text.argmax(dim=-1)]
|
| 370 |
+
out = self.out(out)
|
| 371 |
+
return out
|
| 372 |
+
|
| 373 |
+
def forward(self, motion, m_lens, text):
|
| 374 |
+
motion_features = self.encode_motion(motion, m_lens)
|
| 375 |
+
text_features = self.encode_text(text)
|
| 376 |
+
|
| 377 |
+
motion_features = motion_features / motion_features .norm(dim=1, keepdim=True)
|
| 378 |
+
text_features = text_features / text_features.norm(dim=1, keepdim=True)
|
| 379 |
+
|
| 380 |
+
logit_scale = self.logit_scale.exp()
|
| 381 |
+
logits_per_motion = logit_scale * motion_features @ text_features.t()
|
| 382 |
+
logits_per_text = logits_per_motion.t()
|
| 383 |
+
return logits_per_motion, logits_per_text
|
| 384 |
+
|
| 385 |
+
def forward_loss(self, motion, m_lens, text):
|
| 386 |
+
logits_per_motion, logits_per_text = self.forward(motion, m_lens, text)
|
| 387 |
+
labels = torch.arange(len(logits_per_motion)).to(logits_per_motion.device)
|
| 388 |
+
|
| 389 |
+
image_loss = F.cross_entropy(logits_per_motion, labels)
|
| 390 |
+
text_loss = F.cross_entropy(logits_per_text, labels)
|
| 391 |
+
loss = (image_loss + text_loss) / 2
|
| 392 |
+
return loss
|
utils/glove.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import pickle
|
| 3 |
+
from os.path import join as pjoin
|
| 4 |
+
|
| 5 |
+
#################################################################################
|
| 6 |
+
# GloVe #
|
| 7 |
+
#################################################################################
|
| 8 |
+
POS_enumerator = {
|
| 9 |
+
'VERB': 0,
|
| 10 |
+
'NOUN': 1,
|
| 11 |
+
'DET': 2,
|
| 12 |
+
'ADP': 3,
|
| 13 |
+
'NUM': 4,
|
| 14 |
+
'AUX': 5,
|
| 15 |
+
'PRON': 6,
|
| 16 |
+
'ADJ': 7,
|
| 17 |
+
'ADV': 8,
|
| 18 |
+
'Loc_VIP': 9,
|
| 19 |
+
'Body_VIP': 10,
|
| 20 |
+
'Obj_VIP': 11,
|
| 21 |
+
'Act_VIP': 12,
|
| 22 |
+
'Desc_VIP': 13,
|
| 23 |
+
'OTHER': 14,
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
Loc_list = ('left', 'right', 'clockwise', 'counterclockwise', 'anticlockwise', 'forward', 'back', 'backward',
|
| 27 |
+
'up', 'down', 'straight', 'curve')
|
| 28 |
+
|
| 29 |
+
Body_list = ('arm', 'chin', 'foot', 'feet', 'face', 'hand', 'mouth', 'leg', 'waist', 'eye', 'knee', 'shoulder', 'thigh')
|
| 30 |
+
|
| 31 |
+
Obj_List = ('stair', 'dumbbell', 'chair', 'window', 'floor', 'car', 'ball', 'handrail', 'baseball', 'basketball')
|
| 32 |
+
|
| 33 |
+
Act_list = ('walk', 'run', 'swing', 'pick', 'bring', 'kick', 'put', 'squat', 'throw', 'hop', 'dance', 'jump', 'turn',
|
| 34 |
+
'stumble', 'dance', 'stop', 'sit', 'lift', 'lower', 'raise', 'wash', 'stand', 'kneel', 'stroll',
|
| 35 |
+
'rub', 'bend', 'balance', 'flap', 'jog', 'shuffle', 'lean', 'rotate', 'spin', 'spread', 'climb')
|
| 36 |
+
|
| 37 |
+
Desc_list = ('slowly', 'carefully', 'fast', 'careful', 'slow', 'quickly', 'happy', 'angry', 'sad', 'happily',
|
| 38 |
+
'angrily', 'sadly')
|
| 39 |
+
|
| 40 |
+
VIP_dict = {
|
| 41 |
+
'Loc_VIP': Loc_list,
|
| 42 |
+
'Body_VIP': Body_list,
|
| 43 |
+
'Obj_VIP': Obj_List,
|
| 44 |
+
'Act_VIP': Act_list,
|
| 45 |
+
'Desc_VIP': Desc_list,
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class GloVe(object):
|
| 50 |
+
def __init__(self, meta_root, prefix):
|
| 51 |
+
vectors = np.load(pjoin(meta_root, '%s_data.npy'%prefix))
|
| 52 |
+
words = pickle.load(open(pjoin(meta_root, '%s_words.pkl'%prefix), 'rb'))
|
| 53 |
+
self.word2idx = pickle.load(open(pjoin(meta_root, '%s_idx.pkl'%prefix), 'rb'))
|
| 54 |
+
self.word2vec = {w: vectors[self.word2idx[w]] for w in words}
|
| 55 |
+
|
| 56 |
+
def _get_pos_ohot(self, pos):
|
| 57 |
+
pos_vec = np.zeros(len(POS_enumerator))
|
| 58 |
+
if pos in POS_enumerator:
|
| 59 |
+
pos_vec[POS_enumerator[pos]] = 1
|
| 60 |
+
else:
|
| 61 |
+
pos_vec[POS_enumerator['OTHER']] = 1
|
| 62 |
+
return pos_vec
|
| 63 |
+
|
| 64 |
+
def __len__(self):
|
| 65 |
+
return len(self.word2vec)
|
| 66 |
+
|
| 67 |
+
def __getitem__(self, item):
|
| 68 |
+
word, pos = item.split('/')
|
| 69 |
+
if word in self.word2vec:
|
| 70 |
+
word_vec = self.word2vec[word]
|
| 71 |
+
vip_pos = None
|
| 72 |
+
for key, values in VIP_dict.items():
|
| 73 |
+
if word in values:
|
| 74 |
+
vip_pos = key
|
| 75 |
+
break
|
| 76 |
+
if vip_pos is not None:
|
| 77 |
+
pos_vec = self._get_pos_ohot(vip_pos)
|
| 78 |
+
else:
|
| 79 |
+
pos_vec = self._get_pos_ohot(pos)
|
| 80 |
+
else:
|
| 81 |
+
word_vec = self.word2vec['unk']
|
| 82 |
+
pos_vec = self._get_pos_ohot('OTHER')
|
| 83 |
+
return word_vec, pos_vec
|
utils/mesh_mean_std/t2m/mesh_mean.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:be690561cf95997a0a69fc8b00440d0e1af2eb5b0ac1c4f3c5a45a5b0666f3bf
|
| 3 |
+
size 140
|
utils/mesh_mean_std/t2m/mesh_std.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:19f56c7aaa752639b39eb76742b5ee1b79fb71931b8412dd84ca7312398293b4
|
| 3 |
+
size 140
|
utils/motion_process.py
ADDED
|
@@ -0,0 +1,429 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import matplotlib
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
from matplotlib.animation import FuncAnimation, writers
|
| 6 |
+
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
|
| 7 |
+
import mpl_toolkits.mplot3d.axes3d as p3
|
| 8 |
+
import os
|
| 9 |
+
import io
|
| 10 |
+
try:
|
| 11 |
+
from PIL import Image
|
| 12 |
+
except ImportError:
|
| 13 |
+
Image = None
|
| 14 |
+
try:
|
| 15 |
+
import imageio
|
| 16 |
+
except ImportError:
|
| 17 |
+
imageio = None
|
| 18 |
+
|
| 19 |
+
#################################################################################
|
| 20 |
+
# Data Params #
|
| 21 |
+
#################################################################################
|
| 22 |
+
kit_kinematic_chain = [[0, 11, 12, 13, 14, 15], [0, 16, 17, 18, 19, 20], [0, 1, 2, 3, 4], [3, 5, 6, 7], [3, 8, 9, 10]]
|
| 23 |
+
t2m_kinematic_chain = [[0, 2, 5, 8, 11], [0, 1, 4, 7, 10], [0, 3, 6, 9, 12, 15], [9, 14, 17, 19, 21], [9, 13, 16, 18, 20]]
|
| 24 |
+
t2m_left_hand_chain = [[20, 22, 23, 24], [20, 34, 35, 36], [20, 25, 26, 27], [20, 31, 32, 33], [20, 28, 29, 30]]
|
| 25 |
+
t2m_right_hand_chain = [[21, 43, 44, 45], [21, 46, 47, 48], [21, 40, 41, 42], [21, 37, 38, 39], [21, 49, 50, 51]]
|
| 26 |
+
|
| 27 |
+
kit_raw_offsets = np.array(
|
| 28 |
+
[[0, 0, 0], [0, 1, 0], [0, 1, 0], [0, 1, 0], [0, 1, 0],
|
| 29 |
+
[1, 0, 0], [0, -1, 0], [0, -1, 0], [-1, 0, 0], [0, -1, 0],
|
| 30 |
+
[0, -1, 0], [1, 0, 0], [0, -1, 0], [0, -1, 0], [0, 0, 1],
|
| 31 |
+
[0, 0, 1], [-1, 0, 0], [0, -1, 0], [0, -1, 0], [0, 0, 1],
|
| 32 |
+
[0, 0, 1]])
|
| 33 |
+
t2m_raw_offsets = np.array([[0,0,0], [1,0,0], [-1,0,0], [0,1,0], [0,-1,0],
|
| 34 |
+
[0,-1,0], [0,1,0], [0,-1,0], [0,-1,0], [0,1,0],
|
| 35 |
+
[0,0,1], [0,0,1], [0,1,0], [1,0,0], [-1,0,0],
|
| 36 |
+
[0,0,1], [0,-1,0], [0,-1,0], [0,-1,0], [0,-1,0],
|
| 37 |
+
[0,-1,0], [0,-1,0]])
|
| 38 |
+
|
| 39 |
+
#################################################################################
|
| 40 |
+
# Joints Revert #
|
| 41 |
+
#################################################################################
|
| 42 |
+
def qinv(q):
|
| 43 |
+
assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
|
| 44 |
+
mask = torch.ones_like(q)
|
| 45 |
+
mask[..., 1:] = -mask[..., 1:]
|
| 46 |
+
return q * mask
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def qrot(q, v):
|
| 50 |
+
"""
|
| 51 |
+
Rotate vector(s) v about the rotation described by quaternion(s) q.
|
| 52 |
+
Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v,
|
| 53 |
+
where * denotes any number of dimensions.
|
| 54 |
+
Returns a tensor of shape (*, 3).
|
| 55 |
+
"""
|
| 56 |
+
assert q.shape[-1] == 4
|
| 57 |
+
assert v.shape[-1] == 3
|
| 58 |
+
assert q.shape[:-1] == v.shape[:-1]
|
| 59 |
+
|
| 60 |
+
original_shape = list(v.shape)
|
| 61 |
+
# print(q.shape)
|
| 62 |
+
q = q.contiguous().view(-1, 4)
|
| 63 |
+
v = v.contiguous().view(-1, 3)
|
| 64 |
+
|
| 65 |
+
qvec = q[:, 1:]
|
| 66 |
+
uv = torch.cross(qvec, v, dim=1)
|
| 67 |
+
uuv = torch.cross(qvec, uv, dim=1)
|
| 68 |
+
return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def recover_root_rot_pos(data):
|
| 72 |
+
rot_vel = data[..., 0]
|
| 73 |
+
r_rot_ang = torch.zeros_like(rot_vel).to(data.device)
|
| 74 |
+
'''Get Y-axis rotation from rotation velocity'''
|
| 75 |
+
r_rot_ang[..., 1:] = rot_vel[..., :-1]
|
| 76 |
+
r_rot_ang = torch.cumsum(r_rot_ang, dim=-1)
|
| 77 |
+
|
| 78 |
+
r_rot_quat = torch.zeros(data.shape[:-1] + (4,)).to(data.device)
|
| 79 |
+
r_rot_quat[..., 0] = torch.cos(r_rot_ang)
|
| 80 |
+
r_rot_quat[..., 2] = torch.sin(r_rot_ang)
|
| 81 |
+
|
| 82 |
+
r_pos = torch.zeros(data.shape[:-1] + (3,)).to(data.device)
|
| 83 |
+
r_pos[..., 1:, [0, 2]] = data[..., :-1, 1:3]
|
| 84 |
+
'''Add Y-axis rotation to root position'''
|
| 85 |
+
r_pos = qrot(qinv(r_rot_quat), r_pos)
|
| 86 |
+
|
| 87 |
+
r_pos = torch.cumsum(r_pos, dim=-2)
|
| 88 |
+
|
| 89 |
+
r_pos[..., 1] = data[..., 3]
|
| 90 |
+
return r_rot_quat, r_pos
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def recover_from_ric(data, joints_num):
|
| 94 |
+
r_rot_quat, r_pos = recover_root_rot_pos(data)
|
| 95 |
+
positions = data[..., 4:(joints_num - 1) * 3 + 4]
|
| 96 |
+
positions = positions.view(positions.shape[:-1] + (-1, 3))
|
| 97 |
+
|
| 98 |
+
'''Add Y-axis rotation to local joints'''
|
| 99 |
+
positions = qrot(qinv(r_rot_quat[..., None, :]).expand(positions.shape[:-1] + (4,)), positions)
|
| 100 |
+
|
| 101 |
+
'''Add root XZ to joints'''
|
| 102 |
+
positions[..., 0] += r_pos[..., 0:1]
|
| 103 |
+
positions[..., 2] += r_pos[..., 2:3]
|
| 104 |
+
|
| 105 |
+
'''Concate root and joints'''
|
| 106 |
+
positions = torch.cat([r_pos.unsqueeze(-2), positions], dim=-2)
|
| 107 |
+
|
| 108 |
+
return positions
|
| 109 |
+
|
| 110 |
+
#################################################################################
|
| 111 |
+
# Motion Plotting #
|
| 112 |
+
#################################################################################
|
| 113 |
+
def plot_3d_motion(save_path, kinematic_tree, joints, title, figsize=(10, 10), fps=120, radius=4, save_frames_dir=None):
|
| 114 |
+
# Ensure Agg backend is used (already set at module level, but ensure it's active)
|
| 115 |
+
if matplotlib.get_backend() != 'Agg':
|
| 116 |
+
matplotlib.use('Agg')
|
| 117 |
+
|
| 118 |
+
title_sp = title.split(' ')
|
| 119 |
+
if len(title_sp) > 20:
|
| 120 |
+
title = '\n'.join([' '.join(title_sp[:10]), ' '.join(title_sp[10:20]), ' '.join(title_sp[20:])])
|
| 121 |
+
elif len(title_sp) > 10:
|
| 122 |
+
title = '\n'.join([' '.join(title_sp[:10]), ' '.join(title_sp[10:])])
|
| 123 |
+
|
| 124 |
+
def init():
|
| 125 |
+
ax.set_xlim3d([-radius / 2, radius / 2])
|
| 126 |
+
ax.set_ylim3d([0, radius])
|
| 127 |
+
ax.set_zlim3d([0, radius])
|
| 128 |
+
fig.suptitle(title, fontsize=20)
|
| 129 |
+
ax.grid(b=False)
|
| 130 |
+
|
| 131 |
+
def plot_xzPlane(minx, maxx, miny, minz, maxz):
|
| 132 |
+
verts = [
|
| 133 |
+
[minx, miny, minz],
|
| 134 |
+
[minx, miny, maxz],
|
| 135 |
+
[maxx, miny, maxz],
|
| 136 |
+
[maxx, miny, minz]
|
| 137 |
+
]
|
| 138 |
+
xz_plane = Poly3DCollection([verts])
|
| 139 |
+
xz_plane.set_facecolor((0.5, 0.5, 0.5, 0.5))
|
| 140 |
+
ax.add_collection3d(xz_plane)
|
| 141 |
+
|
| 142 |
+
# Ensure joints is in the correct format: (seq_len, num_joints, 3)
|
| 143 |
+
joints = np.array(joints)
|
| 144 |
+
if joints.ndim == 3:
|
| 145 |
+
# Already in (seq_len, num_joints, 3) format
|
| 146 |
+
data = joints.copy()
|
| 147 |
+
elif joints.ndim == 2:
|
| 148 |
+
# If 2D, reshape to (seq_len, num_joints, 3)
|
| 149 |
+
# Assume it's (seq_len * num_joints, 3) or (seq_len, num_joints * 3)
|
| 150 |
+
if joints.shape[1] == 3:
|
| 151 |
+
# (seq_len * num_joints, 3) - need to infer num_joints
|
| 152 |
+
# For t2m, we expect 22 joints
|
| 153 |
+
num_joints = 22
|
| 154 |
+
seq_len = joints.shape[0] // num_joints
|
| 155 |
+
data = joints.reshape(seq_len, num_joints, 3)
|
| 156 |
+
else:
|
| 157 |
+
# (seq_len, num_joints * 3)
|
| 158 |
+
num_joints = joints.shape[1] // 3
|
| 159 |
+
data = joints.reshape(len(joints), num_joints, 3)
|
| 160 |
+
else:
|
| 161 |
+
raise ValueError(f"Invalid joints shape: {joints.shape}, expected (seq_len, num_joints, 3)")
|
| 162 |
+
|
| 163 |
+
# Check if data is valid
|
| 164 |
+
if data.size == 0:
|
| 165 |
+
raise ValueError("Invalid motion data: data is empty")
|
| 166 |
+
|
| 167 |
+
fig = plt.figure(figsize=figsize)
|
| 168 |
+
ax = p3.Axes3D(fig)
|
| 169 |
+
init()
|
| 170 |
+
MINS = data.min(axis=0).min(axis=0)
|
| 171 |
+
MAXS = data.max(axis=0).max(axis=0)
|
| 172 |
+
colors = ['red', 'blue', 'black', 'red', 'blue',
|
| 173 |
+
'darkblue', 'darkblue', 'darkblue', 'darkblue', 'darkblue',
|
| 174 |
+
'darkred', 'darkred', 'darkred', 'darkred', 'darkred']
|
| 175 |
+
frame_number = data.shape[0]
|
| 176 |
+
|
| 177 |
+
height_offset = MINS[1]
|
| 178 |
+
data[:, :, 1] -= height_offset
|
| 179 |
+
trajec = data[:, 0, [0, 2]]
|
| 180 |
+
|
| 181 |
+
data[..., 0] -= data[:, 0:1, 0]
|
| 182 |
+
data[..., 2] -= data[:, 0:1, 2]
|
| 183 |
+
|
| 184 |
+
# Recompute bounds after centering
|
| 185 |
+
MINS = data.min(axis=0).min(axis=0)
|
| 186 |
+
MAXS = data.max(axis=0).max(axis=0)
|
| 187 |
+
# Add some padding
|
| 188 |
+
center = (MINS + MAXS) / 2
|
| 189 |
+
ranges = MAXS - MINS
|
| 190 |
+
# Ensure we have a minimum range to avoid issues with very small or zero ranges
|
| 191 |
+
min_range = 0.1 # Minimum range for each axis
|
| 192 |
+
ranges = np.maximum(ranges, min_range)
|
| 193 |
+
max_range = max(ranges) * 1.2 # 20% padding
|
| 194 |
+
plot_mins = center - max_range / 2
|
| 195 |
+
plot_maxs = center + max_range / 2
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def update(index):
|
| 199 |
+
# Clear axes properly
|
| 200 |
+
ax.cla()
|
| 201 |
+
# Reapply title
|
| 202 |
+
fig.suptitle(title, fontsize=20)
|
| 203 |
+
# Reapply view settings and limits based on actual data bounds
|
| 204 |
+
ax.set_xlim3d([plot_mins[0], plot_maxs[0]])
|
| 205 |
+
ax.set_ylim3d([plot_mins[1], plot_maxs[1]])
|
| 206 |
+
ax.set_zlim3d([plot_mins[2], plot_maxs[2]])
|
| 207 |
+
ax.view_init(elev=120, azim=-90)
|
| 208 |
+
ax.dist = 7.5
|
| 209 |
+
ax.grid(False)
|
| 210 |
+
plot_xzPlane(plot_mins[0] - trajec[index, 0], plot_maxs[0] - trajec[index, 0], 0,
|
| 211 |
+
plot_mins[2] - trajec[index, 1], plot_maxs[2] - trajec[index, 1])
|
| 212 |
+
|
| 213 |
+
if index > 1:
|
| 214 |
+
ax.plot3D(trajec[:index, 0] - trajec[index, 0], np.zeros_like(trajec[:index, 0]),
|
| 215 |
+
trajec[:index, 1] - trajec[index, 1], linewidth=1.0,
|
| 216 |
+
color='blue')
|
| 217 |
+
|
| 218 |
+
for i, (chain, color) in enumerate(zip(kinematic_tree, colors)):
|
| 219 |
+
if i < len(colors):
|
| 220 |
+
if i < 5:
|
| 221 |
+
linewidth = 4.0
|
| 222 |
+
else:
|
| 223 |
+
linewidth = 2.0
|
| 224 |
+
# Ensure chain indices are valid
|
| 225 |
+
valid_chain = [idx for idx in chain if idx < data.shape[1]]
|
| 226 |
+
if len(valid_chain) > 1:
|
| 227 |
+
ax.plot3D(data[index, valid_chain, 0], data[index, valid_chain, 1], data[index, valid_chain, 2],
|
| 228 |
+
linewidth=linewidth, color=color)
|
| 229 |
+
|
| 230 |
+
plt.axis('off')
|
| 231 |
+
ax.set_xticklabels([])
|
| 232 |
+
ax.set_yticklabels([])
|
| 233 |
+
ax.set_zticklabels([])
|
| 234 |
+
|
| 235 |
+
def render_frame(frame_idx):
|
| 236 |
+
"""Render a single frame and return as PIL Image"""
|
| 237 |
+
# Create a fresh figure for each frame to avoid 3D projection issues
|
| 238 |
+
frame_fig = plt.figure(figsize=figsize, facecolor='white')
|
| 239 |
+
frame_ax = frame_fig.add_subplot(111, projection='3d')
|
| 240 |
+
frame_fig.suptitle(title, fontsize=20)
|
| 241 |
+
|
| 242 |
+
# Set limits and view
|
| 243 |
+
frame_ax.set_xlim3d([plot_mins[0], plot_maxs[0]])
|
| 244 |
+
frame_ax.set_ylim3d([plot_mins[1], plot_maxs[1]])
|
| 245 |
+
frame_ax.set_zlim3d([plot_mins[2], plot_maxs[2]])
|
| 246 |
+
frame_ax.view_init(elev=120, azim=-90)
|
| 247 |
+
frame_ax.dist = 7.5
|
| 248 |
+
frame_ax.grid(False)
|
| 249 |
+
|
| 250 |
+
# Plot ground plane
|
| 251 |
+
minx = plot_mins[0] - trajec[frame_idx, 0]
|
| 252 |
+
maxx = plot_maxs[0] - trajec[frame_idx, 0]
|
| 253 |
+
minz = plot_mins[2] - trajec[frame_idx, 1]
|
| 254 |
+
maxz = plot_maxs[2] - trajec[frame_idx, 1]
|
| 255 |
+
verts = [[minx, 0, minz], [minx, 0, maxz], [maxx, 0, maxz], [maxx, 0, minz]]
|
| 256 |
+
xz_plane = Poly3DCollection([verts])
|
| 257 |
+
xz_plane.set_facecolor((0.5, 0.5, 0.5, 0.5))
|
| 258 |
+
frame_ax.add_collection3d(xz_plane)
|
| 259 |
+
|
| 260 |
+
# Plot trajectory
|
| 261 |
+
if frame_idx > 1:
|
| 262 |
+
frame_ax.plot3D(trajec[:frame_idx, 0] - trajec[frame_idx, 0], np.zeros_like(trajec[:frame_idx, 0]),
|
| 263 |
+
trajec[:frame_idx, 1] - trajec[frame_idx, 1], linewidth=1.0,
|
| 264 |
+
color='blue')
|
| 265 |
+
|
| 266 |
+
# Plot skeleton
|
| 267 |
+
for i, (chain, color) in enumerate(zip(kinematic_tree, colors)):
|
| 268 |
+
if i < len(colors):
|
| 269 |
+
if i < 5:
|
| 270 |
+
linewidth = 4.0
|
| 271 |
+
else:
|
| 272 |
+
linewidth = 2.0
|
| 273 |
+
valid_chain = [idx for idx in chain if idx < data.shape[1]]
|
| 274 |
+
if len(valid_chain) > 1:
|
| 275 |
+
frame_ax.plot3D(data[frame_idx, valid_chain, 0], data[frame_idx, valid_chain, 1],
|
| 276 |
+
data[frame_idx, valid_chain, 2], linewidth=linewidth, color=color)
|
| 277 |
+
|
| 278 |
+
plt.axis('off')
|
| 279 |
+
frame_ax.set_xticklabels([])
|
| 280 |
+
frame_ax.set_yticklabels([])
|
| 281 |
+
frame_ax.set_zticklabels([])
|
| 282 |
+
|
| 283 |
+
# Convert to image - copy data so it's independent of the buffer
|
| 284 |
+
buf = io.BytesIO()
|
| 285 |
+
frame_fig.savefig(buf, format='png', dpi=100, bbox_inches='tight', facecolor='white', edgecolor='none')
|
| 286 |
+
buf.seek(0)
|
| 287 |
+
# Create image and convert to RGB for GIF compatibility
|
| 288 |
+
img = Image.open(buf)
|
| 289 |
+
if img.mode != 'RGB':
|
| 290 |
+
img = img.convert('RGB')
|
| 291 |
+
img_copy = img.copy() # Create a copy that's independent of the buffer
|
| 292 |
+
buf.close()
|
| 293 |
+
plt.close(frame_fig)
|
| 294 |
+
return img_copy
|
| 295 |
+
|
| 296 |
+
# Always use frame-by-frame rendering for reliability (works with both GIF and MP4)
|
| 297 |
+
# This ensures 3D plots render correctly
|
| 298 |
+
actual_path = save_path
|
| 299 |
+
|
| 300 |
+
# Note: We'll use frame-by-frame rendering for both MP4 and GIF
|
| 301 |
+
# imageio-ffmpeg will be used for MP4 if available
|
| 302 |
+
|
| 303 |
+
# Use frame-by-frame rendering (works reliably for GIF)
|
| 304 |
+
if Image is None:
|
| 305 |
+
raise RuntimeError("PIL/Pillow is required for GIF generation. Please install: pip install Pillow")
|
| 306 |
+
|
| 307 |
+
# Use provided frames directory or create one for debugging
|
| 308 |
+
frames_dir = save_frames_dir
|
| 309 |
+
if frames_dir is not None:
|
| 310 |
+
os.makedirs(frames_dir, exist_ok=True)
|
| 311 |
+
print(f"Saving individual frames to: {frames_dir}")
|
| 312 |
+
|
| 313 |
+
frames = []
|
| 314 |
+
print(f"Rendering {frame_number} frames...")
|
| 315 |
+
for i in range(frame_number):
|
| 316 |
+
if (i + 1) % 20 == 0:
|
| 317 |
+
print(f" Frame {i+1}/{frame_number}")
|
| 318 |
+
frame_img = render_frame(i)
|
| 319 |
+
frames.append(frame_img)
|
| 320 |
+
|
| 321 |
+
# Save individual frame as PNG for debugging
|
| 322 |
+
if frames_dir is not None:
|
| 323 |
+
frame_path = os.path.join(frames_dir, f"frame_{i:04d}.png")
|
| 324 |
+
frame_img.save(frame_path)
|
| 325 |
+
if i == 0 or i == frame_number - 1:
|
| 326 |
+
print(f" Saved frame {i} to {frame_path}")
|
| 327 |
+
|
| 328 |
+
# Save video - prefer MP4 if imageio-ffmpeg is available, otherwise GIF
|
| 329 |
+
if len(frames) > 0:
|
| 330 |
+
# Ensure all frames are in the same mode and size
|
| 331 |
+
frames_rgb = []
|
| 332 |
+
first_frame = frames[0]
|
| 333 |
+
if first_frame.mode != 'RGB':
|
| 334 |
+
first_frame = first_frame.convert('RGB')
|
| 335 |
+
|
| 336 |
+
# Get size from first frame
|
| 337 |
+
target_size = first_frame.size
|
| 338 |
+
frames_rgb.append(first_frame)
|
| 339 |
+
|
| 340 |
+
# Convert and resize all other frames to match
|
| 341 |
+
for frame in frames[1:]:
|
| 342 |
+
if frame.mode != 'RGB':
|
| 343 |
+
frame = frame.convert('RGB')
|
| 344 |
+
# Ensure all frames are the same size
|
| 345 |
+
if frame.size != target_size:
|
| 346 |
+
frame = frame.resize(target_size, Image.Resampling.LANCZOS)
|
| 347 |
+
frames_rgb.append(frame)
|
| 348 |
+
|
| 349 |
+
# Convert PIL Images to numpy arrays for imageio
|
| 350 |
+
frame_arrays = [np.array(frame) for frame in frames_rgb]
|
| 351 |
+
|
| 352 |
+
# Try to save as MP4 first (better quality and compatibility)
|
| 353 |
+
if actual_path.endswith('.mp4'):
|
| 354 |
+
if imageio is not None:
|
| 355 |
+
try:
|
| 356 |
+
# Use imageio-ffmpeg for MP4 (automatically uses ffmpeg if available)
|
| 357 |
+
imageio.mimsave(actual_path, frame_arrays, fps=fps, codec='libx264', quality=8)
|
| 358 |
+
print(f"Saved {len(frames_rgb)} frames to MP4 using imageio-ffmpeg: {actual_path}")
|
| 359 |
+
except Exception as e:
|
| 360 |
+
print(f"Error saving MP4 with imageio-ffmpeg: {e}")
|
| 361 |
+
print("Falling back to GIF format...")
|
| 362 |
+
# Fall back to GIF
|
| 363 |
+
base_path = os.path.splitext(actual_path)[0]
|
| 364 |
+
actual_path = base_path + '.gif'
|
| 365 |
+
if imageio is not None:
|
| 366 |
+
try:
|
| 367 |
+
imageio.mimsave(actual_path, frame_arrays, duration=1.0/fps, loop=0)
|
| 368 |
+
print(f"Saved {len(frames_rgb)} frames to GIF using imageio: {actual_path}")
|
| 369 |
+
except Exception as e2:
|
| 370 |
+
print(f"imageio GIF failed, using PIL: {e2}")
|
| 371 |
+
frames_rgb[0].save(
|
| 372 |
+
actual_path,
|
| 373 |
+
save_all=True,
|
| 374 |
+
append_images=frames_rgb[1:],
|
| 375 |
+
duration=int(1000 / fps),
|
| 376 |
+
loop=0,
|
| 377 |
+
optimize=False
|
| 378 |
+
)
|
| 379 |
+
else:
|
| 380 |
+
frames_rgb[0].save(
|
| 381 |
+
actual_path,
|
| 382 |
+
save_all=True,
|
| 383 |
+
append_images=frames_rgb[1:],
|
| 384 |
+
duration=int(1000 / fps),
|
| 385 |
+
loop=0,
|
| 386 |
+
optimize=False
|
| 387 |
+
)
|
| 388 |
+
else:
|
| 389 |
+
print("imageio not available, cannot save MP4. Falling back to GIF...")
|
| 390 |
+
base_path = os.path.splitext(actual_path)[0]
|
| 391 |
+
actual_path = base_path + '.gif'
|
| 392 |
+
frames_rgb[0].save(
|
| 393 |
+
actual_path,
|
| 394 |
+
save_all=True,
|
| 395 |
+
append_images=frames_rgb[1:],
|
| 396 |
+
duration=int(1000 / fps),
|
| 397 |
+
loop=0,
|
| 398 |
+
optimize=False
|
| 399 |
+
)
|
| 400 |
+
elif actual_path.endswith('.gif'):
|
| 401 |
+
# Save as GIF
|
| 402 |
+
if imageio is not None:
|
| 403 |
+
try:
|
| 404 |
+
imageio.mimsave(actual_path, frame_arrays, duration=1.0/fps, loop=0)
|
| 405 |
+
print(f"Saved {len(frames_rgb)} frames to GIF using imageio: {actual_path}")
|
| 406 |
+
except Exception as e:
|
| 407 |
+
print(f"imageio failed, using PIL: {e}")
|
| 408 |
+
frames_rgb[0].save(
|
| 409 |
+
actual_path,
|
| 410 |
+
save_all=True,
|
| 411 |
+
append_images=frames_rgb[1:],
|
| 412 |
+
duration=int(1000 / fps),
|
| 413 |
+
loop=0,
|
| 414 |
+
optimize=False
|
| 415 |
+
)
|
| 416 |
+
else:
|
| 417 |
+
frames_rgb[0].save(
|
| 418 |
+
actual_path,
|
| 419 |
+
save_all=True,
|
| 420 |
+
append_images=frames_rgb[1:],
|
| 421 |
+
duration=int(1000 / fps),
|
| 422 |
+
loop=0,
|
| 423 |
+
optimize=False
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
if frames_dir is not None:
|
| 427 |
+
print(f"Saved {len(frames)} individual frames to {frames_dir}")
|
| 428 |
+
|
| 429 |
+
return actual_path
|
utils/quaternion.py
ADDED
|
@@ -0,0 +1,519 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2018-present, Facebook, Inc.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
#
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
_EPS4 = np.finfo(float).eps * 4.0
|
| 12 |
+
|
| 13 |
+
_FLOAT_EPS = np.finfo(np.float64).eps
|
| 14 |
+
|
| 15 |
+
# PyTorch-backed implementations
|
| 16 |
+
def qinv(q):
|
| 17 |
+
assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
|
| 18 |
+
mask = torch.ones_like(q)
|
| 19 |
+
mask[..., 1:] = -mask[..., 1:]
|
| 20 |
+
return q * mask
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def qinv_np(q):
|
| 24 |
+
assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
|
| 25 |
+
return qinv(torch.from_numpy(q).float()).numpy()
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def qnormalize(q):
|
| 29 |
+
assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
|
| 30 |
+
return q / torch.norm(q, dim=-1, keepdim=True)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def qmul(q, r):
|
| 34 |
+
"""
|
| 35 |
+
Multiply quaternion(s) q with quaternion(s) r.
|
| 36 |
+
Expects two equally-sized tensors of shape (*, 4), where * denotes any number of dimensions.
|
| 37 |
+
Returns q*r as a tensor of shape (*, 4).
|
| 38 |
+
"""
|
| 39 |
+
assert q.shape[-1] == 4
|
| 40 |
+
assert r.shape[-1] == 4
|
| 41 |
+
|
| 42 |
+
original_shape = q.shape
|
| 43 |
+
|
| 44 |
+
# Compute outer product
|
| 45 |
+
terms = torch.bmm(r.view(-1, 4, 1), q.view(-1, 1, 4))
|
| 46 |
+
|
| 47 |
+
w = terms[:, 0, 0] - terms[:, 1, 1] - terms[:, 2, 2] - terms[:, 3, 3]
|
| 48 |
+
x = terms[:, 0, 1] + terms[:, 1, 0] - terms[:, 2, 3] + terms[:, 3, 2]
|
| 49 |
+
y = terms[:, 0, 2] + terms[:, 1, 3] + terms[:, 2, 0] - terms[:, 3, 1]
|
| 50 |
+
z = terms[:, 0, 3] - terms[:, 1, 2] + terms[:, 2, 1] + terms[:, 3, 0]
|
| 51 |
+
return torch.stack((w, x, y, z), dim=1).view(original_shape)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def qrot(q, v):
|
| 55 |
+
"""
|
| 56 |
+
Rotate vector(s) v about the rotation described by quaternion(s) q.
|
| 57 |
+
Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v,
|
| 58 |
+
where * denotes any number of dimensions.
|
| 59 |
+
Returns a tensor of shape (*, 3).
|
| 60 |
+
"""
|
| 61 |
+
assert q.shape[-1] == 4
|
| 62 |
+
assert v.shape[-1] == 3
|
| 63 |
+
assert q.shape[:-1] == v.shape[:-1]
|
| 64 |
+
|
| 65 |
+
original_shape = list(v.shape)
|
| 66 |
+
# print(q.shape)
|
| 67 |
+
q = q.contiguous().view(-1, 4)
|
| 68 |
+
v = v.contiguous().view(-1, 3)
|
| 69 |
+
|
| 70 |
+
qvec = q[:, 1:]
|
| 71 |
+
uv = torch.cross(qvec, v, dim=1)
|
| 72 |
+
uuv = torch.cross(qvec, uv, dim=1)
|
| 73 |
+
return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def qeuler(q, order, epsilon=0, deg=True, follow_order=True):
|
| 77 |
+
"""
|
| 78 |
+
Convert quaternion(s) q to Euler angles.
|
| 79 |
+
Expects a tensor of shape (*, 4), where * denotes any number of dimensions.
|
| 80 |
+
Returns a tensor of shape (*, 3).
|
| 81 |
+
"""
|
| 82 |
+
assert q.shape[-1] == 4
|
| 83 |
+
|
| 84 |
+
original_shape = list(q.shape)
|
| 85 |
+
original_shape[-1] = 3
|
| 86 |
+
q = q.view(-1, 4)
|
| 87 |
+
|
| 88 |
+
q0 = q[:, 0]
|
| 89 |
+
q1 = q[:, 1]
|
| 90 |
+
q2 = q[:, 2]
|
| 91 |
+
q3 = q[:, 3]
|
| 92 |
+
|
| 93 |
+
if order == 'xyz':
|
| 94 |
+
x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
|
| 95 |
+
y = torch.asin(torch.clamp(2 * (q1 * q3 + q0 * q2), -1 + epsilon, 1 - epsilon))
|
| 96 |
+
z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3))
|
| 97 |
+
elif order == 'yzx':
|
| 98 |
+
x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
|
| 99 |
+
y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3))
|
| 100 |
+
z = torch.asin(torch.clamp(2 * (q1 * q2 + q0 * q3), -1 + epsilon, 1 - epsilon))
|
| 101 |
+
elif order == 'zxy':
|
| 102 |
+
x = torch.asin(torch.clamp(2 * (q0 * q1 + q2 * q3), -1 + epsilon, 1 - epsilon))
|
| 103 |
+
y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
|
| 104 |
+
z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q1 * q1 + q3 * q3))
|
| 105 |
+
elif order == 'xzy':
|
| 106 |
+
x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
|
| 107 |
+
y = torch.atan2(2 * (q0 * q2 + q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3))
|
| 108 |
+
z = torch.asin(torch.clamp(2 * (q0 * q3 - q1 * q2), -1 + epsilon, 1 - epsilon))
|
| 109 |
+
elif order == 'yxz':
|
| 110 |
+
x = torch.asin(torch.clamp(2 * (q0 * q1 - q2 * q3), -1 + epsilon, 1 - epsilon))
|
| 111 |
+
y = torch.atan2(2 * (q1 * q3 + q0 * q2), 1 - 2 * (q1 * q1 + q2 * q2))
|
| 112 |
+
z = torch.atan2(2 * (q1 * q2 + q0 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
|
| 113 |
+
elif order == 'zyx':
|
| 114 |
+
x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
|
| 115 |
+
y = torch.asin(torch.clamp(2 * (q0 * q2 - q1 * q3), -1 + epsilon, 1 - epsilon))
|
| 116 |
+
z = torch.atan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3))
|
| 117 |
+
else:
|
| 118 |
+
raise
|
| 119 |
+
resdict = {"x":x, "y":y, "z":z}
|
| 120 |
+
|
| 121 |
+
# print(order)
|
| 122 |
+
reslist = [resdict[order[i]] for i in range(len(order))] if follow_order else [x, y, z]
|
| 123 |
+
# print(reslist)
|
| 124 |
+
if deg:
|
| 125 |
+
return torch.stack(reslist, dim=1).view(original_shape) * 180 / np.pi
|
| 126 |
+
else:
|
| 127 |
+
return torch.stack(reslist, dim=1).view(original_shape)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
# Numpy-backed implementations
|
| 131 |
+
|
| 132 |
+
def qmul_np(q, r):
|
| 133 |
+
q = torch.from_numpy(q).contiguous().float()
|
| 134 |
+
r = torch.from_numpy(r).contiguous().float()
|
| 135 |
+
return qmul(q, r).numpy()
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def qrot_np(q, v):
|
| 139 |
+
q = torch.from_numpy(q).contiguous().float()
|
| 140 |
+
v = torch.from_numpy(v).contiguous().float()
|
| 141 |
+
return qrot(q, v).numpy()
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def qeuler_np(q, order, epsilon=0, use_gpu=False):
|
| 145 |
+
if use_gpu:
|
| 146 |
+
q = torch.from_numpy(q).cuda().float()
|
| 147 |
+
return qeuler(q, order, epsilon).cpu().numpy()
|
| 148 |
+
else:
|
| 149 |
+
q = torch.from_numpy(q).contiguous().float()
|
| 150 |
+
return qeuler(q, order, epsilon).numpy()
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def qfix(q):
|
| 154 |
+
"""
|
| 155 |
+
Enforce quaternion continuity across the time dimension by selecting
|
| 156 |
+
the representation (q or -q) with minimal distance (or, equivalently, maximal dot product)
|
| 157 |
+
between two consecutive frames.
|
| 158 |
+
|
| 159 |
+
Expects a tensor of shape (L, J, 4), where L is the sequence length and J is the number of joints.
|
| 160 |
+
Returns a tensor of the same shape.
|
| 161 |
+
"""
|
| 162 |
+
assert len(q.shape) == 3
|
| 163 |
+
assert q.shape[-1] == 4
|
| 164 |
+
|
| 165 |
+
result = q.copy()
|
| 166 |
+
dot_products = np.sum(q[1:] * q[:-1], axis=2)
|
| 167 |
+
mask = dot_products < 0
|
| 168 |
+
mask = (np.cumsum(mask, axis=0) % 2).astype(bool)
|
| 169 |
+
result[1:][mask] *= -1
|
| 170 |
+
return result
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def euler2quat(e, order, deg=True):
|
| 174 |
+
"""
|
| 175 |
+
Convert Euler angles to quaternions.
|
| 176 |
+
"""
|
| 177 |
+
assert e.shape[-1] == 3
|
| 178 |
+
|
| 179 |
+
original_shape = list(e.shape)
|
| 180 |
+
original_shape[-1] = 4
|
| 181 |
+
|
| 182 |
+
e = e.view(-1, 3)
|
| 183 |
+
|
| 184 |
+
## if euler angles in degrees
|
| 185 |
+
if deg:
|
| 186 |
+
e = e * np.pi / 180.
|
| 187 |
+
|
| 188 |
+
x = e[:, 0]
|
| 189 |
+
y = e[:, 1]
|
| 190 |
+
z = e[:, 2]
|
| 191 |
+
|
| 192 |
+
rx = torch.stack((torch.cos(x / 2), torch.sin(x / 2), torch.zeros_like(x), torch.zeros_like(x)), dim=1)
|
| 193 |
+
ry = torch.stack((torch.cos(y / 2), torch.zeros_like(y), torch.sin(y / 2), torch.zeros_like(y)), dim=1)
|
| 194 |
+
rz = torch.stack((torch.cos(z / 2), torch.zeros_like(z), torch.zeros_like(z), torch.sin(z / 2)), dim=1)
|
| 195 |
+
|
| 196 |
+
result = None
|
| 197 |
+
for coord in order:
|
| 198 |
+
if coord == 'x':
|
| 199 |
+
r = rx
|
| 200 |
+
elif coord == 'y':
|
| 201 |
+
r = ry
|
| 202 |
+
elif coord == 'z':
|
| 203 |
+
r = rz
|
| 204 |
+
else:
|
| 205 |
+
raise
|
| 206 |
+
if result is None:
|
| 207 |
+
result = r
|
| 208 |
+
else:
|
| 209 |
+
result = qmul(result, r)
|
| 210 |
+
|
| 211 |
+
# Reverse antipodal representation to have a non-negative "w"
|
| 212 |
+
if order in ['xyz', 'yzx', 'zxy']:
|
| 213 |
+
result *= -1
|
| 214 |
+
|
| 215 |
+
return result.view(original_shape)
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def expmap_to_quaternion(e):
|
| 219 |
+
"""
|
| 220 |
+
Convert axis-angle rotations (aka exponential maps) to quaternions.
|
| 221 |
+
Stable formula from "Practical Parameterization of Rotations Using the Exponential Map".
|
| 222 |
+
Expects a tensor of shape (*, 3), where * denotes any number of dimensions.
|
| 223 |
+
Returns a tensor of shape (*, 4).
|
| 224 |
+
"""
|
| 225 |
+
assert e.shape[-1] == 3
|
| 226 |
+
|
| 227 |
+
original_shape = list(e.shape)
|
| 228 |
+
original_shape[-1] = 4
|
| 229 |
+
e = e.reshape(-1, 3)
|
| 230 |
+
|
| 231 |
+
theta = np.linalg.norm(e, axis=1).reshape(-1, 1)
|
| 232 |
+
w = np.cos(0.5 * theta).reshape(-1, 1)
|
| 233 |
+
xyz = 0.5 * np.sinc(0.5 * theta / np.pi) * e
|
| 234 |
+
return np.concatenate((w, xyz), axis=1).reshape(original_shape)
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def euler_to_quaternion(e, order):
|
| 238 |
+
"""
|
| 239 |
+
Convert Euler angles to quaternions.
|
| 240 |
+
"""
|
| 241 |
+
assert e.shape[-1] == 3
|
| 242 |
+
|
| 243 |
+
original_shape = list(e.shape)
|
| 244 |
+
original_shape[-1] = 4
|
| 245 |
+
|
| 246 |
+
e = e.reshape(-1, 3)
|
| 247 |
+
|
| 248 |
+
x = e[:, 0]
|
| 249 |
+
y = e[:, 1]
|
| 250 |
+
z = e[:, 2]
|
| 251 |
+
|
| 252 |
+
rx = np.stack((np.cos(x / 2), np.sin(x / 2), np.zeros_like(x), np.zeros_like(x)), axis=1)
|
| 253 |
+
ry = np.stack((np.cos(y / 2), np.zeros_like(y), np.sin(y / 2), np.zeros_like(y)), axis=1)
|
| 254 |
+
rz = np.stack((np.cos(z / 2), np.zeros_like(z), np.zeros_like(z), np.sin(z / 2)), axis=1)
|
| 255 |
+
|
| 256 |
+
result = None
|
| 257 |
+
for coord in order:
|
| 258 |
+
if coord == 'x':
|
| 259 |
+
r = rx
|
| 260 |
+
elif coord == 'y':
|
| 261 |
+
r = ry
|
| 262 |
+
elif coord == 'z':
|
| 263 |
+
r = rz
|
| 264 |
+
else:
|
| 265 |
+
raise
|
| 266 |
+
if result is None:
|
| 267 |
+
result = r
|
| 268 |
+
else:
|
| 269 |
+
result = qmul_np(result, r)
|
| 270 |
+
|
| 271 |
+
# Reverse antipodal representation to have a non-negative "w"
|
| 272 |
+
if order in ['xyz', 'yzx', 'zxy']:
|
| 273 |
+
result *= -1
|
| 274 |
+
|
| 275 |
+
return result.reshape(original_shape)
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def quaternion_to_matrix(quaternions):
|
| 279 |
+
"""
|
| 280 |
+
Convert rotations given as quaternions to rotation matrices.
|
| 281 |
+
Args:
|
| 282 |
+
quaternions: quaternions with real part first,
|
| 283 |
+
as tensor of shape (..., 4).
|
| 284 |
+
Returns:
|
| 285 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
| 286 |
+
"""
|
| 287 |
+
r, i, j, k = torch.unbind(quaternions, -1)
|
| 288 |
+
two_s = 2.0 / (quaternions * quaternions).sum(-1)
|
| 289 |
+
|
| 290 |
+
o = torch.stack(
|
| 291 |
+
(
|
| 292 |
+
1 - two_s * (j * j + k * k),
|
| 293 |
+
two_s * (i * j - k * r),
|
| 294 |
+
two_s * (i * k + j * r),
|
| 295 |
+
two_s * (i * j + k * r),
|
| 296 |
+
1 - two_s * (i * i + k * k),
|
| 297 |
+
two_s * (j * k - i * r),
|
| 298 |
+
two_s * (i * k - j * r),
|
| 299 |
+
two_s * (j * k + i * r),
|
| 300 |
+
1 - two_s * (i * i + j * j),
|
| 301 |
+
),
|
| 302 |
+
-1,
|
| 303 |
+
)
|
| 304 |
+
return o.reshape(quaternions.shape[:-1] + (3, 3))
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
def quaternion_to_matrix_np(quaternions):
|
| 308 |
+
q = torch.from_numpy(quaternions).contiguous().float()
|
| 309 |
+
return quaternion_to_matrix(q).numpy()
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def quaternion_to_cont6d_np(quaternions):
|
| 313 |
+
rotation_mat = quaternion_to_matrix_np(quaternions)
|
| 314 |
+
cont_6d = np.concatenate([rotation_mat[..., 0], rotation_mat[..., 1]], axis=-1)
|
| 315 |
+
return cont_6d
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
def quaternion_to_cont6d(quaternions):
|
| 319 |
+
rotation_mat = quaternion_to_matrix(quaternions)
|
| 320 |
+
cont_6d = torch.cat([rotation_mat[..., 0], rotation_mat[..., 1]], dim=-1)
|
| 321 |
+
return cont_6d
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
def cont6d_to_matrix(cont6d):
|
| 325 |
+
assert cont6d.shape[-1] == 6, "The last dimension must be 6"
|
| 326 |
+
x_raw = cont6d[..., 0:3]
|
| 327 |
+
y_raw = cont6d[..., 3:6]
|
| 328 |
+
|
| 329 |
+
x = x_raw / torch.norm(x_raw, dim=-1, keepdim=True)
|
| 330 |
+
z = torch.cross(x, y_raw, dim=-1)
|
| 331 |
+
z = z / torch.norm(z, dim=-1, keepdim=True)
|
| 332 |
+
|
| 333 |
+
y = torch.cross(z, x, dim=-1)
|
| 334 |
+
|
| 335 |
+
x = x[..., None]
|
| 336 |
+
y = y[..., None]
|
| 337 |
+
z = z[..., None]
|
| 338 |
+
|
| 339 |
+
mat = torch.cat([x, y, z], dim=-1)
|
| 340 |
+
return mat
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
def cont6d_to_matrix_np(cont6d):
|
| 344 |
+
q = torch.from_numpy(cont6d).contiguous().float()
|
| 345 |
+
return cont6d_to_matrix(q).numpy()
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
def qpow(q0, t, dtype=torch.float):
|
| 349 |
+
''' q0 : tensor of quaternions
|
| 350 |
+
t: tensor of powers
|
| 351 |
+
'''
|
| 352 |
+
q0 = qnormalize(q0)
|
| 353 |
+
theta0 = torch.acos(q0[..., 0])
|
| 354 |
+
|
| 355 |
+
## if theta0 is close to zero, add epsilon to avoid NaNs
|
| 356 |
+
mask = (theta0 <= 10e-10) * (theta0 >= -10e-10)
|
| 357 |
+
theta0 = (1 - mask) * theta0 + mask * 10e-10
|
| 358 |
+
v0 = q0[..., 1:] / torch.sin(theta0).view(-1, 1)
|
| 359 |
+
|
| 360 |
+
if isinstance(t, torch.Tensor):
|
| 361 |
+
q = torch.zeros(t.shape + q0.shape)
|
| 362 |
+
theta = t.view(-1, 1) * theta0.view(1, -1)
|
| 363 |
+
else: ## if t is a number
|
| 364 |
+
q = torch.zeros(q0.shape)
|
| 365 |
+
theta = t * theta0
|
| 366 |
+
|
| 367 |
+
q[..., 0] = torch.cos(theta)
|
| 368 |
+
q[..., 1:] = v0 * torch.sin(theta).unsqueeze(-1)
|
| 369 |
+
|
| 370 |
+
return q.to(dtype)
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
def qslerp(q0, q1, t):
|
| 374 |
+
'''
|
| 375 |
+
q0: starting quaternion
|
| 376 |
+
q1: ending quaternion
|
| 377 |
+
t: array of points along the way
|
| 378 |
+
|
| 379 |
+
Returns:
|
| 380 |
+
Tensor of Slerps: t.shape + q0.shape
|
| 381 |
+
'''
|
| 382 |
+
|
| 383 |
+
q0 = qnormalize(q0)
|
| 384 |
+
q1 = qnormalize(q1)
|
| 385 |
+
q_ = qpow(qmul(q1, qinv(q0)), t)
|
| 386 |
+
|
| 387 |
+
return qmul(q_,
|
| 388 |
+
q0.contiguous().view(torch.Size([1] * len(t.shape)) + q0.shape).expand(t.shape + q0.shape).contiguous())
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
def qbetween(v0, v1):
|
| 392 |
+
'''
|
| 393 |
+
find the quaternion used to rotate v0 to v1
|
| 394 |
+
'''
|
| 395 |
+
assert v0.shape[-1] == 3, 'v0 must be of the shape (*, 3)'
|
| 396 |
+
assert v1.shape[-1] == 3, 'v1 must be of the shape (*, 3)'
|
| 397 |
+
|
| 398 |
+
v = torch.cross(v0, v1)
|
| 399 |
+
w = torch.sqrt((v0 ** 2).sum(dim=-1, keepdim=True) * (v1 ** 2).sum(dim=-1, keepdim=True)) + (v0 * v1).sum(dim=-1,
|
| 400 |
+
keepdim=True)
|
| 401 |
+
return qnormalize(torch.cat([w, v], dim=-1))
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
def qbetween_np(v0, v1):
|
| 405 |
+
'''
|
| 406 |
+
find the quaternion used to rotate v0 to v1
|
| 407 |
+
'''
|
| 408 |
+
assert v0.shape[-1] == 3, 'v0 must be of the shape (*, 3)'
|
| 409 |
+
assert v1.shape[-1] == 3, 'v1 must be of the shape (*, 3)'
|
| 410 |
+
|
| 411 |
+
v0 = torch.from_numpy(v0).float()
|
| 412 |
+
v1 = torch.from_numpy(v1).float()
|
| 413 |
+
return qbetween(v0, v1).numpy()
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
def lerp(p0, p1, t):
|
| 417 |
+
if not isinstance(t, torch.Tensor):
|
| 418 |
+
t = torch.Tensor([t])
|
| 419 |
+
|
| 420 |
+
new_shape = t.shape + p0.shape
|
| 421 |
+
new_view_t = t.shape + torch.Size([1] * len(p0.shape))
|
| 422 |
+
new_view_p = torch.Size([1] * len(t.shape)) + p0.shape
|
| 423 |
+
p0 = p0.view(new_view_p).expand(new_shape)
|
| 424 |
+
p1 = p1.view(new_view_p).expand(new_shape)
|
| 425 |
+
t = t.view(new_view_t).expand(new_shape)
|
| 426 |
+
|
| 427 |
+
return p0 + t * (p1 - p0)
|
| 428 |
+
|
| 429 |
+
def matrix_to_quat(R) -> torch.Tensor:
|
| 430 |
+
'''
|
| 431 |
+
https://github.com/duolu/pyrotation/blob/master/pyrotation/pyrotation.py
|
| 432 |
+
Convert a rotation matrix to a unit quaternion.
|
| 433 |
+
This uses the Shepperd’s method for numerical stability.
|
| 434 |
+
'''
|
| 435 |
+
|
| 436 |
+
# The rotation matrix must be orthonormal
|
| 437 |
+
|
| 438 |
+
w2 = (1 + R[..., 0, 0] + R[..., 1, 1] + R[..., 2, 2])
|
| 439 |
+
x2 = (1 + R[..., 0, 0] - R[..., 1, 1] - R[..., 2, 2])
|
| 440 |
+
y2 = (1 - R[..., 0, 0] + R[..., 1, 1] - R[..., 2, 2])
|
| 441 |
+
z2 = (1 - R[..., 0, 0] - R[..., 1, 1] + R[..., 2, 2])
|
| 442 |
+
|
| 443 |
+
yz = (R[..., 1, 2] + R[..., 2, 1])
|
| 444 |
+
xz = (R[..., 2, 0] + R[..., 0, 2])
|
| 445 |
+
xy = (R[..., 0, 1] + R[..., 1, 0])
|
| 446 |
+
|
| 447 |
+
wx = (R[..., 2, 1] - R[..., 1, 2])
|
| 448 |
+
wy = (R[..., 0, 2] - R[..., 2, 0])
|
| 449 |
+
wz = (R[..., 1, 0] - R[..., 0, 1])
|
| 450 |
+
|
| 451 |
+
w = torch.empty_like(x2)
|
| 452 |
+
x = torch.empty_like(x2)
|
| 453 |
+
y = torch.empty_like(x2)
|
| 454 |
+
z = torch.empty_like(x2)
|
| 455 |
+
|
| 456 |
+
flagA = (R[..., 2, 2] < 0) * (R[..., 0, 0] > R[..., 1, 1])
|
| 457 |
+
flagB = (R[..., 2, 2] < 0) * (R[..., 0, 0] <= R[..., 1, 1])
|
| 458 |
+
flagC = (R[..., 2, 2] >= 0) * (R[..., 0, 0] < -R[..., 1, 1])
|
| 459 |
+
flagD = (R[..., 2, 2] >= 0) * (R[..., 0, 0] >= -R[..., 1, 1])
|
| 460 |
+
|
| 461 |
+
x[flagA] = torch.sqrt(x2[flagA])
|
| 462 |
+
w[flagA] = wx[flagA] / x[flagA]
|
| 463 |
+
y[flagA] = xy[flagA] / x[flagA]
|
| 464 |
+
z[flagA] = xz[flagA] / x[flagA]
|
| 465 |
+
|
| 466 |
+
y[flagB] = torch.sqrt(y2[flagB])
|
| 467 |
+
w[flagB] = wy[flagB] / y[flagB]
|
| 468 |
+
x[flagB] = xy[flagB] / y[flagB]
|
| 469 |
+
z[flagB] = yz[flagB] / y[flagB]
|
| 470 |
+
|
| 471 |
+
z[flagC] = torch.sqrt(z2[flagC])
|
| 472 |
+
w[flagC] = wz[flagC] / z[flagC]
|
| 473 |
+
x[flagC] = xz[flagC] / z[flagC]
|
| 474 |
+
y[flagC] = yz[flagC] / z[flagC]
|
| 475 |
+
|
| 476 |
+
w[flagD] = torch.sqrt(w2[flagD])
|
| 477 |
+
x[flagD] = wx[flagD] / w[flagD]
|
| 478 |
+
y[flagD] = wy[flagD] / w[flagD]
|
| 479 |
+
z[flagD] = wz[flagD] / w[flagD]
|
| 480 |
+
|
| 481 |
+
# if R[..., 2, 2] < 0:
|
| 482 |
+
#
|
| 483 |
+
# if R[..., 0, 0] > R[..., 1, 1]:
|
| 484 |
+
#
|
| 485 |
+
# x = torch.sqrt(x2)
|
| 486 |
+
# w = wx / x
|
| 487 |
+
# y = xy / x
|
| 488 |
+
# z = xz / x
|
| 489 |
+
#
|
| 490 |
+
# else:
|
| 491 |
+
#
|
| 492 |
+
# y = torch.sqrt(y2)
|
| 493 |
+
# w = wy / y
|
| 494 |
+
# x = xy / y
|
| 495 |
+
# z = yz / y
|
| 496 |
+
#
|
| 497 |
+
# else:
|
| 498 |
+
#
|
| 499 |
+
# if R[..., 0, 0] < -R[..., 1, 1]:
|
| 500 |
+
#
|
| 501 |
+
# z = torch.sqrt(z2)
|
| 502 |
+
# w = wz / z
|
| 503 |
+
# x = xz / z
|
| 504 |
+
# y = yz / z
|
| 505 |
+
#
|
| 506 |
+
# else:
|
| 507 |
+
#
|
| 508 |
+
# w = torch.sqrt(w2)
|
| 509 |
+
# x = wx / w
|
| 510 |
+
# y = wy / w
|
| 511 |
+
# z = wz / w
|
| 512 |
+
|
| 513 |
+
res = [w, x, y, z]
|
| 514 |
+
res = [z.unsqueeze(-1) for z in res]
|
| 515 |
+
|
| 516 |
+
return torch.cat(res, dim=-1) / 2
|
| 517 |
+
|
| 518 |
+
def cont6d_to_quat(cont6d):
|
| 519 |
+
return matrix_to_quat(cont6d_to_matrix(cont6d))
|
utils/skeleton.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from utils.quaternion import *
|
| 2 |
+
import scipy.ndimage.filters as filters
|
| 3 |
+
|
| 4 |
+
class Skeleton(object):
|
| 5 |
+
def __init__(self, offset, kinematic_tree, device):
|
| 6 |
+
self.device = device
|
| 7 |
+
self._raw_offset_np = offset.numpy()
|
| 8 |
+
self._raw_offset = offset.clone().detach().to(device).float()
|
| 9 |
+
self._kinematic_tree = kinematic_tree
|
| 10 |
+
self._offset = None
|
| 11 |
+
self._parents = [0] * len(self._raw_offset)
|
| 12 |
+
self._parents[0] = -1
|
| 13 |
+
for chain in self._kinematic_tree:
|
| 14 |
+
for j in range(1, len(chain)):
|
| 15 |
+
self._parents[chain[j]] = chain[j-1]
|
| 16 |
+
|
| 17 |
+
def njoints(self):
|
| 18 |
+
return len(self._raw_offset)
|
| 19 |
+
|
| 20 |
+
def offset(self):
|
| 21 |
+
return self._offset
|
| 22 |
+
|
| 23 |
+
def set_offset(self, offsets):
|
| 24 |
+
self._offset = offsets.clone().detach().to(self.device).float()
|
| 25 |
+
|
| 26 |
+
def kinematic_tree(self):
|
| 27 |
+
return self._kinematic_tree
|
| 28 |
+
|
| 29 |
+
def parents(self):
|
| 30 |
+
return self._parents
|
| 31 |
+
|
| 32 |
+
# joints (batch_size, joints_num, 3)
|
| 33 |
+
def get_offsets_joints_batch(self, joints):
|
| 34 |
+
assert len(joints.shape) == 3
|
| 35 |
+
_offsets = self._raw_offset.expand(joints.shape[0], -1, -1).clone()
|
| 36 |
+
for i in range(1, self._raw_offset.shape[0]):
|
| 37 |
+
_offsets[:, i] = torch.norm(joints[:, i] - joints[:, self._parents[i]], p=2, dim=1)[:, None] * _offsets[:, i]
|
| 38 |
+
|
| 39 |
+
self._offset = _offsets.detach()
|
| 40 |
+
return _offsets
|
| 41 |
+
|
| 42 |
+
# joints (joints_num, 3)
|
| 43 |
+
def get_offsets_joints(self, joints):
|
| 44 |
+
assert len(joints.shape) == 2
|
| 45 |
+
_offsets = self._raw_offset.clone()
|
| 46 |
+
for i in range(1, self._raw_offset.shape[0]):
|
| 47 |
+
# print(joints.shape)
|
| 48 |
+
_offsets[i] = torch.norm(joints[i] - joints[self._parents[i]], p=2, dim=0) * _offsets[i]
|
| 49 |
+
|
| 50 |
+
self._offset = _offsets.detach()
|
| 51 |
+
return _offsets
|
| 52 |
+
|
| 53 |
+
# face_joint_idx should follow the order of right hip, left hip, right shoulder, left shoulder
|
| 54 |
+
# joints (batch_size, joints_num, 3)
|
| 55 |
+
def inverse_kinematics_np(self, joints, face_joint_idx, smooth_forward=False):
|
| 56 |
+
assert len(face_joint_idx) == 4
|
| 57 |
+
'''Get Forward Direction'''
|
| 58 |
+
l_hip, r_hip, sdr_r, sdr_l = face_joint_idx
|
| 59 |
+
across1 = joints[:, r_hip] - joints[:, l_hip]
|
| 60 |
+
across2 = joints[:, sdr_r] - joints[:, sdr_l]
|
| 61 |
+
across = across1 + across2
|
| 62 |
+
across = across / (np.sqrt((across**2).sum(axis=-1, keepdims=True)) + 1e-8)
|
| 63 |
+
# print(across1.shape, across2.shape)
|
| 64 |
+
|
| 65 |
+
# forward (batch_size, 3)
|
| 66 |
+
forward = np.cross(np.array([[0, 1, 0]]), across, axis=-1)
|
| 67 |
+
if smooth_forward:
|
| 68 |
+
forward = filters.gaussian_filter1d(forward, 20, axis=0, mode='nearest')
|
| 69 |
+
# forward (batch_size, 3)
|
| 70 |
+
forward = forward / (np.sqrt((forward**2).sum(axis=-1))[..., np.newaxis] + 1e-8)
|
| 71 |
+
|
| 72 |
+
'''Get Root Rotation'''
|
| 73 |
+
target = np.array([[0,0,1]]).repeat(len(forward), axis=0)
|
| 74 |
+
root_quat = qbetween_np(forward, target)
|
| 75 |
+
|
| 76 |
+
'''Inverse Kinematics'''
|
| 77 |
+
# quat_params (batch_size, joints_num, 4)
|
| 78 |
+
# print(joints.shape[:-1])
|
| 79 |
+
quat_params = np.zeros(joints.shape[:-1] + (4,))
|
| 80 |
+
# print(quat_params.shape)
|
| 81 |
+
root_quat[0] = np.array([[1.0, 0.0, 0.0, 0.0]])
|
| 82 |
+
quat_params[:, 0] = root_quat
|
| 83 |
+
# quat_params[0, 0] = np.array([[1.0, 0.0, 0.0, 0.0]])
|
| 84 |
+
for chain in self._kinematic_tree:
|
| 85 |
+
R = root_quat
|
| 86 |
+
for j in range(len(chain) - 1):
|
| 87 |
+
# (batch, 3)
|
| 88 |
+
u = self._raw_offset_np[chain[j+1]][np.newaxis,...].repeat(len(joints), axis=0)
|
| 89 |
+
# print(u.shape)
|
| 90 |
+
# (batch, 3)
|
| 91 |
+
v = joints[:, chain[j+1]] - joints[:, chain[j]]
|
| 92 |
+
v = v / (np.sqrt((v**2).sum(axis=-1, keepdims=True)) + 1e-8)
|
| 93 |
+
# print(u.shape, v.shape)
|
| 94 |
+
rot_u_v = qbetween_np(u, v)
|
| 95 |
+
|
| 96 |
+
R_loc = qmul_np(qinv_np(R), rot_u_v)
|
| 97 |
+
|
| 98 |
+
quat_params[:,chain[j + 1], :] = R_loc
|
| 99 |
+
R = qmul_np(R, R_loc)
|
| 100 |
+
|
| 101 |
+
return quat_params
|
| 102 |
+
|
| 103 |
+
# Be sure root joint is at the beginning of kinematic chains
|
| 104 |
+
def forward_kinematics(self, quat_params, root_pos, skel_joints=None, do_root_R=True):
|
| 105 |
+
# quat_params (batch_size, joints_num, 4)
|
| 106 |
+
# joints (batch_size, joints_num, 3)
|
| 107 |
+
# root_pos (batch_size, 3)
|
| 108 |
+
if skel_joints is not None:
|
| 109 |
+
offsets = self.get_offsets_joints_batch(skel_joints)
|
| 110 |
+
if len(self._offset.shape) == 2:
|
| 111 |
+
offsets = self._offset.expand(quat_params.shape[0], -1, -1)
|
| 112 |
+
joints = torch.zeros(quat_params.shape[:-1] + (3,)).to(self.device)
|
| 113 |
+
joints[:, 0] = root_pos
|
| 114 |
+
for chain in self._kinematic_tree:
|
| 115 |
+
if do_root_R:
|
| 116 |
+
R = quat_params[:, 0]
|
| 117 |
+
else:
|
| 118 |
+
R = torch.tensor([[1.0, 0.0, 0.0, 0.0]]).expand(len(quat_params), -1).detach().to(self.device)
|
| 119 |
+
for i in range(1, len(chain)):
|
| 120 |
+
R = qmul(R, quat_params[:, chain[i]])
|
| 121 |
+
offset_vec = offsets[:, chain[i]]
|
| 122 |
+
joints[:, chain[i]] = qrot(R, offset_vec) + joints[:, chain[i-1]]
|
| 123 |
+
return joints
|
| 124 |
+
|
| 125 |
+
# Be sure root joint is at the beginning of kinematic chains
|
| 126 |
+
def forward_kinematics_np(self, quat_params, root_pos, skel_joints=None, do_root_R=True):
|
| 127 |
+
# quat_params (batch_size, joints_num, 4)
|
| 128 |
+
# joints (batch_size, joints_num, 3)
|
| 129 |
+
# root_pos (batch_size, 3)
|
| 130 |
+
if skel_joints is not None:
|
| 131 |
+
skel_joints = torch.from_numpy(skel_joints)
|
| 132 |
+
offsets = self.get_offsets_joints_batch(skel_joints)
|
| 133 |
+
if len(self._offset.shape) == 2:
|
| 134 |
+
offsets = self._offset.expand(quat_params.shape[0], -1, -1)
|
| 135 |
+
offsets = offsets.numpy()
|
| 136 |
+
joints = np.zeros(quat_params.shape[:-1] + (3,))
|
| 137 |
+
joints[:, 0] = root_pos
|
| 138 |
+
for chain in self._kinematic_tree:
|
| 139 |
+
if do_root_R:
|
| 140 |
+
R = quat_params[:, 0]
|
| 141 |
+
else:
|
| 142 |
+
R = np.array([[1.0, 0.0, 0.0, 0.0]]).repeat(len(quat_params), axis=0)
|
| 143 |
+
for i in range(1, len(chain)):
|
| 144 |
+
R = qmul_np(R, quat_params[:, chain[i]])
|
| 145 |
+
offset_vec = offsets[:, chain[i]]
|
| 146 |
+
joints[:, chain[i]] = qrot_np(R, offset_vec) + joints[:, chain[i - 1]]
|
| 147 |
+
return joints
|
| 148 |
+
|
| 149 |
+
def forward_kinematics_cont6d_np(self, cont6d_params, root_pos, skel_joints=None, do_root_R=True):
|
| 150 |
+
# cont6d_params (batch_size, joints_num, 6)
|
| 151 |
+
# joints (batch_size, joints_num, 3)
|
| 152 |
+
# root_pos (batch_size, 3)
|
| 153 |
+
if skel_joints is not None:
|
| 154 |
+
skel_joints = torch.from_numpy(skel_joints)
|
| 155 |
+
offsets = self.get_offsets_joints_batch(skel_joints)
|
| 156 |
+
if len(self._offset.shape) == 2:
|
| 157 |
+
offsets = self._offset.expand(cont6d_params.shape[0], -1, -1)
|
| 158 |
+
offsets = offsets.numpy()
|
| 159 |
+
joints = np.zeros(cont6d_params.shape[:-1] + (3,))
|
| 160 |
+
joints[:, 0] = root_pos
|
| 161 |
+
for chain in self._kinematic_tree:
|
| 162 |
+
if do_root_R:
|
| 163 |
+
matR = cont6d_to_matrix_np(cont6d_params[:, 0])
|
| 164 |
+
else:
|
| 165 |
+
matR = np.eye(3)[np.newaxis, :].repeat(len(cont6d_params), axis=0)
|
| 166 |
+
for i in range(1, len(chain)):
|
| 167 |
+
matR = np.matmul(matR, cont6d_to_matrix_np(cont6d_params[:, chain[i]]))
|
| 168 |
+
offset_vec = offsets[:, chain[i]][..., np.newaxis]
|
| 169 |
+
# print(matR.shape, offset_vec.shape)
|
| 170 |
+
joints[:, chain[i]] = np.matmul(matR, offset_vec).squeeze(-1) + joints[:, chain[i-1]]
|
| 171 |
+
return joints
|
| 172 |
+
|
| 173 |
+
def forward_kinematics_cont6d(self, cont6d_params, root_pos, skel_joints=None, do_root_R=True):
|
| 174 |
+
# cont6d_params (batch_size, joints_num, 6)
|
| 175 |
+
# joints (batch_size, joints_num, 3)
|
| 176 |
+
# root_pos (batch_size, 3)
|
| 177 |
+
if skel_joints is not None:
|
| 178 |
+
# skel_joints = torch.from_numpy(skel_joints)
|
| 179 |
+
offsets = self.get_offsets_joints_batch(skel_joints)
|
| 180 |
+
if len(self._offset.shape) == 2:
|
| 181 |
+
offsets = self._offset.expand(cont6d_params.shape[0], -1, -1)
|
| 182 |
+
joints = torch.zeros(cont6d_params.shape[:-1] + (3,)).to(cont6d_params.device)
|
| 183 |
+
joints[..., 0, :] = root_pos
|
| 184 |
+
for chain in self._kinematic_tree:
|
| 185 |
+
if do_root_R:
|
| 186 |
+
matR = cont6d_to_matrix(cont6d_params[:, 0])
|
| 187 |
+
else:
|
| 188 |
+
matR = torch.eye(3).expand((len(cont6d_params), -1, -1)).detach().to(cont6d_params.device)
|
| 189 |
+
for i in range(1, len(chain)):
|
| 190 |
+
matR = torch.matmul(matR, cont6d_to_matrix(cont6d_params[:, chain[i]]))
|
| 191 |
+
offset_vec = offsets[:, chain[i]].unsqueeze(-1)
|
| 192 |
+
# print(matR.shape, offset_vec.shape)
|
| 193 |
+
joints[:, chain[i]] = torch.matmul(matR, offset_vec).squeeze(-1) + joints[:, chain[i-1]]
|
| 194 |
+
return joints
|
utils/train_utils.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import math
|
| 3 |
+
import time
|
| 4 |
+
import torch.distributed as dist
|
| 5 |
+
import logging
|
| 6 |
+
#################################################################################
|
| 7 |
+
# DDP Functions #
|
| 8 |
+
#################################################################################
|
| 9 |
+
def cleanup():
|
| 10 |
+
dist.destroy_process_group()
|
| 11 |
+
|
| 12 |
+
#################################################################################
|
| 13 |
+
# Util Functions #
|
| 14 |
+
#################################################################################
|
| 15 |
+
def lengths_to_mask(lengths, max_len):
|
| 16 |
+
# max_len = max(lengths)
|
| 17 |
+
mask = torch.arange(max_len, device=lengths.device).expand(len(lengths), max_len) < lengths.unsqueeze(1)
|
| 18 |
+
return mask #(b, len)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def get_mask_subset_prob(mask, prob):
|
| 22 |
+
subset_mask = torch.bernoulli(mask, p=prob) & mask
|
| 23 |
+
return subset_mask
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def uniform(shape, device=None):
|
| 27 |
+
return torch.zeros(shape, device=device).float().uniform_(0, 1)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def cosine_schedule(t):
|
| 31 |
+
return torch.cos(t * math.pi * 0.5)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def update_ema(model, ema_model, ema_decay):
|
| 35 |
+
with torch.no_grad():
|
| 36 |
+
for ema_param, model_param in zip(ema_model.parameters(), model.parameters()):
|
| 37 |
+
ema_param.data.mul_(ema_decay).add_(model_param.data, alpha=(1 - ema_decay))
|
| 38 |
+
|
| 39 |
+
#################################################################################
|
| 40 |
+
# Logging Functions #
|
| 41 |
+
#################################################################################
|
| 42 |
+
def def_value():
|
| 43 |
+
return 0.0
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def create_logger(logging_dir):
|
| 47 |
+
if dist.get_rank() == 0: # real logger
|
| 48 |
+
logging.basicConfig(
|
| 49 |
+
level=logging.INFO,
|
| 50 |
+
format='[\033[34m%(asctime)s\033[0m] %(message)s',
|
| 51 |
+
datefmt='%Y-%m-%d %H:%M:%S',
|
| 52 |
+
handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")]
|
| 53 |
+
)
|
| 54 |
+
logger = logging.getLogger(__name__)
|
| 55 |
+
else: # dummy logger (does nothing)
|
| 56 |
+
logger = logging.getLogger(__name__)
|
| 57 |
+
logger.addHandler(logging.NullHandler())
|
| 58 |
+
return logger
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def update_lr_warm_up(nb_iter, warm_up_iter, optimizer, lr):
|
| 62 |
+
current_lr = lr * (nb_iter + 1) / (warm_up_iter + 1)
|
| 63 |
+
for param_group in optimizer.param_groups:
|
| 64 |
+
param_group["lr"] = current_lr
|
| 65 |
+
return current_lr
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def save(file_name, ep, model, optimizer, scheduler, total_it, name, ema=None):
|
| 69 |
+
state = {
|
| 70 |
+
name: model.state_dict(),
|
| 71 |
+
f"opt_{name}": optimizer.state_dict(),
|
| 72 |
+
"scheduler": scheduler.state_dict(),
|
| 73 |
+
'ep': ep,
|
| 74 |
+
'total_it': total_it,
|
| 75 |
+
}
|
| 76 |
+
if ema is not None:
|
| 77 |
+
mardm_state_dict = model.state_dict()
|
| 78 |
+
ema_mardm_state_dict = ema.state_dict()
|
| 79 |
+
clip_weights = [e for e in mardm_state_dict.keys() if e.startswith('clip_model.')]
|
| 80 |
+
for e in clip_weights:
|
| 81 |
+
del mardm_state_dict[e]
|
| 82 |
+
del ema_mardm_state_dict[e]
|
| 83 |
+
state[name] = mardm_state_dict
|
| 84 |
+
state[f"ema_{name}"] = ema_mardm_state_dict
|
| 85 |
+
torch.save(state, file_name)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def print_current_loss(start_time, niter_state, total_niters, losses, epoch=None, sub_epoch=None,
|
| 89 |
+
inner_iter=None, tf_ratio=None, sl_steps=None):
|
| 90 |
+
def as_minutes(s):
|
| 91 |
+
m = math.floor(s / 60)
|
| 92 |
+
s -= m * 60
|
| 93 |
+
return '%dm %ds' % (m, s)
|
| 94 |
+
def time_since(since, percent):
|
| 95 |
+
now = time.time()
|
| 96 |
+
s = now - since
|
| 97 |
+
es = s / percent
|
| 98 |
+
rs = es - s
|
| 99 |
+
return '%s (- %s)' % (as_minutes(s), as_minutes(rs))
|
| 100 |
+
if epoch is not None:
|
| 101 |
+
print('ep/it:%2d-%4d niter:%6d' % (epoch, inner_iter, niter_state), end=" ")
|
| 102 |
+
message = ' %s completed:%3d%%)' % (time_since(start_time, niter_state / total_niters), niter_state / total_niters * 100)
|
| 103 |
+
for k, v in losses.items():
|
| 104 |
+
message += ' %s: %.4f ' % (k, v)
|
| 105 |
+
print(message)
|
utils/wandb_utils.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import wandb
|
| 2 |
+
import torch
|
| 3 |
+
from torchvision.utils import make_grid
|
| 4 |
+
import torch.distributed as dist
|
| 5 |
+
import argparse
|
| 6 |
+
import hashlib
|
| 7 |
+
import math
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def is_main_process():
|
| 11 |
+
return dist.get_rank() == 0
|
| 12 |
+
|
| 13 |
+
def namespace_to_dict(namespace):
|
| 14 |
+
return {
|
| 15 |
+
k: namespace_to_dict(v) if isinstance(v, argparse.Namespace) else v
|
| 16 |
+
for k, v in vars(namespace).items()
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def generate_run_id(exp_name):
|
| 21 |
+
# https://stackoverflow.com/questions/16008670/how-to-hash-a-string-into-8-digits
|
| 22 |
+
return str(int(hashlib.sha256(exp_name.encode('utf-8')).hexdigest(), 16) % 10 ** 8)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def initialize(args, entity, exp_name, project_name):
|
| 26 |
+
config_dict = namespace_to_dict(args)
|
| 27 |
+
# wandb.login(key=os.environ["WANDB_KEY"])
|
| 28 |
+
wandb.init(
|
| 29 |
+
entity=entity,
|
| 30 |
+
project=project_name,
|
| 31 |
+
name=exp_name,
|
| 32 |
+
config=config_dict,
|
| 33 |
+
id=generate_run_id(exp_name),
|
| 34 |
+
resume="allow",
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def log(stats, step=None):
|
| 39 |
+
if is_main_process():
|
| 40 |
+
wandb.log({k: v for k, v in stats.items()}, step=step)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def log_image(sample, step=None):
|
| 44 |
+
if is_main_process():
|
| 45 |
+
sample = array2grid(sample)
|
| 46 |
+
wandb.log({f"samples": wandb.Image(sample), "train_step": step})
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def array2grid(x):
|
| 50 |
+
nrow = round(math.sqrt(x.size(0)))
|
| 51 |
+
x = make_grid(x, nrow=nrow, normalize=True, value_range=(-1,1))
|
| 52 |
+
x = x.mul(255).add_(0.5).clamp_(0,255).permute(1,2,0).to('cpu', torch.uint8).numpy()
|
| 53 |
+
return x
|