Gabriel Bibbó commited on
Commit
a21e04b
·
1 Parent(s): 43be67f

GitHub-faithful implementation - 32kHz, 2048 FFT, per-model delays, 80ms gaps

Browse files
Files changed (1) hide show
  1. app.py +165 -374
app.py CHANGED
@@ -101,6 +101,10 @@ class OptimizedSileroVAD:
101
  print(f"❌ Error loading {self.model_name}: {e}")
102
  self.model = None
103
 
 
 
 
 
104
  def predict(self, audio: np.ndarray, timestamp: float = 0.0) -> VADResult:
105
  start_time = time.time()
106
 
@@ -108,21 +112,11 @@ class OptimizedSileroVAD:
108
  return VADResult(0.0, False, f"{self.model_name} (unavailable)", time.time() - start_time, timestamp)
109
 
110
  try:
111
- if len(audio.shape) > 1:
112
- audio = audio.mean(axis=1)
113
-
114
- # Silero expects chunks of 512 samples for 16kHz
115
- required_samples = 512
116
- if len(audio) != required_samples:
117
- if len(audio) > required_samples:
118
- start_idx = (len(audio) - required_samples) // 2
119
- audio_chunk = audio[start_idx:start_idx + required_samples]
120
- else:
121
- audio_chunk = np.pad(audio, (0, required_samples - len(audio)), 'constant')
122
- else:
123
- audio_chunk = audio
124
 
125
- audio_tensor = torch.FloatTensor(audio_chunk).unsqueeze(0)
 
 
126
 
127
  with torch.no_grad():
128
  speech_prob = self.model(audio_tensor, self.sample_rate).item()
@@ -133,45 +127,35 @@ class OptimizedSileroVAD:
133
  return VADResult(speech_prob, is_speech, self.model_name, processing_time, timestamp)
134
 
135
  except Exception as e:
136
- print(f"Error in {self.model_name}: {e}")
137
  return VADResult(0.0, False, self.model_name, time.time() - start_time, timestamp)
138
 
139
  class OptimizedWebRTCVAD:
140
  def __init__(self):
141
  self.model_name = "WebRTC-VAD"
142
  self.sample_rate = 16000
143
- self.frame_duration = 30 # Valid frame size: 10, 20, or 30 ms
144
  self.frame_size = int(self.sample_rate * self.frame_duration / 1000)
145
 
146
  if WEBRTC_AVAILABLE:
147
  try:
148
- self.vad = webrtcvad.Vad(3) # Aggressiveness level 3
149
  print(f"✅ {self.model_name} loaded successfully")
150
- except:
151
- self.vad = None
152
- else:
153
- self.vad = None
154
 
155
  def predict(self, audio: np.ndarray, timestamp: float = 0.0) -> VADResult:
156
  start_time = time.time()
157
 
158
  if self.vad is None or len(audio) == 0:
159
- energy = np.sum(audio ** 2) if len(audio) > 0 else 0
160
- threshold = 0.01
161
- probability = min(energy / threshold, 1.0)
162
- is_speech = energy > threshold
163
- return VADResult(probability, is_speech, f"{self.model_name} (fallback)", time.time() - start_time, timestamp)
164
 
165
  try:
166
- if len(audio.shape) > 1:
167
- audio = audio.mean(axis=1)
168
-
169
  audio_int16 = (audio * 32767).astype(np.int16)
170
 
171
- speech_frames = 0
172
- total_frames = 0
173
 
174
- # Corrected loop to process the last complete frame
175
  for i in range(0, len(audio_int16) - self.frame_size + 1, self.frame_size):
176
  frame = audio_int16[i:i + self.frame_size].tobytes()
177
  if self.vad.is_speech(frame, self.sample_rate):
@@ -179,48 +163,37 @@ class OptimizedWebRTCVAD:
179
  total_frames += 1
180
 
181
  probability = speech_frames / max(total_frames, 1)
182
- is_speech = probability > 0.3 # Default threshold for WebRTC
183
 
184
  return VADResult(probability, is_speech, self.model_name, time.time() - start_time, timestamp)
185
 
186
  except Exception as e:
187
- print(f"Error in {self.model_name}: {e}")
188
  return VADResult(0.0, False, self.model_name, time.time() - start_time, timestamp)
189
 
190
  class OptimizedEPANNs:
191
  def __init__(self):
192
  self.model_name = "E-PANNs"
193
- self.sample_rate = 16000 # Works with the main sample rate
194
  print(f"✅ {self.model_name} initialized")
195
 
196
  def predict(self, audio: np.ndarray, timestamp: float = 0.0) -> VADResult:
197
  start_time = time.time()
 
198
 
199
  try:
200
- if len(audio) == 0:
201
- return VADResult(0.0, False, self.model_name, time.time() - start_time, timestamp)
202
-
203
- if len(audio.shape) > 1:
204
- audio = audio.mean(axis=1)
205
-
206
  if LIBROSA_AVAILABLE:
207
  mel_spec = librosa.feature.melspectrogram(y=audio, sr=self.sample_rate, n_mels=64)
208
  energy = np.mean(librosa.power_to_db(mel_spec, ref=np.max))
209
- spectral_centroid = np.mean(librosa.feature.spectral_centroid(y=audio, sr=self.sample_rate))
210
- speech_score = (energy + 100) / 50 + spectral_centroid / 10000
211
  else:
212
  from scipy import signal
213
- f, t, Sxx = signal.spectrogram(audio, self.sample_rate)
214
  energy = np.mean(10 * np.log10(Sxx + 1e-10))
215
- speech_score = (energy + 100) / 50
216
-
217
  probability = np.clip(speech_score, 0, 1)
