Gabriel Bibbó commited on
Commit
0ea20e3
·
1 Parent(s): 96f8e9f

Hotfix: Restore basic functionality - fix AST saturation and PANNs execution

Browse files
Files changed (1) hide show
  1. app.py +122 -60
app.py CHANGED
@@ -141,12 +141,18 @@ class OptimizedSileroVAD:
141
  audio = audio.mean(axis=1)
142
 
143
  required_samples = 512
 
144
  if len(audio) != required_samples:
145
  if len(audio) > required_samples:
 
146
  start_idx = (len(audio) - required_samples) // 2
147
  audio_chunk = audio[start_idx:start_idx + required_samples]
148
  else:
149
- audio_chunk = np.pad(audio, (0, required_samples - len(audio)), 'constant')
 
 
 
 
150
  else:
151
  audio_chunk = audio
152
 
@@ -194,7 +200,9 @@ class OptimizedWebRTCVAD:
194
  if len(audio.shape) > 1:
195
  audio = audio.mean(axis=1)
196
 
197
- audio_int16 = (audio * 32767).astype(np.int16)
 
 
198
 
199
  speech_frames = 0
200
  total_frames = 0
@@ -237,8 +245,8 @@ class OptimizedEPANNs:
237
  orig_sr=16000,
238
  target_sr=self.sample_rate)
239
 
240
- # Ensure minimum length (1 second) using wrap mode instead of zero padding
241
- min_samples = self.sample_rate # 1 second
242
  if len(audio_resampled) < min_samples:
243
  if LIBROSA_AVAILABLE:
244
  audio_resampled = librosa.util.fix_length(audio_resampled, size=min_samples, mode='wrap')
@@ -327,8 +335,8 @@ class OptimizedPANNs:
327
  audio
328
  )
329
 
330
- # Ensure minimum length for PANNs (1 second) using wrap mode instead of zero padding
331
- min_samples = self.sample_rate # 1 second
332
  if len(audio_resampled) < min_samples:
333
  if LIBROSA_AVAILABLE:
334
  audio_resampled = librosa.util.fix_length(audio_resampled, size=min_samples, mode='wrap')
@@ -443,32 +451,37 @@ class OptimizedAST:
443
  if len(audio.shape) > 1:
444
  audio = audio.mean(axis=1)
445
 
446
- # Use longer context for AST - preferably 4 seconds for better performance
447
- if full_audio is not None and len(full_audio) >= 4 * self.sample_rate:
448
- # Take 4-second window centered around current timestamp
449
  center_pos = int(timestamp * self.sample_rate)
450
- window_size = 2 * self.sample_rate # 2 seconds each side
451
 
452
  start_pos = max(0, center_pos - window_size)
453
  end_pos = min(len(full_audio), center_pos + window_size)
454
 
455
- # Ensure we have at least 4 seconds
456
- if end_pos - start_pos < 4 * self.sample_rate:
457
- end_pos = min(len(full_audio), start_pos + 4 * self.sample_rate)
458
- if end_pos - start_pos < 4 * self.sample_rate:
459
- start_pos = max(0, end_pos - 4 * self.sample_rate)
460
 
461
  audio_for_ast = full_audio[start_pos:end_pos]
462
  else:
463
  audio_for_ast = audio
464
 
465
- # Ensure minimum length for AST (4 seconds preferred, minimum 2 seconds)
466
- min_samples = 4 * self.sample_rate # 4 seconds for better performance
467
  if len(audio_for_ast) < min_samples:
468
- audio_for_ast = np.pad(audio_for_ast, (0, min_samples - len(audio_for_ast)), 'constant')
 
 
 
 
 
469
 
470
- # Truncate if too long (AST can handle up to ~10s, but we'll use 5s max for efficiency)
471
- max_samples = 5 * self.sample_rate
472
  if len(audio_for_ast) > max_samples:
473
  audio_for_ast = audio_for_ast[:max_samples]
474
 
@@ -557,20 +570,29 @@ class AudioProcessor:
557
 
558
  # Model-specific window sizes (each model gets appropriate context)
559
  self.model_windows = {
560
- "Silero-VAD": 0.064, # 64ms as required
561
  "WebRTC-VAD": 0.03, # 30ms frames
562
- "E-PANNs": 1.0, # 1 second minimum
563
- "PANNs": 1.0, # 1 second minimum
564
- "AST": 2.0 # 2 seconds for better performance
565
  }
566
 
567
  # Model-specific hop sizes for efficiency
