Luigi commited on
Commit
2967cdb
·
1 Parent(s): 96bf80c

Restore complete zipvoice package with all source files

Browse files
Files changed (48) hide show
  1. zipvoice/__init__.py +7 -0
  2. zipvoice/bin/__init__.py +1 -0
  3. zipvoice/bin/compute_fbank.py +272 -0
  4. zipvoice/bin/generate_averaged_model.py +229 -0
  5. zipvoice/bin/infer_zipvoice.py +877 -0
  6. zipvoice/bin/infer_zipvoice_dialog.py +1286 -0
  7. zipvoice/bin/infer_zipvoice_onnx.py +924 -0
  8. zipvoice/bin/onnx_export.py +429 -0
  9. zipvoice/bin/prepare_dataset.py +274 -0
  10. zipvoice/bin/prepare_tokens.py +103 -0
  11. zipvoice/bin/train_zipvoice.py +1130 -0
  12. zipvoice/bin/train_zipvoice_dialog.py +980 -0
  13. zipvoice/bin/train_zipvoice_dialog_stereo.py +963 -0
  14. zipvoice/bin/train_zipvoice_distill.py +1158 -0
  15. zipvoice/dataset/datamodule.py +347 -0
  16. zipvoice/dataset/dataset.py +105 -0
  17. zipvoice/eval/models/ecapa_tdnn_wavllm.py +357 -0
  18. zipvoice/eval/models/ecapa_tdnn_wavlm.py +357 -0
  19. zipvoice/eval/models/utmos.py +354 -0
  20. zipvoice/eval/mos/utmos.py +174 -0
  21. zipvoice/eval/speaker_similarity/cpsim.py +411 -0
  22. zipvoice/eval/speaker_similarity/sim.py +229 -0
  23. zipvoice/eval/utils.py +62 -0
  24. zipvoice/eval/wer/dialog.py +493 -0
  25. zipvoice/eval/wer/hubert.py +285 -0
  26. zipvoice/eval/wer/seedtts.py +298 -0
  27. zipvoice/models/__init__.py +1 -0
  28. zipvoice/models/modules/__init__.py +1 -0
  29. zipvoice/models/modules/scaling.py +1590 -0
  30. zipvoice/models/modules/solver.py +281 -0
  31. zipvoice/models/modules/zipformer.py +1680 -0
  32. zipvoice/models/modules/zipformer_two_stream.py +264 -0
  33. zipvoice/models/zipvoice.py +534 -0
  34. zipvoice/models/zipvoice_dialog.py +358 -0
  35. zipvoice/models/zipvoice_distill.py +94 -0
  36. zipvoice/tokenizer/__init__.py +1 -0
  37. zipvoice/tokenizer/normalizer.py +170 -0
  38. zipvoice/tokenizer/tokenizer.py +648 -0
  39. zipvoice/utils/__init__.py +1 -0
  40. zipvoice/utils/checkpoint.py +570 -0
  41. zipvoice/utils/common.py +670 -0
  42. zipvoice/utils/diagnostics.py +723 -0
  43. zipvoice/utils/feature.py +120 -0
  44. zipvoice/utils/hooks.py +111 -0
  45. zipvoice/utils/infer.py +414 -0
  46. zipvoice/utils/lr_scheduler.py +245 -0
  47. zipvoice/utils/optim.py +868 -0
  48. zipvoice/utils/scaling_converter.py +105 -0
zipvoice/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ warnings.filterwarnings(
4
+ "ignore",
5
+ category=UserWarning,
6
+ message="pkg_resources is deprecated as an API.*",
7
+ )
zipvoice/bin/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # ZipVoice bin package
zipvoice/bin/compute_fbank.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2024-2025 Xiaomi Corp. (authors: Wei Kang
3
+ # Han Zhu)
4
+ #
5
+ # See ../../../../LICENSE for clarification regarding multiple authors
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+ """
19
+ Usage:
20
+ python3 -m zipvoice.bin.compute_fbank \
21
+ --source-dir data/manifests \
22
+ --dest-dir data/fbank \
23
+ --dataset libritts \
24
+ --subset dev-other \
25
+ --sampling-rate 24000 \
26
+ --num-jobs 20
27
+
28
+ The input would be data/manifests/libritts-cuts_dev-other.jsonl.gz or
29
+ (libritts_supervisions_dev-other.jsonl.gz and librittsrecordings_dev-other.jsonl.gz)
30
+
31
+ The output would be data/fbank/libritts-cuts_dev-other.jsonl.gz
32
+ """
33
+
34
+
35
+ import argparse
36
+ import logging
37
+ from concurrent.futures import ProcessPoolExecutor as Pool
38
+ from pathlib import Path
39
+
40
+ import lhotse
41
+ import torch
42
+ from lhotse import CutSet, LilcomChunkyWriter, load_manifest_lazy
43
+
44
+ from zipvoice.utils.common import str2bool
45
+ from zipvoice.utils.feature import VocosFbank
46
+
47
+ # Torch's multithreaded behavior needs to be disabled or
48
+ # it wastes a lot of CPU and slow things down.
49
+ # Do this outside of main() in case it needs to take effect
50
+ # even when we are not invoking the main (e.g. when spawning subprocesses).
51
+ torch.set_num_threads(1)
52
+ torch.set_num_interop_threads(1)
53
+
54
+ lhotse.set_audio_duration_mismatch_tolerance(0.1)
55
+
56
+
57
+ def get_args():
58
+ parser = argparse.ArgumentParser()
59
+
60
+ parser.add_argument(
61
+ "--sampling-rate",
62
+ type=int,
63
+ default=24000,
64
+ help="The target sampling rate, the audio will be resampled to it.",
65
+ )
66
+
67
+ parser.add_argument(
68
+ "--type",
69
+ type=str,
70
+ default="vocos",
71
+ help="fbank type",
72
+ )
73
+
74
+ parser.add_argument(
75
+ "--dataset",
76
+ type=str,
77
+ help="Dataset name.",
78
+ )
79
+
80
+ parser.add_argument(
81
+ "--subset",
82
+ type=str,
83
+ help="The subset of the dataset.",
84
+ )
85
+
86
+ parser.add_argument(
87
+ "--source-dir",
88
+ type=str,
89
+ default="data/manifests",
90
+ help="The source directory of manifest files.",
91
+ )
92
+
93
+ parser.add_argument(
94
+ "--dest-dir",
95
+ type=str,
96
+ default="data/fbank",
97
+ help="The destination directory of manifest files.",
98
+ )
99
+
100
+ parser.add_argument(
101
+ "--split-cuts",
102
+ type=str2bool,
103
+ default=False,
104
+ help="Whether to use splited cuts.",
105
+ )
106
+
107
+ parser.add_argument(
108
+ "--split-begin",
109
+ type=int,
110
+ help="Start idx of splited cuts.",
111
+ )
112
+
113
+ parser.add_argument(
114
+ "--split-end",
115
+ type=int,
116
+ help="End idx of splited cuts.",
117
+ )
118
+
119
+ parser.add_argument(
120
+ "--batch-duration",
121
+ type=int,
122
+ default=1000,
123
+ help="The batch duration when computing the features.",
124
+ )
125
+
126
+ parser.add_argument(
127
+ "--num-jobs",
128
+ type=int,
129
+ default=20,
130
+ help="The number of extractor workers.",
131
+ )
132
+
133
+ return parser.parse_args()
134
+
135
+
136
+ def compute_fbank_split_single(params, idx):
137
+ logging.info(
138
+ f"Computing features for {idx}-th split of "
139
+ f"{params.dataset} dataset {params.subset} subset"
140
+ )
141
+ lhotse.set_audio_duration_mismatch_tolerance(0.1) # for emilia
142
+ src_dir = Path(params.source_dir)
143
+ output_dir = Path(params.dest_dir)
144
+
145
+ if not src_dir.exists():
146
+ logging.error(f"{src_dir} not exists")
147
+ return
148
+
149
+ if not output_dir.exists():
150
+ output_dir.mkdir(parents=True, exist_ok=True)
151
+
152
+ num_digits = 8
153
+ if params.type == "vocos":
154
+ extractor = VocosFbank()
155
+ else:
156
+ raise NotImplementedError(f"{params.type} is not supported")
157
+
158
+ prefix = params.dataset
159
+ subset = params.subset
160
+ suffix = "jsonl.gz"
161
+
162
+ idx = f"{idx}".zfill(num_digits)
163
+ cuts_filename = f"{prefix}_cuts_{subset}.{idx}.{suffix}"
164
+
165
+ if (src_dir / cuts_filename).is_file():
166
+ logging.info(f"Loading manifests {src_dir / cuts_filename}")
167
+ cut_set = load_manifest_lazy(src_dir / cuts_filename)
168
+ else:
169
+ logging.warning(f"Raw {cuts_filename} not exists, skipping")
170
+ return
171
+
172
+ cut_set = cut_set.resample(params.sampling_rate)
173
+
174
+ if (output_dir / cuts_filename).is_file():
175
+ logging.info(f"{cuts_filename} already exists - skipping.")
176
+ return
177
+
178
+ logging.info(f"Processing {subset}.{idx} of {prefix}")
179
+
180
+ cut_set = cut_set.compute_and_store_features_batch(
181
+ extractor=extractor,
182
+ storage_path=f"{output_dir}/{prefix}_feats_{subset}_{idx}",
183
+ num_workers=4,
184
+ batch_duration=params.batch_duration,
185
+ storage_type=LilcomChunkyWriter,
186
+ overwrite=True,
187
+ )
188
+ logging.info(f"Saving file to {output_dir / cuts_filename}")
189
+ cut_set.to_file(output_dir / cuts_filename)
190
+
191
+
192
+ def compute_fbank_split(params):
193
+ if params.split_end < params.split_begin:
194
+ logging.warning(
195
+ f"Split begin should be smaller than split end, given "
196
+ f"{params.split_begin} -> {params.split_end}."
197
+ )
198
+
199
+ with Pool(max_workers=params.num_jobs) as pool:
200
+ futures = [
201
+ pool.submit(compute_fbank_split_single, params, i)
202
+ for i in range(params.split_begin, params.split_end)
203
+ ]
204
+ for f in futures:
205
+ f.result()
206
+ f.done()
207
+
208
+
209
+ def compute_fbank(params):
210
+ logging.info(
211
+ f"Computing features for {params.dataset} dataset {params.subset} subset"
212
+ )
213
+ src_dir = Path(params.source_dir)
214
+ output_dir = Path(params.dest_dir)
215
+ num_jobs = params.num_jobs
216
+ if not output_dir.exists():
217
+ output_dir.mkdir(parents=True, exist_ok=True)
218
+
219
+ prefix = params.dataset
220
+ subset = params.subset
221
+ suffix = "jsonl.gz"
222
+
223
+ cut_set_name = f"{prefix}_cuts_{subset}.{suffix}"
224
+
225
+ if (src_dir / cut_set_name).is_file():
226
+ logging.info(f"Loading manifests {src_dir / cut_set_name}")
227
+ cut_set = load_manifest_lazy(src_dir / cut_set_name)
228
+ else:
229
+ recordings = load_manifest_lazy(
230
+ src_dir / f"{prefix}_recordings_{subset}.{suffix}"
231
+ )
232
+ supervisions = load_manifest_lazy(
233
+ src_dir / f"{prefix}_supervisions_{subset}.{suffix}"
234
+ )
235
+ cut_set = CutSet.from_manifests(
236
+ recordings=recordings,
237
+ supervisions=supervisions,
238
+ )
239
+
240
+ cut_set = cut_set.resample(params.sampling_rate)
241
+ if params.type == "vocos":
242
+ extractor = VocosFbank()
243
+ else:
244
+ raise NotImplementedError(f"{params.type} is not supported")
245
+
246
+ cuts_filename = f"{prefix}_cuts_{subset}.{suffix}"
247
+ if (output_dir / cuts_filename).is_file():
248
+ logging.info(f"{prefix} {subset} already exists - skipping.")
249
+ return
250
+ logging.info(f"Processing {subset} of {prefix}")
251
+
252
+ cut_set = cut_set.compute_and_store_features(
253
+ extractor=extractor,
254
+ storage_path=f"{output_dir}/{prefix}_feats_{subset}",
255
+ num_jobs=num_jobs,
256
+ storage_type=LilcomChunkyWriter,
257
+ )
258
+ logging.info(f"Saving file to {output_dir / cuts_filename}")
259
+ cut_set.to_file(output_dir / cuts_filename)
260
+
261
+
262
+ if __name__ == "__main__":
263
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
264
+ logging.basicConfig(format=formatter, level=logging.INFO, force=True)
265
+
266
+ args = get_args()
267
+ logging.info(vars(args))
268
+ if args.split_cuts:
269
+ compute_fbank_split(params=args)
270
+ else:
271
+ compute_fbank(params=args)
272
+ logging.info("Done!")
zipvoice/bin/generate_averaged_model.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ #
3
+ # Copyright 2021-2022 Xiaomi Corporation
4
+ #
5
+ # See ../../../../LICENSE for clarification regarding multiple authors
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+ """
19
+ Usage:
20
+ This script loads checkpoints and averages them.
21
+
22
+ python3 -m zipvoice.bin.generate_averaged_model \
23
+ --epoch 11 \
24
+ --avg 4 \
25
+ --model-name zipvoice \
26
+ --exp-dir exp/zipvoice
27
+
28
+ It will generate a file `epoch-11-avg-14.pt` in the given `exp_dir`.
29
+ You can later load it by `torch.load("epoch-11-avg-4.pt")`.
30
+ """
31
+
32
+ import argparse
33
+ import json
34
+ import logging
35
+ from pathlib import Path
36
+
37
+ import torch
38
+
39
+ from zipvoice.models.zipvoice import ZipVoice
40
+ from zipvoice.models.zipvoice_dialog import ZipVoiceDialog, ZipVoiceDialogStereo
41
+ from zipvoice.models.zipvoice_distill import ZipVoiceDistill
42
+ from zipvoice.tokenizer.tokenizer import SimpleTokenizer
43
+ from zipvoice.utils.checkpoint import (
44
+ average_checkpoints_with_averaged_model,
45
+ find_checkpoints,
46
+ )
47
+ from zipvoice.utils.common import AttributeDict
48
+
49
+
50
+ def get_parser():
51
+ parser = argparse.ArgumentParser(
52
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
53
+ )
54
+
55
+ parser.add_argument(
56
+ "--epoch",
57
+ type=int,
58
+ default=11,
59
+ help="""It specifies the checkpoint to use for decoding.
60
+ Note: Epoch counts from 1.
61
+ You can specify --avg to use more checkpoints for model averaging.""",
62
+ )
63
+
64
+ parser.add_argument(
65
+ "--iter",
66
+ type=int,
67
+ default=0,
68
+ help="""If positive, --epoch is ignored and it
69
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
70
+ You can specify --avg to use more checkpoints for model averaging.
71
+ """,
72
+ )
73
+
74
+ parser.add_argument(
75
+ "--avg",
76
+ type=int,
77
+ default=4,
78
+ help="Number of checkpoints to average. Automatically select "
79
+ "consecutive checkpoints before the checkpoint specified by "
80
+ "'--epoch' or --iter",
81
+ )
82
+
83
+ parser.add_argument(
84
+ "--exp-dir",
85
+ type=str,
86
+ default="exp/zipvoice",
87
+ help="The experiment dir",
88
+ )
89
+
90
+ parser.add_argument(
91
+ "--model-name",
92
+ type=str,
93
+ default="zipvoice",
94
+ choices=[
95
+ "zipvoice",
96
+ "zipvoice_distill",
97
+ "zipvoice_dialog",
98
+ "zipvoice_dialog_stereo",
99
+ ],
100
+ help="The model type to be averaged. ",
101
+ )
102
+
103
+ return parser
104
+
105
+
106
+ @torch.no_grad()
107
+ def main():
108
+ parser = get_parser()
109
+ args = parser.parse_args()
110
+ params = AttributeDict()
111
+ params.update(vars(args))
112
+ params.exp_dir = Path(params.exp_dir)
113
+
114
+ with open(params.exp_dir / "model.json", "r") as f:
115
+ model_config = json.load(f)
116
+
117
+ # Any tokenizer can be used here.
118
+ # Use SimpleTokenizer for simplicity.
119
+ tokenizer = SimpleTokenizer(token_file=params.exp_dir / "tokens.txt")
120
+ if params.model_name in ["zipvoice", "zipvoice_distill"]:
121
+ tokenizer_config = {
122
+ "vocab_size": tokenizer.vocab_size,
123
+ "pad_id": tokenizer.pad_id,
124
+ }
125
+ elif params.model_name in ["zipvoice_dialog", "zipvoice_dialog_stereo"]:
126
+ tokenizer_config = {
127
+ "vocab_size": tokenizer.vocab_size,
128
+ "pad_id": tokenizer.pad_id,
129
+ "spk_a_id": tokenizer.spk_a_id,
130
+ "spk_b_id": tokenizer.spk_b_id,
131
+ }
132
+
133
+ params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
134
+
135
+ logging.info("Script started")
136
+
137
+ params.device = torch.device("cpu")
138
+ logging.info(f"Device: {params.device}")
139
+
140
+ logging.info("About to create model")
141
+ if params.model_name == "zipvoice":
142
+ model = ZipVoice(
143
+ **model_config["model"],
144
+ **tokenizer_config,
145
+ )
146
+ elif params.model_name == "zipvoice_distill":
147
+ model = ZipVoiceDistill(
148
+ **model_config["model"],
149
+ **tokenizer_config,
150
+ )
151
+ elif params.model_name == "zipvoice_dialog":
152
+ model = ZipVoiceDialog(
153
+ **model_config["model"],
154
+ **tokenizer_config,
155
+ )
156
+ elif params.model_name == "zipvoice_dialog_stereo":
157
+ model = ZipVoiceDialogStereo(
158
+ **model_config["model"],
159
+ **tokenizer_config,
160
+ )
161
+ else:
162
+ raise ValueError(f"Unknown model name: {params.model_name}")
163
+
164
+ if params.iter > 0:
165
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
166
+ : params.avg + 1
167
+ ]
168
+ if len(filenames) == 0:
169
+ raise ValueError(
170
+ f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}"
171
+ )
172
+ elif len(filenames) < params.avg + 1:
173
+ raise ValueError(
174
+ f"Not enough checkpoints ({len(filenames)}) found for"
175
+ f" --iter {params.iter}, --avg {params.avg}"
176
+ )
177
+ filename_start = filenames[-1]
178
+ filename_end = filenames[0]
179
+ logging.info(
180
+ "Calculating the averaged model over iteration checkpoints"
181
+ f" from {filename_start} (excluded) to {filename_end}"
182
+ )
183
+ model.to(params.device)
184
+ model.load_state_dict(
185
+ average_checkpoints_with_averaged_model(
186
+ filename_start=filename_start,
187
+ filename_end=filename_end,
188
+ device=params.device,
189
+ ),
190
+ strict=True,
191
+ )
192
+ else:
193
+ assert params.avg > 0, params.avg
194
+ start = params.epoch - params.avg
195
+ assert start >= 1, start
196
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
197
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
198
+ logging.info(
199
+ f"Calculating the averaged model over epoch range from "
200
+ f"{start} (excluded) to {params.epoch}"
201
+ )
202
+ model.to(params.device)
203
+ model.load_state_dict(
204
+ average_checkpoints_with_averaged_model(
205
+ filename_start=filename_start,
206
+ filename_end=filename_end,
207
+ device=params.device,
208
+ ),
209
+ strict=True,
210
+ )
211
+ if params.iter > 0:
212
+ filename = params.exp_dir / f"iter-{params.iter}-avg-{params.avg}.pt"
213
+ else:
214
+ filename = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt"
215
+
216
+ logging.info(f"Saving the averaged checkpoint to {filename}")
217
+ torch.save({"model": model.state_dict()}, filename)
218
+
219
+ num_param = sum([p.numel() for p in model.parameters()])
220
+ logging.info(f"Number of model parameters: {num_param}")
221
+
222
+ logging.info("Done!")
223
+
224
+
225
+ if __name__ == "__main__":
226
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
227
+ logging.basicConfig(format=formatter, level=logging.INFO, force=True)
228
+
229
+ main()
zipvoice/bin/infer_zipvoice.py ADDED
@@ -0,0 +1,877 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2025 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """
19
+ This script generates speech with our pre-trained ZipVoice or
20
+ ZipVoice-Distill models. If no local model is specified,
21
+ Required files will be automatically downloaded from HuggingFace.
22
+
23
+ Usage:
24
+
25
+ Note: If you having trouble connecting to HuggingFace,
26
+ try switching endpoint to mirror site:
27
+ export HF_ENDPOINT=https://hf-mirror.com
28
+
29
+ (1) Inference of a single sentence:
30
+
31
+ python3 -m zipvoice.bin.infer_zipvoice \
32
+ --model-name zipvoice \
33
+ --prompt-wav prompt.wav \
34
+ --prompt-text "I am a prompt." \
35
+ --text "I am a sentence." \
36
+ --res-wav-path result.wav
37
+
38
+ (2) Inference of a list of sentences:
39
+
40
+ python3 -m zipvoice.bin.infer_zipvoice \
41
+ --model-name zipvoice \
42
+ --test-list test.tsv \
43
+ --res-dir results
44
+
45
+ `--model-name` can be `zipvoice` or `zipvoice_distill`,
46
+ which are the models before and after distillation, respectively.
47
+
48
+ Each line of `test.tsv` is in the format of
49
+ `{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}`.
50
+ """
51
+
52
+ import argparse
53
+ import datetime as dt
54
+ import json
55
+ import logging
56
+ import os
57
+ from pathlib import Path
58
+ from typing import Optional
59
+
60
+ import numpy as np
61
+ import safetensors.torch
62
+ import torch
63
+ import torchaudio
64
+ from huggingface_hub import hf_hub_download
65
+ from lhotse.utils import fix_random_seed
66
+ from vocos import Vocos
67
+
68
+ from zipvoice.models.zipvoice import ZipVoice
69
+ from zipvoice.models.zipvoice_distill import ZipVoiceDistill
70
+ from zipvoice.tokenizer.tokenizer import (
71
+ EmiliaTokenizer,
72
+ EspeakTokenizer,
73
+ LibriTTSTokenizer,
74
+ SimpleTokenizer,
75
+ )
76
+ from zipvoice.utils.checkpoint import load_checkpoint
77
+ from zipvoice.utils.common import AttributeDict, str2bool
78
+ from zipvoice.utils.feature import VocosFbank
79
+ from zipvoice.utils.infer import (
80
+ add_punctuation,
81
+ batchify_tokens,
82
+ chunk_tokens_punctuation,
83
+ cross_fade_concat,
84
+ load_prompt_wav,
85
+ remove_silence,
86
+ rms_norm,
87
+ )
88
+
89
+ HUGGINGFACE_REPO = "k2-fsa/ZipVoice"
90
+ MODEL_DIR = {
91
+ "zipvoice": "zipvoice",
92
+ "zipvoice_distill": "zipvoice_distill",
93
+ }
94
+
95
+
96
+ def get_parser():
97
+ parser = argparse.ArgumentParser(
98
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
99
+ )
100
+
101
+ parser.add_argument(
102
+ "--model-name",
103
+ type=str,
104
+ default="zipvoice",
105
+ choices=["zipvoice", "zipvoice_distill"],
106
+ help="The model used for inference",
107
+ )
108
+
109
+ parser.add_argument(
110
+ "--model-dir",
111
+ type=str,
112
+ default=None,
113
+ help="The model directory that contains model checkpoint, configuration "
114
+ "file model.json, and tokens file tokens.txt. Will download pre-trained "
115
+ "checkpoint from huggingface if not specified.",
116
+ )
117
+
118
+ parser.add_argument(
119
+ "--checkpoint-name",
120
+ type=str,
121
+ default="model.pt",
122
+ help="The name of model checkpoint.",
123
+ )
124
+
125
+ parser.add_argument(
126
+ "--vocoder-path",
127
+ type=str,
128
+ default=None,
129
+ help="The vocoder checkpoint. "
130
+ "Will download pre-trained vocoder from huggingface if not specified.",
131
+ )
132
+
133
+ parser.add_argument(
134
+ "--tokenizer",
135
+ type=str,
136
+ default="emilia",
137
+ choices=["emilia", "libritts", "espeak", "simple"],
138
+ help="Tokenizer type.",
139
+ )
140
+
141
+ parser.add_argument(
142
+ "--lang",
143
+ type=str,
144
+ default="en-us",
145
+ help="Language identifier, used when tokenizer type is espeak. see"
146
+ "https://github.com/rhasspy/espeak-ng/blob/master/docs/languages.md",
147
+ )
148
+
149
+ parser.add_argument(
150
+ "--test-list",
151
+ type=str,
152
+ default=None,
153
+ help="The list of prompt speech, prompt_transcription, "
154
+ "and text to synthesizein the format of "
155
+ "'{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}'.",
156
+ )
157
+
158
+ parser.add_argument(
159
+ "--prompt-wav",
160
+ type=str,
161
+ default=None,
162
+ help="The prompt wav to mimic",
163
+ )
164
+
165
+ parser.add_argument(
166
+ "--prompt-text",
167
+ type=str,
168
+ default=None,
169
+ help="The transcription of the prompt wav",
170
+ )
171
+
172
+ parser.add_argument(
173
+ "--text",
174
+ type=str,
175
+ default=None,
176
+ help="The text to synthesize",
177
+ )
178
+
179
+ parser.add_argument(
180
+ "--res-dir",
181
+ type=str,
182
+ default="results",
183
+ help="""
184
+ Path name of the generated wavs dir,
185
+ used when test-list is not None
186
+ """,
187
+ )
188
+
189
+ parser.add_argument(
190
+ "--res-wav-path",
191
+ type=str,
192
+ default="result.wav",
193
+ help="""
194
+ Path name of the generated wav path,
195
+ used when test-list is None
196
+ """,
197
+ )
198
+
199
+ parser.add_argument(
200
+ "--guidance-scale",
201
+ type=float,
202
+ default=None,
203
+ help="The scale of classifier-free guidance during inference.",
204
+ )
205
+
206
+ parser.add_argument(
207
+ "--num-step",
208
+ type=int,
209
+ default=None,
210
+ help="The number of sampling steps.",
211
+ )
212
+
213
+ parser.add_argument(
214
+ "--feat-scale",
215
+ type=float,
216
+ default=0.1,
217
+ help="The scale factor of fbank feature",
218
+ )
219
+
220
+ parser.add_argument(
221
+ "--speed",
222
+ type=float,
223
+ default=1.0,
224
+ help="Control speech speed, 1.0 means normal, >1.0 means speed up",
225
+ )
226
+
227
+ parser.add_argument(
228
+ "--t-shift",
229
+ type=float,
230
+ default=0.5,
231
+ help="Shift t to smaller ones if t_shift < 1.0",
232
+ )
233
+
234
+ parser.add_argument(
235
+ "--target-rms",
236
+ type=float,
237
+ default=0.1,
238
+ help="Target speech normalization rms value, set to 0 to disable normalization",
239
+ )
240
+
241
+ parser.add_argument(
242
+ "--seed",
243
+ type=int,
244
+ default=666,
245
+ help="Random seed",
246
+ )
247
+
248
+ parser.add_argument(
249
+ "--num-thread",
250
+ type=int,
251
+ default=1,
252
+ help="Number of threads to use for PyTorch on CPU.",
253
+ )
254
+
255
+ parser.add_argument(
256
+ "--raw-evaluation",
257
+ type=str2bool,
258
+ default=False,
259
+ help="Whether to use the 'raw' evaluation mode where provided "
260
+ "prompts and text are fed to the model without pre-processing",
261
+ )
262
+
263
+ parser.add_argument(
264
+ "--max-duration",
265
+ type=float,
266
+ default=100,
267
+ help="Maximum duration (seconds) in a single batch, including "
268
+ "durations of the prompt and generated wavs. You can reduce it "
269
+ "if it causes CUDA OOM.",
270
+ )
271
+
272
+ parser.add_argument(
273
+ "--remove-long-sil",
274
+ type=str2bool,
275
+ default=False,
276
+ help="Whether to remove long silences in the middle of the generated "
277
+ "speech (edge silences will be removed by default).",
278
+ )
279
+ return parser
280
+
281
+
282
+ def get_vocoder(vocos_local_path: Optional[str] = None):
283
+ if vocos_local_path:
284
+ vocoder = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
285
+ state_dict = torch.load(
286
+ f"{vocos_local_path}/pytorch_model.bin",
287
+ weights_only=True,
288
+ map_location="cpu",
289
+ )
290
+ vocoder.load_state_dict(state_dict)
291
+ else:
292
+ vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz")
293
+ return vocoder
294
+
295
+
296
+ def generate_sentence_raw_evaluation(
297
+ save_path: str,
298
+ prompt_text: str,
299
+ prompt_wav: str,
300
+ text: str,
301
+ model: torch.nn.Module,
302
+ vocoder: torch.nn.Module,
303
+ tokenizer: EmiliaTokenizer,
304
+ feature_extractor: VocosFbank,
305
+ device: torch.device,
306
+ num_step: int = 16,
307
+ guidance_scale: float = 1.0,
308
+ speed: float = 1.0,
309
+ t_shift: float = 0.5,
310
+ target_rms: float = 0.1,
311
+ feat_scale: float = 0.1,
312
+ sampling_rate: int = 24000,
313
+ ):
314
+ """
315
+ Generate waveform of a text based on a given prompt waveform and its transcription,
316
+ this function directly feed the prompt_text, prompt_wav and text to the model.
317
+ It is not efficient and can have poor results for some inappropriate inputs.
318
+ (e.g., prompt wav contains long silence, text to be generated is too long)
319
+ This function can be used to evaluate the "raw" performance of the model.
320
+
321
+ Args:
322
+ save_path (str): Path to save the generated wav.
323
+ prompt_text (str): Transcription of the prompt wav.
324
+ prompt_wav (str): Path to the prompt wav file.
325
+ text (str): Text to be synthesized into a waveform.
326
+ model (torch.nn.Module): The model used for generation.
327
+ vocoder (torch.nn.Module): The vocoder used to convert features to waveforms.
328
+ tokenizer (EmiliaTokenizer): The tokenizer used to convert text to tokens.
329
+ feature_extractor (VocosFbank): The feature extractor used to
330
+ extract acoustic features.
331
+ device (torch.device): The device on which computations are performed.
332
+ num_step (int, optional): Number of steps for decoding. Defaults to 16.
333
+ guidance_scale (float, optional): Scale for classifier-free guidance.
334
+ Defaults to 1.0.
335
+ speed (float, optional): Speed control. Defaults to 1.0.
336
+ t_shift (float, optional): Time shift. Defaults to 0.5.
337
+ target_rms (float, optional): Target RMS for waveform normalization.
338
+ Defaults to 0.1.
339
+ feat_scale (float, optional): Scale for features.
340
+ Defaults to 0.1.
341
+ sampling_rate (int, optional): Sampling rate for the waveform.
342
+ Defaults to 24000.
343
+ Returns:
344
+ metrics (dict): Dictionary containing time and real-time
345
+ factor metrics for processing.
346
+ """
347
+
348
+ # Load and process prompt wav
349
+ prompt_wav = load_prompt_wav(prompt_wav, sampling_rate=sampling_rate)
350
+ prompt_wav, prompt_rms = rms_norm(prompt_wav, target_rms)
351
+
352
+ # Extract features from prompt wav
353
+ prompt_features = feature_extractor.extract(
354
+ prompt_wav, sampling_rate=sampling_rate
355
+ ).to(device)
356
+
357
+ prompt_features = prompt_features.unsqueeze(0) * feat_scale
358
+ prompt_features_lens = torch.tensor([prompt_features.size(1)], device=device)
359
+
360
+ # Convert text to tokens
361
+ tokens = tokenizer.texts_to_token_ids([text])
362
+ prompt_tokens = tokenizer.texts_to_token_ids([prompt_text])
363
+
364
+ # Start timing
365
+ start_t = dt.datetime.now()
366
+
367
+ # Generate features
368
+ (
369
+ pred_features,
370
+ pred_features_lens,
371
+ pred_prompt_features,
372
+ pred_prompt_features_lens,
373
+ ) = model.sample(
374
+ tokens=tokens,
375
+ prompt_tokens=prompt_tokens,
376
+ prompt_features=prompt_features,
377
+ prompt_features_lens=prompt_features_lens,
378
+ speed=speed,
379
+ t_shift=t_shift,
380
+ duration="predict",
381
+ num_step=num_step,
382
+ guidance_scale=guidance_scale,
383
+ )
384
+
385
+ # Postprocess predicted features
386
+ pred_features = pred_features.permute(0, 2, 1) / feat_scale # (B, C, T)
387
+
388
+ # Start vocoder processing
389
+ start_vocoder_t = dt.datetime.now()
390
+ wav = vocoder.decode(pred_features).squeeze(1).clamp(-1, 1)
391
+
392
+ # Calculate processing times and real-time factors
393
+ t = (dt.datetime.now() - start_t).total_seconds()
394
+ t_no_vocoder = (start_vocoder_t - start_t).total_seconds()
395
+ t_vocoder = (dt.datetime.now() - start_vocoder_t).total_seconds()
396
+ wav_seconds = wav.shape[-1] / sampling_rate
397
+ rtf = t / wav_seconds
398
+ rtf_no_vocoder = t_no_vocoder / wav_seconds
399
+ rtf_vocoder = t_vocoder / wav_seconds
400
+ metrics = {
401
+ "t": t,
402
+ "t_no_vocoder": t_no_vocoder,
403
+ "t_vocoder": t_vocoder,
404
+ "wav_seconds": wav_seconds,
405
+ "rtf": rtf,
406
+ "rtf_no_vocoder": rtf_no_vocoder,
407
+ "rtf_vocoder": rtf_vocoder,
408
+ }
409
+
410
+ # Adjust wav volume if necessary
411
+ if prompt_rms < target_rms:
412
+ wav = wav * prompt_rms / target_rms
413
+ torchaudio.save(save_path, wav.cpu(), sample_rate=sampling_rate)
414
+
415
+ return metrics
416
+
417
+
418
+ def generate_sentence(
419
+ save_path: str,
420
+ prompt_text: str,
421
+ prompt_wav: str,
422
+ text: str,
423
+ model: torch.nn.Module,
424
+ vocoder: torch.nn.Module,
425
+ tokenizer: EmiliaTokenizer,
426
+ feature_extractor: VocosFbank,
427
+ device: torch.device,
428
+ num_step: int = 16,
429
+ guidance_scale: float = 1.0,
430
+ speed: float = 1.0,
431
+ t_shift: float = 0.5,
432
+ target_rms: float = 0.1,
433
+ feat_scale: float = 0.1,
434
+ sampling_rate: int = 24000,
435
+ max_duration: float = 100,
436
+ remove_long_sil: bool = False,
437
+ ):
438
+ """
439
+ Generate waveform of a text based on a given prompt waveform and its transcription,
440
+ this function will do the following to improve the generation quality:
441
+ 1. chunk the text according to punctuations.
442
+ 2. process chunked texts in batches.
443
+ 3. remove long silences in the prompt audio.
444
+ 4. add punctuation to the end of prompt text and text if there is not.
445
+
446
+ Args:
447
+ save_path (str): Path to save the generated wav.
448
+ prompt_text (str): Transcription of the prompt wav.
449
+ prompt_wav (str): Path to the prompt wav file.
450
+ text (str): Text to be synthesized into a waveform.
451
+ model (torch.nn.Module): The model used for generation.
452
+ vocoder (torch.nn.Module): The vocoder used to convert features to waveforms.
453
+ tokenizer (EmiliaTokenizer): The tokenizer used to convert text to tokens.
454
+ feature_extractor (VocosFbank): The feature extractor used to
455
+ extract acoustic features.
456
+ device (torch.device): The device on which computations are performed.
457
+ num_step (int, optional): Number of steps for decoding. Defaults to 16.
458
+ guidance_scale (float, optional): Scale for classifier-free guidance.
459
+ Defaults to 1.0.
460
+ speed (float, optional): Speed control. Defaults to 1.0.
461
+ t_shift (float, optional): Time shift. Defaults to 0.5.
462
+ target_rms (float, optional): Target RMS for waveform normalization.
463
+ Defaults to 0.1.
464
+ feat_scale (float, optional): Scale for features.
465
+ Defaults to 0.1.
466
+ sampling_rate (int, optional): Sampling rate for the waveform.
467
+ Defaults to 24000.
468
+ max_duration (float, optional): The maximum duration to process in each
469
+ batch. Used to control memory consumption when generating long audios.
470
+ remove_long_sil (bool, optional): Whether to remove long silences in the
471
+ middle of the generated speech (edge silences will be removed by default).
472
+ Returns:
473
+ metrics (dict): Dictionary containing time and real-time
474
+ factor metrics for processing.
475
+ """
476
+
477
+ # Load and process prompt wav
478
+ prompt_wav = load_prompt_wav(prompt_wav, sampling_rate=sampling_rate)
479
+
480
+ # Remove edge and long silences in the prompt wav.
481
+ # Add 0.2s trailing silence to avoid leaking prompt to generated speech.
482
+ prompt_wav = remove_silence(
483
+ prompt_wav, sampling_rate, only_edge=False, trail_sil=200
484
+ )
485
+
486
+ prompt_wav, prompt_rms = rms_norm(prompt_wav, target_rms)
487
+
488
+ prompt_duration = prompt_wav.shape[-1] / sampling_rate
489
+
490
+ if prompt_duration > 20:
491
+ logging.warning(
492
+ f"Given prompt wav is too long ({prompt_duration}s). "
493
+ f"Please provide a shorter one (1-3 seconds is recommended)."
494
+ )
495
+ elif prompt_duration > 10:
496
+ logging.warning(
497
+ f"Given prompt wav is long ({prompt_duration}s). "
498
+ f"It will lead to slower inference speed and possibly worse speech quality."
499
+ )
500
+
501
+ # Extract features from prompt wav
502
+ prompt_features = feature_extractor.extract(
503
+ prompt_wav, sampling_rate=sampling_rate
504
+ ).to(device)
505
+
506
+ prompt_features = prompt_features.unsqueeze(0) * feat_scale
507
+
508
+ # Add punctuation in the end if there is not
509
+ text = add_punctuation(text)
510
+ prompt_text = add_punctuation(prompt_text)
511
+
512
+ # Tokenize text (str tokens), punctuations will be preserved.
513
+ tokens_str = tokenizer.texts_to_tokens([text])[0]
514
+ prompt_tokens_str = tokenizer.texts_to_tokens([prompt_text])[0]
515
+
516
+ # chunk text so that each len(prompt wav + generated wav) is around 25 seconds.
517
+ token_duration = (prompt_wav.shape[-1] / sampling_rate) / (
518
+ len(prompt_tokens_str) * speed
519
+ )
520
+ max_tokens = int((25 - prompt_duration) / token_duration)
521
+ chunked_tokens_str = chunk_tokens_punctuation(tokens_str, max_tokens=max_tokens)
522
+
523
+ # Tokenize text (int tokens)
524
+ chunked_tokens = tokenizer.tokens_to_token_ids(chunked_tokens_str)
525
+ prompt_tokens = tokenizer.tokens_to_token_ids([prompt_tokens_str])
526
+
527
+ # Batchify chunked texts for faster processing
528
+ tokens_batches, chunked_index = batchify_tokens(
529
+ chunked_tokens, max_duration, prompt_duration, token_duration
530
+ )
531
+
532
+ # Start predicting features
533
+ chunked_features = []
534
+ start_t = dt.datetime.now()
535
+
536
+ for batch_tokens in tokens_batches:
537
+ batch_prompt_tokens = prompt_tokens * len(batch_tokens)
538
+
539
+ batch_prompt_features = prompt_features.repeat(len(batch_tokens), 1, 1)
540
+ batch_prompt_features_lens = torch.full(
541
+ (len(batch_tokens),), prompt_features.size(1), device=device
542
+ )
543
+
544
+ # Generate features
545
+ (
546
+ pred_features,
547
+ pred_features_lens,
548
+ pred_prompt_features,
549
+ pred_prompt_features_lens,
550
+ ) = model.sample(
551
+ tokens=batch_tokens,
552
+ prompt_tokens=batch_prompt_tokens,
553
+ prompt_features=batch_prompt_features,
554
+ prompt_features_lens=batch_prompt_features_lens,
555
+ speed=speed,
556
+ t_shift=t_shift,
557
+ duration="predict",
558
+ num_step=num_step,
559
+ guidance_scale=guidance_scale,
560
+ )
561
+
562
+ # Postprocess predicted features
563
+ pred_features = pred_features.permute(0, 2, 1) / feat_scale # (B, C, T)
564
+ chunked_features.append((pred_features, pred_features_lens))
565
+
566
+ # Start vocoder processing
567
+ chunked_wavs = []
568
+ start_vocoder_t = dt.datetime.now()
569
+
570
+ for pred_features, pred_features_lens in chunked_features:
571
+ batch_wav = []
572
+ for i in range(pred_features.size(0)):
573
+
574
+ wav = (
575
+ vocoder.decode(pred_features[i][None, :, : pred_features_lens[i]])
576
+ .squeeze(1)
577
+ .clamp(-1, 1)
578
+ )
579
+ # Adjust wav volume if necessary
580
+ if prompt_rms < target_rms:
581
+ wav = wav * prompt_rms / target_rms
582
+ batch_wav.append(wav)
583
+ chunked_wavs.extend(batch_wav)
584
+
585
+ # Finish model generation
586
+ t = (dt.datetime.now() - start_t).total_seconds()
587
+
588
+ # Merge chunked wavs
589
+ indexed_chunked_wavs = [
590
+ (index, wav) for index, wav in zip(chunked_index, chunked_wavs)
591
+ ]
592
+ sequential_indexed_chunked_wavs = sorted(indexed_chunked_wavs, key=lambda x: x[0])
593
+ sequential_chunked_wavs = [
594
+ sequential_indexed_chunked_wavs[i][1]
595
+ for i in range(len(sequential_indexed_chunked_wavs))
596
+ ]
597
+ final_wav = cross_fade_concat(
598
+ sequential_chunked_wavs, fade_duration=0.1, sample_rate=sampling_rate
599
+ )
600
+ final_wav = remove_silence(
601
+ final_wav, sampling_rate, only_edge=(not remove_long_sil), trail_sil=0
602
+ )
603
+
604
+ # Calculate processing time metrics
605
+ t_no_vocoder = (start_vocoder_t - start_t).total_seconds()
606
+ t_vocoder = (dt.datetime.now() - start_vocoder_t).total_seconds()
607
+ wav_seconds = final_wav.shape[-1] / sampling_rate
608
+ rtf = t / wav_seconds
609
+ rtf_no_vocoder = t_no_vocoder / wav_seconds
610
+ rtf_vocoder = t_vocoder / wav_seconds
611
+ metrics = {
612
+ "t": t,
613
+ "t_no_vocoder": t_no_vocoder,
614
+ "t_vocoder": t_vocoder,
615
+ "wav_seconds": wav_seconds,
616
+ "rtf": rtf,
617
+ "rtf_no_vocoder": rtf_no_vocoder,
618
+ "rtf_vocoder": rtf_vocoder,
619
+ }
620
+
621
+ torchaudio.save(save_path, final_wav.cpu(), sample_rate=sampling_rate)
622
+ return metrics
623
+
624
+
625
+ def generate_list(
626
+ res_dir: str,
627
+ test_list: str,
628
+ model: torch.nn.Module,
629
+ vocoder: torch.nn.Module,
630
+ tokenizer: EmiliaTokenizer,
631
+ feature_extractor: VocosFbank,
632
+ device: torch.device,
633
+ num_step: int = 16,
634
+ guidance_scale: float = 1.0,
635
+ speed: float = 1.0,
636
+ t_shift: float = 0.5,
637
+ target_rms: float = 0.1,
638
+ feat_scale: float = 0.1,
639
+ sampling_rate: int = 24000,
640
+ raw_evaluation: bool = False,
641
+ max_duration: float = 100,
642
+ remove_long_sil: bool = False,
643
+ ):
644
+ total_t = []
645
+ total_t_no_vocoder = []
646
+ total_t_vocoder = []
647
+ total_wav_seconds = []
648
+
649
+ with open(test_list, "r") as fr:
650
+ lines = fr.readlines()
651
+
652
+ for i, line in enumerate(lines):
653
+ wav_name, prompt_text, prompt_wav, text = line.strip().split("\t")
654
+ save_path = f"{res_dir}/{wav_name}.wav"
655
+
656
+ common_params = {
657
+ "save_path": save_path,
658
+ "prompt_text": prompt_text,
659
+ "prompt_wav": prompt_wav,
660
+ "text": text,
661
+ "model": model,
662
+ "vocoder": vocoder,
663
+ "tokenizer": tokenizer,
664
+ "feature_extractor": feature_extractor,
665
+ "device": device,
666
+ "num_step": num_step,
667
+ "guidance_scale": guidance_scale,
668
+ "speed": speed,
669
+ "t_shift": t_shift,
670
+ "target_rms": target_rms,
671
+ "feat_scale": feat_scale,
672
+ "sampling_rate": sampling_rate,
673
+ }
674
+
675
+ if raw_evaluation:
676
+ metrics = generate_sentence_raw_evaluation(**common_params)
677
+ else:
678
+ metrics = generate_sentence(
679
+ **common_params,
680
+ max_duration=max_duration,
681
+ remove_long_sil=remove_long_sil,
682
+ )
683
+ logging.info(f"[Sentence: {i}] Saved to: {save_path}")
684
+ logging.info(f"[Sentence: {i}] RTF: {metrics['rtf']:.4f}")
685
+ total_t.append(metrics["t"])
686
+ total_t_no_vocoder.append(metrics["t_no_vocoder"])
687
+ total_t_vocoder.append(metrics["t_vocoder"])
688
+ total_wav_seconds.append(metrics["wav_seconds"])
689
+
690
+ logging.info(f"Average RTF: {np.sum(total_t) / np.sum(total_wav_seconds):.4f}")
691
+ logging.info(
692
+ f"Average RTF w/o vocoder: "
693
+ f"{np.sum(total_t_no_vocoder) / np.sum(total_wav_seconds):.4f}"
694
+ )
695
+ logging.info(
696
+ f"Average RTF vocoder: "
697
+ f"{np.sum(total_t_vocoder) / np.sum(total_wav_seconds):.4f}"
698
+ )
699
+
700
+
701
+ @torch.inference_mode()
702
+ def main():
703
+ parser = get_parser()
704
+ args = parser.parse_args()
705
+
706
+ torch.set_num_threads(args.num_thread)
707
+ torch.set_num_interop_threads(args.num_thread)
708
+
709
+ params = AttributeDict()
710
+ params.update(vars(args))
711
+ fix_random_seed(params.seed)
712
+
713
+ model_defaults = {
714
+ "zipvoice": {
715
+ "num_step": 16,
716
+ "guidance_scale": 1.0,
717
+ },
718
+ "zipvoice_distill": {
719
+ "num_step": 8,
720
+ "guidance_scale": 3.0,
721
+ },
722
+ }
723
+
724
+ model_specific_defaults = model_defaults.get(params.model_name, {})
725
+
726
+ for param, value in model_specific_defaults.items():
727
+ if getattr(params, param) is None:
728
+ setattr(params, param, value)
729
+ logging.info(f"Setting {param} to default value: {value}")
730
+
731
+ assert (params.test_list is not None) ^ (
732
+ (params.prompt_wav and params.prompt_text and params.text) is not None
733
+ ), (
734
+ "For inference, please provide prompts and text with either '--test-list'"
735
+ " or '--prompt-wav, --prompt-text and --text'."
736
+ )
737
+
738
+ if params.model_dir is not None:
739
+ params.model_dir = Path(params.model_dir)
740
+ if not params.model_dir.is_dir():
741
+ raise FileNotFoundError(f"{params.model_dir} does not exist")
742
+ for filename in [params.checkpoint_name, "model.json", "tokens.txt"]:
743
+ if not (params.model_dir / filename).is_file():
744
+ raise FileNotFoundError(f"{params.model_dir / filename} does not exist")
745
+ model_ckpt = params.model_dir / params.checkpoint_name
746
+ model_config = params.model_dir / "model.json"
747
+ token_file = params.model_dir / "tokens.txt"
748
+ logging.info(
749
+ f"Using {params.model_name} in local model dir {params.model_dir}, "
750
+ f"checkpoint {params.checkpoint_name}"
751
+ )
752
+ else:
753
+ logging.info(f"Using pretrained {params.model_name} model from the Huggingface")
754
+ model_ckpt = hf_hub_download(
755
+ HUGGINGFACE_REPO, filename=f"{MODEL_DIR[params.model_name]}/model.pt"
756
+ )
757
+ model_config = hf_hub_download(
758
+ HUGGINGFACE_REPO, filename=f"{MODEL_DIR[params.model_name]}/model.json"
759
+ )
760
+
761
+ token_file = hf_hub_download(
762
+ HUGGINGFACE_REPO, filename=f"{MODEL_DIR[params.model_name]}/tokens.txt"
763
+ )
764
+
765
+ if params.tokenizer == "emilia":
766
+ tokenizer = EmiliaTokenizer(token_file=token_file)
767
+ elif params.tokenizer == "libritts":
768
+ tokenizer = LibriTTSTokenizer(token_file=token_file)
769
+ elif params.tokenizer == "espeak":
770
+ tokenizer = EspeakTokenizer(token_file=token_file, lang=params.lang)
771
+ else:
772
+ assert params.tokenizer == "simple"
773
+ tokenizer = SimpleTokenizer(token_file=token_file)
774
+
775
+ tokenizer_config = {"vocab_size": tokenizer.vocab_size, "pad_id": tokenizer.pad_id}
776
+
777
+ with open(model_config, "r") as f:
778
+ model_config = json.load(f)
779
+
780
+ if params.model_name == "zipvoice":
781
+ model = ZipVoice(
782
+ **model_config["model"],
783
+ **tokenizer_config,
784
+ )
785
+ else:
786
+ assert params.model_name == "zipvoice_distill"
787
+ model = ZipVoiceDistill(
788
+ **model_config["model"],
789
+ **tokenizer_config,
790
+ )
791
+
792
+ if str(model_ckpt).endswith(".safetensors"):
793
+ safetensors.torch.load_model(model, model_ckpt)
794
+ elif str(model_ckpt).endswith(".pt"):
795
+ load_checkpoint(filename=model_ckpt, model=model, strict=True)
796
+ else:
797
+ raise NotImplementedError(f"Unsupported model checkpoint format: {model_ckpt}")
798
+
799
+ if torch.cuda.is_available():
800
+ params.device = torch.device("cuda", 0)
801
+ elif torch.backends.mps.is_available():
802
+ params.device = torch.device("mps")
803
+ else:
804
+ params.device = torch.device("cpu")
805
+ logging.info(f"Device: {params.device}")
806
+
807
+ model = model.to(params.device)
808
+ model.eval()
809
+
810
+ vocoder = get_vocoder(params.vocoder_path)
811
+ vocoder = vocoder.to(params.device)
812
+ vocoder.eval()
813
+
814
+ if model_config["feature"]["type"] == "vocos":
815
+ feature_extractor = VocosFbank()
816
+ else:
817
+ raise NotImplementedError(
818
+ f"Unsupported feature type: {model_config['feature']['type']}"
819
+ )
820
+ params.sampling_rate = model_config["feature"]["sampling_rate"]
821
+
822
+ logging.info("Start generating...")
823
+ if params.test_list:
824
+ res_dir = params.res_dir
825
+ os.makedirs(res_dir, exist_ok=True)
826
+ generate_list(
827
+ res_dir=params.res_dir,
828
+ test_list=params.test_list,
829
+ model=model,
830
+ vocoder=vocoder,
831
+ tokenizer=tokenizer,
832
+ feature_extractor=feature_extractor,
833
+ device=params.device,
834
+ num_step=params.num_step,
835
+ guidance_scale=params.guidance_scale,
836
+ speed=params.speed,
837
+ t_shift=params.t_shift,
838
+ target_rms=params.target_rms,
839
+ feat_scale=params.feat_scale,
840
+ sampling_rate=params.sampling_rate,
841
+ raw_evaluation=params.raw_evaluation,
842
+ max_duration=params.max_duration,
843
+ remove_long_sil=params.remove_long_sil,
844
+ )
845
+ else:
846
+ assert (
847
+ not params.raw_evaluation
848
+ ), "Raw evaluation is only valid with --test-list"
849
+ generate_sentence(
850
+ save_path=params.res_wav_path,
851
+ prompt_text=params.prompt_text,
852
+ prompt_wav=params.prompt_wav,
853
+ text=params.text,
854
+ model=model,
855
+ vocoder=vocoder,
856
+ tokenizer=tokenizer,
857
+ feature_extractor=feature_extractor,
858
+ device=params.device,
859
+ num_step=params.num_step,
860
+ guidance_scale=params.guidance_scale,
861
+ speed=params.speed,
862
+ t_shift=params.t_shift,
863
+ target_rms=params.target_rms,
864
+ feat_scale=params.feat_scale,
865
+ sampling_rate=params.sampling_rate,
866
+ max_duration=params.max_duration,
867
+ remove_long_sil=params.remove_long_sil,
868
+ )
869
+ logging.info(f"Saved to: {params.res_wav_path}")
870
+ logging.info("Done")
871
+
872
+
873
+ if __name__ == "__main__":
874
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
875
+ logging.basicConfig(format=formatter, level=logging.INFO, force=True)
876
+
877
+ main()
zipvoice/bin/infer_zipvoice_dialog.py ADDED
@@ -0,0 +1,1286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2025 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """
19
+ This script generates speech with our pre-trained ZipVoice-Dialog or
20
+ ZipVoice-Dialog-Stereo models. If no local model is specified,
21
+ Required files will be automatically downloaded from HuggingFace.
22
+
23
+ Usage:
24
+
25
+ Note: If you having trouble connecting to HuggingFace,
26
+ try switching endpoint to mirror site:
27
+ export HF_ENDPOINT=https://hf-mirror.com
28
+
29
+ python3 -m zipvoice.bin.infer_zipvoice_dialog \
30
+ --model-name zipvoice_dialog \
31
+ --test-list test.tsv \
32
+ --res-dir results
33
+
34
+ `--model-name` can be `zipvoice_dialog` or `zipvoice_dialog_stereo`,
35
+ which generate mono and stereo dialogues, respectively.
36
+
37
+ Each line of `test.tsv` is in the format of merged conversation:
38
+ '{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}'
39
+ or splited conversation:
40
+ '{wav_name}\t{spk1_prompt_transcription}\t{spk2_prompt_transcription}
41
+ \t{spk1_prompt_wav}\t{spk2_prompt_wav}\t{text}'
42
+ """
43
+
44
+ import argparse
45
+ import datetime as dt
46
+ import json
47
+ import logging
48
+ import os
49
+ from pathlib import Path
50
+ from typing import List, Optional, Union
51
+
52
+ import numpy as np
53
+ import safetensors.torch
54
+ import torch
55
+ import torchaudio
56
+ from huggingface_hub import hf_hub_download
57
+ from lhotse.utils import fix_random_seed
58
+ from vocos import Vocos
59
+
60
+ from zipvoice.models.zipvoice_dialog import ZipVoiceDialog, ZipVoiceDialogStereo
61
+ from zipvoice.tokenizer.tokenizer import DialogTokenizer
62
+ from zipvoice.utils.checkpoint import load_checkpoint
63
+ from zipvoice.utils.common import AttributeDict, str2bool
64
+ from zipvoice.utils.feature import VocosFbank
65
+ from zipvoice.utils.infer import (
66
+ add_punctuation,
67
+ batchify_tokens,
68
+ chunk_tokens_dialog,
69
+ cross_fade_concat,
70
+ load_prompt_wav,
71
+ remove_silence,
72
+ rms_norm,
73
+ )
74
+
75
+ HUGGINGFACE_REPO = "k2-fsa/ZipVoice"
76
+ MODEL_DIR = {
77
+ "zipvoice_dialog": "zipvoice_dialog",
78
+ "zipvoice_dialog_stereo": "zipvoice_dialog_stereo",
79
+ }
80
+
81
+
82
+ def get_parser():
83
+ parser = argparse.ArgumentParser(
84
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
85
+ )
86
+
87
+ parser.add_argument(
88
+ "--model-name",
89
+ type=str,
90
+ default="zipvoice_dialog",
91
+ choices=["zipvoice_dialog", "zipvoice_dialog_stereo"],
92
+ help="The model used for inference",
93
+ )
94
+
95
+ parser.add_argument(
96
+ "--model-dir",
97
+ type=str,
98
+ default=None,
99
+ help="The model directory that contains model checkpoint, configuration "
100
+ "file model.json, and tokens file tokens.txt. Will download pre-trained "
101
+ "checkpoint from huggingface if not specified.",
102
+ )
103
+
104
+ parser.add_argument(
105
+ "--checkpoint-name",
106
+ type=str,
107
+ default="model.pt",
108
+ help="The name of model checkpoint.",
109
+ )
110
+
111
+ parser.add_argument(
112
+ "--vocoder-path",
113
+ type=str,
114
+ default=None,
115
+ help="The vocoder checkpoint. "
116
+ "Will download pre-trained vocoder from huggingface if not specified.",
117
+ )
118
+
119
+ parser.add_argument(
120
+ "--test-list",
121
+ type=str,
122
+ default=None,
123
+ help="The list of prompt speech, prompt_transcription, "
124
+ "and text to synthesizein the format of merged conversation: "
125
+ "'{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}' "
126
+ "or splited conversation: "
127
+ "'{wav_name}\t{spk1_prompt_transcription}\t{spk2_prompt_transcription}"
128
+ "\t{spk1_prompt_wav}\t{spk2_prompt_wav}\t{text}'.",
129
+ )
130
+
131
+ parser.add_argument(
132
+ "--res-dir",
133
+ type=str,
134
+ default="results",
135
+ help="""
136
+ Path name of the generated wavs dir,
137
+ used when test-list is not None
138
+ """,
139
+ )
140
+
141
+ parser.add_argument(
142
+ "--guidance-scale",
143
+ type=float,
144
+ default=1.5,
145
+ help="The scale of classifier-free guidance during inference.",
146
+ )
147
+
148
+ parser.add_argument(
149
+ "--num-step",
150
+ type=int,
151
+ default=16,
152
+ help="The number of sampling steps.",
153
+ )
154
+
155
+ parser.add_argument(
156
+ "--feat-scale",
157
+ type=float,
158
+ default=0.1,
159
+ help="The scale factor of fbank feature",
160
+ )
161
+
162
+ parser.add_argument(
163
+ "--speed",
164
+ type=float,
165
+ default=1.0,
166
+ help="Control speech speed, 1.0 means normal, >1.0 means speed up",
167
+ )
168
+
169
+ parser.add_argument(
170
+ "--t-shift",
171
+ type=float,
172
+ default=0.5,
173
+ help="Shift t to smaller ones if t_shift < 1.0",
174
+ )
175
+
176
+ parser.add_argument(
177
+ "--target-rms",
178
+ type=float,
179
+ default=0.1,
180
+ help="Target speech normalization rms value, set to 0 to disable normalization",
181
+ )
182
+
183
+ parser.add_argument(
184
+ "--seed",
185
+ type=int,
186
+ default=666,
187
+ help="Random seed",
188
+ )
189
+
190
+ parser.add_argument(
191
+ "--silence-wav",
192
+ type=str,
193
+ default="assets/silence.wav",
194
+ help="Path of the silence wav file, used in two-channel generation "
195
+ "with single-channel prompts",
196
+ )
197
+
198
+ parser.add_argument(
199
+ "--num-thread",
200
+ type=int,
201
+ default=1,
202
+ help="Number of threads to use for PyTorch on CPU.",
203
+ )
204
+
205
+ parser.add_argument(
206
+ "--raw-evaluation",
207
+ type=str2bool,
208
+ default=False,
209
+ help="Whether to use the 'raw' evaluation mode where provided "
210
+ "prompts and text are fed to the model without pre-processing",
211
+ )
212
+
213
+ parser.add_argument(
214
+ "--max-duration",
215
+ type=float,
216
+ default=100,
217
+ help="Maximum duration (seconds) in a single batch, including "
218
+ "durations of the prompt and generated wavs. You can reduce it "
219
+ "if it causes CUDA OOM.",
220
+ )
221
+
222
+ parser.add_argument(
223
+ "--remove-long-sil",
224
+ type=str2bool,
225
+ default=False,
226
+ help="Whether to remove long silences in the middle of the generated "
227
+ "speech (edge silences will be removed by default).",
228
+ )
229
+
230
+ return parser
231
+
232
+
233
+ def get_vocoder(vocos_local_path: Optional[str] = None):
234
+ if vocos_local_path:
235
+ vocoder = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
236
+ state_dict = torch.load(
237
+ f"{vocos_local_path}/pytorch_model.bin",
238
+ weights_only=True,
239
+ map_location="cpu",
240
+ )
241
+ vocoder.load_state_dict(state_dict)
242
+ else:
243
+ vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz")
244
+ return vocoder
245
+
246
+
247
+ def generate_sentence_raw_evaluation(
248
+ save_path: str,
249
+ prompt_text: str,
250
+ prompt_wav: Union[str, List[str]],
251
+ text: str,
252
+ model: torch.nn.Module,
253
+ vocoder: torch.nn.Module,
254
+ tokenizer: DialogTokenizer,
255
+ feature_extractor: VocosFbank,
256
+ device: torch.device,
257
+ num_step: int = 16,
258
+ guidance_scale: float = 1.0,
259
+ speed: float = 1.0,
260
+ t_shift: float = 0.5,
261
+ target_rms: float = 0.1,
262
+ feat_scale: float = 0.1,
263
+ sampling_rate: int = 24000,
264
+ ):
265
+ """
266
+ Generate waveform of a text based on a given prompt waveform and its transcription,
267
+ this function directly feed the prompt_text, prompt_wav and text to the model.
268
+ It is not efficient and can have poor results for some inappropriate inputs.
269
+ (e.g., prompt wav contains long silence, text to be generated is too long)
270
+ This function can be used to evaluate the "raw" performance of the model.
271
+
272
+ Args:
273
+ save_path (str): Path to save the generated wav.
274
+ prompt_text (str): Transcription of the prompt wav.
275
+ prompt_wav (Union[str, List[str]]): Path to the prompt wav file, can be
276
+ one or two wav files, which corresponding to a merged conversational
277
+ speech or two seperate speaker's speech.
278
+ text (str): Text to be synthesized into a waveform.
279
+ model (torch.nn.Module): The model used for generation.
280
+ vocoder (torch.nn.Module): The vocoder used to convert features to waveforms.
281
+ tokenizer (DialogTokenizer): The tokenizer used to convert text to tokens.
282
+ feature_extractor (VocosFbank): The feature extractor used to
283
+ extract acoustic features.
284
+ device (torch.device): The device on which computations are performed.
285
+ num_step (int, optional): Number of steps for decoding. Defaults to 16.
286
+ guidance_scale (float, optional): Scale for classifier-free guidance.
287
+ Defaults to 1.0.
288
+ speed (float, optional): Speed control. Defaults to 1.0.
289
+ t_shift (float, optional): Time shift. Defaults to 0.5.
290
+ target_rms (float, optional): Target RMS for waveform normalization.
291
+ Defaults to 0.1.
292
+ feat_scale (float, optional): Scale for features.
293
+ Defaults to 0.1.
294
+ sampling_rate (int, optional): Sampling rate for the waveform.
295
+ Defaults to 24000.
296
+ Returns:
297
+ metrics (dict): Dictionary containing time and real-time
298
+ factor metrics for processing.
299
+ """
300
+
301
+ # Load and preprocess prompt wav
302
+ if isinstance(prompt_wav, str):
303
+ prompt_wav = [
304
+ prompt_wav,
305
+ ]
306
+ else:
307
+ assert len(prompt_wav) == 2 and isinstance(prompt_wav[0], str)
308
+
309
+ loaded_prompt_wavs = prompt_wav
310
+ for i in range(len(prompt_wav)):
311
+ loaded_prompt_wavs[i] = load_prompt_wav(
312
+ loaded_prompt_wavs[i], sampling_rate=sampling_rate
313
+ )
314
+ if loaded_prompt_wavs[i].size(0) != 1:
315
+ loaded_prompt_wavs[i] = loaded_prompt_wavs[i].mean(0, keepdim=True)
316
+
317
+ if len(loaded_prompt_wavs) == 1:
318
+ prompt_wav = loaded_prompt_wavs[0]
319
+ else:
320
+ prompt_wav = torch.cat(loaded_prompt_wavs, dim=1)
321
+
322
+ prompt_wav, prompt_rms = rms_norm(prompt_wav, target_rms)
323
+
324
+ # Extract features from prompt wav
325
+ prompt_features = feature_extractor.extract(
326
+ prompt_wav, sampling_rate=sampling_rate
327
+ ).to(device)
328
+
329
+ prompt_features = prompt_features.unsqueeze(0) * feat_scale
330
+ prompt_features_lens = torch.tensor([prompt_features.size(1)], device=device)
331
+
332
+ # Convert text to tokens
333
+ tokens = tokenizer.texts_to_token_ids([text])
334
+ prompt_tokens = tokenizer.texts_to_token_ids([prompt_text])
335
+
336
+ # Start timing
337
+ start_t = dt.datetime.now()
338
+
339
+ # Generate features
340
+ (
341
+ pred_features,
342
+ pred_features_lens,
343
+ pred_prompt_features,
344
+ pred_prompt_features_lens,
345
+ ) = model.sample(
346
+ tokens=tokens,
347
+ prompt_tokens=prompt_tokens,
348
+ prompt_features=prompt_features,
349
+ prompt_features_lens=prompt_features_lens,
350
+ speed=speed,
351
+ t_shift=t_shift,
352
+ duration="predict",
353
+ num_step=num_step,
354
+ guidance_scale=guidance_scale,
355
+ )
356
+
357
+ # Postprocess predicted features
358
+ pred_features = pred_features.permute(0, 2, 1) / feat_scale # (B, C, T)
359
+
360
+ # Start vocoder processing
361
+ start_vocoder_t = dt.datetime.now()
362
+ wav = vocoder.decode(pred_features).squeeze(1).clamp(-1, 1)
363
+
364
+ # Calculate processing times and real-time factors
365
+ t = (dt.datetime.now() - start_t).total_seconds()
366
+ t_no_vocoder = (start_vocoder_t - start_t).total_seconds()
367
+ t_vocoder = (dt.datetime.now() - start_vocoder_t).total_seconds()
368
+ wav_seconds = wav.shape[-1] / sampling_rate
369
+ rtf = t / wav_seconds
370
+ rtf_no_vocoder = t_no_vocoder / wav_seconds
371
+ rtf_vocoder = t_vocoder / wav_seconds
372
+ metrics = {
373
+ "t": t,
374
+ "t_no_vocoder": t_no_vocoder,
375
+ "t_vocoder": t_vocoder,
376
+ "wav_seconds": wav_seconds,
377
+ "rtf": rtf,
378
+ "rtf_no_vocoder": rtf_no_vocoder,
379
+ "rtf_vocoder": rtf_vocoder,
380
+ }
381
+
382
+ # Adjust wav volume if necessary
383
+ if prompt_rms < target_rms:
384
+ wav = wav * prompt_rms / target_rms
385
+ torchaudio.save(save_path, wav.cpu(), sample_rate=sampling_rate)
386
+
387
+ return metrics
388
+
389
+
390
+ def generate_sentence(
391
+ save_path: str,
392
+ prompt_text: str,
393
+ prompt_wav: Union[str, List[str]],
394
+ text: str,
395
+ model: torch.nn.Module,
396
+ vocoder: torch.nn.Module,
397
+ tokenizer: DialogTokenizer,
398
+ feature_extractor: VocosFbank,
399
+ device: torch.device,
400
+ num_step: int = 16,
401
+ guidance_scale: float = 1.0,
402
+ speed: float = 1.0,
403
+ t_shift: float = 0.5,
404
+ target_rms: float = 0.1,
405
+ feat_scale: float = 0.1,
406
+ sampling_rate: int = 24000,
407
+ max_duration: float = 100,
408
+ remove_long_sil: bool = False,
409
+ ):
410
+ """
411
+ Generate waveform of a text based on a given prompt waveform and its transcription,
412
+ this function will do the following to improve the generation quality:
413
+ 1. chunk the text according to speaker-turn symbol [S1].
414
+ 2. process chunked texts in batches.
415
+ 3. remove long silences in the prompt audio.
416
+ 4. add punctuation to the end of prompt text and text if there is not.
417
+
418
+ Args:
419
+ save_path (str): Path to save the generated wav.
420
+ prompt_text (str): Transcription of the prompt wav.
421
+ prompt_wav (Union[str, List[str]]): Path to the prompt wav file, can be
422
+ one or two wav files, which corresponding to a merged conversational
423
+ speech or two seperate speaker's speech.
424
+ text (str): Text to be synthesized into a waveform.
425
+ model (torch.nn.Module): The model used for generation.
426
+ vocoder (torch.nn.Module): The vocoder used to convert features to waveforms.
427
+ tokenizer (DialogTokenizer): The tokenizer used to convert text to tokens.
428
+ feature_extractor (VocosFbank): The feature extractor used to
429
+ extract acoustic features.
430
+ device (torch.device): The device on which computations are performed.
431
+ num_step (int, optional): Number of steps for decoding. Defaults to 16.
432
+ guidance_scale (float, optional): Scale for classifier-free guidance.
433
+ Defaults to 1.0.
434
+ speed (float, optional): Speed control. Defaults to 1.0.
435
+ t_shift (float, optional): Time shift. Defaults to 0.5.
436
+ target_rms (float, optional): Target RMS for waveform normalization.
437
+ Defaults to 0.1.
438
+ feat_scale (float, optional): Scale for features.
439
+ Defaults to 0.1.
440
+ sampling_rate (int, optional): Sampling rate for the waveform.
441
+ Defaults to 24000.
442
+ max_duration (float, optional): The maximum duration to process in each
443
+ batch. Used to control memory consumption when generating long audios.
444
+ remove_long_sil (bool, optional): Whether to remove long silences in the
445
+ middle of the generated speech (edge silences will be removed by default).
446
+ Returns:
447
+ metrics (dict): Dictionary containing time and real-time
448
+ factor metrics for processing.
449
+ """
450
+
451
+ # Load and preprocess prompt wav
452
+ if isinstance(prompt_wav, str):
453
+ prompt_wav = [
454
+ prompt_wav,
455
+ ]
456
+ else:
457
+ assert len(prompt_wav) == 2 and isinstance(prompt_wav[0], str)
458
+
459
+ loaded_prompt_wavs = prompt_wav
460
+ for i in range(len(prompt_wav)):
461
+ loaded_prompt_wavs[i] = load_prompt_wav(
462
+ loaded_prompt_wavs[i], sampling_rate=sampling_rate
463
+ )
464
+ if loaded_prompt_wavs[i].size(0) != 1:
465
+ loaded_prompt_wavs[i] = loaded_prompt_wavs[i].mean(0, keepdim=True)
466
+
467
+ if len(loaded_prompt_wavs) == 1:
468
+ prompt_wav = loaded_prompt_wavs[0]
469
+ else:
470
+ prompt_wav = torch.cat(loaded_prompt_wavs, dim=1)
471
+
472
+ # Remove edge and long silences in the prompt wav.
473
+ # Add 0.2s trailing silence to avoid leaking prompt to generated speech.
474
+ prompt_wav = remove_silence(
475
+ prompt_wav, sampling_rate, only_edge=False, trail_sil=200
476
+ )
477
+
478
+ prompt_wav, prompt_rms = rms_norm(prompt_wav, target_rms)
479
+
480
+ prompt_duration = prompt_wav.shape[-1] / sampling_rate
481
+
482
+ if prompt_duration > 40:
483
+ logging.warning(
484
+ f"Given prompt wav is too long ({prompt_duration}s). "
485
+ f"Please provide a shorter one (prompt shorter than 10 "
486
+ f"seconds is recommended)."
487
+ )
488
+ elif prompt_duration > 20:
489
+ logging.warning(
490
+ f"Given prompt wav is long ({prompt_duration}s). "
491
+ f"It will lead to slower inference speed and possibly worse speech quality."
492
+ )
493
+
494
+ # Extract features from prompt wav
495
+ prompt_features = feature_extractor.extract(
496
+ prompt_wav, sampling_rate=sampling_rate
497
+ ).to(device)
498
+
499
+ prompt_features = prompt_features.unsqueeze(0) * feat_scale
500
+
501
+ # Add punctuation in the end if there is not
502
+ text = add_punctuation(text)
503
+ prompt_text = add_punctuation(prompt_text)
504
+
505
+ # Tokenize text (str tokens), punctuations will be preserved.
506
+ tokens_str = tokenizer.texts_to_tokens([text])[0]
507
+ prompt_tokens_str = tokenizer.texts_to_tokens([prompt_text])[0]
508
+
509
+ # chunk text so that each len(prompt wav + generated wav) is around 40 seconds.
510
+ token_duration = (prompt_wav.shape[-1] / sampling_rate) / (
511
+ len(prompt_tokens_str) * speed
512
+ )
513
+ max_tokens = int((40 - prompt_duration) / token_duration)
514
+ chunked_tokens_str = chunk_tokens_dialog(tokens_str, max_tokens=max_tokens)
515
+
516
+ # Tokenize text (int tokens)
517
+ chunked_tokens = tokenizer.tokens_to_token_ids(chunked_tokens_str)
518
+ prompt_tokens = tokenizer.tokens_to_token_ids([prompt_tokens_str])
519
+
520
+ # Batchify chunked texts for faster processing
521
+ tokens_batches, chunked_index = batchify_tokens(
522
+ chunked_tokens, max_duration, prompt_duration, token_duration
523
+ )
524
+
525
+ # Start predicting features
526
+ chunked_features = []
527
+ start_t = dt.datetime.now()
528
+
529
+ for batch_tokens in tokens_batches:
530
+ batch_prompt_tokens = prompt_tokens * len(batch_tokens)
531
+
532
+ batch_prompt_features = prompt_features.repeat(len(batch_tokens), 1, 1)
533
+ batch_prompt_features_lens = torch.full(
534
+ (len(batch_tokens),), prompt_features.size(1), device=device
535
+ )
536
+
537
+ # Generate features
538
+ (
539
+ pred_features,
540
+ pred_features_lens,
541
+ pred_prompt_features,
542
+ pred_prompt_features_lens,
543
+ ) = model.sample(
544
+ tokens=batch_tokens,
545
+ prompt_tokens=batch_prompt_tokens,
546
+ prompt_features=batch_prompt_features,
547
+ prompt_features_lens=batch_prompt_features_lens,
548
+ speed=speed,
549
+ t_shift=t_shift,
550
+ duration="predict",
551
+ num_step=num_step,
552
+ guidance_scale=guidance_scale,
553
+ )
554
+
555
+ # Postprocess predicted features
556
+ pred_features = pred_features.permute(0, 2, 1) / feat_scale # (B, C, T)
557
+ chunked_features.append((pred_features, pred_features_lens))
558
+
559
+ # Start vocoder processing
560
+ chunked_wavs = []
561
+ start_vocoder_t = dt.datetime.now()
562
+
563
+ for pred_features, pred_features_lens in chunked_features:
564
+ batch_wav = []
565
+ for i in range(pred_features.size(0)):
566
+
567
+ wav = (
568
+ vocoder.decode(pred_features[i][None, :, : pred_features_lens[i]])
569
+ .squeeze(1)
570
+ .clamp(-1, 1)
571
+ )
572
+ # Adjust wav volume if necessary
573
+ if prompt_rms < target_rms:
574
+ wav = wav * prompt_rms / target_rms
575
+ batch_wav.append(wav)
576
+ chunked_wavs.extend(batch_wav)
577
+
578
+ # Finish model generation
579
+ t = (dt.datetime.now() - start_t).total_seconds()
580
+
581
+ # Merge chunked wavs
582
+ indexed_chunked_wavs = [
583
+ (index, wav) for index, wav in zip(chunked_index, chunked_wavs)
584
+ ]
585
+ sequential_indexed_chunked_wavs = sorted(indexed_chunked_wavs, key=lambda x: x[0])
586
+ sequential_chunked_wavs = [
587
+ sequential_indexed_chunked_wavs[i][1]
588
+ for i in range(len(sequential_indexed_chunked_wavs))
589
+ ]
590
+ final_wav = cross_fade_concat(
591
+ sequential_chunked_wavs, fade_duration=0.1, sample_rate=sampling_rate
592
+ )
593
+ final_wav = remove_silence(
594
+ final_wav, sampling_rate, only_edge=(not remove_long_sil), trail_sil=0
595
+ )
596
+
597
+ # Calculate processing time metrics
598
+ t_no_vocoder = (start_vocoder_t - start_t).total_seconds()
599
+ t_vocoder = (dt.datetime.now() - start_vocoder_t).total_seconds()
600
+ wav_seconds = final_wav.shape[-1] / sampling_rate
601
+ rtf = t / wav_seconds
602
+ rtf_no_vocoder = t_no_vocoder / wav_seconds
603
+ rtf_vocoder = t_vocoder / wav_seconds
604
+ metrics = {
605
+ "t": t,
606
+ "t_no_vocoder": t_no_vocoder,
607
+ "t_vocoder": t_vocoder,
608
+ "wav_seconds": wav_seconds,
609
+ "rtf": rtf,
610
+ "rtf_no_vocoder": rtf_no_vocoder,
611
+ "rtf_vocoder": rtf_vocoder,
612
+ }
613
+
614
+ torchaudio.save(save_path, final_wav.cpu(), sample_rate=sampling_rate)
615
+ return metrics
616
+
617
+
618
+ def generate_sentence_stereo_raw_evaluation(
619
+ save_path: str,
620
+ prompt_text: str,
621
+ prompt_wav: Union[str, List[str]],
622
+ text: str,
623
+ model: torch.nn.Module,
624
+ vocoder: torch.nn.Module,
625
+ tokenizer: DialogTokenizer,
626
+ feature_extractor: VocosFbank,
627
+ device: torch.device,
628
+ num_step: int = 16,
629
+ guidance_scale: float = 1.0,
630
+ speed: float = 1.0,
631
+ t_shift: float = 0.5,
632
+ target_rms: float = 0.1,
633
+ feat_scale: float = 0.1,
634
+ sampling_rate: int = 24000,
635
+ silence_wav: Optional[str] = None,
636
+ ):
637
+ """
638
+ Generate waveform of a text based on a given prompt waveform and its transcription,
639
+ this function directly feed the prompt_text, prompt_wav and text to the model.
640
+ It is not efficient and can have poor results for some inappropriate inputs.
641
+ (e.g., prompt wav contains long silence, text to be generated is too long)
642
+ This function can be used to evaluate the "raw" performance of the model.
643
+
644
+ Args:
645
+ save_path (str): Path to save the generated wav.
646
+ prompt_text (str): Transcription of the prompt wav.
647
+ prompt_wav (Union[str, List[str]]): Path to the prompt wav file, can be
648
+ one or two wav files, which corresponding to a merged conversational
649
+ speech or two seperate speaker's speech.
650
+ text (str): Text to be synthesized into a waveform.
651
+ model (torch.nn.Module): The model used for generation.
652
+ vocoder (torch.nn.Module): The vocoder used to convert features to waveforms.
653
+ tokenizer (DialogTokenizer): The tokenizer used to convert text to tokens.
654
+ feature_extractor (VocosFbank): The feature extractor used to
655
+ extract acoustic features.
656
+ device (torch.device): The device on which computations are performed.
657
+ num_step (int, optional): Number of steps for decoding. Defaults to 16.
658
+ guidance_scale (float, optional): Scale for classifier-free guidance.
659
+ Defaults to 1.0.
660
+ speed (float, optional): Speed control. Defaults to 1.0.
661
+ t_shift (float, optional): Time shift. Defaults to 0.5.
662
+ target_rms (float, optional): Target RMS for waveform normalization.
663
+ Defaults to 0.1.
664
+ feat_scale (float, optional): Scale for features.
665
+ Defaults to 0.1.
666
+ sampling_rate (int, optional): Sampling rate for the waveform.
667
+ Defaults to 24000.
668
+ silence_wav (str): Path of the silence wav file, used in two-channel
669
+ generation with single-channel prompts
670
+ Returns:
671
+ metrics (dict): Dictionary containing time and real-time
672
+ factor metrics for processing.
673
+ """
674
+
675
+ # Load and preprocess prompt wav
676
+ if isinstance(prompt_wav, str):
677
+ prompt_wav = [
678
+ prompt_wav,
679
+ ]
680
+ else:
681
+ assert len(prompt_wav) == 2 and isinstance(prompt_wav[0], str)
682
+
683
+ loaded_prompt_wavs = prompt_wav
684
+ for i in range(len(prompt_wav)):
685
+ loaded_prompt_wavs[i] = load_prompt_wav(
686
+ loaded_prompt_wavs[i], sampling_rate=sampling_rate
687
+ )
688
+
689
+ if len(loaded_prompt_wavs) == 1:
690
+ assert (
691
+ loaded_prompt_wavs[0].size(0) == 2
692
+ ), "Merged prompt wav must be stereo for stereo dialogue generation"
693
+ prompt_wav = loaded_prompt_wavs[0]
694
+
695
+ else:
696
+ assert len(loaded_prompt_wavs) == 2
697
+ if loaded_prompt_wavs[0].size(0) == 2:
698
+ prompt_wav = torch.cat(loaded_prompt_wavs, dim=1)
699
+ else:
700
+ assert loaded_prompt_wavs[0].size(0) == 1
701
+ silence_wav, silence_sampling_rate = torchaudio.load(silence_wav)
702
+ assert silence_sampling_rate == sampling_rate
703
+ prompt_wav = silence_wav[
704
+ :, : loaded_prompt_wavs[0].size(1) + loaded_prompt_wavs[1].size(1)
705
+ ]
706
+ prompt_wav[0, : loaded_prompt_wavs[0].size(1)] = loaded_prompt_wavs[0]
707
+ prompt_wav[1, loaded_prompt_wavs[0].size(1) :] = loaded_prompt_wavs[1]
708
+
709
+ prompt_wav, prompt_rms = rms_norm(prompt_wav, target_rms)
710
+
711
+ # Extract features from prompt wav
712
+ prompt_features = feature_extractor.extract(
713
+ prompt_wav, sampling_rate=sampling_rate
714
+ ).to(device)
715
+
716
+ prompt_features = prompt_features.unsqueeze(0) * feat_scale
717
+ prompt_features_lens = torch.tensor([prompt_features.size(1)], device=device)
718
+
719
+ # Convert text to tokens
720
+ tokens = tokenizer.texts_to_token_ids([text])
721
+ prompt_tokens = tokenizer.texts_to_token_ids([prompt_text])
722
+
723
+ # Start timing
724
+ start_t = dt.datetime.now()
725
+
726
+ # Generate features
727
+ (
728
+ pred_features,
729
+ pred_features_lens,
730
+ pred_prompt_features,
731
+ pred_prompt_features_lens,
732
+ ) = model.sample(
733
+ tokens=tokens,
734
+ prompt_tokens=prompt_tokens,
735
+ prompt_features=prompt_features,
736
+ prompt_features_lens=prompt_features_lens,
737
+ speed=speed,
738
+ t_shift=t_shift,
739
+ duration="predict",
740
+ num_step=num_step,
741
+ guidance_scale=guidance_scale,
742
+ )
743
+
744
+ # Postprocess predicted features
745
+ pred_features = pred_features.permute(0, 2, 1) / feat_scale # (B, C, T)
746
+
747
+ # Start vocoder processing
748
+ start_vocoder_t = dt.datetime.now()
749
+ feat_dim = pred_features.size(1) // 2
750
+ wav_left = vocoder.decode(pred_features[:, :feat_dim]).squeeze(1).clamp(-1, 1)
751
+ wav_right = (
752
+ vocoder.decode(pred_features[:, feat_dim : feat_dim * 2])
753
+ .squeeze(1)
754
+ .clamp(-1, 1)
755
+ )
756
+
757
+ wav = torch.cat([wav_left, wav_right], dim=0)
758
+
759
+ # Calculate processing times and real-time factors
760
+ t = (dt.datetime.now() - start_t).total_seconds()
761
+ t_no_vocoder = (start_vocoder_t - start_t).total_seconds()
762
+ t_vocoder = (dt.datetime.now() - start_vocoder_t).total_seconds()
763
+ wav_seconds = wav.shape[-1] / sampling_rate
764
+ rtf = t / wav_seconds
765
+ rtf_no_vocoder = t_no_vocoder / wav_seconds
766
+ rtf_vocoder = t_vocoder / wav_seconds
767
+ metrics = {
768
+ "t": t,
769
+ "t_no_vocoder": t_no_vocoder,
770
+ "t_vocoder": t_vocoder,
771
+ "wav_seconds": wav_seconds,
772
+ "rtf": rtf,
773
+ "rtf_no_vocoder": rtf_no_vocoder,
774
+ "rtf_vocoder": rtf_vocoder,
775
+ }
776
+
777
+ # Adjust wav volume if necessary
778
+ if prompt_rms < target_rms:
779
+ wav = wav * prompt_rms / target_rms
780
+ torchaudio.save(save_path, wav.cpu(), sample_rate=sampling_rate)
781
+
782
+ return metrics
783
+
784
+
785
+ def generate_sentence_stereo(
786
+ save_path: str,
787
+ prompt_text: str,
788
+ prompt_wav: Union[str, List[str]],
789
+ text: str,
790
+ model: torch.nn.Module,
791
+ vocoder: torch.nn.Module,
792
+ tokenizer: DialogTokenizer,
793
+ feature_extractor: VocosFbank,
794
+ device: torch.device,
795
+ num_step: int = 16,
796
+ guidance_scale: float = 1.0,
797
+ speed: float = 1.0,
798
+ t_shift: float = 0.5,
799
+ target_rms: float = 0.1,
800
+ feat_scale: float = 0.1,
801
+ sampling_rate: int = 24000,
802
+ silence_wav: Optional[str] = None,
803
+ max_duration: float = 100,
804
+ remove_long_sil: bool = False,
805
+ ):
806
+ """
807
+ Generate waveform of a text based on a given prompt waveform and its transcription,
808
+ this function will do the following to improve the generation quality:
809
+ 1. chunk the text according to speaker-turn symbol [S1].
810
+ 2. process chunked texts in batches.
811
+ 3. remove long silences in the prompt audio.
812
+ 4. add punctuation to the end of prompt text and text if there is not.
813
+
814
+ Args:
815
+ save_path (str): Path to save the generated wav.
816
+ prompt_text (str): Transcription of the prompt wav.
817
+ prompt_wav (Union[str, List[str]]): Path to the prompt wav file, can be
818
+ one or two wav files, which corresponding to a merged conversational
819
+ speech or two seperate speaker's speech.
820
+ text (str): Text to be synthesized into a waveform.
821
+ model (torch.nn.Module): The model used for generation.
822
+ vocoder (torch.nn.Module): The vocoder used to convert features to waveforms.
823
+ tokenizer (DialogTokenizer): The tokenizer used to convert text to tokens.
824
+ feature_extractor (VocosFbank): The feature extractor used to
825
+ extract acoustic features.
826
+ device (torch.device): The device on which computations are performed.
827
+ num_step (int, optional): Number of steps for decoding. Defaults to 16.
828
+ guidance_scale (float, optional): Scale for classifier-free guidance.
829
+ Defaults to 1.0.
830
+ speed (float, optional): Speed control. Defaults to 1.0.
831
+ t_shift (float, optional): Time shift. Defaults to 0.5.
832
+ target_rms (float, optional): Target RMS for waveform normalization.
833
+ Defaults to 0.1.
834
+ feat_scale (float, optional): Scale for features.
835
+ Defaults to 0.1.
836
+ sampling_rate (int, optional): Sampling rate for the waveform.
837
+ Defaults to 24000.
838
+ silence_wav (str): Path of the silence wav file, used in two-channel
839
+ generation with single-channel prompts
840
+ max_duration (float, optional): The maximum duration to process in each
841
+ batch. Used to control memory consumption when generating long audios.
842
+ remove_long_sil (bool, optional): Whether to remove long silences in the
843
+ middle of the generated speech (edge silences will be removed by default).
844
+ Returns:
845
+ metrics (dict): Dictionary containing time and real-time
846
+ factor metrics for processing.
847
+ """
848
+
849
+ # Load and preprocess prompt wav
850
+ if isinstance(prompt_wav, str):
851
+ prompt_wav = [
852
+ prompt_wav,
853
+ ]
854
+ else:
855
+ assert len(prompt_wav) == 2 and isinstance(prompt_wav[0], str)
856
+
857
+ loaded_prompt_wavs = prompt_wav
858
+ for i in range(len(prompt_wav)):
859
+ loaded_prompt_wavs[i] = load_prompt_wav(
860
+ loaded_prompt_wavs[i], sampling_rate=sampling_rate
861
+ )
862
+
863
+ if len(loaded_prompt_wavs) == 1:
864
+ assert (
865
+ loaded_prompt_wavs[0].size(0) == 2
866
+ ), "Merged prompt wav must be stereo for stereo dialogue generation"
867
+ prompt_wav = loaded_prompt_wavs[0]
868
+
869
+ else:
870
+ assert len(loaded_prompt_wavs) == 2
871
+ if loaded_prompt_wavs[0].size(0) == 2:
872
+ prompt_wav = torch.cat(loaded_prompt_wavs, dim=1)
873
+ else:
874
+ assert loaded_prompt_wavs[0].size(0) == 1
875
+ silence_wav, silence_sampling_rate = torchaudio.load(silence_wav)
876
+ assert silence_sampling_rate == sampling_rate
877
+ prompt_wav = silence_wav[
878
+ :, : loaded_prompt_wavs[0].size(1) + loaded_prompt_wavs[1].size(1)
879
+ ]
880
+ prompt_wav[0, : loaded_prompt_wavs[0].size(1)] = loaded_prompt_wavs[0]
881
+ prompt_wav[1, loaded_prompt_wavs[0].size(1) :] = loaded_prompt_wavs[1]
882
+
883
+ # Remove edge and long silences in the prompt wav.
884
+ # Add 0.2s trailing silence to avoid leaking prompt to generated speech.
885
+ prompt_wav = remove_silence(
886
+ prompt_wav, sampling_rate, only_edge=False, trail_sil=200
887
+ )
888
+
889
+ prompt_wav, prompt_rms = rms_norm(prompt_wav, target_rms)
890
+
891
+ prompt_duration = prompt_wav.shape[-1] / sampling_rate
892
+
893
+ if prompt_duration > 40:
894
+ logging.warning(
895
+ f"Given prompt wav is too long ({prompt_duration}s). "
896
+ f"Please provide a shorter one (prompt shorter than 10 "
897
+ f"seconds is recommended)."
898
+ )
899
+ elif prompt_duration > 20:
900
+ logging.warning(
901
+ f"Given prompt wav is long ({prompt_duration}s). "
902
+ f"It will lead to slower inference speed and possibly worse speech quality."
903
+ )
904
+
905
+ # Extract features from prompt wav
906
+ prompt_features = feature_extractor.extract(
907
+ prompt_wav, sampling_rate=sampling_rate
908
+ ).to(device)
909
+
910
+ prompt_features = prompt_features.unsqueeze(0) * feat_scale
911
+
912
+ # Add punctuation in the end if there is not
913
+ text = add_punctuation(text)
914
+ prompt_text = add_punctuation(prompt_text)
915
+
916
+ # Tokenize text (str tokens), punctuations will be preserved.
917
+ tokens_str = tokenizer.texts_to_tokens([text])[0]
918
+ prompt_tokens_str = tokenizer.texts_to_tokens([prompt_text])[0]
919
+
920
+ # chunk text so that each len(prompt wav + generated wav) is around 40 seconds.
921
+ token_duration = (prompt_wav.shape[-1] / sampling_rate) / (
922
+ len(prompt_tokens_str) * speed
923
+ )
924
+ max_tokens = int((40 - prompt_duration) / token_duration)
925
+ chunked_tokens_str = chunk_tokens_dialog(tokens_str, max_tokens=max_tokens)
926
+
927
+ # Tokenize text (int tokens)
928
+ chunked_tokens = tokenizer.tokens_to_token_ids(chunked_tokens_str)
929
+ prompt_tokens = tokenizer.tokens_to_token_ids([prompt_tokens_str])
930
+
931
+ # Batchify chunked texts for faster processing
932
+ tokens_batches, chunked_index = batchify_tokens(
933
+ chunked_tokens, max_duration, prompt_duration, token_duration
934
+ )
935
+
936
+ # Start predicting features
937
+ chunked_features = []
938
+ start_t = dt.datetime.now()
939
+
940
+ for batch_tokens in tokens_batches:
941
+ batch_prompt_tokens = prompt_tokens * len(batch_tokens)
942
+
943
+ batch_prompt_features = prompt_features.repeat(len(batch_tokens), 1, 1)
944
+ batch_prompt_features_lens = torch.full(
945
+ (len(batch_tokens),), prompt_features.size(1), device=device
946
+ )
947
+
948
+ # Generate features
949
+ (
950
+ pred_features,
951
+ pred_features_lens,
952
+ pred_prompt_features,
953
+ pred_prompt_features_lens,
954
+ ) = model.sample(
955
+ tokens=batch_tokens,
956
+ prompt_tokens=batch_prompt_tokens,
957
+ prompt_features=batch_prompt_features,
958
+ prompt_features_lens=batch_prompt_features_lens,
959
+ speed=speed,
960
+ t_shift=t_shift,
961
+ duration="predict",
962
+ num_step=num_step,
963
+ guidance_scale=guidance_scale,
964
+ )
965
+
966
+ # Postprocess predicted features
967
+ pred_features = pred_features.permute(0, 2, 1) / feat_scale # (B, C, T)
968
+ chunked_features.append((pred_features, pred_features_lens))
969
+
970
+ # Start vocoder processing
971
+ chunked_wavs = []
972
+ start_vocoder_t = dt.datetime.now()
973
+
974
+ for pred_features, pred_features_lens in chunked_features:
975
+ batch_wav = []
976
+ for i in range(pred_features.size(0)):
977
+
978
+ feat_dim = pred_features.size(1) // 2
979
+ wav_left = (
980
+ vocoder.decode(
981
+ pred_features[i][None, :feat_dim, : pred_features_lens[i]]
982
+ )
983
+ .squeeze(1)
984
+ .clamp(-1, 1)
985
+ )
986
+ wav_right = (
987
+ vocoder.decode(
988
+ pred_features[i][
989
+ None, feat_dim : feat_dim * 2, : pred_features_lens[i]
990
+ ]
991
+ )
992
+ .squeeze(1)
993
+ .clamp(-1, 1)
994
+ )
995
+ wav = torch.cat([wav_left, wav_right], dim=0)
996
+
997
+ # Adjust wav volume if necessary
998
+ if prompt_rms < target_rms:
999
+ wav = wav * prompt_rms / target_rms
1000
+ batch_wav.append(wav)
1001
+ chunked_wavs.extend(batch_wav)
1002
+
1003
+ # Finish model generation
1004
+ t = (dt.datetime.now() - start_t).total_seconds()
1005
+
1006
+ # Merge chunked wavs
1007
+ indexed_chunked_wavs = [
1008
+ (index, wav) for index, wav in zip(chunked_index, chunked_wavs)
1009
+ ]
1010
+ sequential_indexed_chunked_wavs = sorted(indexed_chunked_wavs, key=lambda x: x[0])
1011
+ sequential_chunked_wavs = [
1012
+ sequential_indexed_chunked_wavs[i][1]
1013
+ for i in range(len(sequential_indexed_chunked_wavs))
1014
+ ]
1015
+ final_wav = cross_fade_concat(
1016
+ sequential_chunked_wavs, fade_duration=0.1, sample_rate=sampling_rate
1017
+ )
1018
+ final_wav = remove_silence(
1019
+ final_wav, sampling_rate, only_edge=(not remove_long_sil), trail_sil=0
1020
+ )
1021
+
1022
+ # Calculate processing time metrics
1023
+ t_no_vocoder = (start_vocoder_t - start_t).total_seconds()
1024
+ t_vocoder = (dt.datetime.now() - start_vocoder_t).total_seconds()
1025
+ wav_seconds = final_wav.shape[-1] / sampling_rate
1026
+ rtf = t / wav_seconds
1027
+ rtf_no_vocoder = t_no_vocoder / wav_seconds
1028
+ rtf_vocoder = t_vocoder / wav_seconds
1029
+ metrics = {
1030
+ "t": t,
1031
+ "t_no_vocoder": t_no_vocoder,
1032
+ "t_vocoder": t_vocoder,
1033
+ "wav_seconds": wav_seconds,
1034
+ "rtf": rtf,
1035
+ "rtf_no_vocoder": rtf_no_vocoder,
1036
+ "rtf_vocoder": rtf_vocoder,
1037
+ }
1038
+
1039
+ torchaudio.save(save_path, final_wav.cpu(), sample_rate=sampling_rate)
1040
+ return metrics
1041
+
1042
+
1043
+ def generate_list(
1044
+ model_name: str,
1045
+ res_dir: str,
1046
+ test_list: str,
1047
+ model: torch.nn.Module,
1048
+ vocoder: torch.nn.Module,
1049
+ tokenizer: DialogTokenizer,
1050
+ feature_extractor: VocosFbank,
1051
+ device: torch.device,
1052
+ num_step: int = 16,
1053
+ guidance_scale: float = 1.5,
1054
+ speed: float = 1.0,
1055
+ t_shift: float = 0.5,
1056
+ target_rms: float = 0.1,
1057
+ feat_scale: float = 0.1,
1058
+ sampling_rate: int = 24000,
1059
+ silence_wav: Optional[str] = None,
1060
+ raw_evaluation: bool = False,
1061
+ max_duration: float = 100,
1062
+ remove_long_sil: bool = False,
1063
+ ):
1064
+ total_t = []
1065
+ total_t_no_vocoder = []
1066
+ total_t_vocoder = []
1067
+ total_wav_seconds = []
1068
+
1069
+ with open(test_list, "r") as fr:
1070
+ lines = fr.readlines()
1071
+
1072
+ for i, line in enumerate(lines):
1073
+ items = line.strip().split("\t")
1074
+ if len(items) == 6:
1075
+ (
1076
+ wav_name,
1077
+ prompt_text_1,
1078
+ prompt_text_2,
1079
+ prompt_wav_1,
1080
+ prompt_wav_2,
1081
+ text,
1082
+ ) = items
1083
+ prompt_text = f"[S1]{prompt_text_1}[S2]{prompt_text_2}"
1084
+ prompt_wav = [prompt_wav_1, prompt_wav_2]
1085
+ elif len(items) == 4:
1086
+ wav_name, prompt_text, prompt_wav, text = items
1087
+ else:
1088
+ raise ValueError(f"Invalid line: {line}")
1089
+ assert text.startswith("[S1]")
1090
+
1091
+ save_path = f"{res_dir}/{wav_name}.wav"
1092
+
1093
+ common_params = {
1094
+ "save_path": save_path,
1095
+ "prompt_text": prompt_text,
1096
+ "prompt_wav": prompt_wav,
1097
+ "text": text,
1098
+ "model": model,
1099
+ "vocoder": vocoder,
1100
+ "tokenizer": tokenizer,
1101
+ "feature_extractor": feature_extractor,
1102
+ "device": device,
1103
+ "num_step": num_step,
1104
+ "guidance_scale": guidance_scale,
1105
+ "speed": speed,
1106
+ "t_shift": t_shift,
1107
+ "target_rms": target_rms,
1108
+ "feat_scale": feat_scale,
1109
+ "sampling_rate": sampling_rate,
1110
+ }
1111
+
1112
+ if model_name == "zipvoice_dialog":
1113
+ if raw_evaluation:
1114
+ metrics = generate_sentence_raw_evaluation(**common_params)
1115
+ else:
1116
+ metrics = generate_sentence(
1117
+ **common_params,
1118
+ max_duration=max_duration,
1119
+ remove_long_sil=remove_long_sil,
1120
+ )
1121
+ else:
1122
+ assert model_name == "zipvoice_dialog_stereo"
1123
+ if raw_evaluation:
1124
+ metrics = generate_sentence_stereo_raw_evaluation(
1125
+ **common_params,
1126
+ silence_wav=silence_wav,
1127
+ )
1128
+ else:
1129
+ metrics = generate_sentence_stereo(
1130
+ **common_params,
1131
+ silence_wav=silence_wav,
1132
+ max_duration=max_duration,
1133
+ remove_long_sil=remove_long_sil,
1134
+ )
1135
+ logging.info(f"[Sentence: {i}] Saved to: {save_path}")
1136
+ logging.info(f"[Sentence: {i}] RTF: {metrics['rtf']:.4f}")
1137
+ total_t.append(metrics["t"])
1138
+ total_t_no_vocoder.append(metrics["t_no_vocoder"])
1139
+ total_t_vocoder.append(metrics["t_vocoder"])
1140
+ total_wav_seconds.append(metrics["wav_seconds"])
1141
+
1142
+ logging.info(f"Average RTF: {np.sum(total_t) / np.sum(total_wav_seconds):.4f}")
1143
+ logging.info(
1144
+ f"Average RTF w/o vocoder: "
1145
+ f"{np.sum(total_t_no_vocoder) / np.sum(total_wav_seconds):.4f}"
1146
+ )
1147
+ logging.info(
1148
+ f"Average RTF vocoder: "
1149
+ f"{np.sum(total_t_vocoder) / np.sum(total_wav_seconds):.4f}"
1150
+ )
1151
+
1152
+
1153
+ @torch.inference_mode()
1154
+ def main():
1155
+ parser = get_parser()
1156
+ args = parser.parse_args()
1157
+
1158
+ torch.set_num_threads(args.num_thread)
1159
+ torch.set_num_interop_threads(args.num_thread)
1160
+
1161
+ params = AttributeDict()
1162
+ params.update(vars(args))
1163
+ fix_random_seed(params.seed)
1164
+
1165
+ assert (
1166
+ params.test_list is not None
1167
+ ), "For inference, please provide prompts and text with '--test-list'"
1168
+
1169
+ if params.model_dir is not None:
1170
+ params.model_dir = Path(params.model_dir)
1171
+ if not params.model_dir.is_dir():
1172
+ raise FileNotFoundError(f"{params.model_dir} does not exist")
1173
+ for filename in [params.checkpoint_name, "model.json", "tokens.txt"]:
1174
+ if not (params.model_dir / filename).is_file():
1175
+ raise FileNotFoundError(f"{params.model_dir / filename} does not exist")
1176
+ model_ckpt = params.model_dir / params.checkpoint_name
1177
+ model_config = params.model_dir / "model.json"
1178
+ token_file = params.model_dir / "tokens.txt"
1179
+ logging.info(
1180
+ f"Using {params.model_name} in local model dir {params.model_dir}, "
1181
+ f"checkpoint {params.checkpoint_name}"
1182
+ )
1183
+ else:
1184
+ logging.info(f"Using pretrained {params.model_name} model from the Huggingface")
1185
+ model_ckpt = hf_hub_download(
1186
+ HUGGINGFACE_REPO, filename=f"{MODEL_DIR[params.model_name]}/model.pt"
1187
+ )
1188
+ model_config = hf_hub_download(
1189
+ HUGGINGFACE_REPO, filename=f"{MODEL_DIR[params.model_name]}/model.json"
1190
+ )
1191
+
1192
+ token_file = hf_hub_download(
1193
+ HUGGINGFACE_REPO, filename=f"{MODEL_DIR[params.model_name]}/tokens.txt"
1194
+ )
1195
+
1196
+ tokenizer = DialogTokenizer(token_file=token_file)
1197
+
1198
+ tokenizer_config = {
1199
+ "vocab_size": tokenizer.vocab_size,
1200
+ "pad_id": tokenizer.pad_id,
1201
+ "spk_a_id": tokenizer.spk_a_id,
1202
+ "spk_b_id": tokenizer.spk_b_id,
1203
+ }
1204
+
1205
+ with open(model_config, "r") as f:
1206
+ model_config = json.load(f)
1207
+
1208
+ if params.model_name == "zipvoice_dialog":
1209
+ model = ZipVoiceDialog(
1210
+ **model_config["model"],
1211
+ **tokenizer_config,
1212
+ )
1213
+ else:
1214
+ assert params.model_name == "zipvoice_dialog_stereo"
1215
+ model = ZipVoiceDialogStereo(
1216
+ **model_config["model"],
1217
+ **tokenizer_config,
1218
+ )
1219
+
1220
+ if str(model_ckpt).endswith(".safetensors"):
1221
+ safetensors.torch.load_model(model, model_ckpt)
1222
+ elif str(model_ckpt).endswith(".pt"):
1223
+ load_checkpoint(filename=model_ckpt, model=model, strict=True)
1224
+ else:
1225
+ raise NotImplementedError(f"Unsupported model checkpoint format: {model_ckpt}")
1226
+
1227
+ if torch.cuda.is_available():
1228
+ params.device = torch.device("cuda", 0)
1229
+ elif torch.backends.mps.is_available():
1230
+ params.device = torch.device("mps")
1231
+ else:
1232
+ params.device = torch.device("cpu")
1233
+ logging.info(f"Device: {params.device}")
1234
+
1235
+ model = model.to(params.device)
1236
+ model.eval()
1237
+
1238
+ vocoder = get_vocoder(params.vocoder_path)
1239
+ vocoder = vocoder.to(params.device)
1240
+ vocoder.eval()
1241
+
1242
+ if model_config["feature"]["type"] == "vocos":
1243
+ if params.model_name == "zipvoice_dialog":
1244
+ num_channels = 1
1245
+ else:
1246
+ assert params.model_name == "zipvoice_dialog_stereo"
1247
+ num_channels = 2
1248
+ feature_extractor = VocosFbank(num_channels=num_channels)
1249
+ else:
1250
+ raise NotImplementedError(
1251
+ f"Unsupported feature type: {model_config['feature']['type']}"
1252
+ )
1253
+ params.sampling_rate = model_config["feature"]["sampling_rate"]
1254
+
1255
+ logging.info("Start generating...")
1256
+ os.makedirs(params.res_dir, exist_ok=True)
1257
+ generate_list(
1258
+ model_name=params.model_name,
1259
+ res_dir=params.res_dir,
1260
+ test_list=params.test_list,
1261
+ model=model,
1262
+ vocoder=vocoder,
1263
+ tokenizer=tokenizer,
1264
+ feature_extractor=feature_extractor,
1265
+ device=params.device,
1266
+ num_step=params.num_step,
1267
+ guidance_scale=params.guidance_scale,
1268
+ speed=params.speed,
1269
+ t_shift=params.t_shift,
1270
+ target_rms=params.target_rms,
1271
+ feat_scale=params.feat_scale,
1272
+ sampling_rate=params.sampling_rate,
1273
+ silence_wav=params.silence_wav,
1274
+ raw_evaluation=params.raw_evaluation,
1275
+ max_duration=params.max_duration,
1276
+ remove_long_sil=params.remove_long_sil,
1277
+ )
1278
+ logging.info("Done")
1279
+
1280
+
1281
+ if __name__ == "__main__":
1282
+
1283
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
1284
+ logging.basicConfig(format=formatter, level=logging.INFO, force=True)
1285
+
1286
+ main()
zipvoice/bin/infer_zipvoice_onnx.py ADDED
@@ -0,0 +1,924 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2025 Xiaomi Corp. (authors: Han Zhu,
3
+ # Zengwei Yao)
4
+ #
5
+ # See ../../../../LICENSE for clarification regarding multiple authors
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+
19
+ """
20
+ This script generates speech with our pre-trained ZipVoice or ZipVoice-Distill
21
+ ONNX models. If no local model is specified,
22
+ Required files will be automatically downloaded from HuggingFace.
23
+
24
+ Usage:
25
+
26
+ Note: If you having trouble connecting to HuggingFace,
27
+ try switching endpoint to mirror site:
28
+ export HF_ENDPOINT=https://hf-mirror.com
29
+
30
+ (1) Inference of a single sentence:
31
+
32
+ python3 -m zipvoice.bin.infer_zipvoice_onnx \
33
+ --onnx-int8 False \
34
+ --model-name zipvoice \
35
+ --prompt-wav prompt.wav \
36
+ --prompt-text "I am a prompt." \
37
+ --text "I am a sentence." \
38
+ --res-wav-path result.wav
39
+
40
+ (2) Inference of a list of sentences:
41
+ python3 -m zipvoice.bin.infer_zipvoice_onnx \
42
+ --onnx-int8 False \
43
+ --model-name zipvoice \
44
+ --test-list test.tsv \
45
+ --res-dir results
46
+
47
+ `--model-name` can be `zipvoice` or `zipvoice_distill`,
48
+ which are the models before and after distillation, respectively.
49
+
50
+ Each line of `test.tsv` is in the format of
51
+ `{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}`.
52
+
53
+ Set `--onnx-int8 True` to use int8 quantizated ONNX model.
54
+ Quantizated model has faster but lower quality.
55
+ """
56
+
57
+ import argparse
58
+ import datetime as dt
59
+ import json
60
+ import logging
61
+ import os
62
+ from pathlib import Path
63
+ from typing import List, Tuple
64
+
65
+ import numpy as np
66
+ import onnxruntime as ort
67
+ import torch
68
+ import torchaudio
69
+ from huggingface_hub import hf_hub_download
70
+ from lhotse.utils import fix_random_seed
71
+ from torch import Tensor, nn
72
+
73
+ from zipvoice.bin.infer_zipvoice import get_vocoder
74
+ from zipvoice.models.modules.solver import get_time_steps
75
+ from zipvoice.tokenizer.tokenizer import (
76
+ EmiliaTokenizer,
77
+ EspeakTokenizer,
78
+ LibriTTSTokenizer,
79
+ SimpleTokenizer,
80
+ )
81
+ from zipvoice.utils.common import AttributeDict, str2bool
82
+ from zipvoice.utils.feature import VocosFbank
83
+ from zipvoice.utils.infer import (
84
+ add_punctuation,
85
+ chunk_tokens_punctuation,
86
+ cross_fade_concat,
87
+ load_prompt_wav,
88
+ remove_silence,
89
+ rms_norm,
90
+ )
91
+
92
+ HUGGINGFACE_REPO = "k2-fsa/ZipVoice"
93
+ MODEL_DIR = {
94
+ "zipvoice": "zipvoice",
95
+ "zipvoice_distill": "zipvoice_distill",
96
+ }
97
+
98
+
99
+ def get_parser():
100
+ parser = argparse.ArgumentParser(
101
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
102
+ )
103
+
104
+ parser.add_argument(
105
+ "--onnx-int8",
106
+ type=str2bool,
107
+ default=False,
108
+ help="Whether to use the int8 model",
109
+ )
110
+
111
+ parser.add_argument(
112
+ "--model-name",
113
+ type=str,
114
+ default="zipvoice",
115
+ choices=["zipvoice", "zipvoice_distill"],
116
+ help="The model used for inference",
117
+ )
118
+
119
+ parser.add_argument(
120
+ "--model-dir",
121
+ type=str,
122
+ default=None,
123
+ help="The path to the local onnx model. "
124
+ "Will download pre-trained checkpoint from huggingface if not specified.",
125
+ )
126
+
127
+ parser.add_argument(
128
+ "--vocoder-path",
129
+ type=str,
130
+ default=None,
131
+ help="The vocoder checkpoint. "
132
+ "Will download pre-trained vocoder from huggingface if not specified.",
133
+ )
134
+
135
+ parser.add_argument(
136
+ "--tokenizer",
137
+ type=str,
138
+ default="emilia",
139
+ choices=["emilia", "libritts", "espeak", "simple"],
140
+ help="Tokenizer type.",
141
+ )
142
+
143
+ parser.add_argument(
144
+ "--lang",
145
+ type=str,
146
+ default="en-us",
147
+ help="Language identifier, used when tokenizer type is espeak. see"
148
+ "https://github.com/rhasspy/espeak-ng/blob/master/docs/languages.md",
149
+ )
150
+
151
+ parser.add_argument(
152
+ "--test-list",
153
+ type=str,
154
+ default=None,
155
+ help="The list of prompt speech, prompt_transcription, "
156
+ "and text to synthesizein the format of "
157
+ "'{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}'.",
158
+ )
159
+
160
+ parser.add_argument(
161
+ "--prompt-wav",
162
+ type=str,
163
+ default=None,
164
+ help="The prompt wav to mimic",
165
+ )
166
+
167
+ parser.add_argument(
168
+ "--prompt-text",
169
+ type=str,
170
+ default=None,
171
+ help="The transcription of the prompt wav",
172
+ )
173
+
174
+ parser.add_argument(
175
+ "--text",
176
+ type=str,
177
+ default=None,
178
+ help="The text to synthesize",
179
+ )
180
+
181
+ parser.add_argument(
182
+ "--res-dir",
183
+ type=str,
184
+ default="results",
185
+ help="""
186
+ Path name of the generated wavs dir,
187
+ used when test-list is not None
188
+ """,
189
+ )
190
+
191
+ parser.add_argument(
192
+ "--res-wav-path",
193
+ type=str,
194
+ default="result.wav",
195
+ help="""
196
+ Path name of the generated wav path,
197
+ used when test-list is None
198
+ """,
199
+ )
200
+
201
+ parser.add_argument(
202
+ "--guidance-scale",
203
+ type=float,
204
+ default=None,
205
+ help="The scale of classifier-free guidance during inference.",
206
+ )
207
+
208
+ parser.add_argument(
209
+ "--num-step",
210
+ type=int,
211
+ default=None,
212
+ help="The number of sampling steps.",
213
+ )
214
+
215
+ parser.add_argument(
216
+ "--feat-scale",
217
+ type=float,
218
+ default=0.1,
219
+ help="The scale factor of fbank feature",
220
+ )
221
+
222
+ parser.add_argument(
223
+ "--speed",
224
+ type=float,
225
+ default=1.0,
226
+ help="Control speech speed, 1.0 means normal, >1.0 means speed up",
227
+ )
228
+
229
+ parser.add_argument(
230
+ "--t-shift",
231
+ type=float,
232
+ default=0.5,
233
+ help="Shift t to smaller ones if t_shift < 1.0",
234
+ )
235
+
236
+ parser.add_argument(
237
+ "--target-rms",
238
+ type=float,
239
+ default=0.1,
240
+ help="Target speech normalization rms value, set to 0 to disable normalization",
241
+ )
242
+
243
+ parser.add_argument(
244
+ "--seed",
245
+ type=int,
246
+ default=666,
247
+ help="Random seed",
248
+ )
249
+
250
+ parser.add_argument(
251
+ "--num-thread",
252
+ type=int,
253
+ default=1,
254
+ help="Number of threads to use for ONNX Runtime and PyTorch.",
255
+ )
256
+
257
+ parser.add_argument(
258
+ "--raw-evaluation",
259
+ type=str2bool,
260
+ default=False,
261
+ help="Whether to use the 'raw' evaluation mode where provided "
262
+ "prompts and text are fed to the model without pre-processing",
263
+ )
264
+
265
+ parser.add_argument(
266
+ "--remove-long-sil",
267
+ type=str2bool,
268
+ default=False,
269
+ help="Whether to remove long silences in the middle of the generated "
270
+ "speech (edge silences will be removed by default).",
271
+ )
272
+ return parser
273
+
274
+
275
+ class OnnxModel:
276
+ def __init__(
277
+ self,
278
+ text_encoder_path: str,
279
+ fm_decoder_path: str,
280
+ num_thread: int = 1,
281
+ ):
282
+ session_opts = ort.SessionOptions()
283
+ session_opts.inter_op_num_threads = num_thread
284
+ session_opts.intra_op_num_threads = num_thread
285
+
286
+ self.session_opts = session_opts
287
+
288
+ self.init_text_encoder(text_encoder_path)
289
+ self.init_fm_decoder(fm_decoder_path)
290
+
291
+ def init_text_encoder(self, model_path: str):
292
+ self.text_encoder = ort.InferenceSession(
293
+ model_path,
294
+ sess_options=self.session_opts,
295
+ providers=["CPUExecutionProvider"],
296
+ )
297
+
298
+ def init_fm_decoder(self, model_path: str):
299
+ self.fm_decoder = ort.InferenceSession(
300
+ model_path,
301
+ sess_options=self.session_opts,
302
+ providers=["CPUExecutionProvider"],
303
+ )
304
+ meta = self.fm_decoder.get_modelmeta().custom_metadata_map
305
+ self.feat_dim = int(meta["feat_dim"])
306
+
307
+ def run_text_encoder(
308
+ self,
309
+ tokens: Tensor,
310
+ prompt_tokens: Tensor,
311
+ prompt_features_len: Tensor,
312
+ speed: Tensor,
313
+ ) -> Tuple[Tensor, Tensor]:
314
+ out = self.text_encoder.run(
315
+ [
316
+ self.text_encoder.get_outputs()[0].name,
317
+ ],
318
+ {
319
+ self.text_encoder.get_inputs()[0].name: tokens.numpy(),
320
+ self.text_encoder.get_inputs()[1].name: prompt_tokens.numpy(),
321
+ self.text_encoder.get_inputs()[2].name: prompt_features_len.numpy(),
322
+ self.text_encoder.get_inputs()[3].name: speed.numpy(),
323
+ },
324
+ )
325
+ return torch.from_numpy(out[0])
326
+
327
+ def run_fm_decoder(
328
+ self,
329
+ t: Tensor,
330
+ x: Tensor,
331
+ text_condition: Tensor,
332
+ speech_condition: torch.Tensor,
333
+ guidance_scale: Tensor,
334
+ ) -> Tensor:
335
+ out = self.fm_decoder.run(
336
+ [
337
+ self.fm_decoder.get_outputs()[0].name,
338
+ ],
339
+ {
340
+ self.fm_decoder.get_inputs()[0].name: t.numpy(),
341
+ self.fm_decoder.get_inputs()[1].name: x.numpy(),
342
+ self.fm_decoder.get_inputs()[2].name: text_condition.numpy(),
343
+ self.fm_decoder.get_inputs()[3].name: speech_condition.numpy(),
344
+ self.fm_decoder.get_inputs()[4].name: guidance_scale.numpy(),
345
+ },
346
+ )
347
+ return torch.from_numpy(out[0])
348
+
349
+
350
+ def sample(
351
+ model: OnnxModel,
352
+ tokens: List[List[int]],
353
+ prompt_tokens: List[List[int]],
354
+ prompt_features: Tensor,
355
+ speed: float = 1.0,
356
+ t_shift: float = 0.5,
357
+ guidance_scale: float = 1.0,
358
+ num_step: int = 16,
359
+ ) -> torch.Tensor:
360
+ """
361
+ Generate acoustic features, given text tokens, prompts feature and prompt
362
+ transcription's text tokens.
363
+
364
+ Args:
365
+ tokens: a list of list of text tokens.
366
+ prompt_tokens: a list of list of prompt tokens.
367
+ prompt_features: the prompt feature with the shape
368
+ (batch_size, seq_len, feat_dim).
369
+ speed : speed control.
370
+ t_shift: time shift.
371
+ guidance_scale: the guidance scale for classifier-free guidance.
372
+ num_step: the number of steps to use in the ODE solver.
373
+ """
374
+ # Run text encoder
375
+ assert len(tokens) == len(prompt_tokens) == 1
376
+ tokens = torch.tensor(tokens, dtype=torch.int64)
377
+ prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.int64)
378
+ prompt_features_len = torch.tensor(prompt_features.size(1), dtype=torch.int64)
379
+ speed = torch.tensor(speed, dtype=torch.float32)
380
+
381
+ text_condition = model.run_text_encoder(
382
+ tokens, prompt_tokens, prompt_features_len, speed
383
+ )
384
+
385
+ batch_size, num_frames, _ = text_condition.shape
386
+ assert batch_size == 1
387
+ feat_dim = model.feat_dim
388
+
389
+ # Run flow matching model
390
+ timesteps = get_time_steps(
391
+ t_start=0.0,
392
+ t_end=1.0,
393
+ num_step=num_step,
394
+ t_shift=t_shift,
395
+ )
396
+ x = torch.randn(batch_size, num_frames, feat_dim)
397
+ speech_condition = torch.nn.functional.pad(
398
+ prompt_features, (0, 0, 0, num_frames - prompt_features.shape[1])
399
+ ) # (B, T, F)
400
+ guidance_scale = torch.tensor(guidance_scale, dtype=torch.float32)
401
+
402
+ for step in range(num_step):
403
+ v = model.run_fm_decoder(
404
+ t=timesteps[step],
405
+ x=x,
406
+ text_condition=text_condition,
407
+ speech_condition=speech_condition,
408
+ guidance_scale=guidance_scale,
409
+ )
410
+ x = x + v * (timesteps[step + 1] - timesteps[step])
411
+
412
+ x = x[:, prompt_features_len.item() :, :]
413
+ return x
414
+
415
+
416
+ # Copied from zipvoice/bin/infer_zipvoice.py, but call an external sample function
417
+ def generate_sentence_raw_evaluation(
418
+ save_path: str,
419
+ prompt_text: str,
420
+ prompt_wav: str,
421
+ text: str,
422
+ model: OnnxModel,
423
+ vocoder: nn.Module,
424
+ tokenizer: EmiliaTokenizer,
425
+ feature_extractor: VocosFbank,
426
+ num_step: int = 16,
427
+ guidance_scale: float = 1.0,
428
+ speed: float = 1.0,
429
+ t_shift: float = 0.5,
430
+ target_rms: float = 0.1,
431
+ feat_scale: float = 0.1,
432
+ sampling_rate: int = 24000,
433
+ ):
434
+ """
435
+ Generate waveform of a text based on a given prompt waveform and its transcription,
436
+ this function directly feed the prompt_text, prompt_wav and text to the model.
437
+ It is not efficient and can have poor results for some inappropriate inputs.
438
+ (e.g., prompt wav contains long silence, text to be generated is too long)
439
+ This function can be used to evaluate the "raw" performance of the model.
440
+
441
+ Args:
442
+ save_path (str): Path to save the generated wav.
443
+ prompt_text (str): Transcription of the prompt wav.
444
+ prompt_wav (str): Path to the prompt wav file.
445
+ text (str): Text to be synthesized into a waveform.
446
+ model (torch.nn.Module): The model used for generation.
447
+ vocoder (torch.nn.Module): The vocoder used to convert features to waveforms.
448
+ tokenizer (EmiliaTokenizer): The tokenizer used to convert text to tokens.
449
+ feature_extractor (VocosFbank): The feature extractor used to
450
+ extract acoustic features.
451
+ num_step (int, optional): Number of steps for decoding. Defaults to 16.
452
+ guidance_scale (float, optional): Scale for classifier-free guidance.
453
+ Defaults to 1.0.
454
+ speed (float, optional): Speed control. Defaults to 1.0.
455
+ t_shift (float, optional): Time shift. Defaults to 0.5.
456
+ target_rms (float, optional): Target RMS for waveform normalization.
457
+ Defaults to 0.1.
458
+ feat_scale (float, optional): Scale for features.
459
+ Defaults to 0.1.
460
+ sampling_rate (int, optional): Sampling rate for the waveform.
461
+ Defaults to 24000.
462
+ Returns:
463
+ metrics (dict): Dictionary containing time and real-time
464
+ factor metrics for processing.
465
+ """
466
+
467
+ # Load and process prompt wav
468
+ prompt_wav = load_prompt_wav(prompt_wav, sampling_rate=sampling_rate)
469
+ prompt_wav, prompt_rms = rms_norm(prompt_wav, target_rms)
470
+
471
+ # Extract features from prompt wav
472
+ prompt_features = feature_extractor.extract(prompt_wav, sampling_rate=sampling_rate)
473
+
474
+ prompt_features = prompt_features.unsqueeze(0) * feat_scale
475
+
476
+ # Convert text to tokens
477
+ tokens = tokenizer.texts_to_token_ids([text])
478
+ prompt_tokens = tokenizer.texts_to_token_ids([prompt_text])
479
+
480
+ # Start timing
481
+ start_t = dt.datetime.now()
482
+
483
+ # Generate features
484
+ pred_features = sample(
485
+ model=model,
486
+ tokens=tokens,
487
+ prompt_tokens=prompt_tokens,
488
+ prompt_features=prompt_features,
489
+ speed=speed,
490
+ t_shift=t_shift,
491
+ guidance_scale=guidance_scale,
492
+ num_step=num_step,
493
+ )
494
+
495
+ # Postprocess predicted features
496
+ pred_features = pred_features.permute(0, 2, 1) / feat_scale # (B, C, T)
497
+
498
+ # Start vocoder processing
499
+ start_vocoder_t = dt.datetime.now()
500
+ wav = vocoder.decode(pred_features).squeeze(1).clamp(-1, 1)
501
+
502
+ # Calculate processing times and real-time factors
503
+ t = (dt.datetime.now() - start_t).total_seconds()
504
+ t_no_vocoder = (start_vocoder_t - start_t).total_seconds()
505
+ t_vocoder = (dt.datetime.now() - start_vocoder_t).total_seconds()
506
+ wav_seconds = wav.shape[-1] / sampling_rate
507
+ rtf = t / wav_seconds
508
+ rtf_no_vocoder = t_no_vocoder / wav_seconds
509
+ rtf_vocoder = t_vocoder / wav_seconds
510
+ metrics = {
511
+ "t": t,
512
+ "t_no_vocoder": t_no_vocoder,
513
+ "t_vocoder": t_vocoder,
514
+ "wav_seconds": wav_seconds,
515
+ "rtf": rtf,
516
+ "rtf_no_vocoder": rtf_no_vocoder,
517
+ "rtf_vocoder": rtf_vocoder,
518
+ }
519
+
520
+ # Adjust wav volume if necessary
521
+ if prompt_rms < target_rms:
522
+ wav = wav * prompt_rms / target_rms
523
+ torchaudio.save(save_path, wav.cpu(), sample_rate=sampling_rate)
524
+
525
+ return metrics
526
+
527
+
528
+ def generate_sentence(
529
+ save_path: str,
530
+ prompt_text: str,
531
+ prompt_wav: str,
532
+ text: str,
533
+ model: OnnxModel,
534
+ vocoder: nn.Module,
535
+ tokenizer: EmiliaTokenizer,
536
+ feature_extractor: VocosFbank,
537
+ num_step: int = 16,
538
+ guidance_scale: float = 1.0,
539
+ speed: float = 1.0,
540
+ t_shift: float = 0.5,
541
+ target_rms: float = 0.1,
542
+ feat_scale: float = 0.1,
543
+ sampling_rate: int = 24000,
544
+ remove_long_sil: bool = False,
545
+ ):
546
+ """
547
+ Generate waveform of a text based on a given prompt waveform and its transcription,
548
+ this function will do the following to improve the generation quality:
549
+ 1. chunk the text according to punctuations.
550
+ 2. process chunked texts sequentially.
551
+ 3. remove long silences in the prompt audio.
552
+ 4. add punctuation to the end of prompt text and text if there is not.
553
+
554
+ Args:
555
+ save_path (str): Path to save the generated wav.
556
+ prompt_text (str): Transcription of the prompt wav.
557
+ prompt_wav (str): Path to the prompt wav file.
558
+ text (str): Text to be synthesized into a waveform.
559
+ model (torch.nn.Module): The model used for generation.
560
+ vocoder (torch.nn.Module): The vocoder used to convert features to waveforms.
561
+ tokenizer (EmiliaTokenizer): The tokenizer used to convert text to tokens.
562
+ feature_extractor (VocosFbank): The feature extractor used to
563
+ extract acoustic features.
564
+ num_step (int, optional): Number of steps for decoding. Defaults to 16.
565
+ guidance_scale (float, optional): Scale for classifier-free guidance.
566
+ Defaults to 1.0.
567
+ speed (float, optional): Speed control. Defaults to 1.0.
568
+ t_shift (float, optional): Time shift. Defaults to 0.5.
569
+ target_rms (float, optional): Target RMS for waveform normalization.
570
+ Defaults to 0.1.
571
+ feat_scale (float, optional): Scale for features.
572
+ Defaults to 0.1.
573
+ sampling_rate (int, optional): Sampling rate for the waveform.
574
+ Defaults to 24000.
575
+ remove_long_sil (bool, optional): Whether to remove long silences in the
576
+ middle of the generated speech (edge silences will be removed by default).
577
+ Returns:
578
+ metrics (dict): Dictionary containing time and real-time
579
+ factor metrics for processing.
580
+ """
581
+
582
+ # Load and process prompt wav
583
+ prompt_wav = load_prompt_wav(prompt_wav, sampling_rate=sampling_rate)
584
+
585
+ # Remove edge and long silences in the prompt wav.
586
+ # Add 0.2s trailing silence to avoid leaking prompt to generated speech.
587
+ prompt_wav = remove_silence(
588
+ prompt_wav, sampling_rate, only_edge=False, trail_sil=200
589
+ )
590
+
591
+ prompt_wav, prompt_rms = rms_norm(prompt_wav, target_rms)
592
+
593
+ prompt_duration = prompt_wav.shape[-1] / sampling_rate
594
+
595
+ if prompt_duration > 20:
596
+ logging.warning(
597
+ f"Given prompt wav is too long ({prompt_duration}s). "
598
+ f"Please provide a shorter one (1-3 seconds is recommended)."
599
+ )
600
+ elif prompt_duration > 10:
601
+ logging.warning(
602
+ f"Given prompt wav is long ({prompt_duration}s). "
603
+ f"It will lead to slower inference speed and possibly worse speech quality."
604
+ )
605
+
606
+ # Extract features from prompt wav
607
+ prompt_features = feature_extractor.extract(prompt_wav, sampling_rate=sampling_rate)
608
+
609
+ prompt_features = prompt_features.unsqueeze(0) * feat_scale
610
+
611
+ # Add punctuation in the end if there is not
612
+ text = add_punctuation(text)
613
+ prompt_text = add_punctuation(prompt_text)
614
+
615
+ # Tokenize text (str tokens), punctuations will be preserved.
616
+ tokens_str = tokenizer.texts_to_tokens([text])[0]
617
+ prompt_tokens_str = tokenizer.texts_to_tokens([prompt_text])[0]
618
+
619
+ # chunk text so that each len(prompt wav + generated wav) is around 25 seconds.
620
+ token_duration = (prompt_wav.shape[-1] / sampling_rate) / (
621
+ len(prompt_tokens_str) * speed
622
+ )
623
+ max_tokens = int((25 - prompt_duration) / token_duration)
624
+ chunked_tokens_str = chunk_tokens_punctuation(tokens_str, max_tokens=max_tokens)
625
+ print(len(chunked_tokens_str))
626
+ print(chunked_tokens_str)
627
+
628
+ # Tokenize text (int tokens)
629
+ chunked_tokens = tokenizer.tokens_to_token_ids(chunked_tokens_str)
630
+ prompt_tokens = tokenizer.tokens_to_token_ids([prompt_tokens_str])
631
+
632
+ # Start predicting features
633
+ chunked_features = []
634
+ start_t = dt.datetime.now()
635
+ for tokens in chunked_tokens:
636
+
637
+ # Generate features
638
+ pred_features = sample(
639
+ model=model,
640
+ tokens=[tokens],
641
+ prompt_tokens=prompt_tokens,
642
+ prompt_features=prompt_features,
643
+ speed=speed,
644
+ t_shift=t_shift,
645
+ guidance_scale=guidance_scale,
646
+ num_step=num_step,
647
+ )
648
+
649
+ # Postprocess predicted features
650
+ pred_features = pred_features.permute(0, 2, 1) / feat_scale # (B, C, T)
651
+ chunked_features.append(pred_features)
652
+
653
+ # Start vocoder processing
654
+ chunked_wavs = []
655
+ start_vocoder_t = dt.datetime.now()
656
+
657
+ for pred_features in chunked_features:
658
+ wav = vocoder.decode(pred_features).squeeze(1).clamp(-1, 1)
659
+ # Adjust wav volume if necessary
660
+ if prompt_rms < target_rms:
661
+ wav = wav * prompt_rms / target_rms
662
+ chunked_wavs.append(wav)
663
+
664
+ # Finish model generation
665
+ t = (dt.datetime.now() - start_t).total_seconds()
666
+
667
+ # Merge chunked wavs
668
+ final_wav = cross_fade_concat(
669
+ chunked_wavs, fade_duration=0.1, sample_rate=sampling_rate
670
+ )
671
+ final_wav = remove_silence(
672
+ final_wav, sampling_rate, only_edge=(not remove_long_sil), trail_sil=0
673
+ )
674
+
675
+ # Calculate processing time metrics
676
+ t_no_vocoder = (start_vocoder_t - start_t).total_seconds()
677
+ t_vocoder = (dt.datetime.now() - start_vocoder_t).total_seconds()
678
+ wav_seconds = final_wav.shape[-1] / sampling_rate
679
+ rtf = t / wav_seconds
680
+ rtf_no_vocoder = t_no_vocoder / wav_seconds
681
+ rtf_vocoder = t_vocoder / wav_seconds
682
+ metrics = {
683
+ "t": t,
684
+ "t_no_vocoder": t_no_vocoder,
685
+ "t_vocoder": t_vocoder,
686
+ "wav_seconds": wav_seconds,
687
+ "rtf": rtf,
688
+ "rtf_no_vocoder": rtf_no_vocoder,
689
+ "rtf_vocoder": rtf_vocoder,
690
+ }
691
+
692
+ torchaudio.save(save_path, final_wav.cpu(), sample_rate=sampling_rate)
693
+ return metrics
694
+
695
+
696
+ def generate_list(
697
+ res_dir: str,
698
+ test_list: str,
699
+ model: OnnxModel,
700
+ vocoder: nn.Module,
701
+ tokenizer: EmiliaTokenizer,
702
+ feature_extractor: VocosFbank,
703
+ num_step: int = 16,
704
+ guidance_scale: float = 1.0,
705
+ speed: float = 1.0,
706
+ t_shift: float = 0.5,
707
+ target_rms: float = 0.1,
708
+ feat_scale: float = 0.1,
709
+ sampling_rate: int = 24000,
710
+ raw_evaluation: bool = False,
711
+ remove_long_sil: bool = False,
712
+ ):
713
+ total_t = []
714
+ total_t_no_vocoder = []
715
+ total_t_vocoder = []
716
+ total_wav_seconds = []
717
+
718
+ with open(test_list, "r") as fr:
719
+ lines = fr.readlines()
720
+
721
+ for i, line in enumerate(lines):
722
+ wav_name, prompt_text, prompt_wav, text = line.strip().split("\t")
723
+ save_path = f"{res_dir}/{wav_name}.wav"
724
+
725
+ common_params = {
726
+ "save_path": save_path,
727
+ "prompt_text": prompt_text,
728
+ "prompt_wav": prompt_wav,
729
+ "text": text,
730
+ "model": model,
731
+ "vocoder": vocoder,
732
+ "tokenizer": tokenizer,
733
+ "feature_extractor": feature_extractor,
734
+ "num_step": num_step,
735
+ "guidance_scale": guidance_scale,
736
+ "speed": speed,
737
+ "t_shift": t_shift,
738
+ "target_rms": target_rms,
739
+ "feat_scale": feat_scale,
740
+ "sampling_rate": sampling_rate,
741
+ }
742
+
743
+ if raw_evaluation:
744
+ metrics = generate_sentence_raw_evaluation(**common_params)
745
+ else:
746
+ metrics = generate_sentence(
747
+ **common_params,
748
+ remove_long_sil=remove_long_sil,
749
+ )
750
+ logging.info(f"[Sentence: {i}] Saved to: {save_path}")
751
+ logging.info(f"[Sentence: {i}] RTF: {metrics['rtf']:.4f}")
752
+ total_t.append(metrics["t"])
753
+ total_t_no_vocoder.append(metrics["t_no_vocoder"])
754
+ total_t_vocoder.append(metrics["t_vocoder"])
755
+ total_wav_seconds.append(metrics["wav_seconds"])
756
+
757
+ logging.info(f"Average RTF: {np.sum(total_t) / np.sum(total_wav_seconds):.4f}")
758
+ logging.info(
759
+ f"Average RTF w/o vocoder: "
760
+ f"{np.sum(total_t_no_vocoder) / np.sum(total_wav_seconds):.4f}"
761
+ )
762
+ logging.info(
763
+ f"Average RTF vocoder: "
764
+ f"{np.sum(total_t_vocoder) / np.sum(total_wav_seconds):.4f}"
765
+ )
766
+
767
+
768
+ @torch.inference_mode()
769
+ def main():
770
+ parser = get_parser()
771
+ args = parser.parse_args()
772
+
773
+ torch.set_num_threads(args.num_thread)
774
+ torch.set_num_interop_threads(args.num_thread)
775
+
776
+ params = AttributeDict()
777
+ params.update(vars(args))
778
+ fix_random_seed(params.seed)
779
+
780
+ model_defaults = {
781
+ "zipvoice": {
782
+ "num_step": 16,
783
+ "guidance_scale": 1.0,
784
+ },
785
+ "zipvoice_distill": {
786
+ "num_step": 8,
787
+ "guidance_scale": 3.0,
788
+ },
789
+ }
790
+
791
+ model_specific_defaults = model_defaults.get(params.model_name, {})
792
+
793
+ for param, value in model_specific_defaults.items():
794
+ if getattr(params, param) is None:
795
+ setattr(params, param, value)
796
+ logging.info(f"Setting {param} to default value: {value}")
797
+
798
+ assert (params.test_list is not None) ^ (
799
+ (params.prompt_wav and params.prompt_text and params.text) is not None
800
+ ), (
801
+ "For inference, please provide prompts and text with either '--test-list'"
802
+ " or '--prompt-wav, --prompt-text and --text'."
803
+ )
804
+
805
+ if params.onnx_int8:
806
+ text_encoder_name = "text_encoder_int8.onnx"
807
+ fm_decoder_name = "fm_decoder_int8.onnx"
808
+ else:
809
+ text_encoder_name = "text_encoder.onnx"
810
+ fm_decoder_name = "fm_decoder.onnx"
811
+
812
+ if params.model_dir is not None:
813
+ params.model_dir = Path(params.model_dir)
814
+ if not params.model_dir.is_dir():
815
+ raise FileNotFoundError(f"{params.model_dir} does not exist")
816
+
817
+ for filename in [
818
+ text_encoder_name,
819
+ fm_decoder_name,
820
+ "model.json",
821
+ "tokens.txt",
822
+ ]:
823
+ if not (params.model_dir / filename).is_file():
824
+ raise FileNotFoundError(f"{params.model_dir / filename} does not exist")
825
+ text_encoder_path = params.model_dir / text_encoder_name
826
+ fm_decoder_path = params.model_dir / fm_decoder_name
827
+ model_config = params.model_dir / "model.json"
828
+ token_file = params.model_dir / "tokens.txt"
829
+ logging.info(f"Using local model dir {params.model_dir}.")
830
+ else:
831
+ logging.info("Using pretrained model from the Huggingface")
832
+ text_encoder_path = hf_hub_download(
833
+ HUGGINGFACE_REPO,
834
+ filename=f"{MODEL_DIR[params.model_name]}/{text_encoder_name}",
835
+ )
836
+ fm_decoder_path = hf_hub_download(
837
+ HUGGINGFACE_REPO,
838
+ filename=f"{MODEL_DIR[params.model_name]}/{fm_decoder_name}",
839
+ )
840
+ model_config = hf_hub_download(
841
+ HUGGINGFACE_REPO, filename=f"{MODEL_DIR[params.model_name]}/model.json"
842
+ )
843
+
844
+ token_file = hf_hub_download(
845
+ HUGGINGFACE_REPO, filename=f"{MODEL_DIR[params.model_name]}/tokens.txt"
846
+ )
847
+
848
+ if params.tokenizer == "emilia":
849
+ tokenizer = EmiliaTokenizer(token_file=token_file)
850
+ elif params.tokenizer == "libritts":
851
+ tokenizer = LibriTTSTokenizer(token_file=token_file)
852
+ elif params.tokenizer == "espeak":
853
+ tokenizer = EspeakTokenizer(token_file=token_file, lang=params.lang)
854
+ else:
855
+ assert params.tokenizer == "simple"
856
+ tokenizer = SimpleTokenizer(token_file=token_file)
857
+
858
+ with open(model_config, "r") as f:
859
+ model_config = json.load(f)
860
+
861
+ model = OnnxModel(text_encoder_path, fm_decoder_path, num_thread=args.num_thread)
862
+
863
+ vocoder = get_vocoder(params.vocoder_path)
864
+ vocoder.eval()
865
+
866
+ if model_config["feature"]["type"] == "vocos":
867
+ feature_extractor = VocosFbank()
868
+ else:
869
+ raise NotImplementedError(
870
+ f"Unsupported feature type: {model_config['feature']['type']}"
871
+ )
872
+ params.sampling_rate = model_config["feature"]["sampling_rate"]
873
+
874
+ logging.info("Start generating...")
875
+ if params.test_list:
876
+ os.makedirs(params.res_dir, exist_ok=True)
877
+ generate_list(
878
+ res_dir=params.res_dir,
879
+ test_list=params.test_list,
880
+ model=model,
881
+ vocoder=vocoder,
882
+ tokenizer=tokenizer,
883
+ feature_extractor=feature_extractor,
884
+ num_step=params.num_step,
885
+ guidance_scale=params.guidance_scale,
886
+ speed=params.speed,
887
+ t_shift=params.t_shift,
888
+ target_rms=params.target_rms,
889
+ feat_scale=params.feat_scale,
890
+ sampling_rate=params.sampling_rate,
891
+ raw_evaluation=params.raw_evaluation,
892
+ remove_long_sil=params.remove_long_sil,
893
+ )
894
+ else:
895
+ assert (
896
+ not params.raw_evaluation
897
+ ), "Raw evaluation is only valid with --test-list"
898
+ generate_sentence(
899
+ save_path=params.res_wav_path,
900
+ prompt_text=params.prompt_text,
901
+ prompt_wav=params.prompt_wav,
902
+ text=params.text,
903
+ model=model,
904
+ vocoder=vocoder,
905
+ tokenizer=tokenizer,
906
+ feature_extractor=feature_extractor,
907
+ num_step=params.num_step,
908
+ guidance_scale=params.guidance_scale,
909
+ speed=params.speed,
910
+ t_shift=params.t_shift,
911
+ target_rms=params.target_rms,
912
+ feat_scale=params.feat_scale,
913
+ sampling_rate=params.sampling_rate,
914
+ remove_long_sil=params.remove_long_sil,
915
+ )
916
+ logging.info(f"Saved to: {params.res_wav_path}")
917
+ logging.info("Done")
918
+
919
+
920
+ if __name__ == "__main__":
921
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
922
+ logging.basicConfig(format=formatter, level=logging.INFO, force=True)
923
+
924
+ main()
zipvoice/bin/onnx_export.py ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2025 Xiaomi Corp. (authors: Zengwei Yao)
3
+ #
4
+ # See ../../../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """
19
+ This script exports a pre-trained ZipVoice or ZipVoice-Distill model from PyTorch to
20
+ ONNX.
21
+
22
+ Usage:
23
+
24
+ python3 -m zipvoice.bin.onnx_export \
25
+ --model-name zipvoice \
26
+ --model-dir exp/zipvoice \
27
+ --checkpoint-name epoch-11-avg-4.pt \
28
+ --onnx-model-dir exp/zipvoice
29
+
30
+ `--model-name` can be `zipvoice` or `zipvoice_distill`,
31
+ which are the models before and after distillation, respectively.
32
+ """
33
+
34
+
35
+ import argparse
36
+ import json
37
+ import logging
38
+ from pathlib import Path
39
+ from typing import Dict
40
+
41
+ import onnx
42
+ import safetensors.torch
43
+ import torch
44
+ from onnxruntime.quantization import QuantType, quantize_dynamic
45
+ from torch import Tensor, nn
46
+
47
+ from zipvoice.models.zipvoice import ZipVoice
48
+ from zipvoice.models.zipvoice_distill import ZipVoiceDistill
49
+ from zipvoice.tokenizer.tokenizer import SimpleTokenizer
50
+ from zipvoice.utils.checkpoint import load_checkpoint
51
+ from zipvoice.utils.common import AttributeDict
52
+ from zipvoice.utils.scaling_converter import convert_scaled_to_non_scaled
53
+
54
+
55
+ def get_parser():
56
+ parser = argparse.ArgumentParser(
57
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
58
+ )
59
+
60
+ parser.add_argument(
61
+ "--onnx-model-dir",
62
+ type=str,
63
+ default="exp",
64
+ help="Dir to the exported models",
65
+ )
66
+
67
+ parser.add_argument(
68
+ "--model-name",
69
+ type=str,
70
+ default="zipvoice",
71
+ choices=["zipvoice", "zipvoice_distill"],
72
+ help="The model used for inference",
73
+ )
74
+
75
+ parser.add_argument(
76
+ "--model-dir",
77
+ type=str,
78
+ default=None,
79
+ help="The model directory that contains model checkpoint, configuration "
80
+ "file model.json, and tokens file tokens.txt. Will download pre-trained "
81
+ "checkpoint from huggingface if not specified.",
82
+ )
83
+
84
+ parser.add_argument(
85
+ "--checkpoint-name",
86
+ type=str,
87
+ default="model.pt",
88
+ help="The name of model checkpoint.",
89
+ )
90
+
91
+ return parser
92
+
93
+
94
+ def add_meta_data(filename: str, meta_data: Dict[str, str]):
95
+ """Add meta data to an ONNX model. It is changed in-place.
96
+
97
+ Args:
98
+ filename:
99
+ Filename of the ONNX model to be changed.
100
+ meta_data:
101
+ Key-value pairs.
102
+ """
103
+ model = onnx.load(filename)
104
+ for key, value in meta_data.items():
105
+ meta = model.metadata_props.add()
106
+ meta.key = key
107
+ meta.value = value
108
+
109
+ onnx.save(model, filename)
110
+
111
+
112
+ class OnnxTextModel(nn.Module):
113
+ def __init__(self, model: nn.Module):
114
+ """A wrapper for ZipVoice text encoder."""
115
+ super().__init__()
116
+ self.embed = model.embed
117
+ self.text_encoder = model.text_encoder
118
+ self.pad_id = model.pad_id
119
+
120
+ def forward(
121
+ self,
122
+ tokens: Tensor,
123
+ prompt_tokens: Tensor,
124
+ prompt_features_len: Tensor,
125
+ speed: Tensor,
126
+ ) -> Tensor:
127
+ cat_tokens = torch.cat([prompt_tokens, tokens], dim=1)
128
+ cat_tokens = nn.functional.pad(cat_tokens, (0, 1), value=self.pad_id)
129
+ tokens_len = cat_tokens.shape[1] - 1
130
+ padding_mask = (torch.arange(tokens_len + 1) == tokens_len).unsqueeze(0)
131
+
132
+ embed = self.embed(cat_tokens)
133
+ embed = self.text_encoder(x=embed, t=None, padding_mask=padding_mask)
134
+
135
+ features_len = torch.ceil(
136
+ (prompt_features_len / prompt_tokens.shape[1] * tokens_len / speed)
137
+ ).to(dtype=torch.int64)
138
+
139
+ token_dur = torch.div(features_len, tokens_len, rounding_mode="floor").to(
140
+ dtype=torch.int64
141
+ )
142
+
143
+ # If you pass a scalar tensor, ONNX may infer the shape as [1] (rank-1 tensor),
144
+ # but sometimes expects an actual scalar (rank-0).
145
+ # When exporting, ONNX may generate a model where Concat expects inputs of the
146
+ # same rank, but receives [1] and [].
147
+ # In PyTorch, this is usually fine. In ONNX Runtime (C++), this causes the error like
148
+ # "Ranks of input data are different, cannot concatenate them. expected rank: 1 got: 2"
149
+ # If you use x.item(), ONNX loses the dynamic link and the input mismatch error can happen at inference.
150
+ # use reshape(()) to convert a rank-1 tensor to a rank-0 tensor.
151
+
152
+ token_dur = token_dur.reshape(())
153
+ features_len = features_len.reshape(())
154
+
155
+ text_condition = embed[:, :-1, :].unsqueeze(2).expand(-1, -1, token_dur, -1)
156
+ text_condition = text_condition.reshape(embed.shape[0], -1, embed.shape[2])
157
+
158
+ text_condition = torch.cat(
159
+ [
160
+ text_condition,
161
+ embed[:, -1:, :].expand(-1, features_len - text_condition.shape[1], -1),
162
+ ],
163
+ dim=1,
164
+ )
165
+
166
+ return text_condition
167
+
168
+
169
+ class OnnxFlowMatchingModel(nn.Module):
170
+ def __init__(self, model: nn.Module, distill: bool = False):
171
+ """A wrapper for ZipVoice flow-matching decoder."""
172
+ super().__init__()
173
+ self.distill = distill
174
+ self.fm_decoder = model.fm_decoder
175
+ self.model_func = getattr(model, "forward_fm_decoder")
176
+ self.feat_dim = model.feat_dim
177
+
178
+ def forward(
179
+ self,
180
+ t: Tensor,
181
+ x: Tensor,
182
+ text_condition: Tensor,
183
+ speech_condition: torch.Tensor,
184
+ guidance_scale: Tensor,
185
+ ) -> Tensor:
186
+ if self.distill:
187
+ return self.model_func(
188
+ t=t,
189
+ xt=x,
190
+ text_condition=text_condition,
191
+ speech_condition=speech_condition,
192
+ guidance_scale=guidance_scale,
193
+ )
194
+ else:
195
+ x = x.repeat(2, 1, 1)
196
+ text_condition = torch.cat(
197
+ [torch.zeros_like(text_condition), text_condition], dim=0
198
+ )
199
+ speech_condition = torch.cat(
200
+ [
201
+ torch.where(
202
+ t > 0.5, torch.zeros_like(speech_condition), speech_condition
203
+ ),
204
+ speech_condition,
205
+ ],
206
+ dim=0,
207
+ )
208
+ guidance_scale = torch.where(t > 0.5, guidance_scale, guidance_scale * 2.0)
209
+ data_uncond, data_cond = self.model_func(
210
+ t=t,
211
+ xt=x,
212
+ text_condition=text_condition,
213
+ speech_condition=speech_condition,
214
+ ).chunk(2, dim=0)
215
+ v = (1 + guidance_scale) * data_cond - guidance_scale * data_uncond
216
+ return v
217
+
218
+
219
+ def export_text_encoder(
220
+ model: OnnxTextModel,
221
+ filename: str,
222
+ opset_version: int = 13,
223
+ ) -> None:
224
+ """Export the text encoder model to ONNX format.
225
+
226
+ Args:
227
+ model:
228
+ The input model
229
+ filename:
230
+ The filename to save the exported ONNX model.
231
+ opset_version:
232
+ The opset version to use.
233
+ """
234
+ tokens = torch.tensor([[2, 3, 4, 5]], dtype=torch.int64)
235
+ prompt_tokens = torch.tensor([[0, 1]], dtype=torch.int64)
236
+ prompt_features_len = torch.tensor(10, dtype=torch.int64)
237
+ speed = torch.tensor(1.0, dtype=torch.float32)
238
+
239
+ model = torch.jit.trace(model, (tokens, prompt_tokens, prompt_features_len, speed))
240
+
241
+ torch.onnx.export(
242
+ model,
243
+ (tokens, prompt_tokens, prompt_features_len, speed),
244
+ filename,
245
+ verbose=False,
246
+ opset_version=opset_version,
247
+ input_names=["tokens", "prompt_tokens", "prompt_features_len", "speed"],
248
+ output_names=["text_condition"],
249
+ dynamic_axes={
250
+ "tokens": {0: "N", 1: "T"},
251
+ "prompt_tokens": {0: "N", 1: "T"},
252
+ "text_condition": {0: "N", 1: "T"},
253
+ },
254
+ )
255
+
256
+ meta_data = {
257
+ "version": "1",
258
+ "model_author": "k2-fsa",
259
+ "comment": "ZipVoice text encoder",
260
+ "use_espeak": "1",
261
+ "use_pinyin": "1",
262
+ }
263
+ logging.info(f"meta_data: {meta_data}")
264
+ add_meta_data(filename=filename, meta_data=meta_data)
265
+
266
+ logging.info(f"Exported to {filename}")
267
+
268
+
269
+ def export_fm_decoder(
270
+ model: OnnxFlowMatchingModel,
271
+ filename: str,
272
+ opset_version: int = 13,
273
+ ) -> None:
274
+ """Export the flow matching decoder model to ONNX format.
275
+
276
+ Args:
277
+ model:
278
+ The input model
279
+ filename:
280
+ The filename to save the exported ONNX model.
281
+ opset_version:
282
+ The opset version to use.
283
+ """
284
+ feat_dim = model.feat_dim
285
+ seq_len = 200
286
+ t = torch.tensor(0.5, dtype=torch.float32)
287
+ x = torch.randn(1, seq_len, feat_dim, dtype=torch.float32)
288
+ text_condition = torch.randn(1, seq_len, feat_dim, dtype=torch.float32)
289
+ speech_condition = torch.randn(1, seq_len, feat_dim, dtype=torch.float32)
290
+ guidance_scale = torch.tensor(1.0, dtype=torch.float32)
291
+
292
+ model = torch.jit.trace(
293
+ model, (t, x, text_condition, speech_condition, guidance_scale)
294
+ )
295
+
296
+ torch.onnx.export(
297
+ model,
298
+ (t, x, text_condition, speech_condition, guidance_scale),
299
+ filename,
300
+ verbose=False,
301
+ opset_version=opset_version,
302
+ input_names=["t", "x", "text_condition", "speech_condition", "guidance_scale"],
303
+ output_names=["v"],
304
+ dynamic_axes={
305
+ "x": {0: "N", 1: "T"},
306
+ "text_condition": {0: "N", 1: "T"},
307
+ "speech_condition": {0: "N", 1: "T"},
308
+ "v": {0: "N", 1: "T"},
309
+ },
310
+ )
311
+
312
+ meta_data = {
313
+ "version": "1",
314
+ "model_author": "k2-fsa",
315
+ "comment": "ZipVoice flow-matching decoder",
316
+ "feat_dim": str(feat_dim),
317
+ "sample_rate": "24000",
318
+ "n_fft": "1024",
319
+ "hop_length": "256",
320
+ "window_length": "1024",
321
+ "num_mels": "100",
322
+ }
323
+ logging.info(f"meta_data: {meta_data}")
324
+ add_meta_data(filename=filename, meta_data=meta_data)
325
+
326
+ logging.info(f"Exported to {filename}")
327
+
328
+
329
+ @torch.no_grad()
330
+ def main():
331
+ parser = get_parser()
332
+ args = parser.parse_args()
333
+
334
+ params = AttributeDict()
335
+ params.update(vars(args))
336
+
337
+ params.model_dir = Path(params.model_dir)
338
+ if not params.model_dir.is_dir():
339
+ raise FileNotFoundError(f"{params.model_dir} does not exist")
340
+ for filename in [params.checkpoint_name, "model.json", "tokens.txt"]:
341
+ if not (params.model_dir / filename).is_file():
342
+ raise FileNotFoundError(f"{params.model_dir / filename} does not exist")
343
+ model_ckpt = params.model_dir / params.checkpoint_name
344
+ model_config = params.model_dir / "model.json"
345
+ token_file = params.model_dir / "tokens.txt"
346
+
347
+ logging.info(f"Loading model from {params.model_dir}")
348
+
349
+ tokenizer = SimpleTokenizer(token_file)
350
+ tokenizer_config = {"vocab_size": tokenizer.vocab_size, "pad_id": tokenizer.pad_id}
351
+
352
+ with open(model_config, "r") as f:
353
+ model_config = json.load(f)
354
+
355
+ if params.model_name == "zipvoice":
356
+ model = ZipVoice(
357
+ **model_config["model"],
358
+ **tokenizer_config,
359
+ )
360
+ distill = False
361
+ else:
362
+ assert params.model_name == "zipvoice_distill"
363
+ model = ZipVoiceDistill(
364
+ **model_config["model"],
365
+ **tokenizer_config,
366
+ )
367
+ distill = True
368
+
369
+ if str(model_ckpt).endswith(".safetensors"):
370
+ safetensors.torch.load_model(model, model_ckpt)
371
+ elif str(model_ckpt).endswith(".pt"):
372
+ load_checkpoint(filename=model_ckpt, model=model, strict=True)
373
+ else:
374
+ raise NotImplementedError(f"Unsupported model checkpoint format: {model_ckpt}")
375
+
376
+ device = torch.device("cpu")
377
+ model = model.to(device)
378
+ model.eval()
379
+
380
+ convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True)
381
+
382
+ logging.info("Exporting model")
383
+ onnx_model_dir = Path(params.onnx_model_dir)
384
+ onnx_model_dir.mkdir(parents=True, exist_ok=True)
385
+ opset_version = 13
386
+
387
+ text_encoder = OnnxTextModel(model=model)
388
+ text_encoder_file = onnx_model_dir / "text_encoder.onnx"
389
+ export_text_encoder(
390
+ model=text_encoder,
391
+ filename=text_encoder_file,
392
+ opset_version=opset_version,
393
+ )
394
+
395
+ fm_decoder = OnnxFlowMatchingModel(model=model, distill=distill)
396
+ fm_decoder_file = onnx_model_dir / "fm_decoder.onnx"
397
+ export_fm_decoder(
398
+ model=fm_decoder,
399
+ filename=fm_decoder_file,
400
+ opset_version=opset_version,
401
+ )
402
+
403
+ logging.info("Generate int8 quantization models")
404
+
405
+ text_encoder_int8_file = onnx_model_dir / "text_encoder_int8.onnx"
406
+ quantize_dynamic(
407
+ model_input=text_encoder_file,
408
+ model_output=text_encoder_int8_file,
409
+ op_types_to_quantize=["MatMul"],
410
+ weight_type=QuantType.QInt8,
411
+ )
412
+
413
+ fm_decoder_int8_file = onnx_model_dir / "fm_decoder_int8.onnx"
414
+ quantize_dynamic(
415
+ model_input=fm_decoder_file,
416
+ model_output=fm_decoder_int8_file,
417
+ op_types_to_quantize=["MatMul"],
418
+ weight_type=QuantType.QInt8,
419
+ )
420
+
421
+ logging.info("Done!")
422
+
423
+
424
+ if __name__ == "__main__":
425
+
426
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
427
+ logging.basicConfig(format=formatter, level=logging.INFO, force=True)
428
+
429
+ main()
zipvoice/bin/prepare_dataset.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2025 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """
19
+ This script generates lhotse manifest files from TSV files for custom datasets.
20
+
21
+ Each line of the TSV files should be in one of the following formats:
22
+ 1. "{uniq_id}\t{text}\t{wav_path}" if the text corresponds to the full wav",
23
+ 2. "{uniq_id}\t{text}\t{wav_path}\t{start_time}\t{end_time} if text corresponds
24
+ to part of the wav. The start_time and end_time specify the start and end
25
+ times of the text within the wav, which should be in seconds.
26
+
27
+ Note: {uniq_id} must be unique for each line.
28
+
29
+ Usage:
30
+
31
+ Suppose you have two TSV files: "custom_train.tsv" and "custom_dev.tsv",
32
+ where "custom" is your dataset name, "train"/"dev" are used for training and
33
+ validation respectively.
34
+
35
+ (1) Prepare the training data
36
+
37
+ python3 -m zipvoice.bin.prepare_dataset \
38
+ --tsv-path data/raw/custom_train.tsv \
39
+ --prefix "custom" \
40
+ --subset "train" \
41
+ --num-jobs 20 \
42
+ --output-dir "data/manifests"
43
+
44
+ The output file would be "data/manifests/custom_cuts_train.jsonl.gz".
45
+
46
+ (2) Prepare the validation data
47
+
48
+ python3 -m zipvoice.bin.prepare_dataset \
49
+ --tsv-path data/raw/custom_dev.tsv \
50
+ --prefix "custom" \
51
+ --subset "dev" \
52
+ --num-jobs 1 \
53
+ --output-dir "data/manifests"
54
+
55
+ The output file would be "data/manifests/custom_cuts_dev.jsonl.gz".
56
+
57
+ """
58
+
59
+ import argparse
60
+ import logging
61
+ import re
62
+ from concurrent.futures import ThreadPoolExecutor
63
+ from pathlib import Path
64
+ from typing import List, Optional, Tuple
65
+
66
+ from lhotse import CutSet, validate_recordings_and_supervisions
67
+ from lhotse.audio import Recording, RecordingSet
68
+ from lhotse.qa import fix_manifests
69
+ from lhotse.supervision import SupervisionSegment, SupervisionSet
70
+ from lhotse.utils import Pathlike
71
+ from tqdm.auto import tqdm
72
+
73
+
74
+ def get_args():
75
+ parser = argparse.ArgumentParser()
76
+
77
+ parser.add_argument(
78
+ "--tsv-path",
79
+ type=str,
80
+ help="The path of the tsv file. Each line should be in the format: "
81
+ "{uniq_id}\t{text}\t{wav_path}\t{start_time}\t{end_time} "
82
+ "if text corresponds to part of the wav or {uniq_id}\t{text}\t{wav_path} "
83
+ "if the text corresponds to the full wav",
84
+ )
85
+ parser.add_argument(
86
+ "--prefix",
87
+ type=str,
88
+ default="custom",
89
+ help="Prefix of the output manifest file.",
90
+ )
91
+
92
+ parser.add_argument(
93
+ "--subset",
94
+ type=str,
95
+ default="train",
96
+ help="Subset name manifest file, typically train or dev.",
97
+ )
98
+
99
+ parser.add_argument(
100
+ "--num-jobs",
101
+ type=int,
102
+ default=20,
103
+ help="Number of jobs to processing.",
104
+ )
105
+
106
+ parser.add_argument(
107
+ "--output-dir",
108
+ type=str,
109
+ default="data/manifests",
110
+ help="The destination directory of manifest files.",
111
+ )
112
+ parser.add_argument(
113
+ "--sampling-rate",
114
+ type=int,
115
+ default=24000,
116
+ help="The target sampling rate.",
117
+ )
118
+ return parser.parse_args()
119
+
120
+
121
+ def _parse_recording(
122
+ wav_path: str,
123
+ ) -> Tuple[Recording, str]:
124
+ """
125
+ :param wav_path: Path to the audio file
126
+ :return: a tuple of "recording" and "recording_id"
127
+ """
128
+
129
+ recording_id = wav_path.replace("/", "_").replace(".", "_")
130
+ recording = Recording.from_file(path=wav_path, recording_id=recording_id)
131
+
132
+ return recording, recording_id
133
+
134
+
135
+ def _parse_supervision(
136
+ supervision: List, recording_dict: dict
137
+ ) -> Optional[SupervisionSegment]:
138
+ """
139
+ :param line: A line from the TSV file
140
+ :param recording_dict: Dictionary mapping recording IDs to Recording objects
141
+ :return: A SupervisionSegment object
142
+ """
143
+
144
+ uniq_id, text, wav_path, start, end = supervision
145
+ try:
146
+ recording_id = wav_path.replace("/", "_").replace(".", "_")
147
+
148
+ recording = recording_dict[recording_id]
149
+ duration = end - start if end is not None else recording.duration
150
+ assert duration <= recording.duration, f"Duration {duration} is greater than "
151
+ f"recording duration {recording.duration}"
152
+
153
+ text = re.sub("_", " ", text) # "_" is treated as padding symbol
154
+ text = re.sub(r"\s+", " ", text) # remove extra whitespace
155
+
156
+ return SupervisionSegment(
157
+ id=f"{uniq_id}",
158
+ recording_id=recording.id,
159
+ start=start,
160
+ duration=duration,
161
+ channel=recording.channel_ids,
162
+ text=text.strip(),
163
+ )
164
+ except Exception as e:
165
+ logging.warning(f"Error processing line: {e}")
166
+ return None
167
+
168
+
169
+ def prepare_dataset(
170
+ tsv_path: Pathlike,
171
+ prefix: str,
172
+ subset: str,
173
+ sampling_rate: int,
174
+ num_jobs: int,
175
+ output_dir: Pathlike,
176
+ ):
177
+ """
178
+ Returns the manifests which consist of the Recordings and Supervisions
179
+
180
+ :param tsv_path: Path to the TSV file
181
+ :param output_dir: Path where to write the manifests
182
+ :param num_jobs: Number of processes for parallel processing
183
+ :return: The CutSet containing the data
184
+ """
185
+ logging.info(f"Preparing {prefix} dataset {subset} subset.")
186
+ output_dir = Path(output_dir)
187
+ output_dir.mkdir(parents=True, exist_ok=True)
188
+ file_name = f"{prefix}_cuts_{subset}.jsonl.gz"
189
+ if (output_dir / file_name).is_file():
190
+ logging.info(f"{file_name} exists, skipping.")
191
+ return
192
+
193
+ # Step 1: Read all unique recording paths
194
+ recordings_path_set = set()
195
+ supervision_list = list()
196
+ with open(tsv_path, "r") as fr:
197
+ for line in fr:
198
+ items = line.strip().split("\t")
199
+ if len(items) == 3:
200
+ uniq_id, text, wav_path = items
201
+ start, end = 0, None
202
+ elif len(items) == 5:
203
+ uniq_id, text, wav_path, start, end = items
204
+ start, end = float(start), float(end)
205
+ else:
206
+ raise ValueError(
207
+ f"Invalid line format: {line},"
208
+ "requries to be 3 columns or 5 columns"
209
+ )
210
+ recordings_path_set.add(wav_path)
211
+ supervision_list.append((uniq_id, text, wav_path, start, end))
212
+
213
+ logging.info("Starting to process recordings...")
214
+ # Step 2: Process recordings
215
+ futures = []
216
+ recording_dict = {}
217
+ with ThreadPoolExecutor(max_workers=num_jobs) as ex:
218
+ for wav_path in tqdm(recordings_path_set, desc="Submitting jobs"):
219
+ futures.append(ex.submit(_parse_recording, wav_path))
220
+
221
+ for future in tqdm(futures, desc="Processing recordings"):
222
+ try:
223
+ recording, recording_id = future.result()
224
+ recording_dict[recording_id] = recording
225
+ except Exception as e:
226
+ logging.warning(
227
+ f"Error processing recording {recording_id} with error: {e}"
228
+ )
229
+
230
+ recording_set = RecordingSet.from_recordings(recording_dict.values())
231
+
232
+ logging.info("Starting to process supervisions...")
233
+ # Step 3: Process supervisions
234
+ supervisions = []
235
+ for supervision in tqdm(supervision_list, desc="Processing supervisions"):
236
+ seg = _parse_supervision(supervision, recording_dict)
237
+ if seg is not None:
238
+ supervisions.append(seg)
239
+
240
+ logging.info("Processing Cuts...")
241
+
242
+ # Step 4: Create and validate manifests
243
+ supervision_set = SupervisionSet.from_segments(supervisions)
244
+
245
+ recording_set, supervision_set = fix_manifests(recording_set, supervision_set)
246
+ validate_recordings_and_supervisions(recording_set, supervision_set)
247
+
248
+ cut_set = CutSet.from_manifests(
249
+ recordings=recording_set, supervisions=supervision_set
250
+ )
251
+ cut_set = cut_set.sort_by_recording_id()
252
+ cut_set = cut_set.resample(sampling_rate)
253
+ cut_set = cut_set.trim_to_supervisions(keep_overlapping=False)
254
+
255
+ logging.info(f"Saving file to {output_dir / file_name}")
256
+ # Step 5: Write manifests to disk
257
+ cut_set.to_file(output_dir / file_name)
258
+ logging.info("Done!")
259
+
260
+
261
+ if __name__ == "__main__":
262
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
263
+ logging.basicConfig(format=formatter, level=logging.INFO, force=True)
264
+
265
+ args = get_args()
266
+
267
+ prepare_dataset(
268
+ tsv_path=args.tsv_path,
269
+ prefix=args.prefix,
270
+ subset=args.subset,
271
+ sampling_rate=args.sampling_rate,
272
+ num_jobs=args.num_jobs,
273
+ output_dir=args.output_dir,
274
+ )
zipvoice/bin/prepare_tokens.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file reads the texts in given manifest and save the new cuts with prepared tokens.
3
+ """
4
+
5
+ import argparse
6
+ import logging
7
+ from functools import partial
8
+ from pathlib import Path
9
+
10
+ from lhotse import load_manifest, split_parallelize_combine
11
+
12
+ from zipvoice.tokenizer.tokenizer import add_tokens
13
+
14
+
15
+ def get_args():
16
+ parser = argparse.ArgumentParser()
17
+
18
+ parser.add_argument(
19
+ "--input-file",
20
+ type=str,
21
+ help="Input manifest without tokens",
22
+ )
23
+
24
+ parser.add_argument(
25
+ "--output-file",
26
+ type=str,
27
+ help="Output manifest with tokens.",
28
+ )
29
+
30
+ parser.add_argument(
31
+ "--num-jobs",
32
+ type=int,
33
+ default=20,
34
+ help="Number of jobs to run in parallel.",
35
+ )
36
+
37
+ parser.add_argument(
38
+ "--tokenizer",
39
+ type=str,
40
+ default="emilia",
41
+ choices=["emilia", "espeak", "dialog", "libritts", "simple"],
42
+ help="The destination directory of manifest files.",
43
+ )
44
+
45
+ parser.add_argument(
46
+ "--lang",
47
+ type=str,
48
+ default="en-us",
49
+ help="Language identifier, used when tokenizer type is espeak. see"
50
+ "https://github.com/rhasspy/espeak-ng/blob/master/docs/languages.md",
51
+ )
52
+
53
+ return parser.parse_args()
54
+
55
+
56
+ def prepare_tokens(
57
+ input_file: Path,
58
+ output_file: Path,
59
+ num_jobs: int,
60
+ tokenizer: str,
61
+ lang: str = "en-us",
62
+ ):
63
+ logging.info(f"Processing {input_file}")
64
+ if output_file.is_file():
65
+ logging.info(f"{output_file} exists, skipping.")
66
+ return
67
+ logging.info(f"loading manifest from {input_file}")
68
+ cut_set = load_manifest(input_file)
69
+
70
+ _add_tokens = partial(add_tokens, tokenizer=tokenizer, lang=lang)
71
+
72
+ logging.info("Adding tokens")
73
+
74
+ cut_set = split_parallelize_combine(
75
+ num_jobs=num_jobs, manifest=cut_set, fn=_add_tokens
76
+ )
77
+
78
+ logging.info(f"Saving file to {output_file}")
79
+ cut_set.to_file(output_file)
80
+
81
+
82
+ if __name__ == "__main__":
83
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
84
+ logging.basicConfig(format=formatter, level=logging.INFO, force=True)
85
+
86
+ args = get_args()
87
+ input_file = Path(args.input_file)
88
+ output_file = Path(args.output_file)
89
+ num_jobs = args.num_jobs
90
+ tokenizer = args.tokenizer
91
+ lang = args.lang
92
+
93
+ output_file.parent.mkdir(parents=True, exist_ok=True)
94
+
95
+ prepare_tokens(
96
+ input_file=input_file,
97
+ output_file=output_file,
98
+ num_jobs=num_jobs,
99
+ tokenizer=tokenizer,
100
+ lang=lang,
101
+ )
102
+
103
+ logging.info("Done!")
zipvoice/bin/train_zipvoice.py ADDED
@@ -0,0 +1,1130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2024-2025 Xiaomi Corp. (authors: Wei Kang,
3
+ # Han Zhu)
4
+ #
5
+ # See ../../../../LICENSE for clarification regarding multiple authors
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+
19
+ """
20
+ This script trains a ZipVoice model with the flow-matching loss.
21
+
22
+ Usage:
23
+
24
+ python3 -m zipvoice.bin.train_zipvoice \
25
+ --world-size 8 \
26
+ --use-fp16 1 \
27
+ --num-epochs 11 \
28
+ --max-duration 500 \
29
+ --lr-hours 30000 \
30
+ --model-config conf/zipvoice_base.json \
31
+ --tokenizer emilia \
32
+ --token-file "data/tokens_emilia.txt" \
33
+ --dataset emilia \
34
+ --manifest-dir data/fbank \
35
+ --exp-dir exp/zipvoice
36
+ """
37
+
38
+ import argparse
39
+ import copy
40
+ import json
41
+ import logging
42
+ import os
43
+ from functools import partial
44
+ from pathlib import Path
45
+ from shutil import copyfile
46
+ from typing import List, Optional, Tuple, Union
47
+
48
+ import torch
49
+ import torch.multiprocessing as mp
50
+ import torch.nn as nn
51
+ from lhotse.cut import Cut, CutSet
52
+ from lhotse.utils import fix_random_seed
53
+ from torch import Tensor
54
+ from torch.nn.parallel import DistributedDataParallel as DDP
55
+ from torch.optim import Optimizer
56
+ from torch.utils.tensorboard import SummaryWriter
57
+
58
+ import zipvoice.utils.diagnostics as diagnostics
59
+ from zipvoice.dataset.datamodule import TtsDataModule
60
+ from zipvoice.models.zipvoice import ZipVoice
61
+ from zipvoice.tokenizer.tokenizer import (
62
+ EmiliaTokenizer,
63
+ EspeakTokenizer,
64
+ LibriTTSTokenizer,
65
+ SimpleTokenizer,
66
+ )
67
+ from zipvoice.utils.checkpoint import (
68
+ load_checkpoint,
69
+ remove_checkpoints,
70
+ resume_checkpoint,
71
+ save_checkpoint,
72
+ save_checkpoint_with_global_batch_idx,
73
+ update_averaged_model,
74
+ )
75
+ from zipvoice.utils.common import (
76
+ AttributeDict,
77
+ GradScaler,
78
+ MetricsTracker,
79
+ cleanup_dist,
80
+ create_grad_scaler,
81
+ get_adjusted_batch_count,
82
+ get_env_info,
83
+ get_parameter_groups_with_lrs,
84
+ prepare_input,
85
+ set_batch_count,
86
+ setup_dist,
87
+ setup_logger,
88
+ str2bool,
89
+ torch_autocast,
90
+ )
91
+ from zipvoice.utils.hooks import register_inf_check_hooks
92
+ from zipvoice.utils.lr_scheduler import Eden, FixedLRScheduler, LRScheduler
93
+ from zipvoice.utils.optim import ScaledAdam
94
+
95
+ LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, LRScheduler]
96
+
97
+
98
+ def get_parser():
99
+ parser = argparse.ArgumentParser(
100
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
101
+ )
102
+
103
+ parser.add_argument(
104
+ "--world-size",
105
+ type=int,
106
+ default=1,
107
+ help="Number of GPUs for DDP training.",
108
+ )
109
+
110
+ parser.add_argument(
111
+ "--master-port",
112
+ type=int,
113
+ default=12356,
114
+ help="Master port to use for DDP training.",
115
+ )
116
+
117
+ parser.add_argument(
118
+ "--tensorboard",
119
+ type=str2bool,
120
+ default=True,
121
+ help="Should various information be logged in tensorboard.",
122
+ )
123
+
124
+ parser.add_argument(
125
+ "--num-epochs",
126
+ type=int,
127
+ default=11,
128
+ help="Number of epochs to train.",
129
+ )
130
+
131
+ parser.add_argument(
132
+ "--num-iters",
133
+ type=int,
134
+ default=0,
135
+ help="Number of iter to train, will ignore num_epochs if > 0.",
136
+ )
137
+
138
+ parser.add_argument(
139
+ "--start-epoch",
140
+ type=int,
141
+ default=1,
142
+ help="""Resume training from this epoch. It should be positive.
143
+ If larger than 1, it will load checkpoint from
144
+ exp-dir/epoch-{start_epoch-1}.pt
145
+ """,
146
+ )
147
+
148
+ parser.add_argument(
149
+ "--checkpoint",
150
+ type=str,
151
+ default=None,
152
+ help="""Checkpoints of pre-trained models, will load it if not None
153
+ """,
154
+ )
155
+
156
+ parser.add_argument(
157
+ "--exp-dir",
158
+ type=str,
159
+ default="exp/zipvoice",
160
+ help="""The experiment dir.
161
+ It specifies the directory where all training related
162
+ files, e.g., checkpoints, log, etc, are saved
163
+ """,
164
+ )
165
+
166
+ parser.add_argument(
167
+ "--base-lr", type=float, default=0.02, help="The base learning rate."
168
+ )
169
+
170
+ parser.add_argument(
171
+ "--lr-batches",
172
+ type=float,
173
+ default=7500,
174
+ help="""Number of steps that affects how rapidly the learning rate
175
+ decreases. We suggest not to change this.""",
176
+ )
177
+
178
+ parser.add_argument(
179
+ "--lr-epochs",
180
+ type=float,
181
+ default=10,
182
+ help="""Number of epochs that affects how rapidly the learning rate decreases.
183
+ """,
184
+ )
185
+
186
+ parser.add_argument(
187
+ "--lr-hours",
188
+ type=float,
189
+ default=0,
190
+ help="""If positive, --epoch is ignored and it specifies the number of hours
191
+ that affects how rapidly the learning rate decreases.
192
+ """,
193
+ )
194
+
195
+ parser.add_argument(
196
+ "--ref-duration",
197
+ type=float,
198
+ default=50,
199
+ help="""Reference batch duration for purposes of adjusting batch counts for"
200
+ setting various schedules inside the model".
201
+ """,
202
+ )
203
+
204
+ parser.add_argument(
205
+ "--finetune",
206
+ type=str2bool,
207
+ default=False,
208
+ help="Whether to use the fine-tuning mode, will used a fixed learning rate "
209
+ "schedule and skip the large dropout phase.",
210
+ )
211
+
212
+ parser.add_argument(
213
+ "--seed",
214
+ type=int,
215
+ default=42,
216
+ help="The seed for random generators intended for reproducibility",
217
+ )
218
+
219
+ parser.add_argument(
220
+ "--print-diagnostics",
221
+ type=str2bool,
222
+ default=False,
223
+ help="Accumulate stats on activations, print them and exit.",
224
+ )
225
+
226
+ parser.add_argument(
227
+ "--scan-oom",
228
+ type=str2bool,
229
+ default=False,
230
+ help="Scan pessimistic batches to see whether they cause OOMs.",
231
+ )
232
+
233
+ parser.add_argument(
234
+ "--inf-check",
235
+ type=str2bool,
236
+ default=False,
237
+ help="Add hooks to check for infinite module outputs and gradients.",
238
+ )
239
+
240
+ parser.add_argument(
241
+ "--save-every-n",
242
+ type=int,
243
+ default=5000,
244
+ help="""Save checkpoint after processing this number of batches"
245
+ periodically. We save checkpoint to exp-dir/ whenever
246
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
247
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
248
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
249
+ end of each epoch where `xxx` is the epoch number counting from 1.
250
+ """,
251
+ )
252
+
253
+ parser.add_argument(
254
+ "--valid-by-epoch",
255
+ type=str2bool,
256
+ default=False,
257
+ help="""Whether to validate after each epoch. If False, will validate
258
+ after every save_every_n iterations.
259
+ """,
260
+ )
261
+
262
+ parser.add_argument(
263
+ "--keep-last-k",
264
+ type=int,
265
+ default=30,
266
+ help="""Only keep this number of checkpoints on disk.
267
+ For instance, if it is 3, there are only 3 checkpoints
268
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
269
+ It does not affect checkpoints with name `epoch-xxx.pt`.
270
+ """,
271
+ )
272
+
273
+ parser.add_argument(
274
+ "--average-period",
275
+ type=int,
276
+ default=200,
277
+ help="""Update the averaged model, namely `model_avg`, after processing
278
+ this number of batches. `model_avg` is a separate version of model,
279
+ in which each floating-point parameter is the average of all the
280
+ parameters from the start of training. Each time we take the average,
281
+ we do: `model_avg = model * (average_period / batch_idx_train) +
282
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
283
+ """,
284
+ )
285
+
286
+ parser.add_argument(
287
+ "--use-fp16",
288
+ type=str2bool,
289
+ default=True,
290
+ help="Whether to use half precision training.",
291
+ )
292
+
293
+ parser.add_argument(
294
+ "--feat-scale",
295
+ type=float,
296
+ default=0.1,
297
+ help="The scale factor of fbank feature",
298
+ )
299
+
300
+ parser.add_argument(
301
+ "--condition-drop-ratio",
302
+ type=float,
303
+ default=0.2,
304
+ help="The drop rate of text condition during training.",
305
+ )
306
+
307
+ parser.add_argument(
308
+ "--dataset",
309
+ type=str,
310
+ default="emilia",
311
+ choices=["emilia", "libritts", "custom"],
312
+ help="The used training dataset",
313
+ )
314
+
315
+ parser.add_argument(
316
+ "--train-manifest",
317
+ type=str,
318
+ help="Path of the training manifest",
319
+ )
320
+
321
+ parser.add_argument(
322
+ "--dev-manifest",
323
+ type=str,
324
+ help="Path of the validation manifest",
325
+ )
326
+
327
+ parser.add_argument(
328
+ "--min-len",
329
+ type=float,
330
+ default=1.0,
331
+ help="The minimum audio length used for training",
332
+ )
333
+
334
+ parser.add_argument(
335
+ "--max-len",
336
+ type=float,
337
+ default=30.0,
338
+ help="The maximum audio length used for training",
339
+ )
340
+
341
+ parser.add_argument(
342
+ "--model-config",
343
+ type=str,
344
+ default="conf/zipvoice_base.json",
345
+ help="The model configuration file.",
346
+ )
347
+
348
+ parser.add_argument(
349
+ "--tokenizer",
350
+ type=str,
351
+ default="emilia",
352
+ choices=["emilia", "libritts", "espeak", "simple"],
353
+ help="Tokenizer type.",
354
+ )
355
+
356
+ parser.add_argument(
357
+ "--lang",
358
+ type=str,
359
+ default="en-us",
360
+ help="Language identifier, used when tokenizer type is espeak. see"
361
+ "https://github.com/rhasspy/espeak-ng/blob/master/docs/languages.md",
362
+ )
363
+
364
+ parser.add_argument(
365
+ "--token-file",
366
+ type=str,
367
+ default="data/tokens_emilia.txt",
368
+ help="The file that contains information that maps tokens to ids,"
369
+ "which is a text file with '{token}\t{token_id}' per line.",
370
+ )
371
+
372
+ return parser
373
+
374
+
375
+ def get_params() -> AttributeDict:
376
+ """Return a dict containing training parameters.
377
+
378
+ All training related parameters that are not passed from the commandline
379
+ are saved in the variable `params`.
380
+
381
+ Commandline options are merged into `params` after they are parsed, so
382
+ you can also access them via `params`.
383
+
384
+ Explanation of options saved in `params`:
385
+
386
+ - best_train_loss: Best training loss so far. It is used to select
387
+ the model that has the lowest training loss. It is
388
+ updated during the training.
389
+
390
+ - best_valid_loss: Best validation loss so far. It is used to select
391
+ the model that has the lowest validation loss. It is
392
+ updated during the training.
393
+
394
+ - best_train_epoch: It is the epoch that has the best training loss.
395
+
396
+ - best_valid_epoch: It is the epoch that has the best validation loss.
397
+
398
+ - batch_idx_train: Used to writing statistics to tensorboard. It
399
+ contains number of batches trained so far across
400
+ epochs.
401
+
402
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
403
+
404
+ - reset_interval: Reset statistics if batch_idx % reset_interval is 0
405
+
406
+ - env_info: A dict containing information about the environment.
407
+
408
+ """
409
+ params = AttributeDict(
410
+ {
411
+ "best_train_loss": float("inf"),
412
+ "best_valid_loss": float("inf"),
413
+ "best_train_epoch": -1,
414
+ "best_valid_epoch": -1,
415
+ "batch_idx_train": 0,
416
+ "log_interval": 50,
417
+ "reset_interval": 200,
418
+ "env_info": get_env_info(),
419
+ }
420
+ )
421
+
422
+ return params
423
+
424
+
425
+ def compute_fbank_loss(
426
+ params: AttributeDict,
427
+ model: Union[nn.Module, DDP],
428
+ features: Tensor,
429
+ features_lens: Tensor,
430
+ tokens: List[List[int]],
431
+ is_training: bool,
432
+ ) -> Tuple[Tensor, MetricsTracker]:
433
+ """
434
+ Compute loss given the model and its inputs.
435
+
436
+ Args:
437
+ params:
438
+ Parameters for training. See :func:`get_params`.
439
+ model:
440
+ The model for training.
441
+ features:
442
+ The target acoustic feature.
443
+ features_lens:
444
+ The number of frames of each utterance.
445
+ tokens:
446
+ Input tokens that representing the transcripts.
447
+ is_training:
448
+ True for training. False for validation. When it is True, this
449
+ function enables autograd during computation; when it is False, it
450
+ disables autograd.
451
+ """
452
+
453
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
454
+
455
+ batch_size, num_frames, _ = features.shape
456
+
457
+ noise = torch.randn_like(features) # (B, T, F)
458
+
459
+ # Sampling t from uniform distribution
460
+ if is_training:
461
+ t = torch.rand(batch_size, 1, 1, device=device)
462
+ else:
463
+ t = (
464
+ (torch.arange(batch_size, device=device) / batch_size)
465
+ .unsqueeze(1)
466
+ .unsqueeze(2)
467
+ )
468
+ with torch.set_grad_enabled(is_training):
469
+
470
+ loss = model(
471
+ tokens=tokens,
472
+ features=features,
473
+ features_lens=features_lens,
474
+ noise=noise,
475
+ t=t,
476
+ condition_drop_ratio=params.condition_drop_ratio,
477
+ )
478
+
479
+ assert loss.requires_grad == is_training
480
+ info = MetricsTracker()
481
+ num_frames = features_lens.sum().item()
482
+ info["frames"] = num_frames
483
+ info["loss"] = loss.detach().cpu().item() * num_frames
484
+
485
+ return loss, info
486
+
487
+
488
+ def train_one_epoch(
489
+ params: AttributeDict,
490
+ model: Union[nn.Module, DDP],
491
+ optimizer: Optimizer,
492
+ scheduler: LRSchedulerType,
493
+ train_dl: torch.utils.data.DataLoader,
494
+ valid_dl: torch.utils.data.DataLoader,
495
+ scaler: GradScaler,
496
+ model_avg: Optional[nn.Module] = None,
497
+ tb_writer: Optional[SummaryWriter] = None,
498
+ world_size: int = 1,
499
+ rank: int = 0,
500
+ ) -> None:
501
+ """Train the model for one epoch.
502
+
503
+ The training loss from the mean of all frames is saved in
504
+ `params.train_loss`. It runs the validation process every
505
+ `params.valid_interval` batches or every epochs.
506
+
507
+ Args:
508
+ params:
509
+ It is returned by :func:`get_params`.
510
+ model:
511
+ The model for training.
512
+ optimizer:
513
+ The optimizer.
514
+ scheduler:
515
+ The learning rate scheduler, we call step() every epoch.
516
+ train_dl:
517
+ Dataloader for the training dataset.
518
+ valid_dl:
519
+ Dataloader for the validation dataset.
520
+ scaler:
521
+ The scaler used for mix precision training.
522
+ tb_writer:
523
+ Writer to write log messages to tensorboard.
524
+ world_size:
525
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
526
+ rank:
527
+ The rank of the node in DDP training. If no DDP is used, it should
528
+ be set to 0.
529
+ """
530
+ model.train()
531
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
532
+
533
+ # used to track the stats over iterations in one epoch
534
+ tot_loss = MetricsTracker()
535
+
536
+ saved_bad_model = False
537
+
538
+ def save_bad_model(suffix: str = ""):
539
+ save_checkpoint(
540
+ filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
541
+ model=model,
542
+ model_avg=model_avg,
543
+ params=params,
544
+ optimizer=optimizer,
545
+ scheduler=scheduler,
546
+ sampler=train_dl.sampler,
547
+ scaler=scaler,
548
+ rank=0,
549
+ )
550
+
551
+ for batch_idx, batch in enumerate(train_dl):
552
+
553
+ if batch_idx % 10 == 0:
554
+ if params.finetune:
555
+ set_batch_count(model, get_adjusted_batch_count(params) + 100000)
556
+ else:
557
+ set_batch_count(model, get_adjusted_batch_count(params))
558
+
559
+ if (
560
+ params.valid_by_epoch and batch_idx == 0 and not params.print_diagnostics
561
+ ) or (
562
+ not params.valid_by_epoch
563
+ and params.batch_idx_train % params.valid_interval == 0
564
+ and not params.print_diagnostics
565
+ ):
566
+ logging.info("Computing validation loss")
567
+ valid_info = compute_validation_loss(
568
+ params=params,
569
+ model=model,
570
+ valid_dl=valid_dl,
571
+ world_size=world_size,
572
+ )
573
+ model.train()
574
+ logging.info(
575
+ f"Epoch {params.cur_epoch}, global_batch_idx: {params.batch_idx_train},"
576
+ f" validation: {valid_info}"
577
+ )
578
+ logging.info(
579
+ f"Maximum memory allocated so far is "
580
+ f"{torch.cuda.max_memory_allocated() // 1000000}MB"
581
+ )
582
+ if tb_writer is not None:
583
+ valid_info.write_summary(
584
+ tb_writer, "train/valid_", params.batch_idx_train
585
+ )
586
+
587
+ params.batch_idx_train += 1
588
+
589
+ batch_size = len(batch["text"])
590
+
591
+ tokens, features, features_lens = prepare_input(
592
+ params=params,
593
+ batch=batch,
594
+ device=device,
595
+ return_tokens=True,
596
+ return_feature=True,
597
+ )
598
+
599
+ try:
600
+ with torch_autocast(dtype=torch.float16, enabled=params.use_fp16):
601
+ loss, loss_info = compute_fbank_loss(
602
+ params=params,
603
+ model=model,
604
+ features=features,
605
+ features_lens=features_lens,
606
+ tokens=tokens,
607
+ is_training=True,
608
+ )
609
+
610
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
611
+
612
+ scaler.scale(loss).backward()
613
+
614
+ scheduler.step_batch(params.batch_idx_train)
615
+ # Use the number of hours of speech to adjust the learning rate
616
+ if params.lr_hours > 0:
617
+ scheduler.step_epoch(
618
+ params.batch_idx_train
619
+ * params.max_duration
620
+ * params.world_size
621
+ / 3600
622
+ )
623
+ scaler.step(optimizer)
624
+ scaler.update()
625
+ optimizer.zero_grad()
626
+ except Exception as e:
627
+ logging.info(f"Caught exception : {e}.")
628
+ save_bad_model()
629
+ raise
630
+
631
+ if params.print_diagnostics and batch_idx == 5:
632
+ return
633
+
634
+ if (
635
+ rank == 0
636
+ and params.batch_idx_train > 0
637
+ and params.batch_idx_train % params.average_period == 0
638
+ ):
639
+ update_averaged_model(
640
+ params=params,
641
+ model_cur=model,
642
+ model_avg=model_avg,
643
+ )
644
+
645
+ if (
646
+ params.batch_idx_train > 0
647
+ and params.batch_idx_train % params.save_every_n == 0
648
+ ):
649
+ save_checkpoint_with_global_batch_idx(
650
+ out_dir=params.exp_dir,
651
+ global_batch_idx=params.batch_idx_train,
652
+ model=model,
653
+ model_avg=model_avg,
654
+ params=params,
655
+ optimizer=optimizer,
656
+ scheduler=scheduler,
657
+ sampler=train_dl.sampler,
658
+ scaler=scaler,
659
+ rank=rank,
660
+ )
661
+ remove_checkpoints(
662
+ out_dir=params.exp_dir,
663
+ topk=params.keep_last_k,
664
+ rank=rank,
665
+ )
666
+ if params.num_iters > 0 and params.batch_idx_train > params.num_iters:
667
+ break
668
+ if params.batch_idx_train % 100 == 0 and params.use_fp16:
669
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
670
+ # of the grad scaler is configurable, but we can't configure it to have
671
+ # different behavior depending on the current grad scale.
672
+ cur_grad_scale = scaler._scale.item()
673
+
674
+ if cur_grad_scale < 1024.0 or (
675
+ cur_grad_scale < 4096.0 and params.batch_idx_train % 400 == 0
676
+ ):
677
+ scaler.update(cur_grad_scale * 2.0)
678
+ if cur_grad_scale < 0.01:
679
+ if not saved_bad_model:
680
+ save_bad_model(suffix="-first-warning")
681
+ saved_bad_model = True
682
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
683
+ if cur_grad_scale < 1.0e-05:
684
+ save_bad_model()
685
+ raise RuntimeError(
686
+ f"grad_scale is too small, exiting: {cur_grad_scale}"
687
+ )
688
+
689
+ if params.batch_idx_train % params.log_interval == 0:
690
+ cur_lr = max(scheduler.get_last_lr())
691
+ cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
692
+
693
+ logging.info(
694
+ f"Epoch {params.cur_epoch}, batch {batch_idx}, "
695
+ f"global_batch_idx: {params.batch_idx_train}, "
696
+ f"batch size: {batch_size}, "
697
+ f"loss[{loss_info}], tot_loss[{tot_loss}], "
698
+ f"cur_lr: {cur_lr:.2e}, "
699
+ + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
700
+ )
701
+
702
+ if tb_writer is not None:
703
+ tb_writer.add_scalar(
704
+ "train/learning_rate", cur_lr, params.batch_idx_train
705
+ )
706
+ loss_info.write_summary(
707
+ tb_writer, "train/current_", params.batch_idx_train
708
+ )
709
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
710
+ if params.use_fp16:
711
+ tb_writer.add_scalar(
712
+ "train/grad_scale",
713
+ cur_grad_scale,
714
+ params.batch_idx_train,
715
+ )
716
+
717
+ loss_value = tot_loss["loss"]
718
+ params.train_loss = loss_value
719
+ if params.train_loss < params.best_train_loss:
720
+ params.best_train_epoch = params.cur_epoch
721
+ params.best_train_loss = params.train_loss
722
+
723
+
724
+ def compute_validation_loss(
725
+ params: AttributeDict,
726
+ model: Union[nn.Module, DDP],
727
+ valid_dl: torch.utils.data.DataLoader,
728
+ world_size: int = 1,
729
+ ) -> MetricsTracker:
730
+ """Run the validation process."""
731
+
732
+ model.eval()
733
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
734
+
735
+ # used to summary the stats over iterations
736
+ tot_loss = MetricsTracker()
737
+
738
+ for batch_idx, batch in enumerate(valid_dl):
739
+ tokens, features, features_lens = prepare_input(
740
+ params=params,
741
+ batch=batch,
742
+ device=device,
743
+ return_tokens=True,
744
+ return_feature=True,
745
+ )
746
+
747
+ loss, loss_info = compute_fbank_loss(
748
+ params=params,
749
+ model=model,
750
+ features=features,
751
+ features_lens=features_lens,
752
+ tokens=tokens,
753
+ is_training=False,
754
+ )
755
+ assert loss.requires_grad is False
756
+ tot_loss = tot_loss + loss_info
757
+
758
+ if world_size > 1:
759
+ tot_loss.reduce(loss.device)
760
+
761
+ loss_value = tot_loss["loss"]
762
+ if loss_value < params.best_valid_loss:
763
+ params.best_valid_epoch = params.cur_epoch
764
+ params.best_valid_loss = loss_value
765
+
766
+ return tot_loss
767
+
768
+
769
+ def display_and_save_batch(
770
+ batch: dict,
771
+ params: AttributeDict,
772
+ ) -> None:
773
+ """Display the batch statistics and save the batch into disk.
774
+
775
+ Args:
776
+ batch:
777
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
778
+ for the content in it.
779
+ params:
780
+ Parameters for training. See :func:`get_params`.
781
+ sp:
782
+ The BPE model.
783
+ """
784
+ from lhotse.utils import uuid4
785
+
786
+ filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
787
+ logging.info(f"Saving batch to {filename}")
788
+ torch.save(batch, filename)
789
+
790
+ features = batch["features"]
791
+ tokens = batch["tokens"]
792
+
793
+ logging.info(f"features shape: {features.shape}")
794
+ num_tokens = sum(len(i) for i in tokens)
795
+ logging.info(f"num tokens: {num_tokens}")
796
+
797
+
798
+ def scan_pessimistic_batches_for_oom(
799
+ model: Union[nn.Module, DDP],
800
+ train_dl: torch.utils.data.DataLoader,
801
+ optimizer: torch.optim.Optimizer,
802
+ params: AttributeDict,
803
+ ):
804
+ from lhotse.dataset import find_pessimistic_batches
805
+
806
+ logging.info(
807
+ "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
808
+ )
809
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
810
+
811
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
812
+ for criterion, cuts in batches.items():
813
+ batch = train_dl.dataset[cuts]
814
+ tokens, features, features_lens = prepare_input(
815
+ params=params,
816
+ batch=batch,
817
+ device=device,
818
+ return_tokens=True,
819
+ return_feature=True,
820
+ )
821
+ try:
822
+ with torch_autocast(dtype=torch.float16, enabled=params.use_fp16):
823
+
824
+ loss, loss_info = compute_fbank_loss(
825
+ params=params,
826
+ model=model,
827
+ features=features,
828
+ features_lens=features_lens,
829
+ tokens=tokens,
830
+ is_training=True,
831
+ )
832
+ loss.backward()
833
+ optimizer.zero_grad()
834
+ except Exception as e:
835
+ if "CUDA out of memory" in str(e):
836
+ logging.error(
837
+ "Your GPU ran out of memory with the current "
838
+ "max_duration setting. We recommend decreasing "
839
+ "max_duration and trying again.\n"
840
+ f"Failing criterion: {criterion} "
841
+ f"(={crit_values[criterion]}) ..."
842
+ )
843
+ display_and_save_batch(batch, params=params)
844
+ raise
845
+ logging.info(
846
+ f"Maximum memory allocated so far is "
847
+ f"{torch.cuda.max_memory_allocated() // 1000000}MB"
848
+ )
849
+
850
+
851
+ def tokenize_text(c: Cut, tokenizer):
852
+ if hasattr(c.supervisions[0], "tokens"):
853
+ tokens = tokenizer.tokens_to_token_ids([c.supervisions[0].tokens])
854
+ else:
855
+ tokens = tokenizer.texts_to_token_ids([c.supervisions[0].text])
856
+ c.supervisions[0].tokens = tokens[0]
857
+ return c
858
+
859
+
860
+ def run(rank, world_size, args):
861
+ """
862
+ Args:
863
+ rank:
864
+ It is a value between 0 and `world_size-1`, which is
865
+ passed automatically by `mp.spawn()` in :func:`main`.
866
+ The node with rank 0 is responsible for saving checkpoint.
867
+ world_size:
868
+ Number of GPUs for DDP training.
869
+ args:
870
+ The return value of get_parser().parse_args()
871
+ """
872
+ params = get_params()
873
+ params.update(vars(args))
874
+ params.valid_interval = params.save_every_n
875
+ # Set epoch to a large number to ignore it.
876
+ if params.num_iters > 0:
877
+ params.num_epochs = 1000000
878
+ with open(params.model_config, "r") as f:
879
+ model_config = json.load(f)
880
+ params.update(model_config["model"])
881
+ params.update(model_config["feature"])
882
+
883
+ fix_random_seed(params.seed)
884
+ if world_size > 1:
885
+ setup_dist(rank, world_size, params.master_port)
886
+
887
+ os.makedirs(f"{params.exp_dir}", exist_ok=True)
888
+ copyfile(src=params.model_config, dst=f"{params.exp_dir}/model.json")
889
+ copyfile(src=params.token_file, dst=f"{params.exp_dir}/tokens.txt")
890
+ setup_logger(f"{params.exp_dir}/log/log-train")
891
+
892
+ if args.tensorboard and rank == 0:
893
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
894
+ else:
895
+ tb_writer = None
896
+
897
+ if torch.cuda.is_available():
898
+ params.device = torch.device("cuda", rank)
899
+ else:
900
+ params.device = torch.device("cpu")
901
+ logging.info(f"Device: {params.device}")
902
+
903
+ if params.tokenizer == "emilia":
904
+ tokenizer = EmiliaTokenizer(token_file=params.token_file)
905
+ elif params.tokenizer == "libritts":
906
+ tokenizer = LibriTTSTokenizer(token_file=params.token_file)
907
+ elif params.tokenizer == "espeak":
908
+ tokenizer = EspeakTokenizer(token_file=params.token_file, lang=params.lang)
909
+ else:
910
+ assert params.tokenizer == "simple"
911
+ tokenizer = SimpleTokenizer(token_file=params.token_file)
912
+
913
+ tokenizer_config = {"vocab_size": tokenizer.vocab_size, "pad_id": tokenizer.pad_id}
914
+ params.update(tokenizer_config)
915
+
916
+ logging.info(params)
917
+
918
+ logging.info("About to create model")
919
+
920
+ model = ZipVoice(
921
+ **model_config["model"],
922
+ **tokenizer_config,
923
+ )
924
+
925
+ if params.checkpoint is not None:
926
+ logging.info(f"Loading pre-trained model from {params.checkpoint}")
927
+ _ = load_checkpoint(filename=params.checkpoint, model=model, strict=True)
928
+ num_param = sum([p.numel() for p in model.parameters()])
929
+ logging.info(f"Number of parameters : {num_param}")
930
+
931
+ model_avg: Optional[nn.Module] = None
932
+ if rank == 0:
933
+ # model_avg is only used with rank 0
934
+ model_avg = copy.deepcopy(model).to(torch.float64)
935
+
936
+ assert params.start_epoch > 0, params.start_epoch
937
+ if params.start_epoch > 1:
938
+ checkpoints = resume_checkpoint(params=params, model=model, model_avg=model_avg)
939
+
940
+ model = model.to(params.device)
941
+ if world_size > 1:
942
+ logging.info("Using DDP")
943
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
944
+
945
+ optimizer = ScaledAdam(
946
+ get_parameter_groups_with_lrs(
947
+ model,
948
+ lr=params.base_lr,
949
+ include_names=True,
950
+ ),
951
+ lr=params.base_lr, # should have no effect
952
+ clipping_scale=2.0,
953
+ )
954
+
955
+ assert params.lr_hours >= 0
956
+
957
+ if params.finetune:
958
+ scheduler = FixedLRScheduler(optimizer)
959
+ elif params.lr_hours > 0:
960
+ scheduler = Eden(optimizer, params.lr_batches, params.lr_hours)
961
+ else:
962
+ scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
963
+
964
+ scaler = create_grad_scaler(enabled=params.use_fp16)
965
+
966
+ if params.start_epoch > 1 and checkpoints is not None:
967
+ # load state_dict for optimizers
968
+ if "optimizer" in checkpoints:
969
+ logging.info("Loading optimizer state dict")
970
+ optimizer.load_state_dict(checkpoints["optimizer"])
971
+
972
+ # load state_dict for schedulers
973
+ if "scheduler" in checkpoints:
974
+ logging.info("Loading scheduler state dict")
975
+ scheduler.load_state_dict(checkpoints["scheduler"])
976
+
977
+ if "grad_scaler" in checkpoints:
978
+ logging.info("Loading grad scaler state dict")
979
+ scaler.load_state_dict(checkpoints["grad_scaler"])
980
+
981
+ if params.print_diagnostics:
982
+ opts = diagnostics.TensorDiagnosticOptions(
983
+ 512
984
+ ) # allow 4 megabytes per sub-module
985
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
986
+
987
+ if params.inf_check:
988
+ register_inf_check_hooks(model)
989
+
990
+ def remove_short_and_long_utt(c: Cut, min_len: float, max_len: float):
991
+ if c.duration < min_len or c.duration > max_len:
992
+ return False
993
+ return True
994
+
995
+ _remove_short_and_long_utt = partial(
996
+ remove_short_and_long_utt, min_len=params.min_len, max_len=params.max_len
997
+ )
998
+
999
+ datamodule = TtsDataModule(args)
1000
+ if params.dataset == "emilia":
1001
+ train_cuts = CutSet.mux(
1002
+ datamodule.train_emilia_EN_cuts(),
1003
+ datamodule.train_emilia_ZH_cuts(),
1004
+ weights=[46000, 49000],
1005
+ )
1006
+ train_cuts = train_cuts.filter(_remove_short_and_long_utt)
1007
+ dev_cuts = CutSet.mux(
1008
+ datamodule.dev_emilia_EN_cuts(),
1009
+ datamodule.dev_emilia_ZH_cuts(),
1010
+ weights=[0.5, 0.5],
1011
+ )
1012
+ elif params.dataset == "libritts":
1013
+ train_cuts = datamodule.train_libritts_cuts()
1014
+ train_cuts = train_cuts.filter(_remove_short_and_long_utt)
1015
+ dev_cuts = datamodule.dev_libritts_cuts()
1016
+ else:
1017
+ assert params.dataset == "custom"
1018
+ train_cuts = datamodule.train_custom_cuts(params.train_manifest)
1019
+ train_cuts = train_cuts.filter(_remove_short_and_long_utt)
1020
+ dev_cuts = datamodule.dev_custom_cuts(params.dev_manifest)
1021
+ # To avoid OOM issues due to too long dev cuts
1022
+ dev_cuts = dev_cuts.filter(_remove_short_and_long_utt)
1023
+
1024
+ if params.tokenizer in ["emilia", "espeak", "dialog"]:
1025
+ if not hasattr(train_cuts[0].supervisions[0], "tokens") or not hasattr(
1026
+ dev_cuts[0].supervisions[0], "tokens"
1027
+ ):
1028
+ logging.warning(
1029
+ f"Using {params.tokenizer} tokenizer but tokens are not prepared,"
1030
+ f"will tokenize on-the-fly, which can slow down training significantly."
1031
+ )
1032
+ _tokenize_text = partial(tokenize_text, tokenizer=tokenizer)
1033
+ train_cuts = train_cuts.map(_tokenize_text)
1034
+ dev_cuts = dev_cuts.map(_tokenize_text)
1035
+
1036
+ train_dl = datamodule.train_dataloaders(train_cuts)
1037
+
1038
+ valid_dl = datamodule.dev_dataloaders(dev_cuts)
1039
+
1040
+ if params.scan_oom:
1041
+ scan_pessimistic_batches_for_oom(
1042
+ model=model,
1043
+ train_dl=train_dl,
1044
+ optimizer=optimizer,
1045
+ params=params,
1046
+ )
1047
+
1048
+ logging.info("Training started")
1049
+
1050
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
1051
+ logging.info(f"Start epoch {epoch}")
1052
+
1053
+ if params.lr_hours == 0:
1054
+ scheduler.step_epoch(epoch - 1)
1055
+ fix_random_seed(params.seed + epoch - 1)
1056
+ train_dl.sampler.set_epoch(epoch - 1)
1057
+
1058
+ params.cur_epoch = epoch
1059
+
1060
+ if tb_writer is not None:
1061
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
1062
+
1063
+ train_one_epoch(
1064
+ params=params,
1065
+ model=model,
1066
+ model_avg=model_avg,
1067
+ optimizer=optimizer,
1068
+ scheduler=scheduler,
1069
+ train_dl=train_dl,
1070
+ valid_dl=valid_dl,
1071
+ scaler=scaler,
1072
+ tb_writer=tb_writer,
1073
+ world_size=world_size,
1074
+ rank=rank,
1075
+ )
1076
+
1077
+ if params.num_iters > 0 and params.batch_idx_train > params.num_iters:
1078
+ break
1079
+
1080
+ if params.print_diagnostics:
1081
+ diagnostic.print_diagnostics()
1082
+ break
1083
+
1084
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
1085
+ save_checkpoint(
1086
+ filename=filename,
1087
+ params=params,
1088
+ model=model,
1089
+ model_avg=model_avg,
1090
+ optimizer=optimizer,
1091
+ scheduler=scheduler,
1092
+ sampler=train_dl.sampler,
1093
+ scaler=scaler,
1094
+ rank=rank,
1095
+ )
1096
+
1097
+ if rank == 0:
1098
+ if params.best_train_epoch == params.cur_epoch:
1099
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
1100
+ copyfile(src=filename, dst=best_train_filename)
1101
+
1102
+ if params.best_valid_epoch == params.cur_epoch:
1103
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
1104
+ copyfile(src=filename, dst=best_valid_filename)
1105
+
1106
+ logging.info("Done!")
1107
+
1108
+ if world_size > 1:
1109
+ torch.distributed.barrier()
1110
+ cleanup_dist()
1111
+
1112
+
1113
+ def main():
1114
+ parser = get_parser()
1115
+ TtsDataModule.add_arguments(parser)
1116
+ args = parser.parse_args()
1117
+ args.exp_dir = Path(args.exp_dir)
1118
+
1119
+ world_size = args.world_size
1120
+ assert world_size >= 1
1121
+ if world_size > 1:
1122
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
1123
+ else:
1124
+ run(rank=0, world_size=1, args=args)
1125
+
1126
+
1127
+ if __name__ == "__main__":
1128
+ torch.set_num_threads(1)
1129
+ torch.set_num_interop_threads(1)
1130
+ main()
zipvoice/bin/train_zipvoice_dialog.py ADDED
@@ -0,0 +1,980 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2025 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """
19
+ This script trains a ZipVoice-Dialog model.
20
+
21
+ Usage:
22
+
23
+ python3 -m zipvoice.bin.train_zipvoice_dialog \
24
+ --world-size 8 \
25
+ --use-fp16 1 \
26
+ --base-lr 0.0001 \
27
+ --max-duration 500 \
28
+ --checkpoint download/zipvoice/model.pt \
29
+ --model-config conf/zipvoice_base.json \
30
+ --token-file "data/tokens_dialog.txt" \
31
+ --dataset opendialog \
32
+ --manifest-dir data/fbank \
33
+ --exp-dir exp/zipvoice_dialog
34
+ """
35
+
36
+ import argparse
37
+ import copy
38
+ import json
39
+ import logging
40
+ import os
41
+ from functools import partial
42
+ from pathlib import Path
43
+ from shutil import copyfile
44
+ from typing import List, Optional, Tuple, Union
45
+
46
+ import torch
47
+ import torch.multiprocessing as mp
48
+ import torch.nn as nn
49
+ from lhotse.cut import Cut, CutSet
50
+ from lhotse.utils import fix_random_seed
51
+ from torch import Tensor
52
+ from torch.nn.parallel import DistributedDataParallel as DDP
53
+ from torch.optim import Optimizer
54
+ from torch.utils.tensorboard import SummaryWriter
55
+
56
+ import zipvoice.utils.diagnostics as diagnostics
57
+ from zipvoice.bin.train_zipvoice import (
58
+ display_and_save_batch,
59
+ get_params,
60
+ tokenize_text,
61
+ )
62
+ from zipvoice.dataset.datamodule import TtsDataModule
63
+ from zipvoice.models.zipvoice_dialog import ZipVoiceDialog
64
+ from zipvoice.tokenizer.tokenizer import DialogTokenizer
65
+ from zipvoice.utils.checkpoint import (
66
+ load_checkpoint,
67
+ load_checkpoint_extend_vocab_size,
68
+ remove_checkpoints,
69
+ resume_checkpoint,
70
+ save_checkpoint,
71
+ save_checkpoint_with_global_batch_idx,
72
+ update_averaged_model,
73
+ )
74
+ from zipvoice.utils.common import (
75
+ AttributeDict,
76
+ GradScaler,
77
+ MetricsTracker,
78
+ cleanup_dist,
79
+ create_grad_scaler,
80
+ get_adjusted_batch_count,
81
+ get_parameter_groups_with_lrs,
82
+ prepare_input,
83
+ set_batch_count,
84
+ setup_dist,
85
+ setup_logger,
86
+ str2bool,
87
+ torch_autocast,
88
+ )
89
+ from zipvoice.utils.hooks import register_inf_check_hooks
90
+ from zipvoice.utils.lr_scheduler import FixedLRScheduler, LRScheduler
91
+ from zipvoice.utils.optim import ScaledAdam
92
+
93
+ LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, LRScheduler]
94
+
95
+
96
+ def get_parser():
97
+ parser = argparse.ArgumentParser(
98
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
99
+ )
100
+
101
+ parser.add_argument(
102
+ "--world-size",
103
+ type=int,
104
+ default=1,
105
+ help="Number of GPUs for DDP training.",
106
+ )
107
+
108
+ parser.add_argument(
109
+ "--master-port",
110
+ type=int,
111
+ default=12356,
112
+ help="Master port to use for DDP training.",
113
+ )
114
+
115
+ parser.add_argument(
116
+ "--tensorboard",
117
+ type=str2bool,
118
+ default=True,
119
+ help="Should various information be logged in tensorboard.",
120
+ )
121
+
122
+ parser.add_argument(
123
+ "--num-epochs",
124
+ type=int,
125
+ default=8,
126
+ help="Number of epochs to train.",
127
+ )
128
+
129
+ parser.add_argument(
130
+ "--num-iters",
131
+ type=int,
132
+ default=60000,
133
+ help="Number of iter to train, will ignore num_epochs if > 0.",
134
+ )
135
+
136
+ parser.add_argument(
137
+ "--start-epoch",
138
+ type=int,
139
+ default=1,
140
+ help="""Resume training from this epoch. It should be positive.
141
+ If larger than 1, it will load checkpoint from
142
+ exp-dir/epoch-{start_epoch-1}.pt
143
+ """,
144
+ )
145
+
146
+ parser.add_argument(
147
+ "--checkpoint",
148
+ type=str,
149
+ required=True,
150
+ help="""Checkpoints of pre-trained models, either a ZipVoice model or a
151
+ ZipVoice-Dialog model.
152
+ """,
153
+ )
154
+
155
+ parser.add_argument(
156
+ "--exp-dir",
157
+ type=str,
158
+ default="exp/zipvoice_dialog",
159
+ help="""The experiment dir.
160
+ It specifies the directory where all training related
161
+ files, e.g., checkpoints, log, etc, are saved
162
+ """,
163
+ )
164
+
165
+ parser.add_argument(
166
+ "--base-lr", type=float, default=0.0001, help="The base learning rate."
167
+ )
168
+
169
+ parser.add_argument(
170
+ "--ref-duration",
171
+ type=float,
172
+ default=50,
173
+ help="""Reference batch duration for purposes of adjusting batch counts for"
174
+ setting various schedules inside the model".
175
+ """,
176
+ )
177
+
178
+ parser.add_argument(
179
+ "--finetune",
180
+ type=str2bool,
181
+ default=False,
182
+ help="Whether to fine-tune from our pre-traied ZipVoice-Dialog model."
183
+ "False means to fine-tune from a pre-trained ZipVoice model.",
184
+ )
185
+
186
+ parser.add_argument(
187
+ "--seed",
188
+ type=int,
189
+ default=42,
190
+ help="The seed for random generators intended for reproducibility",
191
+ )
192
+
193
+ parser.add_argument(
194
+ "--print-diagnostics",
195
+ type=str2bool,
196
+ default=False,
197
+ help="Accumulate stats on activations, print them and exit.",
198
+ )
199
+
200
+ parser.add_argument(
201
+ "--scan-oom",
202
+ type=str2bool,
203
+ default=False,
204
+ help="Scan pessimistic batches to see whether they cause OOMs.",
205
+ )
206
+
207
+ parser.add_argument(
208
+ "--inf-check",
209
+ type=str2bool,
210
+ default=False,
211
+ help="Add hooks to check for infinite module outputs and gradients.",
212
+ )
213
+
214
+ parser.add_argument(
215
+ "--save-every-n",
216
+ type=int,
217
+ default=5000,
218
+ help="""Save checkpoint after processing this number of batches"
219
+ periodically. We save checkpoint to exp-dir/ whenever
220
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
221
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
222
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
223
+ end of each epoch where `xxx` is the epoch number counting from 1.
224
+ """,
225
+ )
226
+
227
+ parser.add_argument(
228
+ "--keep-last-k",
229
+ type=int,
230
+ default=30,
231
+ help="""Only keep this number of checkpoints on disk.
232
+ For instance, if it is 3, there are only 3 checkpoints
233
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
234
+ It does not affect checkpoints with name `epoch-xxx.pt`.
235
+ """,
236
+ )
237
+
238
+ parser.add_argument(
239
+ "--average-period",
240
+ type=int,
241
+ default=200,
242
+ help="""Update the averaged model, namely `model_avg`, after processing
243
+ this number of batches. `model_avg` is a separate version of model,
244
+ in which each floating-point parameter is the average of all the
245
+ parameters from the start of training. Each time we take the average,
246
+ we do: `model_avg = model * (average_period / batch_idx_train) +
247
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
248
+ """,
249
+ )
250
+
251
+ parser.add_argument(
252
+ "--use-fp16",
253
+ type=str2bool,
254
+ default=True,
255
+ help="Whether to use half precision training.",
256
+ )
257
+
258
+ parser.add_argument(
259
+ "--feat-scale",
260
+ type=float,
261
+ default=0.1,
262
+ help="The scale factor of fbank feature",
263
+ )
264
+
265
+ parser.add_argument(
266
+ "--condition-drop-ratio",
267
+ type=float,
268
+ default=0.2,
269
+ help="The drop rate of text condition during training.",
270
+ )
271
+
272
+ parser.add_argument(
273
+ "--dataset",
274
+ type=str,
275
+ default="opendialog",
276
+ choices=["opendialog", "custom"],
277
+ help="The used training dataset",
278
+ )
279
+
280
+ parser.add_argument(
281
+ "--train-manifest",
282
+ type=str,
283
+ help="Path of the training manifest",
284
+ )
285
+
286
+ parser.add_argument(
287
+ "--dev-manifest",
288
+ type=str,
289
+ help="Path of the validation manifest",
290
+ )
291
+
292
+ parser.add_argument(
293
+ "--min-len",
294
+ type=float,
295
+ default=1.0,
296
+ help="The minimum audio length used for training",
297
+ )
298
+
299
+ parser.add_argument(
300
+ "--max-len",
301
+ type=float,
302
+ default=30.0,
303
+ help="The maximum audio length used for training",
304
+ )
305
+
306
+ parser.add_argument(
307
+ "--model-config",
308
+ type=str,
309
+ default="zipvoice_base.json",
310
+ help="The model configuration file.",
311
+ )
312
+
313
+ parser.add_argument(
314
+ "--token-file",
315
+ type=str,
316
+ default="data/tokens_dialog.txt",
317
+ help="The file that contains information that maps tokens to ids,"
318
+ "which is a text file with '{token}\t{token_id}' per line.",
319
+ )
320
+
321
+ return parser
322
+
323
+
324
+ def compute_fbank_loss(
325
+ params: AttributeDict,
326
+ model: Union[nn.Module, DDP],
327
+ features: Tensor,
328
+ features_lens: Tensor,
329
+ tokens: List[List[int]],
330
+ is_training: bool,
331
+ ) -> Tuple[Tensor, MetricsTracker]:
332
+ """
333
+ Compute loss given the model and its inputs.
334
+
335
+ Args:
336
+ params:
337
+ Parameters for training. See :func:`get_params`.
338
+ model:
339
+ The model for training.
340
+ features:
341
+ The target acoustic feature.
342
+ features_lens:
343
+ The number of frames of each utterance.
344
+ tokens:
345
+ Input tokens that representing the transcripts.
346
+ is_training:
347
+ True for training. False for validation. When it is True, this
348
+ function enables autograd during computation; when it is False, it
349
+ disables autograd.
350
+ """
351
+
352
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
353
+
354
+ batch_size, num_frames, _ = features.shape
355
+
356
+ noise = torch.randn_like(features) # (B, T, F)
357
+
358
+ # Sampling t from uniform distribution
359
+ if is_training:
360
+ t = torch.rand(batch_size, 1, 1, device=device)
361
+ else:
362
+ t = (
363
+ (torch.arange(batch_size, device=device) / batch_size)
364
+ .unsqueeze(1)
365
+ .unsqueeze(2)
366
+ )
367
+ with torch.set_grad_enabled(is_training):
368
+
369
+ loss = model(
370
+ tokens=tokens,
371
+ features=features,
372
+ features_lens=features_lens,
373
+ noise=noise,
374
+ t=t,
375
+ condition_drop_ratio=params.condition_drop_ratio,
376
+ )
377
+
378
+ assert loss.requires_grad == is_training
379
+ info = MetricsTracker()
380
+ num_frames = features_lens.sum().item()
381
+ info["frames"] = num_frames
382
+ info["loss"] = loss.detach().cpu().item() * num_frames
383
+
384
+ return loss, info
385
+
386
+
387
+ def train_one_epoch(
388
+ params: AttributeDict,
389
+ model: Union[nn.Module, DDP],
390
+ optimizer: Optimizer,
391
+ scheduler: LRSchedulerType,
392
+ train_dl: torch.utils.data.DataLoader,
393
+ valid_dl: torch.utils.data.DataLoader,
394
+ scaler: GradScaler,
395
+ model_avg: Optional[nn.Module] = None,
396
+ tb_writer: Optional[SummaryWriter] = None,
397
+ world_size: int = 1,
398
+ rank: int = 0,
399
+ ) -> None:
400
+ """Train the model for one epoch.
401
+
402
+ The training loss from the mean of all frames is saved in
403
+ `params.train_loss`. It runs the validation process every
404
+ `params.valid_interval` batches.
405
+
406
+ Args:
407
+ params:
408
+ It is returned by :func:`get_params`.
409
+ model:
410
+ The model for training.
411
+ optimizer:
412
+ The optimizer.
413
+ scheduler:
414
+ The learning rate scheduler, we call step() every epoch.
415
+ train_dl:
416
+ Dataloader for the training dataset.
417
+ valid_dl:
418
+ Dataloader for the validation dataset.
419
+ scaler:
420
+ The scaler used for mix precision training.
421
+ tb_writer:
422
+ Writer to write log messages to tensorboard.
423
+ world_size:
424
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
425
+ rank:
426
+ The rank of the node in DDP training. If no DDP is used, it should
427
+ be set to 0.
428
+ """
429
+ model.train()
430
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
431
+
432
+ # used to track the stats over iterations in one epoch
433
+ tot_loss = MetricsTracker()
434
+
435
+ saved_bad_model = False
436
+
437
+ def save_bad_model(suffix: str = ""):
438
+ save_checkpoint(
439
+ filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
440
+ model=model,
441
+ model_avg=model_avg,
442
+ params=params,
443
+ optimizer=optimizer,
444
+ scheduler=scheduler,
445
+ sampler=train_dl.sampler,
446
+ scaler=scaler,
447
+ rank=0,
448
+ )
449
+
450
+ for batch_idx, batch in enumerate(train_dl):
451
+
452
+ if batch_idx % 10 == 0:
453
+ set_batch_count(model, get_adjusted_batch_count(params) + 100000)
454
+
455
+ if (
456
+ params.batch_idx_train > 0
457
+ and params.batch_idx_train % params.valid_interval == 0
458
+ and not params.print_diagnostics
459
+ ):
460
+ logging.info("Computing validation loss")
461
+ valid_info = compute_validation_loss(
462
+ params=params,
463
+ model=model,
464
+ valid_dl=valid_dl,
465
+ world_size=world_size,
466
+ )
467
+ model.train()
468
+ logging.info(
469
+ f"Epoch {params.cur_epoch}, global_batch_idx: {params.batch_idx_train},"
470
+ f" validation: {valid_info}"
471
+ )
472
+ logging.info(
473
+ f"Maximum memory allocated so far is "
474
+ f"{torch.cuda.max_memory_allocated() // 1000000}MB"
475
+ )
476
+ if tb_writer is not None:
477
+ valid_info.write_summary(
478
+ tb_writer, "train/valid_", params.batch_idx_train
479
+ )
480
+
481
+ params.batch_idx_train += 1
482
+
483
+ batch_size = len(batch["text"])
484
+
485
+ tokens, features, features_lens = prepare_input(
486
+ params=params,
487
+ batch=batch,
488
+ device=device,
489
+ return_tokens=True,
490
+ return_feature=True,
491
+ )
492
+
493
+ try:
494
+ with torch_autocast(dtype=torch.float16, enabled=params.use_fp16):
495
+ loss, loss_info = compute_fbank_loss(
496
+ params=params,
497
+ model=model,
498
+ features=features,
499
+ features_lens=features_lens,
500
+ tokens=tokens,
501
+ is_training=True,
502
+ )
503
+
504
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
505
+
506
+ scaler.scale(loss).backward()
507
+
508
+ scheduler.step_batch(params.batch_idx_train)
509
+ scaler.step(optimizer)
510
+ scaler.update()
511
+ optimizer.zero_grad()
512
+ except Exception as e:
513
+ logging.info(f"Caught exception : {e}.")
514
+ save_bad_model()
515
+ raise
516
+
517
+ if params.print_diagnostics and batch_idx == 5:
518
+ return
519
+
520
+ if (
521
+ rank == 0
522
+ and params.batch_idx_train > 0
523
+ and params.batch_idx_train % params.average_period == 0
524
+ ):
525
+ update_averaged_model(
526
+ params=params,
527
+ model_cur=model,
528
+ model_avg=model_avg,
529
+ )
530
+
531
+ if (
532
+ params.batch_idx_train > 0
533
+ and params.batch_idx_train % params.save_every_n == 0
534
+ ):
535
+ save_checkpoint_with_global_batch_idx(
536
+ out_dir=params.exp_dir,
537
+ global_batch_idx=params.batch_idx_train,
538
+ model=model,
539
+ model_avg=model_avg,
540
+ params=params,
541
+ optimizer=optimizer,
542
+ scheduler=scheduler,
543
+ sampler=train_dl.sampler,
544
+ scaler=scaler,
545
+ rank=rank,
546
+ )
547
+ remove_checkpoints(
548
+ out_dir=params.exp_dir,
549
+ topk=params.keep_last_k,
550
+ rank=rank,
551
+ )
552
+ if params.num_iters > 0 and params.batch_idx_train > params.num_iters:
553
+ break
554
+ if params.batch_idx_train % 100 == 0 and params.use_fp16:
555
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
556
+ # of the grad scaler is configurable, but we can't configure it to have
557
+ # different behavior depending on the current grad scale.
558
+ cur_grad_scale = scaler._scale.item()
559
+
560
+ if cur_grad_scale < 1024.0 or (
561
+ cur_grad_scale < 4096.0 and params.batch_idx_train % 400 == 0
562
+ ):
563
+ scaler.update(cur_grad_scale * 2.0)
564
+ if cur_grad_scale < 0.01:
565
+ if not saved_bad_model:
566
+ save_bad_model(suffix="-first-warning")
567
+ saved_bad_model = True
568
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
569
+ if cur_grad_scale < 1.0e-05:
570
+ save_bad_model()
571
+ raise RuntimeError(
572
+ f"grad_scale is too small, exiting: {cur_grad_scale}"
573
+ )
574
+
575
+ if params.batch_idx_train % params.log_interval == 0:
576
+ cur_lr = max(scheduler.get_last_lr())
577
+ cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
578
+
579
+ logging.info(
580
+ f"Epoch {params.cur_epoch}, batch {batch_idx}, "
581
+ f"global_batch_idx: {params.batch_idx_train}, "
582
+ f"batch size: {batch_size}, "
583
+ f"loss[{loss_info}], tot_loss[{tot_loss}], "
584
+ f"cur_lr: {cur_lr:.2e}, "
585
+ + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
586
+ )
587
+
588
+ if tb_writer is not None:
589
+ tb_writer.add_scalar(
590
+ "train/learning_rate", cur_lr, params.batch_idx_train
591
+ )
592
+ loss_info.write_summary(
593
+ tb_writer, "train/current_", params.batch_idx_train
594
+ )
595
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
596
+ if params.use_fp16:
597
+ tb_writer.add_scalar(
598
+ "train/grad_scale",
599
+ cur_grad_scale,
600
+ params.batch_idx_train,
601
+ )
602
+
603
+ loss_value = tot_loss["loss"]
604
+ params.train_loss = loss_value
605
+ if params.train_loss < params.best_train_loss:
606
+ params.best_train_epoch = params.cur_epoch
607
+ params.best_train_loss = params.train_loss
608
+
609
+
610
+ def compute_validation_loss(
611
+ params: AttributeDict,
612
+ model: Union[nn.Module, DDP],
613
+ valid_dl: torch.utils.data.DataLoader,
614
+ world_size: int = 1,
615
+ ) -> MetricsTracker:
616
+ """Run the validation process."""
617
+
618
+ model.eval()
619
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
620
+
621
+ # used to summary the stats over iterations
622
+ tot_loss = MetricsTracker()
623
+
624
+ for batch_idx, batch in enumerate(valid_dl):
625
+ tokens, features, features_lens = prepare_input(
626
+ params=params,
627
+ batch=batch,
628
+ device=device,
629
+ return_tokens=True,
630
+ return_feature=True,
631
+ )
632
+
633
+ loss, loss_info = compute_fbank_loss(
634
+ params=params,
635
+ model=model,
636
+ features=features,
637
+ features_lens=features_lens,
638
+ tokens=tokens,
639
+ is_training=False,
640
+ )
641
+ assert loss.requires_grad is False
642
+ tot_loss = tot_loss + loss_info
643
+
644
+ if world_size > 1:
645
+ tot_loss.reduce(loss.device)
646
+
647
+ loss_value = tot_loss["loss"]
648
+ if loss_value < params.best_valid_loss:
649
+ params.best_valid_epoch = params.cur_epoch
650
+ params.best_valid_loss = loss_value
651
+
652
+ return tot_loss
653
+
654
+
655
+ def scan_pessimistic_batches_for_oom(
656
+ model: Union[nn.Module, DDP],
657
+ train_dl: torch.utils.data.DataLoader,
658
+ optimizer: torch.optim.Optimizer,
659
+ params: AttributeDict,
660
+ ):
661
+ from lhotse.dataset import find_pessimistic_batches
662
+
663
+ logging.info(
664
+ "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
665
+ )
666
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
667
+
668
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
669
+ for criterion, cuts in batches.items():
670
+ batch = train_dl.dataset[cuts]
671
+ tokens, features, features_lens = prepare_input(
672
+ params=params,
673
+ batch=batch,
674
+ device=device,
675
+ return_tokens=True,
676
+ return_feature=True,
677
+ )
678
+ try:
679
+ with torch_autocast(dtype=torch.float16, enabled=params.use_fp16):
680
+
681
+ loss, loss_info = compute_fbank_loss(
682
+ params=params,
683
+ model=model,
684
+ features=features,
685
+ features_lens=features_lens,
686
+ tokens=tokens,
687
+ is_training=True,
688
+ )
689
+ loss.backward()
690
+ optimizer.zero_grad()
691
+ except Exception as e:
692
+ if "CUDA out of memory" in str(e):
693
+ logging.error(
694
+ "Your GPU ran out of memory with the current "
695
+ "max_duration setting. We recommend decreasing "
696
+ "max_duration and trying again.\n"
697
+ f"Failing criterion: {criterion} "
698
+ f"(={crit_values[criterion]}) ..."
699
+ )
700
+ display_and_save_batch(batch, params=params)
701
+ raise
702
+ logging.info(
703
+ f"Maximum memory allocated so far is "
704
+ f"{torch.cuda.max_memory_allocated() // 1000000}MB"
705
+ )
706
+
707
+
708
+ def run(rank, world_size, args):
709
+ """
710
+ Args:
711
+ rank:
712
+ It is a value between 0 and `world_size-1`, which is
713
+ passed automatically by `mp.spawn()` in :func:`main`.
714
+ The node with rank 0 is responsible for saving checkpoint.
715
+ world_size:
716
+ Number of GPUs for DDP training.
717
+ args:
718
+ The return value of get_parser().parse_args()
719
+ """
720
+ params = get_params()
721
+ params.update(vars(args))
722
+ params.valid_interval = params.save_every_n
723
+ # Set epoch to a large number to ignore it.
724
+ if params.num_iters > 0:
725
+ params.num_epochs = 1000000
726
+ with open(params.model_config, "r") as f:
727
+ model_config = json.load(f)
728
+ params.update(model_config["model"])
729
+ params.update(model_config["feature"])
730
+
731
+ fix_random_seed(params.seed)
732
+ if world_size > 1:
733
+ setup_dist(rank, world_size, params.master_port)
734
+
735
+ os.makedirs(f"{params.exp_dir}", exist_ok=True)
736
+ copyfile(src=params.model_config, dst=f"{params.exp_dir}/model.json")
737
+ copyfile(src=params.token_file, dst=f"{params.exp_dir}/tokens.txt")
738
+ setup_logger(f"{params.exp_dir}/log/log-train")
739
+
740
+ if args.tensorboard and rank == 0:
741
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
742
+ else:
743
+ tb_writer = None
744
+
745
+ if torch.cuda.is_available():
746
+ params.device = torch.device("cuda", rank)
747
+ else:
748
+ params.device = torch.device("cpu")
749
+ logging.info(f"Device: {params.device}")
750
+
751
+ tokenizer = DialogTokenizer(token_file=params.token_file)
752
+ tokenizer_config = {
753
+ "vocab_size": tokenizer.vocab_size,
754
+ "pad_id": tokenizer.pad_id,
755
+ "spk_a_id": tokenizer.spk_a_id,
756
+ "spk_b_id": tokenizer.spk_b_id,
757
+ }
758
+ params.update(tokenizer_config)
759
+
760
+ logging.info(params)
761
+
762
+ logging.info("About to create model")
763
+
764
+ model = ZipVoiceDialog(
765
+ **model_config["model"],
766
+ **tokenizer_config,
767
+ )
768
+
769
+ assert params.checkpoint is not None, (
770
+ "require a pre-trained checkpoint, as training from random initialization "
771
+ "leads to uninteligible dialogue speech"
772
+ )
773
+ logging.info(f"Loading pre-trained model from {params.checkpoint}")
774
+
775
+ if params.finetune:
776
+ # load a pre-trained ZipVoice-Dialog model
777
+ _ = load_checkpoint(filename=params.checkpoint, model=model, strict=True)
778
+ else:
779
+ # load a pre-trained ZipVoice model, extend the vocab size for additional tokens
780
+ _ = load_checkpoint_extend_vocab_size(
781
+ filename=params.checkpoint,
782
+ extend_size=28,
783
+ model=model,
784
+ strict=True,
785
+ )
786
+ num_param = sum([p.numel() for p in model.parameters()])
787
+ logging.info(f"Number of parameters : {num_param}")
788
+
789
+ model_avg: Optional[nn.Module] = None
790
+ if rank == 0:
791
+ # model_avg is only used with rank 0
792
+ model_avg = copy.deepcopy(model).to(torch.float64)
793
+
794
+ assert params.start_epoch > 0, params.start_epoch
795
+ if params.start_epoch > 1:
796
+ checkpoints = resume_checkpoint(params=params, model=model, model_avg=model_avg)
797
+
798
+ model = model.to(params.device)
799
+ if world_size > 1:
800
+ logging.info("Using DDP")
801
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
802
+
803
+ optimizer = ScaledAdam(
804
+ get_parameter_groups_with_lrs(
805
+ model,
806
+ lr=params.base_lr,
807
+ include_names=True,
808
+ ),
809
+ lr=params.base_lr, # should have no effect
810
+ clipping_scale=2.0,
811
+ )
812
+
813
+ scheduler = FixedLRScheduler(optimizer)
814
+
815
+ scaler = create_grad_scaler(enabled=params.use_fp16)
816
+
817
+ if params.start_epoch > 1 and checkpoints is not None:
818
+ # load state_dict for optimizers
819
+ if "optimizer" in checkpoints:
820
+ logging.info("Loading optimizer state dict")
821
+ optimizer.load_state_dict(checkpoints["optimizer"])
822
+
823
+ # load state_dict for schedulers
824
+ if "scheduler" in checkpoints:
825
+ logging.info("Loading scheduler state dict")
826
+ scheduler.load_state_dict(checkpoints["scheduler"])
827
+
828
+ if "grad_scaler" in checkpoints:
829
+ logging.info("Loading grad scaler state dict")
830
+ scaler.load_state_dict(checkpoints["grad_scaler"])
831
+
832
+ if params.print_diagnostics:
833
+ opts = diagnostics.TensorDiagnosticOptions(
834
+ 512
835
+ ) # allow 4 megabytes per sub-module
836
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
837
+
838
+ if params.inf_check:
839
+ register_inf_check_hooks(model)
840
+
841
+ def remove_short_and_long_utt(c: Cut, min_len: float, max_len: float):
842
+ if c.duration < min_len or c.duration > max_len:
843
+ return False
844
+ return True
845
+
846
+ _remove_short_and_long_utt = partial(
847
+ remove_short_and_long_utt, min_len=params.min_len, max_len=params.max_len
848
+ )
849
+
850
+ datamodule = TtsDataModule(args)
851
+ if params.dataset == "opendialog":
852
+ train_opendialog_en_cuts = datamodule.train_opendialog_en_cuts()
853
+ train_opendialog_zh_cuts = datamodule.train_opendialog_zh_cuts().repeat(2)
854
+
855
+ train_cuts = CutSet.mux(
856
+ train_opendialog_en_cuts,
857
+ train_opendialog_zh_cuts,
858
+ weights=[
859
+ len(train_opendialog_en_cuts),
860
+ len(train_opendialog_zh_cuts),
861
+ ],
862
+ )
863
+ train_cuts = train_cuts.filter(_remove_short_and_long_utt)
864
+
865
+ dev_cuts = CutSet.mux(
866
+ datamodule.dev_opendialog_en_cuts(),
867
+ datamodule.dev_opendialog_zh_cuts(),
868
+ )
869
+ else:
870
+ assert params.dataset == "custom"
871
+ train_cuts = datamodule.train_custom_cuts(params.train_manifest)
872
+ train_cuts = train_cuts.filter(_remove_short_and_long_utt)
873
+ dev_cuts = datamodule.dev_custom_cuts(params.dev_manifest)
874
+ # To avoid OOM issues due to too long dev cuts
875
+ dev_cuts = dev_cuts.filter(_remove_short_and_long_utt)
876
+
877
+ if not hasattr(train_cuts[0].supervisions[0], "tokens") or not hasattr(
878
+ dev_cuts[0].supervisions[0], "tokens"
879
+ ):
880
+ logging.warning(
881
+ "Tokens are not prepared, will tokenize on-the-fly, "
882
+ "which can slow down training significantly."
883
+ )
884
+ _tokenize_text = partial(tokenize_text, tokenizer=tokenizer)
885
+ train_cuts = train_cuts.map(_tokenize_text)
886
+ dev_cuts = dev_cuts.map(_tokenize_text)
887
+
888
+ train_dl = datamodule.train_dataloaders(train_cuts)
889
+
890
+ valid_dl = datamodule.dev_dataloaders(dev_cuts)
891
+
892
+ if params.scan_oom:
893
+ scan_pessimistic_batches_for_oom(
894
+ model=model,
895
+ train_dl=train_dl,
896
+ optimizer=optimizer,
897
+ params=params,
898
+ )
899
+
900
+ logging.info("Training started")
901
+
902
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
903
+ logging.info(f"Start epoch {epoch}")
904
+ scheduler.step_epoch(epoch - 1)
905
+ fix_random_seed(params.seed + epoch - 1)
906
+ train_dl.sampler.set_epoch(epoch - 1)
907
+
908
+ params.cur_epoch = epoch
909
+
910
+ if tb_writer is not None:
911
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
912
+
913
+ train_one_epoch(
914
+ params=params,
915
+ model=model,
916
+ model_avg=model_avg,
917
+ optimizer=optimizer,
918
+ scheduler=scheduler,
919
+ train_dl=train_dl,
920
+ valid_dl=valid_dl,
921
+ scaler=scaler,
922
+ tb_writer=tb_writer,
923
+ world_size=world_size,
924
+ rank=rank,
925
+ )
926
+
927
+ if params.num_iters > 0 and params.batch_idx_train > params.num_iters:
928
+ break
929
+
930
+ if params.print_diagnostics:
931
+ diagnostic.print_diagnostics()
932
+ break
933
+
934
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
935
+ save_checkpoint(
936
+ filename=filename,
937
+ params=params,
938
+ model=model,
939
+ model_avg=model_avg,
940
+ optimizer=optimizer,
941
+ scheduler=scheduler,
942
+ sampler=train_dl.sampler,
943
+ scaler=scaler,
944
+ rank=rank,
945
+ )
946
+
947
+ if rank == 0:
948
+ if params.best_train_epoch == params.cur_epoch:
949
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
950
+ copyfile(src=filename, dst=best_train_filename)
951
+
952
+ if params.best_valid_epoch == params.cur_epoch:
953
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
954
+ copyfile(src=filename, dst=best_valid_filename)
955
+
956
+ logging.info("Done!")
957
+
958
+ if world_size > 1:
959
+ torch.distributed.barrier()
960
+ cleanup_dist()
961
+
962
+
963
+ def main():
964
+ parser = get_parser()
965
+ TtsDataModule.add_arguments(parser)
966
+ args = parser.parse_args()
967
+ args.exp_dir = Path(args.exp_dir)
968
+
969
+ world_size = args.world_size
970
+ assert world_size >= 1
971
+ if world_size > 1:
972
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
973
+ else:
974
+ run(rank=0, world_size=1, args=args)
975
+
976
+
977
+ if __name__ == "__main__":
978
+ torch.set_num_threads(1)
979
+ torch.set_num_interop_threads(1)
980
+ main()
zipvoice/bin/train_zipvoice_dialog_stereo.py ADDED
@@ -0,0 +1,963 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2025 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """
19
+ This script trains a ZipVoice-Dialog model.
20
+
21
+ Usage:
22
+
23
+ python3 -m zipvoice.bin.train_zipvoice_dialog_stereo \
24
+ --world-size 8 \
25
+ --use-fp16 1 \
26
+ --base-lr 0.002 \
27
+ --max-duration 500 \
28
+ --model-config conf/zipvoice_base.json \
29
+ --token-file "data/tokens_dialog.txt" \
30
+ --manifest-dir data/fbank \
31
+ --exp-dir exp/zipvoice_dialog_stereo
32
+ """
33
+
34
+ import argparse
35
+ import copy
36
+ import json
37
+ import logging
38
+ import os
39
+ from functools import partial
40
+ from pathlib import Path
41
+ from shutil import copyfile
42
+ from typing import List, Optional, Tuple, Union
43
+
44
+ import torch
45
+ import torch.multiprocessing as mp
46
+ import torch.nn as nn
47
+ from lhotse.cut import Cut
48
+ from lhotse.utils import fix_random_seed
49
+ from torch import Tensor
50
+ from torch.nn.parallel import DistributedDataParallel as DDP
51
+ from torch.optim import Optimizer
52
+ from torch.utils.tensorboard import SummaryWriter
53
+
54
+ import zipvoice.utils.diagnostics as diagnostics
55
+ from zipvoice.bin.train_zipvoice import (
56
+ display_and_save_batch,
57
+ get_params,
58
+ tokenize_text,
59
+ )
60
+ from zipvoice.dataset.datamodule import TtsDataModule
61
+ from zipvoice.models.zipvoice_dialog import ZipVoiceDialogStereo
62
+ from zipvoice.tokenizer.tokenizer import DialogTokenizer
63
+ from zipvoice.utils.checkpoint import (
64
+ load_checkpoint,
65
+ load_checkpoint_copy_proj_three_channel_alter,
66
+ remove_checkpoints,
67
+ resume_checkpoint,
68
+ save_checkpoint,
69
+ save_checkpoint_with_global_batch_idx,
70
+ update_averaged_model,
71
+ )
72
+ from zipvoice.utils.common import (
73
+ AttributeDict,
74
+ GradScaler,
75
+ MetricsTracker,
76
+ cleanup_dist,
77
+ create_grad_scaler,
78
+ get_adjusted_batch_count,
79
+ get_parameter_groups_with_lrs,
80
+ prepare_input,
81
+ set_batch_count,
82
+ setup_dist,
83
+ setup_logger,
84
+ str2bool,
85
+ torch_autocast,
86
+ )
87
+ from zipvoice.utils.hooks import register_inf_check_hooks
88
+ from zipvoice.utils.lr_scheduler import FixedLRScheduler, LRScheduler
89
+ from zipvoice.utils.optim import ScaledAdam
90
+
91
+ LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, LRScheduler]
92
+
93
+
94
+ def get_parser():
95
+ parser = argparse.ArgumentParser(
96
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
97
+ )
98
+
99
+ parser.add_argument(
100
+ "--world-size",
101
+ type=int,
102
+ default=1,
103
+ help="Number of GPUs for DDP training.",
104
+ )
105
+
106
+ parser.add_argument(
107
+ "--master-port",
108
+ type=int,
109
+ default=12356,
110
+ help="Master port to use for DDP training.",
111
+ )
112
+
113
+ parser.add_argument(
114
+ "--tensorboard",
115
+ type=str2bool,
116
+ default=True,
117
+ help="Should various information be logged in tensorboard.",
118
+ )
119
+
120
+ parser.add_argument(
121
+ "--num-epochs",
122
+ type=int,
123
+ default=8,
124
+ help="Number of epochs to train.",
125
+ )
126
+
127
+ parser.add_argument(
128
+ "--num-iters",
129
+ type=int,
130
+ default=25000,
131
+ help="Number of iter to train, will ignore num_epochs if > 0.",
132
+ )
133
+
134
+ parser.add_argument(
135
+ "--start-epoch",
136
+ type=int,
137
+ default=1,
138
+ help="""Resume training from this epoch. It should be positive.
139
+ If larger than 1, it will load checkpoint from
140
+ exp-dir/epoch-{start_epoch-1}.pt
141
+ """,
142
+ )
143
+
144
+ parser.add_argument(
145
+ "--checkpoint",
146
+ type=str,
147
+ required=True,
148
+ help="""Checkpoints of pre-trained models, either a ZipVoice model or a
149
+ ZipVoice-Dialog model.
150
+ """,
151
+ )
152
+
153
+ parser.add_argument(
154
+ "--exp-dir",
155
+ type=str,
156
+ default="exp/zipvoice_dialog",
157
+ help="""The experiment dir.
158
+ It specifies the directory where all training related
159
+ files, e.g., checkpoints, log, etc, are saved
160
+ """,
161
+ )
162
+
163
+ parser.add_argument(
164
+ "--base-lr", type=float, default=0.002, help="The base learning rate."
165
+ )
166
+
167
+ parser.add_argument(
168
+ "--ref-duration",
169
+ type=float,
170
+ default=50,
171
+ help="""Reference batch duration for purposes of adjusting batch counts for"
172
+ setting various schedules inside the model".
173
+ """,
174
+ )
175
+
176
+ parser.add_argument(
177
+ "--finetune",
178
+ type=str2bool,
179
+ default=False,
180
+ help="Whether to fine-tune from our pre-traied ZipVoice-Dialog model."
181
+ "False means to fine-tune from a pre-trained ZipVoice model.",
182
+ )
183
+
184
+ parser.add_argument(
185
+ "--seed",
186
+ type=int,
187
+ default=42,
188
+ help="The seed for random generators intended for reproducibility",
189
+ )
190
+
191
+ parser.add_argument(
192
+ "--print-diagnostics",
193
+ type=str2bool,
194
+ default=False,
195
+ help="Accumulate stats on activations, print them and exit.",
196
+ )
197
+
198
+ parser.add_argument(
199
+ "--scan-oom",
200
+ type=str2bool,
201
+ default=False,
202
+ help="Scan pessimistic batches to see whether they cause OOMs.",
203
+ )
204
+
205
+ parser.add_argument(
206
+ "--inf-check",
207
+ type=str2bool,
208
+ default=False,
209
+ help="Add hooks to check for infinite module outputs and gradients.",
210
+ )
211
+
212
+ parser.add_argument(
213
+ "--save-every-n",
214
+ type=int,
215
+ default=5000,
216
+ help="""Save checkpoint after processing this number of batches"
217
+ periodically. We save checkpoint to exp-dir/ whenever
218
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
219
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
220
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
221
+ end of each epoch where `xxx` is the epoch number counting from 1.
222
+ """,
223
+ )
224
+
225
+ parser.add_argument(
226
+ "--keep-last-k",
227
+ type=int,
228
+ default=30,
229
+ help="""Only keep this number of checkpoints on disk.
230
+ For instance, if it is 3, there are only 3 checkpoints
231
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
232
+ It does not affect checkpoints with name `epoch-xxx.pt`.
233
+ """,
234
+ )
235
+
236
+ parser.add_argument(
237
+ "--average-period",
238
+ type=int,
239
+ default=200,
240
+ help="""Update the averaged model, namely `model_avg`, after processing
241
+ this number of batches. `model_avg` is a separate version of model,
242
+ in which each floating-point parameter is the average of all the
243
+ parameters from the start of training. Each time we take the average,
244
+ we do: `model_avg = model * (average_period / batch_idx_train) +
245
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
246
+ """,
247
+ )
248
+
249
+ parser.add_argument(
250
+ "--use-fp16",
251
+ type=str2bool,
252
+ default=True,
253
+ help="Whether to use half precision training.",
254
+ )
255
+
256
+ parser.add_argument(
257
+ "--feat-scale",
258
+ type=float,
259
+ default=0.1,
260
+ help="The scale factor of fbank feature",
261
+ )
262
+
263
+ parser.add_argument(
264
+ "--condition-drop-ratio",
265
+ type=float,
266
+ default=0.2,
267
+ help="The drop rate of text condition during training.",
268
+ )
269
+
270
+ parser.add_argument(
271
+ "--train-manifest",
272
+ type=str,
273
+ help="Path of the training manifest",
274
+ )
275
+
276
+ parser.add_argument(
277
+ "--dev-manifest",
278
+ type=str,
279
+ help="Path of the validation manifest",
280
+ )
281
+
282
+ parser.add_argument(
283
+ "--min-len",
284
+ type=float,
285
+ default=1.0,
286
+ help="The minimum audio length used for training",
287
+ )
288
+
289
+ parser.add_argument(
290
+ "--max-len",
291
+ type=float,
292
+ default=60.0,
293
+ help="The maximum audio length used for training",
294
+ )
295
+
296
+ parser.add_argument(
297
+ "--model-config",
298
+ type=str,
299
+ default="zipvoice_base.json",
300
+ help="The model configuration file.",
301
+ )
302
+
303
+ parser.add_argument(
304
+ "--token-file",
305
+ type=str,
306
+ default="data/tokens_dialog.txt",
307
+ help="The file that contains information that maps tokens to ids,"
308
+ "which is a text file with '{token}\t{token_id}' per line.",
309
+ )
310
+
311
+ return parser
312
+
313
+
314
+ def compute_fbank_loss(
315
+ params: AttributeDict,
316
+ model: Union[nn.Module, DDP],
317
+ features: Tensor,
318
+ features_lens: Tensor,
319
+ tokens: List[List[int]],
320
+ is_training: bool,
321
+ use_two_channel: bool,
322
+ ) -> Tuple[Tensor, MetricsTracker]:
323
+ """
324
+ Compute loss given the model and its inputs.
325
+
326
+ Args:
327
+ params:
328
+ Parameters for training. See :func:`get_params`.
329
+ model:
330
+ The model for training.
331
+ features:
332
+ The target acoustic feature.
333
+ features_lens:
334
+ The number of frames of each utterance.
335
+ tokens:
336
+ Input tokens that representing the transcripts.
337
+ is_training:
338
+ True for training. False for validation. When it is True, this
339
+ function enables autograd during computation; when it is False, it
340
+ disables autograd.
341
+ use_two_channel:
342
+ True for using two channel features, False for using one channel features.
343
+ """
344
+
345
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
346
+
347
+ batch_size, num_frames, _ = features.shape
348
+
349
+ assert (
350
+ features.size(2) == 3 * params.feat_dim
351
+ ), "we assume three channel features, the last channel is the mixed-channel feature"
352
+ if use_two_channel:
353
+ features = features[:, :, : params.feat_dim * 2]
354
+ else:
355
+ features = features[:, :, params.feat_dim * 2 :]
356
+
357
+ noise = torch.randn_like(features) # (B, T, F)
358
+
359
+ # Sampling t from uniform distribution
360
+ if is_training:
361
+ t = torch.rand(batch_size, 1, 1, device=device)
362
+ else:
363
+ t = (
364
+ (torch.arange(batch_size, device=device) / batch_size)
365
+ .unsqueeze(1)
366
+ .unsqueeze(2)
367
+ )
368
+ with torch.set_grad_enabled(is_training):
369
+
370
+ loss = model(
371
+ tokens=tokens,
372
+ features=features,
373
+ features_lens=features_lens,
374
+ noise=noise,
375
+ t=t,
376
+ condition_drop_ratio=params.condition_drop_ratio,
377
+ se_weight=1 if use_two_channel else 0,
378
+ )
379
+
380
+ assert loss.requires_grad == is_training
381
+ info = MetricsTracker()
382
+ num_frames = features_lens.sum().item()
383
+ info["frames"] = num_frames
384
+ info["loss"] = loss.detach().cpu().item() * num_frames
385
+
386
+ return loss, info
387
+
388
+
389
+ def train_one_epoch(
390
+ params: AttributeDict,
391
+ model: Union[nn.Module, DDP],
392
+ optimizer: Optimizer,
393
+ scheduler: LRSchedulerType,
394
+ train_dl: torch.utils.data.DataLoader,
395
+ valid_dl: torch.utils.data.DataLoader,
396
+ scaler: GradScaler,
397
+ model_avg: Optional[nn.Module] = None,
398
+ tb_writer: Optional[SummaryWriter] = None,
399
+ world_size: int = 1,
400
+ rank: int = 0,
401
+ ) -> None:
402
+ """Train the model for one epoch.
403
+
404
+ The training loss from the mean of all frames is saved in
405
+ `params.train_loss`. It runs the validation process every
406
+ `params.valid_interval` batches.
407
+
408
+ Args:
409
+ params:
410
+ It is returned by :func:`get_params`.
411
+ model:
412
+ The model for training.
413
+ optimizer:
414
+ The optimizer.
415
+ scheduler:
416
+ The learning rate scheduler, we call step() every epoch.
417
+ train_dl:
418
+ Dataloader for the training dataset.
419
+ valid_dl:
420
+ Dataloader for the validation dataset.
421
+ scaler:
422
+ The scaler used for mix precision training.
423
+ tb_writer:
424
+ Writer to write log messages to tensorboard.
425
+ world_size:
426
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
427
+ rank:
428
+ The rank of the node in DDP training. If no DDP is used, it should
429
+ be set to 0.
430
+ """
431
+ model.train()
432
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
433
+
434
+ # used to track the stats over iterations in one epoch
435
+ tot_loss = MetricsTracker()
436
+
437
+ saved_bad_model = False
438
+
439
+ def save_bad_model(suffix: str = ""):
440
+ save_checkpoint(
441
+ filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
442
+ model=model,
443
+ model_avg=model_avg,
444
+ params=params,
445
+ optimizer=optimizer,
446
+ scheduler=scheduler,
447
+ sampler=train_dl.sampler,
448
+ scaler=scaler,
449
+ rank=0,
450
+ )
451
+
452
+ for batch_idx, batch in enumerate(train_dl):
453
+
454
+ if batch_idx % 10 == 0:
455
+ set_batch_count(model, get_adjusted_batch_count(params) + 100000)
456
+
457
+ if (
458
+ params.batch_idx_train > 0
459
+ and params.batch_idx_train % params.valid_interval == 0
460
+ and not params.print_diagnostics
461
+ ):
462
+ logging.info("Computing validation loss")
463
+ valid_info = compute_validation_loss(
464
+ params=params,
465
+ model=model,
466
+ valid_dl=valid_dl,
467
+ world_size=world_size,
468
+ )
469
+ model.train()
470
+ logging.info(
471
+ f"Epoch {params.cur_epoch}, global_batch_idx: {params.batch_idx_train},"
472
+ f" validation: {valid_info}"
473
+ )
474
+ logging.info(
475
+ f"Maximum memory allocated so far is "
476
+ f"{torch.cuda.max_memory_allocated() // 1000000}MB"
477
+ )
478
+ if tb_writer is not None:
479
+ valid_info.write_summary(
480
+ tb_writer, "train/valid_", params.batch_idx_train
481
+ )
482
+
483
+ params.batch_idx_train += 1
484
+
485
+ batch_size = len(batch["text"])
486
+
487
+ tokens, features, features_lens = prepare_input(
488
+ params=params,
489
+ batch=batch,
490
+ device=device,
491
+ return_tokens=True,
492
+ return_feature=True,
493
+ )
494
+
495
+ try:
496
+ with torch_autocast(dtype=torch.float16, enabled=params.use_fp16):
497
+ loss, loss_info = compute_fbank_loss(
498
+ params=params,
499
+ model=model,
500
+ features=features,
501
+ features_lens=features_lens,
502
+ tokens=tokens,
503
+ is_training=True,
504
+ use_two_channel=(batch_idx % 2 == 1),
505
+ )
506
+
507
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
508
+
509
+ scaler.scale(loss).backward()
510
+
511
+ scheduler.step_batch(params.batch_idx_train)
512
+ scaler.step(optimizer)
513
+ scaler.update()
514
+ optimizer.zero_grad()
515
+ except Exception as e:
516
+ logging.info(f"Caught exception : {e}.")
517
+ save_bad_model()
518
+ raise
519
+
520
+ if params.print_diagnostics and batch_idx == 5:
521
+ return
522
+
523
+ if (
524
+ rank == 0
525
+ and params.batch_idx_train > 0
526
+ and params.batch_idx_train % params.average_period == 0
527
+ ):
528
+ update_averaged_model(
529
+ params=params,
530
+ model_cur=model,
531
+ model_avg=model_avg,
532
+ )
533
+
534
+ if (
535
+ params.batch_idx_train > 0
536
+ and params.batch_idx_train % params.save_every_n == 0
537
+ ):
538
+ save_checkpoint_with_global_batch_idx(
539
+ out_dir=params.exp_dir,
540
+ global_batch_idx=params.batch_idx_train,
541
+ model=model,
542
+ model_avg=model_avg,
543
+ params=params,
544
+ optimizer=optimizer,
545
+ scheduler=scheduler,
546
+ sampler=train_dl.sampler,
547
+ scaler=scaler,
548
+ rank=rank,
549
+ )
550
+ remove_checkpoints(
551
+ out_dir=params.exp_dir,
552
+ topk=params.keep_last_k,
553
+ rank=rank,
554
+ )
555
+ if params.num_iters > 0 and params.batch_idx_train > params.num_iters:
556
+ break
557
+ if params.batch_idx_train % 100 == 0 and params.use_fp16:
558
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
559
+ # of the grad scaler is configurable, but we can't configure it to have
560
+ # different behavior depending on the current grad scale.
561
+ cur_grad_scale = scaler._scale.item()
562
+
563
+ if cur_grad_scale < 1024.0 or (
564
+ cur_grad_scale < 4096.0 and params.batch_idx_train % 400 == 0
565
+ ):
566
+ scaler.update(cur_grad_scale * 2.0)
567
+ if cur_grad_scale < 0.01:
568
+ if not saved_bad_model:
569
+ save_bad_model(suffix="-first-warning")
570
+ saved_bad_model = True
571
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
572
+ if cur_grad_scale < 1.0e-05:
573
+ save_bad_model()
574
+ raise RuntimeError(
575
+ f"grad_scale is too small, exiting: {cur_grad_scale}"
576
+ )
577
+
578
+ if params.batch_idx_train % params.log_interval == 0:
579
+ cur_lr = max(scheduler.get_last_lr())
580
+ cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
581
+
582
+ logging.info(
583
+ f"Epoch {params.cur_epoch}, batch {batch_idx}, "
584
+ f"global_batch_idx: {params.batch_idx_train}, "
585
+ f"batch size: {batch_size}, "
586
+ f"loss[{loss_info}], tot_loss[{tot_loss}], "
587
+ f"cur_lr: {cur_lr:.2e}, "
588
+ + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
589
+ )
590
+
591
+ if tb_writer is not None:
592
+ tb_writer.add_scalar(
593
+ "train/learning_rate", cur_lr, params.batch_idx_train
594
+ )
595
+ loss_info.write_summary(
596
+ tb_writer, "train/current_", params.batch_idx_train
597
+ )
598
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
599
+ if params.use_fp16:
600
+ tb_writer.add_scalar(
601
+ "train/grad_scale",
602
+ cur_grad_scale,
603
+ params.batch_idx_train,
604
+ )
605
+
606
+ loss_value = tot_loss["loss"]
607
+ params.train_loss = loss_value
608
+ if params.train_loss < params.best_train_loss:
609
+ params.best_train_epoch = params.cur_epoch
610
+ params.best_train_loss = params.train_loss
611
+
612
+
613
+ def compute_validation_loss(
614
+ params: AttributeDict,
615
+ model: Union[nn.Module, DDP],
616
+ valid_dl: torch.utils.data.DataLoader,
617
+ world_size: int = 1,
618
+ ) -> MetricsTracker:
619
+ """Run the validation process."""
620
+
621
+ model.eval()
622
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
623
+
624
+ # used to summary the stats over iterations
625
+ tot_loss = MetricsTracker()
626
+
627
+ for batch_idx, batch in enumerate(valid_dl):
628
+ tokens, features, features_lens = prepare_input(
629
+ params=params,
630
+ batch=batch,
631
+ device=device,
632
+ return_tokens=True,
633
+ return_feature=True,
634
+ )
635
+
636
+ loss, loss_info = compute_fbank_loss(
637
+ params=params,
638
+ model=model,
639
+ features=features,
640
+ features_lens=features_lens,
641
+ tokens=tokens,
642
+ is_training=False,
643
+ use_two_channel=True,
644
+ )
645
+ assert loss.requires_grad is False
646
+ tot_loss = tot_loss + loss_info
647
+
648
+ if world_size > 1:
649
+ tot_loss.reduce(loss.device)
650
+
651
+ loss_value = tot_loss["loss"]
652
+ if loss_value < params.best_valid_loss:
653
+ params.best_valid_epoch = params.cur_epoch
654
+ params.best_valid_loss = loss_value
655
+
656
+ return tot_loss
657
+
658
+
659
+ def scan_pessimistic_batches_for_oom(
660
+ model: Union[nn.Module, DDP],
661
+ train_dl: torch.utils.data.DataLoader,
662
+ optimizer: torch.optim.Optimizer,
663
+ params: AttributeDict,
664
+ ):
665
+ from lhotse.dataset import find_pessimistic_batches
666
+
667
+ logging.info(
668
+ "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
669
+ )
670
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
671
+
672
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
673
+ for criterion, cuts in batches.items():
674
+ batch = train_dl.dataset[cuts]
675
+ tokens, features, features_lens = prepare_input(
676
+ params=params,
677
+ batch=batch,
678
+ device=device,
679
+ return_tokens=True,
680
+ return_feature=True,
681
+ )
682
+ try:
683
+ with torch_autocast(dtype=torch.float16, enabled=params.use_fp16):
684
+
685
+ loss, loss_info = compute_fbank_loss(
686
+ params=params,
687
+ model=model,
688
+ features=features,
689
+ features_lens=features_lens,
690
+ tokens=tokens,
691
+ is_training=True,
692
+ use_two_channel=True,
693
+ )
694
+ loss.backward()
695
+ optimizer.zero_grad()
696
+ except Exception as e:
697
+ if "CUDA out of memory" in str(e):
698
+ logging.error(
699
+ "Your GPU ran out of memory with the current "
700
+ "max_duration setting. We recommend decreasing "
701
+ "max_duration and trying again.\n"
702
+ f"Failing criterion: {criterion} "
703
+ f"(={crit_values[criterion]}) ..."
704
+ )
705
+ display_and_save_batch(batch, params=params)
706
+ raise
707
+ logging.info(
708
+ f"Maximum memory allocated so far is "
709
+ f"{torch.cuda.max_memory_allocated() // 1000000}MB"
710
+ )
711
+
712
+
713
+ def run(rank, world_size, args):
714
+ """
715
+ Args:
716
+ rank:
717
+ It is a value between 0 and `world_size-1`, which is
718
+ passed automatically by `mp.spawn()` in :func:`main`.
719
+ The node with rank 0 is responsible for saving checkpoint.
720
+ world_size:
721
+ Number of GPUs for DDP training.
722
+ args:
723
+ The return value of get_parser().parse_args()
724
+ """
725
+ params = get_params()
726
+ params.update(vars(args))
727
+ params.valid_interval = params.save_every_n
728
+ # Set epoch to a large number to ignore it.
729
+ if params.num_iters > 0:
730
+ params.num_epochs = 1000000
731
+ with open(params.model_config, "r") as f:
732
+ model_config = json.load(f)
733
+ params.update(model_config["model"])
734
+ params.update(model_config["feature"])
735
+
736
+ fix_random_seed(params.seed)
737
+ if world_size > 1:
738
+ setup_dist(rank, world_size, params.master_port)
739
+
740
+ os.makedirs(f"{params.exp_dir}", exist_ok=True)
741
+ copyfile(src=params.model_config, dst=f"{params.exp_dir}/model.json")
742
+ copyfile(src=params.token_file, dst=f"{params.exp_dir}/tokens.txt")
743
+ setup_logger(f"{params.exp_dir}/log/log-train")
744
+
745
+ if args.tensorboard and rank == 0:
746
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
747
+ else:
748
+ tb_writer = None
749
+
750
+ if torch.cuda.is_available():
751
+ params.device = torch.device("cuda", rank)
752
+ else:
753
+ params.device = torch.device("cpu")
754
+ logging.info(f"Device: {params.device}")
755
+
756
+ tokenizer = DialogTokenizer(token_file=params.token_file)
757
+ tokenizer_config = {
758
+ "vocab_size": tokenizer.vocab_size,
759
+ "pad_id": tokenizer.pad_id,
760
+ "spk_a_id": tokenizer.spk_a_id,
761
+ "spk_b_id": tokenizer.spk_b_id,
762
+ }
763
+ params.update(tokenizer_config)
764
+
765
+ logging.info(params)
766
+
767
+ logging.info("About to create model")
768
+
769
+ model = ZipVoiceDialogStereo(
770
+ **model_config["model"],
771
+ **tokenizer_config,
772
+ )
773
+
774
+ assert params.checkpoint is not None
775
+ logging.info(f"Loading pre-trained model from {params.checkpoint}")
776
+
777
+ if params.finetune:
778
+ # load a pre-trained ZipVoice-Dialog-Stereo model
779
+ _ = load_checkpoint(filename=params.checkpoint, model=model, strict=True)
780
+ else:
781
+ # load a pre-trained ZipVoice-Dialog model, duplicate the proj layers
782
+ load_checkpoint_copy_proj_three_channel_alter(
783
+ filename=params.checkpoint,
784
+ in_proj_key="fm_decoder.in_proj",
785
+ out_proj_key="fm_decoder.out_proj",
786
+ dim=params.feat_dim,
787
+ model=model,
788
+ )
789
+ num_param = sum([p.numel() for p in model.parameters()])
790
+ logging.info(f"Number of parameters : {num_param}")
791
+
792
+ model_avg: Optional[nn.Module] = None
793
+ if rank == 0:
794
+ # model_avg is only used with rank 0
795
+ model_avg = copy.deepcopy(model).to(torch.float64)
796
+
797
+ assert params.start_epoch > 0, params.start_epoch
798
+ if params.start_epoch > 1:
799
+ checkpoints = resume_checkpoint(params=params, model=model, model_avg=model_avg)
800
+
801
+ model = model.to(params.device)
802
+ if world_size > 1:
803
+ logging.info("Using DDP")
804
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
805
+
806
+ optimizer = ScaledAdam(
807
+ get_parameter_groups_with_lrs(
808
+ model,
809
+ lr=params.base_lr,
810
+ include_names=True,
811
+ ),
812
+ lr=params.base_lr, # should have no effect
813
+ clipping_scale=2.0,
814
+ )
815
+
816
+ scheduler = FixedLRScheduler(optimizer)
817
+
818
+ scaler = create_grad_scaler(enabled=params.use_fp16)
819
+
820
+ if params.start_epoch > 1 and checkpoints is not None:
821
+ # load state_dict for optimizers
822
+ if "optimizer" in checkpoints:
823
+ logging.info("Loading optimizer state dict")
824
+ optimizer.load_state_dict(checkpoints["optimizer"])
825
+
826
+ # load state_dict for schedulers
827
+ if "scheduler" in checkpoints:
828
+ logging.info("Loading scheduler state dict")
829
+ scheduler.load_state_dict(checkpoints["scheduler"])
830
+
831
+ if "grad_scaler" in checkpoints:
832
+ logging.info("Loading grad scaler state dict")
833
+ scaler.load_state_dict(checkpoints["grad_scaler"])
834
+
835
+ if params.print_diagnostics:
836
+ opts = diagnostics.TensorDiagnosticOptions(
837
+ 512
838
+ ) # allow 4 megabytes per sub-module
839
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
840
+
841
+ if params.inf_check:
842
+ register_inf_check_hooks(model)
843
+
844
+ def remove_short_and_long_utt(c: Cut, min_len: float, max_len: float):
845
+ if c.duration < min_len or c.duration > max_len:
846
+ return False
847
+ return True
848
+
849
+ _remove_short_and_long_utt = partial(
850
+ remove_short_and_long_utt, min_len=params.min_len, max_len=params.max_len
851
+ )
852
+
853
+ datamodule = TtsDataModule(args)
854
+ train_cuts = datamodule.train_custom_cuts(params.train_manifest)
855
+ train_cuts = train_cuts.filter(_remove_short_and_long_utt)
856
+ dev_cuts = datamodule.dev_custom_cuts(params.dev_manifest)
857
+ # To avoid OOM issues due to too long dev cuts
858
+ dev_cuts = dev_cuts.filter(_remove_short_and_long_utt)
859
+
860
+ if not hasattr(train_cuts[0].supervisions[0], "tokens") or not hasattr(
861
+ dev_cuts[0].supervisions[0], "tokens"
862
+ ):
863
+ logging.warning(
864
+ "Tokens are not prepared, will tokenize on-the-fly, "
865
+ "which can slow down training significantly."
866
+ )
867
+ _tokenize_text = partial(tokenize_text, tokenizer=tokenizer)
868
+ train_cuts = train_cuts.map(_tokenize_text)
869
+ dev_cuts = dev_cuts.map(_tokenize_text)
870
+
871
+ train_dl = datamodule.train_dataloaders(train_cuts)
872
+
873
+ valid_dl = datamodule.dev_dataloaders(dev_cuts)
874
+
875
+ if params.scan_oom:
876
+ scan_pessimistic_batches_for_oom(
877
+ model=model,
878
+ train_dl=train_dl,
879
+ optimizer=optimizer,
880
+ params=params,
881
+ )
882
+
883
+ logging.info("Training started")
884
+
885
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
886
+ logging.info(f"Start epoch {epoch}")
887
+ scheduler.step_epoch(epoch - 1)
888
+ fix_random_seed(params.seed + epoch - 1)
889
+ train_dl.sampler.set_epoch(epoch - 1)
890
+
891
+ params.cur_epoch = epoch
892
+
893
+ if tb_writer is not None:
894
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
895
+
896
+ train_one_epoch(
897
+ params=params,
898
+ model=model,
899
+ model_avg=model_avg,
900
+ optimizer=optimizer,
901
+ scheduler=scheduler,
902
+ train_dl=train_dl,
903
+ valid_dl=valid_dl,
904
+ scaler=scaler,
905
+ tb_writer=tb_writer,
906
+ world_size=world_size,
907
+ rank=rank,
908
+ )
909
+
910
+ if params.num_iters > 0 and params.batch_idx_train > params.num_iters:
911
+ break
912
+
913
+ if params.print_diagnostics:
914
+ diagnostic.print_diagnostics()
915
+ break
916
+
917
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
918
+ save_checkpoint(
919
+ filename=filename,
920
+ params=params,
921
+ model=model,
922
+ model_avg=model_avg,
923
+ optimizer=optimizer,
924
+ scheduler=scheduler,
925
+ sampler=train_dl.sampler,
926
+ scaler=scaler,
927
+ rank=rank,
928
+ )
929
+
930
+ if rank == 0:
931
+ if params.best_train_epoch == params.cur_epoch:
932
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
933
+ copyfile(src=filename, dst=best_train_filename)
934
+
935
+ if params.best_valid_epoch == params.cur_epoch:
936
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
937
+ copyfile(src=filename, dst=best_valid_filename)
938
+
939
+ logging.info("Done!")
940
+
941
+ if world_size > 1:
942
+ torch.distributed.barrier()
943
+ cleanup_dist()
944
+
945
+
946
+ def main():
947
+ parser = get_parser()
948
+ TtsDataModule.add_arguments(parser)
949
+ args = parser.parse_args()
950
+ args.exp_dir = Path(args.exp_dir)
951
+
952
+ world_size = args.world_size
953
+ assert world_size >= 1
954
+ if world_size > 1:
955
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
956
+ else:
957
+ run(rank=0, world_size=1, args=args)
958
+
959
+
960
+ if __name__ == "__main__":
961
+ torch.set_num_threads(1)
962
+ torch.set_num_interop_threads(1)
963
+ main()
zipvoice/bin/train_zipvoice_distill.py ADDED
@@ -0,0 +1,1158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2024 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+
19
+ """
20
+ This script trains a ZipVoice-Distill model starting from a ZipVoice model.
21
+ It has two distillation stages.
22
+
23
+ Usage:
24
+
25
+ (1) The first distillation stage with a fixed ZipVoice model as the teacher.
26
+
27
+ python3 -m zipvoice.bin.train_zipvoice_distill \
28
+ --world-size 8 \
29
+ --use-fp16 1 \
30
+ --num-iters 60000 \
31
+ --max-duration 500 \
32
+ --base-lr 0.0005 \
33
+ --tokenizer emilia \
34
+ --token-file data/tokens_emilia.txt \
35
+ --dataset emilia \
36
+ --manifest-dir data/fbank \
37
+ --teacher-model exp/zipvoice/epoch-11-avg-4.pt \
38
+ --distill-stage first \
39
+ --exp-dir exp/zipvoice_distill_1stage
40
+
41
+ (2) The second distillation stage with a EMA model as the teacher.
42
+ python3 -m zipvoice.bin.train_zipvoice_distill \
43
+ --world-size 8 \
44
+ --use-fp16 1 \
45
+ --num-iters 2000 \
46
+ --save-every-n 1000 \
47
+ --max-duration 500 \
48
+ --base-lr 0.0001 \
49
+ --model-config conf/zipvoice_base.json \
50
+ --tokenizer emilia \
51
+ --token-file data/tokens_emilia.txt \
52
+ --dataset emilia \
53
+ --manifest-dir data/fbank \
54
+ --teacher-model exp/zipvoice_distill_1stage/iter-60000-avg-7.pt \
55
+ --distill-stage second \
56
+ --exp-dir exp/zipvoice_distill
57
+ """
58
+
59
+ import argparse
60
+ import copy
61
+ import json
62
+ import logging
63
+ import os
64
+ import random
65
+ from functools import partial
66
+ from pathlib import Path
67
+ from shutil import copyfile
68
+ from typing import List, Optional, Tuple, Union
69
+
70
+ import torch
71
+ import torch.multiprocessing as mp
72
+ import torch.nn as nn
73
+ from lhotse.cut import Cut, CutSet
74
+ from lhotse.utils import fix_random_seed
75
+ from torch import Tensor
76
+ from torch.nn.parallel import DistributedDataParallel as DDP
77
+ from torch.optim import Optimizer
78
+ from torch.utils.tensorboard import SummaryWriter
79
+
80
+ import zipvoice.utils.diagnostics as diagnostics
81
+ from zipvoice.bin.train_zipvoice import (
82
+ display_and_save_batch,
83
+ get_params,
84
+ tokenize_text,
85
+ )
86
+ from zipvoice.dataset.datamodule import TtsDataModule
87
+ from zipvoice.models.zipvoice import ZipVoice
88
+ from zipvoice.models.zipvoice_distill import ZipVoiceDistill
89
+ from zipvoice.tokenizer.tokenizer import (
90
+ EmiliaTokenizer,
91
+ EspeakTokenizer,
92
+ LibriTTSTokenizer,
93
+ SimpleTokenizer,
94
+ )
95
+ from zipvoice.utils.checkpoint import (
96
+ load_checkpoint,
97
+ remove_checkpoints,
98
+ resume_checkpoint,
99
+ save_checkpoint,
100
+ save_checkpoint_with_global_batch_idx,
101
+ update_averaged_model,
102
+ )
103
+ from zipvoice.utils.common import (
104
+ AttributeDict,
105
+ GradScaler,
106
+ MetricsTracker,
107
+ cleanup_dist,
108
+ condition_time_mask,
109
+ create_grad_scaler,
110
+ get_adjusted_batch_count,
111
+ get_parameter_groups_with_lrs,
112
+ make_pad_mask,
113
+ prepare_input,
114
+ set_batch_count,
115
+ setup_dist,
116
+ setup_logger,
117
+ str2bool,
118
+ torch_autocast,
119
+ )
120
+ from zipvoice.utils.hooks import register_inf_check_hooks
121
+ from zipvoice.utils.lr_scheduler import FixedLRScheduler, LRScheduler
122
+ from zipvoice.utils.optim import ScaledAdam
123
+
124
+ LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, LRScheduler]
125
+
126
+
127
+ def get_parser():
128
+ parser = argparse.ArgumentParser(
129
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
130
+ )
131
+
132
+ parser.add_argument(
133
+ "--world-size",
134
+ type=int,
135
+ default=1,
136
+ help="Number of GPUs for DDP training.",
137
+ )
138
+
139
+ parser.add_argument(
140
+ "--master-port",
141
+ type=int,
142
+ default=12356,
143
+ help="Master port to use for DDP training.",
144
+ )
145
+
146
+ parser.add_argument(
147
+ "--tensorboard",
148
+ type=str2bool,
149
+ default=True,
150
+ help="Should various information be logged in tensorboard.",
151
+ )
152
+
153
+ parser.add_argument(
154
+ "--num-epochs",
155
+ type=int,
156
+ default=1,
157
+ help="Number of epochs to train.",
158
+ )
159
+
160
+ parser.add_argument(
161
+ "--num-iters",
162
+ type=int,
163
+ default=0,
164
+ help="Number of iter to train, will ignore num_epochs if > 0.",
165
+ )
166
+
167
+ parser.add_argument(
168
+ "--start-epoch",
169
+ type=int,
170
+ default=1,
171
+ help="""Resume training from this epoch. It should be positive.
172
+ If larger than 1, it will load checkpoint from
173
+ exp-dir/epoch-{start_epoch-1}.pt
174
+ """,
175
+ )
176
+
177
+ parser.add_argument(
178
+ "--teacher-model",
179
+ type=str,
180
+ help="""Checkpoints of pre-trained teacher model""",
181
+ )
182
+
183
+ parser.add_argument(
184
+ "--exp-dir",
185
+ type=str,
186
+ default="exp/zipvoice_distill",
187
+ help="""The experiment dir.
188
+ It specifies the directory where all training related
189
+ files, e.g., checkpoints, log, etc, are saved
190
+ """,
191
+ )
192
+
193
+ parser.add_argument(
194
+ "--base-lr", type=float, default=0.001, help="The base learning rate."
195
+ )
196
+
197
+ parser.add_argument(
198
+ "--ref-duration",
199
+ type=float,
200
+ default=50,
201
+ help="Reference batch duration for purposes of adjusting batch counts for "
202
+ "setting various schedules inside the model",
203
+ )
204
+
205
+ parser.add_argument(
206
+ "--seed",
207
+ type=int,
208
+ default=42,
209
+ help="The seed for random generators intended for reproducibility",
210
+ )
211
+
212
+ parser.add_argument(
213
+ "--print-diagnostics",
214
+ type=str2bool,
215
+ default=False,
216
+ help="Accumulate stats on activations, print them and exit.",
217
+ )
218
+
219
+ parser.add_argument(
220
+ "--scan-oom",
221
+ type=str2bool,
222
+ default=False,
223
+ help="Scan pessimistic batches to see whether they cause OOMs.",
224
+ )
225
+
226
+ parser.add_argument(
227
+ "--inf-check",
228
+ type=str2bool,
229
+ default=False,
230
+ help="Add hooks to check for infinite module outputs and gradients.",
231
+ )
232
+
233
+ parser.add_argument(
234
+ "--save-every-n",
235
+ type=int,
236
+ default=1000,
237
+ help="""Save checkpoint after processing this number of batches"
238
+ periodically. We save checkpoint to exp-dir/ whenever
239
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
240
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
241
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
242
+ end of each epoch where `xxx` is the epoch number counting from 1.
243
+ """,
244
+ )
245
+
246
+ parser.add_argument(
247
+ "--keep-last-k",
248
+ type=int,
249
+ default=30,
250
+ help="""Only keep this number of checkpoints on disk.
251
+ For instance, if it is 3, there are only 3 checkpoints
252
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
253
+ It does not affect checkpoints with name `epoch-xxx.pt`.
254
+ """,
255
+ )
256
+
257
+ parser.add_argument(
258
+ "--average-period",
259
+ type=int,
260
+ default=200,
261
+ help="""Update the averaged model, namely `model_avg`, after processing
262
+ this number of batches. `model_avg` is a separate version of model,
263
+ in which each floating-point parameter is the average of all the
264
+ parameters from the start of training. Each time we take the average,
265
+ we do: `model_avg = model * (average_period / batch_idx_train) +
266
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
267
+ """,
268
+ )
269
+
270
+ parser.add_argument(
271
+ "--use-fp16",
272
+ type=str2bool,
273
+ default=True,
274
+ help="Whether to use half precision training.",
275
+ )
276
+
277
+ parser.add_argument(
278
+ "--feat-scale",
279
+ type=float,
280
+ default=0.1,
281
+ help="The scale factor of fbank feature",
282
+ )
283
+
284
+ parser.add_argument(
285
+ "--ema-decay",
286
+ type=float,
287
+ default=0.9999,
288
+ help="The EMA decay factor of target model in distillation.",
289
+ )
290
+ parser.add_argument(
291
+ "--distill-stage",
292
+ type=str,
293
+ choices=["first", "second"],
294
+ help="The stage of distillation.",
295
+ )
296
+
297
+ parser.add_argument(
298
+ "--dataset",
299
+ type=str,
300
+ default="emilia",
301
+ choices=["emilia", "libritts", "custom"],
302
+ help="The used training dataset",
303
+ )
304
+
305
+ parser.add_argument(
306
+ "--train-manifest",
307
+ type=str,
308
+ help="Path of the training manifest",
309
+ )
310
+
311
+ parser.add_argument(
312
+ "--dev-manifest",
313
+ type=str,
314
+ help="Path of the validation manifest",
315
+ )
316
+
317
+ parser.add_argument(
318
+ "--min-len",
319
+ type=float,
320
+ default=1.0,
321
+ help="The minimum audio length used for training",
322
+ )
323
+
324
+ parser.add_argument(
325
+ "--max-len",
326
+ type=float,
327
+ default=30.0,
328
+ help="The maximum audio length used for training",
329
+ )
330
+
331
+ parser.add_argument(
332
+ "--model-config",
333
+ type=str,
334
+ default="conf/zipvoice_base.json",
335
+ help="The model configuration file.",
336
+ )
337
+
338
+ parser.add_argument(
339
+ "--tokenizer",
340
+ type=str,
341
+ default="emilia",
342
+ choices=["emilia", "libritts", "espeak", "simple"],
343
+ help="Tokenizer type.",
344
+ )
345
+
346
+ parser.add_argument(
347
+ "--lang",
348
+ type=str,
349
+ default="en-us",
350
+ help="Language identifier, used when tokenizer type is espeak. see"
351
+ "https://github.com/rhasspy/espeak-ng/blob/master/docs/languages.md",
352
+ )
353
+
354
+ parser.add_argument(
355
+ "--token-file",
356
+ type=str,
357
+ default="data/tokens_emilia.txt",
358
+ help="The file that contains information that maps tokens to ids,"
359
+ "which is a text file with '{token}\t{token_id}' per line.",
360
+ )
361
+
362
+ return parser
363
+
364
+
365
+ def ema(new_model, ema_model, decay):
366
+ if isinstance(new_model, DDP):
367
+ new_model = new_model.module
368
+ if isinstance(ema_model, DDP):
369
+ ema_model = ema_model.module
370
+ new_model_dict = new_model.state_dict()
371
+ ema_model_dict = ema_model.state_dict()
372
+ for key in new_model_dict.keys():
373
+ ema_model_dict[key].data.copy_(
374
+ ema_model_dict[key].data * decay + new_model_dict[key].data * (1 - decay)
375
+ )
376
+
377
+
378
+ def compute_fbank_loss(
379
+ params: AttributeDict,
380
+ model: Union[nn.Module, DDP],
381
+ teacher_model: Union[nn.Module, DDP],
382
+ features: Tensor,
383
+ features_lens: Tensor,
384
+ tokens: List[List[int]],
385
+ is_training: bool,
386
+ ) -> Tuple[Tensor, MetricsTracker]:
387
+ """
388
+ Compute loss given the model and its inputs.
389
+
390
+ Args:
391
+ params:
392
+ Parameters for training. See :func:`get_params`.
393
+ model:
394
+ The model for training.
395
+ teacher_model:
396
+ The teacher model for distillation.
397
+ features:
398
+ The target acoustic feature.
399
+ features_lens:
400
+ The number of frames of each utterance.
401
+ tokens:
402
+ Input tokens that representing the transcripts.
403
+ is_training:
404
+ True for training. False for validation. When it is True, this
405
+ function enables autograd during computation; when it is False, it
406
+ disables autograd.
407
+ """
408
+
409
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
410
+
411
+ batch_size, num_frames, _ = features.shape
412
+
413
+ noise = torch.randn_like(features) # (B, T, F)
414
+
415
+ # Sampling t and guidance_scale from uniform distribution
416
+
417
+ t_value = random.random()
418
+ t = torch.ones(batch_size, 1, 1, device=device) * t_value
419
+ if params.distill_stage == "first":
420
+ guidance_scale = torch.rand(batch_size, 1, 1, device=device) * 2
421
+ else:
422
+ guidance_scale = torch.rand(batch_size, 1, 1, device=device) * 2 + 1
423
+ xt = features * t + noise * (1 - t)
424
+ t_delta_fix = random.uniform(0.0, min(0.3, 1 - t_value))
425
+ t_delta_ema = random.uniform(0.0, min(0.3, 1 - t_value - t_delta_fix))
426
+ t_dest = t_value + t_delta_fix + t_delta_ema
427
+
428
+ with torch.no_grad():
429
+ speech_condition_mask = condition_time_mask(
430
+ features_lens=features_lens,
431
+ mask_percent=(0.7, 1.0),
432
+ max_len=features.size(1),
433
+ )
434
+
435
+ if params.distill_stage == "first":
436
+ teacher_x_t_mid, _ = teacher_model.sample_intermediate(
437
+ tokens=tokens,
438
+ features=features,
439
+ features_lens=features_lens,
440
+ noise=xt,
441
+ speech_condition_mask=speech_condition_mask,
442
+ t_start=t_value,
443
+ t_end=t_value + t_delta_fix,
444
+ num_step=1,
445
+ guidance_scale=guidance_scale,
446
+ )
447
+
448
+ target_x1, _ = teacher_model.sample_intermediate(
449
+ tokens=tokens,
450
+ features=features,
451
+ features_lens=features_lens,
452
+ noise=teacher_x_t_mid,
453
+ speech_condition_mask=speech_condition_mask,
454
+ t_start=t_value + t_delta_fix,
455
+ t_end=t_dest,
456
+ num_step=1,
457
+ guidance_scale=guidance_scale,
458
+ )
459
+ else:
460
+ teacher_x_t_mid, _ = teacher_model(
461
+ tokens=tokens,
462
+ features=features,
463
+ features_lens=features_lens,
464
+ noise=xt,
465
+ speech_condition_mask=speech_condition_mask,
466
+ t_start=t_value,
467
+ t_end=t_value + t_delta_fix,
468
+ num_step=1,
469
+ guidance_scale=guidance_scale,
470
+ )
471
+
472
+ target_x1, _ = teacher_model(
473
+ tokens=tokens,
474
+ features=features,
475
+ features_lens=features_lens,
476
+ noise=teacher_x_t_mid,
477
+ speech_condition_mask=speech_condition_mask,
478
+ t_start=t_value + t_delta_fix,
479
+ t_end=t_dest,
480
+ num_step=1,
481
+ guidance_scale=guidance_scale,
482
+ )
483
+
484
+ with torch.set_grad_enabled(is_training):
485
+
486
+ pred_x1, _ = model(
487
+ tokens=tokens,
488
+ features=features,
489
+ features_lens=features_lens,
490
+ noise=xt,
491
+ speech_condition_mask=speech_condition_mask,
492
+ t_start=t_value,
493
+ t_end=t_dest,
494
+ num_step=1,
495
+ guidance_scale=guidance_scale,
496
+ )
497
+ pred_v = (pred_x1 - xt) / (t_dest - t)
498
+
499
+ padding_mask = make_pad_mask(features_lens, max_len=num_frames) # (B, T)
500
+ loss_mask = speech_condition_mask & (~padding_mask)
501
+
502
+ target_v = (target_x1 - xt) / (t_dest - t)
503
+ loss = torch.mean((pred_v[loss_mask] - target_v[loss_mask]) ** 2)
504
+
505
+ ut = features - noise # (B, T, F)
506
+
507
+ ref_loss = torch.mean((pred_v[loss_mask] - ut[loss_mask]) ** 2)
508
+
509
+ assert loss.requires_grad == is_training
510
+ info = MetricsTracker()
511
+ num_frames = features_lens.sum().item()
512
+ info["frames"] = num_frames
513
+ info["loss"] = loss.detach().cpu().item() * num_frames
514
+ info["ref_loss"] = ref_loss.detach().cpu().item() * num_frames
515
+ return loss, info
516
+
517
+
518
+ def train_one_epoch(
519
+ params: AttributeDict,
520
+ model: Union[nn.Module, DDP],
521
+ teacher_model: Union[nn.Module, DDP],
522
+ optimizer: Optimizer,
523
+ scheduler: LRSchedulerType,
524
+ train_dl: torch.utils.data.DataLoader,
525
+ valid_dl: torch.utils.data.DataLoader,
526
+ scaler: GradScaler,
527
+ model_avg: Optional[nn.Module] = None,
528
+ tb_writer: Optional[SummaryWriter] = None,
529
+ world_size: int = 1,
530
+ rank: int = 0,
531
+ ) -> None:
532
+ """Train the model for one epoch.
533
+
534
+ The training loss from the mean of all frames is saved in
535
+ `params.train_loss`. It runs the validation process every
536
+ `params.valid_interval` batches.
537
+
538
+ Args:
539
+ params:
540
+ It is returned by :func:`get_params`.
541
+ model:
542
+ The model for training.
543
+ teacher_model:
544
+ The model for distillation.
545
+ Used to convert text to tokens.
546
+ optimizer:
547
+ The optimizer.
548
+ scheduler:
549
+ The learning rate scheduler, we call step() every epoch.
550
+ train_dl:
551
+ Dataloader for the training dataset.
552
+ valid_dl:
553
+ Dataloader for the validation dataset.
554
+ scaler:
555
+ The scaler used for mix precision training.
556
+ tb_writer:
557
+ Writer to write log messages to tensorboard.
558
+ world_size:
559
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
560
+ rank:
561
+ The rank of the node in DDP training. If no DDP is used, it should
562
+ be set to 0.
563
+ """
564
+ model.train()
565
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
566
+
567
+ # used to track the stats over iterations in one epoch
568
+ tot_loss = MetricsTracker()
569
+
570
+ saved_bad_model = False
571
+
572
+ def save_bad_model(suffix: str = ""):
573
+ save_checkpoint(
574
+ filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
575
+ model=model,
576
+ model_avg=model_avg,
577
+ model_ema=teacher_model,
578
+ params=params,
579
+ optimizer=optimizer,
580
+ scheduler=scheduler,
581
+ sampler=train_dl.sampler,
582
+ scaler=scaler,
583
+ rank=0,
584
+ )
585
+
586
+ for batch_idx, batch in enumerate(train_dl):
587
+
588
+ if batch_idx % 10 == 0:
589
+ set_batch_count(model, get_adjusted_batch_count(params) + 100000)
590
+
591
+ if (
592
+ params.batch_idx_train % params.valid_interval == 0
593
+ and not params.print_diagnostics
594
+ ):
595
+ logging.info("Computing validation loss")
596
+ valid_info = compute_validation_loss(
597
+ params=params,
598
+ model=model,
599
+ teacher_model=teacher_model,
600
+ valid_dl=valid_dl,
601
+ world_size=world_size,
602
+ )
603
+ model.train()
604
+ logging.info(
605
+ f"Epoch {params.cur_epoch}, global_batch_idx: {params.batch_idx_train},"
606
+ f" validation: {valid_info}"
607
+ )
608
+ logging.info(
609
+ f"Maximum memory allocated so far is "
610
+ f"{torch.cuda.max_memory_allocated() // 1000000}MB"
611
+ )
612
+ if tb_writer is not None:
613
+ valid_info.write_summary(
614
+ tb_writer, "train/valid_", params.batch_idx_train
615
+ )
616
+
617
+ params.batch_idx_train += 1
618
+
619
+ batch_size = len(batch["text"])
620
+
621
+ tokens, features, features_lens = prepare_input(
622
+ params=params,
623
+ batch=batch,
624
+ device=device,
625
+ return_tokens=True,
626
+ return_feature=True,
627
+ )
628
+
629
+ try:
630
+ with torch_autocast(dtype=torch.float16, enabled=params.use_fp16):
631
+ loss, loss_info = compute_fbank_loss(
632
+ params=params,
633
+ model=model,
634
+ teacher_model=teacher_model,
635
+ features=features,
636
+ features_lens=features_lens,
637
+ tokens=tokens,
638
+ is_training=True,
639
+ )
640
+
641
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
642
+
643
+ scaler.scale(loss).backward()
644
+
645
+ scheduler.step_batch(params.batch_idx_train)
646
+ scaler.step(optimizer)
647
+ scaler.update()
648
+ optimizer.zero_grad()
649
+ if params.distill_stage == "second":
650
+ ema(model, teacher_model, params.ema_decay)
651
+ except Exception as e:
652
+ logging.info(f"Caught exception : {e}.")
653
+ save_bad_model()
654
+ raise
655
+
656
+ if params.print_diagnostics and batch_idx == 5:
657
+ return
658
+
659
+ if (
660
+ rank == 0
661
+ and params.batch_idx_train > 0
662
+ and params.batch_idx_train % params.average_period == 0
663
+ ):
664
+ update_averaged_model(
665
+ params=params,
666
+ model_cur=model,
667
+ model_avg=model_avg,
668
+ )
669
+
670
+ if (
671
+ params.batch_idx_train > 0
672
+ and params.batch_idx_train % params.save_every_n == 0
673
+ ):
674
+ save_checkpoint_with_global_batch_idx(
675
+ out_dir=params.exp_dir,
676
+ global_batch_idx=params.batch_idx_train,
677
+ model=model,
678
+ model_avg=model_avg,
679
+ params=params,
680
+ optimizer=optimizer,
681
+ scheduler=scheduler,
682
+ sampler=train_dl.sampler,
683
+ scaler=scaler,
684
+ rank=rank,
685
+ )
686
+ remove_checkpoints(
687
+ out_dir=params.exp_dir,
688
+ topk=params.keep_last_k,
689
+ rank=rank,
690
+ )
691
+ if params.num_iters > 0 and params.batch_idx_train > params.num_iters:
692
+ break
693
+ if params.batch_idx_train % 100 == 0 and params.use_fp16:
694
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
695
+ # of the grad scaler is configurable, but we can't configure it to have
696
+ # different behavior depending on the current grad scale.
697
+ cur_grad_scale = scaler._scale.item()
698
+
699
+ if cur_grad_scale < 1024.0 or (
700
+ cur_grad_scale < 4096.0 and params.batch_idx_train % 400 == 0
701
+ ):
702
+ scaler.update(cur_grad_scale * 2.0)
703
+ if cur_grad_scale < 0.01:
704
+ if not saved_bad_model:
705
+ save_bad_model(suffix="-first-warning")
706
+ saved_bad_model = True
707
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
708
+ if cur_grad_scale < 1.0e-05:
709
+ save_bad_model()
710
+ raise RuntimeError(
711
+ f"grad_scale is too small, exiting: {cur_grad_scale}"
712
+ )
713
+
714
+ if params.batch_idx_train % params.log_interval == 0:
715
+ cur_lr = max(scheduler.get_last_lr())
716
+ cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
717
+
718
+ logging.info(
719
+ f"Epoch {params.cur_epoch}, batch {batch_idx}, "
720
+ f"global_batch_idx: {params.batch_idx_train}, "
721
+ f"batch size: {batch_size}, "
722
+ f"loss[{loss_info}], tot_loss[{tot_loss}], "
723
+ f"cur_lr: {cur_lr:.2e}, "
724
+ + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
725
+ )
726
+
727
+ if tb_writer is not None:
728
+ tb_writer.add_scalar(
729
+ "train/learning_rate", cur_lr, params.batch_idx_train
730
+ )
731
+ loss_info.write_summary(
732
+ tb_writer, "train/current_", params.batch_idx_train
733
+ )
734
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
735
+ if params.use_fp16:
736
+ tb_writer.add_scalar(
737
+ "train/grad_scale",
738
+ cur_grad_scale,
739
+ params.batch_idx_train,
740
+ )
741
+
742
+ loss_value = tot_loss["loss"]
743
+ params.train_loss = loss_value
744
+ if params.train_loss < params.best_train_loss:
745
+ params.best_train_epoch = params.cur_epoch
746
+ params.best_train_loss = params.train_loss
747
+
748
+
749
+ def compute_validation_loss(
750
+ params: AttributeDict,
751
+ model: Union[nn.Module, DDP],
752
+ teacher_model: Optional[nn.Module],
753
+ valid_dl: torch.utils.data.DataLoader,
754
+ world_size: int = 1,
755
+ ) -> MetricsTracker:
756
+ """Run the validation process."""
757
+
758
+ model.eval()
759
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
760
+
761
+ # used to summary the stats over iterations
762
+ tot_loss = MetricsTracker()
763
+
764
+ for batch_idx, batch in enumerate(valid_dl):
765
+ tokens, features, features_lens = prepare_input(
766
+ params=params,
767
+ batch=batch,
768
+ device=device,
769
+ return_tokens=True,
770
+ return_feature=True,
771
+ )
772
+
773
+ loss, loss_info = compute_fbank_loss(
774
+ params=params,
775
+ model=model,
776
+ teacher_model=teacher_model,
777
+ features=features,
778
+ features_lens=features_lens,
779
+ tokens=tokens,
780
+ is_training=False,
781
+ )
782
+ assert loss.requires_grad is False
783
+ tot_loss = tot_loss + loss_info
784
+
785
+ if world_size > 1:
786
+ tot_loss.reduce(loss.device)
787
+
788
+ loss_value = tot_loss["loss"]
789
+ if loss_value < params.best_valid_loss:
790
+ params.best_valid_epoch = params.cur_epoch
791
+ params.best_valid_loss = loss_value
792
+
793
+ return tot_loss
794
+
795
+
796
+ def scan_pessimistic_batches_for_oom(
797
+ model: Union[nn.Module, DDP],
798
+ teacher_model: nn.Module,
799
+ train_dl: torch.utils.data.DataLoader,
800
+ optimizer: torch.optim.Optimizer,
801
+ params: AttributeDict,
802
+ ):
803
+ from lhotse.dataset import find_pessimistic_batches
804
+
805
+ logging.info(
806
+ "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
807
+ )
808
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
809
+
810
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
811
+ for criterion, cuts in batches.items():
812
+ batch = train_dl.dataset[cuts]
813
+ tokens, features, features_lens = prepare_input(
814
+ params=params,
815
+ batch=batch,
816
+ device=device,
817
+ return_tokens=True,
818
+ return_feature=True,
819
+ )
820
+ try:
821
+ with torch_autocast(dtype=torch.float16, enabled=params.use_fp16):
822
+
823
+ loss, loss_info = compute_fbank_loss(
824
+ params=params,
825
+ model=model,
826
+ teacher_model=teacher_model,
827
+ features=features,
828
+ features_lens=features_lens,
829
+ tokens=tokens,
830
+ is_training=True,
831
+ )
832
+ loss.backward()
833
+ optimizer.zero_grad()
834
+ except Exception as e:
835
+ if "CUDA out of memory" in str(e):
836
+ logging.error(
837
+ "Your GPU ran out of memory with the current "
838
+ "max_duration setting. We recommend decreasing "
839
+ "max_duration and trying again.\n"
840
+ f"Failing criterion: {criterion} "
841
+ f"(={crit_values[criterion]}) ..."
842
+ )
843
+ display_and_save_batch(batch, params=params)
844
+ raise
845
+ logging.info(
846
+ f"Maximum memory allocated so far is "
847
+ f"{torch.cuda.max_memory_allocated() // 1000000}MB"
848
+ )
849
+
850
+
851
+ def run(rank, world_size, args):
852
+ """
853
+ Args:
854
+ rank:
855
+ It is a value between 0 and `world_size-1`, which is
856
+ passed automatically by `mp.spawn()` in :func:`main`.
857
+ The node with rank 0 is responsible for saving checkpoint.
858
+ world_size:
859
+ Number of GPUs for DDP training.
860
+ args:
861
+ The return value of get_parser().parse_args()
862
+ """
863
+ params = get_params()
864
+ params.update(vars(args))
865
+ params.valid_interval = params.save_every_n
866
+ # Set epoch to a large number to ignore it.
867
+ if params.num_iters > 0:
868
+ params.num_epochs = 1000000
869
+ with open(params.model_config, "r") as f:
870
+ model_config = json.load(f)
871
+ params.update(model_config["model"])
872
+ params.update(model_config["feature"])
873
+
874
+ fix_random_seed(params.seed)
875
+ if world_size > 1:
876
+ setup_dist(rank, world_size, params.master_port)
877
+
878
+ os.makedirs(f"{params.exp_dir}", exist_ok=True)
879
+ copyfile(src=params.model_config, dst=f"{params.exp_dir}/model.json")
880
+ copyfile(src=params.token_file, dst=f"{params.exp_dir}/tokens.txt")
881
+ setup_logger(f"{params.exp_dir}/log/log-train")
882
+
883
+ if args.tensorboard and rank == 0:
884
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
885
+ else:
886
+ tb_writer = None
887
+
888
+ if torch.cuda.is_available():
889
+ params.device = torch.device("cuda", rank)
890
+ else:
891
+ params.device = torch.device("cpu")
892
+ logging.info(f"Device: {params.device}")
893
+
894
+ if params.tokenizer == "emilia":
895
+ tokenizer = EmiliaTokenizer(token_file=params.token_file)
896
+ elif params.tokenizer == "libritts":
897
+ tokenizer = LibriTTSTokenizer(token_file=params.token_file)
898
+ elif params.tokenizer == "espeak":
899
+ tokenizer = EspeakTokenizer(token_file=params.token_file, lang=params.lang)
900
+ else:
901
+ assert params.tokenizer == "simple"
902
+ tokenizer = SimpleTokenizer(token_file=params.token_file)
903
+
904
+ tokenizer_config = {"vocab_size": tokenizer.vocab_size, "pad_id": tokenizer.pad_id}
905
+ params.update(tokenizer_config)
906
+
907
+ logging.info(params)
908
+
909
+ logging.info("About to create model")
910
+
911
+ assert params.teacher_model is not None
912
+ logging.info(f"Loading pre-trained model from {params.teacher_model}")
913
+ model = ZipVoiceDistill(
914
+ **model_config["model"],
915
+ **tokenizer_config,
916
+ )
917
+ _ = load_checkpoint(
918
+ filename=params.teacher_model,
919
+ model=model,
920
+ strict=(params.distill_stage == "second"),
921
+ )
922
+
923
+ if params.distill_stage == "first":
924
+ teacher_model = ZipVoice(
925
+ **model_config["model"],
926
+ **tokenizer_config,
927
+ )
928
+ _ = load_checkpoint(
929
+ filename=params.teacher_model, model=teacher_model, strict=True
930
+ )
931
+ else:
932
+ teacher_model = copy.deepcopy(model)
933
+
934
+ num_param = sum([p.numel() for p in model.parameters()])
935
+ logging.info(f"Number of parameters : {num_param}")
936
+
937
+ model_avg: Optional[nn.Module] = None
938
+ if rank == 0:
939
+ # model_avg is only used with rank 0
940
+ model_avg = copy.deepcopy(model).to(torch.float64)
941
+ assert params.start_epoch > 0, params.start_epoch
942
+ if params.start_epoch > 1:
943
+ logging.info(f"Resuming from epoch {params.start_epoch}")
944
+ if params.distill_stage == "first":
945
+ checkpoints = resume_checkpoint(
946
+ params=params, model=model, model_avg=model_avg
947
+ )
948
+ else:
949
+ checkpoints = resume_checkpoint(
950
+ params=params,
951
+ model=model,
952
+ model_avg=model_avg,
953
+ model_ema=teacher_model,
954
+ )
955
+
956
+ model = model.to(params.device)
957
+ teacher_model.to(params.device)
958
+ teacher_model.eval()
959
+
960
+ if world_size > 1:
961
+ logging.info("Using DDP")
962
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
963
+
964
+ # only update the fm_decoder
965
+ num_trainable = 0
966
+ for name, p in model.named_parameters():
967
+ if "fm_decoder" in name:
968
+ p.requires_grad = True
969
+ num_trainable += p.numel()
970
+ else:
971
+ p.requires_grad = False
972
+
973
+ logging.info(
974
+ "A total of {} trainable parameters ({:.3f}% of the whole model)".format(
975
+ num_trainable, num_trainable / num_param * 100
976
+ )
977
+ )
978
+
979
+ optimizer = ScaledAdam(
980
+ get_parameter_groups_with_lrs(
981
+ model,
982
+ lr=params.base_lr,
983
+ include_names=True,
984
+ ),
985
+ lr=params.base_lr, # should have no effect
986
+ clipping_scale=2.0,
987
+ )
988
+
989
+ scheduler = FixedLRScheduler(optimizer)
990
+
991
+ scaler = create_grad_scaler(enabled=params.use_fp16)
992
+
993
+ if params.start_epoch > 1 and checkpoints is not None:
994
+ # load state_dict for optimizers
995
+ if "optimizer" in checkpoints:
996
+ logging.info("Loading optimizer state dict")
997
+ optimizer.load_state_dict(checkpoints["optimizer"])
998
+
999
+ # load state_dict for schedulers
1000
+ if "scheduler" in checkpoints:
1001
+ logging.info("Loading scheduler state dict")
1002
+ scheduler.load_state_dict(checkpoints["scheduler"])
1003
+
1004
+ if "grad_scaler" in checkpoints:
1005
+ logging.info("Loading grad scaler state dict")
1006
+ scaler.load_state_dict(checkpoints["grad_scaler"])
1007
+
1008
+ if params.print_diagnostics:
1009
+ opts = diagnostics.TensorDiagnosticOptions(
1010
+ 512
1011
+ ) # allow 4 megabytes per sub-module
1012
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
1013
+
1014
+ if params.inf_check:
1015
+ register_inf_check_hooks(model)
1016
+
1017
+ def remove_short_and_long_utt(c: Cut, min_len: float, max_len: float):
1018
+ if c.duration < min_len or c.duration > max_len:
1019
+ return False
1020
+ return True
1021
+
1022
+ _remove_short_and_long_utt = partial(
1023
+ remove_short_and_long_utt, min_len=params.min_len, max_len=params.max_len
1024
+ )
1025
+
1026
+ datamodule = TtsDataModule(args)
1027
+ if params.dataset == "emilia":
1028
+ train_cuts = CutSet.mux(
1029
+ datamodule.train_emilia_EN_cuts(),
1030
+ datamodule.train_emilia_ZH_cuts(),
1031
+ weights=[46000, 49000],
1032
+ )
1033
+ train_cuts = train_cuts.filter(_remove_short_and_long_utt)
1034
+ dev_cuts = CutSet.mux(
1035
+ datamodule.dev_emilia_EN_cuts(),
1036
+ datamodule.dev_emilia_ZH_cuts(),
1037
+ weights=[0.5, 0.5],
1038
+ )
1039
+ elif params.dataset == "libritts":
1040
+ train_cuts = datamodule.train_libritts_cuts()
1041
+ train_cuts = train_cuts.filter(_remove_short_and_long_utt)
1042
+ dev_cuts = datamodule.dev_libritts_cuts()
1043
+ else:
1044
+ assert params.dataset == "custom"
1045
+ train_cuts = datamodule.train_custom_cuts(params.train_manifest)
1046
+ train_cuts = train_cuts.filter(_remove_short_and_long_utt)
1047
+ dev_cuts = datamodule.dev_custom_cuts(params.dev_manifest)
1048
+ # To avoid OOM issues due to too long dev cuts
1049
+ dev_cuts = dev_cuts.filter(_remove_short_and_long_utt)
1050
+
1051
+ if params.tokenizer in ["emilia", "espeak", "dialog"]:
1052
+ if not hasattr(train_cuts[0].supervisions[0], "tokens") or not hasattr(
1053
+ dev_cuts[0].supervisions[0], "tokens"
1054
+ ):
1055
+ logging.warning(
1056
+ f"Using {params.tokenizer} tokenizer but tokens are not prepared,"
1057
+ f"will tokenize on-the-fly, which can slow down training significantly."
1058
+ )
1059
+ _tokenize_text = partial(tokenize_text, tokenizer=tokenizer)
1060
+ train_cuts = train_cuts.map(_tokenize_text)
1061
+ dev_cuts = dev_cuts.map(_tokenize_text)
1062
+
1063
+ train_dl = datamodule.train_dataloaders(train_cuts)
1064
+
1065
+ valid_dl = datamodule.dev_dataloaders(dev_cuts)
1066
+
1067
+ if params.scan_oom:
1068
+ scan_pessimistic_batches_for_oom(
1069
+ model=model,
1070
+ teacher_model=teacher_model,
1071
+ train_dl=train_dl,
1072
+ optimizer=optimizer,
1073
+ params=params,
1074
+ )
1075
+ logging.info("Training started")
1076
+
1077
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
1078
+ logging.info(f"Start epoch {epoch}")
1079
+
1080
+ scheduler.step_epoch(epoch - 1)
1081
+ fix_random_seed(params.seed + epoch - 1)
1082
+ train_dl.sampler.set_epoch(epoch - 1)
1083
+
1084
+ params.cur_epoch = epoch
1085
+
1086
+ if tb_writer is not None:
1087
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
1088
+
1089
+ train_one_epoch(
1090
+ params=params,
1091
+ model=model,
1092
+ model_avg=model_avg,
1093
+ teacher_model=teacher_model,
1094
+ optimizer=optimizer,
1095
+ scheduler=scheduler,
1096
+ train_dl=train_dl,
1097
+ valid_dl=valid_dl,
1098
+ scaler=scaler,
1099
+ tb_writer=tb_writer,
1100
+ world_size=world_size,
1101
+ rank=rank,
1102
+ )
1103
+
1104
+ if params.num_iters > 0 and params.batch_idx_train > params.num_iters:
1105
+ break
1106
+
1107
+ if params.print_diagnostics:
1108
+ diagnostic.print_diagnostics()
1109
+ break
1110
+
1111
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
1112
+ save_checkpoint(
1113
+ filename=filename,
1114
+ params=params,
1115
+ model=model,
1116
+ model_avg=model_avg,
1117
+ model_ema=teacher_model,
1118
+ optimizer=optimizer,
1119
+ scheduler=scheduler,
1120
+ sampler=train_dl.sampler,
1121
+ scaler=scaler,
1122
+ rank=rank,
1123
+ )
1124
+
1125
+ if rank == 0:
1126
+ if params.best_train_epoch == params.cur_epoch:
1127
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
1128
+ copyfile(src=filename, dst=best_train_filename)
1129
+
1130
+ if params.best_valid_epoch == params.cur_epoch:
1131
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
1132
+ copyfile(src=filename, dst=best_valid_filename)
1133
+
1134
+ logging.info("Done!")
1135
+
1136
+ if world_size > 1:
1137
+ torch.distributed.barrier()
1138
+ cleanup_dist()
1139
+
1140
+
1141
+ def main():
1142
+ parser = get_parser()
1143
+ TtsDataModule.add_arguments(parser)
1144
+ args = parser.parse_args()
1145
+ args.exp_dir = Path(args.exp_dir)
1146
+
1147
+ world_size = args.world_size
1148
+ assert world_size >= 1
1149
+ if world_size > 1:
1150
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
1151
+ else:
1152
+ run(rank=0, world_size=1, args=args)
1153
+
1154
+
1155
+ if __name__ == "__main__":
1156
+ torch.set_num_threads(1)
1157
+ torch.set_num_interop_threads(1)
1158
+ main()
zipvoice/dataset/datamodule.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 Piotr Żelasko
2
+ # Copyright 2022-2024 Xiaomi Corporation (Authors: Mingshuang Luo,
3
+ # Zengwei Yao,
4
+ # Zengrui Jin,
5
+ # Han Zhu,
6
+ # Wei Kang)
7
+ #
8
+ # See ../../../../LICENSE for clarification regarding multiple authors
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+
22
+
23
+ import argparse
24
+ import logging
25
+ from functools import lru_cache
26
+ from pathlib import Path
27
+ from typing import Any, Dict, Optional
28
+
29
+ import torch
30
+ from lhotse import CutSet, load_manifest_lazy
31
+ from lhotse.dataset import DynamicBucketingSampler, SimpleCutSampler
32
+ from lhotse.dataset.input_strategies import OnTheFlyFeatures, PrecomputedFeatures
33
+ from lhotse.utils import fix_random_seed
34
+ from torch.utils.data import DataLoader
35
+
36
+ from zipvoice.dataset.dataset import SpeechSynthesisDataset
37
+ from zipvoice.utils.common import str2bool
38
+ from zipvoice.utils.feature import VocosFbank
39
+
40
+
41
+ class _SeedWorkers:
42
+ def __init__(self, seed: int):
43
+ self.seed = seed
44
+
45
+ def __call__(self, worker_id: int):
46
+ fix_random_seed(self.seed + worker_id)
47
+
48
+
49
+ SAMPLING_RATE = 24000
50
+
51
+
52
+ class TtsDataModule:
53
+ """
54
+ DataModule for tts experiments.
55
+ It assumes there is always one train and valid dataloader,
56
+ but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
57
+ and test-other).
58
+
59
+ It contains all the common data pipeline modules used in ASR
60
+ experiments, e.g.:
61
+ - dynamic batch size,
62
+ - bucketing samplers,
63
+ - cut concatenation,
64
+ - on-the-fly feature extraction
65
+
66
+ This class should be derived for specific corpora used in ASR tasks.
67
+ """
68
+
69
+ def __init__(self, args: argparse.Namespace):
70
+ self.args = args
71
+
72
+ @classmethod
73
+ def add_arguments(cls, parser: argparse.ArgumentParser):
74
+ group = parser.add_argument_group(
75
+ title="TTS data related options",
76
+ description="These options are used for the preparation of "
77
+ "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
78
+ "effective batch sizes, sampling strategies, applied data "
79
+ "augmentations, etc.",
80
+ )
81
+ group.add_argument(
82
+ "--manifest-dir",
83
+ type=Path,
84
+ default=Path("data/fbank"),
85
+ help="Path to directory with train/valid/test cuts.",
86
+ )
87
+ group.add_argument(
88
+ "--max-duration",
89
+ type=int,
90
+ default=200.0,
91
+ help="Maximum pooled recordings duration (seconds) in a "
92
+ "single batch. You can reduce it if it causes CUDA OOM.",
93
+ )
94
+ group.add_argument(
95
+ "--bucketing-sampler",
96
+ type=str2bool,
97
+ default=True,
98
+ help="When enabled, the batches will come from buckets of "
99
+ "similar duration (saves padding frames).",
100
+ )
101
+ group.add_argument(
102
+ "--num-buckets",
103
+ type=int,
104
+ default=30,
105
+ help="The number of buckets for the DynamicBucketingSampler"
106
+ "(you might want to increase it for larger datasets).",
107
+ )
108
+
109
+ group.add_argument(
110
+ "--on-the-fly-feats",
111
+ type=str2bool,
112
+ default=False,
113
+ help="When enabled, use on-the-fly cut mixing and feature "
114
+ "extraction. Will drop existing precomputed feature manifests "
115
+ "if available.",
116
+ )
117
+ group.add_argument(
118
+ "--shuffle",
119
+ type=str2bool,
120
+ default=True,
121
+ help="When enabled (=default), the examples will be "
122
+ "shuffled for each epoch.",
123
+ )
124
+ group.add_argument(
125
+ "--drop-last",
126
+ type=str2bool,
127
+ default=True,
128
+ help="Whether to drop last batch. Used by sampler.",
129
+ )
130
+ group.add_argument(
131
+ "--return-cuts",
132
+ type=str2bool,
133
+ default=False,
134
+ help="When enabled, each batch will have the "
135
+ "field: batch['cut'] with the cuts that "
136
+ "were used to construct it.",
137
+ )
138
+ group.add_argument(
139
+ "--num-workers",
140
+ type=int,
141
+ default=8,
142
+ help="The number of training dataloader workers that "
143
+ "collect the batches.",
144
+ )
145
+
146
+ group.add_argument(
147
+ "--input-strategy",
148
+ type=str,
149
+ default="PrecomputedFeatures",
150
+ help="AudioSamples or PrecomputedFeatures",
151
+ )
152
+
153
+ def train_dataloaders(
154
+ self,
155
+ cuts_train: CutSet,
156
+ sampler_state_dict: Optional[Dict[str, Any]] = None,
157
+ ) -> DataLoader:
158
+ """
159
+ Args:
160
+ cuts_train:
161
+ CutSet for training.
162
+ sampler_state_dict:
163
+ The state dict for the training sampler.
164
+ """
165
+ logging.info("About to create train dataset")
166
+
167
+ train = SpeechSynthesisDataset(
168
+ return_text=True,
169
+ return_tokens=True,
170
+ return_spk_ids=True,
171
+ feature_input_strategy=OnTheFlyFeatures(VocosFbank())
172
+ if self.args.on_the_fly_feats
173
+ else PrecomputedFeatures(),
174
+ return_cuts=self.args.return_cuts,
175
+ )
176
+
177
+ if self.args.bucketing_sampler:
178
+ logging.info("Using DynamicBucketingSampler.")
179
+ train_sampler = DynamicBucketingSampler(
180
+ cuts_train,
181
+ max_duration=self.args.max_duration,
182
+ shuffle=self.args.shuffle,
183
+ num_buckets=self.args.num_buckets,
184
+ buffer_size=self.args.num_buckets * 2000,
185
+ shuffle_buffer_size=self.args.num_buckets * 5000,
186
+ drop_last=self.args.drop_last,
187
+ )
188
+ else:
189
+ logging.info("Using SimpleCutSampler.")
190
+ train_sampler = SimpleCutSampler(
191
+ cuts_train,
192
+ max_duration=self.args.max_duration,
193
+ shuffle=self.args.shuffle,
194
+ )
195
+ logging.info("About to create train dataloader")
196
+
197
+ if sampler_state_dict is not None:
198
+ logging.info("Loading sampler state dict")
199
+ train_sampler.load_state_dict(sampler_state_dict)
200
+
201
+ # 'seed' is derived from the current random state, which will have
202
+ # previously been set in the main process.
203
+ seed = torch.randint(0, 100000, ()).item()
204
+ worker_init_fn = _SeedWorkers(seed)
205
+
206
+ train_dl = DataLoader(
207
+ train,
208
+ sampler=train_sampler,
209
+ batch_size=None,
210
+ num_workers=self.args.num_workers,
211
+ persistent_workers=False,
212
+ worker_init_fn=worker_init_fn,
213
+ )
214
+
215
+ return train_dl
216
+
217
+ def dev_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
218
+ logging.info("About to create dev dataset")
219
+ validate = SpeechSynthesisDataset(
220
+ return_text=True,
221
+ return_tokens=True,
222
+ return_spk_ids=True,
223
+ feature_input_strategy=OnTheFlyFeatures(VocosFbank())
224
+ if self.args.on_the_fly_feats
225
+ else PrecomputedFeatures(),
226
+ return_cuts=self.args.return_cuts,
227
+ )
228
+ dev_sampler = DynamicBucketingSampler(
229
+ cuts_valid,
230
+ max_duration=self.args.max_duration,
231
+ shuffle=False,
232
+ )
233
+ logging.info("About to create valid dataloader")
234
+ dev_dl = DataLoader(
235
+ validate,
236
+ sampler=dev_sampler,
237
+ batch_size=None,
238
+ num_workers=2,
239
+ persistent_workers=False,
240
+ )
241
+
242
+ return dev_dl
243
+
244
+ def test_dataloaders(self, cuts: CutSet) -> DataLoader:
245
+ logging.info("About to create test dataset")
246
+ test = SpeechSynthesisDataset(
247
+ return_text=True,
248
+ return_tokens=True,
249
+ return_spk_ids=True,
250
+ feature_input_strategy=OnTheFlyFeatures(VocosFbank())
251
+ if self.args.on_the_fly_feats
252
+ else PrecomputedFeatures(),
253
+ return_cuts=self.args.return_cuts,
254
+ return_audio=True,
255
+ )
256
+ test_sampler = DynamicBucketingSampler(
257
+ cuts,
258
+ max_duration=self.args.max_duration,
259
+ shuffle=False,
260
+ )
261
+ logging.info("About to create test dataloader")
262
+ test_dl = DataLoader(
263
+ test,
264
+ batch_size=None,
265
+ sampler=test_sampler,
266
+ num_workers=self.args.num_workers,
267
+ )
268
+ return test_dl
269
+
270
+ @lru_cache()
271
+ def train_custom_cuts(self, manifest_file) -> CutSet:
272
+ logging.info(f"About to get the custom training cuts {manifest_file}")
273
+ return load_manifest_lazy(manifest_file)
274
+
275
+ @lru_cache()
276
+ def dev_custom_cuts(self, manifest_file) -> CutSet:
277
+ logging.info(f"About to get the custom validation cuts {manifest_file}")
278
+ return load_manifest_lazy(manifest_file)
279
+
280
+ @lru_cache()
281
+ def train_emilia_EN_cuts(self) -> CutSet:
282
+ logging.info("About to get train the EN subset")
283
+ return load_manifest_lazy(self.args.manifest_dir / "emilia_cuts_EN.jsonl.gz")
284
+
285
+ @lru_cache()
286
+ def train_emilia_ZH_cuts(self) -> CutSet:
287
+ logging.info("About to get train the ZH subset")
288
+ return load_manifest_lazy(self.args.manifest_dir / "emilia_cuts_ZH.jsonl.gz")
289
+
290
+ @lru_cache()
291
+ def dev_emilia_EN_cuts(self) -> CutSet:
292
+ logging.info("About to get dev the EN subset")
293
+ return load_manifest_lazy(
294
+ self.args.manifest_dir / "emilia_cuts_EN-dev.jsonl.gz"
295
+ )
296
+
297
+ @lru_cache()
298
+ def dev_emilia_ZH_cuts(self) -> CutSet:
299
+ logging.info("About to get dev the ZH subset")
300
+ return load_manifest_lazy(
301
+ self.args.manifest_dir / "emilia_cuts_ZH-dev.jsonl.gz"
302
+ )
303
+
304
+ @lru_cache()
305
+ def train_libritts_cuts(self) -> CutSet:
306
+ logging.info(
307
+ "About to get the shuffled train-clean-100, \
308
+ train-clean-360 and train-other-500 cuts"
309
+ )
310
+ return load_manifest_lazy(
311
+ self.args.manifest_dir / "libritts_cuts_train-all-shuf.jsonl.gz"
312
+ )
313
+
314
+ @lru_cache()
315
+ def dev_libritts_cuts(self) -> CutSet:
316
+ logging.info("About to get dev-clean cuts")
317
+ return load_manifest_lazy(
318
+ self.args.manifest_dir / "libritts_cuts_dev-clean.jsonl.gz"
319
+ )
320
+
321
+ @lru_cache()
322
+ def train_opendialog_en_cuts(self) -> CutSet:
323
+ logging.info("About to ge the EN train subset of OpenDialog")
324
+ return load_manifest_lazy(
325
+ self.args.manifest_dir / "opendialog_cuts_EN-train.jsonl.gz"
326
+ )
327
+
328
+ @lru_cache()
329
+ def train_opendialog_zh_cuts(self) -> CutSet:
330
+ logging.info("About to get the ZH train subset of OpenDialog")
331
+ return load_manifest_lazy(
332
+ self.args.manifest_dir / "opendialog_cuts_ZH-train.jsonl.gz"
333
+ )
334
+
335
+ @lru_cache()
336
+ def dev_opendialog_en_cuts(self) -> CutSet:
337
+ logging.info("About to ge the EN dev subset of OpenDialog")
338
+ return load_manifest_lazy(
339
+ self.args.manifest_dir / "opendialog_cuts_EN-dev.jsonl.gz"
340
+ )
341
+
342
+ @lru_cache()
343
+ def dev_opendialog_zh_cuts(self) -> CutSet:
344
+ logging.info("About to get the ZH dev subset of OpenDialog")
345
+ return load_manifest_lazy(
346
+ self.args.manifest_dir / "opendialog_cuts_ZH-dev.jsonl.gz"
347
+ )
zipvoice/dataset/dataset.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Dict, List, Sequence, Union
2
+
3
+ import torch
4
+ from lhotse import CutSet, validate
5
+ from lhotse.dataset import PrecomputedFeatures
6
+ from lhotse.dataset.collation import collate_audio
7
+ from lhotse.dataset.input_strategies import BatchIO
8
+ from lhotse.utils import ifnone
9
+
10
+
11
+ class SpeechSynthesisDataset(torch.utils.data.Dataset):
12
+ """
13
+ The PyTorch Dataset for the speech synthesis task.
14
+ Each item in this dataset is a dict of:
15
+
16
+ .. code-block::
17
+
18
+ {
19
+ 'audio': (B x NumSamples) float tensor
20
+ 'features': (B x NumFrames x NumFeatures) float tensor
21
+ 'audio_lens': (B, ) int tensor
22
+ 'features_lens': (B, ) int tensor
23
+ 'text': List[str] of len B # when return_text=True
24
+ 'tokens': List[List[str]] # when return_tokens=True
25
+ 'speakers': List[str] of len B # when return_spk_ids=True
26
+ 'cut': List of Cuts # when return_cuts=True
27
+ }
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ cut_transforms: List[Callable[[CutSet], CutSet]] = None,
33
+ feature_input_strategy: BatchIO = PrecomputedFeatures(),
34
+ feature_transforms: Union[Sequence[Callable], Callable] = None,
35
+ return_text: bool = True,
36
+ return_tokens: bool = False,
37
+ return_spk_ids: bool = False,
38
+ return_cuts: bool = False,
39
+ return_audio: bool = False,
40
+ ) -> None:
41
+ super().__init__()
42
+
43
+ self.cut_transforms = ifnone(cut_transforms, [])
44
+ self.feature_input_strategy = feature_input_strategy
45
+
46
+ self.return_text = return_text
47
+ self.return_tokens = return_tokens
48
+ self.return_spk_ids = return_spk_ids
49
+ self.return_cuts = return_cuts
50
+ self.return_audio = return_audio
51
+
52
+ if feature_transforms is None:
53
+ feature_transforms = []
54
+ elif not isinstance(feature_transforms, Sequence):
55
+ feature_transforms = [feature_transforms]
56
+
57
+ assert all(
58
+ isinstance(transform, Callable) for transform in feature_transforms
59
+ ), "Feature transforms must be Callable"
60
+ self.feature_transforms = feature_transforms
61
+
62
+ def __getitem__(self, cuts: CutSet) -> Dict[str, torch.Tensor]:
63
+ validate_for_tts(cuts)
64
+
65
+ for transform in self.cut_transforms:
66
+ cuts = transform(cuts)
67
+
68
+ features, features_lens = self.feature_input_strategy(cuts)
69
+
70
+ for transform in self.feature_transforms:
71
+ features = transform(features)
72
+
73
+ batch = {
74
+ "features": features,
75
+ "features_lens": features_lens,
76
+ }
77
+
78
+ if self.return_audio:
79
+ audio, audio_lens = collate_audio(cuts)
80
+ batch["audio"] = audio
81
+ batch["audio_lens"] = audio_lens
82
+
83
+ if self.return_text:
84
+ text = [cut.supervisions[0].text for cut in cuts]
85
+ batch["text"] = text
86
+
87
+ if self.return_tokens:
88
+ tokens = [cut.supervisions[0].tokens for cut in cuts]
89
+ batch["tokens"] = tokens
90
+
91
+ if self.return_spk_ids:
92
+ batch["speakers"] = [cut.supervisions[0].speaker for cut in cuts]
93
+
94
+ if self.return_cuts:
95
+ batch["cut"] = [cut for cut in cuts]
96
+
97
+ return batch
98
+
99
+
100
+ def validate_for_tts(cuts: CutSet) -> None:
101
+ validate(cuts)
102
+ for cut in cuts:
103
+ assert (
104
+ len(cut.supervisions) == 1
105
+ ), "Only the Cuts with single supervision are supported."
zipvoice/eval/models/ecapa_tdnn_wavllm.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class ECAPA_TDNN_WAVLLM(nn.Module):
9
+ def __init__(
10
+ self,
11
+ feat_dim=80,
12
+ channels=512,
13
+ emb_dim=192,
14
+ global_context_att=False,
15
+ sr=16000,
16
+ ssl_model_path=None,
17
+ ):
18
+ super().__init__()
19
+ self.sr = sr
20
+
21
+ if ssl_model_path is None:
22
+ self.feature_extract = torch.hub.load("s3prl/s3prl", "wavlm_large")
23
+ else:
24
+ self.feature_extract = torch.hub.load(
25
+ os.path.dirname(ssl_model_path),
26
+ "wavlm_local",
27
+ source="local",
28
+ ckpt=ssl_model_path,
29
+ )
30
+
31
+ if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
32
+ self.feature_extract.model.encoder.layers[23].self_attn,
33
+ "fp32_attention",
34
+ ):
35
+ self.feature_extract.model.encoder.layers[
36
+ 23
37
+ ].self_attn.fp32_attention = False
38
+ if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
39
+ self.feature_extract.model.encoder.layers[11].self_attn,
40
+ "fp32_attention",
41
+ ):
42
+ self.feature_extract.model.encoder.layers[
43
+ 11
44
+ ].self_attn.fp32_attention = False
45
+
46
+ self.feat_num = self.get_feat_num()
47
+ self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
48
+
49
+ self.instance_norm = nn.InstanceNorm1d(feat_dim)
50
+ # self.channels = [channels] * 4 + [channels * 3]
51
+ self.channels = [channels] * 4 + [1536]
52
+
53
+ self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2)
54
+ self.layer2 = SE_Res2Block(
55
+ self.channels[0],
56
+ self.channels[1],
57
+ kernel_size=3,
58
+ stride=1,
59
+ padding=2,
60
+ dilation=2,
61
+ scale=8,
62
+ se_bottleneck_dim=128,
63
+ )
64
+ self.layer3 = SE_Res2Block(
65
+ self.channels[1],
66
+ self.channels[2],
67
+ kernel_size=3,
68
+ stride=1,
69
+ padding=3,
70
+ dilation=3,
71
+ scale=8,
72
+ se_bottleneck_dim=128,
73
+ )
74
+ self.layer4 = SE_Res2Block(
75
+ self.channels[2],
76
+ self.channels[3],
77
+ kernel_size=3,
78
+ stride=1,
79
+ padding=4,
80
+ dilation=4,
81
+ scale=8,
82
+ se_bottleneck_dim=128,
83
+ )
84
+
85
+ # self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1)
86
+ cat_channels = channels * 3
87
+ self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
88
+ self.pooling = AttentiveStatsPool(
89
+ self.channels[-1],
90
+ attention_channels=128,
91
+ global_context_att=global_context_att,
92
+ )
93
+ self.bn = nn.BatchNorm1d(self.channels[-1] * 2)
94
+ self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
95
+
96
+ def get_feat_num(self):
97
+ self.feature_extract.eval()
98
+ wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)]
99
+ with torch.no_grad():
100
+ features = self.feature_extract(wav)
101
+ select_feature = features["hidden_states"]
102
+ if isinstance(select_feature, (list, tuple)):
103
+ return len(select_feature)
104
+ else:
105
+ return 1
106
+
107
+ def get_feat(self, x):
108
+ with torch.no_grad():
109
+ x = self.feature_extract([sample for sample in x])
110
+
111
+ x = x["hidden_states"]
112
+ if isinstance(x, (list, tuple)):
113
+ x = torch.stack(x, dim=0)
114
+ else:
115
+ x = x.unsqueeze(0)
116
+ norm_weights = (
117
+ F.softmax(self.feature_weight, dim=-1)
118
+ .unsqueeze(-1)
119
+ .unsqueeze(-1)
120
+ .unsqueeze(-1)
121
+ )
122
+ x = (norm_weights * x).sum(dim=0)
123
+ x = torch.transpose(x, 1, 2) + 1e-6
124
+
125
+ x = self.instance_norm(x)
126
+ return x
127
+
128
+ def forward(self, x):
129
+ x = self.get_feat(x)
130
+
131
+ out1 = self.layer1(x)
132
+ out2 = self.layer2(out1)
133
+ out3 = self.layer3(out2)
134
+ out4 = self.layer4(out3)
135
+
136
+ out = torch.cat([out2, out3, out4], dim=1)
137
+ out = F.relu(self.conv(out))
138
+ out = self.bn(self.pooling(out))
139
+ out = self.linear(out)
140
+
141
+ return out
142
+
143
+
144
+ # part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN
145
+
146
+ """ Res2Conv1d + BatchNorm1d + ReLU
147
+ """
148
+
149
+
150
+ class Res2Conv1dReluBn(nn.Module):
151
+ """
152
+ in_channels == out_channels == channels
153
+ """
154
+
155
+ def __init__(
156
+ self,
157
+ channels,
158
+ kernel_size=1,
159
+ stride=1,
160
+ padding=0,
161
+ dilation=1,
162
+ bias=True,
163
+ scale=4,
164
+ ):
165
+ super().__init__()
166
+ assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
167
+ self.scale = scale
168
+ self.width = channels // scale
169
+ self.nums = scale if scale == 1 else scale - 1
170
+
171
+ self.convs = []
172
+ self.bns = []
173
+ for i in range(self.nums):
174
+ self.convs.append(
175
+ nn.Conv1d(
176
+ self.width,
177
+ self.width,
178
+ kernel_size,
179
+ stride,
180
+ padding,
181
+ dilation,
182
+ bias=bias,
183
+ )
184
+ )
185
+ self.bns.append(nn.BatchNorm1d(self.width))
186
+ self.convs = nn.ModuleList(self.convs)
187
+ self.bns = nn.ModuleList(self.bns)
188
+
189
+ def forward(self, x):
190
+ out = []
191
+ spx = torch.split(x, self.width, 1)
192
+ for i in range(self.nums):
193
+ if i == 0:
194
+ sp = spx[i]
195
+ else:
196
+ sp = sp + spx[i]
197
+ # Order: conv -> relu -> bn
198
+ sp = self.convs[i](sp)
199
+ sp = self.bns[i](F.relu(sp))
200
+ out.append(sp)
201
+ if self.scale != 1:
202
+ out.append(spx[self.nums])
203
+ out = torch.cat(out, dim=1)
204
+
205
+ return out
206
+
207
+
208
+ """ Conv1d + BatchNorm1d + ReLU
209
+ """
210
+
211
+
212
+ class Conv1dReluBn(nn.Module):
213
+ def __init__(
214
+ self,
215
+ in_channels,
216
+ out_channels,
217
+ kernel_size=1,
218
+ stride=1,
219
+ padding=0,
220
+ dilation=1,
221
+ bias=True,
222
+ ):
223
+ super().__init__()
224
+ self.conv = nn.Conv1d(
225
+ in_channels,
226
+ out_channels,
227
+ kernel_size,
228
+ stride,
229
+ padding,
230
+ dilation,
231
+ bias=bias,
232
+ )
233
+ self.bn = nn.BatchNorm1d(out_channels)
234
+
235
+ def forward(self, x):
236
+ return self.bn(F.relu(self.conv(x)))
237
+
238
+
239
+ """ The SE connection of 1D case.
240
+ """
241
+
242
+
243
+ class SE_Connect(nn.Module):
244
+ def __init__(self, channels, se_bottleneck_dim=128):
245
+ super().__init__()
246
+ self.linear1 = nn.Linear(channels, se_bottleneck_dim)
247
+ self.linear2 = nn.Linear(se_bottleneck_dim, channels)
248
+
249
+ def forward(self, x):
250
+ out = x.mean(dim=2)
251
+ out = F.relu(self.linear1(out))
252
+ out = torch.sigmoid(self.linear2(out))
253
+ out = x * out.unsqueeze(2)
254
+
255
+ return out
256
+
257
+
258
+ """ SE-Res2Block of the ECAPA-TDNN architecture.
259
+ """
260
+
261
+
262
+ # def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale):
263
+ # return nn.Sequential(
264
+ # Conv1dReluBn(channels, 512, kernel_size=1, stride=1, padding=0),
265
+ # Res2Conv1dReluBn(512, kernel_size, stride, padding, dilation, scale=scale),
266
+ # Conv1dReluBn(512, channels, kernel_size=1, stride=1, padding=0),
267
+ # SE_Connect(channels)
268
+ # )
269
+
270
+
271
+ class SE_Res2Block(nn.Module):
272
+ def __init__(
273
+ self,
274
+ in_channels,
275
+ out_channels,
276
+ kernel_size,
277
+ stride,
278
+ padding,
279
+ dilation,
280
+ scale,
281
+ se_bottleneck_dim,
282
+ ):
283
+ super().__init__()
284
+ self.Conv1dReluBn1 = Conv1dReluBn(
285
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
286
+ )
287
+ self.Res2Conv1dReluBn = Res2Conv1dReluBn(
288
+ out_channels, kernel_size, stride, padding, dilation, scale=scale
289
+ )
290
+ self.Conv1dReluBn2 = Conv1dReluBn(
291
+ out_channels, out_channels, kernel_size=1, stride=1, padding=0
292
+ )
293
+ self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim)
294
+
295
+ self.shortcut = None
296
+ if in_channels != out_channels:
297
+ self.shortcut = nn.Conv1d(
298
+ in_channels=in_channels,
299
+ out_channels=out_channels,
300
+ kernel_size=1,
301
+ )
302
+
303
+ def forward(self, x):
304
+ residual = x
305
+ if self.shortcut:
306
+ residual = self.shortcut(x)
307
+
308
+ x = self.Conv1dReluBn1(x)
309
+ x = self.Res2Conv1dReluBn(x)
310
+ x = self.Conv1dReluBn2(x)
311
+ x = self.SE_Connect(x)
312
+
313
+ return x + residual
314
+
315
+
316
+ """ Attentive weighted mean and standard deviation pooling.
317
+ """
318
+
319
+
320
+ class AttentiveStatsPool(nn.Module):
321
+ def __init__(self, in_dim, attention_channels=128, global_context_att=False):
322
+ super().__init__()
323
+ self.global_context_att = global_context_att
324
+
325
+ # Use Conv1d with stride == 1 rather than Linear,
326
+ # then we don't need to transpose inputs.
327
+ if global_context_att:
328
+ self.linear1 = nn.Conv1d(
329
+ in_dim * 3, attention_channels, kernel_size=1
330
+ ) # equals W and b in the paper
331
+ else:
332
+ self.linear1 = nn.Conv1d(
333
+ in_dim, attention_channels, kernel_size=1
334
+ ) # equals W and b in the paper
335
+ self.linear2 = nn.Conv1d(
336
+ attention_channels, in_dim, kernel_size=1
337
+ ) # equals V and k in the paper
338
+
339
+ def forward(self, x):
340
+
341
+ if self.global_context_att:
342
+ context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
343
+ context_std = torch.sqrt(
344
+ torch.var(x, dim=-1, keepdim=True) + 1e-10
345
+ ).expand_as(x)
346
+ x_in = torch.cat((x, context_mean, context_std), dim=1)
347
+ else:
348
+ x_in = x
349
+
350
+ # DON'T use ReLU here! In experiments, I find ReLU hard to converge.
351
+ alpha = torch.tanh(self.linear1(x_in))
352
+ # alpha = F.relu(self.linear1(x_in))
353
+ alpha = torch.softmax(self.linear2(alpha), dim=2)
354
+ mean = torch.sum(alpha * x, dim=2)
355
+ residuals = torch.sum(alpha * (x**2), dim=2) - mean**2
356
+ std = torch.sqrt(residuals.clamp(min=1e-9))
357
+ return torch.cat([mean, std], dim=1)
zipvoice/eval/models/ecapa_tdnn_wavlm.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class ECAPA_TDNN_WAVLM(nn.Module):
9
+ def __init__(
10
+ self,
11
+ feat_dim=80,
12
+ channels=512,
13
+ emb_dim=192,
14
+ global_context_att=False,
15
+ sr=16000,
16
+ ssl_model_path=None,
17
+ ):
18
+ super().__init__()
19
+ self.sr = sr
20
+
21
+ if ssl_model_path is None:
22
+ self.feature_extract = torch.hub.load("s3prl/s3prl", "wavlm_large")
23
+ else:
24
+ self.feature_extract = torch.hub.load(
25
+ os.path.dirname(ssl_model_path),
26
+ "wavlm_local",
27
+ source="local",
28
+ ckpt=os.path.join(ssl_model_path, "wavlm_large.pt"),
29
+ )
30
+
31
+ if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
32
+ self.feature_extract.model.encoder.layers[23].self_attn,
33
+ "fp32_attention",
34
+ ):
35
+ self.feature_extract.model.encoder.layers[
36
+ 23
37
+ ].self_attn.fp32_attention = False
38
+ if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
39
+ self.feature_extract.model.encoder.layers[11].self_attn,
40
+ "fp32_attention",
41
+ ):
42
+ self.feature_extract.model.encoder.layers[
43
+ 11
44
+ ].self_attn.fp32_attention = False
45
+
46
+ self.feat_num = self.get_feat_num()
47
+ self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
48
+
49
+ self.instance_norm = nn.InstanceNorm1d(feat_dim)
50
+ # self.channels = [channels] * 4 + [channels * 3]
51
+ self.channels = [channels] * 4 + [1536]
52
+
53
+ self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2)
54
+ self.layer2 = SE_Res2Block(
55
+ self.channels[0],
56
+ self.channels[1],
57
+ kernel_size=3,
58
+ stride=1,
59
+ padding=2,
60
+ dilation=2,
61
+ scale=8,
62
+ se_bottleneck_dim=128,
63
+ )
64
+ self.layer3 = SE_Res2Block(
65
+ self.channels[1],
66
+ self.channels[2],
67
+ kernel_size=3,
68
+ stride=1,
69
+ padding=3,
70
+ dilation=3,
71
+ scale=8,
72
+ se_bottleneck_dim=128,
73
+ )
74
+ self.layer4 = SE_Res2Block(
75
+ self.channels[2],
76
+ self.channels[3],
77
+ kernel_size=3,
78
+ stride=1,
79
+ padding=4,
80
+ dilation=4,
81
+ scale=8,
82
+ se_bottleneck_dim=128,
83
+ )
84
+
85
+ # self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1)
86
+ cat_channels = channels * 3
87
+ self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
88
+ self.pooling = AttentiveStatsPool(
89
+ self.channels[-1],
90
+ attention_channels=128,
91
+ global_context_att=global_context_att,
92
+ )
93
+ self.bn = nn.BatchNorm1d(self.channels[-1] * 2)
94
+ self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
95
+
96
+ def get_feat_num(self):
97
+ self.feature_extract.eval()
98
+ wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)]
99
+ with torch.no_grad():
100
+ features = self.feature_extract(wav)
101
+ select_feature = features["hidden_states"]
102
+ if isinstance(select_feature, (list, tuple)):
103
+ return len(select_feature)
104
+ else:
105
+ return 1
106
+
107
+ def get_feat(self, x):
108
+ with torch.no_grad():
109
+ x = self.feature_extract([sample for sample in x])
110
+
111
+ x = x["hidden_states"]
112
+ if isinstance(x, (list, tuple)):
113
+ x = torch.stack(x, dim=0)
114
+ else:
115
+ x = x.unsqueeze(0)
116
+ norm_weights = (
117
+ F.softmax(self.feature_weight, dim=-1)
118
+ .unsqueeze(-1)
119
+ .unsqueeze(-1)
120
+ .unsqueeze(-1)
121
+ )
122
+ x = (norm_weights * x).sum(dim=0)
123
+ x = torch.transpose(x, 1, 2) + 1e-6
124
+
125
+ x = self.instance_norm(x)
126
+ return x
127
+
128
+ def forward(self, x):
129
+ x = self.get_feat(x)
130
+
131
+ out1 = self.layer1(x)
132
+ out2 = self.layer2(out1)
133
+ out3 = self.layer3(out2)
134
+ out4 = self.layer4(out3)
135
+
136
+ out = torch.cat([out2, out3, out4], dim=1)
137
+ out = F.relu(self.conv(out))
138
+ out = self.bn(self.pooling(out))
139
+ out = self.linear(out)
140
+
141
+ return out
142
+
143
+
144
+ # part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN
145
+
146
+ """ Res2Conv1d + BatchNorm1d + ReLU
147
+ """
148
+
149
+
150
+ class Res2Conv1dReluBn(nn.Module):
151
+ """
152
+ in_channels == out_channels == channels
153
+ """
154
+
155
+ def __init__(
156
+ self,
157
+ channels,
158
+ kernel_size=1,
159
+ stride=1,
160
+ padding=0,
161
+ dilation=1,
162
+ bias=True,
163
+ scale=4,
164
+ ):
165
+ super().__init__()
166
+ assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
167
+ self.scale = scale
168
+ self.width = channels // scale
169
+ self.nums = scale if scale == 1 else scale - 1
170
+
171
+ self.convs = []
172
+ self.bns = []
173
+ for i in range(self.nums):
174
+ self.convs.append(
175
+ nn.Conv1d(
176
+ self.width,
177
+ self.width,
178
+ kernel_size,
179
+ stride,
180
+ padding,
181
+ dilation,
182
+ bias=bias,
183
+ )
184
+ )
185
+ self.bns.append(nn.BatchNorm1d(self.width))
186
+ self.convs = nn.ModuleList(self.convs)
187
+ self.bns = nn.ModuleList(self.bns)
188
+
189
+ def forward(self, x):
190
+ out = []
191
+ spx = torch.split(x, self.width, 1)
192
+ for i in range(self.nums):
193
+ if i == 0:
194
+ sp = spx[i]
195
+ else:
196
+ sp = sp + spx[i]
197
+ # Order: conv -> relu -> bn
198
+ sp = self.convs[i](sp)
199
+ sp = self.bns[i](F.relu(sp))
200
+ out.append(sp)
201
+ if self.scale != 1:
202
+ out.append(spx[self.nums])
203
+ out = torch.cat(out, dim=1)
204
+
205
+ return out
206
+
207
+
208
+ """ Conv1d + BatchNorm1d + ReLU
209
+ """
210
+
211
+
212
+ class Conv1dReluBn(nn.Module):
213
+ def __init__(
214
+ self,
215
+ in_channels,
216
+ out_channels,
217
+ kernel_size=1,
218
+ stride=1,
219
+ padding=0,
220
+ dilation=1,
221
+ bias=True,
222
+ ):
223
+ super().__init__()
224
+ self.conv = nn.Conv1d(
225
+ in_channels,
226
+ out_channels,
227
+ kernel_size,
228
+ stride,
229
+ padding,
230
+ dilation,
231
+ bias=bias,
232
+ )
233
+ self.bn = nn.BatchNorm1d(out_channels)
234
+
235
+ def forward(self, x):
236
+ return self.bn(F.relu(self.conv(x)))
237
+
238
+
239
+ """ The SE connection of 1D case.
240
+ """
241
+
242
+
243
+ class SE_Connect(nn.Module):
244
+ def __init__(self, channels, se_bottleneck_dim=128):
245
+ super().__init__()
246
+ self.linear1 = nn.Linear(channels, se_bottleneck_dim)
247
+ self.linear2 = nn.Linear(se_bottleneck_dim, channels)
248
+
249
+ def forward(self, x):
250
+ out = x.mean(dim=2)
251
+ out = F.relu(self.linear1(out))
252
+ out = torch.sigmoid(self.linear2(out))
253
+ out = x * out.unsqueeze(2)
254
+
255
+ return out
256
+
257
+
258
+ """ SE-Res2Block of the ECAPA-TDNN architecture.
259
+ """
260
+
261
+
262
+ # def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale):
263
+ # return nn.Sequential(
264
+ # Conv1dReluBn(channels, 512, kernel_size=1, stride=1, padding=0),
265
+ # Res2Conv1dReluBn(512, kernel_size, stride, padding, dilation, scale=scale),
266
+ # Conv1dReluBn(512, channels, kernel_size=1, stride=1, padding=0),
267
+ # SE_Connect(channels)
268
+ # )
269
+
270
+
271
+ class SE_Res2Block(nn.Module):
272
+ def __init__(
273
+ self,
274
+ in_channels,
275
+ out_channels,
276
+ kernel_size,
277
+ stride,
278
+ padding,
279
+ dilation,
280
+ scale,
281
+ se_bottleneck_dim,
282
+ ):
283
+ super().__init__()
284
+ self.Conv1dReluBn1 = Conv1dReluBn(
285
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
286
+ )
287
+ self.Res2Conv1dReluBn = Res2Conv1dReluBn(
288
+ out_channels, kernel_size, stride, padding, dilation, scale=scale
289
+ )
290
+ self.Conv1dReluBn2 = Conv1dReluBn(
291
+ out_channels, out_channels, kernel_size=1, stride=1, padding=0
292
+ )
293
+ self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim)
294
+
295
+ self.shortcut = None
296
+ if in_channels != out_channels:
297
+ self.shortcut = nn.Conv1d(
298
+ in_channels=in_channels,
299
+ out_channels=out_channels,
300
+ kernel_size=1,
301
+ )
302
+
303
+ def forward(self, x):
304
+ residual = x
305
+ if self.shortcut:
306
+ residual = self.shortcut(x)
307
+
308
+ x = self.Conv1dReluBn1(x)
309
+ x = self.Res2Conv1dReluBn(x)
310
+ x = self.Conv1dReluBn2(x)
311
+ x = self.SE_Connect(x)
312
+
313
+ return x + residual
314
+
315
+
316
+ """ Attentive weighted mean and standard deviation pooling.
317
+ """
318
+
319
+
320
+ class AttentiveStatsPool(nn.Module):
321
+ def __init__(self, in_dim, attention_channels=128, global_context_att=False):
322
+ super().__init__()
323
+ self.global_context_att = global_context_att
324
+
325
+ # Use Conv1d with stride == 1 rather than Linear,
326
+ # then we don't need to transpose inputs.
327
+ if global_context_att:
328
+ self.linear1 = nn.Conv1d(
329
+ in_dim * 3, attention_channels, kernel_size=1
330
+ ) # equals W and b in the paper
331
+ else:
332
+ self.linear1 = nn.Conv1d(
333
+ in_dim, attention_channels, kernel_size=1
334
+ ) # equals W and b in the paper
335
+ self.linear2 = nn.Conv1d(
336
+ attention_channels, in_dim, kernel_size=1
337
+ ) # equals V and k in the paper
338
+
339
+ def forward(self, x):
340
+
341
+ if self.global_context_att:
342
+ context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
343
+ context_std = torch.sqrt(
344
+ torch.var(x, dim=-1, keepdim=True) + 1e-10
345
+ ).expand_as(x)
346
+ x_in = torch.cat((x, context_mean, context_std), dim=1)
347
+ else:
348
+ x_in = x
349
+
350
+ # DON'T use ReLU here! In experiments, I find ReLU hard to converge.
351
+ alpha = torch.tanh(self.linear1(x_in))
352
+ # alpha = F.relu(self.linear1(x_in))
353
+ alpha = torch.softmax(self.linear2(alpha), dim=2)
354
+ mean = torch.sum(alpha * x, dim=2)
355
+ residuals = torch.sum(alpha * (x**2), dim=2) - mean**2
356
+ std = torch.sqrt(residuals.clamp(min=1e-9))
357
+ return torch.cat([mean, std], dim=1)
zipvoice/eval/models/utmos.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ UTMOS strong model.
3
+ Implementation from https://github.com/tarepan/SpeechMOS
4
+
5
+ """
6
+
7
+ import math
8
+ from typing import List, Optional, Tuple
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+ import torchaudio # pyright: ignore [reportMissingTypeStubs]
13
+ from torch import Tensor, nn
14
+
15
+
16
+ class UTMOS22Strong(nn.Module):
17
+ """Saeki_2022 paper's `UTMOS strong learner` inference model
18
+ (w/o Phoneme encoder)."""
19
+
20
+ def __init__(self):
21
+ """Init."""
22
+
23
+ super().__init__() # pyright: ignore [reportUnknownMemberType]
24
+
25
+ feat_ssl, feat_domain_emb, feat_judge_emb, feat_rnn_h, feat_proj_h = (
26
+ 768,
27
+ 128,
28
+ 128,
29
+ 512,
30
+ 2048,
31
+ )
32
+ feat_cat = feat_ssl + feat_domain_emb + feat_judge_emb
33
+
34
+ # SSL/DataDomainEmb/JudgeIdEmb/BLSTM/Projection
35
+ self.wav2vec2 = Wav2Vec2Model()
36
+ self.domain_emb = nn.Parameter(
37
+ data=torch.empty(1, feat_domain_emb), requires_grad=False
38
+ )
39
+ self.judge_emb = nn.Parameter(
40
+ data=torch.empty(1, feat_judge_emb), requires_grad=False
41
+ )
42
+ self.blstm = nn.LSTM(
43
+ input_size=feat_cat,
44
+ hidden_size=feat_rnn_h,
45
+ batch_first=True,
46
+ bidirectional=True,
47
+ )
48
+ self.projection = nn.Sequential(
49
+ nn.Linear(feat_rnn_h * 2, feat_proj_h), nn.ReLU(), nn.Linear(feat_proj_h, 1)
50
+ )
51
+
52
+ def forward(self, wave: Tensor, sr: int) -> Tensor: # pylint: disable=invalid-name
53
+ """wave-to-score :: (B, T) -> (B,)"""
54
+
55
+ # Feature extraction :: (B, T) -> (B, Frame, Feat)
56
+ unit_series = self.wav2vec2(wave)
57
+ bsz, frm, _ = unit_series.size()
58
+
59
+ # DataDomain/JudgeId Embedding's Batch/Time expansion ::
60
+ # (B=1, Feat) -> (B=bsz, Frame=frm, Feat)
61
+ domain_series = self.domain_emb.unsqueeze(1).expand(bsz, frm, -1)
62
+ judge_series = self.judge_emb.unsqueeze(1).expand(bsz, frm, -1)
63
+
64
+ # Feature concatenation :: (B, Frame, Feat=f1) + (B, Frame, Feat=f2) +
65
+ # (B, Frame, Feat=f3) -> (B, Frame, Feat=f1+f2+f3)
66
+ cat_series = torch.cat([unit_series, domain_series, judge_series], dim=2)
67
+
68
+ # Frame-scale score estimation :: (B, Frame, Feat) -> (B, Frame, Feat)
69
+ # -> (B, Frame, Feat=1) - BLSTM/Projection
70
+ feat_series = self.blstm(cat_series)[0]
71
+ score_series = self.projection(feat_series)
72
+
73
+ # Utterance-scale score :: (B, Frame, Feat=1) -> (B, Feat=1)
74
+ # -> (B,) - Time averaging
75
+ utter_score = score_series.mean(dim=1).squeeze(1) * 2 + 3
76
+
77
+ return utter_score
78
+
79
+
80
+ class Wav2Vec2Model(nn.Module):
81
+ """Wav2Vev2."""
82
+
83
+ def __init__(self):
84
+ super().__init__() # pyright: ignore [reportUnknownMemberType]
85
+
86
+ feat_h1, feat_h2 = 512, 768
87
+ feature_enc_layers = (
88
+ [(feat_h1, 10, 5)] + [(feat_h1, 3, 2)] * 4 + [(feat_h1, 2, 2)] * 2
89
+ )
90
+
91
+ self.feature_extractor = ConvFeatureExtractionModel(
92
+ conv_layers=feature_enc_layers
93
+ ) # pyright: ignore [reportGeneralTypeIssues]
94
+ self.layer_norm = nn.LayerNorm(feat_h1)
95
+ self.post_extract_proj = nn.Linear(feat_h1, feat_h2)
96
+ self.dropout_input = nn.Dropout(0.1)
97
+ self.encoder = TransformerEncoder(feat_h2)
98
+
99
+ # Remnants
100
+ self.mask_emb = nn.Parameter(torch.FloatTensor(feat_h2))
101
+
102
+ def forward(self, source: Tensor):
103
+ """FeatureEncoder + ContextTransformer"""
104
+
105
+ # Feature encoding
106
+ features = self.feature_extractor(source)
107
+ features = features.transpose(1, 2)
108
+ features = self.layer_norm(features)
109
+ features = self.post_extract_proj(features)
110
+
111
+ # Context transformer
112
+ x = self.encoder(features)
113
+
114
+ return x
115
+
116
+
117
+ class ConvFeatureExtractionModel(nn.Module):
118
+ """Feature Encoder."""
119
+
120
+ def __init__(self, conv_layers: List[Tuple[int, int, int]]):
121
+ super().__init__() # pyright: ignore [reportUnknownMemberType]
122
+
123
+ def block(
124
+ n_in: int, n_out: int, k: int, stride: int, is_group_norm: bool = False
125
+ ):
126
+ if is_group_norm:
127
+ return nn.Sequential(
128
+ nn.Conv1d(n_in, n_out, k, stride=stride, bias=False),
129
+ nn.Dropout(p=0.0),
130
+ nn.GroupNorm(dim, dim, affine=True),
131
+ nn.GELU(),
132
+ )
133
+ else:
134
+ return nn.Sequential(
135
+ nn.Conv1d(n_in, n_out, k, stride=stride, bias=False),
136
+ nn.Dropout(p=0.0),
137
+ nn.GELU(),
138
+ )
139
+
140
+ in_d = 1
141
+ self.conv_layers = nn.ModuleList()
142
+ for i, params in enumerate(conv_layers):
143
+ (dim, k, stride) = params
144
+ self.conv_layers.append(block(in_d, dim, k, stride, is_group_norm=i == 0))
145
+ in_d = dim
146
+
147
+ def forward(self, series: Tensor) -> Tensor:
148
+ """:: (B, T) -> (B, Feat, Frame)"""
149
+
150
+ series = series.unsqueeze(1)
151
+ for conv in self.conv_layers:
152
+ series = conv(series)
153
+
154
+ return series
155
+
156
+
157
+ class TransformerEncoder(nn.Module):
158
+ """Transformer."""
159
+
160
+ def build_encoder_layer(self, feat: int):
161
+ """Layer builder."""
162
+ return TransformerSentenceEncoderLayer(
163
+ embedding_dim=feat,
164
+ ffn_embedding_dim=3072,
165
+ num_attention_heads=12,
166
+ activation_fn="gelu",
167
+ dropout=0.1,
168
+ attention_dropout=0.1,
169
+ activation_dropout=0.0,
170
+ layer_norm_first=False,
171
+ )
172
+
173
+ def __init__(self, feat: int):
174
+ super().__init__() # pyright: ignore [reportUnknownMemberType]
175
+
176
+ self.required_seq_len_multiple = 2
177
+
178
+ self.pos_conv = nn.Sequential(
179
+ *[
180
+ nn.utils.weight_norm(
181
+ nn.Conv1d(feat, feat, kernel_size=128, padding=128 // 2, groups=16),
182
+ name="weight",
183
+ dim=2,
184
+ ),
185
+ SamePad(128),
186
+ nn.GELU(),
187
+ ]
188
+ )
189
+ self.layer_norm = nn.LayerNorm(feat)
190
+ self.layers = nn.ModuleList([self.build_encoder_layer(feat) for _ in range(12)])
191
+
192
+ def forward(self, x: Tensor) -> Tensor:
193
+
194
+ x_conv = self.pos_conv(x.transpose(1, 2)).transpose(1, 2)
195
+ x = x + x_conv
196
+
197
+ x = self.layer_norm(x)
198
+
199
+ # pad to the sequence length dimension
200
+ x, pad_length = pad_to_multiple(
201
+ x, self.required_seq_len_multiple, dim=-2, value=0
202
+ )
203
+ if pad_length > 0:
204
+ padding_mask = x.new_zeros((x.size(0), x.size(1)), dtype=torch.bool)
205
+ padding_mask[:, -pad_length:] = True
206
+ else:
207
+ padding_mask, _ = pad_to_multiple(
208
+ None, self.required_seq_len_multiple, dim=-1, value=True
209
+ )
210
+
211
+ # :: (B, T, Feat) -> (T, B, Feat)
212
+ x = x.transpose(0, 1)
213
+ for layer in self.layers:
214
+ x = layer(x, padding_mask)
215
+ # :: (T, B, Feat) -> (B, T, Feat)
216
+ x = x.transpose(0, 1)
217
+
218
+ # undo paddding
219
+ if pad_length > 0:
220
+ x = x[:, :-pad_length]
221
+
222
+ return x
223
+
224
+
225
+ class SamePad(nn.Module):
226
+ """Tail inverse padding."""
227
+
228
+ def __init__(self, kernel_size: int):
229
+ super().__init__() # pyright: ignore [reportUnknownMemberType]
230
+ assert kernel_size % 2 == 0, "`SamePad` now support only even kernel."
231
+
232
+ def forward(self, x: Tensor) -> Tensor:
233
+ return x[:, :, :-1]
234
+
235
+
236
+ def pad_to_multiple(
237
+ x: Optional[Tensor], multiple: int, dim: int = -1, value: float = 0
238
+ ) -> Tuple[Optional[Tensor], int]:
239
+ """Tail padding."""
240
+ if x is None:
241
+ return None, 0
242
+ tsz = x.size(dim)
243
+ m = tsz / multiple
244
+ remainder = math.ceil(m) * multiple - tsz
245
+ if m.is_integer():
246
+ return x, 0
247
+ pad_offset = (0,) * (-1 - dim) * 2
248
+
249
+ return F.pad(x, (*pad_offset, 0, remainder), value=value), remainder
250
+
251
+
252
+ class TransformerSentenceEncoderLayer(nn.Module):
253
+ """Transformer Encoder Layer used in BERT/XLM style pre-trained models."""
254
+
255
+ def __init__(
256
+ self,
257
+ embedding_dim: int,
258
+ ffn_embedding_dim: int,
259
+ num_attention_heads: int,
260
+ activation_fn: str,
261
+ dropout: float,
262
+ attention_dropout: float,
263
+ activation_dropout: float,
264
+ layer_norm_first: bool,
265
+ ) -> None:
266
+ super().__init__() # pyright: ignore [reportUnknownMemberType]
267
+
268
+ assert layer_norm_first is False, "`layer_norm_first` is fixed to `False`"
269
+ assert activation_fn == "gelu", "`activation_fn` is fixed to `gelu`"
270
+
271
+ feat = embedding_dim
272
+
273
+ self.self_attn = MultiheadAttention(
274
+ feat, num_attention_heads, attention_dropout
275
+ )
276
+ self.dropout1 = nn.Dropout(dropout)
277
+ self.dropout2 = nn.Dropout(activation_dropout)
278
+ self.dropout3 = nn.Dropout(dropout)
279
+ self.fc1 = nn.Linear(feat, ffn_embedding_dim)
280
+ self.fc2 = nn.Linear(ffn_embedding_dim, feat)
281
+ self.self_attn_layer_norm = nn.LayerNorm(feat)
282
+ self.final_layer_norm = nn.LayerNorm(feat)
283
+
284
+ def forward(self, x: Tensor, self_attn_padding_mask: Optional[Tensor]):
285
+ # Res[Attn-Do]-LN
286
+ residual = x
287
+ x = self.self_attn(x, x, x, self_attn_padding_mask)
288
+ x = self.dropout1(x)
289
+ x = residual + x
290
+ x = self.self_attn_layer_norm(x)
291
+
292
+ # Res[SegFC-GELU-Do-SegFC-Do]-LN
293
+ residual = x
294
+ x = F.gelu(self.fc1(x)) # pyright: ignore [reportUnknownMemberType]
295
+ x = self.dropout2(x)
296
+ x = self.fc2(x)
297
+ x = self.dropout3(x)
298
+ x = residual + x
299
+ x = self.final_layer_norm(x)
300
+
301
+ return x
302
+
303
+
304
+ class MultiheadAttention(nn.Module):
305
+ """Multi-headed attention."""
306
+
307
+ def __init__(self, embed_dim: int, num_heads: int, dropout: float):
308
+ super().__init__() # pyright: ignore [reportUnknownMemberType]
309
+
310
+ self.embed_dim, self.num_heads, self.p_dropout = embed_dim, num_heads, dropout
311
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True)
312
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True)
313
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True)
314
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
315
+
316
+ def forward(
317
+ self,
318
+ query: Tensor,
319
+ key: Tensor,
320
+ value: Tensor,
321
+ key_padding_mask: Optional[Tensor],
322
+ ) -> Tensor:
323
+ """
324
+ Args:
325
+ query :: (T, B, Feat)
326
+ key_padding_mask :: (B, src_len) - mask to exclude keys that are pads
327
+ , where padding elements are indicated by 1s.
328
+ """
329
+ return F.multi_head_attention_forward(
330
+ query=query,
331
+ key=key,
332
+ value=value,
333
+ embed_dim_to_check=self.embed_dim,
334
+ num_heads=self.num_heads,
335
+ in_proj_weight=torch.empty([0]),
336
+ in_proj_bias=torch.cat(
337
+ (self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)
338
+ ),
339
+ bias_k=None,
340
+ bias_v=None,
341
+ add_zero_attn=False,
342
+ dropout_p=self.p_dropout,
343
+ out_proj_weight=self.out_proj.weight,
344
+ out_proj_bias=self.out_proj.bias,
345
+ training=False,
346
+ key_padding_mask=key_padding_mask.bool()
347
+ if key_padding_mask is not None
348
+ else None,
349
+ need_weights=False,
350
+ use_separate_proj_weight=True,
351
+ q_proj_weight=self.q_proj.weight,
352
+ k_proj_weight=self.k_proj.weight,
353
+ v_proj_weight=self.v_proj.weight,
354
+ )[0]
zipvoice/eval/mos/utmos.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2025 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """
19
+ Calculate UTMOS score with automatic Mean Opinion Score (MOS) prediction system
20
+ """
21
+ import argparse
22
+ import logging
23
+ import os
24
+ from typing import List
25
+
26
+ import numpy as np
27
+ import torch
28
+ from tqdm import tqdm
29
+
30
+ from zipvoice.eval.models.utmos import UTMOS22Strong
31
+ from zipvoice.eval.utils import load_waveform
32
+
33
+
34
+ def get_parser() -> argparse.ArgumentParser:
35
+ parser = argparse.ArgumentParser(
36
+ description="Calculate UTMOS score using UTMOS22Strong model."
37
+ )
38
+
39
+ parser.add_argument(
40
+ "--wav-path",
41
+ type=str,
42
+ required=True,
43
+ help="Path to the directory containing evaluated speech files.",
44
+ )
45
+ parser.add_argument(
46
+ "--model-dir",
47
+ type=str,
48
+ required=True,
49
+ help="Local path of our evaluatioin model repository."
50
+ "Download from https://huggingface.co/k2-fsa/TTS_eval_models."
51
+ "Will use 'tts_eval_models/mos/utmos22_strong_step7459_v1.pt'"
52
+ " in this script",
53
+ )
54
+
55
+ parser.add_argument(
56
+ "--extension",
57
+ type=str,
58
+ default="wav",
59
+ help="Extension of the speech files. Default: wav",
60
+ )
61
+ return parser
62
+
63
+
64
+ class UTMOSScore:
65
+ """Predicting UTMOS score for each audio clip."""
66
+
67
+ def __init__(self, model_path: str):
68
+ """
69
+ Initializes the UTMOS score evaluator with the specified model.
70
+
71
+ Args:
72
+ model_path (str): Path of the UTMOS model checkpoint.
73
+ """
74
+ self.sample_rate = 16000
75
+ self.device = (
76
+ torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
77
+ )
78
+ logging.info(f"Using device: {self.device}")
79
+
80
+ # Initialize and load the model
81
+ self.model = UTMOS22Strong()
82
+ try:
83
+ state_dict = torch.load(
84
+ model_path, map_location=lambda storage, loc: storage
85
+ )
86
+ self.model.load_state_dict(state_dict)
87
+ except Exception as e:
88
+ logging.error(f"Failed to load model from {model_path}: {e}")
89
+ raise
90
+
91
+ self.model.to(self.device)
92
+ self.model.eval()
93
+
94
+ @torch.no_grad()
95
+ def score_files(self, wav_paths: List[str]) -> List[float]:
96
+ """
97
+ Computes UTMOS scores for a list of audio files.
98
+
99
+ Args:
100
+ wav_paths (List[str]): List of paths to audio files.
101
+
102
+ Returns:
103
+ List[float]: List of UTMOS scores.
104
+ """
105
+ scores = []
106
+ for wav_path in tqdm(wav_paths, desc="Scoring audio files"):
107
+ # Load and preprocess waveform
108
+ speech = load_waveform(wav_path, self.sample_rate, device=self.device)
109
+ # Compute score
110
+ score = self.model(speech.unsqueeze(0), self.sample_rate)
111
+ scores.append(score.item())
112
+
113
+ return scores
114
+
115
+ def score_dir(self, dir_path: str, extension: str) -> float:
116
+ """
117
+ Computes the average UTMOS score for all files in a directory.
118
+
119
+ Args:
120
+ dir_path (str): Path to the directory containing audio files.
121
+
122
+ Returns:
123
+ float: Average UTMOS score for the directory.
124
+ """
125
+ logging.info(f"Calculating UTMOS score for {dir_path}")
126
+
127
+ # Get list of wav files
128
+ wav_files = [
129
+ os.path.join(dir_path, f)
130
+ for f in os.listdir(dir_path)
131
+ if f.lower().endswith(extension)
132
+ ]
133
+
134
+ if not wav_files:
135
+ raise ValueError(f"No audio files found in {dir_path}")
136
+
137
+ # Compute scores
138
+ scores = self.score_files(wav_files)
139
+
140
+ return float(np.mean(scores))
141
+
142
+
143
+ if __name__ == "__main__":
144
+
145
+ torch.set_num_threads(1)
146
+ torch.set_num_interop_threads(1)
147
+
148
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
149
+ logging.basicConfig(format=formatter, level=logging.INFO, force=True)
150
+
151
+ parser = get_parser()
152
+ args = parser.parse_args()
153
+
154
+ # Validate input path
155
+ if not os.path.isdir(args.wav_path):
156
+ logging.error(f"Invalid directory: {args.wav_path}")
157
+ exit(1)
158
+
159
+ # Initialize evaluator
160
+ model_path = os.path.join(args.model_dir, "mos/utmos22_strong_step7459_v1.pt")
161
+ if not os.path.exists(model_path):
162
+ logging.error(
163
+ "Please download evaluation models from "
164
+ "https://huggingface.co/k2-fsa/TTS_eval_models"
165
+ " and pass this dir with --model-dir"
166
+ )
167
+ exit(1)
168
+ utmos_evaluator = UTMOSScore(model_path)
169
+
170
+ # Compute UTMOS score
171
+ score = utmos_evaluator.score_dir(args.wav_path, args.extension)
172
+ print("-" * 50)
173
+ logging.info(f"UTMOS score: {score:.2f}")
174
+ print("-" * 50)
zipvoice/eval/speaker_similarity/cpsim.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2025 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """
19
+ Computes concatenated maximum permutation speaker similarity (cpSIM) scores using:
20
+ - A WavLM-based ECAPA-TDNN model for speaker embedding extraction.
21
+ - A pyannote pipeline for speaker diarization (segmenting speakers).
22
+ """
23
+ import argparse
24
+ import logging
25
+ import os
26
+ import warnings
27
+ from typing import List, Tuple
28
+
29
+ import numpy as np
30
+ import torch
31
+ import torch.nn.functional as F
32
+ from pyannote.audio import Pipeline
33
+ from tqdm import tqdm
34
+
35
+ from zipvoice.eval.models.ecapa_tdnn_wavlm import ECAPA_TDNN_WAVLM
36
+ from zipvoice.eval.utils import load_waveform
37
+
38
+ warnings.filterwarnings("ignore")
39
+
40
+
41
+ def get_parser() -> argparse.ArgumentParser:
42
+ parser = argparse.ArgumentParser(
43
+ description="Calculate concatenated maximum permutation speaker "
44
+ "similarity (cpSIM) score."
45
+ )
46
+ parser.add_argument(
47
+ "--wav-path",
48
+ type=str,
49
+ required=True,
50
+ help="Path to the directory containing evaluated speech files.",
51
+ )
52
+ parser.add_argument(
53
+ "--test-list",
54
+ type=str,
55
+ help="Path to the tsv file for speaker splitted prompts. "
56
+ "Each line contains (audio_name, prompt_text_1, prompt_text_2, "
57
+ "prompt_audio_1, prompt_audio_2, text) separated by tabs.",
58
+ )
59
+
60
+ parser.add_argument(
61
+ "--test-list-merge",
62
+ type=str,
63
+ help="Path to the tsv file for merged dialogue prompts. "
64
+ "Each line contains (audio_name, prompt_text_dialogue, "
65
+ "prompt_audio_dialogue, text) separated by tabs.",
66
+ )
67
+ parser.add_argument(
68
+ "--model-dir",
69
+ type=str,
70
+ required=True,
71
+ help="Local path of our evaluatioin model repository."
72
+ "Download from https://huggingface.co/k2-fsa/TTS_eval_models."
73
+ "Will use 'tts_eval_models/speaker_similarity/wavlm_large_finetune.pth'"
74
+ ", 'tts_eval_models/speaker_similarity/wavlm_large/' and "
75
+ "tts_eval_models/speaker_similarity/pyannote/ in this script",
76
+ )
77
+
78
+ parser.add_argument(
79
+ "--extension",
80
+ type=str,
81
+ default="wav",
82
+ help="Extension of the speech files. Default: wav",
83
+ )
84
+ return parser
85
+
86
+
87
+ class CpSpeakerSimilarity:
88
+ """
89
+ Computes concatenated maximum permutation speaker similarity (cpSIM) scores using:
90
+ - A WavLM-based ECAPA-TDNN model for speaker embedding extraction.
91
+ - A pyannote pipeline for speaker diarization (segmenting speakers).
92
+ """
93
+
94
+ def __init__(
95
+ self,
96
+ sv_model_path: str = "speaker_similarity/wavlm_large_finetune.pth",
97
+ ssl_model_path: str = "speaker_similarity/wavlm_large/",
98
+ pyannote_model_path: str = "speaker_similarity/pyannote/",
99
+ ):
100
+ """
101
+ Initializes the cpSIM evaluator with the specified models.
102
+
103
+ Args:
104
+ sv_model_path (str): Path of the wavlm-based ECAPA-TDNN model checkpoint.
105
+ ssl_model_path (str): Path of the wavlm SSL model directory.
106
+ pyannote_model_path (str): Path of the pyannote diarization model directory.
107
+ """
108
+ self.sample_rate = 16000
109
+ self.device = (
110
+ torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
111
+ )
112
+ logging.info(f"Using device: {self.device}")
113
+
114
+ # Initialize speaker verification model
115
+ self.sv_model = ECAPA_TDNN_WAVLM(
116
+ feat_dim=1024,
117
+ channels=512,
118
+ emb_dim=256,
119
+ sr=self.sample_rate,
120
+ ssl_model_path=ssl_model_path,
121
+ )
122
+ state_dict = torch.load(
123
+ sv_model_path, map_location=lambda storage, loc: storage
124
+ )
125
+ self.sv_model.load_state_dict(state_dict["model"], strict=False)
126
+ self.sv_model.to(self.device)
127
+ self.sv_model.eval()
128
+
129
+ # Initialize diarization pipeline
130
+ self.diarization_pipeline = Pipeline.from_pretrained(
131
+ os.path.join(pyannote_model_path, "pyannote_diarization_config.yaml")
132
+ )
133
+ self.diarization_pipeline.to(self.device)
134
+
135
+ @torch.no_grad()
136
+ def get_embeddings_with_diarization(
137
+ self, audio_paths: List[str]
138
+ ) -> List[List[torch.Tensor]]:
139
+ """
140
+ Extracts speaker embeddings from audio files
141
+ with speaker diarization (for 2-speaker conversations).
142
+
143
+ Args:
144
+ audio_paths: List of paths to audio files (each containing 2 speakers).
145
+
146
+ Returns:
147
+ List of embedding pairs, where each pair is
148
+ [embedding_speaker1, embedding_speaker2].
149
+ """
150
+
151
+ embeddings_list = []
152
+ for audio_path in tqdm(
153
+ audio_paths, desc="Extracting embeddings with diarization"
154
+ ):
155
+ # Load audio waveform
156
+ speech = load_waveform(
157
+ audio_path, self.sample_rate, device=self.device, max_seconds=120
158
+ )
159
+
160
+ # Perform speaker diarization (assumes 2 speakers)
161
+ diarization = self.diarization_pipeline(
162
+ {"waveform": speech.unsqueeze(0), "sample_rate": self.sample_rate},
163
+ num_speakers=2,
164
+ )
165
+
166
+ # Collect speech chunks for each speaker
167
+ speaker1_chunks = []
168
+ speaker2_chunks = []
169
+ for turn, _, speaker in diarization.itertracks(yield_label=True):
170
+ start_frame = int(turn.start * self.sample_rate)
171
+ end_frame = int(turn.end * self.sample_rate)
172
+ chunk = speech[start_frame:end_frame]
173
+
174
+ if speaker == "SPEAKER_00":
175
+ speaker1_chunks.append(chunk)
176
+ elif speaker == "SPEAKER_01":
177
+ speaker2_chunks.append(chunk)
178
+
179
+ # Handle cases where diarization fails to detect 2 speakers
180
+ if not (speaker1_chunks and speaker2_chunks):
181
+ logging.debug(
182
+ f"Insufficient speaker chunks in {audio_path} "
183
+ f"using full audio for both speakers"
184
+ )
185
+ speaker1_speech = speech
186
+ speaker2_speech = speech
187
+ else:
188
+ speaker1_speech = torch.cat(speaker1_chunks, dim=0)
189
+ speaker2_speech = torch.cat(speaker2_chunks, dim=0)
190
+
191
+ # Extract embeddings with no gradient computation
192
+ try:
193
+ emb_speaker1 = self.sv_model([speaker1_speech])
194
+ emb_speaker2 = self.sv_model([speaker2_speech])
195
+ except Exception as e:
196
+ logging.debug(
197
+ f"Encountered an error {e} when extracting embeddings with "
198
+ f"segmented speech, will use full audio for both speakers."
199
+ )
200
+ emb_speaker1 = self.sv_model([speech])
201
+ emb_speaker2 = self.sv_model([speech])
202
+
203
+ embeddings_list.append([emb_speaker1, emb_speaker2])
204
+
205
+ return embeddings_list
206
+
207
+ @torch.no_grad()
208
+ def get_embeddings_from_pairs(
209
+ self, audio_pairs: List[Tuple[str, str]]
210
+ ) -> List[List[torch.Tensor]]:
211
+ """
212
+ Extracts speaker embeddings from pairs of single-speaker audio files.
213
+
214
+ Args:
215
+ audio_pairs: List of tuples (path_speaker1, path_speaker2).
216
+
217
+ Returns:
218
+ List of embedding pairs, where each pair is
219
+ [embedding_speaker1, embedding_speaker2].
220
+ """
221
+ embeddings_list = []
222
+ for (path1, path2) in tqdm(
223
+ audio_pairs, desc="Extracting embeddings from pairs"
224
+ ):
225
+ # Load audio for each speaker
226
+ speech1 = load_waveform(path1, self.sample_rate, device=self.device)
227
+ speech2 = load_waveform(path2, self.sample_rate, device=self.device)
228
+
229
+ # Extract embeddings
230
+ emb_speaker1 = self.sv_model([speech1])
231
+ emb_speaker2 = self.sv_model([speech2])
232
+
233
+ embeddings_list.append([emb_speaker1, emb_speaker2])
234
+
235
+ return embeddings_list
236
+
237
+ def score(
238
+ self,
239
+ wav_path: str,
240
+ extension: str,
241
+ test_list: str,
242
+ prompt_mode: str,
243
+ ) -> float:
244
+ """
245
+ Computes the cpSIM score by comparing embeddings of prompt and evaluated speech.
246
+
247
+ Args:
248
+ wav_path: Directory containing evaluated speech files.
249
+ test_list: Path to test list file mapping evaluated files to prompts.
250
+ prompt_mode: Either "merge" (2-speaker prompt) or "split"
251
+ (two single-speaker prompts).
252
+
253
+ Returns:
254
+ Average cpSIM score across all test pairs.
255
+ """
256
+ logging.info(f"Calculating cpSIM score for {wav_path} (mode: {prompt_mode})")
257
+
258
+ # Load and parse test list
259
+ try:
260
+ with open(test_list, "r", encoding="utf-8") as f:
261
+ lines = [line.strip() for line in f if line.strip()]
262
+ except Exception as e:
263
+ logging.error(f"Failed to read test list {test_list}: {e}")
264
+ raise
265
+
266
+ if not lines:
267
+ raise ValueError(f"Test list {test_list} is empty")
268
+
269
+ # Collect valid prompt-eval audio pairs
270
+ prompt_audios = [] # For "merge": [path]; for "split": [(path1, path2)]
271
+ eval_audios = []
272
+
273
+ for line_num, line in enumerate(lines, 1):
274
+ parts = line.split("\t")
275
+ if prompt_mode == "merge":
276
+ if len(parts) != 4:
277
+ raise ValueError(f"Expected 4 columns, got {len(parts)}")
278
+ audio_name, prompt_text, prompt_audio, text = parts
279
+ eval_audio_path = os.path.join(wav_path, f"{audio_name}.{extension}")
280
+ prompt_audios.append(prompt_audio)
281
+
282
+ elif prompt_mode == "split":
283
+ if len(parts) != 6:
284
+ raise ValueError(f"Expected 6 columns, got {len(parts)}")
285
+ (
286
+ audio_name,
287
+ prompt_text1,
288
+ prompt_text2,
289
+ prompt_audio_1,
290
+ prompt_audio_2,
291
+ text,
292
+ ) = parts
293
+ eval_audio_path = os.path.join(wav_path, f"{audio_name}.{extension}")
294
+ prompt_audios.append((prompt_audio_1, prompt_audio_2))
295
+
296
+ else:
297
+ raise ValueError(f"Invalid prompt_mode: {prompt_mode}")
298
+
299
+ # Validate file existence
300
+ if not os.path.exists(eval_audio_path):
301
+ raise FileNotFoundError(f"Evaluated file not found: {eval_audio_path}")
302
+
303
+ if prompt_mode == "merge":
304
+ if not os.path.exists(prompt_audio):
305
+ raise FileNotFoundError(
306
+ f"Prompt merge file not found: {prompt_audio}"
307
+ )
308
+ else:
309
+ if not (
310
+ os.path.exists(prompt_audio_1) and os.path.exists(prompt_audio_2)
311
+ ):
312
+ raise FileNotFoundError(
313
+ f"One or more prompt files missing in {prompt_audio_1}, "
314
+ f"{prompt_audio_2}"
315
+ )
316
+
317
+ eval_audios.append(eval_audio_path)
318
+
319
+ if not prompt_audios or not eval_audios:
320
+ raise ValueError(f"No valid prompt-eval pairs found in {test_list}")
321
+
322
+ logging.info(f"Processing {len(prompt_audios)} valid test pairs")
323
+
324
+ # Extract embeddings for prompts and evaluations
325
+ if prompt_mode == "merge":
326
+ prompt_embeddings = self.get_embeddings_with_diarization(prompt_audios)
327
+ else:
328
+ prompt_embeddings = self.get_embeddings_from_pairs(prompt_audios)
329
+
330
+ eval_embeddings = self.get_embeddings_with_diarization(eval_audios)
331
+
332
+ if len(prompt_embeddings) != len(eval_embeddings):
333
+ raise RuntimeError(
334
+ f"Mismatch: {len(prompt_embeddings)} prompt vs "
335
+ f" {len(eval_embeddings)} eval embeddings"
336
+ )
337
+
338
+ # Calculate maximum permutation similarity scores
339
+ scores = []
340
+ for prompt_embs, eval_embs in zip(prompt_embeddings, eval_embeddings):
341
+ # Prompt and eval each have 2 embeddings: [emb1, emb2]
342
+ sim1 = F.cosine_similarity(
343
+ prompt_embs[0], eval_embs[0], dim=-1
344
+ ) + F.cosine_similarity(prompt_embs[1], eval_embs[1], dim=-1)
345
+ sim2 = F.cosine_similarity(
346
+ prompt_embs[0], eval_embs[1], dim=-1
347
+ ) + F.cosine_similarity(prompt_embs[1], eval_embs[0], dim=-1)
348
+ max_sim = torch.max(sim1, sim2).item() / 2 # Average the sum
349
+ scores.append(max_sim)
350
+
351
+ return float(np.mean(scores))
352
+
353
+
354
+ if __name__ == "__main__":
355
+
356
+ torch.set_num_threads(1)
357
+ torch.set_num_interop_threads(1)
358
+
359
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
360
+ logging.basicConfig(format=formatter, level=logging.INFO, force=True)
361
+
362
+ parser = get_parser()
363
+ args = parser.parse_args()
364
+
365
+ # Validate test list arguments
366
+ if not (args.test_list or args.test_list_merge):
367
+ raise ValueError("Either --test-list or --test-list-merge must be provided")
368
+ if args.test_list and args.test_list_merge:
369
+ raise ValueError(
370
+ "Only one of --test-list-split or --test-list-merge can be provided"
371
+ )
372
+ # Determine mode and test list
373
+ if args.test_list:
374
+ prompt_mode = "split"
375
+ test_list = args.test_list
376
+ else:
377
+ prompt_mode = "merge"
378
+ test_list = args.test_list_merge
379
+
380
+ # Initialize evaluator
381
+ sv_model_path = os.path.join(
382
+ args.model_dir, "speaker_similarity/wavlm_large_finetune.pth"
383
+ )
384
+ ssl_model_path = os.path.join(args.model_dir, "speaker_similarity/wavlm_large/")
385
+ pyannote_model_path = os.path.join(args.model_dir, "speaker_similarity/pyannote/")
386
+ if (
387
+ not os.path.exists(sv_model_path)
388
+ or not os.path.exists(ssl_model_path)
389
+ or not os.path.exists(pyannote_model_path)
390
+ ):
391
+ logging.error(
392
+ "Please download evaluation models from "
393
+ "https://huggingface.co/k2-fsa/TTS_eval_models"
394
+ " and pass this dir with --model-dir"
395
+ )
396
+ exit(1)
397
+ cp_sim = CpSpeakerSimilarity(
398
+ sv_model_path=sv_model_path,
399
+ ssl_model_path=ssl_model_path,
400
+ pyannote_model_path=pyannote_model_path,
401
+ )
402
+ # Compute similarity score
403
+ score = cp_sim.score(
404
+ wav_path=args.wav_path,
405
+ extension=args.extension,
406
+ test_list=test_list,
407
+ prompt_mode=prompt_mode,
408
+ )
409
+ print("-" * 50)
410
+ logging.info(f"cpSIM score: {score:.3f}")
411
+ print("-" * 50)
zipvoice/eval/speaker_similarity/sim.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2025 Xiaomi Corp. (authors: Han Zhu
3
+ # Wei Kang)
4
+ #
5
+ # See ../../../../LICENSE for clarification regarding multiple authors
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+
19
+ """
20
+ Computes speaker similarity (SIM-o) using a WavLM-based
21
+ ECAPA-TDNN speaker verification model.
22
+ """
23
+ import argparse
24
+ import logging
25
+ import os
26
+ import warnings
27
+ from typing import List
28
+
29
+ import numpy as np
30
+ import torch
31
+ from tqdm import tqdm
32
+
33
+ from zipvoice.eval.models.ecapa_tdnn_wavlm import ECAPA_TDNN_WAVLM
34
+ from zipvoice.eval.utils import load_waveform
35
+
36
+ warnings.filterwarnings("ignore")
37
+
38
+
39
+ def get_parser() -> argparse.ArgumentParser:
40
+ parser = argparse.ArgumentParser(
41
+ description="Calculate speaker similarity (SIM-o) score."
42
+ )
43
+
44
+ parser.add_argument(
45
+ "--wav-path",
46
+ type=str,
47
+ required=True,
48
+ help="Path to the directory containing evaluated speech files.",
49
+ )
50
+ parser.add_argument(
51
+ "--test-list",
52
+ type=str,
53
+ required=True,
54
+ help="Path to the file list that contains the correspondence between prompts "
55
+ "and evaluated speech. Each line contains (audio_name, prompt_text_1, "
56
+ "prompt_text_2, prompt_audio_1, prompt_audio_2, text) separated by tabs.",
57
+ )
58
+ parser.add_argument(
59
+ "--model-dir",
60
+ type=str,
61
+ required=True,
62
+ help="Local path of our evaluatioin model repository."
63
+ "Download from https://huggingface.co/k2-fsa/TTS_eval_models."
64
+ "Will use 'tts_eval_models/speaker_similarity/wavlm_large_finetune.pth'"
65
+ "and 'tts_eval_models/speaker_similarity/wavlm_large/' in this script",
66
+ )
67
+
68
+ parser.add_argument(
69
+ "--extension",
70
+ type=str,
71
+ default="wav",
72
+ help="Extension of the speech files. Default: wav",
73
+ )
74
+ return parser
75
+
76
+
77
+ class SpeakerSimilarity:
78
+ """
79
+ Computes speaker similarity (SIM-o) using a WavLM-based
80
+ ECAPA-TDNN speaker verification model.
81
+ """
82
+
83
+ def __init__(
84
+ self,
85
+ sv_model_path: str = "speaker_similarity/wavlm_large_finetune.pth",
86
+ ssl_model_path: str = "speaker_similarity/wavlm_large/",
87
+ ):
88
+ """
89
+ Initializes the speaker similarity evaluator with the specified models.
90
+
91
+ Args:
92
+ sv_model_path (str): Path of the wavlm-based ECAPA-TDNN model checkpoint.
93
+ ssl_model_path (str): Path of the wavlm SSL model directory.
94
+ """
95
+ self.sample_rate = 16000
96
+ self.device = (
97
+ torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
98
+ )
99
+ logging.info(f"Using device: {self.device}")
100
+ self.model = ECAPA_TDNN_WAVLM(
101
+ feat_dim=1024,
102
+ channels=512,
103
+ emb_dim=256,
104
+ sr=self.sample_rate,
105
+ ssl_model_path=ssl_model_path,
106
+ )
107
+ state_dict = torch.load(
108
+ sv_model_path, map_location=lambda storage, loc: storage
109
+ )
110
+ self.model.load_state_dict(state_dict["model"], strict=False)
111
+ self.model.to(self.device)
112
+ self.model.eval()
113
+
114
+ @torch.no_grad()
115
+ def get_embeddings(self, wav_paths: List[str]) -> List[torch.Tensor]:
116
+ """
117
+ Extracts speaker embeddings from a list of audio files.
118
+
119
+ Args:
120
+ wav_paths (List[str]): List of paths to audio files.
121
+
122
+ Returns:
123
+ List[torch.Tensor]: List of speaker embeddings.
124
+ """
125
+ embeddings = []
126
+ for wav_path in tqdm(wav_paths, desc="Extracting speaker embeddings"):
127
+ # Load and preprocess waveform
128
+ speech = load_waveform(
129
+ wav_path, self.sample_rate, device=self.device, max_seconds=120
130
+ )
131
+ # Extract embedding
132
+ embedding = self.model([speech])
133
+ embeddings.append(embedding)
134
+
135
+ return embeddings
136
+
137
+ def score(self, wav_path: str, extension: str, test_list: str) -> float:
138
+ """
139
+ Computes the Speaker Similarity (SIM-o) score between reference and
140
+ evaluated speech.
141
+
142
+ Args:
143
+ wav_path (str): Path to the directory containing evaluated speech files.
144
+ test_list (str): Path to the test list file mapping evaluated files
145
+ to reference prompts.
146
+
147
+ Returns:
148
+ float: Average similarity score between reference and evaluated embeddings.
149
+ """
150
+ logging.info(f"Calculating Speaker Similarity (SIM-o) score for {wav_path}")
151
+ # Read test pairs
152
+ try:
153
+ with open(test_list, "r", encoding="utf-8") as f:
154
+ lines = [line.strip().split("\t") for line in f if line.strip()]
155
+ except Exception as e:
156
+ logging.error(f"Failed to read test list: {e}")
157
+ raise
158
+
159
+ if not lines:
160
+ raise ValueError(f"Test list {test_list} is empty or malformed")
161
+ # Parse test pairs
162
+ prompt_wavs = []
163
+ eval_wavs = []
164
+ for line in lines:
165
+ if len(line) != 4:
166
+ raise ValueError(f"Invalid line: {line}")
167
+ wav_name, prompt_text, prompt_wav, text = line
168
+ eval_wav_path = os.path.join(wav_path, f"{wav_name}.{extension}")
169
+ # Validate file existence
170
+ if not os.path.exists(prompt_wav):
171
+ raise FileNotFoundError(f"Prompt file not found: {prompt_wav}")
172
+ if not os.path.exists(eval_wav_path):
173
+ raise FileNotFoundError(f"Evaluated file not found: {eval_wav_path}")
174
+ prompt_wavs.append(prompt_wav)
175
+ eval_wavs.append(eval_wav_path)
176
+ logging.info(f"Found {len(prompt_wavs)} valid test pairs")
177
+ # Extract embeddings
178
+
179
+ prompt_embeddings = self.get_embeddings(prompt_wavs)
180
+ eval_embeddings = self.get_embeddings(eval_wavs)
181
+
182
+ if len(prompt_embeddings) != len(eval_embeddings):
183
+ raise RuntimeError(
184
+ f"Mismatch: {len(prompt_embeddings)} prompt vs "
185
+ f" {len(eval_embeddings)} eval embeddings"
186
+ )
187
+
188
+ # Calculate similarity scores
189
+ scores = []
190
+ for prompt_emb, eval_emb in zip(prompt_embeddings, eval_embeddings):
191
+ # Compute cosine similarity
192
+ similarity = torch.nn.functional.cosine_similarity(
193
+ prompt_emb, eval_emb, dim=-1
194
+ )
195
+ scores.append(similarity.item())
196
+
197
+ return float(np.mean(scores))
198
+
199
+
200
+ if __name__ == "__main__":
201
+
202
+ torch.set_num_threads(1)
203
+ torch.set_num_interop_threads(1)
204
+
205
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
206
+ logging.basicConfig(format=formatter, level=logging.INFO, force=True)
207
+
208
+ parser = get_parser()
209
+ args = parser.parse_args()
210
+ # Initialize evaluator
211
+ sv_model_path = os.path.join(
212
+ args.model_dir, "speaker_similarity/wavlm_large_finetune.pth"
213
+ )
214
+ ssl_model_path = os.path.join(args.model_dir, "speaker_similarity/wavlm_large/")
215
+ if not os.path.exists(sv_model_path) or not os.path.exists(ssl_model_path):
216
+ logging.error(
217
+ "Please download evaluation models from "
218
+ "https://huggingface.co/k2-fsa/TTS_eval_models"
219
+ " and pass this dir with --model-dir"
220
+ )
221
+ exit(1)
222
+ sim_evaluator = SpeakerSimilarity(
223
+ sv_model_path=sv_model_path, ssl_model_path=ssl_model_path
224
+ )
225
+ # Compute similarity score
226
+ score = sim_evaluator.score(args.wav_path, args.extension, args.test_list)
227
+ print("-" * 50)
228
+ logging.info(f"SIM-o score: {score:.3f}")
229
+ print("-" * 50)
zipvoice/eval/utils.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import librosa
4
+ import soundfile as sf
5
+ import torch
6
+
7
+
8
+ def load_waveform(
9
+ fname: str,
10
+ sample_rate: int,
11
+ dtype: str = "float32",
12
+ device: torch.device = torch.device("cpu"),
13
+ return_numpy: bool = False,
14
+ max_seconds: float = None,
15
+ ) -> torch.Tensor:
16
+ """
17
+ Load an audio file, preprocess it, and convert to a PyTorch tensor.
18
+
19
+ Args:
20
+ fname (str): Path to the audio file.
21
+ sample_rate (int): Target sample rate for resampling.
22
+ dtype (str, optional): Data type to load audio as (default: "float32").
23
+ device (torch.device, optional): Device to place the resulting tensor
24
+ on (default: CPU).
25
+ return_numpy (bool): If True, returns a NumPy array instead of a
26
+ PyTorch tensor.
27
+ max_seconds (int): Maximum length (seconds) of the audio tensor.
28
+ If the audio is longer than this, it will be truncated.
29
+
30
+ Returns:
31
+ torch.Tensor: Processed audio waveform as a PyTorch tensor,
32
+ with shape (num_samples,).
33
+
34
+ Notes:
35
+ - If the audio is stereo, it will be converted to mono by averaging channels.
36
+ - If the audio's sample rate differs from the target, it will be resampled.
37
+ """
38
+ # Load audio file with specified data type
39
+ wav_data, sr = sf.read(fname, dtype=dtype)
40
+
41
+ # Convert stereo to mono if necessary
42
+ if len(wav_data.shape) == 2:
43
+ wav_data = wav_data.mean(1)
44
+
45
+ # Resample to target sample rate if needed
46
+ if sr != sample_rate:
47
+ wav_data = librosa.resample(wav_data, orig_sr=sr, target_sr=sample_rate)
48
+
49
+ if max_seconds is not None:
50
+ # Trim to max length
51
+ max_length = sample_rate * max_seconds
52
+ if len(wav_data) > max_length:
53
+ wav_data = wav_data[:max_length]
54
+ logging.warning(
55
+ f"Wav file {fname} is longer than 2 minutes, "
56
+ f"truncated to 2 minutes to avoid OOM."
57
+ )
58
+ if return_numpy:
59
+ return wav_data
60
+ else:
61
+ wav_data = torch.from_numpy(wav_data)
62
+ return wav_data.to(device)
zipvoice/eval/wer/dialog.py ADDED
@@ -0,0 +1,493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2025 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """
19
+ Computes WER or cpWER for English dialogue speech with WhisperD
20
+ or compute WER for Chinese with Paraformer.
21
+ """
22
+
23
+ import argparse
24
+ import logging
25
+ import os
26
+ import re
27
+ import string
28
+ from typing import List, Tuple
29
+
30
+ import numpy as np
31
+ import torch
32
+ import zhconv
33
+ from funasr import AutoModel
34
+ from jiwer import compute_measures
35
+ from tqdm import tqdm
36
+ from transformers import (
37
+ WhisperForConditionalGeneration,
38
+ WhisperProcessor,
39
+ WhisperTokenizer,
40
+ pipeline,
41
+ )
42
+ from zhon.hanzi import punctuation
43
+
44
+ from zipvoice.eval.utils import load_waveform
45
+
46
+
47
+ def get_parser():
48
+ parser = argparse.ArgumentParser(
49
+ description="Computes WER or cpWER for English dialogue speech"
50
+ " with WhisperD or compute WER for Chinese with Paraformer.",
51
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
52
+ )
53
+
54
+ parser.add_argument(
55
+ "--wav-path",
56
+ type=str,
57
+ required=True,
58
+ help="Path to the directory containing speech files.",
59
+ )
60
+
61
+ parser.add_argument(
62
+ "--extension",
63
+ type=str,
64
+ default="wav",
65
+ help="Extension of the speech files. Default: wav",
66
+ )
67
+
68
+ parser.add_argument(
69
+ "--decode-path",
70
+ type=str,
71
+ default=None,
72
+ help="Path to the output file where WER information will be saved. "
73
+ "If not provided, results are only printed to console.",
74
+ )
75
+ parser.add_argument(
76
+ "--model-dir",
77
+ type=str,
78
+ required=True,
79
+ help="Local path of evaluation models repository. "
80
+ "Download from https://huggingface.co/k2-fsa/TTS_eval_models. "
81
+ "This script expects 'tts_eval_models/wer/whisper-d-v1a/' for English "
82
+ "and 'tts_eval_models/wer/paraformer-zh/' for Chinese within this directory.",
83
+ )
84
+ parser.add_argument(
85
+ "--test-list",
86
+ type=str,
87
+ default="test.tsv",
88
+ help="Path to the tsv file for speaker splitted prompts. "
89
+ "Each line contains (audio_name, prompt_text_1, prompt_text_2, "
90
+ "prompt_audio_1, prompt_audio_2, text) separated by tabs.",
91
+ )
92
+ parser.add_argument(
93
+ "--lang",
94
+ type=str,
95
+ choices=["zh", "en"],
96
+ required=True,
97
+ help="Language of the audio and transcripts for "
98
+ "decoding ('zh' for Chinese or 'en' for English).",
99
+ )
100
+ parser.add_argument(
101
+ "--cpwer",
102
+ action="store_true",
103
+ help="whether to compute the cpWER",
104
+ )
105
+ return parser
106
+
107
+
108
+ def load_en_model(model_dir, device):
109
+ model_path = os.path.join(model_dir, "wer/whisper-d-v1a/")
110
+ if not os.path.exists(model_path):
111
+ logging.error(
112
+ f"Error: Whisper model not found at {model_path}. "
113
+ "Please download evaluation modelss from "
114
+ "https://huggingface.co/k2-fsa/TTS_eval_models "
115
+ "and pass this directory with --model-dir."
116
+ )
117
+ exit(1)
118
+ logging.info(f"Loading Whisper model from: {model_path}")
119
+ processor = WhisperProcessor.from_pretrained(model_path)
120
+ tokenizer = WhisperTokenizer.from_pretrained(model_path)
121
+ model = WhisperForConditionalGeneration.from_pretrained(
122
+ model_path, torch_dtype=torch.float16
123
+ )
124
+
125
+ model.generation_config.suppress_tokens = None
126
+ model.generation_config.forced_decoder_ids = None
127
+ # Using pipline to handle long audios
128
+ pipe = pipeline(
129
+ "automatic-speech-recognition",
130
+ model=model,
131
+ tokenizer=tokenizer,
132
+ feature_extractor=processor.feature_extractor,
133
+ chunk_length_s=30,
134
+ device=device,
135
+ )
136
+ return pipe
137
+
138
+
139
+ def load_zh_model(model_dir):
140
+ model_path = os.path.join(model_dir, "wer/paraformer-zh/")
141
+ if not os.path.exists(model_path):
142
+ logging.error(
143
+ f"Error: Paraformer model not found at {model_path}. "
144
+ "Please download evaluation modelss from "
145
+ "https://huggingface.co/k2-fsa/TTS_eval_models "
146
+ "and pass this directory with --model-dir."
147
+ )
148
+ exit(1)
149
+ logging.info(f"Loading Paraformer model from: {model_path}")
150
+ model = AutoModel(model=model_path, disable_update=True)
151
+ return model
152
+
153
+
154
+ def post_process(text: str, lang: str) -> str:
155
+ """
156
+ Cleans and normalizes text for WER calculation.
157
+ Args:
158
+ text (str): The input text to be processed.
159
+ lang (str): The language of the input text.
160
+
161
+ Returns:
162
+ str: The cleaned and normalized text.
163
+ """
164
+ punctuation_all = punctuation + string.punctuation
165
+ text = re.sub(r"\[.*?\]|<.*?>|\(.*?\)", "", text)
166
+ for x in punctuation_all:
167
+ if x == "'":
168
+ continue
169
+ text = text.replace(x, "")
170
+ text = re.sub(r"\s+", " ", text).strip()
171
+ if lang == "zh":
172
+ text = " ".join([x for x in text])
173
+ elif lang == "en":
174
+ text = text.lower()
175
+ else:
176
+ raise NotImplementedError
177
+ return text
178
+
179
+
180
+ def process_one(hypothesis: str, truth: str, lang: str) -> tuple:
181
+ """
182
+ Computes WER and related metrics for a single hypothesis-truth pair.
183
+
184
+ Args:
185
+ hypothesis (str): The transcribed text from the ASR model.
186
+ truth (str): The ground truth transcript.
187
+
188
+ Returns:
189
+ tuple: A tuple containing:
190
+ - truth (str): Post-processed ground truth text.
191
+ - hypothesis (str): Post-processed hypothesis text.
192
+ - wer (float): Word Error Rate.
193
+ - substitutions (int): Number of substitutions.
194
+ - deletions (int): Number of deletions.
195
+ - insertions (int): Number of insertions.
196
+ - word_num (int): Number of words in the post-processed ground truth.
197
+ """
198
+ truth_processed = post_process(truth, lang)
199
+ hypothesis_processed = post_process(hypothesis, lang)
200
+
201
+ measures = compute_measures(truth_processed, hypothesis_processed)
202
+ word_num = len(truth_processed.split(" "))
203
+
204
+ return (
205
+ truth_processed,
206
+ hypothesis_processed,
207
+ measures["wer"],
208
+ measures["substitutions"],
209
+ measures["deletions"],
210
+ measures["insertions"],
211
+ word_num,
212
+ )
213
+
214
+
215
+ def process_one_cpwer(hypothesis: str, truth: str, lang: str) -> tuple:
216
+ """
217
+ Computes cpWER and related metrics for a single hypothesis-truth pair.
218
+
219
+ Args:
220
+ hypothesis (str): The transcribed text from the ASR model.
221
+ truth (str): The ground truth transcript.
222
+
223
+ Returns:
224
+ tuple: A tuple containing:
225
+ - truth (str): Post-processed ground truth text.
226
+ - hypothesis (str): Post-processed hypothesis text.
227
+ - wer (float): Word Error Rate.
228
+ - substitutions (int): Number of substitutions.
229
+ - deletions (int): Number of deletions.
230
+ - insertions (int): Number of insertions.
231
+ - word_num (int): Number of words in the post-processed ground truth.
232
+ """
233
+ assert lang == "en"
234
+ truths = split_dialogue(truth)
235
+ hypotheses = split_dialogue(hypothesis)
236
+ for i in range(2):
237
+ truths[i] = post_process(truths[i], lang)
238
+ hypotheses[i] = post_process(hypotheses[i], lang)
239
+
240
+ measures_1 = compute_measures(
241
+ f"{truths[0]} {truths[1]}", f"{hypotheses[0]} {hypotheses[1]}"
242
+ )
243
+ measures_2 = compute_measures(
244
+ f"{truths[0]} {truths[1]}", f"{hypotheses[1]} {hypotheses[0]}"
245
+ )
246
+ truth = f"[S1] {truths[0]} [S2] {truths[1]}"
247
+ if measures_1["wer"] < measures_2["wer"]:
248
+ measures = measures_1
249
+ hypothesis = f"[S1] {hypotheses[0]} [S2] {hypotheses[1]}"
250
+ else:
251
+ measures = measures_2
252
+ hypothesis = f"[S1] {hypotheses[1]} [S2] {hypotheses[0]}"
253
+ truth = re.sub(r"\s+", " ", truth)
254
+ hypothesis = re.sub(r"\s+", " ", hypothesis)
255
+ word_num = len(truth.split(" ")) - 2
256
+ return (
257
+ truth,
258
+ hypothesis,
259
+ measures["wer"],
260
+ measures["substitutions"],
261
+ measures["deletions"],
262
+ measures["insertions"],
263
+ word_num,
264
+ )
265
+
266
+
267
+ def split_dialogue(text):
268
+ segments = re.split(r"\[S[1-9]\]", text)
269
+ segments = [segment.strip() for segment in segments]
270
+ spk1_texts = " ".join(segments[::2])
271
+ spk2_texts = " ".join(segments[1::2])
272
+ return [spk1_texts, spk2_texts]
273
+
274
+
275
+ class SpeechEvalDataset(torch.utils.data.Dataset):
276
+ """
277
+ A PyTorch Dataset for loading speech waveforms and their transcripts
278
+ for evaluation. Will only keep shorter-than-30s waveforms if in `cpwer` mode.
279
+ """
280
+
281
+ def __init__(
282
+ self, wav_transcript_path_pair: List[Tuple[str, str]], cpwer: bool = False
283
+ ):
284
+ super().__init__()
285
+ if cpwer:
286
+ self.wav_transcript_path_pair = []
287
+ for wav_path, transcript in wav_transcript_path_pair:
288
+ waveform = load_waveform(
289
+ wav_path,
290
+ sample_rate=16000,
291
+ )
292
+ if len(waveform) / 16000 <= 30:
293
+ self.wav_transcript_path_pair.append((wav_path, transcript))
294
+ else:
295
+ self.wav_transcript_path_pair = wav_transcript_path_pair
296
+
297
+ def __len__(self):
298
+ return len(self.wav_transcript_path_pair)
299
+
300
+ def __getitem__(self, index: int):
301
+ waveform = load_waveform(
302
+ self.wav_transcript_path_pair[index][0],
303
+ sample_rate=16000,
304
+ return_numpy=True,
305
+ )
306
+ item = {
307
+ "array": waveform,
308
+ "sampling_rate": 16000,
309
+ "reference": self.wav_transcript_path_pair[index][1],
310
+ "wav_path": self.wav_transcript_path_pair[index][0],
311
+ }
312
+ return item
313
+
314
+
315
+ def main(test_list, wav_dir, extension, model_dir, decode_path, lang, cpwer, device):
316
+ logging.info(f"Calculating WER for {wav_dir} (cpwer={cpwer})")
317
+ if lang == "en":
318
+ model = load_en_model(model_dir, device=device)
319
+ elif lang == "zh":
320
+ model = load_zh_model(model_dir)
321
+ params = []
322
+ for line in open(test_list).readlines():
323
+ line = line.strip()
324
+ assert len(line.split("\t")) == 6
325
+ items = line.split("\t")
326
+ wav_name, text_ref = items[0], items[-1]
327
+ file_path = os.path.join(wav_dir, wav_name + "." + extension)
328
+ assert os.path.exists(file_path), f"{file_path}"
329
+ params.append((file_path, text_ref))
330
+
331
+ if decode_path:
332
+ # Ensure the output directory exists
333
+ decode_dir = os.path.dirname(decode_path)
334
+ if decode_dir and not os.path.exists(decode_dir):
335
+ os.makedirs(decode_dir)
336
+ fout = open(decode_path, "w", encoding="utf8")
337
+ logging.info(f"Saving detailed WER results to: {decode_path}")
338
+ fout.write(
339
+ "Name\tWER\tTruth\tHypothesis\tInsertions\tDeletions\tSubstitutions\n"
340
+ )
341
+
342
+ # Initialize metrics for overall WER calculation
343
+ wers = []
344
+ inses = []
345
+ deles = []
346
+ subses = []
347
+ word_nums = 0
348
+ if cpwer:
349
+ cp_wers = []
350
+ cp_inses = []
351
+ cp_deles = []
352
+ cp_subses = []
353
+ cp_word_nums = 0
354
+ if decode_path:
355
+ fout = open(decode_path, "w")
356
+ if lang == "zh":
357
+ for wav_path, text_ref in tqdm(params):
358
+ res = model.generate(input=wav_path, batch_size_s=300, disable_pbar=True)
359
+ transcription = res[0]["text"]
360
+ transcription = zhconv.convert(transcription, "zh-cn")
361
+
362
+ truth, hypo, wer, subs, dele, inse, word_num = process_one(
363
+ transcription, text_ref, lang
364
+ )
365
+ if decode_path:
366
+ fout.write(
367
+ f"{wav_path}\t{wer}\t{truth}\t{hypo}\t{inse}\t{dele}\t{subs}\n"
368
+ )
369
+ wers.append(float(wer))
370
+ inses.append(float(inse))
371
+ deles.append(float(dele))
372
+ subses.append(float(subs))
373
+ word_nums += word_num
374
+ elif lang == "en":
375
+ dataset = SpeechEvalDataset(params, cpwer)
376
+ bar = tqdm(
377
+ model(
378
+ dataset,
379
+ generate_kwargs={"language": lang, "task": "transcribe"},
380
+ batch_size=16,
381
+ ),
382
+ total=len(dataset),
383
+ )
384
+ for out in bar:
385
+ transcription = out["text"]
386
+ text_ref = out["reference"][0]
387
+ wav_path = out["wav_path"][0]
388
+ if cpwer:
389
+ (
390
+ cp_truth,
391
+ cp_hypo,
392
+ cp_wer,
393
+ cp_subs,
394
+ cp_dele,
395
+ cp_inse,
396
+ cp_word_num,
397
+ ) = process_one_cpwer(transcription, text_ref, lang)
398
+ if decode_path:
399
+ fout.write(
400
+ f"{wav_path}\t{cp_wer}\t{cp_truth}\t"
401
+ f"{cp_hypo}\t{cp_inse}\t{cp_dele}\t{cp_subs}\n"
402
+ )
403
+ cp_wers.append(float(cp_wer))
404
+ cp_inses.append(float(cp_inse))
405
+ cp_deles.append(float(cp_dele))
406
+ cp_subses.append(float(cp_subs))
407
+ cp_word_nums += cp_word_num
408
+ truth, hypo, wer, subs, dele, inse, word_num = process_one(
409
+ transcription, text_ref, lang
410
+ )
411
+ if decode_path:
412
+ fout.write(
413
+ f"{wav_path}\t{wer}\t{truth}\t{hypo}\t{inse}\t{dele}\t{subs}\n"
414
+ )
415
+ wers.append(float(wer))
416
+ inses.append(float(inse))
417
+ deles.append(float(dele))
418
+ subses.append(float(subs))
419
+ word_nums += word_num
420
+ if cpwer:
421
+ assert (
422
+ word_num == cp_word_num
423
+ ), f"{wav_path} has {word_num} words, but {cp_word_num} cp words"
424
+
425
+ print("-" * 50)
426
+ if cpwer:
427
+ cp_wer = round(
428
+ (np.sum(cp_subses) + np.sum(cp_deles) + np.sum(cp_inses))
429
+ / cp_word_nums
430
+ * 100,
431
+ 2,
432
+ )
433
+ cp_inse = np.sum(cp_inses)
434
+ cp_dele = np.sum(cp_deles)
435
+ cp_subs = np.sum(cp_subses)
436
+ logging.info(f"cpWER = {cp_wer}%")
437
+ logging.info(
438
+ f"Errors: {cp_inse} insertions, {cp_dele} deletions, {cp_subs} "
439
+ f"substitutions, over {cp_word_nums} reference words"
440
+ )
441
+ if decode_path:
442
+ fout.write(f"cpWER = {cp_wer}%\n")
443
+ fout.write(
444
+ f"Errors: {cp_inse} insertions, {cp_dele} deletions, {cp_subs} "
445
+ f"substitutions, over {cp_word_nums} reference words\n"
446
+ )
447
+ wer = round((np.sum(subses) + np.sum(deles) + np.sum(inses)) / word_nums * 100, 2)
448
+ inse = np.sum(inses)
449
+ dele = np.sum(deles)
450
+ subs = np.sum(subses)
451
+
452
+ logging.info(f"WER = {wer}%")
453
+ logging.info(
454
+ f"Errors: {inse} insertions, {dele} deletions, {subs} substitutions, "
455
+ f"over {word_nums} reference words"
456
+ )
457
+ print("-" * 50)
458
+
459
+ if decode_path:
460
+ fout.write(f"WER = {wer}%\n")
461
+ fout.write(
462
+ f"Errors: {inse} insertions, {dele} deletions, {subs} substitutions, "
463
+ f"over {word_nums} reference words\n"
464
+ )
465
+ fout.flush()
466
+
467
+
468
+ if __name__ == "__main__":
469
+
470
+ torch.set_num_threads(1)
471
+ torch.set_num_interop_threads(1)
472
+
473
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
474
+ logging.basicConfig(format=formatter, level=logging.INFO, force=True)
475
+
476
+ parser = get_parser()
477
+ args = parser.parse_args()
478
+ if torch.cuda.is_available():
479
+ device = torch.device("cuda", 0)
480
+ else:
481
+ device = torch.device("cpu")
482
+ if args.cpwer:
483
+ assert args.lang == "en", "Only English is supported for cpWER"
484
+ main(
485
+ args.test_list,
486
+ args.wav_path,
487
+ args.extension,
488
+ args.model_dir,
489
+ args.decode_path,
490
+ args.lang,
491
+ args.cpwer,
492
+ device,
493
+ )
zipvoice/eval/wer/hubert.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2025 Xiaomi Corp. (authors: Han Zhu,
3
+ # Wei Kang)
4
+ #
5
+ # See ../../../../LICENSE for clarification regarding multiple authors
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+
19
+ """
20
+ Computes word error rate (WER) with Hubert models for LibriSpeech test sets.
21
+ """
22
+ import argparse
23
+ import logging
24
+ import os
25
+ import re
26
+ from pathlib import Path
27
+
28
+ import numpy as np
29
+ import torch
30
+ from jiwer import compute_measures
31
+ from tqdm import tqdm
32
+ from transformers import pipeline
33
+
34
+ from zipvoice.eval.utils import load_waveform
35
+
36
+
37
+ def get_parser():
38
+ parser = argparse.ArgumentParser(
39
+ description="Computes WER with Hubert models.",
40
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
41
+ )
42
+
43
+ parser.add_argument(
44
+ "--wav-path",
45
+ type=str,
46
+ required=True,
47
+ help="Path to the directory containing speech files.",
48
+ )
49
+
50
+ parser.add_argument(
51
+ "--extension",
52
+ type=str,
53
+ default="wav",
54
+ help="Extension of the speech files. Default: wav",
55
+ )
56
+
57
+ parser.add_argument(
58
+ "--decode-path",
59
+ type=str,
60
+ default=None,
61
+ help="Path to the output file where WER information will be saved. "
62
+ "If not provided, results are only printed to console.",
63
+ )
64
+ parser.add_argument(
65
+ "--model-dir",
66
+ type=str,
67
+ required=True,
68
+ help="Local path of our evaluatioin model repository."
69
+ "Download from https://huggingface.co/k2-fsa/TTS_eval_models."
70
+ "Will use 'tts_eval_models/wer/hubert-large-ls960-ft/'"
71
+ " in this script",
72
+ )
73
+ parser.add_argument(
74
+ "--test-list",
75
+ type=str,
76
+ default="transcript.tsv",
77
+ help="path of the tsv file. Each line is in the format:"
78
+ "(audio_name, text) separated by tabs.",
79
+ )
80
+ parser.add_argument(
81
+ "--batch-size",
82
+ type=int,
83
+ default=16,
84
+ help="Batch size for decoding with the Hugging Face pipeline.",
85
+ )
86
+ return parser
87
+
88
+
89
+ def post_process(text: str) -> str:
90
+ """
91
+ Cleans and normalizes text for WER calculation.
92
+ Args:
93
+ text (str): The input text to be processed.
94
+
95
+ Returns:
96
+ str: The cleaned and normalized text.
97
+ """
98
+ text = text.replace("‘", "'").replace("’", "'")
99
+ text = re.sub(r"[^a-zA-Z0-9']", " ", text.lower())
100
+ text = re.sub(r"\s+", " ", text).strip()
101
+ return text
102
+
103
+
104
+ def process_one(hypothesis: str, truth: str) -> tuple:
105
+ """
106
+ Computes WER and related metrics for a single hypothesis-truth pair.
107
+
108
+ Args:
109
+ hypothesis (str): The transcribed text from the ASR model.
110
+ truth (str): The ground truth transcript.
111
+
112
+ Returns:
113
+ tuple: A tuple containing:
114
+ - truth (str): Post-processed ground truth text.
115
+ - hypothesis (str): Post-processed hypothesis text.
116
+ - wer (float): Word Error Rate.
117
+ - substitutions (int): Number of substitutions.
118
+ - deletions (int): Number of deletions.
119
+ - insertions (int): Number of insertions.
120
+ - word_num (int): Number of words in the post-processed ground truth.
121
+ """
122
+ truth_processed = post_process(truth)
123
+ hypothesis_processed = post_process(hypothesis)
124
+
125
+ measures = compute_measures(truth_processed, hypothesis_processed)
126
+ word_num = len(truth_processed.split(" "))
127
+
128
+ return (
129
+ truth_processed,
130
+ hypothesis_processed,
131
+ measures["wer"],
132
+ measures["substitutions"],
133
+ measures["deletions"],
134
+ measures["insertions"],
135
+ word_num,
136
+ )
137
+
138
+
139
+ class SpeechEvalDataset(torch.utils.data.Dataset):
140
+ """
141
+ A PyTorch Dataset for loading speech waveforms and their transcripts
142
+ for evaluation.
143
+ """
144
+
145
+ def __init__(self, wav_path: str, test_list: str, extension: str = "wav"):
146
+ """
147
+ Initializes the dataset.
148
+
149
+ Args:
150
+ wav_path (str): Path to the directory containing speech files.
151
+ test_list (str): Path to the TSV file with speech file names and
152
+ transcripts.
153
+ """
154
+ super().__init__()
155
+ self.wav_names = []
156
+ self.wav_paths = []
157
+ self.transcripts = []
158
+ with Path(test_list).open("r", encoding="utf8") as f:
159
+ meta = [item.split("\t") for item in f.read().rstrip().split("\n")]
160
+ for item in meta:
161
+ self.wav_names.append(item[0])
162
+ self.wav_paths.append(Path(wav_path, item[0] + "." + extension))
163
+ self.transcripts.append(item[-1])
164
+
165
+ def __len__(self):
166
+ return len(self.wav_paths)
167
+
168
+ def __getitem__(self, index: int):
169
+ waveform = load_waveform(
170
+ self.wav_paths[index],
171
+ sample_rate=16000,
172
+ return_numpy=True,
173
+ )
174
+ item = {
175
+ "array": waveform,
176
+ "sampling_rate": 16000,
177
+ "reference": self.transcripts[index],
178
+ "wav_name": self.wav_names[index],
179
+ }
180
+ return item
181
+
182
+
183
+ def main(test_list, wav_path, extension, model_dir, decode_path, batch_size, device):
184
+ logging.info(f"Calculating WER for {wav_path}")
185
+ model_path = os.path.join(model_dir, "wer/hubert-large-ls960-ft/")
186
+ if not os.path.exists(model_path):
187
+ logging.error(
188
+ "Please download evaluation models from "
189
+ "https://huggingface.co/k2-fsa/TTS_eval_models"
190
+ " and pass this dir with --model-dir"
191
+ )
192
+ exit(1)
193
+
194
+ asr_pipeline = pipeline(
195
+ "automatic-speech-recognition",
196
+ model=model_path,
197
+ device=device,
198
+ tokenizer=model_path,
199
+ )
200
+
201
+ dataset = SpeechEvalDataset(wav_path, test_list, extension)
202
+
203
+ transcription_results = tqdm(
204
+ asr_pipeline(
205
+ dataset,
206
+ generate_kwargs={"language": "english", "task": "transcribe"},
207
+ batch_size=batch_size,
208
+ ),
209
+ total=len(dataset),
210
+ )
211
+
212
+ # Initialize metrics for overall WER calculation
213
+ wers = []
214
+ inses = []
215
+ deles = []
216
+ subses = []
217
+ word_nums = 0
218
+ if decode_path:
219
+ # Ensure the output directory exists
220
+ decode_dir = os.path.dirname(decode_path)
221
+ if decode_dir and not os.path.exists(decode_dir):
222
+ os.makedirs(decode_dir)
223
+ fout = open(decode_path, "w", encoding="utf8")
224
+ logging.info(f"Saving detailed WER results to: {decode_path}")
225
+ fout.write(
226
+ "Name\tWER\tTruth\tHypothesis\tInsertions\tDeletions\tSubstitutions\n"
227
+ )
228
+ for out in transcription_results:
229
+ wav_name = out["wav_name"][0]
230
+ transcription = out["text"].strip()
231
+ text_ref = out["reference"][0].strip()
232
+ truth, hypo, wer, subs, dele, inse, word_num = process_one(
233
+ transcription, text_ref
234
+ )
235
+ if decode_path:
236
+ fout.write(f"{wav_name}\t{wer}\t{truth}\t{hypo}\t{inse}\t{dele}\t{subs}\n")
237
+ wers.append(float(wer))
238
+ inses.append(float(inse))
239
+ deles.append(float(dele))
240
+ subses.append(float(subs))
241
+ word_nums += word_num
242
+
243
+ wer = round((np.sum(subses) + np.sum(deles) + np.sum(inses)) / word_nums * 100, 2)
244
+ inse = np.sum(inses)
245
+ dele = np.sum(deles)
246
+ subs = np.sum(subses)
247
+ print("-" * 50)
248
+ logging.info(f"WER = {wer}%")
249
+ logging.info(
250
+ f"Errors: {inse} insertions, {dele} deletions, {subs} substitutions, "
251
+ f"over {word_nums} reference words"
252
+ )
253
+ print("-" * 50)
254
+ if decode_path:
255
+ fout.write(f"WER = {wer}%\n")
256
+ fout.write(
257
+ f"Errors: {inse} insertions, {dele} deletions, {subs} substitutions, "
258
+ f"over {word_nums} reference words\n"
259
+ )
260
+ fout.flush()
261
+
262
+
263
+ if __name__ == "__main__":
264
+
265
+ torch.set_num_threads(1)
266
+ torch.set_num_interop_threads(1)
267
+
268
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
269
+ logging.basicConfig(format=formatter, level=logging.INFO, force=True)
270
+
271
+ parser = get_parser()
272
+ args = parser.parse_args()
273
+ if torch.cuda.is_available():
274
+ device = torch.device("cuda", 0)
275
+ else:
276
+ device = torch.device("cpu")
277
+ main(
278
+ args.test_list,
279
+ args.wav_path,
280
+ args.extension,
281
+ args.model_dir,
282
+ args.decode_path,
283
+ args.batch_size,
284
+ device,
285
+ )
zipvoice/eval/wer/seedtts.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2025 Xiaomi Corp. (authors: Han Zhu,
3
+ # Wei Kang)
4
+ #
5
+ # See ../../../../LICENSE for clarification regarding multiple authors
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+
19
+ """
20
+ Computes word error rate (WER) with Whisper-large-v3 for English and
21
+ Paraformer for Chinese. Intended to evaluate WERs on Seed-TTS test sets.
22
+ """
23
+
24
+ import argparse
25
+ import logging
26
+ import os
27
+ import string
28
+
29
+ import numpy as np
30
+ import scipy
31
+ import soundfile as sf
32
+ import torch
33
+ import zhconv
34
+ from funasr import AutoModel
35
+ from jiwer import compute_measures
36
+ from tqdm import tqdm
37
+ from transformers import WhisperForConditionalGeneration, WhisperProcessor
38
+ from zhon.hanzi import punctuation
39
+
40
+
41
+ def get_parser():
42
+ parser = argparse.ArgumentParser(
43
+ description="Computes WER with Whisper and Paraformer models, "
44
+ "following Seed-TTS.",
45
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
46
+ )
47
+
48
+ parser.add_argument(
49
+ "--wav-path",
50
+ type=str,
51
+ required=True,
52
+ help="Path to the directory containing speech files.",
53
+ )
54
+
55
+ parser.add_argument(
56
+ "--extension",
57
+ type=str,
58
+ default="wav",
59
+ help="Extension of the speech files. Default: wav",
60
+ )
61
+
62
+ parser.add_argument(
63
+ "--decode-path",
64
+ type=str,
65
+ default=None,
66
+ help="Path to the output file where WER information will be saved. "
67
+ "If not provided, results are only printed to console.",
68
+ )
69
+ parser.add_argument(
70
+ "--model-dir",
71
+ type=str,
72
+ required=True,
73
+ help="Local path of evaluation models repository. "
74
+ "Download from https://huggingface.co/k2-fsa/TTS_eval_models. "
75
+ "This script expects 'tts_eval_models/wer/whisper-large-v3/' for English "
76
+ "and 'tts_eval_models/wer/paraformer-zh/' for Chinese within this directory.",
77
+ )
78
+ parser.add_argument(
79
+ "--test-list",
80
+ type=str,
81
+ default="test.tsv",
82
+ help="path of the tsv file. Each line is in the format:"
83
+ "(audio_name, prompt_text,prompt_audio, text) separated by tabs.",
84
+ )
85
+ parser.add_argument(
86
+ "--lang",
87
+ type=str,
88
+ choices=["zh", "en"],
89
+ required=True,
90
+ help="Language of the audio and transcripts for "
91
+ "decoding ('zh' for Chinese or 'en' for English).",
92
+ )
93
+ return parser
94
+
95
+
96
+ def load_en_model(model_dir):
97
+ model_path = os.path.join(model_dir, "wer/whisper-large-v3/")
98
+ if not os.path.exists(model_path):
99
+ logging.error(
100
+ f"Error: Whisper model not found at {model_path}. "
101
+ "Please download evaluation modelss from "
102
+ "https://huggingface.co/k2-fsa/TTS_eval_models "
103
+ "and pass this directory with --model-dir."
104
+ )
105
+ exit(1)
106
+ logging.info(f"Loading Whisper model from: {model_path}")
107
+ processor = WhisperProcessor.from_pretrained(model_path)
108
+ model = WhisperForConditionalGeneration.from_pretrained(model_path)
109
+ return processor, model
110
+
111
+
112
+ def load_zh_model(model_dir):
113
+ model_path = os.path.join(model_dir, "wer/paraformer-zh/")
114
+ if not os.path.exists(model_path):
115
+ logging.error(
116
+ f"Error: Paraformer model not found at {model_path}. "
117
+ "Please download evaluation modelss from "
118
+ "https://huggingface.co/k2-fsa/TTS_eval_models "
119
+ "and pass this directory with --model-dir."
120
+ )
121
+ exit(1)
122
+ logging.info(f"Loading Paraformer model from: {model_path}")
123
+ model = AutoModel(model=model_path, disable_update=True)
124
+ return model
125
+
126
+
127
+ def post_process(text: str, lang: str) -> str:
128
+ """
129
+ Cleans and normalizes text for WER calculation.
130
+ Args:
131
+ text (str): The input text to be processed.
132
+ lang (str): The language of the input text.
133
+
134
+ Returns:
135
+ str: The cleaned and normalized text.
136
+ """
137
+ punctuation_all = punctuation + string.punctuation
138
+ for x in punctuation_all:
139
+ if x == "'":
140
+ continue
141
+ text = text.replace(x, "")
142
+
143
+ text = text.replace(" ", " ")
144
+
145
+ if lang == "zh":
146
+ text = " ".join([x for x in text])
147
+ elif lang == "en":
148
+ text = text.lower()
149
+ else:
150
+ raise NotImplementedError
151
+ return text
152
+
153
+
154
+ def process_one(hypothesis: str, truth: str, lang: str) -> tuple:
155
+ """
156
+ Computes WER and related metrics for a single hypothesis-truth pair.
157
+
158
+ Args:
159
+ hypothesis (str): The transcribed text from the ASR model.
160
+ truth (str): The ground truth transcript.
161
+
162
+ Returns:
163
+ tuple: A tuple containing:
164
+ - truth (str): Post-processed ground truth text.
165
+ - hypothesis (str): Post-processed hypothesis text.
166
+ - wer (float): Word Error Rate.
167
+ - substitutions (int): Number of substitutions.
168
+ - deletions (int): Number of deletions.
169
+ - insertions (int): Number of insertions.
170
+ - word_num (int): Number of words in the post-processed ground truth.
171
+ """
172
+ truth_processed = post_process(truth, lang)
173
+ hypothesis_processed = post_process(hypothesis, lang)
174
+
175
+ measures = compute_measures(truth_processed, hypothesis_processed)
176
+ word_num = len(truth_processed.split(" "))
177
+
178
+ return (
179
+ truth_processed,
180
+ hypothesis_processed,
181
+ measures["wer"],
182
+ measures["substitutions"],
183
+ measures["deletions"],
184
+ measures["insertions"],
185
+ word_num,
186
+ )
187
+
188
+
189
+ def main(test_list, wav_path, extension, model_path, decode_path, lang, device):
190
+ logging.info(f"Calculating WER for {wav_path}")
191
+ if lang == "en":
192
+ processor, model = load_en_model(model_path)
193
+ model.to(device)
194
+ elif lang == "zh":
195
+ model = load_zh_model(model_path)
196
+ params = []
197
+ for line in open(test_list).readlines():
198
+ line = line.strip()
199
+ items = line.split("\t")
200
+ wav_name, text_ref = items[0], items[-1]
201
+ file_path = os.path.join(wav_path, wav_name + "." + extension)
202
+ assert os.path.exists(file_path), f"{file_path}"
203
+
204
+ params.append((file_path, text_ref))
205
+ # Initialize metrics for overall WER calculation
206
+ wers = []
207
+ inses = []
208
+ deles = []
209
+ subses = []
210
+ word_nums = 0
211
+ if decode_path:
212
+ # Ensure the output directory exists
213
+ decode_dir = os.path.dirname(decode_path)
214
+ if decode_dir and not os.path.exists(decode_dir):
215
+ os.makedirs(decode_dir)
216
+ fout = open(decode_path, "w")
217
+ for wav_path, text_ref in tqdm(params):
218
+ if lang == "en":
219
+ wav, sr = sf.read(wav_path)
220
+ if sr != 16000:
221
+ wav = scipy.signal.resample(wav, int(len(wav) * 16000 / sr))
222
+ input_features = processor(
223
+ wav, sampling_rate=16000, return_tensors="pt"
224
+ ).input_features
225
+ input_features = input_features.to(device)
226
+ forced_decoder_ids = processor.get_decoder_prompt_ids(
227
+ language="english", task="transcribe"
228
+ )
229
+ predicted_ids = model.generate(
230
+ input_features, forced_decoder_ids=forced_decoder_ids
231
+ )
232
+ transcription = processor.batch_decode(
233
+ predicted_ids, skip_special_tokens=True
234
+ )[0]
235
+ elif lang == "zh":
236
+ res = model.generate(input=wav_path, batch_size_s=300, disable_pbar=True)
237
+ transcription = res[0]["text"]
238
+ transcription = zhconv.convert(transcription, "zh-cn")
239
+
240
+ truth, hypo, wer, subs, dele, inse, word_num = process_one(
241
+ transcription, text_ref, lang
242
+ )
243
+ if decode_path:
244
+ fout.write(f"{wav_path}\t{wer}\t{truth}\t{hypo}\t{inse}\t{dele}\t{subs}\n")
245
+ wers.append(float(wer))
246
+ inses.append(float(inse))
247
+ deles.append(float(dele))
248
+ subses.append(float(subs))
249
+ word_nums += word_num
250
+
251
+ wer_avg = round(np.mean(wers) * 100, 2)
252
+ wer = round((np.sum(subses) + np.sum(deles) + np.sum(inses)) / word_nums * 100, 2)
253
+ inse = np.sum(inses)
254
+ dele = np.sum(deles)
255
+ subs = np.sum(subses)
256
+ print("-" * 50)
257
+ # The official evaluation codes of Seed-TTS uses the average of WERs
258
+ # instead of the weighted average of WERs.
259
+ logging.info(f"Seed-TTS WER: {wer_avg}%\n")
260
+ logging.info(f"WER: {wer}%\n")
261
+ logging.info(
262
+ f"Errors: {inse} insertions, {dele} deletions, {subs} substitutions, "
263
+ f"over {word_nums} reference words"
264
+ )
265
+ print("-" * 50)
266
+ if decode_path:
267
+ fout.write(f"SeedTTS WER: {wer_avg}%\n")
268
+ fout.write(f"WER: {wer}%\n")
269
+ fout.write(
270
+ f"Errors: {inse} insertions, {dele} deletions, {subs} substitutions, "
271
+ f"over {word_nums} reference words\n"
272
+ )
273
+ fout.flush()
274
+
275
+
276
+ if __name__ == "__main__":
277
+
278
+ torch.set_num_threads(1)
279
+ torch.set_num_interop_threads(1)
280
+
281
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
282
+ logging.basicConfig(format=formatter, level=logging.INFO, force=True)
283
+
284
+ parser = get_parser()
285
+ args = parser.parse_args()
286
+ if torch.cuda.is_available():
287
+ device = torch.device("cuda", 0)
288
+ else:
289
+ device = torch.device("cpu")
290
+ main(
291
+ args.test_list,
292
+ args.wav_path,
293
+ args.extension,
294
+ args.model_dir,
295
+ args.decode_path,
296
+ args.lang,
297
+ device,
298
+ )
zipvoice/models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # ZipVoice models package
zipvoice/models/modules/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # ZipVoice models modules package
zipvoice/models/modules/scaling.py ADDED
@@ -0,0 +1,1590 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022-2025 Xiaomi Corp. (authors: Daniel Povey
2
+ # Wei Kang)
3
+ #
4
+ # See ../../../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+
19
+ import logging
20
+ import math
21
+ import random
22
+ import sys
23
+ from typing import Optional, Tuple, Union
24
+
25
+ try:
26
+ import k2
27
+ except Exception as e:
28
+ logging.warning(
29
+ f"Failed import k2 with error {e}. Swoosh functions will fallback to PyTorch"
30
+ f" implementation, leading to slower speed and higher memory consumption."
31
+ )
32
+ import torch
33
+ import torch.nn as nn
34
+ from torch import Tensor
35
+
36
+
37
+ def custom_amp_decorator(dec, cuda_amp_deprecated):
38
+ def decorator(func):
39
+ return dec(func) if not cuda_amp_deprecated else dec(device_type="cuda")(func)
40
+
41
+ return decorator
42
+
43
+
44
+ if hasattr(torch.amp, "custom_fwd"):
45
+ deprecated = True
46
+ from torch.amp import custom_bwd, custom_fwd
47
+ else:
48
+ deprecated = False
49
+ from torch.cuda.amp import custom_bwd, custom_fwd
50
+
51
+ custom_fwd = custom_amp_decorator(custom_fwd, deprecated)
52
+ custom_bwd = custom_amp_decorator(custom_bwd, deprecated)
53
+
54
+
55
+ def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor:
56
+ max_value = torch.max(x, y)
57
+ diff = torch.abs(x - y)
58
+ return max_value + torch.log1p(torch.exp(-diff))
59
+
60
+
61
+ # RuntimeError: Exporting the operator logaddexp to ONNX opset version
62
+ # 14 is not supported. Please feel free to request support or submit
63
+ # a pull request on PyTorch GitHub.
64
+ #
65
+ # The following function is to solve the above error when exporting
66
+ # models to ONNX via torch.jit.trace()
67
+ def logaddexp(x: Tensor, y: Tensor) -> Tensor:
68
+ # Caution(fangjun): Put torch.jit.is_scripting() before
69
+ # torch.onnx.is_in_onnx_export();
70
+ # otherwise, it will cause errors for torch.jit.script().
71
+ #
72
+ # torch.logaddexp() works for both torch.jit.script() and
73
+ # torch.jit.trace() but it causes errors for ONNX export.
74
+ #
75
+ if torch.jit.is_scripting():
76
+ # Note: We cannot use torch.jit.is_tracing() here as it also
77
+ # matches torch.onnx.export().
78
+ return torch.logaddexp(x, y)
79
+ elif torch.onnx.is_in_onnx_export():
80
+ return logaddexp_onnx(x, y)
81
+ else:
82
+ # for torch.jit.trace()
83
+ return torch.logaddexp(x, y)
84
+
85
+
86
+ class PiecewiseLinear(object):
87
+ """
88
+ Piecewise linear function, from float to float, specified as nonempty list of (x,y)
89
+ pairs with the x values in order. x values <[initial x] or >[final x] are map to
90
+ [initial y], [final y] respectively.
91
+ """
92
+
93
+ def __init__(self, *args):
94
+ assert len(args) >= 1, len(args)
95
+ if len(args) == 1 and isinstance(args[0], PiecewiseLinear):
96
+ self.pairs = list(args[0].pairs)
97
+ else:
98
+ self.pairs = [(float(x), float(y)) for x, y in args]
99
+ for x, y in self.pairs:
100
+ assert isinstance(x, (float, int)), type(x)
101
+ assert isinstance(y, (float, int)), type(y)
102
+
103
+ for i in range(len(self.pairs) - 1):
104
+ assert self.pairs[i + 1][0] > self.pairs[i][0], (
105
+ i,
106
+ self.pairs[i],
107
+ self.pairs[i + 1],
108
+ )
109
+
110
+ def __str__(self):
111
+ # e.g. 'PiecewiseLinear((0., 10.), (100., 0.))'
112
+ return f"PiecewiseLinear({str(self.pairs)[1:-1]})"
113
+
114
+ def __call__(self, x):
115
+ if x <= self.pairs[0][0]:
116
+ return self.pairs[0][1]
117
+ elif x >= self.pairs[-1][0]:
118
+ return self.pairs[-1][1]
119
+ else:
120
+ cur_x, cur_y = self.pairs[0]
121
+ for i in range(1, len(self.pairs)):
122
+ next_x, next_y = self.pairs[i]
123
+ if x >= cur_x and x <= next_x:
124
+ return cur_y + (next_y - cur_y) * (x - cur_x) / (next_x - cur_x)
125
+ cur_x, cur_y = next_x, next_y
126
+ assert False
127
+
128
+ def __mul__(self, alpha):
129
+ return PiecewiseLinear(*[(x, y * alpha) for x, y in self.pairs])
130
+
131
+ def __add__(self, x):
132
+ if isinstance(x, (float, int)):
133
+ return PiecewiseLinear(*[(p[0], p[1] + x) for p in self.pairs])
134
+ s, x = self.get_common_basis(x)
135
+ return PiecewiseLinear(
136
+ *[(sp[0], sp[1] + xp[1]) for sp, xp in zip(s.pairs, x.pairs)]
137
+ )
138
+
139
+ def max(self, x):
140
+ if isinstance(x, (float, int)):
141
+ x = PiecewiseLinear((0, x))
142
+ s, x = self.get_common_basis(x, include_crossings=True)
143
+ return PiecewiseLinear(
144
+ *[(sp[0], max(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)]
145
+ )
146
+
147
+ def min(self, x):
148
+ if isinstance(x, float) or isinstance(x, int):
149
+ x = PiecewiseLinear((0, x))
150
+ s, x = self.get_common_basis(x, include_crossings=True)
151
+ return PiecewiseLinear(
152
+ *[(sp[0], min(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)]
153
+ )
154
+
155
+ def __eq__(self, other):
156
+ return self.pairs == other.pairs
157
+
158
+ def get_common_basis(self, p: "PiecewiseLinear", include_crossings: bool = False):
159
+ """
160
+ Returns (self_mod, p_mod) which are equivalent piecewise linear
161
+ functions to self and p, but with the same x values.
162
+
163
+ p: the other piecewise linear function
164
+ include_crossings: if true, include in the x values positions
165
+ where the functions indicate by this and p crosss.
166
+ """
167
+ assert isinstance(p, PiecewiseLinear), type(p)
168
+
169
+ # get sorted x-values without repetition.
170
+ x_vals = sorted(set([x for x, _ in self.pairs] + [x for x, _ in p.pairs]))
171
+ y_vals1 = [self(x) for x in x_vals]
172
+ y_vals2 = [p(x) for x in x_vals]
173
+
174
+ if include_crossings:
175
+ extra_x_vals = []
176
+ for i in range(len(x_vals) - 1):
177
+ if (y_vals1[i] > y_vals2[i]) != (y_vals1[i + 1] > y_vals2[i + 1]):
178
+ # if the two lines in this subsegment potentially cross each other..
179
+ diff_cur = abs(y_vals1[i] - y_vals2[i])
180
+ diff_next = abs(y_vals1[i + 1] - y_vals2[i + 1])
181
+ # `pos`, between 0 and 1, gives the relative x position,
182
+ # with 0 being x_vals[i] and 1 being x_vals[i+1].
183
+ pos = diff_cur / (diff_cur + diff_next)
184
+ extra_x_val = x_vals[i] + pos * (x_vals[i + 1] - x_vals[i])
185
+ extra_x_vals.append(extra_x_val)
186
+ if len(extra_x_vals) > 0:
187
+ x_vals = sorted(set(x_vals + extra_x_vals))
188
+ y_vals1 = [self(x) for x in x_vals]
189
+ y_vals2 = [p(x) for x in x_vals]
190
+ return (
191
+ PiecewiseLinear(*zip(x_vals, y_vals1)),
192
+ PiecewiseLinear(*zip(x_vals, y_vals2)),
193
+ )
194
+
195
+
196
+ class ScheduledFloat(torch.nn.Module):
197
+ """
198
+ This object is a torch.nn.Module only because we want it to show up in
199
+ [top_level module].modules(); it does not have a working forward() function.
200
+ You are supposed to cast it to float, as in, float(parent_module.whatever), and use
201
+ it as something like a dropout prob.
202
+
203
+ It is a floating point value whose value changes depending on the batch count of the
204
+ training loop. It is a piecewise linear function where you specify the (x,y) pairs
205
+ in sorted order on x; x corresponds to the batch index. For batch-index values
206
+ before the first x or after the last x, we just use the first or last y value.
207
+
208
+ Example:
209
+ self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0.0)
210
+
211
+ `default` is used when self.batch_count is not set or not in training mode or in
212
+ torch.jit scripting mode.
213
+ """
214
+
215
+ def __init__(self, *args, default: float = 0.0):
216
+ super().__init__()
217
+ # self.batch_count and self.name will be written to in the training loop.
218
+ self.batch_count = None
219
+ self.name = None
220
+ self.default = default
221
+ self.schedule = PiecewiseLinear(*args)
222
+
223
+ def extra_repr(self) -> str:
224
+ return (
225
+ f"batch_count={self.batch_count}, schedule={str(self.schedule.pairs[1:-1])}"
226
+ )
227
+
228
+ def __float__(self):
229
+ batch_count = self.batch_count
230
+ if (
231
+ batch_count is None
232
+ or not self.training
233
+ or torch.jit.is_scripting()
234
+ or torch.jit.is_tracing()
235
+ ):
236
+ return float(self.default)
237
+ else:
238
+ ans = self.schedule(self.batch_count)
239
+ if random.random() < 0.0002:
240
+ logging.debug(
241
+ f"ScheduledFloat: name={self.name}, "
242
+ f"batch_count={self.batch_count}, ans={ans}"
243
+ )
244
+ return ans
245
+
246
+ def __add__(self, x):
247
+ if isinstance(x, float) or isinstance(x, int):
248
+ return ScheduledFloat(self.schedule + x, default=self.default)
249
+ else:
250
+ return ScheduledFloat(
251
+ self.schedule + x.schedule, default=self.default + x.default
252
+ )
253
+
254
+ def max(self, x):
255
+ if isinstance(x, float) or isinstance(x, int):
256
+ return ScheduledFloat(self.schedule.max(x), default=self.default)
257
+ else:
258
+ return ScheduledFloat(
259
+ self.schedule.max(x.schedule),
260
+ default=max(self.default, x.default),
261
+ )
262
+
263
+
264
+ FloatLike = Union[float, ScheduledFloat]
265
+
266
+
267
+ class CutoffEstimator:
268
+ """
269
+ Estimates cutoffs of an arbitrary numerical quantity such that a specified
270
+ proportion of items will be above the cutoff on average.
271
+
272
+ p is the proportion of items that should be above the cutoff.
273
+ """
274
+
275
+ def __init__(self, p: float):
276
+ self.p = p
277
+ # total count of items
278
+ self.count = 0
279
+ # total count of items that were above the cutoff
280
+ self.count_above = 0
281
+ # initial cutoff value
282
+ self.cutoff = 0
283
+
284
+ def __call__(self, x: float) -> bool:
285
+ """
286
+ Returns true if x is above the cutoff.
287
+ """
288
+ ans = x > self.cutoff
289
+ self.count += 1
290
+ if ans:
291
+ self.count_above += 1
292
+ cur_p = self.count_above / self.count
293
+ delta_p = cur_p - self.p
294
+ if (delta_p > 0) == ans:
295
+ q = abs(delta_p)
296
+ self.cutoff = x * q + self.cutoff * (1 - q)
297
+ return ans
298
+
299
+
300
+ class SoftmaxFunction(torch.autograd.Function):
301
+ """
302
+ Tries to handle half-precision derivatives in a randomized way that should
303
+ be more accurate for training than the default behavior.
304
+ """
305
+
306
+ @staticmethod
307
+ def forward(ctx, x: Tensor, dim: int):
308
+ ans = x.softmax(dim=dim)
309
+ # if x dtype is float16, x.softmax() returns a float32 because
310
+ # (presumably) that op does not support float16, and autocast
311
+ # is enabled.
312
+ if torch.is_autocast_enabled():
313
+ ans = ans.to(torch.float16)
314
+ ctx.save_for_backward(ans)
315
+ ctx.x_dtype = x.dtype
316
+ ctx.dim = dim
317
+ return ans
318
+
319
+ @staticmethod
320
+ def backward(ctx, ans_grad: Tensor):
321
+ (ans,) = ctx.saved_tensors
322
+ with torch.amp.autocast("cuda", enabled=False):
323
+ ans_grad = ans_grad.to(torch.float32)
324
+ ans = ans.to(torch.float32)
325
+ x_grad = ans_grad * ans
326
+ x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True)
327
+ return x_grad, None
328
+
329
+
330
+ def softmax(x: Tensor, dim: int):
331
+ if not x.requires_grad or torch.jit.is_scripting() or torch.jit.is_tracing():
332
+ return x.softmax(dim=dim)
333
+
334
+ return SoftmaxFunction.apply(x, dim)
335
+
336
+
337
+ class BiasNormFunction(torch.autograd.Function):
338
+ # This computes:
339
+ # scales = (torch.mean((x - bias) ** 2, keepdim=True)) ** -0.5 * log_scale.exp()
340
+ # return x * scales
341
+ # (after unsqueezing the bias), but it does it in a memory-efficient way so that
342
+ # it can just store the returned value (chances are, this will also be needed for
343
+ # some other reason, related to the next operation, so we can save memory).
344
+ @staticmethod
345
+ def forward(
346
+ ctx,
347
+ x: Tensor,
348
+ bias: Tensor,
349
+ log_scale: Tensor,
350
+ channel_dim: int,
351
+ store_output_for_backprop: bool,
352
+ ) -> Tensor:
353
+ assert bias.ndim == 1
354
+ if channel_dim < 0:
355
+ channel_dim = channel_dim + x.ndim
356
+ ctx.store_output_for_backprop = store_output_for_backprop
357
+ ctx.channel_dim = channel_dim
358
+ for _ in range(channel_dim + 1, x.ndim):
359
+ bias = bias.unsqueeze(-1)
360
+ scales = (
361
+ torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5
362
+ ) * log_scale.exp()
363
+ ans = x * scales
364
+ ctx.save_for_backward(
365
+ ans.detach() if store_output_for_backprop else x,
366
+ scales.detach(),
367
+ bias.detach(),
368
+ log_scale.detach(),
369
+ )
370
+ return ans
371
+
372
+ @staticmethod
373
+ def backward(ctx, ans_grad: Tensor) -> Tensor:
374
+ ans_or_x, scales, bias, log_scale = ctx.saved_tensors
375
+ if ctx.store_output_for_backprop:
376
+ x = ans_or_x / scales
377
+ else:
378
+ x = ans_or_x
379
+ x = x.detach()
380
+ x.requires_grad = True
381
+ bias.requires_grad = True
382
+ log_scale.requires_grad = True
383
+ with torch.enable_grad():
384
+ # recompute scales from x, bias and log_scale.
385
+ scales = (
386
+ torch.mean((x - bias) ** 2, dim=ctx.channel_dim, keepdim=True) ** -0.5
387
+ ) * log_scale.exp()
388
+ ans = x * scales
389
+ ans.backward(gradient=ans_grad)
390
+ return x.grad, bias.grad.flatten(), log_scale.grad, None, None
391
+
392
+
393
+ class BiasNorm(torch.nn.Module):
394
+ """
395
+ This is intended to be a simpler, and hopefully cheaper, replacement for
396
+ LayerNorm. The observation this is based on, is that Transformer-type
397
+ networks, especially with pre-norm, sometimes seem to set one of the
398
+ feature dimensions to a large constant value (e.g. 50), which "defeats"
399
+ the LayerNorm because the output magnitude is then not strongly dependent
400
+ on the other (useful) features. Presumably the weight and bias of the
401
+ LayerNorm are required to allow it to do this.
402
+
403
+ Instead, we give the BiasNorm a trainable bias that it can use when
404
+ computing the scale for normalization. We also give it a (scalar)
405
+ trainable scale on the output.
406
+
407
+
408
+ Args:
409
+ num_channels: the number of channels, e.g. 512.
410
+ channel_dim: the axis/dimension corresponding to the channel,
411
+ interpreted as an offset from the input's ndim if negative.
412
+ This is NOT the num_channels; it should typically be one of
413
+ {-2, -1, 0, 1, 2, 3}.
414
+ log_scale: the initial log-scale that we multiply the output by; this
415
+ is learnable.
416
+ log_scale_min: FloatLike, minimum allowed value of log_scale
417
+ log_scale_max: FloatLike, maximum allowed value of log_scale
418
+ store_output_for_backprop: only possibly affects memory use; recommend
419
+ to set to True if you think the output of this module is more likely
420
+ than the input of this module to be required to be stored for the
421
+ backprop.
422
+ """
423
+
424
+ def __init__(
425
+ self,
426
+ num_channels: int,
427
+ channel_dim: int = -1, # CAUTION: see documentation.
428
+ log_scale: float = 1.0,
429
+ log_scale_min: float = -1.5,
430
+ log_scale_max: float = 1.5,
431
+ store_output_for_backprop: bool = False,
432
+ ) -> None:
433
+ super(BiasNorm, self).__init__()
434
+ self.num_channels = num_channels
435
+ self.channel_dim = channel_dim
436
+ self.log_scale = nn.Parameter(torch.tensor(log_scale))
437
+ self.bias = nn.Parameter(torch.zeros(num_channels))
438
+
439
+ self.log_scale_min = log_scale_min
440
+ self.log_scale_max = log_scale_max
441
+
442
+ self.store_output_for_backprop = store_output_for_backprop
443
+
444
+ def forward(self, x: Tensor) -> Tensor:
445
+ assert x.shape[self.channel_dim] == self.num_channels
446
+
447
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
448
+ channel_dim = self.channel_dim
449
+ if channel_dim < 0:
450
+ channel_dim += x.ndim
451
+ bias = self.bias
452
+ for _ in range(channel_dim + 1, x.ndim):
453
+ bias = bias.unsqueeze(-1)
454
+ scales = (
455
+ torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5
456
+ ) * self.log_scale.exp()
457
+ return x * scales
458
+
459
+ log_scale = limit_param_value(
460
+ self.log_scale,
461
+ min=float(self.log_scale_min),
462
+ max=float(self.log_scale_max),
463
+ training=self.training,
464
+ )
465
+
466
+ return BiasNormFunction.apply(
467
+ x,
468
+ self.bias,
469
+ log_scale,
470
+ self.channel_dim,
471
+ self.store_output_for_backprop,
472
+ )
473
+
474
+
475
+ def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear:
476
+ """
477
+ Behaves like a constructor of a modified version of nn.Linear
478
+ that gives an easy way to set the default initial parameter scale.
479
+
480
+ Args:
481
+ Accepts the standard args and kwargs that nn.Linear accepts
482
+ e.g. in_features, out_features, bias=False.
483
+
484
+ initial_scale: you can override this if you want to increase
485
+ or decrease the initial magnitude of the module's output
486
+ (affects the initialization of weight_scale and bias_scale).
487
+ Another option, if you want to do something like this, is
488
+ to re-initialize the parameters.
489
+ """
490
+ ans = nn.Linear(*args, **kwargs)
491
+ with torch.no_grad():
492
+ ans.weight[:] *= initial_scale
493
+ if ans.bias is not None:
494
+ torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale)
495
+ return ans
496
+
497
+
498
+ class BalancerFunction(torch.autograd.Function):
499
+ @staticmethod
500
+ def forward(
501
+ ctx,
502
+ x: Tensor,
503
+ min_mean: float,
504
+ max_mean: float,
505
+ min_rms: float,
506
+ max_rms: float,
507
+ grad_scale: float,
508
+ channel_dim: int,
509
+ ) -> Tensor:
510
+ if channel_dim < 0:
511
+ channel_dim += x.ndim
512
+ ctx.channel_dim = channel_dim
513
+ ctx.save_for_backward(x)
514
+ ctx.config = (
515
+ min_mean,
516
+ max_mean,
517
+ min_rms,
518
+ max_rms,
519
+ grad_scale,
520
+ channel_dim,
521
+ )
522
+ return x
523
+
524
+ @staticmethod
525
+ def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]:
526
+ (x,) = ctx.saved_tensors
527
+ (
528
+ min_mean,
529
+ max_mean,
530
+ min_rms,
531
+ max_rms,
532
+ grad_scale,
533
+ channel_dim,
534
+ ) = ctx.config
535
+
536
+ try:
537
+ with torch.enable_grad():
538
+ with torch.amp.autocast("cuda", enabled=False):
539
+ x = x.to(torch.float32)
540
+ x = x.detach()
541
+ x.requires_grad = True
542
+ mean_dims = [i for i in range(x.ndim) if i != channel_dim]
543
+ uncentered_var = (x**2).mean(dim=mean_dims, keepdim=True)
544
+ mean = x.mean(dim=mean_dims, keepdim=True)
545
+ stddev = (uncentered_var - (mean * mean)).clamp(min=1.0e-20).sqrt()
546
+ rms = uncentered_var.clamp(min=1.0e-20).sqrt()
547
+
548
+ m = mean / stddev
549
+ # part of loss that relates to mean / stddev
550
+ m_loss = (m - m.clamp(min=min_mean, max=max_mean)).abs()
551
+
552
+ # put a much larger scale on the RMS-max-limit loss, so that if both
553
+ # it and the m_loss are violated we fix the RMS loss first.
554
+ rms_clamped = rms.clamp(min=min_rms, max=max_rms)
555
+ r_loss = (rms_clamped / rms).log().abs()
556
+
557
+ loss = m_loss + r_loss
558
+
559
+ loss.backward(gradient=torch.ones_like(loss))
560
+ loss_grad = x.grad
561
+ loss_grad_rms = (
562
+ (loss_grad**2)
563
+ .mean(dim=mean_dims, keepdim=True)
564
+ .sqrt()
565
+ .clamp(min=1.0e-20)
566
+ )
567
+
568
+ loss_grad = loss_grad * (grad_scale / loss_grad_rms)
569
+
570
+ x_grad_float = x_grad.to(torch.float32)
571
+ # scale each element of loss_grad by the absolute value of the
572
+ # corresponding element of x_grad, which we view as a noisy estimate
573
+ # of its magnitude for that (frame and dimension). later we can
574
+ # consider factored versions.
575
+ x_grad_mod = x_grad_float + (x_grad_float.abs() * loss_grad)
576
+ x_grad = x_grad_mod.to(x_grad.dtype)
577
+ except Exception as e:
578
+ logging.info(
579
+ f"Caught exception in Balancer backward: {e}, "
580
+ f"size={list(x_grad.shape)}, will continue."
581
+ )
582
+
583
+ return x_grad, None, None, None, None, None, None
584
+
585
+
586
+ class Balancer(torch.nn.Module):
587
+ """
588
+ Modifies the backpropped derivatives of a function to try to encourage, for
589
+ each channel, that it is positive at least a proportion `threshold` of the
590
+ time. It does this by multiplying negative derivative values by up to
591
+ (1+max_factor), and positive derivative values by up to (1-max_factor),
592
+ interpolated from 1 at the threshold to those extremal values when none
593
+ of the inputs are positive.
594
+
595
+ Args:
596
+ num_channels: the number of channels
597
+ channel_dim: the dimension/axis corresponding to the channel, e.g.
598
+ -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
599
+ min_positive: the minimum, per channel, of the proportion of the time
600
+ that (x > 0), below which we start to modify the derivatives.
601
+ max_positive: the maximum, per channel, of the proportion of the time
602
+ that (x > 0), above which we start to modify the derivatives.
603
+ scale_gain_factor: determines the 'gain' with which we increase the
604
+ change in gradient once the constraints on min_abs and max_abs
605
+ are violated.
606
+ min_abs: the minimum average-absolute-value difference from the mean
607
+ value per channel, which we allow, before we start to modify
608
+ the derivatives to prevent this.
609
+ max_abs: the maximum average-absolute-value difference from the mean
610
+ value per channel, which we allow, before we start to modify
611
+ the derivatives to prevent this.
612
+ prob: determines the minimum probability with which we modify the
613
+ gradients for the {min,max}_positive and {min,max}_abs constraints,
614
+ on each forward(). This is done randomly to prevent all layers
615
+ from doing it at the same time.
616
+ """
617
+
618
+ def __init__(
619
+ self,
620
+ num_channels: int,
621
+ channel_dim: int,
622
+ min_positive: FloatLike = 0.05,
623
+ max_positive: FloatLike = 0.95,
624
+ min_abs: FloatLike = 0.2,
625
+ max_abs: FloatLike = 100.0,
626
+ grad_scale: FloatLike = 0.04,
627
+ prob: Optional[FloatLike] = None,
628
+ ):
629
+ super().__init__()
630
+
631
+ if prob is None:
632
+ prob = ScheduledFloat((0.0, 0.5), (8000.0, 0.125), default=0.4)
633
+ self.prob = prob
634
+ # 5% of the time we will return and do nothing because memory usage is
635
+ # too high.
636
+ self.mem_cutoff = CutoffEstimator(0.05)
637
+
638
+ # actually self.num_channels is no longer needed except for an assertion.
639
+ self.num_channels = num_channels
640
+ self.channel_dim = channel_dim
641
+ self.min_positive = min_positive
642
+ self.max_positive = max_positive
643
+ self.min_abs = min_abs
644
+ self.max_abs = max_abs
645
+ self.grad_scale = grad_scale
646
+
647
+ def forward(self, x: Tensor) -> Tensor:
648
+ if (
649
+ torch.jit.is_scripting()
650
+ or not x.requires_grad
651
+ or (x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated()))
652
+ ):
653
+ return _no_op(x)
654
+
655
+ prob = float(self.prob)
656
+ if random.random() < prob:
657
+ # The following inner-functions convert from the way we historically
658
+ # specified these limitations, as limits on the absolute value and the
659
+ # proportion of positive values, to limits on the RMS value and
660
+ # the (mean / stddev).
661
+ def _abs_to_rms(x):
662
+ # for normally distributed data, if the expected absolute value is x,
663
+ # the expected rms value will be sqrt(pi/2) * x.
664
+ return 1.25331413732 * x
665
+
666
+ def _proportion_positive_to_mean(x):
667
+ def _atanh(x):
668
+ eps = 1.0e-10
669
+ # eps is to prevent crashes if x is exactly 0 or 1.
670
+ # we'll just end up returning a fairly large value.
671
+ return (math.log(1 + x + eps) - math.log(1 - x + eps)) / 2.0
672
+
673
+ def _approx_inverse_erf(x):
674
+ # 1 / (sqrt(pi) * ln(2)),
675
+ # see https://math.stackexchange.com/questions/321569/
676
+ # approximating-the-error-function-erf-by-analytical-functions
677
+ # this approximation is extremely crude and gets progressively worse
678
+ # for x very close to -1 or +1, but we mostly care about the
679
+ # "middle" region
680
+ # e.g. _approx_inverse_erf(0.05) = 0.0407316414078772,
681
+ # and math.erf(0.0407316414078772) = 0.045935330944660666,
682
+ # which is pretty close to 0.05.
683
+ return 0.8139535143 * _atanh(x)
684
+
685
+ # first convert x from the range 0..1 to the range -1..1 which the error
686
+ # function returns
687
+ x = -1 + (2 * x)
688
+ return _approx_inverse_erf(x)
689
+
690
+ min_mean = _proportion_positive_to_mean(float(self.min_positive))
691
+ max_mean = _proportion_positive_to_mean(float(self.max_positive))
692
+ min_rms = _abs_to_rms(float(self.min_abs))
693
+ max_rms = _abs_to_rms(float(self.max_abs))
694
+ grad_scale = float(self.grad_scale)
695
+
696
+ assert x.shape[self.channel_dim] == self.num_channels
697
+
698
+ return BalancerFunction.apply(
699
+ x,
700
+ min_mean,
701
+ max_mean,
702
+ min_rms,
703
+ max_rms,
704
+ grad_scale,
705
+ self.channel_dim,
706
+ )
707
+ else:
708
+ return _no_op(x)
709
+
710
+
711
+ def penalize_abs_values_gt(
712
+ x: Tensor, limit: float, penalty: float, name: str = None
713
+ ) -> Tensor:
714
+ """
715
+ Returns x unmodified, but in backprop will put a penalty for the excess of
716
+ the absolute values of elements of x over the limit "limit". E.g. if
717
+ limit == 10.0, then if x has any values over 10 it will get a penalty.
718
+
719
+ Caution: the value of this penalty will be affected by grad scaling used
720
+ in automatic mixed precision training. For this reasons we use this,
721
+ it shouldn't really matter, or may even be helpful; we just use this
722
+ to disallow really implausible values of scores to be given to softmax.
723
+
724
+ The name is for randomly printed debug info.
725
+ """
726
+ x_sign = x.sign()
727
+ over_limit = (x.abs() - limit) > 0
728
+ # The following is a memory efficient way to penalize the absolute values of
729
+ # x that's over the limit. (The memory efficiency comes when you think
730
+ # about which items torch needs to cache for the autograd, and which ones it
731
+ # can throw away). The numerical value of aux_loss as computed here will
732
+ # actually be larger than it should be, by limit * over_limit.sum(), but it
733
+ # has the same derivative as the real aux_loss which is penalty * (x.abs() -
734
+ # limit).relu().
735
+ aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x)
736
+ # note: we don't do sum() here on aux)_loss, but it's as if we had done
737
+ # sum() due to how with_loss() works.
738
+ x = with_loss(x, aux_loss, name)
739
+ # you must use x for something, or this will be ineffective.
740
+ return x
741
+
742
+
743
+ def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims.
744
+ if x.ndim == 2:
745
+ return x.diag()
746
+ else:
747
+ (batch, dim, dim) = x.shape
748
+ x = x.reshape(batch, dim * dim)
749
+ x = x[:, :: dim + 1]
750
+ assert x.shape == (batch, dim)
751
+ return x
752
+
753
+
754
+ def _whitening_metric(x: Tensor, num_groups: int):
755
+ """
756
+ Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of
757
+ of the centered feature covariance are the same within each group's covariance
758
+ matrix and also between groups.
759
+ Args:
760
+ x: a Tensor of shape (*, num_channels)
761
+ num_groups: the number of groups of channels, a number >=1 that divides
762
+ num_channels
763
+ Returns:
764
+ Returns a scalar Tensor that will be 1.0 if the data is "perfectly white" and
765
+ greater than 1.0 otherwise.
766
+ """
767
+ assert x.dtype != torch.float16
768
+ x = x.reshape(-1, x.shape[-1])
769
+ (num_frames, num_channels) = x.shape
770
+ assert num_channels % num_groups == 0
771
+ channels_per_group = num_channels // num_groups
772
+ x = x.reshape(num_frames, num_groups, channels_per_group).transpose(0, 1)
773
+ # x now has shape (num_groups, num_frames, channels_per_group)
774
+ # subtract the mean so we use the centered, not uncentered, covariance.
775
+ # My experience has been that when we "mess with the gradients" like this,
776
+ # it's better not do anything that tries to move the mean around, because
777
+ # that can easily cause instability.
778
+ x = x - x.mean(dim=1, keepdim=True)
779
+ # x_covar: (num_groups, channels_per_group, channels_per_group)
780
+ x_covar = torch.matmul(x.transpose(1, 2), x)
781
+ x_covar_mean_diag = _diag(x_covar).mean()
782
+ # the following expression is what we'd get if we took the matrix product
783
+ # of each covariance and measured the mean of its trace, i.e.
784
+ # the same as _diag(torch.matmul(x_covar, x_covar)).mean().
785
+ x_covarsq_mean_diag = (x_covar**2).sum() / (num_groups * channels_per_group)
786
+ # this metric will be >= 1.0; the larger it is, the less 'white' the data was.
787
+ metric = x_covarsq_mean_diag / (x_covar_mean_diag**2 + 1.0e-20)
788
+ return metric
789
+
790
+
791
+ class WhiteningPenaltyFunction(torch.autograd.Function):
792
+ @staticmethod
793
+ def forward(ctx, x: Tensor, module: nn.Module) -> Tensor:
794
+ ctx.save_for_backward(x)
795
+ ctx.module = module
796
+ return x
797
+
798
+ @staticmethod
799
+ def backward(ctx, x_grad: Tensor):
800
+ (x_orig,) = ctx.saved_tensors
801
+ w = ctx.module
802
+
803
+ try:
804
+ with torch.enable_grad():
805
+ with torch.amp.autocast("cuda", enabled=False):
806
+ x_detached = x_orig.to(torch.float32).detach()
807
+ x_detached.requires_grad = True
808
+
809
+ metric = _whitening_metric(x_detached, w.num_groups)
810
+
811
+ if random.random() < 0.005 or __name__ == "__main__":
812
+ logging.debug(
813
+ f"Whitening: name={w.name}, num_groups={w.num_groups},"
814
+ f"num_channels={x_orig.shape[-1]}, "
815
+ f"metric={metric.item():.2f}"
816
+ f" vs. limit={float(w.whitening_limit)}"
817
+ )
818
+
819
+ if metric < float(w.whitening_limit):
820
+ w.prob = w.min_prob
821
+ return x_grad, None
822
+ else:
823
+ w.prob = w.max_prob
824
+ metric.backward()
825
+ penalty_grad = x_detached.grad
826
+ scale = w.grad_scale * (
827
+ x_grad.to(torch.float32).norm()
828
+ / (penalty_grad.norm() + 1.0e-20)
829
+ )
830
+ penalty_grad = penalty_grad * scale
831
+ return x_grad + penalty_grad.to(x_grad.dtype), None
832
+ except Exception as e:
833
+ logging.info(
834
+ f"Caught exception in Whiten backward: {e}, "
835
+ f"size={list(x_grad.shape)}, will continue."
836
+ )
837
+ return x_grad, None
838
+
839
+
840
+ class Whiten(nn.Module):
841
+ def __init__(
842
+ self,
843
+ num_groups: int,
844
+ whitening_limit: FloatLike,
845
+ prob: Union[float, Tuple[float, float]],
846
+ grad_scale: FloatLike,
847
+ ):
848
+ """
849
+ Args:
850
+ num_groups: the number of groups to divide the channel dim into before
851
+ whitening. We will attempt to make the feature covariance
852
+ within each group, after mean subtraction, as "white" as possible,
853
+ while having the same trace across all groups.
854
+ whitening_limit: a value greater than 1.0, that dictates how much
855
+ freedom we have to violate the constraints. 1.0 would mean perfectly
856
+ white, with exactly the same trace across groups; larger values
857
+ give more freedom. E.g. 2.0.
858
+ prob: the probability with which we apply the gradient modification
859
+ (also affects the grad scale). May be supplied as a float,
860
+ or as a pair (min_prob, max_prob)
861
+
862
+ grad_scale: determines the scale on the gradient term from this object,
863
+ relative to the rest of the gradient on the attention weights.
864
+ E.g. 0.02 (you may want to use smaller values than this if prob is large)
865
+ """
866
+ super(Whiten, self).__init__()
867
+ assert num_groups >= 1
868
+ assert float(whitening_limit) >= 1
869
+ assert grad_scale >= 0
870
+ self.num_groups = num_groups
871
+ self.whitening_limit = whitening_limit
872
+ self.grad_scale = grad_scale
873
+
874
+ if isinstance(prob, float):
875
+ prob = (prob, prob)
876
+ (self.min_prob, self.max_prob) = prob
877
+ assert 0 < self.min_prob <= self.max_prob <= 1
878
+ self.prob = self.max_prob
879
+ self.name = None # will be set in training loop
880
+
881
+ def forward(self, x: Tensor) -> Tensor:
882
+ """
883
+ In the forward pass, this function just returns the input unmodified.
884
+ In the backward pass, it will modify the gradients to ensure that the
885
+ distribution in each group has close to (lambda times I) as the covariance
886
+ after mean subtraction, with the same lambda across groups.
887
+ For whitening_limit > 1, there will be more freedom to violate this
888
+ constraint.
889
+
890
+ Args:
891
+ x: the input of shape (*, num_channels)
892
+
893
+ Returns:
894
+ x, unmodified. You should make sure
895
+ you use the returned value, or the graph will be freed
896
+ and nothing will happen in backprop.
897
+ """
898
+ grad_scale = float(self.grad_scale)
899
+ if not x.requires_grad or random.random() > self.prob or grad_scale == 0:
900
+ return _no_op(x)
901
+ else:
902
+ return WhiteningPenaltyFunction.apply(x, self)
903
+
904
+
905
+ class WithLoss(torch.autograd.Function):
906
+ @staticmethod
907
+ def forward(ctx, x: Tensor, y: Tensor, name: str):
908
+ ctx.y_shape = y.shape
909
+ if random.random() < 0.002 and name is not None:
910
+ loss_sum = y.sum().item()
911
+ logging.debug(f"WithLoss: name={name}, loss-sum={loss_sum:.3e}")
912
+ return x
913
+
914
+ @staticmethod
915
+ def backward(ctx, ans_grad: Tensor):
916
+ return (
917
+ ans_grad,
918
+ torch.ones(ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device),
919
+ None,
920
+ )
921
+
922
+
923
+ def with_loss(x, y, name):
924
+ # returns x but adds y.sum() to the loss function.
925
+ return WithLoss.apply(x, y, name)
926
+
927
+
928
+ class LimitParamValue(torch.autograd.Function):
929
+ @staticmethod
930
+ def forward(ctx, x: Tensor, min: float, max: float):
931
+ ctx.save_for_backward(x)
932
+ assert max >= min
933
+ ctx.min = min
934
+ ctx.max = max
935
+ return x
936
+
937
+ @staticmethod
938
+ def backward(ctx, x_grad: Tensor):
939
+ (x,) = ctx.saved_tensors
940
+ # where x < ctx.min, ensure all grads are negative (this will tend to make
941
+ # x more positive).
942
+ x_grad = x_grad * torch.where(
943
+ torch.logical_and(x_grad > 0, x < ctx.min), -1.0, 1.0
944
+ )
945
+ # where x > ctx.max, ensure all grads are positive (this will tend to make
946
+ # x more negative).
947
+ x_grad *= torch.where(torch.logical_and(x_grad < 0, x > ctx.max), -1.0, 1.0)
948
+ return x_grad, None, None
949
+
950
+
951
+ def limit_param_value(
952
+ x: Tensor, min: float, max: float, prob: float = 0.6, training: bool = True
953
+ ):
954
+ # You apply this to (typically) an nn.Parameter during training to ensure that its
955
+ # (elements mostly) stays within a supplied range. This is done by modifying the
956
+ # gradients in backprop.
957
+ # It's not necessary to do this on every batch: do it only some of the time,
958
+ # to save a little time.
959
+ if training and random.random() < prob:
960
+ return LimitParamValue.apply(x, min, max)
961
+ else:
962
+ return x
963
+
964
+
965
+ def _no_op(x: Tensor) -> Tensor:
966
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
967
+ return x
968
+ else:
969
+ # a no-op function that will have a node in the autograd graph,
970
+ # to avoid certain bugs relating to backward hooks
971
+ return x.chunk(1, dim=-1)[0]
972
+
973
+
974
+ # Identity more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
975
+ class Identity(torch.nn.Module):
976
+ def __init__(self):
977
+ super(Identity, self).__init__()
978
+
979
+ def forward(self, x):
980
+ return _no_op(x)
981
+
982
+
983
+ # Dropout2 is just like normal dropout, except it supports schedules
984
+ # on the dropout rates.
985
+ class Dropout2(nn.Module):
986
+ def __init__(self, p: FloatLike):
987
+ super().__init__()
988
+ self.p = p
989
+
990
+ def forward(self, x: Tensor) -> Tensor:
991
+ return torch.nn.functional.dropout(x, p=float(self.p), training=self.training)
992
+
993
+
994
+ class MulForDropout3(torch.autograd.Function):
995
+ # returns (x * y * alpha) where alpha is a float and y doesn't require
996
+ # grad and is zero-or-one.
997
+ @staticmethod
998
+ @custom_fwd
999
+ def forward(ctx, x, y, alpha):
1000
+ assert not y.requires_grad
1001
+ ans = x * y * alpha
1002
+ ctx.save_for_backward(ans)
1003
+ ctx.alpha = alpha
1004
+ return ans
1005
+
1006
+ @staticmethod
1007
+ @custom_bwd
1008
+ def backward(ctx, ans_grad):
1009
+ (ans,) = ctx.saved_tensors
1010
+ x_grad = ctx.alpha * ans_grad * (ans != 0)
1011
+ return x_grad, None, None
1012
+
1013
+
1014
+ # Dropout3 is just like normal dropout, except it supports schedules on the dropout
1015
+ # rates, and it lets you choose one dimension to share the dropout mask over
1016
+ class Dropout3(nn.Module):
1017
+ def __init__(self, p: FloatLike, shared_dim: int):
1018
+ super().__init__()
1019
+ self.p = p
1020
+ self.shared_dim = shared_dim
1021
+
1022
+ def forward(self, x: Tensor) -> Tensor:
1023
+ p = float(self.p)
1024
+ if not self.training or p == 0:
1025
+ return _no_op(x)
1026
+ scale = 1.0 / (1 - p)
1027
+ rand_shape = list(x.shape)
1028
+ rand_shape[self.shared_dim] = 1
1029
+ mask = torch.rand(*rand_shape, device=x.device) > p
1030
+ ans = MulForDropout3.apply(x, mask, scale)
1031
+ return ans
1032
+
1033
+
1034
+ class SwooshLFunction(torch.autograd.Function):
1035
+ """
1036
+ swoosh_l(x) = log(1 + exp(x-4)) - 0.08*x - 0.035
1037
+ """
1038
+
1039
+ @staticmethod
1040
+ def forward(ctx, x: Tensor) -> Tensor:
1041
+ requires_grad = x.requires_grad
1042
+ if x.dtype == torch.float16:
1043
+ x = x.to(torch.float32)
1044
+
1045
+ zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
1046
+
1047
+ coeff = -0.08
1048
+
1049
+ with torch.amp.autocast("cuda", enabled=False):
1050
+ with torch.enable_grad():
1051
+ x = x.detach()
1052
+ x.requires_grad = True
1053
+ y = torch.logaddexp(zero, x - 4.0) + coeff * x - 0.035
1054
+
1055
+ if not requires_grad:
1056
+ return y
1057
+
1058
+ y.backward(gradient=torch.ones_like(y))
1059
+
1060
+ grad = x.grad
1061
+ floor = coeff
1062
+ ceil = 1.0 + coeff + 0.005
1063
+
1064
+ d_scaled = (grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like(
1065
+ grad
1066
+ )
1067
+ if __name__ == "__main__":
1068
+ # for self-testing only.
1069
+ assert d_scaled.min() >= 0.0
1070
+ assert d_scaled.max() < 256.0
1071
+
1072
+ d_int = d_scaled.to(torch.uint8)
1073
+ ctx.save_for_backward(d_int)
1074
+ if x.dtype == torch.float16 or torch.is_autocast_enabled():
1075
+ y = y.to(torch.float16)
1076
+ return y
1077
+
1078
+ @staticmethod
1079
+ def backward(ctx, y_grad: Tensor) -> Tensor:
1080
+ (d,) = ctx.saved_tensors
1081
+ # the same constants as used in forward pass.
1082
+ coeff = -0.08
1083
+ floor = coeff
1084
+ ceil = 1.0 + coeff + 0.005
1085
+ d = d * ((ceil - floor) / 255.0) + floor
1086
+ return y_grad * d
1087
+
1088
+
1089
+ class SwooshL(torch.nn.Module):
1090
+ def forward(self, x: Tensor) -> Tensor:
1091
+ """Return Swoosh-L activation."""
1092
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
1093
+ zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
1094
+ return logaddexp(zero, x - 4.0) - 0.08 * x - 0.035
1095
+ elif "k2" not in sys.modules:
1096
+ return SwooshLFunction.apply(x)
1097
+ else:
1098
+ if not x.requires_grad:
1099
+ return k2.swoosh_l_forward(x)
1100
+ else:
1101
+ return k2.swoosh_l(x)
1102
+
1103
+
1104
+ class SwooshLOnnx(torch.nn.Module):
1105
+ def forward(self, x: Tensor) -> Tensor:
1106
+ """Return Swoosh-L activation."""
1107
+ zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
1108
+ return logaddexp_onnx(zero, x - 4.0) - 0.08 * x - 0.035
1109
+
1110
+
1111
+ class SwooshRFunction(torch.autograd.Function):
1112
+ """
1113
+ swoosh_r(x) = log(1 + exp(x-1)) - 0.08*x - 0.313261687
1114
+
1115
+ derivatives are between -0.08 and 0.92.
1116
+ """
1117
+
1118
+ @staticmethod
1119
+ def forward(ctx, x: Tensor) -> Tensor:
1120
+ requires_grad = x.requires_grad
1121
+
1122
+ if x.dtype == torch.float16:
1123
+ x = x.to(torch.float32)
1124
+
1125
+ zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
1126
+
1127
+ with torch.amp.autocast("cuda", enabled=False):
1128
+ with torch.enable_grad():
1129
+ x = x.detach()
1130
+ x.requires_grad = True
1131
+ y = torch.logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687
1132
+
1133
+ if not requires_grad:
1134
+ return y
1135
+ y.backward(gradient=torch.ones_like(y))
1136
+
1137
+ grad = x.grad
1138
+ floor = -0.08
1139
+ ceil = 0.925
1140
+
1141
+ d_scaled = (grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like(
1142
+ grad
1143
+ )
1144
+ if __name__ == "__main__":
1145
+ # for self-testing only.
1146
+ assert d_scaled.min() >= 0.0
1147
+ assert d_scaled.max() < 256.0
1148
+
1149
+ d_int = d_scaled.to(torch.uint8)
1150
+ ctx.save_for_backward(d_int)
1151
+ if x.dtype == torch.float16 or torch.is_autocast_enabled():
1152
+ y = y.to(torch.float16)
1153
+ return y
1154
+
1155
+ @staticmethod
1156
+ def backward(ctx, y_grad: Tensor) -> Tensor:
1157
+ (d,) = ctx.saved_tensors
1158
+ # the same constants as used in forward pass.
1159
+ floor = -0.08
1160
+ ceil = 0.925
1161
+ d = d * ((ceil - floor) / 255.0) + floor
1162
+ return y_grad * d
1163
+
1164
+
1165
+ class SwooshR(torch.nn.Module):
1166
+ def forward(self, x: Tensor) -> Tensor:
1167
+ """Return Swoosh-R activation."""
1168
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
1169
+ zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
1170
+ return logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687
1171
+ elif "k2" not in sys.modules:
1172
+ return SwooshRFunction.apply(x)
1173
+ else:
1174
+ if not x.requires_grad:
1175
+ return k2.swoosh_r_forward(x)
1176
+ else:
1177
+ return k2.swoosh_r(x)
1178
+
1179
+
1180
+ class SwooshROnnx(torch.nn.Module):
1181
+ def forward(self, x: Tensor) -> Tensor:
1182
+ """Return Swoosh-R activation."""
1183
+ zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
1184
+ return logaddexp_onnx(zero, x - 1.0) - 0.08 * x - 0.313261687
1185
+
1186
+
1187
+ # simple version of SwooshL that does not redefine the backprop, used in
1188
+ # ActivationDropoutAndLinearFunction.
1189
+ def SwooshLForward(x: Tensor):
1190
+ with torch.amp.autocast("cuda", enabled=False):
1191
+ x = x.to(torch.float32)
1192
+ x_offset = x - 4.0
1193
+ log_sum = (1.0 + x_offset.exp()).log().to(x.dtype)
1194
+ log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum)
1195
+ return log_sum - 0.08 * x - 0.035
1196
+
1197
+
1198
+ # simple version of SwooshR that does not redefine the backprop, used in
1199
+ # ActivationDropoutAndLinearFunction.
1200
+ def SwooshRForward(x: Tensor):
1201
+ with torch.amp.autocast("cuda", enabled=False):
1202
+ x = x.to(torch.float32)
1203
+ x_offset = x - 1.0
1204
+ log_sum = (1.0 + x_offset.exp()).log().to(x.dtype)
1205
+ log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum)
1206
+ return log_sum - 0.08 * x - 0.313261687
1207
+
1208
+
1209
+ class ActivationDropoutAndLinearFunction(torch.autograd.Function):
1210
+ @staticmethod
1211
+ @custom_fwd
1212
+ def forward(
1213
+ ctx,
1214
+ x: Tensor,
1215
+ weight: Tensor,
1216
+ bias: Optional[Tensor],
1217
+ activation: str,
1218
+ dropout_p: float,
1219
+ dropout_shared_dim: Optional[int],
1220
+ ):
1221
+ if dropout_p != 0.0:
1222
+ dropout_shape = list(x.shape)
1223
+ if dropout_shared_dim is not None:
1224
+ dropout_shape[dropout_shared_dim] = 1
1225
+ # else it won't be very memory efficient.
1226
+ dropout_mask = (1.0 / (1.0 - dropout_p)) * (
1227
+ torch.rand(*dropout_shape, device=x.device, dtype=x.dtype) > dropout_p
1228
+ )
1229
+ else:
1230
+ dropout_mask = None
1231
+
1232
+ ctx.save_for_backward(x, weight, bias, dropout_mask)
1233
+
1234
+ ctx.activation = activation
1235
+
1236
+ forward_activation_dict = {
1237
+ "SwooshL": k2.swoosh_l_forward,
1238
+ "SwooshR": k2.swoosh_r_forward,
1239
+ }
1240
+ # it will raise a KeyError if this fails. This will be an error. We let it
1241
+ # propagate to the user.
1242
+ activation_func = forward_activation_dict[activation]
1243
+ x = activation_func(x)
1244
+ if dropout_mask is not None:
1245
+ x = x * dropout_mask
1246
+ x = torch.nn.functional.linear(x, weight, bias)
1247
+ return x
1248
+
1249
+ @staticmethod
1250
+ @custom_bwd
1251
+ def backward(ctx, ans_grad: Tensor):
1252
+ saved = ctx.saved_tensors
1253
+ (x, weight, bias, dropout_mask) = saved
1254
+
1255
+ forward_and_deriv_activation_dict = {
1256
+ "SwooshL": k2.swoosh_l_forward_and_deriv,
1257
+ "SwooshR": k2.swoosh_r_forward_and_deriv,
1258
+ }
1259
+ # the following lines a KeyError if the activation is unrecognized.
1260
+ # This will be an error. We let it propagate to the user.
1261
+ func = forward_and_deriv_activation_dict[ctx.activation]
1262
+
1263
+ y, func_deriv = func(x)
1264
+ if dropout_mask is not None:
1265
+ y = y * dropout_mask
1266
+ # now compute derivative of y w.r.t. weight and bias..
1267
+ # y: (..., in_channels), ans_grad: (..., out_channels),
1268
+ (out_channels, in_channels) = weight.shape
1269
+
1270
+ in_channels = y.shape[-1]
1271
+ g = ans_grad.reshape(-1, out_channels)
1272
+ weight_deriv = torch.matmul(g.t(), y.reshape(-1, in_channels))
1273
+ y_deriv = torch.matmul(ans_grad, weight)
1274
+ bias_deriv = None if bias is None else g.sum(dim=0)
1275
+ x_deriv = y_deriv * func_deriv
1276
+ if dropout_mask is not None:
1277
+ # order versus func_deriv does not matter
1278
+ x_deriv = x_deriv * dropout_mask
1279
+
1280
+ return x_deriv, weight_deriv, bias_deriv, None, None, None
1281
+
1282
+
1283
+ class ActivationDropoutAndLinear(torch.nn.Module):
1284
+ """
1285
+ This merges an activation function followed by dropout and then a nn.Linear module;
1286
+ it does so in a memory efficient way so that it only stores the input to the whole
1287
+ module. If activation == SwooshL and dropout_shared_dim != None, this will be
1288
+ equivalent to:
1289
+ nn.Sequential(SwooshL(),
1290
+ Dropout3(dropout_p, shared_dim=dropout_shared_dim),
1291
+ ScaledLinear(in_channels, out_channels, bias=bias,
1292
+ initial_scale=initial_scale))
1293
+ If dropout_shared_dim is None, the dropout would be equivalent to
1294
+ Dropout2(dropout_p). Note: Dropout3 will be more memory efficient as the dropout
1295
+ mask is smaller.
1296
+
1297
+ Args:
1298
+ in_channels: number of input channels, e.g. 256
1299
+ out_channels: number of output channels, e.g. 256
1300
+ bias: if true, have a bias
1301
+ activation: the activation function, for now just support SwooshL.
1302
+ dropout_p: the dropout probability or schedule (happens after nonlinearity).
1303
+ dropout_shared_dim: the dimension, if any, across which the dropout mask is
1304
+ shared (e.g. the time dimension). If None, this may be less memory
1305
+ efficient if there are modules before this one that cache the input
1306
+ for their backprop (e.g. Balancer or Whiten).
1307
+ """
1308
+
1309
+ def __init__(
1310
+ self,
1311
+ in_channels: int,
1312
+ out_channels: int,
1313
+ bias: bool = True,
1314
+ activation: str = "SwooshL",
1315
+ dropout_p: FloatLike = 0.0,
1316
+ dropout_shared_dim: Optional[int] = -1,
1317
+ initial_scale: float = 1.0,
1318
+ ):
1319
+ super().__init__()
1320
+ # create a temporary module of nn.Linear that we'll steal the
1321
+ # weights and bias from
1322
+ l = ScaledLinear(
1323
+ in_channels, out_channels, bias=bias, initial_scale=initial_scale
1324
+ )
1325
+
1326
+ self.weight = l.weight
1327
+ # register_parameter properly handles making it a parameter when l.bias
1328
+ # is None. I think there is some reason for doing it this way rather
1329
+ # than just setting it to None but I don't know what it is, maybe
1330
+ # something to do with exporting the module..
1331
+ self.register_parameter("bias", l.bias)
1332
+
1333
+ self.activation = activation
1334
+ self.dropout_p = dropout_p
1335
+ self.dropout_shared_dim = dropout_shared_dim
1336
+
1337
+ def forward(self, x: Tensor):
1338
+ if (
1339
+ torch.jit.is_scripting()
1340
+ or torch.jit.is_tracing()
1341
+ or "k2" not in sys.modules
1342
+ ):
1343
+ if self.activation == "SwooshL":
1344
+ x = SwooshLForward(x)
1345
+ elif self.activation == "SwooshR":
1346
+ x = SwooshRForward(x)
1347
+ else:
1348
+ assert False, self.activation
1349
+ return torch.nn.functional.linear(x, self.weight, self.bias)
1350
+
1351
+ return ActivationDropoutAndLinearFunction.apply(
1352
+ x,
1353
+ self.weight,
1354
+ self.bias,
1355
+ self.activation,
1356
+ float(self.dropout_p),
1357
+ self.dropout_shared_dim,
1358
+ )
1359
+
1360
+
1361
+ def _test_whiten():
1362
+ for proportion in [0.1, 0.5, 10.0]:
1363
+ logging.info(f"_test_whiten(): proportion = {proportion}")
1364
+ x = torch.randn(100, 128)
1365
+ direction = torch.randn(128)
1366
+ coeffs = torch.randn(100, 1)
1367
+ x += proportion * direction * coeffs
1368
+
1369
+ x.requires_grad = True
1370
+
1371
+ m = Whiten(
1372
+ 1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit,
1373
+ ) # grad_scale
1374
+
1375
+ for _ in range(4):
1376
+ y = m(x)
1377
+
1378
+ y_grad = torch.randn_like(x)
1379
+ y.backward(gradient=y_grad)
1380
+
1381
+ if proportion < 0.2:
1382
+ assert torch.allclose(x.grad, y_grad)
1383
+ elif proportion > 1.0:
1384
+ assert not torch.allclose(x.grad, y_grad)
1385
+
1386
+
1387
+ def _test_balancer_sign():
1388
+ probs = torch.arange(0, 1, 0.01)
1389
+ N = 1000
1390
+ x = 1.0 * ((2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0)
1391
+ x = x.detach()
1392
+ x.requires_grad = True
1393
+ m = Balancer(
1394
+ probs.numel(),
1395
+ channel_dim=0,
1396
+ min_positive=0.05,
1397
+ max_positive=0.95,
1398
+ min_abs=0.0,
1399
+ prob=1.0,
1400
+ )
1401
+
1402
+ y_grad = torch.sign(torch.randn(probs.numel(), N))
1403
+
1404
+ y = m(x)
1405
+ y.backward(gradient=y_grad)
1406
+ print("_test_balancer_sign: x = ", x)
1407
+ print("_test_balancer_sign: y grad = ", y_grad)
1408
+ print("_test_balancer_sign: x grad = ", x.grad)
1409
+
1410
+
1411
+ def _test_balancer_magnitude():
1412
+ magnitudes = torch.arange(0, 1, 0.01)
1413
+ N = 1000
1414
+ x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1)
1415
+ x = x.detach()
1416
+ x.requires_grad = True
1417
+ m = Balancer(
1418
+ magnitudes.numel(),
1419
+ channel_dim=0,
1420
+ min_positive=0.0,
1421
+ max_positive=1.0,
1422
+ min_abs=0.2,
1423
+ max_abs=0.7,
1424
+ prob=1.0,
1425
+ )
1426
+
1427
+ y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
1428
+
1429
+ y = m(x)
1430
+ y.backward(gradient=y_grad)
1431
+ print("_test_balancer_magnitude: x = ", x)
1432
+ print("_test_balancer_magnitude: y grad = ", y_grad)
1433
+ print("_test_balancer_magnitude: x grad = ", x.grad)
1434
+
1435
+
1436
+ def _test_swooshl_deriv():
1437
+ x = torch.randn(10, 12, dtype=torch.double) * 3.0
1438
+ x.requires_grad = True
1439
+ m = SwooshL()
1440
+
1441
+ tol = 1.0 / 255.0
1442
+ torch.autograd.gradcheck(m, x, atol=tol, eps=0.01)
1443
+
1444
+ # for self-test.
1445
+ x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
1446
+ x.requires_grad = True
1447
+ y = m(x)
1448
+ return y
1449
+
1450
+
1451
+ def _test_swooshr_deriv():
1452
+ x = torch.randn(10, 12, dtype=torch.double) * 3.0
1453
+ x.requires_grad = True
1454
+ m = SwooshR()
1455
+
1456
+ tol = 1.0 / 255.0
1457
+ torch.autograd.gradcheck(m, x, atol=tol, eps=0.01)
1458
+
1459
+ # for self-test.
1460
+ x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
1461
+ x.requires_grad = True
1462
+ y = m(x)
1463
+ return y
1464
+
1465
+
1466
+ def _test_softmax():
1467
+ a = torch.randn(2, 10, dtype=torch.float64)
1468
+ b = a.clone()
1469
+ a.requires_grad = True
1470
+ b.requires_grad = True
1471
+ a.softmax(dim=1)[:, 0].sum().backward()
1472
+ print("a grad = ", a.grad)
1473
+ softmax(b, dim=1)[:, 0].sum().backward()
1474
+ print("b grad = ", b.grad)
1475
+ assert torch.allclose(a.grad, b.grad)
1476
+
1477
+
1478
+ def _test_piecewise_linear():
1479
+ p = PiecewiseLinear((0, 10.0))
1480
+ for x in [-100, 0, 100]:
1481
+ assert p(x) == 10.0
1482
+ p = PiecewiseLinear((0, 10.0), (1, 0.0))
1483
+ for x, y in [(-100, 10.0), (0, 10.0), (0.5, 5.0), (1, 0.0), (2, 0.0)]:
1484
+ print("x, y = ", x, y)
1485
+ assert p(x) == y, (x, p(x), y)
1486
+
1487
+ q = PiecewiseLinear((0.5, 15.0), (0.6, 1.0))
1488
+ x_vals = [-1.0, 0.0, 0.1, 0.2, 0.5, 0.6, 0.7, 0.9, 1.0, 2.0]
1489
+ pq = p.max(q)
1490
+ for x in x_vals:
1491
+ y1 = max(p(x), q(x))
1492
+ y2 = pq(x)
1493
+ assert abs(y1 - y2) < 0.001
1494
+ pq = p.min(q)
1495
+ for x in x_vals:
1496
+ y1 = min(p(x), q(x))
1497
+ y2 = pq(x)
1498
+ assert abs(y1 - y2) < 0.001
1499
+ pq = p + q
1500
+ for x in x_vals:
1501
+ y1 = p(x) + q(x)
1502
+ y2 = pq(x)
1503
+ assert abs(y1 - y2) < 0.001
1504
+
1505
+
1506
+ def _test_activation_dropout_and_linear():
1507
+ in_channels = 20
1508
+ out_channels = 30
1509
+
1510
+ for bias in [True, False]:
1511
+ # actually we don't test for dropout_p != 0.0 because forward functions will
1512
+ # different answers. This is because we are using the k2 implementation of
1513
+ # swoosh_l an swoosh_r inside SwooshL() and SwooshR(), and they call randn()
1514
+ # internally, messing up the random state.
1515
+ for dropout_p in [0.0]:
1516
+ for activation in ["SwooshL", "SwooshR"]:
1517
+ m1 = nn.Sequential(
1518
+ SwooshL() if activation == "SwooshL" else SwooshR(),
1519
+ Dropout3(p=dropout_p, shared_dim=-1),
1520
+ ScaledLinear(
1521
+ in_channels, out_channels, bias=bias, initial_scale=0.5
1522
+ ),
1523
+ )
1524
+ m2 = ActivationDropoutAndLinear(
1525
+ in_channels,
1526
+ out_channels,
1527
+ bias=bias,
1528
+ initial_scale=0.5,
1529
+ activation=activation,
1530
+ dropout_p=dropout_p,
1531
+ )
1532
+ with torch.no_grad():
1533
+ m2.weight[:] = m1[2].weight
1534
+ if bias:
1535
+ m2.bias[:] = m1[2].bias
1536
+ # make sure forward gives same result.
1537
+ x1 = torch.randn(10, in_channels)
1538
+ x1.requires_grad = True
1539
+
1540
+ # TEMP.
1541
+ assert torch.allclose(
1542
+ SwooshRFunction.apply(x1), SwooshRForward(x1), atol=1.0e-03
1543
+ )
1544
+
1545
+ x2 = x1.clone().detach()
1546
+ x2.requires_grad = True
1547
+ seed = 10
1548
+ torch.manual_seed(seed)
1549
+ y1 = m1(x1)
1550
+ y_grad = torch.randn_like(y1)
1551
+ y1.backward(gradient=y_grad)
1552
+ torch.manual_seed(seed)
1553
+ y2 = m2(x2)
1554
+ y2.backward(gradient=y_grad)
1555
+
1556
+ print(
1557
+ f"bias = {bias}, dropout_p = {dropout_p}, activation = {activation}"
1558
+ )
1559
+ print("y1 = ", y1)
1560
+ print("y2 = ", y2)
1561
+ assert torch.allclose(y1, y2, atol=0.02)
1562
+ assert torch.allclose(m1[2].weight.grad, m2.weight.grad, atol=1.0e-05)
1563
+ if bias:
1564
+ assert torch.allclose(m1[2].bias.grad, m2.bias.grad, atol=1.0e-05)
1565
+ print("x1.grad = ", x1.grad)
1566
+ print("x2.grad = ", x2.grad)
1567
+
1568
+ def isclose(a, b):
1569
+ # return true if cosine similarity is > 0.9.
1570
+ return (a * b).sum() > 0.9 * (
1571
+ (a**2).sum() * (b**2).sum()
1572
+ ).sqrt()
1573
+
1574
+ # the SwooshL() implementation has a noisy gradient due to 1-byte
1575
+ # storage of it.
1576
+ assert isclose(x1.grad, x2.grad)
1577
+
1578
+
1579
+ if __name__ == "__main__":
1580
+ logging.getLogger().setLevel(logging.DEBUG)
1581
+ torch.set_num_threads(1)
1582
+ torch.set_num_interop_threads(1)
1583
+ _test_piecewise_linear()
1584
+ _test_softmax()
1585
+ _test_whiten()
1586
+ _test_balancer_sign()
1587
+ _test_balancer_magnitude()
1588
+ _test_swooshr_deriv()
1589
+ _test_swooshl_deriv()
1590
+ _test_activation_dropout_and_linear()
zipvoice/models/modules/solver.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2024 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ from typing import Optional, Union
19
+
20
+ import torch
21
+
22
+
23
+ class DiffusionModel(torch.nn.Module):
24
+ """A wrapper of diffusion models for inference.
25
+ Args:
26
+ model: The diffusion model.
27
+ func_name: The function name to call.
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ model: torch.nn.Module,
33
+ func_name: str = "forward_fm_decoder",
34
+ ):
35
+ super().__init__()
36
+ self.model = model
37
+ self.func_name = func_name
38
+ self.model_func = getattr(self.model, func_name)
39
+
40
+ def forward(
41
+ self,
42
+ t: torch.Tensor,
43
+ x: torch.Tensor,
44
+ text_condition: torch.Tensor,
45
+ speech_condition: torch.Tensor,
46
+ padding_mask: Optional[torch.Tensor] = None,
47
+ guidance_scale: Union[float, torch.Tensor] = 0.0,
48
+ **kwargs
49
+ ) -> torch.Tensor:
50
+ """
51
+ Forward function that Handles the classifier-free guidance.
52
+ Args:
53
+ t: The current timestep, a tensor of a tensor of a single float.
54
+ x: The initial value, with the shape (batch, seq_len, emb_dim).
55
+ text_condition: The text_condition of the diffision model, with
56
+ the shape (batch, seq_len, emb_dim).
57
+ speech_condition: The speech_condition of the diffision model, with the
58
+ shape (batch, seq_len, emb_dim).
59
+ padding_mask: The mask for padding; True means masked position, with the
60
+ shape (batch, seq_len).
61
+ guidance_scale: The scale of classifier-free guidance, a float or a tensor
62
+ of shape (batch, 1, 1).
63
+ Retrun:
64
+ The prediction with the shape (batch, seq_len, emb_dim).
65
+ """
66
+ if not torch.is_tensor(guidance_scale):
67
+ guidance_scale = torch.tensor(
68
+ guidance_scale, dtype=t.dtype, device=t.device
69
+ )
70
+
71
+ if (guidance_scale == 0.0).all():
72
+ return self.model_func(
73
+ t=t,
74
+ xt=x,
75
+ text_condition=text_condition,
76
+ speech_condition=speech_condition,
77
+ padding_mask=padding_mask,
78
+ **kwargs
79
+ )
80
+ else:
81
+ assert t.dim() == 0
82
+
83
+ x = torch.cat([x] * 2, dim=0)
84
+ padding_mask = torch.cat([padding_mask] * 2, dim=0)
85
+
86
+ text_condition = torch.cat(
87
+ [torch.zeros_like(text_condition), text_condition], dim=0
88
+ )
89
+
90
+ if t > 0.5:
91
+ speech_condition = torch.cat(
92
+ [torch.zeros_like(speech_condition), speech_condition], dim=0
93
+ )
94
+ else:
95
+ guidance_scale = guidance_scale * 2
96
+ speech_condition = torch.cat(
97
+ [speech_condition, speech_condition], dim=0
98
+ )
99
+
100
+ data_uncond, data_cond = self.model_func(
101
+ t=t,
102
+ xt=x,
103
+ text_condition=text_condition,
104
+ speech_condition=speech_condition,
105
+ padding_mask=padding_mask,
106
+ **kwargs
107
+ ).chunk(2, dim=0)
108
+
109
+ res = (1 + guidance_scale) * data_cond - guidance_scale * data_uncond
110
+ return res
111
+
112
+
113
+ class DistillDiffusionModel(DiffusionModel):
114
+ """A wrapper of distilled diffusion models for inference.
115
+ Args:
116
+ model: The distilled diffusion model.
117
+ func_name: The function name to call.
118
+ """
119
+
120
+ def __init__(
121
+ self,
122
+ model: torch.nn.Module,
123
+ func_name: str = "forward_fm_decoder",
124
+ ):
125
+ super().__init__(model=model, func_name=func_name)
126
+
127
+ def forward(
128
+ self,
129
+ t: torch.Tensor,
130
+ x: torch.Tensor,
131
+ text_condition: torch.Tensor,
132
+ speech_condition: torch.Tensor,
133
+ padding_mask: Optional[torch.Tensor] = None,
134
+ guidance_scale: Union[float, torch.Tensor] = 0.0,
135
+ **kwargs
136
+ ) -> torch.Tensor:
137
+ """
138
+ Forward function that Handles the classifier-free guidance.
139
+ Args:
140
+ t: The current timestep, a tensor of a single float.
141
+ x: The initial value, with the shape (batch, seq_len, emb_dim).
142
+ text_condition: The text_condition of the diffision model, with
143
+ the shape (batch, seq_len, emb_dim).
144
+ speech_condition: The speech_condition of the diffision model, with the
145
+ shape (batch, seq_len, emb_dim).
146
+ padding_mask: The mask for padding; True means masked position, with the
147
+ shape (batch, seq_len).
148
+ guidance_scale: The scale of classifier-free guidance, a float or a tensor
149
+ of shape (batch, 1, 1).
150
+ Retrun:
151
+ The prediction with the shape (batch, seq_len, emb_dim).
152
+ """
153
+ if not torch.is_tensor(guidance_scale):
154
+ guidance_scale = torch.tensor(
155
+ guidance_scale, dtype=t.dtype, device=t.device
156
+ )
157
+ return self.model_func(
158
+ t=t,
159
+ xt=x,
160
+ text_condition=text_condition,
161
+ speech_condition=speech_condition,
162
+ padding_mask=padding_mask,
163
+ guidance_scale=guidance_scale,
164
+ **kwargs
165
+ )
166
+
167
+
168
+ class EulerSolver:
169
+ def __init__(
170
+ self,
171
+ model: torch.nn.Module,
172
+ func_name: str = "forward_fm_decoder",
173
+ ):
174
+ """Construct a Euler Solver
175
+ Args:
176
+ model: The diffusion model.
177
+ func_name: The function name to call.
178
+ """
179
+
180
+ self.model = DiffusionModel(model, func_name=func_name)
181
+
182
+ def sample(
183
+ self,
184
+ x: torch.Tensor,
185
+ text_condition: torch.Tensor,
186
+ speech_condition: torch.Tensor,
187
+ padding_mask: torch.Tensor,
188
+ num_step: int = 10,
189
+ guidance_scale: Union[float, torch.Tensor] = 0.0,
190
+ t_start: float = 0.0,
191
+ t_end: float = 1.0,
192
+ t_shift: float = 1.0,
193
+ **kwargs
194
+ ) -> torch.Tensor:
195
+ """
196
+ Compute the sample at time `t_end` by Euler Solver.
197
+ Args:
198
+ x: The initial value at time `t_start`, with the shape (batch, seq_len,
199
+ emb_dim).
200
+ text_condition: The text condition of the diffision mode, with the
201
+ shape (batch, seq_len, emb_dim).
202
+ speech_condition: The speech condition of the diffision model, with the
203
+ shape (batch, seq_len, emb_dim).
204
+ padding_mask: The mask for padding; True means masked position, with the
205
+ shape (batch, seq_len).
206
+ num_step: The number of ODE steps.
207
+ guidance_scale: The scale for classifier-free guidance, which is
208
+ a float or a tensor with the shape (batch, 1, 1).
209
+ t_start: the start timestep in the range of [0, 1].
210
+ t_end: the end time_step in the range of [0, 1].
211
+ t_shift: shift the t toward smaller numbers so that the sampling
212
+ will emphasize low SNR region. Should be in the range of (0, 1].
213
+ The shifting will be more significant when the number is smaller.
214
+
215
+ Returns:
216
+ The approximated solution at time `t_end`.
217
+ """
218
+ device = x.device
219
+ assert isinstance(t_start, float) and isinstance(t_end, float)
220
+
221
+ timesteps = get_time_steps(
222
+ t_start=t_start,
223
+ t_end=t_end,
224
+ num_step=num_step,
225
+ t_shift=t_shift,
226
+ device=device,
227
+ )
228
+
229
+ for step in range(num_step):
230
+ v = self.model(
231
+ t=timesteps[step],
232
+ x=x,
233
+ text_condition=text_condition,
234
+ speech_condition=speech_condition,
235
+ padding_mask=padding_mask,
236
+ guidance_scale=guidance_scale,
237
+ **kwargs
238
+ )
239
+ x = x + v * (timesteps[step + 1] - timesteps[step])
240
+ return x
241
+
242
+
243
+ class DistillEulerSolver(EulerSolver):
244
+ def __init__(
245
+ self,
246
+ model: torch.nn.Module,
247
+ func_name: str = "forward_fm_decoder",
248
+ ):
249
+ """Construct a Euler Solver for distilled diffusion models.
250
+ Args:
251
+ model: The diffusion model.
252
+ """
253
+ self.model = DistillDiffusionModel(model, func_name=func_name)
254
+
255
+
256
+ def get_time_steps(
257
+ t_start: float = 0.0,
258
+ t_end: float = 1.0,
259
+ num_step: int = 10,
260
+ t_shift: float = 1.0,
261
+ device: torch.device = torch.device("cpu"),
262
+ ) -> torch.Tensor:
263
+ """Compute the intermediate time steps for sampling.
264
+
265
+ Args:
266
+ t_start: The starting time of the sampling (default is 0).
267
+ t_end: The starting time of the sampling (default is 1).
268
+ num_step: The number of sampling.
269
+ t_shift: shift the t toward smaller numbers so that the sampling
270
+ will emphasize low SNR region. Should be in the range of (0, 1].
271
+ The shifting will be more significant when the number is smaller.
272
+ device: A torch device.
273
+ Returns:
274
+ The time step with the shape (num_step + 1,).
275
+ """
276
+
277
+ timesteps = torch.linspace(t_start, t_end, num_step + 1).to(device)
278
+
279
+ timesteps = t_shift * timesteps / (1 + (t_shift - 1) * timesteps)
280
+
281
+ return timesteps
zipvoice/models/modules/zipformer.py ADDED
@@ -0,0 +1,1680 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2022-2024 Xiaomi Corp. (authors: Daniel Povey,
3
+ # Zengwei Yao,
4
+ # Wei Kang
5
+ # Han Zhu)
6
+ #
7
+ # See ../../../../LICENSE for clarification regarding multiple authors
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+
21
+ import copy
22
+ import logging
23
+ import math
24
+ import random
25
+ from typing import Optional, Tuple, Union
26
+
27
+ import torch
28
+ from torch import Tensor, nn
29
+
30
+ from zipvoice.models.modules.scaling import (
31
+ ActivationDropoutAndLinear,
32
+ Balancer,
33
+ BiasNorm,
34
+ Dropout2,
35
+ FloatLike,
36
+ Identity,
37
+ ScaledLinear,
38
+ ScheduledFloat,
39
+ SwooshR,
40
+ Whiten,
41
+ limit_param_value,
42
+ penalize_abs_values_gt,
43
+ softmax,
44
+ )
45
+
46
+
47
+ def timestep_embedding(timesteps, dim, max_period=10000):
48
+ """Create sinusoidal timestep embeddings.
49
+
50
+ :param timesteps: shape of (N) or (N, T)
51
+ :param dim: the dimension of the output.
52
+ :param max_period: controls the minimum frequency of the embeddings.
53
+ :return: an Tensor of positional embeddings. shape of (N, dim) or (T, N, dim)
54
+ """
55
+ half = dim // 2
56
+ freqs = torch.exp(
57
+ -math.log(max_period)
58
+ * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device)
59
+ / half
60
+ )
61
+
62
+ if timesteps.dim() == 2:
63
+ timesteps = timesteps.transpose(0, 1) # (N, T) -> (T, N)
64
+
65
+ args = timesteps[..., None].float() * freqs[None]
66
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
67
+ if dim % 2:
68
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[..., :1])], dim=-1)
69
+ return embedding
70
+
71
+
72
+ class TTSZipformer(nn.Module):
73
+ """
74
+ Args:
75
+
76
+ Note: all "int or Tuple[int]" arguments below will be treated as lists of the same
77
+ length as downsampling_factor if they are single ints or one-element tuples.
78
+ The length of downsampling_factor defines the number of stacks.
79
+
80
+ downsampling_factor (Tuple[int]): downsampling factor for each encoder stack.
81
+ Note: this is in addition to the downsampling factor of 2 that is applied in
82
+ the frontend (self.encoder_embed).
83
+ encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks,
84
+ one per encoder stack.
85
+ num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack
86
+ query_head_dim (int or Tuple[int]): dimension of query and key per attention
87
+ head: per stack, if a tuple..
88
+ pos_head_dim (int or Tuple[int]): dimension of positional-encoding projection
89
+ per attention head
90
+ value_head_dim (int or Tuple[int]): dimension of value in each attention head
91
+ num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism.
92
+ Must be at least 4.
93
+ feedforward_dim (int or Tuple[int]): hidden dimension in feedforward modules
94
+ cnn_module_kernel (int or Tuple[int])): Kernel size of convolution module
95
+
96
+ pos_dim (int): the dimension of each positional-encoding vector prior to
97
+ projection, e.g. 128.
98
+
99
+ dropout (float): dropout rate
100
+ warmup_batches (float): number of batches to warm up over; this controls
101
+ dropout of encoder layers.
102
+ use_time_embed: (bool): if True, take time embedding as an additional input.
103
+ time_embed_dim: (int): the dimension of the time embedding.
104
+ use_guidance_scale_embed (bool): if True, take guidance scale embedding as
105
+ an additional input.
106
+ guidance_scale_embed_dim: (int): the dimension of the guidance scale embedding.
107
+ """
108
+
109
+ def __init__(
110
+ self,
111
+ in_dim: int,
112
+ out_dim: int,
113
+ downsampling_factor: Union[int, Tuple[int]] = (2, 4),
114
+ num_encoder_layers: Union[int, Tuple[int]] = 4,
115
+ cnn_module_kernel: Union[int, Tuple[int]] = 31,
116
+ encoder_dim: int = 384,
117
+ query_head_dim: int = 24,
118
+ pos_head_dim: int = 4,
119
+ value_head_dim: int = 12,
120
+ num_heads: int = 8,
121
+ feedforward_dim: int = 1536,
122
+ pos_dim: int = 192,
123
+ dropout: FloatLike = None, # see code below for default
124
+ warmup_batches: float = 4000.0,
125
+ use_time_embed: bool = True,
126
+ time_embed_dim: int = 192,
127
+ use_guidance_scale_embed: bool = False,
128
+ guidance_scale_embed_dim: int = 192,
129
+ use_conv: bool = True,
130
+ ) -> None:
131
+ super(TTSZipformer, self).__init__()
132
+
133
+ if dropout is None:
134
+ dropout = ScheduledFloat((0.0, 0.3), (20000.0, 0.1))
135
+ if isinstance(downsampling_factor, int):
136
+ downsampling_factor = (downsampling_factor,)
137
+
138
+ def _to_tuple(x):
139
+ """Converts a single int or a 1-tuple of an int to a tuple with the same
140
+ length as downsampling_factor"""
141
+ if isinstance(x, int):
142
+ x = (x,)
143
+ if len(x) == 1:
144
+ x = x * len(downsampling_factor)
145
+ else:
146
+ assert len(x) == len(downsampling_factor) and isinstance(x[0], int)
147
+ return x
148
+
149
+ def _assert_downsampling_factor(factors):
150
+ """assert downsampling_factor follows u-net style"""
151
+ assert factors[0] == 1 and factors[-1] == 1
152
+
153
+ for i in range(1, len(factors) // 2 + 1):
154
+ assert factors[i] == factors[i - 1] * 2
155
+
156
+ for i in range(len(factors) // 2 + 1, len(factors)):
157
+ assert factors[i] * 2 == factors[i - 1]
158
+
159
+ _assert_downsampling_factor(downsampling_factor)
160
+ self.downsampling_factor = downsampling_factor # tuple
161
+ num_encoder_layers = _to_tuple(num_encoder_layers)
162
+ self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel)
163
+ self.encoder_dim = encoder_dim
164
+ self.num_encoder_layers = num_encoder_layers
165
+ self.query_head_dim = query_head_dim
166
+ self.value_head_dim = value_head_dim
167
+ self.num_heads = num_heads
168
+
169
+ self.use_time_embed = use_time_embed
170
+ self.use_guidance_scale_embed = use_guidance_scale_embed
171
+
172
+ self.time_embed_dim = time_embed_dim
173
+ if self.use_time_embed:
174
+ assert time_embed_dim != -1
175
+ else:
176
+ time_embed_dim = -1
177
+ self.guidance_scale_embed_dim = guidance_scale_embed_dim
178
+
179
+ self.in_proj = nn.Linear(in_dim, encoder_dim)
180
+ self.out_proj = nn.Linear(encoder_dim, out_dim)
181
+
182
+ # each one will be Zipformer2Encoder or DownsampledZipformer2Encoder
183
+ encoders = []
184
+
185
+ num_encoders = len(downsampling_factor)
186
+ for i in range(num_encoders):
187
+ encoder_layer = Zipformer2EncoderLayer(
188
+ embed_dim=encoder_dim,
189
+ pos_dim=pos_dim,
190
+ num_heads=num_heads,
191
+ query_head_dim=query_head_dim,
192
+ pos_head_dim=pos_head_dim,
193
+ value_head_dim=value_head_dim,
194
+ feedforward_dim=feedforward_dim,
195
+ use_conv=use_conv,
196
+ cnn_module_kernel=cnn_module_kernel[i],
197
+ dropout=dropout,
198
+ )
199
+
200
+ # For the segment of the warmup period, we let the Conv2dSubsampling
201
+ # layer learn something. Then we start to warm up the other encoders.
202
+ encoder = Zipformer2Encoder(
203
+ encoder_layer,
204
+ num_encoder_layers[i],
205
+ embed_dim=encoder_dim,
206
+ time_embed_dim=time_embed_dim,
207
+ pos_dim=pos_dim,
208
+ warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
209
+ warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
210
+ final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5),
211
+ )
212
+
213
+ if downsampling_factor[i] != 1:
214
+ encoder = DownsampledZipformer2Encoder(
215
+ encoder,
216
+ dim=encoder_dim,
217
+ downsample=downsampling_factor[i],
218
+ )
219
+
220
+ encoders.append(encoder)
221
+
222
+ self.encoders = nn.ModuleList(encoders)
223
+ if self.use_time_embed:
224
+ self.time_embed = nn.Sequential(
225
+ nn.Linear(time_embed_dim, time_embed_dim * 2),
226
+ SwooshR(),
227
+ nn.Linear(time_embed_dim * 2, time_embed_dim),
228
+ )
229
+ else:
230
+ self.time_embed = None
231
+
232
+ if self.use_guidance_scale_embed:
233
+ self.guidance_scale_embed = ScaledLinear(
234
+ guidance_scale_embed_dim,
235
+ time_embed_dim,
236
+ bias=False,
237
+ initial_scale=0.1,
238
+ )
239
+ else:
240
+ self.guidance_scale_embed = None
241
+
242
+ def forward(
243
+ self,
244
+ x: Tensor,
245
+ t: Optional[Tensor] = None,
246
+ padding_mask: Optional[Tensor] = None,
247
+ guidance_scale: Optional[Tensor] = None,
248
+ ) -> Tuple[Tensor, Tensor]:
249
+ """
250
+ Args:
251
+ x:
252
+ The input tensor. Its shape is (batch_size, seq_len, feature_dim).
253
+ t:
254
+ A t tensor of shape (batch_size,) or (batch_size, seq_len)
255
+ padding_mask:
256
+ The mask for padding, of shape (batch_size, seq_len); True means
257
+ masked position. May be None.
258
+ guidance_scale:
259
+ The guidance scale in classifier-free guidance of distillation model.
260
+ Returns:
261
+ Return the output embeddings. its shape is
262
+ (batch_size, output_seq_len, encoder_dim)
263
+ """
264
+ x = x.permute(1, 0, 2)
265
+ x = self.in_proj(x)
266
+
267
+ if t is not None:
268
+ assert t.dim() == 1 or t.dim() == 2, t.shape
269
+ time_emb = timestep_embedding(t, self.time_embed_dim)
270
+ if guidance_scale is not None:
271
+ assert (
272
+ guidance_scale.dim() == 1 or guidance_scale.dim() == 2
273
+ ), guidance_scale.shape
274
+ guidance_scale_emb = self.guidance_scale_embed(
275
+ timestep_embedding(guidance_scale, self.guidance_scale_embed_dim)
276
+ )
277
+ time_emb = time_emb + guidance_scale_emb
278
+ time_emb = self.time_embed(time_emb)
279
+ else:
280
+ time_emb = None
281
+
282
+ attn_mask = None
283
+
284
+ for i, module in enumerate(self.encoders):
285
+ x = module(
286
+ x,
287
+ time_emb=time_emb,
288
+ src_key_padding_mask=padding_mask,
289
+ attn_mask=attn_mask,
290
+ )
291
+ x = self.out_proj(x)
292
+ x = x.permute(1, 0, 2)
293
+ return x
294
+
295
+
296
+ def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat:
297
+ return ScheduledFloat((0.0, x), (20000.0, ratio * x), default=x)
298
+
299
+
300
+ class Zipformer2EncoderLayer(nn.Module):
301
+ """
302
+ Args:
303
+ embed_dim: the number of expected features in the input (required).
304
+ nhead: the number of heads in the multiheadattention models (required).
305
+ feedforward_dim: the dimension of the feedforward network model (required).
306
+ dropout: the dropout value (default=0.1).
307
+ cnn_module_kernel (int): Kernel size of convolution module (default=31).
308
+
309
+ Examples::
310
+ >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8)
311
+ >>> src = torch.rand(10, 32, 512)
312
+ >>> pos_emb = torch.rand(32, 19, 512)
313
+ >>> out = encoder_layer(src, pos_emb)
314
+ """
315
+
316
+ def __init__(
317
+ self,
318
+ embed_dim: int,
319
+ pos_dim: int,
320
+ num_heads: int,
321
+ query_head_dim: int,
322
+ pos_head_dim: int,
323
+ value_head_dim: int,
324
+ feedforward_dim: int,
325
+ dropout: FloatLike = 0.1,
326
+ cnn_module_kernel: int = 31,
327
+ use_conv: bool = True,
328
+ attention_skip_rate: FloatLike = ScheduledFloat(
329
+ (0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0
330
+ ),
331
+ conv_skip_rate: FloatLike = ScheduledFloat(
332
+ (0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0
333
+ ),
334
+ const_attention_rate: FloatLike = ScheduledFloat(
335
+ (0.0, 0.25), (4000.0, 0.025), default=0
336
+ ),
337
+ ff2_skip_rate: FloatLike = ScheduledFloat(
338
+ (0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)
339
+ ),
340
+ ff3_skip_rate: FloatLike = ScheduledFloat(
341
+ (0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)
342
+ ),
343
+ bypass_skip_rate: FloatLike = ScheduledFloat(
344
+ (0.0, 0.5), (4000.0, 0.02), default=0
345
+ ),
346
+ ) -> None:
347
+ super(Zipformer2EncoderLayer, self).__init__()
348
+ self.embed_dim = embed_dim
349
+
350
+ # self.bypass implements layer skipping as well as bypass.
351
+ self.bypass = BypassModule(
352
+ embed_dim, skip_rate=bypass_skip_rate, straight_through_rate=0
353
+ )
354
+ # bypass_mid is bypass used in the middle of the layer.
355
+ self.bypass_mid = BypassModule(embed_dim, straight_through_rate=0)
356
+
357
+ # skip probability for dynamic modules (meaning: anything but feedforward).
358
+ self.attention_skip_rate = copy.deepcopy(attention_skip_rate)
359
+ # an additional skip probability that applies to ConvModule to stop it from
360
+ # contributing too much early on.
361
+ self.conv_skip_rate = copy.deepcopy(conv_skip_rate)
362
+
363
+ # ff2_skip_rate is to prevent the ff2 module from having output that's too big
364
+ # compared to its residual.
365
+ self.ff2_skip_rate = copy.deepcopy(ff2_skip_rate)
366
+ self.ff3_skip_rate = copy.deepcopy(ff3_skip_rate)
367
+
368
+ self.const_attention_rate = copy.deepcopy(const_attention_rate)
369
+
370
+ self.self_attn_weights = RelPositionMultiheadAttentionWeights(
371
+ embed_dim,
372
+ pos_dim=pos_dim,
373
+ num_heads=num_heads,
374
+ query_head_dim=query_head_dim,
375
+ pos_head_dim=pos_head_dim,
376
+ dropout=0.0,
377
+ )
378
+
379
+ self.self_attn1 = SelfAttention(embed_dim, num_heads, value_head_dim)
380
+
381
+ self.self_attn2 = SelfAttention(embed_dim, num_heads, value_head_dim)
382
+
383
+ self.feed_forward1 = FeedforwardModule(
384
+ embed_dim, (feedforward_dim * 3) // 4, dropout
385
+ )
386
+
387
+ self.feed_forward2 = FeedforwardModule(embed_dim, feedforward_dim, dropout)
388
+
389
+ self.feed_forward3 = FeedforwardModule(
390
+ embed_dim, (feedforward_dim * 5) // 4, dropout
391
+ )
392
+
393
+ self.nonlin_attention = NonlinAttention(
394
+ embed_dim, hidden_channels=3 * embed_dim // 4
395
+ )
396
+
397
+ self.use_conv = use_conv
398
+
399
+ if self.use_conv:
400
+ self.conv_module1 = ConvolutionModule(embed_dim, cnn_module_kernel)
401
+
402
+ self.conv_module2 = ConvolutionModule(embed_dim, cnn_module_kernel)
403
+
404
+ self.norm = BiasNorm(embed_dim)
405
+
406
+ self.balancer1 = Balancer(
407
+ embed_dim,
408
+ channel_dim=-1,
409
+ min_positive=0.45,
410
+ max_positive=0.55,
411
+ min_abs=0.2,
412
+ max_abs=4.0,
413
+ )
414
+
415
+ # balancer for output of NonlinAttentionModule
416
+ self.balancer_na = Balancer(
417
+ embed_dim,
418
+ channel_dim=-1,
419
+ min_positive=0.3,
420
+ max_positive=0.7,
421
+ min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)),
422
+ prob=0.05, # out of concern for memory usage
423
+ )
424
+
425
+ # balancer for output of feedforward2, prevent it from staying too
426
+ # small. give this a very small probability, even at the start of
427
+ # training, it's to fix a rare problem and it's OK to fix it slowly.
428
+ self.balancer_ff2 = Balancer(
429
+ embed_dim,
430
+ channel_dim=-1,
431
+ min_positive=0.3,
432
+ max_positive=0.7,
433
+ min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.1), default=0.0),
434
+ max_abs=2.0,
435
+ prob=0.05,
436
+ )
437
+
438
+ self.balancer_ff3 = Balancer(
439
+ embed_dim,
440
+ channel_dim=-1,
441
+ min_positive=0.3,
442
+ max_positive=0.7,
443
+ min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.2), default=0.0),
444
+ max_abs=4.0,
445
+ prob=0.05,
446
+ )
447
+
448
+ self.whiten = Whiten(
449
+ num_groups=1,
450
+ whitening_limit=_whitening_schedule(4.0, ratio=3.0),
451
+ prob=(0.025, 0.25),
452
+ grad_scale=0.01,
453
+ )
454
+
455
+ self.balancer2 = Balancer(
456
+ embed_dim,
457
+ channel_dim=-1,
458
+ min_positive=0.45,
459
+ max_positive=0.55,
460
+ min_abs=0.1,
461
+ max_abs=4.0,
462
+ )
463
+
464
+ def get_sequence_dropout_mask(
465
+ self, x: Tensor, dropout_rate: float
466
+ ) -> Optional[Tensor]:
467
+ if (
468
+ dropout_rate == 0.0
469
+ or not self.training
470
+ or torch.jit.is_scripting()
471
+ or torch.jit.is_tracing()
472
+ ):
473
+ return None
474
+ batch_size = x.shape[1]
475
+ mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to(x.dtype)
476
+ return mask
477
+
478
+ def sequence_dropout(self, x: Tensor, dropout_rate: float) -> Tensor:
479
+ """
480
+ Apply sequence-level dropout to x.
481
+ x shape: (seq_len, batch_size, embed_dim)
482
+ """
483
+ dropout_mask = self.get_sequence_dropout_mask(x, dropout_rate)
484
+ if dropout_mask is None:
485
+ return x
486
+ else:
487
+ return x * dropout_mask
488
+
489
+ def forward(
490
+ self,
491
+ src: Tensor,
492
+ pos_emb: Tensor,
493
+ time_emb: Optional[Tensor] = None,
494
+ attn_mask: Optional[Tensor] = None,
495
+ src_key_padding_mask: Optional[Tensor] = None,
496
+ ) -> Tensor:
497
+ """
498
+ Pass the input through the encoder layer.
499
+ Args:
500
+ src: the sequence to the encoder (required):
501
+ shape (seq_len, batch_size, embedding_dim).
502
+ pos_emb: (1, 2*seq_len-1, pos_emb_dim) or
503
+ (batch_size, 2*seq_len-1, pos_emb_dim)
504
+ time_emb: the embedding representing the current timestep
505
+ shape (batch_size, embedding_dim) or (seq_len, batch_size, embedding_dim).
506
+ attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len)
507
+ or (seq_len, seq_len), interpreted as (batch_size, tgt_seq_len, src_seq_len)
508
+ or (tgt_seq_len, src_seq_len). True means masked position. May be None.
509
+ src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len);
510
+ True means masked position. May be None.
511
+
512
+ Returns:
513
+ A tensor which has the same shape as src
514
+ """
515
+ src_orig = src
516
+
517
+ # dropout rate for non-feedforward submodules
518
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
519
+ attention_skip_rate = 0.0
520
+ else:
521
+ attention_skip_rate = (
522
+ float(self.attention_skip_rate) if self.training else 0.0
523
+ )
524
+
525
+ # attn_weights: (num_heads, batch_size, seq_len, seq_len)
526
+ attn_weights = self.self_attn_weights(
527
+ src,
528
+ pos_emb=pos_emb,
529
+ attn_mask=attn_mask,
530
+ key_padding_mask=src_key_padding_mask,
531
+ )
532
+ if time_emb is not None:
533
+
534
+ src = src + time_emb
535
+
536
+ src = src + self.feed_forward1(src)
537
+
538
+ self_attn_dropout_mask = self.get_sequence_dropout_mask(
539
+ src, attention_skip_rate
540
+ )
541
+
542
+ selected_attn_weights = attn_weights[0:1]
543
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
544
+ pass
545
+ elif self.training and random.random() < float(self.const_attention_rate):
546
+ # Make attention weights constant. The intention is to
547
+ # encourage these modules to do something similar to an
548
+ # averaging-over-time operation.
549
+ # only need the mask, can just use the 1st one and expand later
550
+ selected_attn_weights = selected_attn_weights[0:1]
551
+ selected_attn_weights = (selected_attn_weights > 0.0).to(
552
+ selected_attn_weights.dtype
553
+ )
554
+ selected_attn_weights = selected_attn_weights * (
555
+ 1.0 / selected_attn_weights.sum(dim=-1, keepdim=True)
556
+ )
557
+
558
+ na = self.balancer_na(self.nonlin_attention(src, selected_attn_weights))
559
+
560
+ src = src + (
561
+ na if self_attn_dropout_mask is None else na * self_attn_dropout_mask
562
+ )
563
+
564
+ self_attn = self.self_attn1(src, attn_weights)
565
+
566
+ src = src + (
567
+ self_attn
568
+ if self_attn_dropout_mask is None
569
+ else self_attn * self_attn_dropout_mask
570
+ )
571
+
572
+ if self.use_conv:
573
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
574
+ conv_skip_rate = 0.0
575
+ else:
576
+ conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0
577
+
578
+ if time_emb is not None:
579
+ src = src + time_emb
580
+
581
+ src = src + self.sequence_dropout(
582
+ self.conv_module1(
583
+ src,
584
+ src_key_padding_mask=src_key_padding_mask,
585
+ ),
586
+ conv_skip_rate,
587
+ )
588
+
589
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
590
+ ff2_skip_rate = 0.0
591
+ else:
592
+ ff2_skip_rate = float(self.ff2_skip_rate) if self.training else 0.0
593
+ src = src + self.sequence_dropout(
594
+ self.balancer_ff2(self.feed_forward2(src)), ff2_skip_rate
595
+ )
596
+
597
+ # bypass in the middle of the layer.
598
+ src = self.bypass_mid(src_orig, src)
599
+
600
+ self_attn = self.self_attn2(src, attn_weights)
601
+
602
+ src = src + (
603
+ self_attn
604
+ if self_attn_dropout_mask is None
605
+ else self_attn * self_attn_dropout_mask
606
+ )
607
+
608
+ if self.use_conv:
609
+
610
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
611
+ conv_skip_rate = 0.0
612
+ else:
613
+ conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0
614
+
615
+ if time_emb is not None:
616
+ src = src + time_emb
617
+
618
+ src = src + self.sequence_dropout(
619
+ self.conv_module2(
620
+ src,
621
+ src_key_padding_mask=src_key_padding_mask,
622
+ ),
623
+ conv_skip_rate,
624
+ )
625
+
626
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
627
+ ff3_skip_rate = 0.0
628
+ else:
629
+ ff3_skip_rate = float(self.ff3_skip_rate) if self.training else 0.0
630
+ src = src + self.sequence_dropout(
631
+ self.balancer_ff3(self.feed_forward3(src)), ff3_skip_rate
632
+ )
633
+
634
+ src = self.balancer1(src)
635
+ src = self.norm(src)
636
+
637
+ src = self.bypass(src_orig, src)
638
+
639
+ src = self.balancer2(src)
640
+ src = self.whiten(src)
641
+
642
+ return src
643
+
644
+
645
+ class Zipformer2Encoder(nn.Module):
646
+ r"""Zipformer2Encoder is a stack of N encoder layers
647
+
648
+ Args:
649
+ encoder_layer: an instance of the Zipformer2EncoderLayer() class (required).
650
+ num_layers: the number of sub-encoder-layers in the encoder (required).
651
+ pos_dim: the dimension for the relative positional encoding
652
+
653
+ Examples::
654
+ >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8)
655
+ >>> zipformer_encoder = Zipformer2Encoder(encoder_layer, num_layers=6)
656
+ >>> src = torch.rand(10, 32, 512)
657
+ >>> out = zipformer_encoder(src)
658
+ """
659
+
660
+ def __init__(
661
+ self,
662
+ encoder_layer: nn.Module,
663
+ num_layers: int,
664
+ embed_dim: int,
665
+ time_embed_dim: int,
666
+ pos_dim: int,
667
+ warmup_begin: float,
668
+ warmup_end: float,
669
+ initial_layerdrop_rate: float = 0.5,
670
+ final_layerdrop_rate: float = 0.05,
671
+ ) -> None:
672
+ super().__init__()
673
+ self.encoder_pos = CompactRelPositionalEncoding(
674
+ pos_dim, dropout_rate=0.15, length_factor=1.0
675
+ )
676
+ if time_embed_dim != -1:
677
+ self.time_emb = nn.Sequential(
678
+ SwooshR(),
679
+ nn.Linear(time_embed_dim, embed_dim),
680
+ )
681
+ else:
682
+ self.time_emb = None
683
+
684
+ self.layers = nn.ModuleList(
685
+ [copy.deepcopy(encoder_layer) for i in range(num_layers)]
686
+ )
687
+ self.num_layers = num_layers
688
+
689
+ assert 0 <= warmup_begin <= warmup_end
690
+
691
+ delta = (1.0 / num_layers) * (warmup_end - warmup_begin)
692
+ cur_begin = warmup_begin # interpreted as a training batch index
693
+ for i in range(num_layers):
694
+ cur_end = cur_begin + delta
695
+ self.layers[i].bypass.skip_rate = ScheduledFloat(
696
+ (cur_begin, initial_layerdrop_rate),
697
+ (cur_end, final_layerdrop_rate),
698
+ default=0.0,
699
+ )
700
+ cur_begin = cur_end
701
+
702
+ def forward(
703
+ self,
704
+ src: Tensor,
705
+ time_emb: Optional[Tensor] = None,
706
+ attn_mask: Optional[Tensor] = None,
707
+ src_key_padding_mask: Optional[Tensor] = None,
708
+ ) -> Tensor:
709
+ r"""Pass the input through the encoder layers in turn.
710
+
711
+ Args:
712
+ src: the sequence to the encoder (required):
713
+ shape (seq_len, batch_size, embedding_dim).
714
+ time_emb: the embedding representing the current timestep:
715
+ shape (batch_size, embedding_dim)
716
+ or (seq_len, batch_size, embedding_dim) .
717
+ attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len)
718
+ or (seq_len, seq_len), interpreted as
719
+ (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
720
+ True means masked position. May be None.
721
+ src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len);
722
+ True means masked position. May be None.
723
+
724
+ Returns: a Tensor with the same shape as src.
725
+ """
726
+ pos_emb = self.encoder_pos(src)
727
+ if self.time_emb is not None:
728
+ assert time_emb is not None
729
+ time_emb = self.time_emb(time_emb)
730
+ else:
731
+ assert time_emb is None
732
+
733
+ output = src
734
+
735
+ for i, mod in enumerate(self.layers):
736
+ output = mod(
737
+ output,
738
+ pos_emb,
739
+ time_emb=time_emb,
740
+ attn_mask=attn_mask,
741
+ src_key_padding_mask=src_key_padding_mask,
742
+ )
743
+
744
+ return output
745
+
746
+
747
+ class BypassModule(nn.Module):
748
+ """
749
+ An nn.Module that implements a learnable bypass scale, and also randomized
750
+ per-sequence layer-skipping. The bypass is limited during early stages of training
751
+ to be close to "straight-through", i.e. to not do the bypass operation much
752
+ initially, in order to force all the modules to learn something.
753
+ """
754
+
755
+ def __init__(
756
+ self,
757
+ embed_dim: int,
758
+ skip_rate: FloatLike = 0.0,
759
+ straight_through_rate: FloatLike = 0.0,
760
+ scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0),
761
+ scale_max: FloatLike = 1.0,
762
+ ):
763
+ super().__init__()
764
+ self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5))
765
+ self.skip_rate = copy.deepcopy(skip_rate)
766
+ self.straight_through_rate = copy.deepcopy(straight_through_rate)
767
+ self.scale_min = copy.deepcopy(scale_min)
768
+ self.scale_max = copy.deepcopy(scale_max)
769
+
770
+ def _get_bypass_scale(self, batch_size: int):
771
+ # returns bypass-scale of shape (num_channels,),
772
+ # or (batch_size, num_channels,). This is actually the
773
+ # scale on the non-residual term, so 0 corresponds to bypassing
774
+ # this module.
775
+ if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
776
+ return self.bypass_scale
777
+ else:
778
+ ans = limit_param_value(
779
+ self.bypass_scale,
780
+ min=float(self.scale_min),
781
+ max=float(self.scale_max),
782
+ )
783
+ skip_rate = float(self.skip_rate)
784
+ if skip_rate != 0.0:
785
+ mask = torch.rand((batch_size, 1), device=ans.device) > skip_rate
786
+ ans = ans * mask
787
+ # now ans is of shape (batch_size, num_channels), and is zero for
788
+ # sequences on which we have randomly chosen to do layer-skipping.
789
+ straight_through_rate = float(self.straight_through_rate)
790
+ if straight_through_rate != 0.0:
791
+ mask = (
792
+ torch.rand((batch_size, 1), device=ans.device)
793
+ < straight_through_rate
794
+ )
795
+ ans = torch.maximum(ans, mask.to(ans.dtype))
796
+ return ans
797
+
798
+ def forward(self, src_orig: Tensor, src: Tensor):
799
+ """
800
+ Args: src_orig and src are both of shape (seq_len, batch_size, num_channels)
801
+ Returns: something with the same shape as src and src_orig
802
+ """
803
+ bypass_scale = self._get_bypass_scale(src.shape[1])
804
+ return src_orig + (src - src_orig) * bypass_scale
805
+
806
+
807
+ class DownsampledZipformer2Encoder(nn.Module):
808
+ r"""
809
+ DownsampledZipformer2Encoder is a zipformer encoder evaluated at a reduced frame
810
+ rate, after convolutional downsampling, and then upsampled again at the output, and
811
+ combined with the origin input, so that the output has the same shape as the input.
812
+ """
813
+
814
+ def __init__(self, encoder: nn.Module, dim: int, downsample: int):
815
+ super(DownsampledZipformer2Encoder, self).__init__()
816
+ self.downsample_factor = downsample
817
+ self.downsample = SimpleDownsample(downsample)
818
+ self.num_layers = encoder.num_layers
819
+ self.encoder = encoder
820
+ self.upsample = SimpleUpsample(downsample)
821
+ self.out_combiner = BypassModule(dim, straight_through_rate=0)
822
+
823
+ def forward(
824
+ self,
825
+ src: Tensor,
826
+ time_emb: Optional[Tensor] = None,
827
+ attn_mask: Optional[Tensor] = None,
828
+ src_key_padding_mask: Optional[Tensor] = None,
829
+ ) -> Tensor:
830
+ r"""Downsample, go through encoder, upsample.
831
+
832
+ Args:
833
+ src: the sequence to the encoder (required):
834
+ shape (seq_len, batch_size, embedding_dim).
835
+ time_emb: the embedding representing the current timestep:
836
+ shape (batch_size, embedding_dim)
837
+ or (seq_len, batch_size, embedding_dim) .
838
+ feature_mask: something that broadcasts with src, that we'll multiply `src`
839
+ by at every layer: if a Tensor, likely of shape
840
+ (seq_len, batch_size, embedding_dim)
841
+ attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len)
842
+ or (seq_len, seq_len), interpreted as
843
+ (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
844
+ True means masked position. May be None.
845
+ src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len);
846
+ True means masked position. May be None.
847
+
848
+ Returns: a Tensor with the same shape as src.
849
+ """
850
+ src_orig = src
851
+ src = self.downsample(src)
852
+ ds = self.downsample_factor
853
+ if time_emb is not None and time_emb.dim() == 3:
854
+ time_emb = time_emb[::ds]
855
+ if attn_mask is not None:
856
+ attn_mask = attn_mask[::ds, ::ds]
857
+ if src_key_padding_mask is not None:
858
+ src_key_padding_mask = src_key_padding_mask[..., ::ds]
859
+
860
+ src = self.encoder(
861
+ src,
862
+ time_emb=time_emb,
863
+ attn_mask=attn_mask,
864
+ src_key_padding_mask=src_key_padding_mask,
865
+ )
866
+ src = self.upsample(src)
867
+ # remove any extra frames that are not a multiple of downsample_factor
868
+ src = src[: src_orig.shape[0]]
869
+
870
+ return self.out_combiner(src_orig, src)
871
+
872
+
873
+ class SimpleDownsample(torch.nn.Module):
874
+ """
875
+ Does downsampling with attention, by weighted sum.
876
+ """
877
+
878
+ def __init__(self, downsample: int):
879
+ super(SimpleDownsample, self).__init__()
880
+
881
+ self.bias = nn.Parameter(torch.zeros(downsample))
882
+
883
+ self.name = None # will be set from training code
884
+
885
+ self.downsample = downsample
886
+
887
+ def forward(self, src: Tensor) -> Tensor:
888
+ """
889
+ x: (seq_len, batch_size, in_channels)
890
+ Returns a tensor of shape
891
+ ( (seq_len+downsample-1)//downsample, batch_size, channels)
892
+ """
893
+ (seq_len, batch_size, in_channels) = src.shape
894
+ ds = self.downsample
895
+ d_seq_len = (seq_len + ds - 1) // ds
896
+
897
+ # Pad to an exact multiple of self.downsample
898
+ # right-pad src, repeating the last element.
899
+ pad = d_seq_len * ds - seq_len
900
+ src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2])
901
+ src = torch.cat((src, src_extra), dim=0)
902
+ assert src.shape[0] == d_seq_len * ds
903
+
904
+ src = src.reshape(d_seq_len, ds, batch_size, in_channels)
905
+
906
+ weights = self.bias.softmax(dim=0)
907
+ # weights: (downsample, 1, 1)
908
+ weights = weights.unsqueeze(-1).unsqueeze(-1)
909
+
910
+ # ans1 is the first `in_channels` channels of the output
911
+ ans = (src * weights).sum(dim=1)
912
+
913
+ return ans
914
+
915
+
916
+ class SimpleUpsample(torch.nn.Module):
917
+ """
918
+ A very simple form of upsampling that just repeats the input.
919
+ """
920
+
921
+ def __init__(self, upsample: int):
922
+ super(SimpleUpsample, self).__init__()
923
+ self.upsample = upsample
924
+
925
+ def forward(self, src: Tensor) -> Tensor:
926
+ """
927
+ x: (seq_len, batch_size, num_channels)
928
+ Returns a tensor of shape
929
+ ( (seq_len*upsample), batch_size, num_channels)
930
+ """
931
+ upsample = self.upsample
932
+ (seq_len, batch_size, num_channels) = src.shape
933
+ src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels)
934
+ src = src.reshape(seq_len * upsample, batch_size, num_channels)
935
+ return src
936
+
937
+
938
+ class CompactRelPositionalEncoding(torch.nn.Module):
939
+ """
940
+ Relative positional encoding module. This version is "compact" meaning it is able
941
+ to encode the important information about the relative position in a relatively
942
+ small number of dimensions. The goal is to make it so that small differences between
943
+ large relative offsets (e.g. 1000 vs. 1001) make very little difference to the
944
+ embedding. Such differences were potentially important when encoding absolute
945
+ position, but not important when encoding relative position because there is now no
946
+ need to compare two large offsets with each other.
947
+
948
+ Our embedding works by projecting the interval [-infinity,infinity] to a finite
949
+ interval using the atan() function, before doing the Fourier transform of that fixed
950
+ interval. The atan() function would compress the "long tails" too small, making it
951
+ hard to distinguish between different magnitudes of large offsets, so we use a
952
+ logarithmic function to compress large offsets to a smaller range before applying
953
+ atan(). Scalings are chosen in such a way that the embedding can clearly distinguish
954
+ individual offsets as long as they are quite close to the origin, e.g. abs(offset)
955
+ <= about sqrt(embedding_dim)
956
+
957
+
958
+ Args:
959
+ embed_dim: Embedding dimension.
960
+ dropout_rate: Dropout rate.
961
+ max_len: Maximum input length: just a heuristic for initialization.
962
+ length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives
963
+ less weight to small differences of offset near the origin.
964
+ """
965
+
966
+ def __init__(
967
+ self,
968
+ embed_dim: int,
969
+ dropout_rate: FloatLike,
970
+ max_len: int = 1000,
971
+ length_factor: float = 1.0,
972
+ ) -> None:
973
+ """Construct a CompactRelPositionalEncoding object."""
974
+ super(CompactRelPositionalEncoding, self).__init__()
975
+ self.embed_dim = embed_dim
976
+ assert embed_dim % 2 == 0, embed_dim
977
+ self.dropout = Dropout2(dropout_rate)
978
+ self.pe = None
979
+ assert length_factor >= 1.0, length_factor
980
+ self.length_factor = length_factor
981
+ self.extend_pe(torch.tensor(0.0).expand(max_len))
982
+
983
+ def extend_pe(self, x: Tensor, left_context_len: int = 0) -> None:
984
+ """Reset the positional encodings."""
985
+ T = x.size(0) + left_context_len
986
+
987
+ if self.pe is not None:
988
+ # self.pe contains both positive and negative parts
989
+ # the length of self.pe is 2 * input_len - 1
990
+ if self.pe.size(0) >= T * 2 - 1:
991
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
992
+ return
993
+
994
+ # if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ]
995
+ x = torch.arange(-(T - 1), T, device=x.device).to(torch.float32).unsqueeze(1)
996
+
997
+ freqs = 1 + torch.arange(self.embed_dim // 2, device=x.device)
998
+
999
+ # `compression_length` this is arbitrary/heuristic, if it is larger we have more
1000
+ # resolution for small time offsets but less resolution for large time offsets.
1001
+ compression_length = self.embed_dim**0.5
1002
+ # x_compressed, like X, goes from -infinity to infinity as T goes from -infinity
1003
+ # to infinity; but it does so more slowly than T for large absolute values of T.
1004
+ # The formula is chosen so that d(x_compressed )/dx is 1 around x == 0, which is
1005
+ # important.
1006
+ x_compressed = (
1007
+ compression_length
1008
+ * x.sign()
1009
+ * ((x.abs() + compression_length).log() - math.log(compression_length))
1010
+ )
1011
+
1012
+ # if self.length_factor == 1.0, then length_scale is chosen so that the
1013
+ # FFT can exactly separate points close to the origin (T == 0). So this
1014
+ # part of the formulation is not really heuristic.
1015
+ # But empirically, for ASR at least, length_factor > 1.0 seems to work better.
1016
+ length_scale = self.length_factor * self.embed_dim / (2.0 * math.pi)
1017
+
1018
+ # note for machine implementations: if atan is not available, we can use:
1019
+ # x.sign() * ((1 / (x.abs() + 1)) - 1) * (-math.pi/2)
1020
+ # check on wolframalpha.com: plot(sign(x) * (1 / ( abs(x) + 1) - 1 ) * -pi/2 ,
1021
+ # atan(x))
1022
+ x_atan = (x_compressed / length_scale).atan() # results between -pi and pi
1023
+
1024
+ cosines = (x_atan * freqs).cos()
1025
+ sines = (x_atan * freqs).sin()
1026
+
1027
+ pe = torch.zeros(x.shape[0], self.embed_dim, device=x.device)
1028
+ pe[:, 0::2] = cosines
1029
+ pe[:, 1::2] = sines
1030
+ pe[:, -1] = 1.0 # for bias.
1031
+
1032
+ self.pe = pe.to(dtype=x.dtype)
1033
+
1034
+ def forward(self, x: Tensor, left_context_len: int = 0) -> Tensor:
1035
+ """Create positional encoding.
1036
+
1037
+ Args:
1038
+ x (Tensor): Input tensor (time, batch, `*`).
1039
+ left_context_len: (int): Length of cached left context.
1040
+
1041
+ Returns:
1042
+ positional embedding, of shape (batch, left_context_len + 2*time-1, `*`).
1043
+ """
1044
+ self.extend_pe(x, left_context_len)
1045
+ x_size_left = x.size(0) + left_context_len
1046
+ # length of positive side: x.size(0) + left_context_len
1047
+ # length of negative side: x.size(0)
1048
+ pos_emb = self.pe[
1049
+ self.pe.size(0) // 2
1050
+ - x_size_left
1051
+ + 1 : self.pe.size(0) // 2 # noqa E203
1052
+ + x.size(0),
1053
+ :,
1054
+ ]
1055
+ pos_emb = pos_emb.unsqueeze(0)
1056
+ return self.dropout(pos_emb)
1057
+
1058
+
1059
+ class RelPositionMultiheadAttentionWeights(nn.Module):
1060
+ r"""Module that computes multi-head attention weights with relative position
1061
+ encoding. Various other modules consume the resulting attention weights:
1062
+ see, for example, the SimpleAttention module which allows you to compute
1063
+ conventional attention.
1064
+
1065
+ This is a quite heavily modified from: "Transformer-XL: Attentive Language
1066
+ Models Beyond a Fixed-Length Context",
1067
+ we have to write up the differences.
1068
+
1069
+
1070
+ Args:
1071
+ embed_dim: number of channels at the input to this module, e.g. 256
1072
+ pos_dim: dimension of the positional encoding vectors, e.g. 128.
1073
+ num_heads: number of heads to compute weights for, e.g. 8
1074
+ query_head_dim: dimension of the query (and key), per head. e.g. 24.
1075
+ pos_head_dim: dimension of the projected positional encoding per head, e.g. 4.
1076
+ dropout: dropout probability for attn_output_weights. Default: 0.0.
1077
+ pos_emb_skip_rate: probability for skipping the pos_emb part of the scores on
1078
+ any given call to forward(), in training time.
1079
+ """
1080
+
1081
+ def __init__(
1082
+ self,
1083
+ embed_dim: int,
1084
+ pos_dim: int,
1085
+ num_heads: int,
1086
+ query_head_dim: int,
1087
+ pos_head_dim: int,
1088
+ dropout: float = 0.0,
1089
+ pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.0)),
1090
+ ) -> None:
1091
+ super().__init__()
1092
+ self.embed_dim = embed_dim
1093
+ self.num_heads = num_heads
1094
+ self.query_head_dim = query_head_dim
1095
+ self.pos_head_dim = pos_head_dim
1096
+ self.dropout = dropout
1097
+ self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate)
1098
+ self.name = None # will be overwritten in training code; for diagnostics.
1099
+
1100
+ key_head_dim = query_head_dim
1101
+ in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads
1102
+
1103
+ # the initial_scale is supposed to take over the "scaling" factor of
1104
+ # head_dim ** -0.5 that has been used in previous forms of attention,
1105
+ # dividing it between the query and key. Note: this module is intended
1106
+ # to be used with the ScaledAdam optimizer; with most other optimizers,
1107
+ # it would be necessary to apply the scaling factor in the forward function.
1108
+ self.in_proj = ScaledLinear(
1109
+ embed_dim,
1110
+ in_proj_dim,
1111
+ bias=True,
1112
+ initial_scale=query_head_dim**-0.25,
1113
+ )
1114
+
1115
+ self.whiten_keys = Whiten(
1116
+ num_groups=num_heads,
1117
+ whitening_limit=_whitening_schedule(3.0),
1118
+ prob=(0.025, 0.25),
1119
+ grad_scale=0.025,
1120
+ )
1121
+
1122
+ # add a balancer for the keys that runs with very small probability, and
1123
+ # tries to enforce that all dimensions have mean around zero. The
1124
+ # weights produced by this module are invariant to adding a constant to
1125
+ # the keys, so the derivative of the bias is mathematically zero; but
1126
+ # due to how Adam/ScaledAdam work, it can learn a fairly large nonzero
1127
+ # bias because the small numerical roundoff tends to have a non-random
1128
+ # sign. This module is intended to prevent that. Use a very small
1129
+ # probability; that should be sufficient to fix the problem.
1130
+ self.balance_keys = Balancer(
1131
+ key_head_dim * num_heads,
1132
+ channel_dim=-1,
1133
+ min_positive=0.4,
1134
+ max_positive=0.6,
1135
+ min_abs=0.0,
1136
+ max_abs=100.0,
1137
+ prob=0.025,
1138
+ )
1139
+
1140
+ # linear transformation for positional encoding.
1141
+ self.linear_pos = ScaledLinear(
1142
+ pos_dim, num_heads * pos_head_dim, bias=False, initial_scale=0.05
1143
+ )
1144
+
1145
+ # the following are for diagnostics only, see --print-diagnostics option
1146
+ self.copy_pos_query = Identity()
1147
+ self.copy_query = Identity()
1148
+
1149
+ def forward(
1150
+ self,
1151
+ x: Tensor,
1152
+ pos_emb: Tensor,
1153
+ key_padding_mask: Optional[Tensor] = None,
1154
+ attn_mask: Optional[Tensor] = None,
1155
+ ) -> Tensor:
1156
+ r"""
1157
+ Args:
1158
+ x: input of shape (seq_len, batch_size, embed_dim)
1159
+ pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim)
1160
+ key_padding_mask: a bool tensor of shape (batch_size, seq_len).
1161
+ Positions that are True in this mask will be ignored as sources in the
1162
+ attention weighting.
1163
+ attn_mask: mask of shape (seq_len, seq_len) or
1164
+ (batch_size, seq_len, seq_len), interpreted as
1165
+ ([batch_size,] tgt_seq_len, src_seq_len)
1166
+ saying which positions are allowed to attend to which other positions.
1167
+ Returns:
1168
+ a tensor of attention weights, of
1169
+ shape (hum_heads, batch_size, seq_len, seq_len)
1170
+ interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len).
1171
+ """
1172
+ x = self.in_proj(x)
1173
+ query_head_dim = self.query_head_dim
1174
+ pos_head_dim = self.pos_head_dim
1175
+ num_heads = self.num_heads
1176
+
1177
+ seq_len, batch_size, _ = x.shape
1178
+
1179
+ query_dim = query_head_dim * num_heads
1180
+
1181
+ # self-attention
1182
+ q = x[..., 0:query_dim]
1183
+ k = x[..., query_dim : 2 * query_dim]
1184
+ # p is the position-encoding query
1185
+ p = x[..., 2 * query_dim :]
1186
+ assert p.shape[-1] == num_heads * pos_head_dim, (
1187
+ p.shape[-1],
1188
+ num_heads,
1189
+ pos_head_dim,
1190
+ )
1191
+
1192
+ q = self.copy_query(q) # for diagnostics only, does nothing.
1193
+ k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass.
1194
+ p = self.copy_pos_query(p) # for diagnostics only, does nothing.
1195
+
1196
+ q = q.reshape(seq_len, batch_size, num_heads, query_head_dim)
1197
+ p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim)
1198
+ k = k.reshape(seq_len, batch_size, num_heads, query_head_dim)
1199
+
1200
+ # time1 refers to target, time2 refers to source.
1201
+ q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim)
1202
+ p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim)
1203
+ k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2)
1204
+
1205
+ attn_scores = torch.matmul(q, k)
1206
+
1207
+ use_pos_scores = False
1208
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
1209
+ # We can't put random.random() in the same line
1210
+ use_pos_scores = True
1211
+ elif not self.training or random.random() >= float(self.pos_emb_skip_rate):
1212
+ use_pos_scores = True
1213
+
1214
+ if use_pos_scores:
1215
+ pos_emb = self.linear_pos(pos_emb)
1216
+ seq_len2 = 2 * seq_len - 1
1217
+ pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(
1218
+ 2, 0, 3, 1
1219
+ )
1220
+ # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2)
1221
+
1222
+ # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head,
1223
+ # batch, time1, seq_len2) [where seq_len2 represents relative position.]
1224
+ pos_scores = torch.matmul(p, pos_emb)
1225
+ # the following .as_strided() expression converts the last axis of
1226
+ # pos_scores from relative to absolute position. I don't know whether I
1227
+ # might have got the time-offsets backwards or not, but let this code define
1228
+ # which way round it is supposed to be.
1229
+ if torch.jit.is_tracing():
1230
+ (num_heads, batch_size, time1, n) = pos_scores.shape
1231
+ rows = torch.arange(start=time1 - 1, end=-1, step=-1)
1232
+ cols = torch.arange(seq_len)
1233
+ rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
1234
+ indexes = rows + cols
1235
+ pos_scores = pos_scores.reshape(-1, n)
1236
+ pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
1237
+ pos_scores = pos_scores.reshape(num_heads, batch_size, time1, seq_len)
1238
+ else:
1239
+ pos_scores = pos_scores.as_strided(
1240
+ (num_heads, batch_size, seq_len, seq_len),
1241
+ (
1242
+ pos_scores.stride(0),
1243
+ pos_scores.stride(1),
1244
+ pos_scores.stride(2) - pos_scores.stride(3),
1245
+ pos_scores.stride(3),
1246
+ ),
1247
+ storage_offset=pos_scores.stride(3) * (seq_len - 1),
1248
+ )
1249
+
1250
+ attn_scores = attn_scores + pos_scores
1251
+
1252
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
1253
+ pass
1254
+ elif self.training and random.random() < 0.1:
1255
+ # This is a harder way of limiting the attention scores to not be
1256
+ # too large. It incurs a penalty if any of them has an absolute
1257
+ # value greater than 50.0. this should be outside the normal range
1258
+ # of the attention scores. We use this mechanism instead of, say,
1259
+ # something added to the loss function involving the entropy,
1260
+ # because once the entropy gets very small gradients through the
1261
+ # softmax can become very small, and we'd get zero derivatives. The
1262
+ # choices of 1.0e-04 as the scale on the penalty makes this
1263
+ # mechanism vulnerable to the absolute scale of the loss function,
1264
+ # but we view this as a failsafe to avoid "implausible" parameter
1265
+ # values rather than a regularization method that should be active
1266
+ # under normal circumstances.
1267
+ attn_scores = penalize_abs_values_gt(
1268
+ attn_scores, limit=25.0, penalty=1.0e-04, name=self.name
1269
+ )
1270
+
1271
+ assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len)
1272
+
1273
+ if attn_mask is not None:
1274
+ assert attn_mask.dtype == torch.bool
1275
+ # use -1000 to avoid nan's where attn_mask and key_padding_mask make
1276
+ # all scores zero. It's important that this be large enough that exp(-1000)
1277
+ # is exactly zero, for reasons related to const_attention_rate, it
1278
+ # compares the final weights with zero.
1279
+ attn_scores = attn_scores.masked_fill(attn_mask, -1000)
1280
+
1281
+ if key_padding_mask is not None:
1282
+ assert key_padding_mask.shape == (
1283
+ batch_size,
1284
+ seq_len,
1285
+ ), key_padding_mask.shape
1286
+ attn_scores = attn_scores.masked_fill(
1287
+ key_padding_mask.unsqueeze(1),
1288
+ -1000,
1289
+ )
1290
+
1291
+ # We use our own version of softmax, defined in scaling.py, which should
1292
+ # save a little of the memory used in backprop by, if we are in
1293
+ # automatic mixed precision mode (amp / autocast), by only storing the
1294
+ # half-precision output for backprop purposes.
1295
+ attn_weights = softmax(attn_scores, dim=-1)
1296
+
1297
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
1298
+ pass
1299
+ elif random.random() < 0.001 and not self.training:
1300
+ self._print_attn_entropy(attn_weights)
1301
+
1302
+ attn_weights = nn.functional.dropout(
1303
+ attn_weights, p=self.dropout, training=self.training
1304
+ )
1305
+
1306
+ return attn_weights
1307
+
1308
+ def _print_attn_entropy(self, attn_weights: Tensor):
1309
+ # attn_weights: (num_heads, batch_size, seq_len, seq_len)
1310
+ (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape
1311
+
1312
+ with torch.no_grad():
1313
+ with torch.amp.autocast("cuda", enabled=False):
1314
+ attn_weights = attn_weights.to(torch.float32)
1315
+ attn_weights_entropy = (
1316
+ -((attn_weights + 1.0e-20).log() * attn_weights)
1317
+ .sum(dim=-1)
1318
+ .mean(dim=(1, 2))
1319
+ )
1320
+ logging.debug(
1321
+ f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}"
1322
+ )
1323
+
1324
+
1325
+ class SelfAttention(nn.Module):
1326
+ """
1327
+ The simplest possible attention module. This one works with already-computed
1328
+ attention weights, e.g. as computed by RelPositionMultiheadAttentionWeights.
1329
+
1330
+ Args:
1331
+ embed_dim: the input and output embedding dimension
1332
+ num_heads: the number of attention heads
1333
+ value_head_dim: the value dimension per head
1334
+ """
1335
+
1336
+ def __init__(
1337
+ self,
1338
+ embed_dim: int,
1339
+ num_heads: int,
1340
+ value_head_dim: int,
1341
+ ) -> None:
1342
+ super().__init__()
1343
+ self.in_proj = nn.Linear(embed_dim, num_heads * value_head_dim, bias=True)
1344
+
1345
+ self.out_proj = ScaledLinear(
1346
+ num_heads * value_head_dim,
1347
+ embed_dim,
1348
+ bias=True,
1349
+ initial_scale=0.05,
1350
+ )
1351
+
1352
+ self.whiten = Whiten(
1353
+ num_groups=1,
1354
+ whitening_limit=_whitening_schedule(7.5, ratio=3.0),
1355
+ prob=(0.025, 0.25),
1356
+ grad_scale=0.01,
1357
+ )
1358
+
1359
+ def forward(
1360
+ self,
1361
+ x: Tensor,
1362
+ attn_weights: Tensor,
1363
+ ) -> Tensor:
1364
+ """
1365
+ Args:
1366
+ x: input tensor, of shape (seq_len, batch_size, embed_dim)
1367
+ attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len),
1368
+ with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect
1369
+ attn_weights.sum(dim=-1) == 1.
1370
+ Returns:
1371
+ a tensor with the same shape as x.
1372
+ """
1373
+ (seq_len, batch_size, embed_dim) = x.shape
1374
+ num_heads = attn_weights.shape[0]
1375
+ assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len)
1376
+
1377
+ x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim)
1378
+ x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3)
1379
+ # now x: (num_heads, batch_size, seq_len, value_head_dim)
1380
+ value_head_dim = x.shape[-1]
1381
+
1382
+ # todo: see whether there is benefit in overriding matmul
1383
+ x = torch.matmul(attn_weights, x)
1384
+ # v: (num_heads, batch_size, seq_len, value_head_dim)
1385
+
1386
+ x = (
1387
+ x.permute(2, 1, 0, 3)
1388
+ .contiguous()
1389
+ .view(seq_len, batch_size, num_heads * value_head_dim)
1390
+ )
1391
+
1392
+ # returned value is of shape (seq_len, batch_size, embed_dim), like the input.
1393
+ x = self.out_proj(x)
1394
+ x = self.whiten(x)
1395
+
1396
+ return x
1397
+
1398
+
1399
+ class FeedforwardModule(nn.Module):
1400
+ """Feedforward module in TTSZipformer model."""
1401
+
1402
+ def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike):
1403
+ super(FeedforwardModule, self).__init__()
1404
+ self.in_proj = nn.Linear(embed_dim, feedforward_dim)
1405
+
1406
+ self.hidden_balancer = Balancer(
1407
+ feedforward_dim,
1408
+ channel_dim=-1,
1409
+ min_positive=0.3,
1410
+ max_positive=1.0,
1411
+ min_abs=0.75,
1412
+ max_abs=5.0,
1413
+ )
1414
+
1415
+ # shared_dim=0 means we share the dropout mask along the time axis
1416
+ self.out_proj = ActivationDropoutAndLinear(
1417
+ feedforward_dim,
1418
+ embed_dim,
1419
+ activation="SwooshL",
1420
+ dropout_p=dropout,
1421
+ dropout_shared_dim=0,
1422
+ bias=True,
1423
+ initial_scale=0.1,
1424
+ )
1425
+
1426
+ self.out_whiten = Whiten(
1427
+ num_groups=1,
1428
+ whitening_limit=_whitening_schedule(7.5),
1429
+ prob=(0.025, 0.25),
1430
+ grad_scale=0.01,
1431
+ )
1432
+
1433
+ def forward(self, x: Tensor):
1434
+ x = self.in_proj(x)
1435
+ x = self.hidden_balancer(x)
1436
+ # out_proj contains SwooshL activation, then dropout, then linear.
1437
+ x = self.out_proj(x)
1438
+ x = self.out_whiten(x)
1439
+ return x
1440
+
1441
+
1442
+ class NonlinAttention(nn.Module):
1443
+ """This is like the ConvolutionModule, but refactored so that we use multiplication
1444
+ by attention weights (borrowed from the attention module) in place of actual
1445
+ convolution. We also took out the second nonlinearity, the one after the
1446
+ attention mechanism.
1447
+
1448
+ Args:
1449
+ channels (int): The number of channels of conv layers.
1450
+ """
1451
+
1452
+ def __init__(
1453
+ self,
1454
+ channels: int,
1455
+ hidden_channels: int,
1456
+ ) -> None:
1457
+ super().__init__()
1458
+
1459
+ self.hidden_channels = hidden_channels
1460
+
1461
+ self.in_proj = nn.Linear(channels, hidden_channels * 3, bias=True)
1462
+
1463
+ # balancer that goes before the sigmoid. Have quite a large min_abs value, at
1464
+ # 2.0, because we noticed that well-trained instances of this module have
1465
+ # abs-value before the sigmoid starting from about 3, and poorly-trained
1466
+ # instances of the module have smaller abs values before the sigmoid.
1467
+ self.balancer = Balancer(
1468
+ hidden_channels,
1469
+ channel_dim=-1,
1470
+ min_positive=ScheduledFloat((0.0, 0.25), (20000.0, 0.05)),
1471
+ max_positive=ScheduledFloat((0.0, 0.75), (20000.0, 0.95)),
1472
+ min_abs=0.5,
1473
+ max_abs=5.0,
1474
+ )
1475
+ self.tanh = nn.Tanh()
1476
+
1477
+ self.identity1 = Identity() # for diagnostics.
1478
+ self.identity2 = Identity() # for diagnostics.
1479
+ self.identity3 = Identity() # for diagnostics.
1480
+
1481
+ self.out_proj = ScaledLinear(
1482
+ hidden_channels, channels, bias=True, initial_scale=0.05
1483
+ )
1484
+
1485
+ self.whiten1 = Whiten(
1486
+ num_groups=1,
1487
+ whitening_limit=_whitening_schedule(5.0),
1488
+ prob=(0.025, 0.25),
1489
+ grad_scale=0.01,
1490
+ )
1491
+
1492
+ self.whiten2 = Whiten(
1493
+ num_groups=1,
1494
+ whitening_limit=_whitening_schedule(5.0, ratio=3.0),
1495
+ prob=(0.025, 0.25),
1496
+ grad_scale=0.01,
1497
+ )
1498
+
1499
+ def forward(
1500
+ self,
1501
+ x: Tensor,
1502
+ attn_weights: Tensor,
1503
+ ) -> Tensor:
1504
+ """.
1505
+ Args:
1506
+ x: a Tensor of shape (seq_len, batch_size, num_channels)
1507
+ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
1508
+ Returns:
1509
+ a Tensor with the same shape as x
1510
+ """
1511
+ x = self.in_proj(x)
1512
+
1513
+ (seq_len, batch_size, _) = x.shape
1514
+ hidden_channels = self.hidden_channels
1515
+
1516
+ s, x, y = x.chunk(3, dim=2)
1517
+
1518
+ # s will go through tanh.
1519
+
1520
+ s = self.balancer(s)
1521
+ s = self.tanh(s)
1522
+
1523
+ s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels)
1524
+ x = self.whiten1(x)
1525
+ x = x * s
1526
+ x = self.identity1(x) # diagnostics only, it's the identity.
1527
+
1528
+ (seq_len, batch_size, embed_dim) = x.shape
1529
+ num_heads = attn_weights.shape[0]
1530
+ assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len)
1531
+
1532
+ x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3)
1533
+ # now x: (num_heads, batch_size, seq_len, head_dim)
1534
+ x = torch.matmul(attn_weights, x)
1535
+ # now x: (num_heads, batch_size, seq_len, head_dim)
1536
+ x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1)
1537
+
1538
+ y = self.identity2(y)
1539
+ x = x * y
1540
+ x = self.identity3(x)
1541
+
1542
+ x = self.out_proj(x)
1543
+ x = self.whiten2(x)
1544
+ return x
1545
+
1546
+
1547
+ class ConvolutionModule(nn.Module):
1548
+ """ConvolutionModule in Zipformer2 model.
1549
+
1550
+ Args:
1551
+ channels (int): The number of channels of conv layers.
1552
+ kernel_size (int): Kernerl size of conv layers.
1553
+ bias (bool): Whether to use bias in conv layers (default=True).
1554
+
1555
+ """
1556
+
1557
+ def __init__(
1558
+ self,
1559
+ channels: int,
1560
+ kernel_size: int,
1561
+ ) -> None:
1562
+ """Construct a ConvolutionModule object."""
1563
+ super(ConvolutionModule, self).__init__()
1564
+ # kernerl_size should be a odd number for 'SAME' padding
1565
+ assert (kernel_size - 1) % 2 == 0
1566
+
1567
+ bottleneck_dim = channels
1568
+
1569
+ self.in_proj = nn.Linear(
1570
+ channels,
1571
+ 2 * bottleneck_dim,
1572
+ )
1573
+ # the gradients on in_proj are a little noisy, likely to do with the
1574
+ # sigmoid in glu.
1575
+
1576
+ # after in_proj we put x through a gated linear unit (nn.functional.glu). For
1577
+ # most layers the normal rms value of channels of x seems to be in the range 1
1578
+ # to 4, but sometimes, for some reason, for layer 0 the rms ends up being very
1579
+ # large, between 50 and 100 for different channels. This will cause very peaky
1580
+ # and sparse derivatives for the sigmoid gating function, which will tend to
1581
+ # make the loss function not learn effectively. (for most layers the average
1582
+ # absolute values are in the range 0.5..9.0, and the average p(x>0), i.e.
1583
+ # positive proportion, at the output of pointwise_conv1.output is around 0.35 to
1584
+ # 0.45 for different layers, which likely breaks down as 0.5 for the "linear"
1585
+ # half and 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that
1586
+ # if we constrain the rms values to a reasonable range via a constraint of
1587
+ # max_abs=10.0, it will be in a better position to start learning something,
1588
+ # i.e. to latch onto the correct range.
1589
+ self.balancer1 = Balancer(
1590
+ bottleneck_dim,
1591
+ channel_dim=-1,
1592
+ min_positive=ScheduledFloat((0.0, 0.05), (8000.0, 0.025)),
1593
+ max_positive=1.0,
1594
+ min_abs=1.5,
1595
+ max_abs=ScheduledFloat((0.0, 5.0), (8000.0, 10.0), default=1.0),
1596
+ )
1597
+
1598
+ self.activation1 = Identity() # for diagnostics
1599
+
1600
+ self.sigmoid = nn.Sigmoid()
1601
+
1602
+ self.activation2 = Identity() # for diagnostics
1603
+
1604
+ assert kernel_size % 2 == 1
1605
+
1606
+ self.depthwise_conv = nn.Conv1d(
1607
+ in_channels=bottleneck_dim,
1608
+ out_channels=bottleneck_dim,
1609
+ groups=bottleneck_dim,
1610
+ kernel_size=kernel_size,
1611
+ padding=kernel_size // 2,
1612
+ )
1613
+
1614
+ self.balancer2 = Balancer(
1615
+ bottleneck_dim,
1616
+ channel_dim=1,
1617
+ min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)),
1618
+ max_positive=1.0,
1619
+ min_abs=ScheduledFloat((0.0, 0.2), (20000.0, 0.5)),
1620
+ max_abs=10.0,
1621
+ )
1622
+
1623
+ self.whiten = Whiten(
1624
+ num_groups=1,
1625
+ whitening_limit=_whitening_schedule(7.5),
1626
+ prob=(0.025, 0.25),
1627
+ grad_scale=0.01,
1628
+ )
1629
+
1630
+ self.out_proj = ActivationDropoutAndLinear(
1631
+ bottleneck_dim,
1632
+ channels,
1633
+ activation="SwooshR",
1634
+ dropout_p=0.0,
1635
+ initial_scale=0.05,
1636
+ )
1637
+
1638
+ def forward(
1639
+ self,
1640
+ x: Tensor,
1641
+ src_key_padding_mask: Optional[Tensor] = None,
1642
+ ) -> Tensor:
1643
+ """Compute convolution module.
1644
+
1645
+ Args:
1646
+ x: Input tensor (#time, batch, channels).
1647
+ src_key_padding_mask: the mask for the src keys per batch (optional):
1648
+ (batch, #time), contains True in masked positions.
1649
+
1650
+ Returns:
1651
+ Tensor: Output tensor (#time, batch, channels).
1652
+
1653
+ """
1654
+
1655
+ x = self.in_proj(x) # (time, batch, 2*channels)
1656
+
1657
+ x, s = x.chunk(2, dim=2)
1658
+ s = self.balancer1(s)
1659
+ s = self.sigmoid(s)
1660
+ x = self.activation1(x) # identity.
1661
+ x = x * s
1662
+ x = self.activation2(x) # identity
1663
+
1664
+ # (time, batch, channels)
1665
+
1666
+ # exchange the temporal dimension and the feature dimension
1667
+ x = x.permute(1, 2, 0) # (#batch, channels, time).
1668
+
1669
+ if src_key_padding_mask is not None:
1670
+ x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
1671
+
1672
+ x = self.depthwise_conv(x)
1673
+
1674
+ x = self.balancer2(x)
1675
+ x = x.permute(2, 0, 1) # (time, batch, channels)
1676
+
1677
+ x = self.whiten(x) # (time, batch, channels)
1678
+ x = self.out_proj(x) # (time, batch, channels)
1679
+
1680
+ return x
zipvoice/models/modules/zipformer_two_stream.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2025 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ import math
19
+ from typing import Optional, Tuple, Union
20
+
21
+ import torch
22
+ from torch import Tensor, nn
23
+
24
+ from zipvoice.models.modules.scaling import FloatLike, ScheduledFloat, SwooshR
25
+ from zipvoice.models.modules.zipformer import (
26
+ DownsampledZipformer2Encoder,
27
+ TTSZipformer,
28
+ Zipformer2Encoder,
29
+ Zipformer2EncoderLayer,
30
+ )
31
+
32
+
33
+ def timestep_embedding(timesteps, dim, max_period=10000):
34
+ """Create sinusoidal timestep embeddings.
35
+
36
+ :param timesteps: shape of (N) or (N, T)
37
+ :param dim: the dimension of the output.
38
+ :param max_period: controls the minimum frequency of the embeddings.
39
+ :return: an Tensor of positional embeddings. shape of (N, dim) or (T, N, dim)
40
+ """
41
+ half = dim // 2
42
+ freqs = torch.exp(
43
+ -math.log(max_period)
44
+ * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device)
45
+ / half
46
+ )
47
+
48
+ if timesteps.dim() == 2:
49
+ timesteps = timesteps.transpose(0, 1) # (N, T) -> (T, N)
50
+
51
+ args = timesteps[..., None].float() * freqs[None]
52
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
53
+ if dim % 2:
54
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[..., :1])], dim=-1)
55
+ return embedding
56
+
57
+
58
+ class TTSZipformerTwoStream(TTSZipformer):
59
+ """
60
+ Args:
61
+
62
+ Note: all "int or Tuple[int]" arguments below will be treated as lists of the same
63
+ length as downsampling_factor if they are single ints or one-element tuples.
64
+ The length of downsampling_factor defines the number of stacks.
65
+
66
+ downsampling_factor (Tuple[int]): downsampling factor for each encoder stack.
67
+ Note: this is in addition to the downsampling factor of 2 that is applied in
68
+ the frontend (self.encoder_embed).
69
+ encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks,
70
+ one per encoder stack.
71
+ num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack
72
+ query_head_dim (int or Tuple[int]): dimension of query and key per attention
73
+ head: per stack, if a tuple..
74
+ pos_head_dim (int or Tuple[int]): dimension of positional-encoding projection
75
+ per attention head
76
+ value_head_dim (int or Tuple[int]): dimension of value in each attention head
77
+ num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism.
78
+ Must be at least 4.
79
+ feedforward_dim (int or Tuple[int]): hidden dimension in feedforward modules
80
+ cnn_module_kernel (int or Tuple[int])): Kernel size of convolution module
81
+
82
+ pos_dim (int): the dimension of each positional-encoding vector prior to
83
+ projection, e.g. 128.
84
+
85
+ dropout (float): dropout rate
86
+ warmup_batches (float): number of batches to warm up over; this controls
87
+ dropout of encoder layers.
88
+ use_time_embed: (bool): if True, do not take time embedding as additional input.
89
+ time_embed_dim: (int): the dimension of the time embedding.
90
+ """
91
+
92
+ def __init__(
93
+ self,
94
+ in_dim: Tuple[int],
95
+ out_dim: Tuple[int],
96
+ downsampling_factor: Tuple[int] = (2, 4),
97
+ num_encoder_layers: Union[int, Tuple[int]] = 4,
98
+ cnn_module_kernel: Union[int, Tuple[int]] = 31,
99
+ encoder_dim: int = 384,
100
+ query_head_dim: int = 24,
101
+ pos_head_dim: int = 4,
102
+ value_head_dim: int = 12,
103
+ num_heads: int = 8,
104
+ feedforward_dim: int = 1536,
105
+ pos_dim: int = 192,
106
+ dropout: FloatLike = None, # see code below for default
107
+ warmup_batches: float = 4000.0,
108
+ use_time_embed: bool = True,
109
+ time_embed_dim: int = 192,
110
+ use_conv: bool = True,
111
+ ) -> None:
112
+ nn.Module.__init__(self)
113
+
114
+ if dropout is None:
115
+ dropout = ScheduledFloat((0.0, 0.3), (20000.0, 0.1))
116
+ if isinstance(downsampling_factor, int):
117
+ downsampling_factor = (downsampling_factor,)
118
+
119
+ def _to_tuple(x):
120
+ """Converts a single int or a 1-tuple of an int to a tuple with the same
121
+ length as downsampling_factor"""
122
+ if isinstance(x, int):
123
+ x = (x,)
124
+ if len(x) == 1:
125
+ x = x * len(downsampling_factor)
126
+ else:
127
+ assert len(x) == len(downsampling_factor) and isinstance(x[0], int)
128
+ return x
129
+
130
+ def _assert_downsampling_factor(factors):
131
+ """assert downsampling_factor follows u-net style"""
132
+ assert factors[0] == 1 and factors[-1] == 1
133
+
134
+ for i in range(1, len(factors) // 2 + 1):
135
+ assert factors[i] == factors[i - 1] * 2
136
+
137
+ for i in range(len(factors) // 2 + 1, len(factors)):
138
+ assert factors[i] * 2 == factors[i - 1]
139
+
140
+ _assert_downsampling_factor(downsampling_factor)
141
+ self.downsampling_factor = downsampling_factor # tuple
142
+ num_encoder_layers = _to_tuple(num_encoder_layers)
143
+ self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel)
144
+ self.encoder_dim = encoder_dim
145
+ self.num_encoder_layers = num_encoder_layers
146
+ self.query_head_dim = query_head_dim
147
+ self.value_head_dim = value_head_dim
148
+ self.num_heads = num_heads
149
+
150
+ self.use_time_embed = use_time_embed
151
+
152
+ self.time_embed_dim = time_embed_dim
153
+ if self.use_time_embed:
154
+ assert time_embed_dim != -1
155
+ else:
156
+ time_embed_dim = -1
157
+
158
+ assert len(in_dim) == len(out_dim) == 2
159
+
160
+ self.in_dim = in_dim
161
+ self.in_proj = nn.ModuleList(
162
+ [nn.Linear(in_dim[0], encoder_dim), nn.Linear(in_dim[1], encoder_dim)]
163
+ )
164
+ self.out_dim = out_dim
165
+ self.out_proj = nn.ModuleList(
166
+ [nn.Linear(encoder_dim, out_dim[0]), nn.Linear(encoder_dim, out_dim[1])]
167
+ )
168
+
169
+ # each one will be Zipformer2Encoder or DownsampledZipformer2Encoder
170
+ encoders = []
171
+
172
+ num_encoders = len(downsampling_factor)
173
+ for i in range(num_encoders):
174
+ encoder_layer = Zipformer2EncoderLayer(
175
+ embed_dim=encoder_dim,
176
+ pos_dim=pos_dim,
177
+ num_heads=num_heads,
178
+ query_head_dim=query_head_dim,
179
+ pos_head_dim=pos_head_dim,
180
+ value_head_dim=value_head_dim,
181
+ feedforward_dim=feedforward_dim,
182
+ use_conv=use_conv,
183
+ cnn_module_kernel=cnn_module_kernel[i],
184
+ dropout=dropout,
185
+ )
186
+
187
+ # For the segment of the warmup period, we let the Conv2dSubsampling
188
+ # layer learn something. Then we start to warm up the other encoders.
189
+ encoder = Zipformer2Encoder(
190
+ encoder_layer,
191
+ num_encoder_layers[i],
192
+ embed_dim=encoder_dim,
193
+ time_embed_dim=time_embed_dim,
194
+ pos_dim=pos_dim,
195
+ warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
196
+ warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
197
+ final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5),
198
+ )
199
+
200
+ if downsampling_factor[i] != 1:
201
+ encoder = DownsampledZipformer2Encoder(
202
+ encoder,
203
+ dim=encoder_dim,
204
+ downsample=downsampling_factor[i],
205
+ )
206
+
207
+ encoders.append(encoder)
208
+
209
+ self.encoders = nn.ModuleList(encoders)
210
+ if self.use_time_embed:
211
+ self.time_embed = nn.Sequential(
212
+ nn.Linear(time_embed_dim, time_embed_dim * 2),
213
+ SwooshR(),
214
+ nn.Linear(time_embed_dim * 2, time_embed_dim),
215
+ )
216
+ else:
217
+ self.time_embed = None
218
+
219
+ def forward(
220
+ self,
221
+ x: Tensor,
222
+ t: Optional[Tensor] = None,
223
+ padding_mask: Optional[Tensor] = None,
224
+ ) -> Tuple[Tensor, Tensor]:
225
+ """
226
+ Args:
227
+ x:
228
+ The input tensor. Its shape is (batch_size, seq_len, feature_dim).
229
+ t:
230
+ A t tensor of shape (batch_size,) or (batch_size, seq_len)
231
+ padding_mask:
232
+ The mask for padding, of shape (batch_size, seq_len); True means
233
+ masked position. May be None.
234
+ Returns:
235
+ Return the output embeddings. its shape is
236
+ (batch_size, output_seq_len, encoder_dim)
237
+ """
238
+ assert x.size(2) in self.in_dim, f"{x.size(2)} in {self.in_dim}"
239
+ if x.size(2) == self.in_dim[0]:
240
+ index = 0
241
+ else:
242
+ index = 1
243
+ x = x.permute(1, 0, 2)
244
+ x = self.in_proj[index](x)
245
+
246
+ if t is not None:
247
+ assert t.dim() == 1 or t.dim() == 2, t.shape
248
+ time_emb = timestep_embedding(t, self.time_embed_dim)
249
+ time_emb = self.time_embed(time_emb)
250
+ else:
251
+ time_emb = None
252
+
253
+ attn_mask = None
254
+
255
+ for i, module in enumerate(self.encoders):
256
+ x = module(
257
+ x,
258
+ time_emb=time_emb,
259
+ src_key_padding_mask=padding_mask,
260
+ attn_mask=attn_mask,
261
+ )
262
+ x = self.out_proj[index](x)
263
+ x = x.permute(1, 0, 2)
264
+ return x
zipvoice/models/zipvoice.py ADDED
@@ -0,0 +1,534 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Xiaomi Corp. (authors: Wei Kang
2
+ # Han Zhu)
3
+ #
4
+ # See ../../../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ from typing import List, Optional
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ from torch.nn.parallel import DistributedDataParallel as DDP
23
+
24
+ from zipvoice.models.modules.solver import EulerSolver
25
+ from zipvoice.models.modules.zipformer import TTSZipformer
26
+ from zipvoice.utils.common import (
27
+ condition_time_mask,
28
+ get_tokens_index,
29
+ make_pad_mask,
30
+ pad_labels,
31
+ prepare_avg_tokens_durations,
32
+ )
33
+
34
+
35
+ class ZipVoice(nn.Module):
36
+ """The ZipVoice model."""
37
+
38
+ def __init__(
39
+ self,
40
+ fm_decoder_downsampling_factor: List[int] = [1, 2, 4, 2, 1],
41
+ fm_decoder_num_layers: List[int] = [2, 2, 4, 4, 4],
42
+ fm_decoder_cnn_module_kernel: List[int] = [31, 15, 7, 15, 31],
43
+ fm_decoder_feedforward_dim: int = 1536,
44
+ fm_decoder_num_heads: int = 4,
45
+ fm_decoder_dim: int = 512,
46
+ text_encoder_num_layers: int = 4,
47
+ text_encoder_feedforward_dim: int = 512,
48
+ text_encoder_cnn_module_kernel: int = 9,
49
+ text_encoder_num_heads: int = 4,
50
+ text_encoder_dim: int = 192,
51
+ time_embed_dim: int = 192,
52
+ text_embed_dim: int = 192,
53
+ query_head_dim: int = 32,
54
+ value_head_dim: int = 12,
55
+ pos_head_dim: int = 4,
56
+ pos_dim: int = 48,
57
+ feat_dim: int = 100,
58
+ vocab_size: int = 26,
59
+ pad_id: int = 0,
60
+ ):
61
+ """
62
+ Initialize the model with specified configuration parameters.
63
+
64
+ Args:
65
+ fm_decoder_downsampling_factor: List of downsampling factors for each layer
66
+ in the flow-matching decoder.
67
+ fm_decoder_num_layers: List of the number of layers for each block in the
68
+ flow-matching decoder.
69
+ fm_decoder_cnn_module_kernel: List of kernel sizes for CNN modules in the
70
+ flow-matching decoder.
71
+ fm_decoder_feedforward_dim: Dimension of the feedforward network in the
72
+ flow-matching decoder.
73
+ fm_decoder_num_heads: Number of attention heads in the flow-matching
74
+ decoder.
75
+ fm_decoder_dim: Hidden dimension of the flow-matching decoder.
76
+ text_encoder_num_layers: Number of layers in the text encoder.
77
+ text_encoder_feedforward_dim: Dimension of the feedforward network in the
78
+ text encoder.
79
+ text_encoder_cnn_module_kernel: Kernel size for the CNN module in the
80
+ text encoder.
81
+ text_encoder_num_heads: Number of attention heads in the text encoder.
82
+ text_encoder_dim: Hidden dimension of the text encoder.
83
+ time_embed_dim: Dimension of the time embedding.
84
+ text_embed_dim: Dimension of the text embedding.
85
+ query_head_dim: Dimension of the query attention head.
86
+ value_head_dim: Dimension of the value attention head.
87
+ pos_head_dim: Dimension of the position attention head.
88
+ pos_dim: Dimension of the positional encoding.
89
+ feat_dim: Dimension of the acoustic features.
90
+ vocab_size: Size of the vocabulary.
91
+ pad_id: ID used for padding tokens.
92
+ """
93
+ super().__init__()
94
+
95
+ self.fm_decoder = TTSZipformer(
96
+ in_dim=feat_dim * 3,
97
+ out_dim=feat_dim,
98
+ downsampling_factor=fm_decoder_downsampling_factor,
99
+ num_encoder_layers=fm_decoder_num_layers,
100
+ cnn_module_kernel=fm_decoder_cnn_module_kernel,
101
+ encoder_dim=fm_decoder_dim,
102
+ feedforward_dim=fm_decoder_feedforward_dim,
103
+ num_heads=fm_decoder_num_heads,
104
+ query_head_dim=query_head_dim,
105
+ pos_head_dim=pos_head_dim,
106
+ value_head_dim=value_head_dim,
107
+ pos_dim=pos_dim,
108
+ use_time_embed=True,
109
+ time_embed_dim=time_embed_dim,
110
+ )
111
+
112
+ self.text_encoder = TTSZipformer(
113
+ in_dim=text_embed_dim,
114
+ out_dim=feat_dim,
115
+ downsampling_factor=1,
116
+ num_encoder_layers=text_encoder_num_layers,
117
+ cnn_module_kernel=text_encoder_cnn_module_kernel,
118
+ encoder_dim=text_encoder_dim,
119
+ feedforward_dim=text_encoder_feedforward_dim,
120
+ num_heads=text_encoder_num_heads,
121
+ query_head_dim=query_head_dim,
122
+ pos_head_dim=pos_head_dim,
123
+ value_head_dim=value_head_dim,
124
+ pos_dim=pos_dim,
125
+ use_time_embed=False,
126
+ )
127
+
128
+ self.feat_dim = feat_dim
129
+ self.text_embed_dim = text_embed_dim
130
+ self.pad_id = pad_id
131
+
132
+ self.embed = nn.Embedding(vocab_size, text_embed_dim)
133
+ self.solver = EulerSolver(self, func_name="forward_fm_decoder")
134
+
135
+ def forward_fm_decoder(
136
+ self,
137
+ t: torch.Tensor,
138
+ xt: torch.Tensor,
139
+ text_condition: torch.Tensor,
140
+ speech_condition: torch.Tensor,
141
+ padding_mask: Optional[torch.Tensor] = None,
142
+ guidance_scale: Optional[torch.Tensor] = None,
143
+ ) -> torch.Tensor:
144
+ """Compute velocity.
145
+ Args:
146
+ t: A tensor of shape (N, 1, 1) or a tensor of a float,
147
+ in the range of (0, 1).
148
+ xt: the input of the current timestep, including condition
149
+ embeddings and noisy acoustic features.
150
+ text_condition: the text condition embeddings, with the
151
+ shape (batch, seq_len, emb_dim).
152
+ speech_condition: the speech condition embeddings, with the
153
+ shape (batch, seq_len, emb_dim).
154
+ padding_mask: The mask for padding, True means masked
155
+ position, with the shape (N, T).
156
+ guidance_scale: The guidance scale in classifier-free guidance,
157
+ which is a tensor of shape (N, 1, 1) or a tensor of a float.
158
+
159
+ Returns:
160
+ predicted velocity, with the shape (batch, seq_len, emb_dim).
161
+ """
162
+
163
+ xt = torch.cat([xt, text_condition, speech_condition], dim=2)
164
+
165
+ assert t.dim() in (0, 3)
166
+ # Handle t with the shape (N, 1, 1):
167
+ # squeeze the last dimension if it's size is 1.
168
+ while t.dim() > 1 and t.size(-1) == 1:
169
+ t = t.squeeze(-1)
170
+ # Handle t with a single value: expand to the size of batch size.
171
+ if t.dim() == 0:
172
+ t = t.repeat(xt.shape[0])
173
+
174
+ if guidance_scale is not None:
175
+ while guidance_scale.dim() > 1 and guidance_scale.size(-1) == 1:
176
+ guidance_scale = guidance_scale.squeeze(-1)
177
+ if guidance_scale.dim() == 0:
178
+ guidance_scale = guidance_scale.repeat(xt.shape[0])
179
+
180
+ vt = self.fm_decoder(
181
+ x=xt, t=t, padding_mask=padding_mask, guidance_scale=guidance_scale
182
+ )
183
+ else:
184
+ vt = self.fm_decoder(x=xt, t=t, padding_mask=padding_mask)
185
+ return vt
186
+
187
+ def forward_text_embed(
188
+ self,
189
+ tokens: List[List[int]],
190
+ ):
191
+ """
192
+ Get the text embeddings.
193
+ Args:
194
+ tokens: a list of list of token ids.
195
+ Returns:
196
+ embed: the text embeddings, shape (batch, seq_len, emb_dim).
197
+ tokens_lens: the length of each token sequence, shape (batch,).
198
+ """
199
+ device = (
200
+ self.device if isinstance(self, DDP) else next(self.parameters()).device
201
+ )
202
+ tokens_padded = pad_labels(tokens, pad_id=self.pad_id, device=device) # (B, S)
203
+ embed = self.embed(tokens_padded) # (B, S, C)
204
+ tokens_lens = torch.tensor(
205
+ [len(token) for token in tokens], dtype=torch.int64, device=device
206
+ )
207
+ tokens_padding_mask = make_pad_mask(tokens_lens, embed.shape[1]) # (B, S)
208
+
209
+ embed = self.text_encoder(
210
+ x=embed, t=None, padding_mask=tokens_padding_mask
211
+ ) # (B, S, C)
212
+ return embed, tokens_lens
213
+
214
+ def forward_text_condition(
215
+ self,
216
+ embed: torch.Tensor,
217
+ tokens_lens: torch.Tensor,
218
+ features_lens: torch.Tensor,
219
+ ):
220
+ """
221
+ Get the text condition with the same length of the acoustic feature.
222
+ Args:
223
+ embed: the text embeddings, shape (batch, token_seq_len, emb_dim).
224
+ tokens_lens: the length of each token sequence, shape (batch,).
225
+ features_lens: the length of each acoustic feature sequence,
226
+ shape (batch,).
227
+ Returns:
228
+ text_condition: the text condition, shape
229
+ (batch, feature_seq_len, emb_dim).
230
+ padding_mask: the padding mask of text condition, shape
231
+ (batch, feature_seq_len).
232
+ """
233
+
234
+ num_frames = int(features_lens.max())
235
+
236
+ padding_mask = make_pad_mask(features_lens, max_len=num_frames) # (B, T)
237
+
238
+ tokens_durations = prepare_avg_tokens_durations(features_lens, tokens_lens)
239
+
240
+ tokens_index = get_tokens_index(tokens_durations, num_frames).to(
241
+ embed.device
242
+ ) # (B, T)
243
+
244
+ text_condition = torch.gather(
245
+ embed,
246
+ dim=1,
247
+ index=tokens_index.unsqueeze(-1).expand(
248
+ embed.size(0), num_frames, embed.size(-1)
249
+ ),
250
+ ) # (B, T, F)
251
+ return text_condition, padding_mask
252
+
253
+ def forward_text_train(
254
+ self,
255
+ tokens: List[List[int]],
256
+ features_lens: torch.Tensor,
257
+ ):
258
+ """
259
+ Process text for training, given text tokens and real feature lengths.
260
+ """
261
+ embed, tokens_lens = self.forward_text_embed(tokens)
262
+ text_condition, padding_mask = self.forward_text_condition(
263
+ embed, tokens_lens, features_lens
264
+ )
265
+ return (
266
+ text_condition,
267
+ padding_mask,
268
+ )
269
+
270
+ def forward_text_inference_gt_duration(
271
+ self,
272
+ tokens: List[List[int]],
273
+ features_lens: torch.Tensor,
274
+ prompt_tokens: List[List[int]],
275
+ prompt_features_lens: torch.Tensor,
276
+ ):
277
+ """
278
+ Process text for inference, given text tokens, real feature lengths and prompts.
279
+ """
280
+ tokens = [
281
+ prompt_token + token for prompt_token, token in zip(prompt_tokens, tokens)
282
+ ]
283
+ features_lens = prompt_features_lens + features_lens
284
+ embed, tokens_lens = self.forward_text_embed(tokens)
285
+ text_condition, padding_mask = self.forward_text_condition(
286
+ embed, tokens_lens, features_lens
287
+ )
288
+ return text_condition, padding_mask
289
+
290
+ def forward_text_inference_ratio_duration(
291
+ self,
292
+ tokens: List[List[int]],
293
+ prompt_tokens: List[List[int]],
294
+ prompt_features_lens: torch.Tensor,
295
+ speed: float,
296
+ ):
297
+ """
298
+ Process text for inference, given text tokens and prompts,
299
+ feature lengths are predicted with the ratio of token numbers.
300
+ """
301
+ device = (
302
+ self.device if isinstance(self, DDP) else next(self.parameters()).device
303
+ )
304
+
305
+ cat_tokens = [
306
+ prompt_token + token for prompt_token, token in zip(prompt_tokens, tokens)
307
+ ]
308
+
309
+ prompt_tokens_lens = torch.tensor(
310
+ [len(token) for token in prompt_tokens],
311
+ dtype=torch.int64,
312
+ device=device,
313
+ )
314
+
315
+ tokens_lens = torch.tensor(
316
+ [len(token) for token in tokens],
317
+ dtype=torch.int64,
318
+ device=device,
319
+ )
320
+
321
+ cat_embed, cat_tokens_lens = self.forward_text_embed(cat_tokens)
322
+
323
+ features_lens = prompt_features_lens + torch.ceil(
324
+ (prompt_features_lens / prompt_tokens_lens * tokens_lens / speed)
325
+ ).to(dtype=torch.int64)
326
+
327
+ text_condition, padding_mask = self.forward_text_condition(
328
+ cat_embed, cat_tokens_lens, features_lens
329
+ )
330
+ return text_condition, padding_mask
331
+
332
+ def forward(
333
+ self,
334
+ tokens: List[List[int]],
335
+ features: torch.Tensor,
336
+ features_lens: torch.Tensor,
337
+ noise: torch.Tensor,
338
+ t: torch.Tensor,
339
+ condition_drop_ratio: float = 0.0,
340
+ ) -> torch.Tensor:
341
+ """Forward pass of the model for training.
342
+ Args:
343
+ tokens: a list of list of token ids.
344
+ features: the acoustic features, with the shape (batch, seq_len, feat_dim).
345
+ features_lens: the length of each acoustic feature sequence, shape (batch,).
346
+ noise: the intitial noise, with the shape (batch, seq_len, feat_dim).
347
+ t: the time step, with the shape (batch, 1, 1).
348
+ condition_drop_ratio: the ratio of dropped text condition.
349
+ Returns:
350
+ fm_loss: the flow-matching loss.
351
+ """
352
+
353
+ (text_condition, padding_mask,) = self.forward_text_train(
354
+ tokens=tokens,
355
+ features_lens=features_lens,
356
+ )
357
+
358
+ speech_condition_mask = condition_time_mask(
359
+ features_lens=features_lens,
360
+ mask_percent=(0.7, 1.0),
361
+ max_len=features.size(1),
362
+ )
363
+ speech_condition = torch.where(speech_condition_mask.unsqueeze(-1), 0, features)
364
+
365
+ if condition_drop_ratio > 0.0:
366
+ drop_mask = (
367
+ torch.rand(text_condition.size(0), 1, 1).to(text_condition.device)
368
+ > condition_drop_ratio
369
+ )
370
+ text_condition = text_condition * drop_mask
371
+
372
+ xt = features * t + noise * (1 - t)
373
+ ut = features - noise # (B, T, F)
374
+
375
+ vt = self.forward_fm_decoder(
376
+ t=t,
377
+ xt=xt,
378
+ text_condition=text_condition,
379
+ speech_condition=speech_condition,
380
+ padding_mask=padding_mask,
381
+ )
382
+
383
+ loss_mask = speech_condition_mask & (~padding_mask)
384
+ fm_loss = torch.mean((vt[loss_mask] - ut[loss_mask]) ** 2)
385
+
386
+ return fm_loss
387
+
388
+ def sample(
389
+ self,
390
+ tokens: List[List[int]],
391
+ prompt_tokens: List[List[int]],
392
+ prompt_features: torch.Tensor,
393
+ prompt_features_lens: torch.Tensor,
394
+ features_lens: Optional[torch.Tensor] = None,
395
+ speed: float = 1.0,
396
+ t_shift: float = 1.0,
397
+ duration: str = "predict",
398
+ num_step: int = 5,
399
+ guidance_scale: float = 0.5,
400
+ ) -> torch.Tensor:
401
+ """
402
+ Generate acoustic features, given text tokens, prompts feature
403
+ and prompt transcription's text tokens.
404
+ Args:
405
+ tokens: a list of list of text tokens.
406
+ prompt_tokens: a list of list of prompt tokens.
407
+ prompt_features: the prompt feature with the shape
408
+ (batch_size, seq_len, feat_dim).
409
+ prompt_features_lens: the length of each prompt feature,
410
+ with the shape (batch_size,).
411
+ features_lens: the length of the predicted eature, with the
412
+ shape (batch_size,). It is used only when duration is "real".
413
+ duration: "real" or "predict". If "real", the predicted
414
+ feature length is given by features_lens.
415
+ num_step: the number of steps to use in the ODE solver.
416
+ guidance_scale: the guidance scale for classifier-free guidance.
417
+ """
418
+
419
+ assert duration in ["real", "predict"]
420
+
421
+ if duration == "predict":
422
+ (
423
+ text_condition,
424
+ padding_mask,
425
+ ) = self.forward_text_inference_ratio_duration(
426
+ tokens=tokens,
427
+ prompt_tokens=prompt_tokens,
428
+ prompt_features_lens=prompt_features_lens,
429
+ speed=speed,
430
+ )
431
+ else:
432
+ assert features_lens is not None
433
+ text_condition, padding_mask = self.forward_text_inference_gt_duration(
434
+ tokens=tokens,
435
+ features_lens=features_lens,
436
+ prompt_tokens=prompt_tokens,
437
+ prompt_features_lens=prompt_features_lens,
438
+ )
439
+ batch_size, num_frames, _ = text_condition.shape
440
+
441
+ speech_condition = torch.nn.functional.pad(
442
+ prompt_features, (0, 0, 0, num_frames - prompt_features.size(1))
443
+ ) # (B, T, F)
444
+
445
+ # False means speech condition positions.
446
+ speech_condition_mask = make_pad_mask(prompt_features_lens, num_frames)
447
+ speech_condition = torch.where(
448
+ speech_condition_mask.unsqueeze(-1),
449
+ torch.zeros_like(speech_condition),
450
+ speech_condition,
451
+ )
452
+
453
+ x0 = torch.randn(
454
+ batch_size,
455
+ num_frames,
456
+ prompt_features.size(-1),
457
+ device=text_condition.device,
458
+ )
459
+
460
+ x1 = self.solver.sample(
461
+ x=x0,
462
+ text_condition=text_condition,
463
+ speech_condition=speech_condition,
464
+ padding_mask=padding_mask,
465
+ num_step=num_step,
466
+ guidance_scale=guidance_scale,
467
+ t_shift=t_shift,
468
+ )
469
+ x1_wo_prompt_lens = (~padding_mask).sum(-1) - prompt_features_lens
470
+ x1_prompt = torch.zeros(
471
+ x1.size(0), prompt_features_lens.max(), x1.size(2), device=x1.device
472
+ )
473
+ x1_wo_prompt = torch.zeros(
474
+ x1.size(0), x1_wo_prompt_lens.max(), x1.size(2), device=x1.device
475
+ )
476
+ for i in range(x1.size(0)):
477
+ x1_wo_prompt[i, : x1_wo_prompt_lens[i], :] = x1[
478
+ i,
479
+ prompt_features_lens[i] : prompt_features_lens[i]
480
+ + x1_wo_prompt_lens[i],
481
+ ]
482
+ x1_prompt[i, : prompt_features_lens[i], :] = x1[
483
+ i, : prompt_features_lens[i]
484
+ ]
485
+
486
+ return x1_wo_prompt, x1_wo_prompt_lens, x1_prompt, prompt_features_lens
487
+
488
+ def sample_intermediate(
489
+ self,
490
+ tokens: List[List[int]],
491
+ features: torch.Tensor,
492
+ features_lens: torch.Tensor,
493
+ noise: torch.Tensor,
494
+ speech_condition_mask: torch.Tensor,
495
+ t_start: float,
496
+ t_end: float,
497
+ num_step: int = 1,
498
+ guidance_scale: torch.Tensor = None,
499
+ ) -> torch.Tensor:
500
+ """
501
+ Generate acoustic features in intermediate timesteps.
502
+ Args:
503
+ tokens: List of list of token ids.
504
+ features: The acoustic features, with the shape (batch, seq_len, feat_dim).
505
+ features_lens: The length of each acoustic feature sequence,
506
+ with the shape (batch,).
507
+ noise: The initial noise, with the shape (batch, seq_len, feat_dim).
508
+ speech_condition_mask: The mask for speech condition, True means
509
+ non-condition positions, with the shape (batch, seq_len).
510
+ t_start: The start timestep.
511
+ t_end: The end timestep.
512
+ num_step: The number of steps for sampling.
513
+ guidance_scale: The scale for classifier-free guidance inference,
514
+ with the shape (batch, 1, 1).
515
+ """
516
+ (text_condition, padding_mask,) = self.forward_text_train(
517
+ tokens=tokens,
518
+ features_lens=features_lens,
519
+ )
520
+
521
+ speech_condition = torch.where(speech_condition_mask.unsqueeze(-1), 0, features)
522
+
523
+ x_t_end = self.solver.sample(
524
+ x=noise,
525
+ text_condition=text_condition,
526
+ speech_condition=speech_condition,
527
+ padding_mask=padding_mask,
528
+ num_step=num_step,
529
+ guidance_scale=guidance_scale,
530
+ t_start=t_start,
531
+ t_end=t_end,
532
+ )
533
+ x_t_end_lens = (~padding_mask).sum(-1)
534
+ return x_t_end, x_t_end_lens
zipvoice/models/zipvoice_dialog.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Xiaomi Corp. (authors: Han Zhu)
2
+ #
3
+ # See ../../../../LICENSE for clarification regarding multiple authors
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from typing import List
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ from torch.nn.parallel import DistributedDataParallel as DDP
22
+
23
+ from zipvoice.models.modules.zipformer_two_stream import TTSZipformerTwoStream
24
+ from zipvoice.models.zipvoice import ZipVoice
25
+ from zipvoice.utils.common import condition_time_mask_suffix, make_pad_mask, pad_labels
26
+
27
+
28
+ class ZipVoiceDialog(ZipVoice):
29
+ """The ZipVoice-Dialog model."""
30
+
31
+ def __init__(
32
+ self,
33
+ fm_decoder_downsampling_factor: List[int] = [1, 2, 4, 2, 1],
34
+ fm_decoder_num_layers: List[int] = [2, 2, 4, 4, 4],
35
+ fm_decoder_cnn_module_kernel: List[int] = [31, 15, 7, 15, 31],
36
+ fm_decoder_feedforward_dim: int = 1536,
37
+ fm_decoder_num_heads: int = 4,
38
+ fm_decoder_dim: int = 512,
39
+ text_encoder_num_layers: int = 4,
40
+ text_encoder_feedforward_dim: int = 512,
41
+ text_encoder_cnn_module_kernel: int = 9,
42
+ text_encoder_num_heads: int = 4,
43
+ text_encoder_dim: int = 192,
44
+ time_embed_dim: int = 192,
45
+ text_embed_dim: int = 192,
46
+ query_head_dim: int = 32,
47
+ value_head_dim: int = 12,
48
+ pos_head_dim: int = 4,
49
+ pos_dim: int = 48,
50
+ feat_dim: int = 100,
51
+ vocab_size: int = 26,
52
+ pad_id: int = 0,
53
+ spk_a_id: int = 360,
54
+ spk_b_id: int = 361,
55
+ ):
56
+ """
57
+ Initialize the model with specified configuration parameters.
58
+
59
+ Args:
60
+ fm_decoder_downsampling_factor: List of downsampling factors for each layer
61
+ in the flow-matching decoder.
62
+ fm_decoder_num_layers: List of the number of layers for each block in the
63
+ flow-matching decoder.
64
+ fm_decoder_cnn_module_kernel: List of kernel sizes for CNN modules in the
65
+ flow-matching decoder.
66
+ fm_decoder_feedforward_dim: Dimension of the feedforward network in the
67
+ flow-matching decoder.
68
+ fm_decoder_num_heads: Number of attention heads in the flow-matching
69
+ decoder.
70
+ fm_decoder_dim: Hidden dimension of the flow-matching decoder.
71
+ text_encoder_num_layers: Number of layers in the text encoder.
72
+ text_encoder_feedforward_dim: Dimension of the feedforward network in the
73
+ text encoder.
74
+ text_encoder_cnn_module_kernel: Kernel size for the CNN module in the
75
+ text encoder.
76
+ text_encoder_num_heads: Number of attention heads in the text encoder.
77
+ text_encoder_dim: Hidden dimension of the text encoder.
78
+ time_embed_dim: Dimension of the time embedding.
79
+ text_embed_dim: Dimension of the text embedding.
80
+ query_head_dim: Dimension of the query attention head.
81
+ value_head_dim: Dimension of the value attention head.
82
+ pos_head_dim: Dimension of the position attention head.
83
+ pos_dim: Dimension of the positional encoding.
84
+ feat_dim: Dimension of the acoustic features.
85
+ vocab_size: Size of the vocabulary.
86
+ pad_id: ID used for padding tokens.
87
+ spk_a_id: ID of speaker A / [S1].
88
+ spk_b_id: ID of speaker B / [S2].
89
+ """
90
+ super().__init__(
91
+ fm_decoder_downsampling_factor=fm_decoder_downsampling_factor,
92
+ fm_decoder_num_layers=fm_decoder_num_layers,
93
+ fm_decoder_cnn_module_kernel=fm_decoder_cnn_module_kernel,
94
+ fm_decoder_feedforward_dim=fm_decoder_feedforward_dim,
95
+ fm_decoder_num_heads=fm_decoder_num_heads,
96
+ fm_decoder_dim=fm_decoder_dim,
97
+ text_encoder_num_layers=text_encoder_num_layers,
98
+ text_encoder_feedforward_dim=text_encoder_feedforward_dim,
99
+ text_encoder_cnn_module_kernel=text_encoder_cnn_module_kernel,
100
+ text_encoder_num_heads=text_encoder_num_heads,
101
+ text_encoder_dim=text_encoder_dim,
102
+ time_embed_dim=time_embed_dim,
103
+ text_embed_dim=text_embed_dim,
104
+ query_head_dim=query_head_dim,
105
+ value_head_dim=value_head_dim,
106
+ pos_head_dim=pos_head_dim,
107
+ pos_dim=pos_dim,
108
+ feat_dim=feat_dim,
109
+ vocab_size=vocab_size,
110
+ pad_id=pad_id,
111
+ )
112
+
113
+ self.spk_a_id = spk_a_id
114
+ self.spk_b_id = spk_b_id
115
+ self.spk_embed = nn.Embedding(2, feat_dim)
116
+ torch.nn.init.normal_(self.spk_embed.weight, mean=0, std=0.1)
117
+
118
+ def extract_spk_indices(self, tensor):
119
+ turn_mask = ((tensor == self.spk_a_id) | (tensor == self.spk_b_id)).long()
120
+ turn_counts = turn_mask.cumsum(dim=1)
121
+ spk_mask = turn_counts % 2
122
+ spk_mask = torch.where(tensor == self.pad_id, -1, spk_mask)
123
+ spk_a_indices = torch.where(spk_mask == 0)
124
+ spk_b_indices = torch.where(spk_mask == 1)
125
+ return spk_a_indices, spk_b_indices
126
+
127
+ def forward_text_embed(
128
+ self,
129
+ tokens: List[List[int]],
130
+ ):
131
+ """
132
+ Get the text embeddings.
133
+ Args:
134
+ tokens: a list of list of token ids.
135
+ Returns:
136
+ embed: the text embeddings, shape (batch, seq_len, emb_dim).
137
+ tokens_lens: the length of each token sequence, shape (batch,).
138
+ """
139
+ device = (
140
+ self.device if isinstance(self, DDP) else next(self.parameters()).device
141
+ )
142
+ tokens_padded = pad_labels(tokens, pad_id=self.pad_id, device=device) # (B, S)
143
+ embed = self.embed(tokens_padded) # (B, S, C)
144
+ spk_a_indices, spk_b_indices = self.extract_spk_indices(tokens_padded)
145
+ tokens_lens = torch.tensor(
146
+ [len(token) for token in tokens], dtype=torch.int64, device=device
147
+ )
148
+ tokens_padding_mask = make_pad_mask(tokens_lens, embed.shape[1]) # (B, S)
149
+
150
+ embed = self.text_encoder(
151
+ x=embed, t=None, padding_mask=tokens_padding_mask
152
+ ) # (B, S, C)
153
+ embed[spk_a_indices] += self.spk_embed(torch.tensor(0, device=device)).to(
154
+ embed.dtype
155
+ )
156
+ embed[spk_b_indices] += self.spk_embed(torch.tensor(1, device=device)).to(
157
+ embed.dtype
158
+ )
159
+ return embed, tokens_lens
160
+
161
+ def forward(
162
+ self,
163
+ tokens: List[List[int]],
164
+ features: torch.Tensor,
165
+ features_lens: torch.Tensor,
166
+ noise: torch.Tensor,
167
+ t: torch.Tensor,
168
+ condition_drop_ratio: float = 0.0,
169
+ ) -> torch.Tensor:
170
+ """Forward pass of the model for training.
171
+ Args:
172
+ tokens: a list of list of token ids.
173
+ features: the acoustic features, with the shape (batch, seq_len, feat_dim).
174
+ features_lens: the length of each acoustic feature sequence, shape (batch,).
175
+ noise: the intitial noise, with the shape (batch, seq_len, feat_dim).
176
+ t: the time step, with the shape (batch, 1, 1).
177
+ condition_drop_ratio: the ratio of dropped text condition.
178
+ Returns:
179
+ fm_loss: the flow-matching loss.
180
+ """
181
+
182
+ (text_condition, padding_mask,) = self.forward_text_train(
183
+ tokens=tokens,
184
+ features_lens=features_lens,
185
+ )
186
+
187
+ speech_condition_mask = condition_time_mask_suffix(
188
+ features_lens=features_lens,
189
+ mask_percent=(0.5, 1.0),
190
+ max_len=features.size(1),
191
+ )
192
+ speech_condition = torch.where(speech_condition_mask.unsqueeze(-1), 0, features)
193
+
194
+ if condition_drop_ratio > 0.0:
195
+ drop_mask = (
196
+ torch.rand(text_condition.size(0), 1, 1).to(text_condition.device)
197
+ > condition_drop_ratio
198
+ )
199
+ text_condition = text_condition * drop_mask
200
+
201
+ xt = features * t + noise * (1 - t)
202
+ ut = features - noise # (B, T, F)
203
+
204
+ vt = self.forward_fm_decoder(
205
+ t=t,
206
+ xt=xt,
207
+ text_condition=text_condition,
208
+ speech_condition=speech_condition,
209
+ padding_mask=padding_mask,
210
+ )
211
+
212
+ loss_mask = speech_condition_mask & (~padding_mask)
213
+ fm_loss = torch.mean((vt[loss_mask] - ut[loss_mask]) ** 2)
214
+
215
+ return fm_loss
216
+
217
+
218
+ class ZipVoiceDialogStereo(ZipVoiceDialog):
219
+ def __init__(self, *args, **kwargs):
220
+ super().__init__(*args, **kwargs)
221
+
222
+ required_params = {
223
+ "feat_dim",
224
+ "fm_decoder_downsampling_factor",
225
+ "fm_decoder_num_layers",
226
+ "fm_decoder_cnn_module_kernel",
227
+ "fm_decoder_dim",
228
+ "fm_decoder_feedforward_dim",
229
+ "fm_decoder_num_heads",
230
+ "query_head_dim",
231
+ "pos_head_dim",
232
+ "value_head_dim",
233
+ "pos_dim",
234
+ "time_embed_dim",
235
+ }
236
+
237
+ missing = [p for p in required_params if p not in kwargs]
238
+ if missing:
239
+ raise ValueError(f"Missing required parameters: {', '.join(missing)}")
240
+
241
+ self.fm_decoder = TTSZipformerTwoStream(
242
+ in_dim=(kwargs["feat_dim"] * 5, kwargs["feat_dim"] * 3),
243
+ out_dim=(kwargs["feat_dim"] * 2, kwargs["feat_dim"]),
244
+ downsampling_factor=kwargs["fm_decoder_downsampling_factor"],
245
+ num_encoder_layers=kwargs["fm_decoder_num_layers"],
246
+ cnn_module_kernel=kwargs["fm_decoder_cnn_module_kernel"],
247
+ encoder_dim=kwargs["fm_decoder_dim"],
248
+ feedforward_dim=kwargs["fm_decoder_feedforward_dim"],
249
+ num_heads=kwargs["fm_decoder_num_heads"],
250
+ query_head_dim=kwargs["query_head_dim"],
251
+ pos_head_dim=kwargs["pos_head_dim"],
252
+ value_head_dim=kwargs["value_head_dim"],
253
+ pos_dim=kwargs["pos_dim"],
254
+ use_time_embed=True,
255
+ time_embed_dim=kwargs["time_embed_dim"],
256
+ )
257
+
258
+ def forward(
259
+ self,
260
+ tokens: List[List[int]],
261
+ features: torch.Tensor,
262
+ features_lens: torch.Tensor,
263
+ noise: torch.Tensor,
264
+ t: torch.Tensor,
265
+ condition_drop_ratio: float = 0.0,
266
+ se_weight: float = 1.0,
267
+ ) -> torch.Tensor:
268
+ """Forward pass of the model for training.
269
+ Args:
270
+ tokens: a list of list of token ids.
271
+ features: the acoustic features, with the shape (batch, seq_len, feat_dim).
272
+ features_lens: the length of each acoustic feature sequence, shape (batch,).
273
+ noise: the intitial noise, with the shape (batch, seq_len, feat_dim).
274
+ t: the time step, with the shape (batch, 1, 1).
275
+ condition_drop_ratio: the ratio of dropped text condition.
276
+ se_weight: the weight of the speaker exclusive loss.
277
+ Returns:
278
+ fm_loss: the flow-matching loss.
279
+ """
280
+
281
+ (text_condition, padding_mask,) = self.forward_text_train(
282
+ tokens=tokens,
283
+ features_lens=features_lens,
284
+ )
285
+
286
+ speech_condition_mask = condition_time_mask_suffix(
287
+ features_lens=features_lens,
288
+ mask_percent=(0.5, 1.0),
289
+ max_len=features.size(1),
290
+ )
291
+ speech_condition = torch.where(speech_condition_mask.unsqueeze(-1), 0, features)
292
+
293
+ if condition_drop_ratio > 0.0:
294
+ drop_mask = (
295
+ torch.rand(text_condition.size(0), 1, 1).to(text_condition.device)
296
+ > condition_drop_ratio
297
+ )
298
+ text_condition = text_condition * drop_mask
299
+
300
+ xt = features * t + noise * (1 - t)
301
+ ut = features - noise # (B, T, F)
302
+
303
+ vt = self.forward_fm_decoder(
304
+ t=t,
305
+ xt=xt,
306
+ text_condition=text_condition,
307
+ speech_condition=speech_condition,
308
+ padding_mask=padding_mask,
309
+ )
310
+
311
+ loss_mask = speech_condition_mask & (~padding_mask)
312
+ fm_loss = torch.mean((vt[loss_mask] - ut[loss_mask]) ** 2)
313
+
314
+ if se_weight > 0:
315
+ target = xt + vt * (1 - t)
316
+ fbank_1 = target[:, :, : self.feat_dim]
317
+ fbank_2 = target[:, :, self.feat_dim :]
318
+ energy_loss = torch.mean(
319
+ self.energy_based_loss(fbank_1, fbank_2, features)[loss_mask]
320
+ )
321
+ loss = fm_loss + energy_loss * se_weight
322
+ else:
323
+ loss = fm_loss
324
+
325
+ return loss
326
+
327
+ def energy_based_loss(self, fbank1, fbank2, gt_fbank):
328
+ energy1 = self.energy(fbank1)
329
+ energy2 = self.energy(fbank2)
330
+
331
+ energy_thresholds = self.adaptive_threshold_from_gt(
332
+ torch.cat(
333
+ [
334
+ gt_fbank[:, :, : self.feat_dim],
335
+ gt_fbank[:, :, self.feat_dim :],
336
+ ],
337
+ dim=1,
338
+ )
339
+ )
340
+
341
+ both_speaking = (
342
+ (energy1 > energy_thresholds) & (energy2 > energy_thresholds)
343
+ ).float()
344
+
345
+ penalty = (
346
+ both_speaking
347
+ * (energy1 - energy_thresholds)
348
+ * (energy2 - energy_thresholds)
349
+ )
350
+ return penalty
351
+
352
+ def energy(self, fbank):
353
+ return torch.mean(fbank, dim=-1)
354
+
355
+ def adaptive_threshold_from_gt(self, gt_fbank, percentile=50):
356
+ frame_energies = self.energy(gt_fbank)
357
+ thresholds = torch.quantile(frame_energies, q=percentile / 100, dim=1)
358
+ return thresholds.unsqueeze(1)
zipvoice/models/zipvoice_distill.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Xiaomi Corp. (authors: Wei Kang
2
+ # Han Zhu)
3
+ #
4
+ # See ../../../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ from typing import List
19
+
20
+ import torch
21
+
22
+ from zipvoice.models.modules.solver import DistillEulerSolver
23
+ from zipvoice.models.modules.zipformer import TTSZipformer
24
+ from zipvoice.models.zipvoice import ZipVoice
25
+
26
+
27
+ class ZipVoiceDistill(ZipVoice):
28
+ """ZipVoice-Distill model."""
29
+
30
+ def __init__(self, *args, **kwargs):
31
+ super().__init__(*args, **kwargs)
32
+
33
+ required_params = {
34
+ "feat_dim",
35
+ "fm_decoder_downsampling_factor",
36
+ "fm_decoder_num_layers",
37
+ "fm_decoder_cnn_module_kernel",
38
+ "fm_decoder_dim",
39
+ "fm_decoder_feedforward_dim",
40
+ "fm_decoder_num_heads",
41
+ "query_head_dim",
42
+ "pos_head_dim",
43
+ "value_head_dim",
44
+ "pos_dim",
45
+ "time_embed_dim",
46
+ }
47
+
48
+ missing = [p for p in required_params if p not in kwargs]
49
+ if missing:
50
+ raise ValueError(f"Missing required parameters: {', '.join(missing)}")
51
+
52
+ self.fm_decoder = TTSZipformer(
53
+ in_dim=kwargs["feat_dim"] * 3,
54
+ out_dim=kwargs["feat_dim"],
55
+ downsampling_factor=kwargs["fm_decoder_downsampling_factor"],
56
+ num_encoder_layers=kwargs["fm_decoder_num_layers"],
57
+ cnn_module_kernel=kwargs["fm_decoder_cnn_module_kernel"],
58
+ encoder_dim=kwargs["fm_decoder_dim"],
59
+ feedforward_dim=kwargs["fm_decoder_feedforward_dim"],
60
+ num_heads=kwargs["fm_decoder_num_heads"],
61
+ query_head_dim=kwargs["query_head_dim"],
62
+ pos_head_dim=kwargs["pos_head_dim"],
63
+ value_head_dim=kwargs["value_head_dim"],
64
+ pos_dim=kwargs["pos_dim"],
65
+ use_time_embed=True,
66
+ time_embed_dim=kwargs["time_embed_dim"],
67
+ use_guidance_scale_embed=True,
68
+ )
69
+ self.solver = DistillEulerSolver(self, func_name="forward_fm_decoder")
70
+
71
+ def forward(
72
+ self,
73
+ tokens: List[List[int]],
74
+ features: torch.Tensor,
75
+ features_lens: torch.Tensor,
76
+ noise: torch.Tensor,
77
+ speech_condition_mask: torch.Tensor,
78
+ t_start: float,
79
+ t_end: float,
80
+ num_step: int = 1,
81
+ guidance_scale: torch.Tensor = None,
82
+ ) -> torch.Tensor:
83
+
84
+ return self.sample_intermediate(
85
+ tokens=tokens,
86
+ features=features,
87
+ features_lens=features_lens,
88
+ noise=noise,
89
+ speech_condition_mask=speech_condition_mask,
90
+ t_start=t_start,
91
+ t_end=t_end,
92
+ num_step=num_step,
93
+ guidance_scale=guidance_scale,
94
+ )
zipvoice/tokenizer/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # ZipVoice tokenizer package
zipvoice/tokenizer/normalizer.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from abc import ABC, abstractmethod
3
+
4
+ import cn2an
5
+ import inflect
6
+
7
+
8
+ class TextNormalizer(ABC):
9
+ """Abstract base class for text normalization, defining common interface."""
10
+
11
+ @abstractmethod
12
+ def normalize(self, text: str) -> str:
13
+ """Normalize text."""
14
+ raise NotImplementedError
15
+
16
+
17
+ class EnglishTextNormalizer(TextNormalizer):
18
+ """
19
+ A class to handle preprocessing of English text including normalization. Following:
20
+ https://github.com/espnet/espnet_tts_frontend/blob/master/tacotron_cleaner/cleaners.py
21
+ """
22
+
23
+ def __init__(self):
24
+ # List of (regular expression, replacement) pairs for abbreviations:
25
+ self._abbreviations = [
26
+ (re.compile("\\b%s\\b" % x[0], re.IGNORECASE), x[1])
27
+ for x in [
28
+ ("mrs", "misess"),
29
+ ("mr", "mister"),
30
+ ("dr", "doctor"),
31
+ ("st", "saint"),
32
+ ("co", "company"),
33
+ ("jr", "junior"),
34
+ ("maj", "major"),
35
+ ("gen", "general"),
36
+ ("drs", "doctors"),
37
+ ("rev", "reverend"),
38
+ ("lt", "lieutenant"),
39
+ ("hon", "honorable"),
40
+ ("sgt", "sergeant"),
41
+ ("capt", "captain"),
42
+ ("esq", "esquire"),
43
+ ("ltd", "limited"),
44
+ ("col", "colonel"),
45
+ ("ft", "fort"),
46
+ ("etc", "et cetera"),
47
+ ("btw", "by the way"),
48
+ ]
49
+ ]
50
+
51
+ self._inflect = inflect.engine()
52
+ self._comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
53
+ self._decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
54
+ self._percent_number_re = re.compile(r"([0-9\.\,]*[0-9]+%)")
55
+ self._pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
56
+ self._dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
57
+ self._fraction_re = re.compile(r"([0-9]+)/([0-9]+)")
58
+ self._ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
59
+ self._number_re = re.compile(r"[0-9]+")
60
+ self._whitespace_re = re.compile(r"\s+")
61
+
62
+ def normalize(self, text: str) -> str:
63
+ """Custom pipeline for English text,
64
+ including number and abbreviation expansion."""
65
+ text = self.expand_abbreviations(text)
66
+ text = self.normalize_numbers(text)
67
+
68
+ return text
69
+
70
+ def fraction_to_words(self, numerator, denominator):
71
+ if numerator == 1 and denominator == 2:
72
+ return " one half "
73
+ if numerator == 1 and denominator == 4:
74
+ return " one quarter "
75
+ if denominator == 2:
76
+ return " " + self._inflect.number_to_words(numerator) + " halves "
77
+ if denominator == 4:
78
+ return " " + self._inflect.number_to_words(numerator) + " quarters "
79
+ return (
80
+ " "
81
+ + self._inflect.number_to_words(numerator)
82
+ + " "
83
+ + self._inflect.ordinal(self._inflect.number_to_words(denominator))
84
+ + " "
85
+ )
86
+
87
+ def _remove_commas(self, m):
88
+ return m.group(1).replace(",", "")
89
+
90
+ def _expand_dollars(self, m):
91
+ match = m.group(1)
92
+ parts = match.split(".")
93
+ if len(parts) > 2:
94
+ return " " + match + " dollars " # Unexpected format
95
+ dollars = int(parts[0]) if parts[0] else 0
96
+ cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
97
+ if dollars and cents:
98
+ dollar_unit = "dollar" if dollars == 1 else "dollars"
99
+ cent_unit = "cent" if cents == 1 else "cents"
100
+ return " %s %s, %s %s " % (dollars, dollar_unit, cents, cent_unit)
101
+ elif dollars:
102
+ dollar_unit = "dollar" if dollars == 1 else "dollars"
103
+ return " %s %s " % (dollars, dollar_unit)
104
+ elif cents:
105
+ cent_unit = "cent" if cents == 1 else "cents"
106
+ return " %s %s " % (cents, cent_unit)
107
+ else:
108
+ return " zero dollars "
109
+
110
+ def _expand_fraction(self, m):
111
+ numerator = int(m.group(1))
112
+ denominator = int(m.group(2))
113
+ return self.fraction_to_words(numerator, denominator)
114
+
115
+ def _expand_decimal_point(self, m):
116
+ return m.group(1).replace(".", " point ")
117
+
118
+ def _expand_percent(self, m):
119
+ return m.group(1).replace("%", " percent ")
120
+
121
+ def _expand_ordinal(self, m):
122
+ return " " + self._inflect.number_to_words(m.group(0)) + " "
123
+
124
+ def _expand_number(self, m):
125
+ num = int(m.group(0))
126
+ if num > 1000 and num < 3000:
127
+ if num == 2000:
128
+ return " two thousand "
129
+ elif num > 2000 and num < 2010:
130
+ return " two thousand " + self._inflect.number_to_words(num % 100) + " "
131
+ elif num % 100 == 0:
132
+ return " " + self._inflect.number_to_words(num // 100) + " hundred "
133
+ else:
134
+ return (
135
+ " "
136
+ + self._inflect.number_to_words(
137
+ num, andword="", zero="oh", group=2
138
+ ).replace(", ", " ")
139
+ + " "
140
+ )
141
+ else:
142
+ return " " + self._inflect.number_to_words(num, andword="") + " "
143
+
144
+ def normalize_numbers(self, text):
145
+ text = re.sub(self._comma_number_re, self._remove_commas, text)
146
+ text = re.sub(self._pounds_re, r"\1 pounds", text)
147
+ text = re.sub(self._dollars_re, self._expand_dollars, text)
148
+ text = re.sub(self._fraction_re, self._expand_fraction, text)
149
+ text = re.sub(self._decimal_number_re, self._expand_decimal_point, text)
150
+ text = re.sub(self._percent_number_re, self._expand_percent, text)
151
+ text = re.sub(self._ordinal_re, self._expand_ordinal, text)
152
+ text = re.sub(self._number_re, self._expand_number, text)
153
+ return text
154
+
155
+ def expand_abbreviations(self, text):
156
+ for regex, replacement in self._abbreviations:
157
+ text = re.sub(regex, replacement, text)
158
+ return text
159
+
160
+
161
+ class ChineseTextNormalizer(TextNormalizer):
162
+ """
163
+ A class to handle preprocessing of Chinese text including normalization.
164
+ """
165
+
166
+ def normalize(self, text: str) -> str:
167
+ """Normalize text."""
168
+ # Convert numbers to Chinese
169
+ text = cn2an.transform(text, "an2cn")
170
+ return text
zipvoice/tokenizer/tokenizer.py ADDED
@@ -0,0 +1,648 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023-2024 Xiaomi Corp. (authors: Zengwei Yao
2
+ # Han Zhu,
3
+ # Wei Kang)
4
+ #
5
+ # See ../../LICENSE for clarification regarding multiple authors
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+
19
+ import logging
20
+ import re
21
+ from abc import ABC, abstractmethod
22
+ from functools import reduce
23
+ from typing import Dict, List, Optional
24
+
25
+ import jieba
26
+ from lhotse import CutSet
27
+ from pypinyin import Style, lazy_pinyin
28
+ from pypinyin.contrib.tone_convert import to_finals_tone3, to_initials
29
+
30
+ from zipvoice.tokenizer.normalizer import ChineseTextNormalizer, EnglishTextNormalizer
31
+
32
+ try:
33
+ from piper_phonemize import phonemize_espeak
34
+ except Exception as ex:
35
+ raise RuntimeError(
36
+ f"{ex}\nPlease run\n"
37
+ "pip install piper_phonemize -f \
38
+ https://k2-fsa.github.io/icefall/piper_phonemize.html"
39
+ )
40
+
41
+ jieba.default_logger.setLevel(logging.INFO)
42
+
43
+
44
+ class Tokenizer(ABC):
45
+ """Abstract base class for tokenizers, defining common interface."""
46
+
47
+ @abstractmethod
48
+ def texts_to_token_ids(self, texts: List[str]) -> List[List[int]]:
49
+ """Convert list of texts to list of token id sequences."""
50
+ raise NotImplementedError
51
+
52
+ @abstractmethod
53
+ def texts_to_tokens(self, texts: List[str]) -> List[List[str]]:
54
+ """Convert list of texts to list of token sequences."""
55
+ raise NotImplementedError
56
+
57
+ @abstractmethod
58
+ def tokens_to_token_ids(self, tokens: List[List[str]]) -> List[List[int]]:
59
+ """Convert list of token sequences to list of token id sequences."""
60
+ raise NotImplementedError
61
+
62
+
63
+ class SimpleTokenizer(Tokenizer):
64
+ """The simplpest tokenizer, treat every character as a token,
65
+ without text normalization.
66
+ """
67
+
68
+ def __init__(self, token_file: Optional[str] = None):
69
+ """
70
+ Args:
71
+ tokens: the file that contains information that maps tokens to ids,
72
+ which is a text file with '{token}\t{token_id}' per line.
73
+ """
74
+ # Parse token file
75
+ self.has_tokens = False
76
+ if token_file is None:
77
+ logging.debug(
78
+ "Initialize Tokenizer without tokens file, \
79
+ will fail when map to ids."
80
+ )
81
+ return
82
+ self.token2id: Dict[str, int] = {}
83
+ with open(token_file, "r", encoding="utf-8") as f:
84
+ for line in f.readlines():
85
+ info = line.rstrip().split("\t")
86
+ token, id = info[0], int(info[1])
87
+ assert token not in self.token2id, token
88
+ self.token2id[token] = id
89
+ self.pad_id = self.token2id["_"] # padding
90
+ self.vocab_size = len(self.token2id)
91
+ self.has_tokens = True
92
+
93
+ def texts_to_token_ids(
94
+ self,
95
+ texts: List[str],
96
+ ) -> List[List[int]]:
97
+ return self.tokens_to_token_ids(self.texts_to_tokens(texts))
98
+
99
+ def texts_to_tokens(
100
+ self,
101
+ texts: List[str],
102
+ ) -> List[List[str]]:
103
+ tokens_list = [list(texts[i]) for i in range(len(texts))]
104
+ return tokens_list
105
+
106
+ def tokens_to_token_ids(
107
+ self,
108
+ tokens_list: List[List[str]],
109
+ ) -> List[List[int]]:
110
+ assert self.has_tokens, "Please initialize Tokenizer with a tokens file."
111
+
112
+ token_ids_list = []
113
+
114
+ for tokens in tokens_list:
115
+ token_ids = []
116
+ for t in tokens:
117
+ if t not in self.token2id:
118
+ logging.debug(f"Skip OOV {t}")
119
+ continue
120
+ token_ids.append(self.token2id[t])
121
+
122
+ token_ids_list.append(token_ids)
123
+
124
+ return token_ids_list
125
+
126
+
127
+ class EspeakTokenizer(Tokenizer):
128
+ """A simple tokenizer with Espeak g2p function."""
129
+
130
+ def __init__(self, token_file: Optional[str] = None, lang: str = "en-us"):
131
+ """
132
+ Args:
133
+ tokens: the file that contains information that maps tokens to ids,
134
+ which is a text file with '{token}\t{token_id}' per line.
135
+ lang: the language identifier, see
136
+ https://github.com/rhasspy/espeak-ng/blob/master/docs/languages.md
137
+ """
138
+ # Parse token file
139
+ self.has_tokens = False
140
+ self.lang = lang
141
+ if token_file is None:
142
+ logging.debug(
143
+ "Initialize Tokenizer without tokens file, \
144
+ will fail when map to ids."
145
+ )
146
+ return
147
+ self.token2id: Dict[str, int] = {}
148
+ with open(token_file, "r", encoding="utf-8") as f:
149
+ for line in f.readlines():
150
+ info = line.rstrip().split("\t")
151
+ token, id = info[0], int(info[1])
152
+ assert token not in self.token2id, token
153
+ self.token2id[token] = id
154
+ self.pad_id = self.token2id["_"] # padding
155
+ self.vocab_size = len(self.token2id)
156
+ self.has_tokens = True
157
+
158
+ def g2p(self, text: str) -> List[str]:
159
+ try:
160
+ tokens = phonemize_espeak(text, self.lang)
161
+ tokens = reduce(lambda x, y: x + y, tokens)
162
+ return tokens
163
+ except Exception as ex:
164
+ logging.warning(f"Tokenization of {self.lang} texts failed: {ex}")
165
+ return []
166
+
167
+ def texts_to_token_ids(
168
+ self,
169
+ texts: List[str],
170
+ ) -> List[List[int]]:
171
+ return self.tokens_to_token_ids(self.texts_to_tokens(texts))
172
+
173
+ def texts_to_tokens(
174
+ self,
175
+ texts: List[str],
176
+ ) -> List[List[str]]:
177
+ tokens_list = [self.g2p(texts[i]) for i in range(len(texts))]
178
+ return tokens_list
179
+
180
+ def tokens_to_token_ids(
181
+ self,
182
+ tokens_list: List[List[str]],
183
+ ) -> List[List[int]]:
184
+ assert self.has_tokens, "Please initialize Tokenizer with a tokens file."
185
+
186
+ token_ids_list = []
187
+
188
+ for tokens in tokens_list:
189
+ token_ids = []
190
+ for t in tokens:
191
+ if t not in self.token2id:
192
+ logging.debug(f"Skip OOV {t}")
193
+ continue
194
+ token_ids.append(self.token2id[t])
195
+
196
+ token_ids_list.append(token_ids)
197
+
198
+ return token_ids_list
199
+
200
+
201
+ class EmiliaTokenizer(Tokenizer):
202
+ def __init__(self, token_file: Optional[str] = None, token_type="phone"):
203
+ """
204
+ Args:
205
+ tokens: the file that contains information that maps tokens to ids,
206
+ which is a text file with '{token}\t{token_id}' per line.
207
+ """
208
+ assert (
209
+ token_type == "phone"
210
+ ), f"Only support phone tokenizer for Emilia, but get {token_type}."
211
+
212
+ self.english_normalizer = EnglishTextNormalizer()
213
+ self.chinese_normalizer = ChineseTextNormalizer()
214
+
215
+ self.has_tokens = False
216
+ if token_file is None:
217
+ logging.debug(
218
+ "Initialize Tokenizer without tokens file, \
219
+ will fail when map to ids."
220
+ )
221
+ return
222
+ self.token2id: Dict[str, int] = {}
223
+ with open(token_file, "r", encoding="utf-8") as f:
224
+ for line in f.readlines():
225
+ info = line.rstrip().split("\t")
226
+ token, id = info[0], int(info[1])
227
+ assert token not in self.token2id, token
228
+ self.token2id[token] = id
229
+ self.pad_id = self.token2id["_"] # padding
230
+
231
+ self.vocab_size = len(self.token2id)
232
+ self.has_tokens = True
233
+
234
+ def texts_to_token_ids(
235
+ self,
236
+ texts: List[str],
237
+ ) -> List[List[int]]:
238
+ return self.tokens_to_token_ids(self.texts_to_tokens(texts))
239
+
240
+ def preprocess_text(
241
+ self,
242
+ text: str,
243
+ ) -> str:
244
+ return self.map_punctuations(text)
245
+
246
+ def texts_to_tokens(
247
+ self,
248
+ texts: List[str],
249
+ ) -> List[List[str]]:
250
+ for i in range(len(texts)):
251
+ # Text normalization
252
+ texts[i] = self.preprocess_text(texts[i])
253
+
254
+ phoneme_list = []
255
+ for text in texts:
256
+ # now only en and ch
257
+ segments = self.get_segment(text)
258
+ all_phoneme = []
259
+ for index in range(len(segments)):
260
+ seg = segments[index]
261
+ if seg[1] == "zh":
262
+ phoneme = self.tokenize_ZH(seg[0])
263
+ elif seg[1] == "en":
264
+ phoneme = self.tokenize_EN(seg[0])
265
+ elif seg[1] == "pinyin":
266
+ phoneme = self.tokenize_pinyin(seg[0])
267
+ elif seg[1] == "tag":
268
+ phoneme = [seg[0]]
269
+ else:
270
+ logging.warning(
271
+ f"No English or Chinese characters found, \
272
+ skipping segment of unknown language: {seg}"
273
+ )
274
+ continue
275
+ all_phoneme += phoneme
276
+ phoneme_list.append(all_phoneme)
277
+ return phoneme_list
278
+
279
+ def tokens_to_token_ids(
280
+ self,
281
+ tokens_list: List[List[str]],
282
+ ) -> List[List[int]]:
283
+ assert self.has_tokens, "Please initialize Tokenizer with a tokens file."
284
+ token_ids_list = []
285
+
286
+ for tokens in tokens_list:
287
+ token_ids = []
288
+ for t in tokens:
289
+ if t not in self.token2id:
290
+ logging.debug(f"Skip OOV {t}")
291
+ continue
292
+ token_ids.append(self.token2id[t])
293
+
294
+ token_ids_list.append(token_ids)
295
+
296
+ return token_ids_list
297
+
298
+ def tokenize_ZH(self, text: str) -> List[str]:
299
+ try:
300
+ text = self.chinese_normalizer.normalize(text)
301
+ segs = list(jieba.cut(text))
302
+ full = lazy_pinyin(
303
+ segs,
304
+ style=Style.TONE3,
305
+ tone_sandhi=True,
306
+ neutral_tone_with_five=True,
307
+ )
308
+ phones = []
309
+ for x in full:
310
+ # valid pinyin (in tone3 style) is alphabet + 1 number in [1-5].
311
+ if not (x[0:-1].isalpha() and x[-1] in ("1", "2", "3", "4", "5")):
312
+ phones.append(x)
313
+ continue
314
+ else:
315
+ phones.extend(self.seperate_pinyin(x))
316
+ return phones
317
+ except Exception as ex:
318
+ logging.warning(f"Tokenization of Chinese texts failed: {ex}")
319
+ return []
320
+
321
+ def tokenize_EN(self, text: str) -> List[str]:
322
+ try:
323
+ text = self.english_normalizer.normalize(text)
324
+ tokens = phonemize_espeak(text, "en-us")
325
+ tokens = reduce(lambda x, y: x + y, tokens)
326
+ return tokens
327
+ except Exception as ex:
328
+ logging.warning(f"Tokenization of English texts failed: {ex}")
329
+ return []
330
+
331
+ def tokenize_pinyin(self, text: str) -> List[str]:
332
+ try:
333
+ assert text.startswith("<") and text.endswith(">")
334
+ text = text.lstrip("<").rstrip(">")
335
+ # valid pinyin (in tone3 style) is alphabet + 1 number in [1-5].
336
+ if not (text[0:-1].isalpha() and text[-1] in ("1", "2", "3", "4", "5")):
337
+ logging.warning(
338
+ f"Strings enclosed with <> should be pinyin, \
339
+ but got: {text}. Skipped it. "
340
+ )
341
+ return []
342
+ else:
343
+ return self.seperate_pinyin(text)
344
+ except Exception as ex:
345
+ logging.warning(f"Tokenize pinyin failed: {ex}")
346
+ return []
347
+
348
+ def seperate_pinyin(self, text: str) -> List[str]:
349
+ """
350
+ Separate pinyin into initial and final
351
+ """
352
+ pinyins = []
353
+ initial = to_initials(text, strict=False)
354
+ # don't want to share tokens with espeak tokens,
355
+ # so use tone3 style
356
+ final = to_finals_tone3(
357
+ text,
358
+ strict=False,
359
+ neutral_tone_with_five=True,
360
+ )
361
+ if initial != "":
362
+ # don't want to share tokens with espeak tokens,
363
+ # so add a '0' after each initial
364
+ pinyins.append(initial + "0")
365
+ if final != "":
366
+ pinyins.append(final)
367
+ return pinyins
368
+
369
+ def map_punctuations(self, text):
370
+ text = text.replace(",", ",")
371
+ text = text.replace("。", ".")
372
+ text = text.replace("!", "!")
373
+ text = text.replace("?", "?")
374
+ text = text.replace(";", ";")
375
+ text = text.replace(":", ":")
376
+ text = text.replace("、", ",")
377
+ text = text.replace("‘", "'")
378
+ text = text.replace("“", '"')
379
+ text = text.replace("”", '"')
380
+ text = text.replace("’", "'")
381
+ text = text.replace("⋯", "…")
382
+ text = text.replace("···", "…")
383
+ text = text.replace("・・・", "…")
384
+ text = text.replace("...", "…")
385
+ return text
386
+
387
+ def get_segment(self, text: str) -> List[str]:
388
+ """
389
+ Split a text into segments based on language types
390
+ (Chinese, English, Pinyin, tags, etc.)
391
+
392
+ Args:
393
+ text (str): Input text to be segmented
394
+
395
+ Returns:
396
+ List[str]: Segmented text parts with their language types
397
+
398
+ Example:
399
+ Input: 我们是小米人,是吗? Yes I think so!霍...啦啦啦
400
+ Output: [('我们是小米人,是吗? ', 'zh'),
401
+ ('Yes I think so!', 'en'), ('霍...啦啦啦', 'zh')]
402
+ """
403
+ # Stores the final segmented parts and their language types
404
+ segments = []
405
+ # Stores the language type of each character in the input text
406
+ types = []
407
+ temp_seg = ""
408
+ temp_lang = ""
409
+
410
+ # Each part is a character, or a special string enclosed in <> and []
411
+ # <> denotes pinyin string, [] denotes other special strings.
412
+ _part_pattern = re.compile(r"[<[].*?[>\]]|.")
413
+ text = _part_pattern.findall(text)
414
+
415
+ for i, part in enumerate(text):
416
+ if self.is_chinese(part) or self.is_pinyin(part):
417
+ types.append("zh")
418
+ elif self.is_alphabet(part):
419
+ types.append("en")
420
+ else:
421
+ types.append("other")
422
+
423
+ assert len(types) == len(text)
424
+
425
+ for i in range(len(types)):
426
+ # find the first char of the seg
427
+ if i == 0:
428
+ temp_seg += text[i]
429
+ temp_lang = types[i]
430
+ else:
431
+ if temp_lang == "other":
432
+ temp_seg += text[i]
433
+ temp_lang = types[i]
434
+ else:
435
+ if types[i] in [temp_lang, "other"]:
436
+ temp_seg += text[i]
437
+ else:
438
+ segments.append((temp_seg, temp_lang))
439
+ temp_seg = text[i]
440
+ temp_lang = types[i]
441
+
442
+ segments.append((temp_seg, temp_lang))
443
+
444
+ # Handle "pinyin" and "tag" types
445
+ segments = self.split_segments(segments)
446
+ return segments
447
+
448
+ def split_segments(self, segments):
449
+ """
450
+ split segments into smaller parts if special strings enclosed by [] or <>
451
+ are found, where <> denotes pinyin strings, [] denotes other special strings.
452
+
453
+ Args:
454
+ segments (list): A list of tuples where each tuple contains:
455
+ - temp_seg (str): The text segment to be split.
456
+ - temp_lang (str): The language code associated with the segment.
457
+
458
+ Returns:
459
+ list: A list of smaller segments.
460
+ """
461
+ result = []
462
+ for temp_seg, temp_lang in segments:
463
+ parts = re.split(r"([<[].*?[>\]])", temp_seg)
464
+ for part in parts:
465
+ if not part:
466
+ continue
467
+ if self.is_pinyin(part):
468
+ result.append((part, "pinyin"))
469
+ elif self.is_tag(part):
470
+ result.append((part, "tag"))
471
+ else:
472
+ result.append((part, temp_lang))
473
+ return result
474
+
475
+ def is_chinese(self, char: str) -> bool:
476
+ if char >= "\u4e00" and char <= "\u9fa5":
477
+ return True
478
+ else:
479
+ return False
480
+
481
+ def is_alphabet(self, char: str) -> bool:
482
+ if (char >= "\u0041" and char <= "\u005a") or (
483
+ char >= "\u0061" and char <= "\u007a"
484
+ ):
485
+ return True
486
+ else:
487
+ return False
488
+
489
+ def is_pinyin(self, part: str) -> bool:
490
+ if part.startswith("<") and part.endswith(">"):
491
+ return True
492
+ else:
493
+ return False
494
+
495
+ def is_tag(self, part: str) -> bool:
496
+ if part.startswith("[") and part.endswith("]"):
497
+ return True
498
+ else:
499
+ return False
500
+
501
+
502
+ class DialogTokenizer(EmiliaTokenizer):
503
+ def __init__(self, token_file: Optional[str] = None, token_type="phone"):
504
+ super().__init__(token_file=token_file, token_type=token_type)
505
+ if token_file:
506
+ self.spk_a_id = self.token2id["[S1]"]
507
+ self.spk_b_id = self.token2id["[S2]"]
508
+
509
+ def preprocess_text(
510
+ self,
511
+ text: str,
512
+ ) -> str:
513
+ text = re.sub(r"\s*(\[S[12]\])\s*", r"\1", text)
514
+ text = self.map_punctuations(text)
515
+ return text
516
+
517
+
518
+ class LibriTTSTokenizer(Tokenizer):
519
+ def __init__(self, token_file: Optional[str] = None, token_type="char"):
520
+ """
521
+ Args:
522
+ type: the type of tokenizer, e.g., bpe, char, phone.
523
+ tokens: the file that contains information that maps tokens to ids,
524
+ which is a text file with '{token}\t{token_id}' per line if type is
525
+ char or phone, otherwise it is a bpe_model file.
526
+ """
527
+ self.type = token_type
528
+ assert token_type in ["bpe", "char", "phone"]
529
+ try:
530
+ import tacotron_cleaner.cleaners
531
+ except Exception as ex:
532
+ raise RuntimeError(f"{ex}\nPlease run\n" "pip install espnet_tts_frontend")
533
+
534
+ self.normalize = tacotron_cleaner.cleaners.custom_english_cleaners
535
+
536
+ self.has_tokens = False
537
+ if token_file is None:
538
+ logging.debug(
539
+ "Initialize Tokenizer without tokens file, \
540
+ will fail when map to ids."
541
+ )
542
+ return
543
+ if token_type == "bpe":
544
+ import sentencepiece as spm
545
+
546
+ self.sp = spm.SentencePieceProcessor()
547
+ self.sp.load(token_file)
548
+ self.pad_id = self.sp.piece_to_id("<pad>")
549
+ self.vocab_size = self.sp.get_piece_size()
550
+ else:
551
+ self.token2id: Dict[str, int] = {}
552
+ with open(token_file, "r", encoding="utf-8") as f:
553
+ for line in f.readlines():
554
+ info = line.rstrip().split("\t")
555
+ token, id = info[0], int(info[1])
556
+ assert token not in self.token2id, token
557
+ self.token2id[token] = id
558
+ self.pad_id = self.token2id["_"] # padding
559
+ self.vocab_size = len(self.token2id)
560
+ self.has_tokens = True
561
+
562
+ def texts_to_token_ids(
563
+ self,
564
+ texts: List[str],
565
+ ) -> List[List[int]]:
566
+ if self.type == "bpe":
567
+ for i in range(len(texts)):
568
+ texts[i] = self.normalize(texts[i])
569
+ return self.sp.encode(texts)
570
+ else:
571
+ return self.tokens_to_token_ids(self.texts_to_tokens(texts))
572
+
573
+ def texts_to_tokens(
574
+ self,
575
+ texts: List[str],
576
+ ) -> List[List[str]]:
577
+ for i in range(len(texts)):
578
+ texts[i] = self.normalize(texts[i])
579
+
580
+ if self.type == "char":
581
+ tokens_list = [list(texts[i]) for i in range(len(texts))]
582
+ elif self.type == "phone":
583
+ tokens_list = [
584
+ phonemize_espeak(texts[i].lower(), "en-us") for i in range(len(texts))
585
+ ]
586
+ elif self.type == "bpe":
587
+ tokens_list = self.sp.encode(texts, out_type=str)
588
+
589
+ return tokens_list
590
+
591
+ def tokens_to_token_ids(
592
+ self,
593
+ tokens_list: List[List[str]],
594
+ ) -> List[List[int]]:
595
+ assert self.has_tokens, "Please initialize Tokenizer with a tokens file."
596
+
597
+ assert self.type != "bpe", "BPE tokenizer does not support this function."
598
+
599
+ token_ids_list = []
600
+
601
+ for tokens in tokens_list:
602
+ token_ids = []
603
+ for t in tokens:
604
+ if t not in self.token2id:
605
+ logging.debug(f"Skip OOV {t}")
606
+ continue
607
+ token_ids.append(self.token2id[t])
608
+
609
+ token_ids_list.append(token_ids)
610
+
611
+ return token_ids_list
612
+
613
+
614
+ def add_tokens(cut_set: CutSet, tokenizer: str, lang: str):
615
+ if tokenizer == "emilia":
616
+ tokenizer = EmiliaTokenizer()
617
+ elif tokenizer == "espeak":
618
+ tokenizer = EspeakTokenizer(lang=lang)
619
+ elif tokenizer == "dialog":
620
+ tokenizer = DialogTokenizer()
621
+ elif tokenizer == "libritts":
622
+ tokenizer = LibriTTSTokenizer()
623
+ elif tokenizer == "simple":
624
+ tokenizer = SimpleTokenizer()
625
+ else:
626
+ raise ValueError(f"Unsupported tokenizer: {tokenizer}.")
627
+
628
+ def _prepare_cut(cut):
629
+ # Each cut only contains one supervision
630
+ assert len(cut.supervisions) == 1, (len(cut.supervisions), cut)
631
+ text = cut.supervisions[0].text
632
+ tokens = tokenizer.texts_to_tokens([text])[0]
633
+ cut.supervisions[0].tokens = tokens
634
+ return cut
635
+
636
+ cut_set = cut_set.map(_prepare_cut)
637
+ return cut_set
638
+
639
+
640
+ if __name__ == "__main__":
641
+ text = (
642
+ "我们是5年小米人,是吗? Yes I think so! "
643
+ "mr king, 5 years, from 2019 to 2024."
644
+ "霍...啦啦啦超过90%的人<le5>...?!9204"
645
+ )
646
+ tokenizer = EmiliaTokenizer()
647
+ tokens = tokenizer.texts_to_tokens([text])
648
+ print(f"tokens: {'|'.join(tokens[0])}")
zipvoice/utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # ZipVoice utils package
zipvoice/utils/checkpoint.py ADDED
@@ -0,0 +1,570 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021-2025 Xiaomi Corporation (authors: Fangjun Kuang,
2
+ # Zengwei Yao)
3
+ #
4
+ # See ../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ import glob
19
+ import logging
20
+ import os
21
+ import re
22
+ from pathlib import Path
23
+ from typing import Any, Dict, List, Optional, Union
24
+
25
+ import torch
26
+ import torch.nn as nn
27
+ from lhotse.dataset.sampling.base import CutSampler
28
+ from torch.nn.parallel import DistributedDataParallel as DDP
29
+ from torch.optim import Optimizer
30
+
31
+ from zipvoice.utils.common import AttributeDict, GradScaler
32
+
33
+ # use duck typing for LRScheduler since we have different possibilities, see
34
+ # our class LRScheduler.
35
+ LRSchedulerType = object
36
+
37
+
38
+ def save_checkpoint(
39
+ filename: Path,
40
+ model: Union[nn.Module, DDP],
41
+ model_avg: Optional[nn.Module] = None,
42
+ model_ema: Optional[nn.Module] = None,
43
+ params: Optional[Dict[str, Any]] = None,
44
+ optimizer: Optional[Optimizer] = None,
45
+ scheduler: Optional[LRSchedulerType] = None,
46
+ scaler: Optional[GradScaler] = None,
47
+ sampler: Optional[CutSampler] = None,
48
+ rank: int = 0,
49
+ ) -> None:
50
+ """Save training information to a file.
51
+
52
+ Args:
53
+ filename:
54
+ The checkpoint filename.
55
+ model:
56
+ The model to be saved. We only save its `state_dict()`.
57
+ model_avg:
58
+ The stored model averaged from the start of training.
59
+ model_ema:
60
+ The EMA version of model.
61
+ params:
62
+ User defined parameters, e.g., epoch, loss.
63
+ optimizer:
64
+ The optimizer to be saved. We only save its `state_dict()`.
65
+ scheduler:
66
+ The scheduler to be saved. We only save its `state_dict()`.
67
+ scalar:
68
+ The GradScaler to be saved. We only save its `state_dict()`.
69
+ sampler:
70
+ The sampler used in the labeled training dataset. We only
71
+ save its `state_dict()`.
72
+ rank:
73
+ Used in DDP. We save checkpoint only for the node whose
74
+ rank is 0.
75
+ Returns:
76
+ Return None.
77
+ """
78
+ if rank != 0:
79
+ return
80
+
81
+ logging.info(f"Saving checkpoint to {filename}")
82
+
83
+ if isinstance(model, DDP):
84
+ model = model.module
85
+
86
+ checkpoint = {
87
+ "model": model.state_dict(),
88
+ "optimizer": optimizer.state_dict() if optimizer is not None else None,
89
+ "scheduler": scheduler.state_dict() if scheduler is not None else None,
90
+ "grad_scaler": scaler.state_dict() if scaler is not None else None,
91
+ "sampler": sampler.state_dict() if sampler is not None else None,
92
+ }
93
+
94
+ if model_avg is not None:
95
+ checkpoint["model_avg"] = model_avg.to(torch.float32).state_dict()
96
+ if model_ema is not None:
97
+ checkpoint["model_ema"] = model_ema.to(torch.float32).state_dict()
98
+
99
+ if params:
100
+ for k, v in params.items():
101
+ assert k not in checkpoint
102
+ checkpoint[k] = v
103
+
104
+ torch.save(checkpoint, filename)
105
+
106
+
107
+ def load_checkpoint(
108
+ filename: Path,
109
+ model: Optional[nn.Module] = None,
110
+ model_avg: Optional[nn.Module] = None,
111
+ model_ema: Optional[nn.Module] = None,
112
+ strict: bool = False,
113
+ ) -> Dict[str, Any]:
114
+ logging.info(f"Loading checkpoint from {filename}")
115
+ checkpoint = torch.load(filename, map_location="cpu", weights_only=False)
116
+
117
+ if model is not None:
118
+
119
+ if next(iter(checkpoint["model"])).startswith("module."):
120
+ logging.debug("Loading checkpoint saved by DDP")
121
+ dst_state_dict = model.state_dict()
122
+ src_state_dict = checkpoint["model"]
123
+ for key in dst_state_dict.keys():
124
+ src_key = "{}.{}".format("module", key)
125
+ dst_state_dict[key] = src_state_dict.pop(src_key)
126
+ assert len(src_state_dict) == 0
127
+ model.load_state_dict(dst_state_dict, strict=strict)
128
+ else:
129
+ logging.debug("Loading checkpoint")
130
+ model.load_state_dict(checkpoint["model"], strict=strict)
131
+
132
+ checkpoint.pop("model")
133
+
134
+ if model_avg is not None and "model_avg" in checkpoint:
135
+ logging.info("Loading averaged model")
136
+ model_avg.load_state_dict(checkpoint["model_avg"], strict=strict)
137
+ checkpoint.pop("model_avg")
138
+
139
+ if model_ema is not None and "model_ema" in checkpoint:
140
+ logging.info("Loading ema model")
141
+ model_ema.load_state_dict(checkpoint["model_ema"], strict=strict)
142
+ checkpoint.pop("model_ema")
143
+
144
+ return checkpoint
145
+
146
+
147
+ def load_checkpoint_extend_vocab_size(
148
+ filename: Path, extend_size: int, model: nn.Module, strict: bool = True
149
+ ) -> Dict[str, Any]:
150
+ logging.info(f"Loading checkpoint from {filename}")
151
+ checkpoint = torch.load(filename, map_location="cpu", weights_only=False)
152
+
153
+ if model is not None:
154
+ if next(iter(checkpoint["model"])).startswith("module."):
155
+ logging.info("Loading checkpoint saved by DDP")
156
+ dst_state_dict = model.state_dict()
157
+ src_state_dict = checkpoint["model"]
158
+ for key in dst_state_dict.keys():
159
+ src_key = "{}.{}".format("module", key)
160
+ dst_state_dict[key] = src_state_dict.pop(src_key)
161
+ assert len(src_state_dict) == 0
162
+ else:
163
+ logging.info("Loading checkpoint")
164
+ dst_state_dict = checkpoint["model"]
165
+ dst_state_dict["spk_embed.weight"] = model.state_dict()["spk_embed.weight"]
166
+ embed_weight = model.state_dict()["embed.weight"]
167
+ embed_weight[:-extend_size, :] = dst_state_dict["embed.weight"]
168
+ dst_state_dict["embed.weight"] = embed_weight
169
+
170
+ model.load_state_dict(dst_state_dict, strict=strict)
171
+
172
+
173
+ def load_checkpoint_copy_proj_three_channel_alter(
174
+ filename: Path,
175
+ in_proj_key: str,
176
+ out_proj_key: str,
177
+ dim: int,
178
+ model: nn.Module,
179
+ ) -> Dict[str, Any]:
180
+ logging.info(f"Loading checkpoint from {filename}")
181
+ checkpoint = torch.load(filename, map_location="cpu", weights_only=False)
182
+
183
+ if model is not None:
184
+ if next(iter(checkpoint["model"])).startswith("module."):
185
+ logging.info("Loading checkpoint saved by DDP")
186
+
187
+ dst_state_dict = dict()
188
+ src_state_dict = checkpoint["model"]
189
+ for key in src_state_dict.keys():
190
+ dst_state_dict[key.lstrip("module.")] = src_state_dict.pop(key)
191
+ assert len(src_state_dict) == 0
192
+ else:
193
+ logging.info("Loading checkpoint")
194
+ dst_state_dict = checkpoint["model"]
195
+ keys = list(dst_state_dict.keys())
196
+ for key in keys:
197
+ if in_proj_key in key:
198
+ if "weight" in key:
199
+ weight = dst_state_dict.pop(key)
200
+ dst_state_dict[key.replace("weight", "0.weight")] = torch.cat(
201
+ [
202
+ weight[:, :dim] / 2,
203
+ weight[:, :dim] / 2,
204
+ weight[:, dim : dim * 2],
205
+ weight[:, dim * 2 :] / 2,
206
+ weight[:, dim * 2 :] / 2,
207
+ ],
208
+ dim=-1,
209
+ )
210
+ dst_state_dict[key.replace("weight", "1.weight")] = weight
211
+ if "bias" in key:
212
+ bias = dst_state_dict.pop(key)
213
+ dst_state_dict[key.replace("bias", "0.bias")] = bias
214
+ dst_state_dict[key.replace("bias", "1.bias")] = bias
215
+ if out_proj_key in key:
216
+ if "weight" in key:
217
+ weight = dst_state_dict.pop(key)
218
+ dst_state_dict[key.replace("weight", "0.weight")] = torch.cat(
219
+ [weight, weight], dim=0
220
+ )
221
+ dst_state_dict[key.replace("weight", "1.weight")] = weight
222
+ elif "bias" in key:
223
+ bias = dst_state_dict.pop(key)
224
+ dst_state_dict[key.replace("bias", "0.bias")] = torch.cat(
225
+ [bias, bias], dim=0
226
+ )
227
+ dst_state_dict[key.replace("bias", "1.bias")] = bias
228
+
229
+ model.load_state_dict(dst_state_dict, strict=True)
230
+
231
+
232
+ def find_checkpoints(out_dir: Path, iteration: int = 0) -> List[str]:
233
+ """Find all available checkpoints in a directory.
234
+
235
+ The checkpoint filenames have the form: `checkpoint-xxx.pt`
236
+ where xxx is a numerical value.
237
+
238
+ Assume you have the following checkpoints in the folder `foo`:
239
+
240
+ - checkpoint-1.pt
241
+ - checkpoint-20.pt
242
+ - checkpoint-300.pt
243
+ - checkpoint-4000.pt
244
+
245
+ Case 1 (Return all checkpoints)::
246
+
247
+ find_checkpoints(out_dir='foo')
248
+
249
+ Case 2 (Return checkpoints newer than checkpoint-20.pt, i.e.,
250
+ checkpoint-4000.pt, checkpoint-300.pt, and checkpoint-20.pt)
251
+
252
+ find_checkpoints(out_dir='foo', iteration=20)
253
+
254
+ Case 3 (Return checkpoints older than checkpoint-20.pt, i.e.,
255
+ checkpoint-20.pt, checkpoint-1.pt)::
256
+
257
+ find_checkpoints(out_dir='foo', iteration=-20)
258
+
259
+ Args:
260
+ out_dir:
261
+ The directory where to search for checkpoints.
262
+ iteration:
263
+ If it is 0, return all available checkpoints.
264
+ If it is positive, return the checkpoints whose iteration number is
265
+ greater than or equal to `iteration`.
266
+ If it is negative, return the checkpoints whose iteration number is
267
+ less than or equal to `-iteration`.
268
+ Returns:
269
+ Return a list of checkpoint filenames, sorted in descending
270
+ order by the numerical value in the filename.
271
+ """
272
+ checkpoints = list(glob.glob(f"{out_dir}/checkpoint-[0-9]*.pt"))
273
+ pattern = re.compile(r"checkpoint-([0-9]+).pt")
274
+ iter_checkpoints = []
275
+ for c in checkpoints:
276
+ result = pattern.search(c)
277
+ if not result:
278
+ logging.warn(f"Invalid checkpoint filename {c}")
279
+ continue
280
+
281
+ iter_checkpoints.append((int(result.group(1)), c))
282
+
283
+ # iter_checkpoints is a list of tuples. Each tuple contains
284
+ # two elements: (iteration_number, checkpoint-iteration_number.pt)
285
+
286
+ iter_checkpoints = sorted(iter_checkpoints, reverse=True, key=lambda x: x[0])
287
+ if iteration >= 0:
288
+ ans = [ic[1] for ic in iter_checkpoints if ic[0] >= iteration]
289
+ else:
290
+ ans = [ic[1] for ic in iter_checkpoints if ic[0] <= -iteration]
291
+
292
+ return ans
293
+
294
+
295
+ def average_checkpoints_with_averaged_model(
296
+ filename_start: str,
297
+ filename_end: str,
298
+ device: torch.device = torch.device("cpu"),
299
+ ) -> Dict[str, torch.Tensor]:
300
+ """Average model parameters over the range with given
301
+ start model (excluded) and end model.
302
+
303
+ Let start = batch_idx_train of model-start;
304
+ end = batch_idx_train of model-end;
305
+ interval = end - start.
306
+ Then the average model over range from start (excluded) to end is
307
+ (1) avg = (model_end * end - model_start * start) / interval.
308
+ It can be written as
309
+ (2) avg = model_end * weight_end + model_start * weight_start,
310
+ where weight_end = end / interval,
311
+ weight_start = -start / interval = 1 - weight_end.
312
+ Since the terms `weight_end` and `weight_start` would be large
313
+ if the model has been trained for lots of batches, which would cause
314
+ overflow when multiplying the model parameters.
315
+ To avoid this, we rewrite (2) as:
316
+ (3) avg = (model_end + model_start * (weight_start / weight_end))
317
+ * weight_end
318
+
319
+ The model index could be epoch number or iteration number.
320
+
321
+ Args:
322
+ filename_start:
323
+ Checkpoint filename of the start model. We assume it
324
+ is saved by :func:`save_checkpoint`.
325
+ filename_end:
326
+ Checkpoint filename of the end model. We assume it
327
+ is saved by :func:`save_checkpoint`.
328
+ device:
329
+ Move checkpoints to this device before averaging.
330
+ """
331
+ state_dict_start = torch.load(
332
+ filename_start, map_location=device, weights_only=False
333
+ )
334
+ state_dict_end = torch.load(filename_end, map_location=device, weights_only=False)
335
+
336
+ average_period = state_dict_start["average_period"]
337
+
338
+ batch_idx_train_start = state_dict_start["batch_idx_train"]
339
+ batch_idx_train_start = (batch_idx_train_start // average_period) * average_period
340
+ batch_idx_train_end = state_dict_end["batch_idx_train"]
341
+ batch_idx_train_end = (batch_idx_train_end // average_period) * average_period
342
+ interval = batch_idx_train_end - batch_idx_train_start
343
+ assert interval > 0, interval
344
+ weight_end = batch_idx_train_end / interval
345
+ weight_start = 1 - weight_end
346
+
347
+ model_end = state_dict_end["model_avg"]
348
+ model_start = state_dict_start["model_avg"]
349
+ avg = model_end
350
+
351
+ # scale the weight to avoid overflow
352
+ average_state_dict(
353
+ state_dict_1=avg,
354
+ state_dict_2=model_start,
355
+ weight_1=1.0,
356
+ weight_2=weight_start / weight_end,
357
+ scaling_factor=weight_end,
358
+ )
359
+
360
+ return avg
361
+
362
+
363
+ def remove_checkpoints(
364
+ out_dir: Path,
365
+ topk: int,
366
+ rank: int = 0,
367
+ ):
368
+ """Remove checkpoints from the given directory.
369
+
370
+ We assume that checkpoint filename has the form `checkpoint-xxx.pt`
371
+ where xxx is a number, representing the number of processed batches
372
+ when saving that checkpoint. We sort checkpoints by filename and keep
373
+ only the `topk` checkpoints with the highest `xxx`.
374
+
375
+ Args:
376
+ out_dir:
377
+ The directory containing checkpoints to be removed.
378
+ topk:
379
+ Number of checkpoints to keep.
380
+ rank:
381
+ If using DDP for training, it is the rank of the current node.
382
+ Use 0 if no DDP is used for training.
383
+ """
384
+ assert topk >= 1, topk
385
+ if rank != 0:
386
+ return
387
+ checkpoints = find_checkpoints(out_dir)
388
+
389
+ if len(checkpoints) == 0:
390
+ logging.warn(f"No checkpoints found in {out_dir}")
391
+ return
392
+
393
+ if len(checkpoints) <= topk:
394
+ return
395
+
396
+ to_remove = checkpoints[topk:]
397
+ for c in to_remove:
398
+ os.remove(c)
399
+
400
+
401
+ def resume_checkpoint(
402
+ params: AttributeDict,
403
+ model: nn.Module,
404
+ model_avg: nn.Module,
405
+ model_ema: Optional[nn.Module] = None,
406
+ ) -> Optional[Dict[str, Any]]:
407
+ """Load checkpoint from file.
408
+
409
+ If params.start_epoch is larger than 1, it will load the checkpoint from
410
+ `params.start_epoch - 1`.
411
+
412
+ Apart from loading state dict for `model` and `optimizer` it also updates
413
+ `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
414
+ and `best_valid_loss` in `params`.
415
+
416
+ Args:
417
+ params:
418
+ The return value of :func:`get_params`.
419
+ model:
420
+ The training model.
421
+ Returns:
422
+ Return a dict containing previously saved training info.
423
+ """
424
+ filename = params.exp_dir / f"epoch-{params.start_epoch - 1}.pt"
425
+
426
+ assert filename.is_file(), f"{filename} does not exist!"
427
+
428
+ saved_params = load_checkpoint(
429
+ filename,
430
+ model=model,
431
+ model_avg=model_avg,
432
+ model_ema=model_ema,
433
+ strict=True,
434
+ )
435
+
436
+ if params.start_epoch > 1:
437
+ keys = [
438
+ "best_train_epoch",
439
+ "best_valid_epoch",
440
+ "batch_idx_train",
441
+ "best_train_loss",
442
+ "best_valid_loss",
443
+ ]
444
+ for k in keys:
445
+ params[k] = saved_params[k]
446
+
447
+ return saved_params
448
+
449
+
450
+ def average_state_dict(
451
+ state_dict_1: Dict[str, torch.Tensor],
452
+ state_dict_2: Dict[str, torch.Tensor],
453
+ weight_1: float,
454
+ weight_2: float,
455
+ scaling_factor: float = 1.0,
456
+ ) -> Dict[str, torch.Tensor]:
457
+ """Average two state_dict with given weights:
458
+ state_dict_1 = (state_dict_1 * weight_1 + state_dict_2 * weight_2)
459
+ * scaling_factor
460
+ It is an in-place operation on state_dict_1 itself.
461
+ """
462
+ # Identify shared parameters. Two parameters are said to be shared
463
+ # if they have the same data_ptr
464
+ uniqued: Dict[int, str] = dict()
465
+ for k, v in state_dict_1.items():
466
+ v_data_ptr = v.data_ptr()
467
+ if v_data_ptr in uniqued:
468
+ continue
469
+ uniqued[v_data_ptr] = k
470
+
471
+ uniqued_names = list(uniqued.values())
472
+ for k in uniqued_names:
473
+ v = state_dict_1[k]
474
+ if torch.is_floating_point(v):
475
+ v *= weight_1
476
+ v += state_dict_2[k].to(device=state_dict_1[k].device) * weight_2
477
+ v *= scaling_factor
478
+
479
+
480
+ def update_averaged_model(
481
+ params: Dict[str, torch.Tensor],
482
+ model_cur: Union[nn.Module, DDP],
483
+ model_avg: nn.Module,
484
+ ) -> None:
485
+ """Update the averaged model:
486
+ model_avg = model_cur * (average_period / batch_idx_train)
487
+ + model_avg * ((batch_idx_train - average_period) / batch_idx_train)
488
+
489
+ Args:
490
+ params:
491
+ User defined parameters, e.g., epoch, loss.
492
+ model_cur:
493
+ The current model.
494
+ model_avg:
495
+ The averaged model to be updated.
496
+ """
497
+ weight_cur = params.average_period / params.batch_idx_train
498
+ weight_avg = 1 - weight_cur
499
+
500
+ if isinstance(model_cur, DDP):
501
+ model_cur = model_cur.module
502
+
503
+ cur = model_cur.state_dict()
504
+ avg = model_avg.state_dict()
505
+
506
+ average_state_dict(
507
+ state_dict_1=avg,
508
+ state_dict_2=cur,
509
+ weight_1=weight_avg,
510
+ weight_2=weight_cur,
511
+ )
512
+
513
+
514
+ def save_checkpoint_with_global_batch_idx(
515
+ out_dir: Path,
516
+ global_batch_idx: int,
517
+ model: Union[nn.Module, DDP],
518
+ model_avg: Optional[nn.Module] = None,
519
+ params: Optional[Dict[str, Any]] = None,
520
+ optimizer: Optional[Optimizer] = None,
521
+ scheduler: Optional[LRSchedulerType] = None,
522
+ scaler: Optional[GradScaler] = None,
523
+ sampler: Optional[CutSampler] = None,
524
+ rank: int = 0,
525
+ ):
526
+ """Save training info after processing given number of batches.
527
+
528
+ Args:
529
+ out_dir:
530
+ The directory to save the checkpoint.
531
+ global_batch_idx:
532
+ The number of batches processed so far from the very start of the
533
+ training. The saved checkpoint will have the following filename:
534
+
535
+ f'out_dir / checkpoint-{global_batch_idx}.pt'
536
+ model:
537
+ The neural network model whose `state_dict` will be saved in the
538
+ checkpoint.
539
+ model_avg:
540
+ The stored model averaged from the start of training.
541
+ params:
542
+ A dict of training configurations to be saved.
543
+ optimizer:
544
+ The optimizer used in the training. Its `state_dict` will be saved.
545
+ scheduler:
546
+ The learning rate scheduler used in the training. Its `state_dict` will
547
+ be saved.
548
+ scaler:
549
+ The scaler used for mix precision training. Its `state_dict` will
550
+ be saved.
551
+ sampler:
552
+ The sampler used in the training dataset.
553
+ rank:
554
+ The rank ID used in DDP training of the current node. Set it to 0
555
+ if DDP is not used.
556
+ """
557
+ out_dir = Path(out_dir)
558
+ out_dir.mkdir(parents=True, exist_ok=True)
559
+ filename = out_dir / f"checkpoint-{global_batch_idx}.pt"
560
+ save_checkpoint(
561
+ filename=filename,
562
+ model=model,
563
+ model_avg=model_avg,
564
+ params=params,
565
+ optimizer=optimizer,
566
+ scheduler=scheduler,
567
+ scaler=scaler,
568
+ sampler=sampler,
569
+ rank=rank,
570
+ )
zipvoice/utils/common.py ADDED
@@ -0,0 +1,670 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import collections
3
+ import json
4
+ import logging
5
+ import os
6
+ import socket
7
+ import subprocess
8
+ import sys
9
+ import warnings
10
+ from collections import defaultdict
11
+ from contextlib import contextmanager
12
+ from datetime import datetime
13
+ from pathlib import Path
14
+ from typing import Any, Dict, List, Tuple, Union
15
+
16
+ import torch
17
+ from packaging import version
18
+ from torch import distributed as dist
19
+ from torch import nn
20
+ from torch.nn.parallel import DistributedDataParallel as DDP
21
+ from torch.utils.tensorboard import SummaryWriter
22
+
23
+
24
+ if hasattr(torch.amp, "GradScaler"):
25
+ from torch.amp import GradScaler
26
+ else:
27
+ from torch.cuda.amp import GradScaler
28
+
29
+ Pathlike = Union[str, Path]
30
+
31
+
32
+ class AttributeDict(dict):
33
+ def __getattr__(self, key):
34
+ if key in self:
35
+ return self[key]
36
+ raise AttributeError(f"No such attribute '{key}'")
37
+
38
+ def __setattr__(self, key, value):
39
+ self[key] = value
40
+
41
+ def __delattr__(self, key):
42
+ if key in self:
43
+ del self[key]
44
+ return
45
+ raise AttributeError(f"No such attribute '{key}'")
46
+
47
+ def __str__(self, indent: int = 2):
48
+ tmp = {}
49
+ for k, v in self.items():
50
+ # PosixPath is ont JSON serializable
51
+ if isinstance(v, (Path, torch.device, torch.dtype)):
52
+ v = str(v)
53
+ tmp[k] = v
54
+ return json.dumps(tmp, indent=indent, sort_keys=True)
55
+
56
+
57
+ class MetricsTracker(collections.defaultdict):
58
+ def __init__(self):
59
+ # Passing the type 'int' to the base-class constructor
60
+ # makes undefined items default to int() which is zero.
61
+ # This class will play a role as metrics tracker.
62
+ # It can record many metrics, including but not limited to loss.
63
+ super(MetricsTracker, self).__init__(int)
64
+
65
+ def __add__(self, other: "MetricsTracker") -> "MetricsTracker":
66
+ ans = MetricsTracker()
67
+ for k, v in self.items():
68
+ ans[k] = v
69
+ for k, v in other.items():
70
+ if v - v == 0:
71
+ ans[k] = ans[k] + v
72
+ return ans
73
+
74
+ def __mul__(self, alpha: float) -> "MetricsTracker":
75
+ ans = MetricsTracker()
76
+ for k, v in self.items():
77
+ ans[k] = v * alpha
78
+ return ans
79
+
80
+ def __str__(self) -> str:
81
+ ans_frames = ""
82
+ ans_utterances = ""
83
+ for k, v in self.norm_items():
84
+ norm_value = "%.4g" % v
85
+ if "utt_" not in k:
86
+ ans_frames += str(k) + "=" + str(norm_value) + ", "
87
+ else:
88
+ ans_utterances += str(k) + "=" + str(norm_value)
89
+ if k == "utt_duration":
90
+ ans_utterances += " frames, "
91
+ elif k == "utt_pad_proportion":
92
+ ans_utterances += ", "
93
+ else:
94
+ raise ValueError(f"Unexpected key: {k}")
95
+ frames = "%.2f" % self["frames"]
96
+ ans_frames += "over " + str(frames) + " frames. "
97
+ if ans_utterances != "":
98
+ utterances = "%.2f" % self["utterances"]
99
+ ans_utterances += "over " + str(utterances) + " utterances."
100
+
101
+ return ans_frames + ans_utterances
102
+
103
+ def norm_items(self) -> List[Tuple[str, float]]:
104
+ """
105
+ Returns a list of pairs, like:
106
+ [('ctc_loss', 0.1), ('att_loss', 0.07)]
107
+ """
108
+ num_frames = self["frames"] if "frames" in self else 1
109
+ num_utterances = self["utterances"] if "utterances" in self else 1
110
+ ans = []
111
+ for k, v in self.items():
112
+ if k == "frames" or k == "utterances":
113
+ continue
114
+ norm_value = (
115
+ float(v) / num_frames if "utt_" not in k else float(v) / num_utterances
116
+ )
117
+ ans.append((k, norm_value))
118
+ return ans
119
+
120
+ def reduce(self, device):
121
+ """
122
+ Reduce using torch.distributed, which I believe ensures that
123
+ all processes get the total.
124
+ """
125
+ keys = sorted(self.keys())
126
+ s = torch.tensor([float(self[k]) for k in keys], device=device)
127
+ dist.all_reduce(s, op=dist.ReduceOp.SUM)
128
+ for k, v in zip(keys, s.cpu().tolist()):
129
+ self[k] = v
130
+
131
+ def write_summary(
132
+ self,
133
+ tb_writer: SummaryWriter,
134
+ prefix: str,
135
+ batch_idx: int,
136
+ ) -> None:
137
+ """Add logging information to a TensorBoard writer.
138
+
139
+ Args:
140
+ tb_writer: a TensorBoard writer
141
+ prefix: a prefix for the name of the loss, e.g. "train/valid_",
142
+ or "train/current_"
143
+ batch_idx: The current batch index, used as the x-axis of the plot.
144
+ """
145
+ for k, v in self.norm_items():
146
+ tb_writer.add_scalar(prefix + k, v, batch_idx)
147
+
148
+
149
+ @contextmanager
150
+ def torch_autocast(device_type="cuda", **kwargs):
151
+ """
152
+ To fix the following warnings:
153
+ FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated.
154
+ Please use `torch.amp.autocast('cuda', args...)` instead.
155
+ with torch.cuda.amp.autocast(enabled=False):
156
+ """
157
+ if version.parse(torch.__version__) >= version.parse("2.3.0"):
158
+ # Use new unified API
159
+ with torch.amp.autocast(device_type=device_type, **kwargs):
160
+ yield
161
+ else:
162
+ # Suppress deprecation warning and use old CUDA-specific autocast
163
+ with warnings.catch_warnings():
164
+ warnings.simplefilter("ignore", category=FutureWarning)
165
+ with torch.cuda.amp.autocast(**kwargs):
166
+ yield
167
+
168
+
169
+ def create_grad_scaler(device="cuda", **kwargs):
170
+ """
171
+ Creates a GradScaler compatible with both torch < 2.3.0 and >= 2.3.0.
172
+ Accepts all kwargs like: enabled, init_scale, growth_factor, etc.
173
+
174
+ FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated.
175
+ Please use `torch.amp.GradScaler('cuda', args...)` instead.
176
+ """
177
+ if version.parse(torch.__version__) >= version.parse("2.3.0"):
178
+ from torch.amp import GradScaler
179
+
180
+ return GradScaler(device=device, **kwargs)
181
+ else:
182
+ with warnings.catch_warnings():
183
+ warnings.simplefilter("ignore", category=FutureWarning)
184
+ return torch.cuda.amp.GradScaler(**kwargs)
185
+
186
+
187
+ def setup_dist(
188
+ rank=None,
189
+ world_size=None,
190
+ master_port=None,
191
+ use_ddp_launch=False,
192
+ master_addr=None,
193
+ ):
194
+ """
195
+ rank and world_size are used only if use_ddp_launch is False.
196
+ """
197
+ if "MASTER_ADDR" not in os.environ:
198
+ os.environ["MASTER_ADDR"] = (
199
+ "localhost" if master_addr is None else str(master_addr)
200
+ )
201
+
202
+ if "MASTER_PORT" not in os.environ:
203
+ os.environ["MASTER_PORT"] = "12354" if master_port is None else str(master_port)
204
+
205
+ if use_ddp_launch is False:
206
+ dist.init_process_group("nccl", rank=rank, world_size=world_size)
207
+ torch.cuda.set_device(rank)
208
+ else:
209
+ dist.init_process_group("nccl")
210
+
211
+
212
+ def cleanup_dist():
213
+ dist.destroy_process_group()
214
+
215
+
216
+ def prepare_input(
217
+ params: AttributeDict,
218
+ batch: dict,
219
+ device: torch.device,
220
+ return_tokens: bool = True,
221
+ return_feature: bool = True,
222
+ return_audio: bool = False,
223
+ ):
224
+ """
225
+ Parse the features and targets of the current batch.
226
+ Args:
227
+ params:
228
+ It is returned by :func:`get_params`.
229
+ batch:
230
+ It is the return value from iterating
231
+ `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
232
+ for the format of the `batch`.
233
+ device:
234
+ The device of Tensor.
235
+ """
236
+ return_list = []
237
+
238
+ if return_tokens:
239
+ return_list += [batch["tokens"]]
240
+
241
+ if return_feature:
242
+ features = batch["features"].to(device)
243
+ features_lens = batch["features_lens"].to(device)
244
+ return_list += [features * params.feat_scale, features_lens]
245
+
246
+ if return_audio:
247
+ return_list += [batch["audio"], batch["audio_lens"]]
248
+
249
+ return return_list
250
+
251
+
252
+ def prepare_avg_tokens_durations(features_lens, tokens_lens):
253
+ tokens_durations = []
254
+ for i in range(len(features_lens)):
255
+ utt_duration = features_lens[i]
256
+ avg_token_duration = utt_duration // tokens_lens[i]
257
+ tokens_durations.append([avg_token_duration] * tokens_lens[i])
258
+ return tokens_durations
259
+
260
+
261
+ def pad_labels(y: List[List[int]], pad_id: int, device: torch.device):
262
+ """
263
+ Pad the transcripts to the same length with zeros.
264
+
265
+ Args:
266
+ y: the transcripts, which is a list of a list
267
+
268
+ Returns:
269
+ Return a Tensor of padded transcripts.
270
+ """
271
+ y = [token_ids + [pad_id] for token_ids in y]
272
+ length = max([len(token_ids) for token_ids in y])
273
+ y = [token_ids + [pad_id] * (length - len(token_ids)) for token_ids in y]
274
+ return torch.tensor(y, dtype=torch.int64, device=device)
275
+
276
+
277
+ def get_tokens_index(durations: List[List[int]], num_frames: int) -> torch.Tensor:
278
+ """
279
+ Gets position in the transcript for each frame, i.e. the position
280
+ in the symbol-sequence to look up.
281
+
282
+ Args:
283
+ durations:
284
+ Duration of each token in transcripts.
285
+ num_frames:
286
+ The maximum frame length of the current batch.
287
+
288
+ Returns:
289
+ Return a Tensor of shape (batch_size, num_frames)
290
+ """
291
+ durations = [x + [num_frames - sum(x)] for x in durations]
292
+ batch_size = len(durations)
293
+ ans = torch.zeros(batch_size, num_frames, dtype=torch.int64)
294
+ for b in range(batch_size):
295
+ this_dur = durations[b]
296
+ cur_frame = 0
297
+ for i, d in enumerate(this_dur):
298
+ ans[b, cur_frame : cur_frame + d] = i
299
+ cur_frame += d
300
+ assert cur_frame == num_frames, (cur_frame, num_frames)
301
+ return ans
302
+
303
+
304
+ def to_int_tuple(s: Union[str, int]):
305
+ if isinstance(s, int):
306
+ return (s,)
307
+ return tuple(map(int, s.split(",")))
308
+
309
+
310
+ def get_adjusted_batch_count(params: AttributeDict) -> float:
311
+ # returns the number of batches we would have used so far if we had used the
312
+ # reference duration. This is for purposes of set_batch_count().
313
+ return (
314
+ params.batch_idx_train
315
+ * (params.max_duration * params.world_size)
316
+ / params.ref_duration
317
+ )
318
+
319
+
320
+ def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
321
+ if isinstance(model, DDP):
322
+ # get underlying nn.Module
323
+ model = model.module
324
+ for name, module in model.named_modules():
325
+ if hasattr(module, "batch_count"):
326
+ module.batch_count = batch_count
327
+ if hasattr(module, "name"):
328
+ module.name = name
329
+
330
+
331
+ def condition_time_mask(
332
+ features_lens: torch.Tensor,
333
+ mask_percent: Tuple[float, float],
334
+ max_len: int = 0,
335
+ ) -> torch.Tensor:
336
+ """
337
+ Apply Time masking.
338
+ Args:
339
+ features_lens:
340
+ input tensor of shape ``(B)``
341
+ mask_size:
342
+ the width size for masking.
343
+ max_len:
344
+ the maximum length of the mask.
345
+ Returns:
346
+ Return a 2-D bool tensor (B, T), where masked positions
347
+ are filled with `True` and non-masked positions are
348
+ filled with `False`.
349
+ """
350
+ mask_size = (
351
+ torch.zeros_like(features_lens, dtype=torch.float32).uniform_(*mask_percent)
352
+ * features_lens
353
+ ).to(torch.int64)
354
+ mask_starts = (
355
+ torch.rand_like(mask_size, dtype=torch.float32) * (features_lens - mask_size)
356
+ ).to(torch.int64)
357
+ mask_ends = mask_starts + mask_size
358
+ max_len = max(max_len, features_lens.max())
359
+ seq_range = torch.arange(0, max_len, device=features_lens.device)
360
+ mask = (seq_range[None, :] >= mask_starts[:, None]) & (
361
+ seq_range[None, :] < mask_ends[:, None]
362
+ )
363
+ return mask
364
+
365
+
366
+ def condition_time_mask_suffix(
367
+ features_lens: torch.Tensor,
368
+ mask_percent: Tuple[float, float],
369
+ max_len: int = 0,
370
+ ) -> torch.Tensor:
371
+ """
372
+ Apply Time masking, mask from the end time index.
373
+ Args:
374
+ features_lens:
375
+ input tensor of shape ``(B)``
376
+ mask_size:
377
+ the width size for masking.
378
+ max_len:
379
+ the maximum length of the mask.
380
+ Returns:
381
+ Return a 2-D bool tensor (B, T), where masked positions
382
+ are filled with `True` and non-masked positions are
383
+ filled with `False`.
384
+ """
385
+ mask_size = (
386
+ torch.zeros_like(features_lens, dtype=torch.float32).uniform_(*mask_percent)
387
+ * features_lens
388
+ ).to(torch.int64)
389
+ mask_starts = (
390
+ torch.ones_like(mask_size, dtype=torch.float32) * (features_lens - mask_size)
391
+ ).to(torch.int64)
392
+ mask_ends = mask_starts + mask_size
393
+ max_len = max(max_len, features_lens.max())
394
+ seq_range = torch.arange(0, max_len, device=features_lens.device)
395
+ mask = (seq_range[None, :] >= mask_starts[:, None]) & (
396
+ seq_range[None, :] < mask_ends[:, None]
397
+ )
398
+ return mask
399
+
400
+
401
+ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
402
+ """
403
+ Args:
404
+ lengths:
405
+ A 1-D tensor containing sentence lengths.
406
+ max_len:
407
+ The length of masks.
408
+ Returns:
409
+ Return a 2-D bool tensor, where masked positions
410
+ are filled with `True` and non-masked positions are
411
+ filled with `False`.
412
+
413
+ >>> lengths = torch.tensor([1, 3, 2, 5])
414
+ >>> make_pad_mask(lengths)
415
+ tensor([[False, True, True, True, True],
416
+ [False, False, False, True, True],
417
+ [False, False, True, True, True],
418
+ [False, False, False, False, False]])
419
+ """
420
+ assert lengths.ndim == 1, lengths.ndim
421
+ max_len = max(max_len, lengths.max())
422
+ n = lengths.size(0)
423
+ seq_range = torch.arange(0, max_len, device=lengths.device)
424
+ expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len)
425
+
426
+ return expaned_lengths >= lengths.unsqueeze(-1)
427
+
428
+
429
+ def str2bool(v):
430
+ """Used in argparse.ArgumentParser.add_argument to indicate
431
+ that a type is a bool type and user can enter
432
+
433
+ - yes, true, t, y, 1, to represent True
434
+ - no, false, f, n, 0, to represent False
435
+
436
+ See https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse # noqa
437
+ """
438
+ if isinstance(v, bool):
439
+ return v
440
+ if v.lower() in ("yes", "true", "t", "y", "1"):
441
+ return True
442
+ elif v.lower() in ("no", "false", "f", "n", "0"):
443
+ return False
444
+ else:
445
+ raise argparse.ArgumentTypeError("Boolean value expected.")
446
+
447
+
448
+ def setup_logger(
449
+ log_filename: Pathlike,
450
+ log_level: str = "info",
451
+ use_console: bool = True,
452
+ ) -> None:
453
+ """Setup log level.
454
+
455
+ Args:
456
+ log_filename:
457
+ The filename to save the log.
458
+ log_level:
459
+ The log level to use, e.g., "debug", "info", "warning", "error",
460
+ "critical"
461
+ use_console:
462
+ True to also print logs to console.
463
+ """
464
+ now = datetime.now()
465
+ date_time = now.strftime("%Y-%m-%d-%H-%M-%S")
466
+ if dist.is_available() and dist.is_initialized():
467
+ world_size = dist.get_world_size()
468
+ rank = dist.get_rank()
469
+ formatter = f"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] ({rank}/{world_size}) %(message)s" # noqa
470
+ log_filename = f"{log_filename}-{date_time}-{rank}"
471
+ else:
472
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
473
+ log_filename = f"{log_filename}-{date_time}"
474
+
475
+ os.makedirs(os.path.dirname(log_filename), exist_ok=True)
476
+
477
+ level = logging.ERROR
478
+ if log_level == "debug":
479
+ level = logging.DEBUG
480
+ elif log_level == "info":
481
+ level = logging.INFO
482
+ elif log_level == "warning":
483
+ level = logging.WARNING
484
+ elif log_level == "critical":
485
+ level = logging.CRITICAL
486
+
487
+ logging.basicConfig(
488
+ filename=log_filename,
489
+ format=formatter,
490
+ level=level,
491
+ filemode="w",
492
+ force=True,
493
+ )
494
+ if use_console:
495
+ console = logging.StreamHandler()
496
+ console.setLevel(level)
497
+ console.setFormatter(logging.Formatter(formatter))
498
+ logging.getLogger("").addHandler(console)
499
+
500
+
501
+ def get_git_sha1():
502
+ try:
503
+ git_commit = (
504
+ subprocess.run(
505
+ ["git", "rev-parse", "--short", "HEAD"],
506
+ check=True,
507
+ stdout=subprocess.PIPE,
508
+ )
509
+ .stdout.decode()
510
+ .rstrip("\n")
511
+ .strip()
512
+ )
513
+ dirty_commit = (
514
+ len(
515
+ subprocess.run(
516
+ ["git", "diff", "--shortstat"],
517
+ check=True,
518
+ stdout=subprocess.PIPE,
519
+ )
520
+ .stdout.decode()
521
+ .rstrip("\n")
522
+ .strip()
523
+ )
524
+ > 0
525
+ )
526
+ git_commit = git_commit + "-dirty" if dirty_commit else git_commit + "-clean"
527
+ except: # noqa
528
+ return None
529
+
530
+ return git_commit
531
+
532
+
533
+ def get_git_date():
534
+ try:
535
+ git_date = (
536
+ subprocess.run(
537
+ ["git", "log", "-1", "--format=%ad", "--date=local"],
538
+ check=True,
539
+ stdout=subprocess.PIPE,
540
+ )
541
+ .stdout.decode()
542
+ .rstrip("\n")
543
+ .strip()
544
+ )
545
+ except: # noqa
546
+ return None
547
+
548
+ return git_date
549
+
550
+
551
+ def get_git_branch_name():
552
+ try:
553
+ git_date = (
554
+ subprocess.run(
555
+ ["git", "rev-parse", "--abbrev-ref", "HEAD"],
556
+ check=True,
557
+ stdout=subprocess.PIPE,
558
+ )
559
+ .stdout.decode()
560
+ .rstrip("\n")
561
+ .strip()
562
+ )
563
+ except: # noqa
564
+ return None
565
+
566
+ return git_date
567
+
568
+
569
+ def get_env_info() -> Dict[str, Any]:
570
+ """Get the environment information."""
571
+ return {
572
+ "torch-version": str(torch.__version__),
573
+ "torch-cuda-available": torch.cuda.is_available(),
574
+ "torch-cuda-version": torch.version.cuda,
575
+ "python-version": sys.version[:4],
576
+ "zipvoice-git-branch": get_git_branch_name(),
577
+ "zipvoice-git-sha1": get_git_sha1(),
578
+ "zipvoice-git-date": get_git_date(),
579
+ "zipvoice-path": str(Path(__file__).resolve().parent.parent),
580
+ "hostname": socket.gethostname(),
581
+ "IP address": socket.gethostbyname(socket.gethostname()),
582
+ }
583
+
584
+
585
+ def get_parameter_groups_with_lrs(
586
+ model: nn.Module,
587
+ lr: float,
588
+ include_names: bool = False,
589
+ freeze_modules: List[str] = [],
590
+ unfreeze_modules: List[str] = [],
591
+ ) -> List[dict]:
592
+ """
593
+ This is for use with the ScaledAdam optimizers (more recent versions that accept
594
+ lists of named-parameters; we can, if needed, create a version without the names).
595
+
596
+ It provides a way to specify learning-rate scales inside the module, so that if
597
+ any nn.Module in the hierarchy has a floating-point parameter 'lr_scale', it will
598
+ scale the LR of any parameters inside that module or its submodules. Note: you
599
+ can set module parameters outside the __init__ function, e.g.:
600
+ >>> a = nn.Linear(10, 10)
601
+ >>> a.lr_scale = 0.5
602
+
603
+ Returns: a list of dicts, of the following form:
604
+ if include_names == False:
605
+ [ { 'params': [ tensor1, tensor2, ... ], 'lr': 0.01 },
606
+ { 'params': [ tensor3, tensor4, ... ], 'lr': 0.005 },
607
+ ... ]
608
+ if include_names == true:
609
+ [ { 'named_params': [ (name1, tensor1, (name2, tensor2), ... ], 'lr': 0.01 },
610
+ { 'named_params': [ (name3, tensor3), (name4, tensor4), ... ], 'lr': 0.005 },
611
+ ... ]
612
+
613
+ """
614
+ # Use freeze_modules or unfreeze_modules to freeze or unfreeze modules
615
+ assert not (len(freeze_modules) and len(unfreeze_modules))
616
+
617
+ # flat_lr_scale just contains the lr_scale explicitly specified
618
+ # for each prefix of the name, e.g. 'encoder.layers.3', these need
619
+ # to be multiplied for all prefix of the name of any given parameter.
620
+ flat_lr_scale = defaultdict(lambda: 1.0)
621
+ names = []
622
+ for name, m in model.named_modules():
623
+ names.append(name)
624
+ if hasattr(m, "lr_scale"):
625
+ flat_lr_scale[name] = m.lr_scale
626
+
627
+ # lr_to_parames is a dict from learning rate (floating point) to: if
628
+ # include_names == true, a list of (name, parameter) for that learning rate;
629
+ # otherwise a list of parameters for that learning rate.
630
+ lr_to_params = defaultdict(list)
631
+
632
+ for name, parameter in model.named_parameters():
633
+ if not parameter.requires_grad:
634
+ logging.info(f"Remove {name} from parameter")
635
+ continue
636
+ split_name = name.split(".")
637
+ # caution: as a special case, if the name is '', split_name will be [ '' ].
638
+ prefix = split_name[0]
639
+ if len(freeze_modules) > 0:
640
+ if prefix == "module": # DDP
641
+ module_name = split_name[1]
642
+ if module_name in freeze_modules:
643
+ logging.info(f"Remove {name} from parameters")
644
+ continue
645
+ else:
646
+ if prefix in freeze_modules:
647
+ logging.info(f"Remove {name} from parameters")
648
+ continue
649
+ elif len(unfreeze_modules) > 0:
650
+ if prefix == "module": # DDP
651
+ module_name = split_name[1]
652
+ if module_name not in unfreeze_modules:
653
+ logging.info(f"Remove {name} from parameters")
654
+ continue
655
+ else:
656
+ if prefix not in unfreeze_modules:
657
+ logging.info(f"Remove {name} from parameters")
658
+ continue
659
+ cur_lr = lr * flat_lr_scale[prefix]
660
+ if prefix != "":
661
+ cur_lr *= flat_lr_scale[""]
662
+ for part in split_name[1:]:
663
+ prefix = ".".join([prefix, part])
664
+ cur_lr *= flat_lr_scale[prefix]
665
+ lr_to_params[cur_lr].append((name, parameter) if include_names else parameter)
666
+
667
+ if include_names:
668
+ return [{"named_params": pairs, "lr": lr} for lr, pairs in lr_to_params.items()]
669
+ else:
670
+ return [{"params": params, "lr": lr} for lr, params in lr_to_params.items()]
zipvoice/utils/diagnostics.py ADDED
@@ -0,0 +1,723 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022-2024 Xiaomi Corp. (authors: Daniel Povey
2
+ # Zengwei Yao
3
+ # Mingshuang Luo,
4
+ # Zengrui Jin,)
5
+ #
6
+ # See ../LICENSE for clarification regarding multiple authors
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+
20
+ import logging
21
+ import random
22
+ from dataclasses import dataclass
23
+ from typing import Optional, Tuple
24
+
25
+ import torch
26
+ from torch import Tensor, nn
27
+
28
+
29
+ class TensorDiagnosticOptions(object):
30
+ """Options object for tensor diagnostics:
31
+
32
+ Args:
33
+ max_eig_dim:
34
+ The maximum dimension for which we print out eigenvalues
35
+ (limited for speed reasons).
36
+ """
37
+
38
+ def __init__(self, max_eig_dim: int = 512):
39
+ self.max_eig_dim = max_eig_dim
40
+
41
+ def dim_is_summarized(self, size: int):
42
+ return size > 10 and size != 31
43
+
44
+
45
+ def get_tensor_stats(
46
+ x: Tensor,
47
+ dim: int,
48
+ stats_type: str,
49
+ ) -> Tuple[Tensor, int]:
50
+ """
51
+ Returns the specified transformation of the Tensor (either x or x.abs()
52
+ or (x > 0), summed over all but the index `dim`.
53
+
54
+ Args:
55
+ x:
56
+ Tensor, tensor to be analyzed
57
+ dim:
58
+ Dimension with 0 <= dim < x.ndim
59
+ stats_type:
60
+ The stats_type includes several types:
61
+ "abs" -> take abs() before summing
62
+ "positive" -> take (x > 0) before summing
63
+ "rms" -> square before summing, we'll take sqrt later
64
+ "value" -> just sum x itself
65
+ "max", "min" -> take the maximum or minimum [over all other dims but dim]
66
+ instead of summing
67
+ "rms-sort" -> this is a bit different than the others, it's based on computing
68
+ the rms over the specified dim and returning percentiles of the result
69
+ (11 of them).
70
+ Returns:
71
+ stats: a Tensor of shape (x.shape[dim],).
72
+ count: an integer saying how many items were counted in each element
73
+ of stats.
74
+ """
75
+
76
+ if stats_type == "rms-sort":
77
+ rms = (x**2).mean(dim=dim).sqrt()
78
+ rms = rms.flatten()
79
+ rms = rms.sort()[0]
80
+ rms = rms[(torch.arange(11) * rms.numel() // 10).clamp(max=rms.numel() - 1)]
81
+ count = 1.0
82
+ return rms, count
83
+
84
+ count = x.numel() // x.shape[dim]
85
+
86
+ if stats_type == "eigs":
87
+ x = x.transpose(dim, -1)
88
+ x = x.reshape(-1, x.shape[-1])
89
+ # shape of returned tensor: (s, s),
90
+ # where s is size of dimension `dim` of original x.
91
+ return torch.matmul(x.transpose(0, 1), x), count
92
+ elif stats_type == "abs":
93
+ x = x.abs()
94
+ elif stats_type == "rms":
95
+ x = x**2
96
+ elif stats_type == "positive":
97
+ x = (x > 0).to(dtype=torch.float)
98
+ else:
99
+ assert stats_type in ["value", "max", "min"]
100
+
101
+ sum_dims = [d for d in range(x.ndim) if d != dim]
102
+ if len(sum_dims) > 0:
103
+ if stats_type == "max":
104
+ for dim in reversed(sum_dims):
105
+ x = torch.max(x, dim=dim)[0]
106
+ elif stats_type == "min":
107
+ for dim in reversed(sum_dims):
108
+ x = torch.min(x, dim=dim)[0]
109
+ else:
110
+ x = torch.sum(x, dim=sum_dims)
111
+ x = x.flatten().clone()
112
+ return x, count
113
+
114
+
115
+ @dataclass
116
+ class TensorAndCount:
117
+ tensor: Tensor
118
+ count: int
119
+
120
+
121
+ class TensorDiagnostic(object):
122
+ """This class is not directly used by the user, it is responsible for
123
+ collecting diagnostics for a module or parameter tensor of a torch.nn.Module.
124
+
125
+ Args:
126
+ opts:
127
+ Options object.
128
+ name:
129
+ The name associated with this diagnostics object, will probably be
130
+ {module_name}.X where X is "output" or "grad", or {parameter_name}.
131
+ Y where Y is param_value or param_grad.
132
+ """
133
+
134
+ def __init__(self, opts: TensorDiagnosticOptions, name: str):
135
+ self.opts = opts
136
+ self.name = name
137
+ self.class_name = None # will assign in accumulate()
138
+
139
+ self.stats = None # we'll later assign a list to self.stats.
140
+ # It's a list of dicts, indexed by dim (i.e. by the
141
+ # axis of the tensor). The dicts, in turn, are
142
+ # indexed by `stats-type` which are strings in
143
+ # ["abs", "max", "min", "positive", "value", "rms"].
144
+
145
+ # scalar_stats contains some analysis of the activations and gradients,
146
+ self.scalar_stats = None
147
+
148
+ # the keys into self.stats[dim] are strings, whose values can be
149
+ # "abs", "max", "min" ,"value", "positive", "rms", "value".
150
+ # The values e.g. self.stats[dim]["rms"] are lists of dataclass TensorAndCount,
151
+ # containing a tensor and its associated count (which is the sum of the other
152
+ # dims that we aggregated over, e.g. the number of frames and/or batch elements
153
+ # and/or channels.
154
+ # ... we actually accumulate the Tensors / counts any time we have the same-dim
155
+ # tensor, only adding a new element to the list if there was a different dim.
156
+ # if the string in the key is "eigs", if we detect a length mismatch we put None
157
+ # as the value.
158
+
159
+ def accumulate(self, x, class_name: Optional[str] = None):
160
+ """
161
+ Accumulate tensors.
162
+ """
163
+ if class_name is not None:
164
+ self.class_name = class_name
165
+ if isinstance(x, Tuple):
166
+ x = x[0]
167
+ if not isinstance(x, Tensor):
168
+ return
169
+ if x.numel() == 0: # for empty tensor
170
+ return
171
+ x = x.detach().clone()
172
+ if x.ndim == 0:
173
+ x = x.unsqueeze(0)
174
+ ndim = x.ndim
175
+ if self.stats is None:
176
+ self.stats = [dict() for _ in range(ndim)]
177
+
178
+ for dim in range(ndim):
179
+ this_dim_stats = self.stats[dim]
180
+ if ndim > 1:
181
+ # rms-sort is different from the others, it's based on summing over just
182
+ # this dim, then sorting and returning the percentiles.
183
+ stats_types = [
184
+ "abs",
185
+ "max",
186
+ "min",
187
+ "positive",
188
+ "value",
189
+ "rms",
190
+ "rms-sort",
191
+ ]
192
+ if x.shape[dim] <= self.opts.max_eig_dim:
193
+ stats_types.append("eigs")
194
+ else:
195
+ stats_types = ["value", "abs", "max", "min"]
196
+
197
+ for stats_type in stats_types:
198
+ stats, count = get_tensor_stats(x, dim, stats_type)
199
+ if stats_type not in this_dim_stats:
200
+ this_dim_stats[stats_type] = [] # list of TensorAndCount
201
+
202
+ done = False
203
+ if this_dim_stats[stats_type] is None:
204
+ # we can reach here if we detected for stats_type "eigs" that
205
+ # where was more than one different size for this dim. Then we
206
+ # disable accumulating this stats type, as it uses too much memory.
207
+ continue
208
+ for s in this_dim_stats[stats_type]:
209
+ if s.tensor.shape == stats.shape:
210
+ if stats_type == "max":
211
+ s.tensor = torch.maximum(s.tensor, stats)
212
+
213
+ elif stats_type == "min":
214
+ s.tensor = torch.minimum(s.tensor, stats)
215
+ else:
216
+ assert stats_type != "max"
217
+ s.tensor += stats
218
+ s.count += count
219
+ done = True
220
+ break
221
+ if not done:
222
+ if this_dim_stats[stats_type] != [] and stats_type == "eigs":
223
+ # >1 size encountered on this dim, e.g. it's a batch or time
224
+ # dimension, don't accumulat "eigs" stats type, it uses too much
225
+ # memory
226
+ this_dim_stats[stats_type] = None
227
+ else:
228
+ this_dim_stats[stats_type].append(TensorAndCount(stats, count))
229
+
230
+ def print_diagnostics(self):
231
+ """Print diagnostics for each dimension of the tensor."""
232
+ if self.stats is None:
233
+ print(f"Warning: the stats of {self.name} is None.")
234
+ return
235
+ for dim, this_dim_stats in enumerate(self.stats):
236
+ if "rms" in this_dim_stats and "value" in this_dim_stats:
237
+ # produce "stddev" stats, which is centered RMS.
238
+ rms_stats_list = this_dim_stats["rms"]
239
+ value_stats_list = this_dim_stats["value"]
240
+ if len(rms_stats_list) == len(value_stats_list):
241
+ stddev_stats_list = []
242
+ for r, v in zip(rms_stats_list, value_stats_list):
243
+ stddev_stats_list.append(
244
+ # r.count and v.count should be the same, but we don't check
245
+ # this.
246
+ TensorAndCount(
247
+ r.tensor - v.tensor * v.tensor / (v.count + 1.0e-20),
248
+ r.count,
249
+ )
250
+ )
251
+ this_dim_stats["stddev"] = stddev_stats_list
252
+
253
+ for stats_type, stats_list in this_dim_stats.items():
254
+ # stats_type could be "rms", "value", "abs", "eigs", "positive", "min"
255
+ # or "max". "stats_list" could be a list of TensorAndCount (one list per
256
+ # distinct tensor shape of the stats), or None
257
+ if stats_list is None:
258
+ assert stats_type == "eigs"
259
+ continue
260
+
261
+ def get_count(count):
262
+ return 1 if stats_type in ["max", "min"] else count
263
+
264
+ if len(stats_list) == 1:
265
+ stats = stats_list[0].tensor / get_count(stats_list[0].count)
266
+ else:
267
+ # a dimension that has variable size in different nnet
268
+ # forwards, e.g. a time dimension in an ASR model.
269
+ stats = torch.cat(
270
+ [x.tensor / get_count(x.count) for x in stats_list], dim=0
271
+ )
272
+
273
+ if stats_type == "eigs":
274
+ try:
275
+ if hasattr(torch, "linalg") and hasattr(torch.linalg, "eigh"):
276
+ eigs, _ = torch.linalg.eigh(stats)
277
+ else:
278
+ eigs, _ = torch.symeig(stats)
279
+ stats = eigs.abs().sqrt()
280
+ except: # noqa
281
+ print("Error getting eigenvalues, trying another method.")
282
+ if hasattr(torch, "linalg") and hasattr(torch.linalg, "eig"):
283
+ eigs, _ = torch.linalg.eig(stats)
284
+ eigs = eigs.abs()
285
+ else:
286
+ eigs, _ = torch.eig(stats)
287
+ eigs = eigs.norm(dim=1)
288
+ stats = eigs.sqrt()
289
+ # sqrt so it reflects data magnitude, like stddev- not variance
290
+
291
+ if stats_type in ["rms", "stddev"]:
292
+ # we stored the square; after aggregation we need to take sqrt.
293
+ stats = stats.sqrt()
294
+
295
+ # if `summarize` we print percentiles of the stats; else,
296
+ # we print out individual elements.
297
+ summarize = (len(stats_list) > 1) or self.opts.dim_is_summarized(
298
+ stats.numel()
299
+ )
300
+ if summarize: # usually `summarize` will be true
301
+ # print out percentiles.
302
+ stats = stats.sort()[0]
303
+ num_percentiles = 10
304
+ size = stats.numel()
305
+ percentiles = []
306
+ for i in range(num_percentiles + 1):
307
+ index = (i * (size - 1)) // num_percentiles
308
+ percentiles.append(stats[index].item())
309
+ percentiles = ["%.2g" % x for x in percentiles]
310
+ percentiles = " ".join(percentiles)
311
+ ans = f"percentiles: [{percentiles}]"
312
+ else:
313
+ ans = stats.tolist()
314
+ ans = ["%.2g" % x for x in ans]
315
+ ans = "[" + " ".join(ans) + "]"
316
+ if stats_type in ["value", "rms", "stddev", "eigs"]:
317
+ # This norm is useful because it is strictly less than the largest
318
+ # sqrt(eigenvalue) of the variance, which we print out, and shows,
319
+ # speaking in an approximate way, how much of that largest
320
+ # eigenvalue can be attributed to the mean of the distribution.
321
+ norm = (stats**2).sum().sqrt().item()
322
+ ans += f", norm={norm:.2g}"
323
+ mean = stats.mean().item()
324
+ rms = (stats**2).mean().sqrt().item()
325
+ ans += f", mean={mean:.3g}, rms={rms:.3g}"
326
+
327
+ # OK, "ans" contains the actual stats, e.g.
328
+ # ans = "percentiles: \
329
+ # [0.43 0.46 0.48 0.49 0.49 0.5 0.51 0.52 0.53 0.54 0.59], \
330
+ # mean=0.5, rms=0.5"
331
+
332
+ sizes = [x.tensor.shape[0] for x in stats_list]
333
+ size_str = (
334
+ f"{sizes[0]}" if len(sizes) == 1 else f"{min(sizes)}..{max(sizes)}"
335
+ )
336
+ maybe_class_name = (
337
+ f" type={self.class_name}," if self.class_name is not None else ""
338
+ )
339
+ print(
340
+ f"module={self.name},{maybe_class_name} dim={dim}, size={size_str}, "
341
+ f"{stats_type} {ans}"
342
+ )
343
+
344
+
345
+ class ScalarDiagnostic(object):
346
+ """This class is not directly used by the user, it is responsible for
347
+ collecting diagnostics for a single module (subclass of torch.nn.Module) that
348
+ represents some kind of nonlinearity, e.g. ReLU, sigmoid, etc.
349
+ """
350
+
351
+ def __init__(self, opts: TensorDiagnosticOptions, name: str):
352
+ self.opts = opts
353
+ self.name = name
354
+ self.class_name = None # will assign in accumulate()
355
+ self.is_forward_pass = True
356
+
357
+ self.tick_scale = None
358
+
359
+ self.saved_inputs = []
360
+ self.is_ok = True
361
+
362
+ self.counts = None
363
+ self.sum_grad = None
364
+ self.sum_gradsq = None
365
+ self.sum_abs_grad = None
366
+
367
+ def accumulate_input(self, x: Tensor, class_name: Optional[str] = None):
368
+ """
369
+ Called in forward pass.
370
+ """
371
+ if not self.is_forward_pass:
372
+ # in case we did a forward pass without a backward pass, for some reason.
373
+ self.saved_inputs = []
374
+ self.is_forward_pass = True
375
+
376
+ if class_name is not None:
377
+ self.class_name = class_name
378
+ if not self.is_ok:
379
+ return
380
+
381
+ limit = 10
382
+ if len(self.saved_inputs) > limit:
383
+ print(
384
+ f"ERROR: forward pass called for this module over {limit} times "
385
+ f"with no backward pass. Will not accumulate scalar stats."
386
+ )
387
+ self.is_ok = False
388
+ return
389
+ self.saved_inputs.append(x)
390
+
391
+ def accumulate_output_grad(self, grad: Tensor):
392
+ if not self.is_ok:
393
+ return
394
+ if self.is_forward_pass:
395
+ self.is_forward_pass = False
396
+
397
+ last_shape = (
398
+ "n/a" if len(self.saved_inputs) == 0 else self.saved_inputs[-1].shape
399
+ )
400
+ if len(self.saved_inputs) == 0 or grad.shape != last_shape:
401
+ print(
402
+ f"ERROR: shape mismatch or no forward activation present when backward "
403
+ f"pass called: grad shape ={tuple(grad.shape)}"
404
+ f", num-saved-inputs={len(self.saved_inputs)}"
405
+ f", shape-of-last-saved-input={last_shape}"
406
+ )
407
+ self.is_ok = False
408
+ return
409
+
410
+ x = self.saved_inputs.pop()
411
+ self.process_input_and_grad(x, grad)
412
+
413
+ def process_input_and_grad(self, x: Tensor, grad: Tensor):
414
+ assert x.shape == grad.shape
415
+ x = x.flatten()
416
+ grad = grad.flatten()
417
+
418
+ num_ticks_per_side = 256
419
+
420
+ if self.tick_scale is None:
421
+ x_abs_sorted = x.abs().sort()[0]
422
+ # take the 98th percentile as the largest value we count separately.
423
+ index = int(x.numel() * 0.98)
424
+ self.tick_scale = float(x_abs_sorted[index] / num_ticks_per_side)
425
+
426
+ # integerize from tick * (-num ticks_per_side .. num_ticks_per_side - 1]
427
+ self.counts = torch.zeros(
428
+ 2 * num_ticks_per_side, dtype=torch.long, device=x.device
429
+ )
430
+ self.sum_grad = torch.zeros(
431
+ 2 * num_ticks_per_side, dtype=torch.double, device=x.device
432
+ )
433
+ # sum_gradsq is for getting error bars.
434
+ self.sum_gradsq = torch.zeros(
435
+ 2 * num_ticks_per_side, dtype=torch.double, device=x.device
436
+ )
437
+ self.sum_abs_grad = torch.zeros(
438
+ 2 * num_ticks_per_side, dtype=torch.double, device=x.device
439
+ )
440
+
441
+ # this will round down.
442
+ x = (x / self.tick_scale).to(torch.long)
443
+ x = x.clamp_(min=-num_ticks_per_side, max=num_ticks_per_side - 1)
444
+ x = x + num_ticks_per_side
445
+
446
+ self.counts.index_add_(dim=0, index=x, source=torch.ones_like(x))
447
+ self.sum_grad.index_add_(dim=0, index=x, source=grad.to(torch.double))
448
+ self.sum_gradsq.index_add_(
449
+ dim=0, index=x, source=(grad * grad).to(torch.double)
450
+ )
451
+ self.sum_abs_grad.index_add_(dim=0, index=x, source=grad.abs().to(torch.double))
452
+
453
+ def print_diagnostics(self):
454
+ """Print diagnostics."""
455
+ if self.is_ok is False or self.counts is None:
456
+ print(f"Warning: no stats accumulated for {self.name}, is_ok={self.is_ok}")
457
+ return
458
+
459
+ counts = self.counts.to("cpu")
460
+ sum_grad = self.sum_grad.to(device="cpu", dtype=torch.float32)
461
+ sum_gradsq = self.sum_gradsq.to(device="cpu", dtype=torch.float32)
462
+ sum_abs_grad = self.sum_abs_grad.to(device="cpu", dtype=torch.float32)
463
+
464
+ counts_cumsum = counts.cumsum(dim=0)
465
+ counts_tot = counts_cumsum[-1]
466
+
467
+ # subdivide the distribution up into `num_bins` intervals for analysis, for
468
+ # greater statistical significance. each bin corresponds to multiple of the
469
+ # original 'tick' intervals.
470
+ num_bins = 20
471
+
472
+ # integer division
473
+ counts_per_bin = (counts_tot // num_bins) + 1
474
+ bin_indexes = counts_cumsum // counts_per_bin
475
+ bin_indexes = bin_indexes.clamp(min=0, max=num_bins).to(torch.long)
476
+
477
+ bin_counts = torch.zeros(num_bins, dtype=torch.long)
478
+ bin_counts.index_add_(dim=0, index=bin_indexes, source=counts)
479
+ bin_grad = torch.zeros(num_bins)
480
+ bin_grad.index_add_(dim=0, index=bin_indexes, source=sum_grad)
481
+ bin_gradsq = torch.zeros(num_bins)
482
+ bin_gradsq.index_add_(dim=0, index=bin_indexes, source=sum_gradsq)
483
+ bin_abs_grad = torch.zeros(num_bins)
484
+ bin_abs_grad.index_add_(dim=0, index=bin_indexes, source=sum_abs_grad)
485
+
486
+ bin_boundary_counts = (
487
+ torch.arange(num_bins + 1, dtype=torch.long) * counts_per_bin
488
+ )
489
+ bin_tick_indexes = torch.searchsorted(counts_cumsum, bin_boundary_counts)
490
+ # boundaries are the "x" values between the bins, e.g. corresponding to the
491
+ # locations of percentiles of the distribution.
492
+ num_ticks_per_side = counts.numel() // 2
493
+ bin_boundaries = (bin_tick_indexes - num_ticks_per_side) * self.tick_scale
494
+
495
+ bin_grad = bin_grad / (bin_counts + 1)
496
+ bin_conf_interval = bin_gradsq.sqrt() / (
497
+ bin_counts + 1
498
+ ) # consider this a standard deviation.
499
+ # bin_grad / bin_abs_grad will give us a sense for how important in a practical
500
+ # sense, the gradients are.
501
+ bin_abs_grad = bin_abs_grad / (bin_counts + 1)
502
+
503
+ bin_rel_grad = bin_grad / (bin_abs_grad + 1.0e-20)
504
+ bin_conf = bin_grad / (bin_conf_interval + 1.0e-20)
505
+
506
+ def tensor_to_str(x: Tensor):
507
+ x = ["%.2g" % f for f in x]
508
+ x = "[" + " ".join(x) + "]"
509
+ return x
510
+
511
+ maybe_class_name = (
512
+ f" type={self.class_name}," if self.class_name is not None else ""
513
+ )
514
+
515
+ print(
516
+ f"module={self.name},{maybe_class_name} "
517
+ f"bin-boundaries={tensor_to_str(bin_boundaries)}, "
518
+ f"rel_grad={tensor_to_str(bin_rel_grad)}, "
519
+ f"grad_conf={tensor_to_str(bin_conf)}"
520
+ )
521
+
522
+
523
+ class ModelDiagnostic(object):
524
+ """This class stores diagnostics for all tensors in the torch.nn.Module.
525
+
526
+ Args:
527
+ opts:
528
+ Options object.
529
+ """
530
+
531
+ def __init__(self, opts: Optional[TensorDiagnosticOptions] = None):
532
+ # In this dictionary, the keys are tensors names and the values
533
+ # are corresponding TensorDiagnostic objects.
534
+ if opts is None:
535
+ self.opts = TensorDiagnosticOptions()
536
+ else:
537
+ self.opts = opts
538
+ self.diagnostics = dict()
539
+
540
+ def __getitem__(self, name: str):
541
+ T = ScalarDiagnostic if name[-7:] == ".scalar" else TensorDiagnostic
542
+ if name not in self.diagnostics:
543
+ self.diagnostics[name] = T(self.opts, name)
544
+ return self.diagnostics[name]
545
+
546
+ def print_diagnostics(self):
547
+ """Print diagnostics for each tensor."""
548
+ for k in sorted(self.diagnostics.keys()):
549
+ self.diagnostics[k].print_diagnostics()
550
+
551
+
552
+ def get_class_name(module: nn.Module):
553
+ ans = type(module).__name__
554
+ # we put the below in try blocks in case anyone is using a different version of
555
+ # these modules that might have different member names.
556
+ if ans == "Balancer" or ans == "ActivationBalancer":
557
+ try:
558
+ ans += f"[{float(module.min_positive)},{float(module.max_positive)},"
559
+ f"{float(module.min_abs)},{float(module.max_abs)}]"
560
+ except:
561
+ pass
562
+ elif ans == "AbsValuePenalizer":
563
+ try:
564
+ ans += f"[{module.limit}]"
565
+ except:
566
+ pass
567
+ return ans
568
+
569
+
570
+ def attach_diagnostics(
571
+ model: nn.Module, opts: Optional[TensorDiagnosticOptions] = None
572
+ ) -> ModelDiagnostic:
573
+ """Attach a ModelDiagnostic object to the model by
574
+ 1) registering forward hook and backward hook on each module, to accumulate
575
+ its output tensors and gradient tensors, respectively;
576
+ 2) registering backward hook on each module parameter, to accumulate its
577
+ values and gradients.
578
+
579
+ Args:
580
+ model:
581
+ the model to be analyzed.
582
+ opts:
583
+ Options object.
584
+
585
+ Returns:
586
+ The ModelDiagnostic object attached to the model.
587
+ """
588
+
589
+ ans = ModelDiagnostic(opts)
590
+ for name, module in model.named_modules():
591
+ if name == "":
592
+ name = "<top-level>"
593
+
594
+ # Setting model_diagnostic=ans and n=name below, instead of trying to
595
+ # capture the variables, ensures that we use the current values.
596
+ # (this matters for `name`, since the variable gets overwritten).
597
+ # These closures don't really capture by value, only by
598
+ # "the final value the variable got in the function" :-(
599
+ def forward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name):
600
+ if isinstance(_output, tuple) and len(_output) == 1:
601
+ _output = _output[0]
602
+
603
+ if isinstance(_output, Tensor) and _output.dtype in (
604
+ torch.float32,
605
+ torch.float16,
606
+ torch.float64,
607
+ ):
608
+ _model_diagnostic[f"{_name}.output"].accumulate(
609
+ _output, class_name=get_class_name(_module)
610
+ )
611
+ elif isinstance(_output, tuple):
612
+ for i, o in enumerate(_output):
613
+ if isinstance(o, Tensor) and o.dtype in (
614
+ torch.float32,
615
+ torch.float16,
616
+ torch.float64,
617
+ ):
618
+ _model_diagnostic[f"{_name}.output[{i}]"].accumulate(
619
+ o, class_name=get_class_name(_module)
620
+ )
621
+
622
+ def backward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name):
623
+ if isinstance(_output, tuple) and len(_output) == 1:
624
+ _output = _output[0]
625
+ if isinstance(_output, Tensor) and _output.dtype in (
626
+ torch.float32,
627
+ torch.float16,
628
+ torch.float64,
629
+ ):
630
+ _model_diagnostic[f"{_name}.grad"].accumulate(
631
+ _output, class_name=get_class_name(_module)
632
+ )
633
+ elif isinstance(_output, tuple):
634
+ for i, o in enumerate(_output):
635
+ if isinstance(o, Tensor) and o.dtype in (
636
+ torch.float32,
637
+ torch.float16,
638
+ torch.float64,
639
+ ):
640
+ _model_diagnostic[f"{_name}.grad[{i}]"].accumulate(
641
+ o, class_name=get_class_name(_module)
642
+ )
643
+
644
+ module.register_forward_hook(forward_hook)
645
+ module.register_backward_hook(backward_hook)
646
+
647
+ if type(module).__name__ in [
648
+ "Sigmoid",
649
+ "Tanh",
650
+ "ReLU",
651
+ "TanSwish",
652
+ "Swish",
653
+ "DoubleSwish",
654
+ "Swoosh",
655
+ ]:
656
+ # For these specific module types, accumulate some additional diagnostics
657
+ # that can help us improve the activation function. These require a lot of
658
+ # memory, to save the forward activations, so limit this to some select
659
+ # classes. Note: this will not work correctly for all model types.
660
+ def scalar_forward_hook(
661
+ _module, _input, _output, _model_diagnostic=ans, _name=name
662
+ ):
663
+ if isinstance(_input, tuple):
664
+ (_input,) = _input
665
+ assert isinstance(_input, Tensor)
666
+ _model_diagnostic[f"{_name}.scalar"].accumulate_input(
667
+ _input, class_name=get_class_name(_module)
668
+ )
669
+
670
+ def scalar_backward_hook(
671
+ _module, _input, _output, _model_diagnostic=ans, _name=name
672
+ ):
673
+ if isinstance(_output, tuple):
674
+ (_output,) = _output
675
+ assert isinstance(_output, Tensor)
676
+ _model_diagnostic[f"{_name}.scalar"].accumulate_output_grad(_output)
677
+
678
+ module.register_forward_hook(scalar_forward_hook)
679
+ module.register_backward_hook(scalar_backward_hook)
680
+
681
+ for name, parameter in model.named_parameters():
682
+
683
+ def param_backward_hook(
684
+ grad, _parameter=parameter, _model_diagnostic=ans, _name=name
685
+ ):
686
+ _model_diagnostic[f"{_name}.param_value"].accumulate(_parameter)
687
+ _model_diagnostic[f"{_name}.param_grad"].accumulate(grad)
688
+
689
+ try:
690
+ parameter.register_hook(param_backward_hook)
691
+ except:
692
+ logging.warning(
693
+ f"Warning: could not register backward hook for parameter {name}, "
694
+ f"it might not be differentiable."
695
+ )
696
+
697
+ return ans
698
+
699
+
700
+ def _test_tensor_diagnostic():
701
+ opts = TensorDiagnosticOptions(512)
702
+
703
+ diagnostic = TensorDiagnostic(opts, "foo")
704
+
705
+ for _ in range(10):
706
+ diagnostic.accumulate(torch.randn(50, 100) * 10.0)
707
+
708
+ diagnostic.print_diagnostics()
709
+
710
+ model = nn.Sequential(nn.Linear(100, 50), nn.ReLU(), nn.Linear(50, 80))
711
+
712
+ diagnostic = attach_diagnostics(model, opts)
713
+ for _ in range(10):
714
+ T = random.randint(200, 300)
715
+ x = torch.randn(T, 100)
716
+ y = model(x)
717
+ y.sum().backward()
718
+
719
+ diagnostic.print_diagnostics()
720
+
721
+
722
+ if __name__ == "__main__":
723
+ _test_tensor_diagnostic()
zipvoice/utils/feature.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2024 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ from dataclasses import dataclass
19
+ from typing import Union
20
+
21
+ import numpy as np
22
+ import torch
23
+ import torchaudio
24
+ from lhotse.features.base import FeatureExtractor, register_extractor
25
+ from lhotse.utils import Seconds, compute_num_frames
26
+
27
+
28
+ @dataclass
29
+ class VocosFbankConfig:
30
+ sampling_rate: int = 24000
31
+ n_mels: int = 100
32
+ n_fft: int = 1024
33
+ hop_length: int = 256
34
+
35
+
36
+ @register_extractor
37
+ class VocosFbank(FeatureExtractor):
38
+
39
+ name = "VocosFbank"
40
+ config_type = VocosFbankConfig
41
+
42
+ def __init__(self, num_channels: int = 1):
43
+ config = VocosFbankConfig
44
+ super().__init__(config=config)
45
+ assert num_channels in (1, 2)
46
+ self.num_channels = num_channels
47
+ self.fbank = torchaudio.transforms.MelSpectrogram(
48
+ sample_rate=self.config.sampling_rate,
49
+ n_fft=self.config.n_fft,
50
+ hop_length=self.config.hop_length,
51
+ n_mels=self.config.n_mels,
52
+ center=True,
53
+ power=1,
54
+ )
55
+
56
+ def _feature_fn(self, sample):
57
+ mel = self.fbank(sample)
58
+ logmel = mel.clamp(min=1e-7).log()
59
+
60
+ return logmel
61
+
62
+ @property
63
+ def device(self) -> Union[str, torch.device]:
64
+ return self.config.device
65
+
66
+ def feature_dim(self, sampling_rate: int) -> int:
67
+ return self.config.n_mels
68
+
69
+ def extract(
70
+ self,
71
+ samples: Union[np.ndarray, torch.Tensor],
72
+ sampling_rate: int,
73
+ ) -> Union[np.ndarray, torch.Tensor]:
74
+ # Check for sampling rate compatibility.
75
+ expected_sr = self.config.sampling_rate
76
+ assert sampling_rate == expected_sr, (
77
+ f"Mismatched sampling rate: extractor expects {expected_sr}, "
78
+ f"got {sampling_rate}"
79
+ )
80
+ is_numpy = False
81
+ if not isinstance(samples, torch.Tensor):
82
+ samples = torch.from_numpy(samples)
83
+ is_numpy = True
84
+
85
+ if len(samples.shape) == 1:
86
+ samples = samples.unsqueeze(0)
87
+ else:
88
+ assert samples.ndim == 2, samples.shape
89
+
90
+ if self.num_channels == 1:
91
+ if samples.shape[0] == 2:
92
+ samples = samples.mean(dim=0, keepdims=True)
93
+ else:
94
+ assert samples.shape[0] == 2, samples.shape
95
+
96
+ mel = self._feature_fn(samples)
97
+ # (1, n_mels, time) or (2, n_mels, time)
98
+ mel = mel.reshape(-1, mel.shape[-1]).t()
99
+ # (time, n_mels) or (time, 2 * n_mels)
100
+
101
+ num_frames = compute_num_frames(
102
+ samples.shape[1] / sampling_rate, self.frame_shift, sampling_rate
103
+ )
104
+
105
+ if mel.shape[0] > num_frames:
106
+ mel = mel[:num_frames]
107
+ elif mel.shape[0] < num_frames:
108
+ mel = mel.unsqueeze(0)
109
+ mel = torch.nn.functional.pad(
110
+ mel, (0, 0, 0, num_frames - mel.shape[1]), mode="replicate"
111
+ ).squeeze(0)
112
+
113
+ if is_numpy:
114
+ return mel.cpu().numpy()
115
+ else:
116
+ return mel
117
+
118
+ @property
119
+ def frame_shift(self) -> Seconds:
120
+ return self.config.hop_length / self.config.sampling_rate
zipvoice/utils/hooks.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021-2024 Xiaomi Corporation (authors: Zengwei Yao,
2
+ # Daniel Povey,
3
+ # Zengrui Jin,)
4
+ #
5
+ # See ../../LICENSE for clarification regarding multiple authors
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+
19
+ import logging
20
+ import random
21
+
22
+ import torch
23
+ from torch import Tensor, nn
24
+
25
+
26
+ def register_inf_check_hooks(model: nn.Module) -> None:
27
+ """Registering forward hook on each module, to check
28
+ whether its output tensors is not finite.
29
+
30
+ Args:
31
+ model:
32
+ the model to be analyzed.
33
+ """
34
+
35
+ for name, module in model.named_modules():
36
+ if name == "":
37
+ name = "<top-level>"
38
+
39
+ # default param _name is a way to capture the current value of the variable
40
+ # "name".
41
+ def forward_hook(_module, _input, _output, _name=name):
42
+ if isinstance(_output, Tensor):
43
+ try:
44
+ if not torch.isfinite(_output.to(torch.float32).sum()):
45
+ logging.warning(f"The sum of {_name}.output is not finite")
46
+ except RuntimeError: # e.g. CUDA out of memory
47
+ pass
48
+ elif isinstance(_output, tuple):
49
+ for i, o in enumerate(_output):
50
+ if isinstance(o, tuple):
51
+ o = o[0]
52
+ if not isinstance(o, Tensor):
53
+ continue
54
+ try:
55
+ if not torch.isfinite(o.to(torch.float32).sum()):
56
+ logging.warning(
57
+ f"The sum of {_name}.output[{i}] is not finite"
58
+ )
59
+ except RuntimeError: # e.g. CUDA out of memory
60
+ pass
61
+
62
+ # default param _name is a way to capture the current value of the variable
63
+ # "name".
64
+ def backward_hook(_module, _input, _output, _name=name):
65
+ if isinstance(_output, Tensor):
66
+ try:
67
+ if not torch.isfinite(_output.to(torch.float32).sum()):
68
+ logging.warning(f"The sum of {_name}.grad is not finite")
69
+ except RuntimeError: # e.g. CUDA out of memory
70
+ pass
71
+
72
+ elif isinstance(_output, tuple):
73
+ for i, o in enumerate(_output):
74
+ if isinstance(o, tuple):
75
+ o = o[0]
76
+ if not isinstance(o, Tensor):
77
+ continue
78
+ if not torch.isfinite(o.to(torch.float32).sum()):
79
+ logging.warning(f"The sum of {_name}.grad[{i}] is not finite")
80
+
81
+ module.register_forward_hook(forward_hook)
82
+ module.register_backward_hook(backward_hook)
83
+
84
+ for name, parameter in model.named_parameters():
85
+
86
+ def param_backward_hook(grad, _name=name):
87
+ if not torch.isfinite(grad.to(torch.float32).sum()):
88
+ logging.warning(f"The sum of {_name}.param_grad is not finite")
89
+
90
+ try:
91
+ parameter.register_hook(param_backward_hook)
92
+ except Exception as e:
93
+ logging.warning(
94
+ f"Warning: could not register backward hook for parameter {name}"
95
+ f" with error {e}, it might not be differentiable."
96
+ )
97
+
98
+
99
+ def _test_inf_check_hooks():
100
+ model = nn.Sequential(nn.Linear(100, 50), nn.Linear(50, 80))
101
+
102
+ register_inf_check_hooks(model)
103
+ for _ in range(10):
104
+ T = random.randint(200, 300)
105
+ x = torch.randn(T, 100) + float("inf") * (T % 2)
106
+ y = model(x)
107
+ y.sum().backward()
108
+
109
+
110
+ if __name__ == "__main__":
111
+ _test_inf_check_hooks()
zipvoice/utils/infer.py ADDED
@@ -0,0 +1,414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torchaudio
6
+ from pydub import AudioSegment
7
+ from pydub.silence import detect_leading_silence, split_on_silence
8
+
9
+ punctuation = {";", ":", ",", ".", "!", "?", ";", ":", ",", "。", "!", "?"}
10
+
11
+
12
+ def chunk_tokens_punctuation(tokens_list: List[str], max_tokens: int = 100):
13
+ """
14
+ Splits the input tokens list into chunks according to punctuations,
15
+ each with a maximum number of tokens.
16
+
17
+ Args:
18
+ token_list (list of str): The list of tokens to be split.
19
+ max_tokens (int): The maximum number of tokens per chunk.
20
+
21
+ Returns:
22
+ List[str]: A list of text chunks.
23
+ """
24
+
25
+ # 1. Split the tokens according to punctuations.
26
+ sentences = []
27
+ current_sentence = []
28
+ for token in tokens_list:
29
+ # If the first token of current sentence is punctuation or blank,
30
+ # append it to the end of the previous sentence.
31
+ if (
32
+ len(current_sentence) == 0
33
+ and len(sentences) != 0
34
+ and (token in punctuation or token == " ")
35
+ ):
36
+ sentences[-1].append(token)
37
+ # Otherwise, append the current token to the current sentence.
38
+ else:
39
+ current_sentence.append(token)
40
+ # Split the sentence in positions of punctuations.
41
+ if token in punctuation:
42
+ sentences.append(current_sentence)
43
+ current_sentence = []
44
+ # Assume the last few tokens are also a sentence
45
+ if len(current_sentence) != 0:
46
+ sentences.append(current_sentence)
47
+
48
+ # 2. Merge short sentences.
49
+ chunks = []
50
+ current_chunk = []
51
+ for sentence in sentences:
52
+ if len(current_chunk) + len(sentence) <= max_tokens:
53
+ current_chunk.extend(sentence)
54
+ else:
55
+ if len(current_chunk) > 0:
56
+ chunks.append(current_chunk)
57
+ current_chunk = sentence
58
+
59
+ if len(current_chunk) > 0:
60
+ chunks.append(current_chunk)
61
+
62
+ return chunks
63
+
64
+
65
+ def chunk_tokens_dialog(tokens_list: List[str], max_tokens: int = 100):
66
+ """
67
+ Splits the input tokens list into chunks according to speaker-turn
68
+ symbol [S1], each with a maximum number of tokens.
69
+
70
+ Args:
71
+ token_list (list of str): The list of tokens to be split.
72
+ max_tokens (int): The maximum number of tokens per chunk.
73
+
74
+ Returns:
75
+ List[str]: A list of text chunks.
76
+ """
77
+
78
+ # 1. Split the tokens according to speaker-turn symbol [S1].
79
+ dialogs = []
80
+ current_dialog = []
81
+ for token in tokens_list:
82
+ if token == "[S1]":
83
+ if len(current_dialog) != 0:
84
+ dialogs.append(current_dialog)
85
+ current_dialog = []
86
+ current_dialog.append(token)
87
+ # Assume the last few tokens are also a dialog
88
+ if len(current_dialog) != 0:
89
+ dialogs.append(current_dialog)
90
+
91
+ # 2. Merge short dialogs.
92
+ chunks = []
93
+ current_chunk = []
94
+ for dialog in dialogs:
95
+ if len(current_chunk) + len(dialog) <= max_tokens:
96
+ current_chunk.extend(dialog)
97
+ else:
98
+ if len(current_chunk) > 0:
99
+ chunks.append(current_chunk)
100
+ current_chunk = dialog
101
+
102
+ if len(current_chunk) > 0:
103
+ chunks.append(current_chunk)
104
+
105
+ return chunks
106
+
107
+
108
+ def batchify_tokens(
109
+ tokens_list: List[List[int]],
110
+ max_duration: float,
111
+ prompt_duration: float,
112
+ token_duration: float,
113
+ ):
114
+ """
115
+ Sort and group the input list of token sequences into batches, where each batch's
116
+ total duration does not exceed the maximum.
117
+
118
+ Args:
119
+ tokens_list (List[List[int]]): A list of token sequences, where each inner
120
+ list represents a sequence of tokens.
121
+ max_duration (float): The maximum allowed total duration for each batch.
122
+ prompt_duration (float): The duration cost per prompt in the batch.
123
+ token_duration (float): The duration cost per token.
124
+
125
+ Returns:
126
+ batches: List[List[List[int]]]: A list of batches, where each batch is a list of
127
+ token sequences that fit within the max duration.
128
+ index: List[int]: The original index of each sentence, used to recover the
129
+ sequential order in the future.
130
+ """
131
+ # Create index for each sentence
132
+ indexed_tokens = list(enumerate(tokens_list))
133
+
134
+ # Sort according to sentence length (for less padding)
135
+ indexed_sorted_tokens = sorted(indexed_tokens, key=lambda x: len(x[1]))
136
+ index = [indexed_sorted_tokens[i][0] for i in range(len(indexed_sorted_tokens))]
137
+ sorted_tokens = [
138
+ indexed_sorted_tokens[i][1] for i in range(len(indexed_sorted_tokens))
139
+ ]
140
+
141
+ batches = []
142
+ batch = []
143
+ batch_size = 0 # Total number of tokens in current batch
144
+
145
+ for tokens in sorted_tokens:
146
+ # Calculate if adding current token sequence would exceed max duration
147
+ # Formula considers: existing tokens' duration + existing
148
+ # prompts' duration + new tokens' duration
149
+ if (
150
+ batch_size * token_duration
151
+ + len(batch) * prompt_duration
152
+ + len(tokens) * token_duration
153
+ <= max_duration
154
+ ):
155
+ # Add to current batch if within duration limit
156
+ batch.append(tokens)
157
+ batch_size += len(tokens)
158
+ else:
159
+ # If exceeding limit, finalize current batch (if not empty)
160
+ if len(batch) > 0:
161
+ batches.append(batch)
162
+ # Start new batch with current token sequence
163
+ batch = [tokens]
164
+ batch_size = len(tokens)
165
+
166
+ # Add the last batch if it's not empty
167
+ if len(batch) > 0:
168
+ batches.append(batch)
169
+
170
+ return batches, index
171
+
172
+
173
+ def cross_fade_concat(
174
+ chunks: List[torch.Tensor], fade_duration: float = 0.1, sample_rate: int = 24000
175
+ ) -> torch.Tensor:
176
+ """
177
+ Concatenates audio chunks with cross-fading between consecutive chunks.
178
+
179
+ Args:
180
+ chunks: List of audio tensors, each with shape (C, T) where
181
+ C = number of channel, T = time dimension (samples)
182
+ fade_duration: Duration of cross-fade in seconds
183
+ sample_rate: Audio sample rate in Hz
184
+
185
+ Returns:
186
+ Concatenated audio tensor with shape (N, T_total)
187
+ """
188
+ # Handle edge cases: empty input or single chunk
189
+ if len(chunks) <= 1:
190
+ return chunks[0] if chunks else torch.tensor([])
191
+
192
+ # Calculate total fade samples from duration and sample rate
193
+ fade_samples = int(fade_duration * sample_rate)
194
+
195
+ # Use simple concatenation if fade duration is non-positive
196
+ if fade_samples <= 0:
197
+ return torch.cat(chunks, dim=-1)
198
+
199
+ # Initialize final tensor with the first chunk
200
+ final = chunks[0]
201
+
202
+ # Iterate through remaining chunks to apply cross-fading
203
+ for next_chunk in chunks[1:]:
204
+ # Calculate safe fade length (cannot exceed either chunk's duration)
205
+ k = min(fade_samples, final.shape[-1], next_chunk.shape[-1])
206
+
207
+ # Fall back to simple concatenation if safe fade length is invalid
208
+ if k <= 0:
209
+ final = torch.cat([final, next_chunk], dim=-1)
210
+ continue
211
+
212
+ # Create fade curve (1 -> 0) with shape (1, k) for broadcasting
213
+ fade = torch.linspace(1, 0, k, device=final.device)[None]
214
+
215
+ # Concatenate three parts:
216
+ # 1. Non-overlapping part of previous audio
217
+ # 2. Cross-faded overlapping region
218
+ # 3. Non-overlapping part of next audio
219
+ final = torch.cat(
220
+ [
221
+ final[..., :-k], # All samples except last k from previous
222
+ final[..., -k:] * fade
223
+ + next_chunk[..., :k] * (1 - fade), # Cross-fade region
224
+ next_chunk[..., k:], # All samples except first k from next
225
+ ],
226
+ dim=-1,
227
+ )
228
+
229
+ return final
230
+
231
+
232
+ def add_punctuation(text: str):
233
+ """Add punctuation if there is not in the end of text"""
234
+ text = text.strip()
235
+ if text[-1] not in punctuation:
236
+ text += "."
237
+ return text
238
+
239
+
240
+ def load_prompt_wav(prompt_wav: str, sampling_rate: int):
241
+ """
242
+ Load the waveform with torchaudio and resampling if needed.
243
+
244
+ Parameters:
245
+ prompt_wav: path of the prompt wav.
246
+ sampling_rate: target sampling rate.
247
+
248
+ Returns:
249
+ Loaded prompt waveform with target sampling rate,
250
+ PyTorch tensor of shape (C, T)
251
+ """
252
+ prompt_wav, prompt_sampling_rate = torchaudio.load(prompt_wav)
253
+
254
+ if prompt_sampling_rate != sampling_rate:
255
+ resampler = torchaudio.transforms.Resample(
256
+ orig_freq=prompt_sampling_rate, new_freq=sampling_rate
257
+ )
258
+ prompt_wav = resampler(prompt_wav)
259
+ return prompt_wav
260
+
261
+
262
+ def rms_norm(prompt_wav: torch.Tensor, target_rms: float):
263
+ """
264
+ Normalize the rms of prompt_wav is it is smaller than target rms.
265
+
266
+ Parameters:
267
+ prompt_wav: PyTorch tensor with shape (C, T).
268
+ target_rms: target rms value
269
+
270
+ Returns:
271
+ prompt_wav: normalized prompt wav with shape (C, T).
272
+ promt_rms: rms of original prompt wav. Will be used to
273
+ re-normalize the generated wav.
274
+ """
275
+ prompt_rms = torch.sqrt(torch.mean(torch.square(prompt_wav)))
276
+ if prompt_rms < target_rms:
277
+ prompt_wav = prompt_wav * target_rms / prompt_rms
278
+ return prompt_wav, prompt_rms
279
+
280
+
281
+ def remove_silence(
282
+ audio: torch.Tensor,
283
+ sampling_rate: int,
284
+ only_edge: bool = False,
285
+ trail_sil: float = 0,
286
+ ):
287
+ """
288
+ Remove silences longer than 1 second, and edge silences longer than 0.1 seconds
289
+
290
+ Parameters:
291
+ audio: PyTorch tensor with shape (C, T).
292
+ sampling_rate: sampling rate of the audio.
293
+ only_edge: If true, only remove edge silences.
294
+ trail_sil: the duration of added trailing silence in ms.
295
+
296
+ Returns:
297
+ PyTorch tensor with shape (C, T), where C is number of channels
298
+ and T is number of audio samples
299
+ """
300
+ # Load audio file
301
+ wave = tensor_to_audiosegment(audio, sampling_rate)
302
+
303
+ if not only_edge:
304
+ # Split audio using silences longer than 1 second
305
+ non_silent_segs = split_on_silence(
306
+ wave,
307
+ min_silence_len=1000, # Silences longer than 1 second (1000ms)
308
+ silence_thresh=-50,
309
+ keep_silence=1000, # Keep 1.0 second of silence around segments
310
+ seek_step=10,
311
+ )
312
+
313
+ # Concatenate all non-silent segments
314
+ wave = AudioSegment.silent(duration=0)
315
+ for seg in non_silent_segs:
316
+ wave += seg
317
+
318
+ # Remove silence longer than 0.1 seconds in the begining and ending of wave
319
+ wave = remove_silence_edges(wave, 100, -50)
320
+
321
+ # Add trailing silence to avoid leaking prompt to generated speech.
322
+ wave = wave + AudioSegment.silent(duration=trail_sil)
323
+
324
+ # Convert to PyTorch tensor
325
+ return audiosegment_to_tensor(wave)
326
+
327
+
328
+ def remove_silence_edges(
329
+ audio: AudioSegment, keep_silence: int = 100, silence_threshold: float = -50
330
+ ):
331
+ """
332
+ Remove edge silences longer than `keep_silence` ms.
333
+
334
+ Parameters:
335
+ audio: an AudioSegment object.
336
+ keep_silence: kept silence in the edge.
337
+ only_edge: If true, only remove edge silences.
338
+ silence_threshold: the threshold of silence.
339
+
340
+ Returns:
341
+ An AudioSegment object
342
+ """
343
+ # Remove leading silence
344
+ start_idx = detect_leading_silence(audio, silence_threshold=silence_threshold)
345
+ start_idx = max(0, start_idx - keep_silence)
346
+ audio = audio[start_idx:]
347
+
348
+ # Remove trailing silence
349
+ audio = audio.reverse()
350
+ start_idx = detect_leading_silence(audio, silence_threshold=silence_threshold)
351
+ start_idx = max(0, start_idx - keep_silence)
352
+ audio = audio[start_idx:]
353
+ audio = audio.reverse()
354
+
355
+ return audio
356
+
357
+
358
+ def audiosegment_to_tensor(aseg):
359
+ """
360
+ Convert a pydub.AudioSegment to PyTorch audio tensor
361
+ """
362
+ audio_data = np.array(aseg.get_array_of_samples())
363
+
364
+ # Convert to float32 and normalize to [-1, 1] range
365
+ audio_data = audio_data.astype(np.float32) / 32768.0
366
+
367
+ # Handle channels
368
+ if aseg.channels == 1:
369
+ # Mono channel: add channel dimension (T) -> (1, T)
370
+ tensor_data = torch.from_numpy(audio_data).unsqueeze(0)
371
+ else:
372
+ # Multi-channel: reshape to (C, T)
373
+ tensor_data = torch.from_numpy(audio_data.reshape(-1, aseg.channels).T)
374
+
375
+ return tensor_data
376
+
377
+
378
+ def tensor_to_audiosegment(tensor, sample_rate):
379
+ """
380
+ Convert a PyTorch audio tensor to pydub.AudioSegment
381
+
382
+ Parameters:
383
+ tensor: Tensor with shape (C, T), where C is the number of channels
384
+ and T is the time steps
385
+ sample_rate: Audio sample rate
386
+ """
387
+ # Convert tensor to numpy array
388
+ audio_np = tensor.cpu().numpy()
389
+
390
+ # Add channel dimension if single channel
391
+ if audio_np.ndim == 1:
392
+ audio_np = audio_np[np.newaxis, :]
393
+
394
+ # Convert to int16 type (common format for pydub)
395
+ # Assumes tensor values are in [-1, 1] range as floating point
396
+ audio_np = (audio_np * 32768.0).clip(-32768, 32767).astype(np.int16)
397
+
398
+ # Convert to byte stream
399
+ # For multi-channel audio, pydub requires interleaved format
400
+ # (e.g., left-right-left-right)
401
+ if audio_np.shape[0] > 1:
402
+ # Convert to interleaved format
403
+ audio_np = audio_np.transpose(1, 0).flatten()
404
+ audio_bytes = audio_np.tobytes()
405
+
406
+ # Create AudioSegment
407
+ audio_segment = AudioSegment(
408
+ data=audio_bytes,
409
+ sample_width=2,
410
+ frame_rate=sample_rate,
411
+ channels=tensor.shape[0],
412
+ )
413
+
414
+ return audio_segment
zipvoice/utils/lr_scheduler.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
2
+ #
3
+ # See ../LICENSE for clarification regarding multiple authors
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import logging
18
+ from typing import List, Optional, Union
19
+
20
+ import torch
21
+ from torch.optim import Optimizer
22
+
23
+
24
+ class LRScheduler(object):
25
+ """
26
+ Base-class for learning rate schedulers where the learning-rate depends on both the
27
+ batch and the epoch.
28
+ """
29
+
30
+ def __init__(self, optimizer: Optimizer, verbose: bool = False):
31
+ # Attach optimizer
32
+ if not isinstance(optimizer, Optimizer):
33
+ raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__))
34
+ self.optimizer = optimizer
35
+ self.verbose = verbose
36
+
37
+ for group in optimizer.param_groups:
38
+ group.setdefault("base_lr", group["lr"])
39
+
40
+ self.base_lrs = [group["base_lr"] for group in optimizer.param_groups]
41
+
42
+ self.epoch = 0
43
+ self.batch = 0
44
+
45
+ def state_dict(self):
46
+ """Returns the state of the scheduler as a :class:`dict`.
47
+
48
+ It contains an entry for every variable in self.__dict__ which
49
+ is not the optimizer.
50
+ """
51
+ return {
52
+ # the user might try to override the base_lr, so don't include this in the
53
+ # state. previously they were included.
54
+ # "base_lrs": self.base_lrs,
55
+ "epoch": self.epoch,
56
+ "batch": self.batch,
57
+ }
58
+
59
+ def load_state_dict(self, state_dict):
60
+ """Loads the schedulers state.
61
+
62
+ Args:
63
+ state_dict (dict): scheduler state. Should be an object returned
64
+ from a call to :meth:`state_dict`.
65
+ """
66
+ # the things with base_lrs are a work-around for a previous problem
67
+ # where base_lrs were written with the state dict.
68
+ base_lrs = self.base_lrs
69
+ self.__dict__.update(state_dict)
70
+ self.base_lrs = base_lrs
71
+
72
+ def get_last_lr(self) -> List[float]:
73
+ """Return last computed learning rate by current scheduler.
74
+ Will be a list of float."""
75
+ return self._last_lr
76
+
77
+ def get_lr(self):
78
+ # Compute list of learning rates from self.epoch and self.batch and
79
+ # self.base_lrs; this must be overloaded by the user.
80
+ # e.g. return [some_formula(self.batch, self.epoch, base_lr)
81
+ # for base_lr in self.base_lrs ]
82
+ raise NotImplementedError
83
+
84
+ def step_batch(self, batch: Optional[int] = None) -> None:
85
+ # Step the batch index, or just set it. If `batch` is specified, it
86
+ # must be the batch index from the start of training, i.e. summed over
87
+ # all epochs.
88
+ # You can call this in any order; if you don't provide 'batch', it should
89
+ # of course be called once per batch.
90
+ if batch is not None:
91
+ self.batch = batch
92
+ else:
93
+ self.batch = self.batch + 1
94
+ self._set_lrs()
95
+
96
+ def step_epoch(self, epoch: Optional[int] = None):
97
+ # Step the epoch index, or just set it. If you provide the 'epoch' arg, you
98
+ # should call this at the start of the epoch; if you don't provide the 'epoch'
99
+ # arg, you should call it at the end of the epoch.
100
+ if epoch is not None:
101
+ self.epoch = epoch
102
+ else:
103
+ self.epoch = self.epoch + 1
104
+ self._set_lrs()
105
+
106
+ def _set_lrs(self):
107
+ values = self.get_lr()
108
+ assert len(values) == len(self.optimizer.param_groups)
109
+
110
+ for i, data in enumerate(zip(self.optimizer.param_groups, values)):
111
+ param_group, lr = data
112
+ param_group["lr"] = lr
113
+ self.print_lr(self.verbose, i, lr)
114
+ self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
115
+
116
+ def print_lr(self, is_verbose, group, lr):
117
+ """Display the current learning rate."""
118
+ if is_verbose:
119
+ logging.warning(
120
+ f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate"
121
+ f" of group {group} to {lr:.4e}."
122
+ )
123
+
124
+
125
+ class Eden(LRScheduler):
126
+ """
127
+ Eden scheduler.
128
+ The basic formula (before warmup) is:
129
+ lr = base_lr * (((batch**2 + lr_batches**2) / lr_batches**2) ** -0.25 *
130
+ (((epoch**2 + lr_epochs**2) / lr_epochs**2) ** -0.25)) * warmup
131
+ where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches
132
+ and then stays constant at 1.
133
+
134
+ If you don't have the concept of epochs, or one epoch takes a very long time,
135
+ you can replace the notion of 'epoch' with some measure of the amount of data
136
+ processed, e.g. hours of data or frames of data, with 'lr_epochs' being set to
137
+ some measure representing "quite a lot of data": say, one fifth or one third
138
+ of an entire training run, but it doesn't matter much. You could also use
139
+ Eden2 which has only the notion of batches.
140
+
141
+ We suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam
142
+
143
+ Args:
144
+ optimizer: the optimizer to change the learning rates on
145
+ lr_batches: the number of batches after which we start significantly
146
+ decreasing the learning rate, suggest 5000.
147
+ lr_epochs: the number of epochs after which we start significantly
148
+ decreasing the learning rate, suggest 6 if you plan to do e.g.
149
+ 20 to 40 epochs, but may need smaller number if dataset is huge
150
+ and you will do few epochs.
151
+ """
152
+
153
+ def __init__(
154
+ self,
155
+ optimizer: Optimizer,
156
+ lr_batches: Union[int, float],
157
+ lr_epochs: Union[int, float],
158
+ warmup_batches: Union[int, float] = 500.0,
159
+ warmup_start: float = 0.5,
160
+ verbose: bool = False,
161
+ ):
162
+ super(Eden, self).__init__(optimizer, verbose)
163
+ self.lr_batches = lr_batches
164
+ self.lr_epochs = lr_epochs
165
+ self.warmup_batches = warmup_batches
166
+
167
+ assert 0.0 <= warmup_start <= 1.0, warmup_start
168
+ self.warmup_start = warmup_start
169
+
170
+ def get_lr(self):
171
+ factor = (
172
+ (self.batch**2 + self.lr_batches**2) / self.lr_batches**2
173
+ ) ** -0.25 * (
174
+ ((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25
175
+ )
176
+ warmup_factor = (
177
+ 1.0
178
+ if self.batch >= self.warmup_batches
179
+ else self.warmup_start
180
+ + (1.0 - self.warmup_start) * (self.batch / self.warmup_batches)
181
+ # else 0.5 + 0.5 * (self.batch / self.warmup_batches)
182
+ )
183
+
184
+ return [x * factor * warmup_factor for x in self.base_lrs]
185
+
186
+
187
+ class FixedLRScheduler(LRScheduler):
188
+ """
189
+ Fixed learning rate scheduler.
190
+
191
+ Args:
192
+ optimizer: the optimizer to change the learning rates on
193
+ """
194
+
195
+ def __init__(
196
+ self,
197
+ optimizer: Optimizer,
198
+ verbose: bool = False,
199
+ ):
200
+ super(FixedLRScheduler, self).__init__(optimizer, verbose)
201
+
202
+ def get_lr(self):
203
+
204
+ return [x for x in self.base_lrs]
205
+
206
+
207
+ def _test_eden():
208
+ m = torch.nn.Linear(100, 100)
209
+ from zipvoice.utils.optim import ScaledAdam
210
+
211
+ optim = ScaledAdam(m.parameters(), lr=0.03)
212
+
213
+ scheduler = Eden(optim, lr_batches=100, lr_epochs=2, verbose=True)
214
+
215
+ for epoch in range(10):
216
+ scheduler.step_epoch(epoch) # sets epoch to `epoch`
217
+
218
+ for step in range(20):
219
+ x = torch.randn(200, 100).detach()
220
+ x.requires_grad = True
221
+ y = m(x)
222
+ dy = torch.randn(200, 100).detach()
223
+ f = (y * dy).sum()
224
+ f.backward()
225
+
226
+ optim.step()
227
+ scheduler.step_batch()
228
+ optim.zero_grad()
229
+
230
+ logging.info(f"last lr = {scheduler.get_last_lr()}")
231
+ logging.info(f"state dict = {scheduler.state_dict()}")
232
+
233
+
234
+ if __name__ == "__main__":
235
+ torch.set_num_threads(1)
236
+ torch.set_num_interop_threads(1)
237
+ logging.getLogger().setLevel(logging.INFO)
238
+ import subprocess
239
+
240
+ s = subprocess.check_output(
241
+ "git status -uno .; git log -1; git diff HEAD .", shell=True
242
+ )
243
+ logging.info(s)
244
+
245
+ _test_eden()
zipvoice/utils/optim.py ADDED
@@ -0,0 +1,868 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
2
+ #
3
+ # See ../LICENSE for clarification regarding multiple authors
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import contextlib
18
+ import logging
19
+ from collections import defaultdict
20
+ from typing import Dict, List, Tuple
21
+
22
+ import torch
23
+ from lhotse.utils import fix_random_seed
24
+ from torch import Tensor
25
+ from torch.optim import Optimizer
26
+
27
+
28
+ class BatchedOptimizer(Optimizer):
29
+ """
30
+ This class adds to class Optimizer the capability to optimize parameters in batches:
31
+ it will stack the parameters and their grads for you so the optimizer can work
32
+ on tensors with an extra leading dimension. This is intended for speed with GPUs,
33
+ as it reduces the number of kernels launched in the optimizer.
34
+
35
+ Args:
36
+ params:
37
+ """
38
+
39
+ def __init__(self, params, defaults):
40
+ super(BatchedOptimizer, self).__init__(params, defaults)
41
+
42
+ @contextlib.contextmanager
43
+ def batched_params(self, param_group, group_params_names):
44
+ """
45
+ This function returns (technically, yields) a list of
46
+ of tuples (p, state), where
47
+ p is a `fake` parameter that is stacked (over axis 0) from real parameters
48
+ that share the same shape, and its gradient is also stacked;
49
+ `state` is the state corresponding to this batch of parameters
50
+ (it will be physically located in the "state" for one of the real
51
+ parameters, the last one that has any particular shape and dtype).
52
+
53
+ This function is decorated as a context manager so that it can
54
+ write parameters back to their "real" locations.
55
+
56
+ The idea is, instead of doing:
57
+ <code>
58
+ for p in group["params"]:
59
+ state = self.state[p]
60
+ ...
61
+ </code>
62
+ you can do:
63
+ <code>
64
+ with self.batched_params(group["params"]) as batches:
65
+ for p, state, p_names in batches:
66
+ ...
67
+ </code>
68
+
69
+ Args:
70
+ group: a parameter group, which is a list of parameters; should be
71
+ one of self.param_groups.
72
+ group_params_names: name for each parameter in group,
73
+ which is List[str].
74
+ """
75
+ batches = defaultdict(
76
+ list
77
+ ) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
78
+ batches_names = defaultdict(
79
+ list
80
+ ) # `batches` maps from tuple (dtype_as_str,*shape) to list of str
81
+
82
+ assert len(param_group) == len(group_params_names)
83
+ for p, named_p in zip(param_group, group_params_names):
84
+ key = (str(p.dtype), *p.shape)
85
+ batches[key].append(p)
86
+ batches_names[key].append(named_p)
87
+
88
+ batches_names_keys = list(batches_names.keys())
89
+ sorted_idx = sorted(
90
+ range(len(batches_names)), key=lambda i: batches_names_keys[i]
91
+ )
92
+ batches_names = [batches_names[batches_names_keys[idx]] for idx in sorted_idx]
93
+ batches = [batches[batches_names_keys[idx]] for idx in sorted_idx]
94
+
95
+ stacked_params_dict = dict()
96
+
97
+ # turn batches into a list, in deterministic order.
98
+ # tuples will contain tuples of (stacked_param, state, stacked_params_names),
99
+ # one for each batch in `batches`.
100
+ tuples = []
101
+
102
+ for batch, batch_names in zip(batches, batches_names):
103
+ p = batch[0]
104
+ # we arbitrarily store the state in the
105
+ # state corresponding to the 1st parameter in the
106
+ # group. class Optimizer will take care of saving/loading state.
107
+ state = self.state[p]
108
+ p_stacked = torch.stack(batch)
109
+ grad = torch.stack(
110
+ [torch.zeros_like(p) if p.grad is None else p.grad for p in batch]
111
+ )
112
+ p_stacked.grad = grad
113
+ stacked_params_dict[key] = p_stacked
114
+ tuples.append((p_stacked, state, batch_names))
115
+
116
+ yield tuples # <-- calling code will do the actual optimization here!
117
+
118
+ for (stacked_params, _state, _names), batch in zip(tuples, batches):
119
+ for i, p in enumerate(batch): # batch is list of Parameter
120
+ p.copy_(stacked_params[i])
121
+
122
+
123
+ def basic_step(group, p, state, grad):
124
+ # computes basic Adam update using beta2 (dividing by gradient stddev) only. no
125
+ # momentum yet.
126
+ lr = group["lr"]
127
+ if p.numel() == p.shape[0]:
128
+ lr = lr * group["scalar_lr_scale"]
129
+ beta2 = group["betas"][1]
130
+ eps = group["eps"]
131
+ # p shape: (batch_size,) or (batch_size, 1, [1,..])
132
+ try:
133
+ exp_avg_sq = state[
134
+ "exp_avg_sq"
135
+ ] # shape: (batch_size,) or (batch_size, 1, [1,..])
136
+ except KeyError:
137
+ exp_avg_sq = torch.zeros(*p.shape, device=p.device, dtype=torch.float)
138
+ state["exp_avg_sq"] = exp_avg_sq
139
+
140
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
141
+
142
+ # bias_correction2 is like in Adam.
143
+ # slower update at the start will help stability anyway.
144
+ bias_correction2 = 1 - beta2 ** (state["step"] + 1)
145
+ if bias_correction2 < 0.99:
146
+ # note: not in-place.
147
+ exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2)
148
+ denom = exp_avg_sq.sqrt().add_(eps)
149
+
150
+ return -lr * grad / denom
151
+
152
+
153
+ def scaling_step(group, p, state, grad):
154
+ delta = basic_step(group, p, state, grad)
155
+ if p.numel() == p.shape[0]:
156
+ return delta
157
+ # there is no scaling for scalar parameters.
158
+ # (p.shape[0] is the batch of parameters.)
159
+
160
+ step = state["step"]
161
+ size_update_period = group["size_update_period"]
162
+
163
+ try:
164
+ param_rms = state["param_rms"]
165
+ scale_grads = state["scale_grads"]
166
+ scale_exp_avg_sq = state["scale_exp_avg_sq"]
167
+ except KeyError:
168
+ # we know p.ndim > 1 because we'd have returned above if not, so don't worry
169
+ # about the speial case of dim=[] that pytorch treats inconsistently.
170
+ param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()
171
+ param_rms = param_rms.to(torch.float)
172
+ scale_exp_avg_sq = torch.zeros_like(param_rms)
173
+ scale_grads = torch.zeros(
174
+ size_update_period,
175
+ *param_rms.shape,
176
+ dtype=torch.float,
177
+ device=p.device,
178
+ )
179
+ state["param_rms"] = param_rms
180
+ state["scale_grads"] = scale_grads
181
+ state["scale_exp_avg_sq"] = scale_exp_avg_sq
182
+
183
+ # on every step, update the gradient w.r.t. the scale of the parameter, we
184
+ # store these as a batch and periodically update the size (for speed only, to
185
+ # avoid too many operations).
186
+ scale_grads[step % size_update_period] = (p * grad).sum(
187
+ dim=list(range(1, p.ndim)), keepdim=True
188
+ )
189
+
190
+ # periodically recompute the value of param_rms.
191
+ if step % size_update_period == size_update_period - 1:
192
+ param_rms.copy_((p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt())
193
+
194
+ param_min_rms = group["param_min_rms"]
195
+
196
+ # scale the step size by param_rms. This is the most important "scaling" part of
197
+ # ScaledAdam
198
+ delta *= param_rms.clamp(min=param_min_rms)
199
+
200
+ if step % size_update_period == size_update_period - 1 and step > 0:
201
+ # This block updates the size of parameter by adding a step ("delta") value in
202
+ # the direction of either shrinking or growing it.
203
+ beta2 = group["betas"][1]
204
+ size_lr = group["lr"] * group["scalar_lr_scale"]
205
+ param_max_rms = group["param_max_rms"]
206
+ eps = group["eps"]
207
+ # correct beta2 for the size update period: we will have
208
+ # faster decay at this level.
209
+ beta2_corr = beta2**size_update_period
210
+ scale_exp_avg_sq.mul_(beta2_corr).add_(
211
+ (scale_grads**2).mean(dim=0), # mean over dim `size_update_period`
212
+ alpha=1 - beta2_corr,
213
+ ) # shape is (batch_size, 1, 1, ...)
214
+
215
+ # The 1st time we reach here is when size_step == 1.
216
+ size_step = (step + 1) // size_update_period
217
+ bias_correction2 = 1 - beta2_corr**size_step
218
+
219
+ denom = scale_exp_avg_sq.sqrt() + eps
220
+
221
+ scale_step = (
222
+ -size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom
223
+ )
224
+
225
+ is_too_small = param_rms < param_min_rms
226
+
227
+ # when the param gets too small, just don't shrink it any further.
228
+ scale_step.masked_fill_(is_too_small, 0.0)
229
+
230
+ # The following may help prevent instability: don't allow the scale step to be
231
+ # too large in either direction.
232
+ scale_step.clamp_(min=-0.1, max=0.1)
233
+
234
+ # and ensure the parameter rms after update never exceeds param_max_rms.
235
+ # We have to look at the trained model for parameters at or around the
236
+ # param_max_rms, because sometimes they can indicate a problem with the
237
+ # topology or settings.
238
+ scale_step = torch.minimum(scale_step, (param_max_rms - param_rms) / param_rms)
239
+
240
+ delta.add_(p * scale_step)
241
+
242
+ return delta
243
+
244
+
245
+ def momentum_step(group, p, state, grad):
246
+ delta = scaling_step(group, p, state, grad)
247
+ beta1 = group["betas"][0]
248
+ try:
249
+ stored_delta = state["delta"]
250
+ except KeyError:
251
+ stored_delta = torch.zeros(*p.shape, device=p.device, dtype=torch.float)
252
+ state["delta"] = stored_delta
253
+ stored_delta.mul_(beta1)
254
+ stored_delta.add_(delta, alpha=(1 - beta1))
255
+ # we don't bother doing the "bias correction" part of Adam for beta1 because this is
256
+ # just an edge effect that affects the first 10 or so batches; and the effect of not
257
+ # doing it is just to do a slower update for the first few batches, which will help
258
+ # stability.
259
+ return stored_delta
260
+
261
+
262
+ class ScaledAdam(BatchedOptimizer):
263
+ """
264
+ Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update
265
+ proportional to the norm of that parameter; and also learn the scale of the
266
+ parameter, in log space, subject to upper and lower limits (as if we had factored
267
+ each parameter as param = underlying_param * log_scale.exp())
268
+
269
+
270
+ Args:
271
+ params: The parameters or param_groups to optimize (like other Optimizer
272
+ subclasses) Unlike common optimizers, which accept
273
+ model.parameters() or groups of parameters(), this optimizer
274
+ could accept model.named_parameters() or groups of
275
+ named_parameters(). See comments of function
276
+ _get_names_of_parameters for its 4 possible cases.
277
+ lr: The learning rate. We will typically use a learning rate schedule
278
+ that starts at 0.03 and decreases over time, i.e. much higher
279
+ than other common optimizers.
280
+ clipping_scale: (e.g. 2.0)
281
+ A scale for gradient-clipping: if specified, the normalized gradients
282
+ over the whole model will be clipped to have 2-norm equal to
283
+ `clipping_scale` times the median 2-norm over the most recent period
284
+ of `clipping_update_period` minibatches. By "normalized gradients",
285
+ we mean after multiplying by the rms parameter value for this tensor
286
+ [for non-scalars]; this is appropriate because our update is scaled
287
+ by this quantity.
288
+ betas: beta1,beta2 are momentum constants for regular momentum, and moving
289
+ sum-sq grad. Must satisfy 0 < beta <= beta2 < 1.
290
+ scalar_lr_scale: A scaling factor on the learning rate, that we use to update the
291
+ scale of each parameter tensor and scalar parameters of the mode..
292
+ If each parameter were decomposed as p * p_scale.exp(),
293
+ where (p**2).mean().sqrt() == 1.0, scalar_lr_scale would be a the
294
+ scaling factor on the learning rate of p_scale.
295
+ eps: A general-purpose epsilon to prevent division by zero
296
+ param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of
297
+ learning the scale on the parameters (we'll constrain the rms of
298
+ each non-scalar parameter tensor to be >= this value)
299
+ param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of
300
+ learning the scale on the parameters (we'll constrain the rms of
301
+ each non-scalar parameter tensor to be <= this value)
302
+ scalar_max: Maximum absolute value for scalar parameters (applicable if your
303
+ model has any parameters with numel() == 1).
304
+ size_update_period: The periodicity, in steps, with which we update the size (scale)
305
+ of the parameter tensor. This is provided to save a little time
306
+ in the update.
307
+ clipping_update_period: if clipping_scale is specified, this is the period
308
+ """
309
+
310
+ def __init__(
311
+ self,
312
+ params,
313
+ lr=3e-02,
314
+ clipping_scale=None,
315
+ betas=(0.9, 0.98),
316
+ scalar_lr_scale=0.1,
317
+ eps=1.0e-08,
318
+ param_min_rms=1.0e-05,
319
+ param_max_rms=3.0,
320
+ scalar_max=10.0,
321
+ size_update_period=4,
322
+ clipping_update_period=100,
323
+ ):
324
+
325
+ defaults = dict(
326
+ lr=lr,
327
+ clipping_scale=clipping_scale,
328
+ betas=betas,
329
+ scalar_lr_scale=scalar_lr_scale,
330
+ eps=eps,
331
+ param_min_rms=param_min_rms,
332
+ param_max_rms=param_max_rms,
333
+ scalar_max=scalar_max,
334
+ size_update_period=size_update_period,
335
+ clipping_update_period=clipping_update_period,
336
+ )
337
+
338
+ # If params only contains parameters or group of parameters,
339
+ # i.e when parameter names are not given,
340
+ # this flag will be set to False in funciton _get_names_of_parameters.
341
+ self.show_dominant_parameters = True
342
+ param_groups, parameters_names = self._get_names_of_parameters(params)
343
+ super(ScaledAdam, self).__init__(param_groups, defaults)
344
+ assert len(self.param_groups) == len(parameters_names)
345
+ self.parameters_names = parameters_names
346
+
347
+ def _get_names_of_parameters(
348
+ self, params_or_named_params
349
+ ) -> Tuple[List[Dict], List[List[str]]]:
350
+ """
351
+ Args:
352
+ params_or_named_params: according to the way ScaledAdam is initialized
353
+ in train.py, this argument could be one of following 4 cases,
354
+ case 1, a generator of parameter, e.g.:
355
+ optimizer = ScaledAdam(model.parameters(), lr=params.base_lr,
356
+ clipping_scale=3.0)
357
+
358
+ case 2, a list of parameter groups with different config, e.g.:
359
+ model_param_groups = [
360
+ {'params': model.encoder.parameters(), 'lr': 0.05},
361
+ {'params': model.decoder.parameters(), 'lr': 0.01},
362
+ {'params': model.joiner.parameters(), 'lr': 0.03},
363
+ ]
364
+ optimizer = ScaledAdam(model_param_groups, lr=params.base_lr,
365
+ clipping_scale=3.0)
366
+
367
+ case 3, a generator of named_parameter, e.g.:
368
+ optimizer = ScaledAdam(model.named_parameters(), lr=params.base_lr,
369
+ clipping_scale=3.0)
370
+
371
+ case 4, a list of named_parameter groups with different config, e.g.:
372
+ model_named_param_groups = [
373
+ {'named_params': model.encoder.named_parameters(), 'lr': 0.05},
374
+ {'named_params': model.decoder.named_parameters(), 'lr': 0.01},
375
+ {'named_params': model.joiner.named_parameters(), 'lr': 0.03},
376
+ ]
377
+ optimizer = ScaledAdam(model_named_param_groups, lr=params.base_lr,
378
+ clipping_scale=3.0)
379
+
380
+ For case 1 and case 2, input params is used to initialize the underlying
381
+ torch.optimizer.
382
+ For case 3 and case 4, firstly, names and params are extracted from input
383
+ named_params, then, these extracted params are used to initialize the
384
+ underlying torch.optimizer, and these extracted names are mainly used by
385
+ function `_show_gradient_dominating_parameter`
386
+
387
+ Returns:
388
+ Returns a tuple containing 2 elements:
389
+ - `param_groups` with type List[Dict], each Dict element is a parameter
390
+ group. An example of `param_groups` could be:
391
+ [
392
+ {'params': `one iterable of Parameter`, 'lr': 0.05},
393
+ {'params': `another iterable of Parameter`, 'lr': 0.08},
394
+ {'params': `a third iterable of Parameter`, 'lr': 0.1},
395
+ ]
396
+ - `param_gruops_names` with type List[List[str]],
397
+ each `List[str]` is for a group['params'] in param_groups,
398
+ and each `str` is the name of a parameter.
399
+ A dummy name "foo" is related to each parameter,
400
+ if input are params without names, i.e. case 1 or case 2.
401
+ """
402
+ # variable naming convention in this function:
403
+ # p is short for param.
404
+ # np is short for named_param.
405
+ # p_or_np is short for param_or_named_param.
406
+ # cur is short for current.
407
+ # group is a dict,
408
+ # e.g. {'params': iterable of parameter, 'lr': 0.05, other fields}.
409
+ # groups is a List[group]
410
+
411
+ iterable_or_groups = list(params_or_named_params)
412
+ if len(iterable_or_groups) == 0:
413
+ raise ValueError("optimizer got an empty parameter list")
414
+
415
+ # The first value of returned tuple. A list of dicts containing at
416
+ # least 'params' as a key.
417
+ param_groups = []
418
+
419
+ # The second value of returned tuple,
420
+ # a List[List[str]], each sub-List is for a group.
421
+ param_groups_names = []
422
+
423
+ if not isinstance(iterable_or_groups[0], dict):
424
+ # case 1 or case 3,
425
+ # the input is an iterable of parameter or named parameter.
426
+ param_iterable_cur_group = []
427
+ param_names_cur_group = []
428
+ for p_or_np in iterable_or_groups:
429
+ if isinstance(p_or_np, tuple):
430
+ # case 3
431
+ name, param = p_or_np
432
+ else:
433
+ # case 1
434
+ assert isinstance(p_or_np, torch.Tensor)
435
+ param = p_or_np
436
+ # Assign a dummy name as a placeholder
437
+ name = "foo"
438
+ self.show_dominant_parameters = False
439
+ param_iterable_cur_group.append(param)
440
+ param_names_cur_group.append(name)
441
+ param_groups.append({"params": param_iterable_cur_group})
442
+ param_groups_names.append(param_names_cur_group)
443
+ else:
444
+ # case 2 or case 4
445
+ # the input is groups of parameter or named parameter.
446
+ for cur_group in iterable_or_groups:
447
+ if "named_params" in cur_group:
448
+ name_list = [x[0] for x in cur_group["named_params"]]
449
+ p_list = [x[1] for x in cur_group["named_params"]]
450
+ del cur_group["named_params"]
451
+ cur_group["params"] = p_list
452
+ else:
453
+ assert "params" in cur_group
454
+ name_list = ["foo" for _ in cur_group["params"]]
455
+ param_groups.append(cur_group)
456
+ param_groups_names.append(name_list)
457
+
458
+ return param_groups, param_groups_names
459
+
460
+ def __setstate__(self, state):
461
+ super(ScaledAdam, self).__setstate__(state)
462
+
463
+ @torch.no_grad()
464
+ def step(self, closure=None):
465
+ """Performs a single optimization step.
466
+
467
+ Arguments:
468
+ closure (callable, optional): A closure that reevaluates the model
469
+ and returns the loss.
470
+ """
471
+ loss = None
472
+ if closure is not None:
473
+ with torch.enable_grad():
474
+ loss = closure()
475
+
476
+ for group, group_params_names in zip(self.param_groups, self.parameters_names):
477
+
478
+ with self.batched_params(group["params"], group_params_names) as batches:
479
+
480
+ # batches is list of pairs (stacked_param, state). stacked_param is
481
+ # like a regular parameter, and will have a .grad, but the 1st dim
482
+ # corresponds to a stacking dim, it is not a real dim.
483
+
484
+ if (
485
+ len(batches[0][1]) == 0
486
+ ): # if len(first state) == 0: not yet initialized
487
+ clipping_scale = 1
488
+ else:
489
+ clipping_scale = self._get_clipping_scale(group, batches)
490
+
491
+ for p, state, _ in batches:
492
+ # Perform optimization step.
493
+ # grad is not going to be None, we handled that when creating the
494
+ # batches.
495
+ grad = p.grad
496
+ if grad.is_sparse:
497
+ raise RuntimeError(
498
+ "ScaledAdam optimizer does not support sparse gradients"
499
+ )
500
+
501
+ try:
502
+ cur_step = state["step"]
503
+ except KeyError:
504
+ state["step"] = 0
505
+ cur_step = 0
506
+
507
+ grad = (
508
+ p.grad if clipping_scale == 1.0 else p.grad.mul_(clipping_scale)
509
+ )
510
+ p += momentum_step(group, p.detach(), state, grad)
511
+
512
+ if p.numel() == p.shape[0]: # scalar parameter
513
+ scalar_max = group["scalar_max"]
514
+ p.clamp_(min=-scalar_max, max=scalar_max)
515
+
516
+ state["step"] = cur_step + 1
517
+
518
+ return loss
519
+
520
+ def _get_clipping_scale(
521
+ self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]]
522
+ ) -> float:
523
+ """
524
+ Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will
525
+ scale the gradients by this amount before applying the rest of the update.
526
+
527
+ Args:
528
+ group: the parameter group, an item in self.param_groups
529
+ tuples: a list of tuples of (param, state, param_names)
530
+ where param is a batched set of parameters,
531
+ with a .grad (1st dim is batch dim)
532
+ and state is the state-dict where optimization parameters are kept.
533
+ param_names is a List[str] while each str is name for a parameter
534
+ in batched set of parameters "param".
535
+ """
536
+ assert len(tuples) >= 1
537
+ clipping_scale = group["clipping_scale"]
538
+ (first_p, first_state, _) = tuples[0]
539
+ step = first_state["step"]
540
+ if clipping_scale is None or step == 0:
541
+ # no clipping. return early on step == 0 because the other
542
+ # parameters' state won't have been initialized yet.
543
+ return 1.0
544
+ clipping_update_period = group["clipping_update_period"]
545
+ scalar_lr_scale = group["scalar_lr_scale"]
546
+
547
+ tot_sumsq = torch.tensor(0.0, device=first_p.device)
548
+ for p, state, param_names in tuples:
549
+ grad = p.grad
550
+ if grad.is_sparse:
551
+ raise RuntimeError(
552
+ "ScaledAdam optimizer does not support sparse gradients"
553
+ )
554
+ if p.numel() == p.shape[0]: # a batch of scalars
555
+ tot_sumsq += (grad**2).sum() * (
556
+ scalar_lr_scale**2
557
+ ) # sum() to change shape [1] to []
558
+ else:
559
+ tot_sumsq += ((grad * state["param_rms"]) ** 2).sum()
560
+
561
+ tot_norm = tot_sumsq.sqrt()
562
+ if "model_norms" not in first_state:
563
+ first_state["model_norms"] = torch.zeros(
564
+ clipping_update_period, device=p.device
565
+ )
566
+ first_state["model_norms"][step % clipping_update_period] = tot_norm
567
+
568
+ irregular_estimate_steps = [
569
+ i for i in [10, 20, 40] if i < clipping_update_period
570
+ ]
571
+ if step % clipping_update_period == 0 or step in irregular_estimate_steps:
572
+ # Print some stats.
573
+ # We don't reach here if step == 0 because we would have returned
574
+ # above.
575
+ sorted_norms = first_state["model_norms"].sort()[0].to("cpu")
576
+ if step in irregular_estimate_steps:
577
+ sorted_norms = sorted_norms[-step:]
578
+ num_norms = sorted_norms.numel()
579
+ quartiles = []
580
+ for n in range(0, 5):
581
+ index = min(num_norms - 1, (num_norms // 4) * n)
582
+ quartiles.append(sorted_norms[index].item())
583
+
584
+ median = quartiles[2]
585
+ if median - median != 0:
586
+ raise RuntimeError("Too many grads were not finite")
587
+ threshold = clipping_scale * median
588
+ if step in irregular_estimate_steps:
589
+ # use larger thresholds on first few steps of estimating threshold,
590
+ # as norm may be changing rapidly.
591
+ threshold = threshold * 2.0
592
+ first_state["model_norm_threshold"] = threshold
593
+ percent_clipped = (
594
+ first_state["num_clipped"] * 100.0 / num_norms
595
+ if "num_clipped" in first_state
596
+ else 0.0
597
+ )
598
+ first_state["num_clipped"] = 0
599
+ quartiles = " ".join(["%.3e" % x for x in quartiles])
600
+ logging.warning(
601
+ f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, "
602
+ f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}"
603
+ )
604
+
605
+ try:
606
+ model_norm_threshold = first_state["model_norm_threshold"]
607
+ except KeyError:
608
+ return 1.0 # threshold has not yet been set.
609
+
610
+ ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item())
611
+ if ans != ans: # e.g. ans is nan
612
+ ans = 0.0
613
+ if ans < 1.0:
614
+ first_state["num_clipped"] += 1
615
+ if ans < 0.5:
616
+ logging.debug(
617
+ f"Scaling gradients by {ans}, "
618
+ f"model_norm_threshold={model_norm_threshold}"
619
+ )
620
+ if self.show_dominant_parameters:
621
+ assert p.shape[0] == len(param_names)
622
+ self._show_gradient_dominating_parameter(
623
+ tuples, tot_sumsq, group["scalar_lr_scale"]
624
+ )
625
+ self._show_param_with_unusual_grad(tuples)
626
+
627
+ if ans == 0.0:
628
+ for p, state, param_names in tuples:
629
+ p.grad.zero_() # get rid of infinity()
630
+
631
+ return ans
632
+
633
+ def _show_param_with_unusual_grad(
634
+ self,
635
+ tuples: List[Tuple[Tensor, dict, List[str]]],
636
+ ):
637
+ """
638
+ Print information about parameter which has the largest ratio of
639
+ grad-on-this-batch divided by normal grad size.
640
+ tuples: a list of tuples of (param, state, param_names)
641
+ where param is a batched set of parameters,
642
+ with a .grad (1st dim is batch dim)
643
+ and state is the state-dict where optimization parameters are kept.
644
+ param_names is a List[str] while each str is name for a parameter
645
+ in batched set of parameters "param".
646
+ """
647
+ # ratios_names is a list of 3-tuples: (grad_ratio, param_name, tensor)
648
+ ratios_names = []
649
+ for p, state, batch_param_names in tuples:
650
+ dims = list(range(1, p.ndim))
651
+
652
+ def mean(x):
653
+ # workaround for bad interface of torch's "mean" for when dims is the
654
+ # empty list.
655
+ if len(dims) > 0:
656
+ return x.mean(dim=dims)
657
+ else:
658
+ return x
659
+
660
+ grad_ratio = (
661
+ (mean(p.grad**2) / state["exp_avg_sq"].mean(dim=dims))
662
+ .sqrt()
663
+ .to("cpu")
664
+ )
665
+
666
+ ratios_names += zip(
667
+ grad_ratio.tolist(), batch_param_names, p.grad.unbind(dim=0)
668
+ )
669
+
670
+ ratios_names = sorted(ratios_names, reverse=True)
671
+ ratios_names = ratios_names[:10]
672
+ ratios_names = [
673
+ (ratio, name, largest_index(tensor))
674
+ for (ratio, name, tensor) in ratios_names
675
+ ]
676
+
677
+ logging.debug(
678
+ f"Parameters with most larger-than-usual grads, with ratios, "
679
+ f"are: {ratios_names}"
680
+ )
681
+
682
+ def _show_gradient_dominating_parameter(
683
+ self,
684
+ tuples: List[Tuple[Tensor, dict, List[str]]],
685
+ tot_sumsq: Tensor,
686
+ scalar_lr_scale: float,
687
+ ):
688
+ """
689
+ Show information of parameter which dominates tot_sumsq.
690
+
691
+ Args:
692
+ tuples: a list of tuples of (param, state, param_names)
693
+ where param is a batched set of parameters,
694
+ with a .grad (1st dim is batch dim)
695
+ and state is the state-dict where optimization parameters are kept.
696
+ param_names is a List[str] while each str is name for a parameter
697
+ in batched set of parameters "param".
698
+ tot_sumsq: sumsq of all parameters. Though it's could be calculated
699
+ from tuples, we still pass it to save some time.
700
+ """
701
+ all_sumsq_orig = {}
702
+ for p, state, batch_param_names in tuples:
703
+ # p is a stacked batch parameters.
704
+ batch_grad = p.grad
705
+ if p.numel() == p.shape[0]: # a batch of scalars
706
+ # Dummy values used by following `zip` statement.
707
+ batch_rms_orig = torch.full(
708
+ p.shape, scalar_lr_scale, device=batch_grad.device
709
+ )
710
+ else:
711
+ batch_rms_orig = state["param_rms"]
712
+ batch_sumsq_orig = (batch_grad * batch_rms_orig) ** 2
713
+ if batch_grad.ndim > 1:
714
+ # need to guard it with if-statement because sum() sums over
715
+ # all dims if dim == ().
716
+ batch_sumsq_orig = batch_sumsq_orig.sum(
717
+ dim=list(range(1, batch_grad.ndim))
718
+ )
719
+ for name, sumsq_orig, rms, grad in zip(
720
+ batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad
721
+ ):
722
+
723
+ proportion_orig = sumsq_orig / tot_sumsq
724
+ all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
725
+
726
+ sorted_by_proportion = {
727
+ k: v
728
+ for k, v in sorted(
729
+ all_sumsq_orig.items(),
730
+ key=lambda item: item[1][0],
731
+ reverse=True,
732
+ )
733
+ }
734
+ dominant_param_name = next(iter(sorted_by_proportion))
735
+ (
736
+ dominant_proportion,
737
+ dominant_sumsq,
738
+ dominant_rms,
739
+ dominant_grad,
740
+ ) = sorted_by_proportion[dominant_param_name]
741
+ logging.debug(
742
+ f"Parameter dominating tot_sumsq {dominant_param_name}"
743
+ f" with proportion {dominant_proportion:.2f},"
744
+ f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
745
+ f"={dominant_sumsq:.3e},"
746
+ f" grad_sumsq={(dominant_grad**2).sum():.3e},"
747
+ f" orig_rms_sq={(dominant_rms**2).item():.3e}"
748
+ )
749
+
750
+
751
+ def largest_index(x: Tensor):
752
+ x = x.contiguous()
753
+ argmax = x.abs().argmax().item()
754
+ return [(argmax // x.stride(i)) % x.size(i) for i in range(x.ndim)]
755
+
756
+
757
+ def _test_scaled_adam(hidden_dim: int):
758
+ import timeit
759
+
760
+ from zipvoice.models.modules.scaling import ScaledLinear
761
+ from zipvoice.utils.lr_scheduler import Eden
762
+
763
+ E = 100
764
+ B = 4
765
+ T = 2
766
+ logging.info("in test_eve_cain")
767
+ # device = torch.device('cuda')
768
+ device = torch.device("cpu")
769
+ dtype = torch.float32
770
+
771
+ fix_random_seed(42)
772
+ # these input_magnitudes and output_magnitudes are to test that
773
+ # Abel is working as we expect and is able to adjust scales of
774
+ # different dims differently.
775
+ input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
776
+ output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
777
+
778
+ fix_random_seed(42)
779
+ Linear = ScaledLinear
780
+
781
+ m = torch.nn.Sequential(
782
+ Linear(E, hidden_dim),
783
+ torch.nn.PReLU(),
784
+ Linear(hidden_dim, hidden_dim),
785
+ torch.nn.PReLU(),
786
+ Linear(hidden_dim, E),
787
+ ).to(device)
788
+
789
+ train_pairs = [
790
+ (
791
+ 100.0 * torch.randn(B, T, E, device=device, dtype=dtype) * input_magnitudes,
792
+ torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes,
793
+ )
794
+ for _ in range(20)
795
+ ]
796
+ optim = ScaledAdam(m.named_parameters(), lr=0.03, clipping_scale=2.0)
797
+ scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False)
798
+
799
+ start = timeit.default_timer()
800
+ avg_loss = 0.0
801
+ for epoch in range(180):
802
+ scheduler.step_epoch()
803
+ # if epoch == 100 and iter in [2,3]:
804
+ # optim.reset_speedup() # check it doesn't crash.
805
+
806
+ # if epoch == 130:
807
+ # opts = diagnostics.TensorDiagnosticOptions(
808
+ # 512
809
+ # ) # allow 4 megabytes per sub-module
810
+ # diagnostic = diagnostics.attach_diagnostics(m, opts)
811
+
812
+ for n, (x, y) in enumerate(train_pairs):
813
+ y_out = m(x)
814
+ loss = ((y_out - y) ** 2).mean() * 100.0
815
+ if epoch == 0 and n == 0:
816
+ avg_loss = loss.item()
817
+ else:
818
+ avg_loss = 0.98 * avg_loss + 0.02 * loss.item()
819
+ if n == 0 and epoch % 5 == 0:
820
+ # norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item()
821
+ # norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item()
822
+ # norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item()
823
+ # norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item()
824
+ # scale1 = '%.2e' % (m[0].weight_scale.exp().item())
825
+ # scale1b = '%.2e' % (m[0].bias_scale.exp().item())
826
+ # scale2 = '%.2e' % (m[2].weight_scale.exp().item())
827
+ # scale2b = '%.2e' % (m[2].bias_scale.exp().item())
828
+ lr = scheduler.get_last_lr()[0]
829
+ logging.info(
830
+ f"Iter {iter}, epoch {epoch}, batch {n}, "
831
+ f"avg_loss {avg_loss:.4g}, lr={lr:.4e}"
832
+ ) # , norms={norm1,norm1b,norm2,norm2b}")
833
+ # scales={scale1,scale1b,scale2,scale2b}
834
+ loss.log().backward()
835
+ optim.step()
836
+ optim.zero_grad()
837
+ scheduler.step_batch()
838
+
839
+ # diagnostic.print_diagnostics()
840
+
841
+ stop = timeit.default_timer()
842
+ logging.info(f"Iter={iter}, Time taken: {stop - start}")
843
+
844
+ logging.info(f"last lr = {scheduler.get_last_lr()}")
845
+ # logging.info("state dict = ", scheduler.state_dict())
846
+ # logging.info("optim state_dict = ", optim.state_dict())
847
+ logging.info(f"input_magnitudes = {input_magnitudes}")
848
+ logging.info(f"output_magnitudes = {output_magnitudes}")
849
+
850
+
851
+ if __name__ == "__main__":
852
+ torch.set_num_threads(1)
853
+ torch.set_num_interop_threads(1)
854
+ logging.getLogger().setLevel(logging.INFO)
855
+ import subprocess
856
+
857
+ s = subprocess.check_output(
858
+ "git status -uno .; git log -1; git diff HEAD .", shell=True
859
+ )
860
+ logging.info(s)
861
+ import sys
862
+
863
+ if len(sys.argv) > 1:
864
+ hidden_dim = int(sys.argv[1])
865
+ else:
866
+ hidden_dim = 200
867
+
868
+ _test_scaled_adam(hidden_dim)
zipvoice/utils/scaling_converter.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang,
2
+ # Zengwei Yao)
3
+ #
4
+ # See ../../../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """
19
+ This file replaces various modules in a model.
20
+ Specifically, ActivationBalancer is replaced with an identity operator;
21
+ Whiten is also replaced with an identity operator;
22
+ BasicNorm is replaced by a module with `exp` removed.
23
+ """
24
+
25
+ import copy
26
+ from typing import List
27
+
28
+ import torch
29
+ import torch.nn as nn
30
+
31
+ from zipvoice.models.modules.scaling import (
32
+ Balancer,
33
+ Dropout3,
34
+ SwooshL,
35
+ SwooshLOnnx,
36
+ SwooshR,
37
+ SwooshROnnx,
38
+ Whiten,
39
+ )
40
+ from zipvoice.models.modules.zipformer import CompactRelPositionalEncoding
41
+
42
+
43
+ # Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule # noqa
44
+ # get_submodule was added to nn.Module at v1.9.0
45
+ def get_submodule(model, target):
46
+ if target == "":
47
+ return model
48
+ atoms: List[str] = target.split(".")
49
+ mod: torch.nn.Module = model
50
+ for item in atoms:
51
+ if not hasattr(mod, item):
52
+ raise AttributeError(
53
+ mod._get_name() + " has no " "attribute `" + item + "`"
54
+ )
55
+ mod = getattr(mod, item)
56
+ if not isinstance(mod, torch.nn.Module):
57
+ raise AttributeError("`" + item + "` is not " "an nn.Module")
58
+ return mod
59
+
60
+
61
+ def convert_scaled_to_non_scaled(
62
+ model: nn.Module,
63
+ inplace: bool = False,
64
+ is_pnnx: bool = False,
65
+ is_onnx: bool = False,
66
+ ):
67
+ """
68
+ Args:
69
+ model:
70
+ The model to be converted.
71
+ inplace:
72
+ If True, the input model is modified inplace.
73
+ If False, the input model is copied and we modify the copied version.
74
+ is_pnnx:
75
+ True if we are going to export the model for PNNX.
76
+ is_onnx:
77
+ True if we are going to export the model for ONNX.
78
+ Return:
79
+ Return a model without scaled layers.
80
+ """
81
+ if not inplace:
82
+ model = copy.deepcopy(model)
83
+
84
+ d = {}
85
+ for name, m in model.named_modules():
86
+ if isinstance(m, (Balancer, Dropout3, Whiten)):
87
+ d[name] = nn.Identity()
88
+ elif is_onnx and isinstance(m, SwooshR):
89
+ d[name] = SwooshROnnx()
90
+ elif is_onnx and isinstance(m, SwooshL):
91
+ d[name] = SwooshLOnnx()
92
+ elif is_onnx and isinstance(m, CompactRelPositionalEncoding):
93
+ # We want to recreate the positional encoding vector when
94
+ # the input changes, so we have to use torch.jit.script()
95
+ # to replace torch.jit.trace()
96
+ d[name] = torch.jit.script(m)
97
+
98
+ for k, v in d.items():
99
+ if "." in k:
100
+ parent, child = k.rsplit(".", maxsplit=1)
101
+ setattr(get_submodule(model, parent), child, v)
102
+ else:
103
+ setattr(model, k, v)
104
+
105
+ return model