218
- is_speech = probability > 0.6
219
-
220
- return VADResult(probability, is_speech, self.model_name, time.time() - start_time, timestamp)
221
 
 
222
  except Exception as e:
223
- print(f"Error in {self.model_name}: {e}")
224
  return VADResult(0.0, False, self.model_name, time.time() - start_time, timestamp)
225
 
226
  class OptimizedPANNs:
@@ -237,47 +210,32 @@ class OptimizedPANNs:
237
  if PANNS_AVAILABLE:
238
  self.model = AudioTagging(checkpoint_path=None, device=self.device)
239
  print(f"✅ {self.model_name} loaded successfully")
240
- else:
241
- print(f"⚠️ {self.model_name} not available, using fallback")
242
- self.model = None
243
  except Exception as e:
244
  print(f"❌ Error loading {self.model_name}: {e}")
245
  self.model = None
246
 
247
  def predict(self, audio: np.ndarray, timestamp: float = 0.0) -> VADResult:
248
- if timestamp > 0 and self.cached_clip_prob is not None:
249
- return VADResult(self.cached_clip_prob,
250
- self.cached_clip_prob > 0.5,
251
- self.model_name, 0.0, timestamp)
252
 
253
  start_time = time.time()
254
-
255
  if self.model is None or len(audio) == 0:
256
  return VADResult(0.0, False, f"{self.model_name} (fallback)", time.time() - start_time, timestamp)
257
 
258
  try:
259
- if len(audio.shape) > 1:
260
- audio = audio.mean(axis=1)
261
-
262
- # Correctly calculate probability using all speech-related labels
263
- clip_probs, _ = self.model.inference(audio[np.newaxis, :],
264
- input_sr=self.sample_rate) # API 1.3
265
-
266
- speech_idx = [i for i, lbl in enumerate(labels)
267
- if 'speech' in lbl.lower() or 'voice' in lbl.lower()]
268
- if not speech_idx:
269
- speech_idx = [labels.index('Speech')]
270
 
271
  speech_prob = clip_probs[0, speech_idx].mean().item()
272
  self.cached_clip_prob = float(speech_prob)
273
- return VADResult(self.cached_clip_prob,
274
- self.cached_clip_prob > 0.5,
275
- self.model_name,
276
- time.time() - start_time,
277
- timestamp)
278
 
 
279
  except Exception as e:
280
- print(f"Error in {self.model_name}: {e}")
281
  return VADResult(0.0, False, f"{self.model_name} (error)", time.time() - start_time, timestamp)
282
 
283
  class OptimizedAST:
@@ -293,56 +251,37 @@ class OptimizedAST:
293
  def load_model(self):
294
  try:
295
  if AST_AVAILABLE:
296
- model_name = "MIT/ast-finetuned-audioset-10-10-0.4593"
297
- self.feature_extractor = ASTFeatureExtractor.from_pretrained(model_name)
298
- self.model = ASTForAudioClassification.from_pretrained(model_name)
299
- self.model.to(self.device)
300
- self.model.eval()
301
  print(f"✅ {self.model_name} loaded successfully")
302
- else:
303
- print(f"⚠️ {self.model_name} not available, using fallback")
304
- self.model = None
305
  except Exception as e:
306
  print(f"❌ Error loading {self.model_name}: {e}")
307
  self.model = None
308
 
309
  def predict(self, audio: np.ndarray, timestamp: float = 0.0) -> VADResult:
310
- if timestamp > 0 and self.cached_clip_prob is not None:
311
- return VADResult(self.cached_clip_prob,
312
- self.cached_clip_prob > 0.5,
313
- self.model_name, 0.0, timestamp)
314
 
315
  start_time = time.time()
316
-
317
- if self.model is None or len(audio) == 0:
318
  return VADResult(0.0, False, f"{self.model_name} (fallback)", time.time() - start_time, timestamp)
319
 
320
  try:
321
- if len(audio.shape) > 1:
322
- audio = audio.mean(axis=1)
323
-
324
- inputs = self.feature_extractor(audio, sampling_rate=self.sample_rate, return_tensors="pt")
325
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
326
-
327
  with torch.no_grad():
328
- outputs = self.model(**inputs)
329
- logits = outputs.logits
330
- probs = torch.sigmoid(logits)
331
 
332
- # Correctly calculate probability using the model's label mapping
333
  label2id = self.model.config.label2id
334
- speech_idx = [idx for lbl, idx in label2id.items()
335
- if 'speech' in lbl.lower() or 'voice' in lbl.lower()]
336
  speech_prob = probs[0, speech_idx].mean().item()
337
  self.cached_clip_prob = float(speech_prob)
338
- return VADResult(self.cached_clip_prob,
339
- self.cached_clip_prob > 0.5,
340
- self.model_name,
341
- time.time() - start_time,
342
- timestamp)
343
 
 
344
  except Exception as e:
345
- print(f"Error in {self.model_name}: {e}")
346
  return VADResult(0.0, False, f"{self.model_name} (error)", time.time() - start_time, timestamp)
347
 
348
  # ===== AUDIO PROCESSOR =====
@@ -351,333 +290,185 @@ class AudioProcessor:
351
  def __init__(self, sample_rate=16000):
352
  self.sample_rate = sample_rate
353
 
354
- # Corrected STFT parameters for better temporal resolution
355
- self.n_fft = 1024 # 64 ms window @ 16 kHz
356
- self.hop_length = 256 # 16 ms hop (win/4), Librosa recommendation
 
 
 
357
  self.n_mels = 128
358
  self.fmin = 20
359
  self.fmax = 8000
360
 
361
- # Corrected windowing for lightweight models
362
- self.window_size = 0.048 # 48 ms
363
- self.hop_size = 0.024 # 24 ms
364
-
365
- self.delay_compensation = 0.0
366
- self.correlation_threshold = 0.7
367
-
368
  def process_audio(self, audio):
