bharathkumarK commited on
Commit
2881cc1
Β·
verified Β·
1 Parent(s): 520a5f4

Create vllm_streaming_inference.py

Browse files
Files changed (1) hide show
  1. vllm_streaming_inference.py +565 -0
vllm_streaming_inference.py ADDED
@@ -0,0 +1,565 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Maya-1-Voice VLLM Streaming Inference - Standalone Reference Implementation
3
+
4
+ This is a complete, self-contained example for using Maya-1-Voice TTS model with VLLM and SNAC.
5
+ Demonstrates streaming audio generation with sliding window approach for smooth playback.
6
+
7
+ Requirements:
8
+ pip install vllm transformers torch snac numpy
9
+
10
+ Usage:
11
+ python vllm_streaming_inference.py
12
+
13
+ Author: Maya-1-Voice Team
14
+ License: MIT
15
+ """
16
+
17
+ import torch
18
+ import numpy as np
19
+ import asyncio
20
+ from typing import List, Optional, AsyncGenerator
21
+ from transformers import AutoTokenizer
22
+ from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams
23
+ from snac import SNAC
24
+
25
+
26
+ # ============================================================================
27
+ # CONSTANTS
28
+ # ============================================================================
29
+
30
+ # Special control tokens
31
+ CODE_START_TOKEN_ID = 128257 # Start of Speech (SOS)
32
+ CODE_END_TOKEN_ID = 128258 # End of Speech (EOS) - stop token for audio
33
+ CODE_TOKEN_OFFSET = 128266 # Start of SNAC codes
34
+
35
+ # SNAC token range (7 tokens per frame, 4096 codes per level)
36
+ SNAC_MIN_ID = 128266
37
+ SNAC_MAX_ID = 156937 # 128266 + (7 * 4096) - 1
38
+
39
+ # SNAC configuration
40
+ SNAC_MODEL_NAME = "hubertsiuzdak/snac_24khz"
41
+ SNAC_SAMPLE_RATE = 24000
42
+ SNAC_TOKENS_PER_FRAME = 7
43
+
44
+ # Generation parameters
45
+ DEFAULT_TEMPERATURE = 0.4
46
+ DEFAULT_TOP_P = 0.9
47
+ DEFAULT_MAX_TOKENS = 2000
48
+ DEFAULT_MIN_TOKENS = 28 # At least 4 SNAC frames
49
+ DEFAULT_REPETITION_PENALTY = 1.1
50
+
51
+
52
+ # ============================================================================
53
+ # SNAC DECODER
54
+ # ============================================================================
55
+
56
+ class SNACDecoder:
57
+ """
58
+ Decodes SNAC tokens (7-token frames) to audio waveforms.
59
+
60
+ The unpacking logic converts flat 7-token frames back to hierarchical
61
+ 3-level SNAC codes (matching the training preprocessing exactly).
62
+ """
63
+
64
+ def __init__(self, device: str = "cuda"):
65
+ """Initialize SNAC decoder with 24kHz model."""
66
+ self.device = device
67
+ print(f"🎡 Loading SNAC 24kHz model to {device}...")
68
+ self.snac_model = SNAC.from_pretrained(SNAC_MODEL_NAME).eval().to(device)
69
+ print(f"βœ… SNAC decoder initialized")
70
+
71
+ def unpack_snac_from_7(self, vocab_ids: List[int]) -> List[List[int]]:
72
+ """
73
+ Unpack 7-token SNAC frames to 3 hierarchical levels.
74
+
75
+ This is the EXACT INVERSE of training preprocessing.
76
+
77
+ Frame structure (7 tokens per frame):
78
+ [slot0, slot1, slot2, slot3, slot4, slot5, slot6]
79
+
80
+ Unpacking to [L1, L2, L3]:
81
+ - slot0 β†’ L1[i] (coarse: 1x rate)
82
+ - slot1 β†’ L2[2*i] (medium: 2x rate, even)
83
+ - slot2 β†’ L3[4*i+0] (fine: 4x rate)
84
+ - slot3 β†’ L3[4*i+1]
85
+ - slot4 β†’ L2[2*i+1] (medium: odd)
86
+ - slot5 β†’ L3[4*i+2]
87
+ - slot6 β†’ L3[4*i+3]
88
+
89
+ Args:
90
+ vocab_ids: List of SNAC token IDs (128266-156937), length divisible by 7
91
+
92
+ Returns:
93
+ [L1, L2, L3] where L1=n, L2=2n, L3=4n elements
94
+ """
95
+ # Remove EOS token if present
96
+ if vocab_ids and vocab_ids[-1] == CODE_END_TOKEN_ID:
97
+ vocab_ids = vocab_ids[:-1]
98
+
99
+ # Ensure complete frames
100
+ frames = len(vocab_ids) // SNAC_TOKENS_PER_FRAME
101
+ vocab_ids = vocab_ids[:frames * SNAC_TOKENS_PER_FRAME]
102
+
103
+ if frames == 0:
104
+ return [[], [], []]
105
+
106
+ l1, l2, l3 = [], [], []
107
+
108
+ for i in range(frames):
109
+ slots = vocab_ids[i*7:(i+1)*7]
110
+
111
+ # Subtract offset and mod 4096 to get original SNAC codes
112
+ l1.append((slots[0] - CODE_TOKEN_OFFSET) % 4096)
113
+ l2.extend([
114
+ (slots[1] - CODE_TOKEN_OFFSET) % 4096, # Even
115
+ (slots[4] - CODE_TOKEN_OFFSET) % 4096, # Odd
116
+ ])
117
+ l3.extend([
118
+ (slots[2] - CODE_TOKEN_OFFSET) % 4096,
119
+ (slots[3] - CODE_TOKEN_OFFSET) % 4096,
120
+ (slots[5] - CODE_TOKEN_OFFSET) % 4096,
121
+ (slots[6] - CODE_TOKEN_OFFSET) % 4096,
122
+ ])
123
+
124
+ return [l1, l2, l3]
125
+
126
+ @torch.inference_mode()
127
+ def decode(
128
+ self,
129
+ snac_tokens: List[int],
130
+ use_sliding_window: bool = False
131
+ ) -> Optional[np.ndarray]:
132
+ """
133
+ Decode SNAC tokens to audio waveform.
134
+
135
+ Args:
136
+ snac_tokens: List of SNAC token IDs (7*n tokens)
137
+ use_sliding_window: If True, return only middle 2048 samples
138
+ (for smooth streaming without pops/clicks)
139
+
140
+ Returns:
141
+ Audio waveform as float32 numpy array, 24kHz mono
142
+ """
143
+ if len(snac_tokens) < SNAC_TOKENS_PER_FRAME:
144
+ return None
145
+
146
+ # Unpack to 3 hierarchical levels
147
+ levels = self.unpack_snac_from_7(snac_tokens)
148
+
149
+ if not levels[0]:
150
+ return None
151
+
152
+ # Convert to tensors
153
+ codes = [
154
+ torch.tensor(level, dtype=torch.long, device=self.device).unsqueeze(0)
155
+ for level in levels
156
+ ]
157
+
158
+ # Decode through SNAC quantizer + decoder
159
+ z_q = self.snac_model.quantizer.from_codes(codes)
160
+ audio = self.snac_model.decoder(z_q)
161
+
162
+ # Extract audio: [batch, 1, samples] β†’ [samples]
163
+ audio = audio[0, 0].cpu().numpy()
164
+
165
+ # Sliding window mode: keep middle 2048 samples only
166
+ # This eliminates popping/cracking in streaming by overlapping windows
167
+ if use_sliding_window and len(audio) >= 4096:
168
+ audio = audio[2048:4096]
169
+
170
+ return audio
171
+
172
+ def decode_to_bytes(
173
+ self,
174
+ snac_tokens: List[int],
175
+ use_sliding_window: bool = False
176
+ ) -> Optional[bytes]:
177
+ """
178
+ Decode SNAC tokens to audio bytes (int16 PCM).
179
+
180
+ Args:
181
+ snac_tokens: List of SNAC token IDs
182
+ use_sliding_window: Use sliding window for smooth streaming
183
+
184
+ Returns:
185
+ Audio as bytes (int16 PCM, 24kHz mono)
186
+ """
187
+ audio = self.decode(snac_tokens, use_sliding_window=use_sliding_window)
188
+
189
+ if audio is None:
190
+ return None
191
+
192
+ # Convert float32 to int16 PCM
193
+ audio_int16 = (audio * 32767).astype(np.int16)
194
+ return audio_int16.tobytes()
195
+
196
+
197
+ # ============================================================================
198
+ # CUSTOM LOGITS PROCESSOR
199
+ # ============================================================================
200
+
201
+ class OnlyAudioAfterSOS:
202
+ """
203
+ Restricts vocabulary to SNAC codes + EOS after SOS token.
204
+
205
+ This prevents the model from generating text tokens during audio phase,
206
+ which would cause "hallucination" where the model repeats description text
207
+ instead of generating proper audio codes.
208
+ """
209
+
210
+ def __init__(self):
211
+ self._seen_sos = False
212
+
213
+ def __call__(
214
+ self,
215
+ prompt_token_ids: List[int],
216
+ generated_token_ids: List[int],
217
+ logits: torch.Tensor,
218
+ ) -> torch.Tensor:
219
+ """
220
+ Apply constraint: after SOS, only allow SNAC codes + EOS.
221
+
222
+ Args:
223
+ prompt_token_ids: Original prompt token IDs
224
+ generated_token_ids: Tokens generated so far
225
+ logits: Logits for next token [vocab_size]
226
+
227
+ Returns:
228
+ Modified logits with masked tokens
229
+ """
230
+ # Check if SOS has been generated
231
+ if not self._seen_sos:
232
+ all_token_ids = prompt_token_ids + generated_token_ids
233
+ if CODE_START_TOKEN_ID in all_token_ids:
234
+ self._seen_sos = True
235
+ else:
236
+ return logits # No constraint yet
237
+
238
+ # Apply constraint: mask all tokens except SNAC codes + EOS
239
+ mask = torch.full_like(logits, float('-inf'))
240
+ mask[SNAC_MIN_ID:SNAC_MAX_ID + 1] = 0 # Allow SNAC codes
241
+ mask[CODE_END_TOKEN_ID] = 0 # Allow EOS
242
+
243
+ return logits + mask
244
+
245
+ def reset(self):
246
+ """Reset state for reuse across generations."""
247
+ self._seen_sos = False
248
+
249
+
250
+ # ============================================================================
251
+ # MAYA-1-VOICE MODEL
252
+ # ============================================================================
253
+
254
+ class Maya1VoiceModel:
255
+ """
256
+ Maya-1-Voice TTS Model with VLLM inference engine.
257
+
258
+ Handles model loading, tokenizer initialization, and VLLM engine setup.
259
+ """
260
+
261
+ def __init__(
262
+ self,
263
+ model_path: str,
264
+ dtype: str = "bfloat16",
265
+ max_model_len: int = 8192,
266
+ gpu_memory_utilization: float = 0.85,
267
+ ):
268
+ """
269
+ Initialize Maya-1-Voice model with VLLM.
270
+
271
+ Args:
272
+ model_path: Path to model checkpoint (local or HuggingFace)
273
+ dtype: Model precision (bfloat16 recommended)
274
+ max_model_len: Maximum sequence length
275
+ gpu_memory_utilization: GPU memory fraction to use (0.0-1.0)
276
+ """
277
+ self.model_path = model_path
278
+
279
+ print(f"πŸš€ Initializing Maya-1-Voice Model")
280
+ print(f"πŸ“ Model: {model_path}")
281
+ print(f"πŸ”’ Dtype: {dtype}")
282
+
283
+ # Load tokenizer (must be from checkpoint with emotion tags)
284
+ print(f"πŸ“ Loading tokenizer...")
285
+ self.tokenizer = AutoTokenizer.from_pretrained(
286
+ model_path,
287
+ trust_remote_code=True,
288
+ )
289
+ print(f"βœ… Tokenizer loaded: {len(self.tokenizer)} tokens")
290
+
291
+ # Initialize VLLM async engine
292
+ print(f"πŸ”§ Initializing VLLM engine...")
293
+ engine_args = AsyncEngineArgs(
294
+ model=model_path,
295
+ tokenizer=model_path,
296
+ dtype=dtype,
297
+ max_model_len=max_model_len,
298
+ gpu_memory_utilization=gpu_memory_utilization,
299
+ trust_remote_code=True,
300
+ )
301
+
302
+ self.engine = AsyncLLMEngine.from_engine_args(engine_args)
303
+ print(f"βœ… VLLM engine ready")
304
+
305
+ def build_prompt(self, description: str, text: str) -> str:
306
+ """
307
+ Build prompt in Maya-1-Voice format.
308
+
309
+ Format: <description="..."> text
310
+
311
+ The model expects:
312
+ 1. Description of voice/character
313
+ 2. Text to synthesize (optionally with <emotion> tags)
314
+
315
+ Args:
316
+ description: Voice description
317
+ Example: "Realistic male voice in the 30s age with american accent.
318
+ Normal pitch, warm timbre, conversational pacing."
319
+ text: Text to synthesize
320
+ Example: "Hello world! <excited> This is amazing!"
321
+
322
+ Returns:
323
+ Formatted prompt string
324
+ """
325
+ return f'<description="{description}"> {text}'
326
+
327
+
328
+ # ============================================================================
329
+ # STREAMING PIPELINE
330
+ # ============================================================================
331
+
332
+ class Maya1VoiceStreamingPipeline:
333
+ """
334
+ Streaming TTS pipeline using sliding window approach.
335
+
336
+ This generates smooth audio by:
337
+ 1. Streaming tokens from VLLM as they're generated
338
+ 2. Every 7 tokens, decoding the last 28 tokens (4 frames) - sliding window
339
+ 3. Keeping only middle 2048 samples from each decode
340
+ 4. Creating natural overlap between chunks for artifact-free playback
341
+ """
342
+
343
+ def __init__(self, model: Maya1VoiceModel, snac_decoder: SNACDecoder):
344
+ """Initialize streaming pipeline."""
345
+ self.model = model
346
+ self.snac_decoder = snac_decoder
347
+ print(f"🌊 Maya-1-Voice Streaming Pipeline initialized")
348
+
349
+ async def generate_speech_stream(
350
+ self,
351
+ description: str,
352
+ text: str,
353
+ temperature: float = DEFAULT_TEMPERATURE,
354
+ top_p: float = DEFAULT_TOP_P,
355
+ max_tokens: int = DEFAULT_MAX_TOKENS,
356
+ repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
357
+ ) -> AsyncGenerator[bytes, None]:
358
+ """
359
+ Generate speech audio with streaming.
360
+
361
+ Args:
362
+ description: Voice/character description
363
+ text: Text to synthesize (with optional <emotion> tags)
364
+ temperature: Sampling temperature (lower = more stable)
365
+ top_p: Nucleus sampling
366
+ max_tokens: Max SNAC tokens to generate
367
+ repetition_penalty: Prevent repetition loops
368
+
369
+ Yields:
370
+ Audio chunks as bytes (int16 PCM, 24kHz mono)
371
+ """
372
+ print(f"\n🌊 Starting streaming generation")
373
+ print(f"πŸ“ Description: {description[:80]}...")
374
+ print(f"πŸ’¬ Text: {text}")
375
+
376
+ # Build prompt
377
+ prompt = self.model.build_prompt(description, text)
378
+
379
+ # Configure sampling with custom logits processor
380
+ logits_processor = OnlyAudioAfterSOS()
381
+
382
+ sampling_params = SamplingParams(
383
+ temperature=temperature,
384
+ top_p=top_p,
385
+ max_tokens=max_tokens,
386
+ min_tokens=DEFAULT_MIN_TOKENS,
387
+ repetition_penalty=repetition_penalty,
388
+ stop_token_ids=[CODE_END_TOKEN_ID], # Stop on audio EOS
389
+ logits_processors=[logits_processor], # Constrain to audio tokens
390
+ )
391
+
392
+ print(f"🎲 Sampling: temp={temperature}, top_p={top_p}, max_tokens={max_tokens}")
393
+
394
+ # Token buffer for sliding window
395
+ token_buffer = []
396
+ total_tokens = 0
397
+ total_chunks = 0
398
+
399
+ # Generate with VLLM
400
+ import uuid
401
+ import time
402
+ request_id = f"maya1voice-{uuid.uuid4().hex[:8]}-{int(time.time() * 1000000)}"
403
+
404
+ results_generator = self.model.engine.generate(
405
+ prompt=prompt,
406
+ sampling_params=sampling_params,
407
+ request_id=request_id,
408
+ )
409
+
410
+ # Stream tokens with sliding window decoding
411
+ async for request_output in results_generator:
412
+ generated_ids = request_output.outputs[0].token_ids
413
+
414
+ # Process only new tokens
415
+ new_tokens = generated_ids[total_tokens:]
416
+ total_tokens = len(generated_ids)
417
+
418
+ # Filter and buffer SNAC tokens only
419
+ for token_id in new_tokens:
420
+ if SNAC_MIN_ID <= token_id <= SNAC_MAX_ID:
421
+ token_buffer.append(token_id)
422
+
423
+ # Sliding window: process every 7 tokens when buffer > 27
424
+ # Take last 28 tokens (4 frames) for smooth overlap
425
+ if len(token_buffer) % 7 == 0 and len(token_buffer) > 27:
426
+ window_tokens = token_buffer[-28:]
427
+
428
+ # Decode with sliding window (returns middle 2048 samples)
429
+ audio_bytes = self.snac_decoder.decode_to_bytes(
430
+ window_tokens,
431
+ use_sliding_window=True
432
+ )
433
+
434
+ if audio_bytes:
435
+ total_chunks += 1
436
+ if total_chunks == 1:
437
+ print(f"🎡 First chunk decoded ({len(audio_bytes)} bytes)")
438
+ yield audio_bytes
439
+
440
+ print(f"βœ… Streaming complete: {total_tokens} tokens β†’ {total_chunks} chunks")
441
+
442
+ # Reset logits processor for next generation
443
+ logits_processor.reset()
444
+
445
+
446
+ # ============================================================================
447
+ # MAIN EXAMPLE
448
+ # ============================================================================
449
+
450
+ async def main():
451
+ """
452
+ Example usage of Maya-1-Voice streaming inference.
453
+
454
+ This demonstrates:
455
+ 1. Model initialization
456
+ 2. SNAC decoder setup
457
+ 3. Streaming generation
458
+ 4. Audio chunk handling
459
+ """
460
+
461
+ # Configuration
462
+ MODEL_PATH = "maya-research/maya-1-voice" # Update with your model
463
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
464
+
465
+ print("=" * 80)
466
+ print("Maya-1-Voice VLLM Streaming Inference Example")
467
+ print("=" * 80)
468
+
469
+ # Initialize model
470
+ model = Maya1VoiceModel(
471
+ model_path=MODEL_PATH,
472
+ dtype="bfloat16",
473
+ max_model_len=8192,
474
+ gpu_memory_utilization=0.85,
475
+ )
476
+
477
+ # Initialize SNAC decoder
478
+ snac_decoder = SNACDecoder(device=DEVICE)
479
+
480
+ # Create pipeline
481
+ pipeline = Maya1VoiceStreamingPipeline(model, snac_decoder)
482
+
483
+ # Example 1: Professional voice
484
+ description = (
485
+ "Realistic male voice in the 30s age with american accent. "
486
+ "Normal pitch, warm timbre, conversational pacing, neutral tone delivery at med intensity."
487
+ )
488
+ text = "Hello! This is a test of the Maya-1-Voice text-to-speech system."
489
+
490
+ print(f"\n{'='*80}")
491
+ print("Example 1: Professional Voice")
492
+ print(f"{'='*80}")
493
+
494
+ audio_chunks = []
495
+ async for chunk in pipeline.generate_speech_stream(
496
+ description=description,
497
+ text=text,
498
+ temperature=0.4,
499
+ max_tokens=500,
500
+ ):
501
+ audio_chunks.append(chunk)
502
+ print(f"πŸ“¦ Received chunk {len(audio_chunks)}: {len(chunk)} bytes")
503
+
504
+ # Combine chunks
505
+ full_audio = b''.join(audio_chunks)
506
+ print(f"\nβœ… Total audio: {len(full_audio)} bytes ({len(full_audio)//2} samples, {len(full_audio)/2/24000:.2f}s)")
507
+
508
+ # Save audio (optional)
509
+ try:
510
+ import wave
511
+ output_file = "output_example1.wav"
512
+ with wave.open(output_file, 'wb') as wav:
513
+ wav.setnchannels(1) # Mono
514
+ wav.setsampwidth(2) # 16-bit
515
+ wav.setframerate(24000) # 24kHz
516
+ wav.writeframes(full_audio)
517
+ print(f"πŸ’Ύ Saved to {output_file}")
518
+ except ImportError:
519
+ print(f"⚠️ Install 'wave' module to save audio files")
520
+
521
+ # Example 2: Character voice with emotions
522
+ print(f"\n{'='*80}")
523
+ print("Example 2: Character Voice with Emotions")
524
+ print(f"{'='*80}")
525
+
526
+ description = (
527
+ "Creative, dark_villain character. Male voice in their 40s with british accent. "
528
+ "Low pitch, gravelly timbre, slow pacing, angry tone at high intensity."
529
+ )
530
+ text = "The darkness isn't coming... <angry> it's already here!"
531
+
532
+ audio_chunks = []
533
+ async for chunk in pipeline.generate_speech_stream(
534
+ description=description,
535
+ text=text,
536
+ temperature=0.5,
537
+ max_tokens=800,
538
+ ):
539
+ audio_chunks.append(chunk)
540
+ print(f"πŸ“¦ Received chunk {len(audio_chunks)}: {len(chunk)} bytes")
541
+
542
+ full_audio = b''.join(audio_chunks)
543
+ print(f"\nβœ… Total audio: {len(full_audio)} bytes ({len(full_audio)//2} samples, {len(full_audio)/2/24000:.2f}s)")
544
+
545
+ # Save audio
546
+ try:
547
+ import wave
548
+ output_file = "output_example2.wav"
549
+ with wave.open(output_file, 'wb') as wav:
550
+ wav.setnchannels(1)
551
+ wav.setsampwidth(2)
552
+ wav.setframerate(24000)
553
+ wav.writeframes(full_audio)
554
+ print(f"πŸ’Ύ Saved to {output_file}")
555
+ except ImportError:
556
+ pass
557
+
558
+ print(f"\n{'='*80}")
559
+ print("πŸŽ‰ Examples complete!")
560
+ print(f"{'='*80}")
561
+
562
+
563
+ if __name__ == "__main__":
564
+ # Run async main
565
+ asyncio.run(main())