568
  self.model_hop_sizes = {
569
- "Silero-VAD": 0.032,
570
- "WebRTC-VAD": 0.03,
571
- "E-PANNs": 0.5, # Process every 0.5s
572
- "PANNs": 0.5, # Process every 0.5s
573
- "AST": 0.5 # Process every 0.5s
 
 
 
 
 
 
 
 
 
574
  }
575
 
576
  self.delay_compensation = 0.0
@@ -822,32 +844,52 @@ def create_realtime_plot(audio_data: np.ndarray, vad_results: List[VADResult],
822
  )
823
 
824
  if len(time_frames) > 0:
825
- # Add threshold lines to both panels
826
- fig.add_hline(
827
- y=threshold,
 
 
828
  line=dict(color='cyan', width=2, dash='dash'),
829
- annotation_text=f'Threshold: {threshold:.2f}',
830
- annotation_position="top right",
831
- row=1, col=1, secondary_y=True
832
  )
833
- fig.add_hline(
834
- y=threshold,
 
 
835
  line=dict(color='cyan', width=2, dash='dash'),
836
- annotation_text=f'Threshold: {threshold:.2f}',
837
- annotation_position="top right",
838
- row=2, col=1, secondary_y=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
839
  )
840
 
841
  model_a_data = {'times': [], 'probs': []}
842
  model_b_data = {'times': [], 'probs': []}
843
 
844
  for result in vad_results:
845
- # Fix model name filtering - remove suffixes like (cached), (fallback), (error)
846
- model_base_name = result.model_name.split(' ')[0].split('(')[0]
847
- if model_base_name == model_a or result.model_name.startswith(model_a):
848
  model_a_data['times'].append(result.timestamp)
849
  model_a_data['probs'].append(result.probability)
850
- elif model_base_name == model_b or result.model_name.startswith(model_b):
851
  model_b_data['times'].append(result.timestamp)
852
  model_b_data['probs'].append(result.probability)
853
 
@@ -881,8 +923,8 @@ def create_realtime_plot(audio_data: np.ndarray, vad_results: List[VADResult],
881
  row=2, col=1, secondary_y=True
882
  )
883
 
884
- model_a_events = [e for e in onsets_offsets if e.model_name.startswith(model_a)]
885
- model_b_events = [e for e in onsets_offsets if e.model_name.startswith(model_b)]
886
 
887
  for event in model_a_events:
888
  if event.onset_time >= 0 and event.onset_time <= time_frames[-1]:
@@ -1009,21 +1051,39 @@ class VADDemo:
1009
  if model_name in self.models:
1010
  window_size = self.processor.model_windows[model_name]
1011
  hop_size = self.processor.model_hop_sizes[model_name]
 
1012
 
1013
  window_samples = int(self.processor.sample_rate * window_size)
1014
  hop_samples = int(self.processor.sample_rate * hop_size)
1015
 
1016
- for i in range(0, len(processed_audio) - window_samples, hop_samples):
 
 
 
 
 
 
 
 
1017
  timestamp = i / self.processor.sample_rate
1018
- chunk = processed_audio[i:i + window_samples]
1019
 
1020
- # Special handling for AST - pass full audio for context
 
 
 
 
 
 
 
 
 
1021
  if model_name == 'AST':
1022
- result = self.models[model_name].predict(chunk, timestamp, full_audio=processed_audio)
1023
  else:
1024
  result = self.models[model_name].predict(chunk, timestamp)
1025
 
1026
- result.is_speech = result.probability > threshold
 
1027
  vad_results.append(result)
1028
 
1029
  delay_compensation = self.processor.estimate_delay_compensation(processed_audio, vad_results)
@@ -1045,30 +1105,32 @@ class VADDemo:
1045
  # Simplified details
1046
  model_summaries = {}
1047
  for result in vad_results:
1048
- # Fix model name filtering - remove suffixes like (cached), (fallback)
1049
- name = result.model_name.split(' ')[0].split('(')[0]
1050
- if name not in model_summaries:
1051
- model_summaries[name] = {'probs': [], 'speech_chunks': 0, 'total_chunks': 0}
1052
- summary = model_summaries[name]
1053
  summary['probs'].append(result.probability)
1054
  summary['total_chunks'] += 1
1055
  if result.is_speech:
1056
  summary['speech_chunks'] += 1
1057
 
1058
- details_lines = [f"**Analysis Results** (Threshold: {threshold:.2f})"]
1059
 
1060
  for model_name, summary in model_summaries.items():
1061
  avg_prob = np.mean(summary['probs']) if summary['probs'] else 0
1062
  speech_ratio = (summary['speech_chunks'] / summary['total_chunks']) if summary['total_chunks'] > 0 else 0
 