369
- if audio is None:
370
- return np.array([])
371
-
372
  try:
373
- if isinstance(audio, tuple):
374
- sample_rate, audio_data = audio
375
- if sample_rate != self.sample_rate and LIBROSA_AVAILABLE:
376
- audio_data = librosa.resample(audio_data.astype(float),
377
- orig_sr=sample_rate,
378
- target_sr=self.sample_rate)
379
- else:
380
- audio_data = audio
381
-
382
- if len(audio_data.shape) > 1:
383
- audio_data = audio_data.mean(axis=1)
384
-
385
- if np.max(np.abs(audio_data)) > 0:
386
- audio_data = audio_data / np.max(np.abs(audio_data))
387
-
388
  return audio_data
389
-
390
  except Exception as e:
391
- print(f"Audio processing error: {e}")
392
  return np.array([])
393
 
394
  def compute_high_res_spectrogram(self, audio_data):
395
  try:
396
  if LIBROSA_AVAILABLE and len(audio_data) > 0:
397
- stft = librosa.stft(
398
- audio_data,
399
- n_fft=self.n_fft,
400
- hop_length=self.hop_length,
401
- win_length=self.n_fft,
402
- window='hann',
403
- center=False
404
- )
405
-
406
- power_spec = np.abs(stft) ** 2
407
-
408
- mel_basis = librosa.filters.mel(
409
- sr=self.sample_rate,
410
- n_fft=self.n_fft,
411
- n_mels=self.n_mels,
412
- fmin=self.fmin,
413
- fmax=self.fmax
414
- )
415
-
416
- mel_spec = np.dot(mel_basis, power_spec)
417
  mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
418
-
419
- time_frames = np.arange(mel_spec_db.shape[1]) * self.hop_length / self.sample_rate
420
-
421
  return mel_spec_db, time_frames
422
- else: # Fallback if Librosa is not available
423
- from scipy import signal
424
- f, t, Sxx = signal.spectrogram(
425
- audio_data,
426
- self.sample_rate,
427
- nperseg=self.n_fft,
428
- noverlap=self.n_fft - self.hop_length,
429
- window='hann'
430
- )
431
- mel_spec_db = 10 * np.log10(Sxx + 1e-10)
432
- return mel_spec_db, t
433
-
434
  except Exception as e:
435
- print(f"Spectrogram computation error: {e}")
436
- dummy_spec = np.zeros((self.n_mels, 200))
437
- dummy_time = np.linspace(0, len(audio_data) / self.sample_rate, 200)
438
- return dummy_spec, dummy_time
439
 
440
  def detect_onset_offset_advanced(self, vad_results: List[VADResult], threshold: float = 0.5) -> List[OnsetOffset]:
441
  onsets_offsets = []
 
442
 
443
- if len(vad_results) < 2:
444
- return onsets_offsets
445
-
446
- models = {}
447
- for result in vad_results:
448
- if result.model_name not in models:
449
- models[result.model_name] = []
450
- models[result.model_name].append(result)
451
-
452
- for model_name, results in models.items():
453
- if len(results) < 2:
454
- continue
455
-
456
- results.sort(key=lambda x: x.timestamp)
457
 
458
  timestamps = np.array([r.timestamp for r in results])
459
  probabilities = np.array([r.probability for r in results])
460
 
461
- # Hysteresis thresholding
462
- upper_thresh = threshold + 0.1
463
- lower_thresh = threshold - 0.1
464
-
465
- in_speech_segment = False
466
- current_onset_time = -1
467
-
468
- for i in range(len(results)):
469
- curr_prob = probabilities[i]
470
- curr_time = timestamps[i]
 
 
 
 
 
 
 
471
 
472
- if not in_speech_segment and curr_prob > upper_thresh:
473
- in_speech_segment = True
474
- current_onset_time = curr_time - self.delay_compensation
475
-
476
- elif in_speech_segment and curr_prob < lower_thresh:
477
- in_speech_segment = False
478
- if current_onset_time >= 0:
479
- offset_time = curr_time - self.delay_compensation
480
- onsets_offsets.append(OnsetOffset(
481
- onset_time=max(0, current_onset_time),
482
- offset_time=offset_time,
483
- model_name=model_name,
484
- confidence=np.mean(probabilities[(timestamps >= current_onset_time) & (timestamps <= offset_time)])
485
- ))
486
- current_onset_time = -1
487
-
488
- if in_speech_segment and current_onset_time >= 0:
489
- onsets_offsets.append(OnsetOffset(
490
- onset_time=max(0, current_onset_time),
491
- offset_time=timestamps[-1],
492
- model_name=model_name,
493
- confidence=np.mean(probabilities[timestamps >= current_onset_time])
494
- ))
495
-
496
  return onsets_offsets
497
 
498
- # ===== ENHANCED VISUALIZATION =====
499
 
500
  def create_realtime_plot(audio_data: np.ndarray, vad_results: List[VADResult],
501
  onsets_offsets: List[OnsetOffset], processor: AudioProcessor,
502
  model_a: str, model_b: str, threshold: float):
503
 
504
- if not PLOTLY_AVAILABLE or audio_data is None or len(audio_data) == 0:
505
- return go.Figure().update_layout(title="No data to display")
 
 
506
 
