Gabriel Bibbó commited on
Commit
96f8e9f
·
1 Parent(s): 5bbaead

Hotfix: Restore basic functionality - fix AST saturation and PANNs execution

Browse files
Files changed (1) hide show
  1. app.py +106 -63
app.py CHANGED
@@ -237,6 +237,16 @@ class OptimizedEPANNs:
237
  orig_sr=16000,
238
  target_sr=self.sample_rate)
239
 
 
 
 
 
 
 
 
 
 
 
240
  mel_spec = librosa.feature.melspectrogram(y=audio_resampled, sr=self.sample_rate, n_mels=64)
241
  energy = np.mean(librosa.power_to_db(mel_spec, ref=np.max))
242
  spectral_centroid = np.mean(librosa.feature.spectral_centroid(y=audio_resampled, sr=self.sample_rate))
@@ -317,10 +327,15 @@ class OptimizedPANNs:
317
  audio
318
  )
319
 
320
- # Ensure minimum length for PANNs (need at least 1 second)
321
  min_samples = self.sample_rate # 1 second
322
  if len(audio_resampled) < min_samples:
323
- audio_resampled = np.pad(audio_resampled, (0, min_samples - len(audio_resampled)), 'constant')
 
 
 
 
 
324
 
325
  clip_probs, _ = self.model.inference(audio_resampled[np.newaxis, :],
326
  input_sr=self.sample_rate)
@@ -373,8 +388,15 @@ class OptimizedAST:
373
  self.feature_extractor = ASTFeatureExtractor.from_pretrained(model_name)
374
  self.model = ASTForAudioClassification.from_pretrained(model_name)
375
  self.model.to(self.device)
 
 
 
 
 
 
 
 
376
  self.model.eval()
377
- print(f"✅ {self.model_name} loaded successfully")
378
  else:
379
  print(f"⚠️ {self.model_name} not available, using fallback")
380
  self.model = None
@@ -421,45 +443,50 @@ class OptimizedAST:
421
  if len(audio.shape) > 1:
422
  audio = audio.mean(axis=1)
423
 
424
- # Use longer context for AST - preferably 2 seconds
425
- if full_audio is not None and len(full_audio) >= 2 * self.sample_rate:
426
- # Take 2-second window centered around current timestamp
427
  center_pos = int(timestamp * self.sample_rate)
428
- window_size = self.sample_rate # 1 second each side
429
 
430
  start_pos = max(0, center_pos - window_size)
431
  end_pos = min(len(full_audio), center_pos + window_size)
432
 
433
- # Ensure we have at least 2 seconds
434
- if end_pos - start_pos < 2 * self.sample_rate:
435
- end_pos = min(len(full_audio), start_pos + 2 * self.sample_rate)
436
- if end_pos - start_pos < 2 * self.sample_rate:
437
- start_pos = max(0, end_pos - 2 * self.sample_rate)
438
 
439
  audio_for_ast = full_audio[start_pos:end_pos]
440
  else:
441
  audio_for_ast = audio
442
 
443
- # Ensure minimum length for AST (2 seconds preferred, minimum 1 second)
444
- min_samples = 2 * self.sample_rate # 2 seconds
445
  if len(audio_for_ast) < min_samples:
446
  audio_for_ast = np.pad(audio_for_ast, (0, min_samples - len(audio_for_ast)), 'constant')
447
 
448
- # Truncate if too long (AST can handle up to ~10s, but we'll use 3s max for efficiency)
449
- max_samples = 3 * self.sample_rate
450
  if len(audio_for_ast) > max_samples:
451
  audio_for_ast = audio_for_ast[:max_samples]
452
 
453
- # Feature extraction with proper AST parameters
454
  inputs = self.feature_extractor(
455
  audio_for_ast,
456
  sampling_rate=self.sample_rate,
457
  return_tensors="pt",
458
  max_length=1024, # Proper AST context
 
459
  truncation=True
460
  )
461
 
 
462
  inputs = {k: v.to(self.device) for k, v in inputs.items()}
 
 
 
463
 
464
  with torch.no_grad():
465
  outputs = self.model(**inputs)