1063
 
1064
  status_icon = "🟢" if speech_ratio > 0.5 else "🟡" if speech_ratio > 0.2 else "🔴"
1065
- details_lines.append(f"{status_icon} **{model_name}**: {avg_prob:.3f} avg prob, {speech_ratio*100:.1f}% speech")
1066
 
1067
  if onsets_offsets:
1068
  details_lines.append(f"\n**Speech Events**: {len(onsets_offsets)} detected")
1069
  for i, event in enumerate(onsets_offsets[:5]): # Show first 5 only
1070
  duration = event.offset_time - event.onset_time if event.offset_time > event.onset_time else 0
1071
- details_lines.append(f"• {event.model_name}: {event.onset_time:.2f}s - {event.offset_time:.2f}s ({duration:.2f}s)")
 
1072
 
1073
  details_text = "\n".join(details_lines)
1074
 
@@ -1139,7 +1201,7 @@ def create_interface():
1139
 
1140
  model_b = gr.Dropdown(
1141
  choices=["Silero-VAD", "WebRTC-VAD", "E-PANNs", "PANNs", "AST"],
1142
- value="PANNs",
1143
  label="Model B (Bottom Panel)"
1144
  )
1145
 
 
141
  audio = audio.mean(axis=1)
142
 
143
  required_samples = 512
144
+ # Silero requires exactly 512 samples, handle this precisely
145
  if len(audio) != required_samples:
146
  if len(audio) > required_samples:
147
+ # Take center portion to avoid edge effects
148
  start_idx = (len(audio) - required_samples) // 2
149
  audio_chunk = audio[start_idx:start_idx + required_samples]
150
  else:
151
+ # Pad symmetrically instead of just at the end
152
+ pad_total = required_samples - len(audio)
153
+ pad_left = pad_total // 2
154
+ pad_right = pad_total - pad_left
155
+ audio_chunk = np.pad(audio, (pad_left, pad_right), 'reflect')
156
  else:
157
  audio_chunk = audio
158
 
 
200
  if len(audio.shape) > 1:
201
  audio = audio.mean(axis=1)
202
 
203
+ # Properly convert to int16 with clipping to avoid saturation
204
+ audio_clipped = np.clip(audio, -1.0, 1.0)
205
+ audio_int16 = (audio_clipped * 32767).astype(np.int16)
206
 
207
  speech_frames = 0
208
  total_frames = 0
 
245
  orig_sr=16000,
246
  target_sr=self.sample_rate)
247
 
248
+ # Ensure minimum length (6 seconds) using wrap mode instead of zero padding
249
+ min_samples = 6 * self.sample_rate # 6 seconds
250
  if len(audio_resampled) < min_samples:
251
  if LIBROSA_AVAILABLE:
252
  audio_resampled = librosa.util.fix_length(audio_resampled, size=min_samples, mode='wrap')
 
335
  audio
336
  )
337
 
338
+ # Ensure minimum length for PANNs (10 seconds) using wrap mode instead of zero padding
339
+ min_samples = 10 * self.sample_rate # 10 seconds for optimal performance
340
  if len(audio_resampled) < min_samples:
341
  if LIBROSA_AVAILABLE:
342
  audio_resampled = librosa.util.fix_length(audio_resampled, size=min_samples, mode='wrap')
 
451
  if len(audio.shape) > 1:
452
  audio = audio.mean(axis=1)
453
 
454
+ # Use longer context for AST - preferably 6.4 seconds (1024 frames)
455
+ if full_audio is not None and len(full_audio) >= 6.4 * self.sample_rate:
456
+ # Take 6.4-second window centered around current timestamp
457
  center_pos = int(timestamp * self.sample_rate)
458
+ window_size = int(3.2 * self.sample_rate) # 3.2 seconds each side
459
 
460
  start_pos = max(0, center_pos - window_size)
461
  end_pos = min(len(full_audio), center_pos + window_size)
462
 
463
+ # Ensure we have at least 6.4 seconds
464
+ if end_pos - start_pos < 6.4 * self.sample_rate:
465
+ end_pos = min(len(full_audio), start_pos + int(6.4 * self.sample_rate))
466
+ if end_pos - start_pos < 6.4 * self.sample_rate:
467
+ start_pos = max(0, end_pos - int(6.4 * self.sample_rate))
468
 
469
  audio_for_ast = full_audio[start_pos:end_pos]
470
  else:
471
  audio_for_ast = audio
472
 