507
- try:
508
- mel_spec_db, time_frames = processor.compute_high_res_spectrogram(audio_data)
509
- freq_axis = np.linspace(processor.fmin, processor.fmax, processor.n_mels)
510
-
511
- fig = make_subplots(
512
- rows=2, cols=1,
513
- subplot_titles=(f"Model A: {model_a}", f"Model B: {model_b}"),
514
- vertical_spacing=0.05,
515
- shared_xaxes=True,
516
- specs=[[{"secondary_y": True}], [{"secondary_y": True}]]
517
- )
518
-
519
- # Shared heatmap settings
520
- heatmap_args = dict(
521
- z=mel_spec_db, x=time_frames, y=freq_axis,
522
- colorscale='Viridis', showscale=False,
523
- hovertemplate='Time: %{x:.2f}s<br>Freq: %{y:.0f}Hz<br>Power: %{z:.1f}dB<extra></extra>'
524
- )
525
-
526
- fig.add_trace(go.Heatmap(**heatmap_args, name=f'Spectrogram {model_a}'), row=1, col=1)
527
- fig.add_trace(go.Heatmap(**heatmap_args, name=f'Spectrogram {model_b}'), row=2, col=1)
528
-
529
- # Data separation
530
- model_a_data = {'times': [], 'probs': []}
531
- model_b_data = {'times': [], 'probs': []}
532
- for r in vad_results:
533
- if r.model_name.startswith(model_a):
534
- model_a_data['times'].append(r.timestamp)
535
- model_a_data['probs'].append(r.probability)
536
- elif r.model_name.startswith(model_b):
537
- model_b_data['times'].append(r.timestamp)
538
- model_b_data['probs'].append(r.probability)
539
-
540
- # Plotting probability curves on secondary Y-axis
541
- if model_a_data['times']:
542
- fig.add_trace(go.Scatter(x=model_a_data['times'], y=model_a_data['probs'], mode='lines',
543
- line=dict(color='yellow', width=3), name=f'{model_a} Probability'),
544
- row=1, col=1, secondary_y=True)
545
- if model_b_data['times']:
546
- fig.add_trace(go.Scatter(x=model_b_data['times'], y=model_b_data['probs'], mode='lines',
547
- line=dict(color='orange', width=3), name=f'{model_b} Probability'),
548
- row=2, col=1, secondary_y=True)
549
-
550
- # Onset/Offset markers
551
- for event in onsets_offsets:
552
- row_num = 1 if event.model_name.startswith(model_a) else 2 if event.model_name.startswith(model_b) else None
553
- if row_num:
554
- fig.add_vline(x=event.onset_time, line=dict(color='lime', width=3), annotation_text='▲', annotation_position="top", row=row_num, col=1)
555
- fig.add_vline(x=event.offset_time, line=dict(color='red', width=3), annotation_text='▼', annotation_position="bottom", row=row_num, col=1)
556
-
557
- # Layout and styling
558
- fig.update_layout(
559
- height=600, title_text="Real-Time Speech Visualizer", showlegend=True,
560
- legend=dict(x=1.05, y=1), plot_bgcolor='black', paper_bgcolor='white'
561
- )
562
- fig.update_xaxes(title_text="Time (seconds)", row=2, col=1)
563
- fig.update_yaxes(title_text="Frequency (Hz)", range=[processor.fmin, processor.fmax], row=1, col=1, secondary_y=False)
564
- fig.update_yaxes(title_text="Frequency (Hz)", range=[processor.fmin, processor.fmax], row=2, col=1, secondary_y=False)
565
-
566
- # Correctly configure secondary axes
567
- fig.update_yaxes(title_text="Probability", range=[0, 1], row=1, col=1, secondary_y=True)
568
- fig.update_yaxes(title_text="Probability", range=[0, 1], row=2, col=1, secondary_y=True)
569
-
570
- return fig
571
-
572
- except Exception as e:
573
- print(f"Visualization error: {e}")
574
- return go.Figure().update_layout(title=f"Visualization Error: {e}")
575
 
576
  # ===== MAIN APPLICATION =====
577
 
578
  class VADDemo:
579
  def __init__(self):
580
- print("🎤 Initializing Real-time VAD Demo with 5 models...")
581
  self.processor = AudioProcessor()
582
  self.models = {
583
- 'Silero-VAD': OptimizedSileroVAD(),
584
- 'WebRTC-VAD': OptimizedWebRTCVAD(),
585
- 'E-PANNs': OptimizedEPANNs(),
586
- 'PANNs': OptimizedPANNs(),
587
- 'AST': OptimizedAST()
588
  }
589
- print("🎤 Real-time VAD Demo initialized successfully")
590
 
591
  def process_audio_with_events(self, audio, model_a, model_b, threshold):
592
- if audio is None:
593
- return None, "🔇 No audio detected", "Ready to process audio..."
594
-
595
- try:
596
- # Reset cache for heavy models at the start of each new clip processing
597
- for m in ['PANNs', 'AST']:
598
- if m in self.models:
599
- self.models[m].cached_clip_prob = None
600
 
 
601
  processed_audio = self.processor.process_audio(audio)
602
- if len(processed_audio) == 0:
603
- return None, "🎵 Processing audio...", "No audio data processed"
604
 
605
- selected_models = list(set([model_a, model_b]))
606
-
607
- # Pre-compute heavy models once
608
- if 'PANNs' in selected_models:
609
- panns_model = self.models['PANNs']
610
- if LIBROSA_AVAILABLE:
611
- audio_32k = librosa.resample(processed_audio,
612
- orig_sr=self.processor.sample_rate,
613
- target_sr=panns_model.sample_rate)
614
- panns_model.predict(audio_32k, 0.0) # This populates the cache
615
-
616
- if 'AST' in selected_models:
617
- self.models['AST'].predict(processed_audio, 0.0) # This populates the cache
618
 
619
- # Process in windows
620
- window_samples = int(self.processor.sample_rate * self.processor.window_size)
621
- hop_samples = int(self.processor.sample_rate * self.processor.hop_size)
622
  vad_results = []
 
 
 
623
 