@@ -477,21 +504,23 @@ class OptimizedAST:
477
 
478
  if speech_indices:
479
  speech_prob = probs[0, speech_indices].mean().item()
480
- # Apply more reasonable thresholding for AST
481
- if speech_prob < 0.1 and np.sum(audio_for_ast ** 2) > 0.001:
482
- speech_prob = min(speech_prob * 3, 0.7) # Moderate boost, cap at 0.7
483
  else:
484
  # Fallback to energy-based detection with higher threshold
485
  energy = np.sum(audio_for_ast ** 2) / len(audio_for_ast) # Normalize by length
486
  speech_prob = min(energy * 50, 1.0)
487
 
488
- result = VADResult(float(speech_prob), speech_prob > 0.4, self.model_name, time.time()-start_time, timestamp)
 
 
489
 
490
  # Cache the result
491
  self.prediction_cache[cache_key] = result
492
 
493
- # Clean old cache entries (keep only last 10 seconds)
494
- cache_keys_to_remove = [k for k in self.prediction_cache.keys() if k < cache_key - 10]
495
  for k in cache_keys_to_remove:
496
  del self.prediction_cache[k]
497
 
@@ -523,16 +552,25 @@ class AudioProcessor:
523
  self.fmin = 20
524
  self.fmax = 8000
525
 
526
- self.window_size = 0.064
527
- self.hop_size = 0.032
 
 
 
 
 
 
 
 
 
528
 
529
  # Model-specific hop sizes for efficiency
530
  self.model_hop_sizes = {
531
  "Silero-VAD": 0.032,
532
  "WebRTC-VAD": 0.03,
533
- "E-PANNs": 1.0,
534
- "PANNs": 1.0,
535
- "AST": 1.0 # Process AST only once per second
536
  }
537
 
538
  self.delay_compensation = 0.0
@@ -697,8 +735,8 @@ class AudioProcessor:
697
  if len(audio_data) == 0 or len(vad_results) == 0:
698
  return 0.0
699
 
700
- window_size = int(self.sample_rate * self.window_size)
701
- hop_size = int(self.sample_rate * self.hop_size)
702
 
703
  energy_signal = []
704
  for i in range(0, len(audio_data) - window_size, hop_size):
@@ -715,14 +753,14 @@ class AudioProcessor:
715
  vad_times = np.array([r.timestamp for r in vad_results])
716
  vad_probs = np.array([r.probability for r in vad_results])
717
 
718
- energy_times = np.arange(len(energy_signal)) * self.hop_size
719
  vad_interp = np.interp(energy_times, vad_times, vad_probs)
720
  vad_interp = (vad_interp - np.mean(vad_interp)) / (np.std(vad_interp) + 1e-8)
721
 
722
  if len(energy_signal) > 10 and len(vad_interp) > 10:
723
  correlation = np.correlate(energy_signal, vad_interp, mode='full')
724
  delay_samples = np.argmax(correlation) - len(vad_interp) + 1
725
- delay_seconds = delay_samples * self.hop_size
726
 
727
  max_corr = np.max(correlation) / (len(vad_interp) * np.std(energy_signal) * np.std(vad_interp))
728
  if max_corr > self.correlation_threshold:
