matbee commited on
Commit
07823f7
·
verified ·
1 Parent(s): 0a290eb

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +10 -0
  2. README.md +66 -53
  3. fp16/dacvae_decoder.onnx +3 -0
  4. fp16/dacvae_decoder.onnx.data +3 -0
  5. fp16/dacvae_encoder.onnx +3 -0
  6. fp16/dacvae_encoder.onnx.data +3 -0
  7. fp16/dit_single_step.onnx +3 -0
  8. fp16/dit_single_step.onnx.data +3 -0
  9. fp16/t5_encoder.onnx +3 -0
  10. fp16/t5_encoder.onnx.data +3 -0
  11. fp16/tokenizer/special_tokens_map.json +107 -0
  12. fp16/tokenizer/spiece.model +3 -0
  13. fp16/tokenizer/tokenizer.json +0 -0
  14. fp16/tokenizer/tokenizer_config.json +939 -0
  15. fp16/tokenizer_config.json +7 -0
  16. fp16/vision_encoder.onnx +3 -0
  17. fp16/vision_encoder.onnx.data +3 -0
  18. fp32/dacvae_decoder.onnx +3 -0
  19. fp32/dacvae_decoder.onnx.data +3 -0
  20. fp32/dacvae_encoder.onnx +3 -0
  21. fp32/dacvae_encoder.onnx.data +3 -0
  22. fp32/dit_single_step.onnx +3 -0
  23. fp32/dit_single_step.onnx.data +3 -0
  24. fp32/t5_encoder.onnx +3 -0
  25. fp32/t5_encoder.onnx.data +3 -0
  26. fp32/tokenizer/special_tokens_map.json +107 -0
  27. fp32/tokenizer/spiece.model +3 -0
  28. fp32/tokenizer/tokenizer.json +0 -0
  29. fp32/tokenizer/tokenizer_config.json +939 -0
  30. fp32/tokenizer_config.json +7 -0
  31. fp32/vision_encoder.onnx +3 -0
  32. fp32/vision_encoder.onnx.data +3 -0
  33. onnx_export/__init__.py +1 -0
  34. onnx_export/__pycache__/__init__.cpython-312.pyc +0 -0
  35. onnx_export/__pycache__/export_dacvae.cpython-312.pyc +0 -0
  36. onnx_export/__pycache__/export_dit.cpython-312.pyc +0 -0
  37. onnx_export/__pycache__/export_peaframe.cpython-312.pyc +0 -0
  38. onnx_export/__pycache__/export_t5.cpython-312.pyc +0 -0
  39. onnx_export/__pycache__/export_vision.cpython-312.pyc +0 -0
  40. onnx_export/__pycache__/quantize_large_model.cpython-312.pyc +0 -0
  41. onnx_export/__pycache__/quantize_models.cpython-312.pyc +0 -0
  42. onnx_export/__pycache__/standalone_config.cpython-312.pyc +0 -0
  43. onnx_export/export_all.py +130 -0
  44. onnx_export/export_dacvae.py +427 -0
  45. onnx_export/export_dit.py +574 -0
  46. onnx_export/export_peaframe.py +288 -0
  47. onnx_export/export_t5.py +315 -0
  48. onnx_export/export_vision.py +113 -0
  49. onnx_export/quantize_large_model.py +115 -0
  50. onnx_export/quantize_models.py +286 -0
.gitattributes CHANGED
@@ -34,3 +34,13 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  vision_encoder.onnx.data filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  vision_encoder.onnx.data filter=lfs diff=lfs merge=lfs -text
37
+ fp16/dacvae_decoder.onnx.data filter=lfs diff=lfs merge=lfs -text
38
+ fp16/dacvae_encoder.onnx.data filter=lfs diff=lfs merge=lfs -text
39
+ fp16/dit_single_step.onnx.data filter=lfs diff=lfs merge=lfs -text
40
+ fp16/t5_encoder.onnx.data filter=lfs diff=lfs merge=lfs -text
41
+ fp16/vision_encoder.onnx.data filter=lfs diff=lfs merge=lfs -text
42
+ fp32/dacvae_decoder.onnx.data filter=lfs diff=lfs merge=lfs -text
43
+ fp32/dacvae_encoder.onnx.data filter=lfs diff=lfs merge=lfs -text
44
+ fp32/dit_single_step.onnx.data filter=lfs diff=lfs merge=lfs -text
45
+ fp32/t5_encoder.onnx.data filter=lfs diff=lfs merge=lfs -text
46
+ fp32/vision_encoder.onnx.data filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,95 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
1
  # SAM-Audio ONNX (Large)
2
 