624
- for i in range(0, len(processed_audio) - window_samples, hop_samples):
625
  timestamp = i / self.processor.sample_rate
626
- chunk = processed_audio[i:i + window_samples]
627
 
628
- for model_name in selected_models:
629
- result = self.models[model_name].predict(chunk, timestamp)
 
 
 
 
 
 
 
 
 
 
630
  result.is_speech = result.probability > threshold
631
  vad_results.append(result)
632
 
633
  onsets_offsets = self.processor.detect_onset_offset_advanced(vad_results, threshold)
 
634
 
635
- fig = create_realtime_plot(processed_audio, vad_results, onsets_offsets,
636
- self.processor, model_a, model_b, threshold)
637
-
638
- status_msg = "🎙️ SPEECH DETECTED" if any(r.is_speech for r in vad_results) else "🔇 No speech detected"
639
-
640
- details_text = f"Analyzed {len(processed_audio)/self.processor.sample_rate:.2f}s of audio with threshold {threshold:.2f}."
641
 
642
  return fig, status_msg, details_text
643
-
644
  except Exception as e:
645
- print(f"Processing error: {e}")
646
  import traceback
 
647
  return None, f"❌ Error: {e}", traceback.format_exc()
648
 
649
- # Initialize demo
650
  demo_app = VADDemo()
651
-
652
- # ===== GRADIO INTERFACE =====
653
-
654
- def create_interface():
655
- with gr.Blocks(title="VAD Demo", theme=gr.themes.Soft()) as interface:
656
- gr.Markdown("# 🎤 VAD Demo: Real-time Speech Detection Framework v4")
657
-
658
- with gr.Row():
659
- with gr.Column(scale=1):
660
- gr.Markdown("### 🎛️ Controls")
661
- audio_input = gr.Audio(sources=["microphone"], type="numpy", label="Record or Upload Audio")
662
- model_a = gr.Dropdown(["Silero-VAD", "WebRTC-VAD", "E-PANNs", "PANNs", "AST"], value="Silero-VAD", label="Model A (Top Panel)")
663
- model_b = gr.Dropdown(["Silero-VAD", "WebRTC-VAD", "E-PANNs", "PANNs", "AST"], value="PANNs", label="Model B (Bottom Panel)")
664
- threshold_slider = gr.Slider(0.0, 1.0, value=0.5, step=0.01, label="Detection Threshold")
665
- process_btn = gr.Button("Analyze", variant="primary")
666
-
667
- with gr.Column(scale=3):
668
- gr.Markdown("### 📊 Visualization Dashboard")
669
- plot_output = gr.Plot(label="VAD Analysis")
670
- status_display = gr.Textbox(label="Status", interactive=False)
671
- details_output = gr.Textbox(label="Details", lines=5, interactive=False)
672
-
673
- process_btn.click(
674
- fn=demo_app.process_audio_with_events,
675
- inputs=[audio_input, model_a, model_b, threshold_slider],
676
- outputs=[plot_output, status_display, details_output]
677
- )
678
- return interface
679
-
680
- if __name__ == "__main__":
681
- print("🚀 Launching Gradio Interface...")
682
- interface = create_interface()
683
- interface.launch(share=True, debug=False)
 
101
  print(f"❌ Error loading {self.model_name}: {e}")
102
  self.model = None
103
 
104
+ def reset_states(self):
105
+ if self.model:
106
+ self.model.reset_states()
107
+
108
  def predict(self, audio: np.ndarray, timestamp: float = 0.0) -> VADResult:
109
  start_time = time.time()
110
 
 
112
  return VADResult(0.0, False, f"{self.model_name} (unavailable)", time.time() - start_time, timestamp)
113
 
114
  try:
115
+ if len(audio.shape) > 1: audio = audio.mean(axis=1)
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
+ # Silero expects a specific chunk size, which the main loop should provide.
118
+ # No padding or trimming here.
119
+ audio_tensor = torch.FloatTensor(audio).unsqueeze(0)
120
 
121
  with torch.no_grad():
122
  speech_prob = self.model(audio_tensor, self.sample_rate).item()
 
127
  return VADResult(speech_prob, is_speech, self.model_name, processing_time, timestamp)
128
 
129
  except Exception as e:
130
+ # This can happen if chunk size is wrong, which is now handled in main loop
131
  return VADResult(0.0, False, self.model_name, time.time() - start_time, timestamp)
132
 
133
  class OptimizedWebRTCVAD:
134
  def __init__(self):
135
  self.model_name = "WebRTC-VAD"
136
  self.sample_rate = 16000
137
+ self.frame_duration = 10 # 10, 20, or 30 ms. 10ms for higher granularity.
138
  self.frame_size = int(self.sample_rate * self.frame_duration / 1000)
139
 
140
  if WEBRTC_AVAILABLE:
141
  try:
142
+ self.vad = webrtcvad.Vad(3)
143
  print(f"✅ {self.model_name} loaded successfully")
144
+ except: self.vad = None
145
+ else: self.vad = None
 
 
146
 
147
  def predict(self, audio: np.ndarray, timestamp: float = 0.0) -> VADResult:
148
  start_time = time.time()
149
 
150
  if self.vad is None or len(audio) == 0:
151
+ return VADResult(0.0, False, f"{self.model_name} (fallback)", time.time() - start_time, timestamp)
 
 
 
 
152
 
153
  try:
154
+ if len(audio.shape) > 1: audio = audio.mean(axis=1)
 
 
155
  audio_int16 = (audio * 32767).astype(np.int16)
156
 
157
+ speech_frames, total_frames = 0, 0
 
158
 
 
159
  for i in range(0, len(audio_int16) - self.frame_size + 1, self.frame_size):
160
  frame = audio_int16[i:i + self.frame_size].tobytes()
161
  if self.vad.is_speech(frame, self.sample_rate):
 