@@ -804,20 +842,23 @@ def create_realtime_plot(audio_data: np.ndarray, vad_results: List[VADResult],
804
  model_b_data = {'times': [], 'probs': []}
805
 
806
  for result in vad_results:
807
- if result.model_name.startswith(model_a):
 
 
808
  model_a_data['times'].append(result.timestamp)
809
  model_a_data['probs'].append(result.probability)
810
- elif result.model_name.startswith(model_b):
811
  model_b_data['times'].append(result.timestamp)
812
  model_b_data['probs'].append(result.probability)
813
 
814
- if len(model_a_data['times']) > 1:
815
  fig.add_trace(
816
  go.Scatter(
817
  x=model_a_data['times'],
818
  y=model_a_data['probs'],
819
- mode='lines',
820
  line=dict(color='yellow', width=3),
 
821
  name=f'{model_a} Probability',
822
  hovertemplate='Time: %{x:.2f}s<br>Probability: %{y:.3f}<extra></extra>',
823
  showlegend=True
@@ -825,13 +866,14 @@ def create_realtime_plot(audio_data: np.ndarray, vad_results: List[VADResult],
825
  row=1, col=1, secondary_y=True
826
  )
827
 
828
- if len(model_b_data['times']) > 1:
829
  fig.add_trace(
830
  go.Scatter(
831
  x=model_b_data['times'],
832
  y=model_b_data['probs'],
833
- mode='lines',
834
  line=dict(color='orange', width=3),
 
835
  name=f'{model_b} Probability',
836
  hovertemplate='Time: %{x:.2f}s<br>Probability: %{y:.3f}<extra></extra>',
837
  showlegend=True
@@ -959,30 +1001,30 @@ class VADDemo:
959
  if len(processed_audio) == 0:
960
  return None, "🎵 Processing audio...", "No audio data processed"
961
 
962
- window_samples = int(self.processor.sample_rate * self.processor.window_size)
963
- hop_samples = int(self.processor.sample_rate * self.processor.hop_size)
964
  vad_results = []
965
-
966
  selected_models = list(set([model_a, model_b]))
967
 
968
- # Process each window with model-specific hop sizes for efficiency
969
- for i in range(0, len(processed_audio) - window_samples, hop_samples):
970
- timestamp = i / self.processor.sample_rate
971
- chunk = processed_audio[i:i + window_samples]
972
-
973
- for model_name in selected_models:
974
- if model_name in self.models:
975
- # Check if we should process this model at this timestamp
976
- model_hop = self.processor.model_hop_sizes.get(model_name, self.processor.hop_size)
977
- if i % int(model_hop * self.processor.sample_rate) == 0:
978
- # Special handling for AST - pass full audio for context
979
- if model_name == 'AST':
980
- result = self.models[model_name].predict(chunk, timestamp, full_audio=processed_audio)
981
- else:
982
- result = self.models[model_name].predict(chunk, timestamp)
983
-
984
- result.is_speech = result.probability > threshold
985
- vad_results.append(result)
 
 
 
986
 
987
  delay_compensation = self.processor.estimate_delay_compensation(processed_audio, vad_results)
988
  onsets_offsets = self.processor.detect_onset_offset_advanced(vad_results, threshold)
@@ -993,17 +1035,18 @@ class VADDemo:
993
  )
994
 
995
  speech_detected = any(result.is_speech for result in vad_results)
996
- total_speech_time = sum(1 for r in vad_results if r.is_speech) * self.processor.hop_size
997
 
998
  if speech_detected:
999
- status_msg = f"🎙️ SPEECH DETECTED - {total_speech_time:.1f}s total"
1000
  else:
1001
  status_msg = f"🔇 No speech detected"
1002
 
1003
  # Simplified details
1004
  model_summaries = {}
1005
  for result in vad_results:
1006
- name = result.model_name.split(' ')[0]
 
1007
  if name not in model_summaries:
1008
  model_summaries[name] = {'probs': [], 'speech_chunks': 0, 'total_chunks': 0}
1009
  summary = model_summaries[name]
@@ -1096,7 +1139,7 @@ def create_interface():
1096
 
1097
  model_b = gr.Dropdown(
1098
  choices=["Silero-VAD", "WebRTC-VAD", "E-PANNs", "PANNs", "AST"],
1099
- value="AST",
1100
  label="Model B (Bottom Panel)"
1101
  )
1102
 
 
237
  orig_sr=16000,
238
  target_sr=self.sample_rate)
239
 
240
+ # Ensure minimum length (1 second) using wrap mode instead of zero padding
241
+ min_samples = self.sample_rate # 1 second
242
+ if len(audio_resampled) < min_samples:
243
+ if LIBROSA_AVAILABLE:
244
+ audio_resampled = librosa.util.fix_length(audio_resampled, size=min_samples, mode='wrap')
245
+ else:
246
+ # Fallback: repeat the signal
247
+ repeat_factor = int(np.ceil(min_samples / len(audio_resampled)))
248
+ audio_resampled = np.tile(audio_resampled, repeat_factor)[:min_samples]
249
+
250
  mel_spec = librosa.feature.melspectrogram(y=audio_resampled, sr=self.sample_rate, n_mels=64)
251
  energy = np.mean(librosa.power_to_db(mel_spec, ref=np.max))
252
  spectral_centroid = np.mean(librosa.feature.spectral_centroid(y=audio_resampled, sr=self.sample_rate))
 
327
  audio
328
  )
329
 
330
+ # Ensure minimum length for PANNs (1 second) using wrap mode instead of zero padding
331
  min_samples = self.sample_rate # 1 second
332
  if len(audio_resampled) < min_samples:
333
+ if LIBROSA_AVAILABLE:
334
+ audio_resampled = librosa.util.fix_length(audio_resampled, size=min_samples, mode='wrap')
335
+ else:
336
+ # Fallback: repeat the signal
337
+ repeat_factor = int(np.ceil(min_samples / len(audio_resampled)))
338
+ audio_resampled = np.tile(audio_resampled, repeat_factor)[:min_samples]
339
 
340
  clip_probs, _ = self.model.inference(audio_resampled[np.newaxis, :],
341
  input_sr=self.sample_rate)
 
388
  self.feature_extractor = ASTFeatureExtractor.from_pretrained(model_name)
389
  self.model = ASTForAudioClassification.from_pretrained(model_name)
390
  self.model.to(self.device)
391
+
392
+ # Use FP16 for faster inference on GPU
393
+ if self.device.type == 'cuda':
394
+ self.model = self.model.half()
395
+ print(f"✅ {self.model_name} loaded with FP16 optimization")
396
+ else:
397
+ print(f"✅ {self.model_name} loaded successfully")
398
+
399
  self.model.eval()
 
400
  else:
401
  print(f"⚠️ {self.model_name} not available, using fallback")
402
  self.model = None
 
443
  if len(audio.shape) > 1:
444
  audio = audio.mean(axis=1)
445
 
446
+ # Use longer context for AST - preferably 4 seconds for better performance
447
+ if full_audio is not None and len(full_audio) >= 4 * self.sample_rate:
448
+ # Take 4-second window centered around current timestamp
449
  center_pos = int(timestamp * self.sample_rate)
450
+ window_size = 2 * self.sample_rate # 2 seconds each side
451
 
452
  start_pos = max(0, center_pos - window_size)
453
  end_pos = min(len(full_audio), center_pos + window_size)
454
 
455
+ # Ensure we have at least 4 seconds
456
+ if end_pos - start_pos < 4 * self.sample_rate:
457
+ end_pos = min(len(full_audio), start_pos + 4 * self.sample_rate)
458
+ if end_pos - start_pos < 4 * self.sample_rate:
459
+ start_pos = max(0, end_pos - 4 * self.sample_rate)
460
 
461
  audio_for_ast = full_audio[start_pos:end_pos]
462
  else:
463
  audio_for_ast = audio
464
 
465
+ # Ensure minimum length for AST (4 seconds preferred, minimum 2 seconds)
466
+ min_samples = 4 * self.sample_rate # 4 seconds for better performance
467
  if len(audio_for_ast) < min_samples:
468
  audio_for_ast = np.pad(audio_for_ast, (0, min_samples - len(audio_for_ast)), 'constant')
469
 
470
+ # Truncate if too long (AST can handle up to ~10s, but we'll use 5s max for efficiency)
471
+ max_samples = 5 * self.sample_rate
472
  if len(audio_for_ast) > max_samples:
473
  audio_for_ast = audio_for_ast[:max_samples]
474
 
475
+ # Feature extraction with proper AST parameters (closer to 1024 frames)
476
  inputs = self.feature_extractor(
477
  audio_for_ast,
478
  sampling_rate=self.sample_rate,
479
  return_tensors="pt",
480
  max_length=1024, # Proper AST context
481
+ padding="max_length", # Ensure consistent length
482
  truncation=True
483
  )
484
 
485
+ # Move inputs to correct device and dtype
486
  inputs = {k: v.to(self.device) for k, v in inputs.items()}
487
+ if self.device.type == 'cuda' and hasattr(self.model, 'half'):
488
+ # Convert inputs to FP16 if model is in FP16
489
+ inputs = {k: v.half() if v.dtype == torch.float32 else v for k, v in inputs.items()}
490
 
491
  with torch.no_grad():
492
  outputs = self.model(**inputs)
 
504
 
505
  if speech_indices:
506
  speech_prob = probs[0, speech_indices].mean().item()
507
+ # Apply more reasonable thresholding for AST with lower threshold
508
+ if speech_prob < 0.15 and np.sum(audio_for_ast ** 2) > 0.001:
509
+ speech_prob = min(speech_prob * 2.5, 0.6) # Moderate boost, cap at 0.6
510
  else:
511
  # Fallback to energy-based detection with higher threshold
512
  energy = np.sum(audio_for_ast ** 2) / len(audio_for_ast) # Normalize by length
513
  speech_prob = min(energy * 50, 1.0)
514
 
515
+ # Use lower threshold specifically for AST (0.25 instead of 0.4)
516
+ is_speech_ast = speech_prob > 0.25
517
+ result = VADResult(float(speech_prob), is_speech_ast, self.model_name, time.time()-start_time, timestamp)
518
 
519
  # Cache the result
520
  self.prediction_cache[cache_key] = result
521
 
522
+ # Clean old cache entries (keep only last 30 seconds for longer sessions)
523
+ cache_keys_to_remove = [k for k in self.prediction_cache.keys() if k < cache_key - 30]
524
  for k in cache_keys_to_remove:
525
  del self.prediction_cache[k]
526
 
 
552
  self.fmin = 20
553
  self.fmax = 8000
554
 
555
+ self.base_window = 0.064
556
+ self.base_hop = 0.032
557
+
558
+ # Model-specific window sizes (each model gets appropriate context)
559
+ self.model_windows = {
560
+ "Silero-VAD": 0.064, # 64ms as required
561
+ "WebRTC-VAD": 0.03, # 30ms frames
562
+ "E-PANNs": 1.0, # 1 second minimum
563
+ "PANNs": 1.0, # 1 second minimum
564
+ "AST": 2.0 # 2 seconds for better performance
565
+ }
566
 
567
  # Model-specific hop sizes for efficiency
568
  self.model_hop_sizes = {
569
  "Silero-VAD": 0.032,
570
  "WebRTC-VAD": 0.03,
571
+ "E-PANNs": 0.5, # Process every 0.5s
572
+ "PANNs": 0.5, # Process every 0.5s
573
+ "AST": 0.5 # Process every 0.5s
574
  }
575
 
576
  self.delay_compensation = 0.0
 
735
  if len(audio_data) == 0 or len(vad_results) == 0:
736
  return 0.0
737
 
738
+ window_size = int(self.sample_rate * self.base_window)
739
+ hop_size = int(self.sample_rate * self.base_hop)
740
 
741
  energy_signal = []
742
  for i in range(0, len(audio_data) - window_size, hop_size):
 
753
  vad_times = np.array([r.timestamp for r in vad_results])
754
  vad_probs = np.array([r.probability for r in vad_results])
755
 
756
+ energy_times = np.arange(len(energy_signal)) * self.base_hop
757
  vad_interp = np.interp(energy_times, vad_times, vad_probs)
758
  vad_interp = (vad_interp - np.mean(vad_interp)) / (np.std(vad_interp) + 1e-8)
759
 
760
  if len(energy_signal) > 10 and len(vad_interp) > 10:
761
  correlation = np.correlate(energy_signal, vad_interp, mode='full')
762
  delay_samples = np.argmax(correlation) - len(vad_interp) + 1
763
+ delay_seconds = delay_samples * self.base_hop
764
 
765
  max_corr = np.max(correlation) / (len(vad_interp) * np.std(energy_signal) * np.std(vad_interp))
766
  if max_corr > self.correlation_threshold:
 
842
  model_b_data = {'times': [], 'probs': []}
843
 
844
  for result in vad_results:
845
+ # Fix model name filtering - remove suffixes like (cached), (fallback), (error)
846
+ model_base_name = result.model_name.split(' ')[0].split('(')[0]
847
+ if model_base_name == model_a or result.model_name.startswith(model_a):
848
  model_a_data['times'].append(result.timestamp)
849
  model_a_data['probs'].append(result.probability)
850
+ elif model_base_name == model_b or result.model_name.startswith(model_b):
851
  model_b_data['times'].append(result.timestamp)
852
  model_b_data['probs'].append(result.probability)
853
 
854
+ if len(model_a_data['times']) > 0:
855
  fig.add_trace(
856
  go.Scatter(
857
  x=model_a_data['times'],
858
  y=model_a_data['probs'],
859
+ mode='lines+markers', # Add markers to show single points
860
  line=dict(color='yellow', width=3),
861
+ marker=dict(size=6, color='yellow'),
862
  name=f'{model_a} Probability',
863
  hovertemplate='Time: %{x:.2f}s<br>Probability: %{y:.3f}<extra></extra>',
864
  showlegend=True
 
866
  row=1, col=1, secondary_y=True
867
  )
868
 
869
+ if len(model_b_data['times']) > 0:
870
  fig.add_trace(
871
  go.Scatter(
872
  x=model_b_data['times'],
873
  y=model_b_data['probs'],
874
+ mode='lines+markers', # Add markers to show single points
875
  line=dict(color='orange', width=3),
876
+ marker=dict(size=6, color='orange'),
877
  name=f'{model_b} Probability',
878
  hovertemplate='Time: %{x:.2f}s<br>Probability: %{y:.3f}<extra></extra>',
879
  showlegend=True
 
1001
  if len(processed_audio) == 0:
1002
  return None, "🎵 Processing audio...", "No audio data processed"
1003
 
 
 
1004
  vad_results = []
 
1005
  selected_models = list(set([model_a, model_b]))
1006
 
1007
+ # Process each model with its specific window and hop size
1008
+ for model_name in selected_models:
1009
+ if model_name in self.models:
1010
+ window_size = self.processor.model_windows[model_name]
1011
+ hop_size = self.processor.model_hop_sizes[model_name]
1012
+
1013
+ window_samples = int(self.processor.sample_rate * window_size)
1014
+ hop_samples = int(self.processor.sample_rate * hop_size)
1015
+
1016
+ for i in range(0, len(processed_audio) - window_samples, hop_samples):
1017
+ timestamp = i / self.processor.sample_rate
1018
+ chunk = processed_audio[i:i + window_samples]
1019
+
1020
+ # Special handling for AST - pass full audio for context
1021
+ if model_name == 'AST':
1022
+ result = self.models[model_name].predict(chunk, timestamp, full_audio=processed_audio)
1023
+ else:
1024
+ result = self.models[model_name].predict(chunk, timestamp)
1025
+
1026
+ result.is_speech = result.probability > threshold
1027
+ vad_results.append(result)
1028
 
1029
  delay_compensation = self.processor.estimate_delay_compensation(processed_audio, vad_results)
1030
  onsets_offsets = self.processor.detect_onset_offset_advanced(vad_results, threshold)
 
1035
  )
1036
 
1037
  speech_detected = any(result.is_speech for result in vad_results)
1038
+ total_speech_chunks = sum(1 for r in vad_results if r.is_speech)
1039
 
1040
  if speech_detected:
1041
+ status_msg = f"🎙️ SPEECH DETECTED - {total_speech_chunks} active chunks"
1042
  else:
1043
  status_msg = f"🔇 No speech detected"
1044
 
1045
  # Simplified details
1046
  model_summaries = {}
1047
  for result in vad_results:
1048
+ # Fix model name filtering - remove suffixes like (cached), (fallback)
1049
+ name = result.model_name.split(' ')[0].split('(')[0]
1050
  if name not in model_summaries:
1051
  model_summaries[name] = {'probs': [], 'speech_chunks': 0, 'total_chunks': 0}
1052
  summary = model_summaries[name]
 
1139
 
1140
  model_b = gr.Dropdown(
1141
  choices=["Silero-VAD", "WebRTC-VAD", "E-PANNs", "PANNs", "AST"],
1142
+ value="PANNs",
1143
  label="Model B (Bottom Panel)"
1144
  )
1145