473
+ # Ensure minimum length for AST (6.4 seconds for 1024 frames)
474
+ min_samples = int(6.4 * self.sample_rate) # 6.4 seconds
475
  if len(audio_for_ast) < min_samples:
476
+ if LIBROSA_AVAILABLE:
477
+ audio_for_ast = librosa.util.fix_length(audio_for_ast, size=min_samples, mode='wrap')
478
+ else:
479
+ # Fallback: repeat the signal
480
+ repeat_factor = int(np.ceil(min_samples / len(audio_for_ast)))
481
+ audio_for_ast = np.tile(audio_for_ast, repeat_factor)[:min_samples]
482
 
483
+ # Truncate if too long (AST can handle up to ~10s, but we'll use 8s max for efficiency)
484
+ max_samples = 8 * self.sample_rate
485
  if len(audio_for_ast) > max_samples:
486
  audio_for_ast = audio_for_ast[:max_samples]
487
 
 
570
 
571
  # Model-specific window sizes (each model gets appropriate context)
572
  self.model_windows = {
573
+ "Silero-VAD": 0.032, # 32ms exactly as required (512 samples)
574
  "WebRTC-VAD": 0.03, # 30ms frames
575
+ "E-PANNs": 6.0, # 6 seconds minimum for reliable results
576
+ "PANNs": 10.0, # 10 seconds for optimal performance
577
+ "AST": 6.4 # ~6.4 seconds (1024 frames * 6.25ms)
578
  }
579
 
580
  # Model-specific hop sizes for efficiency
581
  self.model_hop_sizes = {
582
+ "Silero-VAD": 0.016, # 16ms hop for Silero
583
+ "WebRTC-VAD": 0.01, # 10ms hop for WebRTC
584
+ "E-PANNs": 1.0, # Process every 1s but with 6s window
585
+ "PANNs": 2.0, # Process every 2s but with 10s window
586
+ "AST": 1.0 # Process every 1s but with 6.4s window
587
+ }
588
+
589
+ # Model-specific thresholds for better detection
590
+ self.model_thresholds = {
591
+ "Silero-VAD": 0.5,
592
+ "WebRTC-VAD": 0.5,
593
+ "E-PANNs": 0.4,
594
+ "PANNs": 0.4,
595
+ "AST": 0.25
596
  }
597
 
598
  self.delay_compensation = 0.0
 
844
  )
845
 
846
  if len(time_frames) > 0:
847
+ # Add threshold lines using add_shape to avoid secondary axis bug
848
+ fig.add_shape(
849
+ type="line",
850
+ x0=time_frames[0], x1=time_frames[-1],
851
+ y0=threshold, y1=threshold,
852
  line=dict(color='cyan', width=2, dash='dash'),
853
+ row=1, col=1,
854
+ yref="y2" # Reference to secondary y-axis
 
855
  )
856
+ fig.add_shape(
857
+ type="line",
858
+ x0=time_frames[0], x1=time_frames[-1],
859
+ y0=threshold, y1=threshold,
860
  line=dict(color='cyan', width=2, dash='dash'),
861
+ row=2, col=1,
862
+ yref="y4" # Reference to secondary y-axis of second subplot
863
+ )
864
+
865
+ # Add threshold annotations
866
+ fig.add_annotation(
867
+ x=time_frames[-1] * 0.95, y=threshold,
868
+ text=f'Threshold: {threshold:.2f}',
869
+ showarrow=False,
870
+ font=dict(color='cyan', size=10),
871
+ row=1, col=1,
872
+ yref="y2"
873
+ )
874
+ fig.add_annotation(
875
+ x=time_frames[-1] * 0.95, y=threshold,
876
+ text=f'Threshold: {threshold:.2f}',
877
+ showarrow=False,
878
+ font=dict(color='cyan', size=10),
879
+ row=2, col=1,
880
+ yref="y4"
881
  )
882
 
883
  model_a_data = {'times': [], 'probs': []}
884
  model_b_data = {'times': [], 'probs': []}
885
 
886
  for result in vad_results:
887
+ # Fix model name filtering - remove suffixes properly and consistently
888
+ base_name = result.model_name.split('(')[0].strip()
889
+ if base_name == model_a:
890
  model_a_data['times'].append(result.timestamp)
891
  model_a_data['probs'].append(result.probability)
892
+ elif base_name == model_b:
893
  model_b_data['times'].append(result.timestamp)
894
  model_b_data['probs'].append(result.probability)
895
 
 
923
  row=2, col=1, secondary_y=True
924
  )
925
 