163
  total_frames += 1
164
 
165
  probability = speech_frames / max(total_frames, 1)
166
+ is_speech = probability > 0.5
167
 
168
  return VADResult(probability, is_speech, self.model_name, time.time() - start_time, timestamp)
169
 
170
  except Exception as e:
 
171
  return VADResult(0.0, False, self.model_name, time.time() - start_time, timestamp)
172
 
173
  class OptimizedEPANNs:
174
  def __init__(self):
175
  self.model_name = "E-PANNs"
176
+ self.sample_rate = 16000
177
  print(f"✅ {self.model_name} initialized")
178
 
179
  def predict(self, audio: np.ndarray, timestamp: float = 0.0) -> VADResult:
180
  start_time = time.time()
181
+ if len(audio) == 0: return VADResult(0.0, False, self.model_name, time.time() - start_time, timestamp)
182
 
183
  try:
 
 
 
 
 
 
184
  if LIBROSA_AVAILABLE:
185
  mel_spec = librosa.feature.melspectrogram(y=audio, sr=self.sample_rate, n_mels=64)
186
  energy = np.mean(librosa.power_to_db(mel_spec, ref=np.max))
 
 
187
  else:
188
  from scipy import signal
189
+ _, _, Sxx = signal.spectrogram(audio, self.sample_rate)
190
  energy = np.mean(10 * np.log10(Sxx + 1e-10))
191
+
192
+ speech_score = (energy + 100) / 50
193
  probability = np.clip(speech_score, 0, 1)
 
 
 
194
 
195
+ return VADResult(probability, probability > 0.6, self.model_name, time.time() - start_time, timestamp)
196
  except Exception as e:
 
197
  return VADResult(0.0, False, self.model_name, time.time() - start_time, timestamp)
198
 
199
  class OptimizedPANNs:
 
210
  if PANNS_AVAILABLE:
211
  self.model = AudioTagging(checkpoint_path=None, device=self.device)
212
  print(f"✅ {self.model_name} loaded successfully")
213
+ else: self.model = None
 
 
214
  except Exception as e:
215
  print(f"❌ Error loading {self.model_name}: {e}")
216
  self.model = None
217
 
218
  def predict(self, audio: np.ndarray, timestamp: float = 0.0) -> VADResult:
219
+ if self.cached_clip_prob is not None:
220
+ return VADResult(self.cached_clip_prob, self.cached_clip_prob > 0.5, self.model_name, 0.0, timestamp)
 
 
221
 
222
  start_time = time.time()
 
223
  if self.model is None or len(audio) == 0:
224
  return VADResult(0.0, False, f"{self.model_name} (fallback)", time.time() - start_time, timestamp)
225
 
226
  try:
227
+ # Use clipwise_output for probabilities, not embeddings.
228
+ clip_probs, _ = self.model.inference(audio[np.newaxis, :], input_sr=self.sample_rate)
229
+
230
+ # Filter all speech/voice-related labels for a robust average.
231
+ speech_idx = [i for i, lbl in enumerate(labels) if 'speech' in lbl.lower() or 'voice' in lbl.lower()]
232
+ if not speech_idx: speech_idx = [labels.index('Speech')]
 
 
 
 
 
233
 
234
  speech_prob = clip_probs[0, speech_idx].mean().item()
235
  self.cached_clip_prob = float(speech_prob)
 
 
 
 
 
236
 
237
+ return VADResult(self.cached_clip_prob, self.cached_clip_prob > 0.5, self.model_name, time.time() - start_time, timestamp)
238
  except Exception as e:
 
239
  return VADResult(0.0, False, f"{self.model_name} (error)", time.time() - start_time, timestamp)
240
 
241
  class OptimizedAST:
 
251
  def load_model(self):
252
  try:
253
  if AST_AVAILABLE:
254
+ model_path = "MIT/ast-finetuned-audioset-10-10-0.4593"
255
+ self.feature_extractor = ASTFeatureExtractor.from_pretrained(model_path)
256
+ self.model = ASTForAudioClassification.from_pretrained(model_path).to(self.device).eval()
 
 
257
  print(f"✅ {self.model_name} loaded successfully")
258
+ else: self.model = None
 
 
259
  except Exception as e:
260
  print(f"❌ Error loading {self.model_name}: {e}")
261
  self.model = None
262
 
263
  def predict(self, audio: np.ndarray, timestamp: float = 0.0) -> VADResult:
264
+ if self.cached_clip_prob is not None:
265
+ return VADResult(self.cached_clip_prob, self.cached_clip_prob > 0.5, self.model_name, 0.0, timestamp)
 
 
266
 
267
  start_time = time.time()
268
+ if self.model is None or len(audio) < self.sample_rate * 2: # AST needs at least ~2s
 
269
  return VADResult(0.0, False, f"{self.model_name} (fallback)", time.time() - start_time, timestamp)
270
 
271
  try:
272
+ inputs = self.feature_extractor(audio, sampling_rate=self.sample_rate, return_tensors="pt").to(self.device)
 
 
 
 
 
273
  with torch.no_grad():
274
+ probs = torch.sigmoid(self.model(**inputs).logits)
 
 
275
 
276
+ # Use the model's config to find all speech-related labels
277
  label2id = self.model.config.label2id
278
+ speech_idx = [idx for lbl, idx in label2id.items() if 'speech' in lbl.lower() or 'voice' in lbl.lower()]
279
+
280
  speech_prob = probs[0, speech_idx].mean().item()
281
  self.cached_clip_prob = float(speech_prob)
 
 
 
 
 
282
 
283
+ return VADResult(self.cached_clip_prob, self.cached_clip_prob > 0.5, self.model_name, time.time() - start_time, timestamp)
284
  except Exception as e:
 