3
  ONNX-converted models for [SAM-Audio](https://github.com/facebookresearch/sam-audio) (facebook/sam-audio-large) - Meta's Semantic Audio Modeling for audio source separation.
4
 
5
- ## Model Files
 
 
 
 
 
 
 
6
 
7
- | File | Description | Size |
8
- |------|-------------|------|
9
- | `dacvae_encoder.onnx` | Audio encoder (48kHz latent) | ~110 MB |
10
- | `dacvae_decoder.onnx` | Audio decoder (latent → 48kHz) | ~320 MB |
11
- | `t5_encoder.onnx` | Text encoder (T5-base) | ~440 MB |
12
- | `dit_single_step.onnx` | DiT denoiser (single ODE step) | ~2 GB |
13
- | `vision_encoder.onnx` | Vision encoder (CLIP-based) | ~1.2 GB |
14
- | `tokenizer/` | SentencePiece tokenizer files | - |
 
 
15
 
16
  ## Installation
17
 
18
  ```bash
19
  pip install onnxruntime sentencepiece torchaudio torchvision torchcodec soundfile
20
- # For CUDA support:
21
  pip install onnxruntime-gpu
22
  ```
23
 
24
- ## Quick Start
25
-
26
- ```python
27
- import numpy as np
28
- import onnxruntime as ort
29
- from huggingface_hub import hf_hub_download
30
-
31
- # Download models
32
- model_dir = "sam-audio-large-onnx"
33
- for f in ["dacvae_encoder.onnx", "dacvae_decoder.onnx", "t5_encoder.onnx",
34
- "dit_single_step.onnx", "vision_encoder.onnx"]:
35
- hf_hub_download("matbee/sam-audio-large-onnx", f, local_dir=model_dir)
36
- if f != "vision_encoder.onnx": # vision encoder embeds weights
37
- hf_hub_download("matbee/sam-audio-large-onnx", f + ".data", local_dir=model_dir)
38
- ```
39
-
40
- ## Usage Examples
41
 
42
- ### Audio-Only Separation
43
  ```bash
44
  python onnx_inference.py \
45
- --audio input.wav \
46
  --text "a person speaking" \
47
- --output separated.wav
 
 
48
  ```
49
 
50
- ### Video-Guided Separation
51
  ```bash
52
  python onnx_inference.py \
53
  --video input.mp4 \
54
- --text "the sound of typing" \
55
- --output separated.wav
 
56
  ```
57
 
58
- ### Visual Prompting with SAM3 Mask
59
  ```bash
60
- # First generate a mask with SAM3 (see generate_sam3_mask.py)
61
  python onnx_inference.py \
62
- --video input.mp4 \
63
- --mask object_mask.mp4 \
64
- --text "" \
65
- --output isolated.wav \
66
- --output-video visualization.mp4
67
  ```
68
 
69
- ## Model Details
70
 
71
  - **Audio Sample Rate**: 48kHz
72
  - **Audio Hop Length**: 1536 samples
73
  - **Vision Input Size**: 336×336 pixels
74
  - **Text Encoder**: T5-base (768-dim)
75
  - **Vision Encoder**: PE-Core-L14-336 (1024-dim)
76
- - **ODE Solver**: Midpoint method (configurable steps)
 
77
 
78
- ## License
79
 
80
- SAM-Audio is released under the [CC-BY-NC 4.0 license](https://creativecommons.org/licenses/by-nc/4.0/).
81
-
82
- ## Citation
 
 
 
 
 
83
 
84
- ```bibtex
85
- @article{samaudio2024,
86
- title={SAM-Audio: Semantic Audio Modeling},
87
- author={Meta AI},
88
- year={2024}
89
- }
90
  ```
91
 
 
 
 
 
92
  ## Acknowledgments
93
 
94
  Original model by [Meta AI Research](https://github.com/facebookresearch/sam-audio).
95
- ONNX conversion by [@matbee](https://huggingface.co/matbee).
 
1
+ ---
2
+ license: other
3
+ base_model: facebook/sam-audio-large
4
+ tags:
5
+ - onnx
6
+ - audio
7
+ - sam-audio
8
+ - source-separation
9
+ - audio-visual
10
+ ---
11
+
12
  # SAM-Audio ONNX (Large)
13
 
14
  ONNX-converted models for [SAM-Audio](https://github.com/facebookresearch/sam-audio) (facebook/sam-audio-large) - Meta's Semantic Audio Modeling for audio source separation.
15
 
16
+ This repository contains both **FP32** and **FP16** versions of the models.
17
+
18
+ ## Model Variants
19
+
20
+ | Variant | DiT Size | Total Size | Notes |
21
+ |---------|----------|------------|-------|
22
+ | `fp32/` | 11.76 GB | ~13.9 GB | Full precision |
23
+ | `fp16/` | 5.88 GB | ~8.0 GB | Half precision (recommended) |
24
 
25
+ ## Model Files (per variant)
26
+
27
+ | File | Description | FP32 Size | FP16 Size |
28
+ |------|-------------|-----------|-----------|
29
+ | `dacvae_encoder.onnx` | Audio encoder (48kHz → latent) | 110 MB | 110 MB |
30
+ | `dacvae_decoder.onnx` | Audio decoder (latent 48kHz) | 320 MB | 320 MB |
31
+ | `t5_encoder.onnx` | Text encoder (T5-base) | 440 MB | 440 MB |
32
+ | `dit_single_step.onnx` | DiT denoiser (3B params) | 11.76 GB | 5.88 GB |
33
+ | `vision_encoder.onnx` | Vision encoder (CLIP-based) | 1.27 GB | 1.27 GB |
34
+ | `tokenizer/` | SentencePiece tokenizer files | - | - |
35
 
36
  ## Installation
37
 
38
  ```bash
39
  pip install onnxruntime sentencepiece torchaudio torchvision torchcodec soundfile
40
+ # For CUDA support (recommended for large model):
41
  pip install onnxruntime-gpu
42
  ```
43
 
44
+ ## Usage
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
+ ### Using FP16 Models (Recommended)
47
  ```bash
48
  python onnx_inference.py \
49
+ --video input.mp4 \
50
  --text "a person speaking" \
51
+ --model-dir fp16 \
52
+ --output target.wav \
53
+ --output-residual residual.wav
54
  ```
55
 
56
+ ### Using FP32 Models
57
  ```bash
58
  python onnx_inference.py \
59
  --video input.mp4 \
60
+ --text "keyboard typing" \
61
+ --model-dir fp32 \
62
+ --output target.wav
63
  ```
64
 
65
+ ### Audio-Only Mode
66
  ```bash
 
67
  python onnx_inference.py \
68
+ --audio input.wav \
69
+ --text "drums" \
70
+ --model-dir fp16 \
71
+ --output drums.wav
 
72
  ```
73
 
74
+ ## Model Specifications
75
 
76
  - **Audio Sample Rate**: 48kHz
77
  - **Audio Hop Length**: 1536 samples
78
  - **Vision Input Size**: 336×336 pixels
79
  - **Text Encoder**: T5-base (768-dim)
80
  - **Vision Encoder**: PE-Core-L14-336 (1024-dim)
81
+ - **DiT Parameters**: ~3 billion
82
+ - **ODE Solver**: Midpoint method (default 16 steps)
83
 
84
+ ## Exporting Models
85
 
86
+ ### Export FP16 DiT (Recommended)
87
+ ```bash
88
+ python -m onnx_export.export_dit \
89
+ --output-dir ./my_models \
90
+ --model-id facebook/sam-audio-large \
91
+ --fp16 \
92
+ --device cuda
93
+ ```
94
 
95
+ ### Export Other Components
96
+ ```bash
97
+ python -m onnx_export.export_dacvae --output-dir ./my_models --model-id facebook/sam-audio-large
98
+ python -m onnx_export.export_t5 --output-dir ./my_models --model-id facebook/sam-audio-large
99
+ python -m onnx_export.export_vision --model facebook/sam-audio-large --output ./my_models
 
100
  ```
101
 
102
+ ## License
103
+
104
+ SAM-Audio is released under the [CC-BY-NC 4.0 license](https://creativecommons.org/licenses/by-nc/4.0/). See [original repository](https://huggingface.co/facebook/sam-audio-large) for full terms.
105
+
106
  ## Acknowledgments
107
 
108
  Original model by [Meta AI Research](https://github.com/facebookresearch/sam-audio).
 
fp16/dacvae_decoder.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:be13c0842945293ca0f5d77514783dc126007c5b490f00444c1c29f44f7035df
3
+ size 1103715
fp16/dacvae_decoder.onnx.data ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e1dba74364b1d721a4d7f921b18e58577c4f7375c3e2f82e46bf4898ecd61cba
3
+ size 320536576
fp16/dacvae_encoder.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:97a859c46e43a78dec7a974dcaaf433890e78bb2ac3b4dc5b498da414b86580b
3
+ size 866783
fp16/dacvae_encoder.onnx.data ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:68db13ee862cd7b5869176e5f080b2f88bc1a97711babd59d55ed93b26b047f6
3
+ size 110231552
fp16/dit_single_step.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:85f1260ac7773fd3efa231ace4ebdf063f57c4a52719b3db09f1871af5a5f59b
3
+ size 5698535
fp16/dit_single_step.onnx.data ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2d21fde5750792dd4f9d10f85ccfa397363ec9c3bd72b3c711bcbe5a6e8f48b6
3
+ size 5878317056
fp16/t5_encoder.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1776c8bce2eb4dbcd297bde499ea780253bc09677cfa8e18baef383859437e1c
3
+ size 1110394
fp16/t5_encoder.onnx.data ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9a91e63b28acfc81e85f659309625b54b5bf0c2e88161025d3a3b8580d4e20c8
3
+ size 438566912
fp16/tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<extra_id_0>",
4
+ "<extra_id_1>",
5
+ "<extra_id_2>",
6
+ "<extra_id_3>",
7
+ "<extra_id_4>",
8
+ "<extra_id_5>",
9
+ "<extra_id_6>",
10
+ "<extra_id_7>",
11
+ "<extra_id_8>",
12
+ "<extra_id_9>",
13
+ "<extra_id_10>",
14
+ "<extra_id_11>",
15
+ "<extra_id_12>",
16
+ "<extra_id_13>",
17
+ "<extra_id_14>",
18
+ "<extra_id_15>",
19
+ "<extra_id_16>",
20
+ "<extra_id_17>",
21
+ "<extra_id_18>",
22
+ "<extra_id_19>",
23
+ "<extra_id_20>",
24
+ "<extra_id_21>",
25
+ "<extra_id_22>",
26
+ "<extra_id_23>",
27
+ "<extra_id_24>",
28
+ "<extra_id_25>",
29
+ "<extra_id_26>",
30
+ "<extra_id_27>",
31
+ "<extra_id_28>",
32
+ "<extra_id_29>",
33
+ "<extra_id_30>",
34
+ "<extra_id_31>",
35
+ "<extra_id_32>",
36
+ "<extra_id_33>",
37
+ "<extra_id_34>",
38
+ "<extra_id_35>",
39
+ "<extra_id_36>",
40
+ "<extra_id_37>",
41
+ "<extra_id_38>",
42
+ "<extra_id_39>",
43
+ "<extra_id_40>",
44
+ "<extra_id_41>",
45
+ "<extra_id_42>",
46
+ "<extra_id_43>",
47
+ "<extra_id_44>",
48
+ "<extra_id_45>",
49
+ "<extra_id_46>",
50
+ "<extra_id_47>",
51
+ "<extra_id_48>",
52
+ "<extra_id_49>",
53
+ "<extra_id_50>",
54
+ "<extra_id_51>",
55
+ "<extra_id_52>",
56
+ "<extra_id_53>",
57
+ "<extra_id_54>",
58
+ "<extra_id_55>",
59
+ "<extra_id_56>",
60
+ "<extra_id_57>",
61
+ "<extra_id_58>",
62
+ "<extra_id_59>",
63
+ "<extra_id_60>",
64
+ "<extra_id_61>",
65
+ "<extra_id_62>",
66
+ "<extra_id_63>",
67
+ "<extra_id_64>",
68
+ "<extra_id_65>",
69
+ "<extra_id_66>",
70
+ "<extra_id_67>",
71
+ "<extra_id_68>",
72
+ "<extra_id_69>",
73
+ "<extra_id_70>",
74
+ "<extra_id_71>",
75
+ "<extra_id_72>",
76
+ "<extra_id_73>",
77
+ "<extra_id_74>",
78
+ "<extra_id_75>",
79
+ "<extra_id_76>",
80
+ "<extra_id_77>",
81
+ "<extra_id_78>",
82
+ "<extra_id_79>",
83
+ "<extra_id_80>",
84
+ "<extra_id_81>",
85
+ "<extra_id_82>",
86
+ "<extra_id_83>",
87
+ "<extra_id_84>",
88
+ "<extra_id_85>",
89
+ "<extra_id_86>",
90
+ "<extra_id_87>",
91
+ "<extra_id_88>",
92
+ "<extra_id_89>",
93
+ "<extra_id_90>",
94
+ "<extra_id_91>",
95
+ "<extra_id_92>",
96
+ "<extra_id_93>",
97
+ "<extra_id_94>",
98
+ "<extra_id_95>",
99
+ "<extra_id_96>",
100
+ "<extra_id_97>",
101
+ "<extra_id_98>",
102
+ "<extra_id_99>"
103
+ ],
104
+ "eos_token": "</s>",
105
+ "pad_token": "<pad>",
106
+ "unk_token": "<unk>"
107
+ }
fp16/tokenizer/spiece.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d60acb128cf7b7f2536e8f38a5b18a05535c9e14c7a355904270e15b0945ea86
3
+ size 791656
fp16/tokenizer/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
fp16/tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,939 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": null,
3
+ "added_tokens_decoder": {
4
+ "0": {
5
+ "content": "<pad>",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "1": {
13
+ "content": "</s>",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "2": {
21
+ "content": "<unk>",
22
+ "lstrip": false,
23
+ "normalized": false,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ },
28
+ "32000": {
29
+ "content": "<extra_id_99>",
30
+ "lstrip": false,
31
+ "normalized": false,
32
+ "rstrip": false,
33
+ "single_word": false,
34
+ "special": true
35
+ },
36
+ "32001": {
37
+ "content": "<extra_id_98>",
38
+ "lstrip": false,
39
+ "normalized": false,
40
+ "rstrip": false,
41
+ "single_word": false,
42
+ "special": true
43
+ },
44
+ "32002": {
45
+ "content": "<extra_id_97>",
46
+ "lstrip": false,
47
+ "normalized": false,
48
+ "rstrip": false,
49
+ "single_word": false,
50
+ "special": true
51
+ },
52
+ "32003": {
53
+ "content": "<extra_id_96>",
54
+ "lstrip": false,
55
+ "normalized": false,
56
+ "rstrip": false,
57
+ "single_word": false,
58
+ "special": true
59
+ },
60
+ "32004": {
61
+ "content": "<extra_id_95>",
62
+ "lstrip": false,
63
+ "normalized": false,
64
+ "rstrip": false,
65
+ "single_word": false,
66
+ "special": true
67
+ },
68
+ "32005": {
69
+ "content": "<extra_id_94>",
70
+ "lstrip": false,
71
+ "normalized": false,
72
+ "rstrip": false,
73
+ "single_word": false,
74
+ "special": true
75
+ },
76
+ "32006": {
77
+ "content": "<extra_id_93>",
78
+ "lstrip": false,
79
+ "normalized": false,
80
+ "rstrip": false,
81
+ "single_word": false,
82
+ "special": true
83
+ },
84
+ "32007": {
85
+ "content": "<extra_id_92>",
86
+ "lstrip": false,
87
+ "normalized": false,
88
+ "rstrip": false,
89
+ "single_word": false,
90
+ "special": true
91
+ },
92
+ "32008": {
93
+ "content": "<extra_id_91>",
94
+ "lstrip": false,
95
+ "normalized": false,
96
+ "rstrip": false,
97
+ "single_word": false,
98
+ "special": true
99
+ },
100
+ "32009": {
101
+ "content": "<extra_id_90>",
102
+ "lstrip": false,
103
+ "normalized": false,
104
+ "rstrip": false,
105
+ "single_word": false,
106
+ "special": true
107
+ },
108
+ "32010": {
109
+ "content": "<extra_id_89>",
110
+ "lstrip": false,
111
+ "normalized": false,
112
+ "rstrip": false,
113
+ "single_word": false,
114
+ "special": true
115
+ },
116
+ "32011": {
117
+ "content": "<extra_id_88>",
118
+ "lstrip": false,
119
+ "normalized": false,
120
+ "rstrip": false,
121
+ "single_word": false,
122
+ "special": true
123
+ },
124
+ "32012": {
125
+ "content": "<extra_id_87>",
126
+ "lstrip": false,
127
+ "normalized": false,
128
+ "rstrip": false,
129
+ "single_word": false,
130
+ "special": true
131
+ },
132
+ "32013": {
133
+ "content": "<extra_id_86>",
134
+ "lstrip": false,
135
+ "normalized": false,
136
+ "rstrip": false,
137
+ "single_word": false,
138
+ "special": true
139
+ },
140
+ "32014": {
141
+ "content": "<extra_id_85>",
142
+ "lstrip": false,
143
+ "normalized": false,
144
+ "rstrip": false,
145
+ "single_word": false,
146
+ "special": true
147
+ },
148
+ "32015": {
149
+ "content": "<extra_id_84>",
150
+ "lstrip": false,
151
+ "normalized": false,
152
+ "rstrip": false,
153
+ "single_word": false,
154
+ "special": true
155
+ },
156
+ "32016": {
157
+ "content": "<extra_id_83>",
158
+ "lstrip": false,
159
+ "normalized": false,
160
+ "rstrip": false,
161
+ "single_word": false,
162
+ "special": true
163
+ },
164
+ "32017": {
165
+ "content": "<extra_id_82>",
166
+ "lstrip": false,
167
+ "normalized": false,
168
+ "rstrip": false,
169
+ "single_word": false,
170
+ "special": true
171
+ },
172
+ "32018": {
173
+ "content": "<extra_id_81>",
174
+ "lstrip": false,
175
+ "normalized": false,
176
+ "rstrip": false,
177
+ "single_word": false,
178
+ "special": true
179
+ },
180
+ "32019": {
181
+ "content": "<extra_id_80>",
182
+ "lstrip": false,
183
+ "normalized": false,
184
+ "rstrip": false,
185
+ "single_word": false,
186
+ "special": true
187
+ },
188
+ "32020": {
189
+ "content": "<extra_id_79>",
190
+ "lstrip": false,
191
+ "normalized": false,
192
+ "rstrip": false,
193
+ "single_word": false,
194
+ "special": true
195
+ },
196
+ "32021": {
197
+ "content": "<extra_id_78>",
198
+ "lstrip": false,
199
+ "normalized": false,
200
+ "rstrip": false,
201
+ "single_word": false,
202
+ "special": true
203
+ },
204
+ "32022": {
205
+ "content": "<extra_id_77>",
206
+ "lstrip": false,
207
+ "normalized": false,
208
+ "rstrip": false,
209
+ "single_word": false,
210
+ "special": true
211
+ },
212
+ "32023": {
213
+ "content": "<extra_id_76>",
214
+ "lstrip": false,
215
+ "normalized": false,
216
+ "rstrip": false,
217
+ "single_word": false,
218
+ "special": true
219
+ },
220
+ "32024": {
221
+ "content": "<extra_id_75>",
222
+ "lstrip": false,
223
+ "normalized": false,
224
+ "rstrip": false,
225
+ "single_word": false,
226
+ "special": true
227
+ },
228
+ "32025": {
229
+ "content": "<extra_id_74>",
230
+ "lstrip": false,
231
+ "normalized": false,
232
+ "rstrip": false,
233
+ "single_word": false,
234
+ "special": true
235
+ },
236
+ "32026": {
237
+ "content": "<extra_id_73>",
238
+ "lstrip": false,
239
+ "normalized": false,
240
+ "rstrip": false,
241
+ "single_word": false,
242
+ "special": true
243
+ },
244
+ "32027": {
245
+ "content": "<extra_id_72>",
246
+ "lstrip": false,
247
+ "normalized": false,
248
+ "rstrip": false,
249
+ "single_word": false,
250
+ "special": true
251
+ },
252
+ "32028": {
253
+ "content": "<extra_id_71>",
254
+ "lstrip": false,
255
+ "normalized": false,
256
+ "rstrip": false,
257
+ "single_word": false,
258
+ "special": true
259
+ },
260
+ "32029": {
261
+ "content": "<extra_id_70>",
262
+ "lstrip": false,
263
+ "normalized": false,
264
+ "rstrip": false,
265
+ "single_word": false,
266
+ "special": true
267
+ },
268
+ "32030": {
269
+ "content": "<extra_id_69>",
270
+ "lstrip": false,
271
+ "normalized": false,
272
+ "rstrip": false,
273
+ "single_word": false,
274
+ "special": true
275
+ },
276
+ "32031": {
277
+ "content": "<extra_id_68>",
278
+ "lstrip": false,
279
+ "normalized": false,
280
+ "rstrip": false,
281
+ "single_word": false,
282
+ "special": true
283
+ },
284
+ "32032": {
285
+ "content": "<extra_id_67>",
286
+ "lstrip": false,
287
+ "normalized": false,
288
+ "rstrip": false,
289
+ "single_word": false,
290
+ "special": true
291
+ },
292
+ "32033": {
293
+ "content": "<extra_id_66>",
294
+ "lstrip": false,
295
+ "normalized": false,
296
+ "rstrip": false,
297
+ "single_word": false,
298
+ "special": true
299
+ },
300
+ "32034": {
301
+ "content": "<extra_id_65>",
302
+ "lstrip": false,
303
+ "normalized": false,
304
+ "rstrip": false,
305
+ "single_word": false,
306
+ "special": true
307
+ },
308
+ "32035": {
309
+ "content": "<extra_id_64>",
310
+ "lstrip": false,
311
+ "normalized": false,
312
+ "rstrip": false,
313
+ "single_word": false,
314
+ "special": true
315
+ },
316
+ "32036": {
317
+ "content": "<extra_id_63>",
318
+ "lstrip": false,
319
+ "normalized": false,
320
+ "rstrip": false,
321
+ "single_word": false,
322
+ "special": true
323
+ },
324
+ "32037": {
325
+ "content": "<extra_id_62>",
326
+ "lstrip": false,
327
+ "normalized": false,
328
+ "rstrip": false,
329
+ "single_word": false,
330
+ "special": true
331
+ },
332
+ "32038": {
333
+ "content": "<extra_id_61>",
334
+ "lstrip": false,
335
+ "normalized": false,
336
+ "rstrip": false,
337
+ "single_word": false,
338
+ "special": true
339
+ },
340
+ "32039": {
341
+ "content": "<extra_id_60>",
342
+ "lstrip": false,
343
+ "normalized": false,
344
+ "rstrip": false,
345
+ "single_word": false,
346
+ "special": true
347
+ },
348
+ "32040": {
349
+ "content": "<extra_id_59>",
350
+ "lstrip": false,
351
+ "normalized": false,
352
+ "rstrip": false,
353
+ "single_word": false,
354
+ "special": true
355
+ },
356
+ "32041": {
357
+ "content": "<extra_id_58>",
358
+ "lstrip": false,
359
+ "normalized": false,
360
+ "rstrip": false,
361
+ "single_word": false,
362
+ "special": true
363
+ },
364
+ "32042": {
365
+ "content": "<extra_id_57>",
366
+ "lstrip": false,
367
+ "normalized": false,
368
+ "rstrip": false,
369
+ "single_word": false,
370
+ "special": true
371
+ },
372
+ "32043": {
373
+ "content": "<extra_id_56>",
374
+ "lstrip": false,
375
+ "normalized": false,
376
+ "rstrip": false,
377
+ "single_word": false,
378
+ "special": true
379
+ },
380
+ "32044": {
381
+ "content": "<extra_id_55>",
382
+ "lstrip": false,
383
+ "normalized": false,
384
+ "rstrip": false,
385
+ "single_word": false,
386
+ "special": true
387
+ },
388
+ "32045": {
389
+ "content": "<extra_id_54>",
390
+ "lstrip": false,
391
+ "normalized": false,
392
+ "rstrip": false,
393
+ "single_word": false,
394
+ "special": true
395
+ },
396
+ "32046": {
397
+ "content": "<extra_id_53>",
398
+ "lstrip": false,
399
+ "normalized": false,
400
+ "rstrip": false,
401
+ "single_word": false,
402
+ "special": true
403
+ },
404
+ "32047": {
405
+ "content": "<extra_id_52>",
406
+ "lstrip": false,
407
+ "normalized": false,
408
+ "rstrip": false,
409
+ "single_word": false,
410
+ "special": true
411
+ },
412
+ "32048": {
413
+ "content": "<extra_id_51>",
414
+ "lstrip": false,
415
+ "normalized": false,
416
+ "rstrip": false,
417
+ "single_word": false,
418
+ "special": true
419
+ },
420
+ "32049": {
421
+ "content": "<extra_id_50>",
422
+ "lstrip": false,
423
+ "normalized": false,
424
+ "rstrip": false,
425
+ "single_word": false,
426
+ "special": true
427
+ },
428
+ "32050": {
429
+ "content": "<extra_id_49>",
430
+ "lstrip": false,
431
+ "normalized": false,
432
+ "rstrip": false,
433
+ "single_word": false,
434
+ "special": true
435
+ },
436
+ "32051": {
437
+ "content": "<extra_id_48>",
438
+ "lstrip": false,
439
+ "normalized": false,
440
+ "rstrip": false,
441
+ "single_word": false,
442
+ "special": true
443
+ },
444
+ "32052": {
445
+ "content": "<extra_id_47>",
446
+ "lstrip": false,
447
+ "normalized": false,
448
+ "rstrip": false,
449
+ "single_word": false,
450
+ "special": true
451
+ },
452
+ "32053": {
453
+ "content": "<extra_id_46>",
454
+ "lstrip": false,
455
+ "normalized": false,
456
+ "rstrip": false,
457
+ "single_word": false,
458
+ "special": true
459
+ },
460
+ "32054": {
461
+ "content": "<extra_id_45>",
462
+ "lstrip": false,
463
+ "normalized": false,
464
+ "rstrip": false,
465
+ "single_word": false,
466
+ "special": true
467
+ },
468
+ "32055": {
469
+ "content": "<extra_id_44>",
470
+ "lstrip": false,
471
+ "normalized": false,
472
+ "rstrip": false,
473
+ "single_word": false,
474
+ "special": true
475
+ },
476
+ "32056": {
477
+ "content": "<extra_id_43>",
478
+ "lstrip": false,
479
+ "normalized": false,
480
+ "rstrip": false,
481
+ "single_word": false,
482
+ "special": true
483
+ },
484
+ "32057": {
485
+ "content": "<extra_id_42>",
486
+ "lstrip": false,
487
+ "normalized": false,
488
+ "rstrip": false,
489
+ "single_word": false,
490
+ "special": true
491
+ },
492
+ "32058": {
493
+ "content": "<extra_id_41>",
494
+ "lstrip": false,
495
+ "normalized": false,
496
+ "rstrip": false,
497
+ "single_word": false,
498
+ "special": true
499
+ },
500
+ "32059": {
501
+ "content": "<extra_id_40>",
502
+ "lstrip": false,
503
+ "normalized": false,
504
+ "rstrip": false,
505
+ "single_word": false,
506
+ "special": true
507
+ },
508
+ "32060": {
509
+ "content": "<extra_id_39>",
510
+ "lstrip": false,
511
+ "normalized": false,
512
+ "rstrip": false,
513
+ "single_word": false,
514
+ "special": true
515
+ },
516
+ "32061": {
517
+ "content": "<extra_id_38>",
518
+ "lstrip": false,
519
+ "normalized": false,
520
+ "rstrip": false,
521
+ "single_word": false,
522
+ "special": true
523
+ },
524
+ "32062": {
525
+ "content": "<extra_id_37>",
526
+ "lstrip": false,
527
+ "normalized": false,
528
+ "rstrip": false,
529
+ "single_word": false,
530
+ "special": true
531
+ },
532
+ "32063": {
533
+ "content": "<extra_id_36>",
534
+ "lstrip": false,
535
+ "normalized": false,
536
+ "rstrip": false,
537
+ "single_word": false,
538
+ "special": true
539
+ },
540
+ "32064": {
541
+ "content": "<extra_id_35>",
542
+ "lstrip": false,
543
+ "normalized": false,
544
+ "rstrip": false,
545
+ "single_word": false,
546
+ "special": true
547
+ },
548
+ "32065": {
549
+ "content": "<extra_id_34>",
550
+ "lstrip": false,
551
+ "normalized": false,
552
+ "rstrip": false,
553
+ "single_word": false,
554
+ "special": true
555
+ },
556
+ "32066": {
557
+ "content": "<extra_id_33>",
558
+ "lstrip": false,
559
+ "normalized": false,
560
+ "rstrip": false,
561
+ "single_word": false,
562
+ "special": true
563
+ },
564
+ "32067": {
565
+ "content": "<extra_id_32>",
566
+ "lstrip": false,
567
+ "normalized": false,
568
+ "rstrip": false,
569
+ "single_word": false,
570
+ "special": true
571
+ },
572
+ "32068": {
573
+ "content": "<extra_id_31>",
574
+ "lstrip": false,
575
+ "normalized": false,
576
+ "rstrip": false,
577
+ "single_word": false,
578
+ "special": true
579
+ },
580
+ "32069": {
581
+ "content": "<extra_id_30>",
582
+ "lstrip": false,
583
+ "normalized": false,
584
+ "rstrip": false,
585
+ "single_word": false,
586
+ "special": true
587
+ },
588
+ "32070": {
589
+ "content": "<extra_id_29>",
590
+ "lstrip": false,
591
+ "normalized": false,
592
+ "rstrip": false,
593
+ "single_word": false,
594
+ "special": true
595
+ },
596
+ "32071": {
597
+ "content": "<extra_id_28>",
598
+ "lstrip": false,
599
+ "normalized": false,
600
+ "rstrip": false,
601
+ "single_word": false,
602
+ "special": true
603
+ },
604
+ "32072": {
605
+ "content": "<extra_id_27>",
606
+ "lstrip": false,
607
+ "normalized": false,
608
+ "rstrip": false,
609
+ "single_word": false,
610
+ "special": true
611
+ },
612
+ "32073": {
613
+ "content": "<extra_id_26>",
614
+ "lstrip": false,
615
+ "normalized": false,
616
+ "rstrip": false,
617
+ "single_word": false,
618
+ "special": true
619
+ },
620
+ "32074": {
621
+ "content": "<extra_id_25>",
622
+ "lstrip": false,
623
+ "normalized": false,
624
+ "rstrip": false,
625
+ "single_word": false,
626
+ "special": true
627
+ },
628
+ "32075": {
629
+ "content": "<extra_id_24>",
630
+ "lstrip": false,
631
+ "normalized": false,
632
+ "rstrip": false,
633
+ "single_word": false,
634
+ "special": true
635
+ },
636
+ "32076": {
637
+ "content": "<extra_id_23>",
638
+ "lstrip": false,
639
+ "normalized": false,
640
+ "rstrip": false,
641
+ "single_word": false,
642
+ "special": true
643
+ },
644
+ "32077": {
645
+ "content": "<extra_id_22>",
646
+ "lstrip": false,
647
+ "normalized": false,
648
+ "rstrip": false,
649
+ "single_word": false,
650
+ "special": true
651
+ },
652
+ "32078": {
653
+ "content": "<extra_id_21>",
654
+ "lstrip": false,
655
+ "normalized": false,
656
+ "rstrip": false,
657
+ "single_word": false,
658
+ "special": true
659
+ },
660
+ "32079": {
661
+ "content": "<extra_id_20>",
662
+ "lstrip": false,
663
+ "normalized": false,
664
+ "rstrip": false,
665
+ "single_word": false,
666
+ "special": true
667
+ },
668
+ "32080": {
669
+ "content": "<extra_id_19>",
670
+ "lstrip": false,
671
+ "normalized": false,
672
+ "rstrip": false,
673
+ "single_word": false,
674
+ "special": true
675
+ },
676
+ "32081": {
677
+ "content": "<extra_id_18>",
678
+ "lstrip": false,
679
+ "normalized": false,
680
+ "rstrip": false,
681
+ "single_word": false,
682
+ "special": true
683
+ },
684
+ "32082": {
685
+ "content": "<extra_id_17>",
686
+ "lstrip": false,
687
+ "normalized": false,
688
+ "rstrip": false,
689
+ "single_word": false,
690
+ "special": true
691
+ },
692
+ "32083": {
693
+ "content": "<extra_id_16>",
694
+ "lstrip": false,
695
+ "normalized": false,
696
+ "rstrip": false,
697
+ "single_word": false,
698
+ "special": true
699
+ },
700
+ "32084": {
701
+ "content": "<extra_id_15>",
702
+ "lstrip": false,
703
+ "normalized": false,
704
+ "rstrip": false,
705
+ "single_word": false,
706
+ "special": true
707
+ },
708
+ "32085": {
709
+ "content": "<extra_id_14>",
710
+ "lstrip": false,
711
+ "normalized": false,
712
+ "rstrip": false,
713
+ "single_word": false,
714
+ "special": true
715
+ },
716
+ "32086": {
717
+ "content": "<extra_id_13>",
718
+ "lstrip": false,
719
+ "normalized": false,
720
+ "rstrip": false,
721
+ "single_word": false,
722
+ "special": true
723
+ },
724
+ "32087": {
725
+ "content": "<extra_id_12>",
726
+ "lstrip": false,
727
+ "normalized": false,
728
+ "rstrip": false,
729
+ "single_word": false,
730
+ "special": true
731
+ },
732
+ "32088": {
733
+ "content": "<extra_id_11>",
734
+ "lstrip": false,
735
+ "normalized": false,
736
+ "rstrip": false,
737
+ "single_word": false,
738
+ "special": true
739
+ },
740
+ "32089": {
741
+ "content": "<extra_id_10>",
742
+ "lstrip": false,
743
+ "normalized": false,
744
+ "rstrip": false,
745
+ "single_word": false,
746
+ "special": true
747
+ },
748
+ "32090": {
749
+ "content": "<extra_id_9>",
750
+ "lstrip": false,
751
+ "normalized": false,
752
+ "rstrip": false,
753
+ "single_word": false,
754
+ "special": true
755
+ },
756
+ "32091": {
757
+ "content": "<extra_id_8>",
758
+ "lstrip": false,
759
+ "normalized": false,
760
+ "rstrip": false,
761
+ "single_word": false,
762
+ "special": true
763
+ },
764
+ "32092": {
765
+ "content": "<extra_id_7>",
766
+ "lstrip": false,
767
+ "normalized": false,
768
+ "rstrip": false,
769
+ "single_word": false,
770
+ "special": true
771
+ },
772
+ "32093": {
773
+ "content": "<extra_id_6>",
774
+ "lstrip": false,
775
+ "normalized": false,
776
+ "rstrip": false,
777
+ "single_word": false,
778
+ "special": true
779
+ },
780
+ "32094": {
781
+ "content": "<extra_id_5>",
782
+ "lstrip": false,
783
+ "normalized": false,
784
+ "rstrip": false,
785
+ "single_word": false,
786
+ "special": true
787
+ },
788
+ "32095": {
789
+ "content": "<extra_id_4>",
790
+ "lstrip": false,
791
+ "normalized": false,
792
+ "rstrip": false,
793
+ "single_word": false,
794
+ "special": true
795
+ },
796
+ "32096": {
797
+ "content": "<extra_id_3>",
798
+ "lstrip": false,
799
+ "normalized": false,
800
+ "rstrip": false,
801
+ "single_word": false,
802
+ "special": true
803
+ },
804
+ "32097": {
805
+ "content": "<extra_id_2>",
806
+ "lstrip": false,
807
+ "normalized": false,
808
+ "rstrip": false,
809
+ "single_word": false,
810
+ "special": true
811
+ },
812
+ "32098": {
813
+ "content": "<extra_id_1>",
814
+ "lstrip": false,
815
+ "normalized": false,
816
+ "rstrip": false,
817
+ "single_word": false,
818
+ "special": true
819
+ },
820
+ "32099": {
821
+ "content": "<extra_id_0>",
822
+ "lstrip": false,
823
+ "normalized": false,
824
+ "rstrip": false,
825
+ "single_word": false,
826
+ "special": true
827
+ }
828
+ },
829
+ "additional_special_tokens": [
830
+ "<extra_id_0>",
831
+ "<extra_id_1>",
832
+ "<extra_id_2>",
833
+ "<extra_id_3>",
834
+ "<extra_id_4>",
835
+ "<extra_id_5>",
836
+ "<extra_id_6>",
837
+ "<extra_id_7>",
838
+ "<extra_id_8>",
839
+ "<extra_id_9>",
840
+ "<extra_id_10>",
841
+ "<extra_id_11>",
842
+ "<extra_id_12>",
843
+ "<extra_id_13>",
844
+ "<extra_id_14>",
845
+ "<extra_id_15>",
846
+ "<extra_id_16>",
847
+ "<extra_id_17>",
848
+ "<extra_id_18>",
849
+ "<extra_id_19>",
850
+ "<extra_id_20>",
851
+ "<extra_id_21>",
852
+ "<extra_id_22>",
853
+ "<extra_id_23>",
854
+ "<extra_id_24>",
855
+ "<extra_id_25>",
856
+ "<extra_id_26>",
857
+ "<extra_id_27>",
858
+ "<extra_id_28>",
859
+ "<extra_id_29>",
860
+ "<extra_id_30>",
861
+ "<extra_id_31>",
862
+ "<extra_id_32>",
863
+ "<extra_id_33>",
864
+ "<extra_id_34>",
865
+ "<extra_id_35>",
866
+ "<extra_id_36>",
867
+ "<extra_id_37>",
868
+ "<extra_id_38>",
869
+ "<extra_id_39>",
870
+ "<extra_id_40>",
871
+ "<extra_id_41>",
872
+ "<extra_id_42>",
873
+ "<extra_id_43>",
874
+ "<extra_id_44>",
875
+ "<extra_id_45>",
876
+ "<extra_id_46>",
877
+ "<extra_id_47>",
878
+ "<extra_id_48>",
879
+ "<extra_id_49>",
880
+ "<extra_id_50>",
881
+ "<extra_id_51>",
882
+ "<extra_id_52>",
883
+ "<extra_id_53>",
884
+ "<extra_id_54>",
885
+ "<extra_id_55>",
886
+ "<extra_id_56>",
887
+ "<extra_id_57>",
888
+ "<extra_id_58>",
889
+ "<extra_id_59>",
890
+ "<extra_id_60>",
891
+ "<extra_id_61>",
892
+ "<extra_id_62>",
893
+ "<extra_id_63>",
894
+ "<extra_id_64>",
895
+ "<extra_id_65>",
896
+ "<extra_id_66>",
897
+ "<extra_id_67>",
898
+ "<extra_id_68>",
899
+ "<extra_id_69>",
900
+ "<extra_id_70>",
901
+ "<extra_id_71>",
902
+ "<extra_id_72>",
903
+ "<extra_id_73>",
904
+ "<extra_id_74>",
905
+ "<extra_id_75>",
906
+ "<extra_id_76>",
907
+ "<extra_id_77>",
908
+ "<extra_id_78>",
909
+ "<extra_id_79>",
910
+ "<extra_id_80>",
911
+ "<extra_id_81>",
912
+ "<extra_id_82>",
913
+ "<extra_id_83>",
914
+ "<extra_id_84>",
915
+ "<extra_id_85>",
916
+ "<extra_id_86>",
917
+ "<extra_id_87>",
918
+ "<extra_id_88>",
919
+ "<extra_id_89>",
920
+ "<extra_id_90>",
921
+ "<extra_id_91>",
922
+ "<extra_id_92>",
923
+ "<extra_id_93>",
924
+ "<extra_id_94>",
925
+ "<extra_id_95>",
926
+ "<extra_id_96>",
927
+ "<extra_id_97>",
928
+ "<extra_id_98>",
929
+ "<extra_id_99>"
930
+ ],
931
+ "clean_up_tokenization_spaces": false,
932
+ "eos_token": "</s>",
933
+ "extra_ids": 100,
934
+ "extra_special_tokens": {},
935
+ "model_max_length": 1000000000000000019884624838656,
936
+ "pad_token": "<pad>",
937
+ "tokenizer_class": "T5Tokenizer",
938
+ "unk_token": "<unk>"
939
+ }
fp16/tokenizer_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_name": "google-t5/t5-base",
3
+ "max_length": 77,
4
+ "vocab_size": 32100,
5
+ "pad_token_id": 0,
6
+ "eos_token_id": 1
7
+ }
fp16/vision_encoder.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ea8535c5bc55fbcc62fba78774e34c4ba59dc721275194efda73b26aee2eead9
3
+ size 3098779
fp16/vision_encoder.onnx.data ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e8b15b05f71b646b454bf77037c82cd1335917f39b2f847baaf5cb4d20880ee9
3
+ size 1268842496
fp32/dacvae_decoder.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:be13c0842945293ca0f5d77514783dc126007c5b490f00444c1c29f44f7035df
3
+ size 1103715
fp32/dacvae_decoder.onnx.data ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e1dba74364b1d721a4d7f921b18e58577c4f7375c3e2f82e46bf4898ecd61cba
3
+ size 320536576
fp32/dacvae_encoder.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:97a859c46e43a78dec7a974dcaaf433890e78bb2ac3b4dc5b498da414b86580b
3
+ size 866783
fp32/dacvae_encoder.onnx.data ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:68db13ee862cd7b5869176e5f080b2f88bc1a97711babd59d55ed93b26b047f6
3
+ size 110231552
fp32/dit_single_step.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:03568412927cffbef44e99fc666b7fc0348807ad7a4852fe8ec8c8f272846f83
3
+ size 5115331
fp32/dit_single_step.onnx.data ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cc5380da46b3962029b0d1c604f7ceff527ebe50e13190b7f674b91f29cf0072
3
+ size 11755978752
fp32/t5_encoder.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1776c8bce2eb4dbcd297bde499ea780253bc09677cfa8e18baef383859437e1c
3
+ size 1110394
fp32/t5_encoder.onnx.data ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9a91e63b28acfc81e85f659309625b54b5bf0c2e88161025d3a3b8580d4e20c8
3
+ size 438566912
fp32/tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<extra_id_0>",
4
+ "<extra_id_1>",
5
+ "<extra_id_2>",
6
+ "<extra_id_3>",
7
+ "<extra_id_4>",
8
+ "<extra_id_5>",
9
+ "<extra_id_6>",
10
+ "<extra_id_7>",
11
+ "<extra_id_8>",
12
+ "<extra_id_9>",
13
+ "<extra_id_10>",
14
+ "<extra_id_11>",
15
+ "<extra_id_12>",
16
+ "<extra_id_13>",
17
+ "<extra_id_14>",
18
+ "<extra_id_15>",
19
+ "<extra_id_16>",
20
+ "<extra_id_17>",
21
+ "<extra_id_18>",
22
+ "<extra_id_19>",
23
+ "<extra_id_20>",
24
+ "<extra_id_21>",
25
+ "<extra_id_22>",
26
+ "<extra_id_23>",
27
+ "<extra_id_24>",
28
+ "<extra_id_25>",
29
+ "<extra_id_26>",
30
+ "<extra_id_27>",
31
+ "<extra_id_28>",
32
+ "<extra_id_29>",
33
+ "<extra_id_30>",
34
+ "<extra_id_31>",
35
+ "<extra_id_32>",
36
+ "<extra_id_33>",
37
+ "<extra_id_34>",
38
+ "<extra_id_35>",
39
+ "<extra_id_36>",
40
+ "<extra_id_37>",
41
+ "<extra_id_38>",
42
+ "<extra_id_39>",
43
+ "<extra_id_40>",
44
+ "<extra_id_41>",
45
+ "<extra_id_42>",
46
+ "<extra_id_43>",
47
+ "<extra_id_44>",
48
+ "<extra_id_45>",
49
+ "<extra_id_46>",
50
+ "<extra_id_47>",
51
+ "<extra_id_48>",
52
+ "<extra_id_49>",
53
+ "<extra_id_50>",
54
+ "<extra_id_51>",
55
+ "<extra_id_52>",
56
+ "<extra_id_53>",
57
+ "<extra_id_54>",
58
+ "<extra_id_55>",
59
+ "<extra_id_56>",
60
+ "<extra_id_57>",
61
+ "<extra_id_58>",
62
+ "<extra_id_59>",
63
+ "<extra_id_60>",
64
+ "<extra_id_61>",
65
+ "<extra_id_62>",
66
+ "<extra_id_63>",
67
+ "<extra_id_64>",
68
+ "<extra_id_65>",
69
+ "<extra_id_66>",
70
+ "<extra_id_67>",
71
+ "<extra_id_68>",
72
+ "<extra_id_69>",
73
+ "<extra_id_70>",
74
+ "<extra_id_71>",
75
+ "<extra_id_72>",
76
+ "<extra_id_73>",
77
+ "<extra_id_74>",
78
+ "<extra_id_75>",
79
+ "<extra_id_76>",
80
+ "<extra_id_77>",
81
+ "<extra_id_78>",
82
+ "<extra_id_79>",
83
+ "<extra_id_80>",
84
+ "<extra_id_81>",
85
+ "<extra_id_82>",
86
+ "<extra_id_83>",
87
+ "<extra_id_84>",
88
+ "<extra_id_85>",
89
+ "<extra_id_86>",
90
+ "<extra_id_87>",
91
+ "<extra_id_88>",
92
+ "<extra_id_89>",
93
+ "<extra_id_90>",
94
+ "<extra_id_91>",
95
+ "<extra_id_92>",
96
+ "<extra_id_93>",
97
+ "<extra_id_94>",
98
+ "<extra_id_95>",
99
+ "<extra_id_96>",
100
+ "<extra_id_97>",
101
+ "<extra_id_98>",
102
+ "<extra_id_99>"
103
+ ],
104
+ "eos_token": "</s>",
105
+ "pad_token": "<pad>",
106
+ "unk_token": "<unk>"
107
+ }
fp32/tokenizer/spiece.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d60acb128cf7b7f2536e8f38a5b18a05535c9e14c7a355904270e15b0945ea86
3
+ size 791656
fp32/tokenizer/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
fp32/tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,939 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": null,
3
+ "added_tokens_decoder": {
4
+ "0": {
5
+ "content": "<pad>",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "1": {
13
+ "content": "</s>",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "2": {
21
+ "content": "<unk>",
22
+ "lstrip": false,
23
+ "normalized": false,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ },
28
+ "32000": {
29
+ "content": "<extra_id_99>",
30
+ "lstrip": false,
31
+ "normalized": false,
32
+ "rstrip": false,
33
+ "single_word": false,
34
+ "special": true
35
+ },
36
+ "32001": {
37
+ "content": "<extra_id_98>",
38
+ "lstrip": false,
39
+ "normalized": false,
40
+ "rstrip": false,
41
+ "single_word": false,
42
+ "special": true
43
+ },
44
+ "32002": {
45
+ "content": "<extra_id_97>",
46
+ "lstrip": false,
47
+ "normalized": false,
48
+ "rstrip": false,
49
+ "single_word": false,
50
+ "special": true
51
+ },
52
+ "32003": {
53
+ "content": "<extra_id_96>",
54
+ "lstrip": false,
55
+ "normalized": false,
56
+ "rstrip": false,
57
+ "single_word": false,
58
+ "special": true
59
+ },
60
+ "32004": {
61
+ "content": "<extra_id_95>",
62
+ "lstrip": false,
63
+ "normalized": false,
64
+ "rstrip": false,
65
+ "single_word": false,
66
+ "special": true
67
+ },
68
+ "32005": {
69
+ "content": "<extra_id_94>",
70
+ "lstrip": false,
71
+ "normalized": false,
72
+ "rstrip": false,
73
+ "single_word": false,
74
+ "special": true
75
+ },
76
+ "32006": {
77
+ "content": "<extra_id_93>",
78
+ "lstrip": false,
79
+ "normalized": false,
80
+ "rstrip": false,
81
+ "single_word": false,
82
+ "special": true
83
+ },
84
+ "32007": {
85
+ "content": "<extra_id_92>",
86
+ "lstrip": false,
87
+ "normalized": false,
88
+ "rstrip": false,
89
+ "single_word": false,
90
+ "special": true
91
+ },
92
+ "32008": {
93
+ "content": "<extra_id_91>",
94
+ "lstrip": false,
95
+ "normalized": false,
96
+ "rstrip": false,
97
+ "single_word": false,
98
+ "special": true
99
+ },
100
+ "32009": {
101
+ "content": "<extra_id_90>",
102
+ "lstrip": false,
103
+ "normalized": false,
104
+ "rstrip": false,
105
+ "single_word": false,
106
+ "special": true
107
+ },
108
+ "32010": {
109
+ "content": "<extra_id_89>",
110
+ "lstrip": false,
111
+ "normalized": false,
112
+ "rstrip": false,
113
+ "single_word": false,
114
+ "special": true
115
+ },
116
+ "32011": {
117
+ "content": "<extra_id_88>",
118
+ "lstrip": false,
119
+ "normalized": false,
120
+ "rstrip": false,
121
+ "single_word": false,
122
+ "special": true
123
+ },
124
+ "32012": {
125
+ "content": "<extra_id_87>",
126
+ "lstrip": false,
127
+ "normalized": false,
128
+ "rstrip": false,
129
+ "single_word": false,
130
+ "special": true
131
+ },
132
+ "32013": {
133
+ "content": "<extra_id_86>",
134
+ "lstrip": false,
135
+ "normalized": false,
136
+ "rstrip": false,
137
+ "single_word": false,
138
+ "special": true
139
+ },
140
+ "32014": {
141
+ "content": "<extra_id_85>",
142
+ "lstrip": false,
143
+ "normalized": false,
144
+ "rstrip": false,
145
+ "single_word": false,
146
+ "special": true
147
+ },
148
+ "32015": {
149
+ "content": "<extra_id_84>",
150
+ "lstrip": false,
151
+ "normalized": false,
152
+ "rstrip": false,
153
+ "single_word": false,
154
+ "special": true
155
+ },
156
+ "32016": {
157
+ "content": "<extra_id_83>",
158
+ "lstrip": false,
159
+ "normalized": false,
160
+ "rstrip": false,
161
+ "single_word": false,
162
+ "special": true
163
+ },
164
+ "32017": {
165
+ "content": "<extra_id_82>",
166
+ "lstrip": false,
167
+ "normalized": false,
168
+ "rstrip": false,
169
+ "single_word": false,
170
+ "special": true
171
+ },
172
+ "32018": {
173
+ "content": "<extra_id_81>",
174
+ "lstrip": false,
175
+ "normalized": false,
176
+ "rstrip": false,
177
+ "single_word": false,
178
+ "special": true
179
+ },
180
+ "32019": {
181
+ "content": "<extra_id_80>",
182
+ "lstrip": false,
183
+ "normalized": false,
184
+ "rstrip": false,
185
+ "single_word": false,
186
+ "special": true
187
+ },
188
+ "32020": {
189
+ "content": "<extra_id_79>",
190
+ "lstrip": false,
191
+ "normalized": false,
192
+ "rstrip": false,
193
+ "single_word": false,
194
+ "special": true
195
+ },
196
+ "32021": {
197
+ "content": "<extra_id_78>",
198
+ "lstrip": false,
199
+ "normalized": false,
200
+ "rstrip": false,
201
+ "single_word": false,
202
+ "special": true
203
+ },
204
+ "32022": {
205
+ "content": "<extra_id_77>",
206
+ "lstrip": false,
207
+ "normalized": false,
208
+ "rstrip": false,
209
+ "single_word": false,
210
+ "special": true
211
+ },
212
+ "32023": {
213
+ "content": "<extra_id_76>",
214
+ "lstrip": false,
215
+ "normalized": false,
216
+ "rstrip": false,
217
+ "single_word": false,
218
+ "special": true
219
+ },
220
+ "32024": {
221
+ "content": "<extra_id_75>",
222
+ "lstrip": false,
223
+ "normalized": false,
224
+ "rstrip": false,
225
+ "single_word": false,
226
+ "special": true
227
+ },
228
+ "32025": {
229
+ "content": "<extra_id_74>",
230
+ "lstrip": false,
231
+ "normalized": false,
232
+ "rstrip": false,
233
+ "single_word": false,
234
+ "special": true
235
+ },
236
+ "32026": {
237
+ "content": "<extra_id_73>",
238
+ "lstrip": false,
239
+ "normalized": false,
240
+ "rstrip": false,
241
+ "single_word": false,
242
+ "special": true
243
+ },
244
+ "32027": {
245
+ "content": "<extra_id_72>",
246
+ "lstrip": false,
247
+ "normalized": false,
248
+ "rstrip": false,
249
+ "single_word": false,
250
+ "special": true
251
+ },
252
+ "32028": {
253
+ "content": "<extra_id_71>",
254
+ "lstrip": false,
255
+ "normalized": false,
256
+ "rstrip": false,
257
+ "single_word": false,
258
+ "special": true
259
+ },
260
+ "32029": {
261
+ "content": "<extra_id_70>",
262
+ "lstrip": false,
263
+ "normalized": false,
264
+ "rstrip": false,
265
+ "single_word": false,
266
+ "special": true
267
+ },
268
+ "32030": {
269
+ "content": "<extra_id_69>",
270
+ "lstrip": false,
271
+ "normalized": false,
272
+ "rstrip": false,
273
+ "single_word": false,
274
+ "special": true
275
+ },
276
+ "32031": {
277
+ "content": "<extra_id_68>",
278
+ "lstrip": false,
279
+ "normalized": false,
280
+ "rstrip": false,
281
+ "single_word": false,
282
+ "special": true
283
+ },
284
+ "32032": {
285
+ "content": "<extra_id_67>",
286
+ "lstrip": false,
287
+ "normalized": false,
288
+ "rstrip": false,
289
+ "single_word": false,
290
+ "special": true
291
+ },
292
+ "32033": {
293
+ "content": "<extra_id_66>",
294
+ "lstrip": false,
295
+ "normalized": false,
296
+ "rstrip": false,
297
+ "single_word": false,
298
+ "special": true
299
+ },
300
+ "32034": {
301
+ "content": "<extra_id_65>",
302
+ "lstrip": false,
303
+ "normalized": false,
304
+ "rstrip": false,
305
+ "single_word": false,
306
+ "special": true
307
+ },
308
+ "32035": {
309
+ "content": "<extra_id_64>",
310
+ "lstrip": false,
311
+ "normalized": false,
312
+ "rstrip": false,
313
+ "single_word": false,
314
+ "special": true
315
+ },
316
+ "32036": {
317
+ "content": "<extra_id_63>",
318
+ "lstrip": false,
319
+ "normalized": false,
320
+ "rstrip": false,
321
+ "single_word": false,
322
+ "special": true
323
+ },
324
+ "32037": {
325
+ "content": "<extra_id_62>",
326
+ "lstrip": false,
327
+ "normalized": false,
328
+ "rstrip": false,
329
+ "single_word": false,
330
+ "special": true
331
+ },
332
+ "32038": {
333
+ "content": "<extra_id_61>",
334
+ "lstrip": false,
335
+ "normalized": false,
336
+ "rstrip": false,
337
+ "single_word": false,
338
+ "special": true
339
+ },
340
+ "32039": {
341
+ "content": "<extra_id_60>",
342
+ "lstrip": false,
343
+ "normalized": false,
344
+ "rstrip": false,
345
+ "single_word": false,
346
+ "special": true
347
+ },
348
+ "32040": {
349
+ "content": "<extra_id_59>",
350
+ "lstrip": false,
351
+ "normalized": false,
352
+ "rstrip": false,
353
+ "single_word": false,
354
+ "special": true
355
+ },
356
+ "32041": {
357
+ "content": "<extra_id_58>",
358
+ "lstrip": false,
359
+ "normalized": false,
360
+ "rstrip": false,
361
+ "single_word": false,
362
+ "special": true
363
+ },
364
+ "32042": {
365
+ "content": "<extra_id_57>",
366
+ "lstrip": false,
367
+ "normalized": false,
368
+ "rstrip": false,
369
+ "single_word": false,
370
+ "special": true
371
+ },
372
+ "32043": {
373
+ "content": "<extra_id_56>",
374
+ "lstrip": false,
375
+ "normalized": false,
376
+ "rstrip": false,
377
+ "single_word": false,
378
+ "special": true
379
+ },
380
+ "32044": {
381
+ "content": "<extra_id_55>",
382
+ "lstrip": false,
383
+ "normalized": false,
384
+ "rstrip": false,
385
+ "single_word": false,
386
+ "special": true
387
+ },
388
+ "32045": {
389
+ "content": "<extra_id_54>",
390
+ "lstrip": false,
391
+ "normalized": false,
392
+ "rstrip": false,
393
+ "single_word": false,
394
+ "special": true
395
+ },
396
+ "32046": {
397
+ "content": "<extra_id_53>",
398
+ "lstrip": false,
399
+ "normalized": false,
400
+ "rstrip": false,
401
+ "single_word": false,
402
+ "special": true
403
+ },
404
+ "32047": {
405
+ "content": "<extra_id_52>",
406
+ "lstrip": false,
407
+ "normalized": false,
408
+ "rstrip": false,
409
+ "single_word": false,
410
+ "special": true
411
+ },
412
+ "32048": {
413
+ "content": "<extra_id_51>",
414
+ "lstrip": false,
415
+ "normalized": false,
416
+ "rstrip": false,
417
+ "single_word": false,
418
+ "special": true
419
+ },
420
+ "32049": {
421
+ "content": "<extra_id_50>",
422
+ "lstrip": false,
423
+ "normalized": false,
424
+ "rstrip": false,
425
+ "single_word": false,
426
+ "special": true
427
+ },
428
+ "32050": {
429
+ "content": "<extra_id_49>",
430
+ "lstrip": false,
431
+ "normalized": false,
432
+ "rstrip": false,
433
+ "single_word": false,
434
+ "special": true
435
+ },
436
+ "32051": {
437
+ "content": "<extra_id_48>",
438
+ "lstrip": false,
439
+ "normalized": false,
440
+ "rstrip": false,
441
+ "single_word": false,
442
+ "special": true
443
+ },
444
+ "32052": {
445
+ "content": "<extra_id_47>",
446
+ "lstrip": false,
447
+ "normalized": false,
448
+ "rstrip": false,
449
+ "single_word": false,
450
+ "special": true
451
+ },
452
+ "32053": {
453
+ "content": "<extra_id_46>",
454
+ "lstrip": false,
455
+ "normalized": false,
456
+ "rstrip": false,
457
+ "single_word": false,
458
+ "special": true
459
+ },
460
+ "32054": {
461
+ "content": "<extra_id_45>",
462
+ "lstrip": false,
463
+ "normalized": false,
464
+ "rstrip": false,
465
+ "single_word": false,
466
+ "special": true
467
+ },
468
+ "32055": {
469
+ "content": "<extra_id_44>",
470
+ "lstrip": false,
471
+ "normalized": false,
472
+ "rstrip": false,
473
+ "single_word": false,
474
+ "special": true
475
+ },
476
+ "32056": {
477
+ "content": "<extra_id_43>",
478
+ "lstrip": false,
479
+ "normalized": false,
480
+ "rstrip": false,
481
+ "single_word": false,
482
+ "special": true
483
+ },
484
+ "32057": {
485
+ "content": "<extra_id_42>",
486
+ "lstrip": false,
487
+ "normalized": false,
488
+ "rstrip": false,
489
+ "single_word": false,
490
+ "special": true
491
+ },
492
+ "32058": {
493
+ "content": "<extra_id_41>",
494
+ "lstrip": false,
495
+ "normalized": false,
496
+ "rstrip": false,
497
+ "single_word": false,
498
+ "special": true
499
+ },
500
+ "32059": {
501
+ "content": "<extra_id_40>",
502
+ "lstrip": false,
503
+ "normalized": false,
504
+ "rstrip": false,
505
+ "single_word": false,
506
+ "special": true
507
+ },
508
+ "32060": {
509
+ "content": "<extra_id_39>",
510
+ "lstrip": false,
511
+ "normalized": false,
512
+ "rstrip": false,
513
+ "single_word": false,
514
+ "special": true
515
+ },
516
+ "32061": {
517
+ "content": "<extra_id_38>",
518
+ "lstrip": false,
519
+ "normalized": false,
520
+ "rstrip": false,
521
+ "single_word": false,
522
+ "special": true
523
+ },
524
+ "32062": {
525
+ "content": "<extra_id_37>",
526
+ "lstrip": false,
527
+ "normalized": false,
528
+ "rstrip": false,
529
+ "single_word": false,
530
+ "special": true
531
+ },
532
+ "32063": {
533
+ "content": "<extra_id_36>",
534
+ "lstrip": false,
535
+ "normalized": false,
536
+ "rstrip": false,
537
+ "single_word": false,
538
+ "special": true
539
+ },
540
+ "32064": {
541
+ "content": "<extra_id_35>",
542
+ "lstrip": false,
543
+ "normalized": false,
544
+ "rstrip": false,
545
+ "single_word": false,
546
+ "special": true
547
+ },
548
+ "32065": {
549
+ "content": "<extra_id_34>",
550
+ "lstrip": false,
551
+ "normalized": false,
552
+ "rstrip": false,
553
+ "single_word": false,
554
+ "special": true
555
+ },
556
+ "32066": {
557
+ "content": "<extra_id_33>",
558
+ "lstrip": false,
559
+ "normalized": false,
560
+ "rstrip": false,
561
+ "single_word": false,
562
+ "special": true
563
+ },
564
+ "32067": {
565
+ "content": "<extra_id_32>",
566
+ "lstrip": false,
567
+ "normalized": false,
568
+ "rstrip": false,
569
+ "single_word": false,
570
+ "special": true
571
+ },
572
+ "32068": {
573
+ "content": "<extra_id_31>",
574
+ "lstrip": false,
575
+ "normalized": false,
576
+ "rstrip": false,
577
+ "single_word": false,
578
+ "special": true
579
+ },
580
+ "32069": {
581
+ "content": "<extra_id_30>",
582
+ "lstrip": false,
583
+ "normalized": false,
584
+ "rstrip": false,
585
+ "single_word": false,
586
+ "special": true
587
+ },
588
+ "32070": {
589
+ "content": "<extra_id_29>",
590
+ "lstrip": false,
591
+ "normalized": false,
592
+ "rstrip": false,
593
+ "single_word": false,
594
+ "special": true
595
+ },
596
+ "32071": {
597
+ "content": "<extra_id_28>",
598
+ "lstrip": false,
599
+ "normalized": false,
600
+ "rstrip": false,
601
+ "single_word": false,
602
+ "special": true
603
+ },
604
+ "32072": {
605
+ "content": "<extra_id_27>",
606
+ "lstrip": false,
607
+ "normalized": false,
608
+ "rstrip": false,
609
+ "single_word": false,
610
+ "special": true
611
+ },
612
+ "32073": {
613
+ "content": "<extra_id_26>",
614
+ "lstrip": false,
615
+ "normalized": false,
616
+ "rstrip": false,
617
+ "single_word": false,
618
+ "special": true
619
+ },
620
+ "32074": {
621
+ "content": "<extra_id_25>",
622
+ "lstrip": false,
623
+ "normalized": false,
624
+ "rstrip": false,
625
+ "single_word": false,
626
+ "special": true
627
+ },
628
+ "32075": {
629
+ "content": "<extra_id_24>",
630
+ "lstrip": false,
631
+ "normalized": false,
632
+ "rstrip": false,
633
+ "single_word": false,
634
+ "special": true
635
+ },
636
+ "32076": {
637
+ "content": "<extra_id_23>",
638
+ "lstrip": false,
639
+ "normalized": false,
640
+ "rstrip": false,
641
+ "single_word": false,
642
+ "special": true
643
+ },
644
+ "32077": {
645
+ "content": "<extra_id_22>",
646
+ "lstrip": false,
647
+ "normalized": false,
648
+ "rstrip": false,
649
+ "single_word": false,
650
+ "special": true
651
+ },
652
+ "32078": {
653
+ "content": "<extra_id_21>",
654
+ "lstrip": false,
655
+ "normalized": false,
656
+ "rstrip": false,
657
+ "single_word": false,
658
+ "special": true
659
+ },
660
+ "32079": {
661
+ "content": "<extra_id_20>",
662
+ "lstrip": false,
663
+ "normalized": false,
664
+ "rstrip": false,
665
+ "single_word": false,
666
+ "special": true
667
+ },
668
+ "32080": {
669
+ "content": "<extra_id_19>",
670
+ "lstrip": false,
671
+ "normalized": false,
672
+ "rstrip": false,
673
+ "single_word": false,
674
+ "special": true
675
+ },
676
+ "32081": {
677
+ "content": "<extra_id_18>",
678
+ "lstrip": false,
679
+ "normalized": false,
680
+ "rstrip": false,
681
+ "single_word": false,
682
+ "special": true
683
+ },
684
+ "32082": {
685
+ "content": "<extra_id_17>",
686
+ "lstrip": false,
687
+ "normalized": false,
688
+ "rstrip": false,
689
+ "single_word": false,
690
+ "special": true
691
+ },
692
+ "32083": {
693
+ "content": "<extra_id_16>",
694
+ "lstrip": false,
695
+ "normalized": false,
696
+ "rstrip": false,
697
+ "single_word": false,
698
+ "special": true
699
+ },
700
+ "32084": {
701
+ "content": "<extra_id_15>",
702
+ "lstrip": false,
703
+ "normalized": false,
704
+ "rstrip": false,
705
+ "single_word": false,
706
+ "special": true
707
+ },
708
+ "32085": {
709
+ "content": "<extra_id_14>",
710
+ "lstrip": false,
711
+ "normalized": false,
712
+ "rstrip": false,
713
+ "single_word": false,
714
+ "special": true
715
+ },
716
+ "32086": {
717
+ "content": "<extra_id_13>",
718
+ "lstrip": false,
719
+ "normalized": false,
720
+ "rstrip": false,
721
+ "single_word": false,
722
+ "special": true
723
+ },
724
+ "32087": {
725
+ "content": "<extra_id_12>",
726
+ "lstrip": false,
727
+ "normalized": false,
728
+ "rstrip": false,
729
+ "single_word": false,
730
+ "special": true
731
+ },
732
+ "32088": {
733
+ "content": "<extra_id_11>",
734
+ "lstrip": false,
735
+ "normalized": false,
736
+ "rstrip": false,
737
+ "single_word": false,
738
+ "special": true
739
+ },
740
+ "32089": {
741
+ "content": "<extra_id_10>",
742
+ "lstrip": false,
743
+ "normalized": false,
744
+ "rstrip": false,
745
+ "single_word": false,
746
+ "special": true
747
+ },
748
+ "32090": {
749
+ "content": "<extra_id_9>",
750
+ "lstrip": false,
751
+ "normalized": false,
752
+ "rstrip": false,
753
+ "single_word": false,
754
+ "special": true
755
+ },
756
+ "32091": {
757
+ "content": "<extra_id_8>",
758
+ "lstrip": false,
759
+ "normalized": false,
760
+ "rstrip": false,
761
+ "single_word": false,
762
+ "special": true
763
+ },
764
+ "32092": {
765
+ "content": "<extra_id_7>",
766
+ "lstrip": false,
767
+ "normalized": false,
768
+ "rstrip": false,
769
+ "single_word": false,
770
+ "special": true
771
+ },
772
+ "32093": {
773
+ "content": "<extra_id_6>",
774
+ "lstrip": false,
775
+ "normalized": false,
776
+ "rstrip": false,
777
+ "single_word": false,
778
+ "special": true
779
+ },
780
+ "32094": {
781
+ "content": "<extra_id_5>",
782
+ "lstrip": false,
783
+ "normalized": false,
784
+ "rstrip": false,
785
+ "single_word": false,
786
+ "special": true
787
+ },
788
+ "32095": {
789
+ "content": "<extra_id_4>",
790
+ "lstrip": false,
791
+ "normalized": false,
792
+ "rstrip": false,
793
+ "single_word": false,
794
+ "special": true
795
+ },
796
+ "32096": {
797
+ "content": "<extra_id_3>",
798
+ "lstrip": false,
799
+ "normalized": false,
800
+ "rstrip": false,
801
+ "single_word": false,
802
+ "special": true
803
+ },
804
+ "32097": {
805
+ "content": "<extra_id_2>",
806
+ "lstrip": false,
807
+ "normalized": false,
808
+ "rstrip": false,
809
+ "single_word": false,
810
+ "special": true
811
+ },
812
+ "32098": {
813
+ "content": "<extra_id_1>",
814
+ "lstrip": false,
815
+ "normalized": false,
816
+ "rstrip": false,
817
+ "single_word": false,
818
+ "special": true
819
+ },
820
+ "32099": {
821
+ "content": "<extra_id_0>",
822
+ "lstrip": false,
823
+ "normalized": false,
824
+ "rstrip": false,
825
+ "single_word": false,
826
+ "special": true
827
+ }
828
+ },
829
+ "additional_special_tokens": [
830
+ "<extra_id_0>",
831
+ "<extra_id_1>",
832
+ "<extra_id_2>",
833
+ "<extra_id_3>",
834
+ "<extra_id_4>",
835
+ "<extra_id_5>",
836
+ "<extra_id_6>",
837
+ "<extra_id_7>",
838
+ "<extra_id_8>",
839
+ "<extra_id_9>",
840
+ "<extra_id_10>",
841
+ "<extra_id_11>",
842
+ "<extra_id_12>",
843
+ "<extra_id_13>",
844
+ "<extra_id_14>",
845
+ "<extra_id_15>",
846
+ "<extra_id_16>",
847
+ "<extra_id_17>",
848
+ "<extra_id_18>",
849
+ "<extra_id_19>",
850
+ "<extra_id_20>",
851
+ "<extra_id_21>",
852
+ "<extra_id_22>",
853
+ "<extra_id_23>",
854
+ "<extra_id_24>",
855
+ "<extra_id_25>",
856
+ "<extra_id_26>",
857
+ "<extra_id_27>",
858
+ "<extra_id_28>",
859
+ "<extra_id_29>",
860
+ "<extra_id_30>",
861
+ "<extra_id_31>",
862
+ "<extra_id_32>",
863
+ "<extra_id_33>",
864
+ "<extra_id_34>",
865
+ "<extra_id_35>",
866
+ "<extra_id_36>",
867
+ "<extra_id_37>",
868
+ "<extra_id_38>",
869
+ "<extra_id_39>",
870
+ "<extra_id_40>",
871
+ "<extra_id_41>",
872
+ "<extra_id_42>",
873
+ "<extra_id_43>",
874
+ "<extra_id_44>",
875
+ "<extra_id_45>",
876
+ "<extra_id_46>",
877
+ "<extra_id_47>",
878
+ "<extra_id_48>",
879
+ "<extra_id_49>",
880
+ "<extra_id_50>",
881
+ "<extra_id_51>",
882
+ "<extra_id_52>",
883
+ "<extra_id_53>",
884
+ "<extra_id_54>",
885
+ "<extra_id_55>",
886
+ "<extra_id_56>",
887
+ "<extra_id_57>",
888
+ "<extra_id_58>",
889
+ "<extra_id_59>",
890
+ "<extra_id_60>",
891
+ "<extra_id_61>",
892
+ "<extra_id_62>",
893
+ "<extra_id_63>",
894
+ "<extra_id_64>",
895
+ "<extra_id_65>",
896
+ "<extra_id_66>",
897
+ "<extra_id_67>",
898
+ "<extra_id_68>",
899
+ "<extra_id_69>",
900
+ "<extra_id_70>",
901
+ "<extra_id_71>",
902
+ "<extra_id_72>",
903
+ "<extra_id_73>",
904
+ "<extra_id_74>",
905
+ "<extra_id_75>",
906
+ "<extra_id_76>",
907
+ "<extra_id_77>",
908
+ "<extra_id_78>",
909
+ "<extra_id_79>",
910
+ "<extra_id_80>",
911
+ "<extra_id_81>",
912
+ "<extra_id_82>",
913
+ "<extra_id_83>",
914
+ "<extra_id_84>",
915
+ "<extra_id_85>",
916
+ "<extra_id_86>",
917
+ "<extra_id_87>",
918
+ "<extra_id_88>",
919
+ "<extra_id_89>",
920
+ "<extra_id_90>",
921
+ "<extra_id_91>",
922
+ "<extra_id_92>",
923
+ "<extra_id_93>",
924
+ "<extra_id_94>",
925
+ "<extra_id_95>",
926
+ "<extra_id_96>",
927
+ "<extra_id_97>",
928
+ "<extra_id_98>",
929
+ "<extra_id_99>"
930
+ ],
931
+ "clean_up_tokenization_spaces": false,
932
+ "eos_token": "</s>",
933
+ "extra_ids": 100,
934
+ "extra_special_tokens": {},
935
+ "model_max_length": 1000000000000000019884624838656,
936
+ "pad_token": "<pad>",
937
+ "tokenizer_class": "T5Tokenizer",
938
+ "unk_token": "<unk>"
939
+ }
fp32/tokenizer_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_name": "google-t5/t5-base",
3
+ "max_length": 77,
4
+ "vocab_size": 32100,
5
+ "pad_token_id": 0,
6
+ "eos_token_id": 1
7
+ }
fp32/vision_encoder.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c7011e2c8f01fff26d89ba2515005f34a505d5be381ddc910d089b57cb90b605
3
+ size 2822352
fp32/vision_encoder.onnx.data ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e8b15b05f71b646b454bf77037c82cd1335917f39b2f847baaf5cb4d20880ee9
3
+ size 1268842496
onnx_export/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # ONNX Export utilities for SAM Audio
onnx_export/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (148 Bytes). View file
 
onnx_export/__pycache__/export_dacvae.cpython-312.pyc ADDED
Binary file (15.3 kB). View file
 
onnx_export/__pycache__/export_dit.cpython-312.pyc ADDED
Binary file (23.7 kB). View file
 
onnx_export/__pycache__/export_peaframe.cpython-312.pyc ADDED
Binary file (11.7 kB). View file
 
onnx_export/__pycache__/export_t5.cpython-312.pyc ADDED
Binary file (11.1 kB). View file
 
onnx_export/__pycache__/export_vision.cpython-312.pyc ADDED
Binary file (5.27 kB). View file
 
onnx_export/__pycache__/quantize_large_model.cpython-312.pyc ADDED
Binary file (5.3 kB). View file
 
onnx_export/__pycache__/quantize_models.cpython-312.pyc ADDED
Binary file (11.1 kB). View file
 
onnx_export/__pycache__/standalone_config.cpython-312.pyc ADDED
Binary file (5.93 kB). View file
 
onnx_export/export_all.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Export all SAM Audio components to ONNX format.
4
+
5
+ This script exports:
6
+ 1. DACVAE encoder and decoder (audio codec)
7
+ 2. T5 text encoder
8
+ 3. DiT transformer (single-step for ODE solving)
9
+ 4. Vision encoder (CLIP-based, for video-guided separation)
10
+ python -m onnx_export.export_all --output-dir onnx_models --verify
11
+ """
12
+
13
+ import os
14
+ import argparse
15
+ import subprocess
16
+ import sys
17
+
18
+
19
+ def run_export(module: str, args: list[str]) -> bool:
20
+ """Run an export module with the given arguments."""
21
+ cmd = [sys.executable, "-m", module] + args
22
+ print(f"\n{'='*60}")
23
+ print(f"Running: {' '.join(cmd)}")
24
+ print(f"{'='*60}\n")
25
+
26
+ result = subprocess.run(cmd, cwd=os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
27
+ return result.returncode == 0
28
+
29
+
30
+ def main():
31
+ parser = argparse.ArgumentParser(description="Export all SAM Audio components to ONNX")
32
+ parser.add_argument(
33
+ "--output-dir",
34
+ type=str,
35
+ default="onnx_models",
36
+ help="Output directory for ONNX models",
37
+ )
38
+ parser.add_argument(
39
+ "--model",
40
+ type=str,
41
+ default="facebook/sam-audio-small",
42
+ help="SAM-Audio model ID (e.g., facebook/sam-audio-small, facebook/sam-audio-large, facebook/sam-audio-base-tv)",
43
+ )
44
+ parser.add_argument(
45
+ "--verify",
46
+ action="store_true",
47
+ help="Verify ONNX output matches PyTorch",
48
+ )
49
+ parser.add_argument(
50
+ "--skip-dacvae",
51
+ action="store_true",
52
+ help="Skip DACVAE export",
53
+ )
54
+ parser.add_argument(
55
+ "--skip-t5",
56
+ action="store_true",
57
+ help="Skip T5 export",
58
+ )
59
+ parser.add_argument(
60
+ "--skip-dit",
61
+ action="store_true",
62
+ help="Skip DiT export",
63
+ )
64
+ parser.add_argument(
65
+ "--skip-vision",
66
+ action="store_true",
67
+ help="Skip Vision encoder export",
68
+ )
69
+
70
+ args = parser.parse_args()
71
+
72
+ os.makedirs(args.output_dir, exist_ok=True)
73
+
74
+ results = {}
75
+
76
+ # Export DACVAE
77
+ if not args.skip_dacvae:
78
+ export_args = ["--output-dir", args.output_dir, "--model-id", args.model]
79
+ if args.verify:
80
+ export_args.append("--verify")
81
+ results["DACVAE"] = run_export("onnx_export.export_dacvae", export_args)
82
+
83
+ # Export T5 (always uses google-t5/t5-base, independent of SAM-Audio model)
84
+ if not args.skip_t5:
85
+ export_args = ["--output-dir", args.output_dir]
86
+ if args.verify:
87
+ export_args.append("--verify")
88
+ results["T5"] = run_export("onnx_export.export_t5", export_args)
89
+
90
+ # Export DiT
91
+ if not args.skip_dit:
92
+ export_args = ["--output-dir", args.output_dir, "--model-id", args.model]
93
+ if args.verify:
94
+ export_args.append("--verify")
95
+ results["DiT"] = run_export("onnx_export.export_dit", export_args)
96
+
97
+ # Export Vision Encoder
98
+ if not args.skip_vision:
99
+ export_args = ["--output", args.output_dir, "--model", args.model]
100
+ results["Vision"] = run_export("onnx_export.export_vision", export_args)
101
+
102
+ # Print summary
103
+ print(f"\n{'='*60}")
104
+ print("Export Summary")
105
+ print(f"{'='*60}")
106
+
107
+ all_success = True
108
+ for name, success in results.items():
109
+ status = "✓" if success else "✗"
110
+ print(f" {status} {name}")
111
+ if not success:
112
+ all_success = False
113
+
114
+ # List exported files
115
+ print(f"\nExported files in {args.output_dir}:")
116
+ for f in sorted(os.listdir(args.output_dir)):
117
+ path = os.path.join(args.output_dir, f)
118
+ if os.path.isfile(path):
119
+ size_mb = os.path.getsize(path) / (1024 * 1024)
120
+ print(f" {f}: {size_mb:.1f} MB")
121
+
122
+ if all_success:
123
+ print("\n✓ All exports completed successfully!")
124
+ else:
125
+ print("\n✗ Some exports failed")
126
+ sys.exit(1)
127
+
128
+
129
+ if __name__ == "__main__":
130
+ main()
onnx_export/export_dacvae.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Export DACVAE (audio codec) to ONNX format.
4
+
5
+ This exports the encoder and decoder separately:
6
+ - Encoder: audio waveform → latent features
7
+ - Decoder: latent features → audio waveform
8
+
9
+ Usage:
10
+ python -m onnx_export.export_dacvae --output-dir onnx_models --verify
11
+ """
12
+
13
+ import os
14
+ import argparse
15
+ import torch
16
+ import torch.nn as nn
17
+ import dacvae
18
+ from huggingface_hub import hf_hub_download
19
+
20
+
21
+ # Default DACVAE configuration (matches SAM Audio)
22
+ DEFAULT_CONFIG = {
23
+ "encoder_dim": 64,
24
+ "encoder_rates": [2, 8, 10, 12],
25
+ "latent_dim": 1024,
26
+ "decoder_dim": 1536,
27
+ "decoder_rates": [12, 10, 8, 2],
28
+ "n_codebooks": 16,
29
+ "codebook_size": 1024,
30
+ "codebook_dim": 128,
31
+ "quantizer_dropout": False,
32
+ "sample_rate": 48000,
33
+ }
34
+
35
+
36
+ class DACVAEEncoderWrapper(nn.Module):
37
+ """Wrapper for DACVAE encoder that outputs continuous latent features."""
38
+
39
+ def __init__(self, encoder, quantizer):
40
+ super().__init__()
41
+ self.encoder = encoder
42
+ self.in_proj = quantizer.in_proj
43
+
44
+ def forward(self, audio: torch.Tensor) -> torch.Tensor:
45
+ """
46
+ Encode audio to latent features.
47
+
48
+ Args:
49
+ audio: Input waveform, shape (batch, 1, samples)
50
+
51
+ Returns:
52
+ latent_features: Continuous latent mean, shape (batch, 128, time_steps)
53
+ """
54
+ x = self.encoder(audio)
55
+ # in_proj outputs 256 dim, chunk into mean and variance, use only mean
56
+ mean, _ = self.in_proj(x).chunk(2, dim=1)
57
+ return mean
58
+
59
+
60
+ class DACVAEDecoderWrapper(nn.Module):
61
+ """Wrapper for DACVAE decoder that takes continuous latent features."""
62
+
63
+ def __init__(self, decoder, quantizer):
64
+ super().__init__()
65
+ self.decoder = decoder
66
+ self.out_proj = quantizer.out_proj
67
+
68
+ def forward(self, latent_features: torch.Tensor) -> torch.Tensor:
69
+ """
70
+ Decode latent features to audio.
71
+
72
+ Args:
73
+ latent_features: Continuous latent, shape (batch, 128, time_steps)
74
+
75
+ Returns:
76
+ audio: Output waveform, shape (batch, 1, samples)
77
+ """
78
+ x = self.out_proj(latent_features)
79
+ return self.decoder(x)
80
+
81
+
82
+ def create_dacvae_model(model_id: str = "facebook/sam-audio-small") -> dacvae.DACVAE:
83
+ """
84
+ Create and load DACVAE model with weights from SAM Audio checkpoint.
85
+
86
+ This uses the standalone dacvae library, avoiding loading the full SAM Audio
87
+ model and its dependencies (vision encoder, imagebind, etc).
88
+ """
89
+ print(f"Creating DACVAE model...")
90
+
91
+ model = dacvae.DACVAE(
92
+ encoder_dim=DEFAULT_CONFIG["encoder_dim"],
93
+ encoder_rates=DEFAULT_CONFIG["encoder_rates"],
94
+ latent_dim=DEFAULT_CONFIG["latent_dim"],
95
+ decoder_dim=DEFAULT_CONFIG["decoder_dim"],
96
+ decoder_rates=DEFAULT_CONFIG["decoder_rates"],
97
+ n_codebooks=DEFAULT_CONFIG["n_codebooks"],
98
+ codebook_size=DEFAULT_CONFIG["codebook_size"],
99
+ codebook_dim=DEFAULT_CONFIG["codebook_dim"],
100
+ quantizer_dropout=DEFAULT_CONFIG["quantizer_dropout"],
101
+ sample_rate=DEFAULT_CONFIG["sample_rate"],
102
+ ).eval()
103
+
104
+ # Load weights from SAM Audio checkpoint
105
+ print(f"Downloading checkpoint from {model_id}...")
106
+ checkpoint_path = hf_hub_download(
107
+ repo_id=model_id,
108
+ filename="checkpoint.pt",
109
+ )
110
+
111
+ print("Loading DACVAE weights from checkpoint...")
112
+ state_dict = torch.load(
113
+ checkpoint_path,
114
+ map_location="cpu",
115
+ weights_only=True,
116
+ mmap=True, # Memory-efficient loading
117
+ )
118
+
119
+ # Extract only DACVAE weights (prefixed with "audio_codec.")
120
+ dacvae_state_dict = {}
121
+ for k, v in state_dict.items():
122
+ if k.startswith("audio_codec."):
123
+ new_key = k.replace("audio_codec.", "")
124
+ dacvae_state_dict[new_key] = v.clone()
125
+
126
+ # Load weights
127
+ model.load_state_dict(dacvae_state_dict, strict=False)
128
+
129
+ # Clear large checkpoint from memory
130
+ del state_dict
131
+
132
+ print(f" ✓ Loaded {len(dacvae_state_dict)} DACVAE weight tensors")
133
+
134
+ # Calculate hop_length for reference
135
+ import numpy as np
136
+ hop_length = int(np.prod(DEFAULT_CONFIG["encoder_rates"]))
137
+ model.hop_length = hop_length
138
+ model.sample_rate = DEFAULT_CONFIG["sample_rate"]
139
+
140
+ return model
141
+
142
+
143
+ def export_encoder(
144
+ dacvae_model: dacvae.DACVAE,
145
+ output_path: str,
146
+ opset_version: int = 21,
147
+ device: str = "cpu",
148
+ ) -> None:
149
+ """Export DACVAE encoder to ONNX."""
150
+ print(f"Exporting DACVAE encoder to {output_path}...")
151
+
152
+ wrapper = DACVAEEncoderWrapper(
153
+ dacvae_model.encoder,
154
+ dacvae_model.quantizer
155
+ ).eval().to(device)
156
+
157
+ # Sample input: 1 second of audio at 48kHz
158
+ sample_rate = DEFAULT_CONFIG["sample_rate"]
159
+ dummy_audio = torch.randn(1, 1, sample_rate, device=device)
160
+
161
+ torch.onnx.export(
162
+ wrapper,
163
+ (dummy_audio,),
164
+ output_path,
165
+ input_names=["audio"],
166
+ output_names=["latent_features"],
167
+ dynamic_axes={
168
+ "audio": {0: "batch", 2: "samples"},
169
+ "latent_features": {0: "batch", 2: "time_steps"},
170
+ },
171
+ opset_version=opset_version,
172
+ do_constant_folding=True,
173
+ dynamo=True,
174
+ external_data=True,
175
+ )
176
+
177
+ print(f" ✓ Encoder exported successfully")
178
+
179
+ # Validate
180
+ import onnx
181
+ # Load without external data to avoid OOM - we just need to validate structure
182
+ model = onnx.load(output_path, load_external_data=False)
183
+ onnx.checker.check_model(model, full_check=False)
184
+ print(f" ✓ ONNX model validation passed")
185
+
186
+
187
+ def export_decoder(
188
+ dacvae_model: dacvae.DACVAE,
189
+ output_path: str,
190
+ opset_version: int = 21,
191
+ device: str = "cpu",
192
+ ) -> None:
193
+ """Export DACVAE decoder to ONNX."""
194
+ print(f"Exporting DACVAE decoder to {output_path}...")
195
+
196
+ wrapper = DACVAEDecoderWrapper(
197
+ dacvae_model.decoder,
198
+ dacvae_model.quantizer
199
+ ).eval().to(device)
200
+
201
+ # Sample input: 25 time steps (1 second at 48kHz with hop_length=1920)
202
+ hop_length = int(__import__("numpy").prod(DEFAULT_CONFIG["encoder_rates"]))
203
+ time_steps = DEFAULT_CONFIG["sample_rate"] // hop_length
204
+ dummy_latent = torch.randn(1, 128, time_steps, device=device)
205
+
206
+ torch.onnx.export(
207
+ wrapper,
208
+ (dummy_latent,),
209
+ output_path,
210
+ input_names=["latent_features"],
211
+ output_names=["waveform"],
212
+ dynamic_axes={
213
+ "latent_features": {0: "batch", 2: "time_steps"},
214
+ "waveform": {0: "batch", 2: "samples"},
215
+ },
216
+ opset_version=opset_version,
217
+ do_constant_folding=True,
218
+ dynamo=True,
219
+ external_data=True,
220
+ )
221
+
222
+ print(f" ✓ Decoder exported successfully")
223
+
224
+ # Validate
225
+ import onnx
226
+ # Load without external data to avoid OOM - we just need to validate structure
227
+ model = onnx.load(output_path, load_external_data=False)
228
+ onnx.checker.check_model(model, full_check=False)
229
+ print(f" ✓ ONNX model validation passed")
230
+
231
+
232
+ def verify_encoder(
233
+ dacvae_model: dacvae.DACVAE,
234
+ onnx_path: str,
235
+ device: str = "cpu",
236
+ tolerance: float = 1e-4,
237
+ ) -> bool:
238
+ """Verify ONNX encoder output matches PyTorch."""
239
+ import onnxruntime as ort
240
+ import numpy as np
241
+
242
+ print("Verifying encoder output...")
243
+
244
+ wrapper = DACVAEEncoderWrapper(
245
+ dacvae_model.encoder,
246
+ dacvae_model.quantizer
247
+ ).eval().to(device)
248
+
249
+ # Test with random audio
250
+ sample_rate = DEFAULT_CONFIG["sample_rate"]
251
+ test_audio = torch.randn(1, 1, sample_rate * 2, device=device) # 2 seconds
252
+
253
+ # PyTorch output
254
+ with torch.no_grad():
255
+ pytorch_output = wrapper(test_audio).cpu().numpy()
256
+
257
+ # ONNX Runtime output
258
+ sess = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
259
+ onnx_output = sess.run(
260
+ ["latent_features"],
261
+ {"audio": test_audio.cpu().numpy()}
262
+ )[0]
263
+
264
+ # Compare
265
+ max_diff = np.abs(pytorch_output - onnx_output).max()
266
+ mean_diff = np.abs(pytorch_output - onnx_output).mean()
267
+
268
+ print(f" Max diff: {max_diff:.2e}, Mean diff: {mean_diff:.2e}")
269
+
270
+ if max_diff > tolerance:
271
+ print(f" ✗ Verification failed (tolerance: {tolerance})")
272
+ return False
273
+
274
+ print(f" ✓ Verification passed (tolerance: {tolerance})")
275
+ return True
276
+
277
+
278
+ def verify_decoder(
279
+ dacvae_model: dacvae.DACVAE,
280
+ onnx_path: str,
281
+ device: str = "cpu",
282
+ tolerance: float = 1e-3,
283
+ ) -> bool:
284
+ """Verify ONNX decoder output matches PyTorch."""
285
+ import onnxruntime as ort
286
+ import numpy as np
287
+
288
+ print("Verifying decoder output...")
289
+
290
+ wrapper = DACVAEDecoderWrapper(
291
+ dacvae_model.decoder,
292
+ dacvae_model.quantizer
293
+ ).eval().to(device)
294
+
295
+ # Test with random latent
296
+ hop_length = int(np.prod(DEFAULT_CONFIG["encoder_rates"]))
297
+ time_steps = DEFAULT_CONFIG["sample_rate"] // hop_length # 25 steps = 1 second
298
+ test_latent = torch.randn(1, 128, time_steps, device=device)
299
+
300
+ # PyTorch output
301
+ with torch.no_grad():
302
+ pytorch_output = wrapper(test_latent).cpu().numpy()
303
+
304
+ # ONNX Runtime output
305
+ sess = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
306
+ onnx_output = sess.run(
307
+ ["waveform"],
308
+ {"latent_features": test_latent.cpu().numpy()}
309
+ )[0]
310
+
311
+ # Compare
312
+ max_diff = np.abs(pytorch_output - onnx_output).max()
313
+ mean_diff = np.abs(pytorch_output - onnx_output).mean()
314
+
315
+ print(f" Max diff: {max_diff:.2e}, Mean diff: {mean_diff:.2e}")
316
+
317
+ if max_diff > tolerance:
318
+ print(f" ✗ Verification failed (tolerance: {tolerance})")
319
+ return False
320
+
321
+ print(f" ✓ Verification passed (tolerance: {tolerance})")
322
+ return True
323
+
324
+
325
+ def main():
326
+ parser = argparse.ArgumentParser(description="Export DACVAE to ONNX")
327
+ parser.add_argument(
328
+ "--model-id",
329
+ type=str,
330
+ default="facebook/sam-audio-small",
331
+ help="HuggingFace model ID (default: facebook/sam-audio-small)",
332
+ )
333
+ parser.add_argument(
334
+ "--output-dir",
335
+ type=str,
336
+ default="onnx_models",
337
+ help="Output directory for ONNX models",
338
+ )
339
+ parser.add_argument(
340
+ "--opset-version",
341
+ type=int,
342
+ default=18,
343
+ help="ONNX opset version (default: 18)",
344
+ )
345
+ parser.add_argument(
346
+ "--device",
347
+ type=str,
348
+ default="cpu",
349
+ help="Device to use for export (default: cpu)",
350
+ )
351
+ parser.add_argument(
352
+ "--verify",
353
+ action="store_true",
354
+ help="Verify ONNX output matches PyTorch",
355
+ )
356
+ parser.add_argument(
357
+ "--tolerance",
358
+ type=float,
359
+ default=1e-4,
360
+ help="Tolerance for verification (default: 1e-4)",
361
+ )
362
+ parser.add_argument(
363
+ "--encoder-only",
364
+ action="store_true",
365
+ help="Export only the encoder",
366
+ )
367
+ parser.add_argument(
368
+ "--decoder-only",
369
+ action="store_true",
370
+ help="Export only the decoder",
371
+ )
372
+
373
+ args = parser.parse_args()
374
+
375
+ # Create output directory
376
+ os.makedirs(args.output_dir, exist_ok=True)
377
+
378
+ # Load model
379
+ dacvae_model = create_dacvae_model(args.model_id)
380
+
381
+ print(f"\nDACVAE Configuration:")
382
+ print(f" Model: {args.model_id}")
383
+ print(f" Sample rate: {DEFAULT_CONFIG['sample_rate']} Hz")
384
+ print(f" Hop length: {int(__import__('numpy').prod(DEFAULT_CONFIG['encoder_rates']))}")
385
+ print(f" Latent dim: 128 (continuous)")
386
+
387
+ # Export encoder
388
+ if not args.decoder_only:
389
+ encoder_path = os.path.join(args.output_dir, "dacvae_encoder.onnx")
390
+ export_encoder(
391
+ dacvae_model,
392
+ encoder_path,
393
+ opset_version=args.opset_version,
394
+ device=args.device,
395
+ )
396
+
397
+ if args.verify:
398
+ verify_encoder(
399
+ dacvae_model,
400
+ encoder_path,
401
+ device=args.device,
402
+ tolerance=args.tolerance,
403
+ )
404
+
405
+ # Export decoder
406
+ if not args.encoder_only:
407
+ decoder_path = os.path.join(args.output_dir, "dacvae_decoder.onnx")
408
+ export_decoder(
409
+ dacvae_model,
410
+ decoder_path,
411
+ opset_version=args.opset_version,
412
+ device=args.device,
413
+ )
414
+
415
+ if args.verify:
416
+ verify_decoder(
417
+ dacvae_model,
418
+ decoder_path,
419
+ device=args.device,
420
+ tolerance=args.tolerance * 10, # Decoder has higher tolerance
421
+ )
422
+
423
+ print(f"\n✓ Export complete! Models saved to {args.output_dir}/")
424
+
425
+
426
+ if __name__ == "__main__":
427
+ main()
onnx_export/export_dit.py ADDED
@@ -0,0 +1,574 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Export DiT Transformer with unrolled ODE solver to ONNX format.
4
+
5
+ The DiT transformer is the core denoising model in SAM Audio. It uses a flow-based
6
+ generative model with an ODE solver. For ONNX export, we unroll the fixed-step
7
+ midpoint ODE solver into a static computation graph.
8
+
9
+ The default configuration uses:
10
+ - method: "midpoint"
11
+ - step_size: 2/32 (0.0625)
12
+ - integration range: [0, 1]
13
+ - total steps: 16
14
+
15
+ This creates a single ONNX model that performs the complete denoising process,
16
+ taking noise and conditioning as input and producing denoised audio features.
17
+
18
+ Usage:
19
+ python -m onnx_export.export_dit --output-dir onnx_models --verify
20
+ """
21
+
22
+ import os
23
+ import math
24
+ import argparse
25
+ import torch
26
+ import torch.nn as nn
27
+ from typing import Optional
28
+
29
+
30
+ class SinusoidalEmbedding(nn.Module):
31
+ """Sinusoidal timestep embedding (identical to SAMAudio implementation)."""
32
+
33
+ def __init__(self, dim, theta=10000):
34
+ super().__init__()
35
+ assert (dim % 2) == 0
36
+ half_dim = dim // 2
37
+ inv_freq = torch.exp(
38
+ -math.log(theta) * torch.arange(half_dim).float() / half_dim
39
+ )
40
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
41
+
42
+ def forward(self, x, pos=None):
43
+ if pos is None:
44
+ seq_len, device = x.shape[1], x.device
45
+ pos = torch.arange(seq_len, device=device)
46
+
47
+ emb = torch.einsum("i, j -> i j", pos, self.inv_freq)
48
+ emb = torch.cat((emb.cos(), emb.sin()), dim=-1)
49
+ return emb
50
+
51
+
52
+ class EmbedAnchors(nn.Module):
53
+ """Anchor embedding (identical to SAMAudio implementation)."""
54
+
55
+ def __init__(self, num_embeddings: int, embedding_dim: int, out_dim: int):
56
+ super().__init__()
57
+ self.embed = nn.Embedding(
58
+ num_embeddings + 1, embedding_dim, padding_idx=num_embeddings
59
+ )
60
+ self.gate = nn.Parameter(torch.tensor([0.0]))
61
+ self.proj = nn.Linear(embedding_dim, out_dim, bias=False)
62
+
63
+ def forward(
64
+ self,
65
+ x: torch.Tensor,
66
+ anchor_ids: Optional[torch.Tensor] = None,
67
+ anchor_alignment: Optional[torch.Tensor] = None,
68
+ ):
69
+ if anchor_ids is None:
70
+ return x
71
+
72
+ embs = self.embed(anchor_ids.gather(1, anchor_alignment))
73
+ proj = self.proj(embs)
74
+ return x + self.gate.tanh() * proj
75
+
76
+
77
+ class DiTSingleStepWrapper(nn.Module):
78
+ """
79
+ Wrapper for DiT that performs a single forward pass (one ODE evaluation).
80
+
81
+ This mirrors the SAMAudio.forward() method exactly.
82
+ """
83
+
84
+ def __init__(
85
+ self,
86
+ transformer: nn.Module,
87
+ proj: nn.Module,
88
+ align_masked_video: nn.Module,
89
+ embed_anchors: nn.Module,
90
+ timestep_emb: nn.Module,
91
+ memory_proj: nn.Module,
92
+ ):
93
+ super().__init__()
94
+ self.transformer = transformer
95
+ self.proj = proj
96
+ self.align_masked_video = align_masked_video
97
+ self.embed_anchors = embed_anchors
98
+ self.timestep_emb = timestep_emb
99
+ self.memory_proj = memory_proj
100
+
101
+ def forward(
102
+ self,
103
+ noisy_audio: torch.Tensor,
104
+ time: torch.Tensor,
105
+ audio_features: torch.Tensor,
106
+ text_features: torch.Tensor,
107
+ text_mask: torch.Tensor,
108
+ masked_video_features: torch.Tensor,
109
+ anchor_ids: torch.Tensor,
110
+ anchor_alignment: torch.Tensor,
111
+ audio_pad_mask: torch.Tensor,
112
+ ) -> torch.Tensor:
113
+ """
114
+ Single forward pass of the DiT (one ODE function evaluation).
115
+
116
+ This exactly mirrors SAMAudio.forward() method.
117
+ """
118
+ # Align inputs (concatenate noisy_audio with audio_features)
119
+ # Same as SAMAudio.align_inputs()
120
+ x = torch.cat(
121
+ [
122
+ noisy_audio,
123
+ torch.zeros_like(audio_features),
124
+ audio_features,
125
+ ],
126
+ dim=2,
127
+ )
128
+
129
+ projected = self.proj(x)
130
+ aligned = self.align_masked_video(projected, masked_video_features)
131
+ aligned = self.embed_anchors(aligned, anchor_ids, anchor_alignment)
132
+
133
+ # Timestep embedding and memory
134
+ # Same as SAMAudio.forward()
135
+ timestep_emb_val = self.timestep_emb(time, pos=time).unsqueeze(1)
136
+ memory = self.memory_proj(text_features) + timestep_emb_val
137
+
138
+ # Transformer forward
139
+ output = self.transformer(
140
+ aligned,
141
+ time,
142
+ padding_mask=audio_pad_mask,
143
+ memory=memory,
144
+ memory_padding_mask=text_mask,
145
+ )
146
+
147
+ return output
148
+
149
+
150
+ class UnrolledDiTWrapper(nn.Module):
151
+ """
152
+ DiT wrapper with unrolled midpoint ODE solver.
153
+
154
+ The midpoint method computes:
155
+ k1 = f(t, y)
156
+ k2 = f(t + h/2, y + h/2 * k1)
157
+ y_new = y + h * k2
158
+
159
+ With step_size=0.0625 and range [0,1], we have 16 steps.
160
+ """
161
+
162
+ def __init__(
163
+ self,
164
+ single_step: DiTSingleStepWrapper,
165
+ num_steps: int = 16,
166
+ ):
167
+ super().__init__()
168
+ self.single_step = single_step
169
+ self.num_steps = num_steps
170
+ self.step_size = 1.0 / num_steps
171
+
172
+ def forward(
173
+ self,
174
+ noise: torch.Tensor,
175
+ audio_features: torch.Tensor,
176
+ text_features: torch.Tensor,
177
+ text_mask: torch.Tensor,
178
+ masked_video_features: torch.Tensor,
179
+ anchor_ids: torch.Tensor,
180
+ anchor_alignment: torch.Tensor,
181
+ audio_pad_mask: torch.Tensor,
182
+ ) -> torch.Tensor:
183
+ """Complete denoising using unrolled midpoint ODE solver."""
184
+ B = noise.shape[0]
185
+ h = self.step_size
186
+ y = noise
187
+ t = torch.zeros(B, device=noise.device, dtype=noise.dtype)
188
+
189
+ for step in range(self.num_steps):
190
+ # k1 = f(t, y)
191
+ k1 = self.single_step(
192
+ y, t,
193
+ audio_features, text_features, text_mask,
194
+ masked_video_features, anchor_ids, anchor_alignment, audio_pad_mask
195
+ )
196
+
197
+ # k2 = f(t + h/2, y + h/2 * k1)
198
+ t_mid = t + h / 2
199
+ y_mid = y + (h / 2) * k1
200
+ k2 = self.single_step(
201
+ y_mid, t_mid,
202
+ audio_features, text_features, text_mask,
203
+ masked_video_features, anchor_ids, anchor_alignment, audio_pad_mask
204
+ )
205
+
206
+ # y = y + h * k2
207
+ y = y + h * k2
208
+ t = t + h
209
+
210
+ return y
211
+
212
+
213
+ def load_sam_audio_components(model_id: str = "facebook/sam-audio-small", device: str = "cpu"):
214
+ """
215
+ Load SAM Audio components needed for DiT export.
216
+
217
+ Since we can't load the full SAMAudio model (missing perception_models),
218
+ we construct the components directly and load weights from checkpoint.
219
+ """
220
+ import json
221
+ import sys
222
+ import types
223
+ import importlib.util
224
+ from huggingface_hub import hf_hub_download
225
+
226
+ print(f"Loading SAM Audio components from {model_id}...")
227
+
228
+ # Download config
229
+ config_path = hf_hub_download(repo_id=model_id, filename="config.json")
230
+ with open(config_path) as f:
231
+ config = json.load(f)
232
+
233
+ # Download checkpoint
234
+ checkpoint_path = hf_hub_download(repo_id=model_id, filename="checkpoint.pt")
235
+
236
+ # Use our standalone config that doesn't have 'core' dependencies
237
+ from onnx_export.standalone_config import TransformerConfig
238
+
239
+ sam_audio_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
240
+
241
+ # Create fake module hierarchy so transformer.py's relative imports work
242
+ if 'sam_audio' not in sys.modules:
243
+ sam_audio_pkg = types.ModuleType('sam_audio')
244
+ sam_audio_pkg.__path__ = [os.path.join(sam_audio_path, 'sam_audio')]
245
+ sys.modules['sam_audio'] = sam_audio_pkg
246
+
247
+ if 'sam_audio.model' not in sys.modules:
248
+ model_pkg = types.ModuleType('sam_audio.model')
249
+ model_pkg.__path__ = [os.path.join(sam_audio_path, 'sam_audio', 'model')]
250
+ sys.modules['sam_audio.model'] = model_pkg
251
+
252
+ # Register our standalone config as sam_audio.model.config
253
+ if 'sam_audio.model.config' not in sys.modules:
254
+ import onnx_export.standalone_config as standalone_config
255
+ sys.modules['sam_audio.model.config'] = standalone_config
256
+
257
+ # Now import transformer module - it will use our standalone config
258
+ transformer_spec = importlib.util.spec_from_file_location(
259
+ "sam_audio.model.transformer",
260
+ os.path.join(sam_audio_path, "sam_audio", "model", "transformer.py")
261
+ )
262
+ transformer_module = importlib.util.module_from_spec(transformer_spec)
263
+ sys.modules['sam_audio.model.transformer'] = transformer_module
264
+ transformer_spec.loader.exec_module(transformer_module)
265
+ DiT = transformer_module.DiT
266
+
267
+ # Import align module
268
+ align_spec = importlib.util.spec_from_file_location(
269
+ "sam_audio.model.align",
270
+ os.path.join(sam_audio_path, "sam_audio", "model", "align.py")
271
+ )
272
+ align_module = importlib.util.module_from_spec(align_spec)
273
+ sys.modules['sam_audio.model.align'] = align_module
274
+ align_spec.loader.exec_module(align_module)
275
+ AlignModalities = align_module.AlignModalities
276
+
277
+ # Create transformer
278
+ transformer_config = TransformerConfig(**config.get("transformer", {}))
279
+ transformer = DiT(transformer_config)
280
+
281
+ # Calculate dimensions
282
+ in_channels = config.get("in_channels", 768)
283
+ num_anchors = config.get("num_anchors", 3)
284
+ anchor_embedding_dim = config.get("anchor_embedding_dim", 128)
285
+
286
+ # Get vision encoder dim for align_masked_video
287
+ vision_config = config.get("vision_encoder", {})
288
+ vision_dim = vision_config.get("dim", 768)
289
+
290
+ # Create components exactly as SAMAudio does
291
+ proj = nn.Linear(in_channels, transformer_config.d_model)
292
+ align_masked_video = AlignModalities(vision_dim, transformer_config.d_model)
293
+ embed_anchors = EmbedAnchors(num_anchors, anchor_embedding_dim, transformer_config.d_model)
294
+ timestep_emb = SinusoidalEmbedding(transformer_config.d_model)
295
+
296
+ # Memory projection for text features
297
+ text_encoder_config = config.get("text_encoder", {})
298
+ text_encoder_dim = text_encoder_config.get("dim", 1024) # google/flan-t5-large
299
+ memory_proj = nn.Linear(text_encoder_dim, transformer_config.d_model)
300
+
301
+ # Load weights from checkpoint
302
+ print("Loading weights from checkpoint...")
303
+ state_dict = torch.load(checkpoint_path, map_location="cpu", mmap=True)
304
+
305
+ # Filter and load weights for each component
306
+ transformer_state = {}
307
+ proj_state = {}
308
+ align_state = {}
309
+ embed_anchors_state = {}
310
+ memory_proj_state = {}
311
+
312
+ for key, value in state_dict.items():
313
+ if key.startswith("transformer."):
314
+ new_key = key[len("transformer."):]
315
+ transformer_state[new_key] = value
316
+ elif key.startswith("proj."):
317
+ new_key = key[len("proj."):]
318
+ proj_state[new_key] = value
319
+ elif key.startswith("align_masked_video."):
320
+ new_key = key[len("align_masked_video."):]
321
+ align_state[new_key] = value
322
+ elif key.startswith("embed_anchors."):
323
+ new_key = key[len("embed_anchors."):]
324
+ embed_anchors_state[new_key] = value
325
+ elif key.startswith("memory_proj."):
326
+ new_key = key[len("memory_proj."):]
327
+ memory_proj_state[new_key] = value
328
+
329
+ transformer.load_state_dict(transformer_state)
330
+ proj.load_state_dict(proj_state)
331
+ align_masked_video.load_state_dict(align_state)
332
+ embed_anchors.load_state_dict(embed_anchors_state)
333
+ memory_proj.load_state_dict(memory_proj_state)
334
+
335
+ print(f" ✓ Loaded transformer weights ({len(transformer_state)} tensors)")
336
+ print(f" ✓ Loaded component weights")
337
+
338
+ # Create single step wrapper
339
+ single_step = DiTSingleStepWrapper(
340
+ transformer=transformer,
341
+ proj=proj,
342
+ align_masked_video=align_masked_video,
343
+ embed_anchors=embed_anchors,
344
+ timestep_emb=timestep_emb,
345
+ memory_proj=memory_proj,
346
+ ).eval().to(device)
347
+
348
+ return single_step, config
349
+
350
+
351
+ def create_sample_inputs(batch_size: int = 1, seq_len: int = 25, device: str = "cpu"):
352
+ """Create sample inputs for tracing."""
353
+ latent_dim = 128
354
+ text_dim = 768 # T5-base hidden size (SAM Audio was trained with 768-dim text)
355
+ vision_dim = 1024 # Vision encoder dim from config
356
+ text_len = 77
357
+
358
+ return {
359
+ "noisy_audio": torch.randn(batch_size, seq_len, 2 * latent_dim, device=device),
360
+ "time": torch.zeros(batch_size, device=device),
361
+ "audio_features": torch.randn(batch_size, seq_len, 2 * latent_dim, device=device),
362
+ "text_features": torch.randn(batch_size, text_len, text_dim, device=device),
363
+ "text_mask": torch.ones(batch_size, text_len, dtype=torch.bool, device=device),
364
+ "masked_video_features": torch.zeros(batch_size, vision_dim, seq_len, device=device),
365
+ "anchor_ids": torch.zeros(batch_size, seq_len, dtype=torch.long, device=device),
366
+ "anchor_alignment": torch.zeros(batch_size, seq_len, dtype=torch.long, device=device),
367
+ "audio_pad_mask": torch.ones(batch_size, seq_len, dtype=torch.bool, device=device),
368
+ }
369
+
370
+
371
+ def export_dit_single_step(
372
+ single_step: DiTSingleStepWrapper,
373
+ output_path: str,
374
+ opset_version: int = 21,
375
+ device: str = "cpu",
376
+ fp16: bool = False,
377
+ ):
378
+ """Export single-step DiT to ONNX (for runtime ODE solving)."""
379
+ import onnx
380
+
381
+ print(f"Exporting DiT single-step to {output_path}...")
382
+
383
+ # Convert to FP16 if requested
384
+ if fp16:
385
+ print(" Converting model to FP16...")
386
+ single_step = single_step.half()
387
+
388
+ sample_inputs = create_sample_inputs(device=device)
389
+
390
+ # Convert float inputs to FP16 if exporting in FP16
391
+ if fp16:
392
+ for key, value in sample_inputs.items():
393
+ if value.dtype == torch.float32:
394
+ sample_inputs[key] = value.half()
395
+
396
+ torch.onnx.export(
397
+ single_step,
398
+ tuple(sample_inputs.values()),
399
+ output_path,
400
+ input_names=list(sample_inputs.keys()),
401
+ output_names=["velocity"],
402
+ dynamic_axes={
403
+ "noisy_audio": {0: "batch_size", 1: "seq_len"},
404
+ "time": {0: "batch_size"},
405
+ "audio_features": {0: "batch_size", 1: "seq_len"},
406
+ "text_features": {0: "batch_size", 1: "text_len"},
407
+ "text_mask": {0: "batch_size", 1: "text_len"},
408
+ "masked_video_features": {0: "batch_size", 2: "seq_len"},
409
+ "anchor_ids": {0: "batch_size", 1: "seq_len"},
410
+ "anchor_alignment": {0: "batch_size", 1: "seq_len"},
411
+ "audio_pad_mask": {0: "batch_size", 1: "seq_len"},
412
+ "velocity": {0: "batch_size", 1: "seq_len"},
413
+ },
414
+ opset_version=opset_version,
415
+ do_constant_folding=True,
416
+ dynamo=True,
417
+ external_data=True,
418
+ )
419
+
420
+ print(" ��� DiT single-step exported successfully")
421
+
422
+ # When using external_data=True, we can't run check_model on a model
423
+ # loaded without external data - the checker validates data references.
424
+ # Since torch.onnx.export with dynamo=True already validates the model,
425
+ # we just verify the files exist.
426
+ external_data_path = output_path + ".data"
427
+ if os.path.exists(external_data_path):
428
+ print(f" ✓ External data file exists ({os.path.getsize(external_data_path) / 1e9:.2f} GB)")
429
+ else:
430
+ raise RuntimeError(f"External data file missing: {external_data_path}")
431
+
432
+ # Verify the ONNX file structure is valid (without loading weights)
433
+ model = onnx.load(output_path, load_external_data=False)
434
+ print(f" ✓ ONNX model structure loaded ({len(model.graph.node)} nodes)")
435
+
436
+ return True
437
+
438
+
439
+ def verify_dit_single_step(
440
+ single_step: DiTSingleStepWrapper,
441
+ onnx_path: str,
442
+ device: str = "cpu",
443
+ tolerance: float = 1e-3,
444
+ ) -> bool:
445
+ """Verify single-step ONNX output matches PyTorch."""
446
+ import onnxruntime as ort
447
+ import numpy as np
448
+
449
+ print("Verifying DiT single-step output...")
450
+
451
+ sample_inputs = create_sample_inputs(device=device)
452
+
453
+ # PyTorch output
454
+ with torch.no_grad():
455
+ pytorch_output = single_step(**sample_inputs).cpu().numpy()
456
+
457
+ # ONNX Runtime output
458
+ sess = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
459
+
460
+ onnx_inputs = {}
461
+ for name, tensor in sample_inputs.items():
462
+ if tensor.dtype == torch.bool:
463
+ onnx_inputs[name] = tensor.cpu().numpy().astype(bool)
464
+ elif tensor.dtype == torch.long:
465
+ onnx_inputs[name] = tensor.cpu().numpy().astype(np.int64)
466
+ else:
467
+ onnx_inputs[name] = tensor.cpu().numpy().astype(np.float32)
468
+
469
+ onnx_output = sess.run(["velocity"], onnx_inputs)[0]
470
+
471
+ # Compare
472
+ max_diff = np.abs(pytorch_output - onnx_output).max()
473
+ mean_diff = np.abs(pytorch_output - onnx_output).mean()
474
+
475
+ print(f" Max difference: {max_diff:.2e}")
476
+ print(f" Mean difference: {mean_diff:.2e}")
477
+
478
+ if max_diff < tolerance:
479
+ print(f" ✓ Verification passed (tolerance: {tolerance})")
480
+ return True
481
+ else:
482
+ print(f" ✗ Verification failed (tolerance: {tolerance})")
483
+ return False
484
+
485
+
486
+ def main():
487
+ parser = argparse.ArgumentParser(description="Export DiT Transformer to ONNX")
488
+ parser.add_argument(
489
+ "--model-id",
490
+ type=str,
491
+ default="facebook/sam-audio-small",
492
+ help="SAM Audio model ID from HuggingFace",
493
+ )
494
+ parser.add_argument(
495
+ "--output-dir",
496
+ type=str,
497
+ default="onnx_models",
498
+ help="Output directory for ONNX models",
499
+ )
500
+ parser.add_argument(
501
+ "--num-steps",
502
+ type=int,
503
+ default=16,
504
+ help="Number of ODE solver steps (default: 16)",
505
+ )
506
+ parser.add_argument(
507
+ "--opset",
508
+ type=int,
509
+ default=21,
510
+ help="ONNX opset version (default: 21)",
511
+ )
512
+ parser.add_argument(
513
+ "--device",
514
+ type=str,
515
+ default="cpu",
516
+ help="Device to use for export (default: cpu)",
517
+ )
518
+ parser.add_argument(
519
+ "--verify",
520
+ action="store_true",
521
+ help="Verify ONNX output matches PyTorch",
522
+ )
523
+ parser.add_argument(
524
+ "--tolerance",
525
+ type=float,
526
+ default=1e-3,
527
+ help="Tolerance for verification (default: 1e-3)",
528
+ )
529
+ parser.add_argument(
530
+ "--fp16",
531
+ action="store_true",
532
+ help="Export model in FP16 precision (half the size)",
533
+ )
534
+
535
+ args = parser.parse_args()
536
+
537
+ # Create output directory
538
+ os.makedirs(args.output_dir, exist_ok=True)
539
+
540
+ # Load components
541
+ single_step, config = load_sam_audio_components(args.model_id, args.device)
542
+
543
+ print(f"\nDiT Configuration:")
544
+ print(f" Model: {args.model_id}")
545
+ print(f" ODE steps: {args.num_steps}")
546
+ print(f" Step size: {1.0/args.num_steps:.4f}")
547
+
548
+ # Export single-step model
549
+ single_step_path = os.path.join(args.output_dir, "dit_single_step.onnx")
550
+ export_dit_single_step(
551
+ single_step,
552
+ single_step_path,
553
+ opset_version=args.opset,
554
+ device=args.device,
555
+ fp16=args.fp16,
556
+ )
557
+
558
+ if args.fp16:
559
+ print(f" ✓ Model exported in FP16 precision")
560
+
561
+ # Verify single-step
562
+ if args.verify:
563
+ verify_dit_single_step(
564
+ single_step,
565
+ single_step_path,
566
+ device=args.device,
567
+ tolerance=args.tolerance,
568
+ )
569
+
570
+ print(f"\n✓ Export complete! Model saved to {args.output_dir}")
571
+
572
+
573
+ if __name__ == "__main__":
574
+ main()
onnx_export/export_peaframe.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Export PE-A-Frame (Perception Encoder Audio Frame) span predictor to ONNX.
4
+
5
+ The PE-A-Frame model is used for automatic anchor detection in SAM Audio.
6
+ It analyzes audio features and predicts which segments correspond to the
7
+ target audio source.
8
+
9
+ Usage:
10
+ python -m onnx_export.export_peaframe --output-dir onnx_models --verify
11
+ """
12
+
13
+ import os
14
+ import argparse
15
+ import torch
16
+ import torch.nn as nn
17
+ from typing import Optional
18
+
19
+
20
+ class PEAFrameWrapper(nn.Module):
21
+ """
22
+ Wrapper for PE-A-Frame model for ONNX export.
23
+
24
+ Exposes the forward pass that takes audio features and returns
25
+ frame-level predictions.
26
+ """
27
+
28
+ def __init__(self, model: nn.Module):
29
+ super().__init__()
30
+ self.model = model
31
+
32
+ def forward(
33
+ self,
34
+ audio_features: torch.Tensor,
35
+ audio_mask: Optional[torch.Tensor] = None,
36
+ ) -> torch.Tensor:
37
+ """
38
+ Forward pass for span prediction.
39
+
40
+ Args:
41
+ audio_features: Audio features [batch, seq_len, hidden_dim]
42
+ audio_mask: Optional attention mask [batch, seq_len]
43
+
44
+ Returns:
45
+ Frame-level predictions [batch, seq_len, num_classes]
46
+ """
47
+ return self.model(audio_features, attention_mask=audio_mask)
48
+
49
+
50
+ def load_peaframe_model(config_name: str = "pe-a-frame-large", device: str = "cpu"):
51
+ """Load the PE-A-Frame model from perception_models."""
52
+ from core.audio_visual_encoder.pe import PEAudioFrame
53
+
54
+ print(f"Loading PE-A-Frame model: {config_name}...")
55
+ model = PEAudioFrame.from_config(config_name, pretrained=True)
56
+ model = model.eval().to(device)
57
+
58
+ num_params = sum(p.numel() for p in model.parameters())
59
+ print(f" ✓ Model loaded: {num_params:,} parameters")
60
+
61
+ return model
62
+
63
+
64
+ def get_tokenizer(model):
65
+ """Get the text tokenizer from the model config."""
66
+ from transformers import AutoTokenizer
67
+
68
+ text_model_name = model.config.text_model._name_or_path
69
+ return AutoTokenizer.from_pretrained(text_model_name)
70
+
71
+
72
+ def create_sample_inputs(model, batch_size: int = 1, device: str = "cpu"):
73
+ """Create sample inputs for tracing."""
74
+ tokenizer = get_tokenizer(model)
75
+
76
+ # Sample text query
77
+ text = "a person speaking"
78
+ tokens = tokenizer(
79
+ [text] * batch_size,
80
+ return_tensors="pt",
81
+ padding=True,
82
+ truncation=True,
83
+ max_length=77,
84
+ )
85
+
86
+ # Sample audio (10 seconds at 16kHz)
87
+ # DAC encoder expects (batch, channels, samples) format
88
+ sample_rate = 16000
89
+ audio_len = sample_rate * 10
90
+ audio = torch.randn(batch_size, 1, audio_len, device=device) # Added channel dimension
91
+
92
+ return {
93
+ "input_ids": tokens["input_ids"].to(device),
94
+ "attention_mask": tokens["attention_mask"].to(device),
95
+ "input_values": audio,
96
+ }
97
+
98
+
99
+ def export_peaframe(
100
+ model: nn.Module,
101
+ output_path: str,
102
+ opset_version: int = 21,
103
+ device: str = "cpu",
104
+ ):
105
+ """Export PE-A-Frame to ONNX."""
106
+ import onnx
107
+
108
+ print(f"Exporting PE-A-Frame to {output_path}...")
109
+
110
+ sample_inputs = create_sample_inputs(model, device=device)
111
+
112
+ # Put model in eval mode
113
+ model = model.eval()
114
+
115
+ # Test forward pass first
116
+ with torch.no_grad():
117
+ try:
118
+ output = model(
119
+ input_ids=sample_inputs["input_ids"],
120
+ input_values=sample_inputs["input_values"],
121
+ attention_mask=sample_inputs["attention_mask"],
122
+ return_spans=False, # Disable span return for ONNX (list output)
123
+ )
124
+ print(f" Test forward pass: audio_embeds shape = {output.audio_embeds.shape}")
125
+ print(f" Test forward pass: text_embeds shape = {output.text_embeds.shape}")
126
+ except Exception as e:
127
+ print(f" Forward pass failed: {e}")
128
+ raise
129
+
130
+ # Create a wrapper that returns just the audio embeddings for simpler ONNX
131
+ class PEAFrameONNXWrapper(nn.Module):
132
+ def __init__(self, model):
133
+ super().__init__()
134
+ self.model = model
135
+
136
+ def forward(self, input_ids, input_values, attention_mask):
137
+ output = self.model(
138
+ input_ids=input_ids,
139
+ input_values=input_values,
140
+ attention_mask=attention_mask,
141
+ return_spans=False,
142
+ )
143
+ return output.audio_embeds, output.text_embeds
144
+
145
+ wrapper = PEAFrameONNXWrapper(model)
146
+ wrapper.eval()
147
+
148
+ torch.onnx.export(
149
+ wrapper,
150
+ (sample_inputs["input_ids"], sample_inputs["input_values"], sample_inputs["attention_mask"]),
151
+ output_path,
152
+ input_names=["input_ids", "input_values", "attention_mask"],
153
+ output_names=["audio_embeds", "text_embeds"],
154
+ dynamic_axes={
155
+ "input_ids": {0: "batch_size", 1: "seq_len"},
156
+ "input_values": {0: "batch_size", 1: "audio_len"},
157
+ "attention_mask": {0: "batch_size", 1: "seq_len"},
158
+ "audio_embeds": {0: "batch_size", 1: "num_frames"},
159
+ "text_embeds": {0: "batch_size"},
160
+ },
161
+ opset_version=opset_version,
162
+ do_constant_folding=True,
163
+ external_data=True,
164
+ )
165
+
166
+ print(" ✓ PE-A-Frame exported successfully")
167
+
168
+ # Load without external data to avoid OOM - we just need to validate structure
169
+ onnx_model = onnx.load(output_path, load_external_data=False)
170
+ onnx.checker.check_model(onnx_model, full_check=False)
171
+ print(" ✓ ONNX model validation passed")
172
+
173
+ return True
174
+
175
+
176
+ def verify_peaframe(
177
+ model: nn.Module,
178
+ onnx_path: str,
179
+ device: str = "cpu",
180
+ tolerance: float = 1e-3,
181
+ ) -> bool:
182
+ """Verify ONNX output matches PyTorch."""
183
+ import onnxruntime as ort
184
+ import numpy as np
185
+
186
+ print("Verifying PE-A-Frame output...")
187
+
188
+ sample_inputs = create_sample_inputs(model, device=device)
189
+
190
+ # PyTorch output
191
+ model = model.eval()
192
+ with torch.no_grad():
193
+ pytorch_output = model(
194
+ input_ids=sample_inputs["input_ids"],
195
+ input_values=sample_inputs["input_values"],
196
+ attention_mask=sample_inputs["attention_mask"],
197
+ return_spans=False,
198
+ )
199
+ pytorch_audio_embeds = pytorch_output.audio_embeds.cpu().numpy()
200
+ pytorch_text_embeds = pytorch_output.text_embeds.cpu().numpy()
201
+
202
+ # ONNX Runtime output
203
+ sess = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
204
+
205
+ onnx_inputs = {
206
+ "input_ids": sample_inputs["input_ids"].cpu().numpy().astype(np.int64),
207
+ "input_values": sample_inputs["input_values"].cpu().numpy().astype(np.float32),
208
+ "attention_mask": sample_inputs["attention_mask"].cpu().numpy().astype(np.int64),
209
+ }
210
+
211
+ onnx_outputs = sess.run(["audio_embeds", "text_embeds"], onnx_inputs)
212
+ onnx_audio_embeds = onnx_outputs[0]
213
+ onnx_text_embeds = onnx_outputs[1]
214
+
215
+ # Compare
216
+ audio_max_diff = np.abs(pytorch_audio_embeds - onnx_audio_embeds).max()
217
+ text_max_diff = np.abs(pytorch_text_embeds - onnx_text_embeds).max()
218
+
219
+ print(f" Audio embeds max diff: {audio_max_diff:.2e}")
220
+ print(f" Text embeds max diff: {text_max_diff:.2e}")
221
+
222
+ max_diff = max(audio_max_diff, text_max_diff)
223
+ if max_diff < tolerance:
224
+ print(f" ✓ Verification passed (tolerance: {tolerance})")
225
+ return True
226
+ else:
227
+ print(f" ✗ Verification failed (tolerance: {tolerance})")
228
+ return False
229
+
230
+
231
+ def main():
232
+ parser = argparse.ArgumentParser(description="Export PE-A-Frame to ONNX")
233
+ parser.add_argument(
234
+ "--config",
235
+ type=str,
236
+ default="pe-a-frame-large",
237
+ help="PE-A-Frame config name",
238
+ )
239
+ parser.add_argument(
240
+ "--output-dir",
241
+ type=str,
242
+ default="onnx_models",
243
+ help="Output directory for ONNX models",
244
+ )
245
+ parser.add_argument(
246
+ "--opset",
247
+ type=int,
248
+ default=18,
249
+ help="ONNX opset version",
250
+ )
251
+ parser.add_argument(
252
+ "--device",
253
+ type=str,
254
+ default="cpu",
255
+ help="Device to use",
256
+ )
257
+ parser.add_argument(
258
+ "--verify",
259
+ action="store_true",
260
+ help="Verify ONNX output",
261
+ )
262
+ parser.add_argument(
263
+ "--tolerance",
264
+ type=float,
265
+ default=1e-3,
266
+ help="Verification tolerance",
267
+ )
268
+
269
+ args = parser.parse_args()
270
+
271
+ os.makedirs(args.output_dir, exist_ok=True)
272
+
273
+ # Load model
274
+ model = load_peaframe_model(args.config, args.device)
275
+
276
+ # Export
277
+ output_path = os.path.join(args.output_dir, "peaframe.onnx")
278
+ export_peaframe(model, output_path, args.opset, args.device)
279
+
280
+ # Verify
281
+ if args.verify:
282
+ verify_peaframe(model, output_path, args.device, args.tolerance)
283
+
284
+ print(f"\n✓ Export complete! Model saved to {output_path}")
285
+
286
+
287
+ if __name__ == "__main__":
288
+ main()
onnx_export/export_t5.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Export T5 Text Encoder to ONNX format.
4
+
5
+ The T5 encoder takes tokenized input_ids and attention_mask, and produces
6
+ hidden states. For SAM Audio inference, the output hidden states and attention
7
+ mask are used as conditioning for the DiT transformer.
8
+
9
+ Usage:
10
+ python -m onnx_export.export_t5 --output-dir onnx_models --verify
11
+ """
12
+
13
+ import os
14
+ import argparse
15
+ import torch
16
+ import torch.nn as nn
17
+
18
+
19
+ class T5EncoderWrapper(nn.Module):
20
+ """
21
+ Wrapper for T5EncoderModel that provides a clean interface for ONNX export.
22
+
23
+ The wrapper takes tokenized inputs (input_ids, attention_mask) and returns
24
+ the last hidden state. This matches how SAMAudio's T5TextEncoder uses the model.
25
+ """
26
+
27
+ def __init__(self, t5_model, max_length: int = 77):
28
+ super().__init__()
29
+ self.model = t5_model
30
+ self.max_length = max_length
31
+
32
+ def forward(
33
+ self,
34
+ input_ids: torch.Tensor,
35
+ attention_mask: torch.Tensor,
36
+ ) -> torch.Tensor:
37
+ """
38
+ Args:
39
+ input_ids: Tokenized input IDs, shape (batch, seq_len)
40
+ attention_mask: Attention mask, shape (batch, seq_len)
41
+
42
+ Returns:
43
+ hidden_states: T5 encoder output, shape (batch, seq_len, hidden_dim)
44
+ """
45
+ outputs = self.model(
46
+ input_ids=input_ids,
47
+ attention_mask=attention_mask,
48
+ output_hidden_states=True,
49
+ )
50
+ return outputs.last_hidden_state
51
+
52
+
53
+ def load_t5_encoder(model_name: str = "google-t5/t5-base", device: str = "cuda"):
54
+ """
55
+ Load T5 encoder model and tokenizer.
56
+
57
+ SAM Audio's DiT was trained with T5-base (768-dim) text features.
58
+ """
59
+ from transformers import T5EncoderModel, AutoTokenizer
60
+
61
+ print(f"Loading T5 encoder: {model_name}...")
62
+
63
+ model = T5EncoderModel.from_pretrained(model_name)
64
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
65
+
66
+ model = model.eval().to(device)
67
+
68
+ return model, tokenizer
69
+
70
+
71
+ def export_t5_encoder(
72
+ t5_model,
73
+ tokenizer,
74
+ output_path: str,
75
+ opset_version: int = 21,
76
+ max_length: int = 77,
77
+ device: str = "cuda",
78
+ ):
79
+ """Export T5 encoder to ONNX format."""
80
+ import onnx
81
+
82
+ print(f"Exporting T5 encoder to {output_path}...")
83
+
84
+ wrapper = T5EncoderWrapper(t5_model, max_length=max_length).eval().to(device)
85
+
86
+ # Create sample input
87
+ sample_text = ["A dog barking loudly in the background"]
88
+ encoded = tokenizer(
89
+ sample_text,
90
+ truncation=True,
91
+ max_length=max_length,
92
+ padding="max_length", # Pad to max_length for consistent shape
93
+ return_tensors="pt",
94
+ )
95
+
96
+ sample_input_ids = encoded["input_ids"].to(device)
97
+ sample_attention_mask = encoded["attention_mask"].to(device)
98
+
99
+ # Export using torch.onnx.export
100
+ torch.onnx.export(
101
+ wrapper,
102
+ (sample_input_ids, sample_attention_mask),
103
+ output_path,
104
+ input_names=["input_ids", "attention_mask"],
105
+ output_names=["hidden_states"],
106
+ dynamic_axes={
107
+ "input_ids": {0: "batch_size", 1: "sequence_length"},
108
+ "attention_mask": {0: "batch_size", 1: "sequence_length"},
109
+ "hidden_states": {0: "batch_size", 1: "sequence_length"},
110
+ },
111
+ opset_version=opset_version,
112
+ do_constant_folding=True,
113
+ dynamo=True,
114
+ external_data=True, # T5-large is ~1GB
115
+ )
116
+
117
+ print(" ✓ T5 encoder exported successfully")
118
+
119
+ # Load without external data to avoid OOM - we just need to validate structure
120
+ model = onnx.load(output_path, load_external_data=False)
121
+ onnx.checker.check_model(model, full_check=False)
122
+ print(" ✓ ONNX model validation passed")
123
+
124
+ return True
125
+
126
+
127
+ def verify_t5_encoder(
128
+ t5_model,
129
+ tokenizer,
130
+ onnx_path: str,
131
+ max_length: int = 77,
132
+ device: str = "cuda",
133
+ tolerance: float = 1e-4,
134
+ ) -> bool:
135
+ """Verify ONNX T5 encoder output matches PyTorch."""
136
+ import onnxruntime as ort
137
+ import numpy as np
138
+
139
+ print("Verifying T5 encoder output...")
140
+
141
+ wrapper = T5EncoderWrapper(t5_model, max_length=max_length).eval().to(device)
142
+
143
+ # Test with multiple texts
144
+ test_texts = [
145
+ "A dog barking in the distance",
146
+ "Piano music playing softly",
147
+ "Rain falling on a rooftop",
148
+ ]
149
+
150
+ for text in test_texts:
151
+ # Tokenize
152
+ encoded = tokenizer(
153
+ [text],
154
+ truncation=True,
155
+ max_length=max_length,
156
+ padding="max_length",
157
+ return_tensors="pt",
158
+ )
159
+
160
+ input_ids = encoded["input_ids"].to(device)
161
+ attention_mask = encoded["attention_mask"].to(device)
162
+
163
+ # PyTorch output
164
+ with torch.no_grad():
165
+ pytorch_output = wrapper(input_ids, attention_mask).cpu().numpy()
166
+
167
+ # ONNX Runtime output
168
+ sess = ort.InferenceSession(onnx_path, providers=["CUDAExecutionProvider"])
169
+ onnx_output = sess.run(
170
+ ["hidden_states"],
171
+ {
172
+ "input_ids": input_ids.cpu().numpy().astype(np.int64),
173
+ "attention_mask": attention_mask.cpu().numpy().astype(np.int64),
174
+ }
175
+ )[0]
176
+
177
+ # Compare
178
+ max_diff = np.abs(pytorch_output - onnx_output).max()
179
+ mean_diff = np.abs(pytorch_output - onnx_output).mean()
180
+
181
+ print(f" Text: '{text[:30]}...'")
182
+ print(f" Max diff: {max_diff:.2e}, Mean diff: {mean_diff:.2e}")
183
+
184
+ if max_diff > tolerance:
185
+ print(f" ✗ Verification failed for text: {text}")
186
+ return False
187
+
188
+ print(f" ✓ Verification passed (tolerance: {tolerance})")
189
+ return True
190
+
191
+
192
+ def save_tokenizer_config(tokenizer, output_dir: str):
193
+ """
194
+ Save tokenizer vocabulary and configuration for runtime use.
195
+
196
+ This allows the ONNX runtime to perform tokenization without
197
+ needing the full transformers library.
198
+ """
199
+ import json
200
+
201
+ tokenizer_dir = os.path.join(output_dir, "tokenizer")
202
+ tokenizer.save_pretrained(tokenizer_dir)
203
+
204
+ # Also save a simple config for ONNX.js usage
205
+ config = {
206
+ "model_name": tokenizer.name_or_path,
207
+ "max_length": 77,
208
+ "vocab_size": tokenizer.vocab_size,
209
+ "pad_token_id": tokenizer.pad_token_id,
210
+ "eos_token_id": tokenizer.eos_token_id,
211
+ }
212
+
213
+ config_path = os.path.join(output_dir, "tokenizer_config.json")
214
+ with open(config_path, "w") as f:
215
+ json.dump(config, f, indent=2)
216
+
217
+ print(f" ✓ Tokenizer saved to {tokenizer_dir}")
218
+ return tokenizer_dir
219
+
220
+
221
+ def main():
222
+ parser = argparse.ArgumentParser(description="Export T5 Text Encoder to ONNX")
223
+ parser.add_argument(
224
+ "--model-name",
225
+ type=str,
226
+ default="google-t5/t5-base",
227
+ help="T5 model name from HuggingFace (default: google-t5/t5-base)",
228
+ )
229
+ parser.add_argument(
230
+ "--output-dir",
231
+ type=str,
232
+ default="onnx_models",
233
+ help="Output directory for ONNX models",
234
+ )
235
+ parser.add_argument(
236
+ "--max-length",
237
+ type=int,
238
+ default=77,
239
+ help="Maximum token sequence length (default: 77)",
240
+ )
241
+ parser.add_argument(
242
+ "--opset",
243
+ type=int,
244
+ default=18,
245
+ help="ONNX opset version (default: 18)",
246
+ )
247
+ parser.add_argument(
248
+ "--device",
249
+ type=str,
250
+ default="cuda",
251
+ help="Device to use for export (default: cuda)",
252
+ )
253
+ parser.add_argument(
254
+ "--verify",
255
+ action="store_true",
256
+ help="Verify ONNX output matches PyTorch",
257
+ )
258
+ parser.add_argument(
259
+ "--tolerance",
260
+ type=float,
261
+ default=1e-4,
262
+ help="Tolerance for verification (default: 1e-4)",
263
+ )
264
+ parser.add_argument(
265
+ "--save-tokenizer",
266
+ action="store_true",
267
+ default=True,
268
+ help="Save tokenizer for runtime use (default: True)",
269
+ )
270
+
271
+ args = parser.parse_args()
272
+
273
+ # Create output directory
274
+ os.makedirs(args.output_dir, exist_ok=True)
275
+
276
+ # Load T5
277
+ t5_model, tokenizer = load_t5_encoder(args.model_name, args.device)
278
+
279
+ print(f"\nT5 Configuration:")
280
+ print(f" Model: {args.model_name}")
281
+ print(f" Hidden size: {t5_model.config.d_model}")
282
+ print(f" Max length: {args.max_length}")
283
+ print(f" Vocab size: {tokenizer.vocab_size}")
284
+
285
+ # Export
286
+ encoder_path = os.path.join(args.output_dir, "t5_encoder.onnx")
287
+ export_t5_encoder(
288
+ t5_model,
289
+ tokenizer,
290
+ encoder_path,
291
+ opset_version=args.opset,
292
+ max_length=args.max_length,
293
+ device=args.device,
294
+ )
295
+
296
+ # Save tokenizer
297
+ if args.save_tokenizer:
298
+ save_tokenizer_config(tokenizer, args.output_dir)
299
+
300
+ # Verify
301
+ if args.verify:
302
+ verify_t5_encoder(
303
+ t5_model,
304
+ tokenizer,
305
+ encoder_path,
306
+ max_length=args.max_length,
307
+ device=args.device,
308
+ tolerance=args.tolerance,
309
+ )
310
+
311
+ print(f"\n✓ Export complete! Model saved to {encoder_path}")
312
+
313
+
314
+ if __name__ == "__main__":
315
+ main()
onnx_export/export_vision.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import os
3
+ import torch
4
+ import torch.nn as nn
5
+ import onnx
6
+ from sam_audio.model.vision_encoder import PerceptionEncoder
7
+ from onnx_export.standalone_config import PerceptionEncoderConfig
8
+
9
+ class VisionEncoderWrapper(nn.Module):
10
+ """
11
+ Wrapper for the Vision Encoder (CLIP visual backbone).
12
+ """
13
+ def __init__(self, vision_encoder):
14
+ super().__init__()
15
+ self.model = vision_encoder.model
16
+ self.normalize = vision_encoder.normalize_feature
17
+
18
+ def forward(self, x):
19
+ # x: (N, 3, H, W) where N is number of frames
20
+ # returns: (N, 1024) features
21
+ return self.model.encode_image(x, normalize=self.normalize)
22
+
23
+ def export_vision_encoder(model_id="facebook/sam-audio-small", output_dir="onnx_models", device="cpu"):
24
+ """Export the vision encoder to ONNX."""
25
+ print(f"Loading Vision Encoder from {model_id}...")
26
+
27
+ import torch
28
+ from transformers import AutoConfig
29
+ from sam_audio.model.vision_encoder import PerceptionEncoder
30
+ from onnx_export.standalone_config import PerceptionEncoderConfig
31
+
32
+ print("Fetching config...")
33
+ cfg_hf = AutoConfig.from_pretrained(model_id)
34
+ cfg_dict = cfg_hf.to_dict()
35
+
36
+ # Extract vision encoder config
37
+ v_cfg_dict = cfg_dict.get("vision_encoder", {})
38
+ v_cfg = PerceptionEncoderConfig(**v_cfg_dict)
39
+
40
+ print(f"Initializing PerceptionEncoder with name: {v_cfg.name}...")
41
+ vision_encoder = PerceptionEncoder(v_cfg)
42
+
43
+ # Load weights from checkpoint
44
+ print("Loading weights from SAM Audio checkpoint...")
45
+ from huggingface_hub import hf_hub_download
46
+ checkpoint_path = hf_hub_download(repo_id=model_id, filename="checkpoint.pt")
47
+ state_dict = torch.load(checkpoint_path, map_location="cpu", mmap=True)
48
+
49
+ # Filter vision encoder weights
50
+ vision_state = {}
51
+ prefix = "vision_encoder."
52
+ for key, value in state_dict.items():
53
+ if key.startswith(prefix):
54
+ new_key = key[len(prefix):]
55
+ vision_state[new_key] = value
56
+
57
+ if vision_state:
58
+ print(f" Loading {len(vision_state)} tensors into vision encoder...")
59
+ vision_encoder.load_state_dict(vision_state)
60
+ print(" ✓ Vision encoder weights loaded.")
61
+ else:
62
+ print(" WARNING: No 'vision_encoder' weights found in checkpoint. Using base weights.")
63
+
64
+ image_size = vision_encoder.image_size
65
+ print(f" Image size: {image_size}")
66
+
67
+
68
+ wrapper = VisionEncoderWrapper(vision_encoder).eval().to(device)
69
+
70
+ # Create dummy input on device
71
+ image_size = vision_encoder.image_size
72
+ dummy_input = torch.randn(1, 3, image_size, image_size, device=device)
73
+
74
+ output_path = os.path.join(output_dir, "vision_encoder.onnx")
75
+ os.makedirs(output_dir, exist_ok=True)
76
+
77
+ print(f"Exporting to {output_path}...")
78
+ input_names = ["video_frames"]
79
+ output_names = ["vision_features"]
80
+ opset_version = 18 # Use opset 18 for better CUDA compatibility
81
+ torch.onnx.export(
82
+ wrapper,
83
+ dummy_input,
84
+ output_path,
85
+ input_names=input_names,
86
+ output_names=output_names,
87
+ dynamic_axes={
88
+ "video_frames": {0: "num_frames"},
89
+ "vision_features": {0: "num_frames"},
90
+ },
91
+ opset_version=opset_version,
92
+ do_constant_folding=True,
93
+ dynamo=True,
94
+ external_data=True,
95
+ )
96
+
97
+ # Check if data was saved separately
98
+ data_path = output_path + ".data"
99
+ if os.path.exists(data_path):
100
+ print(f" Large model detected, weights saved to {data_path}")
101
+
102
+ print("✓ Vision encoder export complete!")
103
+
104
+ if __name__ == "__main__":
105
+ import argparse
106
+ parser = argparse.ArgumentParser()
107
+ parser.add_argument("--model", type=str, default="facebook/sam-audio-small")
108
+ parser.add_argument("--output", type=str, default="onnx_models")
109
+ parser.add_argument("--device", type=str, default="cpu", help="Device for export (cpu or cuda)")
110
+ args = parser.parse_args()
111
+
112
+ export_vision_encoder(args.model, args.output, device=args.device)
113
+
onnx_export/quantize_large_model.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Memory-efficient FP16 conversion for large ONNX models with external data.
4
+
5
+ This script converts models by processing tensors one at a time, avoiding
6
+ loading the entire model into memory.
7
+
8
+ Usage:
9
+ python -m onnx_export.quantize_large_model \
10
+ --input onnx_models_large/dit_single_step.onnx \
11
+ --output onnx_models_large_fp16/dit_single_step.onnx
12
+ """
13
+
14
+ import os
15
+ import argparse
16
+ import numpy as np
17
+ from pathlib import Path
18
+
19
+
20
+ def convert_tensor_to_fp16(tensor_data: np.ndarray) -> np.ndarray:
21
+ """Convert tensor data to FP16 if it's FP32."""
22
+ if tensor_data.dtype == np.float32:
23
+ return tensor_data.astype(np.float16)
24
+ return tensor_data
25
+
26
+
27
+ def quantize_large_model_fp16(input_path: str, output_path: str):
28
+ """
29
+ Convert large ONNX model to FP16 using onnxruntime.transformers.
30
+
31
+ This properly updates both tensor data AND graph type annotations.
32
+ """
33
+ import onnx
34
+ from onnxruntime.transformers import float16
35
+ import gc
36
+
37
+ input_dir = os.path.dirname(os.path.abspath(input_path))
38
+ output_dir = os.path.dirname(os.path.abspath(output_path))
39
+ os.makedirs(output_dir, exist_ok=True)
40
+
41
+ print(f"Loading model from {input_path}...")
42
+ print(f" (This may take a while for large models)")
43
+
44
+ # Load model with external data
45
+ model = onnx.load(input_path, load_external_data=False)
46
+ onnx.load_external_data_for_model(model, input_dir)
47
+
48
+ original_size = sum(
49
+ np.prod(tensor.dims) * 4 # Assuming FP32
50
+ for tensor in model.graph.initializer
51
+ if tensor.data_type == onnx.TensorProto.FLOAT
52
+ )
53
+
54
+ print(f" Loaded model ({original_size / 1e9:.2f} GB of FP32 weights)")
55
+
56
+ print(f"Converting to FP16...")
57
+ model_fp16 = float16.convert_float_to_float16(
58
+ model,
59
+ keep_io_types=True, # Keep inputs/outputs as FP32 for compatibility
60
+ disable_shape_infer=True, # Skip shape inference for speed
61
+ )
62
+
63
+ # Free original model
64
+ del model
65
+ gc.collect()
66
+
67
+ # External data file for output
68
+ output_data_filename = os.path.basename(output_path) + ".data"
69
+
70
+ print(f"Saving to {output_path}...")
71
+ onnx.save(
72
+ model_fp16,
73
+ output_path,
74
+ save_as_external_data=True,
75
+ all_tensors_to_one_file=True,
76
+ location=output_data_filename,
77
+ size_threshold=0, # Save all tensors externally
78
+ )
79
+
80
+ # Report results
81
+ output_data_path = os.path.join(output_dir, output_data_filename)
82
+ if os.path.exists(output_path) and os.path.exists(output_data_path):
83
+ output_size = os.path.getsize(output_data_path)
84
+ print(f"✓ Model saved successfully!")
85
+ print(f" Graph: {os.path.getsize(output_path)/1e6:.2f} MB")
86
+ print(f" Weights: {output_size/1e9:.2f} GB")
87
+ print(f" Reduction: {(1 - output_size / original_size) * 100:.1f}%")
88
+ else:
89
+ raise RuntimeError("Output files were not created properly")
90
+
91
+ return True
92
+
93
+
94
+ def main():
95
+ parser = argparse.ArgumentParser(description="Memory-efficient FP16 conversion for large ONNX models")
96
+ parser.add_argument(
97
+ "--input",
98
+ type=str,
99
+ required=True,
100
+ help="Input ONNX model path",
101
+ )
102
+ parser.add_argument(
103
+ "--output",
104
+ type=str,
105
+ required=True,
106
+ help="Output ONNX model path",
107
+ )
108
+
109
+ args = parser.parse_args()
110
+
111
+ quantize_large_model_fp16(args.input, args.output)
112
+
113
+
114
+ if __name__ == "__main__":
115
+ main()
onnx_export/quantize_models.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Quantize ONNX models for SAM Audio to reduce size and improve inference speed.
4
+
5
+ Supports:
6
+ - FP16 quantization (recommended for audio models)
7
+ - INT8 dynamic quantization (best size reduction)
8
+ - INT8 static quantization (requires calibration data)
9
+
10
+ Usage:
11
+ # Quantize all models to FP16
12
+ python -m onnx_export.quantize_models --model-dir onnx_models --output-dir onnx_models_fp16 --mode fp16
13
+
14
+ # Quantize to INT8 (dynamic)
15
+ python -m onnx_export.quantize_models --model-dir onnx_models --output-dir onnx_models_int8 --mode int8
16
+
17
+ # Quantize specific model
18
+ python -m onnx_export.quantize_models --model-dir onnx_models --output-dir onnx_models_fp16 --mode fp16 --models dit
19
+ """
20
+
21
+ import os
22
+ import argparse
23
+ import shutil
24
+ from pathlib import Path
25
+
26
+
27
+ def get_model_files(model_dir: str) -> dict:
28
+ """Find all ONNX model files in directory."""
29
+ models = {}
30
+ model_names = {
31
+ "dit_single_step": "DiT Denoiser",
32
+ "dacvae_encoder": "DACVAE Encoder",
33
+ "dacvae_decoder": "DACVAE Decoder",
34
+ "t5_encoder": "T5 Text Encoder",
35
+ "vision_encoder": "Vision Encoder",
36
+ }
37
+
38
+ for name, display_name in model_names.items():
39
+ onnx_path = os.path.join(model_dir, f"{name}.onnx")
40
+ if os.path.exists(onnx_path):
41
+ models[name] = {
42
+ "path": onnx_path,
43
+ "display_name": display_name,
44
+ "has_external_data": os.path.exists(f"{onnx_path}.data"),
45
+ }
46
+
47
+ return models
48
+
49
+
50
+ def quantize_fp16(input_path: str, output_path: str, has_external_data: bool = False):
51
+ """Convert model to FP16 precision."""
52
+ import onnx
53
+ from onnxruntime.transformers import float16
54
+
55
+ print(f" Loading model...")
56
+
57
+ # For models with external data, load everything into memory
58
+ if has_external_data:
59
+ model_dir = os.path.dirname(os.path.abspath(input_path))
60
+ model = onnx.load(input_path, load_external_data=False)
61
+ onnx.load_external_data_for_model(model, model_dir)
62
+ else:
63
+ model = onnx.load(input_path)
64
+
65
+ print(f" Converting to FP16...")
66
+ model_fp16 = float16.convert_float_to_float16(
67
+ model,
68
+ keep_io_types=True, # Keep inputs/outputs as FP32 for compatibility
69
+ disable_shape_infer=True, # Skip shape inference (faster)
70
+ )
71
+
72
+ # Free original model memory
73
+ del model
74
+ import gc
75
+ gc.collect()
76
+
77
+ # Calculate the size of the FP16 model
78
+ # We estimate by serializing - only use external data if over 2GB
79
+ print(f" Saving to {output_path}...")
80
+
81
+ # First try to save without external data (preferred for smaller models)
82
+ try:
83
+ # Serialize to check size
84
+ model_bytes = model_fp16.SerializeToString()
85
+ model_size = len(model_bytes)
86
+
87
+ if model_size < 2 * 1024 * 1024 * 1024: # Under 2GB
88
+ # Save as self-contained file (no external data)
89
+ with open(output_path, 'wb') as f:
90
+ f.write(model_bytes)
91
+ print(f" Saved as self-contained ONNX ({model_size/1e6:.1f} MB)")
92
+ else:
93
+ # Too large, need external data
94
+ onnx.save(
95
+ model_fp16,
96
+ output_path,
97
+ save_as_external_data=True,
98
+ all_tensors_to_one_file=True,
99
+ location=os.path.basename(output_path) + ".data",
100
+ size_threshold=0,
101
+ )
102
+ print(f" Saved with external data ({model_size/1e9:.2f} GB)")
103
+ except Exception as e:
104
+ # If serialization fails (too large), use external data
105
+ print(f" Model too large for memory, saving with external data...")
106
+ onnx.save(
107
+ model_fp16,
108
+ output_path,
109
+ save_as_external_data=True,
110
+ all_tensors_to_one_file=True,
111
+ location=os.path.basename(output_path) + ".data",
112
+ size_threshold=0,
113
+ )
114
+
115
+ return True
116
+
117
+
118
+ def quantize_int8_dynamic(input_path: str, output_path: str, has_external_data: bool = False):
119
+ """Quantize model to INT8 using dynamic quantization."""
120
+ from onnxruntime.quantization import quantize_dynamic, QuantType
121
+ import onnx
122
+
123
+ print(f" Loading model...")
124
+
125
+ # For models with external data, we need to load and re-save first
126
+ if has_external_data:
127
+ model = onnx.load(input_path, load_external_data=True)
128
+ temp_path = input_path + ".temp.onnx"
129
+ onnx.save(model, temp_path)
130
+ input_path = temp_path
131
+
132
+ print(f" Quantizing to INT8 (dynamic)...")
133
+
134
+ quantize_dynamic(
135
+ input_path,
136
+ output_path,
137
+ weight_type=QuantType.QInt8,
138
+ extra_options={
139
+ "EnableSubgraph": True,
140
+ }
141
+ )
142
+
143
+ # Cleanup temp file
144
+ if has_external_data and os.path.exists(input_path + ".temp.onnx"):
145
+ os.remove(input_path + ".temp.onnx")
146
+
147
+ return True
148
+
149
+
150
+ def quantize_model(
151
+ name: str,
152
+ model_info: dict,
153
+ output_dir: str,
154
+ mode: str,
155
+ ) -> bool:
156
+ """Quantize a single model."""
157
+ input_path = model_info["path"]
158
+ output_path = os.path.join(output_dir, f"{name}.onnx")
159
+ has_external_data = model_info["has_external_data"]
160
+
161
+ print(f"\nQuantizing {model_info['display_name']}...")
162
+ print(f" Input: {input_path}")
163
+ print(f" Output: {output_path}")
164
+ print(f" External data: {has_external_data}")
165
+
166
+ try:
167
+ if mode == "fp16":
168
+ success = quantize_fp16(input_path, output_path, has_external_data)
169
+ elif mode == "int8":
170
+ success = quantize_int8_dynamic(input_path, output_path, has_external_data)
171
+ else:
172
+ print(f" ✗ Unknown quantization mode: {mode}")
173
+ return False
174
+
175
+ if success:
176
+ # Report size reduction
177
+ input_size = os.path.getsize(input_path)
178
+ if has_external_data:
179
+ input_size += os.path.getsize(input_path + ".data")
180
+
181
+ output_size = os.path.getsize(output_path)
182
+ if os.path.exists(output_path + ".data"):
183
+ output_size += os.path.getsize(output_path + ".data")
184
+
185
+ reduction = (1 - output_size / input_size) * 100
186
+ print(f" ✓ Done! Size: {input_size/1e9:.2f}GB → {output_size/1e9:.2f}GB ({reduction:.1f}% reduction)")
187
+ return True
188
+
189
+ except Exception as e:
190
+ print(f" ✗ Error: {e}")
191
+ import traceback
192
+ traceback.print_exc()
193
+ return False
194
+
195
+ return False
196
+
197
+
198
+ def copy_tokenizer(model_dir: str, output_dir: str):
199
+ """Copy tokenizer files to output directory."""
200
+ tokenizer_dir = os.path.join(model_dir, "tokenizer")
201
+ tokenizer_config = os.path.join(model_dir, "tokenizer_config.json")
202
+
203
+ if os.path.exists(tokenizer_dir):
204
+ output_tokenizer_dir = os.path.join(output_dir, "tokenizer")
205
+ if not os.path.exists(output_tokenizer_dir):
206
+ shutil.copytree(tokenizer_dir, output_tokenizer_dir)
207
+ print(f"\n✓ Copied tokenizer directory")
208
+
209
+ if os.path.exists(tokenizer_config):
210
+ shutil.copy(tokenizer_config, os.path.join(output_dir, "tokenizer_config.json"))
211
+ print(f"✓ Copied tokenizer_config.json")
212
+
213
+
214
+ def main():
215
+ parser = argparse.ArgumentParser(description="Quantize ONNX models for SAM Audio")
216
+ parser.add_argument(
217
+ "--model-dir",
218
+ type=str,
219
+ default="onnx_models",
220
+ help="Directory containing ONNX models",
221
+ )
222
+ parser.add_argument(
223
+ "--output-dir",
224
+ type=str,
225
+ required=True,
226
+ help="Output directory for quantized models",
227
+ )
228
+ parser.add_argument(
229
+ "--mode",
230
+ type=str,
231
+ choices=["fp16", "int8"],
232
+ default="fp16",
233
+ help="Quantization mode: fp16 (recommended) or int8",
234
+ )
235
+ parser.add_argument(
236
+ "--models",
237
+ type=str,
238
+ nargs="+",
239
+ choices=["dit", "dacvae_encoder", "dacvae_decoder", "t5", "vision", "all"],
240
+ default=["all"],
241
+ help="Which models to quantize",
242
+ )
243
+
244
+ args = parser.parse_args()
245
+
246
+ # Create output directory
247
+ os.makedirs(args.output_dir, exist_ok=True)
248
+
249
+ # Find models
250
+ models = get_model_files(args.model_dir)
251
+
252
+ if not models:
253
+ print(f"No ONNX models found in {args.model_dir}")
254
+ return
255
+
256
+ print(f"Found {len(models)} models in {args.model_dir}")
257
+ print(f"Quantization mode: {args.mode.upper()}")
258
+
259
+ # Filter models if specific ones requested
260
+ if "all" not in args.models:
261
+ name_mapping = {
262
+ "dit": "dit_single_step",
263
+ "dacvae_encoder": "dacvae_encoder",
264
+ "dacvae_decoder": "dacvae_decoder",
265
+ "t5": "t5_encoder",
266
+ "vision": "vision_encoder",
267
+ }
268
+ selected = {name_mapping[m] for m in args.models if m in name_mapping}
269
+ models = {k: v for k, v in models.items() if k in selected}
270
+
271
+ # Quantize each model
272
+ success_count = 0
273
+ for name, model_info in models.items():
274
+ if quantize_model(name, model_info, args.output_dir, args.mode):
275
+ success_count += 1
276
+
277
+ # Copy tokenizer files
278
+ copy_tokenizer(args.model_dir, args.output_dir)
279
+
280
+ print(f"\n{'='*50}")
281
+ print(f"✓ Quantization complete! {success_count}/{len(models)} models processed")
282
+ print(f" Output directory: {args.output_dir}")
283
+
284
+
285
+ if __name__ == "__main__":
286
+ main()