926
+ model_a_events = [e for e in onsets_offsets if e.model_name.split('(')[0].strip() == model_a]
927
+ model_b_events = [e for e in onsets_offsets if e.model_name.split('(')[0].strip() == model_b]
928
 
929
  for event in model_a_events:
930
  if event.onset_time >= 0 and event.onset_time <= time_frames[-1]:
 
1051
  if model_name in self.models:
1052
  window_size = self.processor.model_windows[model_name]
1053
  hop_size = self.processor.model_hop_sizes[model_name]
1054
+ model_threshold = self.processor.model_thresholds.get(model_name, threshold)
1055
 
1056
  window_samples = int(self.processor.sample_rate * window_size)
1057
  hop_samples = int(self.processor.sample_rate * hop_size)
1058
 
1059
+ # For large models, ensure we have enough audio
1060
+ if len(processed_audio) < window_samples:
1061
+ # If audio is too short, repeat it to reach minimum length
1062
+ repeat_factor = int(np.ceil(window_samples / len(processed_audio)))
1063
+ extended_audio = np.tile(processed_audio, repeat_factor)[:window_samples]
1064
+ else:
1065
+ extended_audio = processed_audio
1066
+
1067
+ for i in range(0, len(extended_audio) - window_samples, hop_samples):
1068
  timestamp = i / self.processor.sample_rate
 
1069
 
1070
+ # Extract window centered around current position
1071
+ start_pos = max(0, i)
1072
+ end_pos = min(len(extended_audio), i + window_samples)
1073
+ chunk = extended_audio[start_pos:end_pos]
1074
+
1075
+ # Ensure chunk has the right length
1076
+ if len(chunk) < window_samples:
1077
+ chunk = np.pad(chunk, (0, window_samples - len(chunk)), 'wrap')
1078
+
1079
+ # Special handling for different models
1080
  if model_name == 'AST':
1081
+ result = self.models[model_name].predict(chunk, timestamp, full_audio=extended_audio)
1082
  else:
1083
  result = self.models[model_name].predict(chunk, timestamp)
1084
 
1085
+ # Use model-specific threshold
1086
+ result.is_speech = result.probability > model_threshold
1087
  vad_results.append(result)
1088
 
1089
  delay_compensation = self.processor.estimate_delay_compensation(processed_audio, vad_results)
 
1105
  # Simplified details
1106
  model_summaries = {}
1107
  for result in vad_results:
1108
+ # Fix model name filtering - remove suffixes properly
1109
+ base_name = result.model_name.split('(')[0].strip()
1110
+ if base_name not in model_summaries:
1111
+ model_summaries[base_name] = {'probs': [], 'speech_chunks': 0, 'total_chunks': 0}
1112
+ summary = model_summaries[base_name]
1113
  summary['probs'].append(result.probability)
1114
  summary['total_chunks'] += 1
1115
  if result.is_speech:
1116
  summary['speech_chunks'] += 1
1117
 
1118
+ details_lines = [f"**Analysis Results** (Global Threshold: {threshold:.2f})"]
1119
 
1120
  for model_name, summary in model_summaries.items():
1121
  avg_prob = np.mean(summary['probs']) if summary['probs'] else 0
1122
  speech_ratio = (summary['speech_chunks'] / summary['total_chunks']) if summary['total_chunks'] > 0 else 0
1123
+ model_thresh = self.processor.model_thresholds.get(model_name, threshold)
1124
 
1125
  status_icon = "🟢" if speech_ratio > 0.5 else "🟡" if speech_ratio > 0.2 else "🔴"
1126
+ details_lines.append(f"{status_icon} **{model_name}**: {avg_prob:.3f} avg prob, {speech_ratio*100:.1f}% speech (thresh: {model_thresh:.2f})")
1127
 
1128
  if onsets_offsets:
1129
  details_lines.append(f"\n**Speech Events**: {len(onsets_offsets)} detected")
1130
  for i, event in enumerate(onsets_offsets[:5]): # Show first 5 only
1131
  duration = event.offset_time - event.onset_time if event.offset_time > event.onset_time else 0
1132
+ event_model = event.model_name.split('(')[0].strip()
1133
+ details_lines.append(f"• {event_model}: {event.onset_time:.2f}s - {event.offset_time:.2f}s ({duration:.2f}s)")
1134
 
1135
  details_text = "\n".join(details_lines)
1136
 
 
1201
 
1202
  model_b = gr.Dropdown(
1203
  choices=["Silero-VAD", "WebRTC-VAD", "E-PANNs", "PANNs", "AST"],
1204
+ value="E-PANNs",
1205
  label="Model B (Bottom Panel)"
1206
  )
1207