285
  return VADResult(0.0, False, f"{self.model_name} (error)", time.time() - start_time, timestamp)
286
 
287
  # ===== AUDIO PROCESSOR =====
 
290
  def __init__(self, sample_rate=16000):
291
  self.sample_rate = sample_rate
292
 
293
+ # Consistent windowing for analysis and STFT
294
+ self.window_size = 0.064 # 64 ms
295
+ self.hop_size = 0.016 # 16 ms
296
+ self.n_fft = int(self.sample_rate * self.window_size) # 1024
297
+ self.hop_length = int(self.sample_rate * self.hop_size) # 256
298
+
299
  self.n_mels = 128
300
  self.fmin = 20
301
  self.fmax = 8000
302
 
 
 
 
 
 
 
 
303
  def process_audio(self, audio):
304
+ if audio is None: return np.array([])
 
 
305
  try:
306
+ sample_rate, audio_data = audio
307
+ if sample_rate != self.sample_rate and LIBROSA_AVAILABLE:
308
+ audio_data = librosa.resample(audio_data.astype(float), orig_sr=sample_rate, target_sr=self.sample_rate)
309
+ if len(audio_data.shape) > 1: audio_data = audio_data.mean(axis=1)
310
+ if np.max(np.abs(audio_data)) > 0: audio_data /= np.max(np.abs(audio_data))
 
 
 
 
 
 
 
 
 
 
311
  return audio_data
 
312
  except Exception as e:
 
313
  return np.array([])
314
 
315
  def compute_high_res_spectrogram(self, audio_data):
316
  try:
317
  if LIBROSA_AVAILABLE and len(audio_data) > 0:
318
+ stft = librosa.stft(audio_data, n_fft=self.n_fft, hop_length=self.hop_length, center=False)
319
+ mel_spec = librosa.feature.melspectrogram(S=np.abs(stft)**2, sr=self.sample_rate, n_fft=self.n_fft, hop_length=self.hop_length, n_mels=self.n_mels)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
  mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
321
+ time_frames = librosa.times_like(mel_spec_db, sr=self.sample_rate, hop_length=self.hop_length, n_fft=self.n_fft)
 
 
322
  return mel_spec_db, time_frames
323
+ return np.array([[]]), np.array([])
 
 
 
 
 
 
 
 
 
 
 
324
  except Exception as e:
325
+ return np.array([[]]), np.array([])
 
 
 
326
 
327
  def detect_onset_offset_advanced(self, vad_results: List[VADResult], threshold: float = 0.5) -> List[OnsetOffset]:
328
  onsets_offsets = []
329
+ models = {res.model_name for res in vad_results}
330
 
331
+ for model_name in models:
332
+ results = sorted([r for r in vad_results if r.model_name == model_name], key=lambda x: x.timestamp)
333
+ if len(results) < 2: continue
 
 
 
 
 
 
 
 
 
 
 
334
 
335
  timestamps = np.array([r.timestamp for r in results])
336
  probabilities = np.array([r.probability for r in results])
337
 
338
+ # Smooth probabilities to prevent brief drops from creating false offsets
339
+ probs_smooth = np.convolve(probabilities, np.ones(3)/3, mode='same')
340
+
341
+ upper = threshold
342
+ lower = threshold * 0.5 # Hysteresis lower bound
343
+
344
+ in_speech = False
345
+ onset_time = -1
346
+ for i, prob in enumerate(probs_smooth):
347
+ if not in_speech and prob > upper:
348
+ in_speech = True
349
+ onset_time = timestamps[i]
350
+ elif in_speech and prob < lower:
351
+ in_speech = False
352
+ onsets_offsets.append(OnsetOffset(onset_time, timestamps[i], model_name, np.mean(probabilities[(timestamps >= onset_time) & (timestamps <= timestamps[i])])))
353
+ if in_speech:
354
+ onsets_offsets.append(OnsetOffset(onset_time, timestamps[-1], model_name, np.mean(probabilities[timestamps >= onset_time])))
355
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
356
  return onsets_offsets
357
 
358
+ # ===== VISUALIZATION =====
359
 
360
  def create_realtime_plot(audio_data: np.ndarray, vad_results: List[VADResult],
361
  onsets_offsets: List[OnsetOffset], processor: AudioProcessor,
362
  model_a: str, model_b: str, threshold: float):
363
 
364
+ if not PLOTLY_AVAILABLE or len(audio_data) == 0: return go.Figure()
365
+
366
+ mel_spec_db, time_frames = processor.compute_high_res_spectrogram(audio_data)
367
+ if mel_spec_db.size == 0: return go.Figure()
368
 
369
+ fig = make_subplots(rows=2, cols=1, subplot_titles=(f"Model A: {model_a}", f"Model B: {model_b}"),
370
+ vertical_spacing=0.05, shared_xaxes=True, specs=[[{"secondary_y": True}], [{"secondary_y": True}]])
371
+
372
+ heatmap_args = dict(z=mel_spec_db, x=time_frames, y=np.linspace(processor.fmin, processor.fmax, processor.n_mels),
373
+ colorscale='Viridis', showscale=False)
374
+ fig.add_trace(go.Heatmap(**heatmap_args, name=f'Spectrogram {model_a}'), row=1, col=1)
375
+ fig.add_trace(go.Heatmap(**heatmap_args, name=f'Spectrogram {model_b}'), row=2, col=1)
376
+
377
+ data_a = [r for r in vad_results if r.model_name.startswith(model_a)]
378
+ data_b = [r for r in vad_results if r.model_name.startswith(model_b)]
379
+
380
+ if data_a: fig.add_trace(go.Scatter(x=[r.timestamp for r in data_a], y=[r.probability for r in data_a], mode='lines', line=dict(color='yellow', width=3), name=f'{model_a} Prob.'), row=1, col=1, secondary_y=True)
381
+ if data_b: fig.add_trace(go.Scatter(x=[r.timestamp for r in data_b], y=[r.probability for r in data_b], mode='lines', line=dict(color='orange', width=3), name=f'{model_b} Prob.'), row=2, col=1, secondary_y=True)
382
+
383
+ # Draw threshold line on the secondary y-axis
384
+ fig.add_hline(y=threshold, line=dict(color='cyan', width=2, dash='dash'), row=1, col=1, secondary_y=True)
385
+ fig.add_hline(y=threshold, line=dict(color='cyan', width=2, dash='dash'), row=2, col=1, secondary_y=True)
386
+
387
+ events_a = [e for e in onsets_offsets if e.model_name.startswith(model_a)]
388
+ events_b = [e for e in onsets_offsets if e.model_name.startswith(model_b)]
389
+
390
+ for event in events_a:
391
+ fig.add_vline(x=event.onset_time, line=dict(color='lime', width=3), row=1, col=1)
392
+ fig.add_vline(x=event.offset_time, line=dict(color='red', width=3), row=1, col=1)
393
+ for event in events_b:
394
+ fig.add_vline(x=event.offset_time, line=dict(color='red', width=3), row=2, col=1)
395
+ fig.add_vline(x=event.onset_time, line=dict(color='lime', width=3), row=2, col=1)
396
+
397
+ fig.update_layout(height=600, title_text="Real-Time Speech Visualizer", plot_bgcolor='black', paper_bgcolor='white', font_color='black')
398
+ fig.update_yaxes(title_text="Frequency (Hz)", range=[processor.fmin, processor.fmax], secondary_y=False)
399
+ fig.update_yaxes(title_text="Probability", range=[0, 1], secondary_y=True) # Apply to all secondary axes
400
+ fig.update_xaxes(title_text="Time (seconds)", row=2, col=1)
401
+
402
+ return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
403
 
404
  # ===== MAIN APPLICATION =====
405
 
406
  class VADDemo:
407
  def __init__(self):
 
408
  self.processor = AudioProcessor()
409
  self.models = {
410
+ 'Silero-VAD': OptimizedSileroVAD(), 'WebRTC-VAD': OptimizedWebRTCVAD(),
411
+ 'E-PANNs': OptimizedEPANNs(), 'PANNs': OptimizedPANNs(), 'AST': OptimizedAST()
 
 
 
412
  }
413
+ print("🎤 VAD Demo initialized with all modules.")
414
 
415
  def process_audio_with_events(self, audio, model_a, model_b, threshold):
416
+ if audio is None: return None, "🔇 No audio detected", "Ready..."
 
 
 
 
 
 
 
417
 
418
+ try:
419
  processed_audio = self.processor.process_audio(audio)
420
+ if len(processed_audio) == 0: return None, "Audio empty", "No data"
 
421
 
422
+ # Reset caches and states for new clip
423
+ for model in self.models.values():
424
+ if hasattr(model, 'cached_clip_prob'): model.cached_clip_prob = None
425
+ if hasattr(model, 'reset_states'): model.reset_states()
426
+
427
+ # Pre-compute for heavy models once
428
+ if 'PANNs' in self.models:
429
+ audio_32k = librosa.resample(processed_audio, orig_sr=self.processor.sample_rate, target_sr=32000)
430
+ self.models['PANNs'].predict(audio_32k, 0.0)
431
+ if 'AST' in self.models:
432
+ self.models['AST'].predict(processed_audio, 0.0)
 
 
433
 
434
+ # Main analysis loop with consistent windowing
 
 
435
  vad_results = []
436
+ window = int(self.processor.sample_rate * self.processor.window_size) # 1024
437
+ hop = int(self.processor.sample_rate * self.hop_size) # 256
438
+ silero_chunk_size = 512 # Silero specific requirement
439
 
440
+ for i in range(0, len(processed_audio) - window + 1, hop):
441
  timestamp = i / self.processor.sample_rate
442
+ chunk_1024 = processed_audio[i : i + window]
443
 
444
+ # Prepare chunk for Silero (last 512 samples of the current window)
445
+ chunk_512 = chunk_1024[-silero_chunk_size:]
446
+
447
+ for model_name in list(set([model_a, model_b])):
448
+ model = self.models[model_name]
449
+ # Feed correct chunk to each model type
450
+ if model_name == 'Silero-VAD':
451
+ current_chunk = chunk_512
452
+ else:
453
+ current_chunk = chunk_1024 # For WebRTC, E-PANNs, and cached models
454
+
455
+ result = model.predict(current_chunk, timestamp)
456
  result.is_speech = result.probability > threshold
457
  vad_results.append(result)
458
 
459
  onsets_offsets = self.processor.detect_onset_offset_advanced(vad_results, threshold)
460
+ fig = create_realtime_plot(processed_audio, vad_results, onsets_offsets, self.processor, model_a, model_b, threshold)
461
 
462
+ status_msg = f"🎙️ Speech detected" if any(e.offset_time > e.onset_time for e in onsets_offsets) else "🔇 No speech detected"
463
+ details_text = f"Analyzed {len(processed_audio)/self.processor.sample_rate:.2f}s. Found {len(onsets_offsets)} speech events."
 
 
 
 
464
 
465
  return fig, status_msg, details_text
 
466
  except Exception as e:
 
467
  import traceback
468
+ traceback.print_exc()
469
  return None, f"❌ Error: {e}", traceback.format_exc()
470
 
471
+ # Initialize and create interface
472
  demo_app = VADDemo()
473
+ interface = create_interface() # Using the original full interface
474
+ interface.launch(share=True, debug=False)